bitorch.quantizations.base.STE

class bitorch.quantizations.base.STE(*args, **kwargs)[source]

Straight Through estimator for backward pass

Methods

backward

just passes the unchanged output gradient as input gradient.

forward

Performs the operation.

Attributes

static backward(ctx: Any, output_gradient: Tensor) Tensor[source]

just passes the unchanged output gradient as input gradient.

Parameters:
  • ctx (Any) – autograd context

  • output_gradient (torch.Tensor) – output gradient

Returns:

the unchanged output gradient

Return type:

torch.Tensor

static forward(ctx: BackwardCFunction, input_tensor: Tensor) Tensor[source]

Performs the operation.

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.