Source code for bitorch.models.quicknet

import logging
from typing import Any, List, Optional

import torch
from torch import nn
from torch.nn import Module
import numpy as np

from .base import Model, NoArgparseArgsMixin
from bitorch.layers import QConv2d, PadModule
from bitorch.models.common_layers import get_initial_layers


[docs]class ResidualBlock(nn.Sequential):
[docs] def __init__(self, in_channels: int, out_channels: int) -> None: super().__init__() self.add_module( "qconv", QConv2d( in_channels, out_channels, kernel_size=3, pad_value=1, padding="same", bias=False, ), ) self.add_module("relu", nn.ReLU()) self.add_module("bnorm", nn.BatchNorm2d(out_channels, momentum=0.9))
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: return super().forward(x) + x
[docs]class TransitionBlock(nn.Sequential):
[docs] def __init__(self, in_channels: int, out_channels: int, strides: int) -> None: super().__init__() self.add_module("relu", nn.ReLU()) self.add_module("pool", nn.MaxPool2d(strides, stride=1)) self.add_module("pad", PadModule(1, 1, 1, 1)) self.add_module( "depth_conv", nn.Conv2d( in_channels, in_channels, kernel_size=3, groups=in_channels, stride=strides, bias=False, ).requires_grad_(False), ) self.add_module("conv", nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)) self.add_module("relu2", nn.ReLU()) self.add_module("norm", nn.BatchNorm2d(out_channels, momentum=0.9))
[docs]class QuickNet(Model): """QuickNet model from `"Larq Compute Engine: Design, Benchmark, and Deploy State-of-the-Art Binarized Neural Networks" <https://arxiv.org/abs/2011.09398>`_ paper. """ name = "QuickNet"
[docs] def __init__( self, input_shape: List[int], section_filters: Optional[List[int]] = None, section_blocks: Optional[List[int]] = None, num_classes: int = 0, ) -> None: super(QuickNet, self).__init__(input_shape, num_classes) if section_filters is None: section_filters = [64, 128, 256, 512] if section_blocks is None: section_blocks = [4, 4, 4, 4] self.image_channels = self._input_shape[1] self.num_classes = num_classes self.section_filters = section_filters self.section_blocks = section_blocks self._model = self._build_model() logging.info("building Quicknet") self._model.stem.apply(self._initialize_stem) # type: ignore self._model.body.apply(self._initialize_body_top) # type: ignore self._model.top.apply(self._initialize_body_top) # type: ignore
def _blurpool_init(self, weight: torch.Tensor) -> None: """Initialize anti-alias low_pass filter. See the `"Making Convolutional Networks Shift-Invariant Again" <https://arxiv.org/abs/1904.11486>`_ paper. """ filters, kernel_size = weight.data.shape[0], weight.data.shape[2] if kernel_size == 2: base = np.array([1.0, 1.0]) elif kernel_size == 3: base = np.array([1.0, 2.0, 1.0]) elif kernel_size == 5: base = np.array([1.0, 4.0, 6.0, 4.0, 1.0]) else: raise ValueError("filter size should be in 2, 3, 5") new_weights = torch.Tensor(base[:, None] * base[None, :]) new_weights = new_weights / torch.sum(new_weights) new_weights = new_weights[None, None, :, :].repeat((filters, 1, 1, 1)) weight.data = new_weights def _initialize_stem(self, layer: Module) -> None: if isinstance(layer, nn.Conv2d): if layer.groups == 1: nn.init.kaiming_normal_(layer.weight) else: nn.init.xavier_uniform_(layer.weight) def _initialize_body_top(self, layer: Module) -> None: if isinstance(layer, (nn.Conv2d, nn.Linear)): if isinstance(layer, nn.Linear) or layer.groups == 1: nn.init.xavier_normal_(layer.weight) else: self._blurpool_init(layer.weight) def _build_model(self) -> nn.Sequential: model = nn.Sequential() model.add_module( "stem", nn.Sequential(*get_initial_layers("quicknet_stem", self.image_channels, self.section_filters[0])), ) body = nn.Sequential() for block_num, (layers, filters) in enumerate(zip(self.section_blocks, self.section_filters)): residual_blocks: List[Module] = [] for layer in range(layers): residual_blocks.append(ResidualBlock(filters, filters)) body.add_module( "ResidualBlocks_%d" % (block_num + 1), nn.Sequential(*residual_blocks), ) if block_num != len(self.section_blocks) - 1: body.add_module( "Transition_%d" % (block_num + 1), TransitionBlock(filters, self.section_filters[block_num + 1], 2) ) model.add_module( "body", body, ) model.add_module( "top", nn.Sequential( nn.ReLU(), nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(self.section_filters[-1], self.num_classes), ), ) return model
[docs] def clip_weights(self, layer: Module, clip_value: float = 1.25) -> None: """Clips weights in quantized convolution layer in Residual Blocks""" if isinstance(layer, ResidualBlock): weights = layer.qconv.weight.data # type: ignore weights = weights.clamp(-clip_value, clip_value) # type: ignore layer.qconv.weight.data = weights # type: ignore
[docs] def on_train_batch_end(self, layer: Module) -> None: self.clip_weights(layer)
[docs]class QuickNetSmall(NoArgparseArgsMixin, QuickNet): """QuickNetSmall model from `"Larq Compute Engine: Design, Benchmark, and Deploy State-of-the-Art Binarized Neural Networks" <https://arxiv.org/abs/2011.09398>`_ paper. """ name = "QuickNetSmall" section_filters = [32, 64, 256, 512]
[docs] def __init__(self, *args: Any, **kwargs: Any) -> None: super(QuickNetSmall, self).__init__(section_filters=self.section_filters, *args, **kwargs)
[docs]class QuickNetLarge(NoArgparseArgsMixin, QuickNet): """QuickNetLarge model from `"Larq Compute Engine: Design, Benchmark, and Deploy State-of-the-Art Binarized Neural Networks" <https://arxiv.org/abs/2011.09398>`_ paper. """ name = "QuickNetLarge" section_blocks = [6, 8, 12, 6]
[docs] def __init__(self, *args: Any, **kwargs: Any) -> None: super(QuickNetLarge, self).__init__(section_blocks=self.section_blocks, *args, **kwargs)