diff --git a/api_server.py b/api_server.py index 78e09f3..0448071 100644 --- a/api_server.py +++ b/api_server.py @@ -13,13 +13,13 @@ from PIL import Image from nets import Model - app = FastAPI() reference_pattern_path = '/home/nils/kinect_reference_cropped.png' reference_pattern = cv2.imread(reference_pattern_path) +iters = 20 +minimal_data = False device = torch.device('cuda:0') -model = None def load_model(epoch): @@ -57,8 +57,8 @@ def inference(left, right, model, n_iter=20): imgL = torch.tensor(imgL.astype("float32")).to(device) imgR = torch.tensor(imgR.astype("float32")).to(device) - imgR = imgR.transpose(1,2) - imgL = imgL.transpose(1,2) + imgR = imgR.transpose(1, 2) + imgL = imgL.transpose(1, 2) imgL_dw2 = F.interpolate( imgL, @@ -91,6 +91,18 @@ async def change_model(epoch: Union[int, Literal['latest']]): return {'status': 'success'} +@app.post("/params/iterations/{iterations}") +async def set_iterations(iterations: int): + global iters + iters = iterations + + +@app.post("/params/minimal_data/{enable}") +async def set_minimal_data(enable: bool): + global minimal_data + minimal_data = enable + + @app.put("/ir") async def read_ir_input(file: UploadFile = File(...)): try: @@ -104,20 +116,22 @@ async def read_ir_input(file: UploadFile = File(...)): if img.shape == (1024, 1280, 3): diff = (512 - 480) // 2 downsampled = cv2.pyrDown(img) - img = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]] + img = downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]] - img = img.transpose((1,2,0)) - ref_pat = reference_pattern.transpose((1,2,0)) + img = img.transpose((1, 2, 0)) + ref_pat = reference_pattern.transpose((1, 2, 0)) start = datetime.now() - pred_disp = inference(img, ref_pat, model, 20) + pred_disp = inference(img, ref_pat, model, iters) duration = (datetime.now() - start).total_seconds() - return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img, 'duration': duration}, cls=NumpyEncoder) - # return json.dumps({'disp': pred_disp, 'duration': duration}, cls=NumpyEncoder) + if minimal_data: + return json.dumps({'disp': pred_disp, 'duration': duration}, cls=NumpyEncoder) + else: + return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img, 'duration': duration}, + cls=NumpyEncoder) @app.get('/') def main(): return {'test': 'abc'} -