import os import json from datetime import datetime from typing import Union, Literal from io import BytesIO import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch_tensorrt from cv2 import cv2 from fastapi import FastAPI, File, UploadFile from PIL import Image from nets import Model from train import inference as ctd_inference app = FastAPI() # reference_pattern_path = '/home/nils/kinect_reference_cropped.png' reference_pattern_path = '/home/nils/kinect_reference_far.png' # reference_pattern_path = '/home/nils/kinect_diff_ref.png' print(reference_pattern_path) reference_pattern = cv2.imread(reference_pattern_path) # shift reference pattern a few pixels to the left to simulate further backdrop trans_mat = np.float32([[1, 0, 0], [0, 1, 0]]) reference_pattern = cv2.warpAffine( reference_pattern, trans_mat, reference_pattern.shape[1::-1], flags=cv2.INTER_LINEAR ) iters = 20 minimal_data = True temporal_init = False last_img = None device = torch.device('cuda:0') def downsize(img): diff = (512 - 480) // 2 downsampled = cv2.pyrDown(img) img = downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]] return img if 1024 in reference_pattern.shape: reference_pattern = downsize(reference_pattern) def ghetto_lcn(img): # gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) gray = img float_gray = gray.astype(np.float32) / 255.0 blur = cv2.GaussianBlur(float_gray, (0, 0), sigmaX=2, sigmaY=2) num = float_gray - blur blur = cv2.GaussianBlur(num * num, (0, 0), sigmaX=20, sigmaY=20) den = cv2.pow(blur, 0.5) gray = num / den # cv2.normalize(gray, dst=gray, alpha=0.0, beta=1.0, norm_type=cv2.NORM_MINMAX) cv2.normalize(gray, dst=gray, alpha=0.0, beta=255.0, norm_type=cv2.NORM_MINMAX) return gray # reference_pattern = ghetto_lcn(reference_pattern) def load_model(epoch, use_tensorrt=False): global model epoch = f'epoch-{epoch}.pth' if epoch != 'latest' else 'latest.pth' model_path = f"train_log/models/{epoch}" model = Model(max_disp=256, mixed_precision=False, test_mode=True) # FIXME WIP Workaround Dataparallel TensorRT incompatibility model = nn.DataParallel(model, device_ids=[device]) # model.load_state_dict(torch.load(model_path), strict=False) state_dict = torch.load(model_path)['state_dict'] model.load_state_dict(state_dict, strict=True) model.to(device) model.eval() if use_tensorrt: np_model = Model(max_disp=256, mixed_precision=False, test_mode=True) np_model.load_state_dict(model.module.state_dict(), strict=True) np_model.to(device) np_model.eval() spec_dict = { "inputs": [ torch_tensorrt.Input( min_shape=[1, 2, 240, 320], max_shape=[1, 2, 480, 640], opt_shape=[1, 2, 480, 640], dtype=torch.int32, ), torch_tensorrt.Input( min_shape=[1, 2, 240, 320], max_shape=[1, 2, 480, 640], opt_shape=[1, 2, 480, 640], dtype=torch.int32, ), ], "enabled_precisions": {torch.float, torch.half}, "refit": False, "debug": False, "device": { "device_type": torch_tensorrt.DeviceType.GPU, "gpu_id": 0, "dla_core": 0, "allow_gpu_fallback": True }, "capability": torch_tensorrt.EngineCapability.default, "num_min_timing_iters": 2, "num_avg_timing_iters": 1, } spec = { "forward": torch_tensorrt.ts.TensorRTCompileSpec(**spec_dict) } # trt_model = torch_tensorrt.compile(np_model , # inputs=torch_tensorrt.Input( # min_shape=[1, 2, 240, 320], # max_shape=[1, 2, 480, 640], # opt_shape=[1, 2, 480, 640], # dtype=torch.int32, # inputs = [torch_tensorrt.Input((1, 2, 480, 640)), torch_tensorrt.Input((1, 2, 480, 640))], # input shape # enabled_precisions = {torch_tensorrt.dtype.half} # Run with FP16 # ) # trt_dw_model = torch_tensorrt.compile(np_model , # inputs = [torch_tensorrt.Input((1, 2, 240, 320)), torch_tensorrt.Input((1, 2, 240, 320))], # input shape # enabled_precisions = {torch_tensorrt.dtype.half} # Run with FP16 # ) script_model = torch.jit.script(np_model.eval()) # script_dw_model = torch.jit.script(trt_dw_model.eval()) # save the TensorRT embedded Torchscript # torch.jit.save(trt_model, 'trt_torchscript_module.ts') # torch.jit.save(trt_dw_model, 'trt_torchscript_dw_module.ts') print(script_model) print(script_model.forward) print(script_model.forward()) print(dir(script_model)) model = torch._C._jit_to_backend("tensorrt", script_model, spec) print(f'loaded model {epoch}') return model model = load_model('latest') class NumpyEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, np.ndarray): return obj.tolist() return json.JSONEncoder.default(self, obj) def inference(left, right, model, n_iter=20): print("Model Forwarding...") imgL = np.ascontiguousarray(left[None, :, :, :]) imgR = np.ascontiguousarray(right[None, :, :, :]) device = torch.device('cuda') imgL = torch.tensor(imgL.astype("float32")).to(device) imgR = torch.tensor(imgR.astype("float32")).to(device) imgR = imgR.transpose(1, 2) imgL = imgL.transpose(1, 2) imgL_dw2 = F.interpolate( imgL, size=(imgL.shape[2] // 2, imgL.shape[3] // 2), mode="bilinear", align_corners=True, ) imgR_dw2 = F.interpolate( imgR, size=(imgL.shape[2] // 2, imgL.shape[3] // 2), mode="bilinear", align_corners=True, ) with torch.inference_mode(): # pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None) # pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2) pred_flow_dw2 = model.forward(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None) pred_flow = model.forward(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2) pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy() return pred_disp def get_reference(): refs = [ref.path for ref in os.scandir('/home/nils/references/')] for ref in refs: reference = cv2.imread(ref) yield reference references = get_reference() @app.post('/params/update_reference') async def update_reference(): global references, reference_pattern try: reference_pattern = downsize(next(references)) print(reference_pattern.shape) return {'status': 'success'} except StopIteration: references = get_reference() return {'status': 'finished'} @app.post("/model/update/{epoch}") async def change_model(epoch: Union[int, Literal['latest']]): global model print(epoch) print('updating model') model = load_model(epoch) return {'status': 'success'} @app.post("/params/iterations/{iterations}") async def set_iterations(iterations: int): global iters iters = iterations @app.post("/params/minimal_data/{enable}") async def set_minimal_data(enable: bool): global minimal_data minimal_data = enable @app.post("/params/temporal_init/{enable}") async def set_temporal_init(enable: bool): global temporal_init temporal_init = enable @app.put("/ir") async def read_ir_input(file: UploadFile = File(...)): global last_img, minimal_data try: img = np.array(Image.open(BytesIO(await file.read()))) except Exception as e: return {'error': 'couldn\'t read file', 'exception': e} # img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) if len(img.shape) == 2: img = cv2.merge([img for _ in range(3)]) if img.shape == (1024, 1280, 3): img = downsize(img) # img = img.transpose((1, 2, 0)) # ref_pat = reference_pattern.transpose((1, 2, 0)) ref_pat = reference_pattern start = datetime.now() if temporal_init: pred_disp = ctd_inference(img, ref_pat, None, None, model, None, iters, False, last_img) last_img = pred_disp else: pred_disp = ctd_inference(img, ref_pat, None, None, model, None, iters, False) # pred_disp = inference(img, ref_pat, model, iters) duration = (datetime.now() - start).total_seconds() if minimal_data: return json.dumps({'disp': pred_disp, 'duration': duration}, cls=NumpyEncoder) else: # return json.dumps({'disp': pred_disp, 'input': img, 'duration': duration}, return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img, 'duration': duration}, cls=NumpyEncoder) @app.get('/temporal_init') def get_temporal_init(): return {'status': 'enabled' if temporal_init else 'disabled'} @app.get('/') def main(): return {'test': 'abc'}