diff --git a/frontend/__init__.py b/frontend/__init__.py index ef58e2a..6f51964 100644 --- a/frontend/__init__.py +++ b/frontend/__init__.py @@ -1,4 +1,11 @@ -import requests +import signal +from time import sleep + +# import requests +import httpx as requests +import asyncio +import open3d as o3d + from cv2 import cv2 import numpy as np import json @@ -13,6 +20,27 @@ img_dir = '../../usable_imgs/' cv2.namedWindow('Input Image') cv2.namedWindow('Predicted Disparity') +vis = o3d.visualization.VisualizerWithKeyCallback() +viscont = o3d.visualization.ViewControl() +# vis.register_key_callback(99) +vis.create_window() + +K = np.array([[567.6, 0, 324.7], [0, 570.2, 250.1], [0, 0, 1]], dtype=np.float32) + +# temporal_init = requests.get(f'{API_URL}/temporal_init') + +good_models = [260, 183] +interesting = [214, ] +# new ganz gut bei ca 175 +verbose = False + +running_tasks = set() +minimal_data = False + +with open('frontend.pid', 'w+') as f: + print('writing pid') + f.write(str(os.getpid())) + # epoch 75 ist weird @@ -24,6 +52,34 @@ class NumpyEncoder(json.JSONEncoder): return json.JSONEncoder.default(self, obj) +def update_vis(*args): + vis.poll_events() + vis.update_renderer() + + +# signal.signal(signal.SIGALRM, update_vis) +# signal.setitimer(signal.ITIMER_REAL, 0.1, 0.1) + + +def ghetto_lcn(img): + # gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + gray = img + + float_gray = gray.astype(np.float32) / 255.0 + + blur = cv2.GaussianBlur(float_gray, (0, 0), sigmaX=2, sigmaY=2) + num = float_gray - blur + + blur = cv2.GaussianBlur(num * num, (0, 0), sigmaX=20, sigmaY=20) + den = cv2.pow(blur, 0.5) + + gray = num / den + + # cv2.normalize(gray, dst=gray, alpha=0.0, beta=1.0, norm_type=cv2.NORM_MINMAX) + cv2.normalize(gray, dst=gray, alpha=0.0, beta=255.0, norm_type=cv2.NORM_MINMAX) + return gray + + def normalize_and_colormap(img): ret = (img - img.min()) / (img.max() - img.min()) * 255.0 ret = ret.astype("uint8") @@ -31,10 +87,74 @@ def normalize_and_colormap(img): return ret -def change_epoch(): - epoch = input('Enter epoch number or "latest"\n') +def reproject(disparity_img): + print('reprojecting') + baseline = 0.075 + depth_img = baseline * K[0][0] / disparity_img + pointcloud = o3d.geometry.PointCloud() + intrinsics = o3d.pybind.camera.PinholeCameraIntrinsic() + print('setting intrinsics') + intrinsics.set_intrinsics(width=640, height=480, fx=K[0][0], fy=K[1][1], cx=0., cy=0.) + # depth = open3d.geometry.Image(depth_img.astype('float32')) + rgb = normalize_and_colormap(disparity_img) + rgb = o3d.geometry.Image(rgb * 255) + print(depth_img.max(), depth_img.min()) + depth_img = np.log(depth_img + (1 - depth_img.min()) + 1) + print(depth_img.max(), depth_img.min()) + depth = o3d.geometry.Image(depth_img.astype('float32')) + + rgb_depth = o3d.geometry.RGBDImage().create_from_color_and_depth( + color=rgb, + depth=depth, + depth_scale=1, + convert_rgb_to_intensity=False, + ) + print('creating pointcloud') + # depth = open3d.cpu.pybind.t.geometry.Image(depth_img.astype('float32')) + # depth.colorize_depth(1.0, 0., 1.) + # print('now really creating pointcloud') + # dpcd = pointcloud.create_from_depth_image( + # depth=depth, + # intrinsic=intrinsics, + # ) + # print(type(depth)) + + pcd = pointcloud.create_from_rgbd_image( + image=rgb_depth, + intrinsic=intrinsics, + # project_valid_depth_only=False, + ) + + flip_transform = [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + # dpcd.paint_uniform_color(np.asarray([0.5, 0.4, 0.25])) + pcd.transform(flip_transform) + # dpcd.transform(flip_transform) + pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=2.0) + print('drawing pointcloud') + # vis.clear_geometries() + # vis.add_geometry(pcd, reset_bounding_box=True) + # viscont.rotate(250, 500) + vis.update_geometry(pcd) + vis.poll_events() + vis.update_renderer() + + # vis.run() + # o3d.visualization.draw(geometry=[rgb_depth]) + # o3d.visualization.draw([dpcd]) + + +def change_epoch(epoch: int = None): + if epoch is None: + epoch = input('Enter epoch number or "latest"\n') r = requests.post(f'{API_URL}/model/update/{epoch}') - print(r.text) + # print(r.text) + + +def change_reference(): + r = requests.post(f'{API_URL}/params/update_reference') + print(r.json()['status']) + if r.json()['status'] == 'finished': + change_reference() def extract_data(data): @@ -42,72 +162,190 @@ def extract_data(data): duration = data['duration'] # get result and rotate 90 deg - pred_disp = cv2.transpose(np.asarray(data['disp'], dtype='uint8')) + # pred_disp = cv2.transpose(np.asarray(data['disp'], dtype='uint8')) + raw_disp = np.asarray(data['disp']) + # print(raw_disp.min(), raw_disp.max()) + if raw_disp.min() < 0: + # print('Negative disparity detected. shifting...') + raw_disp = raw_disp - raw_disp.min() + if raw_disp.max() > 255: + # print('Excessive disparity detected. scaling...') + raw_disp = raw_disp / (raw_disp.max() / 255) + pred_disp = np.asarray(raw_disp, dtype='uint8') - if 'input' not in data: + # if 'input' not in data: + if len(data) == 2: return pred_disp, duration - - in_img = np.asarray(data['input'], dtype='uint8').transpose((2, 0, 1)) - ref_pat = np.asarray(data['reference'], dtype='uint8').transpose((2, 0, 1)) + ref_pat = data.get('reference', None) + in_img = np.asarray(data['input'], dtype='uint8') # .transpose((2, 0, 1)) + if ref_pat: + ref_pat = np.asarray(ref_pat, dtype='uint8') # .transpose((2, 0, 1)) return pred_disp, in_img, ref_pat, duration -def downsize_input_img(): - input_img = cv2.imread(img.path) +def downsize_input_img(path): + input_img = None + while input_img is None: + input_img = cv2.imread(path) if input_img.shape == (1024, 1280, 3): diff = (512 - 480) // 2 downsampled = cv2.pyrDown(input_img) input_img = downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]] + # print(input_img.shape) + input_img = cv2.normalize(input_img, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) + # input_img = ghetto_lcn(input_img) cv2.imwrite('buffer.png', input_img) -def put_image(img_path): +async def put_image(img_path): openBin = {'file': ('file', open(img_path, 'rb'), 'image/png')} - print('sending image') - r = requests.put(f'{API_URL}/ir', files=openBin) - print('received response') + if verbose: + print('sending image') + async with requests.AsyncClient() as client: + r = await client.put(f'{API_URL}/ir', files=openBin) + if verbose: + print('received response') r.raise_for_status() data = json.loads(json.loads(r.text)) return data -def change_minimal_data(enabled): - r = requests.post(f'{API_URL}/params/minimal_data/{not enabled}') +def change_minimal_data(current: bool = None): + global minimal_data + if current is None: + current = minimal_data + minimal_data = not current + r = requests.post(f'{API_URL}/params/minimal_data/{minimal_data}') cv2.destroyWindow('Input Image') cv2.destroyWindow('Reference Image') -if __name__ == '__main__': - while True: - for img in os.scandir(img_dir): - start = datetime.now() +def change_temporal_init(enabled): + global temporal_init + r = requests.post(f'{API_URL}/params/temporal_init/{not enabled}') + temporal_init = not temporal_init + + +def handle_keypress(key): + if key == 113: + quit() + elif key == 101: + change_epoch() + elif key == 109: + change_minimal_data() + elif key == 116: + change_temporal_init(temporal_init) + elif key == 99: + change_reference() + + +async def do_inference(): + start = datetime.now() + data = await put_image('buffer.png') + in_img = None + ref_pat = None + if len(data) == 4: + pred_disp, in_img, ref_pat, duration = extract_data(data) + elif len(data) == 2: + pred_disp, duration = extract_data(data) + reproject(pred_disp) + show_results(duration, in_img, pred_disp, ref_pat, start) + # reproject(pred_disp) + + +def show_results(duration, in_img, pred_disp, ref_pat, start): + print(f"Pred. Disparity: \n\t{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}") + if verbose: + print(f'inference took {duration:1.4f}s') + print(f'pipeline and transfer took another {(datetime.now() - start).total_seconds() - float(duration):1.4f}s') + print(f'total {(datetime.now() - start).total_seconds():1.4f}s') + if in_img is not None: + cv2.imshow('Input Image', in_img) + else: + cv2.imshow('Input Image', cv2.imread('buffer.png')) + if ref_pat is not None: + cv2.imshow('Reference Image', ref_pat) + cv2.imshow('Normalized Predicted Disparity', normalize_and_colormap(pred_disp)) + cv2.imshow('Predicted Disparity', pred_disp) + key = cv2.waitKey(1000) + handle_keypress(key) + + +async def fresh_img(): + # print('running task') + start = datetime.now() + print(f'started at {start}') + downsize_input_img('kinect_ir.png') + await do_inference() + print(f'task took {(datetime.now() - start).total_seconds()}') + print() + + +def create_task(*args): + global running_tasks + # print('received signal') + print(f'currently running: {len(running_tasks)} tasks') + task = asyncio.create_task(fresh_img()) + # print(f'created task {task.get_name()}') + running_tasks.add(task) + task.add_done_callback(running_tasks.discard) + # await task + # return task + + +signal.signal(signal.SIGUSR1, create_task) + + +async def run_test(img_dir, iterate_checkpoints): + img_dir = list(os.scandir(img_dir)) + for epoch in range(175, 270): + if iterate_checkpoints: + change_epoch(epoch) + print() + print(f'loaded epoch {epoch}') + for img in img_dir: if 'ir' not in img.path: continue # alternatively: use img.path for native size - downsize_input_img() - - data = put_image('buffer.png') - if 'input' in data: - pred_disp, in_img, ref_pat, duration = extract_data(data) - else: - pred_disp, duration = extract_data(data) - - print(f'inference took {duration:1.4f}s') - print(f'pipeline and transfer took another {(datetime.now() - start).total_seconds() - float(duration):1.4f}s') - print(f"Pred. Disparity: \n\t{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}\n") - - if 'input' in data: - cv2.imshow('Input Image', in_img) - cv2.imshow('Reference Image', ref_pat) - cv2.imshow('Normalized Predicted Disparity', normalize_and_colormap(pred_disp)) - cv2.imshow('Predicted Disparity', pred_disp) - key = cv2.waitKey() - - if key == 113: - quit() - elif key == 101: - change_epoch() - elif key == 109: - change_minimal_data('input' not in data) + downsize_input_img(img.path) + + # asyncio.run(do_inference()) + await do_inference() + await asyncio.sleep(10) + + +async def main(): + use_live_data = True + iterate_checkpoints = False + # change_epoch(good_models[1]) + # change_epoch('latest') + o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Info) + + change_epoch(150) + change_minimal_data(False) + await asyncio.sleep(50000) + + # signal.signal(signal.SIGBUS, lambda x: print('received sigbus')) + # loop = asyncio.get_running_loop() + # loop.run_forever() + # loop = asyncio.get_event_loop() + while True: + # create_task() + # await asyncio.sleep(0.1) + # await run_test(img_dir, iterate_checkpoints) + await asyncio.sleep(50000) + # print('[main] slept') + # if use_live_data: + # signal.pause() + + # else: + # await run_test(img_dir, iterate_checkpoints) + + +# + + +if __name__ == '__main__': + asyncio.run(main())