change a bunch of stuff, add wip lightning implementation
This commit is contained in:
parent
11959eef61
commit
63da24f429
181
api_server.py
181
api_server.py
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Union, Literal
|
from typing import Union, Literal
|
||||||
@ -7,32 +8,149 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torch_tensorrt
|
||||||
from cv2 import cv2
|
from cv2 import cv2
|
||||||
from fastapi import FastAPI, File, UploadFile
|
from fastapi import FastAPI, File, UploadFile
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
from nets import Model
|
from nets import Model
|
||||||
|
|
||||||
|
from train import inference as ctd_inference
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
|
# reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
|
||||||
|
reference_pattern_path = '/home/nils/kinect_reference_far.png'
|
||||||
|
# reference_pattern_path = '/home/nils/kinect_diff_ref.png'
|
||||||
|
print(reference_pattern_path)
|
||||||
reference_pattern = cv2.imread(reference_pattern_path)
|
reference_pattern = cv2.imread(reference_pattern_path)
|
||||||
|
|
||||||
|
# shift reference pattern a few pixels to the left to simulate further backdrop
|
||||||
|
trans_mat = np.float32([[1, 0, 0], [0, 1, 0]])
|
||||||
|
reference_pattern = cv2.warpAffine(
|
||||||
|
reference_pattern, trans_mat, reference_pattern.shape[1::-1], flags=cv2.INTER_LINEAR
|
||||||
|
)
|
||||||
|
|
||||||
iters = 20
|
iters = 20
|
||||||
minimal_data = False
|
minimal_data = True
|
||||||
|
temporal_init = False
|
||||||
|
last_img = None
|
||||||
device = torch.device('cuda:0')
|
device = torch.device('cuda:0')
|
||||||
|
|
||||||
|
|
||||||
def load_model(epoch):
|
def downsize(img):
|
||||||
|
diff = (512 - 480) // 2
|
||||||
|
downsampled = cv2.pyrDown(img)
|
||||||
|
img = downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]]
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
if 1024 in reference_pattern.shape:
|
||||||
|
reference_pattern = downsize(reference_pattern)
|
||||||
|
|
||||||
|
|
||||||
|
def ghetto_lcn(img):
|
||||||
|
# gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||||
|
gray = img
|
||||||
|
|
||||||
|
float_gray = gray.astype(np.float32) / 255.0
|
||||||
|
|
||||||
|
blur = cv2.GaussianBlur(float_gray, (0, 0), sigmaX=2, sigmaY=2)
|
||||||
|
num = float_gray - blur
|
||||||
|
|
||||||
|
blur = cv2.GaussianBlur(num * num, (0, 0), sigmaX=20, sigmaY=20)
|
||||||
|
den = cv2.pow(blur, 0.5)
|
||||||
|
|
||||||
|
gray = num / den
|
||||||
|
|
||||||
|
# cv2.normalize(gray, dst=gray, alpha=0.0, beta=1.0, norm_type=cv2.NORM_MINMAX)
|
||||||
|
cv2.normalize(gray, dst=gray, alpha=0.0, beta=255.0, norm_type=cv2.NORM_MINMAX)
|
||||||
|
return gray
|
||||||
|
|
||||||
|
|
||||||
|
# reference_pattern = ghetto_lcn(reference_pattern)
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(epoch, use_tensorrt=False):
|
||||||
global model
|
global model
|
||||||
epoch = f'epoch-{epoch}.pth' if epoch != 'latest' else 'latest.pth'
|
epoch = f'epoch-{epoch}.pth' if epoch != 'latest' else 'latest.pth'
|
||||||
model_path = f"train_log/models/{epoch}"
|
model_path = f"train_log/models/{epoch}"
|
||||||
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
|
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
|
||||||
|
# FIXME WIP Workaround Dataparallel TensorRT incompatibility
|
||||||
model = nn.DataParallel(model, device_ids=[device])
|
model = nn.DataParallel(model, device_ids=[device])
|
||||||
# model.load_state_dict(torch.load(model_path), strict=False)
|
# model.load_state_dict(torch.load(model_path), strict=False)
|
||||||
state_dict = torch.load(model_path)['state_dict']
|
state_dict = torch.load(model_path)['state_dict']
|
||||||
model.load_state_dict(state_dict, strict=True)
|
model.load_state_dict(state_dict, strict=True)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
if use_tensorrt:
|
||||||
|
np_model = Model(max_disp=256, mixed_precision=False, test_mode=True)
|
||||||
|
np_model.load_state_dict(model.module.state_dict(), strict=True)
|
||||||
|
np_model.to(device)
|
||||||
|
np_model.eval()
|
||||||
|
|
||||||
|
spec_dict = {
|
||||||
|
"inputs": [
|
||||||
|
torch_tensorrt.Input(
|
||||||
|
min_shape=[1, 2, 240, 320],
|
||||||
|
max_shape=[1, 2, 480, 640],
|
||||||
|
opt_shape=[1, 2, 480, 640],
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
torch_tensorrt.Input(
|
||||||
|
min_shape=[1, 2, 240, 320],
|
||||||
|
max_shape=[1, 2, 480, 640],
|
||||||
|
opt_shape=[1, 2, 480, 640],
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
"enabled_precisions": {torch.float, torch.half},
|
||||||
|
"refit": False,
|
||||||
|
"debug": False,
|
||||||
|
"device": {
|
||||||
|
"device_type": torch_tensorrt.DeviceType.GPU,
|
||||||
|
"gpu_id": 0,
|
||||||
|
"dla_core": 0,
|
||||||
|
"allow_gpu_fallback": True
|
||||||
|
},
|
||||||
|
"capability": torch_tensorrt.EngineCapability.default,
|
||||||
|
"num_min_timing_iters": 2,
|
||||||
|
"num_avg_timing_iters": 1,
|
||||||
|
}
|
||||||
|
spec = {
|
||||||
|
"forward":
|
||||||
|
torch_tensorrt.ts.TensorRTCompileSpec(**spec_dict)
|
||||||
|
}
|
||||||
|
|
||||||
|
# trt_model = torch_tensorrt.compile(np_model ,
|
||||||
|
# inputs=torch_tensorrt.Input(
|
||||||
|
# min_shape=[1, 2, 240, 320],
|
||||||
|
# max_shape=[1, 2, 480, 640],
|
||||||
|
# opt_shape=[1, 2, 480, 640],
|
||||||
|
# dtype=torch.int32,
|
||||||
|
# inputs = [torch_tensorrt.Input((1, 2, 480, 640)), torch_tensorrt.Input((1, 2, 480, 640))], # input shape
|
||||||
|
# enabled_precisions = {torch_tensorrt.dtype.half} # Run with FP16
|
||||||
|
# )
|
||||||
|
# trt_dw_model = torch_tensorrt.compile(np_model ,
|
||||||
|
# inputs = [torch_tensorrt.Input((1, 2, 240, 320)), torch_tensorrt.Input((1, 2, 240, 320))], # input shape
|
||||||
|
# enabled_precisions = {torch_tensorrt.dtype.half} # Run with FP16
|
||||||
|
# )
|
||||||
|
|
||||||
|
script_model = torch.jit.script(np_model.eval())
|
||||||
|
# script_dw_model = torch.jit.script(trt_dw_model.eval())
|
||||||
|
|
||||||
|
# save the TensorRT embedded Torchscript
|
||||||
|
# torch.jit.save(trt_model, 'trt_torchscript_module.ts')
|
||||||
|
# torch.jit.save(trt_dw_model, 'trt_torchscript_dw_module.ts')
|
||||||
|
print(script_model)
|
||||||
|
print(script_model.forward)
|
||||||
|
print(script_model.forward())
|
||||||
|
print(dir(script_model))
|
||||||
|
|
||||||
|
model = torch._C._jit_to_backend("tensorrt", script_model, spec)
|
||||||
|
|
||||||
print(f'loaded model {epoch}')
|
print(f'loaded model {epoch}')
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -74,14 +192,37 @@ def inference(left, right, model, n_iter=20):
|
|||||||
)
|
)
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
|
# pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
|
||||||
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
|
# pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
|
||||||
|
pred_flow_dw2 = model.forward(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
|
||||||
|
pred_flow = model.forward(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
|
||||||
|
|
||||||
pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy()
|
pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy()
|
||||||
|
|
||||||
return pred_disp
|
return pred_disp
|
||||||
|
|
||||||
|
|
||||||
|
def get_reference():
|
||||||
|
refs = [ref.path for ref in os.scandir('/home/nils/references/')]
|
||||||
|
for ref in refs:
|
||||||
|
reference = cv2.imread(ref)
|
||||||
|
yield reference
|
||||||
|
|
||||||
|
references = get_reference()
|
||||||
|
|
||||||
|
|
||||||
|
@app.post('/params/update_reference')
|
||||||
|
async def update_reference():
|
||||||
|
global references, reference_pattern
|
||||||
|
try:
|
||||||
|
reference_pattern = downsize(next(references))
|
||||||
|
print(reference_pattern.shape)
|
||||||
|
return {'status': 'success'}
|
||||||
|
except StopIteration:
|
||||||
|
references = get_reference()
|
||||||
|
return {'status': 'finished'}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/model/update/{epoch}")
|
@app.post("/model/update/{epoch}")
|
||||||
async def change_model(epoch: Union[int, Literal['latest']]):
|
async def change_model(epoch: Union[int, Literal['latest']]):
|
||||||
global model
|
global model
|
||||||
@ -103,8 +244,15 @@ async def set_minimal_data(enable: bool):
|
|||||||
minimal_data = enable
|
minimal_data = enable
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/params/temporal_init/{enable}")
|
||||||
|
async def set_temporal_init(enable: bool):
|
||||||
|
global temporal_init
|
||||||
|
temporal_init = enable
|
||||||
|
|
||||||
|
|
||||||
@app.put("/ir")
|
@app.put("/ir")
|
||||||
async def read_ir_input(file: UploadFile = File(...)):
|
async def read_ir_input(file: UploadFile = File(...)):
|
||||||
|
global last_img, minimal_data
|
||||||
try:
|
try:
|
||||||
img = np.array(Image.open(BytesIO(await file.read())))
|
img = np.array(Image.open(BytesIO(await file.read())))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -114,24 +262,35 @@ async def read_ir_input(file: UploadFile = File(...)):
|
|||||||
if len(img.shape) == 2:
|
if len(img.shape) == 2:
|
||||||
img = cv2.merge([img for _ in range(3)])
|
img = cv2.merge([img for _ in range(3)])
|
||||||
if img.shape == (1024, 1280, 3):
|
if img.shape == (1024, 1280, 3):
|
||||||
diff = (512 - 480) // 2
|
img = downsize(img)
|
||||||
downsampled = cv2.pyrDown(img)
|
|
||||||
img = downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]]
|
|
||||||
|
|
||||||
img = img.transpose((1, 2, 0))
|
# img = img.transpose((1, 2, 0))
|
||||||
ref_pat = reference_pattern.transpose((1, 2, 0))
|
# ref_pat = reference_pattern.transpose((1, 2, 0))
|
||||||
|
ref_pat = reference_pattern
|
||||||
|
|
||||||
start = datetime.now()
|
start = datetime.now()
|
||||||
pred_disp = inference(img, ref_pat, model, iters)
|
if temporal_init:
|
||||||
|
pred_disp = ctd_inference(img, ref_pat, None, None, model, None, iters, False, last_img)
|
||||||
|
last_img = pred_disp
|
||||||
|
else:
|
||||||
|
pred_disp = ctd_inference(img, ref_pat, None, None, model, None, iters, False)
|
||||||
|
# pred_disp = inference(img, ref_pat, model, iters)
|
||||||
duration = (datetime.now() - start).total_seconds()
|
duration = (datetime.now() - start).total_seconds()
|
||||||
|
|
||||||
if minimal_data:
|
if minimal_data:
|
||||||
return json.dumps({'disp': pred_disp, 'duration': duration}, cls=NumpyEncoder)
|
return json.dumps({'disp': pred_disp, 'duration': duration}, cls=NumpyEncoder)
|
||||||
else:
|
else:
|
||||||
|
# return json.dumps({'disp': pred_disp, 'input': img, 'duration': duration},
|
||||||
return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img, 'duration': duration},
|
return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img, 'duration': duration},
|
||||||
cls=NumpyEncoder)
|
cls=NumpyEncoder)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get('/temporal_init')
|
||||||
|
def get_temporal_init():
|
||||||
|
return {'status': 'enabled' if temporal_init else 'disabled'}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@app.get('/')
|
@app.get('/')
|
||||||
def main():
|
def main():
|
||||||
return {'test': 'abc'}
|
return {'test': 'abc'}
|
||||||
|
@ -4,18 +4,22 @@ base_lr: 4.0e-4
|
|||||||
|
|
||||||
nr_gpus: 3
|
nr_gpus: 3
|
||||||
batch_size: 4
|
batch_size: 4
|
||||||
n_total_epoch: 600
|
n_total_epoch: 300
|
||||||
minibatch_per_epoch: 500
|
minibatch_per_epoch: 500
|
||||||
|
|
||||||
loadmodel: ~
|
loadmodel: ~
|
||||||
log_dir: "./train_log"
|
log_dir: "./train_log"
|
||||||
|
log_dir_lightning: "./train_log_lightning"
|
||||||
model_save_freq_epoch: 1
|
model_save_freq_epoch: 1
|
||||||
|
|
||||||
max_disp: 256
|
max_disp: 256
|
||||||
image_width: 640
|
image_width: 640
|
||||||
image_height: 480
|
image_height: 480
|
||||||
# training_data_path: "./stereo_trainset/crestereo"
|
# training_data_path: "./stereo_trainset/crestereo"
|
||||||
training_data_path: "/media/Data1/connecting_the_dots_data/ctd_data/"
|
pattern_attention: true
|
||||||
|
dataset: "blender"
|
||||||
|
# training_data_path: "/media/Data1/connecting_the_dots_data/ctd_data/"
|
||||||
|
training_data_path: "/media/Data1/connecting_the_dots_data/blender_renders/data"
|
||||||
|
|
||||||
log_level: "logging.INFO"
|
log_level: "logging.INFO"
|
||||||
|
|
||||||
|
111
dataset.py
111
dataset.py
@ -17,7 +17,7 @@ class Augmentor:
|
|||||||
scale_min=0.6,
|
scale_min=0.6,
|
||||||
scale_max=1.0,
|
scale_max=1.0,
|
||||||
seed=0,
|
seed=0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.image_height = image_height
|
self.image_height = image_height
|
||||||
self.image_width = image_width
|
self.image_width = image_width
|
||||||
@ -234,12 +234,16 @@ class CREStereoDataset(Dataset):
|
|||||||
|
|
||||||
|
|
||||||
class CTDDataset(Dataset):
|
class CTDDataset(Dataset):
|
||||||
def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False):
|
def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False, split=0.9, test_set=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rng = np.random.RandomState(0)
|
self.rng = np.random.RandomState(0)
|
||||||
self.augment = augment
|
self.augment = augment
|
||||||
self.blur = blur
|
self.blur = blur
|
||||||
self.imgs = glob.glob(os.path.join(root, f"{data_type}/*/im0_*.npy"), recursive=True)
|
imgs = glob.glob(os.path.join(root, f"{data_type}/*/im0_*.npy"), recursive=True)
|
||||||
|
if test_set:
|
||||||
|
self.imgs = imgs[:int(split * len(imgs))]
|
||||||
|
else:
|
||||||
|
self.imgs = imgs[int(split * len(imgs)):]
|
||||||
self.pattern = cv2.imread(pattern_path)#, cv2.IMREAD_GRAYSCALE)
|
self.pattern = cv2.imread(pattern_path)#, cv2.IMREAD_GRAYSCALE)
|
||||||
|
|
||||||
if resize_pattern and self.pattern.shape != (480, 640, 3):
|
if resize_pattern and self.pattern.shape != (480, 640, 3):
|
||||||
@ -271,6 +275,10 @@ class CTDDataset(Dataset):
|
|||||||
|
|
||||||
# read img, disp
|
# read img, disp
|
||||||
left_img = np.load(left_path)
|
left_img = np.load(left_path)
|
||||||
|
|
||||||
|
if left_img.dtype == 'float32':
|
||||||
|
left_img = (left_img * 255).astype('uint8')
|
||||||
|
|
||||||
left_img = cv2.merge([left_img, left_img, left_img]).reshape((480, 640, 3))
|
left_img = cv2.merge([left_img, left_img, left_img]).reshape((480, 640, 3))
|
||||||
|
|
||||||
right_img = self.pattern
|
right_img = self.pattern
|
||||||
@ -307,3 +315,100 @@ class CTDDataset(Dataset):
|
|||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.imgs)
|
return len(self.imgs)
|
||||||
|
|
||||||
|
|
||||||
|
class BlenderDataset(CTDDataset):
|
||||||
|
def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False, split=0.9, test_set=False, use_lightning=False):
|
||||||
|
super().__init__(root, pattern_path)
|
||||||
|
self.use_lightning = use_lightning
|
||||||
|
imgs = [f for f in glob.glob(f"{root}/im_*.png", recursive=True) if not 'depth0001' in f]
|
||||||
|
if test_set:
|
||||||
|
self.imgs = imgs[:int(split * len(imgs))]
|
||||||
|
else:
|
||||||
|
self.imgs = imgs[int(split * len(imgs)):]
|
||||||
|
self.pattern = cv2.imread(pattern_path)#, cv2.IMREAD_GRAYSCALE)
|
||||||
|
|
||||||
|
if resize_pattern and self.pattern.shape != (480, 640, 3):
|
||||||
|
downsampled = cv2.pyrDown(self.pattern)
|
||||||
|
diff = (downsampled.shape[0] - 480) // 2
|
||||||
|
self.pattern = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]]
|
||||||
|
|
||||||
|
self.augmentor = Augmentor(
|
||||||
|
image_height=480,
|
||||||
|
image_width=640,
|
||||||
|
max_disp=256,
|
||||||
|
scale_min=0.6,
|
||||||
|
scale_max=1.0,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
# find path
|
||||||
|
left_path = self.imgs[index]
|
||||||
|
left_disp_path = left_path.split('.')[0] + '_depth0001.png'
|
||||||
|
|
||||||
|
# read img, disp
|
||||||
|
left_img = cv2.imread(left_path)
|
||||||
|
|
||||||
|
if left_img.dtype == 'float32':
|
||||||
|
left_img = (left_img * 255).astype('uint8')
|
||||||
|
|
||||||
|
if left_img.shape != (480, 640, 3):
|
||||||
|
downsampled = cv2.pyrDown(left_img)
|
||||||
|
diff = (downsampled.shape[0] - 480) // 2
|
||||||
|
left_img = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]]
|
||||||
|
if left_img.shape[-1] != 3:
|
||||||
|
left_img = cv2.merge([left_img, left_img, left_img]).reshape((480, 640, 3))
|
||||||
|
|
||||||
|
right_img = self.pattern
|
||||||
|
left_disp = self.get_disp(left_disp_path)
|
||||||
|
|
||||||
|
if False: # self.rng.binomial(1, 0.5):
|
||||||
|
left_img, right_img = np.fliplr(right_img), np.fliplr(left_img)
|
||||||
|
left_disp, right_disp = np.fliplr(right_disp), np.fliplr(left_disp)
|
||||||
|
left_disp[left_disp == np.inf] = 0
|
||||||
|
|
||||||
|
if self.blur:
|
||||||
|
kernel_size = random.sample([1,3,5,7,9], 1)[0]
|
||||||
|
kernel = (kernel_size, kernel_size)
|
||||||
|
left_img = cv2.GaussianBlur(left_img, kernel, 0)
|
||||||
|
|
||||||
|
# augmentation
|
||||||
|
if not self.augment:
|
||||||
|
_left_img, _right_img, _left_disp, disp_mask = self.augmentor(
|
||||||
|
left_img, right_img, left_disp
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
left_img, right_img, left_disp, disp_mask = self.augmentor(
|
||||||
|
left_img, right_img, left_disp
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.use_lightning:
|
||||||
|
right_img = right_img.transpose((2, 0, 1)).astype("uint8")
|
||||||
|
return {
|
||||||
|
"left": left_img,
|
||||||
|
"right": right_img,
|
||||||
|
"disparity": left_disp,
|
||||||
|
"mask": disp_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
right_img = right_img.transpose((2, 0, 1)).astype("uint8")
|
||||||
|
left_img = left_img.transpose((2, 0, 1)).astype("uint8")
|
||||||
|
return left_img, right_img, left_disp, disp_mask
|
||||||
|
|
||||||
|
def get_disp(self, path):
|
||||||
|
baseline = 0.075 # meters
|
||||||
|
fl = 560. # as per CTD
|
||||||
|
depth = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||||
|
downsampled = cv2.pyrDown(depth)
|
||||||
|
diff = (downsampled.shape[0] - 480) // 2
|
||||||
|
depth = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]]
|
||||||
|
# disp = np.load(path).transpose(1,2,0)
|
||||||
|
# disp = baseline * fl / depth
|
||||||
|
# return disp.astype(np.float32) / 32
|
||||||
|
# FIXME temporarily increase disparity until new data with better depth values is generated
|
||||||
|
# higher values seem to speedup convergence, but introduce much stronger artifacting
|
||||||
|
# mystery_factor = 150
|
||||||
|
mystery_factor = 1
|
||||||
|
disp = (baseline * fl * mystery_factor) / depth
|
||||||
|
return disp.astype(np.float32)
|
||||||
|
@ -86,8 +86,15 @@ class LocalFeatureTransformer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal"
|
assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal"
|
||||||
|
|
||||||
for layer, name in zip(self.layers, self.layer_names):
|
# NOTE Workaround for non statically determinable zip
|
||||||
|
# for layer, name in zip(self.layers, self.layer_names):
|
||||||
|
# layer_zip = ((layer, self.layer_names[i]) for i, layer in enumerate(self.layers))
|
||||||
|
# layer_zip = []
|
||||||
|
# for i, layer in enumerate(self.layers):
|
||||||
|
# layer_zip.append((layer, self.layer_names[i]))
|
||||||
|
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
name = self.layer_names[i]
|
||||||
if name == 'self':
|
if name == 'self':
|
||||||
feat0 = layer(feat0, feat0, mask0, mask0)
|
feat0 = layer(feat0, feat0, mask0, mask0)
|
||||||
feat1 = layer(feat1, feat1, mask1, mask1)
|
feat1 = layer(feat1, feat1, mask1, mask1)
|
||||||
@ -97,4 +104,4 @@ class LocalFeatureTransformer(nn.Module):
|
|||||||
else:
|
else:
|
||||||
raise KeyError
|
raise KeyError
|
||||||
|
|
||||||
return feat0, feat1
|
return feat0, feat1
|
||||||
|
@ -36,6 +36,12 @@ class CREStereo(nn.Module):
|
|||||||
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=self.dropout)
|
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=self.dropout)
|
||||||
self.update_block = BasicUpdateBlock(hidden_dim=self.hidden_dim, cor_planes=4 * 9, mask_size=4)
|
self.update_block = BasicUpdateBlock(hidden_dim=self.hidden_dim, cor_planes=4 * 9, mask_size=4)
|
||||||
|
|
||||||
|
# # NOTE Position_encoding as workaround for TensorRt
|
||||||
|
# image1_shape = [1, 2, 480, 640]
|
||||||
|
# self.pos_encoding_fn_small = PositionEncodingSine(
|
||||||
|
# d_model=256, max_shape=(image1_shape[2] // 16, image1_shape[3] // 16)
|
||||||
|
# )
|
||||||
|
|
||||||
# loftr
|
# loftr
|
||||||
self.self_att_fn = LocalFeatureTransformer(
|
self.self_att_fn = LocalFeatureTransformer(
|
||||||
d_model=256, nhead=8, layer_names=["self"] * 1, attention="linear"
|
d_model=256, nhead=8, layer_names=["self"] * 1, attention="linear"
|
||||||
@ -81,7 +87,7 @@ class CREStereo(nn.Module):
|
|||||||
zero_flow = torch.cat((_x, _y), dim=1).to(fmap.device)
|
zero_flow = torch.cat((_x, _y), dim=1).to(fmap.device)
|
||||||
return zero_flow
|
return zero_flow
|
||||||
|
|
||||||
def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False):
|
def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False, self_attend_right=True):
|
||||||
""" Estimate optical flow between pair of frames """
|
""" Estimate optical flow between pair of frames """
|
||||||
|
|
||||||
image1 = 2 * (image1 / 255.0) - 1.0
|
image1 = 2 * (image1 / 255.0) - 1.0
|
||||||
@ -130,17 +136,22 @@ class CREStereo(nn.Module):
|
|||||||
inp_dw16 = F.avg_pool2d(inp, 4, stride=4)
|
inp_dw16 = F.avg_pool2d(inp, 4, stride=4)
|
||||||
|
|
||||||
# positional encoding and self-attention
|
# positional encoding and self-attention
|
||||||
pos_encoding_fn_small = PositionEncodingSine(
|
# pos_encoding_fn_small = PositionEncodingSine(
|
||||||
d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16)
|
# d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16)
|
||||||
)
|
# )
|
||||||
# 'n c h w -> n (h w) c'
|
# 'n c h w -> n (h w) c'
|
||||||
x_tmp = pos_encoding_fn_small(fmap1_dw16)
|
x_tmp = self.pos_encoding_fn_small(fmap1_dw16)
|
||||||
fmap1_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
|
fmap1_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
|
||||||
# 'n c h w -> n (h w) c'
|
# 'n c h w -> n (h w) c'
|
||||||
x_tmp = pos_encoding_fn_small(fmap2_dw16)
|
x_tmp = self.pos_encoding_fn_small(fmap2_dw16)
|
||||||
fmap2_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
|
fmap2_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
|
||||||
|
|
||||||
fmap1_dw16, fmap2_dw16 = self.self_att_fn(fmap1_dw16, fmap2_dw16)
|
# FIXME experimental ! no self-attention for pattern
|
||||||
|
if not self_attend_right:
|
||||||
|
fmap1_dw16, _ = self.self_att_fn(fmap1_dw16, fmap2_dw16)
|
||||||
|
else:
|
||||||
|
fmap1_dw16, fmap2_dw16 = self.self_att_fn(fmap1_dw16, fmap2_dw16)
|
||||||
|
|
||||||
fmap1_dw16, fmap2_dw16 = [
|
fmap1_dw16, fmap2_dw16 = [
|
||||||
x.reshape(x.shape[0], image1.shape[2] // 16, -1, x.shape[2]).permute(0, 3, 1, 2)
|
x.reshape(x.shape[0], image1.shape[2] // 16, -1, x.shape[2]).permute(0, 3, 1, 2)
|
||||||
for x in [fmap1_dw16, fmap2_dw16]
|
for x in [fmap1_dw16, fmap2_dw16]
|
||||||
@ -258,3 +269,4 @@ class CREStereo(nn.Module):
|
|||||||
return flow_up
|
return flow_up
|
||||||
|
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
|
from typing import List
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
# Ref: https://github.com/princeton-vl/RAFT/blob/master/core/extractor.py
|
# Ref: https://github.com/princeton-vl/RAFT/blob/master/core/extractor.py
|
||||||
class ResidualBlock(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
@ -96,28 +98,43 @@ class BasicEncoder(nn.Module):
|
|||||||
self.in_planes = dim
|
self.in_planes = dim
|
||||||
return nn.Sequential(*layers)
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x: List[Tensor]):
|
||||||
|
# NOTE always assume list, otherwise TensorRT is sad
|
||||||
|
# batch_dim = x[0].shape[0]
|
||||||
|
# x_tensor = torch.cat(list(x), dim=0)
|
||||||
|
|
||||||
# if input is list, combine batch dimension
|
# if input is list, combine batch dimension
|
||||||
is_list = isinstance(x, tuple) or isinstance(x, list)
|
is_list = isinstance(x, tuple) or isinstance(x, list)
|
||||||
if is_list:
|
if is_list:
|
||||||
batch_dim = x[0].shape[0]
|
batch_dim = x[0].shape[0]
|
||||||
x = torch.cat(x, dim=0)
|
x_tensor = torch.cat(x, dim=0)
|
||||||
|
else:
|
||||||
|
x_tensor = x
|
||||||
|
|
||||||
x = self.conv1(x)
|
print()
|
||||||
x = self.norm1(x)
|
print()
|
||||||
x = self.relu1(x)
|
print(x_tensor.shape)
|
||||||
|
print()
|
||||||
|
print()
|
||||||
|
|
||||||
x = self.layer1(x)
|
x_tensor = self.conv1(x_tensor)
|
||||||
x = self.layer2(x)
|
x_tensor = self.norm1(x_tensor)
|
||||||
x = self.layer3(x)
|
x_tensor = self.relu1(x_tensor)
|
||||||
|
|
||||||
x = self.conv2(x)
|
x_tensor = self.layer1(x_tensor)
|
||||||
|
x_tensor = self.layer2(x_tensor)
|
||||||
|
x_tensor = self.layer3(x_tensor)
|
||||||
|
|
||||||
|
x_tensor = self.conv2(x_tensor)
|
||||||
|
|
||||||
if self.dropout is not None:
|
if self.dropout is not None:
|
||||||
x = self.dropout(x)
|
x_tensor = self.dropout(x_tensor)
|
||||||
|
|
||||||
if is_list:
|
if is_list:
|
||||||
x = torch.split(x, x.shape[0]//2, dim=0)
|
x_list = torch.split(x_tensor, x_tensor.shape[0]//2, dim=0)
|
||||||
|
return x_list
|
||||||
|
|
||||||
return x
|
x_list = torch.split(x_tensor, x_tensor.shape[0]//2, dim=0)
|
||||||
|
return x_list
|
||||||
|
|
||||||
|
# return list(x)
|
||||||
|
@ -77,7 +77,7 @@ class BasicUpdateBlock(nn.Module):
|
|||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Conv2d(256, mask_size**2 *9, 1, padding=0))
|
nn.Conv2d(256, mask_size**2 *9, 1, padding=0))
|
||||||
|
|
||||||
def forward(self, net, inp, corr, flow, upsample=True):
|
def forward(self, net, inp, corr, flow, upsample: bool=True):
|
||||||
# print(inp.shape, corr.shape, flow.shape)
|
# print(inp.shape, corr.shape, flow.shape)
|
||||||
motion_features = self.encoder(flow, corr)
|
motion_features = self.encoder(flow, corr)
|
||||||
# print(motion_features.shape, inp.shape)
|
# print(motion_features.shape, inp.shape)
|
||||||
|
109
test_model.py
109
test_model.py
@ -3,6 +3,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
|
import os
|
||||||
|
|
||||||
from nets import Model
|
from nets import Model
|
||||||
|
|
||||||
@ -16,17 +17,20 @@ device = 'cuda'
|
|||||||
wandb.init(project="crestereo", entity="cpt-captain")
|
wandb.init(project="crestereo", entity="cpt-captain")
|
||||||
|
|
||||||
|
|
||||||
def do_infer(left_img, right_img, gt_disp, model):
|
def do_infer(left_img, right_img, gt_disp, model, attend_pattern=True):
|
||||||
disp = ctd_inference(left_img, right_img, gt_disp, None, model, None, n_iter=20, wandb_log=False)
|
disp = ctd_inference(left_img, right_img, gt_disp, None, model, None, n_iter=20, wandb_log=False, attend_pattern=attend_pattern)
|
||||||
|
|
||||||
disp_vis = normalize_and_colormap(disp)
|
disp_vis = normalize_and_colormap(disp)
|
||||||
gt_disp_vis = normalize_and_colormap(gt_disp)
|
# gt_disp_vis = normalize_and_colormap(gt_disp)
|
||||||
if gt_disp.shape != disp.shape:
|
# if gt_disp.shape != disp.shape:
|
||||||
gt_disp = gt_disp.reshape(disp.shape)
|
# gt_disp = gt_disp.reshape(disp.shape)
|
||||||
disp_err = gt_disp - disp
|
# disp_err = gt_disp - disp
|
||||||
disp_err = normalize_and_colormap(disp_err.abs())
|
# disp_err = normalize_and_colormap(disp_err.abs())
|
||||||
|
if isinstance(left_img, torch.Tensor):
|
||||||
|
left_img = left_img.cpu().detach().numpy().astype('uint8')
|
||||||
|
right_img = right_img.cpu().detach().numpy().astype('uint8')
|
||||||
|
|
||||||
wandb.log({
|
results = {
|
||||||
'disp': wandb.Image(
|
'disp': wandb.Image(
|
||||||
disp,
|
disp,
|
||||||
caption=f"Pred. Disparity \n{disp.min():.{2}f}/{disp.max():.{2}f}",
|
caption=f"Pred. Disparity \n{disp.min():.{2}f}/{disp.max():.{2}f}",
|
||||||
@ -35,41 +39,60 @@ def do_infer(left_img, right_img, gt_disp, model):
|
|||||||
disp_vis,
|
disp_vis,
|
||||||
caption=f"Pred. Disparity \n{disp.min():.{2}f}/{disp.max():.{2}f}",
|
caption=f"Pred. Disparity \n{disp.min():.{2}f}/{disp.max():.{2}f}",
|
||||||
),
|
),
|
||||||
'gt_disp_vis': wandb.Image(
|
# 'disp_err': wandb.Image(
|
||||||
gt_disp_vis,
|
# disp_err,
|
||||||
caption=f"GT Disparity \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}",
|
# caption=f"Disparity Error\n{disp_err.min():.{2}f}/{disp_err.max():.{2}f}",
|
||||||
),
|
# ),
|
||||||
'disp_err': wandb.Image(
|
|
||||||
disp_err,
|
|
||||||
caption=f"Disparity Error\n{disp_err.min():.{2}f}/{disp_err.max():.{2}f}",
|
|
||||||
),
|
|
||||||
'input_left': wandb.Image(
|
'input_left': wandb.Image(
|
||||||
left_img.cpu().detach().numpy().astype('uint8'),
|
left_img,
|
||||||
caption=f"Input left",
|
caption=f"Input left",
|
||||||
),
|
),
|
||||||
'input_right': wandb.Image(
|
'input_right': wandb.Image(
|
||||||
right_img.cpu().detach().numpy().astype('uint8'),
|
right_img,
|
||||||
caption=f"Input right",
|
caption=f"Input right",
|
||||||
),
|
),
|
||||||
})
|
}
|
||||||
|
|
||||||
|
if gt_disp is not None:
|
||||||
|
print('logging gt')
|
||||||
|
print(f'gt: {gt_disp.max()}/{gt_disp.min()}/{gt_disp.mean()}')
|
||||||
|
gt_disp_vis = normalize_and_colormap(gt_disp)
|
||||||
|
results.update({
|
||||||
|
'gt_disp_vis': wandb.Image(
|
||||||
|
gt_disp_vis,
|
||||||
|
caption=f"GT Disparity \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}",
|
||||||
|
)})
|
||||||
|
wandb.log(results)
|
||||||
|
|
||||||
|
|
||||||
|
def downsample(img, half_height_out=480):
|
||||||
|
downsampled = cv2.pyrDown(img)
|
||||||
|
diff = (downsampled.shape[0] - half_height_out) // 2
|
||||||
|
return downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# model_path = "models/crestereo_eth3d.pth"
|
# model_path = "models/crestereo_eth3d.pth"
|
||||||
model_path = "train_log/models/latest.pth"
|
model_path = "train_log/models/latest.pth"
|
||||||
|
# model_path = "train_log/models/epoch-120.pth"
|
||||||
|
# model_path = "train_log/models/epoch-250.pth"
|
||||||
|
|
||||||
|
print(model_path)
|
||||||
# reference_pattern_path = '/home/nils/kinect_reference_high_res_scaled_down.png'
|
# reference_pattern_path = '/home/nils/kinect_reference_high_res_scaled_down.png'
|
||||||
reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
|
|
||||||
# reference_pattern_path = '/home/nils/new_reference.png'
|
# reference_pattern_path = '/home/nils/new_reference.png'
|
||||||
# reference_pattern_path = '/home/nils/kinect_reference_high_res.png'
|
# reference_pattern_path = '/home/nils/kinect_reference_high_res.png'
|
||||||
|
reference_pattern_path = '/home/nils/miniprojekt/kinect_high_res_thresh_denoised.png'
|
||||||
|
# reference_pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
|
||||||
|
# reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
|
||||||
# reference_pattern_path = '/home/nils/orig_ctd/connecting_the_dots/data/kinect_pattern.png'
|
# reference_pattern_path = '/home/nils/orig_ctd/connecting_the_dots/data/kinect_pattern.png'
|
||||||
|
|
||||||
data_type = 'kinect'
|
# data_type = 'kinect'
|
||||||
|
data_type = 'blender'
|
||||||
augment = False
|
augment = False
|
||||||
|
|
||||||
args = parse_yaml("cfgs/train.yaml")
|
args = parse_yaml("cfgs/train.yaml")
|
||||||
|
|
||||||
wandb.config.update({'model_path': model_path, 'reference_pattern': reference_pattern_path, 'augment': augment})
|
wandb.config.update({'model_path': model_path, 'reference_pattern': reference_pattern_path, 'augment': augment, 'data_type': data_type, 'pattern_self_attention': args.pattern_attention})
|
||||||
|
|
||||||
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
|
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
|
||||||
model = nn.DataParallel(model, device_ids=[device])
|
model = nn.DataParallel(model, device_ids=[device])
|
||||||
@ -78,16 +101,32 @@ if __name__ == '__main__':
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
dataset = CTDDataset('/media/Data1/connecting_the_dots_data/ctd_data/', data_type=data_type,
|
# dataset = CTDDataset('/media/Data1/connecting_the_dots_data/ctd_data/', data_type=data_type,
|
||||||
pattern_path=reference_pattern_path, augment=augment)
|
# pattern_path=reference_pattern_path, augment=augment)
|
||||||
dataloader = DataLoader(dataset, args.batch_size, shuffle=True,
|
# dataloader = DataLoader(dataset, args.batch_size, shuffle=True,
|
||||||
num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True)
|
# num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True)
|
||||||
for batch in dataloader:
|
|
||||||
for left, right, disparity in zip(batch['left'], batch['right'], batch['disparity']):
|
# for batch in dataloader:
|
||||||
right = right.transpose(0, 2).transpose(0, 1)
|
gt_disp = None
|
||||||
left_img = left
|
right = downsample(cv2.imread(reference_pattern_path))
|
||||||
imgL = left.cpu().detach().numpy()
|
|
||||||
right_img = right
|
if data_type == 'blender':
|
||||||
imgR = right.cpu().detach().numpy()
|
img_path = '/media/Data1/connecting_the_dots_data/blender_renders/data/'
|
||||||
gt_disp = disparity
|
elif data_type == 'kinect':
|
||||||
do_infer(left_img, right_img, gt_disp, model)
|
img_path = '/home/nils/kinect_pngs/ir/'
|
||||||
|
|
||||||
|
for img in sorted(list(entry for entry in os.scandir(img_path) if 'depth' not in entry.name), key=lambda x:x.name)[:25]:
|
||||||
|
print(img.path)
|
||||||
|
if data_type == 'blender':
|
||||||
|
baseline = 0.075 # meters
|
||||||
|
fl = 560. # as per CTD
|
||||||
|
|
||||||
|
gt_path = img.path.rsplit('.')[0] + '_depth0001.png'
|
||||||
|
gt_depth = downsample(cv2.imread(gt_path))
|
||||||
|
|
||||||
|
mystery_factor = 35 # we don't get reasonable disparities due to incorrect depth scaling (or something like that)
|
||||||
|
gt_disp = (baseline * fl * mystery_factor) / gt_depth
|
||||||
|
|
||||||
|
left = downsample(cv2.imread(img.path))
|
||||||
|
|
||||||
|
do_infer(left, right, gt_disp, model, attend_pattern=args.pattern_attention)
|
||||||
|
191
train.py
191
train.py
@ -9,7 +9,7 @@ import yaml
|
|||||||
|
|
||||||
from nets import Model
|
from nets import Model
|
||||||
# from dataset import CREStereoDataset
|
# from dataset import CREStereoDataset
|
||||||
from dataset import CREStereoDataset, CTDDataset
|
from dataset import BlenderDataset, CREStereoDataset, CTDDataset
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -32,14 +32,18 @@ def normalize_and_colormap(img):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=True):
|
def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=True, last_img=None, test=False, attend_pattern=True):
|
||||||
|
|
||||||
print("Model Forwarding...")
|
print("Model Forwarding...")
|
||||||
left = left.cpu().detach().numpy()
|
if isinstance(left, torch.Tensor):
|
||||||
|
left = left.cpu().detach().numpy()
|
||||||
|
imgR = right.cpu().detach().numpy()
|
||||||
imgL = left
|
imgL = left
|
||||||
imgR = right.cpu().detach().numpy()
|
imgR = right
|
||||||
imgL = np.ascontiguousarray(imgL[None, :, :, :])
|
imgL = np.ascontiguousarray(imgL[None, :, :, :])
|
||||||
imgR = np.ascontiguousarray(imgR[None, :, :, :])
|
imgR = np.ascontiguousarray(imgR[None, :, :, :])
|
||||||
|
|
||||||
|
flow_init = None
|
||||||
|
|
||||||
# chosen for convenience
|
# chosen for convenience
|
||||||
device = torch.device('cuda:0')
|
device = torch.device('cuda:0')
|
||||||
@ -55,19 +59,54 @@ def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=Tru
|
|||||||
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
||||||
mode="bilinear",
|
mode="bilinear",
|
||||||
align_corners=True,
|
align_corners=True,
|
||||||
)
|
).clamp(min=0, max=255)
|
||||||
imgR_dw2 = F.interpolate(
|
imgR_dw2 = F.interpolate(
|
||||||
imgR,
|
imgR,
|
||||||
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
||||||
mode="bilinear",
|
mode="bilinear",
|
||||||
align_corners=True,
|
align_corners=True,
|
||||||
)
|
).clamp(min=0, max=255)
|
||||||
|
if last_img is not None:
|
||||||
|
print('using flow_initialization')
|
||||||
|
print(last_img.shape)
|
||||||
|
# FIXME trying 8bit normalization because sometimes we get really high values that don't seem to help
|
||||||
|
print(last_img.max(), last_img.min())
|
||||||
|
if last_img.min() < 0:
|
||||||
|
# print('Negative disparity detected. shifting...')
|
||||||
|
last_img = last_img - last_img.min()
|
||||||
|
if last_img.max() > 255:
|
||||||
|
# print('Excessive disparity detected. scaling...')
|
||||||
|
last_img = last_img / (last_img.max() / 255)
|
||||||
|
|
||||||
|
|
||||||
|
last_img = np.dstack([last_img, last_img])
|
||||||
|
# last_img = np.dstack([last_img, last_img, last_img])
|
||||||
|
last_img = np.dstack([last_img])
|
||||||
|
last_img = last_img.reshape((1, 2, 480, 640))
|
||||||
|
# print(last_img.shape)
|
||||||
|
# print(last_img.dtype)
|
||||||
|
# print(last_img.max(), last_img.min())
|
||||||
|
flow_init = torch.tensor(last_img.astype("float32")).to(device)
|
||||||
|
# flow_init = F.interpolate(
|
||||||
|
# last_img,
|
||||||
|
# size=(last_img.shape[0] // 2, last_img.shape[1] // 2),
|
||||||
|
# mode="bilinear",
|
||||||
|
# align_corners=True,
|
||||||
|
# )
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
|
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=flow_init, self_attend_right=attend_pattern)
|
||||||
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
|
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2, self_attend_right=attend_pattern)
|
||||||
|
pf_base = pred_flow
|
||||||
|
if isinstance(pf_base, list):
|
||||||
|
pf_base = pred_flow[0]
|
||||||
|
pf = torch.squeeze(pf_base[:, 0, :, :]).cpu().detach().numpy()
|
||||||
|
print('pred_flow max min')
|
||||||
|
print(pf.max(), pf.min())
|
||||||
|
|
||||||
|
|
||||||
if not wandb_log:
|
if not wandb_log:
|
||||||
|
if test:
|
||||||
|
return pred_flow
|
||||||
return torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy()
|
return torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy()
|
||||||
|
|
||||||
log = {}
|
log = {}
|
||||||
@ -96,30 +135,36 @@ def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=Tru
|
|||||||
np.array([pred_disp.reshape(480, 640)]),
|
np.array([pred_disp.reshape(480, 640)]),
|
||||||
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
|
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
|
||||||
)
|
)
|
||||||
log[f'pred_norm_{i}'] = wandb.Image(
|
# log[f'pred_norm_{i}'] = wandb.Image(
|
||||||
np.array([pred_disp_norm.reshape(480, 640)]),
|
# np.array([pred_disp_norm.reshape(480, 640)]),
|
||||||
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
|
# caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
|
||||||
)
|
# )
|
||||||
|
|
||||||
log[f'pred_dw2_{i}'] = wandb.Image(
|
# log[f'pred_dw2_{i}'] = wandb.Image(
|
||||||
np.array([pred_disp_dw2.reshape(240, 320)]),
|
# np.array([pred_disp_dw2.reshape(240, 320)]),
|
||||||
caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
|
# caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
|
||||||
)
|
# )
|
||||||
log[f'pred_dw2_norm_{i}'] = wandb.Image(
|
# log[f'pred_dw2_norm_{i}'] = wandb.Image(
|
||||||
np.array([pred_disp_dw2_norm.reshape(240, 320)]),
|
# np.array([pred_disp_dw2_norm.reshape(240, 320)]),
|
||||||
caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
|
# caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
|
||||||
)
|
# )
|
||||||
|
|
||||||
|
|
||||||
log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left")
|
log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left")
|
||||||
log['input_right'] = wandb.Image(right.cpu().detach().numpy().transpose(1,2,0).astype('uint8'), caption="Input Right")
|
input_right = right.cpu().detach().numpy() if isinstance(right, torch.Tensor) else right
|
||||||
|
if input_right.shape != (480, 640, 3):
|
||||||
|
input_right.transpose(1,2,0)
|
||||||
|
log['input_right'] = wandb.Image(input_right.astype('uint8'), caption="Input Right")
|
||||||
|
|
||||||
log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}")
|
log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}")
|
||||||
|
|
||||||
|
gt_disp = gt_disp.cpu().detach().numpy() if isinstance(gt_disp, torch.Tensor) else gt_disp
|
||||||
|
disp = disp.cpu().detach().numpy() if isinstance(disp, torch.Tensor) else disp
|
||||||
|
|
||||||
disp_error = gt_disp - disp
|
disp_error = gt_disp - disp
|
||||||
log['disp_error'] = wandb.Image(
|
log['disp_error'] = wandb.Image(
|
||||||
normalize_and_colormap(disp_error.abs()),
|
normalize_and_colormap(abs(disp_error)),
|
||||||
caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{disp_error.abs().mean():.{2}f}",
|
caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{abs(disp_error).mean():.{2}f}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -129,6 +174,7 @@ def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=Tru
|
|||||||
)
|
)
|
||||||
|
|
||||||
wandb.log(log)
|
wandb.log(log)
|
||||||
|
return pred_flow
|
||||||
|
|
||||||
|
|
||||||
def parse_yaml(file_path: str) -> namedtuple:
|
def parse_yaml(file_path: str) -> namedtuple:
|
||||||
@ -172,12 +218,25 @@ def adjust_learning_rate(optimizer, epoch):
|
|||||||
for param_group in optimizer.param_groups:
|
for param_group in optimizer.param_groups:
|
||||||
param_group['lr'] = lr
|
param_group['lr'] = lr
|
||||||
|
|
||||||
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8):
|
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False):
|
||||||
'''
|
'''
|
||||||
valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W)
|
valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W)
|
||||||
flow_preds[0]: (B, 2, H, W)
|
flow_preds[0]: (B, 2, H, W)
|
||||||
flow_gt: (B, 2, H, W)
|
flow_gt: (B, 2, H, W)
|
||||||
'''
|
'''
|
||||||
|
if test:
|
||||||
|
# print('sequence loss')
|
||||||
|
if valid.shape != (2, 480, 640):
|
||||||
|
valid = torch.stack([valid, valid])#.transpose(0,1)#.transpose(1,2)
|
||||||
|
# print(valid.shape)
|
||||||
|
#valid = torch.stack([valid, valid])
|
||||||
|
# print(valid.shape)
|
||||||
|
if valid.shape != (2, 480, 640):
|
||||||
|
valid = valid.transpose(0,1)
|
||||||
|
# print(valid.shape)
|
||||||
|
# print(valid.shape)
|
||||||
|
# print(flow_preds[0].shape)
|
||||||
|
# print(flow_gt.shape)
|
||||||
n_predictions = len(flow_preds)
|
n_predictions = len(flow_preds)
|
||||||
flow_loss = 0.0
|
flow_loss = 0.0
|
||||||
|
|
||||||
@ -260,20 +319,41 @@ def main(args):
|
|||||||
start_iters = 0
|
start_iters = 0
|
||||||
|
|
||||||
# pattern_path = '/home/nils/kinect_reference_high_res_scaled_down.png'
|
# pattern_path = '/home/nils/kinect_reference_high_res_scaled_down.png'
|
||||||
pattern_path = '/home/nils/kinect_reference_cropped.png'
|
# pattern_path = '/home/nils/kinect_reference_cropped.png'
|
||||||
|
# pattern_path = '/home/nils/kinect_reference_far.png'
|
||||||
|
# pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
|
||||||
|
pattern_path = '/home/nils/miniprojekt/kinect_high_res_thresh_denoised.png'
|
||||||
# pattern_path = '/home/nils/orig_ctd/connecting_the_dots/data/kinect_pattern.png'
|
# pattern_path = '/home/nils/orig_ctd/connecting_the_dots/data/kinect_pattern.png'
|
||||||
# datasets
|
# datasets
|
||||||
# dataset = CREStereoDataset(args.training_data_path)
|
# dataset = CREStereoDataset(args.training_data_path)
|
||||||
dataset = CTDDataset(args.training_data_path, pattern_path=pattern_path)
|
# dataset = CTDDataset(args.training_data_path, pattern_path=pattern_path)
|
||||||
|
# test_dataset = CTDDataset(args.training_data_path, pattern_path=pattern_path, test_set=True)
|
||||||
|
if args.dataset == 'blender':
|
||||||
|
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
|
||||||
|
dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path)
|
||||||
|
test_dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path, test_set=True)
|
||||||
|
elif args.dataset == 'ctd':
|
||||||
|
dataset = CTDDataset(args.training_data_path, pattern_path=pattern_path)
|
||||||
|
test_dataset = CTDDataset(args.training_data_path, pattern_path=pattern_path, test_set=True)
|
||||||
|
else:
|
||||||
|
print('unrecognized dataset')
|
||||||
|
quit()
|
||||||
|
|
||||||
|
test_data_iter = iter(test_dataset)
|
||||||
# if rank == 0:
|
# if rank == 0:
|
||||||
worklog.info(f"Dataset size: {len(dataset)}")
|
worklog.info(f"Dataset size: {len(dataset)}")
|
||||||
|
print(args.batch_size)
|
||||||
dataloader = DataLoader(dataset, args.batch_size, shuffle=True,
|
dataloader = DataLoader(dataset, args.batch_size, shuffle=True,
|
||||||
num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True)
|
num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True)
|
||||||
|
# num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True)
|
||||||
|
test_dataloader = DataLoader(test_dataset, args.batch_size, shuffle=False,
|
||||||
|
num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True)
|
||||||
|
|
||||||
# counter
|
# counter
|
||||||
cur_iters = start_iters
|
cur_iters = start_iters
|
||||||
total_iters = args.minibatch_per_epoch * args.n_total_epoch
|
total_iters = args.minibatch_per_epoch * args.n_total_epoch
|
||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
|
test_idx = 0
|
||||||
for epoch_idx in range(start_epoch_idx, args.n_total_epoch + 1):
|
for epoch_idx in range(start_epoch_idx, args.n_total_epoch + 1):
|
||||||
|
|
||||||
# adjust learning rate
|
# adjust learning rate
|
||||||
@ -310,24 +390,59 @@ def main(args):
|
|||||||
# forward
|
# forward
|
||||||
# left = left.transpose(1, 2).transpose(2, 3)
|
# left = left.transpose(1, 2).transpose(2, 3)
|
||||||
left = left.transpose(1, 3).transpose(2, 3)
|
left = left.transpose(1, 3).transpose(2, 3)
|
||||||
right = right.transpose(2, 3).transpose(2, 3)#.transpose(1, 2)
|
# right = right.transpose(2, 3).transpose(2, 3)#.transpose(1, 2)
|
||||||
flow_predictions = model(left.cuda(), right.cuda())
|
flow_predictions = model(left.cuda(), right.cuda(), self_attend_right=args.pattern_attention)
|
||||||
|
|
||||||
# loss & backword
|
# loss & backword
|
||||||
loss = sequence_loss(
|
loss = sequence_loss(
|
||||||
flow_predictions, gt_flow, valid_mask, gamma=0.8
|
flow_predictions, gt_flow, valid_mask, gamma=0.8
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % 128 == 0:
|
if batch_idx % 512 == 0:
|
||||||
inference(
|
test_idx = 0
|
||||||
mini_batch_data['left'][0],
|
test_loss = 0
|
||||||
mini_batch_data['right'][0],
|
for i, test_batch in enumerate(test_dataset):
|
||||||
mini_batch_data['disparity'][0],
|
# test_batch = next(test_data_iter)
|
||||||
mini_batch_data['mask'][0],
|
if i >= 24:
|
||||||
model,
|
break
|
||||||
batch_idx,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# TODO refactor, DRY
|
||||||
|
left, right, gt_disp, valid_mask = (
|
||||||
|
test_batch['left'],
|
||||||
|
test_batch['right'],
|
||||||
|
torch.Tensor(test_batch['disparity']).cuda(),
|
||||||
|
torch.Tensor(test_batch['mask']).cuda(),
|
||||||
|
)
|
||||||
|
gt_disp = torch.dstack([gt_disp, gt_disp]).transpose(2,0).transpose(1,2)
|
||||||
|
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512]
|
||||||
|
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
|
||||||
|
# print(f'left {left.shape}, right {right.shape}')
|
||||||
|
# left = left.transpose([2, 0, 1])
|
||||||
|
right = right.transpose([1, 2, 0])
|
||||||
|
# right = right.transpose(2, 3).transpose(2, 3)#.transpose(1, 2)
|
||||||
|
# print(f'left {left.shape}, right {right.shape}')
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
flow_predictions = inference(
|
||||||
|
left,
|
||||||
|
right,
|
||||||
|
# gt_disp,
|
||||||
|
torch.Tensor(test_batch['disparity']).cuda(),
|
||||||
|
valid_mask,
|
||||||
|
model,
|
||||||
|
test_idx,
|
||||||
|
wandb_log=i % 4 == 0,
|
||||||
|
test=True,
|
||||||
|
)
|
||||||
|
test_idx += 1
|
||||||
|
test_loss += sequence_loss(
|
||||||
|
flow_predictions, gt_flow, valid_mask, gamma=0.8, test=True
|
||||||
|
).data.item()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
avg_test_loss = test_loss / test_idx
|
||||||
|
print(f'test_loss: {test_loss}\nlen test: {test_idx}\navg. loss: {avg_test_loss}')
|
||||||
|
metrics['test/loss'] = avg_test_loss
|
||||||
# loss stats
|
# loss stats
|
||||||
loss_item = loss.data.item()
|
loss_item = loss.data.item()
|
||||||
epoch_total_train_loss += loss_item
|
epoch_total_train_loss += loss_item
|
||||||
|
346
train_lightning.py
Normal file
346
train_lightning.py
Normal file
@ -0,0 +1,346 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
# from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
|
from nets import Model
|
||||||
|
# from dataset import CREStereoDataset
|
||||||
|
from dataset import BlenderDataset, CREStereoDataset, CTDDataset
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from pytorch_lightning.lite import LightningLite
|
||||||
|
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
|
||||||
|
from pytorch_lightning import Trainer, seed_everything
|
||||||
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
|
|
||||||
|
seed_everything(42, workers=True)
|
||||||
|
|
||||||
|
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_and_colormap(img):
|
||||||
|
ret = (img - img.min()) / (img.max() - img.min()) * 255.0
|
||||||
|
if isinstance(ret, torch.Tensor):
|
||||||
|
ret = ret.cpu().detach().numpy()
|
||||||
|
ret = ret.astype("uint8")
|
||||||
|
ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=True, last_img=None, test=False, attend_pattern=True):
|
||||||
|
|
||||||
|
print("Model Forwarding...")
|
||||||
|
if isinstance(left, torch.Tensor):
|
||||||
|
left = left# .cpu().detach().numpy()
|
||||||
|
imgR = right# .cpu().detach().numpy()
|
||||||
|
imgL = left
|
||||||
|
imgR = right
|
||||||
|
imgL = np.ascontiguousarray(imgL[None, :, :, :])
|
||||||
|
imgR = np.ascontiguousarray(imgR[None, :, :, :])
|
||||||
|
|
||||||
|
flow_init = None
|
||||||
|
|
||||||
|
# chosen for convenience
|
||||||
|
|
||||||
|
imgL = torch.tensor(imgL.astype("float32"))
|
||||||
|
imgR = torch.tensor(imgR.astype("float32"))
|
||||||
|
imgL = imgL.transpose(2,3).transpose(1,2)
|
||||||
|
if imgL.shape != imgR.shape:
|
||||||
|
imgR = imgR.transpose(2,3).transpose(1,2)
|
||||||
|
|
||||||
|
imgL_dw2 = F.interpolate(
|
||||||
|
imgL,
|
||||||
|
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=True,
|
||||||
|
).clamp(min=0, max=255)
|
||||||
|
imgR_dw2 = F.interpolate(
|
||||||
|
imgR,
|
||||||
|
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=True,
|
||||||
|
).clamp(min=0, max=255)
|
||||||
|
if last_img is not None:
|
||||||
|
print('using flow_initialization')
|
||||||
|
print(last_img.shape)
|
||||||
|
# FIXME trying 8bit normalization because sometimes we get really high values that don't seem to help
|
||||||
|
print(last_img.max(), last_img.min())
|
||||||
|
if last_img.min() < 0:
|
||||||
|
# print('Negative disparity detected. shifting...')
|
||||||
|
last_img = last_img - last_img.min()
|
||||||
|
if last_img.max() > 255:
|
||||||
|
# print('Excessive disparity detected. scaling...')
|
||||||
|
last_img = last_img / (last_img.max() / 255)
|
||||||
|
|
||||||
|
|
||||||
|
last_img = np.dstack([last_img, last_img])
|
||||||
|
# last_img = np.dstack([last_img, last_img, last_img])
|
||||||
|
last_img = np.dstack([last_img])
|
||||||
|
last_img = last_img.reshape((1, 2, 480, 640))
|
||||||
|
# print(last_img.shape)
|
||||||
|
# print(last_img.dtype)
|
||||||
|
# print(last_img.max(), last_img.min())
|
||||||
|
flow_init = torch.tensor(last_img.astype("float32"))
|
||||||
|
# flow_init = F.interpolate(
|
||||||
|
# last_img,
|
||||||
|
# size=(last_img.shape[0] // 2, last_img.shape[1] // 2),
|
||||||
|
# mode="bilinear",
|
||||||
|
# align_corners=True,
|
||||||
|
# )
|
||||||
|
with torch.inference_mode():
|
||||||
|
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=flow_init, self_attend_right=attend_pattern)
|
||||||
|
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2, self_attend_right=attend_pattern)
|
||||||
|
pf_base = pred_flow
|
||||||
|
if isinstance(pf_base, list):
|
||||||
|
pf_base = pred_flow[0]
|
||||||
|
pf = torch.squeeze(pf_base[:, 0, :, :])# .cpu().detach().numpy()
|
||||||
|
print('pred_flow max min')
|
||||||
|
print(pf.max(), pf.min())
|
||||||
|
|
||||||
|
|
||||||
|
if not wandb_log:
|
||||||
|
if test:
|
||||||
|
return pred_flow
|
||||||
|
return torch.squeeze(pred_flow[:, 0, :, :])# .cpu().detach().numpy()
|
||||||
|
|
||||||
|
log = {}
|
||||||
|
in_h, in_w = left.shape[:2]
|
||||||
|
|
||||||
|
# Resize image in case the GPU memory overflows
|
||||||
|
eval_h, eval_w = (in_h,in_w)
|
||||||
|
|
||||||
|
for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)):
|
||||||
|
pred_disp = torch.squeeze(pf[:, 0, :, :])# .cpu().detach().numpy()
|
||||||
|
pred_disp_dw2 = torch.squeeze(pf_dw2[:, 0, :, :])# .cpu().detach().numpy()
|
||||||
|
|
||||||
|
# pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
|
||||||
|
# pred_disp_dw2_norm = cv2.normalize(pred_disp_dw2, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
|
||||||
|
|
||||||
|
if i == n_iter-1:
|
||||||
|
t = float(in_w) / float(eval_w)
|
||||||
|
disp = cv2.resize(pred_disp.cpu().detach().numpy(), (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t
|
||||||
|
|
||||||
|
log[f'disp_vis'] = wandb.Image(
|
||||||
|
normalize_and_colormap(disp),
|
||||||
|
caption=f"Disparity \n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
|
||||||
|
)
|
||||||
|
|
||||||
|
log[f'pred_{i}'] = wandb.Image(
|
||||||
|
np.array([pred_disp.cpu().detach().numpy().reshape(480, 640)]),
|
||||||
|
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
|
||||||
|
)
|
||||||
|
# log[f'pred_norm_{i}'] = wandb.Image(
|
||||||
|
# np.array([pred_disp_norm.reshape(480, 640)]),
|
||||||
|
# caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
|
||||||
|
# )
|
||||||
|
|
||||||
|
# log[f'pred_dw2_{i}'] = wandb.Image(
|
||||||
|
# np.array([pred_disp_dw2.reshape(240, 320)]),
|
||||||
|
# caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
|
||||||
|
# )
|
||||||
|
# log[f'pred_dw2_norm_{i}'] = wandb.Image(
|
||||||
|
# np.array([pred_disp_dw2_norm.reshape(240, 320)]),
|
||||||
|
# caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left")
|
||||||
|
input_right = right# .cpu().detach().numpy() if isinstance(right, torch.Tensor) else right
|
||||||
|
if input_right.shape != (480, 640, 3):
|
||||||
|
input_right.transpose(1,2,0)
|
||||||
|
log['input_right'] = wandb.Image(input_right.astype('uint8'), caption="Input Right")
|
||||||
|
|
||||||
|
log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}")
|
||||||
|
|
||||||
|
gt_disp = gt_disp# .cpu().detach().numpy() if isinstance(gt_disp, torch.Tensor) else gt_disp
|
||||||
|
disp = disp# .cpu().detach().numpy() if isinstance(disp, torch.Tensor) else disp
|
||||||
|
|
||||||
|
disp_error = gt_disp - disp
|
||||||
|
log['disp_error'] = wandb.Image(
|
||||||
|
normalize_and_colormap(abs(disp_error)),
|
||||||
|
caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{abs(disp_error).mean():.{2}f}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
log[f'gt_disp_vis'] = wandb.Image(
|
||||||
|
normalize_and_colormap(gt_disp),
|
||||||
|
caption=f"GT Disp Vis \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}",
|
||||||
|
)
|
||||||
|
|
||||||
|
wandb.log(log)
|
||||||
|
return pred_flow
|
||||||
|
|
||||||
|
|
||||||
|
def parse_yaml(file_path: str) -> namedtuple:
|
||||||
|
"""Parse yaml configuration file and return the object in `namedtuple`."""
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
cfg: dict = yaml.safe_load(f)
|
||||||
|
args = namedtuple("train_args", cfg.keys())(*cfg.values())
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def format_time(elapse):
|
||||||
|
elapse = int(elapse)
|
||||||
|
hour = elapse // 3600
|
||||||
|
minute = elapse % 3600 // 60
|
||||||
|
seconds = elapse % 60
|
||||||
|
return "{:02d}:{:02d}:{:02d}".format(hour, minute, seconds)
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_dir(path):
|
||||||
|
if not os.path.exists(path):
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_learning_rate(optimizer, epoch):
|
||||||
|
|
||||||
|
warm_up = 0.02
|
||||||
|
const_range = 0.6
|
||||||
|
min_lr_rate = 0.05
|
||||||
|
|
||||||
|
if epoch <= args.n_total_epoch * warm_up:
|
||||||
|
lr = (1 - min_lr_rate) * args.base_lr / (
|
||||||
|
args.n_total_epoch * warm_up
|
||||||
|
) * epoch + min_lr_rate * args.base_lr
|
||||||
|
elif args.n_total_epoch * warm_up < epoch <= args.n_total_epoch * const_range:
|
||||||
|
lr = args.base_lr
|
||||||
|
else:
|
||||||
|
lr = (min_lr_rate - 1) * args.base_lr / (
|
||||||
|
(1 - const_range) * args.n_total_epoch
|
||||||
|
) * epoch + (1 - min_lr_rate * const_range) / (1 - const_range) * args.base_lr
|
||||||
|
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] = lr
|
||||||
|
|
||||||
|
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False):
|
||||||
|
'''
|
||||||
|
valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W)
|
||||||
|
flow_preds[0]: (B, 2, H, W)
|
||||||
|
flow_gt: (B, 2, H, W)
|
||||||
|
'''
|
||||||
|
if test:
|
||||||
|
# print('sequence loss')
|
||||||
|
if valid.shape != (2, 480, 640):
|
||||||
|
valid = torch.stack([valid, valid])#.transpose(0,1)#.transpose(1,2)
|
||||||
|
# print(valid.shape)
|
||||||
|
#valid = torch.stack([valid, valid])
|
||||||
|
# print(valid.shape)
|
||||||
|
if valid.shape != (2, 480, 640):
|
||||||
|
valid = valid.transpose(0,1)
|
||||||
|
# print(valid.shape)
|
||||||
|
# print(valid.shape)
|
||||||
|
# print(flow_preds[0].shape)
|
||||||
|
# print(flow_gt.shape)
|
||||||
|
n_predictions = len(flow_preds)
|
||||||
|
flow_loss = 0.0
|
||||||
|
|
||||||
|
# TEST
|
||||||
|
flow_gt = torch.squeeze(flow_gt, dim=-1)
|
||||||
|
|
||||||
|
for i in range(n_predictions):
|
||||||
|
i_weight = gamma ** (n_predictions - i - 1)
|
||||||
|
i_loss = torch.abs(flow_preds[i] - flow_gt)
|
||||||
|
flow_loss += i_weight * (valid.unsqueeze(1) * i_loss).mean()
|
||||||
|
|
||||||
|
return flow_loss
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CREStereoLightning(LightningModule):
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__()
|
||||||
|
self.batch_size = args.batch_size
|
||||||
|
self.model = Model(
|
||||||
|
max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False, self_attend_right=True):
|
||||||
|
return self.model(image1, image2, flow_init, iters, upsample, test_mode, self_attend_right)
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
# loss = self(batch)
|
||||||
|
left, right, gt_disp, valid_mask = batch
|
||||||
|
left = torch.Tensor(left).to(self.device)
|
||||||
|
right = torch.Tensor(right).to(self.device)
|
||||||
|
left = left
|
||||||
|
right = right
|
||||||
|
flow_predictions = self.forward(left, right)
|
||||||
|
loss = sequence_loss(
|
||||||
|
flow_predictions, gt_flow, valid_mask, gamma=0.8
|
||||||
|
)
|
||||||
|
self.log("train_loss", loss)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
left, right, gt_disp, valid_mask = batch
|
||||||
|
left = torch.Tensor(left).to(self.device)
|
||||||
|
right = torch.Tensor(right).to(self.device)
|
||||||
|
print(left.shape)
|
||||||
|
print(right.shape)
|
||||||
|
flow_predictions = self.forward(left, right)
|
||||||
|
val_loss = sequence_loss(
|
||||||
|
flow_predictions, gt_flow, valid_mask, gamma=0.8
|
||||||
|
)
|
||||||
|
self.log("val_loss", val_loss)
|
||||||
|
|
||||||
|
def test_step(self, batch, batch_idx):
|
||||||
|
left, right, gt_disp, valid_mask = batch
|
||||||
|
# left, right, gt_disp, valid_mask = (
|
||||||
|
# batch["left"],
|
||||||
|
# batch["right"],
|
||||||
|
# batch["disparity"],
|
||||||
|
# batch["mask"],
|
||||||
|
# )
|
||||||
|
left = torch.Tensor(left).to(self.device)
|
||||||
|
right = torch.Tensor(right).to(self.device)
|
||||||
|
flow_predictions = self.forward(left, right)
|
||||||
|
test_loss = sequence_loss(
|
||||||
|
flow_predictions, gt_flow, valid_mask, gamma=0.8
|
||||||
|
)
|
||||||
|
self.log("test_loss", test_loss)
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
return optim.Adam(self.model.parameters(), lr=0.01, betas=(0.9, 0.999))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# train configuration
|
||||||
|
args = parse_yaml("cfgs/train.yaml")
|
||||||
|
# wandb.init(project="crestereo-lightning", entity="cpt-captain")
|
||||||
|
# Lite(strategy='dp', accelerator='gpu', devices=2).run(args)
|
||||||
|
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
|
||||||
|
model = CREStereoLightning(args)
|
||||||
|
dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path, use_lightning=True)
|
||||||
|
test_dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path, test_set=True, use_lightning=True)
|
||||||
|
print(len(dataset))
|
||||||
|
print(len(test_dataset))
|
||||||
|
wandb_logger = WandbLogger(project="crestereo-lightning")
|
||||||
|
wandb.config.update(args._asdict())
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
max_epochs=args.n_total_epoch,
|
||||||
|
accelerator='gpu',
|
||||||
|
devices=2,
|
||||||
|
# auto_scale_batch_size='binsearch',
|
||||||
|
# strategy='ddp',
|
||||||
|
deterministic=True,
|
||||||
|
check_val_every_n_epoch=1,
|
||||||
|
limit_val_batches=24,
|
||||||
|
limit_test_batches=24,
|
||||||
|
logger=wandb_logger,
|
||||||
|
default_root_dir=args.log_dir_lightning,
|
||||||
|
)
|
||||||
|
# trainer.tune(model)
|
||||||
|
trainer.fit(model, dataset, test_dataset)
|
Loading…
Reference in New Issue
Block a user