import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import cv2 import os from nets import Model import wandb from torch.utils.data import DataLoader from dataset import CTDDataset from train import normalize_and_colormap, parse_yaml, inference as ctd_inference device = 'cuda' wandb.init(project="crestereo", entity="cpt-captain") def do_infer(left_img, right_img, gt_disp, model, attend_pattern=True): disp = ctd_inference(left_img, right_img, gt_disp, None, model, None, n_iter=20, wandb_log=False, attend_pattern=attend_pattern) disp_vis = normalize_and_colormap(disp) # gt_disp_vis = normalize_and_colormap(gt_disp) # if gt_disp.shape != disp.shape: # gt_disp = gt_disp.reshape(disp.shape) # disp_err = gt_disp - disp # disp_err = normalize_and_colormap(disp_err.abs()) if isinstance(left_img, torch.Tensor): left_img = left_img.cpu().detach().numpy().astype('uint8') right_img = right_img.cpu().detach().numpy().astype('uint8') results = { 'disp': wandb.Image( disp, caption=f"Pred. Disparity \n{disp.min():.{2}f}/{disp.max():.{2}f}", ), 'disp_vis': wandb.Image( disp_vis, caption=f"Pred. Disparity \n{disp.min():.{2}f}/{disp.max():.{2}f}", ), # 'disp_err': wandb.Image( # disp_err, # caption=f"Disparity Error\n{disp_err.min():.{2}f}/{disp_err.max():.{2}f}", # ), 'input_left': wandb.Image( left_img, caption=f"Input left", ), 'input_right': wandb.Image( right_img, caption=f"Input right", ), } if gt_disp is not None: print('logging gt') print(f'gt: {gt_disp.max()}/{gt_disp.min()}/{gt_disp.mean()}') gt_disp_vis = normalize_and_colormap(gt_disp) results.update({ 'gt_disp_vis': wandb.Image( gt_disp_vis, caption=f"GT Disparity \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}", )}) wandb.log(results) def downsample(img, half_height_out=480): downsampled = cv2.pyrDown(img) diff = (downsampled.shape[0] - half_height_out) // 2 return downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]] if __name__ == '__main__': # model_path = "models/crestereo_eth3d.pth" model_path = "train_log/models/latest.pth" # model_path = "train_log/models/epoch-120.pth" # model_path = "train_log/models/epoch-250.pth" print(model_path) # reference_pattern_path = '/home/nils/kinect_reference_high_res_scaled_down.png' # reference_pattern_path = '/home/nils/new_reference.png' # reference_pattern_path = '/home/nils/kinect_reference_high_res.png' reference_pattern_path = '/home/nils/miniprojekt/kinect_high_res_thresh_denoised.png' # reference_pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png' # reference_pattern_path = '/home/nils/kinect_reference_cropped.png' # reference_pattern_path = '/home/nils/orig_ctd/connecting_the_dots/data/kinect_pattern.png' # data_type = 'kinect' data_type = 'blender' augment = False args = parse_yaml("cfgs/train.yaml") wandb.config.update({'model_path': model_path, 'reference_pattern': reference_pattern_path, 'augment': augment, 'data_type': data_type, 'pattern_self_attention': args.pattern_attention}) model = Model(max_disp=256, mixed_precision=False, test_mode=True) model = nn.DataParallel(model, device_ids=[device]) state_dict = torch.load(model_path)['state_dict'] model.load_state_dict(state_dict, strict=True) model.to(device) model.eval() # dataset = CTDDataset('/media/Data1/connecting_the_dots_data/ctd_data/', data_type=data_type, # pattern_path=reference_pattern_path, augment=augment) # dataloader = DataLoader(dataset, args.batch_size, shuffle=True, # num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True) # for batch in dataloader: gt_disp = None right = downsample(cv2.imread(reference_pattern_path)) if data_type == 'blender': img_path = '/media/Data1/connecting_the_dots_data/blender_renders/data/' elif data_type == 'kinect': img_path = '/home/nils/kinect_pngs/ir/' for img in sorted(list(entry for entry in os.scandir(img_path) if 'depth' not in entry.name), key=lambda x:x.name)[:25]: print(img.path) if data_type == 'blender': baseline = 0.075 # meters fl = 560. # as per CTD gt_path = img.path.rsplit('.')[0] + '_depth0001.png' gt_depth = downsample(cv2.imread(gt_path)) mystery_factor = 35 # we don't get reasonable disparities due to incorrect depth scaling (or something like that) gt_disp = (baseline * fl * mystery_factor) / gt_depth left = downsample(cv2.imread(img.path)) do_infer(left, right, gt_disp, model, attend_pattern=args.pattern_attention)