bitorch.quantizations.approx_sign.ApproxSignFunction

class bitorch.quantizations.approx_sign.ApproxSignFunction(*args, **kwargs)[source]

ApproxSign Function for input binarization.

Methods

backward

Apply approx sign function.

forward

Binarize input tensor using the _sign function.

Attributes

static backward(ctx: BackwardCFunction, output_grad: Tensor) Tensor[source]

Apply approx sign function. used e.g. for birealnet

Parameters:
  • ctx (gradient context) – context

  • output_grad (toch.Tensor) – the tensor containing the output gradient

Returns:

the input gradient

Return type:

torch.Tensor

static forward(ctx: BackwardCFunction, input_tensor: Tensor) Tensor[source]

Binarize input tensor using the _sign function.

Parameters:

input_tensor (tensor) – the input values to the Sign function

Returns:

binarized input tensor

Return type:

tensor