Source code for bitorch.quantizations.base

"""Quantization superclass implementation"""

import typing
from typing import Any
from warnings import warn

import torch
from torch import nn
from torch.autograd.function import Function


[docs]class STE(Function): """Straight Through estimator for backward pass"""
[docs] @staticmethod @typing.no_type_check def forward( ctx: torch.autograd.function.BackwardCFunction, # type: ignore input_tensor: torch.Tensor, ) -> torch.Tensor: raise NotImplementedError("Forwards pass of STE should be implemented by subclass.")
[docs] @staticmethod @typing.no_type_check def backward(ctx: Any, output_gradient: torch.Tensor) -> torch.Tensor: """just passes the unchanged output gradient as input gradient. Args: ctx (Any): autograd context output_gradient (torch.Tensor): output gradient Returns: torch.Tensor: the unchanged output gradient """ return output_gradient
[docs]class Quantization(nn.Module): """superclass for quantization modules""" name: str = "None" bit_width: int = -1 @property def bitwidth(self) -> int: warn("Attribute 'bitwidth' is deprecated, use 'bit_width' instead.", DeprecationWarning, stacklevel=2) return self.bit_width
[docs] def quantize(self, x: torch.Tensor) -> torch.Tensor: """Apply the quantization function to the input tensor. It is recommended to use a torch.Function to also manipulate backwards behavior. See the implementations of sign or dorefa quantization functions for more examples. Args: x (torch.Tensor): the input to be quantized Raises: NotImplementedError: raised if quantize function of superclass is called. Returns: torch.Tensor: the quantized tensor """ raise NotImplementedError()
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Quantizes the tensor using this classes quantize-method. Subclasses shall add some semantic there. Args: x (torch.Tensor): tensor to be forwarded. Returns: torch.Tensor: quantized tensor x """ return self.quantize(x)