import os import sys import time import logging from collections import namedtuple import yaml from nets import Model from dataset import BlenderDataset, CREStereoDataset, CTDDataset import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks import LearningRateMonitor from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.strategies import DDPSpawnStrategy seed_everything(42, workers=True) import wandb import numpy as np import cv2 def normalize_and_colormap(img): ret = (img - img.min()) / (img.max() - img.min()) * 255.0 if isinstance(ret, torch.Tensor): ret = ret.cpu().detach().numpy() ret = ret.astype("uint8") ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO) return ret def log_images(left, right, pred_disp, gt_disp): log = {} batch_idx = 0 if isinstance(pred_disp, list): pred_disp = pred_disp[-1] pred_disp = torch.squeeze(pred_disp[:, 0, :, :]) gt_disp = torch.squeeze(gt_disp[:, 0, :, :]) left = torch.squeeze(left[:, 0, :, :]) right = torch.squeeze(right[:, 0, :, :]) disp = pred_disp disp_error = gt_disp - disp input_left = left[batch_idx].cpu().detach().numpy()# .transpose(1,2,0) input_right = right[batch_idx].cpu().detach().numpy()# .transpose(1,2,0) wandb_log = dict( key='samples', images=[ normalize_and_colormap(pred_disp[batch_idx]), normalize_and_colormap(abs(disp_error[batch_idx])), normalize_and_colormap(gt_disp[batch_idx]), input_left, input_right, ], caption=[ f"Disparity \n{pred_disp[batch_idx].min():.{2}f}/{pred_disp[batch_idx].max():.{2}f}", f"Disp. Error\n{disp_error[batch_idx].min():.{2}f}/{disp_error[batch_idx].max():.{2}f}\n{abs(disp_error[batch_idx]).mean():.{2}f}", f"GT Disp Vis \n{gt_disp[batch_idx].min():.{2}f}/{gt_disp[batch_idx].max():.{2}f}", "Input Left", "Input Right" ], ) return wandb_log def parse_yaml(file_path: str) -> namedtuple: """Parse yaml configuration file and return the object in `namedtuple`.""" with open(file_path, "rb") as f: cfg: dict = yaml.safe_load(f) args = namedtuple("train_args", cfg.keys())(*cfg.values()) return args def format_time(elapse): elapse = int(elapse) hour = elapse // 3600 minute = elapse % 3600 // 60 seconds = elapse % 60 return "{:02d}:{:02d}:{:02d}".format(hour, minute, seconds) def outlier_fraction(estimate, target, mask=None, threshold=0): def _process_inputs(estimate, target, mask): if estimate.shape != target.shape: raise Exception(f'estimate and target have to be same shape (expected {estimate.shape} == {target.shape})') if mask is None: mask = np.ones(estimate.shape, dtype=np.bool) else: mask = mask != 0 if estimate.shape != mask.shape: raise Exception(f'estimate and mask have to be same shape (expected {estimate.shape} == {mask.shape})') return estimate, target, mask estimate = torch.squeeze(estimate[:, 0, :, :]) target = torch.squeeze(target[:, 0, :, :]) estimate, target, mask = _process_inputs(estimate, target, mask) mask = mask.cpu().detach().numpy() estimate = estimate.cpu().detach().numpy() target = target.cpu().detach().numpy() diff = np.abs(estimate[mask] - target[mask]) m = (diff > threshold).sum() / mask.sum() return m def ensure_dir(path): if not os.path.exists(path): os.makedirs(path, exist_ok=True) def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False): ''' valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W) flow_preds[0]: (B, 2, H, W) flow_gt: (B, 2, H, W) ''' """ if test: # print('sequence loss') if valid.shape != (2, 480, 640): valid = torch.stack([valid, valid])#.transpose(0,1)#.transpose(1,2) # print(valid.shape) #valid = torch.stack([valid, valid]) # print(valid.shape) if valid.shape != (2, 480, 640): valid = valid.transpose(0,1) # print(valid.shape) """ # print(valid.shape) # print(flow_preds[0].shape) # print(flow_gt.shape) n_predictions = len(flow_preds) flow_loss = 0.0 # TEST # flow_gt = torch.squeeze(flow_gt, dim=-1) for i in range(n_predictions): i_weight = gamma ** (n_predictions - i - 1) i_loss = torch.abs(flow_preds[i] - flow_gt) flow_loss += i_weight * (valid.unsqueeze(1) * i_loss).mean() return flow_loss class CREStereoLightning(LightningModule): def __init__(self, args, logger=None, pattern_path='', data_path=''): super().__init__() self.batch_size = args.batch_size self.wandb_logger = logger self.data_type = 'blender' if 'blender' in data_path else 'ctd' self.lr = args.base_lr print(f'lr = {self.lr}') self.T_max = args.t_max if args.t_max else None self.pattern_attention = args.pattern_attention self.pattern_path = pattern_path self.data_path = data_path self.model = Model( max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False ) # so I can access it in adjust learn rate more easily self.n_total_epoch = args.n_total_epoch self.base_lr = args.base_lr self.automatic_optimization = False def train_dataloader(self): if self.data_type == 'blender': dataset = BlenderDataset( root=self.data_path, pattern_path=self.pattern_path, use_lightning=True, ) elif self.data_type == 'ctd': dataset = CTDDataset( root=self.data_path, pattern_path=self.pattern_path, use_lightning=True, ) dataloader = DataLoader( dataset, self.batch_size, shuffle=True, num_workers=4, drop_last=True, persistent_workers=True, pin_memory=True, ) # num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True) return dataloader def val_dataloader(self): if self.data_type == 'blender': test_dataset = BlenderDataset( root=self.data_path, pattern_path=self.pattern_path, test_set=True, use_lightning=True, ) elif self.data_type == 'ctd': test_dataset = CTDDataset( root=self.data_path, pattern_path=self.pattern_path, test_set=True, use_lightning=True, ) test_dataloader = DataLoader( test_dataset, self.batch_size, shuffle=False, num_workers=4, drop_last=False, persistent_workers=True, pin_memory=True ) # num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True) return test_dataloader def test_dataloader(self): # TODO change this to use IRL data? if self.data_type == 'blender': test_dataset = CTDDataset( root=self.data_path, pattern_path=self.pattern_path, test_set=True, use_lightning=True, ) elif self.data_type == 'ctd': test_dataset = BlenderDataset( root=self.data_path, pattern_path=self.pattern_path, test_set=True, use_lightning=True, ) test_dataloader = DataLoader( test_dataset, self.batch_size, shuffle=False, num_workers=4, drop_last=False, persistent_workers=True, pin_memory=True ) return test_dataloader def forward( self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False, ): return self.model(image1, image2, flow_init, iters, upsample, test_mode, self.pattern_attention) def training_step(self, batch, batch_idx): left, right, gt_disp, valid_mask = batch flow_predictions = self.forward(left, right) gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512] loss = sequence_loss( flow_predictions, gt_flow, valid_mask, gamma=0.8 ) if batch_idx % 128 == 0: image_log = log_images(left, right, flow_predictions, gt_disp) image_log['key'] = 'debug_train' self.wandb_logger.log_image(**image_log) self.log("train_loss", loss) # update learn rate every N epochs if self.trainer.is_last_batch: self.adjust_learning_rate() return loss def validation_step(self, batch, batch_idx): left, right, gt_disp, valid_mask = batch flow_predictions = self.forward(left, right, test_mode=True) gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512] val_loss = sequence_loss( flow_predictions, gt_flow, valid_mask, gamma=0.8 ) self.log("val_loss", val_loss) of = {} for threshold in [0.1, 0.5, 1, 2, 5]: of[threshold] = outlier_fraction(flow_predictions[0], gt_flow, valid_mask, threshold) self.log("outlier_fraction", of) # print(', '.join(f'of{thr}={val}' for thr, val in of.items())) if batch_idx % 8 == 0: self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp)) def test_step(self, batch, batch_idx): left, right, gt_disp, valid_mask = batch gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512] flow_predictions = self.forward(left, right, test_mode=True) test_loss = sequence_loss( flow_predictions, gt_flow, valid_mask, gamma=0.8 ) self.log("test_loss", test_loss) self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp)) def predict_step(self, batch, batch_idx, dataloader_idx=0): return self(batch) def configure_optimizers(self): optimizer = optim.Adam(self.model.parameters(), lr=self.lr, betas=(0.9, 0.999)) print('len(self.train_dataloader)', len(self.train_dataloader())) lr_scheduler = { 'scheduler': torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=self.T_max if self.T_max else len(self.train_dataloader())/self.batch_size, ), 'name': 'CosineAnnealingLRScheduler', } return [optimizer], [lr_scheduler] def adjust_learning_rate(self): optimizer = self.optimizers().optimizer epoch = self.trainer.current_epoch+1 warm_up = 0.02 const_range = 0.6 min_lr_rate = 0.05 if epoch <= self.n_total_epoch * warm_up: lr = (1 - min_lr_rate) * self.base_lr / ( self.n_total_epoch * warm_up ) * epoch + min_lr_rate * self.base_lr elif self.n_total_epoch * warm_up < epoch <= self.n_total_epoch * const_range: lr = self.base_lr else: lr = (min_lr_rate - 1) * self.base_lr / ( (1 - const_range) * self.n_total_epoch ) * epoch + (1 - min_lr_rate * const_range) / (1 - const_range) * self.base_lr for param_group in optimizer.param_groups: param_group['lr'] = lr if __name__ == "__main__": wandb_logger = WandbLogger(project="crestereo-lightning", log_model=True) # train configuration args = parse_yaml("cfgs/train.yaml") wandb_logger.experiment.config.update(args._asdict()) config = wandb.config if 'blender' in config.training_data_path: # this was used for our blender renders pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png' if 'ctd' in config.training_data_path: # this one is used (i hope) for ctd pattern_path = '/home/nils/kinect_from_settings.png' devices = min(config.nr_gpus, torch.cuda.device_count()) if devices != config.nr_gpus: print(f'Using less devices than expected! ({devices} / {config.nr_gpus})') model = CREStereoLightning( # args, config, wandb_logger, pattern_path, config.training_data_path, # lr=0.00017378008287493763, # found with auto_lr_find=True ) # NOTE turn this down once it's working, this might use too much space # wandb_logger.watch(model, log_graph=False) #, log='all') model_checkpoint = ModelCheckpoint( monitor="val_loss", mode="min", save_top_k=2, save_last=True, ) trainer = Trainer( accelerator='gpu', devices=devices, max_epochs=config.n_total_epoch, callbacks=[ EarlyStopping( monitor="val_loss", mode="min", patience=16, ), LearningRateMonitor(), model_checkpoint, ], strategy=DDPSpawnStrategy(find_unused_parameters=False), # auto_scale_batch_size='binsearch', # auto_lr_find=True, # accumulate_grad_batches=4, # needed to disable for manual optimization deterministic=True, check_val_every_n_epoch=1, limit_val_batches=64, limit_test_batches=256, logger=wandb_logger, default_root_dir=config.log_dir_lightning, ) # trainer.tune(model) trainer.fit(model) # trainer.validate(chkpt_path=model_checkpoint.best_model_path)