Source code for bitorch.quantizations.sign
"""Sign Function Implementation"""
import typing
import torch
from .base import Quantization, STE
[docs]class SignFunction(STE):
[docs] @staticmethod
@typing.no_type_check
def forward(
ctx: torch.autograd.function.BackwardCFunction, # type: ignore
input_tensor: torch.Tensor,
) -> torch.Tensor:
"""Binarize the input tensor using the sign function.
Args:
ctx (Any): autograd context
input_tensor (torch.Tensor): input tensor
Returns:
torch.Tensor: the sign tensor
"""
sign_tensor = torch.sign(input_tensor)
sign_tensor = torch.where(sign_tensor == 0, torch.tensor(1.0, device=sign_tensor.device), sign_tensor)
return sign_tensor
[docs]class Sign(Quantization):
"""Module for applying the sign function with straight through estimator in backward pass."""
name = "sign"
bit_width = 1
[docs] def quantize(self, x: torch.Tensor) -> torch.Tensor:
"""Forwards the tensor through the sign function.
Args:
x (torch.Tensor): tensor to be forwarded.
Returns:
torch.Tensor: sign of tensor x
"""
return SignFunction.apply(x)