Source code for bitorch.models.lenet

import argparse
from typing import Optional, List
from bitorch.layers.debug_layers import ShapePrintDebug
from bitorch.layers import QLinear, QConv2d, QActivation
from torch import nn
from .base import Model


[docs]class LeNet(Model): """LeNet model, both in quantized and full precision version""" num_channels_conv = 64 activation_function = nn.Tanh num_fc = 1000 name = "LeNet" def generate_quant_model( self, weight_quant: str, input_quant: str, weight_quant_2: Optional[str] = None, input_quant_2: Optional[str] = None, ) -> nn.Sequential: weight_quant_2 = weight_quant_2 or weight_quant input_quant_2 = input_quant_2 or input_quant model = nn.Sequential( nn.Conv2d(self.input_channels, self.num_channels_conv, kernel_size=5), self.activation_function(), nn.MaxPool2d(2, 2), nn.BatchNorm2d(self.num_channels_conv), QConv2d( self.num_channels_conv, self.num_channels_conv, kernel_size=5, input_quantization=input_quant, weight_quantization=weight_quant, ), nn.BatchNorm2d(self.num_channels_conv), nn.MaxPool2d(2, 2), ShapePrintDebug(), nn.Flatten(), QActivation(activation=input_quant_2), QLinear( self.num_channels_conv * 4 * 4, self.num_fc, weight_quantization=weight_quant_2, ), nn.BatchNorm1d(self.num_fc), self.activation_function(), nn.Linear(self.num_fc, self.num_output), ) return model
[docs] def __init__(self, input_shape: List[int], num_classes: int = 0, lenet_version: int = 0) -> None: """builds the model depending on mode in either quantized or full_precision mode Args: input_shape (List[int]): input shape of images num_classes (int, optional): number of output classes. Defaults to None. lenet_version (int, optional): lenet version. if version outside of [0, 3], the full precision version is used. Defaults to 0. Raises: ValueError: thrown if num classes is none """ super(LeNet, self).__init__(input_shape, num_classes) self.input_channels = self._input_shape[1] self.num_output = self._num_classes if lenet_version == 0: self._model = self.generate_quant_model("sign", "sign") elif lenet_version == 1: self._model = self.generate_quant_model("weightdorefa", "weightdorefa") elif lenet_version == 2: self._model = self.generate_quant_model("sign", "weightdorefa", weight_quant_2="sign") elif lenet_version == 3: self._model = self.generate_quant_model("sign", "weightdorefa") else: self._model = nn.Sequential( nn.Conv2d(self.input_channels, self.num_channels_conv, kernel_size=5), nn.BatchNorm2d(self.num_channels_conv), self.activation_function(), nn.MaxPool2d(2, 2), nn.Conv2d(self.num_channels_conv, self.num_channels_conv, kernel_size=5), nn.BatchNorm2d(self.num_channels_conv), self.activation_function(), nn.MaxPool2d(2, 2), nn.Flatten(), nn.Linear(self.num_channels_conv * 4 * 4, self.num_fc), nn.BatchNorm1d(self.num_fc), self.activation_function(), nn.Linear(self.num_fc, self.num_output), )
[docs] @staticmethod def add_argparse_arguments(parser: argparse.ArgumentParser) -> None: parser.add_argument("--version", type=int, default=0, help="choose a version of lenet")