test_model.py: cleanup
This commit is contained in:
parent
7e0305ed91
commit
6ba2cd9e5d
@ -8,7 +8,6 @@ from nets import Model
|
|||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
import random
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from dataset import CTDDataset
|
from dataset import CTDDataset
|
||||||
from train import normalize_and_colormap, parse_yaml, inference as ctd_inference
|
from train import normalize_and_colormap, parse_yaml, inference as ctd_inference
|
||||||
@ -79,45 +78,16 @@ if __name__ == '__main__':
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
CTD = True
|
dataset = CTDDataset('/media/Data1/connecting_the_dots_data/ctd_data/', data_type=data_type,
|
||||||
if not CTD:
|
pattern_path=reference_pattern_path, augment=augment)
|
||||||
left_img = cv2.imread("../test_imgs/left.png")
|
dataloader = DataLoader(dataset, args.batch_size, shuffle=True,
|
||||||
right_img = cv2.imread("../test_imgs/right.png")
|
num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True)
|
||||||
in_h, in_w = left_img.shape[:2]
|
for batch in dataloader:
|
||||||
|
for left, right, disparity in zip(batch['left'], batch['right'], batch['disparity']):
|
||||||
# Resize image in case the GPU memory overflows
|
right = right.transpose(0, 2).transpose(0, 1)
|
||||||
eval_h, eval_w = (in_h, in_w)
|
left_img = left
|
||||||
|
imgL = left.cpu().detach().numpy()
|
||||||
# FIXME borked for some reason, hopefully not very important
|
right_img = right
|
||||||
imgL = cv2.resize(left_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR)
|
imgR = right.cpu().detach().numpy()
|
||||||
imgR = cv2.resize(right_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR)
|
gt_disp = disparity
|
||||||
|
do_infer(left_img, right_img, gt_disp, model)
|
||||||
pred = inference(imgL, imgR, model, n_iter=20)
|
|
||||||
|
|
||||||
t = float(in_w) / float(eval_w)
|
|
||||||
disp = cv2.resize(pred, (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t
|
|
||||||
|
|
||||||
disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0
|
|
||||||
disp_vis = disp_vis.astype("uint8")
|
|
||||||
disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO)
|
|
||||||
|
|
||||||
combined_img = np.hstack((left_img, disp_vis))
|
|
||||||
# cv2.namedWindow("output", cv2.WINDOW_NORMAL)
|
|
||||||
# cv2.imshow("output", combined_img)
|
|
||||||
cv2.imwrite("output.jpg", disp_vis)
|
|
||||||
# cv2.waitKey(0)
|
|
||||||
|
|
||||||
else:
|
|
||||||
dataset = CTDDataset('/media/Data1/connecting_the_dots_data/ctd_data/', data_type=data_type,
|
|
||||||
pattern_path=reference_pattern_path, augment=augment)
|
|
||||||
dataloader = DataLoader(dataset, args.batch_size, shuffle=True,
|
|
||||||
num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True)
|
|
||||||
for batch in dataloader:
|
|
||||||
for left, right, disparity in zip(batch['left'], batch['right'], batch['disparity']):
|
|
||||||
right = right.transpose(0, 2).transpose(0, 1)
|
|
||||||
left_img = left
|
|
||||||
imgL = left.cpu().detach().numpy()
|
|
||||||
right_img = right
|
|
||||||
imgR = right.cpu().detach().numpy()
|
|
||||||
gt_disp = disparity
|
|
||||||
do_infer(left_img, right_img, gt_disp, model)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user