From 6ba2cd9e5de6f62c6f680b3b6a39d93f6da8d17b Mon Sep 17 00:00:00 2001 From: Nils Koch Date: Thu, 2 Jun 2022 15:20:28 +0200 Subject: [PATCH] test_model.py: cleanup --- test_model.py | 56 ++++++++++++--------------------------------------- 1 file changed, 13 insertions(+), 43 deletions(-) diff --git a/test_model.py b/test_model.py index e2f7121..902e455 100644 --- a/test_model.py +++ b/test_model.py @@ -8,7 +8,6 @@ from nets import Model import wandb -import random from torch.utils.data import DataLoader from dataset import CTDDataset from train import normalize_and_colormap, parse_yaml, inference as ctd_inference @@ -79,45 +78,16 @@ if __name__ == '__main__': model.to(device) model.eval() - CTD = True - if not CTD: - left_img = cv2.imread("../test_imgs/left.png") - right_img = cv2.imread("../test_imgs/right.png") - in_h, in_w = left_img.shape[:2] - - # Resize image in case the GPU memory overflows - eval_h, eval_w = (in_h, in_w) - - # FIXME borked for some reason, hopefully not very important - imgL = cv2.resize(left_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR) - imgR = cv2.resize(right_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR) - - pred = inference(imgL, imgR, model, n_iter=20) - - t = float(in_w) / float(eval_w) - disp = cv2.resize(pred, (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t - - disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0 - disp_vis = disp_vis.astype("uint8") - disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) - - combined_img = np.hstack((left_img, disp_vis)) - # cv2.namedWindow("output", cv2.WINDOW_NORMAL) - # cv2.imshow("output", combined_img) - cv2.imwrite("output.jpg", disp_vis) - # cv2.waitKey(0) - - else: - dataset = CTDDataset('/media/Data1/connecting_the_dots_data/ctd_data/', data_type=data_type, - pattern_path=reference_pattern_path, augment=augment) - dataloader = DataLoader(dataset, args.batch_size, shuffle=True, - num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True) - for batch in dataloader: - for left, right, disparity in zip(batch['left'], batch['right'], batch['disparity']): - right = right.transpose(0, 2).transpose(0, 1) - left_img = left - imgL = left.cpu().detach().numpy() - right_img = right - imgR = right.cpu().detach().numpy() - gt_disp = disparity - do_infer(left_img, right_img, gt_disp, model) + dataset = CTDDataset('/media/Data1/connecting_the_dots_data/ctd_data/', data_type=data_type, + pattern_path=reference_pattern_path, augment=augment) + dataloader = DataLoader(dataset, args.batch_size, shuffle=True, + num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True) + for batch in dataloader: + for left, right, disparity in zip(batch['left'], batch['right'], batch['disparity']): + right = right.transpose(0, 2).transpose(0, 1) + left_img = left + imgL = left.cpu().detach().numpy() + right_img = right + imgR = right.cpu().detach().numpy() + gt_disp = disparity + do_infer(left_img, right_img, gt_disp, model)