Source code for ccsd.src.utils.loader

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""loader.py: code for loading the model, the optimizer, the scheduler, the loss function, etc

Adapted from Jo, J. & al (2022)
"""

import os
import random
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import networkx as nx
import numpy as np
import torch
from easydict import EasyDict
from toponetx.classes.combinatorial_complex import CombinatorialComplex
from torch.utils.data import DataLoader

from ccsd.src.evaluation.mmd import gaussian, gaussian_emd, gaussian_tv
from ccsd.src.losses import get_sde_loss_fn, get_sde_loss_fn_cc
from ccsd.src.models.ScoreNetwork_A import ScoreNetworkA
from ccsd.src.models.ScoreNetwork_A_Base_CC import ScoreNetworkA_Base_CC
from ccsd.src.models.ScoreNetwork_A_CC import ScoreNetworkA_CC
from ccsd.src.models.ScoreNetwork_F import ScoreNetworkF
from ccsd.src.models.ScoreNetwork_X import ScoreNetworkX, ScoreNetworkX_GMH
from ccsd.src.sde import SDE, VESDE, VPSDE, subVPSDE
from ccsd.src.solver import S4_solver, get_pc_sampler
from ccsd.src.utils.cc_utils import get_rank2_dim
from ccsd.src.utils.data_loader import dataloader, dataloader_cc
from ccsd.src.utils.data_loader_mol import dataloader_mol, dataloader_mol_cc
from ccsd.src.utils.ema import ExponentialMovingAverage


[docs] def load_seed(seed: int) -> int: """Apply the random seed to all libraries (torch, numpy, random) and make sure that the results are reproducible. Args: seed (int): seed to use Returns: int: return the seed """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False return seed
[docs] def load_device() -> Union[str, List[int]]: """Check if cuda is available and then return the device(s) to use Returns: Union[str, List[int]]: device(s) to use """ if torch.cuda.is_available(): device = list(range(torch.cuda.device_count())) else: device = "cpu" return device
[docs] def load_model(params: Dict[str, Any]) -> torch.nn.Module: """Load the Score Network model from the parameters Args: params (dict): parameters to use Raises: ValueError: raise an error if the model is unknown Returns: torch.nn.Module: Score Network model to use """ params_ = params.copy() model_type = params_.pop("model_type", None) if model_type == "ScoreNetworkX": model = ScoreNetworkX(**params_) elif model_type == "ScoreNetworkX_GMH": model = ScoreNetworkX_GMH(**params_) elif model_type == "ScoreNetworkA": model = ScoreNetworkA(**params_) elif model_type == "ScoreNetworkA_Base_CC": model = ScoreNetworkA_Base_CC(**params_) elif model_type == "ScoreNetworkA_CC": model = ScoreNetworkA_CC(**params_) elif model_type == "ScoreNetworkF": model = ScoreNetworkF(**params_) else: raise ValueError( f"Model Name <{model_type}> is unknown. Please select from [ScoreNetworkX, ScoreNetworkX_GMH, ScoreNetworkA, ScoreNetworkA_CC, ScoreNetworkA_Base_CC, ScoreNetworkF]" ) return model
[docs] def load_model_optimizer( params: Dict[str, Any], config_train: EasyDict, device: Union[str, List[str], List[int]], ) -> Tuple[ Union[torch.nn.Module, torch.nn.DataParallel], torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler, ]: """Return the model, the optimizer and the scheduler in function of the parameters Args: params (Dict[str, Any]): model parameters config_train (EasyDict): configuration for training device (Union[str, List[str], List[int]]): device to use Returns: Tuple[Union[torch.nn.Module, torch.nn.DataParallel], torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: return the model, the optimizer and the scheduler """ model = load_model(params) if isinstance(device, list): # check for multi-gpu assert len(device) > 0, "At least one device must be provided" assert all( ( isinstance(dev, int) or isinstance(dev, torch.device) or isinstance(dev, str) ) for dev in device ), "Device(s) must be device ids (integers, strings, or torch.device objects)" if len(device) > 1: # multi-gpu model = torch.nn.DataParallel(model, device_ids=device) if "cuda" in str(device[0]): model = model.to(device[0]) else: model = model.to(f"cuda:{device[0]}") else: model = model.to(device) # "cpu" or "cuda" optimizer = torch.optim.Adam( model.parameters(), lr=config_train.lr, weight_decay=config_train.weight_decay ) scheduler = None if config_train.lr_schedule: scheduler = torch.optim.lr_scheduler.ExponentialLR( optimizer, gamma=config_train.lr_decay ) return model, optimizer, scheduler
[docs] def load_ema(model: torch.nn.Module, decay: float = 0.999) -> ExponentialMovingAverage: """Create an exponential moving average object for the model's parameters Args: model (torch.nn.Module): model used to train the model decay (float, optional): decay parameter. Defaults to 0.999. Returns: ExponentialMovingAverage: exponential moving average object for the model's parameters """ ema = ExponentialMovingAverage(model.parameters(), decay=decay) return ema
[docs] def load_ema_from_ckpt( model: torch.nn.Module, ema_state_dict: Dict[str, Any], decay: float = 0.999 ) -> ExponentialMovingAverage: """Load the exponential moving average object for the model's parameters from a checkpoint Args: model (torch.nn.Module): model used to train the model ema_state_dict (Dict[str, Any]): parameters of the exponential moving average decay (float, optional): decay parameter. Defaults to 0.999. Returns: ExponentialMovingAverage: exponential moving average object for the model's parameters """ ema = ExponentialMovingAverage(model.parameters(), decay=decay) ema.load_state_dict(ema_state_dict) return ema
[docs] def load_data( config: EasyDict, get_list: bool = False, is_cc: bool = False, ) -> Union[ Tuple[DataLoader, DataLoader], Union[ Tuple[List[nx.Graph], List[nx.Graph]], Tuple[List[CombinatorialComplex], List[CombinatorialComplex]], ], ]: """Return a DataLoader object for training based on the configuration Args: config (EasyDict): configuration for training get_list (bool, optional): if True, returns lists of graph or combinatorial complexes instead of dataloaders. Defaults to False. is_cc (bool, optional): if True, the dataset is made of combinatorial complexes. Defaults to False. Returns: Union[Tuple[DataLoader, DataLoader], Union[Tuple[List[nx.Graph], List[nx.Graph]], Tuple[List[CombinatorialComplex], List[CombinatorialComplex]]]]: DataLoader object or list of objects for training """ if config.data.data in ["QM9", "ZINC250k"]: if not (is_cc): return dataloader_mol(config, get_list) return dataloader_mol_cc(config, get_list) else: if not (is_cc): return dataloader(config, get_list) return dataloader_cc(config, get_list)
[docs] def load_batch( batch: List[torch.Tensor], device: Union[str, List[str]], is_cc: bool = False ) -> Union[ Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor] ]: """Load the batch on the device Args: batch (List[torch.Tensor]): input batch device (Union[str, List[str]]): device to use is_cc (bool, optional): if True, the elements of the input batch are combinatorial complexes. Defaults to False. Returns: Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: input batch on the device """ device_id = f"cuda:{device[0]}" if isinstance(device, list) else device x_b = batch[0].to(device_id) adj_b = batch[1].to(device_id) if not (is_cc): return x_b, adj_b rank2_b = batch[2].to(device_id) return x_b, adj_b, rank2_b
[docs] def load_sde(config_sde: EasyDict) -> SDE: """Load the stochastic differential equation (SDE) from the configuration Args: config_sde (EasyDict): configuration for the SDE Raises: NotImplementedError: raise an error if the SDE is unknown Returns: SDE: SDE to use """ sde_type = config_sde.type beta_min = config_sde.beta_min beta_max = config_sde.beta_max num_scales = config_sde.num_scales if sde_type == "VP": sde = VPSDE(beta_min=beta_min, beta_max=beta_max, N=num_scales) elif sde_type == "VE": sde = VESDE(sigma_min=beta_min, sigma_max=beta_max, N=num_scales) elif sde_type == "subVP": sde = subVPSDE(beta_min=beta_min, beta_max=beta_max, N=num_scales) else: raise NotImplementedError(f"SDE class {sde_type} not (yet) supported.") return sde
[docs] def load_loss_fn( config: EasyDict, is_cc: bool = False, ) -> Union[ Callable[ [ torch.nn.Module, torch.nn.Module, torch.Tensor, torch.Tensor, ], Tuple[torch.Tensor, torch.Tensor], ], Callable[ [ torch.nn.Module, torch.nn.Module, torch.nn.Module, torch.Tensor, torch.Tensor, torch.Tensor, ], Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ], ]: """Load the loss function from the configuration Args: config (EasyDict): configuration to use is_cc (bool, optional): if True, loss function for combinatorial complexes. Defaults to False. Returns: Union[Callable[[torch.nn.Module, torch.nn.Module, torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]], Callable[[torch.nn.Module, torch.nn.Module, torch.nn.Module, torch.Tensor, torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: loss function that returns 2 or 3 losses, for x, adj and rank2 if cc """ reduce_mean = config.train.reduce_mean sde_x = load_sde(config.sde.x) sde_adj = load_sde(config.sde.adj) if not (is_cc): loss_fn = get_sde_loss_fn( sde_x, sde_adj, train=True, reduce_mean=reduce_mean, continuous=True, likelihood_weighting=False, eps=config.train.eps, ) else: sde_rank2 = load_sde(config.sde.rank2) d_min = config.data.d_min d_max = config.data.d_max loss_fn = get_sde_loss_fn_cc( sde_x, sde_adj, sde_rank2, d_min=d_min, d_max=d_max, train=True, reduce_mean=reduce_mean, continuous=True, likelihood_weighting=False, eps=config.train.eps, ) return loss_fn
[docs] def load_sampling_fn( config_train: EasyDict, config_module: EasyDict, config_sample: EasyDict, device: Union[str, List[str]], is_cc: bool = False, d_min: Optional[int] = None, d_max: Optional[int] = None, divide_batch: Optional[int] = None, ) -> Union[ Callable[ [torch.nn.Module, torch.nn.Module, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, float], ], Callable[ [torch.nn.Module, torch.nn.Module, torch.nn.Module, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor, float], ], ]: """Load the sampling function from the configuration Args: config_train (EasyDict): configuration for training config_module (EasyDict): configuration for the module config_sample (EasyDict): configuration for the sampling device (Union[str, List[str]]): device to use is_cc (bool, optional): if True, we sample combinatorial complexes. Defaults to False. d_min (Optional[int], optional): minimum size of rank2 cells (for cc). Defaults to None. d_max (Optional[int], optional): maximum size of rank2 cells (for cc). Defaults to None. divide_batch (Optional[int], optional): if not None, divide the samples by this number to bypass RAM saturation. Defaults to None. Returns: Union[Callable[[torch.nn.Module, torch.nn.Module, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, float]], Callable[[torch.nn.Module, torch.nn.Module, torch.nn.Module, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]]]: sampling function """ sde_x = load_sde(config_train.sde.x) sde_adj = load_sde(config_train.sde.adj) if is_cc: sde_rank2 = load_sde(config_train.sde.rank2) max_node_num = config_train.data.max_node_num device_id = f"cuda:{device[0]}" if isinstance(device, list) else device # Get sampler if config_module.predictor == "S4": get_sampler = S4_solver else: get_sampler = get_pc_sampler # Get shape in function of dataset if config_train.data.data in ["QM9", "ZINC250k"]: batch_size = ( config_sample.n_samples if divide_batch is None else config_sample.n_samples // divide_batch ) shape_x = ( batch_size, max_node_num, config_train.data.max_feat_num, ) shape_adj = (batch_size, max_node_num, max_node_num) if is_cc: rank2_dim = get_rank2_dim(max_node_num, d_min, d_max) shape_rank2 = (batch_size, rank2_dim[0], rank2_dim[1]) else: batch_size = ( config_train.data.batch_size if divide_batch is None else config_train.data.batch_size // divide_batch ) shape_x = ( batch_size, max_node_num, config_train.data.max_feat_num, ) shape_adj = (batch_size, max_node_num, max_node_num) if is_cc: rank2_dim = get_rank2_dim(max_node_num, d_min, d_max) shape_rank2 = (batch_size, rank2_dim[0], rank2_dim[1]) # Get sampling function if not (is_cc): sampling_fn = get_sampler( sde_x=sde_x, sde_adj=sde_adj, shape_x=shape_x, shape_adj=shape_adj, predictor=config_module.predictor, corrector=config_module.corrector, snr=config_module.snr, scale_eps=config_module.scale_eps, n_steps=config_module.n_steps, probability_flow=config_sample.probability_flow, continuous=True, denoise=config_sample.noise_removal, eps=config_sample.eps, device=device_id, ) else: sampling_fn = get_sampler( sde_x=sde_x, sde_adj=sde_adj, shape_x=shape_x, shape_adj=shape_adj, predictor=config_module.predictor, corrector=config_module.corrector, snr=config_module.snr, scale_eps=config_module.scale_eps, n_steps=config_module.n_steps, probability_flow=config_sample.probability_flow, continuous=True, denoise=config_sample.noise_removal, eps=config_sample.eps, device=device_id, is_cc=is_cc, sde_rank2=sde_rank2, shape_rank2=shape_rank2, d_min=d_min, d_max=d_max, ) return sampling_fn
[docs] def load_model_params( config: EasyDict, is_cc: bool = False, ) -> Union[ Tuple[Dict[str, Any], Dict[str, Any]], Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]], ]: """Load the model parameters from the configuration Args: config (EasyDict): configuration to use is_cc (bool, optional): whether to model using combinatorial complexes. Defaults to False. Returns: Union[Tuple[Dict[str, Any], Dict[str, Any]], Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]]: parameters for x, adj, and rank-2 cells if cc """ assert is_cc == config.is_cc, "is_cc should be the same in config and function call" config_m = config.model max_feat_num = config.data.max_feat_num max_node_num = config.data.max_node_num if "GMH" in config_m.x: params_x = { "is_cc": is_cc, "model_type": config_m.x, "max_feat_num": max_feat_num, "depth": config_m.depth, "nhid": config_m.nhid, "num_linears": config_m.num_linears, "c_init": config_m.c_init, "c_hid": config_m.c_hid, "c_final": config_m.c_final, "adim": config_m.adim, "num_heads": config_m.num_heads, "conv": config_m.conv, "use_bn": config_m.use_bn, } else: params_x = { "is_cc": is_cc, "model_type": config_m.x, "max_feat_num": max_feat_num, "depth": config_m.depth, "nhid": config_m.nhid, "use_bn": config_m.use_bn, } params_adj = { "is_cc": is_cc, "model_type": config_m.adj, "max_feat_num": max_feat_num, "max_node_num": config.data.max_node_num, "nhid": config_m.nhid, "num_layers": config_m.num_layers, "num_linears": config_m.num_linears, "c_init": config_m.c_init, "c_hid": config_m.c_hid, "c_final": config_m.c_final, "adim": config_m.adim, "num_heads": config_m.num_heads, "conv": config_m.conv, "use_bn": config_m.use_bn, } if not (is_cc): return params_x, params_adj # If is_cc, also load rank-2 parameters and some additional parameters to params_adj if ScoreNetworkA_CC or ScoreNetworkA_Base_CC d_min = config.data.d_min d_max = config.data.d_max if config_m.adj == "ScoreNetworkA_CC": params_adj["d_min"] = d_min params_adj["d_max"] = d_max params_adj["nhid_h"] = config_m.nhid_h params_adj["num_layers_h"] = config_m.num_layers_h params_adj["num_linears_h"] = config_m.num_linears_h params_adj["c_hid_h"] = config_m.c_hid_h params_adj["c_final_h"] = config_m.c_final_h params_adj["adim_h"] = config_m.adim_h params_adj["num_heads_h"] = config_m.num_heads_h params_adj["conv_hodge"] = config_m.conv_hodge elif config_m.adj == "ScoreNetworkA_Base_CC": params_adj["d_min"] = d_min params_adj["d_max"] = d_max params_adj["nhid_h"] = config_m.nhid_h params_adj["num_layers_h"] = config_m.num_layers_h params_adj["num_linears_h"] = config_m.num_linears_h params_adj["c_hid_h"] = config_m.c_hid_h params_adj["c_final_h"] = config_m.c_final_h params_adj["hidden_h"] = config_m.hidden_h params_rank2 = { "is_cc": config.is_cc, "model_type": config_m.rank2, "num_layers_mlp": config_m.num_layers_mlp, "num_layers": config_m.num_layers_h, "num_linears": config_m.num_linears_h, "nhid": config_m.nhid_h, "c_hid": config_m.c_hid_h, "c_final": config_m.c_final_h, "cnum": config_m.cnum, "max_node_num": max_node_num, "d_min": d_min, "d_max": d_max, "use_hodge_mask": config_m.use_hodge_mask, "use_bn": config_m.use_bn, } return params_x, params_adj, params_rank2
[docs] def load_ckpt( config: EasyDict, device: Union[str, List[str]], ts: Optional[str] = None, return_ckpt: bool = False, is_cc: bool = False, ) -> Dict[str, Any]: """Load the checkpoint from the configuration Args: config (EasyDict): configuration to use device (Union[str, List[str]]): device to use ts (Optional[str], optional): timestamp (checkpoint name). Defaults to None. return_ckpt (bool, optional): if True, add the checkpoint in the resulting dictionary (key: "ckpt"). Defaults to False. is_cc (bool, optional): whether to model using combinatorial complexes. Defaults to False. Returns: Dict[str, Any]: loaded checkpoint parameters and configuration """ device_id = f"cuda:{device[0]}" if isinstance(device, list) else device ckpt_dict = {} if ts is not None: config.ckpt = ts path = os.path.join( config.folder, "checkpoints", f"{config.data.data}", f"{config.ckpt}.pth" ) ckpt = torch.load(path, map_location=device_id) print(f"{path} loaded") ckpt_dict = { "config": ckpt["model_config"], "params_x": ckpt["params_x"], "x_state_dict": ckpt["x_state_dict"], "params_adj": ckpt["params_adj"], "adj_state_dict": ckpt["adj_state_dict"], } if is_cc: ckpt_dict["params_rank2"] = ckpt["params_rank2"] ckpt_dict["rank2_state_dict"] = ckpt["rank2_state_dict"] if config.sample.use_ema: ckpt_dict["ema_x"] = ckpt["ema_x"] ckpt_dict["ema_adj"] = ckpt["ema_adj"] if is_cc: ckpt_dict["ema_rank2"] = ckpt["ema_rank2"] if return_ckpt: ckpt_dict["ckpt"] = ckpt # Change folder with the one provided ckpt_dict["config"]["folder"] = config.folder return ckpt_dict
[docs] def load_model_from_ckpt( params: Dict[str, Any], state_dict: Dict[str, Any], device: Union[str, List[torch.device], List[int]], ) -> Union[torch.nn.Module, torch.nn.DataParallel]: """Load the model from the checkpoint Args: params (Dict[str, Any]): parameters of the model state_dict (Dict[str, Any]): state dictionary of the model device (Union[str, List[str], List[int]]): device to use Returns: Union[torch.nn.Module, torch.nn.DataParallel]: loaded model """ model = load_model(params) if "module." in list(state_dict.keys())[0]: # strip 'module.' at front; for DataParallel models state_dict = {k[7:]: v for k, v in state_dict.items()} model.load_state_dict(state_dict) if isinstance(device, list): # check for multi-gpu assert len(device) > 0, "At least one device must be provided" assert all( ( isinstance(dev, int) or isinstance(dev, torch.device) or isinstance(dev, str) ) for dev in device ), "Device(s) must be device ids (integers, strings, or torch.device objects)" if len(device) > 1: # multi-gpu model = torch.nn.DataParallel(model, device_ids=device) if "cuda" in str(device[0]): model = model.to(device[0]) else: model = model.to(f"cuda:{device[0]}") else: model = model.to(device) # "cpu" or "cuda" return model
[docs] def load_eval_settings( data: str, orbit_on: bool = True ) -> Tuple[List[str], Dict[str, Callable[[np.ndarray, np.ndarray], float]]]: """Load the evaluation settings from the configuration Args: data (str): dataset to use. UNUSED HERE. orbit_on (bool, optional): whether to use orbit distance. UNUSED HERE. Defaults to True. Returns: Tuple[List[str], Dict[str, Callable[[np.ndarray, np.ndarray], float]]]: methods and kernels, used for generic graph generation """ # Settings for generic graph generation # Methods to use (from [degree, cluster, orbit, spectral, nspdk], see evaluation/stats.py) methods = ["degree", "cluster", "orbit", "spectral"] # Kernels to use for each method (from [gaussian, gaussian_emd, gaussian_tv], see evaluation/mmd.py) kernels = { "degree": gaussian_emd, "cluster": gaussian_emd, "orbit": gaussian, "spectral": gaussian_emd, } return methods, kernels