diff --git a/model/exp_synphge.py b/model/exp_synphge.py index 23da8e5..499bac1 100644 --- a/model/exp_synphge.py +++ b/model/exp_synphge.py @@ -12,12 +12,18 @@ import torchext from model import networks from data import dataset +import wandb + class Worker(torchext.Worker): - def __init__(self, args, num_workers=18, train_batch_size=6, test_batch_size=6, save_frequency=1, **kwargs): + def __init__(self, args, num_workers=18, train_batch_size=2, test_batch_size=2, save_frequency=1, **kwargs): + if 'batch_size' in dir(args): + train_batch_size = args.batch_size + test_batch_size = args.batch_size super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, train_batch_size=train_batch_size, test_batch_size=test_batch_size, save_frequency=save_frequency, **kwargs) + print(args.no_double_heads) self.ms = args.ms self.pattern_path = args.pattern_path @@ -28,13 +34,15 @@ class Worker(torchext.Worker): self.data_type = args.data_type assert (self.track_length > 1) - self.imsizes = [(488, 648)] - for iter in range(3): - self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2))) - with open('config.json') as fp: config = json.load(fp) data_root = Path(config['DATA_ROOT']) + self.imsizes = [tuple(map(int, config['IMSIZE'].split(',')))] + + for iter in range(3): + self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2))) + + self.settings_path = data_root / self.data_type / 'settings.pkl' sample_paths = sorted((data_root / self.data_type).glob('0*/')) @@ -84,7 +92,9 @@ class Worker(torchext.Worker): Ki = np.linalg.inv(K) K = torch.from_numpy(K) Ki = torch.from_numpy(Ki) - ge_loss = networks.ProjectionDepthSimilarityLoss(K, Ki, imsize[0], imsize[1], clamp=0.1) + # FIXME why would i need to increase this? + ge_loss = networks.ProjectionDepthSimilarityLoss(K, Ki, imsize[0], imsize[1], clamp=0.5) + # ge_loss = networks.ProjectionDepthSimilarityLoss(K, Ki, imsize[0], imsize[1], clamp=0.1) self.ph_losses.append(ph_loss) self.ge_losses.append(ge_loss) @@ -130,7 +140,8 @@ class Worker(torchext.Worker): out = out.view(tl, bs, *out.shape[1:]) edge = edge.view(tl, bs, *out.shape[1:]) else: - out = [o.view(tl, bs, *o.shape[1:]) for o in out] + out = [o[0].view(tl, bs, *o[0].shape[1:]) for o in out] + # out = [o.view(tl, bs, *o.shape[1:]) for o in out] edge = [e.view(tl, bs, *e.shape[1:]) for e in edge] return out, edge @@ -140,6 +151,7 @@ class Worker(torchext.Worker): out = [out] vals = [] diffs = [] + losses = {} # apply photometric loss for s, l, o in zip(itertools.count(), self.ph_losses, out): @@ -149,6 +161,7 @@ class Worker(torchext.Worker): std = self.data[f'std{s}'] std = std.view(-1, *std.shape[2:]) val, pattern_proj = l(o, im[:, 0:1, ...], std) + losses['photometric'] = val vals.append(val) if s == 0: self.pattern_proj = pattern_proj.detach() @@ -159,6 +172,7 @@ class Worker(torchext.Worker): edge0 = edge0.view(-1, *edge0.shape[2:]) out0 = out[0].view(-1, *out[0].shape[2:]) val = self.disparity_loss(out0, edge0) + losses['disparity'] = val * self.dp_weight if self.dp_weight > 0: vals.append(val * self.dp_weight) @@ -177,6 +191,7 @@ class Worker(torchext.Worker): val = self.edge_loss(e, grad) else: val = torch.zeros_like(vals[0]) + losses['edge loss'] = val vals.append(val) if train is False: @@ -201,8 +216,10 @@ class Worker(torchext.Worker): t1 = t[tidx1] val = ge_loss(depth0, depth1, R0, t0, R1, t1) + losses['geometric loss'] = val vals.append(val * self.ge_weight / ge_num) + wandb.log(losses) return vals def numpy_in_out(self, output): @@ -287,6 +304,7 @@ class Worker(torchext.Worker): plt.tight_layout() plt.savefig(str(out_path)) + wandb.log({f'results_{"_".join(out_path.stem.split("_")[:-1])}': plt}) plt.close(fig) def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks):