CREStereo Repository for the 'Towards accurate and robust depth estimation' project
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

111 lines
3.2 KiB

from typing import Optional
from fastapi import FastAPI
from pydantic import BaseModel
import numpy as np
from cv2 import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..nets import Model
app = FastAPI()
# TODO
# beide modelle laden, jeweils eine gpu zuweisen
# routen bauen, gegen die man bilder werfen kann, die dann jeweils von einem modell interpretiert werden
# ergebnisse zurueck geben
#
# input validierung nicht vergessen
# paramter (bildgroesse etc.) konfigurierbar machen oder automatisch rausfinden?
# kommt ctd überhaupt mit was anderem klar?
class Item(BaseModel):
name: str
price: float
is_offer: Optional[bool] = None
class IrImage(BaseModel):
image: np.array
reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
reference_pattern = cv2.imread(reference_pattern_path)
model_path = "../train_log/models/latest.pth"
device = torch.device('cuda:0')
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
model = nn.DataParallel(model, device_ids=[])
# model.load_state_dict(torch.load(model_path), strict=False)
state_dict = torch.load(model_path)['state_dict']
model.load_state_dict(state_dict, strict=True)
model.to(device)
model.eval()
def normalize_and_colormap(img):
ret = (img - img.min()) / (img.max() - img.min()) * 255.0
if isinstance(ret, torch.Tensor):
ret = ret.cpu().detach().numpy()
ret = ret.astype("uint8")
ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO)
return ret
def inference_ctd(left, right, model, n_iter=20):
print("Model Forwarding...")
# print(left.shape)
# left = left.cpu().detach().numpy()
# imgL = left
# imgR = right.cpu().detach().numpy()
imgL = np.ascontiguousarray(left[None, :, :, :])
imgR = np.ascontiguousarray(right[None, :, :, :])
# chosen for convenience
device = torch.device('cuda:0')
imgL = torch.tensor(imgL.astype("float32")).to(device)
imgR = torch.tensor(imgR.astype("float32")).to(device)
imgL = imgL.transpose(2, 3).transpose(1, 2)
imgL_dw2 = F.interpolate(
imgL,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
mode="bilinear",
align_corners=True,
)
imgR_dw2 = F.interpolate(
imgR,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
mode="bilinear",
align_corners=True,
)
with torch.inference_mode():
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)):
pred_disp = torch.squeeze(pf[:, 0, :, :]).cpu().detach().numpy()
pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
return pred_disp_norm
@app.put("/ir")
def read_ir_input(ir_image: IrImage):
pred_disp = inference_ctd(ir_image.image, reference_pattern)
return {"pred_disp": pred_disp}
@app.get("/items/{item_id}")
def read_item(item_id: int, q: Optional[str] = None):
return {"item_id": item_id, "q": q}
@app.put("/items/{item_id}")
def update_item(item_id: int, item: Item):
return {"item_price": item.price, "item_id": item_id}