CREStereo Repository for the 'Towards accurate and robust depth estimation' project
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
CREStereo-pytorch-nxt/train_lightning.py

504 lines
18 KiB

import os
import sys
import time
import logging
from collections import namedtuple
import yaml
from nets import Model
from dataset import BlenderDataset, CREStereoDataset, CTDDataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.strategies import DDPSpawnStrategy
seed_everything(42, workers=True)
import wandb
import numpy as np
import cv2
def normalize_and_colormap(img, reduce_dynamic_range=False):
ret = (img - img.min()) / (img.max() - img.min()) * 255.0
# FIXME do I need to compress dynamic range somehow or something?
if reduce_dynamic_range and img.max() > 5*img.mean():
ret = (img - img.min()) / (5*img.mean() - img.min()) * 255.0
if isinstance(ret, torch.Tensor):
ret = ret.cpu().detach().numpy()
ret = ret.astype("uint8")
ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO)
return ret
def log_images(left, right, pred_disp, gt_disp):
log = {}
batch_idx = 0
if isinstance(pred_disp, list):
pred_disp = pred_disp[-1]
left = torch.squeeze(left[:, 0, :, :])
right = torch.squeeze(right[:, 0, :, :])
pred_disp = torch.squeeze(pred_disp[:, 0, :, :])
gt_disp = torch.squeeze(gt_disp[:, 0, :, :])
singular_batch = False
if len(left.shape) == 2:
singular_batch = True
print('batch_size seems to be 1')
input_left = left.cpu().detach().numpy()
input_right = right.cpu().detach().numpy()
else:
input_left = left[batch_idx].cpu().detach().numpy()
input_right = right[batch_idx].cpu().detach().numpy()
disp = pred_disp
disp_error = gt_disp - disp
if singular_batch:
wandb_log = dict(
key='samples',
images=[
pred_disp,
normalize_and_colormap(pred_disp),
normalize_and_colormap(abs(disp_error), reduce_dynamic_range=True),
normalize_and_colormap(gt_disp, reduce_dynamic_range=True),
input_left,
input_right,
],
caption=[
f"Disparity \n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
f"Disparity (vis) \n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{abs(disp_error).mean():.{2}f}",
f"GT Disp Vis \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}",
"Input Left",
"Input Right"
],
)
else:
wandb_log = dict(
key='samples',
images=[
normalize_and_colormap(pred_disp[batch_idx]),
normalize_and_colormap(abs(disp_error[batch_idx])),
normalize_and_colormap(gt_disp[batch_idx]),
input_left,
input_right,
],
caption=[
f"Disparity (vis)\n{pred_disp[batch_idx].min():.{2}f}/{pred_disp[batch_idx].max():.{2}f}",
f"Disp. Error\n{disp_error[batch_idx].min():.{2}f}/{disp_error[batch_idx].max():.{2}f}\n{abs(disp_error[batch_idx]).mean():.{2}f}",
f"GT Disp Vis \n{gt_disp[batch_idx].min():.{2}f}/{gt_disp[batch_idx].max():.{2}f}",
"Input Left",
"Input Right"
],
)
return wandb_log
def parse_yaml(file_path: str) -> namedtuple:
"""Parse yaml configuration file and return the object in `namedtuple`."""
with open(file_path, "rb") as f:
cfg: dict = yaml.safe_load(f)
args = namedtuple("train_args", cfg.keys())(*cfg.values())
return args
def format_time(elapse):
elapse = int(elapse)
hour = elapse // 3600
minute = elapse % 3600 // 60
seconds = elapse % 60
return "{:02d}:{:02d}:{:02d}".format(hour, minute, seconds)
def outlier_fraction(estimate, target, mask=None, threshold=0):
def _process_inputs(estimate, target, mask):
if estimate.shape != target.shape:
raise Exception(f'estimate and target have to be same shape (expected {estimate.shape} == {target.shape})')
if mask is None:
mask = np.ones(estimate.shape, dtype=np.bool)
else:
mask = mask != 0
if estimate.shape != mask.shape:
if len(mask.shape) == 3:
mask = mask[0]
if estimate.shape != mask.shape:
raise Exception(f'estimate and mask have to be same shape (expected {estimate.shape} == {mask.shape})')
return estimate, target, mask
estimate = torch.squeeze(estimate[:, 0, :, :])
target = torch.squeeze(target[:, 0, :, :])
estimate, target, mask = _process_inputs(estimate, target, mask)
mask = mask.cpu().detach().numpy()
estimate = estimate.cpu().detach().numpy()
target = target.cpu().detach().numpy()
diff = np.abs(estimate[mask] - target[mask])
m = (diff > threshold).sum() / mask.sum()
return m
def ensure_dir(path):
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False):
'''
valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W)
flow_preds[0]: (B, 2, H, W)
flow_gt: (B, 2, H, W)
'''
n_predictions = len(flow_preds)
flow_loss = 0.0
for i in range(n_predictions):
i_weight = gamma ** (n_predictions - i - 1)
i_loss = torch.abs(flow_preds[i] - flow_gt)
flow_loss += i_weight * (valid.unsqueeze(1) * i_loss).mean()
return flow_loss
class CREStereoLightning(LightningModule):
def __init__(self, args, logger=None, pattern_path=''):
super().__init__()
self.batch_size = args.batch_size
self.wandb_logger = logger
self.imwidth = args.image_width
self.imheight = args.image_height
self.data_type = 'blender' if 'blender' in args.training_data_path else 'ctd'
self.eval_type = 'kinect' if 'kinect' in args.test_data_path else args.training_data_path
self.lr = args.base_lr
self.T_max = args.t_max if args.t_max else None
self.pattern_attention = args.pattern_attention
self.pattern_path = pattern_path
self.data_path = args.training_data_path
self.test_data_path = args.test_data_path
self.data_limit = args.data_limit # between 0 and 1.
self.model = Model(
max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False
)
if args.scheduler == 'default':
self.automatic_optimization = False
# so I can access it in adjust learn rate more easily
self.n_total_epoch = args.n_total_epoch
self.base_lr = args.base_lr
def train_dataloader(self):
# we never train on kinect
is_kinect = False
if self.data_type == 'blender':
dataset = BlenderDataset(
root=self.data_path,
pattern_path=self.pattern_path,
use_lightning=True,
data_type='kinect' if is_kinect else 'blender',
disp_avail=not is_kinect,
data_limit = self.data_limit,
)
elif self.data_type == 'ctd':
dataset = CTDDataset(
root=self.data_path,
pattern_path=self.pattern_path,
use_lightning=True,
data_limit = self.data_limit,
)
dataloader = DataLoader(
dataset,
self.batch_size,
shuffle=True,
num_workers=4,
drop_last=True,
persistent_workers=True,
pin_memory=True,
)
return dataloader
def val_dataloader(self):
# we also don't want to validate on kinect data
is_kinect = False
if self.data_type == 'blender':
test_dataset = BlenderDataset(
root=self.data_path,
pattern_path=self.pattern_path,
test_set=True,
use_lightning=True,
data_type='kinect' if is_kinect else 'blender',
disp_avail=not is_kinect,
data_limit = self.data_limit,
)
elif self.data_type == 'ctd':
test_dataset = CTDDataset(
root=self.data_path,
pattern_path=self.pattern_path,
test_set=True,
use_lightning=True,
data_limit = self.data_limit,
)
test_dataloader = DataLoader(
test_dataset,
self.batch_size,
shuffle=False,
num_workers=4,
drop_last=False,
persistent_workers=True,
pin_memory=True
)
return test_dataloader
def test_dataloader(self):
is_kinect = self.eval_type == 'kinect'
if self.data_type == 'blender':
test_dataset = BlenderDataset(
root=self.test_data_path,
pattern_path=self.pattern_path,
test_set=True,
split=0. if is_kinect else 0.9, # if we test on kinect data, use all available samples for test set
use_lightning=True,
augment=False,
disp_avail=not is_kinect,
data_type='kinect' if is_kinect else 'blender',
data_limit = self.data_limit,
)
elif self.data_type == 'ctd':
test_dataset = CTDDataset(
root=self.test_data_path,
pattern_path=self.pattern_path,
test_set=True,
use_lightning=True,
augment=False,
data_limit = self.data_limit,
)
test_dataloader = DataLoader(
test_dataset,
1 if is_kinect else self.batch_size,
shuffle=False,
num_workers=4,
drop_last=False,
persistent_workers=True,
pin_memory=True
)
return test_dataloader
def forward(
self,
image1,
image2,
flow_init=None,
iters=10,
upsample=True,
test_mode=False,
):
return self.model(image1, image2, flow_init, iters, upsample, test_mode, self.pattern_attention)
def training_step(self, batch, batch_idx):
left, right, gt_disp, valid_mask = batch
flow_predictions = self.forward(left, right)
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512]
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
loss = sequence_loss(
flow_predictions, gt_flow, valid_mask, gamma=0.8
)
if batch_idx % 128 == 0:
image_log = log_images(left, right, flow_predictions, gt_disp)
image_log['key'] = 'debug_train'
self.wandb_logger.log_image(**image_log)
self.log("train_loss", loss)
# update learn rate every N epochs
if self.trainer.is_last_batch:
self.adjust_learning_rate()
return loss
def validation_step(self, batch, batch_idx):
left, right, gt_disp, valid_mask = batch
flow_predictions = self.forward(left, right, test_mode=True)
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512]
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
val_loss = sequence_loss(
flow_predictions, gt_flow, valid_mask, gamma=0.8
)
self.log("val_loss", val_loss)
of = {}
for threshold in [0.1, 0.5, 1, 2, 5]:
of[threshold] = outlier_fraction(flow_predictions[0], gt_flow, valid_mask, threshold)
self.log("outlier_fraction", of)
# print(', '.join(f'of{thr}={val}' for thr, val in of.items()))
if batch_idx % 8 == 0:
images = log_images(left, right, flow_predictions, gt_disp)
self.wandb_logger.log_image(**images)
def test_step(self, batch, batch_idx):
left, right, gt_disp, valid_mask = batch
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512]
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
flow_predictions = self.forward(left, right, test_mode=True)
test_loss = sequence_loss(
flow_predictions, gt_flow, valid_mask, gamma=0.8
)
self.log("test_loss", test_loss)
of = {}
for threshold in [0.1, 0.5, 1, 2, 5]:
of[str(threshold)] = outlier_fraction(flow_predictions[0], gt_flow, valid_mask, threshold)
self.log("outlier_fraction", of)
images = log_images(left, right, flow_predictions, gt_disp)
images['images'].append(gt_disp)
images['caption'].append('GT Disp')
self.wandb_logger.log_image(**images)
def predict_step(self, batch, batch_idx, dataloader_idx=0):
return self(batch)
def configure_optimizers(self):
optimizer = optim.Adam(self.model.parameters(), lr=self.lr, betas=(0.9, 0.999))
if not self.automatic_optimization:
return optimizer
lr_scheduler = {
'scheduler': torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=self.T_max if self.T_max else len(self.train_dataloader())/self.batch_size,
),
'name': 'LR Scheduler',
}
return [optimizer], [lr_scheduler]
def adjust_learning_rate(self):
optimizer = self.optimizers().optimizer
epoch = self.trainer.current_epoch+1
warm_up = 0.02
const_range = 0.6
min_lr_rate = 0.05
if epoch <= self.n_total_epoch * warm_up:
lr = (1 - min_lr_rate) * self.base_lr / (
self.n_total_epoch * warm_up
) * epoch + min_lr_rate * self.base_lr
elif self.n_total_epoch * warm_up < epoch <= self.n_total_epoch * const_range:
lr = self.base_lr
else:
lr = (min_lr_rate - 1) * self.base_lr / (
(1 - const_range) * self.n_total_epoch
) * epoch + (1 - min_lr_rate * const_range) / (1 - const_range) * self.base_lr
for param_group in optimizer.param_groups:
param_group['lr'] = lr
self.log('train/lr', lr)
if __name__ == "__main__":
# wandb.init(project='crestereo-lightning')
wandb_logger = WandbLogger(project="crestereo-lightning", log_model=True)
# train configuration
args = parse_yaml("cfgs/train.yaml")
wandb_logger.experiment.config.update(args._asdict())
config = wandb.config
data_limit = config.data_limit
if 'blender' in config.training_data_path:
# this was used for our blender renders
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
elif 'ctd' in config.training_data_path:
# this one is used for ctd
pattern_path = '/home/nils/kinect_from_settings.png'
devices = min(config.nr_gpus, torch.cuda.device_count())
if devices != config.nr_gpus:
print(f'Using less devices than expected! ({devices} / {config.nr_gpus})')
model = CREStereoLightning(
# args,
config,
wandb_logger,
pattern_path,
# lr=0.00017378008287493763, # found with auto_lr_find=True
)
model_checkpoint = ModelCheckpoint(
monitor="val_loss",
mode="min",
save_top_k=2,
save_last=True,
)
if config.scheduler == 'default':
trainer = Trainer(
accelerator='gpu',
devices=devices,
max_epochs=config.n_total_epoch,
callbacks=[
EarlyStopping(
monitor="val_loss",
mode="min",
patience=8,
),
LearningRateMonitor(),
model_checkpoint,
],
strategy=DDPSpawnStrategy(find_unused_parameters=False),
# auto_scale_batch_size='binsearch',
# auto_lr_find=True,
# accumulate_grad_batches=4, # needed to disable for manual optimization
deterministic=True,
check_val_every_n_epoch=1,
limit_val_batches=64,
limit_test_batches=256,
logger=wandb_logger,
default_root_dir=config.log_dir_lightning,
)
else:
trainer = Trainer(
accelerator='gpu',
devices=devices,
max_epochs=config.n_total_epoch,
callbacks=[
EarlyStopping(
monitor="val_loss",
mode="min",
patience=8,
),
LearningRateMonitor(),
model_checkpoint,
],
strategy=DDPSpawnStrategy(find_unused_parameters=False),
# auto_scale_batch_size='binsearch',
# auto_lr_find=True,
accumulate_grad_batches=4, # needed to disable for manual optimization
deterministic=True,
check_val_every_n_epoch=1,
limit_val_batches=64,
limit_test_batches=256,
logger=wandb_logger,
default_root_dir=config.log_dir_lightning,
)
# trainer.tune(model)
trainer.fit(model)
# trainer.validate(chkpt_path=model_checkpoint.best_model_path)
trainer.test(model)