Source code for bitorch.layers.pact
from typing import Optional, Tuple
from torch.autograd import Function
from torch.nn import Module
import torch
from .config import config
# Taken from:
# https://github.com/KwangHoonAn/PACT
[docs]class PactActFn(Function):
[docs] @staticmethod
def forward(ctx, input_tensor: torch.Tensor, alpha: torch.nn.Parameter, bits: int) -> torch.Tensor: # type: ignore
ctx.save_for_backward(input_tensor, alpha)
# y_1 = 0.5 * ( torch.abs(x).detach() - torch.abs(x - alpha).detach() + alpha.item() )
clamped = torch.clamp(input_tensor, min=0, max=alpha.item())
scale = (2**bits - 1) / alpha
quantized = torch.round(clamped * scale) / scale
return quantized
[docs] @staticmethod
def backward(ctx, output_gradient: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, None]: # type: ignore
# Backward function, I borrowed code from
# https://github.com/obilaniu/GradOverride/blob/master/functional.py
# We get dL / dy_q as a gradient
x, alpha = ctx.saved_tensors
# Weight gradient is only valid when [0, alpha]
# Actual gradient for alpha,
# By applying Chain Rule, we get dL / dy_q * dy_q / dy * dy / dalpha
# dL / dy_q = argument, dy_q / dy * dy / dalpha = 0, 1 with x value range
lower_bound = x < 0
upper_bound = x > alpha
# x_range = 1.0-lower_bound-upper_bound
x_range = ~(lower_bound | upper_bound)
grad_alpha = torch.sum(output_gradient * torch.ge(x, alpha).float()).view(-1)
return output_gradient * x_range.float(), grad_alpha, None
[docs]class Pact(Module):
"""Pact activation function taken from https://github.com/KwangHoonAn/PACT.
Initially proposed in
Choi, Jungwook, et al. "Pact: Parameterized clipping activation for quantized neural networks." (2018)
"""
[docs] def __init__(self, bits: Optional[int] = None) -> None:
super().__init__()
self.alpha = torch.nn.parameter.Parameter(torch.tensor(10.0))
self.bits = bits or config.pact_bits
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
return PactActFn.apply(x, self.alpha, self.bits)