Source code for bitorch.models.model_hub

from pathlib import Path
from typing import Dict, Any, Union, Tuple
import numbers
import pandas
import logging
import warnings
import torch
import base64
import hashlib
from torchvision.datasets.utils import download_url


def _md5_hash_file(path: Path) -> Any:
    hash_md5 = hashlib.md5()
    with path.open("rb") as f:
        for chunk in iter(lambda: f.read(64 * 1024), b""):
            hash_md5.update(chunk)
    return hash_md5


def _digest_file(path: Union[Path, str]) -> str:
    return base64.b64encode(_md5_hash_file(Path(path)).digest()).decode("ascii")


[docs]def convert_dtypes(data: dict) -> dict: """converts types of the values of dict so that they can be easily compared accross dataframes and csvs. converts all values that are not numerical to string. Args: data (dict): dict with values to be converted Returns: dict: dict with converted values """ for key, value in data.items(): if isinstance(value, list): value = tuple(value) if not isinstance(value, numbers.Number) and not isinstance(value, bool): data[key] = str(value) return data
[docs]def get_matching_row(version_table: pandas.DataFrame, model_kwargs: dict) -> pandas.DataFrame: """searches the version table dataframe for a row that matches model kwargs Args: version_table (pandas.DataFrame): the dataframe to search in model_kwargs (dict): the dict to search for. does not have to have key-value-pairs of each column of version_table, i.e. can be subset Returns: pandas.DataFrame: row with values in model_kwargs.keys() columns that are equal to model_kwargs values. if not existent, returns an empty dataframe. """ model_kwargs = convert_dtypes(model_kwargs) with warnings.catch_warnings(): model_kwargs_series = pandas.Series(model_kwargs) existing_row = version_table[(version_table[model_kwargs.keys()] == model_kwargs_series).all(1)] if existing_row.empty: return None return existing_row
[docs]def get_model_path(version_table: pandas.DataFrame, model_kwargs: dict) -> Tuple[str, str]: """finds the matching row for model_kwargs in version table and path to model artifact for given configuration Args: version_table (pandas.DataFrame): version table with model configurations and corresponding model hub versions model_kwargs (dict): model configuration to search for Raises: RuntimeError: thrown if no matching model can be found in version table Returns: str: path to matching model hub artifact """ matching_row = get_matching_row(version_table, model_kwargs) if matching_row is None: raise RuntimeError( f"No matching model found in hub with configuration: {model_kwargs}! You can train" " it yourself or try to load it from a local checkpoint!" ) model_url = matching_row["model_hub_url"][0] model_digest = matching_row["model_digest"][0] return model_url, model_digest
[docs]def load_from_hub( model_version_table_path: str, download_path: str = "bitorch_models", **model_kwargs: str ) -> torch.Tensor: """loads the model that matches the requested model configuration in model_kwargs from the model hub. Args: model_version_table_path (str): path to model version table on model hub download_path (str, optional): path to store the downloaded files. Defaults to "/tmp". Returns: torch.Tensor: state dict of downloaded model file """ Path(download_path).mkdir(parents=True, exist_ok=True) version_table = download_version_table(model_version_table_path) model_path, model_digest = get_model_path(version_table, model_kwargs) model_checksum = model_path.split("/")[-1] model_local_path = Path(f"{download_path}/{model_checksum}") if not model_local_path.exists() or _digest_file(str(model_local_path)) != model_digest: logging.info("downloading model...") download_url(model_path, model_local_path.parent, model_local_path.name, model_checksum) logging.info("Model downloaded!") else: logging.info(f"Using already downloaded model at {model_local_path}") artifact = torch.load(model_local_path, map_location="cpu") # true if artifact is a checkpoint from pytorch lightning if isinstance(artifact, dict): return lightning_checkpoint_to_state_dict(artifact) # type: ignore return artifact
[docs]def lightning_checkpoint_to_state_dict(artifact: Dict[Any, Any]) -> Dict[Any, Any]: """converts a pytorch lightning checkpoint to a normal torch state dict Args: artifact (Dict[Any, Any]): dict containing a ['state_dict'] attribute Returns: Dict[Any, Any]: state dict for model """ state_dict = artifact["state_dict"] for key in state_dict.keys(): assert key.startswith("model."), f"Unexpected malformed static dict key {key}." # turns model._model.arg keys in state dict into _model.arg extracted_state_dict = {key[6:]: value for key, value in state_dict.items()} return extracted_state_dict
[docs]def download_version_table(model_table_path: str, no_exception: bool = False) -> pandas.DataFrame: """downloads the newest version table from model hub. Args: model_table_path (str): path on hub to model version table api (wandb.Api): api to make download request with no_exception (bool, optional): weather exception shall be thrown if received version table is empty. Defaults to False. Raises: Exception: thrown if received version table is empty / cannot be downloaded and no_exception is False Returns: pandas.DataFrame: model version table """ logging.info("downloading model version table from hub...") try: download_url(model_table_path, "/tmp", "bitorch_model_version_table.csv") version_table = pandas.read_csv("/tmp/bitorch_model_version_table.csv") except Exception as e: logging.info(f"could not retrieve model version table from {model_table_path}: {e}") if no_exception: logging.info("creating empty table...") return pandas.DataFrame() raise Exception(e) return version_table