Source code for bitorch.layers.extensions.layer_registration
from abc import ABC
from typing import Optional, Any, Type, Union, Tuple, TYPE_CHECKING
import torch
import bitorch
from bitorch import runtime_mode_type, RuntimeMode
from .layer_container import LayerContainer
from .layer_implementation import DefaultImplementationMixin, BaseImplementation, CustomImplementationMixin
from .layer_recipe import LayerRecipe
if TYPE_CHECKING:
from .layer_registry import LayerRegistry
[docs]class LayerImplementation(ABC):
"""
Superclass for storing different implementations of a common layer type.
It registers all decorated classes in the given registry. On creation of a decorated class, it
wraps the created class object in a layer container and stores the arguments used to create the layer.
It also captures which RuntimeMode(s) is/are supported by an implementation.
"""
registry: "LayerRegistry"
class_: Type[BaseImplementation]
class_name: str
_supported_modes: runtime_mode_type
__initialized: bool
[docs] def __init__(self, registry: "LayerRegistry", supported_modes: runtime_mode_type) -> None:
"""
Define an implementation decorator for a certain type of layer. All implementations and objects of this type of
layer are stored in the given registry.
Args:
registry: the registry which should store the implementation and objects of this layer type
supported_modes: the mode supported by the registering implementation
"""
self.registry = registry
assert RuntimeMode.is_combined_mode(supported_modes), f"invalid mode {supported_modes} given"
self._supported_modes = supported_modes
self.__initialized = False
self.class_ = None # type: ignore
self.class_name = ""
[docs] def __call__(
self, *args: Any, **kwargs: Any
) -> Union["LayerImplementation", Type[BaseImplementation], LayerContainer]:
if not self.__initialized:
# this object is called once when @Decorator is used, we need to initialize
return self._initialize(*args, **kwargs)
if bitorch.mode == RuntimeMode.RAW:
return self.class_(*args, **kwargs) # type: ignore
# on later calls we need to provide the correct layer implementation
return self._provide_layer_implementation(*args, **kwargs)
def _initialize(self, class_: Type[BaseImplementation]) -> Union["LayerImplementation", Type[BaseImplementation]]:
self.__initialized = True
self.class_ = class_
self.class_name = self.class_.__name__
self.registry.register(self)
if self._supported_modes == RuntimeMode.DEFAULT:
assert issubclass(
self.class_, DefaultImplementationMixin
), f"{self.class_name} should be a subclass of DefaultLayerImplementation."
# provide this wrapper
return self
else:
assert issubclass(self.class_, CustomImplementationMixin), (
f"{self.class_name} should be a subclass of CustomImplementationInterface (and it should "
f"implement the corresponding class methods)."
)
# after we have registered custom implementations, we do not interfere anymore
return self.class_
def _provide_layer_implementation(self, *args: Any, **kwargs: Any) -> LayerContainer:
correct_layer_implementation = self.registry.get_layer()
if self == correct_layer_implementation:
# this class provides the correct implementation for the current mode (recursion stop)
layer_container = LayerContainer(self.class_, *args, **kwargs)
self.registry.add_recipe(layer_container.recipe)
return layer_container
# call this method again but on the correct base class
return correct_layer_implementation._provide_layer_implementation(*args, **kwargs)
[docs] def supports_mode(self, mode: RuntimeMode) -> bool:
"""
Check whether this layer implementation supports a given RuntimeMode.
Args:
mode: the runtime mode that should be supported
Returns:
True if the given mode is supported, False otherwise
"""
return mode.is_supported_by(self._supported_modes)
def can_create_clone_from(self, recipe: LayerRecipe) -> Tuple[bool, str]:
return self.class_.can_clone(recipe)
def get_replacement(self, recipe: LayerRecipe, device: Optional[torch.device] = None) -> Any:
return self.class_.create_clone_from(recipe, device)
def is_default(self) -> bool:
return self.class_.is_default_implementation()