test_model.py: reformat
This commit is contained in:
parent
17bf30fa2a
commit
9740e5d647
327
test_model.py
327
test_model.py
@ -17,43 +17,40 @@ device = 'cuda'
|
|||||||
wandb.init(project="crestereo", entity="cpt-captain")
|
wandb.init(project="crestereo", entity="cpt-captain")
|
||||||
|
|
||||||
|
|
||||||
|
# Ref: https://github.com/megvii-research/CREStereo/blob/master/test.py
|
||||||
#Ref: https://github.com/megvii-research/CREStereo/blob/master/test.py
|
|
||||||
def inference(left, right, model, n_iter=20):
|
def inference(left, right, model, n_iter=20):
|
||||||
|
print("Model Forwarding...")
|
||||||
|
imgL = left.transpose(2, 0, 1)
|
||||||
|
imgR = right.transpose(2, 0, 1)
|
||||||
|
imgL = np.ascontiguousarray(imgL[None, :, :, :])
|
||||||
|
imgR = np.ascontiguousarray(imgR[None, :, :, :])
|
||||||
|
|
||||||
print("Model Forwarding...")
|
imgL = torch.tensor(imgL.astype("float32")).to(device)
|
||||||
imgL = left.transpose(2, 0, 1)
|
imgR = torch.tensor(imgR.astype("float32")).to(device)
|
||||||
imgR = right.transpose(2, 0, 1)
|
|
||||||
imgL = np.ascontiguousarray(imgL[None, :, :, :])
|
|
||||||
imgR = np.ascontiguousarray(imgR[None, :, :, :])
|
|
||||||
|
|
||||||
imgL = torch.tensor(imgL.astype("float32")).to(device)
|
imgL_dw2 = F.interpolate(
|
||||||
imgR = torch.tensor(imgR.astype("float32")).to(device)
|
imgL,
|
||||||
|
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=True,
|
||||||
|
)
|
||||||
|
imgR_dw2 = F.interpolate(
|
||||||
|
imgR,
|
||||||
|
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=True,
|
||||||
|
)
|
||||||
|
# print(imgR_dw2.shape)
|
||||||
|
with torch.inference_mode():
|
||||||
|
pred_flow_dw2 = model(imgL_dw2, imgR_dw2, iters=n_iter, flow_init=None)
|
||||||
|
|
||||||
imgL_dw2 = F.interpolate(
|
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
|
||||||
imgL,
|
pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy()
|
||||||
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
|
||||||
mode="bilinear",
|
|
||||||
align_corners=True,
|
|
||||||
)
|
|
||||||
imgR_dw2 = F.interpolate(
|
|
||||||
imgR,
|
|
||||||
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
|
||||||
mode="bilinear",
|
|
||||||
align_corners=True,
|
|
||||||
)
|
|
||||||
# print(imgR_dw2.shape)
|
|
||||||
with torch.inference_mode():
|
|
||||||
pred_flow_dw2 = model(imgL_dw2, imgR_dw2, iters=n_iter, flow_init=None)
|
|
||||||
|
|
||||||
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
|
return pred_disp
|
||||||
pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy()
|
|
||||||
|
|
||||||
return pred_disp
|
|
||||||
|
|
||||||
|
|
||||||
def inference_ctd(left, right, gt_disp, mask, model, epoch, n_iter=20):
|
def inference_ctd(left, right, gt_disp, mask, model, epoch, n_iter=20):
|
||||||
|
|
||||||
print("Model Forwarding...")
|
print("Model Forwarding...")
|
||||||
# print(left.shape)
|
# print(left.shape)
|
||||||
left = left.cpu().detach().numpy()
|
left = left.cpu().detach().numpy()
|
||||||
@ -61,26 +58,26 @@ def inference_ctd(left, right, gt_disp, mask, model, epoch, n_iter=20):
|
|||||||
imgR = right.cpu().detach().numpy()
|
imgR = right.cpu().detach().numpy()
|
||||||
imgL = np.ascontiguousarray(imgL[None, :, :, :])
|
imgL = np.ascontiguousarray(imgL[None, :, :, :])
|
||||||
imgR = np.ascontiguousarray(imgR[None, :, :, :])
|
imgR = np.ascontiguousarray(imgR[None, :, :, :])
|
||||||
|
|
||||||
# chosen for convenience
|
# chosen for convenience
|
||||||
device = torch.device('cuda:0')
|
device = torch.device('cuda:0')
|
||||||
|
|
||||||
imgL = torch.tensor(imgL.astype("float32")).to(device)
|
imgL = torch.tensor(imgL.astype("float32")).to(device)
|
||||||
imgR = torch.tensor(imgR.astype("float32")).to(device)
|
imgR = torch.tensor(imgR.astype("float32")).to(device)
|
||||||
imgL = imgL.transpose(2,3).transpose(1,2)
|
imgL = imgL.transpose(2, 3).transpose(1, 2)
|
||||||
|
|
||||||
imgL_dw2 = F.interpolate(
|
imgL_dw2 = F.interpolate(
|
||||||
imgL,
|
imgL,
|
||||||
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
||||||
mode="bilinear",
|
mode="bilinear",
|
||||||
align_corners=True,
|
align_corners=True,
|
||||||
)
|
)
|
||||||
imgR_dw2 = F.interpolate(
|
imgR_dw2 = F.interpolate(
|
||||||
imgR,
|
imgR,
|
||||||
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
||||||
mode="bilinear",
|
mode="bilinear",
|
||||||
align_corners=True,
|
align_corners=True,
|
||||||
)
|
)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
|
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
|
||||||
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
|
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
|
||||||
@ -89,160 +86,158 @@ def inference_ctd(left, right, gt_disp, mask, model, epoch, n_iter=20):
|
|||||||
for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)):
|
for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)):
|
||||||
pred_disp = torch.squeeze(pf[:, 0, :, :]).cpu().detach().numpy()
|
pred_disp = torch.squeeze(pf[:, 0, :, :]).cpu().detach().numpy()
|
||||||
pred_disp_dw2 = torch.squeeze(pf_dw2[:, 0, :, :]).cpu().detach().numpy()
|
pred_disp_dw2 = torch.squeeze(pf_dw2[:, 0, :, :]).cpu().detach().numpy()
|
||||||
|
|
||||||
pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
|
pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
|
||||||
pred_disp_dw2_norm = cv2.normalize(pred_disp_dw2, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
|
pred_disp_dw2_norm = cv2.normalize(pred_disp_dw2, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
|
||||||
|
|
||||||
log[f'pred_{i}'] = wandb.Image(
|
log[f'pred_{i}'] = wandb.Image(
|
||||||
np.array([pred_disp.reshape(480, 640)]),
|
np.array([pred_disp.reshape(480, 640)]),
|
||||||
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
|
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
|
||||||
)
|
)
|
||||||
log[f'pred_norm_{i}'] = wandb.Image(
|
log[f'pred_norm_{i}'] = wandb.Image(
|
||||||
np.array([pred_disp_norm.reshape(480, 640)]),
|
np.array([pred_disp_norm.reshape(480, 640)]),
|
||||||
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
|
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
|
||||||
)
|
)
|
||||||
|
|
||||||
log[f'pred_dw2_{i}'] = wandb.Image(
|
log[f'pred_dw2_{i}'] = wandb.Image(
|
||||||
np.array([pred_disp_dw2.reshape(240, 320)]),
|
np.array([pred_disp_dw2.reshape(240, 320)]),
|
||||||
caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
|
caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
|
||||||
)
|
)
|
||||||
log[f'pred_dw2_norm_{i}'] = wandb.Image(
|
log[f'pred_dw2_norm_{i}'] = wandb.Image(
|
||||||
np.array([pred_disp_dw2_norm.reshape(240, 320)]),
|
np.array([pred_disp_dw2_norm.reshape(240, 320)]),
|
||||||
caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
|
caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left")
|
log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left")
|
||||||
log['input_right'] = wandb.Image(right.cpu().detach().numpy().transpose(1,2,0).astype('uint8'), caption="Input Right")
|
log['input_right'] = wandb.Image(right.cpu().detach().numpy().transpose(1, 2, 0).astype('uint8'),
|
||||||
|
caption="Input Right")
|
||||||
|
|
||||||
log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}")
|
log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}")
|
||||||
|
|
||||||
|
|
||||||
disp_error = gt_disp - disp
|
disp_error = gt_disp - disp
|
||||||
log['disp_error'] = wandb.Image(
|
log['disp_error'] = wandb.Image(
|
||||||
normalize_and_colormap(disp_error),
|
normalize_and_colormap(disp_error),
|
||||||
caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{disp_error.mean():.{2}f}",
|
caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{disp_error.mean():.{2}f}",
|
||||||
)
|
)
|
||||||
|
|
||||||
wandb.log(log)
|
wandb.log(log)
|
||||||
|
|
||||||
|
|
||||||
def do_infer(left_img, right_img, gt_disp, model):
|
def do_infer(left_img, right_img, gt_disp, model):
|
||||||
|
in_h, in_w = left_img.shape[:2]
|
||||||
|
|
||||||
|
# Resize image in case the GPU memory overflows
|
||||||
|
eval_h, eval_w = (in_h, in_w)
|
||||||
|
|
||||||
|
# FIXME borked for some reason, hopefully not very important
|
||||||
|
|
||||||
|
imgL = left_img.cpu().detach().numpy() if isinstance(left_img, torch.Tensor) else left_img
|
||||||
|
imgR = right_img.cpu().detach().numpy() if isinstance(right_img, torch.Tensor) else right_img
|
||||||
|
|
||||||
|
imgL = cv2.resize(imgL, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR)
|
||||||
|
imgR = cv2.resize(imgR, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
|
# pred = ctd_inference(imgL, imgR, gt_disp, None, model, None, n_iter=20)
|
||||||
|
pred = ctd_inference(left_img, right_img, gt_disp, None, model, None, n_iter=20, wandb_log=False)
|
||||||
|
|
||||||
|
t = float(in_w) / float(eval_w)
|
||||||
|
disp = cv2.resize(pred, (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t
|
||||||
|
|
||||||
|
disp_vis = normalize_and_colormap(disp)
|
||||||
|
gt_disp_vis = normalize_and_colormap(gt_disp)
|
||||||
|
if gt_disp.shape != disp.shape:
|
||||||
|
gt_disp = gt_disp.reshape(disp.shape)
|
||||||
|
disp_err = gt_disp - disp
|
||||||
|
disp_err = normalize_and_colormap(disp_err.abs())
|
||||||
|
|
||||||
|
wandb.log({
|
||||||
|
'disp_vis': wandb.Image(
|
||||||
|
disp_vis,
|
||||||
|
caption=f"Pred. Disparity \n{disp.min():.{2}f}/{disp.max():.{2}f}",
|
||||||
|
),
|
||||||
|
'gt_disp_vis': wandb.Image(
|
||||||
|
gt_disp_vis,
|
||||||
|
caption=f"GT Disparity \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}",
|
||||||
|
),
|
||||||
|
'disp_err': wandb.Image(
|
||||||
|
disp_err,
|
||||||
|
caption=f"Disparity Error\n{disp_err.min():.{2}f}/{disp_err.max():.{2}f}",
|
||||||
|
),
|
||||||
|
'input_left': wandb.Image(
|
||||||
|
left_img.cpu().detach().numpy().astype('uint8'),
|
||||||
|
caption=f"Input left",
|
||||||
|
),
|
||||||
|
'input_right': wandb.Image(
|
||||||
|
right_img.cpu().detach().numpy().astype('uint8'),
|
||||||
|
caption=f"Input right",
|
||||||
|
),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# model_path = "models/crestereo_eth3d.pth"
|
||||||
|
model_path = "train_log/models/latest.pth"
|
||||||
|
|
||||||
|
# reference_pattern_path = '/home/nils/kinect_reference_high_res_scaled_down.png'
|
||||||
|
reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
|
||||||
|
# reference_pattern_path = '/home/nils/new_reference.png'
|
||||||
|
# reference_pattern_path = '/home/nils/kinect_reference_high_res.png'
|
||||||
|
# reference_pattern_path = '/home/nils/orig_ctd/connecting_the_dots/data/kinect_pattern.png'
|
||||||
|
|
||||||
|
data_type = 'kinect'
|
||||||
|
augment = False
|
||||||
|
|
||||||
|
args = parse_yaml("cfgs/train.yaml")
|
||||||
|
|
||||||
|
wandb.config.update({'model_path': model_path, 'reference_pattern': reference_pattern_path, 'augment': augment})
|
||||||
|
|
||||||
|
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
|
||||||
|
model = nn.DataParallel(model, device_ids=[device])
|
||||||
|
# model.load_state_dict(torch.load(model_path), strict=False)
|
||||||
|
state_dict = torch.load(model_path)['state_dict']
|
||||||
|
model.load_state_dict(state_dict, strict=True)
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
CTD = True
|
||||||
|
if not CTD:
|
||||||
|
left_img = cv2.imread("../test_imgs/left.png")
|
||||||
|
right_img = cv2.imread("../test_imgs/right.png")
|
||||||
in_h, in_w = left_img.shape[:2]
|
in_h, in_w = left_img.shape[:2]
|
||||||
|
|
||||||
# Resize image in case the GPU memory overflows
|
# Resize image in case the GPU memory overflows
|
||||||
eval_h, eval_w = (in_h,in_w)
|
eval_h, eval_w = (in_h, in_w)
|
||||||
|
|
||||||
# FIXME borked for some reason, hopefully not very important
|
# FIXME borked for some reason, hopefully not very important
|
||||||
|
imgL = cv2.resize(left_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR)
|
||||||
imgL = left_img.cpu().detach().numpy() if isinstance(left_img, torch.Tensor) else left_img
|
imgR = cv2.resize(right_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR)
|
||||||
imgR = right_img.cpu().detach().numpy() if isinstance(right_img, torch.Tensor) else right_img
|
|
||||||
|
|
||||||
imgL = cv2.resize(imgL, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR)
|
pred = inference(imgL, imgR, model, n_iter=20)
|
||||||
imgR = cv2.resize(imgR, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR)
|
|
||||||
|
|
||||||
# pred = ctd_inference(imgL, imgR, gt_disp, None, model, None, n_iter=20)
|
|
||||||
pred = ctd_inference(left_img, right_img, gt_disp, None, model, None, n_iter=20, wandb_log=False)
|
|
||||||
|
|
||||||
t = float(in_w) / float(eval_w)
|
t = float(in_w) / float(eval_w)
|
||||||
disp = cv2.resize(pred, (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t
|
disp = cv2.resize(pred, (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t
|
||||||
|
|
||||||
disp_vis = normalize_and_colormap(disp)
|
disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0
|
||||||
gt_disp_vis = normalize_and_colormap(gt_disp)
|
disp_vis = disp_vis.astype("uint8")
|
||||||
if gt_disp.shape != disp.shape:
|
disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO)
|
||||||
gt_disp = gt_disp.reshape(disp.shape)
|
|
||||||
disp_err = gt_disp - disp
|
|
||||||
disp_err = normalize_and_colormap(disp_err.abs())
|
|
||||||
|
|
||||||
wandb.log({
|
combined_img = np.hstack((left_img, disp_vis))
|
||||||
'disp_vis': wandb.Image(
|
# cv2.namedWindow("output", cv2.WINDOW_NORMAL)
|
||||||
disp_vis,
|
# cv2.imshow("output", combined_img)
|
||||||
caption=f"Pred. Disparity \n{disp.min():.{2}f}/{disp.max():.{2}f}",
|
cv2.imwrite("output.jpg", disp_vis)
|
||||||
),
|
# cv2.waitKey(0)
|
||||||
'gt_disp_vis': wandb.Image(
|
|
||||||
gt_disp_vis,
|
|
||||||
caption=f"GT Disparity \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}",
|
|
||||||
),
|
|
||||||
'disp_err': wandb.Image(
|
|
||||||
disp_err,
|
|
||||||
caption=f"Disparity Error\n{disp_err.min():.{2}f}/{disp_err.max():.{2}f}",
|
|
||||||
),
|
|
||||||
'input_left': wandb.Image(
|
|
||||||
left_img.cpu().detach().numpy().astype('uint8'),
|
|
||||||
caption=f"Input left",
|
|
||||||
),
|
|
||||||
'input_right': wandb.Image(
|
|
||||||
right_img.cpu().detach().numpy().astype('uint8'),
|
|
||||||
caption=f"Input right",
|
|
||||||
),
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# model_path = "models/crestereo_eth3d.pth"
|
|
||||||
model_path = "train_log/models/latest.pth"
|
|
||||||
|
|
||||||
# reference_pattern_path = '/home/nils/kinect_reference_high_res_scaled_down.png'
|
|
||||||
reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
|
|
||||||
# reference_pattern_path = '/home/nils/new_reference.png'
|
|
||||||
# reference_pattern_path = '/home/nils/kinect_reference_high_res.png'
|
|
||||||
# reference_pattern_path = '/home/nils/orig_ctd/connecting_the_dots/data/kinect_pattern.png'
|
|
||||||
|
|
||||||
data_type = 'kinect'
|
|
||||||
augment = False
|
|
||||||
|
|
||||||
args = parse_yaml("cfgs/train.yaml")
|
|
||||||
|
|
||||||
wandb.config.update({'model_path': model_path, 'reference_pattern': reference_pattern_path, 'augment': augment})
|
|
||||||
|
|
||||||
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
|
|
||||||
model = nn.DataParallel(model,device_ids=[device])
|
|
||||||
# model.load_state_dict(torch.load(model_path), strict=False)
|
|
||||||
state_dict = torch.load(model_path)['state_dict']
|
|
||||||
model.load_state_dict(state_dict, strict=True)
|
|
||||||
model.to(device)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
CTD = True
|
|
||||||
if not CTD:
|
|
||||||
left_img = cv2.imread("../test_imgs/left.png")
|
|
||||||
right_img = cv2.imread("../test_imgs/right.png")
|
|
||||||
in_h, in_w = left_img.shape[:2]
|
|
||||||
|
|
||||||
# Resize image in case the GPU memory overflows
|
|
||||||
eval_h, eval_w = (in_h,in_w)
|
|
||||||
|
|
||||||
# FIXME borked for some reason, hopefully not very important
|
|
||||||
imgL = cv2.resize(left_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR)
|
|
||||||
imgR = cv2.resize(right_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR)
|
|
||||||
|
|
||||||
pred = inference(imgL, imgR, model, n_iter=20)
|
|
||||||
|
|
||||||
t = float(in_w) / float(eval_w)
|
|
||||||
disp = cv2.resize(pred, (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t
|
|
||||||
|
|
||||||
disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0
|
|
||||||
disp_vis = disp_vis.astype("uint8")
|
|
||||||
disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO)
|
|
||||||
|
|
||||||
combined_img = np.hstack((left_img, disp_vis))
|
|
||||||
# cv2.namedWindow("output", cv2.WINDOW_NORMAL)
|
|
||||||
# cv2.imshow("output", combined_img)
|
|
||||||
cv2.imwrite("output.jpg", disp_vis)
|
|
||||||
# cv2.waitKey(0)
|
|
||||||
|
|
||||||
else:
|
|
||||||
dataset = CTDDataset('/media/Data1/connecting_the_dots_data/ctd_data/', data_type=data_type, pattern_path=reference_pattern_path, augment=augment)
|
|
||||||
dataloader = DataLoader(dataset, args.batch_size, shuffle=True,
|
|
||||||
num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True)
|
|
||||||
for batch in dataloader:
|
|
||||||
for left, right, disparity in zip(batch['left'], batch['right'], batch['disparity']):
|
|
||||||
right = right.transpose(0, 2).transpose(0, 1)
|
|
||||||
left_img = left
|
|
||||||
imgL = left.cpu().detach().numpy()
|
|
||||||
right_img = right
|
|
||||||
imgR = right.cpu().detach().numpy()
|
|
||||||
gt_disp = disparity
|
|
||||||
do_infer(left_img, right_img, gt_disp, model)
|
|
||||||
|
|
||||||
|
else:
|
||||||
|
dataset = CTDDataset('/media/Data1/connecting_the_dots_data/ctd_data/', data_type=data_type,
|
||||||
|
pattern_path=reference_pattern_path, augment=augment)
|
||||||
|
dataloader = DataLoader(dataset, args.batch_size, shuffle=True,
|
||||||
|
num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True)
|
||||||
|
for batch in dataloader:
|
||||||
|
for left, right, disparity in zip(batch['left'], batch['right'], batch['disparity']):
|
||||||
|
right = right.transpose(0, 2).transpose(0, 1)
|
||||||
|
left_img = left
|
||||||
|
imgL = left.cpu().detach().numpy()
|
||||||
|
right_img = right
|
||||||
|
imgR = right.cpu().detach().numpy()
|
||||||
|
gt_disp = disparity
|
||||||
|
do_infer(left_img, right_img, gt_disp, model)
|
||||||
|
Loading…
Reference in New Issue
Block a user