diff --git a/frontend/__init__.py b/frontend/__init__.py index 4d67868..3940816 100644 --- a/frontend/__init__.py +++ b/frontend/__init__.py @@ -32,20 +32,23 @@ def normalize_and_colormap(img): def change_epoch(): - global r epoch = input('Enter epoch number or "latest"\n') r = requests.post(f'{API_URL}/model/update/{epoch}') print(r.text) -def extract_data(r): +def extract_data(data): # FIXME yuck, don't json the json - data = json.loads(json.loads(r.text)) - pred_disp = np.asarray(data['disp'], dtype='uint8') + duration = data['duration'] + + # get result and rotate 90 deg + pred_disp = cv2.transpose(np.asarray(data['disp'], dtype='uint8')) + + if input not in data: + return pred_disp, duration + in_img = np.asarray(data['input'], dtype='uint8').transpose((2, 0, 1)) ref_pat = np.asarray(data['reference'], dtype='uint8').transpose((2, 0, 1)) - duration = data['duration'] - pred_disp = cv2.transpose(pred_disp) return pred_disp, in_img, ref_pat, duration @@ -65,32 +68,34 @@ def put_image(img_path): r = requests.put(f'{API_URL}/ir', files=openBin) print('received response') r.raise_for_status() - return r + data = json.loads(json.loads(r.text)) + return data -while True: - for img in os.scandir(img_dir): - start = datetime.now() - if 'ir' not in img.path: - continue +if __name__ == '__main__': + while True: + for img in os.scandir(img_dir): + start = datetime.now() + if 'ir' not in img.path: + continue - # alternatively: use img.path for native size - downsize_input_img() + # alternatively: use img.path for native size + downsize_input_img() - r = put_image('buffer.png') - pred_disp, in_img, ref_pat, duration = extract_data(r) + data = put_image('buffer.png') + pred_disp, in_img, ref_pat, duration = extract_data(data) - print(f'inference took {duration:1.4f}s') - print(f'pipeline and transfer took another {(datetime.now() - start).total_seconds() - float(duration):1.4f}s') - print(f"Pred. Disparity: \n\t{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}\n") + print(f'inference took {duration:1.4f}s') + print(f'pipeline and transfer took another {(datetime.now() - start).total_seconds() - float(duration):1.4f}s') + print(f"Pred. Disparity: \n\t{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}\n") - cv2.imshow('Input Image', in_img) - # cv2.imshow('Reference Image', ref_pat) - cv2.imshow('Normalized Predicted Disparity', normalize_and_colormap(pred_disp)) - cv2.imshow('Predicted Disparity', pred_disp) - key = cv2.waitKey() + cv2.imshow('Input Image', in_img) + # cv2.imshow('Reference Image', ref_pat) + cv2.imshow('Normalized Predicted Disparity', normalize_and_colormap(pred_disp)) + cv2.imshow('Predicted Disparity', pred_disp) + key = cv2.waitKey() - if key == 113: - quit() - elif key == 101: - change_epoch() + if key == 113: + quit() + elif key == 101: + change_epoch()