From d8169e01bcd2d83596a8f977ba0e24f80b885d34 Mon Sep 17 00:00:00 2001 From: "Cpt.Captain" Date: Wed, 24 Aug 2022 19:18:20 +0200 Subject: [PATCH] finish lightningification\n\nTraining still seems borked --- cfgs/train.yaml | 2 +- nets/crestereo.py | 17 +-- nets/extractor.py | 6 - train_lightning.py | 265 ++++++++++++++++----------------------------- 4 files changed, 104 insertions(+), 186 deletions(-) diff --git a/cfgs/train.yaml b/cfgs/train.yaml index c30a1ec..7232137 100644 --- a/cfgs/train.yaml +++ b/cfgs/train.yaml @@ -3,7 +3,7 @@ mixed_precision: false base_lr: 4.0e-4 nr_gpus: 3 -batch_size: 4 +batch_size: 2 n_total_epoch: 300 minibatch_per_epoch: 500 diff --git a/nets/crestereo.py b/nets/crestereo.py index 377e36b..dc6e997 100644 --- a/nets/crestereo.py +++ b/nets/crestereo.py @@ -22,9 +22,10 @@ except: #Ref: https://github.com/princeton-vl/RAFT/blob/master/core/raft.py class CREStereo(nn.Module): - def __init__(self, max_disp=192, mixed_precision=False, test_mode=False): + def __init__(self, max_disp=192, mixed_precision=False, test_mode=False, batch_size=4): super(CREStereo, self).__init__() + self.batch_size = batch_size self.max_flow = max_disp self.mixed_precision = mixed_precision self.test_mode = test_mode @@ -37,10 +38,10 @@ class CREStereo(nn.Module): self.update_block = BasicUpdateBlock(hidden_dim=self.hidden_dim, cor_planes=4 * 9, mask_size=4) # # NOTE Position_encoding as workaround for TensorRt - # image1_shape = [1, 2, 480, 640] - # self.pos_encoding_fn_small = PositionEncodingSine( - # d_model=256, max_shape=(image1_shape[2] // 16, image1_shape[3] // 16) - # ) + image1_shape = [1, 2, 480, 640] + self.pos_encoding_fn_small = PositionEncodingSine( + d_model=256, max_shape=(image1_shape[2] // 16, image1_shape[3] // 16) + ) # loftr self.self_att_fn = LocalFeatureTransformer( @@ -136,9 +137,9 @@ class CREStereo(nn.Module): inp_dw16 = F.avg_pool2d(inp, 4, stride=4) # positional encoding and self-attention - # pos_encoding_fn_small = PositionEncodingSine( - # d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16) - # ) + pos_encoding_fn_small = PositionEncodingSine( + d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16) + ) # 'n c h w -> n (h w) c' x_tmp = self.pos_encoding_fn_small(fmap1_dw16) fmap1_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1]) diff --git a/nets/extractor.py b/nets/extractor.py index 367a416..47cd8a3 100644 --- a/nets/extractor.py +++ b/nets/extractor.py @@ -111,12 +111,6 @@ class BasicEncoder(nn.Module): else: x_tensor = x - print() - print() - print(x_tensor.shape) - print() - print() - x_tensor = self.conv1(x_tensor) x_tensor = self.norm1(x_tensor) x_tensor = self.relu1(x_tensor) diff --git a/train_lightning.py b/train_lightning.py index b65b58e..56ee5cd 100644 --- a/train_lightning.py +++ b/train_lightning.py @@ -16,10 +16,10 @@ import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader -from pytorch_lightning.lite import LightningLite from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.callbacks.early_stopping import EarlyStopping seed_everything(42, workers=True) @@ -39,148 +39,44 @@ def normalize_and_colormap(img): return ret -def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=True, last_img=None, test=False, attend_pattern=True): - - print("Model Forwarding...") - if isinstance(left, torch.Tensor): - left = left# .cpu().detach().numpy() - imgR = right# .cpu().detach().numpy() - imgL = left - imgR = right - imgL = np.ascontiguousarray(imgL[None, :, :, :]) - imgR = np.ascontiguousarray(imgR[None, :, :, :]) - - flow_init = None - - # chosen for convenience - - imgL = torch.tensor(imgL.astype("float32")) - imgR = torch.tensor(imgR.astype("float32")) - imgL = imgL.transpose(2,3).transpose(1,2) - if imgL.shape != imgR.shape: - imgR = imgR.transpose(2,3).transpose(1,2) - - imgL_dw2 = F.interpolate( - imgL, - size=(imgL.shape[2] // 2, imgL.shape[3] // 2), - mode="bilinear", - align_corners=True, - ).clamp(min=0, max=255) - imgR_dw2 = F.interpolate( - imgR, - size=(imgL.shape[2] // 2, imgL.shape[3] // 2), - mode="bilinear", - align_corners=True, - ).clamp(min=0, max=255) - if last_img is not None: - print('using flow_initialization') - print(last_img.shape) - # FIXME trying 8bit normalization because sometimes we get really high values that don't seem to help - print(last_img.max(), last_img.min()) - if last_img.min() < 0: - # print('Negative disparity detected. shifting...') - last_img = last_img - last_img.min() - if last_img.max() > 255: - # print('Excessive disparity detected. scaling...') - last_img = last_img / (last_img.max() / 255) - - - last_img = np.dstack([last_img, last_img]) - # last_img = np.dstack([last_img, last_img, last_img]) - last_img = np.dstack([last_img]) - last_img = last_img.reshape((1, 2, 480, 640)) - # print(last_img.shape) - # print(last_img.dtype) - # print(last_img.max(), last_img.min()) - flow_init = torch.tensor(last_img.astype("float32")) - # flow_init = F.interpolate( - # last_img, - # size=(last_img.shape[0] // 2, last_img.shape[1] // 2), - # mode="bilinear", - # align_corners=True, - # ) - with torch.inference_mode(): - pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=flow_init, self_attend_right=attend_pattern) - pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2, self_attend_right=attend_pattern) - pf_base = pred_flow - if isinstance(pf_base, list): - pf_base = pred_flow[0] - pf = torch.squeeze(pf_base[:, 0, :, :])# .cpu().detach().numpy() - print('pred_flow max min') - print(pf.max(), pf.min()) - - - if not wandb_log: - if test: - return pred_flow - return torch.squeeze(pred_flow[:, 0, :, :])# .cpu().detach().numpy() - +def log_images(left, right, pred_disp, gt_disp, wandb_logger=None): + # wandb_logger.log_text('test') + # return log = {} - in_h, in_w = left.shape[:2] - - # Resize image in case the GPU memory overflows - eval_h, eval_w = (in_h,in_w) - - for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)): - pred_disp = torch.squeeze(pf[:, 0, :, :])# .cpu().detach().numpy() - pred_disp_dw2 = torch.squeeze(pf_dw2[:, 0, :, :])# .cpu().detach().numpy() - - # pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) - # pred_disp_dw2_norm = cv2.normalize(pred_disp_dw2, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) - - if i == n_iter-1: - t = float(in_w) / float(eval_w) - disp = cv2.resize(pred_disp.cpu().detach().numpy(), (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t - - log[f'disp_vis'] = wandb.Image( - normalize_and_colormap(disp), - caption=f"Disparity \n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", - ) - - log[f'pred_{i}'] = wandb.Image( - np.array([pred_disp.cpu().detach().numpy().reshape(480, 640)]), - caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", - ) - # log[f'pred_norm_{i}'] = wandb.Image( - # np.array([pred_disp_norm.reshape(480, 640)]), - # caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", - # ) - - # log[f'pred_dw2_{i}'] = wandb.Image( - # np.array([pred_disp_dw2.reshape(240, 320)]), - # caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}", - # ) - # log[f'pred_dw2_norm_{i}'] = wandb.Image( - # np.array([pred_disp_dw2_norm.reshape(240, 320)]), - # caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}", - # ) - + batch_idx = 1 - log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left") - input_right = right# .cpu().detach().numpy() if isinstance(right, torch.Tensor) else right - if input_right.shape != (480, 640, 3): - input_right.transpose(1,2,0) - log['input_right'] = wandb.Image(input_right.astype('uint8'), caption="Input Right") + if isinstance(pred_disp, list): + pred_disp = pred_disp[-1] - log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}") - - gt_disp = gt_disp# .cpu().detach().numpy() if isinstance(gt_disp, torch.Tensor) else gt_disp - disp = disp# .cpu().detach().numpy() if isinstance(disp, torch.Tensor) else disp + pred_disp = torch.squeeze(pred_disp[:, 0, :, :]) + gt_disp = torch.squeeze(gt_disp[:, 0, :, :]) + left = torch.squeeze(left[:, 0, :, :]) + right = torch.squeeze(right[:, 0, :, :]) + disp = pred_disp disp_error = gt_disp - disp - log['disp_error'] = wandb.Image( - normalize_and_colormap(abs(disp_error)), - caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{abs(disp_error).mean():.{2}f}", - ) - - - log[f'gt_disp_vis'] = wandb.Image( - normalize_and_colormap(gt_disp), - caption=f"GT Disp Vis \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}", - ) - wandb.log(log) - return pred_flow + input_left = left[batch_idx].cpu().detach().numpy()# .transpose(1,2,0) + input_right = right[batch_idx].cpu().detach().numpy()# .transpose(1,2,0) + + wandb_log = dict( + key='samples', + images=[ + normalize_and_colormap(pred_disp[batch_idx]), + normalize_and_colormap(abs(disp_error[batch_idx])), + normalize_and_colormap(gt_disp[batch_idx]), + input_left, + input_right, + ], + caption=[ + f"Disparity \n{pred_disp[batch_idx].min():.{2}f}/{pred_disp[batch_idx].max():.{2}f}", + f"Disp. Error\n{disp_error[batch_idx].min():.{2}f}/{disp_error[batch_idx].max():.{2}f}\n{abs(disp_error[batch_idx]).mean():.{2}f}", + f"GT Disp Vis \n{gt_disp[batch_idx].min():.{2}f}/{gt_disp[batch_idx].max():.{2}f}", + "Input Left", + "Input Right" + ], + ) + return wandb_log def parse_yaml(file_path: str) -> namedtuple: @@ -259,9 +155,10 @@ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False): class CREStereoLightning(LightningModule): - def __init__(self, args): + def __init__(self, args, logger): super().__init__() self.batch_size = args.batch_size + self.wandb_logger = logger self.model = Model( max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False ) @@ -270,13 +167,10 @@ class CREStereoLightning(LightningModule): return self.model(image1, image2, flow_init, iters, upsample, test_mode, self_attend_right) def training_step(self, batch, batch_idx): - # loss = self(batch) left, right, gt_disp, valid_mask = batch - left = torch.Tensor(left).to(self.device) - right = torch.Tensor(right).to(self.device) - left = left - right = right flow_predictions = self.forward(left, right) + gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] + gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512] loss = sequence_loss( flow_predictions, gt_flow, valid_mask, gamma=0.8 ) @@ -285,62 +179,91 @@ class CREStereoLightning(LightningModule): def validation_step(self, batch, batch_idx): left, right, gt_disp, valid_mask = batch - left = torch.Tensor(left).to(self.device) - right = torch.Tensor(right).to(self.device) - print(left.shape) - print(right.shape) - flow_predictions = self.forward(left, right) + flow_predictions = self.forward(left, right, test_mode=True) + gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] + gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512] val_loss = sequence_loss( flow_predictions, gt_flow, valid_mask, gamma=0.8 ) self.log("val_loss", val_loss) + if batch_idx % 4 == 0: + self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp)) def test_step(self, batch, batch_idx): left, right, gt_disp, valid_mask = batch - # left, right, gt_disp, valid_mask = ( - # batch["left"], - # batch["right"], - # batch["disparity"], - # batch["mask"], - # ) - left = torch.Tensor(left).to(self.device) - right = torch.Tensor(right).to(self.device) - flow_predictions = self.forward(left, right) + gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512] + flow_predictions = self.forward(left, right, test_mode=True) test_loss = sequence_loss( flow_predictions, gt_flow, valid_mask, gamma=0.8 ) self.log("test_loss", test_loss) + print('test_batch_idx:', batch_idx) + self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp)) def configure_optimizers(self): - return optim.Adam(self.model.parameters(), lr=0.01, betas=(0.9, 0.999)) + return optim.Adam(self.model.parameters(), lr=0.1, betas=(0.9, 0.999)) if __name__ == "__main__": # train configuration args = parse_yaml("cfgs/train.yaml") - # wandb.init(project="crestereo-lightning", entity="cpt-captain") - # Lite(strategy='dp', accelerator='gpu', devices=2).run(args) pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png' - model = CREStereoLightning(args) - dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path, use_lightning=True) - test_dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path, test_set=True, use_lightning=True) - print(len(dataset)) - print(len(test_dataset)) + wandb_logger = WandbLogger(project="crestereo-lightning") wandb.config.update(args._asdict()) + model = CREStereoLightning(args, wandb_logger) + + dataset = BlenderDataset( + root=args.training_data_path, + pattern_path=pattern_path, + use_lightning=True, + ) + test_dataset = BlenderDataset( + root=args.training_data_path, + pattern_path=pattern_path, + test_set=True, + use_lightning=True, + ) + + dataloader = DataLoader( + dataset, + args.batch_size, + shuffle=True, + num_workers=16, + drop_last=True, + persistent_workers=True, + pin_memory=True, + ) + # num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True) + test_dataloader = DataLoader( + test_dataset, + args.batch_size, + shuffle=False, + num_workers=16, + drop_last=False, + persistent_workers=True, + pin_memory=True + ) + trainer = Trainer( - max_epochs=args.n_total_epoch, accelerator='gpu', devices=2, - # auto_scale_batch_size='binsearch', - # strategy='ddp', + max_epochs=args.n_total_epoch, + callbacks=[ + EarlyStopping( + monitor="val_loss", + mode="min", + patience=4, + ) + ], + accumulate_grad_batches=8, deterministic=True, check_val_every_n_epoch=1, limit_val_batches=24, limit_test_batches=24, logger=wandb_logger, default_root_dir=args.log_dir_lightning, - ) - # trainer.tune(model) - trainer.fit(model, dataset, test_dataset) + ) + + trainer.fit(model, dataloader, test_dataloader)