From 9740e5d647d955197f28924b6b1430876bf3e9cb Mon Sep 17 00:00:00 2001 From: Nils Koch Date: Mon, 30 May 2022 16:27:13 +0200 Subject: [PATCH] test_model.py: reformat --- test_model.py | 337 +++++++++++++++++++++++++------------------------- 1 file changed, 166 insertions(+), 171 deletions(-) diff --git a/test_model.py b/test_model.py index 9e04be7..f91bf13 100644 --- a/test_model.py +++ b/test_model.py @@ -17,43 +17,40 @@ device = 'cuda' wandb.init(project="crestereo", entity="cpt-captain") - -#Ref: https://github.com/megvii-research/CREStereo/blob/master/test.py +# Ref: https://github.com/megvii-research/CREStereo/blob/master/test.py def inference(left, right, model, n_iter=20): + print("Model Forwarding...") + imgL = left.transpose(2, 0, 1) + imgR = right.transpose(2, 0, 1) + imgL = np.ascontiguousarray(imgL[None, :, :, :]) + imgR = np.ascontiguousarray(imgR[None, :, :, :]) + + imgL = torch.tensor(imgL.astype("float32")).to(device) + imgR = torch.tensor(imgR.astype("float32")).to(device) - print("Model Forwarding...") - imgL = left.transpose(2, 0, 1) - imgR = right.transpose(2, 0, 1) - imgL = np.ascontiguousarray(imgL[None, :, :, :]) - imgR = np.ascontiguousarray(imgR[None, :, :, :]) - - imgL = torch.tensor(imgL.astype("float32")).to(device) - imgR = torch.tensor(imgR.astype("float32")).to(device) - - imgL_dw2 = F.interpolate( - imgL, - size=(imgL.shape[2] // 2, imgL.shape[3] // 2), - mode="bilinear", - align_corners=True, - ) - imgR_dw2 = F.interpolate( - imgR, - size=(imgL.shape[2] // 2, imgL.shape[3] // 2), - mode="bilinear", - align_corners=True, - ) - # print(imgR_dw2.shape) - with torch.inference_mode(): - pred_flow_dw2 = model(imgL_dw2, imgR_dw2, iters=n_iter, flow_init=None) - - pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2) - pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy() - - return pred_disp + imgL_dw2 = F.interpolate( + imgL, + size=(imgL.shape[2] // 2, imgL.shape[3] // 2), + mode="bilinear", + align_corners=True, + ) + imgR_dw2 = F.interpolate( + imgR, + size=(imgL.shape[2] // 2, imgL.shape[3] // 2), + mode="bilinear", + align_corners=True, + ) + # print(imgR_dw2.shape) + with torch.inference_mode(): + pred_flow_dw2 = model(imgL_dw2, imgR_dw2, iters=n_iter, flow_init=None) + pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2) + pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy() + + return pred_disp -def inference_ctd(left, right, gt_disp, mask, model, epoch, n_iter=20): +def inference_ctd(left, right, gt_disp, mask, model, epoch, n_iter=20): print("Model Forwarding...") # print(left.shape) left = left.cpu().detach().numpy() @@ -61,26 +58,26 @@ def inference_ctd(left, right, gt_disp, mask, model, epoch, n_iter=20): imgR = right.cpu().detach().numpy() imgL = np.ascontiguousarray(imgL[None, :, :, :]) imgR = np.ascontiguousarray(imgR[None, :, :, :]) - + # chosen for convenience device = torch.device('cuda:0') imgL = torch.tensor(imgL.astype("float32")).to(device) imgR = torch.tensor(imgR.astype("float32")).to(device) - imgL = imgL.transpose(2,3).transpose(1,2) + imgL = imgL.transpose(2, 3).transpose(1, 2) imgL_dw2 = F.interpolate( - imgL, - size=(imgL.shape[2] // 2, imgL.shape[3] // 2), - mode="bilinear", - align_corners=True, - ) + imgL, + size=(imgL.shape[2] // 2, imgL.shape[3] // 2), + mode="bilinear", + align_corners=True, + ) imgR_dw2 = F.interpolate( - imgR, - size=(imgL.shape[2] // 2, imgL.shape[3] // 2), - mode="bilinear", - align_corners=True, - ) + imgR, + size=(imgL.shape[2] // 2, imgL.shape[3] // 2), + mode="bilinear", + align_corners=True, + ) with torch.inference_mode(): pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None) pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2) @@ -89,160 +86,158 @@ def inference_ctd(left, right, gt_disp, mask, model, epoch, n_iter=20): for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)): pred_disp = torch.squeeze(pf[:, 0, :, :]).cpu().detach().numpy() pred_disp_dw2 = torch.squeeze(pf_dw2[:, 0, :, :]).cpu().detach().numpy() - + pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) pred_disp_dw2_norm = cv2.normalize(pred_disp_dw2, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) log[f'pred_{i}'] = wandb.Image( - np.array([pred_disp.reshape(480, 640)]), - caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", - ) + np.array([pred_disp.reshape(480, 640)]), + caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", + ) log[f'pred_norm_{i}'] = wandb.Image( - np.array([pred_disp_norm.reshape(480, 640)]), - caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", - ) + np.array([pred_disp_norm.reshape(480, 640)]), + caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", + ) log[f'pred_dw2_{i}'] = wandb.Image( - np.array([pred_disp_dw2.reshape(240, 320)]), - caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}", - ) + np.array([pred_disp_dw2.reshape(240, 320)]), + caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}", + ) log[f'pred_dw2_norm_{i}'] = wandb.Image( - np.array([pred_disp_dw2_norm.reshape(240, 320)]), - caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}", - ) - + np.array([pred_disp_dw2_norm.reshape(240, 320)]), + caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}", + ) log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left") - log['input_right'] = wandb.Image(right.cpu().detach().numpy().transpose(1,2,0).astype('uint8'), caption="Input Right") + log['input_right'] = wandb.Image(right.cpu().detach().numpy().transpose(1, 2, 0).astype('uint8'), + caption="Input Right") log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}") - disp_error = gt_disp - disp log['disp_error'] = wandb.Image( - normalize_and_colormap(disp_error), - caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{disp_error.mean():.{2}f}", - ) + normalize_and_colormap(disp_error), + caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{disp_error.mean():.{2}f}", + ) wandb.log(log) def do_infer(left_img, right_img, gt_disp, model): + in_h, in_w = left_img.shape[:2] + + # Resize image in case the GPU memory overflows + eval_h, eval_w = (in_h, in_w) + + # FIXME borked for some reason, hopefully not very important + + imgL = left_img.cpu().detach().numpy() if isinstance(left_img, torch.Tensor) else left_img + imgR = right_img.cpu().detach().numpy() if isinstance(right_img, torch.Tensor) else right_img + + imgL = cv2.resize(imgL, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR) + imgR = cv2.resize(imgR, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR) + + # pred = ctd_inference(imgL, imgR, gt_disp, None, model, None, n_iter=20) + pred = ctd_inference(left_img, right_img, gt_disp, None, model, None, n_iter=20, wandb_log=False) + + t = float(in_w) / float(eval_w) + disp = cv2.resize(pred, (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t + + disp_vis = normalize_and_colormap(disp) + gt_disp_vis = normalize_and_colormap(gt_disp) + if gt_disp.shape != disp.shape: + gt_disp = gt_disp.reshape(disp.shape) + disp_err = gt_disp - disp + disp_err = normalize_and_colormap(disp_err.abs()) + + wandb.log({ + 'disp_vis': wandb.Image( + disp_vis, + caption=f"Pred. Disparity \n{disp.min():.{2}f}/{disp.max():.{2}f}", + ), + 'gt_disp_vis': wandb.Image( + gt_disp_vis, + caption=f"GT Disparity \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}", + ), + 'disp_err': wandb.Image( + disp_err, + caption=f"Disparity Error\n{disp_err.min():.{2}f}/{disp_err.max():.{2}f}", + ), + 'input_left': wandb.Image( + left_img.cpu().detach().numpy().astype('uint8'), + caption=f"Input left", + ), + 'input_right': wandb.Image( + right_img.cpu().detach().numpy().astype('uint8'), + caption=f"Input right", + ), + }) + + +if __name__ == '__main__': + # model_path = "models/crestereo_eth3d.pth" + model_path = "train_log/models/latest.pth" + + # reference_pattern_path = '/home/nils/kinect_reference_high_res_scaled_down.png' + reference_pattern_path = '/home/nils/kinect_reference_cropped.png' + # reference_pattern_path = '/home/nils/new_reference.png' + # reference_pattern_path = '/home/nils/kinect_reference_high_res.png' + # reference_pattern_path = '/home/nils/orig_ctd/connecting_the_dots/data/kinect_pattern.png' + + data_type = 'kinect' + augment = False + + args = parse_yaml("cfgs/train.yaml") + + wandb.config.update({'model_path': model_path, 'reference_pattern': reference_pattern_path, 'augment': augment}) + + model = Model(max_disp=256, mixed_precision=False, test_mode=True) + model = nn.DataParallel(model, device_ids=[device]) + # model.load_state_dict(torch.load(model_path), strict=False) + state_dict = torch.load(model_path)['state_dict'] + model.load_state_dict(state_dict, strict=True) + model.to(device) + model.eval() + + CTD = True + if not CTD: + left_img = cv2.imread("../test_imgs/left.png") + right_img = cv2.imread("../test_imgs/right.png") in_h, in_w = left_img.shape[:2] # Resize image in case the GPU memory overflows - eval_h, eval_w = (in_h,in_w) + eval_h, eval_w = (in_h, in_w) # FIXME borked for some reason, hopefully not very important - - imgL = left_img.cpu().detach().numpy() if isinstance(left_img, torch.Tensor) else left_img - imgR = right_img.cpu().detach().numpy() if isinstance(right_img, torch.Tensor) else right_img + imgL = cv2.resize(left_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR) + imgR = cv2.resize(right_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR) - imgL = cv2.resize(imgL, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR) - imgR = cv2.resize(imgR, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR) - - # pred = ctd_inference(imgL, imgR, gt_disp, None, model, None, n_iter=20) - pred = ctd_inference(left_img, right_img, gt_disp, None, model, None, n_iter=20, wandb_log=False) + pred = inference(imgL, imgR, model, n_iter=20) t = float(in_w) / float(eval_w) disp = cv2.resize(pred, (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t - disp_vis = normalize_and_colormap(disp) - gt_disp_vis = normalize_and_colormap(gt_disp) - if gt_disp.shape != disp.shape: - gt_disp = gt_disp.reshape(disp.shape) - disp_err = gt_disp - disp - disp_err = normalize_and_colormap(disp_err.abs()) - - wandb.log({ - 'disp_vis': wandb.Image( - disp_vis, - caption=f"Pred. Disparity \n{disp.min():.{2}f}/{disp.max():.{2}f}", - ), - 'gt_disp_vis': wandb.Image( - gt_disp_vis, - caption=f"GT Disparity \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}", - ), - 'disp_err': wandb.Image( - disp_err, - caption=f"Disparity Error\n{disp_err.min():.{2}f}/{disp_err.max():.{2}f}", - ), - 'input_left': wandb.Image( - left_img.cpu().detach().numpy().astype('uint8'), - caption=f"Input left", - ), - 'input_right': wandb.Image( - right_img.cpu().detach().numpy().astype('uint8'), - caption=f"Input right", - ), - }) - - - -if __name__ == '__main__': - # model_path = "models/crestereo_eth3d.pth" - model_path = "train_log/models/latest.pth" - - # reference_pattern_path = '/home/nils/kinect_reference_high_res_scaled_down.png' - reference_pattern_path = '/home/nils/kinect_reference_cropped.png' - # reference_pattern_path = '/home/nils/new_reference.png' - # reference_pattern_path = '/home/nils/kinect_reference_high_res.png' - # reference_pattern_path = '/home/nils/orig_ctd/connecting_the_dots/data/kinect_pattern.png' - - data_type = 'kinect' - augment = False - - args = parse_yaml("cfgs/train.yaml") - - wandb.config.update({'model_path': model_path, 'reference_pattern': reference_pattern_path, 'augment': augment}) - - model = Model(max_disp=256, mixed_precision=False, test_mode=True) - model = nn.DataParallel(model,device_ids=[device]) - # model.load_state_dict(torch.load(model_path), strict=False) - state_dict = torch.load(model_path)['state_dict'] - model.load_state_dict(state_dict, strict=True) - model.to(device) - model.eval() - - CTD = True - if not CTD: - left_img = cv2.imread("../test_imgs/left.png") - right_img = cv2.imread("../test_imgs/right.png") - in_h, in_w = left_img.shape[:2] - - # Resize image in case the GPU memory overflows - eval_h, eval_w = (in_h,in_w) - - # FIXME borked for some reason, hopefully not very important - imgL = cv2.resize(left_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR) - imgR = cv2.resize(right_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR) - - pred = inference(imgL, imgR, model, n_iter=20) - - t = float(in_w) / float(eval_w) - disp = cv2.resize(pred, (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t - - disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0 - disp_vis = disp_vis.astype("uint8") - disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) - - combined_img = np.hstack((left_img, disp_vis)) - # cv2.namedWindow("output", cv2.WINDOW_NORMAL) - # cv2.imshow("output", combined_img) - cv2.imwrite("output.jpg", disp_vis) - # cv2.waitKey(0) - - else: - dataset = CTDDataset('/media/Data1/connecting_the_dots_data/ctd_data/', data_type=data_type, pattern_path=reference_pattern_path, augment=augment) - dataloader = DataLoader(dataset, args.batch_size, shuffle=True, - num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True) - for batch in dataloader: - for left, right, disparity in zip(batch['left'], batch['right'], batch['disparity']): - right = right.transpose(0, 2).transpose(0, 1) - left_img = left - imgL = left.cpu().detach().numpy() - right_img = right - imgR = right.cpu().detach().numpy() - gt_disp = disparity - do_infer(left_img, right_img, gt_disp, model) - + disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0 + disp_vis = disp_vis.astype("uint8") + disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) + + combined_img = np.hstack((left_img, disp_vis)) + # cv2.namedWindow("output", cv2.WINDOW_NORMAL) + # cv2.imshow("output", combined_img) + cv2.imwrite("output.jpg", disp_vis) + # cv2.waitKey(0) + + else: + dataset = CTDDataset('/media/Data1/connecting_the_dots_data/ctd_data/', data_type=data_type, + pattern_path=reference_pattern_path, augment=augment) + dataloader = DataLoader(dataset, args.batch_size, shuffle=True, + num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True) + for batch in dataloader: + for left, right, disparity in zip(batch['left'], batch['right'], batch['disparity']): + right = right.transpose(0, 2).transpose(0, 1) + left_img = left + imgL = left.cpu().detach().numpy() + right_img = right + imgR = right.cpu().detach().numpy() + gt_disp = disparity + do_infer(left_img, right_img, gt_disp, model)