#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""ccsd.py: Code for the CCSD class for training and/or sampling.
"""
import os
import warnings
import matplotlib
import plotly
import wandb
from easydict import EasyDict
from rdkit import RDLogger
from ccsd.src.parsers.config import get_config, get_general_config
from ccsd.src.sampler import Sampler, get_sampler_from_config
from ccsd.src.trainer import Trainer, get_trainer_from_config
from ccsd.src.utils.print import initial_print
from ccsd.src.utils.time_utils import get_time
warnings.filterwarnings("ignore", category=matplotlib.MatplotlibDeprecationWarning)
plotly.io.kaleido.scope.mathjax = None
RDLogger.DisableLog("rdApp.*")
[docs]
class CCSD:
"""CCSD class for training and/or sampling."""
[docs]
def __init__(
self,
type: str,
config: str,
folder: str = "./",
comment: str = "",
seed: int = 42,
) -> None:
"""Initialize the CCSD class.
Args:
type (str): Type of experiment. Choose from ["train", "sample"].
config (str): Path of config file
folder (str, optional): Directory to save the results, load checkpoints, load config, etc. Defaults to "./".
comment (str, optional): A single line comment for the experiment. Defaults to "".
seed (int, optional): Random seed for reproducibility. Defaults to 42.
"""
# Check the type and config
assert type in (
"train",
"sample",
), f"Unknown type: {type}. Please select from [train, sample]."
if config[-5:] == ".yaml":
config = config[:-5]
assert os.path.exists(
os.path.join(folder, "config", f"{config}.yaml")
), f"Config {config} not found."
# General experiment parameters
self.type = type
self.config = config
self.folder = folder
self.comment = comment
self.seed = seed
self.args = EasyDict(
{
"type": type,
"config": config,
"folder": folder,
"comment": comment,
"seed": seed,
}
)
# Objects saved during the experiment
self.cfg = None # config dictionary
self.trainer = None # trainer object
self.sampler = None # sampler object
def __repr__(self) -> str:
"""Representation of the CCSD class.
Returns:
str: representation of the CCSD class.
"""
string_repr = (
f"{self.__class__.__name__}("
f"type={self.type}, "
f"config={self.config}, "
f"folder={self.folder}, "
f"comment={self.comment}, "
f"seed={self.seed}"
)
if self.trainer is not None:
string_repr += f", trainer={self.trainer.__class__.__name__}"
if self.sampler is not None:
string_repr += f", sampler={self.sampler.__class__.__name__}"
string_repr += ")"
return string_repr
[docs]
def run(self) -> None:
"""Run the code for training and/or sampling.
Raises:
ValueError: raise and error the experiment type is not one of [train, sample].
"""
# Get the configuration and the general configuration
config = get_config(self.args.config, self.args.seed, self.args.folder)
general_config = get_general_config(self.args.folder)
# Print the initial message
if general_config.print_initial:
initial_print(self.args)
# Current timestamp (name of the experiment)
timezone = general_config.timezone
ts = get_time(timezone)
# Add some information to the config
config.current_time = ts # add the timestamp to the config
config.experiment_type = self.args.type # add the experiment type to the config
config.config_name = self.args.config # add the config name to the config
config.general_config = general_config # add the general config to the config
config.folder = self.args.folder # add the folder to the config
self.cfg = config # save the config object
# -------- Train --------
if self.args.type == "train":
# Initialize wandb
if general_config.use_wandb:
run_name = f"{self.args.config}_{ts}"
wandb.init(
project=general_config.project_name,
entity=general_config.entity,
config=config,
name=run_name,
)
wandb.run.name = run_name
wandb.run.save()
wandb.config.update(config)
# Train the model
# Select the trainer based on the config
trainer = get_trainer_from_config(config)
# Train the model
ckpt = trainer.train(ts)
self.trainer = trainer # save the trainer object
if "sample" in config.keys(): # then sample from the trained model
config.ckpt = ckpt # load the model that has just been trained
self.cfg = config # save the updated config object
# Select the sampler based on the config
sampler = get_sampler_from_config(config)
# Sample from the model
sampler.sample()
self.sampler = sampler # save the sampler object
# Finish wandb
wandb.finish()
# -------- Generation --------
elif self.args.type == "sample":
# Select the sampler based on the config
sampler = get_sampler_from_config(config)
# Sample from the model
sampler.sample()
self.sampler = sampler # save the sampler object
else:
raise ValueError(
f"Unknown type: {self.args.type}. Please read the documentation and select from [train, sample]."
)
[docs]
def is_trained(self) -> bool:
"""Check if the CCSD model is trained.
Returns:
bool: True if the model is trained, False otherwise.
"""
return self.trainer is not None
[docs]
def get_trainer(self) -> Trainer:
"""Get the trainer object.
Returns:
Trainer: Trainer object.
"""
return self.trainer
[docs]
def get_sampler(self) -> Sampler:
"""Get the sampler object.
Returns:
Sampler: Sampler object.
"""
return self.sampler
[docs]
def get_config(self) -> EasyDict:
"""Get the config object.
Returns:
EasyDict: Config object.
"""
return self.cfg