#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""numpytupledataset.py: NumpyTupleDataset class.
Just used to in data/preprocess.py.
Original code from MoFlow (under MIT License) https://github.com/calvin-zcx/moflow
Code from Jo, J. & al (2022)
Left untouched.
"""
import os
import numpy as np
from torch.utils.data import Dataset
[docs]
class NumpyTupleDataset(Dataset):
"""NumpyTupleDataset class. Just used to in data/preprocess.py.
Original code from MoFlow (under MIT License) https://github.com/calvin-zcx/moflow
"""
[docs]
def __init__(self, datasets, transform=None):
if not datasets:
raise ValueError("no datasets are given")
length = len(datasets[0]) # 133885
for i, dataset in enumerate(datasets):
if len(dataset) != length:
raise ValueError("dataset of the index {} has a wrong length".format(i))
self._datasets = datasets
self._length = length
self.transform = transform
def __len__(self):
return self._length
def __getitem__(self, index):
batches = [dataset[index] for dataset in self._datasets]
if isinstance(index, (slice, list, np.ndarray)):
length = len(batches[0])
batches = [
tuple([batch[i] for batch in batches]) for i in range(length)
] # six.moves.range(length)]
else:
batches = tuple(batches)
if self.transform:
batches = self.transform(batches)
return batches
[docs]
def get_datasets(self):
return self._datasets
[docs]
@classmethod
def save(cls, filepath, numpy_tuple_dataset):
if not isinstance(numpy_tuple_dataset, NumpyTupleDataset):
raise TypeError(
"numpy_tuple_dataset is not instance of "
"NumpyTupleDataset, got {}".format(type(numpy_tuple_dataset))
)
np.savez(filepath, *numpy_tuple_dataset._datasets)
print("Save {} done.".format(filepath))
[docs]
@classmethod
def load(cls, filepath, transform=None):
print("Loading file {}".format(filepath))
if not os.path.exists(filepath):
raise ValueError("Invalid filepath {} for dataset".format(filepath))
# return None
load_data = np.load(filepath)
result = []
i = 0
while True:
key = "arr_{}".format(i)
if key in load_data.keys():
result.append(load_data[key])
i += 1
else:
break
return cls(result, transform)