import os import sys import time import logging from collections import namedtuple import yaml # from tensorboardX import SummaryWriter from nets import Model # from dataset import CREStereoDataset from dataset import CREStereoDataset, CTDDataset import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader import wandb import numpy as np import cv2 def normalize_and_colormap(img): ret = (img - img.min()) / (img.max() - img.min()) * 255.0 if isinstance(ret, torch.Tensor): ret = ret.cpu().detach().numpy() ret = ret.astype("uint8") ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO) return ret def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=True): print("Model Forwarding...") left = left.cpu().detach().numpy() imgL = left imgR = right.cpu().detach().numpy() imgL = np.ascontiguousarray(imgL[None, :, :, :]) imgR = np.ascontiguousarray(imgR[None, :, :, :]) # chosen for convenience device = torch.device('cuda:0') imgL = torch.tensor(imgL.astype("float32")).to(device) imgR = torch.tensor(imgR.astype("float32")).to(device) imgL = imgL.transpose(2,3).transpose(1,2) if imgL.shape != imgR.shape: imgR = imgR.transpose(2,3).transpose(1,2) imgL_dw2 = F.interpolate( imgL, size=(imgL.shape[2] // 2, imgL.shape[3] // 2), mode="bilinear", align_corners=True, ) imgR_dw2 = F.interpolate( imgR, size=(imgL.shape[2] // 2, imgL.shape[3] // 2), mode="bilinear", align_corners=True, ) with torch.inference_mode(): pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None) pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2) if not wandb_log: return torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy() log = {} in_h, in_w = left.shape[:2] # Resize image in case the GPU memory overflows eval_h, eval_w = (in_h,in_w) for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)): pred_disp = torch.squeeze(pf[:, 0, :, :]).cpu().detach().numpy() pred_disp_dw2 = torch.squeeze(pf_dw2[:, 0, :, :]).cpu().detach().numpy() pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) pred_disp_dw2_norm = cv2.normalize(pred_disp_dw2, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) if i == n_iter-1: t = float(in_w) / float(eval_w) disp = cv2.resize(pred_disp, (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t log[f'disp_vis'] = wandb.Image( normalize_and_colormap(disp), caption=f"Disparity \n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", ) log[f'pred_{i}'] = wandb.Image( np.array([pred_disp.reshape(480, 640)]), caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", ) log[f'pred_norm_{i}'] = wandb.Image( np.array([pred_disp_norm.reshape(480, 640)]), caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", ) log[f'pred_dw2_{i}'] = wandb.Image( np.array([pred_disp_dw2.reshape(240, 320)]), caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}", ) log[f'pred_dw2_norm_{i}'] = wandb.Image( np.array([pred_disp_dw2_norm.reshape(240, 320)]), caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}", ) log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left") log['input_right'] = wandb.Image(right.cpu().detach().numpy().transpose(1,2,0).astype('uint8'), caption="Input Right") log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}") disp_error = gt_disp - disp log['disp_error'] = wandb.Image( normalize_and_colormap(disp_error.abs()), caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{disp_error.abs().mean():.{2}f}", ) log[f'gt_disp_vis'] = wandb.Image( normalize_and_colormap(gt_disp), caption=f"GT Disp Vis \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}", ) wandb.log(log) def parse_yaml(file_path: str) -> namedtuple: """Parse yaml configuration file and return the object in `namedtuple`.""" with open(file_path, "rb") as f: cfg: dict = yaml.safe_load(f) args = namedtuple("train_args", cfg.keys())(*cfg.values()) return args def format_time(elapse): elapse = int(elapse) hour = elapse // 3600 minute = elapse % 3600 // 60 seconds = elapse % 60 return "{:02d}:{:02d}:{:02d}".format(hour, minute, seconds) def ensure_dir(path): if not os.path.exists(path): os.makedirs(path, exist_ok=True) def adjust_learning_rate(optimizer, epoch): warm_up = 0.02 const_range = 0.6 min_lr_rate = 0.05 if epoch <= args.n_total_epoch * warm_up: lr = (1 - min_lr_rate) * args.base_lr / ( args.n_total_epoch * warm_up ) * epoch + min_lr_rate * args.base_lr elif args.n_total_epoch * warm_up < epoch <= args.n_total_epoch * const_range: lr = args.base_lr else: lr = (min_lr_rate - 1) * args.base_lr / ( (1 - const_range) * args.n_total_epoch ) * epoch + (1 - min_lr_rate * const_range) / (1 - const_range) * args.base_lr for param_group in optimizer.param_groups: param_group['lr'] = lr def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8): ''' valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W) flow_preds[0]: (B, 2, H, W) flow_gt: (B, 2, H, W) ''' n_predictions = len(flow_preds) flow_loss = 0.0 # TEST flow_gt = torch.squeeze(flow_gt, dim=-1) for i in range(n_predictions): i_weight = gamma ** (n_predictions - i - 1) i_loss = torch.abs(flow_preds[i] - flow_gt) flow_loss += i_weight * (valid.unsqueeze(1) * i_loss).mean() return flow_loss def main(args): # initial info torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) # rank, world_size = dist.get_rank(), dist.get_world_size() world_size = torch.cuda.device_count() # number of GPU(s) # directory check log_model_dir = os.path.join(args.log_dir, "models") ensure_dir(log_model_dir) # model / optimizer model = Model( max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False ) model = nn.DataParallel(model,device_ids=[i for i in range(world_size)]) model.cuda() optimizer = optim.Adam(model.parameters(), lr=0.1, betas=(0.9, 0.999)) # model = nn.DataParallel(model,device_ids=[0]) # tb_log = SummaryWriter(os.path.join(args.log_dir, "train.events")) wandb.watch(model) metrics = {} # worklog logging.basicConfig(level=eval(args.log_level)) worklog = logging.getLogger("train_logger") worklog.propagate = False fileHandler = logging.FileHandler( os.path.join(args.log_dir, "worklog.txt"), mode="a", encoding="utf8" ) formatter = logging.Formatter( fmt="%(asctime)s %(message)s", datefmt="%Y/%m/%d %H:%M:%S" ) fileHandler.setFormatter(formatter) consoleHandler = logging.StreamHandler(sys.stdout) formatter = logging.Formatter( fmt="\x1b[32m%(asctime)s\x1b[0m %(message)s", datefmt="%Y/%m/%d %H:%M:%S" ) consoleHandler.setFormatter(formatter) worklog.handlers = [fileHandler, consoleHandler] # params stat worklog.info(f"Use {world_size} GPU(s)") worklog.info("Params: %s" % sum([p.numel() for p in model.parameters()])) # load pretrained model if exist chk_path = os.path.join(log_model_dir, "latest.pth") if args.loadmodel is not None: chk_path = args.loadmodel elif not os.path.exists(chk_path): chk_path = None if chk_path is not None: # if rank == 0: worklog.info(f"loading model: {chk_path}") state_dict = torch.load(chk_path) model.load_state_dict(state_dict['state_dict']) optimizer.load_state_dict(state_dict['optim_state_dict']) resume_epoch_idx = state_dict["epoch"] resume_iters = state_dict["iters"] start_epoch_idx = resume_epoch_idx + 1 start_iters = resume_iters else: start_epoch_idx = 1 start_iters = 0 # pattern_path = '/home/nils/kinect_reference_high_res_scaled_down.png' pattern_path = '/home/nils/kinect_reference_cropped.png' # pattern_path = '/home/nils/orig_ctd/connecting_the_dots/data/kinect_pattern.png' # datasets # dataset = CREStereoDataset(args.training_data_path) dataset = CTDDataset(args.training_data_path, pattern_path=pattern_path) # if rank == 0: worklog.info(f"Dataset size: {len(dataset)}") dataloader = DataLoader(dataset, args.batch_size, shuffle=True, num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True) # counter cur_iters = start_iters total_iters = args.minibatch_per_epoch * args.n_total_epoch t0 = time.perf_counter() for epoch_idx in range(start_epoch_idx, args.n_total_epoch + 1): # adjust learning rate epoch_total_train_loss = 0 adjust_learning_rate(optimizer, epoch_idx) model.train() t1 = time.perf_counter() # batch_idx = 0 # for mini_batch_data in dataloader: for batch_idx, mini_batch_data in enumerate(dataloader): if batch_idx % args.minibatch_per_epoch == 0 and batch_idx != 0: break # batch_idx += 1 cur_iters += 1 # parse data left, right, gt_disp, valid_mask = ( mini_batch_data["left"], mini_batch_data["right"], mini_batch_data["disparity"].cuda(), mini_batch_data["mask"].cuda(), ) t2 = time.perf_counter() optimizer.zero_grad() # pre-process gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512] # forward # left = left.transpose(1, 2).transpose(2, 3) left = left.transpose(1, 3).transpose(2, 3) right = right.transpose(2, 3).transpose(2, 3)#.transpose(1, 2) flow_predictions = model(left.cuda(), right.cuda()) # loss & backword loss = sequence_loss( flow_predictions, gt_flow, valid_mask, gamma=0.8 ) if batch_idx % 128 == 0: inference( mini_batch_data['left'][0], mini_batch_data['right'][0], mini_batch_data['disparity'][0], mini_batch_data['mask'][0], model, batch_idx, ) # loss stats loss_item = loss.data.item() epoch_total_train_loss += loss_item loss.backward() optimizer.step() t3 = time.perf_counter() if cur_iters % 10 == 0: tdata = t2 - t1 time_train_passed = t3 - t0 time_iter_passed = t3 - t1 step_passed = cur_iters - start_iters eta = ( (total_iters - cur_iters) / max(step_passed, 1e-7) * time_train_passed ) meta_info = list() meta_info.append("{:.2g} b/s".format(1.0 / time_iter_passed)) meta_info.append("passed:{}".format(format_time(time_train_passed))) meta_info.append("eta:{}".format(format_time(eta))) meta_info.append( "data_time:{:.2g}".format(tdata / time_iter_passed) ) meta_info.append( "lr:{:.5g}".format(optimizer.param_groups[0]["lr"]) ) meta_info.append( "[{}/{}:{}/{}]".format( epoch_idx, args.n_total_epoch, batch_idx, args.minibatch_per_epoch, ) ) loss_info = [" ==> {}:{:.4g}".format("loss", loss_item)] # exp_name = ['\n' + os.path.basename(os.getcwd())] info = [",".join(meta_info)] + loss_info worklog.info("".join(info)) # minibatch loss # tb_log.add_scalar("train/loss_batch", loss_item, cur_iters) metrics['train/loss_batch'] = loss_item # tb_log.add_scalar( # "train/lr", optimizer.param_groups[0]["lr"], cur_iters # ) metrics['train/lr'] = optimizer.param_groups[0]["lr"] # tb_log.flush() wandb.log(metrics) t1 = time.perf_counter() # tb_log.add_scalar( # "train/loss", # epoch_total_train_loss / args.minibatch_per_epoch, # epoch_idx, # ) metrics['train/loss'] = epoch_total_train_loss / args.minibatch_per_epoch # tb_log.flush() wandb.log(metrics) # save model params ckp_data = { "epoch": epoch_idx, "iters": cur_iters, "batch_size": args.batch_size, "epoch_size": args.minibatch_per_epoch, "train_loss": epoch_total_train_loss / args.minibatch_per_epoch, "state_dict": model.state_dict(), "optim_state_dict": optimizer.state_dict(), } torch.save(ckp_data, os.path.join(log_model_dir, "latest.pth")) if epoch_idx % args.model_save_freq_epoch == 0: save_path = os.path.join(log_model_dir, "epoch-%d.pth" % epoch_idx) worklog.info(f"Model params saved: {save_path}") torch.save(ckp_data, save_path) worklog.info("Training is done, exit.") if __name__ == "__main__": # train configuration args = parse_yaml("cfgs/train.yaml") wandb.init(project="crestereo", entity="cpt-captain") wandb.config.update(args._asdict()) main(args)