Source code for bitorch.layers.extensions.layer_recipe

import typing
from dataclasses import dataclass
from typing import TypeVar, Tuple, Any, Dict

if typing.TYPE_CHECKING:
    from .layer_container import LayerContainer

T = TypeVar("T")


[docs]@dataclass(eq=False, frozen=True) class LayerRecipe: """ Data class to store a layer object and the arguments used to create it. It allows to create other implementations of the same layer later on. """ layer: "LayerContainer" args: Tuple[Any, ...] kwargs: Dict[str, Any]
[docs] def get_positional_arg(self, pos: int) -> Any: """ Get a positional argument from the stored args. Args: pos: the position of the argument if given as a positional arg Returns: the argument value retrieved """ return self.args[pos]
[docs] def get_arg(self, pos: int, key: str, default: T) -> T: """ Get an argument from the stored args or kwargs. Args: pos: the position of the argument if given as a positional arg key: the name of the argument default: the default value of the argument Returns: the argument value retrieved """ if len(self.args) > pos: return self.args[pos] return self.kwargs.get(key, default)