|
|
|
@ -17,10 +17,8 @@ device = 'cuda' |
|
|
|
|
wandb.init(project="crestereo", entity="cpt-captain") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#Ref: https://github.com/megvii-research/CREStereo/blob/master/test.py |
|
|
|
|
# Ref: https://github.com/megvii-research/CREStereo/blob/master/test.py |
|
|
|
|
def inference(left, right, model, n_iter=20): |
|
|
|
|
|
|
|
|
|
print("Model Forwarding...") |
|
|
|
|
imgL = left.transpose(2, 0, 1) |
|
|
|
|
imgR = right.transpose(2, 0, 1) |
|
|
|
@ -53,7 +51,6 @@ def inference(left, right, model, n_iter=20): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def inference_ctd(left, right, gt_disp, mask, model, epoch, n_iter=20): |
|
|
|
|
|
|
|
|
|
print("Model Forwarding...") |
|
|
|
|
# print(left.shape) |
|
|
|
|
left = left.cpu().detach().numpy() |
|
|
|
@ -67,7 +64,7 @@ def inference_ctd(left, right, gt_disp, mask, model, epoch, n_iter=20): |
|
|
|
|
|
|
|
|
|
imgL = torch.tensor(imgL.astype("float32")).to(device) |
|
|
|
|
imgR = torch.tensor(imgR.astype("float32")).to(device) |
|
|
|
|
imgL = imgL.transpose(2,3).transpose(1,2) |
|
|
|
|
imgL = imgL.transpose(2, 3).transpose(1, 2) |
|
|
|
|
|
|
|
|
|
imgL_dw2 = F.interpolate( |
|
|
|
|
imgL, |
|
|
|
@ -111,13 +108,12 @@ def inference_ctd(left, right, gt_disp, mask, model, epoch, n_iter=20): |
|
|
|
|
caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}", |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left") |
|
|
|
|
log['input_right'] = wandb.Image(right.cpu().detach().numpy().transpose(1,2,0).astype('uint8'), caption="Input Right") |
|
|
|
|
log['input_right'] = wandb.Image(right.cpu().detach().numpy().transpose(1, 2, 0).astype('uint8'), |
|
|
|
|
caption="Input Right") |
|
|
|
|
|
|
|
|
|
log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
disp_error = gt_disp - disp |
|
|
|
|
log['disp_error'] = wandb.Image( |
|
|
|
|
normalize_and_colormap(disp_error), |
|
|
|
@ -131,7 +127,7 @@ def do_infer(left_img, right_img, gt_disp, model): |
|
|
|
|
in_h, in_w = left_img.shape[:2] |
|
|
|
|
|
|
|
|
|
# Resize image in case the GPU memory overflows |
|
|
|
|
eval_h, eval_w = (in_h,in_w) |
|
|
|
|
eval_h, eval_w = (in_h, in_w) |
|
|
|
|
|
|
|
|
|
# FIXME borked for some reason, hopefully not very important |
|
|
|
|
|
|
|
|
@ -178,7 +174,6 @@ def do_infer(left_img, right_img, gt_disp, model): |
|
|
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
# model_path = "models/crestereo_eth3d.pth" |
|
|
|
|
model_path = "train_log/models/latest.pth" |
|
|
|
@ -197,7 +192,7 @@ if __name__ == '__main__': |
|
|
|
|
wandb.config.update({'model_path': model_path, 'reference_pattern': reference_pattern_path, 'augment': augment}) |
|
|
|
|
|
|
|
|
|
model = Model(max_disp=256, mixed_precision=False, test_mode=True) |
|
|
|
|
model = nn.DataParallel(model,device_ids=[device]) |
|
|
|
|
model = nn.DataParallel(model, device_ids=[device]) |
|
|
|
|
# model.load_state_dict(torch.load(model_path), strict=False) |
|
|
|
|
state_dict = torch.load(model_path)['state_dict'] |
|
|
|
|
model.load_state_dict(state_dict, strict=True) |
|
|
|
@ -211,7 +206,7 @@ if __name__ == '__main__': |
|
|
|
|
in_h, in_w = left_img.shape[:2] |
|
|
|
|
|
|
|
|
|
# Resize image in case the GPU memory overflows |
|
|
|
|
eval_h, eval_w = (in_h,in_w) |
|
|
|
|
eval_h, eval_w = (in_h, in_w) |
|
|
|
|
|
|
|
|
|
# FIXME borked for some reason, hopefully not very important |
|
|
|
|
imgL = cv2.resize(left_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR) |
|
|
|
@ -233,7 +228,8 @@ if __name__ == '__main__': |
|
|
|
|
# cv2.waitKey(0) |
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
dataset = CTDDataset('/media/Data1/connecting_the_dots_data/ctd_data/', data_type=data_type, pattern_path=reference_pattern_path, augment=augment) |
|
|
|
|
dataset = CTDDataset('/media/Data1/connecting_the_dots_data/ctd_data/', data_type=data_type, |
|
|
|
|
pattern_path=reference_pattern_path, augment=augment) |
|
|
|
|
dataloader = DataLoader(dataset, args.batch_size, shuffle=True, |
|
|
|
|
num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True) |
|
|
|
|
for batch in dataloader: |
|
|
|
@ -245,4 +241,3 @@ if __name__ == '__main__': |
|
|
|
|
imgR = right.cpu().detach().numpy() |
|
|
|
|
gt_disp = disparity |
|
|
|
|
do_infer(left_img, right_img, gt_disp, model) |
|
|
|
|
|
|
|
|
|