bitorch.models.base.Model

class bitorch.models.base.Model(input_shape: List[int], num_classes: int = 0)[source]

Base class for Bitorch models

Methods

__init__

Initializes internal Module state, shared by both nn.Module and ScriptModule.

add_argparse_arguments

allows additions to the argument parser if required, e.g.

convert

forward

forwards the input tensor through the model.

from_pretrained

initialize

initializes model weights a little differently for BNNs.

model

getter method for model

on_train_batch_end

Is used with the pytorch lighting on_train_batch_end callback

Attributes

name

version_table_url

__init__(input_shape: List[int], num_classes: int = 0) None[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

static add_argparse_arguments(parser: ArgumentParser) None[source]

allows additions to the argument parser if required, e.g. to add layer count, etc.

! please note that the inferred variable names of additional cli arguments are passed as keyword arguments to the constructor of this class !

Parameters:

parser (ArgumentParser) – the argument parser

forward(x: Tensor) Tensor[source]

forwards the input tensor through the model.

Parameters:

x (torch.Tensor) – input tensor

Returns:

the model output

Return type:

torch.Tensor

initialize() None[source]

initializes model weights a little differently for BNNs.

model() Module[source]

getter method for model

Returns:

the main torch.nn.Module of this model

Return type:

Module

on_train_batch_end(layer: Module) None[source]

Is used with the pytorch lighting on_train_batch_end callback

Implement it to e.g. clip weights after optimization. Is recursively applied to every submodule.

Parameters:

layer (nn.Module) – current layer