api_server.py: reformat, more customizability
This commit is contained in:
parent
3bc0e7d575
commit
50581efa01
@ -13,13 +13,13 @@ from PIL import Image
|
|||||||
|
|
||||||
from nets import Model
|
from nets import Model
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
|
reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
|
||||||
reference_pattern = cv2.imread(reference_pattern_path)
|
reference_pattern = cv2.imread(reference_pattern_path)
|
||||||
|
iters = 20
|
||||||
|
minimal_data = False
|
||||||
device = torch.device('cuda:0')
|
device = torch.device('cuda:0')
|
||||||
model = None
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(epoch):
|
def load_model(epoch):
|
||||||
@ -57,8 +57,8 @@ def inference(left, right, model, n_iter=20):
|
|||||||
imgL = torch.tensor(imgL.astype("float32")).to(device)
|
imgL = torch.tensor(imgL.astype("float32")).to(device)
|
||||||
imgR = torch.tensor(imgR.astype("float32")).to(device)
|
imgR = torch.tensor(imgR.astype("float32")).to(device)
|
||||||
|
|
||||||
imgR = imgR.transpose(1,2)
|
imgR = imgR.transpose(1, 2)
|
||||||
imgL = imgL.transpose(1,2)
|
imgL = imgL.transpose(1, 2)
|
||||||
|
|
||||||
imgL_dw2 = F.interpolate(
|
imgL_dw2 = F.interpolate(
|
||||||
imgL,
|
imgL,
|
||||||
@ -91,6 +91,18 @@ async def change_model(epoch: Union[int, Literal['latest']]):
|
|||||||
return {'status': 'success'}
|
return {'status': 'success'}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/params/iterations/{iterations}")
|
||||||
|
async def set_iterations(iterations: int):
|
||||||
|
global iters
|
||||||
|
iters = iterations
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/params/minimal_data/{enable}")
|
||||||
|
async def set_minimal_data(enable: bool):
|
||||||
|
global minimal_data
|
||||||
|
minimal_data = enable
|
||||||
|
|
||||||
|
|
||||||
@app.put("/ir")
|
@app.put("/ir")
|
||||||
async def read_ir_input(file: UploadFile = File(...)):
|
async def read_ir_input(file: UploadFile = File(...)):
|
||||||
try:
|
try:
|
||||||
@ -104,20 +116,22 @@ async def read_ir_input(file: UploadFile = File(...)):
|
|||||||
if img.shape == (1024, 1280, 3):
|
if img.shape == (1024, 1280, 3):
|
||||||
diff = (512 - 480) // 2
|
diff = (512 - 480) // 2
|
||||||
downsampled = cv2.pyrDown(img)
|
downsampled = cv2.pyrDown(img)
|
||||||
img = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]]
|
img = downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]]
|
||||||
|
|
||||||
img = img.transpose((1,2,0))
|
img = img.transpose((1, 2, 0))
|
||||||
ref_pat = reference_pattern.transpose((1,2,0))
|
ref_pat = reference_pattern.transpose((1, 2, 0))
|
||||||
|
|
||||||
start = datetime.now()
|
start = datetime.now()
|
||||||
pred_disp = inference(img, ref_pat, model, 20)
|
pred_disp = inference(img, ref_pat, model, iters)
|
||||||
duration = (datetime.now() - start).total_seconds()
|
duration = (datetime.now() - start).total_seconds()
|
||||||
|
|
||||||
return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img, 'duration': duration}, cls=NumpyEncoder)
|
if minimal_data:
|
||||||
# return json.dumps({'disp': pred_disp, 'duration': duration}, cls=NumpyEncoder)
|
return json.dumps({'disp': pred_disp, 'duration': duration}, cls=NumpyEncoder)
|
||||||
|
else:
|
||||||
|
return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img, 'duration': duration},
|
||||||
|
cls=NumpyEncoder)
|
||||||
|
|
||||||
|
|
||||||
@app.get('/')
|
@app.get('/')
|
||||||
def main():
|
def main():
|
||||||
return {'test': 'abc'}
|
return {'test': 'abc'}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user