bitorch.quantizations.swish_sign.SwishSignFunction

class bitorch.quantizations.swish_sign.SwishSignFunction(*args, **kwargs)[source]

SwishSign Function for input binarization.

Methods

backward

Apply straight through estimator.

forward

Binarize input tensor using the _sign function.

Attributes

static backward(ctx: BackwardCFunction, output_grad: Tensor) Tuple[Tensor, None][source]

Apply straight through estimator.

This passes the output gradient as input gradient after clamping the gradient values to the range [-1, 1]

Parameters:
  • ctx (gradient context) – context

  • output_grad (toch.Tensor) – the tensor containing the output gradient

Returns:

the input gradient (= the clamped output gradient)

Return type:

torch.Tensor

static forward(ctx: BackwardCFunction, input_tensor: Tensor, beta: float = 1.0) Tensor[source]

Binarize input tensor using the _sign function.

Parameters:

input_tensor (tensor) – the input values to the Sign function

Returns:

binarized input tensor

Return type:

tensor