Source code for bitorch.models.common_layers

from typing import List, Optional, Union
from torch import nn

from bitorch.layers.pad import PadModule


[docs]def get_initial_layers( variant: Optional[Union[List[int], str]], input_channels: int, output_channels: int ) -> List[nn.Module]: """returns the initial layers for the given variant""" layers: List[nn.Module] = [] if variant == (224, 224) or variant == "imagenet": layers.append(nn.Conv2d(input_channels, output_channels, kernel_size=7, stride=2, padding=3, bias=False)) elif variant == "quicknet_stem": assert output_channels % 4 == 0 stem_channels = output_channels // 4 layers.append(PadModule(0, 1, 0, 1)) layers.append( nn.Conv2d( input_channels, stem_channels, kernel_size=3, stride=2, bias=False, ) ) layers.append(nn.BatchNorm2d(stem_channels, momentum=0.9)) layers.append(nn.ReLU()) layers.append(PadModule(0, 1, 0, 1)) layers.append( nn.Conv2d( stem_channels, stem_channels, kernel_size=3, groups=stem_channels, stride=2, bias=False, ) ) layers.append(nn.BatchNorm2d(stem_channels, momentum=0.9)) layers.append( nn.Conv2d( stem_channels, output_channels, kernel_size=1, bias=False, ) ) layers.append(nn.BatchNorm2d(output_channels, momentum=0.9)) elif variant == "grouped_stem": stem_width = output_channels // 2 layers.append(nn.Conv2d(input_channels, stem_width, kernel_size=3, stride=2, padding=1, bias=False)) layers.append(nn.BatchNorm2d(stem_width, momentum=0.9)) layers.append(nn.ReLU()) layers.append(nn.Conv2d(stem_width, stem_width, kernel_size=3, stride=1, padding=1, groups=4, bias=False)) layers.append(nn.BatchNorm2d(stem_width, momentum=0.9)) layers.append(nn.ReLU()) layers.append(nn.Conv2d(stem_width, output_channels, kernel_size=3, stride=1, padding=1, groups=8, bias=False)) else: layers.append(nn.Conv2d(input_channels, output_channels, kernel_size=3, padding=1, bias=False)) if variant in [(224, 224), "imagenet", "grouped_stem"]: layers.append(nn.BatchNorm2d(output_channels, momentum=0.9)) layers.append(nn.ReLU()) layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) return layers