|
|
|
@ -17,131 +17,8 @@ device = 'cuda' |
|
|
|
|
wandb.init(project="crestereo", entity="cpt-captain") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Ref: https://github.com/megvii-research/CREStereo/blob/master/test.py |
|
|
|
|
def inference(left, right, model, n_iter=20): |
|
|
|
|
print("Model Forwarding...") |
|
|
|
|
imgL = left.transpose(2, 0, 1) |
|
|
|
|
imgR = right.transpose(2, 0, 1) |
|
|
|
|
imgL = np.ascontiguousarray(imgL[None, :, :, :]) |
|
|
|
|
imgR = np.ascontiguousarray(imgR[None, :, :, :]) |
|
|
|
|
|
|
|
|
|
imgL = torch.tensor(imgL.astype("float32")).to(device) |
|
|
|
|
imgR = torch.tensor(imgR.astype("float32")).to(device) |
|
|
|
|
|
|
|
|
|
imgL_dw2 = F.interpolate( |
|
|
|
|
imgL, |
|
|
|
|
size=(imgL.shape[2] // 2, imgL.shape[3] // 2), |
|
|
|
|
mode="bilinear", |
|
|
|
|
align_corners=True, |
|
|
|
|
) |
|
|
|
|
imgR_dw2 = F.interpolate( |
|
|
|
|
imgR, |
|
|
|
|
size=(imgL.shape[2] // 2, imgL.shape[3] // 2), |
|
|
|
|
mode="bilinear", |
|
|
|
|
align_corners=True, |
|
|
|
|
) |
|
|
|
|
# print(imgR_dw2.shape) |
|
|
|
|
with torch.inference_mode(): |
|
|
|
|
pred_flow_dw2 = model(imgL_dw2, imgR_dw2, iters=n_iter, flow_init=None) |
|
|
|
|
|
|
|
|
|
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2) |
|
|
|
|
pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy() |
|
|
|
|
|
|
|
|
|
return pred_disp |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def inference_ctd(left, right, gt_disp, mask, model, epoch, n_iter=20): |
|
|
|
|
print("Model Forwarding...") |
|
|
|
|
# print(left.shape) |
|
|
|
|
left = left.cpu().detach().numpy() |
|
|
|
|
imgL = left |
|
|
|
|
imgR = right.cpu().detach().numpy() |
|
|
|
|
imgL = np.ascontiguousarray(imgL[None, :, :, :]) |
|
|
|
|
imgR = np.ascontiguousarray(imgR[None, :, :, :]) |
|
|
|
|
|
|
|
|
|
# chosen for convenience |
|
|
|
|
device = torch.device('cuda:0') |
|
|
|
|
|
|
|
|
|
imgL = torch.tensor(imgL.astype("float32")).to(device) |
|
|
|
|
imgR = torch.tensor(imgR.astype("float32")).to(device) |
|
|
|
|
imgL = imgL.transpose(2, 3).transpose(1, 2) |
|
|
|
|
|
|
|
|
|
imgL_dw2 = F.interpolate( |
|
|
|
|
imgL, |
|
|
|
|
size=(imgL.shape[2] // 2, imgL.shape[3] // 2), |
|
|
|
|
mode="bilinear", |
|
|
|
|
align_corners=True, |
|
|
|
|
) |
|
|
|
|
imgR_dw2 = F.interpolate( |
|
|
|
|
imgR, |
|
|
|
|
size=(imgL.shape[2] // 2, imgL.shape[3] // 2), |
|
|
|
|
mode="bilinear", |
|
|
|
|
align_corners=True, |
|
|
|
|
) |
|
|
|
|
with torch.inference_mode(): |
|
|
|
|
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None) |
|
|
|
|
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2) |
|
|
|
|
|
|
|
|
|
log = {} |
|
|
|
|
for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)): |
|
|
|
|
pred_disp = torch.squeeze(pf[:, 0, :, :]).cpu().detach().numpy() |
|
|
|
|
pred_disp_dw2 = torch.squeeze(pf_dw2[:, 0, :, :]).cpu().detach().numpy() |
|
|
|
|
|
|
|
|
|
pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) |
|
|
|
|
pred_disp_dw2_norm = cv2.normalize(pred_disp_dw2, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) |
|
|
|
|
|
|
|
|
|
log[f'pred_{i}'] = wandb.Image( |
|
|
|
|
np.array([pred_disp.reshape(480, 640)]), |
|
|
|
|
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", |
|
|
|
|
) |
|
|
|
|
log[f'pred_norm_{i}'] = wandb.Image( |
|
|
|
|
np.array([pred_disp_norm.reshape(480, 640)]), |
|
|
|
|
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
log[f'pred_dw2_{i}'] = wandb.Image( |
|
|
|
|
np.array([pred_disp_dw2.reshape(240, 320)]), |
|
|
|
|
caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}", |
|
|
|
|
) |
|
|
|
|
log[f'pred_dw2_norm_{i}'] = wandb.Image( |
|
|
|
|
np.array([pred_disp_dw2_norm.reshape(240, 320)]), |
|
|
|
|
caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}", |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left") |
|
|
|
|
log['input_right'] = wandb.Image(right.cpu().detach().numpy().transpose(1, 2, 0).astype('uint8'), |
|
|
|
|
caption="Input Right") |
|
|
|
|
|
|
|
|
|
log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}") |
|
|
|
|
|
|
|
|
|
disp_error = gt_disp - disp |
|
|
|
|
log['disp_error'] = wandb.Image( |
|
|
|
|
normalize_and_colormap(disp_error), |
|
|
|
|
caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{disp_error.mean():.{2}f}", |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
wandb.log(log) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def do_infer(left_img, right_img, gt_disp, model): |
|
|
|
|
in_h, in_w = left_img.shape[:2] |
|
|
|
|
|
|
|
|
|
# Resize image in case the GPU memory overflows |
|
|
|
|
eval_h, eval_w = (in_h, in_w) |
|
|
|
|
|
|
|
|
|
# FIXME borked for some reason, hopefully not very important |
|
|
|
|
|
|
|
|
|
imgL = left_img.cpu().detach().numpy() if isinstance(left_img, torch.Tensor) else left_img |
|
|
|
|
imgR = right_img.cpu().detach().numpy() if isinstance(right_img, torch.Tensor) else right_img |
|
|
|
|
|
|
|
|
|
imgL = cv2.resize(imgL, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR) |
|
|
|
|
imgR = cv2.resize(imgR, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR) |
|
|
|
|
|
|
|
|
|
# pred = ctd_inference(imgL, imgR, gt_disp, None, model, None, n_iter=20) |
|
|
|
|
pred = ctd_inference(left_img, right_img, gt_disp, None, model, None, n_iter=20, wandb_log=False) |
|
|
|
|
|
|
|
|
|
t = float(in_w) / float(eval_w) |
|
|
|
|
disp = cv2.resize(pred, (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t |
|
|
|
|
disp = ctd_inference(left_img, right_img, gt_disp, None, model, None, n_iter=20, wandb_log=False) |
|
|
|
|
|
|
|
|
|
disp_vis = normalize_and_colormap(disp) |
|
|
|
|
gt_disp_vis = normalize_and_colormap(gt_disp) |
|
|
|
@ -151,6 +28,10 @@ def do_infer(left_img, right_img, gt_disp, model): |
|
|
|
|
disp_err = normalize_and_colormap(disp_err.abs()) |
|
|
|
|
|
|
|
|
|
wandb.log({ |
|
|
|
|
'disp': wandb.Image( |
|
|
|
|
disp, |
|
|
|
|
caption=f"Pred. Disparity \n{disp.min():.{2}f}/{disp.max():.{2}f}", |
|
|
|
|
), |
|
|
|
|
'disp_vis': wandb.Image( |
|
|
|
|
disp_vis, |
|
|
|
|
caption=f"Pred. Disparity \n{disp.min():.{2}f}/{disp.max():.{2}f}", |
|
|
|
@ -193,7 +74,6 @@ if __name__ == '__main__': |
|
|
|
|
|
|
|
|
|
model = Model(max_disp=256, mixed_precision=False, test_mode=True) |
|
|
|
|
model = nn.DataParallel(model, device_ids=[device]) |
|
|
|
|
# model.load_state_dict(torch.load(model_path), strict=False) |
|
|
|
|
state_dict = torch.load(model_path)['state_dict'] |
|
|
|
|
model.load_state_dict(state_dict, strict=True) |
|
|
|
|
model.to(device) |
|
|
|
|