|
|
@ -8,7 +8,6 @@ from nets import Model |
|
|
|
|
|
|
|
|
|
|
|
import wandb |
|
|
|
import wandb |
|
|
|
|
|
|
|
|
|
|
|
import random |
|
|
|
|
|
|
|
from torch.utils.data import DataLoader |
|
|
|
from torch.utils.data import DataLoader |
|
|
|
from dataset import CTDDataset |
|
|
|
from dataset import CTDDataset |
|
|
|
from train import normalize_and_colormap, parse_yaml, inference as ctd_inference |
|
|
|
from train import normalize_and_colormap, parse_yaml, inference as ctd_inference |
|
|
@ -79,35 +78,6 @@ if __name__ == '__main__': |
|
|
|
model.to(device) |
|
|
|
model.to(device) |
|
|
|
model.eval() |
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
CTD = True |
|
|
|
|
|
|
|
if not CTD: |
|
|
|
|
|
|
|
left_img = cv2.imread("../test_imgs/left.png") |
|
|
|
|
|
|
|
right_img = cv2.imread("../test_imgs/right.png") |
|
|
|
|
|
|
|
in_h, in_w = left_img.shape[:2] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Resize image in case the GPU memory overflows |
|
|
|
|
|
|
|
eval_h, eval_w = (in_h, in_w) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# FIXME borked for some reason, hopefully not very important |
|
|
|
|
|
|
|
imgL = cv2.resize(left_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR) |
|
|
|
|
|
|
|
imgR = cv2.resize(right_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pred = inference(imgL, imgR, model, n_iter=20) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
t = float(in_w) / float(eval_w) |
|
|
|
|
|
|
|
disp = cv2.resize(pred, (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0 |
|
|
|
|
|
|
|
disp_vis = disp_vis.astype("uint8") |
|
|
|
|
|
|
|
disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
combined_img = np.hstack((left_img, disp_vis)) |
|
|
|
|
|
|
|
# cv2.namedWindow("output", cv2.WINDOW_NORMAL) |
|
|
|
|
|
|
|
# cv2.imshow("output", combined_img) |
|
|
|
|
|
|
|
cv2.imwrite("output.jpg", disp_vis) |
|
|
|
|
|
|
|
# cv2.waitKey(0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
dataset = CTDDataset('/media/Data1/connecting_the_dots_data/ctd_data/', data_type=data_type, |
|
|
|
dataset = CTDDataset('/media/Data1/connecting_the_dots_data/ctd_data/', data_type=data_type, |
|
|
|
pattern_path=reference_pattern_path, augment=augment) |
|
|
|
pattern_path=reference_pattern_path, augment=augment) |
|
|
|
dataloader = DataLoader(dataset, args.batch_size, shuffle=True, |
|
|
|
dataloader = DataLoader(dataset, args.batch_size, shuffle=True, |
|
|
|