|
|
|
import os
|
|
|
|
import sys
|
|
|
|
import time
|
|
|
|
import logging
|
|
|
|
from collections import namedtuple
|
|
|
|
|
|
|
|
import yaml
|
|
|
|
|
|
|
|
from nets import Model
|
|
|
|
from dataset import BlenderDataset, CREStereoDataset, CTDDataset
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import torch.optim as optim
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
|
|
|
|
from pytorch_lightning import Trainer, seed_everything
|
|
|
|
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
|
|
|
from pytorch_lightning.callbacks import LearningRateMonitor
|
|
|
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
|
|
|
from pytorch_lightning.loggers import WandbLogger
|
|
|
|
from pytorch_lightning.strategies import DDPSpawnStrategy
|
|
|
|
|
|
|
|
seed_everything(42, workers=True)
|
|
|
|
|
|
|
|
|
|
|
|
import wandb
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import cv2
|
|
|
|
|
|
|
|
|
|
|
|
def normalize_and_colormap(img):
|
|
|
|
ret = (img - img.min()) / (img.max() - img.min()) * 255.0
|
|
|
|
if isinstance(ret, torch.Tensor):
|
|
|
|
ret = ret.cpu().detach().numpy()
|
|
|
|
ret = ret.astype("uint8")
|
|
|
|
ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO)
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
|
|
def log_images(left, right, pred_disp, gt_disp):
|
|
|
|
log = {}
|
|
|
|
batch_idx = 0
|
|
|
|
|
|
|
|
if isinstance(pred_disp, list):
|
|
|
|
pred_disp = pred_disp[-1]
|
|
|
|
|
|
|
|
pred_disp = torch.squeeze(pred_disp[:, 0, :, :])
|
|
|
|
gt_disp = torch.squeeze(gt_disp[:, 0, :, :])
|
|
|
|
left = torch.squeeze(left[:, 0, :, :])
|
|
|
|
right = torch.squeeze(right[:, 0, :, :])
|
|
|
|
|
|
|
|
disp = pred_disp
|
|
|
|
disp_error = gt_disp - disp
|
|
|
|
|
|
|
|
input_left = left[batch_idx].cpu().detach().numpy()# .transpose(1,2,0)
|
|
|
|
input_right = right[batch_idx].cpu().detach().numpy()# .transpose(1,2,0)
|
|
|
|
|
|
|
|
wandb_log = dict(
|
|
|
|
key='samples',
|
|
|
|
images=[
|
|
|
|
normalize_and_colormap(pred_disp[batch_idx]),
|
|
|
|
normalize_and_colormap(abs(disp_error[batch_idx])),
|
|
|
|
normalize_and_colormap(gt_disp[batch_idx]),
|
|
|
|
input_left,
|
|
|
|
input_right,
|
|
|
|
],
|
|
|
|
caption=[
|
|
|
|
f"Disparity \n{pred_disp[batch_idx].min():.{2}f}/{pred_disp[batch_idx].max():.{2}f}",
|
|
|
|
f"Disp. Error\n{disp_error[batch_idx].min():.{2}f}/{disp_error[batch_idx].max():.{2}f}\n{abs(disp_error[batch_idx]).mean():.{2}f}",
|
|
|
|
f"GT Disp Vis \n{gt_disp[batch_idx].min():.{2}f}/{gt_disp[batch_idx].max():.{2}f}",
|
|
|
|
"Input Left",
|
|
|
|
"Input Right"
|
|
|
|
],
|
|
|
|
)
|
|
|
|
return wandb_log
|
|
|
|
|
|
|
|
|
|
|
|
def parse_yaml(file_path: str) -> namedtuple:
|
|
|
|
"""Parse yaml configuration file and return the object in `namedtuple`."""
|
|
|
|
with open(file_path, "rb") as f:
|
|
|
|
cfg: dict = yaml.safe_load(f)
|
|
|
|
args = namedtuple("train_args", cfg.keys())(*cfg.values())
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
|
def format_time(elapse):
|
|
|
|
elapse = int(elapse)
|
|
|
|
hour = elapse // 3600
|
|
|
|
minute = elapse % 3600 // 60
|
|
|
|
seconds = elapse % 60
|
|
|
|
return "{:02d}:{:02d}:{:02d}".format(hour, minute, seconds)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def outlier_fraction(estimate, target, mask=None, threshold=0):
|
|
|
|
def _process_inputs(estimate, target, mask):
|
|
|
|
if estimate.shape != target.shape:
|
|
|
|
raise Exception(f'estimate and target have to be same shape (expected {estimate.shape} == {target.shape})')
|
|
|
|
if mask is None:
|
|
|
|
mask = np.ones(estimate.shape, dtype=np.bool)
|
|
|
|
else:
|
|
|
|
mask = mask != 0
|
|
|
|
if estimate.shape != mask.shape:
|
|
|
|
raise Exception(f'estimate and mask have to be same shape (expected {estimate.shape} == {mask.shape})')
|
|
|
|
return estimate, target, mask
|
|
|
|
estimate = torch.squeeze(estimate[:, 0, :, :])
|
|
|
|
target = torch.squeeze(target[:, 0, :, :])
|
|
|
|
|
|
|
|
estimate, target, mask = _process_inputs(estimate, target, mask)
|
|
|
|
|
|
|
|
mask = mask.cpu().detach().numpy()
|
|
|
|
estimate = estimate.cpu().detach().numpy()
|
|
|
|
target = target.cpu().detach().numpy()
|
|
|
|
|
|
|
|
diff = np.abs(estimate[mask] - target[mask])
|
|
|
|
m = (diff > threshold).sum() / mask.sum()
|
|
|
|
return m
|
|
|
|
|
|
|
|
|
|
|
|
def ensure_dir(path):
|
|
|
|
if not os.path.exists(path):
|
|
|
|
os.makedirs(path, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False):
|
|
|
|
'''
|
|
|
|
valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W)
|
|
|
|
flow_preds[0]: (B, 2, H, W)
|
|
|
|
flow_gt: (B, 2, H, W)
|
|
|
|
'''
|
|
|
|
"""
|
|
|
|
if test:
|
|
|
|
# print('sequence loss')
|
|
|
|
if valid.shape != (2, 480, 640):
|
|
|
|
valid = torch.stack([valid, valid])#.transpose(0,1)#.transpose(1,2)
|
|
|
|
# print(valid.shape)
|
|
|
|
#valid = torch.stack([valid, valid])
|
|
|
|
# print(valid.shape)
|
|
|
|
if valid.shape != (2, 480, 640):
|
|
|
|
valid = valid.transpose(0,1)
|
|
|
|
# print(valid.shape)
|
|
|
|
"""
|
|
|
|
# print(valid.shape)
|
|
|
|
# print(flow_preds[0].shape)
|
|
|
|
# print(flow_gt.shape)
|
|
|
|
n_predictions = len(flow_preds)
|
|
|
|
flow_loss = 0.0
|
|
|
|
|
|
|
|
# TEST
|
|
|
|
# flow_gt = torch.squeeze(flow_gt, dim=-1)
|
|
|
|
|
|
|
|
for i in range(n_predictions):
|
|
|
|
i_weight = gamma ** (n_predictions - i - 1)
|
|
|
|
i_loss = torch.abs(flow_preds[i] - flow_gt)
|
|
|
|
flow_loss += i_weight * (valid.unsqueeze(1) * i_loss).mean()
|
|
|
|
|
|
|
|
return flow_loss
|
|
|
|
|
|
|
|
|
|
|
|
class CREStereoLightning(LightningModule):
|
|
|
|
def __init__(self, args, logger=None, pattern_path='', data_path=''):
|
|
|
|
super().__init__()
|
|
|
|
self.batch_size = args.batch_size
|
|
|
|
self.wandb_logger = logger
|
|
|
|
self.data_type = 'blender' if 'blender' in data_path else 'ctd'
|
|
|
|
self.lr = args.base_lr
|
|
|
|
print(f'lr = {self.lr}')
|
|
|
|
self.T_max = args.t_max if args.t_max else None
|
|
|
|
self.pattern_attention = args.pattern_attention
|
|
|
|
self.pattern_path = pattern_path
|
|
|
|
self.data_path = data_path
|
|
|
|
self.model = Model(
|
|
|
|
max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False
|
|
|
|
)
|
|
|
|
# so I can access it in adjust learn rate more easily
|
|
|
|
self.n_total_epoch = args.n_total_epoch
|
|
|
|
self.base_lr = args.base_lr
|
|
|
|
|
|
|
|
self.automatic_optimization = False
|
|
|
|
|
|
|
|
def train_dataloader(self):
|
|
|
|
if self.data_type == 'blender':
|
|
|
|
dataset = BlenderDataset(
|
|
|
|
root=self.data_path,
|
|
|
|
pattern_path=self.pattern_path,
|
|
|
|
use_lightning=True,
|
|
|
|
)
|
|
|
|
elif self.data_type == 'ctd':
|
|
|
|
dataset = CTDDataset(
|
|
|
|
root=self.data_path,
|
|
|
|
pattern_path=self.pattern_path,
|
|
|
|
use_lightning=True,
|
|
|
|
)
|
|
|
|
dataloader = DataLoader(
|
|
|
|
dataset,
|
|
|
|
self.batch_size,
|
|
|
|
shuffle=True,
|
|
|
|
num_workers=4,
|
|
|
|
drop_last=True,
|
|
|
|
persistent_workers=True,
|
|
|
|
pin_memory=True,
|
|
|
|
)
|
|
|
|
# num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True)
|
|
|
|
return dataloader
|
|
|
|
|
|
|
|
def val_dataloader(self):
|
|
|
|
if self.data_type == 'blender':
|
|
|
|
test_dataset = BlenderDataset(
|
|
|
|
root=self.data_path,
|
|
|
|
pattern_path=self.pattern_path,
|
|
|
|
test_set=True,
|
|
|
|
use_lightning=True,
|
|
|
|
)
|
|
|
|
elif self.data_type == 'ctd':
|
|
|
|
test_dataset = CTDDataset(
|
|
|
|
root=self.data_path,
|
|
|
|
pattern_path=self.pattern_path,
|
|
|
|
test_set=True,
|
|
|
|
use_lightning=True,
|
|
|
|
)
|
|
|
|
|
|
|
|
test_dataloader = DataLoader(
|
|
|
|
test_dataset,
|
|
|
|
self.batch_size,
|
|
|
|
shuffle=False,
|
|
|
|
num_workers=4,
|
|
|
|
drop_last=False,
|
|
|
|
persistent_workers=True,
|
|
|
|
pin_memory=True
|
|
|
|
)
|
|
|
|
# num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True)
|
|
|
|
return test_dataloader
|
|
|
|
|
|
|
|
def test_dataloader(self):
|
|
|
|
# TODO change this to use IRL data?
|
|
|
|
if self.data_type == 'blender':
|
|
|
|
test_dataset = CTDDataset(
|
|
|
|
root=self.data_path,
|
|
|
|
pattern_path=self.pattern_path,
|
|
|
|
test_set=True,
|
|
|
|
use_lightning=True,
|
|
|
|
)
|
|
|
|
elif self.data_type == 'ctd':
|
|
|
|
test_dataset = BlenderDataset(
|
|
|
|
root=self.data_path,
|
|
|
|
pattern_path=self.pattern_path,
|
|
|
|
test_set=True,
|
|
|
|
use_lightning=True,
|
|
|
|
)
|
|
|
|
|
|
|
|
test_dataloader = DataLoader(
|
|
|
|
test_dataset,
|
|
|
|
self.batch_size,
|
|
|
|
shuffle=False,
|
|
|
|
num_workers=4,
|
|
|
|
drop_last=False,
|
|
|
|
persistent_workers=True,
|
|
|
|
pin_memory=True
|
|
|
|
)
|
|
|
|
return test_dataloader
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
image1,
|
|
|
|
image2,
|
|
|
|
flow_init=None,
|
|
|
|
iters=10,
|
|
|
|
upsample=True,
|
|
|
|
test_mode=False,
|
|
|
|
):
|
|
|
|
return self.model(image1, image2, flow_init, iters, upsample, test_mode, self.pattern_attention)
|
|
|
|
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
left, right, gt_disp, valid_mask = batch
|
|
|
|
flow_predictions = self.forward(left, right)
|
|
|
|
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512]
|
|
|
|
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
|
|
|
|
loss = sequence_loss(
|
|
|
|
flow_predictions, gt_flow, valid_mask, gamma=0.8
|
|
|
|
)
|
|
|
|
if batch_idx % 128 == 0:
|
|
|
|
image_log = log_images(left, right, flow_predictions, gt_disp)
|
|
|
|
image_log['key'] = 'debug_train'
|
|
|
|
self.wandb_logger.log_image(**image_log)
|
|
|
|
self.log("train_loss", loss)
|
|
|
|
|
|
|
|
# update learn rate every N epochs
|
|
|
|
if self.trainer.is_last_batch:
|
|
|
|
self.adjust_learning_rate()
|
|
|
|
return loss
|
|
|
|
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
|
|
left, right, gt_disp, valid_mask = batch
|
|
|
|
flow_predictions = self.forward(left, right, test_mode=True)
|
|
|
|
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512]
|
|
|
|
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
|
|
|
|
val_loss = sequence_loss(
|
|
|
|
flow_predictions, gt_flow, valid_mask, gamma=0.8
|
|
|
|
)
|
|
|
|
self.log("val_loss", val_loss)
|
|
|
|
of = {}
|
|
|
|
for threshold in [0.1, 0.5, 1, 2, 5]:
|
|
|
|
of[threshold] = outlier_fraction(flow_predictions[0], gt_flow, valid_mask, threshold)
|
|
|
|
self.log("outlier_fraction", of)
|
|
|
|
# print(', '.join(f'of{thr}={val}' for thr, val in of.items()))
|
|
|
|
if batch_idx % 8 == 0:
|
|
|
|
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp))
|
|
|
|
|
|
|
|
def test_step(self, batch, batch_idx):
|
|
|
|
left, right, gt_disp, valid_mask = batch
|
|
|
|
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512]
|
|
|
|
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
|
|
|
|
flow_predictions = self.forward(left, right, test_mode=True)
|
|
|
|
test_loss = sequence_loss(
|
|
|
|
flow_predictions, gt_flow, valid_mask, gamma=0.8
|
|
|
|
)
|
|
|
|
self.log("test_loss", test_loss)
|
|
|
|
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp))
|
|
|
|
|
|
|
|
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
|
|
|
return self(batch)
|
|
|
|
|
|
|
|
def configure_optimizers(self):
|
|
|
|
optimizer = optim.Adam(self.model.parameters(), lr=self.lr, betas=(0.9, 0.999))
|
|
|
|
print('len(self.train_dataloader)', len(self.train_dataloader()))
|
|
|
|
lr_scheduler = {
|
|
|
|
'scheduler': torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
|
|
optimizer,
|
|
|
|
T_max=self.T_max if self.T_max else len(self.train_dataloader())/self.batch_size,
|
|
|
|
),
|
|
|
|
'name': 'CosineAnnealingLRScheduler',
|
|
|
|
}
|
|
|
|
return [optimizer], [lr_scheduler]
|
|
|
|
|
|
|
|
def adjust_learning_rate(self):
|
|
|
|
optimizer = self.optimizers().optimizer
|
|
|
|
epoch = self.trainer.current_epoch+1
|
|
|
|
|
|
|
|
warm_up = 0.02
|
|
|
|
const_range = 0.6
|
|
|
|
min_lr_rate = 0.05
|
|
|
|
|
|
|
|
if epoch <= self.n_total_epoch * warm_up:
|
|
|
|
lr = (1 - min_lr_rate) * self.base_lr / (
|
|
|
|
self.n_total_epoch * warm_up
|
|
|
|
) * epoch + min_lr_rate * self.base_lr
|
|
|
|
elif self.n_total_epoch * warm_up < epoch <= self.n_total_epoch * const_range:
|
|
|
|
lr = self.base_lr
|
|
|
|
else:
|
|
|
|
lr = (min_lr_rate - 1) * self.base_lr / (
|
|
|
|
(1 - const_range) * self.n_total_epoch
|
|
|
|
) * epoch + (1 - min_lr_rate * const_range) / (1 - const_range) * self.base_lr
|
|
|
|
|
|
|
|
for param_group in optimizer.param_groups:
|
|
|
|
param_group['lr'] = lr
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
wandb_logger = WandbLogger(project="crestereo-lightning", log_model=True)
|
|
|
|
# train configuration
|
|
|
|
args = parse_yaml("cfgs/train.yaml")
|
|
|
|
wandb_logger.experiment.config.update(args._asdict())
|
|
|
|
config = wandb.config
|
|
|
|
if 'blender' in config.training_data_path:
|
|
|
|
# this was used for our blender renders
|
|
|
|
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
|
|
|
|
if 'ctd' in config.training_data_path:
|
|
|
|
# this one is used (i hope) for ctd
|
|
|
|
pattern_path = '/home/nils/kinect_from_settings.png'
|
|
|
|
|
|
|
|
|
|
|
|
devices = min(config.nr_gpus, torch.cuda.device_count())
|
|
|
|
if devices != config.nr_gpus:
|
|
|
|
print(f'Using less devices than expected! ({devices} / {config.nr_gpus})')
|
|
|
|
|
|
|
|
model = CREStereoLightning(
|
|
|
|
# args,
|
|
|
|
config,
|
|
|
|
wandb_logger,
|
|
|
|
pattern_path,
|
|
|
|
config.training_data_path,
|
|
|
|
# lr=0.00017378008287493763, # found with auto_lr_find=True
|
|
|
|
)
|
|
|
|
# NOTE turn this down once it's working, this might use too much space
|
|
|
|
# wandb_logger.watch(model, log_graph=False) #, log='all')
|
|
|
|
|
|
|
|
model_checkpoint = ModelCheckpoint(
|
|
|
|
monitor="val_loss",
|
|
|
|
mode="min",
|
|
|
|
save_top_k=2,
|
|
|
|
save_last=True,
|
|
|
|
)
|
|
|
|
|
|
|
|
trainer = Trainer(
|
|
|
|
accelerator='gpu',
|
|
|
|
devices=devices,
|
|
|
|
max_epochs=config.n_total_epoch,
|
|
|
|
callbacks=[
|
|
|
|
EarlyStopping(
|
|
|
|
monitor="val_loss",
|
|
|
|
mode="min",
|
|
|
|
patience=16,
|
|
|
|
),
|
|
|
|
LearningRateMonitor(),
|
|
|
|
model_checkpoint,
|
|
|
|
],
|
|
|
|
strategy=DDPSpawnStrategy(find_unused_parameters=False),
|
|
|
|
# auto_scale_batch_size='binsearch',
|
|
|
|
# auto_lr_find=True,
|
|
|
|
# accumulate_grad_batches=4, # needed to disable for manual optimization
|
|
|
|
deterministic=True,
|
|
|
|
check_val_every_n_epoch=1,
|
|
|
|
limit_val_batches=64,
|
|
|
|
limit_test_batches=256,
|
|
|
|
logger=wandb_logger,
|
|
|
|
default_root_dir=config.log_dir_lightning,
|
|
|
|
)
|
|
|
|
|
|
|
|
# trainer.tune(model)
|
|
|
|
trainer.fit(model)
|
|
|
|
# trainer.validate(chkpt_path=model_checkpoint.best_model_path)
|