From bb15dcd0a12911ce4fbcb23d3ce53047e330ffcd Mon Sep 17 00:00:00 2001 From: Nils Koch Date: Thu, 2 Jun 2022 12:12:53 +0200 Subject: [PATCH] api_server.py: report timing info, allow on the fly changing of model --- api_server.py | 54 +++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 39 insertions(+), 15 deletions(-) diff --git a/api_server.py b/api_server.py index 26227e7..78e09f3 100644 --- a/api_server.py +++ b/api_server.py @@ -1,4 +1,6 @@ import json +from datetime import datetime +from typing import Union, Literal from io import BytesIO import numpy as np @@ -16,17 +18,27 @@ app = FastAPI() reference_pattern_path = '/home/nils/kinect_reference_cropped.png' reference_pattern = cv2.imread(reference_pattern_path) -model_path = "train_log/models/latest.pth" -# model_path = "train_log/models/epoch-100.pth" -device = torch.device('cuda') - -model = Model(max_disp=256, mixed_precision=False, test_mode=True) -model = nn.DataParallel(model, device_ids=[device]) -# model.load_state_dict(torch.load(model_path), strict=False) -state_dict = torch.load(model_path)['state_dict'] -model.load_state_dict(state_dict, strict=True) -model.to(device) -model.eval() +device = torch.device('cuda:0') +model = None + + +def load_model(epoch): + global model + epoch = f'epoch-{epoch}.pth' if epoch != 'latest' else 'latest.pth' + model_path = f"train_log/models/{epoch}" + model = Model(max_disp=256, mixed_precision=False, test_mode=True) + model = nn.DataParallel(model, device_ids=[device]) + # model.load_state_dict(torch.load(model_path), strict=False) + state_dict = torch.load(model_path)['state_dict'] + model.load_state_dict(state_dict, strict=True) + model.to(device) + model.eval() + print(f'loaded model {epoch}') + return model + + +model = load_model('latest') + class NumpyEncoder(json.JSONEncoder): def default(self, obj): @@ -45,10 +57,9 @@ def inference(left, right, model, n_iter=20): imgL = torch.tensor(imgL.astype("float32")).to(device) imgR = torch.tensor(imgR.astype("float32")).to(device) - # Funzt grob imgR = imgR.transpose(1,2) imgL = imgL.transpose(1,2) - + imgL_dw2 = F.interpolate( imgL, size=(imgL.shape[2] // 2, imgL.shape[3] // 2), @@ -71,6 +82,15 @@ def inference(left, right, model, n_iter=20): return pred_disp +@app.post("/model/update/{epoch}") +async def change_model(epoch: Union[int, Literal['latest']]): + global model + print(epoch) + print('updating model') + model = load_model(epoch) + return {'status': 'success'} + + @app.put("/ir") async def read_ir_input(file: UploadFile = File(...)): try: @@ -89,11 +109,15 @@ async def read_ir_input(file: UploadFile = File(...)): img = img.transpose((1,2,0)) ref_pat = reference_pattern.transpose((1,2,0)) - pred_disp = inference(img, ref_pat, model) + start = datetime.now() + pred_disp = inference(img, ref_pat, model, 20) + duration = (datetime.now() - start).total_seconds() - return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img}, cls=NumpyEncoder) + return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img, 'duration': duration}, cls=NumpyEncoder) + # return json.dumps({'disp': pred_disp, 'duration': duration}, cls=NumpyEncoder) @app.get('/') def main(): return {'test': 'abc'} +