remove debug prints and other useless comments

main
Cpt.Captain 2 years ago
parent 6f6ac23175
commit 495617279b
  1. 24
      train_lightning.py

@ -32,20 +32,11 @@ import cv2
def normalize_and_colormap(img, reduce_dynamic_range=False): def normalize_and_colormap(img, reduce_dynamic_range=False):
# print(img.min())
# print(img.max())
# print(img.mean())
ret = (img - img.min()) / (img.max() - img.min()) * 255.0 ret = (img - img.min()) / (img.max() - img.min()) * 255.0
# print(ret.min())
# print(ret.max())
# print(ret.mean())
# FIXME do I need to compress dynamic range somehow or something? # FIXME do I need to compress dynamic range somehow or something?
if reduce_dynamic_range and img.max() > 5*img.mean(): if reduce_dynamic_range and img.max() > 5*img.mean():
ret = (img - img.min()) / (5*img.mean() - img.min()) * 255.0 ret = (img - img.min()) / (5*img.mean() - img.min()) * 255.0
# print(ret.min())
# print(ret.max())
# print(ret.mean())
if isinstance(ret, torch.Tensor): if isinstance(ret, torch.Tensor):
ret = ret.cpu().detach().numpy() ret = ret.cpu().detach().numpy()
@ -66,8 +57,6 @@ def log_images(left, right, pred_disp, gt_disp):
pred_disp = torch.squeeze(pred_disp[:, 0, :, :]) pred_disp = torch.squeeze(pred_disp[:, 0, :, :])
gt_disp = torch.squeeze(gt_disp[:, 0, :, :]) gt_disp = torch.squeeze(gt_disp[:, 0, :, :])
# print('gt_disp debug')
# print(gt_disp.shape)
singular_batch = False singular_batch = False
if len(left.shape) == 2: if len(left.shape) == 2:
@ -76,15 +65,12 @@ def log_images(left, right, pred_disp, gt_disp):
input_left = left.cpu().detach().numpy() input_left = left.cpu().detach().numpy()
input_right = right.cpu().detach().numpy() input_right = right.cpu().detach().numpy()
else: else:
input_left = left[batch_idx].cpu().detach().numpy()# .transpose(1,2,0) input_left = left[batch_idx].cpu().detach().numpy()
input_right = right[batch_idx].cpu().detach().numpy()# .transpose(1,2,0) input_right = right[batch_idx].cpu().detach().numpy()
disp = pred_disp disp = pred_disp
disp_error = gt_disp - disp disp_error = gt_disp - disp
# print('gt_disp debug normalize')
# print(gt_disp.max(), gt_disp.min())
# print(gt_disp.dtype)
if singular_batch: if singular_batch:
wandb_log = dict( wandb_log = dict(
@ -110,7 +96,6 @@ def log_images(left, right, pred_disp, gt_disp):
wandb_log = dict( wandb_log = dict(
key='samples', key='samples',
images=[ images=[
# pred_disp.cpu().detach().numpy().transpose(1,2,0),
normalize_and_colormap(pred_disp[batch_idx]), normalize_and_colormap(pred_disp[batch_idx]),
normalize_and_colormap(abs(disp_error[batch_idx])), normalize_and_colormap(abs(disp_error[batch_idx])),
normalize_and_colormap(gt_disp[batch_idx]), normalize_and_colormap(gt_disp[batch_idx]),
@ -118,7 +103,6 @@ def log_images(left, right, pred_disp, gt_disp):
input_right, input_right,
], ],
caption=[ caption=[
# f"Disparity \n{pred_disp[batch_idx].min():.{2}f}/{pred_disp[batch_idx].max():.{2}f}",
f"Disparity (vis)\n{pred_disp[batch_idx].min():.{2}f}/{pred_disp[batch_idx].max():.{2}f}", f"Disparity (vis)\n{pred_disp[batch_idx].min():.{2}f}/{pred_disp[batch_idx].max():.{2}f}",
f"Disp. Error\n{disp_error[batch_idx].min():.{2}f}/{disp_error[batch_idx].max():.{2}f}\n{abs(disp_error[batch_idx]).mean():.{2}f}", f"Disp. Error\n{disp_error[batch_idx].min():.{2}f}/{disp_error[batch_idx].max():.{2}f}\n{abs(disp_error[batch_idx]).mean():.{2}f}",
f"GT Disp Vis \n{gt_disp[batch_idx].min():.{2}f}/{gt_disp[batch_idx].max():.{2}f}", f"GT Disp Vis \n{gt_disp[batch_idx].min():.{2}f}/{gt_disp[batch_idx].max():.{2}f}",
@ -439,7 +423,7 @@ if __name__ == "__main__":
# this was used for our blender renders # this was used for our blender renders
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png' pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
elif 'ctd' in config.training_data_path: elif 'ctd' in config.training_data_path:
# this one is used (i hope) for ctd # this one is used for ctd
pattern_path = '/home/nils/kinect_from_settings.png' pattern_path = '/home/nils/kinect_from_settings.png'
@ -454,8 +438,6 @@ if __name__ == "__main__":
pattern_path, pattern_path,
# lr=0.00017378008287493763, # found with auto_lr_find=True # lr=0.00017378008287493763, # found with auto_lr_find=True
) )
# NOTE turn this down once it's working, this might use too much space
# wandb_logger.watch(model, log_graph=False) #, log='all')
model_checkpoint = ModelCheckpoint( model_checkpoint = ModelCheckpoint(
monitor="val_loss", monitor="val_loss",

Loading…
Cancel
Save