Source code for bitorch.layers.extensions.layer_registry
from typing import Set, Any, Optional, Iterable
import bitorch
import torch
from bitorch import RuntimeMode
from .layer_container import LayerContainer
from .layer_recipe import LayerRecipe
from .layer_registration import LayerImplementation
[docs]class LayerRegistry:
"""
Stores all available implementations (and their supported modes) for a certain type of layer.
It also wraps these implementations and stores references to them, so they can be replaced easily.
Needs to be subclassed for each type of layer.
"""
[docs] def __init__(self, name: str) -> None:
self.name = name
self._class = None
self.layer_implementations: Set[LayerImplementation] = set()
self._instance_recipes: Set[LayerRecipe] = set()
self.is_replacing = False
@property
def layer_instances(self) -> Set["LayerContainer"]:
return set(x.layer for x in self._instance_recipes)
def get_recipe_for(self, layer: Any) -> Optional["LayerRecipe"]:
if layer not in map(lambda x: x.layer, self._instance_recipes):
return None
return next(filter(lambda x: x.layer == layer, self._instance_recipes))
def get_replacement(self, mode: RuntimeMode, recipe: LayerRecipe, device: Optional[torch.device] = None) -> Any:
layer = self.get_layer(mode, recipe)
return layer.get_replacement(recipe, device)
def add_recipe(self, new_recipe: LayerRecipe) -> None:
if self.is_replacing:
return
self._instance_recipes.add(new_recipe)
def __contains__(self, item: Any) -> bool:
return item.__class__ in map(lambda x: x.class_, self.layer_implementations)
[docs] def register(self, layer: LayerImplementation) -> None:
"""
Register a layer implementaiton in this registry.
Args:
layer: the layer to be registered
"""
self.layer_implementations.add(layer)
[docs] def get_layer(
self, mode: Optional[RuntimeMode] = None, recipe: Optional[LayerRecipe] = None
) -> LayerImplementation:
"""
Get a layer implementation compatible to the given mode and recipe.
If no recipe is given, only compatibility with the mode is checked.
If no mode is given, the current bitorch mode is used.
Args:
mode: mode that the layer implementation should support
recipe: recipe that the layer implementation should be able to copy
Returns:
a LayerImplementation compatible with the given mode and recipe (if available)
"""
if mode is None:
mode = bitorch.mode
available_layers = []
unavailable_layers = []
for implementation in self.layer_implementations:
if not implementation.supports_mode(mode):
continue
if recipe:
return_tuple = implementation.can_create_clone_from(recipe)
if not isinstance(return_tuple, tuple) and len(return_tuple) == 2:
raise RuntimeError(f"{implementation.__class__} returned non-tuple on 'can_create_clone_from'.")
can_be_used, message = return_tuple
if not can_be_used:
unavailable_layers.append(f" {implementation.__class__} unavailable because: {message}")
continue
available_layers.append(implementation)
if len(available_layers) > 1:
RuntimeWarning(f"Multiple layer implementations available for '{self.name}' available (mode='{mode}').")
if len(available_layers) == 0:
base_error = f"No implementations for '{self.name}' available (mode='{mode}')."
if len(unavailable_layers) > 0:
raise RuntimeError("\n".join([base_error] + unavailable_layers))
else:
raise RuntimeError(base_error)
return available_layers[0]
def clear(self) -> None:
while len(self._instance_recipes) > 0:
self._instance_recipes.pop()
def unregister_custom_implementations(self) -> None:
to_remove = list(filter(lambda x: not x.is_default(), self.layer_implementations))
for i in to_remove:
self.layer_implementations.remove(i)
def convert_layers_to(
self,
new_mode: RuntimeMode,
only: Optional[Iterable[Any]] = None,
device: Optional[torch.device] = None,
verbose: bool = False,
) -> None:
for recipe in list(self._instance_recipes):
module = recipe.layer
if only is not None and module.layer_implementation not in only and module not in only:
continue
assert isinstance(module, LayerContainer)
if verbose:
print("| Replacing layer in", module)
replacement_module = self.get_replacement(new_mode, recipe, device)
replacement_module.to(device)
if verbose:
print("- with:", replacement_module)
module.replace_layer_implementation(replacement_module)