|
|
|
@ -6,7 +6,6 @@ import os |
|
|
|
|
|
|
|
|
|
from datetime import datetime |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
API_URL = 'http://127.0.0.1:8000' |
|
|
|
|
|
|
|
|
|
img_dir = '../../usable_imgs/' |
|
|
|
@ -14,9 +13,17 @@ img_dir = '../../usable_imgs/' |
|
|
|
|
cv2.namedWindow('Input Image') |
|
|
|
|
cv2.namedWindow('Predicted Disparity') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# epoch 75 ist weird |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NumpyEncoder(json.JSONEncoder): |
|
|
|
|
def default(self, obj): |
|
|
|
|
if isinstance(obj, np.ndarray): |
|
|
|
|
return obj.tolist() |
|
|
|
|
return json.JSONEncoder.default(self, obj) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def normalize_and_colormap(img): |
|
|
|
|
ret = (img - img.min()) / (img.max() - img.min()) * 255.0 |
|
|
|
|
ret = ret.astype("uint8") |
|
|
|
@ -24,42 +31,66 @@ def normalize_and_colormap(img): |
|
|
|
|
return ret |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
while True: |
|
|
|
|
for img in os.scandir(img_dir): |
|
|
|
|
start = datetime.now() |
|
|
|
|
if 'ir' not in img.path: |
|
|
|
|
continue |
|
|
|
|
def change_epoch(): |
|
|
|
|
global r |
|
|
|
|
epoch = input('Enter epoch number or "latest"\n') |
|
|
|
|
r = requests.post(f'{API_URL}/model/update/{epoch}') |
|
|
|
|
print(r.text) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_data(r): |
|
|
|
|
# FIXME yuck, don't json the json |
|
|
|
|
data = json.loads(json.loads(r.text)) |
|
|
|
|
pred_disp = np.asarray(data['disp'], dtype='uint8') |
|
|
|
|
in_img = np.asarray(data['input'], dtype='uint8').transpose((2, 0, 1)) |
|
|
|
|
ref_pat = np.asarray(data['reference'], dtype='uint8').transpose((2, 0, 1)) |
|
|
|
|
duration = data['duration'] |
|
|
|
|
pred_disp = cv2.transpose(pred_disp) |
|
|
|
|
return pred_disp, in_img, ref_pat, duration |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def downsize_input_img(): |
|
|
|
|
input_img = cv2.imread(img.path) |
|
|
|
|
if input_img.shape == (1024, 1280, 3): |
|
|
|
|
diff = (512 - 480) // 2 |
|
|
|
|
downsampled = cv2.pyrDown(input_img) |
|
|
|
|
input_img = downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]] |
|
|
|
|
|
|
|
|
|
openBin = {'file': ('file', open(img.path, 'rb'), 'image/png')} |
|
|
|
|
cv2.imwrite('buffer.png', input_img) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def put_image(img_path): |
|
|
|
|
openBin = {'file': ('file', open(img_path, 'rb'), 'image/png')} |
|
|
|
|
print('sending image') |
|
|
|
|
r = requests.put(f'{API_URL}/ir', files=openBin) |
|
|
|
|
print('received response') |
|
|
|
|
r.raise_for_status() |
|
|
|
|
data = json.loads(json.loads(r.text)) |
|
|
|
|
return r |
|
|
|
|
|
|
|
|
|
# FIXME yuck, don't json the json |
|
|
|
|
pred_disp = np.asarray(data['disp'], dtype='uint8') |
|
|
|
|
in_img = np.asarray(data['input'], dtype='uint8').transpose((2,0,1)) |
|
|
|
|
ref_pat = np.asarray(data['reference'], dtype='uint8').transpose((2,0,1)).astype('uint8') |
|
|
|
|
duration = data['duration'] |
|
|
|
|
pred_disp = cv2.transpose(pred_disp) |
|
|
|
|
print(f'inference took {duration}s') |
|
|
|
|
print(f'pipeline and transfer took another {(datetime.now() - start).total_seconds() - float(duration)}s\n') |
|
|
|
|
|
|
|
|
|
while True: |
|
|
|
|
for img in os.scandir(img_dir): |
|
|
|
|
start = datetime.now() |
|
|
|
|
if 'ir' not in img.path: |
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
# alternatively: use img.path for native size |
|
|
|
|
downsize_input_img() |
|
|
|
|
|
|
|
|
|
r = put_image('buffer.png') |
|
|
|
|
pred_disp, in_img, ref_pat, duration = extract_data(r) |
|
|
|
|
|
|
|
|
|
print(f'inference took {duration:1.4f}s') |
|
|
|
|
print(f'pipeline and transfer took another {(datetime.now() - start).total_seconds() - float(duration):1.4f}s') |
|
|
|
|
print(f"Pred. Disparity: \n\t{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}\n") |
|
|
|
|
|
|
|
|
|
cv2.imshow('Input Image', in_img) |
|
|
|
|
cv2.imshow('Reference Image', ref_pat) |
|
|
|
|
# cv2.imshow('Reference Image', ref_pat) |
|
|
|
|
cv2.imshow('Normalized Predicted Disparity', normalize_and_colormap(pred_disp)) |
|
|
|
|
cv2.imshow('Predicted Disparity', pred_disp) |
|
|
|
|
key = cv2.waitKey() |
|
|
|
|
|
|
|
|
|
if key == 113: |
|
|
|
|
quit() |
|
|
|
|
elif key == 101: |
|
|
|
|
epoch = input('Enter epoch number or "latest"\n') |
|
|
|
|
r = requests.post(f'{API_URL}/model/update/{epoch}') |
|
|
|
|
print(r.text) |
|
|
|
|
change_epoch() |
|
|
|
|