#1 Added weight conversion to Pytorch
parent
ecd3a40505
commit
1bdbb27240
@ -0,0 +1,26 @@ |
|||||||
|
import copy |
||||||
|
import torch |
||||||
|
import numpy as np |
||||||
|
import megengine as mge |
||||||
|
|
||||||
|
from nets import Model |
||||||
|
|
||||||
|
# Read Megengine parameters |
||||||
|
pretrained_dict = mge.load("models/crestereo_eth3d.mge") |
||||||
|
|
||||||
|
model = Model(max_disp=256, mixed_precision=False, test_mode=True) |
||||||
|
model.eval() |
||||||
|
|
||||||
|
state_dict = model.state_dict() |
||||||
|
for key, value in pretrained_dict['state_dict'].items(): |
||||||
|
|
||||||
|
print(f"Converting {key}") |
||||||
|
# Fix shape mismatch |
||||||
|
if value.shape[0] == 1: |
||||||
|
value = np.squeeze(value) |
||||||
|
|
||||||
|
state_dict[key] = torch.tensor(value) |
||||||
|
|
||||||
|
output_path = "models/crestereo_eth3d.pth" |
||||||
|
torch.save(state_dict, output_path) |
||||||
|
print(f"\nModel saved to: {output_path}") |
After Width: | Height: | Size: 20 KiB |
@ -1,13 +1,63 @@ |
|||||||
import torch |
import torch |
||||||
|
import torch.nn.functional as F |
||||||
|
import numpy as np |
||||||
|
import cv2 |
||||||
|
from imread_from_url import imread_from_url |
||||||
|
|
||||||
from nets import Model |
from nets import Model |
||||||
|
|
||||||
model = Model(max_disp=256, mixed_precision=False, test_mode=True) |
#Ref: https://github.com/megvii-research/CREStereo/blob/master/test.py |
||||||
model.eval() |
def inference(left, right, model, n_iter=20): |
||||||
|
|
||||||
|
print("Model Forwarding...") |
||||||
|
imgL = left.transpose(2, 0, 1) |
||||||
|
imgR = right.transpose(2, 0, 1) |
||||||
|
imgL = np.ascontiguousarray(imgL[None, :, :, :]) |
||||||
|
imgR = np.ascontiguousarray(imgR[None, :, :, :]) |
||||||
|
|
||||||
|
imgL = torch.tensor(imgL.astype("float32")) |
||||||
|
imgR = torch.tensor(imgR.astype("float32")) |
||||||
|
|
||||||
|
imgL_dw2 = F.interpolate( |
||||||
|
imgL, |
||||||
|
size=(imgL.shape[2] // 2, imgL.shape[3] // 2), |
||||||
|
mode="bilinear", |
||||||
|
align_corners=True, |
||||||
|
) |
||||||
|
imgR_dw2 = F.interpolate( |
||||||
|
imgR, |
||||||
|
size=(imgL.shape[2] // 2, imgL.shape[3] // 2), |
||||||
|
mode="bilinear", |
||||||
|
align_corners=True, |
||||||
|
) |
||||||
|
# print(imgR_dw2.shape) |
||||||
|
|
||||||
|
pred_flow_dw2 = model(imgL_dw2, imgR_dw2, iters=n_iter, flow_init=None) |
||||||
|
|
||||||
|
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2) |
||||||
|
pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).detach().numpy() |
||||||
|
|
||||||
|
return pred_disp |
||||||
|
|
||||||
|
if __name__ == '__main__': |
||||||
|
|
||||||
|
left_img = imread_from_url("https://vision.middlebury.edu/stereo/data/scenes2003/newdata/cones/im2.png") |
||||||
|
right_img = imread_from_url("https://vision.middlebury.edu/stereo/data/scenes2003/newdata/cones/im6.png") |
||||||
|
|
||||||
|
model_path = "models/crestereo_eth3d.pth" |
||||||
|
|
||||||
|
model = Model(max_disp=256, mixed_precision=False, test_mode=True) |
||||||
|
model.load_state_dict(torch.load(model_path), strict=True) |
||||||
|
model.eval() |
||||||
|
|
||||||
|
disp = inference(left_img, right_img, model, n_iter=20) |
||||||
|
disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0 |
||||||
|
disp_vis = disp_vis.astype("uint8") |
||||||
|
disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) |
||||||
|
|
||||||
|
cv2.imshow("output", disp_vis) |
||||||
|
cv2.imwrite("output.jpg", disp_vis) |
||||||
|
cv2.waitKey(0) |
||||||
|
|
||||||
t1 = torch.rand(1, 3, 480, 640) |
|
||||||
t2 = torch.rand(1, 3, 480, 640) |
|
||||||
|
|
||||||
output = model(t1,t2) |
|
||||||
print(output.shape) |
|
||||||
|
|
||||||
|
Loading…
Reference in new issue