Source code for ccsd.src.utils.data_loader_mol

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""data_loader_mol.py: utility functions for loading the graph data (molecular ones).

Only dataloader_mol left untouched from Jo, J. & al (2022)
"""

import json
import os
from time import perf_counter
from typing import Any, Callable, List, Tuple, Union

import networkx as nx
import numpy as np
import torch
from easydict import EasyDict
from toponetx.classes.combinatorial_complex import CombinatorialComplex
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from ccsd.data.data_generators import load_dataset, save_dataset
from ccsd.src.utils.cc_utils import (
    cc_from_incidence,
    create_incidence_1_2,
    get_all_mol_rings,
    get_mol_from_x_adj,
)


[docs] def load_mol(filepath: str) -> List[Tuple[Any, Any]]: """Load molecular dataset from filepath. Adapted from GraphEBM Args: filepath (str): filepath to the dataset Raises: ValueError: raise an error if the filepath is invalid Returns: List[Tuple[Any, Any]]: list of tuples of (node features, adjacency matrix) """ print(f"Loading file {filepath}") if not os.path.exists(filepath): raise ValueError(f"Invalid filepath {filepath} for dataset") try: load_data = np.load( filepath, allow_pickle=True ) # allow pickle for complex data except: with open(filepath, "rb") as f: load_data = np.load(f, allow_pickle=True) if isinstance( load_data, np.ndarray ): # if the data is a numpy array, convert it to dict load_data = load_data.item() result = [] i = 0 while True: key = f"arr_{i}" if key in load_data.keys(): result.append(load_data[key]) i += 1 else: break return list(map(lambda x, a: (x, a), result[0], result[1]))
[docs] class MolDataset(Dataset): """Dataset object for molecular dataset."""
[docs] def __init__( self, mols: List[Tuple[np.ndarray, np.ndarray]], transform: Union[ Callable[ [Tuple[np.ndarray, np.ndarray]], Tuple[torch.Tensor, torch.Tensor] ], Callable[ [Tuple[np.ndarray, np.ndarray]], Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ], ], ) -> None: """Initialize the dataset. Args: mols (List[Tuple[np.ndarray, np.ndarray]]): list of tuples of (node features, adjacency matrix) transform (Union[Callable[[Tuple[np.ndarray, np.ndarray]], Tuple[torch.Tensor, torch.Tensor]], Callable[[Tuple[np.ndarray, np.ndarray]], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]): transform function that transforms the data into tensors with some preprocessing. Two tensors for graph-based modelisation and three tensors for combinatorial complex-based modelisation. """ self.mols = mols self.transform = transform
def __len__(self) -> int: """Return the length of the dataset. Returns: int: length of the dataset """ return len(self.mols) def __getitem__( self, idx: int ) -> Union[ Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ]: """Get the item of the dataset at the given index. Args: idx (int): index of the item Returns: Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: tuple of (node features, adjacency matrix) or (node features, adjacency matrix, rank2 incidence matrix) as tensors """ return self.transform(self.mols[idx]) def __repr__(self) -> str: """Return the string representation of the MolDataset class. Returns: str: the string representation of the MolDataset class """ return self.__class__.__name__
[docs] def get_transform_fn( dataset: str, is_cc: bool = False, **kwargs: Any, ) -> Union[ Callable[[Tuple[np.ndarray, np.ndarray]], Tuple[torch.Tensor, torch.Tensor]], Callable[ [Tuple[np.ndarray, np.ndarray]], Tuple[torch.Tensor, torch.Tensor, torch.Tensor] ], ]: """Get the transform function for the given dataset. Args: dataset (str): name of the dataset is_cc (bool, optional): if True, the transform function returns three tensors for combinatorial complexes modelisation. Defaults to False. Raises: ValueError: raise an error if the dataset is invalid/unsupported Returns: Union[Callable[[Tuple[np.ndarray, np.ndarray]], Tuple[torch.Tensor, torch.Tensor]], Callable[[Tuple[np.ndarray, np.ndarray]], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: transform function that transforms the data into tensors with some preprocessing. Two tensors for graph-based modelisation and three tensors for combinatorial complex-based modelisation. """ if dataset == "QM9": if not (is_cc): def transform( data: Tuple[np.ndarray, np.ndarray] ) -> Tuple[torch.Tensor, torch.Tensor]: """Transform data from QM9 (node matrix, adj matrix) into tensors with some preprocessing. Args: data (Tuple[np.ndarray, np.ndarray]): tuple of (node features, adjacency matrix) Returns: Tuple[torch.Tensor, torch.Tensor]: tuple of (node features, adjacency matrix) as tensors """ x, adj = data # the last place is for virtual nodes # 6: C, 7: N, 8: O, 9: F x_ = np.zeros((9, 5)) indices = np.where(x >= 6, x - 6, 4) x_[np.arange(9), indices] = 1 x = torch.tensor(x_).to(torch.float32) # single, double, triple and no-bond; the last channel is for virtual edges adj = np.concatenate( [adj[:3], 1 - np.sum(adj[:3], axis=0, keepdims=True)], axis=0 ).astype(np.float32) x = x[ :, :-1 ] # 9, 5 (the last place is for vitual nodes) -> 9, 4 (38, 9) adj = torch.tensor( adj.argmax(axis=0) ) # 4, 9, 9 (the last place is for vitual edges) -> 9, 9 (38, 38) # 0, 1, 2, 3 -> 1, 2, 3, 0; now virtual edges are denoted as 0 adj = torch.where(adj == 3, 0, adj + 1).to(torch.float32) return x, adj else: d_min = kwargs["d_min"] d_max = kwargs["d_max"] def transform( data: Tuple[np.ndarray, np.ndarray], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Transform data from QM9 (node matrix, adj matrix) into tensors with some preprocessing. Args: data (Tuple[np.ndarray, np.ndarray]): tuple of (node features, adjacency matrix) Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: tuple of (node features, adjacency matrix, rank2 incidence matrix) as tensors """ x, adj = data # the last place is for virtual nodes # 6: C, 7: N, 8: O, 9: F x_ = np.zeros((9, 5)) indices = np.where(x >= 6, x - 6, 4) x_[np.arange(9), indices] = 1 x = torch.tensor(x_).to(torch.float32) # single, double, triple and no-bond; the last channel is for virtual edges adj = np.concatenate( [adj[:3], 1 - np.sum(adj[:3], axis=0, keepdims=True)], axis=0 ).astype(np.float32) x = x[ :, :-1 ] # 9, 5 (the last place is for vitual nodes) -> 9, 4 (38, 9) adj = torch.tensor( adj.argmax(axis=0) ) # 4, 9, 9 (the last place is for vitual edges) -> 9, 9 (38, 38) # 0, 1, 2, 3 -> 1, 2, 3, 0; now virtual edges are denoted as 0 adj = torch.where(adj == 3, 0, adj + 1).to(torch.float32) # rank2 incidence matrix mol = get_mol_from_x_adj(x, adj) rings = get_all_mol_rings(mol) rings = {ring: {} for ring in rings} # convert to dict rank2 = create_incidence_1_2( x.shape[0], adj, d_min, d_max, two_rank_cells=rings ) rank2 = torch.tensor(rank2).to(torch.float32) return x, adj, rank2 elif dataset == "ZINC250k": if not (is_cc): def transform( data: Tuple[np.ndarray, np.ndarray], ) -> Tuple[torch.Tensor, torch.Tensor]: """Transform data from ZINC250k (node matrix, adj matrix) into tensors with some preprocessing. Args: data (Tuple[np.ndarray, np.ndarray]): tuple of (node features, adjacency matrix) Returns: Tuple[torch.Tensor, torch.Tensor]: tuple of (node features, adjacency matrix) as tensors """ x, adj = data # the last place is for virtual nodes # 6: C, 7: N, 8: O, 9: F, 15: P, 16: S, 17: Cl, 35: Br, 53: I zinc250k_atomic_num_list = [6, 7, 8, 9, 15, 16, 17, 35, 53, 0] x_ = np.zeros((38, 10), dtype=np.float32) for i in range(38): ind = zinc250k_atomic_num_list.index(x[i]) x_[i, ind] = 1.0 x = torch.tensor(x_).to(torch.float32) # single, double, triple and no-bond; the last channel is for virtual edges adj = np.concatenate( [adj[:3], 1 - np.sum(adj[:3], axis=0, keepdims=True)], axis=0 ).astype(np.float32) x = x[ :, :-1 ] # 9, 5 (the last place is for vitual nodes) -> 9, 4 (38, 9) adj = torch.tensor( adj.argmax(axis=0) ) # 4, 9, 9 (the last place is for vitual edges) -> 9, 9 (38, 38) # 0, 1, 2, 3 -> 1, 2, 3, 0; now virtual edges are denoted as 0 adj = torch.where(adj == 3, 0, adj + 1).to(torch.float32) return x, adj else: d_min = kwargs["d_min"] d_max = kwargs["d_max"] def transform( data: Tuple[np.ndarray, np.ndarray], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Transform data from ZINC250k (node matrix, adj matrix) into tensors with some preprocessing. Args: data (Tuple[np.ndarray, np.ndarray]): tuple of (node features, adjacency matrix) Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: tuple of (node features, adjacency matrix, rank2 incidence matrix) as tensors """ x, adj = data # the last place is for virtual nodes # 6: C, 7: N, 8: O, 9: F, 15: P, 16: S, 17: Cl, 35: Br, 53: I zinc250k_atomic_num_list = [6, 7, 8, 9, 15, 16, 17, 35, 53, 0] x_ = np.zeros((38, 10), dtype=np.float32) for i in range(38): ind = zinc250k_atomic_num_list.index(x[i]) x_[i, ind] = 1.0 x = torch.tensor(x_).to(torch.float32) # single, double, triple and no-bond; the last channel is for virtual edges adj = np.concatenate( [adj[:3], 1 - np.sum(adj[:3], axis=0, keepdims=True)], axis=0 ).astype(np.float32) x = x[ :, :-1 ] # 9, 5 (the last place is for vitual nodes) -> 9, 4 (38, 9) adj = torch.tensor( adj.argmax(axis=0) ) # 4, 9, 9 (the last place is for vitual edges) -> 9, 9 (38, 38) # 0, 1, 2, 3 -> 1, 2, 3, 0; now virtual edges are denoted as 0 adj = torch.where(adj == 3, 0, adj + 1).to(torch.float32) # rank2 incidence matrix mol = get_mol_from_x_adj(x, adj) rings = get_all_mol_rings(mol) rings = {ring: {} for ring in rings} # convert to dict rank2 = create_incidence_1_2( x.shape[0], adj, d_min, d_max, two_rank_cells=rings ) rank2 = torch.tensor(rank2).to(torch.float32) return x, adj, rank2 else: raise ValueError(f"Invalid dataset {dataset}") return transform
[docs] def dataloader_mol( config: EasyDict, get_graph_list: bool = False ) -> Union[Tuple[DataLoader, DataLoader], Tuple[List[nx.Graph], List[nx.Graph]]]: """Load the dataset and return the train and test dataloader for the given molecular dataset. Args: config (EasyDict): configuration to use get_graph_list (bool, optional): if True, the dataloader are lists of graphs. Defaults to False. Returns: Union[Tuple[DataLoader, DataLoader], Tuple[List[nx.Graph], List[nx.Graph]]]: train and test dataloader (tensors or lists of graphs) """ dataset_name = f"{config.data.data}_graphs_{get_graph_list}" data_dir = os.path.join(config.folder, config.data.dir) if os.path.exists(os.path.join(data_dir, f"{dataset_name}_train.pkl")): # Load the data print("Loading existing files...") train = load_dataset(data_dir=data_dir, file_name=f"{dataset_name}_train") test = load_dataset(data_dir=data_dir, file_name=f"{dataset_name}_test") return train, test # If the data does not exist, create it start_time = perf_counter() mols = load_mol( os.path.join( config.folder, config.data.dir, f"{config.data.data.lower()}_kekulized.npz" ) ) with open( os.path.join( config.folder, config.data.dir, f"valid_idx_{config.data.data.lower()}.json" ) ) as f: test_idx = json.load(f) if config.data.data == "QM9": # process QM9 differently test_idx = test_idx["valid_idxs"] test_idx = [int(i) for i in test_idx] test_idx = set(test_idx) # convert to set to speed up the process train_idx = [i for i in range(len(mols)) if i not in test_idx] print( f"Number of training mols: {len(train_idx)} | Number of test mols: {len(test_idx)}" ) train_mols = [mols[i] for i in train_idx] test_mols = [mols[i] for i in test_idx] # Create MolDataset objects train_dataset = MolDataset(train_mols, get_transform_fn(config.data.data)) test_dataset = MolDataset(test_mols, get_transform_fn(config.data.data)) if get_graph_list: print("Loading train graphs...") train_mols_nx = [] for i in tqdm(range(len(train_dataset))): _, adj = train_dataset[i] train_mols_nx.append(nx.from_numpy_matrix(np.array(adj))) print("Loading test graphs...") test_mols_nx = [] for i in tqdm(range(len(test_dataset))): _, adj = test_dataset[i] test_mols_nx.append(nx.from_numpy_matrix(np.array(adj))) save_dataset( data_dir=data_dir, obj=train_mols_nx, save_name=f"{dataset_name}_train", save_txt=False, ) save_dataset( data_dir=data_dir, obj=test_mols_nx, save_name=f"{dataset_name}_test", save_txt=False, ) print(f"{perf_counter() - start_time:.2f} sec elapsed for data loading") return train_mols_nx, test_mols_nx print("Loading train dataloader...") train_dataloader = DataLoader( train_dataset, batch_size=config.data.batch_size, shuffle=True ) print("Loading test dataloader...") test_dataloader = DataLoader( test_dataset, batch_size=config.data.batch_size, shuffle=True ) print(f"{perf_counter() - start_time:.2f} sec elapsed for data loading") return train_dataloader, test_dataloader
[docs] def dataloader_mol_cc( config: EasyDict, get_cc_list: bool = False ) -> Union[ Tuple[DataLoader, DataLoader], Tuple[List[CombinatorialComplex], List[CombinatorialComplex]], ]: """Load the dataset and return the train and test dataloader for the given molecular dataset. Args: config (EasyDict): configuration to use get_cc_list (bool, optional): if True, the dataloader are lists of combinatorial complexes. Defaults to False. Returns: Union[Tuple[DataLoader, DataLoader], Tuple[List[CombinatorialComplex], List[CombinatorialComplex]]]: train and test dataloader (tensors or lists of combinatorial complexes) """ dataset_name = f"{config.data.data}_cc_{get_cc_list}" data_dir = os.path.join(config.folder, config.data.dir) if os.path.exists(os.path.join(data_dir, f"{dataset_name}_train.pkl")): # Load the data print("Loading existing files...") train = load_dataset(data_dir=data_dir, file_name=f"{dataset_name}_train") test = load_dataset(data_dir=data_dir, file_name=f"{dataset_name}_test") return train, test # If the data does not exist, create it start_time = perf_counter() mols = load_mol( os.path.join( config.folder, config.data.dir, f"{config.data.data.lower()}_kekulized.npz" ) ) with open( os.path.join( config.folder, config.data.dir, f"valid_idx_{config.data.data.lower()}.json" ) ) as f: test_idx = json.load(f) if config.data.data == "QM9": # process QM9 differently test_idx = test_idx["valid_idxs"] test_idx = [int(i) for i in test_idx] test_idx = set(test_idx) # convert to set to speed up the process train_idx = [i for i in range(len(mols)) if i not in test_idx] print( f"Number of training mols: {len(train_idx)} | Number of test mols: {len(test_idx)}" ) train_mols = [mols[i] for i in train_idx] test_mols = [mols[i] for i in test_idx] # Create MolDataset objects train_dataset = MolDataset( train_mols, get_transform_fn( config.data.data, is_cc=True, d_min=config.data.d_min, d_max=config.data.d_max, ), ) test_dataset = MolDataset( test_mols, get_transform_fn( config.data.data, is_cc=True, d_min=config.data.d_min, d_max=config.data.d_max, ), ) if get_cc_list: print("Loading train combinatorial complexes...") train_mols_cc = [] for i in tqdm(range(len(train_dataset))): x, adj, rank2 = train_dataset[i] train_mols_cc.append( cc_from_incidence( [x, adj, rank2], config.data.d_min, config.data.d_max, is_molecule=True, ) ) print("Loading test combinatorial complexes...") test_mols_cc = [] for i in tqdm(range(len(test_dataset))): x, adj, rank2 = test_dataset[i] test_mols_cc.append( cc_from_incidence( [x, adj, rank2], config.data.d_min, config.data.d_max, is_molecule=True, ) ) save_dataset( data_dir=data_dir, obj=train_mols_cc, save_name=f"{dataset_name}_train", save_txt=False, ) save_dataset( data_dir=data_dir, obj=test_mols_cc, save_name=f"{dataset_name}_test", save_txt=False, ) print(f"{perf_counter() - start_time:.2f} sec elapsed for data loading") return train_mols_cc, test_mols_cc print("Loading train dataloader...") train_dataloader = DataLoader( train_dataset, batch_size=config.data.batch_size, shuffle=True ) print("Loading test dataloader...") test_dataloader = DataLoader( test_dataset, batch_size=config.data.batch_size, shuffle=True ) print(f"{perf_counter() - start_time:.2f} sec elapsed for data loading") return train_dataloader, test_dataloader