last batch of live-fixes and improvments
This commit is contained in:
parent
2731ef1ada
commit
6f6ac23175
@ -22,7 +22,8 @@ from train import inference as ctd_inference
|
|||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
# reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
|
# reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
|
||||||
reference_pattern_path = '/home/nils/kinect_reference_far.png'
|
# reference_pattern_path = '/home/nils/kinect_reference_far.png'
|
||||||
|
reference_pattern_path = '/home/nils/mpc/kinect_downshift_rotate_left-1.png'
|
||||||
# reference_pattern_path = '/home/nils/kinect_diff_ref.png'
|
# reference_pattern_path = '/home/nils/kinect_diff_ref.png'
|
||||||
print(reference_pattern_path)
|
print(reference_pattern_path)
|
||||||
reference_pattern = cv2.imread(reference_pattern_path)
|
reference_pattern = cv2.imread(reference_pattern_path)
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
seed: 0
|
seed: 0
|
||||||
mixed_precision: true
|
mixed_precision: true
|
||||||
base_lr: 4.0e-4
|
base_lr: 0.00025
|
||||||
# base_lr: 0.00001
|
|
||||||
t_max: 16100
|
t_max: 16100
|
||||||
|
scheduler: "cosineannealing"
|
||||||
|
|
||||||
nr_gpus: 3
|
nr_gpus: 3
|
||||||
batch_size: 3
|
batch_size: 3
|
||||||
n_total_epoch: 100
|
n_total_epoch: 64
|
||||||
minibatch_per_epoch: 500
|
minibatch_per_epoch: 500
|
||||||
|
|
||||||
loadmodel: ~
|
loadmodel: ~
|
||||||
@ -17,17 +17,9 @@ model_save_freq_epoch: 1
|
|||||||
max_disp: 256
|
max_disp: 256
|
||||||
image_width: 640
|
image_width: 640
|
||||||
image_height: 480
|
image_height: 480
|
||||||
# dataset: "blender"
|
|
||||||
# training_data_path: "./stereo_trainset/crestereo"
|
|
||||||
# training_data_path: "/media/Data1/connecting_the_dots_data/ctd_data/"
|
|
||||||
# training_data_path: "/media/Data1/connecting_the_dots_data/blender_renders/data"
|
|
||||||
training_data_path: "/media/Data1/connecting_the_dots_data/blender_renders_ctd_randomize_light/data"
|
training_data_path: "/media/Data1/connecting_the_dots_data/blender_renders_ctd_randomize_light/data"
|
||||||
|
test_data_path: "./eval_kinect"
|
||||||
|
data_limit: 1.
|
||||||
# FIXME any of this??
|
|
||||||
pattern_attention: false
|
|
||||||
scene_attention: true
|
|
||||||
ignore_pattern_completely: false
|
|
||||||
|
|
||||||
|
|
||||||
log_level: "logging.INFO"
|
log_level: "logging.INFO"
|
||||||
|
96
dataset.py
96
dataset.py
@ -239,17 +239,22 @@ class CREStereoDataset(Dataset):
|
|||||||
|
|
||||||
|
|
||||||
class CTDDataset(Dataset):
|
class CTDDataset(Dataset):
|
||||||
def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False, split=0.9, test_set=False, use_lightning=True):
|
def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False, split=0.9, test_set=False, use_lightning=True, data_limit=1.):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rng = np.random.RandomState(0)
|
self.rng = np.random.RandomState(0)
|
||||||
self.augment = augment
|
self.augment = augment
|
||||||
self.blur = blur
|
self.blur = blur
|
||||||
self.use_lightning = use_lightning
|
self.use_lightning = use_lightning
|
||||||
imgs = glob.glob(os.path.join(root, f"{data_type}/*/im0_*.npy"), recursive=True)
|
self.data_type = data_type
|
||||||
if test_set:
|
|
||||||
|
imgs = glob.glob(os.path.join(root, f"{data_type if not 'syn' in root else ''}/*/im0_0*.npy"), recursive=True)
|
||||||
|
if not test_set:
|
||||||
self.imgs = imgs[:int(split * len(imgs))]
|
self.imgs = imgs[:int(split * len(imgs))]
|
||||||
else:
|
else:
|
||||||
self.imgs = imgs[int(split * len(imgs)):]
|
self.imgs = imgs[int(split * len(imgs)):]
|
||||||
|
|
||||||
|
self.imgs = self.imgs[:int(data_limit * len(self.imgs))]
|
||||||
|
|
||||||
self.pattern = cv2.imread(pattern_path)#, cv2.IMREAD_GRAYSCALE)
|
self.pattern = cv2.imread(pattern_path)#, cv2.IMREAD_GRAYSCALE)
|
||||||
|
|
||||||
if resize_pattern and self.pattern.shape != (480, 640, 3):
|
if resize_pattern and self.pattern.shape != (480, 640, 3):
|
||||||
@ -325,9 +330,30 @@ class CTDDataset(Dataset):
|
|||||||
|
|
||||||
|
|
||||||
class BlenderDataset(CTDDataset):
|
class BlenderDataset(CTDDataset):
|
||||||
def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False, split=0.9, test_set=False, use_lightning=False):
|
def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False, split=0.9, test_set=False, use_lightning=False, disp_avail=False, data_limit=1.):
|
||||||
super().__init__(root, pattern_path)
|
super().__init__(root, pattern_path, augment=augment)
|
||||||
self.use_lightning = use_lightning
|
self.use_lightning = use_lightning
|
||||||
|
self.disp_avail = disp_avail
|
||||||
|
self.data_type = data_type
|
||||||
|
|
||||||
|
self.get_imgs(root, test_set, split)
|
||||||
|
self.imgs = self.imgs[:int(data_limit * len(self.imgs))]
|
||||||
|
|
||||||
|
self.pattern = cv2.imread(pattern_path)#, cv2.IMREAD_GRAYSCALE)
|
||||||
|
|
||||||
|
if resize_pattern and self.pattern.shape != (480, 640, 3):
|
||||||
|
self.pattern = downsample(self.pattern)
|
||||||
|
|
||||||
|
self.augmentor = Augmentor(
|
||||||
|
image_height=480,
|
||||||
|
image_width=640,
|
||||||
|
max_disp=256,
|
||||||
|
scale_min=0.6,
|
||||||
|
scale_max=1.0,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_imgs(self, root, test_set, split):
|
||||||
additional_img_types = {
|
additional_img_types = {
|
||||||
'depth',
|
'depth',
|
||||||
'disp',
|
'disp',
|
||||||
@ -347,25 +373,13 @@ class BlenderDataset(CTDDataset):
|
|||||||
self.imgs = imgs[:int(split * len(imgs))]
|
self.imgs = imgs[:int(split * len(imgs))]
|
||||||
else:
|
else:
|
||||||
self.imgs = imgs[int(split * len(imgs)):]
|
self.imgs = imgs[int(split * len(imgs)):]
|
||||||
|
|
||||||
self.pattern = cv2.imread(pattern_path)#, cv2.IMREAD_GRAYSCALE)
|
|
||||||
|
|
||||||
if resize_pattern and self.pattern.shape != (480, 640, 3):
|
|
||||||
self.pattern = downsample(self.pattern)
|
|
||||||
|
|
||||||
self.augmentor = Augmentor(
|
|
||||||
image_height=480,
|
|
||||||
image_width=640,
|
|
||||||
max_disp=256,
|
|
||||||
scale_min=0.6,
|
|
||||||
scale_max=1.0,
|
|
||||||
seed=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
# find path
|
# find path
|
||||||
left_path = self.imgs[index]
|
left_path = self.imgs[index]
|
||||||
left_disp_path = left_path.split('.')[0] + '_disp0001.png'
|
left_disp_path = left_path.rsplit('.', maxsplit=1)[0] + '_disp0001.png'
|
||||||
|
if not self.disp_avail:
|
||||||
|
left_depth_path = left_path.rsplit('.', maxsplit=1)[0] + '_depth0001.png'
|
||||||
|
|
||||||
# read img, disp
|
# read img, disp
|
||||||
left_img = cv2.imread(left_path)
|
left_img = cv2.imread(left_path)
|
||||||
@ -379,9 +393,23 @@ class BlenderDataset(CTDDataset):
|
|||||||
left_img = cv2.merge([left_img, left_img, left_img]).reshape((480, 640, 3))
|
left_img = cv2.merge([left_img, left_img, left_img]).reshape((480, 640, 3))
|
||||||
|
|
||||||
right_img = self.pattern
|
right_img = self.pattern
|
||||||
# left_disp = self.get_disp(left_disp_path)
|
|
||||||
disp = cv2.imread(left_disp_path, cv2.IMREAD_UNCHANGED)
|
# In some cases, we have disparity as floats in the range [0..1]. Thus we need to upscale the values.
|
||||||
left_disp = downsample(disp)
|
# 64 has been arbitrarily chosen as max_disp for this case, as this is roughly the max disparity of the CTD dataset
|
||||||
|
max_disp = 64
|
||||||
|
if not self.disp_avail:
|
||||||
|
left_disp = self.get_disp(left_depth_path)
|
||||||
|
if left_disp.max() <= 1.:
|
||||||
|
left_disp = (left_disp * max_disp).astype('uint8')
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
left_disp = cv2.imread(left_disp_path, cv2.IMREAD_UNCHANGED)
|
||||||
|
if left_disp.max() <= 1.:
|
||||||
|
left_disp = (left_disp * max_disp).astype('uint8')
|
||||||
|
if left_disp.shape != (480, 640, 3):
|
||||||
|
left_disp = downsample(left_disp)
|
||||||
|
except:
|
||||||
|
print(f'something happened, probably couldn\'t find {left_disp_path}')
|
||||||
|
|
||||||
if False: # self.rng.binomial(1, 0.5):
|
if False: # self.rng.binomial(1, 0.5):
|
||||||
left_img, right_img = np.fliplr(right_img), np.fliplr(left_img)
|
left_img, right_img = np.fliplr(right_img), np.fliplr(left_img)
|
||||||
@ -398,6 +426,9 @@ class BlenderDataset(CTDDataset):
|
|||||||
_left_img, _right_img, _left_disp, disp_mask = self.augmentor(
|
_left_img, _right_img, _left_disp, disp_mask = self.augmentor(
|
||||||
left_img, right_img, left_disp
|
left_img, right_img, left_disp
|
||||||
)
|
)
|
||||||
|
left_img = left_img.astype('float32')
|
||||||
|
right_img = right_img.astype('float32')
|
||||||
|
left_disp = left_disp.astype('float32')
|
||||||
else:
|
else:
|
||||||
left_img, right_img, left_disp, disp_mask = self.augmentor(
|
left_img, right_img, left_disp, disp_mask = self.augmentor(
|
||||||
left_img, right_img, left_disp
|
left_img, right_img, left_disp
|
||||||
@ -418,14 +449,13 @@ class BlenderDataset(CTDDataset):
|
|||||||
def get_disp(self, path):
|
def get_disp(self, path):
|
||||||
baseline = 0.075 # meters
|
baseline = 0.075 # meters
|
||||||
fl = 560. # as per CTD
|
fl = 560. # as per CTD
|
||||||
depth = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
# depth = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||||
depth = downsample(depth)
|
depth = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
|
||||||
# disp = np.load(path).transpose(1,2,0)
|
if depth.shape != (480, 640):
|
||||||
# disp = baseline * fl / depth
|
depth = downsample(depth)
|
||||||
# return disp.astype(np.float32) / 32
|
|
||||||
# FIXME temporarily increase disparity until new data with better depth values is generated
|
disp = (baseline * fl) / depth
|
||||||
# higher values seem to speedup convergence, but introduce much stronger artifacting
|
|
||||||
mystery_factor = 35
|
disp[disp == np.inf] = 0
|
||||||
# mystery_factor = 1
|
|
||||||
disp = (baseline * fl * mystery_factor) / depth
|
|
||||||
return disp.astype(np.float32)
|
return disp.astype(np.float32)
|
||||||
|
@ -76,7 +76,7 @@ class LocalFeatureTransformer(nn.Module):
|
|||||||
if p.dim() > 1:
|
if p.dim() > 1:
|
||||||
nn.init.xavier_uniform_(p)
|
nn.init.xavier_uniform_(p)
|
||||||
|
|
||||||
def forward(self, feat0, feat1, mask0=None, mask1=None):
|
def forward(self, feat0, feat1, mask0=None, mask1=None, ignore_second_feat=False):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
feat0 (torch.Tensor): [N, L, C]
|
feat0 (torch.Tensor): [N, L, C]
|
||||||
@ -97,6 +97,9 @@ class LocalFeatureTransformer(nn.Module):
|
|||||||
name = self.layer_names[i]
|
name = self.layer_names[i]
|
||||||
if name == 'self':
|
if name == 'self':
|
||||||
feat0 = layer(feat0, feat0, mask0, mask0)
|
feat0 = layer(feat0, feat0, mask0, mask0)
|
||||||
|
if ignore_second_feat:
|
||||||
|
# save some compute
|
||||||
|
continue
|
||||||
feat1 = layer(feat1, feat1, mask1, mask1)
|
feat1 = layer(feat1, feat1, mask1, mask1)
|
||||||
elif name == 'cross':
|
elif name == 'cross':
|
||||||
feat0 = layer(feat0, feat1, mask0, mask1)
|
feat0 = layer(feat0, feat1, mask0, mask1)
|
||||||
@ -104,4 +107,6 @@ class LocalFeatureTransformer(nn.Module):
|
|||||||
else:
|
else:
|
||||||
raise KeyError
|
raise KeyError
|
||||||
|
|
||||||
|
if ignore_second_feat:
|
||||||
|
return feat0
|
||||||
return feat0, feat1
|
return feat0, feat1
|
||||||
|
@ -151,7 +151,8 @@ class CREStereo(nn.Module):
|
|||||||
|
|
||||||
# FIXME experimental ! no self-attention for pattern
|
# FIXME experimental ! no self-attention for pattern
|
||||||
if not self_attend_right:
|
if not self_attend_right:
|
||||||
fmap1_dw16, _ = self.self_att_fn(fmap1_dw16, fmap2_dw16)
|
print('skipping right attention')
|
||||||
|
fmap1_dw16 = self.self_att_fn(fmap1_dw16, fmap2_dw16, ignore_second_feat=True)
|
||||||
else:
|
else:
|
||||||
fmap1_dw16, fmap2_dw16 = self.self_att_fn(fmap1_dw16, fmap2_dw16)
|
fmap1_dw16, fmap2_dw16 = self.self_att_fn(fmap1_dw16, fmap2_dw16)
|
||||||
|
|
||||||
|
@ -31,8 +31,22 @@ import numpy as np
|
|||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
def normalize_and_colormap(img):
|
def normalize_and_colormap(img, reduce_dynamic_range=False):
|
||||||
|
# print(img.min())
|
||||||
|
# print(img.max())
|
||||||
|
# print(img.mean())
|
||||||
ret = (img - img.min()) / (img.max() - img.min()) * 255.0
|
ret = (img - img.min()) / (img.max() - img.min()) * 255.0
|
||||||
|
# print(ret.min())
|
||||||
|
# print(ret.max())
|
||||||
|
# print(ret.mean())
|
||||||
|
|
||||||
|
# FIXME do I need to compress dynamic range somehow or something?
|
||||||
|
if reduce_dynamic_range and img.max() > 5*img.mean():
|
||||||
|
ret = (img - img.min()) / (5*img.mean() - img.min()) * 255.0
|
||||||
|
# print(ret.min())
|
||||||
|
# print(ret.max())
|
||||||
|
# print(ret.mean())
|
||||||
|
|
||||||
if isinstance(ret, torch.Tensor):
|
if isinstance(ret, torch.Tensor):
|
||||||
ret = ret.cpu().detach().numpy()
|
ret = ret.cpu().detach().numpy()
|
||||||
ret = ret.astype("uint8")
|
ret = ret.astype("uint8")
|
||||||
@ -47,34 +61,71 @@ def log_images(left, right, pred_disp, gt_disp):
|
|||||||
if isinstance(pred_disp, list):
|
if isinstance(pred_disp, list):
|
||||||
pred_disp = pred_disp[-1]
|
pred_disp = pred_disp[-1]
|
||||||
|
|
||||||
pred_disp = torch.squeeze(pred_disp[:, 0, :, :])
|
|
||||||
gt_disp = torch.squeeze(gt_disp[:, 0, :, :])
|
|
||||||
left = torch.squeeze(left[:, 0, :, :])
|
left = torch.squeeze(left[:, 0, :, :])
|
||||||
right = torch.squeeze(right[:, 0, :, :])
|
right = torch.squeeze(right[:, 0, :, :])
|
||||||
|
pred_disp = torch.squeeze(pred_disp[:, 0, :, :])
|
||||||
|
gt_disp = torch.squeeze(gt_disp[:, 0, :, :])
|
||||||
|
|
||||||
|
# print('gt_disp debug')
|
||||||
|
# print(gt_disp.shape)
|
||||||
|
|
||||||
|
singular_batch = False
|
||||||
|
if len(left.shape) == 2:
|
||||||
|
singular_batch = True
|
||||||
|
print('batch_size seems to be 1')
|
||||||
|
input_left = left.cpu().detach().numpy()
|
||||||
|
input_right = right.cpu().detach().numpy()
|
||||||
|
else:
|
||||||
|
input_left = left[batch_idx].cpu().detach().numpy()# .transpose(1,2,0)
|
||||||
|
input_right = right[batch_idx].cpu().detach().numpy()# .transpose(1,2,0)
|
||||||
|
|
||||||
disp = pred_disp
|
disp = pred_disp
|
||||||
disp_error = gt_disp - disp
|
disp_error = gt_disp - disp
|
||||||
|
|
||||||
|
# print('gt_disp debug normalize')
|
||||||
|
# print(gt_disp.max(), gt_disp.min())
|
||||||
|
# print(gt_disp.dtype)
|
||||||
|
|
||||||
input_left = left[batch_idx].cpu().detach().numpy()# .transpose(1,2,0)
|
if singular_batch:
|
||||||
input_right = right[batch_idx].cpu().detach().numpy()# .transpose(1,2,0)
|
wandb_log = dict(
|
||||||
|
key='samples',
|
||||||
wandb_log = dict(
|
images=[
|
||||||
key='samples',
|
pred_disp,
|
||||||
images=[
|
normalize_and_colormap(pred_disp),
|
||||||
normalize_and_colormap(pred_disp[batch_idx]),
|
normalize_and_colormap(abs(disp_error), reduce_dynamic_range=True),
|
||||||
normalize_and_colormap(abs(disp_error[batch_idx])),
|
normalize_and_colormap(gt_disp, reduce_dynamic_range=True),
|
||||||
normalize_and_colormap(gt_disp[batch_idx]),
|
input_left,
|
||||||
input_left,
|
input_right,
|
||||||
input_right,
|
],
|
||||||
],
|
caption=[
|
||||||
caption=[
|
f"Disparity \n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
|
||||||
f"Disparity \n{pred_disp[batch_idx].min():.{2}f}/{pred_disp[batch_idx].max():.{2}f}",
|
f"Disparity (vis) \n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
|
||||||
f"Disp. Error\n{disp_error[batch_idx].min():.{2}f}/{disp_error[batch_idx].max():.{2}f}\n{abs(disp_error[batch_idx]).mean():.{2}f}",
|
f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{abs(disp_error).mean():.{2}f}",
|
||||||
f"GT Disp Vis \n{gt_disp[batch_idx].min():.{2}f}/{gt_disp[batch_idx].max():.{2}f}",
|
f"GT Disp Vis \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}",
|
||||||
"Input Left",
|
"Input Left",
|
||||||
"Input Right"
|
"Input Right"
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
wandb_log = dict(
|
||||||
|
key='samples',
|
||||||
|
images=[
|
||||||
|
# pred_disp.cpu().detach().numpy().transpose(1,2,0),
|
||||||
|
normalize_and_colormap(pred_disp[batch_idx]),
|
||||||
|
normalize_and_colormap(abs(disp_error[batch_idx])),
|
||||||
|
normalize_and_colormap(gt_disp[batch_idx]),
|
||||||
|
input_left,
|
||||||
|
input_right,
|
||||||
|
],
|
||||||
|
caption=[
|
||||||
|
# f"Disparity \n{pred_disp[batch_idx].min():.{2}f}/{pred_disp[batch_idx].max():.{2}f}",
|
||||||
|
f"Disparity (vis)\n{pred_disp[batch_idx].min():.{2}f}/{pred_disp[batch_idx].max():.{2}f}",
|
||||||
|
f"Disp. Error\n{disp_error[batch_idx].min():.{2}f}/{disp_error[batch_idx].max():.{2}f}\n{abs(disp_error[batch_idx]).mean():.{2}f}",
|
||||||
|
f"GT Disp Vis \n{gt_disp[batch_idx].min():.{2}f}/{gt_disp[batch_idx].max():.{2}f}",
|
||||||
|
"Input Left",
|
||||||
|
"Input Right"
|
||||||
|
],
|
||||||
|
)
|
||||||
return wandb_log
|
return wandb_log
|
||||||
|
|
||||||
|
|
||||||
@ -104,7 +155,10 @@ def outlier_fraction(estimate, target, mask=None, threshold=0):
|
|||||||
else:
|
else:
|
||||||
mask = mask != 0
|
mask = mask != 0
|
||||||
if estimate.shape != mask.shape:
|
if estimate.shape != mask.shape:
|
||||||
raise Exception(f'estimate and mask have to be same shape (expected {estimate.shape} == {mask.shape})')
|
if len(mask.shape) == 3:
|
||||||
|
mask = mask[0]
|
||||||
|
if estimate.shape != mask.shape:
|
||||||
|
raise Exception(f'estimate and mask have to be same shape (expected {estimate.shape} == {mask.shape})')
|
||||||
return estimate, target, mask
|
return estimate, target, mask
|
||||||
estimate = torch.squeeze(estimate[:, 0, :, :])
|
estimate = torch.squeeze(estimate[:, 0, :, :])
|
||||||
target = torch.squeeze(target[:, 0, :, :])
|
target = torch.squeeze(target[:, 0, :, :])
|
||||||
@ -131,27 +185,9 @@ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False):
|
|||||||
flow_preds[0]: (B, 2, H, W)
|
flow_preds[0]: (B, 2, H, W)
|
||||||
flow_gt: (B, 2, H, W)
|
flow_gt: (B, 2, H, W)
|
||||||
'''
|
'''
|
||||||
"""
|
|
||||||
if test:
|
|
||||||
# print('sequence loss')
|
|
||||||
if valid.shape != (2, 480, 640):
|
|
||||||
valid = torch.stack([valid, valid])#.transpose(0,1)#.transpose(1,2)
|
|
||||||
# print(valid.shape)
|
|
||||||
#valid = torch.stack([valid, valid])
|
|
||||||
# print(valid.shape)
|
|
||||||
if valid.shape != (2, 480, 640):
|
|
||||||
valid = valid.transpose(0,1)
|
|
||||||
# print(valid.shape)
|
|
||||||
"""
|
|
||||||
# print(valid.shape)
|
|
||||||
# print(flow_preds[0].shape)
|
|
||||||
# print(flow_gt.shape)
|
|
||||||
n_predictions = len(flow_preds)
|
n_predictions = len(flow_preds)
|
||||||
flow_loss = 0.0
|
flow_loss = 0.0
|
||||||
|
|
||||||
# TEST
|
|
||||||
# flow_gt = torch.squeeze(flow_gt, dim=-1)
|
|
||||||
|
|
||||||
for i in range(n_predictions):
|
for i in range(n_predictions):
|
||||||
i_weight = gamma ** (n_predictions - i - 1)
|
i_weight = gamma ** (n_predictions - i - 1)
|
||||||
i_loss = torch.abs(flow_preds[i] - flow_gt)
|
i_loss = torch.abs(flow_preds[i] - flow_gt)
|
||||||
@ -161,38 +197,50 @@ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False):
|
|||||||
|
|
||||||
|
|
||||||
class CREStereoLightning(LightningModule):
|
class CREStereoLightning(LightningModule):
|
||||||
def __init__(self, args, logger=None, pattern_path='', data_path=''):
|
def __init__(self, args, logger=None, pattern_path=''):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.batch_size = args.batch_size
|
self.batch_size = args.batch_size
|
||||||
self.wandb_logger = logger
|
self.wandb_logger = logger
|
||||||
self.data_type = 'blender' if 'blender' in data_path else 'ctd'
|
self.imwidth = args.image_width
|
||||||
|
self.imheight = args.image_height
|
||||||
|
self.data_type = 'blender' if 'blender' in args.training_data_path else 'ctd'
|
||||||
|
self.eval_type = 'kinect' if 'kinect' in args.test_data_path else args.training_data_path
|
||||||
self.lr = args.base_lr
|
self.lr = args.base_lr
|
||||||
print(f'lr = {self.lr}')
|
|
||||||
self.T_max = args.t_max if args.t_max else None
|
self.T_max = args.t_max if args.t_max else None
|
||||||
self.pattern_attention = args.pattern_attention
|
self.pattern_attention = args.pattern_attention
|
||||||
self.pattern_path = pattern_path
|
self.pattern_path = pattern_path
|
||||||
self.data_path = data_path
|
self.data_path = args.training_data_path
|
||||||
|
self.test_data_path = args.test_data_path
|
||||||
|
self.data_limit = args.data_limit # between 0 and 1.
|
||||||
self.model = Model(
|
self.model = Model(
|
||||||
max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False
|
max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False
|
||||||
)
|
)
|
||||||
# so I can access it in adjust learn rate more easily
|
|
||||||
|
if args.scheduler == 'default':
|
||||||
|
self.automatic_optimization = False
|
||||||
|
# so I can access it in adjust learn rate more easily
|
||||||
|
|
||||||
self.n_total_epoch = args.n_total_epoch
|
self.n_total_epoch = args.n_total_epoch
|
||||||
self.base_lr = args.base_lr
|
self.base_lr = args.base_lr
|
||||||
|
|
||||||
self.automatic_optimization = False
|
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
|
# we never train on kinect
|
||||||
|
is_kinect = False
|
||||||
if self.data_type == 'blender':
|
if self.data_type == 'blender':
|
||||||
dataset = BlenderDataset(
|
dataset = BlenderDataset(
|
||||||
root=self.data_path,
|
root=self.data_path,
|
||||||
pattern_path=self.pattern_path,
|
pattern_path=self.pattern_path,
|
||||||
use_lightning=True,
|
use_lightning=True,
|
||||||
|
data_type='kinect' if is_kinect else 'blender',
|
||||||
|
disp_avail=not is_kinect,
|
||||||
|
data_limit = self.data_limit,
|
||||||
)
|
)
|
||||||
elif self.data_type == 'ctd':
|
elif self.data_type == 'ctd':
|
||||||
dataset = CTDDataset(
|
dataset = CTDDataset(
|
||||||
root=self.data_path,
|
root=self.data_path,
|
||||||
pattern_path=self.pattern_path,
|
pattern_path=self.pattern_path,
|
||||||
use_lightning=True,
|
use_lightning=True,
|
||||||
|
data_limit = self.data_limit,
|
||||||
)
|
)
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
@ -203,16 +251,20 @@ class CREStereoLightning(LightningModule):
|
|||||||
persistent_workers=True,
|
persistent_workers=True,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
# num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True)
|
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
|
# we also don't want to validate on kinect data
|
||||||
|
is_kinect = False
|
||||||
if self.data_type == 'blender':
|
if self.data_type == 'blender':
|
||||||
test_dataset = BlenderDataset(
|
test_dataset = BlenderDataset(
|
||||||
root=self.data_path,
|
root=self.data_path,
|
||||||
pattern_path=self.pattern_path,
|
pattern_path=self.pattern_path,
|
||||||
test_set=True,
|
test_set=True,
|
||||||
use_lightning=True,
|
use_lightning=True,
|
||||||
|
data_type='kinect' if is_kinect else 'blender',
|
||||||
|
disp_avail=not is_kinect,
|
||||||
|
data_limit = self.data_limit,
|
||||||
)
|
)
|
||||||
elif self.data_type == 'ctd':
|
elif self.data_type == 'ctd':
|
||||||
test_dataset = CTDDataset(
|
test_dataset = CTDDataset(
|
||||||
@ -220,6 +272,7 @@ class CREStereoLightning(LightningModule):
|
|||||||
pattern_path=self.pattern_path,
|
pattern_path=self.pattern_path,
|
||||||
test_set=True,
|
test_set=True,
|
||||||
use_lightning=True,
|
use_lightning=True,
|
||||||
|
data_limit = self.data_limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
test_dataloader = DataLoader(
|
test_dataloader = DataLoader(
|
||||||
@ -231,29 +284,35 @@ class CREStereoLightning(LightningModule):
|
|||||||
persistent_workers=True,
|
persistent_workers=True,
|
||||||
pin_memory=True
|
pin_memory=True
|
||||||
)
|
)
|
||||||
# num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True)
|
|
||||||
return test_dataloader
|
return test_dataloader
|
||||||
|
|
||||||
def test_dataloader(self):
|
def test_dataloader(self):
|
||||||
# TODO change this to use IRL data?
|
is_kinect = self.eval_type == 'kinect'
|
||||||
if self.data_type == 'blender':
|
if self.data_type == 'blender':
|
||||||
test_dataset = CTDDataset(
|
test_dataset = BlenderDataset(
|
||||||
root=self.data_path,
|
root=self.test_data_path,
|
||||||
pattern_path=self.pattern_path,
|
pattern_path=self.pattern_path,
|
||||||
test_set=True,
|
test_set=True,
|
||||||
|
split=0. if is_kinect else 0.9, # if we test on kinect data, use all available samples for test set
|
||||||
use_lightning=True,
|
use_lightning=True,
|
||||||
|
augment=False,
|
||||||
|
disp_avail=not is_kinect,
|
||||||
|
data_type='kinect' if is_kinect else 'blender',
|
||||||
|
data_limit = self.data_limit,
|
||||||
)
|
)
|
||||||
elif self.data_type == 'ctd':
|
elif self.data_type == 'ctd':
|
||||||
test_dataset = BlenderDataset(
|
test_dataset = CTDDataset(
|
||||||
root=self.data_path,
|
root=self.test_data_path,
|
||||||
pattern_path=self.pattern_path,
|
pattern_path=self.pattern_path,
|
||||||
test_set=True,
|
test_set=True,
|
||||||
use_lightning=True,
|
use_lightning=True,
|
||||||
|
augment=False,
|
||||||
|
data_limit = self.data_limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
test_dataloader = DataLoader(
|
test_dataloader = DataLoader(
|
||||||
test_dataset,
|
test_dataset,
|
||||||
self.batch_size,
|
1 if is_kinect else self.batch_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
@ -307,7 +366,8 @@ class CREStereoLightning(LightningModule):
|
|||||||
self.log("outlier_fraction", of)
|
self.log("outlier_fraction", of)
|
||||||
# print(', '.join(f'of{thr}={val}' for thr, val in of.items()))
|
# print(', '.join(f'of{thr}={val}' for thr, val in of.items()))
|
||||||
if batch_idx % 8 == 0:
|
if batch_idx % 8 == 0:
|
||||||
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp))
|
images = log_images(left, right, flow_predictions, gt_disp)
|
||||||
|
self.wandb_logger.log_image(**images)
|
||||||
|
|
||||||
def test_step(self, batch, batch_idx):
|
def test_step(self, batch, batch_idx):
|
||||||
left, right, gt_disp, valid_mask = batch
|
left, right, gt_disp, valid_mask = batch
|
||||||
@ -318,20 +378,28 @@ class CREStereoLightning(LightningModule):
|
|||||||
flow_predictions, gt_flow, valid_mask, gamma=0.8
|
flow_predictions, gt_flow, valid_mask, gamma=0.8
|
||||||
)
|
)
|
||||||
self.log("test_loss", test_loss)
|
self.log("test_loss", test_loss)
|
||||||
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp))
|
of = {}
|
||||||
|
for threshold in [0.1, 0.5, 1, 2, 5]:
|
||||||
|
of[str(threshold)] = outlier_fraction(flow_predictions[0], gt_flow, valid_mask, threshold)
|
||||||
|
self.log("outlier_fraction", of)
|
||||||
|
images = log_images(left, right, flow_predictions, gt_disp)
|
||||||
|
images['images'].append(gt_disp)
|
||||||
|
images['caption'].append('GT Disp')
|
||||||
|
self.wandb_logger.log_image(**images)
|
||||||
|
|
||||||
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
||||||
return self(batch)
|
return self(batch)
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = optim.Adam(self.model.parameters(), lr=self.lr, betas=(0.9, 0.999))
|
optimizer = optim.Adam(self.model.parameters(), lr=self.lr, betas=(0.9, 0.999))
|
||||||
print('len(self.train_dataloader)', len(self.train_dataloader()))
|
if not self.automatic_optimization:
|
||||||
|
return optimizer
|
||||||
lr_scheduler = {
|
lr_scheduler = {
|
||||||
'scheduler': torch.optim.lr_scheduler.CosineAnnealingLR(
|
'scheduler': torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||||
optimizer,
|
optimizer,
|
||||||
T_max=self.T_max if self.T_max else len(self.train_dataloader())/self.batch_size,
|
T_max=self.T_max if self.T_max else len(self.train_dataloader())/self.batch_size,
|
||||||
),
|
),
|
||||||
'name': 'CosineAnnealingLRScheduler',
|
'name': 'LR Scheduler',
|
||||||
}
|
}
|
||||||
return [optimizer], [lr_scheduler]
|
return [optimizer], [lr_scheduler]
|
||||||
|
|
||||||
@ -356,18 +424,21 @@ class CREStereoLightning(LightningModule):
|
|||||||
|
|
||||||
for param_group in optimizer.param_groups:
|
for param_group in optimizer.param_groups:
|
||||||
param_group['lr'] = lr
|
param_group['lr'] = lr
|
||||||
|
self.log('train/lr', lr)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
# wandb.init(project='crestereo-lightning')
|
||||||
wandb_logger = WandbLogger(project="crestereo-lightning", log_model=True)
|
wandb_logger = WandbLogger(project="crestereo-lightning", log_model=True)
|
||||||
# train configuration
|
# train configuration
|
||||||
args = parse_yaml("cfgs/train.yaml")
|
args = parse_yaml("cfgs/train.yaml")
|
||||||
wandb_logger.experiment.config.update(args._asdict())
|
wandb_logger.experiment.config.update(args._asdict())
|
||||||
config = wandb.config
|
config = wandb.config
|
||||||
|
data_limit = config.data_limit
|
||||||
if 'blender' in config.training_data_path:
|
if 'blender' in config.training_data_path:
|
||||||
# this was used for our blender renders
|
# this was used for our blender renders
|
||||||
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
|
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
|
||||||
if 'ctd' in config.training_data_path:
|
elif 'ctd' in config.training_data_path:
|
||||||
# this one is used (i hope) for ctd
|
# this one is used (i hope) for ctd
|
||||||
pattern_path = '/home/nils/kinect_from_settings.png'
|
pattern_path = '/home/nils/kinect_from_settings.png'
|
||||||
|
|
||||||
@ -381,7 +452,6 @@ if __name__ == "__main__":
|
|||||||
config,
|
config,
|
||||||
wandb_logger,
|
wandb_logger,
|
||||||
pattern_path,
|
pattern_path,
|
||||||
config.training_data_path,
|
|
||||||
# lr=0.00017378008287493763, # found with auto_lr_find=True
|
# lr=0.00017378008287493763, # found with auto_lr_find=True
|
||||||
)
|
)
|
||||||
# NOTE turn this down once it's working, this might use too much space
|
# NOTE turn this down once it's working, this might use too much space
|
||||||
@ -394,31 +464,59 @@ if __name__ == "__main__":
|
|||||||
save_last=True,
|
save_last=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = Trainer(
|
if config.scheduler == 'default':
|
||||||
accelerator='gpu',
|
trainer = Trainer(
|
||||||
devices=devices,
|
accelerator='gpu',
|
||||||
max_epochs=config.n_total_epoch,
|
devices=devices,
|
||||||
callbacks=[
|
max_epochs=config.n_total_epoch,
|
||||||
EarlyStopping(
|
callbacks=[
|
||||||
monitor="val_loss",
|
EarlyStopping(
|
||||||
mode="min",
|
monitor="val_loss",
|
||||||
patience=16,
|
mode="min",
|
||||||
),
|
patience=8,
|
||||||
LearningRateMonitor(),
|
),
|
||||||
model_checkpoint,
|
LearningRateMonitor(),
|
||||||
],
|
model_checkpoint,
|
||||||
strategy=DDPSpawnStrategy(find_unused_parameters=False),
|
],
|
||||||
# auto_scale_batch_size='binsearch',
|
strategy=DDPSpawnStrategy(find_unused_parameters=False),
|
||||||
# auto_lr_find=True,
|
# auto_scale_batch_size='binsearch',
|
||||||
# accumulate_grad_batches=4, # needed to disable for manual optimization
|
# auto_lr_find=True,
|
||||||
deterministic=True,
|
# accumulate_grad_batches=4, # needed to disable for manual optimization
|
||||||
check_val_every_n_epoch=1,
|
deterministic=True,
|
||||||
limit_val_batches=64,
|
check_val_every_n_epoch=1,
|
||||||
limit_test_batches=256,
|
limit_val_batches=64,
|
||||||
logger=wandb_logger,
|
limit_test_batches=256,
|
||||||
default_root_dir=config.log_dir_lightning,
|
logger=wandb_logger,
|
||||||
)
|
default_root_dir=config.log_dir_lightning,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
trainer = Trainer(
|
||||||
|
accelerator='gpu',
|
||||||
|
devices=devices,
|
||||||
|
max_epochs=config.n_total_epoch,
|
||||||
|
callbacks=[
|
||||||
|
EarlyStopping(
|
||||||
|
monitor="val_loss",
|
||||||
|
mode="min",
|
||||||
|
patience=8,
|
||||||
|
),
|
||||||
|
LearningRateMonitor(),
|
||||||
|
model_checkpoint,
|
||||||
|
],
|
||||||
|
strategy=DDPSpawnStrategy(find_unused_parameters=False),
|
||||||
|
# auto_scale_batch_size='binsearch',
|
||||||
|
# auto_lr_find=True,
|
||||||
|
accumulate_grad_batches=4, # needed to disable for manual optimization
|
||||||
|
deterministic=True,
|
||||||
|
check_val_every_n_epoch=1,
|
||||||
|
limit_val_batches=64,
|
||||||
|
limit_test_batches=256,
|
||||||
|
logger=wandb_logger,
|
||||||
|
default_root_dir=config.log_dir_lightning,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# trainer.tune(model)
|
# trainer.tune(model)
|
||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
# trainer.validate(chkpt_path=model_checkpoint.best_model_path)
|
# trainer.validate(chkpt_path=model_checkpoint.best_model_path)
|
||||||
|
trainer.test(model)
|
||||||
|
Loading…
Reference in New Issue
Block a user