|
|
@ -1,4 +1,6 @@ |
|
|
|
import json |
|
|
|
import json |
|
|
|
|
|
|
|
from datetime import datetime |
|
|
|
|
|
|
|
from typing import Union, Literal |
|
|
|
from io import BytesIO |
|
|
|
from io import BytesIO |
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
|
import numpy as np |
|
|
@ -16,10 +18,14 @@ app = FastAPI() |
|
|
|
|
|
|
|
|
|
|
|
reference_pattern_path = '/home/nils/kinect_reference_cropped.png' |
|
|
|
reference_pattern_path = '/home/nils/kinect_reference_cropped.png' |
|
|
|
reference_pattern = cv2.imread(reference_pattern_path) |
|
|
|
reference_pattern = cv2.imread(reference_pattern_path) |
|
|
|
model_path = "train_log/models/latest.pth" |
|
|
|
device = torch.device('cuda:0') |
|
|
|
# model_path = "train_log/models/epoch-100.pth" |
|
|
|
model = None |
|
|
|
device = torch.device('cuda') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model(epoch): |
|
|
|
|
|
|
|
global model |
|
|
|
|
|
|
|
epoch = f'epoch-{epoch}.pth' if epoch != 'latest' else 'latest.pth' |
|
|
|
|
|
|
|
model_path = f"train_log/models/{epoch}" |
|
|
|
model = Model(max_disp=256, mixed_precision=False, test_mode=True) |
|
|
|
model = Model(max_disp=256, mixed_precision=False, test_mode=True) |
|
|
|
model = nn.DataParallel(model, device_ids=[device]) |
|
|
|
model = nn.DataParallel(model, device_ids=[device]) |
|
|
|
# model.load_state_dict(torch.load(model_path), strict=False) |
|
|
|
# model.load_state_dict(torch.load(model_path), strict=False) |
|
|
@ -27,6 +33,12 @@ state_dict = torch.load(model_path)['state_dict'] |
|
|
|
model.load_state_dict(state_dict, strict=True) |
|
|
|
model.load_state_dict(state_dict, strict=True) |
|
|
|
model.to(device) |
|
|
|
model.to(device) |
|
|
|
model.eval() |
|
|
|
model.eval() |
|
|
|
|
|
|
|
print(f'loaded model {epoch}') |
|
|
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = load_model('latest') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NumpyEncoder(json.JSONEncoder): |
|
|
|
class NumpyEncoder(json.JSONEncoder): |
|
|
|
def default(self, obj): |
|
|
|
def default(self, obj): |
|
|
@ -45,7 +57,6 @@ def inference(left, right, model, n_iter=20): |
|
|
|
imgL = torch.tensor(imgL.astype("float32")).to(device) |
|
|
|
imgL = torch.tensor(imgL.astype("float32")).to(device) |
|
|
|
imgR = torch.tensor(imgR.astype("float32")).to(device) |
|
|
|
imgR = torch.tensor(imgR.astype("float32")).to(device) |
|
|
|
|
|
|
|
|
|
|
|
# Funzt grob |
|
|
|
|
|
|
|
imgR = imgR.transpose(1,2) |
|
|
|
imgR = imgR.transpose(1,2) |
|
|
|
imgL = imgL.transpose(1,2) |
|
|
|
imgL = imgL.transpose(1,2) |
|
|
|
|
|
|
|
|
|
|
@ -71,6 +82,15 @@ def inference(left, right, model, n_iter=20): |
|
|
|
return pred_disp |
|
|
|
return pred_disp |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/model/update/{epoch}") |
|
|
|
|
|
|
|
async def change_model(epoch: Union[int, Literal['latest']]): |
|
|
|
|
|
|
|
global model |
|
|
|
|
|
|
|
print(epoch) |
|
|
|
|
|
|
|
print('updating model') |
|
|
|
|
|
|
|
model = load_model(epoch) |
|
|
|
|
|
|
|
return {'status': 'success'} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.put("/ir") |
|
|
|
@app.put("/ir") |
|
|
|
async def read_ir_input(file: UploadFile = File(...)): |
|
|
|
async def read_ir_input(file: UploadFile = File(...)): |
|
|
|
try: |
|
|
|
try: |
|
|
@ -89,11 +109,15 @@ async def read_ir_input(file: UploadFile = File(...)): |
|
|
|
img = img.transpose((1,2,0)) |
|
|
|
img = img.transpose((1,2,0)) |
|
|
|
ref_pat = reference_pattern.transpose((1,2,0)) |
|
|
|
ref_pat = reference_pattern.transpose((1,2,0)) |
|
|
|
|
|
|
|
|
|
|
|
pred_disp = inference(img, ref_pat, model) |
|
|
|
start = datetime.now() |
|
|
|
|
|
|
|
pred_disp = inference(img, ref_pat, model, 20) |
|
|
|
|
|
|
|
duration = (datetime.now() - start).total_seconds() |
|
|
|
|
|
|
|
|
|
|
|
return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img}, cls=NumpyEncoder) |
|
|
|
return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img, 'duration': duration}, cls=NumpyEncoder) |
|
|
|
|
|
|
|
# return json.dumps({'disp': pred_disp, 'duration': duration}, cls=NumpyEncoder) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get('/') |
|
|
|
@app.get('/') |
|
|
|
def main(): |
|
|
|
def main(): |
|
|
|
return {'test': 'abc'} |
|
|
|
return {'test': 'abc'} |
|
|
|
|
|
|
|
|
|
|
|