You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
298 lines
11 KiB
298 lines
11 KiB
import torch
|
|
import numpy as np
|
|
import time
|
|
from pathlib import Path
|
|
import logging
|
|
import sys
|
|
import itertools
|
|
import json
|
|
import matplotlib.pyplot as plt
|
|
import co
|
|
import torchext
|
|
from model import networks
|
|
from data import dataset
|
|
|
|
class Worker(torchext.Worker):
|
|
def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs):
|
|
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, train_batch_size=train_batch_size, test_batch_size=test_batch_size, save_frequency=save_frequency, **kwargs)
|
|
|
|
self.ms = args.ms
|
|
self.pattern_path = args.pattern_path
|
|
self.lcn_radius = args.lcn_radius
|
|
self.dp_weight = args.dp_weight
|
|
self.ge_weight = args.ge_weight
|
|
self.track_length = args.track_length
|
|
self.data_type = args.data_type
|
|
assert(self.track_length>1)
|
|
|
|
self.imsizes = [(480,640)]
|
|
for iter in range(3):
|
|
self.imsizes.append((int(self.imsizes[-1][0]/2), int(self.imsizes[-1][1]/2)))
|
|
|
|
with open('config.json') as fp:
|
|
config = json.load(fp)
|
|
data_root = Path(config['DATA_ROOT'])
|
|
self.settings_path = data_root / self.data_type / 'settings.pkl'
|
|
sample_paths = sorted((data_root / self.data_type).glob('0*/'))
|
|
|
|
self.train_paths = sample_paths[2**10:]
|
|
self.test_paths = sample_paths[:2**8]
|
|
|
|
# supervise the edge encoder with only 2**8 samples
|
|
self.train_edge = len(self.train_paths) - 2**8
|
|
|
|
self.lcn_in = networks.LCN(self.lcn_radius, 0.05)
|
|
self.disparity_loss = networks.DisparityLoss()
|
|
self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device))
|
|
|
|
# evaluate in the region where opencv Block Matching has valid values
|
|
self.eval_mask = np.zeros(self.imsizes[0])
|
|
self.eval_mask[13:self.imsizes[0][0]-13, 140:self.imsizes[0][1]-13]=1
|
|
self.eval_mask = self.eval_mask.astype(np.bool)
|
|
self.eval_h = self.imsizes[0][0]-2*13
|
|
self.eval_w = self.imsizes[0][1]-13-140
|
|
|
|
|
|
def get_train_set(self):
|
|
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True, track_length=self.track_length)
|
|
return train_set
|
|
|
|
def get_test_sets(self):
|
|
test_sets = torchext.TestSets()
|
|
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True, track_length=1)
|
|
test_sets.append('simple', test_set, test_frequency=1)
|
|
|
|
self.ph_losses = []
|
|
self.ge_losses = []
|
|
self.d2ds = []
|
|
|
|
self.lcn_in = self.lcn_in.to('cuda')
|
|
for sidx in range(len(test_set.imsizes)):
|
|
imsize = test_set.imsizes[sidx]
|
|
pat = test_set.patterns[sidx]
|
|
pat = pat.mean(axis=2)
|
|
pat = torch.from_numpy(pat[None][None].astype(np.float32)).to('cuda')
|
|
pat,_ = self.lcn_in(pat)
|
|
pat = torch.cat([pat for idx in range(3)], dim=1)
|
|
ph_loss = networks.RectifiedPatternSimilarityLoss(imsize[0],imsize[1], pattern=pat)
|
|
|
|
K = test_set.getK(sidx)
|
|
Ki = np.linalg.inv(K)
|
|
K = torch.from_numpy(K)
|
|
Ki = torch.from_numpy(Ki)
|
|
ge_loss = networks.ProjectionDepthSimilarityLoss(K, Ki, imsize[0], imsize[1], clamp=0.1)
|
|
|
|
self.ph_losses.append( ph_loss )
|
|
self.ge_losses.append( ge_loss )
|
|
|
|
d2d = networks.DispToDepth(float(test_set.focal_lengths[sidx]), float(test_set.baseline))
|
|
self.d2ds.append( d2d )
|
|
|
|
return test_sets
|
|
|
|
def copy_data(self, data, device, requires_grad, train):
|
|
self.data = {}
|
|
|
|
self.lcn_in = self.lcn_in.to(device)
|
|
for key, val in data.items():
|
|
# from
|
|
# batch_size x track_length x ...
|
|
# to
|
|
# track_length x batch_size x ...
|
|
if len(val.shape)>2:
|
|
if train:
|
|
val = val.transpose(0,1)
|
|
else:
|
|
val = val.unsqueeze(0)
|
|
grad = 'im' in key and requires_grad
|
|
self.data[key] = val.to(device).requires_grad_(requires_grad=grad)
|
|
if 'im' in key and 'blend' not in key:
|
|
im = self.data[key]
|
|
tl = im.shape[0]
|
|
bs = im.shape[1]
|
|
im_lcn,im_std = self.lcn_in(im.contiguous().view(-1, *im.shape[2:]))
|
|
key_std = key.replace('im','std')
|
|
self.data[key_std] = im_std.view(tl, bs, *im.shape[2:]).to(device)
|
|
im_cat = torch.cat((im_lcn.view(tl, bs, *im.shape[2:]), im), dim=2)
|
|
self.data[key] = im_cat
|
|
|
|
def net_forward(self, net, train):
|
|
im0 = self.data['im0']
|
|
tl = im0.shape[0]
|
|
bs = im0.shape[1]
|
|
im0 = im0.view(-1, *im0.shape[2:])
|
|
out, edge = net(im0)
|
|
if not(isinstance(out, tuple) or isinstance(out, list)):
|
|
out = out.view(tl, bs, *out.shape[1:])
|
|
edge = edge.view(tl, bs, *out.shape[1:])
|
|
else:
|
|
out = [o.view(tl, bs, *o.shape[1:]) for o in out]
|
|
edge = [e.view(tl, bs, *e.shape[1:]) for e in edge]
|
|
return out, edge
|
|
|
|
def loss_forward(self, out, train):
|
|
out, edge = out
|
|
if not(isinstance(out, tuple) or isinstance(out, list)):
|
|
out = [out]
|
|
vals = []
|
|
diffs = []
|
|
|
|
# apply photometric loss
|
|
for s,l,o in zip(itertools.count(), self.ph_losses, out):
|
|
im = self.data[f'im{s}']
|
|
im = im.view(-1, *im.shape[2:])
|
|
o = o.view(-1, *o.shape[2:])
|
|
std = self.data[f'std{s}']
|
|
std = std.view(-1, *std.shape[2:])
|
|
val, pattern_proj = l(o, im[:,0:1,...], std)
|
|
vals.append(val)
|
|
if s == 0:
|
|
self.pattern_proj = pattern_proj.detach()
|
|
|
|
# apply disparity loss
|
|
# 1-edge as ground truth edge if inversed
|
|
edge0 = 1-torch.sigmoid(edge[0])
|
|
edge0 = edge0.view(-1, *edge0.shape[2:])
|
|
out0 = out[0].view(-1, *out[0].shape[2:])
|
|
val = self.disparity_loss(out0, edge0)
|
|
if self.dp_weight>0:
|
|
vals.append(val * self.dp_weight)
|
|
|
|
# apply edge loss on a subset of training samples
|
|
for s,e in zip(itertools.count(), edge):
|
|
# inversed ground truth edge where 0 means edge
|
|
grad = self.data[f'grad{s}']<0.2
|
|
grad = grad.to(torch.float32)
|
|
ids = self.data['id']
|
|
mask = ids>self.train_edge
|
|
if mask.sum()>0:
|
|
e = e[:,mask,:]
|
|
grad = grad[:,mask,:]
|
|
e = e.view(-1, *e.shape[2:])
|
|
grad = grad.view(-1, *grad.shape[2:])
|
|
val = self.edge_loss(e, grad)
|
|
else:
|
|
val = torch.zeros_like(vals[0])
|
|
vals.append(val)
|
|
|
|
if train is False:
|
|
return vals
|
|
|
|
# apply geometric loss
|
|
R = self.data['R']
|
|
t = self.data['t']
|
|
ge_num = self.track_length * (self.track_length-1) / 2
|
|
for sidx in range(len(out)):
|
|
d2d = self.d2ds[sidx]
|
|
depth = d2d(out[sidx])
|
|
ge_loss = self.ge_losses[sidx]
|
|
imsize = self.imsizes[sidx]
|
|
for tidx0 in range(depth.shape[0]):
|
|
for tidx1 in range(tidx0+1, depth.shape[0]):
|
|
depth0 = depth[tidx0]
|
|
R0 = R[tidx0]
|
|
t0 = t[tidx0]
|
|
depth1 = depth[tidx1]
|
|
R1 = R[tidx1]
|
|
t1 = t[tidx1]
|
|
|
|
val = ge_loss(depth0, depth1, R0, t0, R1, t1)
|
|
vals.append(val * self.ge_weight / ge_num)
|
|
|
|
return vals
|
|
|
|
def numpy_in_out(self, output):
|
|
output, edge = output
|
|
if not(isinstance(output, tuple) or isinstance(output, list)):
|
|
output = [output]
|
|
es = output[0].detach().to('cpu').numpy()
|
|
gt = self.data['disp0'].to('cpu').numpy().astype(np.float32)
|
|
im = self.data['im0'][:,:,0:1,...].detach().to('cpu').numpy()
|
|
ma = gt>0
|
|
return es, gt, im, ma
|
|
|
|
def write_img(self, out_path, es, gt, im, ma):
|
|
logging.info(f'write img {out_path}')
|
|
u_pos, _ = np.meshgrid(range(es.shape[1]), range(es.shape[0]))
|
|
|
|
diff = np.abs(es - gt)
|
|
|
|
vmin, vmax = np.nanmin(gt), np.nanmax(gt)
|
|
vmin = vmin - 0.2*(vmax-vmin)
|
|
vmax = vmax + 0.2*(vmax-vmin)
|
|
|
|
pattern_proj = self.pattern_proj.to('cpu').numpy()[0,0]
|
|
im_orig = self.data['im0'].detach().to('cpu').numpy()[0,0,0]
|
|
pattern_diff = np.abs(im_orig - pattern_proj)
|
|
|
|
fig = plt.figure(figsize=(16,16))
|
|
es0 = co.cmap.color_depth_map(es[0], scale=vmax)
|
|
gt0 = co.cmap.color_depth_map(gt[0], scale=vmax)
|
|
diff0 = co.cmap.color_error_image(diff[0], BGR=True)
|
|
|
|
# plot disparities, ground truth disparity is shown only for reference
|
|
ax = plt.subplot(3,3,1); plt.imshow(es0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Est. {es0.min():.4f}/{es0.max():.4f}')
|
|
ax = plt.subplot(3,3,2); plt.imshow(gt0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity GT {np.nanmin(gt0):.4f}/{np.nanmax(gt0):.4f}')
|
|
ax = plt.subplot(3,3,3); plt.imshow(diff0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Err. {diff0.mean():.5f}')
|
|
|
|
# plot disparities of the second frame in the track if exists
|
|
if es.shape[0]>=2:
|
|
es1 = co.cmap.color_depth_map(es[1], scale=vmax)
|
|
gt1 = co.cmap.color_depth_map(gt[1], scale=vmax)
|
|
diff1 = co.cmap.color_error_image(diff[1], BGR=True)
|
|
ax = plt.subplot(3,3,4); plt.imshow(es1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity Est. {es1.min():.4f}/{es1.max():.4f}')
|
|
ax = plt.subplot(3,3,5); plt.imshow(gt1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity GT {np.nanmin(gt1):.4f}/{np.nanmax(gt1):.4f}')
|
|
ax = plt.subplot(3,3,6); plt.imshow(diff1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity Err. {diff1.mean():.5f}')
|
|
|
|
# plot normalized IR inputs
|
|
ax = plt.subplot(3,3,7); plt.imshow(im[0], vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 IR input {im[0].mean():.5f}/{im[0].std():.5f}')
|
|
if es.shape[0]>=2:
|
|
ax = plt.subplot(3,3,8); plt.imshow(im[1], vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 IR input {im[1].mean():.5f}/{im[1].std():.5f}')
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(str(out_path))
|
|
plt.close(fig)
|
|
|
|
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks):
|
|
if batch_idx % 512 == 0:
|
|
out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png'
|
|
es, gt, im, ma = self.numpy_in_out(output)
|
|
masks = [ m.detach().to('cpu').numpy() for m in masks ]
|
|
self.write_img(out_path, es[:,0,0], gt[:,0,0], im[:,0,0], ma[:,0,0])
|
|
|
|
def callback_test_start(self, epoch, set_idx):
|
|
self.metric = co.metric.MultipleMetric(
|
|
co.metric.DistanceMetric(vec_length=1),
|
|
co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5])
|
|
)
|
|
|
|
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks):
|
|
es, gt, im, ma = self.numpy_in_out(output)
|
|
|
|
if batch_idx % 8 == 0:
|
|
out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png'
|
|
self.write_img(out_path, es[:,0,0], gt[:,0,0], im[:,0,0], ma[:,0,0])
|
|
|
|
es, gt, im, ma = self.crop_output(es, gt, im, ma)
|
|
|
|
es = es.reshape(-1,1)
|
|
gt = gt.reshape(-1,1)
|
|
ma = ma.ravel()
|
|
self.metric.add(es, gt, ma)
|
|
|
|
def callback_test_stop(self, epoch, set_idx, loss):
|
|
logging.info(f'{self.metric}')
|
|
for k, v in self.metric.items():
|
|
self.metric_add_test(epoch, set_idx, k, v)
|
|
|
|
def crop_output(self, es, gt, im, ma):
|
|
tl = es.shape[0]
|
|
bs = es.shape[1]
|
|
es = np.reshape(es[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w])
|
|
gt = np.reshape(gt[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w])
|
|
im = np.reshape(im[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w])
|
|
ma = np.reshape(ma[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w])
|
|
return es, gt, im, ma
|
|
|
|
if __name__ == '__main__':
|
|
pass
|
|
|