Source code for bitorch.layers.debug_layers

from typing import Optional, Any
import torch
from .config import config


class _Debug(torch.nn.Module):
    def __init__(self, debug_interval: int = 100, num_outputs: int = 10, name: str = "Debug") -> None:
        """inits values.

        Args:
            debug_interval (int, optional): interval at which debug output shall be prompted. Defaults to 100.
            num_outputs (int, optional): number of weights/inputs that shall be debugged. Defaults to 10.
            name (str, optional): name of debug layer, only relevant for print debugging
        """
        super(_Debug, self).__init__()
        self._debug_interval = debug_interval
        self._num_outputs = num_outputs
        self.name = name

        self._forward_counter = 0

    def _debug(self, x: torch.Tensor) -> None:
        pass

    def _debug_tensor(self, debug_tensor: torch.Tensor) -> None:
        """outputs debug information for given tensor

        Args:
            debug_tensor (torch.Tensor): tensor to be debugged
        """
        if config.debug_activated and self._forward_counter % self._debug_interval == 0:
            self._debug(debug_tensor)

        self._forward_counter += 1


class _PrintDebug(_Debug):
    def _debug(self, debug_tensor: torch.Tensor) -> None:
        """prints the first num_outputs entries in tensor debug_tensor

        Args:
            debug_tensor (torch.Tensor): tensor to be debugged
        """
        print(
            self.name, ":", debug_tensor if len(debug_tensor) < self._num_outputs else debug_tensor[: self._num_outputs]
        )


class _GraphicalDebug(_Debug):
    def __init__(
        self,
        figure: Optional[object] = None,
        images: Optional[list] = None,
        debug_interval: int = 100,
        num_outputs: int = 10,
    ) -> None:
        """Debugs the given layer by drawing weights/inputs in given matplotlib plot images.

        Args:
            figure (object): figure to draw in
            images (list): list of images to update with given data
            debug_interval (int, optional): interval at which debug output shall be prompted. Defaults to 100.
            num_outputs (int, optional): number of weights/inputs that shall be debugged. Defaults to 10.

        Raises:
            ValueError: raised if number of images does not match desired number of outputs.
        """
        super(_GraphicalDebug, self).__init__(debug_interval, num_outputs)
        self.set_figure(figure)
        self.set_images(images)

    def set_figure(self, figure: Optional[object] = None) -> None:
        """setter for figure object

        Args:
            figure (object): the figure object
        """
        self._figure = figure

    def set_images(self, images: Optional[list] = None) -> None:
        """setter for images list

        Args:
            images (list): list of image objects to output the graphical information to

        Raises:
            ValueError: raised if number of images does not match desired number of outputs.
        """
        self._images = images
        if self._images is not None and len(self._images) != self._num_outputs:
            raise ValueError(
                f"number of given images ({len(self._images)}) must match "
                f"number of desired outputs ({self._num_outputs})!"
            )

    def _debug(self, debug_tensor: torch.Tensor) -> None:
        """draws graphical debug information about given debug tensor into figure

        Args:
            debug_tensor (torch.Tensor): tensor to be debugged

        Raises:
            ValueError: raised if either no figure or no images were given
        """
        debug_tensor = debug_tensor.clone().detach()
        if self._figure is None or self._images is None:
            raise ValueError("no subplot given to debug into!")
        dimensionality = len(debug_tensor.shape)
        filters = []
        # depending of dimensionality select the filters to be drawn to the images
        if dimensionality == 2:
            filters.append(debug_tensor)
        elif dimensionality == 3 or dimensionality == 4:
            for i in range(debug_tensor.shape[0]):
                for j in range(debug_tensor.shape[1]):
                    filters.append(debug_tensor[i, j])
        elif dimensionality == 5:
            for i in range(debug_tensor.shape[0]):
                for j in range(debug_tensor.shape[1]):
                    for k in range(debug_tensor.shape[2]):
                        filters.append(debug_tensor[i, j, k])
        # normalize all filters
        for image, fltr in zip(self._images, filters):
            fltr -= torch.min(fltr)
            fltr /= torch.max(fltr)
            image.set_data(fltr)
        self._figure.canvas.draw()


"""
Classes above are internal, use classes below for debugging
"""


[docs]class InputPrintDebug(_PrintDebug):
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """forwards the given tensor without modification, debug output if activated Args: x (torch.Tensor): tensor to be debugged Returns: torch.Tensor: input tensor x """ self._debug_tensor(x) return x
[docs]class InputGraphicalDebug(_GraphicalDebug):
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """forwards the given tensor without modification, debug output if activated Args: x (torch.Tensor): tensor to be debugged Returns: torch.Tensor: input tensor x """ self._debug_tensor(x) return x
[docs]class WeightPrintDebug(_PrintDebug):
[docs] def __init__(self, module: torch.nn.Module, *args: Any, **kwargs: Any) -> None: """stores given module Args: module (torch.nn.Module): module the weights of which shall be debugged """ super(WeightPrintDebug, self).__init__(*args, **kwargs) self._debug_module = module
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forwards the input tensor through the debug model and outputs debug information about the given modules weights. Args: x (torch.Tensor): the Tensor to be forwarded untouched. Returns: torch.Tensor: the input tensor """ x = self._debug_module(x) weight = self._debug_module.weight.clone() # type: ignore # check if given module is a quantized module if hasattr(self._debug_module, "quantize"): weight = self._debug_module.quantize(weight) # type: ignore self._debug_tensor(weight) return x
[docs]class WeightGraphicalDebug(_GraphicalDebug):
[docs] def __init__(self, module: torch.nn.Module, *args: Any, **kwargs: Any) -> None: """stores given module Args: module (torch.nn.Module): module the weights of which shall be debugged """ super(WeightGraphicalDebug, self).__init__(*args, **kwargs) self._debug_module = module
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forwards the input tensor through the debug model and outputs debug information about the given modules weights. Args: x (torch.Tensor): the Tensor to be forwarded untouched. Returns: torch.Tensor: the input tensor """ x = self._debug_module(x) weight = self._debug_module.weight.clone() # type: ignore # check if given module is a quantized module if hasattr(self._debug_module, "quantize"): weight = self._debug_module.quantize(weight) # type: ignore self._debug_tensor(weight) return x
[docs]class ShapePrintDebug(_PrintDebug):
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """prints the shape of x, leaves x untouched Args: x (torch.Tensor): the tensor to be debugged Returns: torch.Tensor: input tensor x """ self._debug_tensor(torch.tensor(x.shape)) return x