bitorch.quantizations.swish_sign.SwishSign

class bitorch.quantizations.swish_sign.SwishSign(beta: Optional[float] = None)[source]

Module for applying the SwishSign function

Methods

__init__

Initializes gradient cancelation threshold.

quantize

Forwards the tensor through the swishsign function.

Attributes

bit_width

name

__init__(beta: Optional[float] = None) None[source]

Initializes gradient cancelation threshold.

Parameters:

gradient_cancelation_threshold (float, optional) – threshold after which gradient is 0. Defaults to 1.0.

quantize(x: Tensor) Tensor[source]

Forwards the tensor through the swishsign function.

Parameters:

x (torch.Tensor) – tensor to be forwarded.

Returns:

sign of tensor x

Return type:

torch.Tensor