Source code for bitorch.runtime_mode
from enum import Enum
from functools import total_ordering
from types import TracebackType
from typing import Union, Any, Optional, Type, List
import bitorch
__all__ = ["RuntimeMode", "runtime_mode_type", "change_mode", "pause_wrapping"]
runtime_mode_type = Union["RuntimeMode", int]
[docs]@total_ordering
class RuntimeMode(Enum):
"""
Enum for BITorch modes:
- DEFAULT: use the default implementation of all layers
- CPU: use layer implementations for inference on CPU
- GPU: use layer implementations for inference on GPU
- INFERENCE_AUTO: use an automatic layer that uses the fastest implementation available (not recommended)
- RAW: while in this mode, new layers are created as the default implementation BUT without wrapping, so they can
not be switched to other layers later on (it does not influence already wrapped layers)
"""
RAW = 0
DEFAULT = 1
CPU = 2
GPU = 4
INFERENCE_AUTO = 8
def __add__(self, other: runtime_mode_type) -> runtime_mode_type:
if self._to_int(self) == self._to_int(other):
return self
return self._to_int(other) + self.value
@staticmethod
def available_values() -> List["RuntimeMode"]:
return RuntimeMode.__members__.values() # type:ignore
@staticmethod
def list_of_names() -> List[str]:
return RuntimeMode.__members__.keys() # type:ignore
@staticmethod
def _max_val() -> int:
return sum(map(lambda x: x.value, RuntimeMode.__members__.values()))
@staticmethod
def is_single_mode(mode: runtime_mode_type) -> bool:
return any(x.value == mode for x in RuntimeMode.__members__.values())
@staticmethod
def is_combined_mode(mode: runtime_mode_type) -> bool:
return 0 <= mode < RuntimeMode._max_val()
def __lt__(self, other: Any) -> bool:
if not isinstance(other, RuntimeMode) and not isinstance(other, int):
return NotImplemented
return self.value < self._to_int(other)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, RuntimeMode) and not isinstance(other, int):
return NotImplemented
return self.value == self._to_int(other)
def __str__(self) -> str:
return self.name.lower()
@staticmethod
def _to_int(mode: runtime_mode_type) -> int:
if isinstance(mode, RuntimeMode):
return mode.value
return mode
@staticmethod
def from_string(level: str) -> "RuntimeMode":
return {
"raw": RuntimeMode.RAW,
"default": RuntimeMode.DEFAULT,
"cpu": RuntimeMode.CPU,
"gpu": RuntimeMode.GPU,
"inference_auto": RuntimeMode.INFERENCE_AUTO,
}[level.lower()]
@staticmethod
def mode_compatible(required_mode: "RuntimeMode", provided_modes: runtime_mode_type) -> bool:
if required_mode == RuntimeMode.RAW.value or provided_modes == RuntimeMode.RAW.value:
return True
return bool(RuntimeMode._to_int(required_mode) & RuntimeMode._to_int(provided_modes))
def is_supported_by(self, provided_modes: runtime_mode_type) -> bool:
if self._to_int(self) == RuntimeMode.RAW.value:
return True
return self.mode_compatible(self, provided_modes)
class _PauseWrapping:
def __init__(self) -> None:
self._previous_mode = bitorch.mode
def __enter__(self) -> "_PauseWrapping":
bitorch.mode = RuntimeMode.RAW
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
bitorch.mode = self._previous_mode
class _SafeModeChanger:
def __init__(self, new_mode: RuntimeMode) -> None:
assert new_mode.is_single_mode(new_mode)
self._previous_mode = bitorch.mode
self._new_mode = new_mode
def __enter__(self) -> "_SafeModeChanger":
bitorch.mode = self._new_mode
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
bitorch.mode = self._previous_mode
change_mode = _SafeModeChanger
pause_wrapping = _PauseWrapping