api_server.py: reformat, more customizability

main
Nils Koch 2 years ago
parent 3bc0e7d575
commit 50581efa01
  1. 26
      api_server.py

@ -13,13 +13,13 @@ from PIL import Image
from nets import Model from nets import Model
app = FastAPI() app = FastAPI()
reference_pattern_path = '/home/nils/kinect_reference_cropped.png' reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
reference_pattern = cv2.imread(reference_pattern_path) reference_pattern = cv2.imread(reference_pattern_path)
iters = 20
minimal_data = False
device = torch.device('cuda:0') device = torch.device('cuda:0')
model = None
def load_model(epoch): def load_model(epoch):
@ -91,6 +91,18 @@ async def change_model(epoch: Union[int, Literal['latest']]):
return {'status': 'success'} return {'status': 'success'}
@app.post("/params/iterations/{iterations}")
async def set_iterations(iterations: int):
global iters
iters = iterations
@app.post("/params/minimal_data/{enable}")
async def set_minimal_data(enable: bool):
global minimal_data
minimal_data = enable
@app.put("/ir") @app.put("/ir")
async def read_ir_input(file: UploadFile = File(...)): async def read_ir_input(file: UploadFile = File(...)):
try: try:
@ -110,14 +122,16 @@ async def read_ir_input(file: UploadFile = File(...)):
ref_pat = reference_pattern.transpose((1, 2, 0)) ref_pat = reference_pattern.transpose((1, 2, 0))
start = datetime.now() start = datetime.now()
pred_disp = inference(img, ref_pat, model, 20) pred_disp = inference(img, ref_pat, model, iters)
duration = (datetime.now() - start).total_seconds() duration = (datetime.now() - start).total_seconds()
return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img, 'duration': duration}, cls=NumpyEncoder) if minimal_data:
# return json.dumps({'disp': pred_disp, 'duration': duration}, cls=NumpyEncoder) return json.dumps({'disp': pred_disp, 'duration': duration}, cls=NumpyEncoder)
else:
return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img, 'duration': duration},
cls=NumpyEncoder)
@app.get('/') @app.get('/')
def main(): def main():
return {'test': 'abc'} return {'test': 'abc'}

Loading…
Cancel
Save