Source code for bitorch.quantizations
"""
This submodule contains several quantization methods that can be used with our quantized layers to
build quantized models.
If you want to implement a new function, use the :code:`Quantization` base class as superclass.
"""
from typing import List, Type, Dict
from .base import Quantization
from .approx_sign import ApproxSign
from .dorefa import WeightDoReFa, InputDoReFa
from .identity import Identity
from .sign import Sign
from .ste_heaviside import SteHeaviside
from .swish_sign import SwishSign
from .progressive_sign import ProgressiveSign
from .quantization_scheduler import Quantization_Scheduler, ScheduledQuantizer
from ..util import build_lookup_dictionary
__all__ = [
"Quantization",
"quantization_from_name",
"quantization_names",
"register_custom_quantization",
"ApproxSign",
"InputDoReFa",
"WeightDoReFa",
"Identity",
"ProgressiveSign",
"Sign",
"SteHeaviside",
"SwishSign",
"Quantization_Scheduler",
"ScheduledQuantizer",
]
quantizations_by_name: Dict[str, Type[Quantization]] = build_lookup_dictionary(__name__, __all__, Quantization)
[docs]def quantization_from_name(name: str) -> Type[Quantization]:
"""returns the quantization to which the name belongs to (name has to be the value of the quantizations
name-attribute)
Args:
name (str): name of the quantization
Raises:
ValueError: raised if no quantization under that name was found
Returns:
quantization: the quantization
"""
if name not in quantizations_by_name:
raise ValueError(f"{name} quantization not found!")
return quantizations_by_name[name]
[docs]def quantization_names() -> List:
"""getter for list of quantization names for argparse
Returns:
List: the quantization names
"""
return list(quantizations_by_name.keys())
[docs]def register_custom_quantization(custom_quantization: Type[Quantization]) -> None:
"""
Register a custom (external) quantization in bitorch.
Args:
custom_quantization: the custom config which should be added to bitorch
"""
quantizations_by_name[custom_quantization.name] = custom_quantization