frontend/__init__.py: restructure

main
Nils Koch 2 years ago
parent b614fcfd74
commit 3bc0e7d575
  1. 85
      frontend/__init__.py

@ -6,7 +6,6 @@ import os
from datetime import datetime from datetime import datetime
API_URL = 'http://127.0.0.1:8000' API_URL = 'http://127.0.0.1:8000'
img_dir = '../../usable_imgs/' img_dir = '../../usable_imgs/'
@ -14,9 +13,17 @@ img_dir = '../../usable_imgs/'
cv2.namedWindow('Input Image') cv2.namedWindow('Input Image')
cv2.namedWindow('Predicted Disparity') cv2.namedWindow('Predicted Disparity')
# epoch 75 ist weird # epoch 75 ist weird
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)
def normalize_and_colormap(img): def normalize_and_colormap(img):
ret = (img - img.min()) / (img.max() - img.min()) * 255.0 ret = (img - img.min()) / (img.max() - img.min()) * 255.0
ret = ret.astype("uint8") ret = ret.astype("uint8")
@ -24,42 +31,66 @@ def normalize_and_colormap(img):
return ret return ret
def change_epoch():
global r
epoch = input('Enter epoch number or "latest"\n')
r = requests.post(f'{API_URL}/model/update/{epoch}')
print(r.text)
def extract_data(r):
# FIXME yuck, don't json the json
data = json.loads(json.loads(r.text))
pred_disp = np.asarray(data['disp'], dtype='uint8')
in_img = np.asarray(data['input'], dtype='uint8').transpose((2, 0, 1))
ref_pat = np.asarray(data['reference'], dtype='uint8').transpose((2, 0, 1))
duration = data['duration']
pred_disp = cv2.transpose(pred_disp)
return pred_disp, in_img, ref_pat, duration
def downsize_input_img():
input_img = cv2.imread(img.path)
if input_img.shape == (1024, 1280, 3):
diff = (512 - 480) // 2
downsampled = cv2.pyrDown(input_img)
input_img = downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]]
cv2.imwrite('buffer.png', input_img)
def put_image(img_path):
openBin = {'file': ('file', open(img_path, 'rb'), 'image/png')}
print('sending image')
r = requests.put(f'{API_URL}/ir', files=openBin)
print('received response')
r.raise_for_status()
return r
while True: while True:
for img in os.scandir(img_dir): for img in os.scandir(img_dir):
start = datetime.now() start = datetime.now()
if 'ir' not in img.path: if 'ir' not in img.path:
continue continue
input_img = cv2.imread(img.path)
if input_img.shape == (1024, 1280, 3): # alternatively: use img.path for native size
diff = (512 - 480) // 2 downsize_input_img()
downsampled = cv2.pyrDown(input_img)
input_img = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]] r = put_image('buffer.png')
pred_disp, in_img, ref_pat, duration = extract_data(r)
openBin = {'file': ('file', open(img.path, 'rb'), 'image/png')}
print(f'inference took {duration:1.4f}s')
print('sending image') print(f'pipeline and transfer took another {(datetime.now() - start).total_seconds() - float(duration):1.4f}s')
r = requests.put(f'{API_URL}/ir', files=openBin) print(f"Pred. Disparity: \n\t{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}\n")
print('received response')
r.raise_for_status()
data = json.loads(json.loads(r.text))
# FIXME yuck, don't json the json
pred_disp = np.asarray(data['disp'], dtype='uint8')
in_img = np.asarray(data['input'], dtype='uint8').transpose((2,0,1))
ref_pat = np.asarray(data['reference'], dtype='uint8').transpose((2,0,1)).astype('uint8')
duration = data['duration']
pred_disp = cv2.transpose(pred_disp)
print(f'inference took {duration}s')
print(f'pipeline and transfer took another {(datetime.now() - start).total_seconds() - float(duration)}s\n')
cv2.imshow('Input Image', in_img) cv2.imshow('Input Image', in_img)
cv2.imshow('Reference Image', ref_pat) # cv2.imshow('Reference Image', ref_pat)
cv2.imshow('Normalized Predicted Disparity', normalize_and_colormap(pred_disp)) cv2.imshow('Normalized Predicted Disparity', normalize_and_colormap(pred_disp))
cv2.imshow('Predicted Disparity', pred_disp) cv2.imshow('Predicted Disparity', pred_disp)
key = cv2.waitKey() key = cv2.waitKey()
if key == 113: if key == 113:
quit() quit()
elif key == 101: elif key == 101:
epoch = input('Enter epoch number or "latest"\n') change_epoch()
r = requests.post(f'{API_URL}/model/update/{epoch}')
print(r.text)

Loading…
Cancel
Save