CREStereo Repository for the 'Towards accurate and robust depth estimation' project
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
CREStereo-pytorch-nxt/api_server.py

99 lines
2.8 KiB

import json
from io import BytesIO
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from cv2 import cv2
from fastapi import FastAPI, File, UploadFile
from PIL import Image
from nets import Model
app = FastAPI()
reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
reference_pattern = cv2.imread(reference_pattern_path)
model_path = "train_log/models/latest.pth"
# model_path = "train_log/models/epoch-100.pth"
device = torch.device('cuda')
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
model = nn.DataParallel(model, device_ids=[device])
# model.load_state_dict(torch.load(model_path), strict=False)
state_dict = torch.load(model_path)['state_dict']
model.load_state_dict(state_dict, strict=True)
model.to(device)
model.eval()
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)
def inference(left, right, model, n_iter=20):
print("Model Forwarding...")
imgL = np.ascontiguousarray(left[None, :, :, :])
imgR = np.ascontiguousarray(right[None, :, :, :])
device = torch.device('cuda')
imgL = torch.tensor(imgL.astype("float32")).to(device)
imgR = torch.tensor(imgR.astype("float32")).to(device)
# Funzt grob
imgR = imgR.transpose(1,2)
imgL = imgL.transpose(1,2)
imgL_dw2 = F.interpolate(
imgL,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
mode="bilinear",
align_corners=True,
)
imgR_dw2 = F.interpolate(
imgR,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
mode="bilinear",
align_corners=True,
)
with torch.inference_mode():
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy()
return pred_disp
@app.put("/ir")
async def read_ir_input(file: UploadFile = File(...)):
try:
img = np.array(Image.open(BytesIO(await file.read())))
except Exception as e:
return {'error': 'couldn\'t read file', 'exception': e}
# img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
if len(img.shape) == 2:
img = cv2.merge([img for _ in range(3)])
if img.shape == (1024, 1280, 3):
diff = (512 - 480) // 2
downsampled = cv2.pyrDown(img)
img = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]]
img = img.transpose((1,2,0))
ref_pat = reference_pattern.transpose((1,2,0))
pred_disp = inference(img, ref_pat, model)
return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img}, cls=NumpyEncoder)
@app.get('/')
def main():
return {'test': 'abc'}