#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""sampler.py: code for sampling from the model.
"""
import abc
import math
import os
import pickle
from time import perf_counter
import torch
import wandb
from easydict import EasyDict
from moses.metrics.metrics import get_all_metrics
from ccsd.src.evaluation.stats import eval_graph_list
from ccsd.src.utils.cc_utils import (
cc_from_incidence,
convert_CC_to_graphs,
convert_graphs_to_CCs,
eval_CC_list,
init_flags,
load_cc_eval_settings,
mols_to_cc,
)
from ccsd.src.utils.graph_utils import (
adjs_to_graphs,
nxs_to_mols,
quantize,
quantize_mol,
)
from ccsd.src.utils.loader import (
load_ckpt,
load_data,
load_device,
load_ema_from_ckpt,
load_eval_settings,
load_model_from_ckpt,
load_sampling_fn,
load_seed,
)
from ccsd.src.utils.logger import (
Logger,
check_log,
device_log,
sample_log,
set_log,
start_log,
time_log,
train_log,
)
from ccsd.src.utils.mol_utils import (
canonicalize_smiles,
gen_mol,
is_molecular_config,
load_smiles,
mols_to_nx,
mols_to_smiles,
)
from ccsd.src.utils.plot import (
diffusion_animation,
plot_3D_molecule,
plot_cc_list,
plot_graphs_list,
plot_molecule_list,
rotate_molecule_animation,
save_cc_list,
save_graph_list,
save_molecule_list,
)
[docs]
class Sampler(abc.ABC):
"""Abstract class for Sampler objects."""
[docs]
def __init__(self, config: EasyDict) -> None:
"""Initialize the sampler.
Args:
config (EasyDict): the config object to use
"""
self.config = config
[docs]
@abc.abstractmethod
def sample(self) -> None:
"""Sample from the model. Loads the checkpoint, load the modes, generates samples, evaluates, saves and plot them."""
pass
[docs]
class Sampler_Graph(Sampler):
"""Sampler for generic graph generation tasks
Adapted from Jo, J. & al (2022)
"""
[docs]
def __init__(self, config: EasyDict) -> None:
"""Initialize the sampler with the config and the device.
Args:
config (EasyDict): the config object to use
"""
super(Sampler_Graph, self).__init__(config)
self.config = config
self.device = load_device()
self.device0 = self.device[0] if isinstance(self.device, list) else self.device
# Device to compute the score metrics
if self.device0 == "cpu":
self.device_score = "cpu"
elif "cuda" in str(self.device0):
self.device_score = str(self.device0)
else:
self.device_score = f"cuda:{self.device0}"
self.n_samples = self.config.sample.get(
"n_samples", self.config.data.get("batch_size", None)
)
self.cc_nb_eval = None
# Worker kwargs for CC eval
self.worker_kwargs = {
"min_node_val": self.config.data.min_node_val,
"max_node_val": self.config.data.max_node_val,
"node_label": self.config.data.node_label,
"min_edge_val": self.config.data.min_edge_val,
"max_edge_val": self.config.data.max_edge_val,
"edge_label": self.config.data.edge_label,
"d_min": self.config.data.d_min,
"d_max": self.config.data.d_max,
"N": self.config.data.max_node_num,
}
self.divide_batch = self.config.sample.get("divide_batch", 1)
def __repr__(self) -> str:
"""Return the string representation of the sampler."""
return f"{self.__class__.__name__}(batch_size={self.config.data.batch_size}, cc_nb_eval={self.cc_nb_eval})"
[docs]
def sample(self) -> None:
"""Sample from the model. Loads the checkpoint, load the modes, generates samples, evaluates, saves and plot them."""
# -------- Load checkpoint --------
self.ckpt_dict = load_ckpt(self.config, self.device, is_cc=False)
self.configt = self.ckpt_dict["config"]
load_seed(self.configt.seed)
self.train_graph_list, self.test_graph_list = load_data(
self.configt, get_list=True
)
self.log_folder_name, self.log_dir, _ = set_log(
self.configt, is_train=False, folder=self.config.folder
)
self.log_name = f"{self.config.config_name}_{self.config.ckpt}-sample_{self.config.current_time}"
logger = Logger(
str(os.path.join(self.log_dir, f"{self.log_name}.log")), mode="a"
)
if not check_log(self.log_folder_name, self.log_name):
logger.log(f"{self.log_name}")
device_log(logger, self.device)
start_log(logger, self.configt)
train_log(logger, self.configt)
sample_log(logger, self.config)
# -------- Load models --------
self.model_x = load_model_from_ckpt(
self.ckpt_dict["params_x"], self.ckpt_dict["x_state_dict"], self.device
)
self.model_adj = load_model_from_ckpt(
self.ckpt_dict["params_adj"], self.ckpt_dict["adj_state_dict"], self.device
)
for m in [self.model_x, self.model_adj]:
logger.log(
f"Model {m.__class__.__name__} loaded on {next(m.parameters()).device.type}"
)
if self.config.sample.use_ema:
self.ema_x = load_ema_from_ckpt(
self.model_x, self.ckpt_dict["ema_x"], self.configt.train.ema
)
self.ema_adj = load_ema_from_ckpt(
self.model_adj, self.ckpt_dict["ema_adj"], self.configt.train.ema
)
self.ema_x.copy_to(self.model_x.parameters())
self.ema_adj.copy_to(self.model_adj.parameters())
self.sampling_fn = load_sampling_fn(
self.configt,
self.config.sampler,
self.config.sample,
self.device,
divide_batch=self.divide_batch,
)
# -------- Generate samples --------
logger.log(f"GEN SEED: {self.config.sample.seed}")
load_seed(self.config.sample.seed)
num_sampling_rounds = math.ceil(
len(self.test_graph_list) / self.configt.data.batch_size
)
gen_graph_list = []
logger.log(
f"Number sampling rounds: {num_sampling_rounds}, number of samples per round: {self.config.data.batch_size}"
)
start_sampling_time = perf_counter()
for r in range(num_sampling_rounds):
t_start = perf_counter()
qty_to_generate = (
self.n_samples // self.divide_batch
if self.n_samples is not None
else None
)
self.init_flags = init_flags(
self.train_graph_list, self.configt, qty_to_generate
).to(self.device0)
x, adj, _, diff_traj = self.sampling_fn(
self.model_x, self.model_adj, self.init_flags
)
for _ in range(1, self.divide_batch):
self.init_flags = init_flags(
self.train_graph_list, self.configt, qty_to_generate
).to(self.device0)
x_, adj_, _, _ = self.sampling_fn(
self.model_x, self.model_adj, self.init_flags
)
x = torch.cat((x, x_), dim=0)
adj = torch.cat((adj, adj_), dim=0)
logger.log(f"Round {r} : {perf_counter()-t_start:.2f}s")
samples_int = quantize(adj)
gen_graph_list.extend(adjs_to_graphs(samples_int, True))
gen_graph_list = gen_graph_list[: len(self.test_graph_list)]
print("Sampling done.")
sampling_time = perf_counter() - start_sampling_time
time_log(logger, time_type="sample", elapsed_time=sampling_time)
if (
self.config.experiment_type == "train"
) and self.config.general_config.use_wandb:
wandb.log({"Sampling time": sampling_time})
# -------- Evaluation --------
# Eval graphs
methods, kernels = load_eval_settings(self.config.data.data)
result_dict_graph = eval_graph_list(
self.test_graph_list,
gen_graph_list,
methods=methods,
kernels=kernels,
folder=self.config.folder,
)
# Eval lifted CCs from the graphs
# Lift the test graphs and generate graphs into CCs for evaluation
lifting_procedure = self.config.data.lifting_procedure
lifting_procedure_kwargs = self.config.data.lifting_procedure_kwargs
self.test_CC_list = convert_graphs_to_CCs(
self.test_graph_list,
is_molecule=False,
lifting_procedure=lifting_procedure,
lifting_procedure_kwargs=lifting_procedure_kwargs,
max_nb_nodes=self.config.data.max_node_num,
)
gen_CC_list = convert_graphs_to_CCs(
gen_graph_list,
is_molecule=False,
lifting_procedure=lifting_procedure,
lifting_procedure_kwargs=lifting_procedure_kwargs,
max_nb_nodes=self.config.data.max_node_num,
) # same for the generated graphs
methods, kernels = load_cc_eval_settings()
result_dict_CC = eval_CC_list(
self.test_CC_list,
gen_CC_list,
worker_kwargs=self.worker_kwargs,
methods=methods,
kernels=kernels,
cc_nb_eval=self.cc_nb_eval,
)
logger.log(
f"CCs Eval @{self.cc_nb_eval} {result_dict_CC}", verbose=False
) # verbose=False cause already printed
logger.log(
f"MMD_full {result_dict_graph}", verbose=False
) # verbose=False cause already printed
logger.log(100 * "=")
if (
self.config.experiment_type == "train"
) and self.config.general_config.use_wandb:
# add scores to wandb
wandb.log(result_dict_graph)
# -------- Save samples & Plot --------
# Graphs
save_dir = save_graph_list(
self.config, self.log_folder_name, self.log_name + "_graphs", gen_graph_list
)
with open(save_dir, "rb") as f:
sample_graph_list = pickle.load(f)
plot_graphs_list(
config=self.config,
graphs=sample_graph_list,
title=f"graphs_{self.log_name}",
max_num=16,
save_dir=self.log_folder_name,
)
if (
self.config.experiment_type == "train"
) and self.config.general_config.use_wandb:
# add plots to wandb
img_path = os.path.join(
os.path.join(
*[self.config.folder, "samples", "fig", self.log_folder_name]
),
f"graphs_{self.log_name}.png",
)
wandb.log({"Generated Graphs": wandb.Image(img_path)})
# Diffusion trajectory animation
if self.config.general_config.plotly_fig:
filedir = os.path.join(
*[self.config.folder, "samples", "fig", self.log_folder_name]
)
filename = f"diff_traj_graphs_{self.log_name}.gif"
diffusion_animation(
diff_traj=diff_traj,
is_molecule=False,
filedir=filedir,
filename=filename,
fps=25,
overwrite=True,
engine=self.config.general_config.engine,
)
if (
self.config.experiment_type == "train"
) and self.config.general_config.use_wandb:
# add plots to wandb
img_path = os.path.join(filedir, filename)
wandb.log({"Diffusion Trajectory Graph": wandb.Image(img_path)})
# Cropped
filename = f"diff_traj_graphs_cropped_{self.log_name}.gif"
diffusion_animation(
diff_traj=diff_traj,
is_molecule=False,
filedir=filedir,
filename=filename,
fps=25,
overwrite=True,
engine=self.config.general_config.engine,
cropped=True,
)
if (
self.config.experiment_type == "train"
) and self.config.general_config.use_wandb:
# add plots to wandb
img_path = os.path.join(filedir, filename)
wandb.log({"Diffusion Trajectory Graph Cropped": wandb.Image(img_path)})
[docs]
class Sampler_CC(Sampler):
"""Sampler for generic combinatorial complexes generation tasks
Adapted from Jo, J. & al (2022)
"""
[docs]
def __init__(self, config: EasyDict) -> None:
"""Initialize the sampler with the config and the device.
Args:
config (EasyDict): the config object to use
"""
super(Sampler_CC, self).__init__(config)
self.config = config
self.device = load_device()
self.device0 = self.device[0] if isinstance(self.device, list) else self.device
# Device to compute the score metrics
if self.device0 == "cpu":
self.device_score = "cpu"
elif "cuda" in str(self.device0):
self.device_score = str(self.device0)
else:
self.device_score = f"cuda:{self.device0}"
self.n_samples = self.config.sample.get(
"n_samples", self.config.data.get("batch_size", None)
)
self.cc_nb_eval = None
# Worker kwargs for CC eval
self.worker_kwargs = {
"min_node_val": self.config.data.min_node_val,
"max_node_val": self.config.data.max_node_val,
"node_label": self.config.data.node_label,
"min_edge_val": self.config.data.min_edge_val,
"max_edge_val": self.config.data.max_edge_val,
"edge_label": self.config.data.edge_label,
"d_min": self.config.data.d_min,
"d_max": self.config.data.d_max,
"N": self.config.data.max_node_num,
}
self.divide_batch = self.config.sample.get("divide_batch", 1)
def __repr__(self) -> str:
"""Return the string representation of the sampler."""
return f"{self.__class__.__name__}(batch_size={self.config.data.batch_size})"
[docs]
def sample(self) -> None:
"""Sample from the model. Loads the checkpoint, load the modes, generates samples, evaluates, saves and plot them."""
# -------- Load checkpoint --------
self.ckpt_dict = load_ckpt(self.config, self.device, is_cc=True)
self.configt = self.ckpt_dict["config"]
load_seed(self.configt.seed)
self.train_CC_list, self.test_CC_list = load_data(
self.configt, get_list=True, is_cc=True
)
self.log_folder_name, self.log_dir, _ = set_log(
self.configt, is_train=False, folder=self.config.folder
)
self.log_name = f"{self.config.config_name}_{self.config.ckpt}-sample_{self.config.current_time}"
logger = Logger(
str(os.path.join(self.log_dir, f"{self.log_name}.log")), mode="a"
)
if not check_log(self.log_folder_name, self.log_name):
logger.log(f"{self.log_name}")
device_log(logger, self.device)
start_log(logger, self.configt)
train_log(logger, self.configt)
sample_log(logger, self.config)
# -------- Load models --------
self.model_x = load_model_from_ckpt(
self.ckpt_dict["params_x"], self.ckpt_dict["x_state_dict"], self.device
)
self.model_adj = load_model_from_ckpt(
self.ckpt_dict["params_adj"], self.ckpt_dict["adj_state_dict"], self.device
)
self.model_rank2 = load_model_from_ckpt(
self.ckpt_dict["params_rank2"],
self.ckpt_dict["rank2_state_dict"],
self.device,
)
for m in [self.model_x, self.model_adj, self.model_rank2]:
logger.log(
f"Model {m.__class__.__name__} loaded on {next(m.parameters()).device.type}"
)
if self.config.sample.use_ema:
self.ema_x = load_ema_from_ckpt(
self.model_x, self.ckpt_dict["ema_x"], self.configt.train.ema
)
self.ema_adj = load_ema_from_ckpt(
self.model_adj, self.ckpt_dict["ema_adj"], self.configt.train.ema
)
self.ema_rank2 = load_ema_from_ckpt(
self.model_rank2, self.ckpt_dict["ema_rank2"], self.configt.train.ema
)
self.ema_x.copy_to(self.model_x.parameters())
self.ema_adj.copy_to(self.model_adj.parameters())
self.ema_rank2.copy_to(self.model_rank2.parameters())
self.sampling_fn = load_sampling_fn(
self.configt,
self.config.sampler,
self.config.sample,
self.device,
is_cc=True,
d_min=self.config.data.d_min,
d_max=self.config.data.d_max,
divide_batch=self.divide_batch,
)
# -------- Generate samples --------
logger.log(f"GEN SEED: {self.config.sample.seed}")
load_seed(self.config.sample.seed)
num_sampling_rounds = math.ceil(
len(self.test_CC_list) / self.configt.data.batch_size
)
gen_CC_list = []
logger.log(
f"Number sampling rounds: {num_sampling_rounds}, number of samples per round: {self.config.data.batch_size}"
)
start_sampling_time = perf_counter()
for r in range(num_sampling_rounds):
t_start = perf_counter()
qty_to_generate = (
self.n_samples // self.divide_batch
if self.n_samples is not None
else None
)
self.init_flags = init_flags(
self.train_CC_list,
self.configt,
qty_to_generate,
is_cc=True,
).to(self.device0)
x, adj, rank2, _, diff_traj = self.sampling_fn(
self.model_x, self.model_adj, self.model_rank2, self.init_flags
)
for _ in range(1, self.divide_batch):
self.init_flags = init_flags(
self.train_CC_list,
self.configt,
qty_to_generate,
is_cc=True,
).to(self.device0)
x_, adj_, rank2_, _, _ = self.sampling_fn(
self.model_x, self.model_adj, self.model_rank2, self.init_flags
)
x = torch.cat((x, x_), dim=0)
adj = torch.cat((adj, adj_), dim=0)
rank2 = torch.cat((rank2, rank2_), dim=0)
logger.log(f"Round {r} : {perf_counter()-t_start:.2f}s")
samples_int = quantize(adj)
sample_int_rank2 = quantize(rank2)
gen_CC_list.extend(
[
cc_from_incidence(
[x_, adj_, rank2_],
d_min=self.config.data.d_min,
d_max=self.config.data.d_max,
is_molecule=False,
)
for x_, adj_, rank2_ in zip(x, samples_int, sample_int_rank2)
]
)
gen_CC_list = gen_CC_list[: len(self.test_CC_list)]
print("Sampling done.")
sampling_time = perf_counter() - start_sampling_time
time_log(logger, time_type="sample", elapsed_time=sampling_time)
if (
self.config.experiment_type == "train"
) and self.config.general_config.use_wandb:
wandb.log({"Sampling time": sampling_time})
# -------- Evaluation --------
# Convert CC into graphs for evaluation
self.test_graph_list = convert_CC_to_graphs(self.test_CC_list)
gen_graph_list = convert_CC_to_graphs(gen_CC_list)
# Eval graphs
methods, kernels = load_eval_settings(self.config.data.data)
result_dict_graph = eval_graph_list(
self.test_graph_list,
gen_graph_list,
methods=methods,
kernels=kernels,
folder=self.config.folder,
)
# Eval CCs
methods, kernels = load_cc_eval_settings()
result_dict_CC = eval_CC_list(
self.test_CC_list,
gen_CC_list,
worker_kwargs=self.worker_kwargs,
methods=methods,
kernels=kernels,
cc_nb_eval=self.cc_nb_eval,
)
logger.log(
f"MMD_full {result_dict_graph}", verbose=False
) # verbose=False cause already printed
logger.log(
f"CCs eval @{self.cc_nb_eval} {result_dict_CC}", verbose=False
) # verbose=False cause already printed
logger.log(100 * "=")
if (
self.config.experiment_type == "train"
) and self.config.general_config.use_wandb:
# add scores to wandb
wandb.log(result_dict_graph)
wandb.log(result_dict_CC)
# -------- Save samples & Plot --------
# Ccs
save_dir = save_cc_list(
self.config, self.log_folder_name, self.log_name + "_ccs", gen_CC_list
)
with open(save_dir, "rb") as f:
sample_CC_list = pickle.load(f)
plot_cc_list(
config=self.config,
ccs=sample_CC_list,
title=f"ccs_{self.log_name}",
max_num=16,
save_dir=self.log_folder_name,
)
if (
self.config.experiment_type == "train"
) and self.config.general_config.use_wandb:
# add plots to wandb
img_path = os.path.join(
os.path.join(
*[self.config.folder, "samples", "fig", self.log_folder_name]
),
f"ccs_{self.log_name}.png",
)
wandb.log({"Generated Combinatorial Complexes": wandb.Image(img_path)})
# Graphs
save_dir = save_graph_list(
self.config, self.log_folder_name, self.log_name + "_graphs", gen_graph_list
)
with open(save_dir, "rb") as f:
sample_graph_list = pickle.load(f)
plot_graphs_list(
config=self.config,
graphs=sample_graph_list,
title=f"graphs_{self.log_name}",
max_num=16,
save_dir=self.log_folder_name,
)
if (
self.config.experiment_type == "train"
) and self.config.general_config.use_wandb:
# add plots to wandb
img_path = os.path.join(
os.path.join(
*[self.config.folder, "samples", "fig", self.log_folder_name]
),
f"graphs_{self.log_name}.png",
)
wandb.log({"Generated Graphs": wandb.Image(img_path)})
# Diffusion trajectory animation
if self.config.general_config.plotly_fig:
filedir = os.path.join(
*[self.config.folder, "samples", "fig", self.log_folder_name]
)
filename = f"diff_traj_graphs_{self.log_name}.gif"
diffusion_animation(
diff_traj=diff_traj,
is_molecule=False,
filedir=filedir,
filename=filename,
fps=25,
overwrite=True,
engine=self.config.general_config.engine,
)
if (
self.config.experiment_type == "train"
) and self.config.general_config.use_wandb:
# add plots to wandb
img_path = os.path.join(filedir, filename)
wandb.log({"Diffusion Trajectory Graph": wandb.Image(img_path)})
# Cropped
filename = f"diff_traj_graphs_cropped_{self.log_name}.gif"
diffusion_animation(
diff_traj=diff_traj,
is_molecule=False,
filedir=filedir,
filename=filename,
fps=25,
overwrite=True,
engine=self.config.general_config.engine,
cropped=True,
)
if (
self.config.experiment_type == "train"
) and self.config.general_config.use_wandb:
# add plots to wandb
img_path = os.path.join(filedir, filename)
wandb.log({"Diffusion Trajectory Graph Cropped": wandb.Image(img_path)})
[docs]
class Sampler_mol_Graph(Sampler):
"""Sampler for molecule generation tasks"""
[docs]
def __init__(self, config: EasyDict) -> None:
"""Initialize the sampler with the config and the device.
Args:
config (EasyDict): the config object to use
"""
super(Sampler_mol_Graph, self).__init__(config)
self.config = config
self.device = load_device()
self.device0 = self.device[0] if isinstance(self.device, list) else self.device
# Device to compute the score metrics
if self.device0 == "cpu":
self.device_score = "cpu"
elif "cuda" in str(self.device0):
self.device_score = str(self.device0)
else:
self.device_score = f"cuda:{self.device0}"
self.n_samples = self.config.sample.get(
"n_samples", self.config.data.get("batch_size", None)
)
self.cc_nb_eval = self.config.sample.cc_nb_eval
# Worker kwargs for CC eval
self.worker_kwargs = {
"min_node_val": self.config.data.min_node_val,
"max_node_val": self.config.data.max_node_val,
"node_label": self.config.data.node_label,
"min_edge_val": self.config.data.min_edge_val,
"max_edge_val": self.config.data.max_edge_val,
"edge_label": self.config.data.edge_label,
"d_min": self.config.data.d_min,
"d_max": self.config.data.d_max,
"N": self.config.data.max_node_num,
}
self.divide_batch = self.config.sample.get("divide_batch", 1)
def __repr__(self) -> str:
"""Return the string representation of the sampler."""
return f"{self.__class__.__name__}(n_samples={self.n_samples}, cc_nb_eval={self.cc_nb_eval})"
[docs]
def sample(self) -> None:
"""Sample from the model. Loads the checkpoint, load the modes, generates samples, evaluates and saves them."""
# -------- Load checkpoint --------
self.ckpt_dict = load_ckpt(self.config, self.device)
self.configt = self.ckpt_dict["config"]
load_seed(self.config.seed)
self.log_folder_name, self.log_dir, _ = set_log(
self.configt, is_train=False, folder=self.config.folder
)
self.log_name = f"{self.config.config_name}_{self.config.ckpt}-sample_{self.config.current_time}"
logger = Logger(
str(os.path.join(self.log_dir, f"{self.log_name}.log")), mode="a"
)
if not check_log(self.log_folder_name, self.log_name):
logger.log(f"{self.log_name}")
device_log(logger, self.device)
start_log(logger, self.configt)
train_log(logger, self.configt)
sample_log(logger, self.config)
# -------- Load models --------
self.model_x = load_model_from_ckpt(
self.ckpt_dict["params_x"], self.ckpt_dict["x_state_dict"], self.device
)
self.model_adj = load_model_from_ckpt(
self.ckpt_dict["params_adj"], self.ckpt_dict["adj_state_dict"], self.device
)
for m in [self.model_x, self.model_adj]:
logger.log(
f"Model {m.__class__.__name__} loaded on {next(m.parameters()).device.type}"
)
self.sampling_fn = load_sampling_fn(
self.configt,
self.config.sampler,
self.config.sample,
self.device,
divide_batch=self.divide_batch,
)
# -------- Generate samples --------
logger.log(f"GEN SEED: {self.config.sample.seed}")
load_seed(self.config.sample.seed)
train_smiles, test_smiles = load_smiles(self.configt.data.data)
train_smiles, test_smiles = canonicalize_smiles(
train_smiles
), canonicalize_smiles(test_smiles)
self.train_graph_list, _ = load_data(
self.configt, get_list=True
) # for init_flags
with open(
os.path.join(
self.config.folder,
"data",
f"{self.configt.data.data.lower()}_test_nx.pkl",
),
"rb",
) as f:
self.test_graph_list = pickle.load(f) # for NSPDK MMD
logger.log(f"Sampling {self.n_samples} samples ...")
start_sampling_time = perf_counter()
qty_to_generate = (
self.n_samples // self.divide_batch if self.n_samples is not None else None
)
self.init_flags = init_flags(
self.train_graph_list, self.configt, qty_to_generate
).to(self.device0)
x, adj, _, diff_traj = self.sampling_fn(
self.model_x, self.model_adj, self.init_flags
)
for _ in range(1, self.divide_batch):
self.init_flags = init_flags(
self.train_graph_list, self.configt, qty_to_generate
).to(self.device0)
x_, adj_, _, _ = self.sampling_fn(
self.model_x, self.model_adj, self.init_flags
)
x = torch.cat((x, x_), dim=0)
adj = torch.cat((adj, adj_), dim=0)
samples_int = quantize_mol(adj)
samples_int = samples_int - 1
samples_int[samples_int == -1] = 3 # 0, 1, 2, 3 (no, S, D, T) -> 3, 0, 1, 2
adj = torch.nn.functional.one_hot(
torch.tensor(samples_int), num_classes=4
).permute(0, 3, 1, 2)
x = torch.where(x > 0.5, 1, 0)
x = torch.concat(
[x, 1 - x.sum(dim=-1, keepdim=True)], dim=-1
) # 32, 9, 4 -> 32, 9, 5
gen_mols, num_mols_wo_correction = gen_mol(x, adj, self.configt.data.data)
num_mols = len(gen_mols)
gen_smiles = mols_to_smiles(gen_mols)
gen_smiles = [smi for smi in gen_smiles if len(smi)]
# Convert generated molecules into graphs
gen_graph_list = mols_to_nx(gen_mols)
print("Sampling done.")
sampling_time = perf_counter() - start_sampling_time
time_log(logger, time_type="sample", elapsed_time=sampling_time)
if (
self.config.experiment_type == "train"
) and self.config.general_config.use_wandb:
wandb.log({"Sampling time": sampling_time})
# -------- Save generated molecules --------
with open(os.path.join(self.log_dir, f"{self.log_name}.txt"), "a") as f:
for smiles in gen_smiles:
f.write(f"{smiles}\n")
# -------- Evaluation --------
# Eval molecules
scores = get_all_metrics(
gen=gen_smiles,
k=len(gen_smiles),
device=self.device_score,
n_jobs=8,
test=test_smiles,
train=train_smiles,
)
scores_nspdk = eval_graph_list(
self.test_graph_list,
gen_graph_list,
methods=["nspdk"],
folder=self.config.folder,
)["nspdk"]
# Eval lifted CCs from the graphs
# Create test_CC_list based on test_graph_list via a conversion to molecules
test_mol_list = nxs_to_mols(self.test_graph_list)
self.test_CC_list = mols_to_cc(test_mol_list)
gen_CC_list = mols_to_cc(gen_mols) # same for the generated molecules
methods, kernels = load_cc_eval_settings()
result_dict_CC = eval_CC_list(
self.test_CC_list,
gen_CC_list,
worker_kwargs=self.worker_kwargs,
methods=methods,
kernels=kernels,
cc_nb_eval=self.cc_nb_eval,
)
logger.log(
f"CCs Eval @{self.cc_nb_eval} {result_dict_CC}", verbose=False
) # verbose=False cause already printed
logger.log(f"Number of molecules: {num_mols}")
logger.log(f"validity w/o correction: {num_mols_wo_correction / num_mols}")
for metric in ["valid", f"unique@{len(gen_smiles)}", "FCD/Test", "Novelty"]:
logger.log(f"{metric}: {scores[metric]}")
logger.log(f"NSPDK MMD: {scores_nspdk}")
logger.log(100 * "=")
if (
self.config.experiment_type == "train"
) and self.config.general_config.use_wandb:
# add scores to wandb
wandb.log(
{
"validity": scores["valid"],
f"unique@{len(gen_smiles)}": scores[f"unique@{len(gen_smiles)}"],
"FCD/Test": scores["FCD/Test"],
"Novelty": scores["Novelty"],
"NSPDK MMD": scores_nspdk,
}
)
# -------- Save samples & Plot --------
# Graphs
save_dir = save_graph_list(
self.config,
self.log_folder_name,
self.log_name + "_mol_graphs",
gen_graph_list,
)
with open(save_dir, "rb") as f:
sample_graph_list = pickle.load(f)
plot_graphs_list(
config=self.config,
graphs=sample_graph_list,
title=f"mol_graphs_{self.log_name}",
max_num=16,
save_dir=self.log_folder_name,
)
if (
self.config.experiment_type == "train"
) and self.config.general_config.use_wandb:
# add plots to wandb
img_path = os.path.join(
os.path.join(
*[self.config.folder, "samples", "fig", self.log_folder_name]
),
f"mol_graphs_{self.log_name}.png",
)
wandb.log({"Generated Mol Graphs": wandb.Image(img_path)})
# Molecules
save_dir = save_molecule_list(
self.config, self.log_folder_name, self.log_name + "_mols", gen_mols
)
with open(save_dir, "rb") as f:
sample_mol_list = pickle.load(f)
plot_molecule_list(
config=self.config,
mols=sample_mol_list,
title=f"mols_{self.log_name}",
max_num=16,
save_dir=self.log_folder_name,
)
if (
self.config.experiment_type == "train"
) and self.config.general_config.use_wandb:
# add plots to wandb
img_path = os.path.join(
os.path.join(
*[self.config.folder, "samples", "fig", self.log_folder_name]
),
f"mols_{self.log_name}.png",
)
wandb.log({"Generated Molecules": wandb.Image(img_path)})
# 3D Molecule
if self.config.general_config.plotly_fig:
molecule = gen_mols[0]
mol_3d = plot_3D_molecule(molecule)
filedir = os.path.join(
*[self.config.folder, "samples", "fig", self.log_folder_name]
)
filename = f"mols_3d_{self.log_name}.gif"
rotate_molecule_animation(
mol_3d,
filedir=filedir,
filename=filename,
duration=1.0,
frames=30,
rotations_per_sec=1.0,
overwrite=True,
engine=self.config.general_config.engine,
)
if (
self.config.experiment_type == "train"
) and self.config.general_config.use_wandb:
# add plots to wandb
img_path = os.path.join(filedir, filename)
wandb.log({"Generated Molecules 3D": wandb.Image(img_path)})
# Diffusion trajectory animation - Graphs and molecules
if self.config.general_config.plotly_fig:
# Graph
filedir = os.path.join(
*[self.config.folder, "samples", "fig", self.log_folder_name]
)
filename = f"diff_traj_graphs_{self.log_name}.gif"
diffusion_animation(
diff_traj=diff_traj,
is_molecule=False,
filedir=filedir,
filename=filename,
fps=25,
overwrite=True,
engine=self.config.general_config.engine,
)
if (
self.config.experiment_type == "train"
) and self.config.general_config.use_wandb:
# add plots to wandb
img_path = os.path.join(filedir, filename)
wandb.log({"Diffusion Trajectory Graph": wandb.Image(img_path)})
# Graph Cropped
filename = f"diff_traj_graphs_cropped_{self.log_name}.gif"
diffusion_animation(
diff_traj=diff_traj,
is_molecule=False,
filedir=filedir,
filename=filename,
fps=25,
overwrite=True,
engine=self.config.general_config.engine,
cropped=True,
)
if (
self.config.experiment_type == "train"
) and self.config.general_config.use_wandb:
# add plots to wandb
img_path = os.path.join(filedir, filename)
wandb.log({"Diffusion Trajectory Graph Cropped": wandb.Image(img_path)})
# Molecule
filename = f"diff_traj_mol_{self.log_name}.gif"
diffusion_animation(
diff_traj=diff_traj,
is_molecule=True,
filedir=filedir,
filename=filename,
fps=25,
overwrite=True,
engine=self.config.general_config.engine,
)
if (
self.config.experiment_type == "train"
) and self.config.general_config.use_wandb:
# add plots to wandb
img_path = os.path.join(filedir, filename)
wandb.log({"Diffusion Trajectory Molecule": wandb.Image(img_path)})
# Molecule Cropped
filename = f"diff_traj_mol_cropped_{self.log_name}.gif"
diffusion_animation(
diff_traj=diff_traj,
is_molecule=True,
filedir=filedir,
filename=filename,
fps=25,
overwrite=True,
engine=self.config.general_config.engine,
cropped=True,
)
if (
self.config.experiment_type == "train"
) and self.config.general_config.use_wandb:
# add plots to wandb
img_path = os.path.join(filedir, filename)
wandb.log(
{"Diffusion Trajectory Molecule Cropped": wandb.Image(img_path)}
)
[docs]
class Sampler_mol_CC(Sampler):
"""Sampler for molecule generation tasks with combinatorial complexes"""
[docs]
def __init__(self, config: EasyDict) -> None:
"""Initialize the sampler with the config and the device.
Args:
config (EasyDict): the config object to use
"""
super(Sampler_mol_CC, self).__init__(config)
self.config = config
self.device = load_device()
self.device0 = self.device[0] if isinstance(self.device, list) else self.device
# Device to compute the score metrics
if self.device0 == "cpu":
self.device_score = "cpu"
elif "cuda" in str(self.device0):
self.device_score = str(self.device0)
else:
self.device_score = f"cuda:{self.device0}"
self.n_samples = self.config.sample.get(
"n_samples", self.config.data.get("batch_size", None)
)
self.cc_nb_eval = self.config.sample.cc_nb_eval
# Worker kwargs for CC eval
self.worker_kwargs = {
"min_node_val": self.config.data.min_node_val,
"max_node_val": self.config.data.max_node_val,
"node_label": self.config.data.node_label,
"min_edge_val": self.config.data.min_edge_val,
"max_edge_val": self.config.data.max_edge_val,
"edge_label": self.config.data.edge_label,
"d_min": self.config.data.d_min,
"d_max": self.config.data.d_max,
"N": self.config.data.max_node_num,
}
self.divide_batch = self.config.sample.get("divide_batch", 1)
def __repr__(self) -> str:
"""Return the string representation of the sampler."""
return f"{self.__class__.__name__}(n_samples={self.n_samples}, cc_nb_eval={self.cc_nb_eval})"
[docs]
def sample(self) -> None:
"""Sample from the model. Loads the checkpoint, load the modes, generates samples, evaluates and saves them."""
# -------- Load checkpoint --------
self.ckpt_dict = load_ckpt(self.config, self.device, is_cc=True)
self.configt = self.ckpt_dict["config"]
load_seed(self.config.seed)
self.log_folder_name, self.log_dir, _ = set_log(
self.configt, is_train=False, folder=self.config.folder
)
self.log_name = f"{self.config.config_name}_{self.config.ckpt}-sample_{self.config.current_time}"
logger = Logger(
str(os.path.join(self.log_dir, f"{self.log_name}.log")), mode="a"
)
if not check_log(self.log_folder_name, self.log_name):
logger.log(f"{self.log_name}")
device_log(logger, self.device)
start_log(logger, self.configt)
train_log(logger, self.configt)
sample_log(logger, self.config)
# -------- Load models --------
self.model_x = load_model_from_ckpt(
self.ckpt_dict["params_x"], self.ckpt_dict["x_state_dict"], self.device
)
self.model_adj = load_model_from_ckpt(
self.ckpt_dict["params_adj"], self.ckpt_dict["adj_state_dict"], self.device
)
self.model_rank2 = load_model_from_ckpt(
self.ckpt_dict["params_rank2"],
self.ckpt_dict["rank2_state_dict"],
self.device,
)
for m in [self.model_x, self.model_adj, self.model_rank2]:
logger.log(
f"Model {m.__class__.__name__} loaded on {next(m.parameters()).device.type}"
)
self.sampling_fn = load_sampling_fn(
self.configt,
self.config.sampler,
self.config.sample,
self.device,
is_cc=True,
d_min=self.config.data.d_min,
d_max=self.config.data.d_max,
divide_batch=self.divide_batch,
)
# -------- Generate samples --------
logger.log(f"GEN SEED: {self.config.sample.seed}")
load_seed(self.config.sample.seed)
train_smiles, test_smiles = load_smiles(
self.configt.data.data, self.config.folder
)
train_smiles, test_smiles = canonicalize_smiles(
train_smiles
), canonicalize_smiles(test_smiles)
self.train_CC_list, _ = load_data(
self.configt, get_list=True, is_cc=True
) # for init_flags
with open(
os.path.join(
self.config.folder,
"data",
f"{self.configt.data.data.lower()}_test_nx.pkl",
),
"rb",
) as f:
self.test_graph_list = pickle.load(f) # for NSPDK MMD
# Create test_CC_list based on test_graph_list via a conversion to molecules
test_mol_list = nxs_to_mols(self.test_graph_list)
self.test_CC_list = mols_to_cc(test_mol_list)
# Generate samples
logger.log(f"Sampling {self.n_samples} samples ...")
start_sampling_time = perf_counter()
qty_to_generate = (
self.n_samples // self.divide_batch if self.n_samples is not None else None
)
self.init_flags = init_flags(
self.train_CC_list,
self.configt,
qty_to_generate,
is_cc=True,
).to(self.device0)
x, adj, rank2, _, diff_traj = self.sampling_fn(
self.model_x, self.model_adj, self.model_rank2, self.init_flags
)
for _ in range(1, self.divide_batch):
self.init_flags = init_flags(
self.train_CC_list,
self.configt,
qty_to_generate,
is_cc=True,
).to(self.device0)
x_, adj_, rank2_, _, _ = self.sampling_fn(
self.model_x, self.model_adj, self.model_rank2, self.init_flags
)
x = torch.cat((x, x_), dim=0)
adj = torch.cat((adj, adj_), dim=0)
rank2 = torch.cat((rank2, rank2_), dim=0)
samples_int = quantize_mol(adj)
samples_int_rank2 = quantize(rank2)
samples_int = samples_int - 1
samples_int[samples_int == -1] = 3 # 0, 1, 2, 3 (no, S, D, T) -> 3, 0, 1, 2
adj = torch.nn.functional.one_hot(
torch.tensor(samples_int), num_classes=4
).permute(0, 3, 1, 2)
x = torch.where(x > 0.5, 1, 0)
x = torch.concat(
[x, 1 - x.sum(dim=-1, keepdim=True)], dim=-1
) # 32, 9, 4 -> 32, 9, 5
gen_mols, num_mols_wo_correction = gen_mol(x, adj, self.configt.data.data)
num_mols = len(gen_mols)
gen_smiles = mols_to_smiles(gen_mols)
gen_smiles = [smi for smi in gen_smiles if len(smi)]
# Convert generated molecules into graphs and combinatorial complexes
gen_graph_list = mols_to_nx(gen_mols)
gen_CC_list = mols_to_cc(gen_mols)
print("Sampling done.")
sampling_time = perf_counter() - start_sampling_time
time_log(logger, time_type="sample", elapsed_time=sampling_time)
if (
self.config.experiment_type == "train"
) and self.config.general_config.use_wandb:
wandb.log({"Sampling time": sampling_time})
# -------- Save generated molecules --------
with open(os.path.join(self.log_dir, f"{self.log_name}_smiles.txt"), "a") as f:
for smiles in gen_smiles:
f.write(f"{smiles}\n")
# -------- Evaluation --------
# Eval molecules
scores = get_all_metrics(
gen=gen_smiles,
k=len(gen_smiles),
device=self.device_score,
n_jobs=8,
test=test_smiles,
train=train_smiles,
)
scores_nspdk = eval_graph_list(
self.test_graph_list,
gen_graph_list,
methods=["nspdk"],
folder=self.config.folder,
)["nspdk"]
# Eval CCs
methods, kernels = load_cc_eval_settings()
result_dict_CC = eval_CC_list(
self.test_CC_list,
gen_CC_list,
worker_kwargs=self.worker_kwargs,
methods=methods,
kernels=kernels,
cc_nb_eval=self.cc_nb_eval,
)
logger.log(
f"CCs Eval @{self.cc_nb_eval} {result_dict_CC}", verbose=False
) # verbose=False cause already printed
logger.log(f"Number of molecules: {num_mols}")
logger.log(f"validity w/o correction: {num_mols_wo_correction / num_mols}")
for metric in ["valid", f"unique@{len(gen_smiles)}", "FCD/Test", "Novelty"]:
logger.log(f"{metric}: {scores[metric]}")
logger.log(f"NSPDK MMD: {scores_nspdk}")
logger.log(100 * "=")
if (
self.config.experiment_type == "train"
) and self.config.general_config.use_wandb:
# add scores to wandb
wandb.log(
{
"validity": scores["valid"],
f"unique@{len(gen_smiles)}": scores[f"unique@{len(gen_smiles)}"],
"FCD/Test": scores["FCD/Test"],
"Novelty": scores["Novelty"],
"NSPDK MMD": scores_nspdk,
}
)
wandb.log(result_dict_CC)
# -------- Save samples & Plot --------
# Ccs
save_dir = save_cc_list(
self.config, self.log_folder_name, self.log_name + "_mol_ccs", gen_CC_list
)
with open(save_dir, "rb") as f:
sample_CC_list = pickle.load(f)
plot_cc_list(
config=self.config,
ccs=sample_CC_list,
title=f"mol_ccs_{self.log_name}",
max_num=16,
save_dir=self.log_folder_name,
)
if (
self.config.experiment_type == "train"
) and self.config.general_config.use_wandb:
# add plots to wandb
img_path = os.path.join(
os.path.join(
*[self.config.folder, "samples", "fig", self.log_folder_name]
),
f"mol_ccs_{self.log_name}.png",
)
wandb.log({"Generated Mol Combinatorial Complexes": wandb.Image(img_path)})
# Graphs
save_dir = save_graph_list(
self.config,
self.log_folder_name,
self.log_name + "_mol_graphs",
gen_graph_list,
)
with open(save_dir, "rb") as f:
sample_graph_list = pickle.load(f)
plot_graphs_list(
config=self.config,
graphs=sample_graph_list,
title=f"mol_graphs_{self.log_name}",
max_num=16,
save_dir=self.log_folder_name,
)
if (
self.config.experiment_type == "train"
) and self.config.general_config.use_wandb:
# add plots to wandb
img_path = os.path.join(
os.path.join(
*[self.config.folder, "samples", "fig", self.log_folder_name]
),
f"mol_graphs_{self.log_name}.png",
)
wandb.log({"Generated Mol Graphs": wandb.Image(img_path)})
# Molecules
save_dir = save_molecule_list(
self.config, self.log_folder_name, self.log_name + "_mols", gen_mols
)
with open(save_dir, "rb") as f:
sample_mol_list = pickle.load(f)
plot_molecule_list(
config=self.config,
mols=sample_mol_list,
title=f"mols_{self.log_name}",
max_num=16,
save_dir=self.log_folder_name,
)
if (
self.config.experiment_type == "train"
) and self.config.general_config.use_wandb:
# add plots to wandb
img_path = os.path.join(
os.path.join(
*[self.config.folder, "samples", "fig", self.log_folder_name]
),
f"mols_{self.log_name}.png",
)
wandb.log({"Generated Molecules": wandb.Image(img_path)})
# 3D Molecule
if self.config.general_config.plotly_fig:
molecule = gen_mols[0]
mol_3d = plot_3D_molecule(molecule)
filedir = os.path.join(
*[self.config.folder, "samples", "fig", self.log_folder_name]
)
filename = f"mols_3d_{self.log_name}.gif"
rotate_molecule_animation(
mol_3d,
filedir=filedir,
filename=filename,
duration=1.0,
frames=30,
rotations_per_sec=1.0,
overwrite=True,
engine=self.config.general_config.engine,
)
if (
self.config.experiment_type == "train"
) and self.config.general_config.use_wandb:
# add plots to wandb
img_path = os.path.join(filedir, filename)
wandb.log({"Generated Molecule 3D": wandb.Image(img_path)})
# Diffusion trajectory animation - Graphs and molecules
if self.config.general_config.plotly_fig:
# Graph
filedir = os.path.join(
*[self.config.folder, "samples", "fig", self.log_folder_name]
)
filename = f"diff_traj_graphs_{self.log_name}.gif"
diffusion_animation(
diff_traj=diff_traj,
is_molecule=False,
filedir=filedir,
filename=filename,
fps=25,
overwrite=True,
engine=self.config.general_config.engine,
)
if (
self.config.experiment_type == "train"
) and self.config.general_config.use_wandb:
# add plots to wandb
img_path = os.path.join(filedir, filename)
wandb.log({"Diffusion Trajectory Graph": wandb.Image(img_path)})
# Graph Cropped
filename = f"diff_traj_graphs_cropped_{self.log_name}.gif"
diffusion_animation(
diff_traj=diff_traj,
is_molecule=False,
filedir=filedir,
filename=filename,
fps=25,
overwrite=True,
engine=self.config.general_config.engine,
cropped=True,
)
if (
self.config.experiment_type == "train"
) and self.config.general_config.use_wandb:
# add plots to wandb
img_path = os.path.join(filedir, filename)
wandb.log({"Diffusion Trajectory Graph Cropped": wandb.Image(img_path)})
# Molecule
filename = f"diff_traj_mol_{self.log_name}.gif"
diffusion_animation(
diff_traj=diff_traj,
is_molecule=True,
filedir=filedir,
filename=filename,
fps=25,
overwrite=True,
engine=self.config.general_config.engine,
)
if (
self.config.experiment_type == "train"
) and self.config.general_config.use_wandb:
# add plots to wandb
img_path = os.path.join(filedir, filename)
wandb.log({"Diffusion Trajectory Molecule": wandb.Image(img_path)})
# Molecule Cropped
filename = f"diff_traj_mol_cropped_{self.log_name}.gif"
diffusion_animation(
diff_traj=diff_traj,
is_molecule=True,
filedir=filedir,
filename=filename,
fps=25,
overwrite=True,
engine=self.config.general_config.engine,
cropped=True,
)
if (
self.config.experiment_type == "train"
) and self.config.general_config.use_wandb:
# add plots to wandb
img_path = os.path.join(filedir, filename)
wandb.log(
{"Diffusion Trajectory Molecule Cropped": wandb.Image(img_path)}
)
[docs]
def get_sampler_from_config(
config: EasyDict,
) -> Sampler:
"""Get the sampler from a configuration file config
Args:
config (EasyDict): configuration file
Returns:
Sampler: sampler to use for the experiment
"""
if config.is_cc:
sampler = (
Sampler_mol_CC(config)
if is_molecular_config(config)
else Sampler_CC(config)
)
else:
sampler = (
Sampler_mol_Graph(config)
if is_molecular_config(config)
else Sampler_Graph(config)
)
return sampler