bitorch.quantizations.swish_sign.SwishSignFunction¶
- class bitorch.quantizations.swish_sign.SwishSignFunction(*args, **kwargs)[source]¶
SwishSign Function for input binarization.
Methods
Apply straight through estimator.
Binarize input tensor using the _sign function.
Attributes
- static backward(ctx: BackwardCFunction, output_grad: Tensor) Tuple[Tensor, None] [source]¶
Apply straight through estimator.
This passes the output gradient as input gradient after clamping the gradient values to the range [-1, 1]
- Parameters:
ctx (gradient context) – context
output_grad (toch.Tensor) – the tensor containing the output gradient
- Returns:
the input gradient (= the clamped output gradient)
- Return type:
torch.Tensor