You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
66 lines
1.8 KiB
66 lines
1.8 KiB
import torch
|
|
import torch.utils.data
|
|
import numpy as np
|
|
|
|
|
|
class TestSet(object):
|
|
def __init__(self, name, dset, test_frequency=1):
|
|
self.name = name
|
|
self.dset = dset
|
|
self.test_frequency = test_frequency
|
|
|
|
|
|
class TestSets(list):
|
|
def append(self, name, dset, test_frequency=1):
|
|
super().append(TestSet(name, dset, test_frequency))
|
|
|
|
|
|
class MultiDataset(torch.utils.data.Dataset):
|
|
def __init__(self, *datasets):
|
|
self.current_epoch = 0
|
|
|
|
self.datasets = []
|
|
self.cum_n_samples = [0]
|
|
|
|
for dataset in datasets:
|
|
self.append(dataset)
|
|
|
|
def append(self, dataset):
|
|
self.datasets.append(dataset)
|
|
self.__update_cum_n_samples(dataset)
|
|
|
|
def __update_cum_n_samples(self, dataset):
|
|
n_samples = self.cum_n_samples[-1] + len(dataset)
|
|
self.cum_n_samples.append(n_samples)
|
|
|
|
def dataset_updated(self):
|
|
self.cum_n_samples = [0]
|
|
for dset in self.datasets:
|
|
self.__update_cum_n_samples(dset)
|
|
|
|
def __len__(self):
|
|
return self.cum_n_samples[-1]
|
|
|
|
def __getitem__(self, idx):
|
|
didx = np.searchsorted(self.cum_n_samples, idx, side='right') - 1
|
|
sidx = idx - self.cum_n_samples[didx]
|
|
return self.datasets[didx][sidx]
|
|
|
|
|
|
class BaseDataset(torch.utils.data.Dataset):
|
|
def __init__(self, train=True, fix_seed_per_epoch=False):
|
|
self.current_epoch = 0
|
|
self.train = train
|
|
self.fix_seed_per_epoch = fix_seed_per_epoch
|
|
|
|
def get_rng(self, idx):
|
|
rng = np.random.RandomState()
|
|
if self.train:
|
|
if self.fix_seed_per_epoch:
|
|
seed = 1 * len(self) + idx
|
|
else:
|
|
seed = (self.current_epoch + 1) * len(self) + idx
|
|
rng.seed(seed)
|
|
else:
|
|
rng.seed(idx)
|
|
return rng
|
|
|