Source code for bitorch.layers.qactivation

import typing
from typing import Optional, Union, Tuple
import torch
from torch import nn
from torch.autograd.function import Function

from bitorch.quantizations import Quantization
from .config import config


[docs]class GradientCancellation(Function):
[docs] @staticmethod @typing.no_type_check def forward( ctx: torch.autograd.function.BackwardCFunction, # type: ignore input_tensor: torch.Tensor, threshold: float, ) -> torch.Tensor: """Binarize input tensor using the _sign function. Args: input_tensor (tensor): the input values to the Sign function Returns: tensor: binarized input tensor """ ctx.save_for_backward(input_tensor, torch.tensor(threshold, device=input_tensor.device)) return input_tensor
[docs] @staticmethod @typing.no_type_check def backward( ctx: torch.autograd.function.BackwardCFunction, # type: ignore output_grad: torch.Tensor, ) -> Tuple[torch.Tensor, None]: """Apply straight through estimator. This passes the output gradient towards the input if the inputs are in the range [-1, 1]. Args: ctx (gradient context): context output_grad (toch.Tensor): the tensor containing the output gradient Returns: torch.Tensor: the input gradient (= the masked output gradient) """ input_tensor, threshold = ctx.saved_tensors cancelled = torch.where( torch.abs(input_tensor) <= threshold, output_grad, torch.tensor(0.0, device=output_grad.device) ) return cancelled, None
[docs]class QActivation(nn.Module): """Activation layer for quantization"""
[docs] def __init__( self, activation: Optional[Union[str, Quantization]] = None, gradient_cancellation_threshold: Optional[float] = 0.0, ) -> None: """initialization function for fetching suitable activation function. Args: activation (Union[str, Quantization], optional): quantization module or name of quantization function. Defaults to None. gradient_cancellation_threshold (Optional[float], optional): threshold for input gradient cancellation. Disabled if threshold is 0. """ super(QActivation, self).__init__() self.activation_function = config.get_quantization_function(activation or config.input_quantization) self.gradient_cancellation_threshold = gradient_cancellation_threshold or config.gradient_cancellation_threshold
[docs] def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: """Forwards input tensor through activation function. Args: input_tensor (torch.Tensor): input tensor Returns: torch.Tensor: quantized input tensor. """ if self.gradient_cancellation_threshold > 0: input_tensor = GradientCancellation.apply(input_tensor, self.gradient_cancellation_threshold) return self.activation_function(input_tensor)