diff --git a/test_model.py b/test_model.py index 7030a6d..3d9cf12 100644 --- a/test_model.py +++ b/test_model.py @@ -46,10 +46,12 @@ if __name__ == '__main__': left_img = imread_from_url("https://raw.githubusercontent.com/megvii-research/CREStereo/master/img/test/left.png") right_img = imread_from_url("https://raw.githubusercontent.com/megvii-research/CREStereo/master/img/test/right.png") + in_h, in_w = left_img.shape[:2] + # Resize image in case the GPU memory overflows - eval_h, eval_w = (240,426) - imgL = cv2.resize(left, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR) - imgR = cv2.resize(right, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR) + eval_h, eval_w = (1024//4,1536//4) + imgL = cv2.resize(left_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR) + imgR = cv2.resize(right_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR) model_path = "models/crestereo_eth3d.pth" @@ -58,15 +60,18 @@ if __name__ == '__main__': model.to(device) model.eval() - disp = inference(imgL, imgR, model, n_iter=20) - disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0 + pred = inference(imgL, imgR, model, n_iter=20) + + t = float(in_w) / float(eval_w) + disp = cv2.resize(pred, (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t + + disp_vis = (disp - disp.min()) / (256 - disp.min()) * 255.0 disp_vis = disp_vis.astype("uint8") disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) - left_img = cv2.resize(left_img, disp_vis.shape[1::-1]) combined_img = np.hstack((left_img, disp_vis)) cv2.imshow("output", combined_img) - cv2.imwrite("output.jpg", combined_img) + cv2.imwrite("output.jpg", disp_vis) cv2.waitKey(0)