frontend/__init__.py: major refactor, still wip

now uses async to pipeline images for inference, triggered by signals, generates pointclouds from depth and displays them
main
Nils Koch 2 years ago
parent 1eefa2847b
commit ed80c3056f
  1. 330
      frontend/__init__.py

@ -1,4 +1,11 @@
import requests import signal
from time import sleep
# import requests
import httpx as requests
import asyncio
import open3d as o3d
from cv2 import cv2 from cv2 import cv2
import numpy as np import numpy as np
import json import json
@ -13,6 +20,27 @@ img_dir = '../../usable_imgs/'
cv2.namedWindow('Input Image') cv2.namedWindow('Input Image')
cv2.namedWindow('Predicted Disparity') cv2.namedWindow('Predicted Disparity')
vis = o3d.visualization.VisualizerWithKeyCallback()
viscont = o3d.visualization.ViewControl()
# vis.register_key_callback(99)
vis.create_window()
K = np.array([[567.6, 0, 324.7], [0, 570.2, 250.1], [0, 0, 1]], dtype=np.float32)
# temporal_init = requests.get(f'{API_URL}/temporal_init')
good_models = [260, 183]
interesting = [214, ]
# new ganz gut bei ca 175
verbose = False
running_tasks = set()
minimal_data = False
with open('frontend.pid', 'w+') as f:
print('writing pid')
f.write(str(os.getpid()))
# epoch 75 ist weird # epoch 75 ist weird
@ -24,6 +52,34 @@ class NumpyEncoder(json.JSONEncoder):
return json.JSONEncoder.default(self, obj) return json.JSONEncoder.default(self, obj)
def update_vis(*args):
vis.poll_events()
vis.update_renderer()
# signal.signal(signal.SIGALRM, update_vis)
# signal.setitimer(signal.ITIMER_REAL, 0.1, 0.1)
def ghetto_lcn(img):
# gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
gray = img
float_gray = gray.astype(np.float32) / 255.0
blur = cv2.GaussianBlur(float_gray, (0, 0), sigmaX=2, sigmaY=2)
num = float_gray - blur
blur = cv2.GaussianBlur(num * num, (0, 0), sigmaX=20, sigmaY=20)
den = cv2.pow(blur, 0.5)
gray = num / den
# cv2.normalize(gray, dst=gray, alpha=0.0, beta=1.0, norm_type=cv2.NORM_MINMAX)
cv2.normalize(gray, dst=gray, alpha=0.0, beta=255.0, norm_type=cv2.NORM_MINMAX)
return gray
def normalize_and_colormap(img): def normalize_and_colormap(img):
ret = (img - img.min()) / (img.max() - img.min()) * 255.0 ret = (img - img.min()) / (img.max() - img.min()) * 255.0
ret = ret.astype("uint8") ret = ret.astype("uint8")
@ -31,10 +87,74 @@ def normalize_and_colormap(img):
return ret return ret
def change_epoch(): def reproject(disparity_img):
epoch = input('Enter epoch number or "latest"\n') print('reprojecting')
baseline = 0.075
depth_img = baseline * K[0][0] / disparity_img
pointcloud = o3d.geometry.PointCloud()
intrinsics = o3d.pybind.camera.PinholeCameraIntrinsic()
print('setting intrinsics')
intrinsics.set_intrinsics(width=640, height=480, fx=K[0][0], fy=K[1][1], cx=0., cy=0.)
# depth = open3d.geometry.Image(depth_img.astype('float32'))
rgb = normalize_and_colormap(disparity_img)
rgb = o3d.geometry.Image(rgb * 255)
print(depth_img.max(), depth_img.min())
depth_img = np.log(depth_img + (1 - depth_img.min()) + 1)
print(depth_img.max(), depth_img.min())
depth = o3d.geometry.Image(depth_img.astype('float32'))
rgb_depth = o3d.geometry.RGBDImage().create_from_color_and_depth(
color=rgb,
depth=depth,
depth_scale=1,
convert_rgb_to_intensity=False,
)
print('creating pointcloud')
# depth = open3d.cpu.pybind.t.geometry.Image(depth_img.astype('float32'))
# depth.colorize_depth(1.0, 0., 1.)
# print('now really creating pointcloud')
# dpcd = pointcloud.create_from_depth_image(
# depth=depth,
# intrinsic=intrinsics,
# )
# print(type(depth))
pcd = pointcloud.create_from_rgbd_image(
image=rgb_depth,
intrinsic=intrinsics,
# project_valid_depth_only=False,
)
flip_transform = [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
# dpcd.paint_uniform_color(np.asarray([0.5, 0.4, 0.25]))
pcd.transform(flip_transform)
# dpcd.transform(flip_transform)
pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=2.0)
print('drawing pointcloud')
# vis.clear_geometries()
# vis.add_geometry(pcd, reset_bounding_box=True)
# viscont.rotate(250, 500)
vis.update_geometry(pcd)
vis.poll_events()
vis.update_renderer()
# vis.run()
# o3d.visualization.draw(geometry=[rgb_depth])
# o3d.visualization.draw([dpcd])
def change_epoch(epoch: int = None):
if epoch is None:
epoch = input('Enter epoch number or "latest"\n')
r = requests.post(f'{API_URL}/model/update/{epoch}') r = requests.post(f'{API_URL}/model/update/{epoch}')
print(r.text) # print(r.text)
def change_reference():
r = requests.post(f'{API_URL}/params/update_reference')
print(r.json()['status'])
if r.json()['status'] == 'finished':
change_reference()
def extract_data(data): def extract_data(data):
@ -42,72 +162,190 @@ def extract_data(data):
duration = data['duration'] duration = data['duration']
# get result and rotate 90 deg # get result and rotate 90 deg
pred_disp = cv2.transpose(np.asarray(data['disp'], dtype='uint8')) # pred_disp = cv2.transpose(np.asarray(data['disp'], dtype='uint8'))
raw_disp = np.asarray(data['disp'])
# print(raw_disp.min(), raw_disp.max())
if raw_disp.min() < 0:
# print('Negative disparity detected. shifting...')
raw_disp = raw_disp - raw_disp.min()
if raw_disp.max() > 255:
# print('Excessive disparity detected. scaling...')
raw_disp = raw_disp / (raw_disp.max() / 255)
pred_disp = np.asarray(raw_disp, dtype='uint8')
if 'input' not in data: # if 'input' not in data:
if len(data) == 2:
return pred_disp, duration return pred_disp, duration
ref_pat = data.get('reference', None)
in_img = np.asarray(data['input'], dtype='uint8').transpose((2, 0, 1)) in_img = np.asarray(data['input'], dtype='uint8') # .transpose((2, 0, 1))
ref_pat = np.asarray(data['reference'], dtype='uint8').transpose((2, 0, 1)) if ref_pat:
ref_pat = np.asarray(ref_pat, dtype='uint8') # .transpose((2, 0, 1))
return pred_disp, in_img, ref_pat, duration return pred_disp, in_img, ref_pat, duration
def downsize_input_img(): def downsize_input_img(path):
input_img = cv2.imread(img.path) input_img = None
while input_img is None:
input_img = cv2.imread(path)
if input_img.shape == (1024, 1280, 3): if input_img.shape == (1024, 1280, 3):
diff = (512 - 480) // 2 diff = (512 - 480) // 2
downsampled = cv2.pyrDown(input_img) downsampled = cv2.pyrDown(input_img)
input_img = downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]] input_img = downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]]
# print(input_img.shape)
input_img = cv2.normalize(input_img, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
# input_img = ghetto_lcn(input_img)
cv2.imwrite('buffer.png', input_img) cv2.imwrite('buffer.png', input_img)
def put_image(img_path): async def put_image(img_path):
openBin = {'file': ('file', open(img_path, 'rb'), 'image/png')} openBin = {'file': ('file', open(img_path, 'rb'), 'image/png')}
print('sending image') if verbose:
r = requests.put(f'{API_URL}/ir', files=openBin) print('sending image')
print('received response') async with requests.AsyncClient() as client:
r = await client.put(f'{API_URL}/ir', files=openBin)
if verbose:
print('received response')
r.raise_for_status() r.raise_for_status()
data = json.loads(json.loads(r.text)) data = json.loads(json.loads(r.text))
return data return data
def change_minimal_data(enabled): def change_minimal_data(current: bool = None):
r = requests.post(f'{API_URL}/params/minimal_data/{not enabled}') global minimal_data
if current is None:
current = minimal_data
minimal_data = not current
r = requests.post(f'{API_URL}/params/minimal_data/{minimal_data}')
cv2.destroyWindow('Input Image') cv2.destroyWindow('Input Image')
cv2.destroyWindow('Reference Image') cv2.destroyWindow('Reference Image')
if __name__ == '__main__': def change_temporal_init(enabled):
while True: global temporal_init
for img in os.scandir(img_dir): r = requests.post(f'{API_URL}/params/temporal_init/{not enabled}')
start = datetime.now() temporal_init = not temporal_init
def handle_keypress(key):
if key == 113:
quit()
elif key == 101:
change_epoch()
elif key == 109:
change_minimal_data()
elif key == 116:
change_temporal_init(temporal_init)
elif key == 99:
change_reference()
async def do_inference():
start = datetime.now()
data = await put_image('buffer.png')
in_img = None
ref_pat = None
if len(data) == 4:
pred_disp, in_img, ref_pat, duration = extract_data(data)
elif len(data) == 2:
pred_disp, duration = extract_data(data)
reproject(pred_disp)
show_results(duration, in_img, pred_disp, ref_pat, start)
# reproject(pred_disp)
def show_results(duration, in_img, pred_disp, ref_pat, start):
print(f"Pred. Disparity: \n\t{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}")
if verbose:
print(f'inference took {duration:1.4f}s')
print(f'pipeline and transfer took another {(datetime.now() - start).total_seconds() - float(duration):1.4f}s')
print(f'total {(datetime.now() - start).total_seconds():1.4f}s')
if in_img is not None:
cv2.imshow('Input Image', in_img)
else:
cv2.imshow('Input Image', cv2.imread('buffer.png'))
if ref_pat is not None:
cv2.imshow('Reference Image', ref_pat)
cv2.imshow('Normalized Predicted Disparity', normalize_and_colormap(pred_disp))
cv2.imshow('Predicted Disparity', pred_disp)
key = cv2.waitKey(1000)
handle_keypress(key)
async def fresh_img():
# print('running task')
start = datetime.now()
print(f'started at {start}')
downsize_input_img('kinect_ir.png')
await do_inference()
print(f'task took {(datetime.now() - start).total_seconds()}')
print()
def create_task(*args):
global running_tasks
# print('received signal')
print(f'currently running: {len(running_tasks)} tasks')
task = asyncio.create_task(fresh_img())
# print(f'created task {task.get_name()}')
running_tasks.add(task)
task.add_done_callback(running_tasks.discard)
# await task
# return task
signal.signal(signal.SIGUSR1, create_task)
async def run_test(img_dir, iterate_checkpoints):
img_dir = list(os.scandir(img_dir))
for epoch in range(175, 270):
if iterate_checkpoints:
change_epoch(epoch)
print()
print(f'loaded epoch {epoch}')
for img in img_dir:
if 'ir' not in img.path: if 'ir' not in img.path:
continue continue
# alternatively: use img.path for native size # alternatively: use img.path for native size
downsize_input_img() downsize_input_img(img.path)
data = put_image('buffer.png') # asyncio.run(do_inference())
if 'input' in data: await do_inference()
pred_disp, in_img, ref_pat, duration = extract_data(data) await asyncio.sleep(10)
else:
pred_disp, duration = extract_data(data)
async def main():
print(f'inference took {duration:1.4f}s') use_live_data = True
print(f'pipeline and transfer took another {(datetime.now() - start).total_seconds() - float(duration):1.4f}s') iterate_checkpoints = False
print(f"Pred. Disparity: \n\t{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}\n") # change_epoch(good_models[1])
# change_epoch('latest')
if 'input' in data: o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Info)
cv2.imshow('Input Image', in_img)
cv2.imshow('Reference Image', ref_pat) change_epoch(150)
cv2.imshow('Normalized Predicted Disparity', normalize_and_colormap(pred_disp)) change_minimal_data(False)
cv2.imshow('Predicted Disparity', pred_disp) await asyncio.sleep(50000)
key = cv2.waitKey()
# signal.signal(signal.SIGBUS, lambda x: print('received sigbus'))
if key == 113: # loop = asyncio.get_running_loop()
quit() # loop.run_forever()
elif key == 101: # loop = asyncio.get_event_loop()
change_epoch() while True:
elif key == 109: # create_task()
change_minimal_data('input' not in data) # await asyncio.sleep(0.1)
# await run_test(img_dir, iterate_checkpoints)
await asyncio.sleep(50000)
# print('[main] slept')
# if use_live_data:
# signal.pause()
# else:
# await run_test(img_dir, iterate_checkpoints)
#
if __name__ == '__main__':
asyncio.run(main())

Loading…
Cancel
Save