Source code for ccsd.src.utils.models_utils

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""models_utils.py: utility functions related to the models.
"""

from functools import lru_cache
from typing import Sequence, Union

import torch


[docs] def get_model_device(model: Union[torch.nn.Module, torch.nn.DataParallel]) -> str: """Get the the device on which the model is loaded ("cpu", "cuda", etc?) Args: model (Union[torch.nn.Module, torch.nn.DataParallel]): Pytorch model Returns: str: device on which the model is loaded """ return next(model.parameters()).device.type
[docs] def get_nb_parameters(model: torch.nn.Module) -> int: """Get the number of parameters of the model. Args: model (torch.nn.Module): model. Returns: int: number of parameters of the model. """ return sum(p.numel() for p in model.parameters() if p.requires_grad)
[docs] @lru_cache(maxsize=64) def get_ones_cache(shape: Sequence[int], device: str) -> torch.Tensor: """Cached function to get a tensor of ones of the given shape and device. Args: shape (Sequence[int]): shape of the tensor device (str): device on which the tensor should be allocated Returns: torch.Tensor: tensor of ones of the given shape and device """ return torch.ones(shape, dtype=torch.float32, device=device)
[docs] def get_ones(shape: Sequence[int], device: str) -> torch.Tensor: """Function to get a tensor of ones of the given shape and device. Call the cached version of the function and clone it. Args: shape (Sequence[int]): shape of the tensor device (str): device on which the tensor should be allocated Returns: torch.Tensor: tensor of ones of the given shape and device """ return get_ones_cache(shape, device).clone()