|
|
@ -13,13 +13,13 @@ from PIL import Image |
|
|
|
|
|
|
|
|
|
|
|
from nets import Model |
|
|
|
from nets import Model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
|
|
|
reference_pattern_path = '/home/nils/kinect_reference_cropped.png' |
|
|
|
reference_pattern_path = '/home/nils/kinect_reference_cropped.png' |
|
|
|
reference_pattern = cv2.imread(reference_pattern_path) |
|
|
|
reference_pattern = cv2.imread(reference_pattern_path) |
|
|
|
|
|
|
|
iters = 20 |
|
|
|
|
|
|
|
minimal_data = False |
|
|
|
device = torch.device('cuda:0') |
|
|
|
device = torch.device('cuda:0') |
|
|
|
model = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model(epoch): |
|
|
|
def load_model(epoch): |
|
|
@ -91,6 +91,18 @@ async def change_model(epoch: Union[int, Literal['latest']]): |
|
|
|
return {'status': 'success'} |
|
|
|
return {'status': 'success'} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/params/iterations/{iterations}") |
|
|
|
|
|
|
|
async def set_iterations(iterations: int): |
|
|
|
|
|
|
|
global iters |
|
|
|
|
|
|
|
iters = iterations |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/params/minimal_data/{enable}") |
|
|
|
|
|
|
|
async def set_minimal_data(enable: bool): |
|
|
|
|
|
|
|
global minimal_data |
|
|
|
|
|
|
|
minimal_data = enable |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.put("/ir") |
|
|
|
@app.put("/ir") |
|
|
|
async def read_ir_input(file: UploadFile = File(...)): |
|
|
|
async def read_ir_input(file: UploadFile = File(...)): |
|
|
|
try: |
|
|
|
try: |
|
|
@ -110,14 +122,16 @@ async def read_ir_input(file: UploadFile = File(...)): |
|
|
|
ref_pat = reference_pattern.transpose((1, 2, 0)) |
|
|
|
ref_pat = reference_pattern.transpose((1, 2, 0)) |
|
|
|
|
|
|
|
|
|
|
|
start = datetime.now() |
|
|
|
start = datetime.now() |
|
|
|
pred_disp = inference(img, ref_pat, model, 20) |
|
|
|
pred_disp = inference(img, ref_pat, model, iters) |
|
|
|
duration = (datetime.now() - start).total_seconds() |
|
|
|
duration = (datetime.now() - start).total_seconds() |
|
|
|
|
|
|
|
|
|
|
|
return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img, 'duration': duration}, cls=NumpyEncoder) |
|
|
|
if minimal_data: |
|
|
|
# return json.dumps({'disp': pred_disp, 'duration': duration}, cls=NumpyEncoder) |
|
|
|
return json.dumps({'disp': pred_disp, 'duration': duration}, cls=NumpyEncoder) |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img, 'duration': duration}, |
|
|
|
|
|
|
|
cls=NumpyEncoder) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get('/') |
|
|
|
@app.get('/') |
|
|
|
def main(): |
|
|
|
def main(): |
|
|
|
return {'test': 'abc'} |
|
|
|
return {'test': 'abc'} |
|
|
|
|
|
|
|
|
|
|
|