Source code for ccsd.src.trainer

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""trainer.py: code for training the model.
"""

import abc
import os
import pickle
from time import perf_counter
from typing import Dict, List, Optional

import numpy as np
import torch
import wandb
from easydict import EasyDict
from tqdm import tqdm, trange

from ccsd.src.utils.loader import (
    load_batch,
    load_data,
    load_device,
    load_ema,
    load_loss_fn,
    load_model_optimizer,
    load_model_params,
    load_seed,
)
from ccsd.src.utils.logger import (
    Logger,
    device_log,
    model_parameters_log,
    set_log,
    start_log,
    time_log,
    train_log,
)
from ccsd.src.utils.plot import plot_lc


[docs] class Trainer(abc.ABC): """Abstract class for a Trainer."""
[docs] def __init__(self, config: Optional[EasyDict]) -> None: """Initialize the trainer. Args: config (Optional[EasyDict], optional): the config object to use. Defaults to None. """ super().__init__() self.config = config
[docs] @abc.abstractmethod def train(self, ts: str) -> str: """Train method to load the models, the optimizers, etc, train the model and save the checkpoint. Args: ts (str): checkpoint name (usually a timestamp) Returns: str: checkpoint name """ pass
[docs] def save_learning_curves(self, learning_curves: Dict[str, List[float]]) -> None: """Save the learning curves in a .npy file. Args: learning_curves (Dict[str, List[float]]): the learning curves to save """ log_name = f"{self.config.config_name}_{self.ckpt}" with open( os.path.join(self.log_dir, f"{log_name}_learning_curves.npy"), "wb" ) as f: pickle.dump(learning_curves, f, protocol=pickle.HIGHEST_PROTOCOL)
[docs] def plot_learning_curves(self, learning_curves: Dict[str, List[float]]) -> None: """Plot the learning curves. Args: learning_curves (Dict[str, List[float]]): the learning curves to plot """ # Call the plot function from utils log_name = f"{self.config.config_name}_{self.ckpt}" plot_lc( self.config, learning_curves, self.log_dir, f"{log_name}_learning_curves" )
[docs] class Trainer_Graph(Trainer): """Trainer class for training the model with graphs. Adapted from Jo, J. & al (2022) """
[docs] def __init__(self, config: EasyDict) -> None: """Initialize the trainer with the different configs. Args: config (EasyDict): the config object to use """ super(Trainer_Graph, self).__init__(config) # Load general config self.config = config self.log_folder_name, self.log_dir, self.ckpt_dir = set_log( self.config, is_train=True, folder=self.config.folder ) self.is_cc = self.config.is_cc # Load training config self.seed = load_seed(self.config.seed) self.device = load_device() self.train_loader, self.test_loader = load_data(self.config) self.params_x, self.params_adj = load_model_params(self.config)
def __repr__(self) -> str: """Return the string representation of the Trainer_Graph class. Returns: str: the string representation of the Trainer_Graph class """ return f"{self.__class__.__name__}(is_cc={self.is_cc})"
[docs] def train(self, ts: str) -> str: """Train method to load the models, the optimizers, etc, train the model and save the checkpoint. Args: ts (str): checkpoint name (usually a timestamp) Returns: str: checkpoint name """ self.config.exp_name = ts self.ckpt = f"{ts}" print("\033[91m" + f"{self.ckpt}" + "\033[0m") # -------- Load models, optimizers, ema -------- self.model_x, self.optimizer_x, self.scheduler_x = load_model_optimizer( self.params_x, self.config.train, self.device ) self.model_adj, self.optimizer_adj, self.scheduler_adj = load_model_optimizer( self.params_adj, self.config.train, self.device ) self.ema_x = load_ema(self.model_x, decay=self.config.train.ema) self.ema_adj = load_ema(self.model_adj, decay=self.config.train.ema) log_name = f"{self.config.config_name}_{self.ckpt}" logger = Logger(str(os.path.join(self.log_dir, f"{log_name}.log")), mode="a") logger.log(f"{self.ckpt}", verbose=False) device_log(logger, self.device) for m in [self.model_x, self.model_adj]: logger.log( f"Model {m.__class__.__name__} loaded on {next(m.parameters()).device.type}" ) start_log(logger, self.config) train_log(logger, self.config) model_parameters_log(logger, [self.model_x, self.model_adj]) self.loss_fn = load_loss_fn(self.config) # -------- Training -------- print("Training started...") start_train_time = perf_counter() for epoch in trange( 0, (self.config.train.num_epochs), desc="[Epoch]", position=1, leave=False ): self.train_x = [] self.train_adj = [] self.test_x = [] self.test_adj = [] t_start = perf_counter() self.model_x.train() self.model_adj.train() for _, train_b in enumerate(self.train_loader): self.optimizer_x.zero_grad() self.optimizer_adj.zero_grad() x, adj = load_batch(train_b, self.device, is_cc=self.is_cc) loss_subject = (x, adj) loss_x, loss_adj = self.loss_fn( self.model_x, self.model_adj, *loss_subject ) loss_x.backward() loss_adj.backward() torch.nn.utils.clip_grad_norm_( self.model_x.parameters(), self.config.train.grad_norm ) torch.nn.utils.clip_grad_norm_( self.model_adj.parameters(), self.config.train.grad_norm ) self.optimizer_x.step() self.optimizer_adj.step() # -------- EMA update -------- self.ema_x.update(self.model_x.parameters()) self.ema_adj.update(self.model_adj.parameters()) self.train_x.append(loss_x.item()) self.train_adj.append(loss_adj.item()) if self.config.train.lr_schedule: self.scheduler_x.step() self.scheduler_adj.step() self.model_x.eval() self.model_adj.eval() for _, test_b in enumerate(self.test_loader): x, adj = load_batch(test_b, self.device, is_cc=self.is_cc) loss_subject = (x, adj) with torch.no_grad(): self.ema_x.store(self.model_x.parameters()) self.ema_x.copy_to(self.model_x.parameters()) self.ema_adj.store(self.model_adj.parameters()) self.ema_adj.copy_to(self.model_adj.parameters()) loss_x, loss_adj = self.loss_fn( self.model_x, self.model_adj, *loss_subject ) self.test_x.append(loss_x.item()) self.test_adj.append(loss_adj.item()) self.ema_x.restore(self.model_x.parameters()) self.ema_adj.restore(self.model_adj.parameters()) mean_train_x = np.mean(self.train_x) mean_train_adj = np.mean(self.train_adj) mean_test_x = np.mean(self.test_x) mean_test_adj = np.mean(self.test_adj) # -------- Log losses -------- # Logger logger.log( f"{epoch+1:03d} | {perf_counter()-t_start:.2f}s | " f"test x: {mean_test_x:.3e} | test adj: {mean_test_adj:.3e} | " f"train x: {mean_train_x:.3e} | train adj: {mean_train_adj:.3e} | ", verbose=False, ) # Wandb wandb.log( { "epoch": epoch + 1, "time": perf_counter() - t_start, "test_x_loss": mean_test_x, "test_adj_loss": mean_test_adj, "train_x_loss": mean_train_x, "train_adj_loss": mean_train_adj, } ) # -------- Save checkpoints -------- if ( epoch % self.config.train.save_interval == self.config.train.save_interval - 1 ): save_name = ( f"_{epoch+1}" if epoch < self.config.train.num_epochs - 1 else "" ) torch.save( { "model_config": self.config, "params_x": self.params_x, "params_adj": self.params_adj, "x_state_dict": self.model_x.state_dict(), "adj_state_dict": self.model_adj.state_dict(), "ema_x": self.ema_x.state_dict(), "ema_adj": self.ema_adj.state_dict(), }, os.path.join( self.config.folder, "checkpoints", f"{self.config.data.data}", f"{self.ckpt + save_name}.pth", ), ) if ( epoch % self.config.train.print_interval == self.config.train.print_interval - 1 ): tqdm.write( f"[EPOCH {epoch+1:04d}] test adj: {mean_test_adj:.3e} | train adj: {mean_train_adj:.3e} | " f"test x: {mean_test_x:.3e} | train x: {mean_train_x:.3e}" ) print("Training complete.") training_time = perf_counter() - start_train_time time_log(logger, time_type="train", elapsed_time=training_time) wandb.log({"Training time": training_time}) # -------- Save final model -------- torch.save( { "model_config": self.config, "params_x": self.params_x, "params_adj": self.params_adj, "x_state_dict": self.model_x.state_dict(), "adj_state_dict": self.model_adj.state_dict(), "ema_x": self.ema_x.state_dict(), "ema_adj": self.ema_adj.state_dict(), }, os.path.join( self.config.folder, "checkpoints", f"{self.config.data.data}", f"{self.ckpt}_final.pth", ), ) # -------- Save learning curves and plots -------- learning_curves = { "train_x": self.train_x, "train_adj": self.train_adj, "test_x": self.test_x, "test_adj": self.test_adj, } self.save_learning_curves(learning_curves) self.plot_learning_curves(learning_curves) if ( self.config.experiment_type == "train" ) and self.config.general_config.use_wandb: # add plots to wandb img_path = os.path.join( os.path.join(*[self.log_dir, "fig"]), f"{self.config.config_name}_{self.ckpt}_learning_curves.png", ) wandb.log({"Learning Curves": wandb.Image(img_path)}) return f"{self.ckpt}_final"
[docs] class Trainer_CC(Trainer): """Trainer class for training the model with combinatorial complexes."""
[docs] def __init__(self, config: EasyDict) -> None: """Initialize the trainer with the different configs. Args: config (EasyDict): the config object to use """ super(Trainer_CC, self).__init__(config) # Load general config self.config = config self.log_folder_name, self.log_dir, self.ckpt_dir = set_log( self.config, is_train=True, folder=self.config.folder ) self.is_cc = self.config.is_cc # Load training config self.seed = load_seed(self.config.seed) self.device = load_device() self.train_loader, self.test_loader = load_data(self.config, is_cc=True) self.params_x, self.params_adj, self.params_rank2 = load_model_params( self.config, is_cc=True )
def __repr__(self) -> str: """Return the string representation of the Trainer_CC class. Returns: str: the string representation of the Trainer_CC class """ return f"{self.__class__.__name__}(is_cc={self.is_cc})"
[docs] def train(self, ts: str) -> str: """Train method to load the models, the optimizers, etc, train the model and save the checkpoint. Args: ts (str): checkpoint name (usually a timestamp) Returns: str: checkpoint name """ self.config.exp_name = ts self.ckpt = f"{ts}" print("\033[91m" + f"{self.ckpt}" + "\033[0m") # -------- Load models, optimizers, ema -------- self.model_x, self.optimizer_x, self.scheduler_x = load_model_optimizer( self.params_x, self.config.train, self.device ) self.model_adj, self.optimizer_adj, self.scheduler_adj = load_model_optimizer( self.params_adj, self.config.train, self.device ) ( self.model_rank2, self.optimizer_rank2, self.scheduler_rank2, ) = load_model_optimizer(self.params_rank2, self.config.train, self.device) self.ema_x = load_ema(self.model_x, decay=self.config.train.ema) self.ema_adj = load_ema(self.model_adj, decay=self.config.train.ema) self.ema_rank2 = load_ema(self.model_rank2, decay=self.config.train.ema) log_name = f"{self.config.config_name}_{self.ckpt}" logger = Logger(str(os.path.join(self.log_dir, f"{log_name}.log")), mode="a") logger.log(f"{self.ckpt}", verbose=False) device_log(logger, self.device) for m in [self.model_x, self.model_adj, self.model_rank2]: logger.log( f"Model {m.__class__.__name__} loaded on {next(m.parameters()).device.type}" ) start_log(logger, self.config) train_log(logger, self.config) model_parameters_log(logger, [self.model_x, self.model_adj, self.model_rank2]) self.loss_fn = load_loss_fn(self.config, is_cc=True) # -------- Training -------- print("Training started...") start_train_time = perf_counter() for epoch in trange( 0, (self.config.train.num_epochs), desc="[Epoch]", position=1, leave=False ): self.train_x = [] self.train_adj = [] self.train_rank2 = [] self.test_x = [] self.test_adj = [] self.test_rank2 = [] t_start = perf_counter() self.model_x.train() self.model_adj.train() self.model_rank2.train() for _, train_b in enumerate(self.train_loader): self.optimizer_x.zero_grad() self.optimizer_adj.zero_grad() self.optimizer_rank2.zero_grad() x, adj, rank2 = load_batch(train_b, self.device, is_cc=self.is_cc) loss_subject = (x, adj, rank2) loss_x, loss_adj, loss_rank2 = self.loss_fn( self.model_x, self.model_adj, self.model_rank2, *loss_subject ) loss_x.backward() loss_adj.backward() loss_rank2.backward() torch.nn.utils.clip_grad_norm_( self.model_x.parameters(), self.config.train.grad_norm ) torch.nn.utils.clip_grad_norm_( self.model_adj.parameters(), self.config.train.grad_norm ) torch.nn.utils.clip_grad_norm_( self.model_rank2.parameters(), self.config.train.grad_norm ) self.optimizer_x.step() self.optimizer_adj.step() self.optimizer_rank2.step() # -------- EMA update -------- self.ema_x.update(self.model_x.parameters()) self.ema_adj.update(self.model_adj.parameters()) self.ema_rank2.update(self.model_rank2.parameters()) self.train_x.append(loss_x.item()) self.train_adj.append(loss_adj.item()) self.train_rank2.append(loss_rank2.item()) if self.config.train.lr_schedule: self.scheduler_x.step() self.scheduler_adj.step() self.scheduler_rank2.step() self.model_x.eval() self.model_adj.eval() self.model_rank2.eval() for _, test_b in enumerate(self.test_loader): x, adj, rank2 = load_batch(test_b, self.device, is_cc=self.is_cc) loss_subject = (x, adj, rank2) with torch.no_grad(): self.ema_x.store(self.model_x.parameters()) self.ema_x.copy_to(self.model_x.parameters()) self.ema_adj.store(self.model_adj.parameters()) self.ema_adj.copy_to(self.model_adj.parameters()) self.ema_rank2.store(self.model_rank2.parameters()) self.ema_rank2.copy_to(self.model_rank2.parameters()) loss_x, loss_adj, loss_rank2 = self.loss_fn( self.model_x, self.model_adj, self.model_rank2, *loss_subject ) self.test_x.append(loss_x.item()) self.test_adj.append(loss_adj.item()) self.test_rank2.append(loss_rank2.item()) self.ema_x.restore(self.model_x.parameters()) self.ema_adj.restore(self.model_adj.parameters()) self.ema_rank2.restore(self.model_rank2.parameters()) mean_train_x = np.mean(self.train_x) mean_train_adj = np.mean(self.train_adj) mean_train_rank2 = np.mean(self.train_rank2) mean_test_x = np.mean(self.test_x) mean_test_adj = np.mean(self.test_adj) mean_test_rank2 = np.mean(self.test_rank2) # -------- Log losses -------- logger.log( f"{epoch+1:03d} | {perf_counter()-t_start:.2f}s | " f"test x: {mean_test_x:.3e} | test adj: {mean_test_adj:.3e} | test rank2: {mean_test_rank2:.3e} | " f"train x: {mean_train_x:.3e} | train adj: {mean_train_adj:.3e} | train rank2: {mean_train_rank2:.3e} |", verbose=False, ) # Wandb wandb.log( { "epoch": epoch + 1, "time": perf_counter() - t_start, "test_x_loss": mean_test_x, "test_adj_loss": mean_test_adj, "test_rank2_loss": mean_test_rank2, "train_x_loss": mean_train_x, "train_adj_loss": mean_train_adj, "train_rank2_loss": mean_train_rank2, } ) # -------- Save checkpoints -------- if ( epoch % self.config.train.save_interval == self.config.train.save_interval - 1 ): save_name = ( f"_{epoch+1}" if epoch < self.config.train.num_epochs - 1 else "" ) torch.save( { "model_config": self.config, "params_x": self.params_x, "params_adj": self.params_adj, "params_rank2": self.params_rank2, "x_state_dict": self.model_x.state_dict(), "adj_state_dict": self.model_adj.state_dict(), "rank2_state_dict": self.model_rank2.state_dict(), "ema_x": self.ema_x.state_dict(), "ema_adj": self.ema_adj.state_dict(), "ema_rank2": self.ema_rank2.state_dict(), }, os.path.join( self.config.folder, "checkpoints", f"{self.config.data.data}", f"{self.ckpt + save_name}.pth", ), ) if ( epoch % self.config.train.print_interval == self.config.train.print_interval - 1 ): tqdm.write( f"[EPOCH {epoch+1:04d}] test adj: {mean_test_adj:.3e} | train adj: {mean_train_adj:.3e} | " f"test x: {mean_test_x:.3e} | train x: {mean_train_x:.3e} | " f"test rank2: {mean_test_rank2:.3e} | train rank2: {mean_train_rank2:.3e}" ) print("Training complete.") training_time = perf_counter() - start_train_time time_log(logger, time_type="train", elapsed_time=training_time) wandb.log({"Training time": training_time}) # -------- Save final model -------- torch.save( { "model_config": self.config, "params_x": self.params_x, "params_adj": self.params_adj, "params_rank2": self.params_rank2, "x_state_dict": self.model_x.state_dict(), "adj_state_dict": self.model_adj.state_dict(), "rank2_state_dict": self.model_rank2.state_dict(), "ema_x": self.ema_x.state_dict(), "ema_adj": self.ema_adj.state_dict(), "ema_rank2": self.ema_rank2.state_dict(), }, os.path.join( self.config.folder, "checkpoints", f"{self.config.data.data}", f"{self.ckpt}_final.pth", ), ) # -------- Save learning curves and plots -------- learning_curves = { "train_x": self.train_x, "train_adj": self.train_adj, "train_rank2": self.train_rank2, "test_x": self.test_x, "test_adj": self.test_adj, "test_rank2": self.test_rank2, } self.save_learning_curves(learning_curves) self.plot_learning_curves(learning_curves) if ( self.config.experiment_type == "train" ) and self.config.general_config.use_wandb: # add plots to wandb img_path = os.path.join( os.path.join(*[self.log_dir, "fig"]), f"{self.config.config_name}_{self.ckpt}_learning_curves.png", ) wandb.log({"Learning Curves": wandb.Image(img_path)}) return f"{self.ckpt}_final"
[docs] def get_trainer_from_config( config: EasyDict, ) -> Trainer: """Get the trainer from a configuration file config Args: config (EasyDict): configuration file Returns: Trainer: trainer to use for the experiment """ if config.is_cc: trainer = Trainer_CC(config) else: trainer = Trainer_Graph(config) return trainer