Source code for bitorch.quantizations.dorefa

"""Dorefa Function Implementation"""
import torch
import typing
from typing import Any, Tuple, Union
from torch.autograd.function import Function

from .base import Quantization
from .config import config

[docs]class WeightDoReFaFunction(Function):
[docs] @staticmethod @typing.no_type_check def forward( ctx: torch.autograd.function.BackwardCFunction, input_tensor: torch.Tensor, maximum_bit_value: int ) -> torch.Tensor: """quantizes input tensor and forwards it. Args: ctx (Any): autograd context input_tensor (torch.Tensor): input tensor bits (int): number of bits to round the input tensor to Returns: torch.Tensor: the quantized input tensor """ ctx.save_for_backward(input_tensor) squashed_values = torch.tanh(input_tensor) max_val = torch.max(torch.abs(squashed_values)).detach() adjusted_values = squashed_values / (2.0 * max_val) + 0.5 return 2.0 * (torch.round(adjusted_values * maximum_bit_value) / maximum_bit_value) - 1.0
[docs] @staticmethod @typing.no_type_check def backward(ctx: Any, output_gradient: torch.Tensor) -> torch.Tensor: """just passes the unchanged output gradient as input gradient. Args: ctx (Any): autograd context output_gradient (torch.Tensor): output gradient Returns: torch.Tensor: the unchanged output gradient """ return output_gradient, None, None
[docs]class WeightDoReFa(Quantization): """Module for applying the dorefa function on weights. Reference: "DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients" Zouh et al. 2016, """ name = "weightdorefa" bit_width = config.dorefa_bits
[docs] def __init__(self, bits: Union[int, None] = None) -> None: """Initiates quantization bits. Args: bits (int, optional): number of bits to quantize into. Defaults to None. """ super(WeightDoReFa, self).__init__() self.bit_width = bits or config.dorefa_bits self._max_value = 2**self.bit_width - 1
[docs] def quantize(self, x: torch.Tensor) -> torch.Tensor: """DoReFas the tensor to desired bit resolution using weight dorefa. Args: x (torch.Tensor): tensor to be forwarded. Returns: torch.Tensor: DoReFaed tensor x """ return WeightDoReFaFunction.apply(x, self._max_value)
[docs]class InputDoReFaFunction(Function):
[docs] @staticmethod @typing.no_type_check def forward( ctx: torch.autograd.function.BackwardCFunction, input_tensor: torch.Tensor, bits: int # type: ignore ) -> torch.Tensor: """quantizes input tensor and forwards it. Args: ctx (Any): autograd context input_tensor (torch.Tensor): input tensor bits (int): number of bits to round the input tensor to Returns: torch.Tensor: the quantized input tensor """ max_value = 2**bits - 1 quantized_tensor = torch.round(torch.clamp(input_tensor, 0, 1) * max_value) / max_value return quantized_tensor
[docs] @staticmethod @typing.no_type_check def backward(ctx: Any, output_gradient: torch.Tensor) -> Tuple[torch.Tensor, None]: """just passes the unchanged output gradient as input gradient (i.e. applies straight through estimator) Args: ctx (Any): autograd context output_gradient (torch.Tensor): output gradient Returns: torch.Tensor: the unchanged output gradient """ return output_gradient, None
[docs]class InputDoReFa(Quantization): """Module for applying the dorefa function on inputs. Reference: "DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients" Zouh et al. 2016, """ name = "inputdorefa" bit_width = config.dorefa_bits
[docs] def __init__(self, bits: Union[int, None] = None) -> None: """Initiates quantization bits. Args: bits (int, optional): number of bits to quantize into. Defaults to None. """ super(InputDoReFa, self).__init__() self.bit_width = bits or config.dorefa_bits
[docs] def quantize(self, x: torch.Tensor) -> torch.Tensor: """DoReFas the tensor to desired bit resolution. Args: x (torch.Tensor): tensor to be forwarded. Returns: torch.Tensor: DoReFaed tensor x """ return InputDoReFaFunction.apply(x, self.bit_width)