Source code for bitorch.quantizations.ste_heaviside

"""Sign Function Implementation"""
import torch
import typing
from typing import Any
from .base import STE, Quantization


[docs]class SteHeavisideFunction(STE):
[docs] @staticmethod @typing.no_type_check def forward( ctx: torch.autograd.function.BackwardCFunction, # type: ignore input_tensor: torch.Tensor, ) -> torch.Tensor: """quantizes input tensor and forwards it. Args: ctx (Any): autograd context input_tensor (torch.Tensor): input tensor Returns: torch.Tensor: the quantized input tensor """ ctx.save_for_backward(input_tensor) quantized_tensor = torch.where( input_tensor > 0, torch.tensor(1.0, device=input_tensor.device), torch.tensor(-1.0, device=input_tensor.device), ) return quantized_tensor
[docs] @staticmethod @typing.no_type_check def backward(ctx: Any, output_gradient: torch.Tensor) -> torch.Tensor: """just passes the unchanged output gradient as input gradient. Args: ctx (Any): autograd context output_gradient (torch.Tensor): output gradient Returns: torch.Tensor: the unchanged output gradient """ input_tensor = ctx.saved_tensors[0] inside_threshold = torch.abs(input_tensor) <= 1 print("over threshold:", len(input_tensor) - torch.sum(inside_threshold)) return output_gradient * inside_threshold
[docs]class SteHeaviside(Quantization): """Module for applying the SteHeaviside quantization, using an ste in backward pass""" name = "steheaviside" bit_width = 1
[docs] def quantize(self, x: torch.Tensor) -> torch.Tensor: """Forwards the tensor through the sign function. Args: x (torch.Tensor): tensor to be forwarded. Returns: torch.Tensor: sign of tensor x """ return SteHeavisideFunction.apply(x)