diff --git a/api_server.py b/api_server.py index 015c1b2..26227e7 100644 --- a/api_server.py +++ b/api_server.py @@ -1,40 +1,24 @@ +import json from io import BytesIO -from typing import Optional - -from fastapi import FastAPI, File, UploadFile -from pydantic import BaseModel import numpy as np -from cv2 import cv2 import torch import torch.nn as nn import torch.nn.functional as F -from nets import Model +from cv2 import cv2 +from fastapi import FastAPI, File, UploadFile from PIL import Image +from nets import Model app = FastAPI() -# TODO -# beide modelle laden, jeweils eine gpu zuweisen -# routen bauen, gegen die man bilder werfen kann, die dann jeweils von einem modell interpretiert werden -# ergebnisse zurueck geben -# -# input validierung nicht vergessen -# paramter (bildgroesse etc.) konfigurierbar machen oder automatisch rausfinden? -# kommt ctd überhaupt mit was anderem klar? - - - -class IrImage(BaseModel): - image: np.array - - reference_pattern_path = '/home/nils/kinect_reference_cropped.png' reference_pattern = cv2.imread(reference_pattern_path) model_path = "train_log/models/latest.pth" -device = torch.device('cuda:0') +# model_path = "train_log/models/epoch-100.pth" +device = torch.device('cuda') model = Model(max_disp=256, mixed_precision=False, test_mode=True) model = nn.DataParallel(model, device_ids=[device]) @@ -44,32 +28,27 @@ model.load_state_dict(state_dict, strict=True) model.to(device) model.eval() +class NumpyEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, np.ndarray): + return obj.tolist() + return json.JSONEncoder.default(self, obj) -def normalize_and_colormap(img): - ret = (img - img.min()) / (img.max() - img.min()) * 255.0 - if isinstance(ret, torch.Tensor): - ret = ret.cpu().detach().numpy() - ret = ret.astype("uint8") - ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO) - return ret - -def inference_ctd(left, right, model, n_iter=20): +def inference(left, right, model, n_iter=20): print("Model Forwarding...") - # print(left.shape) - # left = left.cpu().detach().numpy() - # imgL = left - # imgR = right.cpu().detach().numpy() imgL = np.ascontiguousarray(left[None, :, :, :]) imgR = np.ascontiguousarray(right[None, :, :, :]) - # chosen for convenience - device = torch.device('cuda:0') + device = torch.device('cuda') imgL = torch.tensor(imgL.astype("float32")).to(device) imgR = torch.tensor(imgR.astype("float32")).to(device) - imgL = imgL.transpose(2, 3).transpose(1, 2) + # Funzt grob + imgR = imgR.transpose(1,2) + imgL = imgL.transpose(1,2) + imgL_dw2 = F.interpolate( imgL, size=(imgL.shape[2] // 2, imgL.shape[3] // 2), @@ -82,15 +61,14 @@ def inference_ctd(left, right, model, n_iter=20): mode="bilinear", align_corners=True, ) + with torch.inference_mode(): pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None) pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2) - for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)): - pred_disp = torch.squeeze(pf[:, 0, :, :]).cpu().detach().numpy() - pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) + pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy() - return pred_disp_norm + return pred_disp @app.put("/ir") @@ -99,11 +77,21 @@ async def read_ir_input(file: UploadFile = File(...)): img = np.array(Image.open(BytesIO(await file.read()))) except Exception as e: return {'error': 'couldn\'t read file', 'exception': e} - print(img.shape) + + # img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) if len(img.shape) == 2: - img = np.stack((img for _ in range(3))) - pred_disp = inference_ctd(np.array(img), reference_pattern, None) - return {"pred_disp": pred_disp} + img = cv2.merge([img for _ in range(3)]) + if img.shape == (1024, 1280, 3): + diff = (512 - 480) // 2 + downsampled = cv2.pyrDown(img) + img = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]] + + img = img.transpose((1,2,0)) + ref_pat = reference_pattern.transpose((1,2,0)) + + pred_disp = inference(img, ref_pat, model) + + return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img}, cls=NumpyEncoder) @app.get('/')