Source code for bitorch.layers.config
"""Config class for quantization layers. This file should be imported before the other layers."""
from typing import Union
from bitorch.config import Config
from bitorch.quantizations import quantization_from_name, Quantization
[docs]class LayerConfig(Config):
"""Class to provide layer configurations."""
name = "layer_config"
[docs] def get_quantization_function(self, quantization: Union[str, Quantization]) -> Quantization:
"""Returns the quantization module specified by the given name or object.
Args:
quantization: quantization module or name of quantization function.
Returns:
the quantization module
"""
if isinstance(quantization, Quantization):
return quantization
elif isinstance(quantization, str):
return quantization_from_name(quantization)()
else:
raise ValueError(f"Invalid quantization: {quantization}")
# default quantization to be used in layers for inputs
input_quantization = "sign"
# default quantization to be used in layers for inputs
weight_quantization = "sign"
# toggles print / matplotlib output in debug layers
debug_activated = False
# default padding value used in convolution layers
padding_value = -1.0
# threshold used by qactivation for gradient cancellation
gradient_cancellation_threshold = 1.0
# bits for pact activation function
pact_bits = 4
# config object, global referencable
config = LayerConfig()