|
|
|
@ -10,6 +10,7 @@ import matplotlib.pyplot as plt |
|
|
|
|
import cv2 |
|
|
|
|
import torchvision.transforms as transforms |
|
|
|
|
|
|
|
|
|
import wandb |
|
|
|
|
|
|
|
|
|
import co |
|
|
|
|
import torchext |
|
|
|
@ -18,28 +19,33 @@ from data import dataset |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Worker(torchext.Worker): |
|
|
|
|
def __init__(self, args, num_workers=18, train_batch_size=4, test_batch_size=4, save_frequency=1, **kwargs): |
|
|
|
|
def __init__(self, args, num_workers=18, train_batch_size=6, test_batch_size=6, save_frequency=1, **kwargs): |
|
|
|
|
if 'batch_size' in dir(args): |
|
|
|
|
train_batch_size = args.batch_size |
|
|
|
|
test_batch_size = args.batch_size |
|
|
|
|
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, |
|
|
|
|
train_batch_size=train_batch_size, test_batch_size=test_batch_size, |
|
|
|
|
save_frequency=save_frequency, **kwargs) |
|
|
|
|
|
|
|
|
|
self.ms = args.ms |
|
|
|
|
self.pattern_path = args.pattern_path |
|
|
|
|
self.lcn_radius = args.lcn_radius |
|
|
|
|
self.dp_weight = args.dp_weight |
|
|
|
|
self.data_type = args.data_type |
|
|
|
|
|
|
|
|
|
self.imsizes = [(488, 648)] |
|
|
|
|
for iter in range(3): |
|
|
|
|
self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2))) |
|
|
|
|
|
|
|
|
|
with open('config.json') as fp: |
|
|
|
|
config = json.load(fp) |
|
|
|
|
data_root = Path(config['DATA_ROOT']) |
|
|
|
|
self.imsizes = [tuple(map(int, config['IMSIZE'].split(',')))] |
|
|
|
|
|
|
|
|
|
for iter in range(3): |
|
|
|
|
self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2))) |
|
|
|
|
|
|
|
|
|
self.settings_path = data_root / self.data_type / 'settings.pkl' |
|
|
|
|
sample_paths = sorted((data_root / self.data_type).glob('0*/')) |
|
|
|
|
|
|
|
|
|
self.train_paths = sample_paths[2 ** 10:] |
|
|
|
|
# FIXME just for testing, make this bigger at some point |
|
|
|
|
# self.train_paths = sample_paths[2 ** 3:] |
|
|
|
|
self.test_paths = sample_paths[:2 ** 8] |
|
|
|
|
|
|
|
|
|
# supervise the edge encoder with only 2**8 samples |
|
|
|
@ -51,6 +57,25 @@ class Worker(torchext.Worker): |
|
|
|
|
self.sup_disp_loss = torch.nn.MSELoss() |
|
|
|
|
self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device)) |
|
|
|
|
|
|
|
|
|
# FIXME L2 Regularization, try it!! |
|
|
|
|
# l2_lambda = 0.001 |
|
|
|
|
# l2_norm = sum(p.pow(2.0).sum() |
|
|
|
|
# for p in net.parameters()) |
|
|
|
|
# self.sup_disp_loss = torch.nn.MSELoss() + l2_lambda * l2_norm |
|
|
|
|
# FIXME try using log of this loss, otherwise it's very large compared to others |
|
|
|
|
# self.sup_disp_loss = torch.nn.MSELoss() |
|
|
|
|
class RMSLELoss(torch.nn.Module): |
|
|
|
|
def __init__(self): |
|
|
|
|
super().__init__() |
|
|
|
|
self.mse = torch.nn.MSELoss() |
|
|
|
|
|
|
|
|
|
def forward(self, pred, actual): |
|
|
|
|
# FIXME rename this if log is better than sqrt |
|
|
|
|
return torch.log(self.mse(pred, actual)) |
|
|
|
|
# return torch.sqrt(self.mse(torch.log(pred + 1), torch.log(actual + 1))) |
|
|
|
|
|
|
|
|
|
self.sup_disp_loss = RMSLELoss() |
|
|
|
|
|
|
|
|
|
# evaluate in the region where opencv Block Matching has valid values |
|
|
|
|
self.eval_mask = np.zeros(self.imsizes[0]) |
|
|
|
|
self.eval_mask[13:self.imsizes[0][0] - 13, 140:self.imsizes[0][1] - 13] = 1 |
|
|
|
@ -108,6 +133,7 @@ class Worker(torchext.Worker): |
|
|
|
|
|
|
|
|
|
def loss_forward(self, out, train): |
|
|
|
|
out, edge = out |
|
|
|
|
losses = {} |
|
|
|
|
if not (isinstance(out, tuple) or isinstance(out, list)): |
|
|
|
|
out = [out] |
|
|
|
|
if not (isinstance(edge, tuple) or isinstance(edge, list)): |
|
|
|
@ -118,11 +144,11 @@ class Worker(torchext.Worker): |
|
|
|
|
# apply photometric loss |
|
|
|
|
for s, l, o in zip(itertools.count(), self.losses, out): |
|
|
|
|
val, pattern_proj = l(o[0], self.data[f'im{s}'][:, 0:1, ...], self.data[f'std{s}']) |
|
|
|
|
|
|
|
|
|
if s == 0: |
|
|
|
|
self.pattern_proj = pattern_proj.detach() |
|
|
|
|
vals.append(val) |
|
|
|
|
|
|
|
|
|
# apply disparity loss |
|
|
|
|
losses['photometric'] = val |
|
|
|
|
# 1-edge as ground truth edge if inversed |
|
|
|
|
if isinstance(edge, tuple): |
|
|
|
|
edge0 = 1 - torch.sigmoid(edge[0][0]) |
|
|
|
@ -130,11 +156,18 @@ class Worker(torchext.Worker): |
|
|
|
|
edge0 = 1 - torch.sigmoid(edge[0]) |
|
|
|
|
val = 0 |
|
|
|
|
if isinstance(out[0], tuple): |
|
|
|
|
sup_loss = self.sup_disp_loss(out[0][1], self.data['disp0']) |
|
|
|
|
val += sup_loss |
|
|
|
|
|
|
|
|
|
disp_loss = self.disparity_loss(out[0][0], edge0) |
|
|
|
|
val += disp_loss |
|
|
|
|
|
|
|
|
|
val += self.sup_disp_loss(out[0][1], self.data['disp0']) |
|
|
|
|
val += self.disparity_loss(out[0][0], edge0) |
|
|
|
|
losses['GT Supervised disparity loss'] = sup_loss * self.dp_weight |
|
|
|
|
losses['OG disparity loss'] = disp_loss * self.dp_weight |
|
|
|
|
else: |
|
|
|
|
val += self.disparity_loss(out[0], edge0) |
|
|
|
|
disp_loss = self.disparity_loss(out[0], edge0) |
|
|
|
|
val += disp_loss |
|
|
|
|
losses['OG disparity loss'] = disp_loss * self.dp_weight |
|
|
|
|
if self.dp_weight > 0: |
|
|
|
|
vals.append(val * self.dp_weight) |
|
|
|
|
|
|
|
|
@ -159,15 +192,20 @@ class Worker(torchext.Worker): |
|
|
|
|
self.edge = e.detach() |
|
|
|
|
self.edge = torch.sigmoid(self.edge) |
|
|
|
|
self.edge_gt = grad.detach() |
|
|
|
|
losses['edge'] = val |
|
|
|
|
vals.append(val) |
|
|
|
|
|
|
|
|
|
wandb.log(losses) |
|
|
|
|
return vals |
|
|
|
|
|
|
|
|
|
def numpy_in_out(self, output): |
|
|
|
|
output, edge = output |
|
|
|
|
if not (isinstance(output, tuple) or isinstance(output, list)): |
|
|
|
|
output = [output] |
|
|
|
|
es = output[0][0].detach().to('cpu').numpy() |
|
|
|
|
if isinstance(output[0], tuple): |
|
|
|
|
es = output[0][0].detach().to('cpu').numpy() |
|
|
|
|
else: |
|
|
|
|
es = output[0].detach().to('cpu').numpy() |
|
|
|
|
gt = self.data['disp0'].to('cpu').numpy().astype(np.float32) |
|
|
|
|
im = self.data['im0'][:, 0:1, ...].detach().to('cpu').numpy() |
|
|
|
|
|
|
|
|
@ -250,6 +288,7 @@ class Worker(torchext.Worker): |
|
|
|
|
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
plt.savefig(str(out_path)) |
|
|
|
|
wandb.log({f'results_{"_".join(out_path.stem.split("_")[:-1])}': plt}) |
|
|
|
|
plt.close(fig) |
|
|
|
|
|
|
|
|
|
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks=[]): |
|
|
|
@ -293,9 +332,4 @@ class Worker(torchext.Worker): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
# FIXME Nicolas fixe idee |
|
|
|
|
# SGBM nutzen, um GT zu finden |
|
|
|
|
# bei dispnet (oder w/e) letzte paar layer 'dublizieren' (zweiten head bauen) und so mehrere Loss funktionen gleichzeitig trainieren |
|
|
|
|
# L1 + L2 und dann im selben Backwardspass optimieren |
|
|
|
|
# für das ganze forward pass anpassen |
|
|
|
|
pass |
|
|
|
|