#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""logger.py: utility functions for logging.
Adapted from Jo, J. & al (2022), almost left untouched.
"""
import os
from typing import Any, List, Optional, Tuple, Union
import torch
from easydict import EasyDict
from ccsd.src.utils.models_utils import get_nb_parameters
[docs]
class Logger:
"""Logger class for logging to a file."""
[docs]
def __init__(self, filepath: str, mode: str, lock: Optional[Any] = None) -> None:
"""Initialize the Logger class.
Args:
filepath (str): the file where to write
mode (str): can be 'w' or 'a'
lock (Optional[Any], optional): pass a shared lock for multi process write access. Defaults to None.
"""
self.filepath = filepath
if mode not in ("w", "a"):
assert False, "Mode must be one of w, r or a"
else:
self.mode = mode
self.lock = lock
def __repr__(self) -> str:
"""Return the string representation of the Logger class.
Returns:
str: the string representation of the Logger class
"""
return f"Logger(filepath={self.filepath}, mode={self.mode})"
[docs]
def log(self, str: str, verbose: bool = True) -> None:
"""Log a string to the file and optionally print it
Args:
str (str): string to log
verbose (bool, optional): whether or not we print the message. Defaults to True.
"""
if self.lock:
self.lock.acquire()
try:
with open(self.filepath, self.mode) as f:
f.write(str + "\n")
except Exception as e:
print(e)
if self.lock:
self.lock.release()
if verbose:
print(str)
[docs]
def set_log(
config: EasyDict, is_train: bool = True, folder: str = "./"
) -> Tuple[str, str, str]:
"""Set the log folder name, log directory and checkpoint directory
Args:
config (EasyDict): the config object
is_train (bool, optional): True if we are training, False if we are sampling. Defaults to True.
folder (str, optional): the general saving folder. Defaults to "./".
Returns:
Tuple[str, str, str]: the name of the folder, the log directory and the checkpoint directory of the log
"""
data = config.data.data
exp_name = config.train.name
log_folder_name = os.path.join(*[data, exp_name])
root = "logs_train" if is_train else "logs_sample"
log_dir = os.path.join(folder, f"{root}", f"{log_folder_name}")
if not (os.path.isdir(log_dir)):
os.makedirs(log_dir)
ckpt_dir = os.path.join(folder, "checkpoints", f"{data}")
if not (os.path.isdir(ckpt_dir)) and is_train:
os.makedirs(ckpt_dir)
print(100 * "-")
print("Make Directory {} in Logs".format(log_folder_name))
return log_folder_name, log_dir, ckpt_dir
[docs]
def check_log(log_folder_name: str, log_name: str) -> bool:
"""Check if a log file exists
Args:
log_folder_name (str): given log folder name
log_name (str): given log name
Returns:
bool: True if the log file exists, False otherwise
"""
filepath = os.path.join(*["logs_sample", log_folder_name, f"{log_name}.log"])
return os.path.isfile(filepath)
[docs]
def data_log(logger: Logger, config: EasyDict) -> None:
"""Log the current configuration
Args:
logger (Logger): Logger object
config (EasyDict): current configuration used
"""
logger.log(
f"[{config.data.data}] init={config.data.init} ({config.data.max_feat_num}) seed={config.seed} batch_size={config.data.batch_size}"
)
[docs]
def sde_log(logger: Logger, config_sde: EasyDict, is_cc: bool = False) -> None:
"""Log the current SDE configuration
Args:
logger (Logger): Logger object
config_sde (EasyDict): sde configuration
is_cc (bool, optional): True if we are modelling with combinatorial complexes. Defaults to False.
"""
sde_x = config_sde.x
sde_adj = config_sde.adj
to_log = (
f"(x:{sde_x.type})=({sde_x.beta_min:.2f}, {sde_x.beta_max:.2f}) N={sde_x.num_scales} "
f"(adj:{sde_adj.type})=({sde_adj.beta_min:.2f}, {sde_adj.beta_max:.2f}) N={sde_adj.num_scales} "
)
if is_cc:
sde_rank2 = config_sde.rank2
to_log += f"(rank2:{sde_rank2.type})=({sde_rank2.beta_min:.2f}, {sde_rank2.beta_max:.2f}) N={sde_rank2.num_scales}"
logger.log(to_log)
[docs]
def model_log(logger: Logger, config: EasyDict, is_cc: bool = False) -> None:
"""Log the current model configuration
Args:
logger (Logger): Logger object
config (EasyDict): current configuration used
is_cc (bool, optional): True if we are modelling with combinatorial complexes. Defaults to False.
"""
config_m = config.model
line1 = f"({config_m.x})+({config_m.adj}={config_m.conv},{config_m.num_heads})"
if is_cc:
h_mask = "hodge mask" if config_m.hodge_mask else "no hodge mask"
line1 += (
f"+({config_m.rank2}={h_mask}, {config_m.num_layers_mlp} {config_m.cnum})"
)
line1 += " : "
model_log = (
line1
+ f"depth={config_m.depth} adim={config_m.adim} nhid={config_m.nhid} layers={config_m.num_layers} "
+ f"linears={config_m.num_linears} c=({config_m.c_init} {config_m.c_hid} {config_m.c_final})"
)
logger.log(model_log)
[docs]
def device_log(
logger: Logger, device: Union[str, List[int], List[str], List[torch.device]]
) -> None:
"""Log the device(s) that will be used as detected by PyTorch
Args:
logger (Logger): Logger object
device (Union[str, List[int], List[str], List[torch.device]]): device(s) used as detected
"""
print(100 * "-")
if isinstance(device, list):
device_str_list = [
f"cuda:{dev}" if "cuda" not in str(dev) else str(dev) for dev in device
]
device_str_list_names = [
torch.cuda.get_device_name(int(d.split("cuda:")[1]))
for d in device_str_list
]
device_str = f"GPU: {device_str_list}\nGPU names: {device_str_list_names}"
else:
device_str = f"{device}"
logger.log(f"Using device: {device_str}")
[docs]
def start_log(logger: Logger, config: EasyDict) -> None:
"""Log initial message with the configuration
Args:
logger (Logger): Logger object
config (EasyDict): configuration used
"""
logger.log(100 * "-")
data_log(logger, config)
logger.log(100 * "-")
[docs]
def train_log(logger: Logger, config: EasyDict) -> None:
"""Log configuration used for training
Args:
logger (Logger): Logger object
config (EasyDict): configuration used
"""
logger.log(
f"lr={config.train.lr} schedule={config.train.lr_schedule} ema={config.train.ema} "
f"epochs={config.train.num_epochs} reduce={config.train.reduce_mean} eps={config.train.eps}"
)
model_log(logger, config)
sde_log(logger, config.sde)
logger.log(100 * "-")
[docs]
def sample_log(logger: Logger, config: EasyDict) -> None:
"""Log configuration used for sampling
Args:
logger (Logger): Logger object
config (EasyDict): configuration used
"""
sample_log = (
f"({config.sampler.predictor})+({config.sampler.corrector}): "
f"eps={config.sample.eps} denoise={config.sample.noise_removal} "
f"ema={config.sample.use_ema} "
)
if config.sampler.corrector == "Langevin": # add Langevin's parameters
sample_log += (
f"|| snr={config.sampler.snr} seps={config.sampler.scale_eps} "
f"n_steps={config.sampler.n_steps} "
)
logger.log(sample_log)
logger.log(100 * "-")
[docs]
def model_parameters_log(logger: Logger, models: List[torch.nn.Module]) -> None:
"""Print the number of parameters of the models and the total number of parameters.
Args:
logger (Logger): Logger object
models (List[torch.nn.Module]): list of models.
"""
model_parameters = [
(model.__class__.__name__, get_nb_parameters(model)) for model in models
]
total_parameters = sum(nb_param for _, nb_param in model_parameters)
logger.log(100 * "-")
logger.log("\nNumber of parameters:\n")
for model_name, nb_param in model_parameters:
logger.log(f"\t{model_name}: {nb_param}\n")
logger.log(f"\nTotal: {total_parameters}\n")
logger.log(100 * "-")
[docs]
def time_log(logger: Logger, time_type: str, elapsed_time: float) -> None:
"""Log the time elapsed since the start of the training/sampling
Args:
logger (Logger): Logger object
time_type (str): type of time. Must be in ["train", "sample"].
elapsed_time (float): elapsed time since the start of the training/sampling
Raises:
ValueError: raise an error if time_type is not in ["train", "sample"]
"""
if time_type == "train":
logger.log(f"Training time: {round(elapsed_time, 3)} seconds")
elif time_type == "sample":
logger.log(f"Sampling time: {round(elapsed_time, 3)} seconds")
else:
raise ValueError(f"time_type must be in ['train', 'sample'], not {time_type}")