bitorch.models.base.Model¶
- class bitorch.models.base.Model(input_shape: List[int], num_classes: int = 0)[source]¶
Base class for Bitorch models
Methods
Initializes internal Module state, shared by both nn.Module and ScriptModule.
allows additions to the argument parser if required, e.g.
convert
forwards the input tensor through the model.
from_pretrained
initializes model weights a little differently for BNNs.
getter method for model
Is used with the pytorch lighting on_train_batch_end callback
Attributes
name
version_table_url
- __init__(input_shape: List[int], num_classes: int = 0) None [source]¶
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- static add_argparse_arguments(parser: ArgumentParser) None [source]¶
allows additions to the argument parser if required, e.g. to add layer count, etc.
! please note that the inferred variable names of additional cli arguments are passed as keyword arguments to the constructor of this class !
- Parameters:
parser (ArgumentParser) – the argument parser
- forward(x: Tensor) Tensor [source]¶
forwards the input tensor through the model.
- Parameters:
x (torch.Tensor) – input tensor
- Returns:
the model output
- Return type:
torch.Tensor