Source code for bitorch.quantizations.swish_sign

"""Sign Function Implementation"""

import torch
from torch.autograd.function import Function
import typing
from typing import Tuple, Union

from .base import Quantization
from .config import config


[docs]class SwishSignFunction(Function): """SwishSign Function for input binarization."""
[docs] @staticmethod @typing.no_type_check def forward( ctx: torch.autograd.function.BackwardCFunction, # type: ignore input_tensor: torch.Tensor, beta: float = 1.0, ) -> torch.Tensor: """Binarize input tensor using the _sign function. Args: input_tensor (tensor): the input values to the Sign function Returns: tensor: binarized input tensor """ ctx.save_for_backward(input_tensor, torch.tensor(beta, device=input_tensor.device)) sign_tensor = torch.sign(input_tensor) sign_tensor = torch.where(sign_tensor == 0, torch.tensor(1.0, device=input_tensor.device), sign_tensor) return sign_tensor
[docs] @staticmethod @typing.no_type_check def backward( ctx: torch.autograd.function.BackwardCFunction, output_grad: torch.Tensor # type: ignore ) -> Tuple[torch.Tensor, None]: """Apply straight through estimator. This passes the output gradient as input gradient after clamping the gradient values to the range [-1, 1] Args: ctx (gradient context): context output_grad (toch.Tensor): the tensor containing the output gradient Returns: torch.Tensor: the input gradient (= the clamped output gradient) """ input_tensor, beta = ctx.saved_tensors # produces zeros where preactivation inputs exceeded threshold, ones otherwise swish = (beta * (2 - beta * input_tensor * torch.tanh(beta * input_tensor / 2))) / ( 1 + torch.cosh(beta * input_tensor) ) return swish * output_grad, None
[docs]class SwishSign(Quantization): """Module for applying the SwishSign function""" name = "swishsign" bit_width = 1
[docs] def __init__(self, beta: Union[float, None] = None) -> None: """Initializes gradient cancelation threshold. Args: gradient_cancelation_threshold (float, optional): threshold after which gradient is 0. Defaults to 1.0. """ super(SwishSign, self).__init__() self.beta = beta or config.beta
[docs] def quantize(self, x: torch.Tensor) -> torch.Tensor: """Forwards the tensor through the swishsign function. Args: x (torch.Tensor): tensor to be forwarded. Returns: torch.Tensor: sign of tensor x """ return SwishSignFunction.apply(x, self.beta)