CREStereo Repository for the 'Towards accurate and robust depth estimation' project
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
CREStereo-pytorch-nxt/train_lightning.py

328 lines
11 KiB

import os
import sys
import time
import logging
from collections import namedtuple
import yaml
from nets import Model
from dataset import BlenderDataset, CREStereoDataset, CTDDataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.strategies import DDPSpawnStrategy
seed_everything(42, workers=True)
import wandb
import numpy as np
import cv2
def normalize_and_colormap(img):
ret = (img - img.min()) / (img.max() - img.min()) * 255.0
if isinstance(ret, torch.Tensor):
ret = ret.cpu().detach().numpy()
ret = ret.astype("uint8")
ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO)
return ret
def log_images(left, right, pred_disp, gt_disp):
log = {}
batch_idx = 0
if isinstance(pred_disp, list):
pred_disp = pred_disp[-1]
pred_disp = torch.squeeze(pred_disp[:, 0, :, :])
gt_disp = torch.squeeze(gt_disp[:, 0, :, :])
left = torch.squeeze(left[:, 0, :, :])
right = torch.squeeze(right[:, 0, :, :])
disp = pred_disp
disp_error = gt_disp - disp
input_left = left[batch_idx].cpu().detach().numpy()# .transpose(1,2,0)
input_right = right[batch_idx].cpu().detach().numpy()# .transpose(1,2,0)
wandb_log = dict(
key='samples',
images=[
normalize_and_colormap(pred_disp[batch_idx]),
normalize_and_colormap(abs(disp_error[batch_idx])),
normalize_and_colormap(gt_disp[batch_idx]),
input_left,
input_right,
],
caption=[
f"Disparity \n{pred_disp[batch_idx].min():.{2}f}/{pred_disp[batch_idx].max():.{2}f}",
f"Disp. Error\n{disp_error[batch_idx].min():.{2}f}/{disp_error[batch_idx].max():.{2}f}\n{abs(disp_error[batch_idx]).mean():.{2}f}",
f"GT Disp Vis \n{gt_disp[batch_idx].min():.{2}f}/{gt_disp[batch_idx].max():.{2}f}",
"Input Left",
"Input Right"
],
)
return wandb_log
def parse_yaml(file_path: str) -> namedtuple:
"""Parse yaml configuration file and return the object in `namedtuple`."""
with open(file_path, "rb") as f:
cfg: dict = yaml.safe_load(f)
args = namedtuple("train_args", cfg.keys())(*cfg.values())
return args
def format_time(elapse):
elapse = int(elapse)
hour = elapse // 3600
minute = elapse % 3600 // 60
seconds = elapse % 60
return "{:02d}:{:02d}:{:02d}".format(hour, minute, seconds)
def ensure_dir(path):
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False):
'''
valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W)
flow_preds[0]: (B, 2, H, W)
flow_gt: (B, 2, H, W)
'''
"""
if test:
# print('sequence loss')
if valid.shape != (2, 480, 640):
valid = torch.stack([valid, valid])#.transpose(0,1)#.transpose(1,2)
# print(valid.shape)
#valid = torch.stack([valid, valid])
# print(valid.shape)
if valid.shape != (2, 480, 640):
valid = valid.transpose(0,1)
# print(valid.shape)
"""
# print(valid.shape)
# print(flow_preds[0].shape)
# print(flow_gt.shape)
n_predictions = len(flow_preds)
flow_loss = 0.0
# TEST
# flow_gt = torch.squeeze(flow_gt, dim=-1)
for i in range(n_predictions):
i_weight = gamma ** (n_predictions - i - 1)
i_loss = torch.abs(flow_preds[i] - flow_gt)
flow_loss += i_weight * (valid.unsqueeze(1) * i_loss).mean()
return flow_loss
class CREStereoLightning(LightningModule):
def __init__(self, args, logger, pattern_path, data_path):
super().__init__()
self.batch_size = args.batch_size
self.wandb_logger = logger
self.lr = args.base_lr
print(f'lr = {self.lr}')
self.T_max = args.t_max if args.t_max else None
self.pattern_attention = args.pattern_attention
self.pattern_path = pattern_path
self.data_path = data_path
self.model = Model(
max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False
)
def train_dataloader(self):
dataset = BlenderDataset(
root=self.data_path,
pattern_path=self.pattern_path,
use_lightning=True,
)
dataloader = DataLoader(
dataset,
self.batch_size,
shuffle=True,
num_workers=4,
drop_last=True,
persistent_workers=True,
pin_memory=True,
)
# num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True)
return dataloader
def val_dataloader(self):
test_dataset = BlenderDataset(
root=self.data_path,
pattern_path=self.pattern_path,
test_set=True,
use_lightning=True,
)
test_dataloader = DataLoader(
test_dataset,
self.batch_size,
shuffle=False,
num_workers=4,
drop_last=False,
persistent_workers=True,
pin_memory=True
)
# num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True)
return test_dataloader
def test_dataloader(self):
# TODO change this to use IRL data?
test_dataset = BlenderDataset(
root=self.data_path,
pattern_path=self.pattern_path,
test_set=True,
use_lightning=True,
)
test_dataloader = DataLoader(
test_dataset,
self.batch_size,
shuffle=False,
num_workers=4,
drop_last=False,
persistent_workers=True,
pin_memory=True
)
return test_dataloader
def forward(
self,
image1,
image2,
flow_init=None,
iters=10,
upsample=True,
test_mode=False,
):
return self.model(image1, image2, flow_init, iters, upsample, test_mode, self.pattern_attention)
def training_step(self, batch, batch_idx):
left, right, gt_disp, valid_mask = batch
flow_predictions = self.forward(left, right)
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512]
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
loss = sequence_loss(
flow_predictions, gt_flow, valid_mask, gamma=0.8
)
if batch_idx % 128 == 0:
image_log = log_images(left, right, flow_predictions, gt_disp)
image_log['key'] = 'debug_train'
self.wandb_logger.log_image(**image_log)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
left, right, gt_disp, valid_mask = batch
flow_predictions = self.forward(left, right, test_mode=True)
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512]
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
val_loss = sequence_loss(
flow_predictions, gt_flow, valid_mask, gamma=0.8
)
self.log("val_loss", val_loss)
if batch_idx % 8 == 0:
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp))
def test_step(self, batch, batch_idx):
left, right, gt_disp, valid_mask = batch
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512]
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
flow_predictions = self.forward(left, right, test_mode=True)
test_loss = sequence_loss(
flow_predictions, gt_flow, valid_mask, gamma=0.8
)
self.log("test_loss", test_loss)
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp))
def configure_optimizers(self):
optimizer = optim.Adam(self.model.parameters(), lr=self.lr, betas=(0.9, 0.999))
print('len(self.train_dataloader)', len(self.train_dataloader()))
lr_scheduler = {
'scheduler': torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=self.T_max if self.T_max else len(self.train_dataloader())/self.batch_size,
),
'name': 'CosineAnnealingLRScheduler',
}
return [optimizer], [lr_scheduler]
if __name__ == "__main__":
# train configuration
args = parse_yaml("cfgs/train.yaml")
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
run = wandb.init(project="crestereo-lightning", config=args._asdict(), tags=['new_scheduler', 'default_lr', f'{"" if args.pattern_attention else "no-"}pattern-attention'], notes='')
run.config.update(args._asdict())
config = wandb.config
wandb_logger = WandbLogger(project="crestereo-lightning", id=run.id, log_model=True)
# wandb_logger = WandbLogger(project="crestereo-lightning", log_model='all')
# wandb_logger.experiment.config.update(args._asdict())
model = CREStereoLightning(
# args,
config,
wandb_logger,
pattern_path,
args.training_data_path,
# lr=0.00017378008287493763, # found with auto_lr_find=True
)
# NOTE turn this down once it's working, this might use too much space
# wandb_logger.watch(model, log_graph=False) #, log='all')
trainer = Trainer(
accelerator='gpu',
devices=args.nr_gpus,
max_epochs=args.n_total_epoch,
callbacks=[
EarlyStopping(
monitor="val_loss",
mode="min",
patience=16,
),
LearningRateMonitor(),
ModelCheckpoint(
monitor="val_loss",
mode="min",
save_top_k=2,
save_last=True,
)
],
strategy=DDPSpawnStrategy(find_unused_parameters=False),
# auto_scale_batch_size='binsearch',
# auto_lr_find=True,
accumulate_grad_batches=4,
deterministic=True,
check_val_every_n_epoch=1,
limit_val_batches=64,
limit_test_batches=256,
logger=wandb_logger,
default_root_dir=args.log_dir_lightning,
)
# trainer.tune(model)
trainer.fit(model)
trainer.validate()