|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import numpy as np
|
|
|
|
import cv2
|
|
|
|
import os
|
|
|
|
|
|
|
|
from nets import Model
|
|
|
|
|
|
|
|
import wandb
|
|
|
|
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
from dataset import CTDDataset
|
|
|
|
from train import normalize_and_colormap, parse_yaml, inference as ctd_inference
|
|
|
|
|
|
|
|
device = 'cuda'
|
|
|
|
wandb.init(project="crestereo", entity="cpt-captain")
|
|
|
|
|
|
|
|
|
|
|
|
def do_infer(left_img, right_img, gt_disp, model, attend_pattern=True):
|
|
|
|
disp = ctd_inference(left_img, right_img, gt_disp, None, model, None, n_iter=20, wandb_log=False, attend_pattern=attend_pattern)
|
|
|
|
|
|
|
|
disp_vis = normalize_and_colormap(disp)
|
|
|
|
# gt_disp_vis = normalize_and_colormap(gt_disp)
|
|
|
|
# if gt_disp.shape != disp.shape:
|
|
|
|
# gt_disp = gt_disp.reshape(disp.shape)
|
|
|
|
# disp_err = gt_disp - disp
|
|
|
|
# disp_err = normalize_and_colormap(disp_err.abs())
|
|
|
|
if isinstance(left_img, torch.Tensor):
|
|
|
|
left_img = left_img.cpu().detach().numpy().astype('uint8')
|
|
|
|
right_img = right_img.cpu().detach().numpy().astype('uint8')
|
|
|
|
|
|
|
|
results = {
|
|
|
|
'disp': wandb.Image(
|
|
|
|
disp,
|
|
|
|
caption=f"Pred. Disparity \n{disp.min():.{2}f}/{disp.max():.{2}f}",
|
|
|
|
),
|
|
|
|
'disp_vis': wandb.Image(
|
|
|
|
disp_vis,
|
|
|
|
caption=f"Pred. Disparity \n{disp.min():.{2}f}/{disp.max():.{2}f}",
|
|
|
|
),
|
|
|
|
# 'disp_err': wandb.Image(
|
|
|
|
# disp_err,
|
|
|
|
# caption=f"Disparity Error\n{disp_err.min():.{2}f}/{disp_err.max():.{2}f}",
|
|
|
|
# ),
|
|
|
|
'input_left': wandb.Image(
|
|
|
|
left_img,
|
|
|
|
caption=f"Input left",
|
|
|
|
),
|
|
|
|
'input_right': wandb.Image(
|
|
|
|
right_img,
|
|
|
|
caption=f"Input right",
|
|
|
|
),
|
|
|
|
}
|
|
|
|
|
|
|
|
if gt_disp is not None:
|
|
|
|
print('logging gt')
|
|
|
|
print(f'gt: {gt_disp.max()}/{gt_disp.min()}/{gt_disp.mean()}')
|
|
|
|
gt_disp_vis = normalize_and_colormap(gt_disp)
|
|
|
|
results.update({
|
|
|
|
'gt_disp_vis': wandb.Image(
|
|
|
|
gt_disp_vis,
|
|
|
|
caption=f"GT Disparity \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}",
|
|
|
|
)})
|
|
|
|
wandb.log(results)
|
|
|
|
|
|
|
|
|
|
|
|
def downsample(img, half_height_out=480):
|
|
|
|
downsampled = cv2.pyrDown(img)
|
|
|
|
diff = (downsampled.shape[0] - half_height_out) // 2
|
|
|
|
return downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]]
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
# model_path = "models/crestereo_eth3d.pth"
|
|
|
|
model_path = "train_log/models/latest.pth"
|
|
|
|
# model_path = "train_log/models/epoch-120.pth"
|
|
|
|
# model_path = "train_log/models/epoch-250.pth"
|
|
|
|
|
|
|
|
print(model_path)
|
|
|
|
# reference_pattern_path = '/home/nils/kinect_reference_high_res_scaled_down.png'
|
|
|
|
# reference_pattern_path = '/home/nils/new_reference.png'
|
|
|
|
# reference_pattern_path = '/home/nils/kinect_reference_high_res.png'
|
|
|
|
reference_pattern_path = '/home/nils/miniprojekt/kinect_high_res_thresh_denoised.png'
|
|
|
|
# reference_pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
|
|
|
|
# reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
|
|
|
|
# reference_pattern_path = '/home/nils/orig_ctd/connecting_the_dots/data/kinect_pattern.png'
|
|
|
|
|
|
|
|
# data_type = 'kinect'
|
|
|
|
data_type = 'blender'
|
|
|
|
augment = False
|
|
|
|
|
|
|
|
args = parse_yaml("cfgs/train.yaml")
|
|
|
|
|
|
|
|
wandb.config.update({'model_path': model_path, 'reference_pattern': reference_pattern_path, 'augment': augment, 'data_type': data_type, 'pattern_self_attention': args.pattern_attention})
|
|
|
|
|
|
|
|
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
|
|
|
|
model = nn.DataParallel(model, device_ids=[device])
|
|
|
|
state_dict = torch.load(model_path)['state_dict']
|
|
|
|
model.load_state_dict(state_dict, strict=True)
|
|
|
|
model.to(device)
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
# dataset = CTDDataset('/media/Data1/connecting_the_dots_data/ctd_data/', data_type=data_type,
|
|
|
|
# pattern_path=reference_pattern_path, augment=augment)
|
|
|
|
# dataloader = DataLoader(dataset, args.batch_size, shuffle=True,
|
|
|
|
# num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True)
|
|
|
|
|
|
|
|
# for batch in dataloader:
|
|
|
|
gt_disp = None
|
|
|
|
right = downsample(cv2.imread(reference_pattern_path))
|
|
|
|
|
|
|
|
if data_type == 'blender':
|
|
|
|
img_path = '/media/Data1/connecting_the_dots_data/blender_renders/data/'
|
|
|
|
elif data_type == 'kinect':
|
|
|
|
img_path = '/home/nils/kinect_pngs/ir/'
|
|
|
|
|
|
|
|
for img in sorted(list(entry for entry in os.scandir(img_path) if 'depth' not in entry.name), key=lambda x:x.name)[:25]:
|
|
|
|
print(img.path)
|
|
|
|
if data_type == 'blender':
|
|
|
|
baseline = 0.075 # meters
|
|
|
|
fl = 560. # as per CTD
|
|
|
|
|
|
|
|
gt_path = img.path.rsplit('.')[0] + '_depth0001.png'
|
|
|
|
gt_depth = downsample(cv2.imread(gt_path))
|
|
|
|
|
|
|
|
mystery_factor = 35 # we don't get reasonable disparities due to incorrect depth scaling (or something like that)
|
|
|
|
gt_disp = (baseline * fl * mystery_factor) / gt_depth
|
|
|
|
|
|
|
|
left = downsample(cv2.imread(img.path))
|
|
|
|
|
|
|
|
do_infer(left, right, gt_disp, model, attend_pattern=args.pattern_attention)
|