|
|
@ -12,12 +12,18 @@ import torchext |
|
|
|
from model import networks |
|
|
|
from model import networks |
|
|
|
from data import dataset |
|
|
|
from data import dataset |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import wandb |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Worker(torchext.Worker): |
|
|
|
class Worker(torchext.Worker): |
|
|
|
def __init__(self, args, num_workers=18, train_batch_size=6, test_batch_size=6, save_frequency=1, **kwargs): |
|
|
|
def __init__(self, args, num_workers=18, train_batch_size=2, test_batch_size=2, save_frequency=1, **kwargs): |
|
|
|
|
|
|
|
if 'batch_size' in dir(args): |
|
|
|
|
|
|
|
train_batch_size = args.batch_size |
|
|
|
|
|
|
|
test_batch_size = args.batch_size |
|
|
|
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, |
|
|
|
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, |
|
|
|
train_batch_size=train_batch_size, test_batch_size=test_batch_size, |
|
|
|
train_batch_size=train_batch_size, test_batch_size=test_batch_size, |
|
|
|
save_frequency=save_frequency, **kwargs) |
|
|
|
save_frequency=save_frequency, **kwargs) |
|
|
|
|
|
|
|
print(args.no_double_heads) |
|
|
|
|
|
|
|
|
|
|
|
self.ms = args.ms |
|
|
|
self.ms = args.ms |
|
|
|
self.pattern_path = args.pattern_path |
|
|
|
self.pattern_path = args.pattern_path |
|
|
@ -28,13 +34,15 @@ class Worker(torchext.Worker): |
|
|
|
self.data_type = args.data_type |
|
|
|
self.data_type = args.data_type |
|
|
|
assert (self.track_length > 1) |
|
|
|
assert (self.track_length > 1) |
|
|
|
|
|
|
|
|
|
|
|
self.imsizes = [(488, 648)] |
|
|
|
|
|
|
|
for iter in range(3): |
|
|
|
|
|
|
|
self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with open('config.json') as fp: |
|
|
|
with open('config.json') as fp: |
|
|
|
config = json.load(fp) |
|
|
|
config = json.load(fp) |
|
|
|
data_root = Path(config['DATA_ROOT']) |
|
|
|
data_root = Path(config['DATA_ROOT']) |
|
|
|
|
|
|
|
self.imsizes = [tuple(map(int, config['IMSIZE'].split(',')))] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for iter in range(3): |
|
|
|
|
|
|
|
self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.settings_path = data_root / self.data_type / 'settings.pkl' |
|
|
|
self.settings_path = data_root / self.data_type / 'settings.pkl' |
|
|
|
sample_paths = sorted((data_root / self.data_type).glob('0*/')) |
|
|
|
sample_paths = sorted((data_root / self.data_type).glob('0*/')) |
|
|
|
|
|
|
|
|
|
|
@ -84,7 +92,9 @@ class Worker(torchext.Worker): |
|
|
|
Ki = np.linalg.inv(K) |
|
|
|
Ki = np.linalg.inv(K) |
|
|
|
K = torch.from_numpy(K) |
|
|
|
K = torch.from_numpy(K) |
|
|
|
Ki = torch.from_numpy(Ki) |
|
|
|
Ki = torch.from_numpy(Ki) |
|
|
|
ge_loss = networks.ProjectionDepthSimilarityLoss(K, Ki, imsize[0], imsize[1], clamp=0.1) |
|
|
|
# FIXME why would i need to increase this? |
|
|
|
|
|
|
|
ge_loss = networks.ProjectionDepthSimilarityLoss(K, Ki, imsize[0], imsize[1], clamp=0.5) |
|
|
|
|
|
|
|
# ge_loss = networks.ProjectionDepthSimilarityLoss(K, Ki, imsize[0], imsize[1], clamp=0.1) |
|
|
|
|
|
|
|
|
|
|
|
self.ph_losses.append(ph_loss) |
|
|
|
self.ph_losses.append(ph_loss) |
|
|
|
self.ge_losses.append(ge_loss) |
|
|
|
self.ge_losses.append(ge_loss) |
|
|
@ -130,7 +140,8 @@ class Worker(torchext.Worker): |
|
|
|
out = out.view(tl, bs, *out.shape[1:]) |
|
|
|
out = out.view(tl, bs, *out.shape[1:]) |
|
|
|
edge = edge.view(tl, bs, *out.shape[1:]) |
|
|
|
edge = edge.view(tl, bs, *out.shape[1:]) |
|
|
|
else: |
|
|
|
else: |
|
|
|
out = [o.view(tl, bs, *o.shape[1:]) for o in out] |
|
|
|
out = [o[0].view(tl, bs, *o[0].shape[1:]) for o in out] |
|
|
|
|
|
|
|
# out = [o.view(tl, bs, *o.shape[1:]) for o in out] |
|
|
|
edge = [e.view(tl, bs, *e.shape[1:]) for e in edge] |
|
|
|
edge = [e.view(tl, bs, *e.shape[1:]) for e in edge] |
|
|
|
return out, edge |
|
|
|
return out, edge |
|
|
|
|
|
|
|
|
|
|
@ -140,6 +151,7 @@ class Worker(torchext.Worker): |
|
|
|
out = [out] |
|
|
|
out = [out] |
|
|
|
vals = [] |
|
|
|
vals = [] |
|
|
|
diffs = [] |
|
|
|
diffs = [] |
|
|
|
|
|
|
|
losses = {} |
|
|
|
|
|
|
|
|
|
|
|
# apply photometric loss |
|
|
|
# apply photometric loss |
|
|
|
for s, l, o in zip(itertools.count(), self.ph_losses, out): |
|
|
|
for s, l, o in zip(itertools.count(), self.ph_losses, out): |
|
|
@ -149,6 +161,7 @@ class Worker(torchext.Worker): |
|
|
|
std = self.data[f'std{s}'] |
|
|
|
std = self.data[f'std{s}'] |
|
|
|
std = std.view(-1, *std.shape[2:]) |
|
|
|
std = std.view(-1, *std.shape[2:]) |
|
|
|
val, pattern_proj = l(o, im[:, 0:1, ...], std) |
|
|
|
val, pattern_proj = l(o, im[:, 0:1, ...], std) |
|
|
|
|
|
|
|
losses['photometric'] = val |
|
|
|
vals.append(val) |
|
|
|
vals.append(val) |
|
|
|
if s == 0: |
|
|
|
if s == 0: |
|
|
|
self.pattern_proj = pattern_proj.detach() |
|
|
|
self.pattern_proj = pattern_proj.detach() |
|
|
@ -159,6 +172,7 @@ class Worker(torchext.Worker): |
|
|
|
edge0 = edge0.view(-1, *edge0.shape[2:]) |
|
|
|
edge0 = edge0.view(-1, *edge0.shape[2:]) |
|
|
|
out0 = out[0].view(-1, *out[0].shape[2:]) |
|
|
|
out0 = out[0].view(-1, *out[0].shape[2:]) |
|
|
|
val = self.disparity_loss(out0, edge0) |
|
|
|
val = self.disparity_loss(out0, edge0) |
|
|
|
|
|
|
|
losses['disparity'] = val * self.dp_weight |
|
|
|
if self.dp_weight > 0: |
|
|
|
if self.dp_weight > 0: |
|
|
|
vals.append(val * self.dp_weight) |
|
|
|
vals.append(val * self.dp_weight) |
|
|
|
|
|
|
|
|
|
|
@ -177,6 +191,7 @@ class Worker(torchext.Worker): |
|
|
|
val = self.edge_loss(e, grad) |
|
|
|
val = self.edge_loss(e, grad) |
|
|
|
else: |
|
|
|
else: |
|
|
|
val = torch.zeros_like(vals[0]) |
|
|
|
val = torch.zeros_like(vals[0]) |
|
|
|
|
|
|
|
losses['edge loss'] = val |
|
|
|
vals.append(val) |
|
|
|
vals.append(val) |
|
|
|
|
|
|
|
|
|
|
|
if train is False: |
|
|
|
if train is False: |
|
|
@ -201,8 +216,10 @@ class Worker(torchext.Worker): |
|
|
|
t1 = t[tidx1] |
|
|
|
t1 = t[tidx1] |
|
|
|
|
|
|
|
|
|
|
|
val = ge_loss(depth0, depth1, R0, t0, R1, t1) |
|
|
|
val = ge_loss(depth0, depth1, R0, t0, R1, t1) |
|
|
|
|
|
|
|
losses['geometric loss'] = val |
|
|
|
vals.append(val * self.ge_weight / ge_num) |
|
|
|
vals.append(val * self.ge_weight / ge_num) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
wandb.log(losses) |
|
|
|
return vals |
|
|
|
return vals |
|
|
|
|
|
|
|
|
|
|
|
def numpy_in_out(self, output): |
|
|
|
def numpy_in_out(self, output): |
|
|
@ -287,6 +304,7 @@ class Worker(torchext.Worker): |
|
|
|
|
|
|
|
|
|
|
|
plt.tight_layout() |
|
|
|
plt.tight_layout() |
|
|
|
plt.savefig(str(out_path)) |
|
|
|
plt.savefig(str(out_path)) |
|
|
|
|
|
|
|
wandb.log({f'results_{"_".join(out_path.stem.split("_")[:-1])}': plt}) |
|
|
|
plt.close(fig) |
|
|
|
plt.close(fig) |
|
|
|
|
|
|
|
|
|
|
|
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks): |
|
|
|
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks): |
|
|
|