From 3bc0e7d57595698b1ac0d7c276e57d35d1edcf93 Mon Sep 17 00:00:00 2001 From: Nils Koch Date: Thu, 2 Jun 2022 14:41:12 +0200 Subject: [PATCH] frontend/__init__.py: restructure --- frontend/__init__.py | 85 ++++++++++++++++++++++++++++++-------------- 1 file changed, 58 insertions(+), 27 deletions(-) diff --git a/frontend/__init__.py b/frontend/__init__.py index df4ce14..4d67868 100644 --- a/frontend/__init__.py +++ b/frontend/__init__.py @@ -6,7 +6,6 @@ import os from datetime import datetime - API_URL = 'http://127.0.0.1:8000' img_dir = '../../usable_imgs/' @@ -14,9 +13,17 @@ img_dir = '../../usable_imgs/' cv2.namedWindow('Input Image') cv2.namedWindow('Predicted Disparity') + # epoch 75 ist weird +class NumpyEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, np.ndarray): + return obj.tolist() + return json.JSONEncoder.default(self, obj) + + def normalize_and_colormap(img): ret = (img - img.min()) / (img.max() - img.min()) * 255.0 ret = ret.astype("uint8") @@ -24,42 +31,66 @@ def normalize_and_colormap(img): return ret +def change_epoch(): + global r + epoch = input('Enter epoch number or "latest"\n') + r = requests.post(f'{API_URL}/model/update/{epoch}') + print(r.text) + + +def extract_data(r): + # FIXME yuck, don't json the json + data = json.loads(json.loads(r.text)) + pred_disp = np.asarray(data['disp'], dtype='uint8') + in_img = np.asarray(data['input'], dtype='uint8').transpose((2, 0, 1)) + ref_pat = np.asarray(data['reference'], dtype='uint8').transpose((2, 0, 1)) + duration = data['duration'] + pred_disp = cv2.transpose(pred_disp) + return pred_disp, in_img, ref_pat, duration + + +def downsize_input_img(): + input_img = cv2.imread(img.path) + if input_img.shape == (1024, 1280, 3): + diff = (512 - 480) // 2 + downsampled = cv2.pyrDown(input_img) + input_img = downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]] + + cv2.imwrite('buffer.png', input_img) + + +def put_image(img_path): + openBin = {'file': ('file', open(img_path, 'rb'), 'image/png')} + print('sending image') + r = requests.put(f'{API_URL}/ir', files=openBin) + print('received response') + r.raise_for_status() + return r + + while True: for img in os.scandir(img_dir): start = datetime.now() if 'ir' not in img.path: continue - input_img = cv2.imread(img.path) - if input_img.shape == (1024, 1280, 3): - diff = (512 - 480) // 2 - downsampled = cv2.pyrDown(input_img) - input_img = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]] - - openBin = {'file': ('file', open(img.path, 'rb'), 'image/png')} - - print('sending image') - r = requests.put(f'{API_URL}/ir', files=openBin) - print('received response') - r.raise_for_status() - data = json.loads(json.loads(r.text)) - - # FIXME yuck, don't json the json - pred_disp = np.asarray(data['disp'], dtype='uint8') - in_img = np.asarray(data['input'], dtype='uint8').transpose((2,0,1)) - ref_pat = np.asarray(data['reference'], dtype='uint8').transpose((2,0,1)).astype('uint8') - duration = data['duration'] - pred_disp = cv2.transpose(pred_disp) - print(f'inference took {duration}s') - print(f'pipeline and transfer took another {(datetime.now() - start).total_seconds() - float(duration)}s\n') + + # alternatively: use img.path for native size + downsize_input_img() + + r = put_image('buffer.png') + pred_disp, in_img, ref_pat, duration = extract_data(r) + + print(f'inference took {duration:1.4f}s') + print(f'pipeline and transfer took another {(datetime.now() - start).total_seconds() - float(duration):1.4f}s') + print(f"Pred. Disparity: \n\t{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}\n") cv2.imshow('Input Image', in_img) - cv2.imshow('Reference Image', ref_pat) + # cv2.imshow('Reference Image', ref_pat) cv2.imshow('Normalized Predicted Disparity', normalize_and_colormap(pred_disp)) cv2.imshow('Predicted Disparity', pred_disp) key = cv2.waitKey() + if key == 113: quit() elif key == 101: - epoch = input('Enter epoch number or "latest"\n') - r = requests.post(f'{API_URL}/model/update/{epoch}') - print(r.text) + change_epoch()