api_server.py: rework so it actually kinda works

results are weirdly poor, but idk
main
Nils Koch 3 years ago
parent 46a6ae44af
commit 1bdf2e7776
  1. 80
      api_server.py

@ -1,40 +1,24 @@
import json
from io import BytesIO
from typing import Optional
from fastapi import FastAPI, File, UploadFile
from pydantic import BaseModel
import numpy as np
from cv2 import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from nets import Model
from cv2 import cv2
from fastapi import FastAPI, File, UploadFile
from PIL import Image
from nets import Model
app = FastAPI()
# TODO
# beide modelle laden, jeweils eine gpu zuweisen
# routen bauen, gegen die man bilder werfen kann, die dann jeweils von einem modell interpretiert werden
# ergebnisse zurueck geben
#
# input validierung nicht vergessen
# paramter (bildgroesse etc.) konfigurierbar machen oder automatisch rausfinden?
# kommt ctd überhaupt mit was anderem klar?
class IrImage(BaseModel):
image: np.array
reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
reference_pattern = cv2.imread(reference_pattern_path)
model_path = "train_log/models/latest.pth"
device = torch.device('cuda:0')
# model_path = "train_log/models/epoch-100.pth"
device = torch.device('cuda')
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
model = nn.DataParallel(model, device_ids=[device])
@ -44,32 +28,27 @@ model.load_state_dict(state_dict, strict=True)
model.to(device)
model.eval()
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)
def normalize_and_colormap(img):
ret = (img - img.min()) / (img.max() - img.min()) * 255.0
if isinstance(ret, torch.Tensor):
ret = ret.cpu().detach().numpy()
ret = ret.astype("uint8")
ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO)
return ret
def inference_ctd(left, right, model, n_iter=20):
def inference(left, right, model, n_iter=20):
print("Model Forwarding...")
# print(left.shape)
# left = left.cpu().detach().numpy()
# imgL = left
# imgR = right.cpu().detach().numpy()
imgL = np.ascontiguousarray(left[None, :, :, :])
imgR = np.ascontiguousarray(right[None, :, :, :])
# chosen for convenience
device = torch.device('cuda:0')
device = torch.device('cuda')
imgL = torch.tensor(imgL.astype("float32")).to(device)
imgR = torch.tensor(imgR.astype("float32")).to(device)
imgL = imgL.transpose(2, 3).transpose(1, 2)
# Funzt grob
imgR = imgR.transpose(1,2)
imgL = imgL.transpose(1,2)
imgL_dw2 = F.interpolate(
imgL,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
@ -82,15 +61,14 @@ def inference_ctd(left, right, model, n_iter=20):
mode="bilinear",
align_corners=True,
)
with torch.inference_mode():
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)):
pred_disp = torch.squeeze(pf[:, 0, :, :]).cpu().detach().numpy()
pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy()
return pred_disp_norm
return pred_disp
@app.put("/ir")
@ -99,11 +77,21 @@ async def read_ir_input(file: UploadFile = File(...)):
img = np.array(Image.open(BytesIO(await file.read())))
except Exception as e:
return {'error': 'couldn\'t read file', 'exception': e}
print(img.shape)
# img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
if len(img.shape) == 2:
img = np.stack((img for _ in range(3)))
pred_disp = inference_ctd(np.array(img), reference_pattern, None)
return {"pred_disp": pred_disp}
img = cv2.merge([img for _ in range(3)])
if img.shape == (1024, 1280, 3):
diff = (512 - 480) // 2
downsampled = cv2.pyrDown(img)
img = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]]
img = img.transpose((1,2,0))
ref_pat = reference_pattern.transpose((1,2,0))
pred_disp = inference(img, ref_pat, model)
return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img}, cls=NumpyEncoder)
@app.get('/')

Loading…
Cancel
Save