From cebf7767146089dd692142ce0f2763e72a5b5b35 Mon Sep 17 00:00:00 2001 From: "Cpt.Captain" Date: Tue, 22 Feb 2022 13:29:54 +0100 Subject: [PATCH] Add real world dataset --- data/dataset.py | 123 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 123 insertions(+) diff --git a/data/dataset.py b/data/dataset.py index fe7d359..cc3e6a1 100644 --- a/data/dataset.py +++ b/data/dataset.py @@ -144,5 +144,128 @@ class TrackSynDataset(torchext.BaseDataset): return K +class RealWorldDataset(torchext.BaseDataset): + ''' + Load locally saved real-world dataset + Please generate the dataset beforehand + ''' + + def __init__(self, settings_path, sample_paths, track_length=1, train=True, data_aug=False): + super().__init__(train=train) + + self.settings_path = settings_path + self.sample_paths = sample_paths + self.data_aug = data_aug + self.train = train + self.track_length = track_length + assert (track_length <= 4) + + with open(str(settings_path), 'rb') as f: + settings = pickle.load(f) + self.imsizes = settings['imsizes'] + self.patterns = settings['patterns'] + self.focal_lengths = settings['focal_lengths'] + self.baseline = settings['baseline'] + self.K = settings['K'] + + self.scale = 1 + + self.max_shift = 0 + self.max_blur = 0.5 + self.max_noise = 3.0 + self.max_sp_noise = 0.0005 + + def __len__(self): + return len(self.sample_paths) + + def __getitem__(self, idx): + if not self.train: + rng = self.get_rng(idx) + else: + rng = np.random.RandomState() + sample_path = self.sample_paths[idx] + + if self.train: + track_ind = np.random.permutation(4)[0:self.track_length] + else: + track_ind = [0] + + ret = {} + ret['id'] = idx + + # load imgs, at all scales + for sidx in range(len(self.imsizes)): + imgs = [] + ambs = [] + grads = [] + for tidx in track_ind: + imgs.append(np.load(os.path.join(sample_path, f'im0.npy'), allow_pickle=True)) + ambs.append(np.load(os.path.join(sample_path, f'ambient0.npy'), allow_pickle=True)) + grads.append(np.load(os.path.join(sample_path, f'grad0.npy'), allow_pickle=True)) + ret[f'im0'] = np.stack(imgs, axis=0) + ret[f'ambient0'] = np.stack(ambs, axis=0) + ret[f'grad0'] = np.stack(grads, axis=0) + + # load disp and grad only at full resolution + # FIXME do this for our stuff + disps = [] + R = [] + t = [] + for tidx in track_ind: + disps.append(np.load(os.path.join(sample_path, f'disp0_{tidx}.npy'), allow_pickle=True)) + R.append(np.load(os.path.join(sample_path, f'R_{tidx}.npy'), allow_pickle=True)) + t.append(np.load(os.path.join(sample_path, f't_{tidx}.npy'), allow_pickle=True)) + ret[f'disp0'] = np.stack(disps, axis=0) + ret['R'] = np.stack(R, axis=0) + ret['t'] = np.stack(t, axis=0) + + blend_im = np.load(os.path.join(sample_path, 'blend_im.npy'), allow_pickle=True) + ret['blend_im'] = blend_im.astype(np.float32) + + #### apply data augmentation at different scales seperately, only work for max_shift=0 + if self.data_aug: + for sidx in range(len(self.imsizes)): + if sidx == 0: + img = ret[f'im{sidx}'] + disp = ret[f'disp{sidx}'] + grad = ret[f'grad{sidx}'] + img_aug = np.zeros_like(img) + disp_aug = np.zeros_like(img) + grad_aug = np.zeros_like(img) + for i in range(img.shape[0]): + img_aug_, disp_aug_, grad_aug_ = augment_image(img[i, 0], rng, + disp=disp[i, 0], grad=grad[i, 0], + max_shift=self.max_shift, max_blur=self.max_blur, + max_noise=self.max_noise, + max_sp_noise=self.max_sp_noise) + img_aug[i] = img_aug_[None].astype(np.float32) + disp_aug[i] = disp_aug_[None].astype(np.float32) + grad_aug[i] = grad_aug_[None].astype(np.float32) + ret[f'im{sidx}'] = img_aug + ret[f'disp{sidx}'] = disp_aug + ret[f'grad{sidx}'] = grad_aug + else: + img = ret[f'im{sidx}'] + img_aug = np.zeros_like(img) + for i in range(img.shape[0]): + img_aug_, _, _ = augment_image(img[i, 0], rng, + max_shift=self.max_shift, max_blur=self.max_blur, + max_noise=self.max_noise, max_sp_noise=self.max_sp_noise) + img_aug[i] = img_aug_[None].astype(np.float32) + ret[f'im{sidx}'] = img_aug + + if len(track_ind) == 1: + for key, val in ret.items(): + if key != 'blend_im' and key != 'id': + ret[key] = val[0] + + return ret + + def getK(self, sidx=0): + K = self.K.copy() / (2 ** sidx) + K[2, 2] = 1 + return K + + if __name__ == '__main__': pass