Source code for bitorch.models.meliusnet

import argparse
import logging
from typing import Optional, List, Any

import torch
from torch import nn
from torch.nn import Module

from .densenet import BaseNetDense, DOWNSAMPLE_STRUCT, basedensenet_constructor
from .base import Model, NoArgparseArgsMixin
from bitorch.layers import QConv2d


# Blocks
[docs]class ImprovementBlock(Module): """ImprovementBlock which improves the last n channels"""
[docs] def __init__(self, channels: int, in_channels: int, dilation: int = 1): super(ImprovementBlock, self).__init__() self.body_layers: List[Module] = [] self.body_layers.append(nn.BatchNorm2d(in_channels)) self.body_layers.append( QConv2d(in_channels, channels, kernel_size=3, stride=1, padding=dilation, dilation=dilation) ) self.use_sliced_addition = channels != in_channels if self.use_sliced_addition: assert channels < in_channels self.slices = [0, in_channels - channels, in_channels] self.slices_add_x = [False, True] self.body = nn.Sequential(*self.body_layers)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: residual = x x = self.body(x) if not self.use_sliced_addition: return x + residual parts = [] for add_x, slice_begin, slice_end in zip(self.slices_add_x, self.slices[:-1], self.slices[1:]): length = slice_end - slice_begin if length == 0: continue result = torch.narrow(residual, dim=1, start=slice_begin, length=length) if add_x: result = result + x parts.append(result) return torch.cat(parts, dim=1)
class _MeliusNet(BaseNetDense): def _add_base_block_structure(self, layer_num: int, dilation: int) -> None: self._add_dense_layer(layer_num, dilation) self.current_dense_block.add_module( "ImprovementBlock%d" % (layer_num + 1), ImprovementBlock(self.growth_rate, self.num_features, dilation=dilation), )
[docs]class MeliusNet(Model): name = "MeliusNet" meliusnet_spec = { # name: block_config, reduction_factors, downsampling None: (None, [1 / 2, 1 / 2, 1 / 2], DOWNSAMPLE_STRUCT), "23": ([2, 4, 6, 6], [128 / 192, 192 / 384, 288 / 576], DOWNSAMPLE_STRUCT.replace("fp_conv", "cs,fp_conv:8")), "22": ([4, 5, 4, 4], [160 / 320, 224 / 480, 256 / 480], DOWNSAMPLE_STRUCT), "29": ([4, 6, 8, 6], [128 / 320, 192 / 512, 256 / 704], DOWNSAMPLE_STRUCT), "42": ([5, 8, 14, 10], [160 / 384, 256 / 672, 416 / 1152], DOWNSAMPLE_STRUCT), "59": ([6, 12, 24, 12], [192 / 448, 320 / 960, 544 / 1856], DOWNSAMPLE_STRUCT), "a": ([4, 5, 5, 6], [160 / 320, 256 / 480, 288 / 576], DOWNSAMPLE_STRUCT.replace("fp_conv", "cs,fp_conv:4")), "b": ([4, 6, 8, 6], [160 / 320, 224 / 544, 320 / 736], DOWNSAMPLE_STRUCT.replace("fp_conv", "cs,fp_conv:2")), "c": ([3, 5, 10, 6], [128 / 256, 192 / 448, 288 / 832], DOWNSAMPLE_STRUCT.replace("fp_conv", "cs,fp_conv:4")), }
[docs] def __init__( self, num_layers: Optional[str], input_shape: List[int], num_classes: int = 0, num_init_features: int = 64, growth_rate: int = 64, bn_size: int = 0, dropout: float = 0, dilated: bool = False, flex_block_config: Optional[List[int]] = None, ) -> None: super(MeliusNet, self).__init__(input_shape, num_classes) self._model = basedensenet_constructor( self.meliusnet_spec, _MeliusNet, num_layers, num_init_features, growth_rate, bn_size, dropout, dilated, flex_block_config, self._num_classes, self._input_shape[-2:], self._input_shape[1], ) logging.info(f"building MeliusNet with {str(num_layers)} layers...")
[docs] @staticmethod def add_argparse_arguments(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--num-layers", type=str, choices=[None, "22", "23", "29", "42", "59", "a", "b", "c"], required=True, help="number of layers to be used inside meliusnet", ) parser.add_argument( "--reduction", type=str, required=False, help="divide channels by this number in transition blocks", ) parser.add_argument( "--growth-rate", type=int, required=False, help="add this many features each block", ) parser.add_argument( "--init-features", type=int, required=False, help="start with this many filters in the first layer", ) parser.add_argument( "--downsample-structure", type=str, required=False, help="layers in downsampling branch (available: bn,relu,conv,fp_conv,pool,max_pool)", )
[docs]class MeliusNetFlex(MeliusNet): """MeliusNet-Flex model from `"MeliusNet: Can Binary Neural Networks Achieve MobileNet-level Accuracy?" <https://arxiv.org/abs/2001.05936>` paper. """ name = "MeliusNetFlex"
[docs] def __init__(self, *args: Any, **kwargs: Any) -> None: super(MeliusNetFlex, self).__init__(None, *args, **kwargs)
[docs] @staticmethod def add_argparse_arguments(parser: argparse.ArgumentParser) -> None: MeliusNet.add_argparse_arguments(parser) parser.add_argument( "--block-config", type=str, required=True, help="how many blocks to use in a flex model", )
[docs]class MeliusNet22(NoArgparseArgsMixin, MeliusNet): """MeliusNet-22 model from `"MeliusNet: Can Binary Neural Networks Achieve MobileNet-level Accuracy?" <https://arxiv.org/abs/2001.05936>` paper. """ name = "MeliusNet22"
[docs] def __init__(self, *args: Any, **kwargs: Any) -> None: super(MeliusNet22, self).__init__("22", *args, **kwargs)
[docs]class MeliusNet23(NoArgparseArgsMixin, MeliusNet): """MeliusNet-23 model from `"MeliusNet: Can Binary Neural Networks Achieve MobileNet-level Accuracy?" <https://arxiv.org/abs/2001.05936>` paper. """ name = "MeliusNet23"
[docs] def __init__(self, *args: Any, **kwargs: Any) -> None: super(MeliusNet23, self).__init__("23", *args, **kwargs)
[docs]class MeliusNet29(NoArgparseArgsMixin, MeliusNet): """MeliusNet-29 model from `"MeliusNet: Can Binary Neural Networks Achieve MobileNet-level Accuracy?" <https://arxiv.org/abs/2001.05936>` paper. """ name = "MeliusNet29"
[docs] def __init__(self, *args: Any, **kwargs: Any) -> None: super(MeliusNet29, self).__init__("29", *args, **kwargs)
[docs]class MeliusNet42(NoArgparseArgsMixin, MeliusNet): """MeliusNet-42 model from `"MeliusNet: Can Binary Neural Networks Achieve MobileNet-level Accuracy?" <https://arxiv.org/abs/2001.05936>` paper. """ name = "MeliusNet42"
[docs] def __init__(self, *args: Any, **kwargs: Any) -> None: super(MeliusNet42, self).__init__("42", *args, **kwargs)
[docs]class MeliusNet59(NoArgparseArgsMixin, MeliusNet): """MeliusNet-59 model from `"MeliusNet: Can Binary Neural Networks Achieve MobileNet-level Accuracy?" <https://arxiv.org/abs/2001.05936>` paper. """ name = "MeliusNet59"
[docs] def __init__(self, *args: Any, **kwargs: Any) -> None: super(MeliusNet59, self).__init__("59", *args, **kwargs)
[docs]class MeliusNetA(NoArgparseArgsMixin, MeliusNet): """MeliusNet-A model from `"MeliusNet: Can Binary Neural Networks Achieve MobileNet-level Accuracy?" <https://arxiv.org/abs/2001.05936>` paper. """ name = "MeliusNetA"
[docs] def __init__(self, *args: Any, **kwargs: Any) -> None: super(MeliusNetA, self).__init__("a", *args, **kwargs)
[docs]class MeliusNetB(NoArgparseArgsMixin, MeliusNet): """MeliusNet-B model from `"MeliusNet: Can Binary Neural Networks Achieve MobileNet-level Accuracy?" <https://arxiv.org/abs/2001.05936>` paper. """ name = "MeliusNetB"
[docs] def __init__(self, *args: Any, **kwargs: Any) -> None: super(MeliusNetB, self).__init__("b", *args, **kwargs)
[docs]class MeliusNetC(NoArgparseArgsMixin, MeliusNet): """MeliusNet-C model from `"MeliusNet: Can Binary Neural Networks Achieve MobileNet-level Accuracy?" <https://arxiv.org/abs/2001.05936>` paper. """ name = "MeliusNetC"
[docs] def __init__(self, *args: Any, **kwargs: Any) -> None: super(MeliusNetC, self).__init__("c", *args, **kwargs)