Source code for ccsd.data.utils.numpytupledataset

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""numpytupledataset.py: NumpyTupleDataset class.
Just used to in data/preprocess.py.
Original code from MoFlow (under MIT License) https://github.com/calvin-zcx/moflow
Code from Jo, J. & al (2022)

Left untouched.
"""

import os

import numpy as np
from torch.utils.data import Dataset


[docs] class NumpyTupleDataset(Dataset): """NumpyTupleDataset class. Just used to in data/preprocess.py. Original code from MoFlow (under MIT License) https://github.com/calvin-zcx/moflow """
[docs] def __init__(self, datasets, transform=None): if not datasets: raise ValueError("no datasets are given") length = len(datasets[0]) # 133885 for i, dataset in enumerate(datasets): if len(dataset) != length: raise ValueError("dataset of the index {} has a wrong length".format(i)) self._datasets = datasets self._length = length self.transform = transform
def __len__(self): return self._length def __getitem__(self, index): batches = [dataset[index] for dataset in self._datasets] if isinstance(index, (slice, list, np.ndarray)): length = len(batches[0]) batches = [ tuple([batch[i] for batch in batches]) for i in range(length) ] # six.moves.range(length)] else: batches = tuple(batches) if self.transform: batches = self.transform(batches) return batches
[docs] def get_datasets(self): return self._datasets
[docs] @classmethod def save(cls, filepath, numpy_tuple_dataset): if not isinstance(numpy_tuple_dataset, NumpyTupleDataset): raise TypeError( "numpy_tuple_dataset is not instance of " "NumpyTupleDataset, got {}".format(type(numpy_tuple_dataset)) ) np.savez(filepath, *numpy_tuple_dataset._datasets) print("Save {} done.".format(filepath))
[docs] @classmethod def load(cls, filepath, transform=None): print("Loading file {}".format(filepath)) if not os.path.exists(filepath): raise ValueError("Invalid filepath {} for dataset".format(filepath)) # return None load_data = np.load(filepath) result = [] i = 0 while True: key = "arr_{}".format(i) if key in load_data.keys(): result.append(load_data[key]) i += 1 else: break return cls(result, transform)