Reformat $EVERYTHING

This commit is contained in:
2021-11-15 16:53:30 +01:00
parent 56f2aa7d5d
commit 43df77fb9b
32 changed files with 4171 additions and 3749 deletions
+215 -178
View File
@@ -12,226 +12,263 @@ import torchext
from model import networks
from data import dataset
class Worker(torchext.Worker):
def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs):
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, train_batch_size=train_batch_size, test_batch_size=test_batch_size, save_frequency=save_frequency, **kwargs)
def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs):
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers,
train_batch_size=train_batch_size, test_batch_size=test_batch_size,
save_frequency=save_frequency, **kwargs)
self.ms = args.ms
self.pattern_path = args.pattern_path
self.lcn_radius = args.lcn_radius
self.dp_weight = args.dp_weight
self.data_type = args.data_type
self.ms = args.ms
self.pattern_path = args.pattern_path
self.lcn_radius = args.lcn_radius
self.dp_weight = args.dp_weight
self.data_type = args.data_type
self.imsizes = [(480,640)]
for iter in range(3):
self.imsizes.append((int(self.imsizes[-1][0]/2), int(self.imsizes[-1][1]/2)))
self.imsizes = [(488, 648)]
for iter in range(3):
self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2)))
with open('config.json') as fp:
config = json.load(fp)
data_root = Path(config['DATA_ROOT'])
self.settings_path = data_root / self.data_type / 'settings.pkl'
sample_paths = sorted((data_root / self.data_type).glob('0*/'))
with open('config.json') as fp:
config = json.load(fp)
data_root = Path(config['DATA_ROOT'])
self.settings_path = data_root / self.data_type / 'settings.pkl'
sample_paths = sorted((data_root / self.data_type).glob('0*/'))
self.train_paths = sample_paths[2**10:]
self.test_paths = sample_paths[:2**8]
self.train_paths = sample_paths[2 ** 10:]
self.test_paths = sample_paths[:2 ** 8]
# supervise the edge encoder with only 2**8 samples
self.train_edge = len(self.train_paths) - 2**8
# supervise the edge encoder with only 2**8 samples
self.train_edge = len(self.train_paths) - 2 ** 8
self.lcn_in = networks.LCN(self.lcn_radius, 0.05)
self.disparity_loss = networks.DisparityLoss()
self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device))
self.lcn_in = networks.LCN(self.lcn_radius, 0.05)
self.disparity_loss = networks.DisparityLoss()
self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device))
# evaluate in the region where opencv Block Matching has valid values
self.eval_mask = np.zeros(self.imsizes[0])
self.eval_mask[13:self.imsizes[0][0]-13, 140:self.imsizes[0][1]-13]=1
self.eval_mask = self.eval_mask.astype(np.bool)
self.eval_h = self.imsizes[0][0]-2*13
self.eval_w = self.imsizes[0][1]-13-140
# evaluate in the region where opencv Block Matching has valid values
self.eval_mask = np.zeros(self.imsizes[0])
self.eval_mask[13:self.imsizes[0][0] - 13, 140:self.imsizes[0][1] - 13] = 1
self.eval_mask = self.eval_mask.astype(np.bool)
self.eval_h = self.imsizes[0][0] - 2 * 13
self.eval_w = self.imsizes[0][1] - 13 - 140
def get_train_set(self):
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True, track_length=1)
def get_train_set(self):
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True,
track_length=1)
return train_set
return train_set
def get_test_sets(self):
test_sets = torchext.TestSets()
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True, track_length=1)
test_sets.append('simple', test_set, test_frequency=1)
def get_test_sets(self):
test_sets = torchext.TestSets()
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True,
track_length=1)
test_sets.append('simple', test_set, test_frequency=1)
# initialize photometric loss modules according to image sizes
self.losses = []
for imsize, pat in zip(test_set.imsizes, test_set.patterns):
pat = pat.mean(axis=2)
pat = torch.from_numpy(pat[None][None].astype(np.float32))
pat = pat.to(self.train_device)
self.lcn_in = self.lcn_in.to(self.train_device)
pat,_ = self.lcn_in(pat)
pat = torch.cat([pat for idx in range(3)], dim=1)
self.losses.append( networks.RectifiedPatternSimilarityLoss(imsize[0],imsize[1], pattern=pat) )
# initialize photometric loss modules according to image sizes
self.losses = []
for imsize, pat in zip(test_set.imsizes, test_set.patterns):
pat = pat.mean(axis=2)
pat = torch.from_numpy(pat[None][None].astype(np.float32))
pat = pat.to(self.train_device)
self.lcn_in = self.lcn_in.to(self.train_device)
pat, _ = self.lcn_in(pat)
pat = torch.cat([pat for idx in range(3)], dim=1)
self.losses.append(networks.RectifiedPatternSimilarityLoss(imsize[0], imsize[1], pattern=pat))
return test_sets
return test_sets
def copy_data(self, data, device, requires_grad, train):
self.lcn_in = self.lcn_in.to(device)
def copy_data(self, data, device, requires_grad, train):
self.lcn_in = self.lcn_in.to(device)
self.data = {}
for key, val in data.items():
grad = 'im' in key and requires_grad
self.data[key] = val.to(device).requires_grad_(requires_grad=grad)
self.data = {}
for key, val in data.items():
grad = 'im' in key and requires_grad
self.data[key] = val.to(device).requires_grad_(requires_grad=grad)
# apply lcn to IR input
# concatenate the normalized IR input and the original IR image
if 'im' in key and 'blend' not in key:
im = self.data[key]
im_lcn,im_std = self.lcn_in(im)
im_cat = torch.cat((im_lcn, im), dim=1)
key_std = key.replace('im','std')
self.data[key]=im_cat
self.data[key_std] = im_std.to(device).detach()
# apply lcn to IR input
# concatenate the normalized IR input and the original IR image
if 'im' in key and 'blend' not in key:
im = self.data[key]
im_lcn, im_std = self.lcn_in(im)
im_cat = torch.cat((im_lcn, im), dim=1)
key_std = key.replace('im', 'std')
self.data[key] = im_cat
self.data[key_std] = im_std.to(device).detach()
def net_forward(self, net, train):
out = net(self.data['im0'])
return out
def net_forward(self, net, train):
out = net(self.data['im0'])
return out
def loss_forward(self, out, train):
out, edge = out
if not(isinstance(out, tuple) or isinstance(out, list)):
out = [out]
if not(isinstance(edge, tuple) or isinstance(edge, list)):
edge = [edge]
def loss_forward(self, out, train):
out, edge = out
if not (isinstance(out, tuple) or isinstance(out, list)):
out = [out]
if not (isinstance(edge, tuple) or isinstance(edge, list)):
edge = [edge]
vals = []
vals = []
# apply photometric loss
for s,l,o in zip(itertools.count(), self.losses, out):
val, pattern_proj = l(o, self.data[f'im{s}'][:,0:1,...], self.data[f'std{s}'])
if s == 0:
self.pattern_proj = pattern_proj.detach()
vals.append(val)
# apply photometric loss
for s, l, o in zip(itertools.count(), self.losses, out):
val, pattern_proj = l(o, self.data[f'im{s}'][:, 0:1, ...], self.data[f'std{s}'])
if s == 0:
self.pattern_proj = pattern_proj.detach()
vals.append(val)
# apply disparity loss
# 1-edge as ground truth edge if inversed
edge0 = 1-torch.sigmoid(edge[0])
val = self.disparity_loss(out[0], edge0)
if self.dp_weight>0:
vals.append(val * self.dp_weight)
# apply disparity loss
# 1-edge as ground truth edge if inversed
edge0 = 1 - torch.sigmoid(edge[0])
val = self.disparity_loss(out[0], edge0)
if self.dp_weight > 0:
vals.append(val * self.dp_weight)
# apply edge loss on a subset of training samples
for s,e in zip(itertools.count(), edge):
# inversed ground truth edge where 0 means edge
grad = self.data[f'grad{s}']<0.2
grad = grad.to(torch.float32)
ids = self.data['id']
mask = ids>self.train_edge
if mask.sum()>0:
val = self.edge_loss(e[mask], grad[mask])
else:
val = torch.zeros_like(vals[0])
if s == 0:
self.edge = e.detach()
self.edge = torch.sigmoid(self.edge)
self.edge_gt = grad.detach()
vals.append(val)
# apply edge loss on a subset of training samples
for s, e in zip(itertools.count(), edge):
# inversed ground truth edge where 0 means edge
grad = self.data[f'grad{s}'] < 0.2
grad = grad.to(torch.float32)
ids = self.data['id']
mask = ids > self.train_edge
if mask.sum() > 0:
val = self.edge_loss(e[mask], grad[mask])
else:
val = torch.zeros_like(vals[0])
if s == 0:
self.edge = e.detach()
self.edge = torch.sigmoid(self.edge)
self.edge_gt = grad.detach()
vals.append(val)
return vals
return vals
def numpy_in_out(self, output):
output, edge = output
if not(isinstance(output, tuple) or isinstance(output, list)):
output = [output]
es = output[0].detach().to('cpu').numpy()
gt = self.data['disp0'].to('cpu').numpy().astype(np.float32)
im = self.data['im0'][:,0:1,...].detach().to('cpu').numpy()
def numpy_in_out(self, output):
output, edge = output
if not (isinstance(output, tuple) or isinstance(output, list)):
output = [output]
es = output[0].detach().to('cpu').numpy()
gt = self.data['disp0'].to('cpu').numpy().astype(np.float32)
im = self.data['im0'][:, 0:1, ...].detach().to('cpu').numpy()
ma = gt>0
return es, gt, im, ma
ma = gt > 0
return es, gt, im, ma
def write_img(self, out_path, es, gt, im, ma):
logging.info(f'write img {out_path}')
u_pos, _ = np.meshgrid(range(es.shape[1]), range(es.shape[0]))
def write_img(self, out_path, es, gt, im, ma):
logging.info(f'write img {out_path}')
u_pos, _ = np.meshgrid(range(es.shape[1]), range(es.shape[0]))
diff = np.abs(es - gt)
diff = np.abs(es - gt)
vmin, vmax = np.nanmin(gt), np.nanmax(gt)
vmin = vmin - 0.2*(vmax-vmin)
vmax = vmax + 0.2*(vmax-vmin)
vmin, vmax = np.nanmin(gt), np.nanmax(gt)
vmin = vmin - 0.2 * (vmax - vmin)
vmax = vmax + 0.2 * (vmax - vmin)
pattern_proj = self.pattern_proj.to('cpu').numpy()[0,0]
im_orig = self.data['im0'].detach().to('cpu').numpy()[0,0]
pattern_diff = np.abs(im_orig - pattern_proj)
pattern_proj = self.pattern_proj.to('cpu').numpy()[0, 0]
im_orig = self.data['im0'].detach().to('cpu').numpy()[0, 0]
pattern_diff = np.abs(im_orig - pattern_proj)
fig = plt.figure(figsize=(16, 16))
es_ = co.cmap.color_depth_map(es, scale=vmax)
gt_ = co.cmap.color_depth_map(gt, scale=vmax)
diff_ = co.cmap.color_error_image(diff, BGR=True)
fig = plt.figure(figsize=(16,16))
es_ = co.cmap.color_depth_map(es, scale=vmax)
gt_ = co.cmap.color_depth_map(gt, scale=vmax)
diff_ = co.cmap.color_error_image(diff, BGR=True)
# plot disparities, ground truth disparity is shown only for reference
ax = plt.subplot(3, 3, 1)
plt.imshow(es_[..., [2, 1, 0]])
plt.xticks([])
plt.yticks([])
ax.set_title(f'Disparity Est. {es.min():.4f}/{es.max():.4f}')
ax = plt.subplot(3, 3, 2)
plt.imshow(gt_[..., [2, 1, 0]])
plt.xticks([])
plt.yticks([])
ax.set_title(f'Disparity GT {np.nanmin(gt):.4f}/{np.nanmax(gt):.4f}')
ax = plt.subplot(3, 3, 3)
plt.imshow(diff_[..., [2, 1, 0]])
plt.xticks([])
plt.yticks([])
ax.set_title(f'Disparity Err. {diff.mean():.5f}')
# plot disparities, ground truth disparity is shown only for reference
ax = plt.subplot(3,3,1); plt.imshow(es_[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'Disparity Est. {es.min():.4f}/{es.max():.4f}')
ax = plt.subplot(3,3,2); plt.imshow(gt_[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'Disparity GT {np.nanmin(gt):.4f}/{np.nanmax(gt):.4f}')
ax = plt.subplot(3,3,3); plt.imshow(diff_[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'Disparity Err. {diff.mean():.5f}')
# plot edges
edge = self.edge.to('cpu').numpy()[0, 0]
edge_gt = self.edge_gt.to('cpu').numpy()[0, 0]
edge_err = np.abs(edge - edge_gt)
ax = plt.subplot(3, 3, 4);
plt.imshow(edge, cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'Edge Est. {edge.min():.5f}/{edge.max():.5f}')
ax = plt.subplot(3, 3, 5);
plt.imshow(edge_gt, cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'Edge GT {edge_gt.min():.5f}/{edge_gt.max():.5f}')
ax = plt.subplot(3, 3, 6);
plt.imshow(edge_err, cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'Edge Err. {edge_err.mean():.5f}')
# plot edges
edge = self.edge.to('cpu').numpy()[0,0]
edge_gt = self.edge_gt.to('cpu').numpy()[0,0]
edge_err = np.abs(edge - edge_gt)
ax = plt.subplot(3,3,4); plt.imshow(edge, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Edge Est. {edge.min():.5f}/{edge.max():.5f}')
ax = plt.subplot(3,3,5); plt.imshow(edge_gt, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Edge GT {edge_gt.min():.5f}/{edge_gt.max():.5f}')
ax = plt.subplot(3,3,6); plt.imshow(edge_err, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Edge Err. {edge_err.mean():.5f}')
# plot normalized IR input and warped pattern
ax = plt.subplot(3, 3, 7);
plt.imshow(im, vmin=im.min(), vmax=im.max(), cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'IR input {im.mean():.5f}/{im.std():.5f}')
ax = plt.subplot(3, 3, 8);
plt.imshow(pattern_proj, vmin=im.min(), vmax=im.max(), cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'Warped Pattern {pattern_proj.mean():.5f}/{pattern_proj.std():.5f}')
im_std = self.data['std0'].to('cpu').numpy()[0, 0]
ax = plt.subplot(3, 3, 9);
plt.imshow(im_std, cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'IR std {im_std.min():.5f}/{im_std.max():.5f}')
# plot normalized IR input and warped pattern
ax = plt.subplot(3,3,7); plt.imshow(im, vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'IR input {im.mean():.5f}/{im.std():.5f}')
ax = plt.subplot(3,3,8); plt.imshow(pattern_proj, vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Warped Pattern {pattern_proj.mean():.5f}/{pattern_proj.std():.5f}')
im_std = self.data['std0'].to('cpu').numpy()[0,0]
ax = plt.subplot(3,3,9); plt.imshow(im_std, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'IR std {im_std.min():.5f}/{im_std.max():.5f}')
plt.tight_layout()
plt.savefig(str(out_path))
plt.close(fig)
plt.tight_layout()
plt.savefig(str(out_path))
plt.close(fig)
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks=[]):
if batch_idx % 512 == 0:
out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png'
es, gt, im, ma = self.numpy_in_out(output)
self.write_img(out_path, es[0, 0], gt[0, 0], im[0, 0], ma[0, 0])
def callback_test_start(self, epoch, set_idx):
self.metric = co.metric.MultipleMetric(
co.metric.DistanceMetric(vec_length=1),
co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5])
)
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks=[]):
if batch_idx % 512 == 0:
out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png'
es, gt, im, ma = self.numpy_in_out(output)
self.write_img(out_path, es[0,0], gt[0,0], im[0,0], ma[0,0])
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks=[]):
es, gt, im, ma = self.numpy_in_out(output)
if batch_idx % 8 == 0:
out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png'
self.write_img(out_path, es[0, 0], gt[0, 0], im[0, 0], ma[0, 0])
def callback_test_start(self, epoch, set_idx):
self.metric = co.metric.MultipleMetric(
co.metric.DistanceMetric(vec_length=1),
co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5])
)
es, gt, im, ma = self.crop_output(es, gt, im, ma)
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks=[]):
es, gt, im, ma = self.numpy_in_out(output)
es = es.reshape(-1, 1)
gt = gt.reshape(-1, 1)
ma = ma.ravel()
self.metric.add(es, gt, ma)
if batch_idx % 8 == 0:
out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png'
self.write_img(out_path, es[0,0], gt[0,0], im[0,0], ma[0,0])
es, gt, im, ma = self.crop_output(es, gt, im, ma)
es = es.reshape(-1,1)
gt = gt.reshape(-1,1)
ma = ma.ravel()
self.metric.add(es, gt, ma)
def callback_test_stop(self, epoch, set_idx, loss):
logging.info(f'{self.metric}')
for k, v in self.metric.items():
self.metric_add_test(epoch, set_idx, k, v)
def crop_output(self, es, gt, im, ma):
bs = es.shape[0]
es = np.reshape(es[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
gt = np.reshape(gt[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
im = np.reshape(im[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
ma = np.reshape(ma[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
return es, gt, im, ma
def callback_test_stop(self, epoch, set_idx, loss):
logging.info(f'{self.metric}')
for k, v in self.metric.items():
self.metric_add_test(epoch, set_idx, k, v)
def crop_output(self, es, gt, im, ma):
bs = es.shape[0]
es = np.reshape(es[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
gt = np.reshape(gt[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
im = np.reshape(im[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
ma = np.reshape(ma[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
return es, gt, im, ma
if __name__ == '__main__':
pass
pass
+274 -237
View File
@@ -12,287 +12,324 @@ import torchext
from model import networks
from data import dataset
class Worker(torchext.Worker):
def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs):
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, train_batch_size=train_batch_size, test_batch_size=test_batch_size, save_frequency=save_frequency, **kwargs)
def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs):
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers,
train_batch_size=train_batch_size, test_batch_size=test_batch_size,
save_frequency=save_frequency, **kwargs)
self.ms = args.ms
self.pattern_path = args.pattern_path
self.lcn_radius = args.lcn_radius
self.dp_weight = args.dp_weight
self.ge_weight = args.ge_weight
self.track_length = args.track_length
self.data_type = args.data_type
assert(self.track_length>1)
self.ms = args.ms
self.pattern_path = args.pattern_path
self.lcn_radius = args.lcn_radius
self.dp_weight = args.dp_weight
self.ge_weight = args.ge_weight
self.track_length = args.track_length
self.data_type = args.data_type
assert (self.track_length > 1)
self.imsizes = [(480,640)]
for iter in range(3):
self.imsizes.append((int(self.imsizes[-1][0]/2), int(self.imsizes[-1][1]/2)))
self.imsizes = [(480, 640)]
for iter in range(3):
self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2)))
with open('config.json') as fp:
config = json.load(fp)
data_root = Path(config['DATA_ROOT'])
self.settings_path = data_root / self.data_type / 'settings.pkl'
sample_paths = sorted((data_root / self.data_type).glob('0*/'))
with open('config.json') as fp:
config = json.load(fp)
data_root = Path(config['DATA_ROOT'])
self.settings_path = data_root / self.data_type / 'settings.pkl'
sample_paths = sorted((data_root / self.data_type).glob('0*/'))
self.train_paths = sample_paths[2**10:]
self.test_paths = sample_paths[:2**8]
self.train_paths = sample_paths[2 ** 10:]
self.test_paths = sample_paths[:2 ** 8]
# supervise the edge encoder with only 2**8 samples
self.train_edge = len(self.train_paths) - 2**8
# supervise the edge encoder with only 2**8 samples
self.train_edge = len(self.train_paths) - 2 ** 8
self.lcn_in = networks.LCN(self.lcn_radius, 0.05)
self.disparity_loss = networks.DisparityLoss()
self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device))
self.lcn_in = networks.LCN(self.lcn_radius, 0.05)
self.disparity_loss = networks.DisparityLoss()
self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device))
# evaluate in the region where opencv Block Matching has valid values
self.eval_mask = np.zeros(self.imsizes[0])
self.eval_mask[13:self.imsizes[0][0]-13, 140:self.imsizes[0][1]-13]=1
self.eval_mask = self.eval_mask.astype(np.bool)
self.eval_h = self.imsizes[0][0]-2*13
self.eval_w = self.imsizes[0][1]-13-140
# evaluate in the region where opencv Block Matching has valid values
self.eval_mask = np.zeros(self.imsizes[0])
self.eval_mask[13:self.imsizes[0][0] - 13, 140:self.imsizes[0][1] - 13] = 1
self.eval_mask = self.eval_mask.astype(np.bool)
self.eval_h = self.imsizes[0][0] - 2 * 13
self.eval_w = self.imsizes[0][1] - 13 - 140
def get_train_set(self):
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True,
track_length=self.track_length)
return train_set
def get_train_set(self):
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True, track_length=self.track_length)
return train_set
def get_test_sets(self):
test_sets = torchext.TestSets()
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True,
track_length=1)
test_sets.append('simple', test_set, test_frequency=1)
def get_test_sets(self):
test_sets = torchext.TestSets()
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True, track_length=1)
test_sets.append('simple', test_set, test_frequency=1)
self.ph_losses = []
self.ge_losses = []
self.d2ds = []
self.ph_losses = []
self.ge_losses = []
self.d2ds = []
self.lcn_in = self.lcn_in.to('cuda')
for sidx in range(len(test_set.imsizes)):
imsize = test_set.imsizes[sidx]
pat = test_set.patterns[sidx]
pat = pat.mean(axis=2)
pat = torch.from_numpy(pat[None][None].astype(np.float32)).to('cuda')
pat, _ = self.lcn_in(pat)
pat = torch.cat([pat for idx in range(3)], dim=1)
ph_loss = networks.RectifiedPatternSimilarityLoss(imsize[0], imsize[1], pattern=pat)
self.lcn_in = self.lcn_in.to('cuda')
for sidx in range(len(test_set.imsizes)):
imsize = test_set.imsizes[sidx]
pat = test_set.patterns[sidx]
pat = pat.mean(axis=2)
pat = torch.from_numpy(pat[None][None].astype(np.float32)).to('cuda')
pat,_ = self.lcn_in(pat)
pat = torch.cat([pat for idx in range(3)], dim=1)
ph_loss = networks.RectifiedPatternSimilarityLoss(imsize[0],imsize[1], pattern=pat)
K = test_set.getK(sidx)
Ki = np.linalg.inv(K)
K = torch.from_numpy(K)
Ki = torch.from_numpy(Ki)
ge_loss = networks.ProjectionDepthSimilarityLoss(K, Ki, imsize[0], imsize[1], clamp=0.1)
K = test_set.getK(sidx)
Ki = np.linalg.inv(K)
K = torch.from_numpy(K)
Ki = torch.from_numpy(Ki)
ge_loss = networks.ProjectionDepthSimilarityLoss(K, Ki, imsize[0], imsize[1], clamp=0.1)
self.ph_losses.append(ph_loss)
self.ge_losses.append(ge_loss)
self.ph_losses.append( ph_loss )
self.ge_losses.append( ge_loss )
d2d = networks.DispToDepth(float(test_set.focal_lengths[sidx]), float(test_set.baseline))
self.d2ds.append( d2d )
d2d = networks.DispToDepth(float(test_set.focal_lengths[sidx]), float(test_set.baseline))
self.d2ds.append(d2d)
return test_sets
return test_sets
def copy_data(self, data, device, requires_grad, train):
self.data = {}
def copy_data(self, data, device, requires_grad, train):
self.data = {}
self.lcn_in = self.lcn_in.to(device)
for key, val in data.items():
# from
# batch_size x track_length x ...
# to
# track_length x batch_size x ...
if len(val.shape)>2:
if train:
val = val.transpose(0,1)
self.lcn_in = self.lcn_in.to(device)
for key, val in data.items():
# from
# batch_size x track_length x ...
# to
# track_length x batch_size x ...
if len(val.shape) > 2:
if train:
val = val.transpose(0, 1)
else:
val = val.unsqueeze(0)
grad = 'im' in key and requires_grad
self.data[key] = val.to(device).requires_grad_(requires_grad=grad)
if 'im' in key and 'blend' not in key:
im = self.data[key]
tl = im.shape[0]
bs = im.shape[1]
im_lcn, im_std = self.lcn_in(im.contiguous().view(-1, *im.shape[2:]))
key_std = key.replace('im', 'std')
self.data[key_std] = im_std.view(tl, bs, *im.shape[2:]).to(device)
im_cat = torch.cat((im_lcn.view(tl, bs, *im.shape[2:]), im), dim=2)
self.data[key] = im_cat
def net_forward(self, net, train):
im0 = self.data['im0']
tl = im0.shape[0]
bs = im0.shape[1]
im0 = im0.view(-1, *im0.shape[2:])
out, edge = net(im0)
if not (isinstance(out, tuple) or isinstance(out, list)):
out = out.view(tl, bs, *out.shape[1:])
edge = edge.view(tl, bs, *out.shape[1:])
else:
val = val.unsqueeze(0)
grad = 'im' in key and requires_grad
self.data[key] = val.to(device).requires_grad_(requires_grad=grad)
if 'im' in key and 'blend' not in key:
im = self.data[key]
tl = im.shape[0]
bs = im.shape[1]
im_lcn,im_std = self.lcn_in(im.contiguous().view(-1, *im.shape[2:]))
key_std = key.replace('im','std')
self.data[key_std] = im_std.view(tl, bs, *im.shape[2:]).to(device)
im_cat = torch.cat((im_lcn.view(tl, bs, *im.shape[2:]), im), dim=2)
self.data[key] = im_cat
out = [o.view(tl, bs, *o.shape[1:]) for o in out]
edge = [e.view(tl, bs, *e.shape[1:]) for e in edge]
return out, edge
def net_forward(self, net, train):
im0 = self.data['im0']
tl = im0.shape[0]
bs = im0.shape[1]
im0 = im0.view(-1, *im0.shape[2:])
out, edge = net(im0)
if not(isinstance(out, tuple) or isinstance(out, list)):
out = out.view(tl, bs, *out.shape[1:])
edge = edge.view(tl, bs, *out.shape[1:])
else:
out = [o.view(tl, bs, *o.shape[1:]) for o in out]
edge = [e.view(tl, bs, *e.shape[1:]) for e in edge]
return out, edge
def loss_forward(self, out, train):
out, edge = out
if not (isinstance(out, tuple) or isinstance(out, list)):
out = [out]
vals = []
diffs = []
def loss_forward(self, out, train):
out, edge = out
if not(isinstance(out, tuple) or isinstance(out, list)):
out = [out]
vals = []
diffs = []
# apply photometric loss
for s, l, o in zip(itertools.count(), self.ph_losses, out):
im = self.data[f'im{s}']
im = im.view(-1, *im.shape[2:])
o = o.view(-1, *o.shape[2:])
std = self.data[f'std{s}']
std = std.view(-1, *std.shape[2:])
val, pattern_proj = l(o, im[:, 0:1, ...], std)
vals.append(val)
if s == 0:
self.pattern_proj = pattern_proj.detach()
# apply photometric loss
for s,l,o in zip(itertools.count(), self.ph_losses, out):
im = self.data[f'im{s}']
im = im.view(-1, *im.shape[2:])
o = o.view(-1, *o.shape[2:])
std = self.data[f'std{s}']
std = std.view(-1, *std.shape[2:])
val, pattern_proj = l(o, im[:,0:1,...], std)
vals.append(val)
if s == 0:
self.pattern_proj = pattern_proj.detach()
# apply disparity loss
# 1-edge as ground truth edge if inversed
edge0 = 1 - torch.sigmoid(edge[0])
edge0 = edge0.view(-1, *edge0.shape[2:])
out0 = out[0].view(-1, *out[0].shape[2:])
val = self.disparity_loss(out0, edge0)
if self.dp_weight > 0:
vals.append(val * self.dp_weight)
# apply disparity loss
# 1-edge as ground truth edge if inversed
edge0 = 1-torch.sigmoid(edge[0])
edge0 = edge0.view(-1, *edge0.shape[2:])
out0 = out[0].view(-1, *out[0].shape[2:])
val = self.disparity_loss(out0, edge0)
if self.dp_weight>0:
vals.append(val * self.dp_weight)
# apply edge loss on a subset of training samples
for s, e in zip(itertools.count(), edge):
# inversed ground truth edge where 0 means edge
grad = self.data[f'grad{s}'] < 0.2
grad = grad.to(torch.float32)
ids = self.data['id']
mask = ids > self.train_edge
if mask.sum() > 0:
e = e[:, mask, :]
grad = grad[:, mask, :]
e = e.view(-1, *e.shape[2:])
grad = grad.view(-1, *grad.shape[2:])
val = self.edge_loss(e, grad)
else:
val = torch.zeros_like(vals[0])
vals.append(val)
# apply edge loss on a subset of training samples
for s,e in zip(itertools.count(), edge):
# inversed ground truth edge where 0 means edge
grad = self.data[f'grad{s}']<0.2
grad = grad.to(torch.float32)
ids = self.data['id']
mask = ids>self.train_edge
if mask.sum()>0:
e = e[:,mask,:]
grad = grad[:,mask,:]
e = e.view(-1, *e.shape[2:])
grad = grad.view(-1, *grad.shape[2:])
val = self.edge_loss(e, grad)
else:
val = torch.zeros_like(vals[0])
vals.append(val)
if train is False:
return vals
if train is False:
return vals
# apply geometric loss
R = self.data['R']
t = self.data['t']
ge_num = self.track_length * (self.track_length - 1) / 2
for sidx in range(len(out)):
d2d = self.d2ds[sidx]
depth = d2d(out[sidx])
ge_loss = self.ge_losses[sidx]
imsize = self.imsizes[sidx]
for tidx0 in range(depth.shape[0]):
for tidx1 in range(tidx0 + 1, depth.shape[0]):
depth0 = depth[tidx0]
R0 = R[tidx0]
t0 = t[tidx0]
depth1 = depth[tidx1]
R1 = R[tidx1]
t1 = t[tidx1]
# apply geometric loss
R = self.data['R']
t = self.data['t']
ge_num = self.track_length * (self.track_length-1) / 2
for sidx in range(len(out)):
d2d = self.d2ds[sidx]
depth = d2d(out[sidx])
ge_loss = self.ge_losses[sidx]
imsize = self.imsizes[sidx]
for tidx0 in range(depth.shape[0]):
for tidx1 in range(tidx0+1, depth.shape[0]):
depth0 = depth[tidx0]
R0 = R[tidx0]
t0 = t[tidx0]
depth1 = depth[tidx1]
R1 = R[tidx1]
t1 = t[tidx1]
val = ge_loss(depth0, depth1, R0, t0, R1, t1)
vals.append(val * self.ge_weight / ge_num)
val = ge_loss(depth0, depth1, R0, t0, R1, t1)
vals.append(val * self.ge_weight / ge_num)
return vals
return vals
def numpy_in_out(self, output):
output, edge = output
if not (isinstance(output, tuple) or isinstance(output, list)):
output = [output]
es = output[0].detach().to('cpu').numpy()
gt = self.data['disp0'].to('cpu').numpy().astype(np.float32)
im = self.data['im0'][:, :, 0:1, ...].detach().to('cpu').numpy()
ma = gt > 0
return es, gt, im, ma
def numpy_in_out(self, output):
output, edge = output
if not(isinstance(output, tuple) or isinstance(output, list)):
output = [output]
es = output[0].detach().to('cpu').numpy()
gt = self.data['disp0'].to('cpu').numpy().astype(np.float32)
im = self.data['im0'][:,:,0:1,...].detach().to('cpu').numpy()
ma = gt>0
return es, gt, im, ma
def write_img(self, out_path, es, gt, im, ma):
logging.info(f'write img {out_path}')
u_pos, _ = np.meshgrid(range(es.shape[1]), range(es.shape[0]))
def write_img(self, out_path, es, gt, im, ma):
logging.info(f'write img {out_path}')
u_pos, _ = np.meshgrid(range(es.shape[1]), range(es.shape[0]))
diff = np.abs(es - gt)
diff = np.abs(es - gt)
vmin, vmax = np.nanmin(gt), np.nanmax(gt)
vmin = vmin - 0.2 * (vmax - vmin)
vmax = vmax + 0.2 * (vmax - vmin)
vmin, vmax = np.nanmin(gt), np.nanmax(gt)
vmin = vmin - 0.2*(vmax-vmin)
vmax = vmax + 0.2*(vmax-vmin)
pattern_proj = self.pattern_proj.to('cpu').numpy()[0, 0]
im_orig = self.data['im0'].detach().to('cpu').numpy()[0, 0, 0]
pattern_diff = np.abs(im_orig - pattern_proj)
pattern_proj = self.pattern_proj.to('cpu').numpy()[0,0]
im_orig = self.data['im0'].detach().to('cpu').numpy()[0,0,0]
pattern_diff = np.abs(im_orig - pattern_proj)
fig = plt.figure(figsize=(16, 16))
es0 = co.cmap.color_depth_map(es[0], scale=vmax)
gt0 = co.cmap.color_depth_map(gt[0], scale=vmax)
diff0 = co.cmap.color_error_image(diff[0], BGR=True)
fig = plt.figure(figsize=(16,16))
es0 = co.cmap.color_depth_map(es[0], scale=vmax)
gt0 = co.cmap.color_depth_map(gt[0], scale=vmax)
diff0 = co.cmap.color_error_image(diff[0], BGR=True)
# plot disparities, ground truth disparity is shown only for reference
ax = plt.subplot(3, 3, 1);
plt.imshow(es0[..., [2, 1, 0]]);
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F0 Disparity Est. {es0.min():.4f}/{es0.max():.4f}')
ax = plt.subplot(3, 3, 2);
plt.imshow(gt0[..., [2, 1, 0]]);
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F0 Disparity GT {np.nanmin(gt0):.4f}/{np.nanmax(gt0):.4f}')
ax = plt.subplot(3, 3, 3);
plt.imshow(diff0[..., [2, 1, 0]]);
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F0 Disparity Err. {diff0.mean():.5f}')
# plot disparities, ground truth disparity is shown only for reference
ax = plt.subplot(3,3,1); plt.imshow(es0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Est. {es0.min():.4f}/{es0.max():.4f}')
ax = plt.subplot(3,3,2); plt.imshow(gt0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity GT {np.nanmin(gt0):.4f}/{np.nanmax(gt0):.4f}')
ax = plt.subplot(3,3,3); plt.imshow(diff0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Err. {diff0.mean():.5f}')
# plot disparities of the second frame in the track if exists
if es.shape[0] >= 2:
es1 = co.cmap.color_depth_map(es[1], scale=vmax)
gt1 = co.cmap.color_depth_map(gt[1], scale=vmax)
diff1 = co.cmap.color_error_image(diff[1], BGR=True)
ax = plt.subplot(3, 3, 4);
plt.imshow(es1[..., [2, 1, 0]]);
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F1 Disparity Est. {es1.min():.4f}/{es1.max():.4f}')
ax = plt.subplot(3, 3, 5);
plt.imshow(gt1[..., [2, 1, 0]]);
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F1 Disparity GT {np.nanmin(gt1):.4f}/{np.nanmax(gt1):.4f}')
ax = plt.subplot(3, 3, 6);
plt.imshow(diff1[..., [2, 1, 0]]);
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F1 Disparity Err. {diff1.mean():.5f}')
# plot disparities of the second frame in the track if exists
if es.shape[0]>=2:
es1 = co.cmap.color_depth_map(es[1], scale=vmax)
gt1 = co.cmap.color_depth_map(gt[1], scale=vmax)
diff1 = co.cmap.color_error_image(diff[1], BGR=True)
ax = plt.subplot(3,3,4); plt.imshow(es1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity Est. {es1.min():.4f}/{es1.max():.4f}')
ax = plt.subplot(3,3,5); plt.imshow(gt1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity GT {np.nanmin(gt1):.4f}/{np.nanmax(gt1):.4f}')
ax = plt.subplot(3,3,6); plt.imshow(diff1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity Err. {diff1.mean():.5f}')
# plot normalized IR inputs
ax = plt.subplot(3, 3, 7);
plt.imshow(im[0], vmin=im.min(), vmax=im.max(), cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F0 IR input {im[0].mean():.5f}/{im[0].std():.5f}')
if es.shape[0] >= 2:
ax = plt.subplot(3, 3, 8);
plt.imshow(im[1], vmin=im.min(), vmax=im.max(), cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F1 IR input {im[1].mean():.5f}/{im[1].std():.5f}')
# plot normalized IR inputs
ax = plt.subplot(3,3,7); plt.imshow(im[0], vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 IR input {im[0].mean():.5f}/{im[0].std():.5f}')
if es.shape[0]>=2:
ax = plt.subplot(3,3,8); plt.imshow(im[1], vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 IR input {im[1].mean():.5f}/{im[1].std():.5f}')
plt.tight_layout()
plt.savefig(str(out_path))
plt.close(fig)
plt.tight_layout()
plt.savefig(str(out_path))
plt.close(fig)
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks):
if batch_idx % 512 == 0:
out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png'
es, gt, im, ma = self.numpy_in_out(output)
masks = [ m.detach().to('cpu').numpy() for m in masks ]
self.write_img(out_path, es[:,0,0], gt[:,0,0], im[:,0,0], ma[:,0,0])
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks):
if batch_idx % 512 == 0:
out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png'
es, gt, im, ma = self.numpy_in_out(output)
masks = [m.detach().to('cpu').numpy() for m in masks]
self.write_img(out_path, es[:, 0, 0], gt[:, 0, 0], im[:, 0, 0], ma[:, 0, 0])
def callback_test_start(self, epoch, set_idx):
self.metric = co.metric.MultipleMetric(
co.metric.DistanceMetric(vec_length=1),
co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5])
)
def callback_test_start(self, epoch, set_idx):
self.metric = co.metric.MultipleMetric(
co.metric.DistanceMetric(vec_length=1),
co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5])
)
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks):
es, gt, im, ma = self.numpy_in_out(output)
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks):
es, gt, im, ma = self.numpy_in_out(output)
if batch_idx % 8 == 0:
out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png'
self.write_img(out_path, es[:,0,0], gt[:,0,0], im[:,0,0], ma[:,0,0])
if batch_idx % 8 == 0:
out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png'
self.write_img(out_path, es[:, 0, 0], gt[:, 0, 0], im[:, 0, 0], ma[:, 0, 0])
es, gt, im, ma = self.crop_output(es, gt, im, ma)
es, gt, im, ma = self.crop_output(es, gt, im, ma)
es = es.reshape(-1,1)
gt = gt.reshape(-1,1)
ma = ma.ravel()
self.metric.add(es, gt, ma)
es = es.reshape(-1, 1)
gt = gt.reshape(-1, 1)
ma = ma.ravel()
self.metric.add(es, gt, ma)
def callback_test_stop(self, epoch, set_idx, loss):
logging.info(f'{self.metric}')
for k, v in self.metric.items():
self.metric_add_test(epoch, set_idx, k, v)
def callback_test_stop(self, epoch, set_idx, loss):
logging.info(f'{self.metric}')
for k, v in self.metric.items():
self.metric_add_test(epoch, set_idx, k, v)
def crop_output(self, es, gt, im, ma):
tl = es.shape[0]
bs = es.shape[1]
es = np.reshape(es[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w])
gt = np.reshape(gt[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w])
im = np.reshape(im[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w])
ma = np.reshape(ma[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w])
return es, gt, im, ma
def crop_output(self, es, gt, im, ma):
tl = es.shape[0]
bs = es.shape[1]
es = np.reshape(es[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w])
gt = np.reshape(gt[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w])
im = np.reshape(im[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w])
ma = np.reshape(ma[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w])
return es, gt, im, ma
if __name__ == '__main__':
pass
pass
+436 -423
View File
@@ -8,559 +8,572 @@ import co
class TimedModule(torch.nn.Module):
def __init__(self, mod_name):
super().__init__()
self.mod_name = mod_name
def __init__(self, mod_name):
super().__init__()
self.mod_name = mod_name
def tforward(self, *args, **kwargs):
raise Exception('not implemented')
def tforward(self, *args, **kwargs):
raise Exception('not implemented')
def forward(self, *args, **kwargs):
torch.cuda.synchronize()
with co.gtimer.Ctx(self.mod_name):
x = self.tforward(*args, **kwargs)
torch.cuda.synchronize()
return x
def forward(self, *args, **kwargs):
torch.cuda.synchronize()
with co.gtimer.Ctx(self.mod_name):
x = self.tforward(*args, **kwargs)
torch.cuda.synchronize()
return x
class PosOutput(TimedModule):
def __init__(self, channels_in, type, im_height, im_width, alpha=1, beta=0, gamma=1, offset=0):
super().__init__(mod_name='PosOutput')
self.im_width = im_width
self.im_width = im_width
def __init__(self, channels_in, type, im_height, im_width, alpha=1, beta=0, gamma=1, offset=0):
super().__init__(mod_name='PosOutput')
self.im_width = im_width
self.im_width = im_width
if type == 'pos':
self.layer = torch.nn.Sequential(
torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1),
SigmoidAffine(alpha=alpha, beta=beta, gamma=gamma, offset=offset)
)
elif type == 'pos_row':
self.layer = torch.nn.Sequential(
MultiLinear(im_height, channels_in, 1),
SigmoidAffine(alpha=alpha, beta=beta, gamma=gamma, offset=offset)
)
if type == 'pos':
self.layer = torch.nn.Sequential(
torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1),
SigmoidAffine(alpha=alpha, beta=beta, gamma=gamma, offset=offset)
)
elif type == 'pos_row':
self.layer = torch.nn.Sequential(
MultiLinear(im_height, channels_in, 1),
SigmoidAffine(alpha=alpha, beta=beta, gamma=gamma, offset=offset)
)
self.u_pos = None
self.u_pos = None
def tforward(self, x):
if self.u_pos is None:
self.u_pos = torch.arange(x.shape[3], dtype=torch.float32).view(1,1,1,-1)
self.u_pos = self.u_pos.to(x.device)
pos = self.layer(x)
disp = self.u_pos - pos
return disp
def tforward(self, x):
if self.u_pos is None:
self.u_pos = torch.arange(x.shape[3], dtype=torch.float32).view(1, 1, 1, -1)
self.u_pos = self.u_pos.to(x.device)
pos = self.layer(x)
disp = self.u_pos - pos
return disp
class OutputLayerFactory(object):
'''
Define type of output
type options:
linear: apply only conv channel, used for the edge decoder
disp: estimate the disparity
disp_row: independently estimate the disparity per row
pos: estimate the absolute location
pos_row: independently estimate the absolute location per row
'''
def __init__(self, type='disp', params={}):
self.type = type
self.params = params
'''
Define type of output
type options:
linear: apply only conv channel, used for the edge decoder
disp: estimate the disparity
disp_row: independently estimate the disparity per row
pos: estimate the absolute location
pos_row: independently estimate the absolute location per row
'''
def __call__(self, channels_in, imsize):
def __init__(self, type='disp', params={}):
self.type = type
self.params = params
if self.type == 'linear':
return torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1)
def __call__(self, channels_in, imsize):
elif self.type == 'disp':
return torch.nn.Sequential(
torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1),
SigmoidAffine(**self.params)
)
if self.type == 'linear':
return torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1)
elif self.type == 'disp_row':
return torch.nn.Sequential(
MultiLinear(imsize[0], channels_in, 1),
SigmoidAffine(**self.params)
)
elif self.type == 'disp':
return torch.nn.Sequential(
torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1),
SigmoidAffine(**self.params)
)
elif self.type == 'pos' or self.type == 'pos_row':
return PosOutput(channels_in, **self.params)
elif self.type == 'disp_row':
return torch.nn.Sequential(
MultiLinear(imsize[0], channels_in, 1),
SigmoidAffine(**self.params)
)
else:
raise Exception('unknown output layer type')
elif self.type == 'pos' or self.type == 'pos_row':
return PosOutput(channels_in, **self.params)
else:
raise Exception('unknown output layer type')
class SigmoidAffine(TimedModule):
def __init__(self, alpha=1, beta=0, gamma=1, offset=0):
super().__init__(mod_name='SigmoidAffine')
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.offset = offset
def __init__(self, alpha=1, beta=0, gamma=1, offset=0):
super().__init__(mod_name='SigmoidAffine')
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.offset = offset
def tforward(self, x):
return torch.sigmoid(x/self.gamma - self.offset) * self.alpha + self.beta
def tforward(self, x):
return torch.sigmoid(x / self.gamma - self.offset) * self.alpha + self.beta
class MultiLinear(TimedModule):
def __init__(self, n, channels_in, channels_out):
super().__init__(mod_name='MultiLinear')
self.channels_out = channels_out
self.mods = torch.nn.ModuleList()
for idx in range(n):
self.mods.append(torch.nn.Linear(channels_in, channels_out))
def tforward(self, x):
x = x.permute(2,0,3,1) # BxCxHxW => HxBxWxC
y = x.new_empty(*x.shape[:-1], self.channels_out)
for hidx in range(x.shape[0]):
y[hidx] = self.mods[hidx](x[hidx])
y = y.permute(1,3,0,2) # HxBxWxC => BxCxHxW
return y
def __init__(self, n, channels_in, channels_out):
super().__init__(mod_name='MultiLinear')
self.channels_out = channels_out
self.mods = torch.nn.ModuleList()
for idx in range(n):
self.mods.append(torch.nn.Linear(channels_in, channels_out))
def tforward(self, x):
x = x.permute(2, 0, 3, 1) # BxCxHxW => HxBxWxC
y = x.new_empty(*x.shape[:-1], self.channels_out)
for hidx in range(x.shape[0]):
y[hidx] = self.mods[hidx](x[hidx])
y = y.permute(1, 3, 0, 2) # HxBxWxC => BxCxHxW
return y
class DispNetS(TimedModule):
'''
Disparity Decoder based on DispNetS
'''
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False, channel_multiplier=1):
super(DispNetS, self).__init__(mod_name='DispNetS')
'''
Disparity Decoder based on DispNetS
'''
self.output_ms = output_ms
self.coordconv = coordconv
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False,
channel_multiplier=1):
super(DispNetS, self).__init__(mod_name='DispNetS')
conv_planes = channel_multiplier * np.array( [32, 64, 128, 256, 512, 512, 512] )
self.conv1 = self.downsample_conv(channels_in, conv_planes[0], kernel_size=7)
self.conv2 = self.downsample_conv(conv_planes[0], conv_planes[1], kernel_size=5)
self.conv3 = self.downsample_conv(conv_planes[1], conv_planes[2])
self.conv4 = self.downsample_conv(conv_planes[2], conv_planes[3])
self.conv5 = self.downsample_conv(conv_planes[3], conv_planes[4])
self.conv6 = self.downsample_conv(conv_planes[4], conv_planes[5])
self.conv7 = self.downsample_conv(conv_planes[5], conv_planes[6])
self.output_ms = output_ms
self.coordconv = coordconv
upconv_planes = channel_multiplier * np.array( [512, 512, 256, 128, 64, 32, 16] )
self.upconv7 = self.upconv(conv_planes[6], upconv_planes[0])
self.upconv6 = self.upconv(upconv_planes[0], upconv_planes[1])
self.upconv5 = self.upconv(upconv_planes[1], upconv_planes[2])
self.upconv4 = self.upconv(upconv_planes[2], upconv_planes[3])
self.upconv3 = self.upconv(upconv_planes[3], upconv_planes[4])
self.upconv2 = self.upconv(upconv_planes[4], upconv_planes[5])
self.upconv1 = self.upconv(upconv_planes[5], upconv_planes[6])
conv_planes = channel_multiplier * np.array([32, 64, 128, 256, 512, 512, 512])
self.conv1 = self.downsample_conv(channels_in, conv_planes[0], kernel_size=7)
self.conv2 = self.downsample_conv(conv_planes[0], conv_planes[1], kernel_size=5)
self.conv3 = self.downsample_conv(conv_planes[1], conv_planes[2])
self.conv4 = self.downsample_conv(conv_planes[2], conv_planes[3])
self.conv5 = self.downsample_conv(conv_planes[3], conv_planes[4])
self.conv6 = self.downsample_conv(conv_planes[4], conv_planes[5])
self.conv7 = self.downsample_conv(conv_planes[5], conv_planes[6])
self.iconv7 = self.conv(upconv_planes[0] + conv_planes[5], upconv_planes[0])
self.iconv6 = self.conv(upconv_planes[1] + conv_planes[4], upconv_planes[1])
self.iconv5 = self.conv(upconv_planes[2] + conv_planes[3], upconv_planes[2])
self.iconv4 = self.conv(upconv_planes[3] + conv_planes[2], upconv_planes[3])
self.iconv3 = self.conv(1 + upconv_planes[4] + conv_planes[1], upconv_planes[4])
self.iconv2 = self.conv(1 + upconv_planes[5] + conv_planes[0], upconv_planes[5])
self.iconv1 = self.conv(1 + upconv_planes[6], upconv_planes[6])
upconv_planes = channel_multiplier * np.array([512, 512, 256, 128, 64, 32, 16])
self.upconv7 = self.upconv(conv_planes[6], upconv_planes[0])
self.upconv6 = self.upconv(upconv_planes[0], upconv_planes[1])
self.upconv5 = self.upconv(upconv_planes[1], upconv_planes[2])
self.upconv4 = self.upconv(upconv_planes[2], upconv_planes[3])
self.upconv3 = self.upconv(upconv_planes[3], upconv_planes[4])
self.upconv2 = self.upconv(upconv_planes[4], upconv_planes[5])
self.upconv1 = self.upconv(upconv_planes[5], upconv_planes[6])
if isinstance(output_facs, list):
self.predict_disp4 = output_facs[3](upconv_planes[3], imsizes[3])
self.predict_disp3 = output_facs[2](upconv_planes[4], imsizes[2])
self.predict_disp2 = output_facs[1](upconv_planes[5], imsizes[1])
self.predict_disp1 = output_facs[0](upconv_planes[6], imsizes[0])
else:
self.predict_disp4 = output_facs(upconv_planes[3], imsizes[3])
self.predict_disp3 = output_facs(upconv_planes[4], imsizes[2])
self.predict_disp2 = output_facs(upconv_planes[5], imsizes[1])
self.predict_disp1 = output_facs(upconv_planes[6], imsizes[0])
self.iconv7 = self.conv(upconv_planes[0] + conv_planes[5], upconv_planes[0])
self.iconv6 = self.conv(upconv_planes[1] + conv_planes[4], upconv_planes[1])
self.iconv5 = self.conv(upconv_planes[2] + conv_planes[3], upconv_planes[2])
self.iconv4 = self.conv(upconv_planes[3] + conv_planes[2], upconv_planes[3])
self.iconv3 = self.conv(1 + upconv_planes[4] + conv_planes[1], upconv_planes[4])
self.iconv2 = self.conv(1 + upconv_planes[5] + conv_planes[0], upconv_planes[5])
self.iconv1 = self.conv(1 + upconv_planes[6], upconv_planes[6])
if isinstance(output_facs, list):
self.predict_disp4 = output_facs[3](upconv_planes[3], imsizes[3])
self.predict_disp3 = output_facs[2](upconv_planes[4], imsizes[2])
self.predict_disp2 = output_facs[1](upconv_planes[5], imsizes[1])
self.predict_disp1 = output_facs[0](upconv_planes[6], imsizes[0])
else:
self.predict_disp4 = output_facs(upconv_planes[3], imsizes[3])
self.predict_disp3 = output_facs(upconv_planes[4], imsizes[2])
self.predict_disp2 = output_facs(upconv_planes[5], imsizes[1])
self.predict_disp1 = output_facs(upconv_planes[6], imsizes[0])
def init_weights(self):
for m in self.modules():
if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.ConvTranspose2d):
torch.nn.init.xavier_uniform_(m.weight, gain=0.1)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
def init_weights(self):
for m in self.modules():
if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.ConvTranspose2d):
torch.nn.init.xavier_uniform_(m.weight, gain=0.1)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
def downsample_conv(self, in_planes, out_planes, kernel_size=3):
if self.coordconv:
conv = torchext.CoordConv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=(kernel_size-1)//2)
else:
conv = torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=(kernel_size-1)//2)
return torch.nn.Sequential(
conv,
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(out_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size-1)//2),
torch.nn.ReLU(inplace=True)
)
def downsample_conv(self, in_planes, out_planes, kernel_size=3):
if self.coordconv:
conv = torchext.CoordConv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2,
padding=(kernel_size - 1) // 2)
else:
conv = torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2,
padding=(kernel_size - 1) // 2)
return torch.nn.Sequential(
conv,
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(out_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size - 1) // 2),
torch.nn.ReLU(inplace=True)
)
def conv(self, in_planes, out_planes):
return torch.nn.Sequential(
torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1),
torch.nn.ReLU(inplace=True)
)
def conv(self, in_planes, out_planes):
return torch.nn.Sequential(
torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1),
torch.nn.ReLU(inplace=True)
)
def upconv(self, in_planes, out_planes):
return torch.nn.Sequential(
torch.nn.ConvTranspose2d(in_planes, out_planes, kernel_size=3, stride=2, padding=1, output_padding=1),
torch.nn.ReLU(inplace=True)
)
def upconv(self, in_planes, out_planes):
return torch.nn.Sequential(
torch.nn.ConvTranspose2d(in_planes, out_planes, kernel_size=3, stride=2, padding=1, output_padding=1),
torch.nn.ReLU(inplace=True)
)
def crop_like(self, input, ref):
assert(input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3))
return input[:, :, :ref.size(2), :ref.size(3)]
def crop_like(self, input, ref):
assert (input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3))
return input[:, :, :ref.size(2), :ref.size(3)]
def tforward(self, x):
out_conv1 = self.conv1(x)
out_conv2 = self.conv2(out_conv1)
out_conv3 = self.conv3(out_conv2)
out_conv4 = self.conv4(out_conv3)
out_conv5 = self.conv5(out_conv4)
out_conv6 = self.conv6(out_conv5)
out_conv7 = self.conv7(out_conv6)
def tforward(self, x):
out_conv1 = self.conv1(x)
out_conv2 = self.conv2(out_conv1)
out_conv3 = self.conv3(out_conv2)
out_conv4 = self.conv4(out_conv3)
out_conv5 = self.conv5(out_conv4)
out_conv6 = self.conv6(out_conv5)
out_conv7 = self.conv7(out_conv6)
out_upconv7 = self.crop_like(self.upconv7(out_conv7), out_conv6)
concat7 = torch.cat((out_upconv7, out_conv6), 1)
out_iconv7 = self.iconv7(concat7)
out_upconv7 = self.crop_like(self.upconv7(out_conv7), out_conv6)
concat7 = torch.cat((out_upconv7, out_conv6), 1)
out_iconv7 = self.iconv7(concat7)
out_upconv6 = self.crop_like(self.upconv6(out_iconv7), out_conv5)
concat6 = torch.cat((out_upconv6, out_conv5), 1)
out_iconv6 = self.iconv6(concat6)
out_upconv6 = self.crop_like(self.upconv6(out_iconv7), out_conv5)
concat6 = torch.cat((out_upconv6, out_conv5), 1)
out_iconv6 = self.iconv6(concat6)
out_upconv5 = self.crop_like(self.upconv5(out_iconv6), out_conv4)
concat5 = torch.cat((out_upconv5, out_conv4), 1)
out_iconv5 = self.iconv5(concat5)
out_upconv5 = self.crop_like(self.upconv5(out_iconv6), out_conv4)
concat5 = torch.cat((out_upconv5, out_conv4), 1)
out_iconv5 = self.iconv5(concat5)
out_upconv4 = self.crop_like(self.upconv4(out_iconv5), out_conv3)
concat4 = torch.cat((out_upconv4, out_conv3), 1)
out_iconv4 = self.iconv4(concat4)
disp4 = self.predict_disp4(out_iconv4)
out_upconv4 = self.crop_like(self.upconv4(out_iconv5), out_conv3)
concat4 = torch.cat((out_upconv4, out_conv3), 1)
out_iconv4 = self.iconv4(concat4)
disp4 = self.predict_disp4(out_iconv4)
out_upconv3 = self.crop_like(self.upconv3(out_iconv4), out_conv2)
disp4_up = self.crop_like(torch.nn.functional.interpolate(disp4, scale_factor=2, mode='bilinear', align_corners=False), out_conv2)
concat3 = torch.cat((out_upconv3, out_conv2, disp4_up), 1)
out_iconv3 = self.iconv3(concat3)
disp3 = self.predict_disp3(out_iconv3)
out_upconv3 = self.crop_like(self.upconv3(out_iconv4), out_conv2)
disp4_up = self.crop_like(
torch.nn.functional.interpolate(disp4, scale_factor=2, mode='bilinear', align_corners=False), out_conv2)
concat3 = torch.cat((out_upconv3, out_conv2, disp4_up), 1)
out_iconv3 = self.iconv3(concat3)
disp3 = self.predict_disp3(out_iconv3)
out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1)
disp3_up = self.crop_like(torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1)
out_iconv2 = self.iconv2(concat2)
disp2 = self.predict_disp2(out_iconv2)
out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1)
disp3_up = self.crop_like(
torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1)
out_iconv2 = self.iconv2(concat2)
disp2 = self.predict_disp2(out_iconv2)
out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x)
disp2_up = self.crop_like(torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x)
concat1 = torch.cat((out_upconv1, disp2_up), 1)
out_iconv1 = self.iconv1(concat1)
disp1 = self.predict_disp1(out_iconv1)
out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x)
disp2_up = self.crop_like(
torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x)
concat1 = torch.cat((out_upconv1, disp2_up), 1)
out_iconv1 = self.iconv1(concat1)
disp1 = self.predict_disp1(out_iconv1)
if self.output_ms:
return disp1, disp2, disp3, disp4
else:
return disp1
if self.output_ms:
return disp1, disp2, disp3, disp4
else:
return disp1
class DispNetShallow(DispNetS):
'''
Edge Decoder based on DispNetS with fewer layers
'''
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False):
super(DispNetShallow, self).__init__(channels_in, imsizes, output_facs, output_ms, coordconv, weight_init)
self.mod_name = 'DispNetShallow'
conv_planes = [32, 64, 128, 256, 512, 512, 512]
upconv_planes = [512, 512, 256, 128, 64, 32, 16]
self.iconv3 = self.conv(upconv_planes[4] + conv_planes[1], upconv_planes[4])
'''
Edge Decoder based on DispNetS with fewer layers
'''
def tforward(self, x):
out_conv1 = self.conv1(x)
out_conv2 = self.conv2(out_conv1)
out_conv3 = self.conv3(out_conv2)
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False):
super(DispNetShallow, self).__init__(channels_in, imsizes, output_facs, output_ms, coordconv, weight_init)
self.mod_name = 'DispNetShallow'
conv_planes = [32, 64, 128, 256, 512, 512, 512]
upconv_planes = [512, 512, 256, 128, 64, 32, 16]
self.iconv3 = self.conv(upconv_planes[4] + conv_planes[1], upconv_planes[4])
out_upconv3 = self.crop_like(self.upconv3(out_conv3), out_conv2)
concat3 = torch.cat((out_upconv3, out_conv2), 1)
out_iconv3 = self.iconv3(concat3)
disp3 = self.predict_disp3(out_iconv3)
def tforward(self, x):
out_conv1 = self.conv1(x)
out_conv2 = self.conv2(out_conv1)
out_conv3 = self.conv3(out_conv2)
out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1)
disp3_up = self.crop_like(torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1)
out_iconv2 = self.iconv2(concat2)
disp2 = self.predict_disp2(out_iconv2)
out_upconv3 = self.crop_like(self.upconv3(out_conv3), out_conv2)
concat3 = torch.cat((out_upconv3, out_conv2), 1)
out_iconv3 = self.iconv3(concat3)
disp3 = self.predict_disp3(out_iconv3)
out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x)
disp2_up = self.crop_like(torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x)
concat1 = torch.cat((out_upconv1, disp2_up), 1)
out_iconv1 = self.iconv1(concat1)
disp1 = self.predict_disp1(out_iconv1)
out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1)
disp3_up = self.crop_like(
torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1)
out_iconv2 = self.iconv2(concat2)
disp2 = self.predict_disp2(out_iconv2)
if self.output_ms:
return disp1, disp2, disp3
else:
return disp1
out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x)
disp2_up = self.crop_like(
torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x)
concat1 = torch.cat((out_upconv1, disp2_up), 1)
out_iconv1 = self.iconv1(concat1)
disp1 = self.predict_disp1(out_iconv1)
if self.output_ms:
return disp1, disp2, disp3
else:
return disp1
class DispEdgeDecoders(TimedModule):
'''
Disparity Decoder and Edge Decoder
'''
def __init__(self, *args, max_disp=128, **kwargs):
super(DispEdgeDecoders, self).__init__(mod_name='DispEdgeDecoders')
'''
Disparity Decoder and Edge Decoder
'''
output_facs = [OutputLayerFactory( type='disp', params={ 'alpha': max_disp/(2**s), 'beta': 0, 'gamma': 1, 'offset': 3}) for s in range(4)]
self.disp_decoder = DispNetS(*args, output_facs=output_facs, **kwargs)
def __init__(self, *args, max_disp=128, **kwargs):
super(DispEdgeDecoders, self).__init__(mod_name='DispEdgeDecoders')
output_facs = [OutputLayerFactory( type='linear' ) for s in range(4)]
self.edge_decoder = DispNetShallow(*args, output_facs=output_facs, **kwargs)
output_facs = [
OutputLayerFactory(type='disp', params={'alpha': max_disp / (2 ** s), 'beta': 0, 'gamma': 1, 'offset': 3})
for s in range(4)]
self.disp_decoder = DispNetS(*args, output_facs=output_facs, **kwargs)
def tforward(self, x):
disp = self.disp_decoder(x)
edge = self.edge_decoder(x)
return disp, edge
output_facs = [OutputLayerFactory(type='linear') for s in range(4)]
self.edge_decoder = DispNetShallow(*args, output_facs=output_facs, **kwargs)
def tforward(self, x):
disp = self.disp_decoder(x)
edge = self.edge_decoder(x)
return disp, edge
class DispToDepth(TimedModule):
def __init__(self, focal_length, baseline):
super().__init__(mod_name='DispToDepth')
self.baseline_focal_length = baseline * focal_length
def __init__(self, focal_length, baseline):
super().__init__(mod_name='DispToDepth')
self.baseline_focal_length = baseline * focal_length
def tforward(self, disp):
disp = torch.nn.functional.relu(disp) + 1e-12
depth = self.baseline_focal_length / disp
return depth
def tforward(self, disp):
disp = torch.nn.functional.relu(disp) + 1e-12
depth = self.baseline_focal_length / disp
return depth
class PosToDepth(DispToDepth):
def __init__(self, focal_length, baseline, im_height, im_width):
super().__init__(focal_length, baseline)
self.mod_name = 'PosToDepth'
def __init__(self, focal_length, baseline, im_height, im_width):
super().__init__(focal_length, baseline)
self.mod_name = 'PosToDepth'
self.im_height = im_height
self.im_width = im_width
self.u_pos = torch.arange(im_width, dtype=torch.float32).view(1,1,1,-1)
def tforward(self, pos):
self.u_pos = self.u_pos.to(pos.device)
disp = self.u_pos - pos
return super().forward(disp)
self.im_height = im_height
self.im_width = im_width
self.u_pos = torch.arange(im_width, dtype=torch.float32).view(1, 1, 1, -1)
def tforward(self, pos):
self.u_pos = self.u_pos.to(pos.device)
disp = self.u_pos - pos
return super().forward(disp)
class RectifiedPatternSimilarityLoss(TimedModule):
'''
Photometric Loss
'''
def __init__(self, im_height, im_width, pattern, loss_type='census_sad', loss_eps=0.5):
super().__init__(mod_name='RectifiedPatternSimilarityLoss')
self.im_height = im_height
self.im_width = im_width
self.pattern = pattern.mean(dim=1, keepdim=True).contiguous()
'''
Photometric Loss
'''
u, v = np.meshgrid(range(im_width), range(im_height))
uv0 = np.stack((u,v), axis=2).reshape(-1,1)
uv0 = uv0.astype(np.float32).reshape(1,-1,2)
self.uv0 = torch.from_numpy(uv0)
def __init__(self, im_height, im_width, pattern, loss_type='census_sad', loss_eps=0.5):
super().__init__(mod_name='RectifiedPatternSimilarityLoss')
self.im_height = im_height
self.im_width = im_width
self.pattern = pattern.mean(dim=1, keepdim=True).contiguous()
self.loss_type = loss_type
self.loss_eps = loss_eps
u, v = np.meshgrid(range(im_width), range(im_height))
uv0 = np.stack((u, v), axis=2).reshape(-1, 1)
uv0 = uv0.astype(np.float32).reshape(1, -1, 2)
self.uv0 = torch.from_numpy(uv0)
def tforward(self, disp0, im, std=None):
self.pattern = self.pattern.to(disp0.device)
self.uv0 = self.uv0.to(disp0.device)
self.loss_type = loss_type
self.loss_eps = loss_eps
uv0 = self.uv0.expand(disp0.shape[0], *self.uv0.shape[1:])
uv1 = torch.empty_like(uv0)
uv1[...,0] = uv0[...,0] - disp0.contiguous().view(disp0.shape[0],-1)
uv1[...,1] = uv0[...,1]
def tforward(self, disp0, im, std=None):
self.pattern = self.pattern.to(disp0.device)
self.uv0 = self.uv0.to(disp0.device)
uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width-1) - 0.5)
uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height-1) - 0.5)
uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone()
pattern = self.pattern.expand(disp0.shape[0], *self.pattern.shape[1:])
pattern_proj = torch.nn.functional.grid_sample(pattern, uv1, padding_mode='border')
mask = torch.ones_like(im)
if std is not None:
mask = mask*std
uv0 = self.uv0.expand(disp0.shape[0], *self.uv0.shape[1:])
uv1 = torch.empty_like(uv0)
uv1[..., 0] = uv0[..., 0] - disp0.contiguous().view(disp0.shape[0], -1)
uv1[..., 1] = uv0[..., 1]
uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width - 1) - 0.5)
uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height - 1) - 0.5)
uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone()
pattern = self.pattern.expand(disp0.shape[0], *self.pattern.shape[1:])
pattern_proj = torch.nn.functional.grid_sample(pattern, uv1, padding_mode='border')
mask = torch.ones_like(im)
if std is not None:
mask = mask * std
diff = torchext.photometric_loss(pattern_proj.contiguous(), im.contiguous(), 9, self.loss_type, self.loss_eps)
val = (mask * diff).sum() / mask.sum()
return val, pattern_proj
diff = torchext.photometric_loss(pattern_proj.contiguous(), im.contiguous(), 9, self.loss_type, self.loss_eps)
val = (mask*diff).sum() / mask.sum()
return val, pattern_proj
class DisparityLoss(TimedModule):
'''
Disparity Loss
'''
def __init__(self):
super().__init__(mod_name='DisparityLoss')
self.sobel = SobelFilter(norm=False)
'''
Disparity Loss
'''
#if not edge_gt:
self.b0=0.0503428816795
self.b1=1.07274045944
#else:
# self.b0=0.0587115108967
# self.b1=1.51931190491
def __init__(self):
super().__init__(mod_name='DisparityLoss')
self.sobel = SobelFilter(norm=False)
def tforward(self, disp, edge=None):
self.sobel=self.sobel.to(disp.device)
# if not edge_gt:
self.b0 = 0.0503428816795
self.b1 = 1.07274045944
# else:
# self.b0=0.0587115108967
# self.b1=1.51931190491
if edge is not None:
grad = self.sobel(disp)
grad = torch.sqrt(grad[:,0:1,...]**2 + grad[:,1:2,...]**2 + 1e-8)
pdf = (1-edge)/self.b0 * torch.exp(-torch.abs(grad)/self.b0) + \
edge/self.b1 * torch.exp(-torch.abs(grad)/self.b1)
val = torch.mean(-torch.log(pdf.clamp(min=1e-4)))
else:
# on qifeng's data we don't have ambient info
# therefore we supress edge everywhere
grad = self.sobel(disp)
grad = torch.sqrt(grad[:,0:1,...]**2 + grad[:,1:2,...]**2 + 1e-8)
grad= torch.clamp(grad, 0, 1.0)
val = torch.mean(grad)
def tforward(self, disp, edge=None):
self.sobel = self.sobel.to(disp.device)
return val
if edge is not None:
grad = self.sobel(disp)
grad = torch.sqrt(grad[:, 0:1, ...] ** 2 + grad[:, 1:2, ...] ** 2 + 1e-8)
pdf = (1 - edge) / self.b0 * torch.exp(-torch.abs(grad) / self.b0) + \
edge / self.b1 * torch.exp(-torch.abs(grad) / self.b1)
val = torch.mean(-torch.log(pdf.clamp(min=1e-4)))
else:
# on qifeng's data we don't have ambient info
# therefore we supress edge everywhere
grad = self.sobel(disp)
grad = torch.sqrt(grad[:, 0:1, ...] ** 2 + grad[:, 1:2, ...] ** 2 + 1e-8)
grad = torch.clamp(grad, 0, 1.0)
val = torch.mean(grad)
return val
class ProjectionBaseLoss(TimedModule):
'''
Base module of the Geometric Loss
'''
def __init__(self, K, Ki, im_height, im_width):
super().__init__(mod_name='ProjectionBaseLoss')
'''
Base module of the Geometric Loss
'''
self.K = K.view(-1,3,3)
def __init__(self, K, Ki, im_height, im_width):
super().__init__(mod_name='ProjectionBaseLoss')
self.im_height = im_height
self.im_width = im_width
self.K = K.view(-1, 3, 3)
u, v = np.meshgrid(range(im_width), range(im_height))
uv = np.stack((u,v,np.ones_like(u)), axis=2).reshape(-1,3)
self.im_height = im_height
self.im_width = im_width
ray = uv @ Ki.numpy().T
u, v = np.meshgrid(range(im_width), range(im_height))
uv = np.stack((u, v, np.ones_like(u)), axis=2).reshape(-1, 3)
ray = ray.reshape(1,-1,3).astype(np.float32)
self.ray = torch.from_numpy(ray)
ray = uv @ Ki.numpy().T
def transform(self, xyz, R=None, t=None):
if t is not None:
bs = xyz.shape[0]
xyz = xyz - t.reshape(bs,1,3)
if R is not None:
xyz = torch.bmm(xyz, R)
return xyz
ray = ray.reshape(1, -1, 3).astype(np.float32)
self.ray = torch.from_numpy(ray)
def unproject(self, depth, R=None, t=None):
self.ray = self.ray.to(depth.device)
bs = depth.shape[0]
def transform(self, xyz, R=None, t=None):
if t is not None:
bs = xyz.shape[0]
xyz = xyz - t.reshape(bs, 1, 3)
if R is not None:
xyz = torch.bmm(xyz, R)
return xyz
xyz = depth.reshape(bs,-1,1) * self.ray
xyz = self.transform(xyz, R, t)
return xyz
def unproject(self, depth, R=None, t=None):
self.ray = self.ray.to(depth.device)
bs = depth.shape[0]
def project(self, xyz, R, t):
self.K = self.K.to(xyz.device)
bs = xyz.shape[0]
xyz = depth.reshape(bs, -1, 1) * self.ray
xyz = self.transform(xyz, R, t)
return xyz
xyz = torch.bmm(xyz, R.transpose(1,2))
xyz = xyz + t.reshape(bs,1,3)
def project(self, xyz, R, t):
self.K = self.K.to(xyz.device)
bs = xyz.shape[0]
Kt = self.K.transpose(1,2).expand(bs,-1,-1)
uv = torch.bmm(xyz, Kt)
xyz = torch.bmm(xyz, R.transpose(1, 2))
xyz = xyz + t.reshape(bs, 1, 3)
d = uv[:,:,2:3]
Kt = self.K.transpose(1, 2).expand(bs, -1, -1)
uv = torch.bmm(xyz, Kt)
# avoid division by zero
uv = uv[:,:,:2] / (torch.nn.functional.relu(d) + 1e-12)
return uv, d
d = uv[:, :, 2:3]
# avoid division by zero
uv = uv[:, :, :2] / (torch.nn.functional.relu(d) + 1e-12)
return uv, d
def tforward(self, depth0, R0, t0, R1, t1):
xyz = self.unproject(depth0, R0, t0)
return self.project(xyz, R1, t1)
def tforward(self, depth0, R0, t0, R1, t1):
xyz = self.unproject(depth0, R0, t0)
return self.project(xyz, R1, t1)
class ProjectionDepthSimilarityLoss(ProjectionBaseLoss):
'''
Geometric Loss
'''
def __init__(self, *args, clamp=-1):
super().__init__(*args)
self.mod_name = 'ProjectionDepthSimilarityLoss'
self.clamp = clamp
'''
Geometric Loss
'''
def fwd(self, depth0, depth1, R0, t0, R1, t1):
uv1, d1 = super().tforward(depth0, R0, t0, R1, t1)
def __init__(self, *args, clamp=-1):
super().__init__(*args)
self.mod_name = 'ProjectionDepthSimilarityLoss'
self.clamp = clamp
uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width-1) - 0.5)
uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height-1) - 0.5)
uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone()
def fwd(self, depth0, depth1, R0, t0, R1, t1):
uv1, d1 = super().tforward(depth0, R0, t0, R1, t1)
depth10 = torch.nn.functional.grid_sample(depth1, uv1, padding_mode='border')
uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width - 1) - 0.5)
uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height - 1) - 0.5)
uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone()
diff = torch.abs(d1.view(-1) - depth10.view(-1))
depth10 = torch.nn.functional.grid_sample(depth1, uv1, padding_mode='border')
if self.clamp > 0:
diff = torch.clamp(diff, 0, self.clamp)
diff = torch.abs(d1.view(-1) - depth10.view(-1))
# return diff without clamping for debugging
return diff.mean()
if self.clamp > 0:
diff = torch.clamp(diff, 0, self.clamp)
def tforward(self, depth0, depth1, R0, t0, R1, t1):
l0 = self.fwd(depth0, depth1, R0, t0, R1, t1)
l1 = self.fwd(depth1, depth0, R1, t1, R0, t0)
return l0+l1
# return diff without clamping for debugging
return diff.mean()
def tforward(self, depth0, depth1, R0, t0, R1, t1):
l0 = self.fwd(depth0, depth1, R0, t0, R1, t1)
l1 = self.fwd(depth1, depth0, R1, t1, R0, t0)
return l0 + l1
class LCN(TimedModule):
'''
Local Contract Normalization
'''
def __init__(self, radius, epsilon):
super().__init__(mod_name='LCN')
self.box_conv = torch.nn.Sequential(
torch.nn.ReflectionPad2d(radius),
torch.nn.Conv2d(1, 1, kernel_size=2*radius+1, bias=False)
)
self.box_conv[1].weight.requires_grad=False
self.box_conv[1].weight.fill_(1.)
'''
Local Contract Normalization
'''
self.epsilon = epsilon
self.radius = radius
def __init__(self, radius, epsilon):
super().__init__(mod_name='LCN')
self.box_conv = torch.nn.Sequential(
torch.nn.ReflectionPad2d(radius),
torch.nn.Conv2d(1, 1, kernel_size=2 * radius + 1, bias=False)
)
self.box_conv[1].weight.requires_grad = False
self.box_conv[1].weight.fill_(1.)
def tforward(self, data):
boxs = self.box_conv(data)
self.epsilon = epsilon
self.radius = radius
avgs = boxs / (2*self.radius+1)**2
boxs_n2 = boxs**2
boxs_2n = self.box_conv(data**2)
def tforward(self, data):
boxs = self.box_conv(data)
stds = torch.sqrt(boxs_2n / (2*self.radius+1)**2 - avgs**2 + 1e-6)
stds = stds + self.epsilon
avgs = boxs / (2 * self.radius + 1) ** 2
boxs_n2 = boxs ** 2
boxs_2n = self.box_conv(data ** 2)
return (data - avgs) / stds, stds
stds = torch.sqrt(boxs_2n / (2 * self.radius + 1) ** 2 - avgs ** 2 + 1e-6)
stds = stds + self.epsilon
return (data - avgs) / stds, stds
class SobelFilter(TimedModule):
'''
Sobel Filter
'''
def __init__(self, norm=False):
super(SobelFilter, self).__init__(mod_name='SobelFilter')
kx = np.array([[-5, -4, 0, 4, 5],
[-8, -10, 0, 10, 8],
[-10, -20, 0, 20, 10],
[-8, -10, 0, 10, 8],
[-5, -4, 0, 4, 5]])/240.0
ky = kx.copy().transpose(1,0)
'''
Sobel Filter
'''
self.conv_x=torch.nn.Conv2d(1, 1, kernel_size=5, stride=1, padding=0, bias=False)
self.conv_x.weight=torch.nn.Parameter(torch.from_numpy(kx).float().unsqueeze(0).unsqueeze(0))
def __init__(self, norm=False):
super(SobelFilter, self).__init__(mod_name='SobelFilter')
kx = np.array([[-5, -4, 0, 4, 5],
[-8, -10, 0, 10, 8],
[-10, -20, 0, 20, 10],
[-8, -10, 0, 10, 8],
[-5, -4, 0, 4, 5]]) / 240.0
ky = kx.copy().transpose(1, 0)
self.conv_y=torch.nn.Conv2d(1, 1, kernel_size=5, stride=1, padding=0, bias=False)
self.conv_y.weight=torch.nn.Parameter(torch.from_numpy(ky).float().unsqueeze(0).unsqueeze(0))
self.conv_x = torch.nn.Conv2d(1, 1, kernel_size=5, stride=1, padding=0, bias=False)
self.conv_x.weight = torch.nn.Parameter(torch.from_numpy(kx).float().unsqueeze(0).unsqueeze(0))
self.norm=norm
self.conv_y = torch.nn.Conv2d(1, 1, kernel_size=5, stride=1, padding=0, bias=False)
self.conv_y.weight = torch.nn.Parameter(torch.from_numpy(ky).float().unsqueeze(0).unsqueeze(0))
def tforward(self,x):
x = F.pad(x, (2,2,2,2), "replicate")
gx = self.conv_x(x)
gy = self.conv_y(x)
if self.norm:
return torch.sqrt(gx**2 + gy**2 + 1e-8)
else:
return torch.cat((gx, gy), dim=1)
self.norm = norm
def tforward(self, x):
x = F.pad(x, (2, 2, 2, 2), "replicate")
gx = self.conv_x(x)
gy = self.conv_y(x)
if self.norm:
return torch.sqrt(gx ** 2 + gy ** 2 + 1e-8)
else:
return torch.cat((gx, gy), dim=1)