Source code for bitorch.layers.extensions.layer_implementation
from abc import ABC
from typing import Optional, Any, Tuple, TYPE_CHECKING
import torch
if TYPE_CHECKING:
from . import LayerRecipe
[docs]class BaseImplementation:
"""Defines the class interface of a custom layer implementation of a certain layer type."""
[docs] def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
[docs] @classmethod
def is_default_implementation(cls) -> bool:
"""
Returns:
bool: whether this implementation is the default implementation of the current layer type
"""
raise NotImplementedError("Should be implemented by subclass.")
[docs] @classmethod
def can_clone(cls, recipe: "LayerRecipe") -> Tuple[bool, str]:
"""
Returns whether this layer class supports the implementation of a given layer recipe.
Args:
recipe (LayerRecipe): the layer which should be checked for cloning
Returns:
Whether the layer can be cloned or not and an info message if it can not be cloned
"""
raise NotImplementedError("A custom layer should implement their own compatibility check.")
[docs] @classmethod
def create_clone_from(cls, recipe: "LayerRecipe", device: Optional[torch.device] = None) -> Any:
"""
Create a new layer based on a given layer recipe (can be expected to be from the default category).
Args:
recipe: the layer which should be cloned
device: the device on which the layer is going to be run
Returns:
A clone of the LayerRecipe in the current class implementation
"""
raise NotImplementedError("A custom layer should implement a method to create a cloned layer.")
[docs]class DefaultImplementationMixin(BaseImplementation, ABC):
"""Defines the class interface of a default layer implementation of a certain layer type."""
[docs] @classmethod
def is_default_implementation(cls) -> bool:
return True
[docs] @classmethod
def can_clone(cls, recipe: "LayerRecipe") -> Tuple[bool, str]:
return True, ""
[docs] @classmethod
def create_clone_from(cls, recipe: "LayerRecipe", device: Optional[torch.device] = None) -> Any:
return cls(*recipe.args, **recipe.kwargs)
[docs]class CustomImplementationMixin(BaseImplementation, ABC):
"""Defines the class interface of a custom layer implementation of a certain layer type."""
[docs] @classmethod
def is_default_implementation(cls) -> bool:
return False
[docs] @classmethod
def can_clone(cls, recipe: "LayerRecipe") -> Tuple[bool, str]:
raise NotImplementedError("A custom layer should implement their own compatibility check.")
[docs] @classmethod
def create_clone_from(cls, recipe: "LayerRecipe", device: Optional[torch.device] = None) -> Any:
raise NotImplementedError("A custom layer should implement a method to create a cloned layer.")