Source code for bitorch.layers.extensions.layer_container

from typing import Any, TypeVar, Type, Generic

from bitorch.layers.extensions.layer_recipe import LayerRecipe

T = TypeVar("T")


[docs]class LayerContainer(Generic[T]): """This class wraps another layer - but the internally contained class can be swapped out during runtime.""" internal_variable_names = [ "_layer_implementation", "_recipe", ] patch = [ "to", ]
[docs] def __init__(self, impl_class: Type[T], *args: Any, **kwargs: Any) -> None: """ Wrap a new object based on the given class, positional arguments, and keyword arguments. Args: impl_class: class of the new object *args: positional arguments of the new object **kwargs: keyword arguments of the new object """ self._layer_implementation = impl_class(*args, **kwargs) self._recipe = LayerRecipe(layer=self, args=args, kwargs=kwargs)
[docs] def replace_layer_implementation(self, new_implementation: T) -> None: """ Replace the internally stored layer object with the given one. Args: new_implementation: new class which should replace the previous implementation. """ self._layer_implementation = new_implementation
def __getattr__(self, item: Any) -> Any: if item in self.internal_variable_names: return self.__dict__[item] attr_value = getattr(self._layer_implementation, item) if attr_value == self._layer_implementation: return self if callable(attr_value) and item in self.patch: # patch return values of all functions/classes defined in self.patch # they should return this LayerContainer instead of themselves # required for e.g. pytorch's .to(device) function other = self class Patch: def __call__(self, *args: Any, **kwargs: Any) -> Any: fn_return_val = attr_value(*args, **kwargs) if fn_return_val == other._layer_implementation: return other return fn_return_val def __getattr__(self, item_: Any) -> Any: return getattr(attr_value, item_) # needed for tests: @property # type: ignore[misc] def __class__(self) -> Any: return attr_value.__class__ return Patch() return attr_value def __repr__(self) -> "str": return f"LayerContainer (at {hex(id(self))}), contains: {self._layer_implementation}"
[docs] def __call__(self, *args: Any, **kwargs: Any) -> Any: return self._layer_implementation(*args, **kwargs) # type:ignore[operator]
def __setattr__(self, key: Any, value: Any) -> None: if key in self.internal_variable_names: self.__dict__[key] = value return setattr(self._layer_implementation, key, value) @property # type: ignore[misc] def __class__(self) -> Type[T]: # type: ignore return self._layer_implementation.__class__ @property def layer_implementation(self) -> T: """ Access the internally wrapped layer object directly. Returns: the internal layer object """ return self._layer_implementation @property def recipe(self) -> LayerRecipe: return self._recipe