bitorch.quantizations.base.STE¶
- class bitorch.quantizations.base.STE(*args, **kwargs)[source]¶
Straight Through estimator for backward pass
Methods
just passes the unchanged output gradient as input gradient.
Performs the operation.
Attributes
- static backward(ctx: Any, output_gradient: Tensor) Tensor [source]¶
just passes the unchanged output gradient as input gradient.
- Parameters:
ctx (Any) – autograd context
output_gradient (torch.Tensor) – output gradient
- Returns:
the unchanged output gradient
- Return type:
torch.Tensor
- static forward(ctx: BackwardCFunction, input_tensor: Tensor) Tensor [source]¶
Performs the operation.
This function is to be overridden by all subclasses.
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()
if they are intended to be used inbackward
(equivalently,vjp
) orctx.save_for_forward()
if they are intended to be used for injvp
.