import os import sys import time import logging from collections import namedtuple import yaml # from tensorboardX import SummaryWriter from nets import Model # from dataset import CREStereoDataset from dataset import BlenderDataset, CREStereoDataset, CTDDataset import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.callbacks.early_stopping import EarlyStopping seed_everything(42, workers=True) import wandb import numpy as np import cv2 def normalize_and_colormap(img): ret = (img - img.min()) / (img.max() - img.min()) * 255.0 if isinstance(ret, torch.Tensor): ret = ret.cpu().detach().numpy() ret = ret.astype("uint8") ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO) return ret def log_images(left, right, pred_disp, gt_disp, wandb_logger=None): # wandb_logger.log_text('test') # return log = {} batch_idx = 1 if isinstance(pred_disp, list): pred_disp = pred_disp[-1] pred_disp = torch.squeeze(pred_disp[:, 0, :, :]) gt_disp = torch.squeeze(gt_disp[:, 0, :, :]) left = torch.squeeze(left[:, 0, :, :]) right = torch.squeeze(right[:, 0, :, :]) disp = pred_disp disp_error = gt_disp - disp input_left = left[batch_idx].cpu().detach().numpy()# .transpose(1,2,0) input_right = right[batch_idx].cpu().detach().numpy()# .transpose(1,2,0) wandb_log = dict( key='samples', images=[ normalize_and_colormap(pred_disp[batch_idx]), normalize_and_colormap(abs(disp_error[batch_idx])), normalize_and_colormap(gt_disp[batch_idx]), input_left, input_right, ], caption=[ f"Disparity \n{pred_disp[batch_idx].min():.{2}f}/{pred_disp[batch_idx].max():.{2}f}", f"Disp. Error\n{disp_error[batch_idx].min():.{2}f}/{disp_error[batch_idx].max():.{2}f}\n{abs(disp_error[batch_idx]).mean():.{2}f}", f"GT Disp Vis \n{gt_disp[batch_idx].min():.{2}f}/{gt_disp[batch_idx].max():.{2}f}", "Input Left", "Input Right" ], ) return wandb_log def parse_yaml(file_path: str) -> namedtuple: """Parse yaml configuration file and return the object in `namedtuple`.""" with open(file_path, "rb") as f: cfg: dict = yaml.safe_load(f) args = namedtuple("train_args", cfg.keys())(*cfg.values()) return args def format_time(elapse): elapse = int(elapse) hour = elapse // 3600 minute = elapse % 3600 // 60 seconds = elapse % 60 return "{:02d}:{:02d}:{:02d}".format(hour, minute, seconds) def ensure_dir(path): if not os.path.exists(path): os.makedirs(path, exist_ok=True) def adjust_learning_rate(optimizer, epoch): warm_up = 0.02 const_range = 0.6 min_lr_rate = 0.05 if epoch <= args.n_total_epoch * warm_up: lr = (1 - min_lr_rate) * args.base_lr / ( args.n_total_epoch * warm_up ) * epoch + min_lr_rate * args.base_lr elif args.n_total_epoch * warm_up < epoch <= args.n_total_epoch * const_range: lr = args.base_lr else: lr = (min_lr_rate - 1) * args.base_lr / ( (1 - const_range) * args.n_total_epoch ) * epoch + (1 - min_lr_rate * const_range) / (1 - const_range) * args.base_lr for param_group in optimizer.param_groups: param_group['lr'] = lr def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False): ''' valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W) flow_preds[0]: (B, 2, H, W) flow_gt: (B, 2, H, W) ''' if test: # print('sequence loss') if valid.shape != (2, 480, 640): valid = torch.stack([valid, valid])#.transpose(0,1)#.transpose(1,2) # print(valid.shape) #valid = torch.stack([valid, valid]) # print(valid.shape) if valid.shape != (2, 480, 640): valid = valid.transpose(0,1) # print(valid.shape) # print(valid.shape) # print(flow_preds[0].shape) # print(flow_gt.shape) n_predictions = len(flow_preds) flow_loss = 0.0 # TEST flow_gt = torch.squeeze(flow_gt, dim=-1) for i in range(n_predictions): i_weight = gamma ** (n_predictions - i - 1) i_loss = torch.abs(flow_preds[i] - flow_gt) flow_loss += i_weight * (valid.unsqueeze(1) * i_loss).mean() return flow_loss class CREStereoLightning(LightningModule): def __init__(self, args, logger): super().__init__() self.batch_size = args.batch_size self.wandb_logger = logger self.model = Model( max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False ) def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False, self_attend_right=True): return self.model(image1, image2, flow_init, iters, upsample, test_mode, self_attend_right) def training_step(self, batch, batch_idx): left, right, gt_disp, valid_mask = batch flow_predictions = self.forward(left, right) gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512] loss = sequence_loss( flow_predictions, gt_flow, valid_mask, gamma=0.8 ) self.log("train_loss", loss) return loss def validation_step(self, batch, batch_idx): left, right, gt_disp, valid_mask = batch flow_predictions = self.forward(left, right, test_mode=True) gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512] val_loss = sequence_loss( flow_predictions, gt_flow, valid_mask, gamma=0.8 ) self.log("val_loss", val_loss) if batch_idx % 4 == 0: self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp)) def test_step(self, batch, batch_idx): left, right, gt_disp, valid_mask = batch gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512] flow_predictions = self.forward(left, right, test_mode=True) test_loss = sequence_loss( flow_predictions, gt_flow, valid_mask, gamma=0.8 ) self.log("test_loss", test_loss) print('test_batch_idx:', batch_idx) self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp)) def configure_optimizers(self): return optim.Adam(self.model.parameters(), lr=0.1, betas=(0.9, 0.999)) if __name__ == "__main__": # train configuration args = parse_yaml("cfgs/train.yaml") pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png' wandb_logger = WandbLogger(project="crestereo-lightning") wandb.config.update(args._asdict()) model = CREStereoLightning(args, wandb_logger) dataset = BlenderDataset( root=args.training_data_path, pattern_path=pattern_path, use_lightning=True, ) test_dataset = BlenderDataset( root=args.training_data_path, pattern_path=pattern_path, test_set=True, use_lightning=True, ) dataloader = DataLoader( dataset, args.batch_size, shuffle=True, num_workers=16, drop_last=True, persistent_workers=True, pin_memory=True, ) # num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True) test_dataloader = DataLoader( test_dataset, args.batch_size, shuffle=False, num_workers=16, drop_last=False, persistent_workers=True, pin_memory=True ) trainer = Trainer( accelerator='gpu', devices=2, max_epochs=args.n_total_epoch, callbacks=[ EarlyStopping( monitor="val_loss", mode="min", patience=4, ) ], accumulate_grad_batches=8, deterministic=True, check_val_every_n_epoch=1, limit_val_batches=24, limit_test_batches=24, logger=wandb_logger, default_root_dir=args.log_dir_lightning, ) trainer.fit(model, dataloader, test_dataloader)