Source code for bitorch.quantizations.approx_sign

"""Sign Function Implementation"""
import torch
from torch.autograd.function import Function
import typing

from .base import Quantization


[docs]class ApproxSignFunction(Function): """ApproxSign Function for input binarization."""
[docs] @staticmethod @typing.no_type_check def forward( ctx: torch.autograd.function.BackwardCFunction, input_tensor: torch.Tensor # type: ignore ) -> torch.Tensor: """Binarize input tensor using the _sign function. Args: input_tensor (tensor): the input values to the Sign function Returns: tensor: binarized input tensor """ ctx.save_for_backward(input_tensor) sign_tensor = torch.sign(input_tensor) sign_tensor = torch.where(sign_tensor == 0, torch.tensor(1.0, device=sign_tensor.device), sign_tensor) return sign_tensor
[docs] @staticmethod @typing.no_type_check def backward( ctx: torch.autograd.function.BackwardCFunction, output_grad: torch.Tensor # type: ignore ) -> torch.Tensor: """Apply approx sign function. used e.g. for birealnet Args: ctx (gradient context): context output_grad (toch.Tensor): the tensor containing the output gradient Returns: torch.Tensor: the input gradient """ input_tensor = ctx.saved_tensors[0] # produces zeros where preactivation inputs exceeded threshold, ones otherwise inside_threshold = torch.abs(input_tensor) <= 1 approx_sign = (2.0 - 2.0 * torch.abs(input_tensor)) * inside_threshold return approx_sign * output_grad
[docs]class ApproxSign(Quantization): """Module for applying the sign function with approx sign in backward pass""" name = "approxsign" bit_width = 1
[docs] def quantize(self, x: torch.Tensor) -> torch.Tensor: """Forwards the tensor through the approx sign function. Args: x (torch.Tensor): tensor to be forwarded. Returns: torch.Tensor: sign of tensor x """ return ApproxSignFunction.apply(x) # type: ignore