allow pickle, otherwise we can't load all data

master
CptCaptain 3 years ago
parent d63fedd297
commit 0d18f71c6d
  1. 14
      data/dataset.py

@ -77,9 +77,9 @@ class TrackSynDataset(torchext.BaseDataset):
ambs = [] ambs = []
grads = [] grads = []
for tidx in track_ind: for tidx in track_ind:
imgs.append(np.load(os.path.join(sample_path, f'im{sidx}_{tidx}.npy'))) imgs.append(np.load(os.path.join(sample_path, f'im{sidx}_{tidx}.npy')), allow_pickle=True)
ambs.append(np.load(os.path.join(sample_path, f'ambient{sidx}_{tidx}.npy'))) ambs.append(np.load(os.path.join(sample_path, f'ambient{sidx}_{tidx}.npy')), allow_pickle=True)
grads.append(np.load(os.path.join(sample_path, f'grad{sidx}_{tidx}.npy'))) grads.append(np.load(os.path.join(sample_path, f'grad{sidx}_{tidx}.npy')), allow_pickle=True)
ret[f'im{sidx}'] = np.stack(imgs, axis=0) ret[f'im{sidx}'] = np.stack(imgs, axis=0)
ret[f'ambient{sidx}'] = np.stack(ambs, axis=0) ret[f'ambient{sidx}'] = np.stack(ambs, axis=0)
ret[f'grad{sidx}'] = np.stack(grads, axis=0) ret[f'grad{sidx}'] = np.stack(grads, axis=0)
@ -89,14 +89,14 @@ class TrackSynDataset(torchext.BaseDataset):
R = [] R = []
t = [] t = []
for tidx in track_ind: for tidx in track_ind:
disps.append(np.load(os.path.join(sample_path, f'disp0_{tidx}.npy'))) disps.append(np.load(os.path.join(sample_path, f'disp0_{tidx}.npy')), allow_pickle=True)
R.append(np.load(os.path.join(sample_path, f'R_{tidx}.npy'))) R.append(np.load(os.path.join(sample_path, f'R_{tidx}.npy')), allow_pickle=True)
t.append(np.load(os.path.join(sample_path, f't_{tidx}.npy'))) t.append(np.load(os.path.join(sample_path, f't_{tidx}.npy')), allow_pickle=True)
ret[f'disp0'] = np.stack(disps, axis=0) ret[f'disp0'] = np.stack(disps, axis=0)
ret['R'] = np.stack(R, axis=0) ret['R'] = np.stack(R, axis=0)
ret['t'] = np.stack(t, axis=0) ret['t'] = np.stack(t, axis=0)
blend_im = np.load(os.path.join(sample_path, 'blend_im.npy')) blend_im = np.load(os.path.join(sample_path, 'blend_im.npy'), allow_pickle=True)
ret['blend_im'] = blend_im.astype(np.float32) ret['blend_im'] = blend_im.astype(np.float32)
#### apply data augmentation at different scales seperately, only work for max_shift=0 #### apply data augmentation at different scales seperately, only work for max_shift=0

Loading…
Cancel
Save