|
|
|
import torch
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import numpy as np
|
|
|
|
import cv2
|
|
|
|
from imread_from_url import imread_from_url
|
|
|
|
|
|
|
|
from nets import Model
|
|
|
|
|
|
|
|
device = 'cuda'
|
|
|
|
|
|
|
|
#Ref: https://github.com/megvii-research/CREStereo/blob/master/test.py
|
|
|
|
def inference(left, right, model, n_iter=20):
|
|
|
|
|
|
|
|
print("Model Forwarding...")
|
|
|
|
imgL = left.transpose(2, 0, 1)
|
|
|
|
imgR = right.transpose(2, 0, 1)
|
|
|
|
imgL = np.ascontiguousarray(imgL[None, :, :, :])
|
|
|
|
imgR = np.ascontiguousarray(imgR[None, :, :, :])
|
|
|
|
|
|
|
|
imgL = torch.tensor(imgL.astype("float32")).to(device)
|
|
|
|
imgR = torch.tensor(imgR.astype("float32")).to(device)
|
|
|
|
|
|
|
|
imgL_dw2 = F.interpolate(
|
|
|
|
imgL,
|
|
|
|
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
|
|
|
mode="bilinear",
|
|
|
|
align_corners=True,
|
|
|
|
)
|
|
|
|
imgR_dw2 = F.interpolate(
|
|
|
|
imgR,
|
|
|
|
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
|
|
|
mode="bilinear",
|
|
|
|
align_corners=True,
|
|
|
|
)
|
|
|
|
# print(imgR_dw2.shape)
|
|
|
|
|
|
|
|
pred_flow_dw2 = model(imgL_dw2, imgR_dw2, iters=n_iter, flow_init=None)
|
|
|
|
|
|
|
|
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
|
|
|
|
pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy()
|
|
|
|
|
|
|
|
return pred_disp
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
|
|
left_img = imread_from_url("https://raw.githubusercontent.com/megvii-research/CREStereo/master/img/test/left.png")
|
|
|
|
right_img = imread_from_url("https://raw.githubusercontent.com/megvii-research/CREStereo/master/img/test/right.png")
|
|
|
|
|
|
|
|
# Resize image in case the GPU memory overflows
|
|
|
|
eval_h, eval_w = (240,426)
|
|
|
|
imgL = cv2.resize(left, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR)
|
|
|
|
imgR = cv2.resize(right, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR)
|
|
|
|
|
|
|
|
model_path = "models/crestereo_eth3d.pth"
|
|
|
|
|
|
|
|
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
|
|
|
|
model.load_state_dict(torch.load(model_path), strict=True)
|
|
|
|
model.to(device)
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
disp = inference(imgL, imgR, model, n_iter=20)
|
|
|
|
disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0
|
|
|
|
disp_vis = disp_vis.astype("uint8")
|
|
|
|
disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO)
|
|
|
|
left_img = cv2.resize(left_img, disp_vis.shape[1::-1])
|
|
|
|
|
|
|
|
combined_img = np.hstack((left_img, disp_vis))
|
|
|
|
cv2.imshow("output", combined_img)
|
|
|
|
cv2.imwrite("output.jpg", combined_img)
|
|
|
|
cv2.waitKey(0)
|
|
|
|
|
|
|
|
|
|
|
|
|