Source code for bitorch.layers.register

from typing import List, Iterable, Any, Optional

import torch

from bitorch import runtime_mode_type, RuntimeMode
from bitorch.layers.extensions import LayerImplementation, LayerRegistry

q_linear_registry = LayerRegistry("QLinear")
q_conv1d_registry = LayerRegistry("QConv1d")
q_conv2d_registry = LayerRegistry("QConv2d")
q_conv3d_registry = LayerRegistry("QConv3d")


[docs]def all_layer_registries() -> List[LayerRegistry]: """ Return all layer registries (one for each layer type: QLinear, QConv[1-3]d). Returns: A list of all layer registries. """ return [ q_conv1d_registry, q_conv2d_registry, q_conv3d_registry, q_linear_registry, ]
[docs]def convert_layers_to( new_mode: RuntimeMode, only: Optional[Iterable[Any]] = None, device: Optional[torch.device] = None, verbose: bool = False, ) -> None: """ Convert all wrapped layers (or a given subset of them) to a new mode. Args: new_mode: the new RuntimeMode only: optional white"list" (Iterable) of layers or wrapped layers which should be converted device: the new device for the layers verbose: whether to print which layers are being converted """ for registry in all_layer_registries(): registry.convert_layers_to(new_mode, only, device, verbose)
[docs]class QLinearImplementation(LayerImplementation): """Decorator for :class:`QLinear` implementations."""
[docs] def __init__(self, supports_modes: runtime_mode_type) -> None: """ Args: supports_modes: RuntimeMode(s) that is/are supported by an implementation """ super().__init__(q_linear_registry, supports_modes)
[docs]class QConv1dImplementation(LayerImplementation): """Decorator for :class:`QConv1d` implementations."""
[docs] def __init__(self, supports_modes: runtime_mode_type) -> None: """ Args: supports_modes: RuntimeMode(s) that is/are supported by an implementation """ super().__init__(q_conv1d_registry, supports_modes)
[docs]class QConv2dImplementation(LayerImplementation): """Decorator for :class:`QConv2d` implementations."""
[docs] def __init__(self, supports_modes: runtime_mode_type) -> None: """ Args: supports_modes: RuntimeMode(s) that is/are supported by an implementation """ super().__init__(q_conv2d_registry, supports_modes)
[docs]class QConv3dImplementation(LayerImplementation): """Decorator for :class:`QConv3d` implementations."""
[docs] def __init__(self, supports_modes: runtime_mode_type) -> None: """ Args: supports_modes: RuntimeMode(s) that is/are supported by an implementation """ super().__init__(q_conv3d_registry, supports_modes)