From 37c537ca314b0feca1d90377a96d39586a7dfdc4 Mon Sep 17 00:00:00 2001 From: "Cpt.Captain" Date: Sat, 27 Aug 2022 11:21:00 +0200 Subject: [PATCH] fix lightning, prepare sweeps --- cfgs/train.yaml | 6 +- dataset.py | 6 +- nets/crestereo.py | 14 +-- train.py | 1 + train_lightning.py | 213 +++++++++++++++++++++++++++++---------------- 5 files changed, 152 insertions(+), 88 deletions(-) diff --git a/cfgs/train.yaml b/cfgs/train.yaml index 7232137..6c152f0 100644 --- a/cfgs/train.yaml +++ b/cfgs/train.yaml @@ -1,6 +1,8 @@ seed: 0 mixed_precision: false -base_lr: 4.0e-4 +# base_lr: 4.0e-4 +base_lr: 0.001 +t_max: 161 nr_gpus: 3 batch_size: 2 @@ -16,7 +18,7 @@ max_disp: 256 image_width: 640 image_height: 480 # training_data_path: "./stereo_trainset/crestereo" -pattern_attention: true +pattern_attention: false dataset: "blender" # training_data_path: "/media/Data1/connecting_the_dots_data/ctd_data/" training_data_path: "/media/Data1/connecting_the_dots_data/blender_renders/data" diff --git a/dataset.py b/dataset.py index 3e4f36e..2649992 100644 --- a/dataset.py +++ b/dataset.py @@ -384,7 +384,7 @@ class BlenderDataset(CTDDataset): ) if not self.use_lightning: - right_img = right_img.transpose((2, 0, 1)).astype("uint8") + # right_img = right_img.transpose((2, 0, 1)).astype("uint8") return { "left": left_img, "right": right_img, @@ -408,7 +408,7 @@ class BlenderDataset(CTDDataset): # return disp.astype(np.float32) / 32 # FIXME temporarily increase disparity until new data with better depth values is generated # higher values seem to speedup convergence, but introduce much stronger artifacting - # mystery_factor = 150 - mystery_factor = 1 + mystery_factor = 150 + # mystery_factor = 1 disp = (baseline * fl * mystery_factor) / depth return disp.astype(np.float32) diff --git a/nets/crestereo.py b/nets/crestereo.py index dc6e997..6cc0bfd 100644 --- a/nets/crestereo.py +++ b/nets/crestereo.py @@ -38,10 +38,10 @@ class CREStereo(nn.Module): self.update_block = BasicUpdateBlock(hidden_dim=self.hidden_dim, cor_planes=4 * 9, mask_size=4) # # NOTE Position_encoding as workaround for TensorRt - image1_shape = [1, 2, 480, 640] - self.pos_encoding_fn_small = PositionEncodingSine( - d_model=256, max_shape=(image1_shape[2] // 16, image1_shape[3] // 16) - ) + # image1_shape = [1, 2, 480, 640] + # self.pos_encoding_fn_small = PositionEncodingSine( + # d_model=256, max_shape=(image1_shape[2] // 16, image1_shape[3] // 16) + # ) # loftr self.self_att_fn = LocalFeatureTransformer( @@ -141,10 +141,12 @@ class CREStereo(nn.Module): d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16) ) # 'n c h w -> n (h w) c' - x_tmp = self.pos_encoding_fn_small(fmap1_dw16) + # x_tmp = self.pos_encoding_fn_small(fmap1_dw16) + x_tmp = pos_encoding_fn_small(fmap1_dw16) fmap1_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1]) # 'n c h w -> n (h w) c' - x_tmp = self.pos_encoding_fn_small(fmap2_dw16) + # x_tmp = self.pos_encoding_fn_small(fmap2_dw16) + x_tmp = pos_encoding_fn_small(fmap2_dw16) fmap2_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1]) # FIXME experimental ! no self-attention for pattern diff --git a/train.py b/train.py index b036e03..245a213 100644 --- a/train.py +++ b/train.py @@ -419,6 +419,7 @@ def main(args): # print(f'left {left.shape}, right {right.shape}') # left = left.transpose([2, 0, 1]) right = right.transpose([1, 2, 0]) + # right = right.transpose(2, 3).transpose(2, 3)#.transpose(1, 2) # print(f'left {left.shape}, right {right.shape}') diff --git a/train_lightning.py b/train_lightning.py index 56ee5cd..de5bce7 100644 --- a/train_lightning.py +++ b/train_lightning.py @@ -5,10 +5,8 @@ import logging from collections import namedtuple import yaml -# from tensorboardX import SummaryWriter from nets import Model -# from dataset import CREStereoDataset from dataset import BlenderDataset, CREStereoDataset, CTDDataset import torch @@ -18,8 +16,11 @@ import torch.optim as optim from torch.utils.data import DataLoader from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning import Trainer, seed_everything -from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.callbacks import LearningRateMonitor +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.strategies import DDPSpawnStrategy seed_everything(42, workers=True) @@ -39,11 +40,9 @@ def normalize_and_colormap(img): return ret -def log_images(left, right, pred_disp, gt_disp, wandb_logger=None): - # wandb_logger.log_text('test') - # return +def log_images(left, right, pred_disp, gt_disp): log = {} - batch_idx = 1 + batch_idx = 0 if isinstance(pred_disp, list): pred_disp = pred_disp[-1] @@ -100,32 +99,13 @@ def ensure_dir(path): os.makedirs(path, exist_ok=True) -def adjust_learning_rate(optimizer, epoch): - - warm_up = 0.02 - const_range = 0.6 - min_lr_rate = 0.05 - - if epoch <= args.n_total_epoch * warm_up: - lr = (1 - min_lr_rate) * args.base_lr / ( - args.n_total_epoch * warm_up - ) * epoch + min_lr_rate * args.base_lr - elif args.n_total_epoch * warm_up < epoch <= args.n_total_epoch * const_range: - lr = args.base_lr - else: - lr = (min_lr_rate - 1) * args.base_lr / ( - (1 - const_range) * args.n_total_epoch - ) * epoch + (1 - min_lr_rate * const_range) / (1 - const_range) * args.base_lr - - for param_group in optimizer.param_groups: - param_group['lr'] = lr - def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False): ''' valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W) flow_preds[0]: (B, 2, H, W) flow_gt: (B, 2, H, W) ''' + """ if test: # print('sequence loss') if valid.shape != (2, 480, 640): @@ -136,6 +116,7 @@ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False): if valid.shape != (2, 480, 640): valid = valid.transpose(0,1) # print(valid.shape) + """ # print(valid.shape) # print(flow_preds[0].shape) # print(flow_gt.shape) @@ -143,7 +124,7 @@ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False): flow_loss = 0.0 # TEST - flow_gt = torch.squeeze(flow_gt, dim=-1) + # flow_gt = torch.squeeze(flow_gt, dim=-1) for i in range(n_predictions): i_weight = gamma ** (n_predictions - i - 1) @@ -155,16 +136,88 @@ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False): class CREStereoLightning(LightningModule): - def __init__(self, args, logger): + def __init__(self, args, logger, pattern_path, data_path): super().__init__() self.batch_size = args.batch_size self.wandb_logger = logger + self.lr = args.base_lr + print(f'lr = {self.lr}') + self.T_max = args.t_max if args.t_max else None + self.pattern_attention = args.pattern_attention + self.pattern_path = pattern_path + self.data_path = data_path self.model = Model( max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False ) - def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False, self_attend_right=True): - return self.model(image1, image2, flow_init, iters, upsample, test_mode, self_attend_right) + def train_dataloader(self): + dataset = BlenderDataset( + root=self.data_path, + pattern_path=self.pattern_path, + use_lightning=True, + ) + dataloader = DataLoader( + dataset, + self.batch_size, + shuffle=True, + num_workers=4, + drop_last=True, + persistent_workers=True, + pin_memory=True, + ) + # num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True) + return dataloader + + def val_dataloader(self): + test_dataset = BlenderDataset( + root=self.data_path, + pattern_path=self.pattern_path, + test_set=True, + use_lightning=True, + ) + + test_dataloader = DataLoader( + test_dataset, + self.batch_size, + shuffle=False, + num_workers=4, + drop_last=False, + persistent_workers=True, + pin_memory=True + ) + # num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True) + return test_dataloader + + def test_dataloader(self): + # TODO change this to use IRL data? + test_dataset = BlenderDataset( + root=self.data_path, + pattern_path=self.pattern_path, + test_set=True, + use_lightning=True, + ) + + test_dataloader = DataLoader( + test_dataset, + self.batch_size, + shuffle=False, + num_workers=4, + drop_last=False, + persistent_workers=True, + pin_memory=True + ) + return test_dataloader + + def forward( + self, + image1, + image2, + flow_init=None, + iters=10, + upsample=True, + test_mode=False, + ): + return self.model(image1, image2, flow_init, iters, upsample, test_mode, self.pattern_attention) def training_step(self, batch, batch_idx): left, right, gt_disp, valid_mask = batch @@ -174,6 +227,10 @@ class CREStereoLightning(LightningModule): loss = sequence_loss( flow_predictions, gt_flow, valid_mask, gamma=0.8 ) + if batch_idx % 128 == 0: + image_log = log_images(left, right, flow_predictions, gt_disp) + image_log['key'] = 'debug_train' + self.wandb_logger.log_image(**image_log) self.log("train_loss", loss) return loss @@ -186,22 +243,31 @@ class CREStereoLightning(LightningModule): flow_predictions, gt_flow, valid_mask, gamma=0.8 ) self.log("val_loss", val_loss) - if batch_idx % 4 == 0: + if batch_idx % 8 == 0: self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp)) def test_step(self, batch, batch_idx): left, right, gt_disp, valid_mask = batch - gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512] + gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] + gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512] flow_predictions = self.forward(left, right, test_mode=True) test_loss = sequence_loss( flow_predictions, gt_flow, valid_mask, gamma=0.8 ) self.log("test_loss", test_loss) - print('test_batch_idx:', batch_idx) self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp)) def configure_optimizers(self): - return optim.Adam(self.model.parameters(), lr=0.1, betas=(0.9, 0.999)) + optimizer = optim.Adam(self.model.parameters(), lr=self.lr, betas=(0.9, 0.999)) + print('len(self.train_dataloader)', len(self.train_dataloader())) + lr_scheduler = { + 'scheduler': torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=self.T_max if self.T_max else len(self.train_dataloader())/self.batch_size, + ), + 'name': 'CosineAnnealingLRScheduler', + } + return [optimizer], [lr_scheduler] if __name__ == "__main__": @@ -209,61 +275,54 @@ if __name__ == "__main__": args = parse_yaml("cfgs/train.yaml") pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png' - wandb_logger = WandbLogger(project="crestereo-lightning") - wandb.config.update(args._asdict()) - - model = CREStereoLightning(args, wandb_logger) - - dataset = BlenderDataset( - root=args.training_data_path, - pattern_path=pattern_path, - use_lightning=True, - ) - test_dataset = BlenderDataset( - root=args.training_data_path, - pattern_path=pattern_path, - test_set=True, - use_lightning=True, - ) - - dataloader = DataLoader( - dataset, - args.batch_size, - shuffle=True, - num_workers=16, - drop_last=True, - persistent_workers=True, - pin_memory=True, - ) - # num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True) - test_dataloader = DataLoader( - test_dataset, - args.batch_size, - shuffle=False, - num_workers=16, - drop_last=False, - persistent_workers=True, - pin_memory=True - ) + run = wandb.init(project="crestereo-lightning", config=args._asdict(), tags=['new_scheduler', 'default_lr', f'{"" if args.pattern_attention else "no-"}pattern-attention'], notes='') + run.config.update(args._asdict()) + config = wandb.config + wandb_logger = WandbLogger(project="crestereo-lightning", id=run.id, log_model=True) + # wandb_logger = WandbLogger(project="crestereo-lightning", log_model='all') + # wandb_logger.experiment.config.update(args._asdict()) + + model = CREStereoLightning( + # args, + config, + wandb_logger, + pattern_path, + args.training_data_path, + # lr=0.00017378008287493763, # found with auto_lr_find=True + ) + # NOTE turn this down once it's working, this might use too much space + # wandb_logger.watch(model, log_graph=False) #, log='all') trainer = Trainer( accelerator='gpu', - devices=2, + devices=args.nr_gpus, max_epochs=args.n_total_epoch, callbacks=[ EarlyStopping( monitor="val_loss", mode="min", - patience=4, + patience=16, + ), + LearningRateMonitor(), + ModelCheckpoint( + monitor="val_loss", + mode="min", + save_top_k=2, + save_last=True, ) ], - accumulate_grad_batches=8, + strategy=DDPSpawnStrategy(find_unused_parameters=False), + # auto_scale_batch_size='binsearch', + # auto_lr_find=True, + accumulate_grad_batches=4, deterministic=True, check_val_every_n_epoch=1, - limit_val_batches=24, - limit_test_batches=24, + limit_val_batches=64, + limit_test_batches=256, logger=wandb_logger, default_root_dir=args.log_dir_lightning, ) - trainer.fit(model, dataloader, test_dataloader) + # trainer.tune(model) + trainer.fit(model) + trainer.validate()