diff --git a/convert_weights.py b/convert_weights.py new file mode 100644 index 0000000..c765f39 --- /dev/null +++ b/convert_weights.py @@ -0,0 +1,26 @@ +import copy +import torch +import numpy as np +import megengine as mge + +from nets import Model + +# Read Megengine parameters +pretrained_dict = mge.load("models/crestereo_eth3d.mge") + +model = Model(max_disp=256, mixed_precision=False, test_mode=True) +model.eval() + +state_dict = model.state_dict() +for key, value in pretrained_dict['state_dict'].items(): + + print(f"Converting {key}") + # Fix shape mismatch + if value.shape[0] == 1: + value = np.squeeze(value) + + state_dict[key] = torch.tensor(value) + +output_path = "models/crestereo_eth3d.pth" +torch.save(state_dict, output_path) +print(f"\nModel saved to: {output_path}") \ No newline at end of file diff --git a/doc/img/output.jpg b/doc/img/output.jpg new file mode 100644 index 0000000..6fc8447 Binary files /dev/null and b/doc/img/output.jpg differ diff --git a/models/.gitkeep b/models/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/test_model.py b/test_model.py index 796c7bd..7e32363 100644 --- a/test_model.py +++ b/test_model.py @@ -1,13 +1,63 @@ import torch +import torch.nn.functional as F +import numpy as np +import cv2 +from imread_from_url import imread_from_url from nets import Model -model = Model(max_disp=256, mixed_precision=False, test_mode=True) -model.eval() +#Ref: https://github.com/megvii-research/CREStereo/blob/master/test.py +def inference(left, right, model, n_iter=20): + + print("Model Forwarding...") + imgL = left.transpose(2, 0, 1) + imgR = right.transpose(2, 0, 1) + imgL = np.ascontiguousarray(imgL[None, :, :, :]) + imgR = np.ascontiguousarray(imgR[None, :, :, :]) + + imgL = torch.tensor(imgL.astype("float32")) + imgR = torch.tensor(imgR.astype("float32")) + + imgL_dw2 = F.interpolate( + imgL, + size=(imgL.shape[2] // 2, imgL.shape[3] // 2), + mode="bilinear", + align_corners=True, + ) + imgR_dw2 = F.interpolate( + imgR, + size=(imgL.shape[2] // 2, imgL.shape[3] // 2), + mode="bilinear", + align_corners=True, + ) + # print(imgR_dw2.shape) + + pred_flow_dw2 = model(imgL_dw2, imgR_dw2, iters=n_iter, flow_init=None) + + pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2) + pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).detach().numpy() + + return pred_disp + +if __name__ == '__main__': + + left_img = imread_from_url("https://vision.middlebury.edu/stereo/data/scenes2003/newdata/cones/im2.png") + right_img = imread_from_url("https://vision.middlebury.edu/stereo/data/scenes2003/newdata/cones/im6.png") + + model_path = "models/crestereo_eth3d.pth" + + model = Model(max_disp=256, mixed_precision=False, test_mode=True) + model.load_state_dict(torch.load(model_path), strict=True) + model.eval() + + disp = inference(left_img, right_img, model, n_iter=20) + disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0 + disp_vis = disp_vis.astype("uint8") + disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) + + cv2.imshow("output", disp_vis) + cv2.imwrite("output.jpg", disp_vis) + cv2.waitKey(0) -t1 = torch.rand(1, 3, 480, 640) -t2 = torch.rand(1, 3, 480, 640) -output = model(t1,t2) -print(output.shape)