from typing import Optional from fastapi import FastAPI from pydantic import BaseModel import numpy as np from cv2 import cv2 import torch import torch.nn as nn import torch.nn.functional as F from ..nets import Model app = FastAPI() # TODO # beide modelle laden, jeweils eine gpu zuweisen # routen bauen, gegen die man bilder werfen kann, die dann jeweils von einem modell interpretiert werden # ergebnisse zurueck geben # # input validierung nicht vergessen # paramter (bildgroesse etc.) konfigurierbar machen oder automatisch rausfinden? # kommt ctd überhaupt mit was anderem klar? class Item(BaseModel): name: str price: float is_offer: Optional[bool] = None class IrImage(BaseModel): image: np.array reference_pattern_path = '/home/nils/kinect_reference_cropped.png' reference_pattern = cv2.imread(reference_pattern_path) model_path = "../train_log/models/latest.pth" device = torch.device('cuda:0') model = Model(max_disp=256, mixed_precision=False, test_mode=True) model = nn.DataParallel(model, device_ids=[]) # model.load_state_dict(torch.load(model_path), strict=False) state_dict = torch.load(model_path)['state_dict'] model.load_state_dict(state_dict, strict=True) model.to(device) model.eval() def normalize_and_colormap(img): ret = (img - img.min()) / (img.max() - img.min()) * 255.0 if isinstance(ret, torch.Tensor): ret = ret.cpu().detach().numpy() ret = ret.astype("uint8") ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO) return ret def inference_ctd(left, right, model, n_iter=20): print("Model Forwarding...") # print(left.shape) # left = left.cpu().detach().numpy() # imgL = left # imgR = right.cpu().detach().numpy() imgL = np.ascontiguousarray(left[None, :, :, :]) imgR = np.ascontiguousarray(right[None, :, :, :]) # chosen for convenience device = torch.device('cuda:0') imgL = torch.tensor(imgL.astype("float32")).to(device) imgR = torch.tensor(imgR.astype("float32")).to(device) imgL = imgL.transpose(2, 3).transpose(1, 2) imgL_dw2 = F.interpolate( imgL, size=(imgL.shape[2] // 2, imgL.shape[3] // 2), mode="bilinear", align_corners=True, ) imgR_dw2 = F.interpolate( imgR, size=(imgL.shape[2] // 2, imgL.shape[3] // 2), mode="bilinear", align_corners=True, ) with torch.inference_mode(): pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None) pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2) for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)): pred_disp = torch.squeeze(pf[:, 0, :, :]).cpu().detach().numpy() pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) return pred_disp_norm @app.put("/ir") def read_ir_input(ir_image: IrImage): pred_disp = inference_ctd(ir_image.image, reference_pattern) return {"pred_disp": pred_disp} @app.get("/items/{item_id}") def read_item(item_id: int, q: Optional[str] = None): return {"item_id": item_id, "q": q} @app.put("/items/{item_id}") def update_item(item_id: int, item: Item): return {"item_price": item.price, "item_id": item_id}