You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
297 lines
9.5 KiB
297 lines
9.5 KiB
import os
|
|
import json
|
|
from datetime import datetime
|
|
from typing import Union, Literal
|
|
from io import BytesIO
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch_tensorrt
|
|
from cv2 import cv2
|
|
from fastapi import FastAPI, File, UploadFile
|
|
from PIL import Image
|
|
|
|
|
|
from nets import Model
|
|
|
|
from train import inference as ctd_inference
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
# reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
|
|
# reference_pattern_path = '/home/nils/kinect_reference_far.png'
|
|
reference_pattern_path = '/home/nils/mpc/kinect_downshift_rotate_left-1.png'
|
|
# reference_pattern_path = '/home/nils/kinect_diff_ref.png'
|
|
print(reference_pattern_path)
|
|
reference_pattern = cv2.imread(reference_pattern_path)
|
|
|
|
# shift reference pattern a few pixels to the left to simulate further backdrop
|
|
trans_mat = np.float32([[1, 0, 0], [0, 1, 0]])
|
|
reference_pattern = cv2.warpAffine(
|
|
reference_pattern, trans_mat, reference_pattern.shape[1::-1], flags=cv2.INTER_LINEAR
|
|
)
|
|
|
|
iters = 20
|
|
minimal_data = True
|
|
temporal_init = False
|
|
last_img = None
|
|
device = torch.device('cuda:0')
|
|
|
|
|
|
def downsize(img):
|
|
diff = (512 - 480) // 2
|
|
downsampled = cv2.pyrDown(img)
|
|
img = downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]]
|
|
return img
|
|
|
|
|
|
if 1024 in reference_pattern.shape:
|
|
reference_pattern = downsize(reference_pattern)
|
|
|
|
|
|
def ghetto_lcn(img):
|
|
# gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
|
gray = img
|
|
|
|
float_gray = gray.astype(np.float32) / 255.0
|
|
|
|
blur = cv2.GaussianBlur(float_gray, (0, 0), sigmaX=2, sigmaY=2)
|
|
num = float_gray - blur
|
|
|
|
blur = cv2.GaussianBlur(num * num, (0, 0), sigmaX=20, sigmaY=20)
|
|
den = cv2.pow(blur, 0.5)
|
|
|
|
gray = num / den
|
|
|
|
# cv2.normalize(gray, dst=gray, alpha=0.0, beta=1.0, norm_type=cv2.NORM_MINMAX)
|
|
cv2.normalize(gray, dst=gray, alpha=0.0, beta=255.0, norm_type=cv2.NORM_MINMAX)
|
|
return gray
|
|
|
|
|
|
# reference_pattern = ghetto_lcn(reference_pattern)
|
|
|
|
|
|
def load_model(epoch, use_tensorrt=False):
|
|
global model
|
|
epoch = f'epoch-{epoch}.pth' if epoch != 'latest' else 'latest.pth'
|
|
model_path = f"train_log/models/{epoch}"
|
|
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
|
|
# FIXME WIP Workaround Dataparallel TensorRT incompatibility
|
|
model = nn.DataParallel(model, device_ids=[device])
|
|
# model.load_state_dict(torch.load(model_path), strict=False)
|
|
state_dict = torch.load(model_path)['state_dict']
|
|
model.load_state_dict(state_dict, strict=True)
|
|
model.to(device)
|
|
model.eval()
|
|
if use_tensorrt:
|
|
np_model = Model(max_disp=256, mixed_precision=False, test_mode=True)
|
|
np_model.load_state_dict(model.module.state_dict(), strict=True)
|
|
np_model.to(device)
|
|
np_model.eval()
|
|
|
|
spec_dict = {
|
|
"inputs": [
|
|
torch_tensorrt.Input(
|
|
min_shape=[1, 2, 240, 320],
|
|
max_shape=[1, 2, 480, 640],
|
|
opt_shape=[1, 2, 480, 640],
|
|
dtype=torch.int32,
|
|
),
|
|
torch_tensorrt.Input(
|
|
min_shape=[1, 2, 240, 320],
|
|
max_shape=[1, 2, 480, 640],
|
|
opt_shape=[1, 2, 480, 640],
|
|
dtype=torch.int32,
|
|
),
|
|
],
|
|
"enabled_precisions": {torch.float, torch.half},
|
|
"refit": False,
|
|
"debug": False,
|
|
"device": {
|
|
"device_type": torch_tensorrt.DeviceType.GPU,
|
|
"gpu_id": 0,
|
|
"dla_core": 0,
|
|
"allow_gpu_fallback": True
|
|
},
|
|
"capability": torch_tensorrt.EngineCapability.default,
|
|
"num_min_timing_iters": 2,
|
|
"num_avg_timing_iters": 1,
|
|
}
|
|
spec = {
|
|
"forward":
|
|
torch_tensorrt.ts.TensorRTCompileSpec(**spec_dict)
|
|
}
|
|
|
|
# trt_model = torch_tensorrt.compile(np_model ,
|
|
# inputs=torch_tensorrt.Input(
|
|
# min_shape=[1, 2, 240, 320],
|
|
# max_shape=[1, 2, 480, 640],
|
|
# opt_shape=[1, 2, 480, 640],
|
|
# dtype=torch.int32,
|
|
# inputs = [torch_tensorrt.Input((1, 2, 480, 640)), torch_tensorrt.Input((1, 2, 480, 640))], # input shape
|
|
# enabled_precisions = {torch_tensorrt.dtype.half} # Run with FP16
|
|
# )
|
|
# trt_dw_model = torch_tensorrt.compile(np_model ,
|
|
# inputs = [torch_tensorrt.Input((1, 2, 240, 320)), torch_tensorrt.Input((1, 2, 240, 320))], # input shape
|
|
# enabled_precisions = {torch_tensorrt.dtype.half} # Run with FP16
|
|
# )
|
|
|
|
script_model = torch.jit.script(np_model.eval())
|
|
# script_dw_model = torch.jit.script(trt_dw_model.eval())
|
|
|
|
# save the TensorRT embedded Torchscript
|
|
# torch.jit.save(trt_model, 'trt_torchscript_module.ts')
|
|
# torch.jit.save(trt_dw_model, 'trt_torchscript_dw_module.ts')
|
|
print(script_model)
|
|
print(script_model.forward)
|
|
print(script_model.forward())
|
|
print(dir(script_model))
|
|
|
|
model = torch._C._jit_to_backend("tensorrt", script_model, spec)
|
|
|
|
print(f'loaded model {epoch}')
|
|
return model
|
|
|
|
|
|
model = load_model('latest')
|
|
|
|
|
|
class NumpyEncoder(json.JSONEncoder):
|
|
def default(self, obj):
|
|
if isinstance(obj, np.ndarray):
|
|
return obj.tolist()
|
|
return json.JSONEncoder.default(self, obj)
|
|
|
|
|
|
def inference(left, right, model, n_iter=20):
|
|
print("Model Forwarding...")
|
|
imgL = np.ascontiguousarray(left[None, :, :, :])
|
|
imgR = np.ascontiguousarray(right[None, :, :, :])
|
|
|
|
device = torch.device('cuda')
|
|
|
|
imgL = torch.tensor(imgL.astype("float32")).to(device)
|
|
imgR = torch.tensor(imgR.astype("float32")).to(device)
|
|
|
|
imgR = imgR.transpose(1, 2)
|
|
imgL = imgL.transpose(1, 2)
|
|
|
|
imgL_dw2 = F.interpolate(
|
|
imgL,
|
|
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
|
mode="bilinear",
|
|
align_corners=True,
|
|
)
|
|
imgR_dw2 = F.interpolate(
|
|
imgR,
|
|
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
|
mode="bilinear",
|
|
align_corners=True,
|
|
)
|
|
|
|
with torch.inference_mode():
|
|
# pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
|
|
# pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
|
|
pred_flow_dw2 = model.forward(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
|
|
pred_flow = model.forward(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
|
|
|
|
pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy()
|
|
|
|
return pred_disp
|
|
|
|
|
|
def get_reference():
|
|
refs = [ref.path for ref in os.scandir('/home/nils/references/')]
|
|
for ref in refs:
|
|
reference = cv2.imread(ref)
|
|
yield reference
|
|
|
|
references = get_reference()
|
|
|
|
|
|
@app.post('/params/update_reference')
|
|
async def update_reference():
|
|
global references, reference_pattern
|
|
try:
|
|
reference_pattern = downsize(next(references))
|
|
print(reference_pattern.shape)
|
|
return {'status': 'success'}
|
|
except StopIteration:
|
|
references = get_reference()
|
|
return {'status': 'finished'}
|
|
|
|
|
|
@app.post("/model/update/{epoch}")
|
|
async def change_model(epoch: Union[int, Literal['latest']]):
|
|
global model
|
|
print(epoch)
|
|
print('updating model')
|
|
model = load_model(epoch)
|
|
return {'status': 'success'}
|
|
|
|
|
|
@app.post("/params/iterations/{iterations}")
|
|
async def set_iterations(iterations: int):
|
|
global iters
|
|
iters = iterations
|
|
|
|
|
|
@app.post("/params/minimal_data/{enable}")
|
|
async def set_minimal_data(enable: bool):
|
|
global minimal_data
|
|
minimal_data = enable
|
|
|
|
|
|
@app.post("/params/temporal_init/{enable}")
|
|
async def set_temporal_init(enable: bool):
|
|
global temporal_init
|
|
temporal_init = enable
|
|
|
|
|
|
@app.put("/ir")
|
|
async def read_ir_input(file: UploadFile = File(...)):
|
|
global last_img, minimal_data
|
|
try:
|
|
img = np.array(Image.open(BytesIO(await file.read())))
|
|
except Exception as e:
|
|
return {'error': 'couldn\'t read file', 'exception': e}
|
|
|
|
# img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
|
|
if len(img.shape) == 2:
|
|
img = cv2.merge([img for _ in range(3)])
|
|
if img.shape == (1024, 1280, 3):
|
|
img = downsize(img)
|
|
|
|
# img = img.transpose((1, 2, 0))
|
|
# ref_pat = reference_pattern.transpose((1, 2, 0))
|
|
ref_pat = reference_pattern
|
|
|
|
start = datetime.now()
|
|
if temporal_init:
|
|
pred_disp = ctd_inference(img, ref_pat, None, None, model, None, iters, False, last_img)
|
|
last_img = pred_disp
|
|
else:
|
|
pred_disp = ctd_inference(img, ref_pat, None, None, model, None, iters, False)
|
|
# pred_disp = inference(img, ref_pat, model, iters)
|
|
duration = (datetime.now() - start).total_seconds()
|
|
|
|
if minimal_data:
|
|
return json.dumps({'disp': pred_disp, 'duration': duration}, cls=NumpyEncoder)
|
|
else:
|
|
# return json.dumps({'disp': pred_disp, 'input': img, 'duration': duration},
|
|
return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img, 'duration': duration},
|
|
cls=NumpyEncoder)
|
|
|
|
|
|
@app.get('/temporal_init')
|
|
def get_temporal_init():
|
|
return {'status': 'enabled' if temporal_init else 'disabled'}
|
|
|
|
|
|
|
|
@app.get('/')
|
|
def main():
|
|
return {'test': 'abc'}
|
|
|