Source code for bitorch.layers.qlinear
"""Module containing the quantized linear layer"""
from typing import Optional, Any, Type, Union, Dict
import torch
from torch.nn import Linear
from torch.nn.functional import linear
from bitorch import RuntimeMode
from bitorch.quantizations import Quantization
from .config import config
from .extensions import LayerRecipe, DefaultImplementationMixin
from .qactivation import QActivation
from .register import QLinearImplementation
[docs]class QLinearBase(Linear):
[docs] def __init__(
self,
*args: int,
input_quantization: Optional[Union[str, Quantization]] = None,
gradient_cancellation_threshold: Union[float, None] = None,
weight_quantization: Optional[Union[str, Quantization]] = None,
**kwargs: bool,
) -> None:
"""Applies the given quantization functions on weights and inputs before applying the linear operation.
Args:
*args: positional arguments for linear layer
input_quantization (Union[str, Quantization], optional): quantization module used for input
quantization. Defaults to None.
gradient_cancellation_threshold (Union[float, None], optional): threshold for input gradient cancellation.
disabled if threshold is None. Defaults to None.
weight_quantization (Union[str, Quantization], optional): quantization module or name of quantization
function. Defaults to None.
**kwargs: keyword arguments for linear layer
"""
super().__init__(*args, **kwargs) # type: ignore
self.weight_quantization = config.get_quantization_function(weight_quantization or config.weight_quantization)
self.activation = QActivation(input_quantization, gradient_cancellation_threshold)
[docs] @staticmethod
def get_args_as_kwargs(recipe: LayerRecipe) -> Dict[str, Any]:
"""
Gather all arguments that were used to create a QLinear layer with argument names.
Can be used to recreate a layer with identical arguments.
Returns:
A dictionary with all arguments (key is the argument name as a string even for positional arguments)
"""
return {
"in_features": recipe.get_positional_arg(0),
"out_features": recipe.get_positional_arg(1),
"input_quantization": recipe.layer.input_quantization,
"gradient_cancellation_threshold": recipe.layer.gradient_cancellation_threshold,
"weight_quantization": recipe.layer.weight_quantization,
"bias": recipe.get_arg(5, "bias", True),
"device": recipe.get_arg(6, "device", None),
"dtype": recipe.get_arg(7, "dtype", None),
}
@property
def input_quantization(self) -> Quantization:
return self.activation.activation_function
@property
def gradient_cancellation_threshold(self) -> float:
return self.activation.gradient_cancellation_threshold
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forwards x through the binary linear layer.
Args:
x (torch.Tensor): tensor to forward
Returns:
torch.Tensors: forwarded tensor
"""
return linear(self.activation(x), self.weight_quantization(self.weight), self.bias)
class _QLinearComposed(DefaultImplementationMixin, QLinearBase):
"""
This class defines the default implementation of a QLinear layer (which is actually implemented by QLinearBase).
To implement a custom QLinear implementation use QLinearBase as a super class instead.
"""
pass
QLinear: Type[_QLinearComposed] = QLinearImplementation(RuntimeMode.DEFAULT)(_QLinearComposed) # type: ignore
"""
This class provides the current implementation of a QLinear layer (which is actually implemented by :class:`QLinearBase`).
To implement a custom QLinear implementation use :class:`QLinearBase` as a super class instead.
"""