bitorch.quantizations.ste_heaviside.SteHeavisideFunction

class bitorch.quantizations.ste_heaviside.SteHeavisideFunction(*args, **kwargs)[source]

Methods

backward

just passes the unchanged output gradient as input gradient.

forward

quantizes input tensor and forwards it.

Attributes

static backward(ctx: Any, output_gradient: Tensor) Tensor[source]

just passes the unchanged output gradient as input gradient.

Parameters:
  • ctx (Any) – autograd context

  • output_gradient (torch.Tensor) – output gradient

Returns:

the unchanged output gradient

Return type:

torch.Tensor

static forward(ctx: BackwardCFunction, input_tensor: Tensor) Tensor[source]

quantizes input tensor and forwards it.

Parameters:
  • ctx (Any) – autograd context

  • input_tensor (torch.Tensor) – input tensor

Returns:

the quantized input tensor

Return type:

torch.Tensor