From 43df77fb9bbdec58cf4fb2e8b1f69d9c7800de9e Mon Sep 17 00:00:00 2001 From: CptCaptain Date: Mon, 15 Nov 2021 16:53:30 +0100 Subject: [PATCH] Reformat $EVERYTHING --- co/__init__.py | 3 +- co/args.py | 13 +- co/cmap.py | 61 +- co/geometry.py | 1336 ++++++++++++++++--------------- co/gtimer.py | 42 +- co/io3d.py | 471 +++++------ co/metric.py | 426 +++++----- co/plt.py | 161 ++-- co/plt2d.py | 95 +-- co/plt3d.py | 68 +- co/table.py | 804 ++++++++++--------- co/utils.py | 108 +-- data/commons.py | 138 ++-- data/create_syn_data.py | 454 ++++++----- data/dataset.py | 238 +++--- data/lcn/lcn.html | 395 ++++++--- data/lcn/setup.py | 2 +- data/lcn/test_lcn.py | 26 +- hyperdepth/hyperparam_search.py | 65 +- hyperdepth/setup.py | 37 +- hyperdepth/vis_eval.py | 13 +- model/exp_synph.py | 475 ++++++----- model/exp_synphge.py | 597 +++++++------- model/networks.py | 953 +++++++++++----------- readme.md | 68 +- renderer/setup.py | 36 +- torchext/dataset.py | 92 +-- torchext/functions.py | 233 +++--- torchext/modules.py | 36 +- torchext/setup.py | 15 +- torchext/worker.py | 960 +++++++++++----------- train_val.py | 15 +- 32 files changed, 4429 insertions(+), 4007 deletions(-) diff --git a/co/__init__.py b/co/__init__.py index 3dcfa5d..f256f9b 100644 --- a/co/__init__.py +++ b/co/__init__.py @@ -7,8 +7,9 @@ # set matplotlib backend depending on env import os import matplotlib + if os.name == 'posix' and "DISPLAY" not in os.environ: - matplotlib.use('Agg') + matplotlib.use('Agg') from . import geometry from . import plt diff --git a/co/args.py b/co/args.py index c55ba7b..19fe75e 100644 --- a/co/args.py +++ b/co/args.py @@ -12,14 +12,14 @@ def parse_args(): parser.add_argument('--loss', help='Train with \'ph\' for the first stage without geometric loss, \ train with \'phge\' for the second stage with geometric loss', - default='ph', choices=['ph','phge'], type=str) + default='ph', choices=['ph', 'phge'], type=str) parser.add_argument('--data_type', default='syn', choices=['syn'], type=str) # - parser.add_argument('--cmd', - help='Start training or test', + parser.add_argument('--cmd', + help='Start training or test', default='resume', choices=['retrain', 'resume', 'retest', 'test_init'], type=str) - parser.add_argument('--epoch', + parser.add_argument('--epoch', help='If larger than -1, retest on the specified epoch', default=-1, type=int) parser.add_argument('--epochs', @@ -55,7 +55,7 @@ def parse_args(): parser.add_argument('--blend_im', help='Parameter for adding texture', default=0.6, type=float) - + args = parser.parse_args() args.exp_name = get_exp_name(args) @@ -66,6 +66,3 @@ def parse_args(): def get_exp_name(args): name = f"exp_{args.data_type}" return name - - - diff --git a/co/cmap.py b/co/cmap.py index 58a7e17..78bfb67 100644 --- a/co/cmap.py +++ b/co/cmap.py @@ -1,19 +1,20 @@ import numpy as np _color_map_errors = np.array([ - [149, 54, 49], #0: log2(x) = -infinity - [180, 117, 69], #0.0625: log2(x) = -4 - [209, 173, 116], #0.125: log2(x) = -3 - [233, 217, 171], #0.25: log2(x) = -2 - [248, 243, 224], #0.5: log2(x) = -1 - [144, 224, 254], #1.0: log2(x) = 0 - [97, 174, 253], #2.0: log2(x) = 1 - [67, 109, 244], #4.0: log2(x) = 2 - [39, 48, 215], #8.0: log2(x) = 3 - [38, 0, 165], #16.0: log2(x) = 4 - [38, 0, 165] #inf: log2(x) = inf + [149, 54, 49], # 0: log2(x) = -infinity + [180, 117, 69], # 0.0625: log2(x) = -4 + [209, 173, 116], # 0.125: log2(x) = -3 + [233, 217, 171], # 0.25: log2(x) = -2 + [248, 243, 224], # 0.5: log2(x) = -1 + [144, 224, 254], # 1.0: log2(x) = 0 + [97, 174, 253], # 2.0: log2(x) = 1 + [67, 109, 244], # 4.0: log2(x) = 2 + [39, 48, 215], # 8.0: log2(x) = 3 + [38, 0, 165], # 16.0: log2(x) = 4 + [38, 0, 165] # inf: log2(x) = inf ]).astype(float) + def color_error_image(errors, scale=1, mask=None, BGR=True): """ Color an input error map. @@ -27,31 +28,33 @@ def color_error_image(errors, scale=1, mask=None, BGR=True): Returns: colored_errors -- HxWx3 numpy array visualizing the errors """ - + errors_flat = errors.flatten() errors_color_indices = np.clip(np.log2(errors_flat / scale + 1e-5) + 5, 0, 9) i0 = np.floor(errors_color_indices).astype(int) f1 = errors_color_indices - i0.astype(float) - colored_errors_flat = _color_map_errors[i0, :] * (1-f1).reshape(-1,1) + _color_map_errors[i0+1, :] * f1.reshape(-1,1) + colored_errors_flat = _color_map_errors[i0, :] * (1 - f1).reshape(-1, 1) + _color_map_errors[i0 + 1, + :] * f1.reshape(-1, 1) if mask is not None: colored_errors_flat[mask.flatten() == 0] = 255 if not BGR: - colored_errors_flat = colored_errors_flat[:,[2,1,0]] + colored_errors_flat = colored_errors_flat[:, [2, 1, 0]] return colored_errors_flat.reshape(errors.shape[0], errors.shape[1], 3).astype(np.int) + _color_map_depths = np.array([ - [0, 0, 0], # 0.000 - [0, 0, 255], # 0.114 - [255, 0, 0], # 0.299 - [255, 0, 255], # 0.413 - [0, 255, 0], # 0.587 - [0, 255, 255], # 0.701 - [255, 255, 0], # 0.886 - [255, 255, 255], # 1.000 - [255, 255, 255], # 1.000 + [0, 0, 0], # 0.000 + [0, 0, 255], # 0.114 + [255, 0, 0], # 0.299 + [255, 0, 255], # 0.413 + [0, 255, 0], # 0.587 + [0, 255, 255], # 0.701 + [255, 255, 0], # 0.886 + [255, 255, 255], # 1.000 + [255, 255, 255], # 1.000 ]).astype(float) _color_map_bincenters = np.array([ 0.0, @@ -62,9 +65,10 @@ _color_map_bincenters = np.array([ 0.701, 0.886, 1.000, - 2.000, # doesn't make a difference, just strictly higher than 1 + 2.000, # doesn't make a difference, just strictly higher than 1 ]) + def color_depth_map(depths, scale=None): """ Color an input depth map. @@ -82,12 +86,13 @@ def color_depth_map(depths, scale=None): values = np.clip(depths.flatten() / scale, 0, 1) # for each value, figure out where they fit in in the bincenters: what is the last bincenter smaller than this value? - lower_bin = ((values.reshape(-1, 1) >= _color_map_bincenters.reshape(1,-1)) * np.arange(0,9)).max(axis=1) + lower_bin = ((values.reshape(-1, 1) >= _color_map_bincenters.reshape(1, -1)) * np.arange(0, 9)).max(axis=1) lower_bin_value = _color_map_bincenters[lower_bin] higher_bin_value = _color_map_bincenters[lower_bin + 1] alphas = (values - lower_bin_value) / (higher_bin_value - lower_bin_value) - colors = _color_map_depths[lower_bin] * (1-alphas).reshape(-1,1) + _color_map_depths[lower_bin + 1] * alphas.reshape(-1,1) + colors = _color_map_depths[lower_bin] * (1 - alphas).reshape(-1, 1) + _color_map_depths[ + lower_bin + 1] * alphas.reshape(-1, 1) return colors.reshape(depths.shape[0], depths.shape[1], 3).astype(np.uint8) -#from utils.debug import save_color_numpy -#save_color_numpy(color_depth_map(np.matmul(np.ones((100,1)), np.arange(0,1200).reshape(1,1200)), scale=1000)) +# from utils.debug import save_color_numpy +# save_color_numpy(color_depth_map(np.matmul(np.ones((100,1)), np.arange(0,1200).reshape(1,1200)), scale=1000)) diff --git a/co/geometry.py b/co/geometry.py index 5cf194e..6295dc8 100644 --- a/co/geometry.py +++ b/co/geometry.py @@ -2,799 +2,849 @@ import numpy as np import scipy.spatial import scipy.linalg + def nullspace(A, atol=1e-13, rtol=0): - u, s, vh = np.linalg.svd(A) - tol = max(atol, rtol * s[0]) - nnz = (s >= tol).sum() - ns = vh[nnz:].conj().T - return ns + u, s, vh = np.linalg.svd(A) + tol = max(atol, rtol * s[0]) + nnz = (s >= tol).sum() + ns = vh[nnz:].conj().T + return ns + def nearest_orthogonal_matrix(R): - U,S,Vt = np.linalg.svd(R) - return U @ np.eye(3,dtype=R.dtype) @ Vt + U, S, Vt = np.linalg.svd(R) + return U @ np.eye(3, dtype=R.dtype) @ Vt + def power_iters(A, n_iters=10): - b = np.random.uniform(-1,1, size=(A.shape[0], A.shape[1], 1)) - for iter in range(n_iters): - b = A @ b - b = b / np.linalg.norm(b, axis=1, keepdims=True) - return b + b = np.random.uniform(-1, 1, size=(A.shape[0], A.shape[1], 1)) + for iter in range(n_iters): + b = A @ b + b = b / np.linalg.norm(b, axis=1, keepdims=True) + return b + def rayleigh_quotient(A, b): - return (b.transpose(0,2,1) @ A @ b) / (b.transpose(0,2,1) @ b) + return (b.transpose(0, 2, 1) @ A @ b) / (b.transpose(0, 2, 1) @ b) def cross_prod_mat(x): - x = x.reshape(-1,3) - X = np.empty((x.shape[0],3,3), dtype=x.dtype) - X[:,0,0] = 0 - X[:,0,1] = -x[:,2] - X[:,0,2] = x[:,1] - X[:,1,0] = x[:,2] - X[:,1,1] = 0 - X[:,1,2] = -x[:,0] - X[:,2,0] = -x[:,1] - X[:,2,1] = x[:,0] - X[:,2,2] = 0 - return X.squeeze() + x = x.reshape(-1, 3) + X = np.empty((x.shape[0], 3, 3), dtype=x.dtype) + X[:, 0, 0] = 0 + X[:, 0, 1] = -x[:, 2] + X[:, 0, 2] = x[:, 1] + X[:, 1, 0] = x[:, 2] + X[:, 1, 1] = 0 + X[:, 1, 2] = -x[:, 0] + X[:, 2, 0] = -x[:, 1] + X[:, 2, 1] = x[:, 0] + X[:, 2, 2] = 0 + return X.squeeze() + def hat_operator(x): - return cross_prod_mat(x) + return cross_prod_mat(x) + def vee_operator(X): - X = X.reshape(-1,3,3) - x = np.empty((X.shape[0], 3), dtype=X.dtype) - x[:,0] = X[:,2,1] - x[:,1] = X[:,0,2] - x[:,2] = X[:,1,0] - return x.squeeze() + X = X.reshape(-1, 3, 3) + x = np.empty((X.shape[0], 3), dtype=X.dtype) + x[:, 0] = X[:, 2, 1] + x[:, 1] = X[:, 0, 2] + x[:, 2] = X[:, 1, 0] + return x.squeeze() def rot_x(x, dtype=np.float32): - x = np.array(x, copy=False) - x = x.reshape(-1,1) - R = np.zeros((x.shape[0],3,3), dtype=dtype) - R[:,0,0] = 1 - R[:,1,1] = np.cos(x).ravel() - R[:,1,2] = -np.sin(x).ravel() - R[:,2,1] = np.sin(x).ravel() - R[:,2,2] = np.cos(x).ravel() - return R.squeeze() + x = np.array(x, copy=False) + x = x.reshape(-1, 1) + R = np.zeros((x.shape[0], 3, 3), dtype=dtype) + R[:, 0, 0] = 1 + R[:, 1, 1] = np.cos(x).ravel() + R[:, 1, 2] = -np.sin(x).ravel() + R[:, 2, 1] = np.sin(x).ravel() + R[:, 2, 2] = np.cos(x).ravel() + return R.squeeze() + def rot_y(y, dtype=np.float32): - y = np.array(y, copy=False) - y = y.reshape(-1,1) - R = np.zeros((y.shape[0],3,3), dtype=dtype) - R[:,0,0] = np.cos(y).ravel() - R[:,0,2] = np.sin(y).ravel() - R[:,1,1] = 1 - R[:,2,0] = -np.sin(y).ravel() - R[:,2,2] = np.cos(y).ravel() - return R.squeeze() + y = np.array(y, copy=False) + y = y.reshape(-1, 1) + R = np.zeros((y.shape[0], 3, 3), dtype=dtype) + R[:, 0, 0] = np.cos(y).ravel() + R[:, 0, 2] = np.sin(y).ravel() + R[:, 1, 1] = 1 + R[:, 2, 0] = -np.sin(y).ravel() + R[:, 2, 2] = np.cos(y).ravel() + return R.squeeze() + def rot_z(z, dtype=np.float32): - z = np.array(z, copy=False) - z = z.reshape(-1,1) - R = np.zeros((z.shape[0],3,3), dtype=dtype) - R[:,0,0] = np.cos(z).ravel() - R[:,0,1] = -np.sin(z).ravel() - R[:,1,0] = np.sin(z).ravel() - R[:,1,1] = np.cos(z).ravel() - R[:,2,2] = 1 - return R.squeeze() + z = np.array(z, copy=False) + z = z.reshape(-1, 1) + R = np.zeros((z.shape[0], 3, 3), dtype=dtype) + R[:, 0, 0] = np.cos(z).ravel() + R[:, 0, 1] = -np.sin(z).ravel() + R[:, 1, 0] = np.sin(z).ravel() + R[:, 1, 1] = np.cos(z).ravel() + R[:, 2, 2] = 1 + return R.squeeze() + def xyz_from_rotm(R): - R = R.reshape(-1,3,3) - xyz = np.empty((R.shape[0],3), dtype=R.dtype) - for bidx in range(R.shape[0]): - if R[bidx,0,2] < 1: - if R[bidx,0,2] > -1: - xyz[bidx,1] = np.arcsin(R[bidx,0,2]) - xyz[bidx,0] = np.arctan2(-R[bidx,1,2], R[bidx,2,2]) - xyz[bidx,2] = np.arctan2(-R[bidx,0,1], R[bidx,0,0]) - else: - xyz[bidx,1] = -np.pi/2 - xyz[bidx,0] = -np.arctan2(R[bidx,1,0],R[bidx,1,1]) - xyz[bidx,2] = 0 - else: - xyz[bidx,1] = np.pi/2 - xyz[bidx,0] = np.arctan2(R[bidx,1,0], R[bidx,1,1]) - xyz[bidx,2] = 0 - return xyz.squeeze() + R = R.reshape(-1, 3, 3) + xyz = np.empty((R.shape[0], 3), dtype=R.dtype) + for bidx in range(R.shape[0]): + if R[bidx, 0, 2] < 1: + if R[bidx, 0, 2] > -1: + xyz[bidx, 1] = np.arcsin(R[bidx, 0, 2]) + xyz[bidx, 0] = np.arctan2(-R[bidx, 1, 2], R[bidx, 2, 2]) + xyz[bidx, 2] = np.arctan2(-R[bidx, 0, 1], R[bidx, 0, 0]) + else: + xyz[bidx, 1] = -np.pi / 2 + xyz[bidx, 0] = -np.arctan2(R[bidx, 1, 0], R[bidx, 1, 1]) + xyz[bidx, 2] = 0 + else: + xyz[bidx, 1] = np.pi / 2 + xyz[bidx, 0] = np.arctan2(R[bidx, 1, 0], R[bidx, 1, 1]) + xyz[bidx, 2] = 0 + return xyz.squeeze() + def zyx_from_rotm(R): - R = R.reshape(-1,3,3) - zyx = np.empty((R.shape[0],3), dtype=R.dtype) - for bidx in range(R.shape[0]): - if R[bidx,2,0] < 1: - if R[bidx,2,0] > -1: - zyx[bidx,1] = np.arcsin(-R[bidx,2,0]) - zyx[bidx,0] = np.arctan2(R[bidx,1,0], R[bidx,0,0]) - zyx[bidx,2] = np.arctan2(R[bidx,2,1], R[bidx,2,2]) - else: - zyx[bidx,1] = np.pi / 2 - zyx[bidx,0] = -np.arctan2(-R[bidx,1,2], R[bidx,1,1]) - zyx[bidx,2] = 0 - else: - zyx[bidx,1] = -np.pi / 2 - zyx[bidx,0] = np.arctan2(-R[bidx,1,2], R[bidx,1,1]) - zyx[bidx,2] = 0 - return zyx.squeeze() + R = R.reshape(-1, 3, 3) + zyx = np.empty((R.shape[0], 3), dtype=R.dtype) + for bidx in range(R.shape[0]): + if R[bidx, 2, 0] < 1: + if R[bidx, 2, 0] > -1: + zyx[bidx, 1] = np.arcsin(-R[bidx, 2, 0]) + zyx[bidx, 0] = np.arctan2(R[bidx, 1, 0], R[bidx, 0, 0]) + zyx[bidx, 2] = np.arctan2(R[bidx, 2, 1], R[bidx, 2, 2]) + else: + zyx[bidx, 1] = np.pi / 2 + zyx[bidx, 0] = -np.arctan2(-R[bidx, 1, 2], R[bidx, 1, 1]) + zyx[bidx, 2] = 0 + else: + zyx[bidx, 1] = -np.pi / 2 + zyx[bidx, 0] = np.arctan2(-R[bidx, 1, 2], R[bidx, 1, 1]) + zyx[bidx, 2] = 0 + return zyx.squeeze() + def rotm_from_xyz(xyz): - xyz = np.array(xyz, copy=False).reshape(-1,3) - return (rot_x(xyz[:,0]) @ rot_y(xyz[:,1]) @ rot_z(xyz[:,2])).squeeze() + xyz = np.array(xyz, copy=False).reshape(-1, 3) + return (rot_x(xyz[:, 0]) @ rot_y(xyz[:, 1]) @ rot_z(xyz[:, 2])).squeeze() + def rotm_from_zyx(zyx): - zyx = np.array(zyx, copy=False).reshape(-1,3) - return (rot_z(zyx[:,0]) @ rot_y(zyx[:,1]) @ rot_x(zyx[:,2])).squeeze() + zyx = np.array(zyx, copy=False).reshape(-1, 3) + return (rot_z(zyx[:, 0]) @ rot_y(zyx[:, 1]) @ rot_x(zyx[:, 2])).squeeze() + def rotm_from_quat(q): - q = q.reshape(-1,4) - w, x, y, z = q[:,0], q[:,1], q[:,2], q[:,3] - R = np.array([ - [1 - 2*y*y - 2*z*z, 2*x*y - 2*z*w, 2*x*z + 2*y*w], - [2*x*y + 2*z*w, 1 - 2*x*x - 2*z*z, 2*y*z - 2*x*w], - [2*x*z - 2*y*w, 2*y*z + 2*x*w, 1 - 2*x*x - 2*y*y] - ], dtype=q.dtype) - R = R.transpose((2,0,1)) - return R.squeeze() + q = q.reshape(-1, 4) + w, x, y, z = q[:, 0], q[:, 1], q[:, 2], q[:, 3] + R = np.array([ + [1 - 2 * y * y - 2 * z * z, 2 * x * y - 2 * z * w, 2 * x * z + 2 * y * w], + [2 * x * y + 2 * z * w, 1 - 2 * x * x - 2 * z * z, 2 * y * z - 2 * x * w], + [2 * x * z - 2 * y * w, 2 * y * z + 2 * x * w, 1 - 2 * x * x - 2 * y * y] + ], dtype=q.dtype) + R = R.transpose((2, 0, 1)) + return R.squeeze() + def rotm_from_axisangle(a): - # exponential - a = a.reshape(-1,3) - phi = np.linalg.norm(a, axis=1).reshape(-1,1,1) - iphi = np.zeros_like(phi) - np.divide(1, phi, out=iphi, where=phi != 0) - A = cross_prod_mat(a) * iphi - R = np.eye(3, dtype=a.dtype) + np.sin(phi) * A + (1 - np.cos(phi)) * A @ A - return R.squeeze() + # exponential + a = a.reshape(-1, 3) + phi = np.linalg.norm(a, axis=1).reshape(-1, 1, 1) + iphi = np.zeros_like(phi) + np.divide(1, phi, out=iphi, where=phi != 0) + A = cross_prod_mat(a) * iphi + R = np.eye(3, dtype=a.dtype) + np.sin(phi) * A + (1 - np.cos(phi)) * A @ A + return R.squeeze() + def rotm_from_lookat(dir, up=None): - dir = dir.reshape(-1,3) - if up is None: - up = np.zeros_like(dir) - up[:,1] = 1 - dir /= np.linalg.norm(dir, axis=1, keepdims=True) - up /= np.linalg.norm(up, axis=1, keepdims=True) - x = dir[:,None,:] @ cross_prod_mat(up).transpose(0,2,1) - y = x @ cross_prod_mat(dir).transpose(0,2,1) - x = x.squeeze() - y = y.squeeze() - x /= np.linalg.norm(x, axis=1, keepdims=True) - y /= np.linalg.norm(y, axis=1, keepdims=True) - R = np.empty((dir.shape[0],3,3), dtype=dir.dtype) - R[:,0,0] = x[:,0] - R[:,0,1] = y[:,0] - R[:,0,2] = dir[:,0] - R[:,1,0] = x[:,1] - R[:,1,1] = y[:,1] - R[:,1,2] = dir[:,1] - R[:,2,0] = x[:,2] - R[:,2,1] = y[:,2] - R[:,2,2] = dir[:,2] - return R.transpose(0,2,1).squeeze() + dir = dir.reshape(-1, 3) + if up is None: + up = np.zeros_like(dir) + up[:, 1] = 1 + dir /= np.linalg.norm(dir, axis=1, keepdims=True) + up /= np.linalg.norm(up, axis=1, keepdims=True) + x = dir[:, None, :] @ cross_prod_mat(up).transpose(0, 2, 1) + y = x @ cross_prod_mat(dir).transpose(0, 2, 1) + x = x.squeeze() + y = y.squeeze() + x /= np.linalg.norm(x, axis=1, keepdims=True) + y /= np.linalg.norm(y, axis=1, keepdims=True) + R = np.empty((dir.shape[0], 3, 3), dtype=dir.dtype) + R[:, 0, 0] = x[:, 0] + R[:, 0, 1] = y[:, 0] + R[:, 0, 2] = dir[:, 0] + R[:, 1, 0] = x[:, 1] + R[:, 1, 1] = y[:, 1] + R[:, 1, 2] = dir[:, 1] + R[:, 2, 0] = x[:, 2] + R[:, 2, 1] = y[:, 2] + R[:, 2, 2] = dir[:, 2] + return R.transpose(0, 2, 1).squeeze() + def rotm_distance_identity(R0, R1): - # https://link.springer.com/article/10.1007%2Fs10851-009-0161-2 - # in [0, 2*sqrt(2)] - R0 = R0.reshape(-1,3,3) - R1 = R1.reshape(-1,3,3) - dists = np.linalg.norm(np.eye(3,dtype=R0.dtype) - R0 @ R1.transpose(0,2,1), axis=(1,2)) - return dists.squeeze() + # https://link.springer.com/article/10.1007%2Fs10851-009-0161-2 + # in [0, 2*sqrt(2)] + R0 = R0.reshape(-1, 3, 3) + R1 = R1.reshape(-1, 3, 3) + dists = np.linalg.norm(np.eye(3, dtype=R0.dtype) - R0 @ R1.transpose(0, 2, 1), axis=(1, 2)) + return dists.squeeze() -def rotm_distance_geodesic(R0, R1): - # https://link.springer.com/article/10.1007%2Fs10851-009-0161-2 - # in [0, pi) - R0 = R0.reshape(-1,3,3) - R1 = R1.reshape(-1,3,3) - RtR = R0 @ R1.transpose(0,2,1) - aa = axisangle_from_rotm(RtR) - S = cross_prod_mat(aa).reshape(-1,3,3) - dists = np.linalg.norm(S, axis=(1,2)) - return dists.squeeze() +def rotm_distance_geodesic(R0, R1): + # https://link.springer.com/article/10.1007%2Fs10851-009-0161-2 + # in [0, pi) + R0 = R0.reshape(-1, 3, 3) + R1 = R1.reshape(-1, 3, 3) + RtR = R0 @ R1.transpose(0, 2, 1) + aa = axisangle_from_rotm(RtR) + S = cross_prod_mat(aa).reshape(-1, 3, 3) + dists = np.linalg.norm(S, axis=(1, 2)) + return dists.squeeze() def axisangle_from_rotm(R): - # logarithm of rotation matrix - # R = R.reshape(-1,3,3) - # tr = np.trace(R, axis1=1, axis2=2) - # phi = np.arccos(np.clip((tr - 1) / 2, -1, 1)) - # scale = np.zeros_like(phi) - # div = 2 * np.sin(phi) - # np.divide(phi, div, out=scale, where=np.abs(div) > 1e-6) - # A = (R - R.transpose(0,2,1)) * scale.reshape(-1,1,1) - # aa = np.stack((A[:,2,1], A[:,0,2], A[:,1,0]), axis=1) - # return aa.squeeze() - R = R.reshape(-1,3,3) - omega = np.empty((R.shape[0], 3), dtype=R.dtype) - omega[:,0] = R[:,2,1] - R[:,1,2] - omega[:,1] = R[:,0,2] - R[:,2,0] - omega[:,2] = R[:,1,0] - R[:,0,1] - r = np.linalg.norm(omega, axis=1).reshape(-1,1) - t = np.trace(R, axis1=1, axis2=2).reshape(-1,1) - omega = np.arctan2(r, t-1) * omega - aa = np.zeros_like(omega) - np.divide(omega, r, out=aa, where=r != 0) - return aa.squeeze() + # logarithm of rotation matrix + # R = R.reshape(-1,3,3) + # tr = np.trace(R, axis1=1, axis2=2) + # phi = np.arccos(np.clip((tr - 1) / 2, -1, 1)) + # scale = np.zeros_like(phi) + # div = 2 * np.sin(phi) + # np.divide(phi, div, out=scale, where=np.abs(div) > 1e-6) + # A = (R - R.transpose(0,2,1)) * scale.reshape(-1,1,1) + # aa = np.stack((A[:,2,1], A[:,0,2], A[:,1,0]), axis=1) + # return aa.squeeze() + R = R.reshape(-1, 3, 3) + omega = np.empty((R.shape[0], 3), dtype=R.dtype) + omega[:, 0] = R[:, 2, 1] - R[:, 1, 2] + omega[:, 1] = R[:, 0, 2] - R[:, 2, 0] + omega[:, 2] = R[:, 1, 0] - R[:, 0, 1] + r = np.linalg.norm(omega, axis=1).reshape(-1, 1) + t = np.trace(R, axis1=1, axis2=2).reshape(-1, 1) + omega = np.arctan2(r, t - 1) * omega + aa = np.zeros_like(omega) + np.divide(omega, r, out=aa, where=r != 0) + return aa.squeeze() + def axisangle_from_quat(q): - q = q.reshape(-1,4) - phi = 2 * np.arccos(q[:,0]) - denom = np.zeros_like(q[:,0]) - np.divide(1, np.sqrt(1 - q[:,0]**2), out=denom, where=q[:,0] != 1) - axis = q[:,1:] * denom.reshape(-1,1) - denom = np.linalg.norm(axis, axis=1).reshape(-1,1) - a = np.zeros_like(axis) - np.divide(phi.reshape(-1,1) * axis, denom, out=a, where=denom != 0) - aa = a.astype(q.dtype) - return aa.squeeze() + q = q.reshape(-1, 4) + phi = 2 * np.arccos(q[:, 0]) + denom = np.zeros_like(q[:, 0]) + np.divide(1, np.sqrt(1 - q[:, 0] ** 2), out=denom, where=q[:, 0] != 1) + axis = q[:, 1:] * denom.reshape(-1, 1) + denom = np.linalg.norm(axis, axis=1).reshape(-1, 1) + a = np.zeros_like(axis) + np.divide(phi.reshape(-1, 1) * axis, denom, out=a, where=denom != 0) + aa = a.astype(q.dtype) + return aa.squeeze() + def axisangle_apply(aa, x): - # working only with single aa and single x at the moment - xshape = x.shape - aa = aa.reshape(3,) - x = x.reshape(3,) - phi = np.linalg.norm(aa) - e = np.zeros_like(aa) - np.divide(aa, phi, out=e, where=phi != 0) - xr = np.cos(phi) * x + np.sin(phi) * np.cross(e, x) + (1 - np.cos(phi)) * (e.T @ x) * e - return xr.reshape(xshape) + # working only with single aa and single x at the moment + xshape = x.shape + aa = aa.reshape(3, ) + x = x.reshape(3, ) + phi = np.linalg.norm(aa) + e = np.zeros_like(aa) + np.divide(aa, phi, out=e, where=phi != 0) + xr = np.cos(phi) * x + np.sin(phi) * np.cross(e, x) + (1 - np.cos(phi)) * (e.T @ x) * e + return xr.reshape(xshape) def exp_so3(R): - w = axisangle_from_rotm(R) - return w + w = axisangle_from_rotm(R) + return w + def log_so3(w): - R = rotm_from_axisangle(w) - return R + R = rotm_from_axisangle(w) + return R + def exp_se3(R, t): - R = R.reshape(-1,3,3) - t = t.reshape(-1,3) + R = R.reshape(-1, 3, 3) + t = t.reshape(-1, 3) + + w = exp_so3(R).reshape(-1, 3) - w = exp_so3(R).reshape(-1,3) + phi = np.linalg.norm(w, axis=1).reshape(-1, 1, 1) + A = cross_prod_mat(w) + Vi = np.eye(3, dtype=R.dtype) - A / 2 + (1 - (phi * np.sin(phi) / (2 * (1 - np.cos(phi))))) / phi ** 2 * A @ A + u = t.reshape(-1, 1, 3) @ Vi.transpose(0, 2, 1) - phi = np.linalg.norm(w, axis=1).reshape(-1,1,1) - A = cross_prod_mat(w) - Vi = np.eye(3, dtype=R.dtype) - A/2 + (1 - (phi * np.sin(phi) / (2 * (1 - np.cos(phi))))) / phi**2 * A @ A - u = t.reshape(-1,1,3) @ Vi.transpose(0,2,1) + # v = (u, w) + v = np.empty((R.shape[0], 6), dtype=R.dtype) + v[:, :3] = u.squeeze() + v[:, 3:] = w - # v = (u, w) - v = np.empty((R.shape[0],6), dtype=R.dtype) - v[:,:3] = u.squeeze() - v[:,3:] = w + return v.squeeze() - return v.squeeze() def log_se3(v): - # v = (u, w) - v = v.reshape(-1,6) - u = v[:,:3] - w = v[:,3:] + # v = (u, w) + v = v.reshape(-1, 6) + u = v[:, :3] + w = v[:, 3:] - R = log_so3(w) + R = log_so3(w) - phi = np.linalg.norm(w, axis=1).reshape(-1,1,1) - A = cross_prod_mat(w) - V = np.eye(3, dtype=v.dtype) + (1 - np.cos(phi)) / phi**2 * A + (phi - np.sin(phi)) / phi**3 * A @ A - t = u.reshape(-1,1,3) @ V.transpose(0,2,1) + phi = np.linalg.norm(w, axis=1).reshape(-1, 1, 1) + A = cross_prod_mat(w) + V = np.eye(3, dtype=v.dtype) + (1 - np.cos(phi)) / phi ** 2 * A + (phi - np.sin(phi)) / phi ** 3 * A @ A + t = u.reshape(-1, 1, 3) @ V.transpose(0, 2, 1) - return R.squeeze(), t.squeeze() + return R.squeeze(), t.squeeze() def quat_from_rotm(R): - R = R.reshape(-1,3,3) - q = np.empty((R.shape[0], 4,), dtype=R.dtype) - q[:,0] = np.sqrt( np.maximum(0, 1 + R[:,0,0] + R[:,1,1] + R[:,2,2]) ) - q[:,1] = np.sqrt( np.maximum(0, 1 + R[:,0,0] - R[:,1,1] - R[:,2,2]) ) - q[:,2] = np.sqrt( np.maximum(0, 1 - R[:,0,0] + R[:,1,1] - R[:,2,2]) ) - q[:,3] = np.sqrt( np.maximum(0, 1 - R[:,0,0] - R[:,1,1] + R[:,2,2]) ) - q[:,1] *= np.sign(q[:,1] * (R[:,2,1] - R[:,1,2])) - q[:,2] *= np.sign(q[:,2] * (R[:,0,2] - R[:,2,0])) - q[:,3] *= np.sign(q[:,3] * (R[:,1,0] - R[:,0,1])) - q /= np.linalg.norm(q,axis=1,keepdims=True) - return q.squeeze() + R = R.reshape(-1, 3, 3) + q = np.empty((R.shape[0], 4,), dtype=R.dtype) + q[:, 0] = np.sqrt(np.maximum(0, 1 + R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2])) + q[:, 1] = np.sqrt(np.maximum(0, 1 + R[:, 0, 0] - R[:, 1, 1] - R[:, 2, 2])) + q[:, 2] = np.sqrt(np.maximum(0, 1 - R[:, 0, 0] + R[:, 1, 1] - R[:, 2, 2])) + q[:, 3] = np.sqrt(np.maximum(0, 1 - R[:, 0, 0] - R[:, 1, 1] + R[:, 2, 2])) + q[:, 1] *= np.sign(q[:, 1] * (R[:, 2, 1] - R[:, 1, 2])) + q[:, 2] *= np.sign(q[:, 2] * (R[:, 0, 2] - R[:, 2, 0])) + q[:, 3] *= np.sign(q[:, 3] * (R[:, 1, 0] - R[:, 0, 1])) + q /= np.linalg.norm(q, axis=1, keepdims=True) + return q.squeeze() + def quat_from_axisangle(a): - a = a.reshape(-1, 3) - phi = np.linalg.norm(a, axis=1) - iphi = np.zeros_like(phi) - np.divide(1, phi, out=iphi, where=phi != 0) - a = a * iphi.reshape(-1,1) - theta = phi / 2.0 - r = np.cos(theta) - stheta = np.sin(theta) - q = np.stack((r, stheta*a[:,0], stheta*a[:,1], stheta*a[:,2]), axis=1) - q /= np.linalg.norm(q, axis=1).reshape(-1,1) - return q.squeeze() + a = a.reshape(-1, 3) + phi = np.linalg.norm(a, axis=1) + iphi = np.zeros_like(phi) + np.divide(1, phi, out=iphi, where=phi != 0) + a = a * iphi.reshape(-1, 1) + theta = phi / 2.0 + r = np.cos(theta) + stheta = np.sin(theta) + q = np.stack((r, stheta * a[:, 0], stheta * a[:, 1], stheta * a[:, 2]), axis=1) + q /= np.linalg.norm(q, axis=1).reshape(-1, 1) + return q.squeeze() + def quat_identity(n=1, dtype=np.float32): - q = np.zeros((n,4), dtype=dtype) - q[:,0] = 1 - return q.squeeze() + q = np.zeros((n, 4), dtype=dtype) + q[:, 0] = 1 + return q.squeeze() + def quat_conjugate(q): - shape = q.shape - q = q.reshape(-1,4).copy() - q[:,1:] *= -1 - return q.reshape(shape) + shape = q.shape + q = q.reshape(-1, 4).copy() + q[:, 1:] *= -1 + return q.reshape(shape) -def quat_product(q1, q2): - # q1 . q2 is equivalent to R(q1) @ R(q2) - shape = q1.shape - q1, q2 = q1.reshape(-1,4), q2.reshape(-1, 4) - q = np.empty((max(q1.shape[0], q2.shape[0]), 4), dtype=q1.dtype) - a1,b1,c1,d1 = q1[:,0], q1[:,1], q1[:,2], q1[:,3] - a2,b2,c2,d2 = q2[:,0], q2[:,1], q2[:,2], q2[:,3] - q[:,0] = a1 * a2 - b1 * b2 - c1 * c2 - d1 * d2 - q[:,1] = a1 * b2 + b1 * a2 + c1 * d2 - d1 * c2 - q[:,2] = a1 * c2 - b1 * d2 + c1 * a2 + d1 * b2 - q[:,3] = a1 * d2 + b1 * c2 - c1 * b2 + d1 * a2 - return q.squeeze() -def quat_apply(q, x): - xshape = x.shape - x = x.reshape(-1, 3) - qshape = q.shape - q = q.reshape(-1, 4) +def quat_product(q1, q2): + # q1 . q2 is equivalent to R(q1) @ R(q2) + shape = q1.shape + q1, q2 = q1.reshape(-1, 4), q2.reshape(-1, 4) + q = np.empty((max(q1.shape[0], q2.shape[0]), 4), dtype=q1.dtype) + a1, b1, c1, d1 = q1[:, 0], q1[:, 1], q1[:, 2], q1[:, 3] + a2, b2, c2, d2 = q2[:, 0], q2[:, 1], q2[:, 2], q2[:, 3] + q[:, 0] = a1 * a2 - b1 * b2 - c1 * c2 - d1 * d2 + q[:, 1] = a1 * b2 + b1 * a2 + c1 * d2 - d1 * c2 + q[:, 2] = a1 * c2 - b1 * d2 + c1 * a2 + d1 * b2 + q[:, 3] = a1 * d2 + b1 * c2 - c1 * b2 + d1 * a2 + return q.squeeze() - p = np.empty((x.shape[0], 4), dtype=x.dtype) - p[:,0] = 0 - p[:,1:] = x - r = quat_product(quat_product(q, p), quat_conjugate(q)) - if r.ndim == 1: - return r[1:].reshape(xshape) - else: - return r[:,1:].reshape(xshape) +def quat_apply(q, x): + xshape = x.shape + x = x.reshape(-1, 3) + qshape = q.shape + q = q.reshape(-1, 4) + + p = np.empty((x.shape[0], 4), dtype=x.dtype) + p[:, 0] = 0 + p[:, 1:] = x + + r = quat_product(quat_product(q, p), quat_conjugate(q)) + if r.ndim == 1: + return r[1:].reshape(xshape) + else: + return r[:, 1:].reshape(xshape) def quat_random(rng=None, n=1): - # http://planning.cs.uiuc.edu/node198.html - if rng is not None: - u = rng.uniform(0, 1, size=(3,n)) - else: - u = np.random.uniform(0, 1, size=(3,n)) - q = np.array(( - np.sqrt(1 - u[0]) * np.sin(2 * np.pi * u[1]), - np.sqrt(1 - u[0]) * np.cos(2 * np.pi * u[1]), - np.sqrt(u[0]) * np.sin(2 * np.pi * u[2]), - np.sqrt(u[0]) * np.cos(2 * np.pi * u[2]) - )).T - q /= np.linalg.norm(q,axis=1,keepdims=True) - return q.squeeze() + # http://planning.cs.uiuc.edu/node198.html + if rng is not None: + u = rng.uniform(0, 1, size=(3, n)) + else: + u = np.random.uniform(0, 1, size=(3, n)) + q = np.array(( + np.sqrt(1 - u[0]) * np.sin(2 * np.pi * u[1]), + np.sqrt(1 - u[0]) * np.cos(2 * np.pi * u[1]), + np.sqrt(u[0]) * np.sin(2 * np.pi * u[2]), + np.sqrt(u[0]) * np.cos(2 * np.pi * u[2]) + )).T + q /= np.linalg.norm(q, axis=1, keepdims=True) + return q.squeeze() + def quat_distance_angle(q0, q1): - # https://math.stackexchange.com/questions/90081/quaternion-distance - # https://link.springer.com/article/10.1007%2Fs10851-009-0161-2 - q0 = q0.reshape(-1,4) - q1 = q1.reshape(-1,4) - dists = np.arccos(np.clip(2 * np.sum(q0 * q1, axis=1)**2 - 1, -1, 1)) - return dists + # https://math.stackexchange.com/questions/90081/quaternion-distance + # https://link.springer.com/article/10.1007%2Fs10851-009-0161-2 + q0 = q0.reshape(-1, 4) + q1 = q1.reshape(-1, 4) + dists = np.arccos(np.clip(2 * np.sum(q0 * q1, axis=1) ** 2 - 1, -1, 1)) + return dists + def quat_distance_normdiff(q0, q1): - # https://link.springer.com/article/10.1007%2Fs10851-009-0161-2 - # \phi_4 - # [0, 1] - q0 = q0.reshape(-1,4) - q1 = q1.reshape(-1,4) - return 1 - np.sum(q0 * q1, axis=1)**2 + # https://link.springer.com/article/10.1007%2Fs10851-009-0161-2 + # \phi_4 + # [0, 1] + q0 = q0.reshape(-1, 4) + q1 = q1.reshape(-1, 4) + return 1 - np.sum(q0 * q1, axis=1) ** 2 + def quat_distance_mineucl(q0, q1): - # https://link.springer.com/article/10.1007%2Fs10851-009-0161-2 - # http://users.cecs.anu.edu.au/~trumpf/pubs/Hartley_Trumpf_Dai_Li.pdf - q0 = q0.reshape(-1,4) - q1 = q1.reshape(-1,4) - diff0 = ((q0 - q1)**2).sum(axis=1) - diff1 = ((q0 + q1)**2).sum(axis=1) - return np.minimum(diff0, diff1) + # https://link.springer.com/article/10.1007%2Fs10851-009-0161-2 + # http://users.cecs.anu.edu.au/~trumpf/pubs/Hartley_Trumpf_Dai_Li.pdf + q0 = q0.reshape(-1, 4) + q1 = q1.reshape(-1, 4) + diff0 = ((q0 - q1) ** 2).sum(axis=1) + diff1 = ((q0 + q1) ** 2).sum(axis=1) + return np.minimum(diff0, diff1) + def quat_slerp_space(q0, q1, num=100, endpoint=True): - q0 = q0.ravel() - q1 = q1.ravel() - dot = q0.dot(q1) - if dot < 0: - q1 *= -1 - dot *= -1 - t = np.linspace(0, 1, num=num, endpoint=endpoint, dtype=q0.dtype) - t = t.reshape((-1,1)) - if dot > 0.9995: - ret = q0 + t * (q1 - q0) - return ret - dot = np.clip(dot, -1, 1) - theta0 = np.arccos(dot) - theta = theta0 * t - s0 = np.cos(theta) - dot * np.sin(theta) / np.sin(theta0) - s1 = np.sin(theta) / np.sin(theta0) - return (s0 * q0) + (s1 * q1) + q0 = q0.ravel() + q1 = q1.ravel() + dot = q0.dot(q1) + if dot < 0: + q1 *= -1 + dot *= -1 + t = np.linspace(0, 1, num=num, endpoint=endpoint, dtype=q0.dtype) + t = t.reshape((-1, 1)) + if dot > 0.9995: + ret = q0 + t * (q1 - q0) + return ret + dot = np.clip(dot, -1, 1) + theta0 = np.arccos(dot) + theta = theta0 * t + s0 = np.cos(theta) - dot * np.sin(theta) / np.sin(theta0) + s1 = np.sin(theta) / np.sin(theta0) + return (s0 * q0) + (s1 * q1) + def cart_to_spherical(x): - shape = x.shape - x = x.reshape(-1,3) - y = np.empty_like(x) - y[:,0] = np.linalg.norm(x, axis=1) # r - y[:,1] = np.arccos(x[:,2] / y[:,0]) # theta - y[:,2] = np.arctan2(x[:,1], x[:,0]) # phi - return y.reshape(shape) + shape = x.shape + x = x.reshape(-1, 3) + y = np.empty_like(x) + y[:, 0] = np.linalg.norm(x, axis=1) # r + y[:, 1] = np.arccos(x[:, 2] / y[:, 0]) # theta + y[:, 2] = np.arctan2(x[:, 1], x[:, 0]) # phi + return y.reshape(shape) + def spherical_to_cart(x): - shape = x.shape - x = x.reshape(-1,3) - y = np.empty_like(x) - y[:,0] = x[:,0] * np.sin(x[:,1]) * np.cos(x[:,2]) - y[:,1] = x[:,0] * np.sin(x[:,1]) * np.sin(x[:,2]) - y[:,2] = x[:,0] * np.cos(x[:,1]) - return y.reshape(shape) + shape = x.shape + x = x.reshape(-1, 3) + y = np.empty_like(x) + y[:, 0] = x[:, 0] * np.sin(x[:, 1]) * np.cos(x[:, 2]) + y[:, 1] = x[:, 0] * np.sin(x[:, 1]) * np.sin(x[:, 2]) + y[:, 2] = x[:, 0] * np.cos(x[:, 1]) + return y.reshape(shape) + def spherical_random(r=1, n=1): - # http://mathworld.wolfram.com/SpherePointPicking.html - # https://math.stackexchange.com/questions/1585975/how-to-generate-random-points-on-a-sphere - x = np.empty((n,3)) - x[:,0] = r - x[:,1] = 2 * np.pi * np.random.uniform(0,1, size=(n,)) - x[:,2] = np.arccos(2 * np.random.uniform(0,1, size=(n,)) - 1) - return x.squeeze() - -def color_pcl(pcl, K, im, color_axis=0, as_int=True, invalid_color=[0,0,0]): - uvd = K @ pcl.T - uvd /= uvd[2] - uvd = np.round(uvd).astype(np.int32) - mask = np.logical_and(uvd[0] >= 0, uvd[1] >= 0) - color = np.empty((pcl.shape[0], 3), dtype=im.dtype) - if color_axis == 0: - mask = np.logical_and(mask, uvd[0] < im.shape[2]) - mask = np.logical_and(mask, uvd[1] < im.shape[1]) - uvd = uvd[:,mask] - color[mask,:] = im[:,uvd[1],uvd[0]].T - elif color_axis == 2: - mask = np.logical_and(mask, uvd[0] < im.shape[1]) - mask = np.logical_and(mask, uvd[1] < im.shape[0]) - uvd = uvd[:,mask] - color[mask,:] = im[uvd[1],uvd[0], :] - else: - raise Exception('invalid color_axis') - color[np.logical_not(mask),:3] = invalid_color - if as_int: - color = (255.0 * color).astype(np.int32) - return color + # http://mathworld.wolfram.com/SpherePointPicking.html + # https://math.stackexchange.com/questions/1585975/how-to-generate-random-points-on-a-sphere + x = np.empty((n, 3)) + x[:, 0] = r + x[:, 1] = 2 * np.pi * np.random.uniform(0, 1, size=(n,)) + x[:, 2] = np.arccos(2 * np.random.uniform(0, 1, size=(n,)) - 1) + return x.squeeze() + + +def color_pcl(pcl, K, im, color_axis=0, as_int=True, invalid_color=[0, 0, 0]): + uvd = K @ pcl.T + uvd /= uvd[2] + uvd = np.round(uvd).astype(np.int32) + mask = np.logical_and(uvd[0] >= 0, uvd[1] >= 0) + color = np.empty((pcl.shape[0], 3), dtype=im.dtype) + if color_axis == 0: + mask = np.logical_and(mask, uvd[0] < im.shape[2]) + mask = np.logical_and(mask, uvd[1] < im.shape[1]) + uvd = uvd[:, mask] + color[mask, :] = im[:, uvd[1], uvd[0]].T + elif color_axis == 2: + mask = np.logical_and(mask, uvd[0] < im.shape[1]) + mask = np.logical_and(mask, uvd[1] < im.shape[0]) + uvd = uvd[:, mask] + color[mask, :] = im[uvd[1], uvd[0], :] + else: + raise Exception('invalid color_axis') + color[np.logical_not(mask), :3] = invalid_color + if as_int: + color = (255.0 * color).astype(np.int32) + return color + def center_pcl(pcl, robust=False, copy=False, axis=1): - if copy: - pcl = pcl.copy() - if robust: - mu = np.median(pcl, axis=axis, keepdims=True) - else: - mu = np.mean(pcl, axis=axis, keepdims=True) - return pcl - mu + if copy: + pcl = pcl.copy() + if robust: + mu = np.median(pcl, axis=axis, keepdims=True) + else: + mu = np.mean(pcl, axis=axis, keepdims=True) + return pcl - mu + def to_homogeneous(x): - # return np.hstack((x, np.ones((x.shape[0],1),dtype=x.dtype))) - return np.concatenate((x, np.ones((*x.shape[:-1],1),dtype=x.dtype)), axis=-1) + # return np.hstack((x, np.ones((x.shape[0],1),dtype=x.dtype))) + return np.concatenate((x, np.ones((*x.shape[:-1], 1), dtype=x.dtype)), axis=-1) + def from_homogeneous(x): - return x[:,:-1] / x[:,-1] + return x[:, :-1] / x[:, -1] + def project_uvn(uv, Ki=None): - if uv.shape[1] == 2: - uvn = to_homogeneous(uv) - else: - uvn = uv - if uvn.shape[1] != 3: - raise Exception('uv should have shape Nx2 or Nx3') - if Ki is None: - return uvn - else: - return uvn @ Ki.T - -def project_uvd(uv, depth, K=np.eye(3), R=np.eye(3), t=np.zeros((3,1)), ignore_negative_depth=True, return_uvn=False): - Ki = np.linalg.inv(K) - - if ignore_negative_depth: - mask = depth >= 0 - uv = uv[mask,:] - d = depth[mask] - else: - d = depth.ravel() - - uv1 = to_homogeneous(uv) - - uvn1 = uv1 @ Ki.T - xyz = d.reshape(-1,1) * uvn1 - xyz = (xyz - t.reshape((1,3))) @ R - - if return_uvn: - return xyz, uvn1 - else: - return xyz - -def project_depth(depth, K, R=np.eye(3,3), t=np.zeros((3,1)), ignore_negative_depth=True, return_uvn=False): - u, v = np.meshgrid(range(depth.shape[1]), range(depth.shape[0])) - uv = np.hstack((u.reshape(-1,1), v.reshape(-1,1))) - return project_uvd(uv, depth.ravel(), K, R, t, ignore_negative_depth, return_uvn) - - -def project_xyz(xyz, K=np.eye(3), R=np.eye(3,3), t=np.zeros((3,1))): - uvd = K @ (R @ xyz.T + t.reshape((3,1))) - uvd[:2] /= uvd[2] - return uvd[:2].T, uvd[2] + if uv.shape[1] == 2: + uvn = to_homogeneous(uv) + else: + uvn = uv + if uvn.shape[1] != 3: + raise Exception('uv should have shape Nx2 or Nx3') + if Ki is None: + return uvn + else: + return uvn @ Ki.T -def relative_motion(R0, t0, R1, t1, Rt_from_global=True): - t0 = t0.reshape((3,1)) - t1 = t1.reshape((3,1)) - if Rt_from_global: - Rr = R1 @ R0.T - tr = t1 - Rr @ t0 - else: - Rr = R1.T @ R0 - tr = R1.T @ (t0 - t1) - return Rr, tr.ravel() +def project_uvd(uv, depth, K=np.eye(3), R=np.eye(3), t=np.zeros((3, 1)), ignore_negative_depth=True, return_uvn=False): + Ki = np.linalg.inv(K) + + if ignore_negative_depth: + mask = depth >= 0 + uv = uv[mask, :] + d = depth[mask] + else: + d = depth.ravel() + uv1 = to_homogeneous(uv) -def translation_to_cameracenter(R, t): - t = t.reshape(-1,3,1) - R = R.reshape(-1,3,3) - C = -R.transpose(0,2,1) @ t - return C.squeeze() + uvn1 = uv1 @ Ki.T + xyz = d.reshape(-1, 1) * uvn1 + xyz = (xyz - t.reshape((1, 3))) @ R -def cameracenter_to_translation(R, C): - C = C.reshape(-1,3,1) - R = R.reshape(-1,3,3) - t = -R @ C - return t.squeeze() + if return_uvn: + return xyz, uvn1 + else: + return xyz -def decompose_projection_matrix(P, return_t=True): - if P.shape[0] != 3 or P.shape[1] != 4: - raise Exception('P has to be 3x4') - M = P[:, :3] - C = -np.linalg.inv(M) @ P[:, 3:] - R,K = np.linalg.qr(np.flipud(M).T) - K = np.flipud(K.T) - K = np.fliplr(K) - R = np.flipud(R.T) +def project_depth(depth, K, R=np.eye(3, 3), t=np.zeros((3, 1)), ignore_negative_depth=True, return_uvn=False): + u, v = np.meshgrid(range(depth.shape[1]), range(depth.shape[0])) + uv = np.hstack((u.reshape(-1, 1), v.reshape(-1, 1))) + return project_uvd(uv, depth.ravel(), K, R, t, ignore_negative_depth, return_uvn) - T = np.diag(np.sign(np.diag(K))) - K = K @ T - R = T @ R - if np.linalg.det(R) < 0: - R *= -1 +def project_xyz(xyz, K=np.eye(3), R=np.eye(3, 3), t=np.zeros((3, 1))): + uvd = K @ (R @ xyz.T + t.reshape((3, 1))) + uvd[:2] /= uvd[2] + return uvd[:2].T, uvd[2] - K /= K[2,2] - if return_t: - return K, R, cameracenter_to_translation(R, C) - else: - return K, R, C +def relative_motion(R0, t0, R1, t1, Rt_from_global=True): + t0 = t0.reshape((3, 1)) + t1 = t1.reshape((3, 1)) + if Rt_from_global: + Rr = R1 @ R0.T + tr = t1 - Rr @ t0 + else: + Rr = R1.T @ R0 + tr = R1.T @ (t0 - t1) + return Rr, tr.ravel() + + +def translation_to_cameracenter(R, t): + t = t.reshape(-1, 3, 1) + R = R.reshape(-1, 3, 3) + C = -R.transpose(0, 2, 1) @ t + return C.squeeze() + + +def cameracenter_to_translation(R, C): + C = C.reshape(-1, 3, 1) + R = R.reshape(-1, 3, 3) + t = -R @ C + return t.squeeze() + + +def decompose_projection_matrix(P, return_t=True): + if P.shape[0] != 3 or P.shape[1] != 4: + raise Exception('P has to be 3x4') + M = P[:, :3] + C = -np.linalg.inv(M) @ P[:, 3:] + + R, K = np.linalg.qr(np.flipud(M).T) + K = np.flipud(K.T) + K = np.fliplr(K) + R = np.flipud(R.T) + + T = np.diag(np.sign(np.diag(K))) + K = K @ T + R = T @ R + + if np.linalg.det(R) < 0: + R *= -1 + + K /= K[2, 2] + if return_t: + return K, R, cameracenter_to_translation(R, C) + else: + return K, R, C -def compose_projection_matrix(K=np.eye(3), R=np.eye(3,3), t=np.zeros((3,1))): - return K @ np.hstack((R, t.reshape((3,1)))) +def compose_projection_matrix(K=np.eye(3), R=np.eye(3, 3), t=np.zeros((3, 1))): + return K @ np.hstack((R, t.reshape((3, 1)))) def point_plane_distance(pts, plane): - pts = pts.reshape(-1,3) - return np.abs(np.sum(plane[:3] * pts, axis=1) + plane[3]) / np.linalg.norm(plane[:3]) + pts = pts.reshape(-1, 3) + return np.abs(np.sum(plane[:3] * pts, axis=1) + plane[3]) / np.linalg.norm(plane[:3]) + def fit_plane(pts): - pts = pts.reshape(-1,3) - center = np.mean(pts, axis=0) - A = pts - center - u, s, vh = np.linalg.svd(A, full_matrices=False) - # if pts.shape[0] > 100: - # import ipdb; ipdb.set_trace() - plane = np.array([*vh[2], -vh[2].dot(center)]) - return plane + pts = pts.reshape(-1, 3) + center = np.mean(pts, axis=0) + A = pts - center + u, s, vh = np.linalg.svd(A, full_matrices=False) + # if pts.shape[0] > 100: + # import ipdb; ipdb.set_trace() + plane = np.array([*vh[2], -vh[2].dot(center)]) + return plane + def tetrahedron(dtype=np.float32): - verts = np.array([ - (np.sqrt(8/9), 0, -1/3), (-np.sqrt(2/9), np.sqrt(2/3), -1/3), - (-np.sqrt(2/9), -np.sqrt(2/3), -1/3), (0, 0, 1)], dtype=dtype) - faces = np.array([(0,1,2), (0,2,3), (0,1,3), (1,2,3)], dtype=np.int32) - normals = -np.mean(verts, axis=0) + verts - normals /= np.linalg.norm(normals, axis=1).reshape(-1,1) - return verts, faces, normals + verts = np.array([ + (np.sqrt(8 / 9), 0, -1 / 3), (-np.sqrt(2 / 9), np.sqrt(2 / 3), -1 / 3), + (-np.sqrt(2 / 9), -np.sqrt(2 / 3), -1 / 3), (0, 0, 1)], dtype=dtype) + faces = np.array([(0, 1, 2), (0, 2, 3), (0, 1, 3), (1, 2, 3)], dtype=np.int32) + normals = -np.mean(verts, axis=0) + verts + normals /= np.linalg.norm(normals, axis=1).reshape(-1, 1) + return verts, faces, normals + def cube(dtype=np.float32): - verts = np.array([ - [-0.5,-0.5,-0.5], [-0.5,0.5,-0.5], [0.5,0.5,-0.5], [0.5,-0.5,-0.5], - [-0.5,-0.5,0.5], [-0.5,0.5,0.5], [0.5,0.5,0.5], [0.5,-0.5,0.5]], dtype=dtype) - faces = np.array([ - (0,1,2), (0,2,3), (4,5,6), (4,6,7), - (0,4,7), (0,7,3), (1,5,6), (1,6,2), - (3,2,6), (3,6,7), (0,1,5), (0,5,4)], dtype=np.int32) - normals = -np.mean(verts, axis=0) + verts - normals /= np.linalg.norm(normals, axis=1).reshape(-1,1) - return verts, faces, normals + verts = np.array([ + [-0.5, -0.5, -0.5], [-0.5, 0.5, -0.5], [0.5, 0.5, -0.5], [0.5, -0.5, -0.5], + [-0.5, -0.5, 0.5], [-0.5, 0.5, 0.5], [0.5, 0.5, 0.5], [0.5, -0.5, 0.5]], dtype=dtype) + faces = np.array([ + (0, 1, 2), (0, 2, 3), (4, 5, 6), (4, 6, 7), + (0, 4, 7), (0, 7, 3), (1, 5, 6), (1, 6, 2), + (3, 2, 6), (3, 6, 7), (0, 1, 5), (0, 5, 4)], dtype=np.int32) + normals = -np.mean(verts, axis=0) + verts + normals /= np.linalg.norm(normals, axis=1).reshape(-1, 1) + return verts, faces, normals + def octahedron(dtype=np.float32): - verts = np.array([ - (+1,0,0), (0,+1,0), (0,0,+1), - (-1,0,0), (0,-1,0), (0,0,-1)], dtype=dtype) - faces = np.array([ - (0,1,2), (1,2,3), (3,2,4), (4,2,0), - (0,1,5), (1,5,3), (3,5,4), (4,5,0)], dtype=np.int32) - normals = -np.mean(verts, axis=0) + verts - normals /= np.linalg.norm(normals, axis=1).reshape(-1,1) - return verts, faces, normals + verts = np.array([ + (+1, 0, 0), (0, +1, 0), (0, 0, +1), + (-1, 0, 0), (0, -1, 0), (0, 0, -1)], dtype=dtype) + faces = np.array([ + (0, 1, 2), (1, 2, 3), (3, 2, 4), (4, 2, 0), + (0, 1, 5), (1, 5, 3), (3, 5, 4), (4, 5, 0)], dtype=np.int32) + normals = -np.mean(verts, axis=0) + verts + normals /= np.linalg.norm(normals, axis=1).reshape(-1, 1) + return verts, faces, normals + def icosahedron(dtype=np.float32): - p = (1 + np.sqrt(5)) / 2 - verts = np.array([ - (-1,0,p), (1,0,p), (1,0,-p), (-1,0,-p), - (0,-p,1), (0,p,1), (0,p,-1), (0,-p,-1), - (-p,-1,0), (p,-1,0), (p,1,0), (-p,1,0) + p = (1 + np.sqrt(5)) / 2 + verts = np.array([ + (-1, 0, p), (1, 0, p), (1, 0, -p), (-1, 0, -p), + (0, -p, 1), (0, p, 1), (0, p, -1), (0, -p, -1), + (-p, -1, 0), (p, -1, 0), (p, 1, 0), (-p, 1, 0) ], dtype=dtype) - faces = np.array([ - (0,1,4), (0,1,5), (1,4,9), (1,9,10), (1,10,5), (0,4,8), (0,8,11), (0,11,5), - (5,6,11), (5,6,10), (4,7,8), (4,7,9), - (3,2,6), (3,2,7), (2,6,10), (2,10,9), (2,9,7), (3,6,11), (3,11,8), (3,8,7), + faces = np.array([ + (0, 1, 4), (0, 1, 5), (1, 4, 9), (1, 9, 10), (1, 10, 5), (0, 4, 8), (0, 8, 11), (0, 11, 5), + (5, 6, 11), (5, 6, 10), (4, 7, 8), (4, 7, 9), + (3, 2, 6), (3, 2, 7), (2, 6, 10), (2, 10, 9), (2, 9, 7), (3, 6, 11), (3, 11, 8), (3, 8, 7), ], dtype=np.int32) - normals = -np.mean(verts, axis=0) + verts - normals /= np.linalg.norm(normals, axis=1).reshape(-1,1) - return verts, faces, normals + normals = -np.mean(verts, axis=0) + verts + normals /= np.linalg.norm(normals, axis=1).reshape(-1, 1) + return verts, faces, normals + def xyplane(dtype=np.float32, z=0, interleaved=False): - if interleaved: - eps = 1e-6 - verts = np.array([ - (-1,-1,z), (-1,1,z), (1,1,z), - (1-eps,1,z), (1-eps,-1,z), (-1-eps,-1,z)], dtype=dtype) - faces = np.array([(0,1,2), (3,4,5)], dtype=np.int32) - else: - verts = np.array([(-1,-1,z), (-1,1,z), (1,1,z), (1,-1,z)], dtype=dtype) - faces = np.array([(0,1,2), (0,2,3)], dtype=np.int32) - normals = np.zeros_like(verts) - normals[:,2] = -1 - return verts, faces, normals + if interleaved: + eps = 1e-6 + verts = np.array([ + (-1, -1, z), (-1, 1, z), (1, 1, z), + (1 - eps, 1, z), (1 - eps, -1, z), (-1 - eps, -1, z)], dtype=dtype) + faces = np.array([(0, 1, 2), (3, 4, 5)], dtype=np.int32) + else: + verts = np.array([(-1, -1, z), (-1, 1, z), (1, 1, z), (1, -1, z)], dtype=dtype) + faces = np.array([(0, 1, 2), (0, 2, 3)], dtype=np.int32) + normals = np.zeros_like(verts) + normals[:, 2] = -1 + return verts, faces, normals + def mesh_independent_verts(verts, faces, normals=None): - new_verts = [] - new_normals = [] - for f in faces: - new_verts.append(verts[f[0]]) - new_verts.append(verts[f[1]]) - new_verts.append(verts[f[2]]) - if normals is not None: - new_normals.append(normals[f[0]]) - new_normals.append(normals[f[1]]) - new_normals.append(normals[f[2]]) - new_verts = np.array(new_verts) - new_faces = np.arange(0, faces.size, dtype=faces.dtype).reshape(-1,3) - if normals is None: - return new_verts, new_faces - else: - new_normals = np.array(new_normals) - return new_verts, new_faces, new_normals + new_verts = [] + new_normals = [] + for f in faces: + new_verts.append(verts[f[0]]) + new_verts.append(verts[f[1]]) + new_verts.append(verts[f[2]]) + if normals is not None: + new_normals.append(normals[f[0]]) + new_normals.append(normals[f[1]]) + new_normals.append(normals[f[2]]) + new_verts = np.array(new_verts) + new_faces = np.arange(0, faces.size, dtype=faces.dtype).reshape(-1, 3) + if normals is None: + return new_verts, new_faces + else: + new_normals = np.array(new_normals) + return new_verts, new_faces, new_normals def stack_mesh(verts, faces): - n_verts = 0 - mfaces = [] - for idx, f in enumerate(faces): - mfaces.append(f + n_verts) - n_verts += verts[idx].shape[0] - verts = np.vstack(verts) - faces = np.vstack(mfaces) - return verts, faces + n_verts = 0 + mfaces = [] + for idx, f in enumerate(faces): + mfaces.append(f + n_verts) + n_verts += verts[idx].shape[0] + verts = np.vstack(verts) + faces = np.vstack(mfaces) + return verts, faces + def normalize_mesh(verts): - # all the verts have unit distance to the center (0,0,0) - return verts / np.linalg.norm(verts, axis=1, keepdims=True) + # all the verts have unit distance to the center (0,0,0) + return verts / np.linalg.norm(verts, axis=1, keepdims=True) def mesh_triangle_areas(verts, faces): - a = verts[faces[:,0]] - b = verts[faces[:,1]] - c = verts[faces[:,2]] - x = np.empty_like(a) - x = a - b - y = a - c - t = np.empty_like(a) - t[:,0] = (x[:,1] * y[:,2] - x[:,2] * y[:,1]); - t[:,1] = (x[:,2] * y[:,0] - x[:,0] * y[:,2]); - t[:,2] = (x[:,0] * y[:,1] - x[:,1] * y[:,0]); - return np.linalg.norm(t, axis=1) / 2 + a = verts[faces[:, 0]] + b = verts[faces[:, 1]] + c = verts[faces[:, 2]] + x = np.empty_like(a) + x = a - b + y = a - c + t = np.empty_like(a) + t[:, 0] = (x[:, 1] * y[:, 2] - x[:, 2] * y[:, 1]); + t[:, 1] = (x[:, 2] * y[:, 0] - x[:, 0] * y[:, 2]); + t[:, 2] = (x[:, 0] * y[:, 1] - x[:, 1] * y[:, 0]); + return np.linalg.norm(t, axis=1) / 2 + def subdivde_mesh(verts_in, faces_in, n=1): - for iter in range(n): - verts = [] - for v in verts_in: - verts.append(v) - faces = [] - verts_dict = {} - for f in faces_in: - f = np.sort(f) - i0,i1,i2 = f - v0,v1,v2 = verts_in[f] - - k = i0*len(verts_in)+i1 - if k in verts_dict: - i01 = verts_dict[k] - else: - i01 = len(verts) - verts_dict[k] = i01 - v01 = (v0 + v1) / 2 - verts.append(v01) - - k = i0*len(verts_in)+i2 - if k in verts_dict: - i02 = verts_dict[k] - else: - i02 = len(verts) - verts_dict[k] = i02 - v02 = (v0 + v2) / 2 - verts.append(v02) - - k = i1*len(verts_in)+i2 - if k in verts_dict: - i12 = verts_dict[k] - else: - i12 = len(verts) - verts_dict[k] = i12 - v12 = (v1 + v2) / 2 - verts.append(v12) - - faces.append((i0,i01,i02)) - faces.append((i01,i1,i12)) - faces.append((i12,i2,i02)) - faces.append((i01,i12,i02)) - - verts_in = np.array(verts, dtype=verts_in.dtype) - faces_in = np.array(faces, dtype=np.int32) - return verts_in, faces_in + for iter in range(n): + verts = [] + for v in verts_in: + verts.append(v) + faces = [] + verts_dict = {} + for f in faces_in: + f = np.sort(f) + i0, i1, i2 = f + v0, v1, v2 = verts_in[f] + + k = i0 * len(verts_in) + i1 + if k in verts_dict: + i01 = verts_dict[k] + else: + i01 = len(verts) + verts_dict[k] = i01 + v01 = (v0 + v1) / 2 + verts.append(v01) + + k = i0 * len(verts_in) + i2 + if k in verts_dict: + i02 = verts_dict[k] + else: + i02 = len(verts) + verts_dict[k] = i02 + v02 = (v0 + v2) / 2 + verts.append(v02) + + k = i1 * len(verts_in) + i2 + if k in verts_dict: + i12 = verts_dict[k] + else: + i12 = len(verts) + verts_dict[k] = i12 + v12 = (v1 + v2) / 2 + verts.append(v12) + + faces.append((i0, i01, i02)) + faces.append((i01, i1, i12)) + faces.append((i12, i2, i02)) + faces.append((i01, i12, i02)) + + verts_in = np.array(verts, dtype=verts_in.dtype) + faces_in = np.array(faces, dtype=np.int32) + return verts_in, faces_in def mesh_adjust_winding_order(verts, faces, normals): - n0 = normals[faces[:,0]] - n1 = normals[faces[:,1]] - n2 = normals[faces[:,2]] - fnormals = (n0 + n1 + n2) / 3 + n0 = normals[faces[:, 0]] + n1 = normals[faces[:, 1]] + n2 = normals[faces[:, 2]] + fnormals = (n0 + n1 + n2) / 3 - v0 = verts[faces[:,0]] - v1 = verts[faces[:,1]] - v2 = verts[faces[:,2]] + v0 = verts[faces[:, 0]] + v1 = verts[faces[:, 1]] + v2 = verts[faces[:, 2]] - e0 = v1 - v0 - e1 = v2 - v0 - fn = np.cross(e0, e1) + e0 = v1 - v0 + e1 = v2 - v0 + fn = np.cross(e0, e1) - dot = np.sum(fnormals * fn, axis=1) - ma = dot < 0 + dot = np.sum(fnormals * fn, axis=1) + ma = dot < 0 - nfaces = faces.copy() - nfaces[ma,1], nfaces[ma,2] = nfaces[ma,2], nfaces[ma,1] + nfaces = faces.copy() + nfaces[ma, 1], nfaces[ma, 2] = nfaces[ma, 2], nfaces[ma, 1] - return nfaces + return nfaces def pcl_to_shapecl(verts, colors=None, shape='cube', width=1.0): - if shape == 'tetrahedron': - cverts, cfaces, _ = tetrahedron() - elif shape == 'cube': - cverts, cfaces, _ = cube() - elif shape == 'octahedron': - cverts, cfaces, _ = octahedron() - elif shape == 'icosahedron': - cverts, cfaces, _ = icosahedron() - else: - raise Exception('invalid shape') - - sverts = np.tile(cverts, (verts.shape[0], 1)) - sverts *= width - sverts += np.repeat(verts, cverts.shape[0], axis=0) - - sfaces = np.tile(cfaces, (verts.shape[0], 1)) - sfoffset = cverts.shape[0] * np.arange(0, verts.shape[0]) - sfaces += np.repeat(sfoffset, cfaces.shape[0]).reshape(-1,1) - - if colors is not None: - scolors = np.repeat(colors, cverts.shape[0], axis=0) - else: - scolors = None - - return sverts, sfaces, scolors + if shape == 'tetrahedron': + cverts, cfaces, _ = tetrahedron() + elif shape == 'cube': + cverts, cfaces, _ = cube() + elif shape == 'octahedron': + cverts, cfaces, _ = octahedron() + elif shape == 'icosahedron': + cverts, cfaces, _ = icosahedron() + else: + raise Exception('invalid shape') + + sverts = np.tile(cverts, (verts.shape[0], 1)) + sverts *= width + sverts += np.repeat(verts, cverts.shape[0], axis=0) + + sfaces = np.tile(cfaces, (verts.shape[0], 1)) + sfoffset = cverts.shape[0] * np.arange(0, verts.shape[0]) + sfaces += np.repeat(sfoffset, cfaces.shape[0]).reshape(-1, 1) + + if colors is not None: + scolors = np.repeat(colors, cverts.shape[0], axis=0) + else: + scolors = None + + return sverts, sfaces, scolors diff --git a/co/gtimer.py b/co/gtimer.py index 5bad06b..78f55fc 100644 --- a/co/gtimer.py +++ b/co/gtimer.py @@ -2,31 +2,37 @@ import numpy as np from . import utils + class StopWatch(utils.StopWatch): - def __del__(self): - print('='*80) - print('gtimer:') - total = ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get(reduce=np.sum).items()]) - print(f' [total] {total}') - mean = ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get(reduce=np.mean).items()]) - print(f' [mean] {mean}') - median = ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get(reduce=np.median).items()]) - print(f' [median] {median}') - print('='*80) + def __del__(self): + print('=' * 80) + print('gtimer:') + total = ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get(reduce=np.sum).items()]) + print(f' [total] {total}') + mean = ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get(reduce=np.mean).items()]) + print(f' [mean] {mean}') + median = ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get(reduce=np.median).items()]) + print(f' [median] {median}') + print('=' * 80) + GTIMER = StopWatch() + def start(name): - GTIMER.start(name) + GTIMER.start(name) + + def stop(name): - GTIMER.stop(name) + GTIMER.stop(name) + class Ctx(object): - def __init__(self, name): - self.name = name + def __init__(self, name): + self.name = name - def __enter__(self): - start(self.name) + def __enter__(self): + start(self.name) - def __exit__(self, *args): - stop(self.name) + def __exit__(self, *args): + stop(self.name) diff --git a/co/io3d.py b/co/io3d.py index 54bc60a..cab59a4 100644 --- a/co/io3d.py +++ b/co/io3d.py @@ -2,266 +2,273 @@ import struct import numpy as np import collections -def _write_ply_point(fp, x,y,z, color=None, normal=None, binary=False): - args = [x,y,z] - if color is not None: - args += [int(color[0]), int(color[1]), int(color[2])] - if normal is not None: - args += [normal[0],normal[1],normal[2]] - if binary: - fmt = ' 1: - c = color[vidx] +def write_ply(path, verts, trias=None, color=None, normals=None, binary=False): + if verts.shape[1] != 3: + raise Exception('verts has to be of shape Nx3') + if trias is not None and trias.shape[1] != 3: + raise Exception('trias has to be of shape Nx3') + if color is not None and not callable(color) and not isinstance(color, np.ndarray) and color.shape[1] != 3: + raise Exception('color has to be of shape Nx3 or a callable') + + mode = 'wb' if binary else 'w' + with open(path, mode) as fp: + _write_ply_header_line(fp, "ply\n", binary) + if binary: + _write_ply_header_line(fp, "format binary_little_endian 1.0\n", binary) else: - c = color[0] - else: - c = None - if normals is None: - n = None - else: - n = normals[vidx] - _write_ply_point(fp, v[0],v[1],v[2], c, n, binary) + _write_ply_header_line(fp, "format ascii 1.0\n", binary) + _write_ply_header_line(fp, "element vertex %d\n" % (verts.shape[0]), binary) + _write_ply_header_line(fp, "property float32 x\n", binary) + _write_ply_header_line(fp, "property float32 y\n", binary) + _write_ply_header_line(fp, "property float32 z\n", binary) + if color is not None: + _write_ply_header_line(fp, "property uchar red\n", binary) + _write_ply_header_line(fp, "property uchar green\n", binary) + _write_ply_header_line(fp, "property uchar blue\n", binary) + if normals is not None: + _write_ply_header_line(fp, "property float32 nx\n", binary) + _write_ply_header_line(fp, "property float32 ny\n", binary) + _write_ply_header_line(fp, "property float32 nz\n", binary) + if trias is not None: + _write_ply_header_line(fp, "element face %d\n" % (trias.shape[0]), binary) + _write_ply_header_line(fp, "property list uchar int32 vertex_indices\n", binary) + _write_ply_header_line(fp, "end_header\n", binary) + + for vidx, v in enumerate(verts): + if color is not None: + if callable(color): + c = color(vidx) + elif color.shape[0] > 1: + c = color[vidx] + else: + c = color[0] + else: + c = None + if normals is None: + n = None + else: + n = normals[vidx] + _write_ply_point(fp, v[0], v[1], v[2], c, n, binary) + + if trias is not None: + for t in trias: + _write_ply_triangle(fp, t[0], t[1], t[2], binary) - if trias is not None: - for t in trias: - _write_ply_triangle(fp, t[0],t[1],t[2], binary) def faces_to_triangles(faces): - new_faces = [] - for f in faces: - if f[0] == 3: - new_faces.append([f[1], f[2], f[3]]) - elif f[0] == 4: - new_faces.append([f[1], f[2], f[3]]) - new_faces.append([f[3], f[4], f[1]]) - else: - raise Exception('unknown face count %d', f[0]) - return new_faces + new_faces = [] + for f in faces: + if f[0] == 3: + new_faces.append([f[1], f[2], f[3]]) + elif f[0] == 4: + new_faces.append([f[1], f[2], f[3]]) + new_faces.append([f[3], f[4], f[1]]) + else: + raise Exception('unknown face count %d', f[0]) + return new_faces + def read_ply(path): - with open(path, 'rb') as f: - # parse header - line = f.readline().decode().strip() - if line != 'ply': - raise Exception('Header error') - n_verts = 0 - n_faces = 0 - vert_types = {} - vert_bin_format = [] - vert_bin_len = 0 - vert_bin_cols = 0 - line = f.readline().decode() - parse_vertex_prop = False - while line.strip() != 'end_header': - if 'format' in line: - if 'ascii' in line: - binary = False - elif 'binary_little_endian' in line: - binary = True - else: - raise Exception('invalid ply format') - if 'element face' in line: - splits = line.strip().split(' ') - n_faces = int(splits[-1]) - parse_vertex_prop = False - if 'element camera' in line: + with open(path, 'rb') as f: + # parse header + line = f.readline().decode().strip() + if line != 'ply': + raise Exception('Header error') + n_verts = 0 + n_faces = 0 + vert_types = {} + vert_bin_format = [] + vert_bin_len = 0 + vert_bin_cols = 0 + line = f.readline().decode() parse_vertex_prop = False - if 'element vertex' in line: - splits = line.strip().split(' ') - n_verts = int(splits[-1]) - parse_vertex_prop = True - if parse_vertex_prop and 'property' in line: - prop = line.strip().split() - if prop[1] == 'float': - vert_bin_format.append('f4') - vert_bin_len += 4 - vert_bin_cols += 1 - elif prop[1] == 'uchar': - vert_bin_format.append('B') - vert_bin_len += 1 - vert_bin_cols += 1 - else: - raise Exception('invalid property') - vert_types[prop[2]] = len(vert_types) - line = f.readline().decode() + while line.strip() != 'end_header': + if 'format' in line: + if 'ascii' in line: + binary = False + elif 'binary_little_endian' in line: + binary = True + else: + raise Exception('invalid ply format') + if 'element face' in line: + splits = line.strip().split(' ') + n_faces = int(splits[-1]) + parse_vertex_prop = False + if 'element camera' in line: + parse_vertex_prop = False + if 'element vertex' in line: + splits = line.strip().split(' ') + n_verts = int(splits[-1]) + parse_vertex_prop = True + if parse_vertex_prop and 'property' in line: + prop = line.strip().split() + if prop[1] == 'float': + vert_bin_format.append('f4') + vert_bin_len += 4 + vert_bin_cols += 1 + elif prop[1] == 'uchar': + vert_bin_format.append('B') + vert_bin_len += 1 + vert_bin_cols += 1 + else: + raise Exception('invalid property') + vert_types[prop[2]] = len(vert_types) + line = f.readline().decode() - # parse content - if binary: - sz = n_verts * vert_bin_len - fmt = ','.join(vert_bin_format) - verts = np.ndarray(shape=(1, n_verts), dtype=np.dtype(fmt), buffer=f.read(sz)) - verts = verts[0].astype(vert_bin_cols*'f4,').view(dtype='f4').reshape((n_verts,-1)) - faces = [] - for idx in range(n_faces): - fmt = '= 2 and len(parts[1]) > 0: - tidx = int(parts[1]) - 1 - else: - tidx = -1 - if len(parts) >= 3 and len(parts[2]) > 0: - nidx = int(parts[2]) - 1 - else: - nidx = -1 - return vidx, tidx, nidx + parts = s.split('/') + vidx = int(parts[0]) - 1 + if len(parts) >= 2 and len(parts[1]) > 0: + tidx = int(parts[1]) - 1 + else: + tidx = -1 + if len(parts) >= 3 and len(parts[2]) > 0: + nidx = int(parts[2]) - 1 + else: + nidx = -1 + return vidx, tidx, nidx + def read_obj(path): - with open(path, 'r') as fp: - lines = fp.readlines() + with open(path, 'r') as fp: + lines = fp.readlines() + + verts = [] + colors = [] + fnorms = [] + fnorm_map = collections.defaultdict(list) + faces = [] + for line in lines: + line = line.strip() + if line.startswith('#') or len(line) == 0: + continue - verts = [] - colors = [] - fnorms = [] - fnorm_map = collections.defaultdict(list) - faces = [] - for line in lines: - line = line.strip() - if line.startswith('#') or len(line) == 0: - continue + parts = line.split() + if line.startswith('v '): + parts = parts[1:] + x, y, z = float(parts[0]), float(parts[1]), float(parts[2]) + if len(parts) == 4 or len(parts) == 7: + w = float(parts[3]) + x, y, z = x / w, y / w, z / w + verts.append((x, y, z)) + if len(parts) >= 6: + r, g, b = float(parts[-3]), float(parts[-2]), float(parts[-1]) + rgb.append((r, g, b)) - parts = line.split() - if line.startswith('v '): - parts = parts[1:] - x,y,z = float(parts[0]), float(parts[1]), float(parts[2]) - if len(parts) == 4 or len(parts) == 7: - w = float(parts[3]) - x,y,z = x/w, y/w, z/w - verts.append((x,y,z)) - if len(parts) >= 6: - r,g,b = float(parts[-3]), float(parts[-2]), float(parts[-1]) - rgb.append((r,g,b)) + elif line.startswith('vn '): + parts = parts[1:] + x, y, z = float(parts[0]), float(parts[1]), float(parts[2]) + fnorms.append((x, y, z)) - elif line.startswith('vn '): - parts = parts[1:] - x,y,z = float(parts[0]), float(parts[1]), float(parts[2]) - fnorms.append((x,y,z)) + elif line.startswith('f '): + parts = parts[1:] + if len(parts) != 3: + raise Exception('only triangle meshes supported atm') + vidx0, tidx0, nidx0 = _read_obj_split_f(parts[0]) + vidx1, tidx1, nidx1 = _read_obj_split_f(parts[1]) + vidx2, tidx2, nidx2 = _read_obj_split_f(parts[2]) - elif line.startswith('f '): - parts = parts[1:] - if len(parts) != 3: - raise Exception('only triangle meshes supported atm') - vidx0, tidx0, nidx0 = _read_obj_split_f(parts[0]) - vidx1, tidx1, nidx1 = _read_obj_split_f(parts[1]) - vidx2, tidx2, nidx2 = _read_obj_split_f(parts[2]) + faces.append((vidx0, vidx1, vidx2)) + if nidx0 >= 0: + fnorm_map[vidx0].append(nidx0) + if nidx1 >= 0: + fnorm_map[vidx1].append(nidx1) + if nidx2 >= 0: + fnorm_map[vidx2].append(nidx2) - faces.append((vidx0, vidx1, vidx2)) - if nidx0 >= 0: - fnorm_map[vidx0].append( nidx0 ) - if nidx1 >= 0: - fnorm_map[vidx1].append( nidx1 ) - if nidx2 >= 0: - fnorm_map[vidx2].append( nidx2 ) + verts = np.array(verts) + colors = np.array(colors) + fnorms = np.array(fnorms) + faces = np.array(faces) - verts = np.array(verts) - colors = np.array(colors) - fnorms = np.array(fnorms) - faces = np.array(faces) - - # face normals to vertex normals - norms = np.zeros_like(verts) - for vidx in fnorm_map.keys(): - ind = fnorm_map[vidx] - norms[vidx] = fnorms[ind].sum(axis=0) - N = np.linalg.norm(norms, axis=1, keepdims=True) - np.divide(norms, N, out=norms, where=N != 0) + # face normals to vertex normals + norms = np.zeros_like(verts) + for vidx in fnorm_map.keys(): + ind = fnorm_map[vidx] + norms[vidx] = fnorms[ind].sum(axis=0) + N = np.linalg.norm(norms, axis=1, keepdims=True) + np.divide(norms, N, out=norms, where=N != 0) - return verts, faces, colors, norms + return verts, faces, colors, norms diff --git a/co/metric.py b/co/metric.py index 26061da..de2ca80 100644 --- a/co/metric.py +++ b/co/metric.py @@ -1,248 +1,260 @@ import numpy as np from . import geometry + def _process_inputs(estimate, target, mask): - if estimate.shape != target.shape: - raise Exception('estimate and target have to be same shape') - if mask is None: - mask = np.ones(estimate.shape, dtype=np.bool) - else: - mask = mask != 0 - if estimate.shape != mask.shape: - raise Exception('estimate and mask have to be same shape') - return estimate, target, mask + if estimate.shape != target.shape: + raise Exception('estimate and target have to be same shape') + if mask is None: + mask = np.ones(estimate.shape, dtype=np.bool) + else: + mask = mask != 0 + if estimate.shape != mask.shape: + raise Exception('estimate and mask have to be same shape') + return estimate, target, mask + def mse(estimate, target, mask=None): - estimate, target, mask = _process_inputs(estimate, target, mask) - m = np.sum((estimate[mask] - target[mask])**2) / mask.sum() - return m + estimate, target, mask = _process_inputs(estimate, target, mask) + m = np.sum((estimate[mask] - target[mask]) ** 2) / mask.sum() + return m + def rmse(estimate, target, mask=None): - return np.sqrt(mse(estimate, target, mask)) + return np.sqrt(mse(estimate, target, mask)) + def mae(estimate, target, mask=None): - estimate, target, mask = _process_inputs(estimate, target, mask) - m = np.abs(estimate[mask] - target[mask]).sum() / mask.sum() - return m + estimate, target, mask = _process_inputs(estimate, target, mask) + m = np.abs(estimate[mask] - target[mask]).sum() / mask.sum() + return m + def outlier_fraction(estimate, target, mask=None, threshold=0): - estimate, target, mask = _process_inputs(estimate, target, mask) - diff = np.abs(estimate[mask] - target[mask]) - m = (diff > threshold).sum() / mask.sum() - return m + estimate, target, mask = _process_inputs(estimate, target, mask) + diff = np.abs(estimate[mask] - target[mask]) + m = (diff > threshold).sum() / mask.sum() + return m class Metric(object): - def __init__(self, str_prefix=''): - self.str_prefix = str_prefix - self.reset() + def __init__(self, str_prefix=''): + self.str_prefix = str_prefix + self.reset() + + def reset(self): + pass - def reset(self): - pass + def add(self, es, ta, ma=None): + pass - def add(self, es, ta, ma=None): - pass + def get(self): + return {} - def get(self): - return {} + def items(self): + return self.get().items() - def items(self): - return self.get().items() + def __str__(self): + return ', '.join([f'{self.str_prefix}{key}={value:.5f}' for key, value in self.get().items()]) - def __str__(self): - return ', '.join([f'{self.str_prefix}{key}={value:.5f}' for key, value in self.get().items()]) class MultipleMetric(Metric): - def __init__(self, *metrics, **kwargs): - self.metrics = [*metrics] - super().__init__(**kwargs) + def __init__(self, *metrics, **kwargs): + self.metrics = [*metrics] + super().__init__(**kwargs) + + def reset(self): + for m in self.metrics: + m.reset() - def reset(self): - for m in self.metrics: - m.reset() + def add(self, es, ta, ma=None): + for m in self.metrics: + m.add(es, ta, ma) - def add(self, es, ta, ma=None): - for m in self.metrics: - m.add(es, ta, ma) + def get(self): + ret = {} + for m in self.metrics: + vals = m.get() + for k in vals: + ret[k] = vals[k] + return ret - def get(self): - ret = {} - for m in self.metrics: - vals = m.get() - for k in vals: - ret[k] = vals[k] - return ret + def __str__(self): + return '\n'.join([str(m) for m in self.metrics]) - def __str__(self): - return '\n'.join([str(m) for m in self.metrics]) class BaseDistanceMetric(Metric): - def __init__(self, name='', **kwargs): - super().__init__(**kwargs) - self.name = name - - def reset(self): - self.dists = [] - - def add(self, es, ta, ma=None): - pass - - def get(self): - dists = np.hstack(self.dists) - return { - f'dist{self.name}_mean': float(np.mean(dists)), - f'dist{self.name}_std': float(np.std(dists)), - f'dist{self.name}_median': float(np.median(dists)), - f'dist{self.name}_q10': float(np.percentile(dists, 10)), - f'dist{self.name}_q90': float(np.percentile(dists, 90)), - f'dist{self.name}_min': float(np.min(dists)), - f'dist{self.name}_max': float(np.max(dists)), - } + def __init__(self, name='', **kwargs): + super().__init__(**kwargs) + self.name = name + + def reset(self): + self.dists = [] + + def add(self, es, ta, ma=None): + pass + + def get(self): + dists = np.hstack(self.dists) + return { + f'dist{self.name}_mean': float(np.mean(dists)), + f'dist{self.name}_std': float(np.std(dists)), + f'dist{self.name}_median': float(np.median(dists)), + f'dist{self.name}_q10': float(np.percentile(dists, 10)), + f'dist{self.name}_q90': float(np.percentile(dists, 90)), + f'dist{self.name}_min': float(np.min(dists)), + f'dist{self.name}_max': float(np.max(dists)), + } + class DistanceMetric(BaseDistanceMetric): - def __init__(self, vec_length, p=2, **kwargs): - super().__init__(name=f'{p}', **kwargs) - self.vec_length = vec_length - self.p = p - - def add(self, es, ta, ma=None): - if es.shape != ta.shape or es.shape[1] != self.vec_length or es.ndim != 2: - print(es.shape, ta.shape) - raise Exception('es and ta have to be of shape Nxdim') - if ma is not None: - es = es[ma != 0] - ta = ta[ma != 0] - dist = np.linalg.norm(es - ta, ord=self.p, axis=1) - self.dists.append( dist ) + def __init__(self, vec_length, p=2, **kwargs): + super().__init__(name=f'{p}', **kwargs) + self.vec_length = vec_length + self.p = p + + def add(self, es, ta, ma=None): + if es.shape != ta.shape or es.shape[1] != self.vec_length or es.ndim != 2: + print(es.shape, ta.shape) + raise Exception('es and ta have to be of shape Nxdim') + if ma is not None: + es = es[ma != 0] + ta = ta[ma != 0] + dist = np.linalg.norm(es - ta, ord=self.p, axis=1) + self.dists.append(dist) + class OutlierFractionMetric(DistanceMetric): - def __init__(self, thresholds, *args, **kwargs): - super().__init__(*args, **kwargs) - self.thresholds = thresholds - - def get(self): - dists = np.hstack(self.dists) - ret = {} - for t in self.thresholds: - ma = dists > t - ret[f'of{t}'] = float(ma.sum() / ma.size) - return ret + def __init__(self, thresholds, *args, **kwargs): + super().__init__(*args, **kwargs) + self.thresholds = thresholds + + def get(self): + dists = np.hstack(self.dists) + ret = {} + for t in self.thresholds: + ma = dists > t + ret[f'of{t}'] = float(ma.sum() / ma.size) + return ret + class RelativeDistanceMetric(BaseDistanceMetric): - def __init__(self, vec_length, p=2, **kwargs): - super().__init__(name=f'rel{p}', **kwargs) - self.vec_length = vec_length - self.p = p - - def add(self, es, ta, ma=None): - if es.shape != ta.shape or es.shape[1] != self.vec_length or es.ndim != 2: - raise Exception('es and ta have to be of shape Nxdim') - dist = np.linalg.norm(es - ta, ord=self.p, axis=1) - denom = np.linalg.norm(ta, ord=self.p, axis=1) - dist /= denom - if ma is not None: - dist = dist[ma != 0] - self.dists.append( dist ) + def __init__(self, vec_length, p=2, **kwargs): + super().__init__(name=f'rel{p}', **kwargs) + self.vec_length = vec_length + self.p = p + + def add(self, es, ta, ma=None): + if es.shape != ta.shape or es.shape[1] != self.vec_length or es.ndim != 2: + raise Exception('es and ta have to be of shape Nxdim') + dist = np.linalg.norm(es - ta, ord=self.p, axis=1) + denom = np.linalg.norm(ta, ord=self.p, axis=1) + dist /= denom + if ma is not None: + dist = dist[ma != 0] + self.dists.append(dist) + class RotmDistanceMetric(BaseDistanceMetric): - def __init__(self, type='identity', **kwargs): - super().__init__(name=type, **kwargs) - self.type = type - - def add(self, es, ta, ma=None): - if es.shape != ta.shape or es.shape[1] != 3 or es.shape[2] != 3 or es.ndim != 3: - print(es.shape, ta.shape) - raise Exception('es and ta have to be of shape Nx3x3') - if ma is not None: - raise Exception('mask is not implemented') - if self.type == 'identity': - self.dists.append( geometry.rotm_distance_identity(es, ta) ) - elif self.type == 'geodesic': - self.dists.append( geometry.rotm_distance_geodesic_unit_sphere(es, ta) ) - else: - raise Exception('invalid distance type') + def __init__(self, type='identity', **kwargs): + super().__init__(name=type, **kwargs) + self.type = type + + def add(self, es, ta, ma=None): + if es.shape != ta.shape or es.shape[1] != 3 or es.shape[2] != 3 or es.ndim != 3: + print(es.shape, ta.shape) + raise Exception('es and ta have to be of shape Nx3x3') + if ma is not None: + raise Exception('mask is not implemented') + if self.type == 'identity': + self.dists.append(geometry.rotm_distance_identity(es, ta)) + elif self.type == 'geodesic': + self.dists.append(geometry.rotm_distance_geodesic_unit_sphere(es, ta)) + else: + raise Exception('invalid distance type') + class QuaternionDistanceMetric(BaseDistanceMetric): - def __init__(self, type='angle', **kwargs): - super().__init__(name=type, **kwargs) - self.type = type - - def add(self, es, ta, ma=None): - if es.shape != ta.shape or es.shape[1] != 4 or es.ndim != 2: - print(es.shape, ta.shape) - raise Exception('es and ta have to be of shape Nx4') - if ma is not None: - raise Exception('mask is not implemented') - if self.type == 'angle': - self.dists.append( geometry.quat_distance_angle(es, ta) ) - elif self.type == 'mineucl': - self.dists.append( geometry.quat_distance_mineucl(es, ta) ) - elif self.type == 'normdiff': - self.dists.append( geometry.quat_distance_normdiff(es, ta) ) - else: - raise Exception('invalid distance type') + def __init__(self, type='angle', **kwargs): + super().__init__(name=type, **kwargs) + self.type = type + + def add(self, es, ta, ma=None): + if es.shape != ta.shape or es.shape[1] != 4 or es.ndim != 2: + print(es.shape, ta.shape) + raise Exception('es and ta have to be of shape Nx4') + if ma is not None: + raise Exception('mask is not implemented') + if self.type == 'angle': + self.dists.append(geometry.quat_distance_angle(es, ta)) + elif self.type == 'mineucl': + self.dists.append(geometry.quat_distance_mineucl(es, ta)) + elif self.type == 'normdiff': + self.dists.append(geometry.quat_distance_normdiff(es, ta)) + else: + raise Exception('invalid distance type') class BinaryAccuracyMetric(Metric): - def __init__(self, thresholds=np.linspace(0.0, 1.0, num=101, dtype=np.float64)[:-1], **kwargs): - self.thresholds = thresholds - super().__init__(**kwargs) - - def reset(self): - self.tps = [0 for wp in self.thresholds] - self.fps = [0 for wp in self.thresholds] - self.fns = [0 for wp in self.thresholds] - self.tns = [0 for wp in self.thresholds] - self.n_pos = 0 - self.n_neg = 0 - - def add(self, es, ta, ma=None): - if ma is not None: - raise Exception('mask is not implemented') - es = es.ravel() - ta = ta.ravel() - if es.shape[0] != ta.shape[0]: - raise Exception('invalid shape of es, or ta') - if es.min() < 0 or es.max() > 1: - raise Exception('estimate has wrong value range') - ta_p = (ta == 1) - ta_n = (ta == 0) - es_p = es[ta_p] - es_n = es[ta_n] - for idx, wp in enumerate(self.thresholds): - wp = np.asscalar(wp) - self.tps[idx] += (es_p > wp).sum() - self.fps[idx] += (es_n > wp).sum() - self.fns[idx] += (es_p <= wp).sum() - self.tns[idx] += (es_n <= wp).sum() - self.n_pos += ta_p.sum() - self.n_neg += ta_n.sum() - - def get(self): - tps = np.array(self.tps).astype(np.float32) - fps = np.array(self.fps).astype(np.float32) - fns = np.array(self.fns).astype(np.float32) - tns = np.array(self.tns).astype(np.float32) - wp = self.thresholds - - ret = {} - - precisions = np.divide(tps, tps + fps, out=np.zeros_like(tps), where=tps + fps != 0) - recalls = np.divide(tps, tps + fns, out=np.zeros_like(tps), where=tps + fns != 0) # tprs - fprs = np.divide(fps, fps + tns, out=np.zeros_like(tps), where=fps + tns != 0) - - precisions = np.r_[0, precisions, 1] - recalls = np.r_[1, recalls, 0] - fprs = np.r_[1, fprs, 0] - - ret['auc'] = float(-np.trapz(recalls, fprs)) - ret['prauc'] = float(-np.trapz(precisions, recalls)) - ret['ap'] = float(-(np.diff(recalls) * precisions[:-1]).sum()) - - accuracies = np.divide(tps + tns, tps + tns + fps + fns) - aacc = np.mean(accuracies) - for t in np.linspace(0,1,num=11)[1:-1]: - idx = np.argmin(np.abs(t - wp)) - ret[f'acc{wp[idx]:.2f}'] = float(accuracies[idx]) - - return ret + def __init__(self, thresholds=np.linspace(0.0, 1.0, num=101, dtype=np.float64)[:-1], **kwargs): + self.thresholds = thresholds + super().__init__(**kwargs) + + def reset(self): + self.tps = [0 for wp in self.thresholds] + self.fps = [0 for wp in self.thresholds] + self.fns = [0 for wp in self.thresholds] + self.tns = [0 for wp in self.thresholds] + self.n_pos = 0 + self.n_neg = 0 + + def add(self, es, ta, ma=None): + if ma is not None: + raise Exception('mask is not implemented') + es = es.ravel() + ta = ta.ravel() + if es.shape[0] != ta.shape[0]: + raise Exception('invalid shape of es, or ta') + if es.min() < 0 or es.max() > 1: + raise Exception('estimate has wrong value range') + ta_p = (ta == 1) + ta_n = (ta == 0) + es_p = es[ta_p] + es_n = es[ta_n] + for idx, wp in enumerate(self.thresholds): + wp = np.asscalar(wp) + self.tps[idx] += (es_p > wp).sum() + self.fps[idx] += (es_n > wp).sum() + self.fns[idx] += (es_p <= wp).sum() + self.tns[idx] += (es_n <= wp).sum() + self.n_pos += ta_p.sum() + self.n_neg += ta_n.sum() + + def get(self): + tps = np.array(self.tps).astype(np.float32) + fps = np.array(self.fps).astype(np.float32) + fns = np.array(self.fns).astype(np.float32) + tns = np.array(self.tns).astype(np.float32) + wp = self.thresholds + + ret = {} + + precisions = np.divide(tps, tps + fps, out=np.zeros_like(tps), where=tps + fps != 0) + recalls = np.divide(tps, tps + fns, out=np.zeros_like(tps), where=tps + fns != 0) # tprs + fprs = np.divide(fps, fps + tns, out=np.zeros_like(tps), where=fps + tns != 0) + + precisions = np.r_[0, precisions, 1] + recalls = np.r_[1, recalls, 0] + fprs = np.r_[1, fprs, 0] + + ret['auc'] = float(-np.trapz(recalls, fprs)) + ret['prauc'] = float(-np.trapz(precisions, recalls)) + ret['ap'] = float(-(np.diff(recalls) * precisions[:-1]).sum()) + + accuracies = np.divide(tps + tns, tps + tns + fps + fns) + aacc = np.mean(accuracies) + for t in np.linspace(0, 1, num=11)[1:-1]: + idx = np.argmin(np.abs(t - wp)) + ret[f'acc{wp[idx]:.2f}'] = float(accuracies[idx]) + + return ret diff --git a/co/plt.py b/co/plt.py index 200ba25..1706d15 100644 --- a/co/plt.py +++ b/co/plt.py @@ -6,94 +6,99 @@ import matplotlib.pyplot as plt import os import time + def save(path, remove_axis=False, dpi=300, fig=None): - if fig is None: - fig = plt.gcf() - dirname = os.path.dirname(path) - if dirname != '' and not os.path.exists(dirname): - os.makedirs(dirname) - if remove_axis: - for ax in fig.axes: - ax.axis('off') - ax.margins(0,0) - fig.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) - for ax in fig.axes: - ax.xaxis.set_major_locator(plt.NullLocator()) - ax.yaxis.set_major_locator(plt.NullLocator()) - fig.savefig(path, dpi=dpi, bbox_inches='tight', pad_inches=0) + if fig is None: + fig = plt.gcf() + dirname = os.path.dirname(path) + if dirname != '' and not os.path.exists(dirname): + os.makedirs(dirname) + if remove_axis: + for ax in fig.axes: + ax.axis('off') + ax.margins(0, 0) + fig.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) + for ax in fig.axes: + ax.xaxis.set_major_locator(plt.NullLocator()) + ax.yaxis.set_major_locator(plt.NullLocator()) + fig.savefig(path, dpi=dpi, bbox_inches='tight', pad_inches=0) + def color_map(im_, cmap='viridis', vmin=None, vmax=None): - cm = plt.get_cmap(cmap) - im = im_.copy() - if vmin is None: - vmin = np.nanmin(im) - if vmax is None: - vmax = np.nanmax(im) - mask = np.logical_not(np.isfinite(im)) - im[mask] = vmin - im = (im.clip(vmin, vmax) - vmin) / (vmax - vmin) - im = cm(im) - im = im[...,:3] - for c in range(3): - im[mask, c] = 1 - return im + cm = plt.get_cmap(cmap) + im = im_.copy() + if vmin is None: + vmin = np.nanmin(im) + if vmax is None: + vmax = np.nanmax(im) + mask = np.logical_not(np.isfinite(im)) + im[mask] = vmin + im = (im.clip(vmin, vmax) - vmin) / (vmax - vmin) + im = cm(im) + im = im[..., :3] + for c in range(3): + im[mask, c] = 1 + return im + def interactive_legend(leg=None, fig=None, all_axes=True): - if leg is None: - leg = plt.legend() - if fig is None: - fig = plt.gcf() - if all_axes: - axs = fig.get_axes() - else: - axs = [fig.gca()] + if leg is None: + leg = plt.legend() + if fig is None: + fig = plt.gcf() + if all_axes: + axs = fig.get_axes() + else: + axs = [fig.gca()] - # lined = dict() - # lines = ax.lines - # for legline, origline in zip(leg.get_lines(), ax.lines): - # legline.set_picker(5) - # lined[legline] = origline - lined = dict() - for lidx, legline in enumerate(leg.get_lines()): - legline.set_picker(5) - lined[legline] = [ax.lines[lidx] for ax in axs] + # lined = dict() + # lines = ax.lines + # for legline, origline in zip(leg.get_lines(), ax.lines): + # legline.set_picker(5) + # lined[legline] = origline + lined = dict() + for lidx, legline in enumerate(leg.get_lines()): + legline.set_picker(5) + lined[legline] = [ax.lines[lidx] for ax in axs] - def onpick(event): - if event.mouseevent.dblclick: - tmp = [(k,v) for k,v in lined.items()] - else: - tmp = [(event.artist, lined[event.artist])] + def onpick(event): + if event.mouseevent.dblclick: + tmp = [(k, v) for k, v in lined.items()] + else: + tmp = [(event.artist, lined[event.artist])] + + for legline, origline in tmp: + for ol in origline: + vis = not ol.get_visible() + ol.set_visible(vis) + if vis: + legline.set_alpha(1.0) + else: + legline.set_alpha(0.2) + fig.canvas.draw() - for legline, origline in tmp: - for ol in origline: - vis = not ol.get_visible() - ol.set_visible(vis) - if vis: - legline.set_alpha(1.0) - else: - legline.set_alpha(0.2) - fig.canvas.draw() + fig.canvas.mpl_connect('pick_event', onpick) - fig.canvas.mpl_connect('pick_event', onpick) def non_annoying_pause(interval, focus_figure=False): - # https://github.com/matplotlib/matplotlib/issues/11131 - backend = mpl.rcParams['backend'] - if backend in _interactive_bk: - figManager = _pylab_helpers.Gcf.get_active() - if figManager is not None: - canvas = figManager.canvas - if canvas.figure.stale: - canvas.draw() - if focus_figure: - plt.show(block=False) - canvas.start_event_loop(interval) - return - time.sleep(interval) + # https://github.com/matplotlib/matplotlib/issues/11131 + backend = mpl.rcParams['backend'] + if backend in _interactive_bk: + figManager = _pylab_helpers.Gcf.get_active() + if figManager is not None: + canvas = figManager.canvas + if canvas.figure.stale: + canvas.draw() + if focus_figure: + plt.show(block=False) + canvas.start_event_loop(interval) + return + time.sleep(interval) + def remove_all_ticks(fig=None): - if fig is None: - fig = plt.gcf() - for ax in fig.axes: - ax.axes.get_xaxis().set_visible(False) - ax.axes.get_yaxis().set_visible(False) + if fig is None: + fig = plt.gcf() + for ax in fig.axes: + ax.axes.get_xaxis().set_visible(False) + ax.axes.get_yaxis().set_visible(False) diff --git a/co/plt2d.py b/co/plt2d.py index de38aa1..e40122d 100644 --- a/co/plt2d.py +++ b/co/plt2d.py @@ -3,55 +3,60 @@ import matplotlib.pyplot as plt from . import geometry + def image_matrix(ims, bgval=0): - n = ims.shape[0] - m = int( np.ceil(np.sqrt(n)) ) - h = ims.shape[1] - w = ims.shape[2] - mat = np.empty((m*h, m*w), dtype=ims.dtype) - mat.fill(bgval) - idx = 0 - for r in range(m): - for c in range(m): - if idx < n: - mat[r*h:(r+1)*h, c*w:(c+1)*w] = ims[idx] - idx += 1 - return mat + n = ims.shape[0] + m = int(np.ceil(np.sqrt(n))) + h = ims.shape[1] + w = ims.shape[2] + mat = np.empty((m * h, m * w), dtype=ims.dtype) + mat.fill(bgval) + idx = 0 + for r in range(m): + for c in range(m): + if idx < n: + mat[r * h:(r + 1) * h, c * w:(c + 1) * w] = ims[idx] + idx += 1 + return mat + def image_cat(ims, vertical=False): - offx = [0] - offy = [0] - if vertical: - width = max([im.shape[1] for im in ims]) - offx += [0 for im in ims[:-1]] - offy += [im.shape[0] for im in ims[:-1]] - height = sum([im.shape[0] for im in ims]) - else: - height = max([im.shape[0] for im in ims]) - offx += [im.shape[1] for im in ims[:-1]] - offy += [0 for im in ims[:-1]] - width = sum([im.shape[1] for im in ims]) - offx = np.cumsum(offx) - offy = np.cumsum(offy) - - im = np.zeros((height,width,*ims[0].shape[2:]), dtype=ims[0].dtype) - for im0, ox, oy in zip(ims, offx, offy): - im[oy:oy + im0.shape[0], ox:ox + im0.shape[1]] = im0 - - return im, offx, offy + offx = [0] + offy = [0] + if vertical: + width = max([im.shape[1] for im in ims]) + offx += [0 for im in ims[:-1]] + offy += [im.shape[0] for im in ims[:-1]] + height = sum([im.shape[0] for im in ims]) + else: + height = max([im.shape[0] for im in ims]) + offx += [im.shape[1] for im in ims[:-1]] + offy += [0 for im in ims[:-1]] + width = sum([im.shape[1] for im in ims]) + offx = np.cumsum(offx) + offy = np.cumsum(offy) + + im = np.zeros((height, width, *ims[0].shape[2:]), dtype=ims[0].dtype) + for im0, ox, oy in zip(ims, offx, offy): + im[oy:oy + im0.shape[0], ox:ox + im0.shape[1]] = im0 + + return im, offx, offy + def line(li, h, w, ax=None, *args, **kwargs): - if ax is None: - ax = plt.gca() - xs = (-li[2] - li[1] * np.array((0, h-1))) / li[0] - ys = (-li[2] - li[0] * np.array((0, w-1))) / li[1] - pts = np.array([(0,ys[0]), (w-1, ys[1]), (xs[0], 0), (xs[1], h-1)]) - pts = pts[np.logical_and(np.logical_and(pts[:,0] >= 0, pts[:,0] < w), np.logical_and(pts[:,1] >= 0, pts[:,1] < h))] - ax.plot(pts[:,0], pts[:,1], *args, **kwargs) + if ax is None: + ax = plt.gca() + xs = (-li[2] - li[1] * np.array((0, h - 1))) / li[0] + ys = (-li[2] - li[0] * np.array((0, w - 1))) / li[1] + pts = np.array([(0, ys[0]), (w - 1, ys[1]), (xs[0], 0), (xs[1], h - 1)]) + pts = pts[ + np.logical_and(np.logical_and(pts[:, 0] >= 0, pts[:, 0] < w), np.logical_and(pts[:, 1] >= 0, pts[:, 1] < h))] + ax.plot(pts[:, 0], pts[:, 1], *args, **kwargs) + def depthshow(depth, *args, ax=None, **kwargs): - if ax is None: - ax = plt.gca() - d = depth.copy() - d[d < 0] = np.NaN - ax.imshow(d, *args, **kwargs) + if ax is None: + ax = plt.gca() + d = depth.copy() + d[d < 0] = np.NaN + ax.imshow(d, *args, **kwargs) diff --git a/co/plt3d.py b/co/plt3d.py index 8c3eb1c..401c9c4 100644 --- a/co/plt3d.py +++ b/co/plt3d.py @@ -4,35 +4,45 @@ from mpl_toolkits.mplot3d import Axes3D from . import geometry + def ax3d(fig=None): - if fig is None: - fig = plt.gcf() - return fig.add_subplot(111, projection='3d') - -def plot_camera(ax=None, R=np.eye(3), t=np.zeros((3,)), size=25, marker_C='.', color='b', linestyle='-', linewidth=0.1, label=None, **kwargs): - if ax is None: - ax = plt.gca() - C0 = geometry.translation_to_cameracenter(R, t).ravel() - C1 = C0 + R.T.dot( np.array([[-size],[-size],[3*size]], dtype=np.float32) ).ravel() - C2 = C0 + R.T.dot( np.array([[-size],[+size],[3*size]], dtype=np.float32) ).ravel() - C3 = C0 + R.T.dot( np.array([[+size],[+size],[3*size]], dtype=np.float32) ).ravel() - C4 = C0 + R.T.dot( np.array([[+size],[-size],[3*size]], dtype=np.float32) ).ravel() - - if marker_C != '': - ax.plot([C0[0]], [C0[1]], [C0[2]], marker=marker_C, color=color, label=label, **kwargs) - ax.plot([C0[0], C1[0]], [C0[1], C1[1]], [C0[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs) - ax.plot([C0[0], C2[0]], [C0[1], C2[1]], [C0[2], C2[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs) - ax.plot([C0[0], C3[0]], [C0[1], C3[1]], [C0[2], C3[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs) - ax.plot([C0[0], C4[0]], [C0[1], C4[1]], [C0[2], C4[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs) - ax.plot([C1[0], C2[0], C3[0], C4[0], C1[0]], [C1[1], C2[1], C3[1], C4[1], C1[1]], [C1[2], C2[2], C3[2], C4[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs) + if fig is None: + fig = plt.gcf() + return fig.add_subplot(111, projection='3d') + + +def plot_camera(ax=None, R=np.eye(3), t=np.zeros((3,)), size=25, marker_C='.', color='b', linestyle='-', linewidth=0.1, + label=None, **kwargs): + if ax is None: + ax = plt.gca() + C0 = geometry.translation_to_cameracenter(R, t).ravel() + C1 = C0 + R.T.dot(np.array([[-size], [-size], [3 * size]], dtype=np.float32)).ravel() + C2 = C0 + R.T.dot(np.array([[-size], [+size], [3 * size]], dtype=np.float32)).ravel() + C3 = C0 + R.T.dot(np.array([[+size], [+size], [3 * size]], dtype=np.float32)).ravel() + C4 = C0 + R.T.dot(np.array([[+size], [-size], [3 * size]], dtype=np.float32)).ravel() + + if marker_C != '': + ax.plot([C0[0]], [C0[1]], [C0[2]], marker=marker_C, color=color, label=label, **kwargs) + ax.plot([C0[0], C1[0]], [C0[1], C1[1]], [C0[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle, + linewidth=linewidth, **kwargs) + ax.plot([C0[0], C2[0]], [C0[1], C2[1]], [C0[2], C2[2]], color=color, label='_nolegend_', linestyle=linestyle, + linewidth=linewidth, **kwargs) + ax.plot([C0[0], C3[0]], [C0[1], C3[1]], [C0[2], C3[2]], color=color, label='_nolegend_', linestyle=linestyle, + linewidth=linewidth, **kwargs) + ax.plot([C0[0], C4[0]], [C0[1], C4[1]], [C0[2], C4[2]], color=color, label='_nolegend_', linestyle=linestyle, + linewidth=linewidth, **kwargs) + ax.plot([C1[0], C2[0], C3[0], C4[0], C1[0]], [C1[1], C2[1], C3[1], C4[1], C1[1]], + [C1[2], C2[2], C3[2], C4[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle, + linewidth=linewidth, **kwargs) + def axis_equal(ax=None): - if ax is None: - ax = plt.gca() - extents = np.array([getattr(ax, 'get_{}lim'.format(dim))() for dim in 'xyz']) - sz = extents[:,1] - extents[:,0] - centers = np.mean(extents, axis=1) - maxsize = max(abs(sz)) - r = maxsize/2 - for ctr, dim in zip(centers, 'xyz'): - getattr(ax, 'set_{}lim'.format(dim))(ctr - r, ctr + r) + if ax is None: + ax = plt.gca() + extents = np.array([getattr(ax, 'get_{}lim'.format(dim))() for dim in 'xyz']) + sz = extents[:, 1] - extents[:, 0] + centers = np.mean(extents, axis=1) + maxsize = max(abs(sz)) + r = maxsize / 2 + for ctr, dim in zip(centers, 'xyz'): + getattr(ax, 'set_{}lim'.format(dim))(ctr - r, ctr + r) diff --git a/co/table.py b/co/table.py index 7b45aae..7af7f0f 100644 --- a/co/table.py +++ b/co/table.py @@ -3,443 +3,453 @@ import pandas as pd import enum import itertools + class Table(object): - def __init__(self, n_cols): - self.n_cols = n_cols - self.rows = [] - self.aligns = ['r' for c in range(n_cols)] - - def get_cell_align(self, r, c): - align = self.rows[r].cells[c].align - if align is None: - return self.aligns[c] - else: - return align - - def add_row(self, row): - if row.ncols() != self.n_cols: - raise Exception(f'row has invalid number of cols, {row.ncols()} vs. {self.n_cols}') - self.rows.append(row) - - def empty_row(self): - return Row.Empty(self.n_cols) - - def expand_rows(self, n_add_cols=1): - if n_add_cols < 0: raise Exception('n_add_cols has to be positive') - self.n_cols += n_add_cols - for row in self.rows: - row.cells.extend([Cell() for cidx in range(n_add_cols)]) - - def add_block(self, data, row=-1, col=0, fmt=None, expand=False): - if row < 0: row = len(self.rows) - while len(self.rows) < row + len(data): - self.add_row(self.empty_row()) - for r in range(len(data)): - cols = data[r] - if col + len(cols) > self.n_cols: - if expand: - self.expand_rows(col + len(cols) - self.n_cols) + def __init__(self, n_cols): + self.n_cols = n_cols + self.rows = [] + self.aligns = ['r' for c in range(n_cols)] + + def get_cell_align(self, r, c): + align = self.rows[r].cells[c].align + if align is None: + return self.aligns[c] else: - raise Exception('number of cols does not fit in table') - for c in range(len(cols)): - self.rows[row+r].cells[col+c] = Cell(data[r][c], fmt) + return align + + def add_row(self, row): + if row.ncols() != self.n_cols: + raise Exception(f'row has invalid number of cols, {row.ncols()} vs. {self.n_cols}') + self.rows.append(row) + + def empty_row(self): + return Row.Empty(self.n_cols) + + def expand_rows(self, n_add_cols=1): + if n_add_cols < 0: raise Exception('n_add_cols has to be positive') + self.n_cols += n_add_cols + for row in self.rows: + row.cells.extend([Cell() for cidx in range(n_add_cols)]) + + def add_block(self, data, row=-1, col=0, fmt=None, expand=False): + if row < 0: row = len(self.rows) + while len(self.rows) < row + len(data): + self.add_row(self.empty_row()) + for r in range(len(data)): + cols = data[r] + if col + len(cols) > self.n_cols: + if expand: + self.expand_rows(col + len(cols) - self.n_cols) + else: + raise Exception('number of cols does not fit in table') + for c in range(len(cols)): + self.rows[row + r].cells[col + c] = Cell(data[r][c], fmt) -class Row(object): - def __init__(self, cells, pre_separator=None, post_separator=None): - self.cells = cells - self.pre_separator = pre_separator - self.post_separator = post_separator - @classmethod - def Empty(cls, n_cols): - return Row([Cell() for c in range(n_cols)]) +class Row(object): + def __init__(self, cells, pre_separator=None, post_separator=None): + self.cells = cells + self.pre_separator = pre_separator + self.post_separator = post_separator - def add_cell(self, cell): - self.cells.append(cell) + @classmethod + def Empty(cls, n_cols): + return Row([Cell() for c in range(n_cols)]) - def ncols(self): - return sum([c.span for c in self.cells]) + def add_cell(self, cell): + self.cells.append(cell) + def ncols(self): + return sum([c.span for c in self.cells]) class Color(object): - def __init__(self, color=(0,0,0), fmt='rgb'): - if fmt == 'rgb': - self.color = color - elif fmt == 'RGB': - self.color = tuple(c / 255 for c in color) - else: - return Exception('invalid color format') + def __init__(self, color=(0, 0, 0), fmt='rgb'): + if fmt == 'rgb': + self.color = color + elif fmt == 'RGB': + self.color = tuple(c / 255 for c in color) + else: + return Exception('invalid color format') - def as_rgb(self): - return self.color + def as_rgb(self): + return self.color - def as_RGB(self): - return tuple(int(c * 255) for c in self.color) + def as_RGB(self): + return tuple(int(c * 255) for c in self.color) - @classmethod - def rgb(cls, r, g, b): - return Color(color=(r,g,b), fmt='rgb') + @classmethod + def rgb(cls, r, g, b): + return Color(color=(r, g, b), fmt='rgb') - @classmethod - def RGB(cls, r, g, b): - return Color(color=(r,g,b), fmt='RGB') + @classmethod + def RGB(cls, r, g, b): + return Color(color=(r, g, b), fmt='RGB') class CellFormat(object): - def __init__(self, fmt='%s', fgcolor=None, bgcolor=None, bold=False): - self.fmt = fmt - self.fgcolor = fgcolor - self.bgcolor = bgcolor - self.bold = bold + def __init__(self, fmt='%s', fgcolor=None, bgcolor=None, bold=False): + self.fmt = fmt + self.fgcolor = fgcolor + self.bgcolor = bgcolor + self.bold = bold + class Cell(object): - def __init__(self, data=None, fmt=None, span=1, align=None): - self.data = data - if fmt is None: - fmt = CellFormat() - self.fmt = fmt - self.span = span - self.align = align + def __init__(self, data=None, fmt=None, span=1, align=None): + self.data = data + if fmt is None: + fmt = CellFormat() + self.fmt = fmt + self.span = span + self.align = align + + def __str__(self): + return self.fmt.fmt % self.data - def __str__(self): - return self.fmt.fmt % self.data class Separator(enum.Enum): - HEAD = 1 - BOTTOM = 2 - INNER = 3 + HEAD = 1 + BOTTOM = 2 + INNER = 3 class Renderer(object): - def __init__(self): - pass + def __init__(self): + pass - def cell_str_len(self, cell): - return len(str(cell)) + def cell_str_len(self, cell): + return len(str(cell)) - def col_widths(self, table): - widths = [0 for c in range(table.n_cols)] - for row in table.rows: - cidx = 0 - for cell in row.cells: - if cell.span == 1: - strlen = self.cell_str_len(cell) - widths[cidx] = max(widths[cidx], strlen) - cidx += cell.span - return widths + def col_widths(self, table): + widths = [0 for c in range(table.n_cols)] + for row in table.rows: + cidx = 0 + for cell in row.cells: + if cell.span == 1: + strlen = self.cell_str_len(cell) + widths[cidx] = max(widths[cidx], strlen) + cidx += cell.span + return widths - def render(self, table): - raise NotImplementedError('not implemented') + def render(self, table): + raise NotImplementedError('not implemented') - def __call__(self, table): - return self.render(table) + def __call__(self, table): + return self.render(table) - def render_to_file_comment(self): - return '' + def render_to_file_comment(self): + return '' + + def render_to_file(self, path, table): + txt = self.render(table) + with open(path, 'w') as fp: + fp.write(txt) - def render_to_file(self, path, table): - txt = self.render(table) - with open(path, 'w') as fp: - fp.write(txt) class TerminalRenderer(Renderer): - def __init__(self, col_sep=' '): - super().__init__() - self.col_sep = col_sep - - def render_cell(self, table, row, col, widths): - cell = table.rows[row].cells[col] - str = cell.fmt.fmt % cell.data - str_width = len(str) - cell_width = sum([widths[idx] for idx in range(col, col+cell.span)]) - cell_width += len(self.col_sep) * (cell.span - 1) - if len(str) > cell_width: - str = str[:cell_width] - if cell.fmt.bold: - # str = sty.ef.bold + str + sty.rs.bold_dim - # str = sty.ef.bold + str + sty.rs.bold - pass - if cell.fmt.fgcolor is not None: - # color = cell.fmt.fgcolor.as_RGB() - # str = sty.fg(*color) + str + sty.rs.fg - pass - if str_width < cell_width: - n_ws = (cell_width - str_width) - if table.get_cell_align(row, col) == 'r': - str = ' '*n_ws + str - elif table.get_cell_align(row, col) == 'l': - str = str + ' '*n_ws - elif table.get_cell_align(row, col) == 'c': - n_ws1 = n_ws // 2 - n_ws0 = n_ws - n_ws1 - str = ' '*n_ws0 + str + ' '*n_ws1 - if cell.fmt.bgcolor is not None: - # color = cell.fmt.bgcolor.as_RGB() - # str = sty.bg(*color) + str + sty.rs.bg - pass - return str - - def render_separator(self, separator, tab, col_widths, total_width): - if separator == Separator.HEAD: - return '='*total_width - elif separator == Separator.INNER: - return '-'*total_width - elif separator == Separator.BOTTOM: - return '='*total_width - - def render(self, table): - widths = self.col_widths(table) - total_width = sum(widths) + len(self.col_sep) * (table.n_cols - 1) - lines = [] - for ridx, row in enumerate(table.rows): - if row.pre_separator is not None: - sepline = self.render_separator(row.pre_separator, table, widths, total_width) - if len(sepline) > 0: - lines.append(sepline) - line = [] - for cidx, cell in enumerate(row.cells): - line.append(self.render_cell(table, ridx, cidx, widths)) - lines.append(self.col_sep.join(line)) - if row.post_separator is not None: - sepline = self.render_separator(row.post_separator, table, widths, total_width) - if len(sepline) > 0: - lines.append(sepline) - return '\n'.join(lines) + def __init__(self, col_sep=' '): + super().__init__() + self.col_sep = col_sep + + def render_cell(self, table, row, col, widths): + cell = table.rows[row].cells[col] + str = cell.fmt.fmt % cell.data + str_width = len(str) + cell_width = sum([widths[idx] for idx in range(col, col + cell.span)]) + cell_width += len(self.col_sep) * (cell.span - 1) + if len(str) > cell_width: + str = str[:cell_width] + if cell.fmt.bold: + # str = sty.ef.bold + str + sty.rs.bold_dim + # str = sty.ef.bold + str + sty.rs.bold + pass + if cell.fmt.fgcolor is not None: + # color = cell.fmt.fgcolor.as_RGB() + # str = sty.fg(*color) + str + sty.rs.fg + pass + if str_width < cell_width: + n_ws = (cell_width - str_width) + if table.get_cell_align(row, col) == 'r': + str = ' ' * n_ws + str + elif table.get_cell_align(row, col) == 'l': + str = str + ' ' * n_ws + elif table.get_cell_align(row, col) == 'c': + n_ws1 = n_ws // 2 + n_ws0 = n_ws - n_ws1 + str = ' ' * n_ws0 + str + ' ' * n_ws1 + if cell.fmt.bgcolor is not None: + # color = cell.fmt.bgcolor.as_RGB() + # str = sty.bg(*color) + str + sty.rs.bg + pass + return str + + def render_separator(self, separator, tab, col_widths, total_width): + if separator == Separator.HEAD: + return '=' * total_width + elif separator == Separator.INNER: + return '-' * total_width + elif separator == Separator.BOTTOM: + return '=' * total_width + + def render(self, table): + widths = self.col_widths(table) + total_width = sum(widths) + len(self.col_sep) * (table.n_cols - 1) + lines = [] + for ridx, row in enumerate(table.rows): + if row.pre_separator is not None: + sepline = self.render_separator(row.pre_separator, table, widths, total_width) + if len(sepline) > 0: + lines.append(sepline) + line = [] + for cidx, cell in enumerate(row.cells): + line.append(self.render_cell(table, ridx, cidx, widths)) + lines.append(self.col_sep.join(line)) + if row.post_separator is not None: + sepline = self.render_separator(row.post_separator, table, widths, total_width) + if len(sepline) > 0: + lines.append(sepline) + return '\n'.join(lines) + class MarkdownRenderer(TerminalRenderer): - def __init__(self): - super().__init__(col_sep='|') - self.printed_color_warning = False - - def print_color_warning(self): - if not self.printed_color_warning: - print('[WARNING] MarkdownRenderer does not support color yet') - self.printed_color_warning = True - - def cell_str_len(self, cell): - strlen = len(str(cell)) - if cell.fmt.bold: - strlen += 4 - strlen = max(5, strlen) - return strlen - - def render_cell(self, table, row, col, widths): - cell = table.rows[row].cells[col] - str = cell.fmt.fmt % cell.data - if cell.fmt.bold: - str = f'**{str}**' - - str_width = len(str) - cell_width = sum([widths[idx] for idx in range(col, col+cell.span)]) - cell_width += len(self.col_sep) * (cell.span - 1) - if len(str) > cell_width: - str = str[:cell_width] - else: - n_ws = (cell_width - str_width) - if table.get_cell_align(row, col) == 'r': - str = ' '*n_ws + str - elif table.get_cell_align(row, col) == 'l': - str = str + ' '*n_ws - elif table.get_cell_align(row, col) == 'c': - n_ws1 = n_ws // 2 - n_ws0 = n_ws - n_ws1 - str = ' '*n_ws0 + str + ' '*n_ws1 - - if col == 0: str = self.col_sep + str - if col == table.n_cols - 1: str += self.col_sep - - if cell.fmt.fgcolor is not None: - self.print_color_warning() - if cell.fmt.bgcolor is not None: - self.print_color_warning() - return str - - def render_separator(self, separator, tab, widths, total_width): - sep = '' - if separator == Separator.INNER: - sep = self.col_sep - for idx, width in enumerate(widths): - csep = '-' * (width - 2) - if tab.get_cell_align(1, idx) == 'r': - csep = '-' + csep + ':' - elif tab.get_cell_align(1, idx) == 'l': - csep = ':' + csep + '-' - elif tab.get_cell_align(1, idx) == 'c': - csep = ':' + csep + ':' - sep += csep + self.col_sep - return sep + def __init__(self): + super().__init__(col_sep='|') + self.printed_color_warning = False + + def print_color_warning(self): + if not self.printed_color_warning: + print('[WARNING] MarkdownRenderer does not support color yet') + self.printed_color_warning = True + + def cell_str_len(self, cell): + strlen = len(str(cell)) + if cell.fmt.bold: + strlen += 4 + strlen = max(5, strlen) + return strlen + + def render_cell(self, table, row, col, widths): + cell = table.rows[row].cells[col] + str = cell.fmt.fmt % cell.data + if cell.fmt.bold: + str = f'**{str}**' + + str_width = len(str) + cell_width = sum([widths[idx] for idx in range(col, col + cell.span)]) + cell_width += len(self.col_sep) * (cell.span - 1) + if len(str) > cell_width: + str = str[:cell_width] + else: + n_ws = (cell_width - str_width) + if table.get_cell_align(row, col) == 'r': + str = ' ' * n_ws + str + elif table.get_cell_align(row, col) == 'l': + str = str + ' ' * n_ws + elif table.get_cell_align(row, col) == 'c': + n_ws1 = n_ws // 2 + n_ws0 = n_ws - n_ws1 + str = ' ' * n_ws0 + str + ' ' * n_ws1 + + if col == 0: str = self.col_sep + str + if col == table.n_cols - 1: str += self.col_sep + + if cell.fmt.fgcolor is not None: + self.print_color_warning() + if cell.fmt.bgcolor is not None: + self.print_color_warning() + return str + + def render_separator(self, separator, tab, widths, total_width): + sep = '' + if separator == Separator.INNER: + sep = self.col_sep + for idx, width in enumerate(widths): + csep = '-' * (width - 2) + if tab.get_cell_align(1, idx) == 'r': + csep = '-' + csep + ':' + elif tab.get_cell_align(1, idx) == 'l': + csep = ':' + csep + '-' + elif tab.get_cell_align(1, idx) == 'c': + csep = ':' + csep + ':' + sep += csep + self.col_sep + return sep class LatexRenderer(Renderer): - def __init__(self): - super().__init__() - - def render_cell(self, table, row, col): - cell = table.rows[row].cells[col] - str = cell.fmt.fmt % cell.data - if cell.fmt.bold: - str = '{\\bf '+ str + '}' - if cell.fmt.fgcolor is not None: - color = cell.fmt.fgcolor.as_rgb() - str = f'{{\\color[rgb]{{{color[0]},{color[1]},{color[2]}}} ' + str + '}' - if cell.fmt.bgcolor is not None: - color = cell.fmt.bgcolor.as_rgb() - str = f'\\cellcolor[rgb]{{{color[0]},{color[1]},{color[2]}}} ' + str - align = table.get_cell_align(row, col) - if cell.span != 1 or align != table.aligns[col]: - str = f'\\multicolumn{{{cell.span}}}{{{align}}}{{{str}}}' - return str - - def render_separator(self, separator): - if separator == Separator.HEAD: - return '\\toprule' - elif separator == Separator.INNER: - return '\\midrule' - elif separator == Separator.BOTTOM: - return '\\bottomrule' - - def render(self, table): - lines = ['\\begin{tabular}{' + ''.join(table.aligns) + '}'] - for ridx, row in enumerate(table.rows): - if row.pre_separator is not None: - lines.append(self.render_separator(row.pre_separator)) - line = [] - for cidx, cell in enumerate(row.cells): - line.append(self.render_cell(table, ridx, cidx)) - lines.append(' & '.join(line) + ' \\\\') - if row.post_separator is not None: - lines.append(self.render_separator(row.post_separator)) - lines.append('\\end{tabular}') - return '\n'.join(lines) + def __init__(self): + super().__init__() + + def render_cell(self, table, row, col): + cell = table.rows[row].cells[col] + str = cell.fmt.fmt % cell.data + if cell.fmt.bold: + str = '{\\bf ' + str + '}' + if cell.fmt.fgcolor is not None: + color = cell.fmt.fgcolor.as_rgb() + str = f'{{\\color[rgb]{{{color[0]},{color[1]},{color[2]}}} ' + str + '}' + if cell.fmt.bgcolor is not None: + color = cell.fmt.bgcolor.as_rgb() + str = f'\\cellcolor[rgb]{{{color[0]},{color[1]},{color[2]}}} ' + str + align = table.get_cell_align(row, col) + if cell.span != 1 or align != table.aligns[col]: + str = f'\\multicolumn{{{cell.span}}}{{{align}}}{{{str}}}' + return str + + def render_separator(self, separator): + if separator == Separator.HEAD: + return '\\toprule' + elif separator == Separator.INNER: + return '\\midrule' + elif separator == Separator.BOTTOM: + return '\\bottomrule' + + def render(self, table): + lines = ['\\begin{tabular}{' + ''.join(table.aligns) + '}'] + for ridx, row in enumerate(table.rows): + if row.pre_separator is not None: + lines.append(self.render_separator(row.pre_separator)) + line = [] + for cidx, cell in enumerate(row.cells): + line.append(self.render_cell(table, ridx, cidx)) + lines.append(' & '.join(line) + ' \\\\') + if row.post_separator is not None: + lines.append(self.render_separator(row.post_separator)) + lines.append('\\end{tabular}') + return '\n'.join(lines) -class HtmlRenderer(Renderer): - def __init__(self, html_class='result_table'): - super().__init__() - self.html_class = html_class - - def render_cell(self, table, row, col): - cell = table.rows[row].cells[col] - str = cell.fmt.fmt % cell.data - styles = [] - if cell.fmt.bold: - styles.append('font-weight: bold;') - if cell.fmt.fgcolor is not None: - color = cell.fmt.fgcolor.as_RGB() - styles.append(f'color: rgb({color[0]},{color[1]},{color[2]});') - if cell.fmt.bgcolor is not None: - color = cell.fmt.bgcolor.as_RGB() - styles.append(f'background-color: rgb({color[0]},{color[1]},{color[2]});') - align = table.get_cell_align(row, col) - if align == 'l': align = 'left' - elif align == 'r': align = 'right' - elif align == 'c': align = 'center' - else: raise Exception('invalid align') - styles.append(f'text-align: {align};') - row = table.rows[row] - if row.pre_separator is not None: - styles.append(f'border-top: {self.render_separator(row.pre_separator)};') - if row.post_separator is not None: - styles.append(f'border-bottom: {self.render_separator(row.post_separator)};') - style = ' '.join(styles) - str = f' {str}\n' - return str - - def render_separator(self, separator): - if separator == Separator.HEAD: - return '1.5pt solid black' - elif separator == Separator.INNER: - return '0.75pt solid black' - elif separator == Separator.BOTTOM: - return '1.5pt solid black' - - def render(self, table): - lines = [f''] - for ridx, row in enumerate(table.rows): - line = [f' \n'] - for cidx, cell in enumerate(row.cells): - line.append(self.render_cell(table, ridx, cidx)) - line.append(' \n') - lines.append(' '.join(line)) - lines.append('
') - return '\n'.join(lines) - - -def pandas_to_table(rowname, colname, valname, data, val_cell_fmt=CellFormat(fmt='%.4f'), best_val_cell_fmt=CellFormat(fmt='%.4f', bold=True), best_is_max=[]): - rnames = data[rowname].unique() - cnames = data[colname].unique() - tab = Table(1+len(cnames)) - - header = [Cell('', align='r')] - header.extend([Cell(h, align='r') for h in cnames]) - header = Row(header, pre_separator=Separator.HEAD, post_separator=Separator.INNER) - tab.add_row(header) - - for rname in rnames: - cells = [Cell(rname, align='l')] - for cname in cnames: - cdata = data[data[colname] == cname] - if cname in best_is_max: - bestval = cdata[valname].max() - val = cdata[cdata[rowname] == rname][valname].max() - else: - bestval = cdata[valname].min() - val = cdata[cdata[rowname] == rname][valname].min() - if val == bestval: - fmt = best_val_cell_fmt - else: - fmt = val_cell_fmt - cells.append(Cell(val, align='r', fmt=fmt)) - tab.add_row(Row(cells)) - tab.rows[-1].post_separator = Separator.BOTTOM - return tab +class HtmlRenderer(Renderer): + def __init__(self, html_class='result_table'): + super().__init__() + self.html_class = html_class + + def render_cell(self, table, row, col): + cell = table.rows[row].cells[col] + str = cell.fmt.fmt % cell.data + styles = [] + if cell.fmt.bold: + styles.append('font-weight: bold;') + if cell.fmt.fgcolor is not None: + color = cell.fmt.fgcolor.as_RGB() + styles.append(f'color: rgb({color[0]},{color[1]},{color[2]});') + if cell.fmt.bgcolor is not None: + color = cell.fmt.bgcolor.as_RGB() + styles.append(f'background-color: rgb({color[0]},{color[1]},{color[2]});') + align = table.get_cell_align(row, col) + if align == 'l': + align = 'left' + elif align == 'r': + align = 'right' + elif align == 'c': + align = 'center' + else: + raise Exception('invalid align') + styles.append(f'text-align: {align};') + row = table.rows[row] + if row.pre_separator is not None: + styles.append(f'border-top: {self.render_separator(row.pre_separator)};') + if row.post_separator is not None: + styles.append(f'border-bottom: {self.render_separator(row.post_separator)};') + style = ' '.join(styles) + str = f' {str}\n' + return str + + def render_separator(self, separator): + if separator == Separator.HEAD: + return '1.5pt solid black' + elif separator == Separator.INNER: + return '0.75pt solid black' + elif separator == Separator.BOTTOM: + return '1.5pt solid black' + + def render(self, table): + lines = [f''] + for ridx, row in enumerate(table.rows): + line = [f' \n'] + for cidx, cell in enumerate(row.cells): + line.append(self.render_cell(table, ridx, cidx)) + line.append(' \n') + lines.append(' '.join(line)) + lines.append('
') + return '\n'.join(lines) + + +def pandas_to_table(rowname, colname, valname, data, val_cell_fmt=CellFormat(fmt='%.4f'), + best_val_cell_fmt=CellFormat(fmt='%.4f', bold=True), best_is_max=[]): + rnames = data[rowname].unique() + cnames = data[colname].unique() + tab = Table(1 + len(cnames)) + + header = [Cell('', align='r')] + header.extend([Cell(h, align='r') for h in cnames]) + header = Row(header, pre_separator=Separator.HEAD, post_separator=Separator.INNER) + tab.add_row(header) + + for rname in rnames: + cells = [Cell(rname, align='l')] + for cname in cnames: + cdata = data[data[colname] == cname] + if cname in best_is_max: + bestval = cdata[valname].max() + val = cdata[cdata[rowname] == rname][valname].max() + else: + bestval = cdata[valname].min() + val = cdata[cdata[rowname] == rname][valname].min() + if val == bestval: + fmt = best_val_cell_fmt + else: + fmt = val_cell_fmt + cells.append(Cell(val, align='r', fmt=fmt)) + tab.add_row(Row(cells)) + tab.rows[-1].post_separator = Separator.BOTTOM + return tab if __name__ == '__main__': - # df = pd.read_pickle('full.df') - # best_is_max = ['movF0.5', 'movF1.0'] - # tab = pandas_to_table(rowname='method', colname='metric', valname='val', data=df, best_is_max=best_is_max) - - # renderer = TerminalRenderer() - # print(renderer(tab)) - - tab = Table(7) - # header = Row([Cell('header', span=7, align='c')], pre_separator=Separator.HEAD, post_separator=Separator.INNER) - # tab.add_row(header) - # header2 = Row([Cell('thisisaverylongheader', span=4, align='c'), Cell('vals2', span=3, align='c')], post_separator=Separator.INNER) - # tab.add_row(header2) - tab.add_row(Row([Cell(f'c{c}') for c in range(7)])) - tab.rows[-1].post_separator = Separator.INNER - tab.add_block(np.arange(15*7).reshape(15,7)) - tab.rows[4].cells[2].fmt = CellFormat(bold=True) - tab.rows[2].cells[1].fmt = CellFormat(fgcolor=Color.rgb(0.2,0.6,0.1)) - tab.rows[2].cells[2].fmt = CellFormat(bgcolor=Color.rgb(0.7,0.1,0.5)) - tab.rows[5].cells[3].fmt = CellFormat(bold=True,bgcolor=Color.rgb(0.7,0.1,0.5),fgcolor=Color.rgb(0.1,0.1,0.1)) - tab.rows[-1].post_separator = Separator.BOTTOM - - renderer = TerminalRenderer() - print(renderer(tab)) - renderer = MarkdownRenderer() - print(renderer(tab)) - - # renderer = HtmlRenderer() - # html_tab = renderer(tab) - # print(html_tab) - # with open('test.html', 'w') as fp: - # fp.write(html_tab) - - # import latex - - # renderer = LatexRenderer() - # ltx_tab = renderer(tab) - # print(ltx_tab) - - # with open('test.tex', 'w') as fp: - # latex.write_doc_prefix(fp, document_class='article') - # fp.write('this is text that should appear before the table and should be long enough to wrap around.\n'*40) - # fp.write('\\begin{table}') - # fp.write(ltx_tab) - # fp.write('\\end{table}') - # fp.write('this is text that should appear after the table and should be long enough to wrap around.\n'*40) - # latex.write_doc_suffix(fp) + # df = pd.read_pickle('full.df') + # best_is_max = ['movF0.5', 'movF1.0'] + # tab = pandas_to_table(rowname='method', colname='metric', valname='val', data=df, best_is_max=best_is_max) + + # renderer = TerminalRenderer() + # print(renderer(tab)) + + tab = Table(7) + # header = Row([Cell('header', span=7, align='c')], pre_separator=Separator.HEAD, post_separator=Separator.INNER) + # tab.add_row(header) + # header2 = Row([Cell('thisisaverylongheader', span=4, align='c'), Cell('vals2', span=3, align='c')], post_separator=Separator.INNER) + # tab.add_row(header2) + tab.add_row(Row([Cell(f'c{c}') for c in range(7)])) + tab.rows[-1].post_separator = Separator.INNER + tab.add_block(np.arange(15 * 7).reshape(15, 7)) + tab.rows[4].cells[2].fmt = CellFormat(bold=True) + tab.rows[2].cells[1].fmt = CellFormat(fgcolor=Color.rgb(0.2, 0.6, 0.1)) + tab.rows[2].cells[2].fmt = CellFormat(bgcolor=Color.rgb(0.7, 0.1, 0.5)) + tab.rows[5].cells[3].fmt = CellFormat(bold=True, bgcolor=Color.rgb(0.7, 0.1, 0.5), fgcolor=Color.rgb(0.1, 0.1, 0.1)) + tab.rows[-1].post_separator = Separator.BOTTOM + + renderer = TerminalRenderer() + print(renderer(tab)) + renderer = MarkdownRenderer() + print(renderer(tab)) + + # renderer = HtmlRenderer() + # html_tab = renderer(tab) + # print(html_tab) + # with open('test.html', 'w') as fp: + # fp.write(html_tab) + + # import latex + + # renderer = LatexRenderer() + # ltx_tab = renderer(tab) + # print(ltx_tab) + + # with open('test.tex', 'w') as fp: + # latex.write_doc_prefix(fp, document_class='article') + # fp.write('this is text that should appear before the table and should be long enough to wrap around.\n'*40) + # fp.write('\\begin{table}') + # fp.write(ltx_tab) + # fp.write('\\end{table}') + # fp.write('this is text that should appear after the table and should be long enough to wrap around.\n'*40) + # latex.write_doc_suffix(fp) diff --git a/co/utils.py b/co/utils.py index c113e72..cc8a4ce 100644 --- a/co/utils.py +++ b/co/utils.py @@ -8,6 +8,7 @@ import re import pickle import subprocess + def str2bool(v): if v.lower() in ('yes', 'true', 't', 'y', '1'): return True @@ -16,71 +17,74 @@ def str2bool(v): else: raise argparse.ArgumentTypeError('Boolean value expected.') + class StopWatch(object): - def __init__(self): - self.timings = OrderedDict() - self.starts = {} + def __init__(self): + self.timings = OrderedDict() + self.starts = {} - def start(self, name): - self.starts[name] = time.time() + def start(self, name): + self.starts[name] = time.time() - def stop(self, name): - if name not in self.timings: - self.timings[name] = [] - self.timings[name].append(time.time() - self.starts[name]) + def stop(self, name): + if name not in self.timings: + self.timings[name] = [] + self.timings[name].append(time.time() - self.starts[name]) - def get(self, name=None, reduce=np.sum): - if name is not None: - return reduce(self.timings[name]) - else: - ret = {} - for k in self.timings: - ret[k] = reduce(self.timings[k]) - return ret + def get(self, name=None, reduce=np.sum): + if name is not None: + return reduce(self.timings[name]) + else: + ret = {} + for k in self.timings: + ret[k] = reduce(self.timings[k]) + return ret + + def __repr__(self): + return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()]) + + def __str__(self): + return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()]) - def __repr__(self): - return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()]) - def __str__(self): - return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()]) class ETA(object): - def __init__(self, length): - self.length = length - self.start_time = time.time() - self.current_idx = 0 - self.current_time = time.time() + def __init__(self, length): + self.length = length + self.start_time = time.time() + self.current_idx = 0 + self.current_time = time.time() - def update(self, idx): - self.current_idx = idx - self.current_time = time.time() + def update(self, idx): + self.current_idx = idx + self.current_time = time.time() - def get_elapsed_time(self): - return self.current_time - self.start_time + def get_elapsed_time(self): + return self.current_time - self.start_time - def get_item_time(self): - return self.get_elapsed_time() / (self.current_idx + 1) + def get_item_time(self): + return self.get_elapsed_time() / (self.current_idx + 1) - def get_remaining_time(self): - return self.get_item_time() * (self.length - self.current_idx + 1) + def get_remaining_time(self): + return self.get_item_time() * (self.length - self.current_idx + 1) - def format_time(self, seconds): - minutes, seconds = divmod(seconds, 60) - hours, minutes = divmod(minutes, 60) - hours = int(hours) - minutes = int(minutes) - return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}' + def format_time(self, seconds): + minutes, seconds = divmod(seconds, 60) + hours, minutes = divmod(minutes, 60) + hours = int(hours) + minutes = int(minutes) + return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}' - def get_elapsed_time_str(self): - return self.format_time(self.get_elapsed_time()) + def get_elapsed_time_str(self): + return self.format_time(self.get_elapsed_time()) - def get_remaining_time_str(self): - return self.format_time(self.get_remaining_time()) + def get_remaining_time_str(self): + return self.format_time(self.get_remaining_time()) -def git_hash(cwd=None): - ret = subprocess.run(['git', 'describe', '--always'], cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - hash = ret.stdout - if hash is not None and 'fatal' not in hash.decode(): - return hash.decode().strip() - else: - return None +def git_hash(cwd=None): + ret = subprocess.run(['git', 'describe', '--always'], cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + hash = ret.stdout + if hash is not None and 'fatal' not in hash.decode(): + return hash.decode().strip() + else: + return None diff --git a/data/commons.py b/data/commons.py index da88ca9..5bebb8c 100644 --- a/data/commons.py +++ b/data/commons.py @@ -4,107 +4,109 @@ import cv2 def get_patterns(path='syn', imsizes=[], crop=True): - pattern_size = imsizes[0] - if path == 'syn': - np.random.seed(42) - pattern = np.random.uniform(0,1, size=pattern_size) - pattern = (pattern < 0.1).astype(np.float32) - pattern.reshape(*imsizes[0]) - else: - pattern = cv2.imread(path) - pattern = pattern.astype(np.float32) - pattern /= 255 - - if pattern.ndim == 2: - pattern = np.stack([pattern for idx in range(3)], axis=2) - - if crop and pattern.shape[0] > pattern_size[0] and pattern.shape[1] > pattern_size[1]: - r0 = (pattern.shape[0] - pattern_size[0]) // 2 - c0 = (pattern.shape[1] - pattern_size[1]) // 2 - pattern = pattern[r0:r0+imsizes[0][0], c0:c0+imsizes[0][1]] - - patterns = [] - for imsize in imsizes: - pat = cv2.resize(pattern, (imsize[1],imsize[0]), interpolation=cv2.INTER_LINEAR) - patterns.append(pat) - - return patterns + pattern_size = imsizes[0] + if path == 'syn': + np.random.seed(42) + pattern = np.random.uniform(0, 1, size=pattern_size) + pattern = (pattern < 0.1).astype(np.float32) + pattern.reshape(*imsizes[0]) + else: + pattern = cv2.imread(path) + pattern = pattern.astype(np.float32) + pattern /= 255 -def get_rotation_matrix(v0, v1): - v0 = v0/np.linalg.norm(v0) - v1 = v1/np.linalg.norm(v1) - v = np.cross(v0,v1) - c = np.dot(v0,v1) - s = np.linalg.norm(v) - I = np.eye(3) - vXStr = '{} {} {}; {} {} {}; {} {} {}'.format(0, -v[2], v[1], v[2], 0, -v[0], -v[1], v[0], 0) - k = np.matrix(vXStr) - r = I + k + k @ k * ((1 -c)/(s**2)) - return np.asarray(r.astype(np.float32)) + if pattern.ndim == 2: + pattern = np.stack([pattern for idx in range(3)], axis=2) + + if crop and pattern.shape[0] > pattern_size[0] and pattern.shape[1] > pattern_size[1]: + r0 = (pattern.shape[0] - pattern_size[0]) // 2 + c0 = (pattern.shape[1] - pattern_size[1]) // 2 + pattern = pattern[r0:r0 + imsizes[0][0], c0:c0 + imsizes[0][1]] + + patterns = [] + for imsize in imsizes: + pat = cv2.resize(pattern, (imsize[1], imsize[0]), interpolation=cv2.INTER_LINEAR) + patterns.append(pat) + + return patterns -def augment_image(img,rng,disp=None,grad=None,max_shift=64,max_blur=1.5,max_noise=10.0,max_sp_noise=0.001): +def get_rotation_matrix(v0, v1): + v0 = v0 / np.linalg.norm(v0) + v1 = v1 / np.linalg.norm(v1) + v = np.cross(v0, v1) + c = np.dot(v0, v1) + s = np.linalg.norm(v) + I = np.eye(3) + vXStr = '{} {} {}; {} {} {}; {} {} {}'.format(0, -v[2], v[1], v[2], 0, -v[0], -v[1], v[0], 0) + k = np.matrix(vXStr) + r = I + k + k @ k * ((1 - c) / (s ** 2)) + return np.asarray(r.astype(np.float32)) + +def augment_image(img, rng, disp=None, grad=None, max_shift=64, max_blur=1.5, max_noise=10.0, max_sp_noise=0.001): # get min/max values of image min_val = np.min(img) max_val = np.max(img) - + # init augmented image img_aug = img - + # init disparity correction map disp_aug = disp grad_aug = grad # apply affine transformation - if max_shift>1: - + if max_shift > 1: + # affine parameters - rows,cols = img.shape + rows, cols = img.shape shear = 0 shift = 0 shear_correction = 0 - if rng.uniform(0,1)<0.75: shear = rng.uniform(-max_shift,max_shift) # shear with 75% probability - else: shift = rng.uniform(0,max_shift) # shift with 25% probability - if shear<0: shear_correction = -shear - + if rng.uniform(0, 1) < 0.75: + shear = rng.uniform(-max_shift, max_shift) # shear with 75% probability + else: + shift = rng.uniform(0, max_shift) # shift with 25% probability + if shear < 0: shear_correction = -shear + # affine transformation - a = shear/float(rows) - b = shift+shear_correction - + a = shear / float(rows) + b = shift + shear_correction + # warp image - T = np.float32([[1,a,b],[0,1,0]]) - img_aug = cv2.warpAffine(img_aug,T,(cols,rows)) + T = np.float32([[1, a, b], [0, 1, 0]]) + img_aug = cv2.warpAffine(img_aug, T, (cols, rows)) if grad is not None: - grad_aug = cv2.warpAffine(grad,T,(cols,rows)) - + grad_aug = cv2.warpAffine(grad, T, (cols, rows)) + # disparity correction map - col = a*np.array(range(rows))+b - disp_delta = np.tile(col,(cols,1)).transpose() + col = a * np.array(range(rows)) + b + disp_delta = np.tile(col, (cols, 1)).transpose() if disp is not None: - disp_aug = cv2.warpAffine(disp+disp_delta,T,(cols,rows)) + disp_aug = cv2.warpAffine(disp + disp_delta, T, (cols, rows)) # gaussian smoothing - if rng.uniform(0,1)<0.5: - img_aug = cv2.GaussianBlur(img_aug,(5,5),rng.uniform(0.2,max_blur)) - + if rng.uniform(0, 1) < 0.5: + img_aug = cv2.GaussianBlur(img_aug, (5, 5), rng.uniform(0.2, max_blur)) + # per-pixel gaussian noise - img_aug = img_aug + rng.randn(*img_aug.shape)*rng.uniform(0.0,max_noise)/255.0 + img_aug = img_aug + rng.randn(*img_aug.shape) * rng.uniform(0.0, max_noise) / 255.0 # salt-and-pepper noise - if rng.uniform(0,1)<0.5: - ratio=rng.uniform(0.0,max_sp_noise) + if rng.uniform(0, 1) < 0.5: + ratio = rng.uniform(0.0, max_sp_noise) img_shape = img_aug.shape img_aug = img_aug.flatten() - coord = rng.choice(np.size(img_aug), int(np.size(img_aug)*ratio)) + coord = rng.choice(np.size(img_aug), int(np.size(img_aug) * ratio)) img_aug[coord] = max_val - coord = rng.choice(np.size(img_aug), int(np.size(img_aug)*ratio)) + coord = rng.choice(np.size(img_aug), int(np.size(img_aug) * ratio)) img_aug[coord] = min_val img_aug = np.reshape(img_aug, img_shape) - + # clip intensities back to [0,1] - img_aug = np.maximum(img_aug,0.0) - img_aug = np.minimum(img_aug,1.0) - + img_aug = np.maximum(img_aug, 0.0) + img_aug = np.minimum(img_aug, 1.0) + # return image return img_aug, disp_aug, grad_aug diff --git a/data/create_syn_data.py b/data/create_syn_data.py index d7b83f4..c392a9e 100644 --- a/data/create_syn_data.py +++ b/data/create_syn_data.py @@ -10,261 +10,259 @@ import cv2 import os import collections import sys + sys.path.append('../') import renderer import co -from commons import get_patterns,get_rotation_matrix +from commons import get_patterns, get_rotation_matrix from lcn import lcn -def get_objs(shapenet_dir, obj_classes, num_perclass=100): - shapenet = {'chair': '03001627', - 'airplane': '02691156', - 'car': '02958343', - 'watercraft': '04530566'} - - obj_paths = [] - for cls in obj_classes: - if cls not in shapenet.keys(): - raise Exception('unknown class name') - ids = shapenet[cls] - obj_path = sorted(Path(f'{shapenet_dir}/{ids}').glob('**/models/*.obj')) - obj_paths += obj_path[:num_perclass] - print(f'found {len(obj_paths)} object paths') - - objs = [] - for obj_path in obj_paths: - print(f'load {obj_path}') - v, f, _, n = co.io3d.read_obj(obj_path) - diffs = v.max(axis=0) - v.min(axis=0) - v /= (0.5 * diffs.max()) - v -= (v.min(axis=0) + 1) - f = f.astype(np.int32) - objs.append((v,f,n)) - print(f'loaded {len(objs)} objects') - - return objs +def get_objs(shapenet_dir, obj_classes, num_perclass=100): + shapenet = {'chair': '03001627', + 'airplane': '02691156', + 'car': '02958343', + 'watercraft': '04530566'} + + obj_paths = [] + for cls in obj_classes: + if cls not in shapenet.keys(): + raise Exception('unknown class name') + ids = shapenet[cls] + obj_path = sorted(Path(f'{shapenet_dir}/{ids}').glob('**/models/*.obj')) + obj_paths += obj_path[:num_perclass] + print(f'found {len(obj_paths)} object paths') + + objs = [] + for obj_path in obj_paths: + print(f'load {obj_path}') + v, f, _, n = co.io3d.read_obj(obj_path) + diffs = v.max(axis=0) - v.min(axis=0) + v /= (0.5 * diffs.max()) + v -= (v.min(axis=0) + 1) + f = f.astype(np.int32) + objs.append((v, f, n)) + print(f'loaded {len(objs)} objects') + + return objs def get_mesh(rng, min_z=0): - # set up background board - verts, faces, normals, colors = [], [], [], [] - v, f, n = co.geometry.xyplane(z=0, interleaved=True) - v[:,2] += -v[:,2].min() + rng.uniform(2,7) - v[:,:2] *= 5e2 - v[:,2] = np.mean(v[:,2]) + (v[:,2] - np.mean(v[:,2])) * 5e2 - c = np.empty_like(v) - c[:] = rng.uniform(0,1, size=(3,)).astype(np.float32) - verts.append(v) - faces.append(f) - normals.append(n) - colors.append(c) - - # randomly sample 4 foreground objects for each scene - for shape_idx in range(4): - v, f, n = objs[rng.randint(0,len(objs))] - v, f, n = v.copy(), f.copy(), n.copy() - - s = rng.uniform(0.25, 1) - v *= s - R = co.geometry.rotm_from_quat(co.geometry.quat_random(rng=rng)) - v = v @ R.T - n = n @ R.T - v[:,2] += -v[:,2].min() + min_z + rng.uniform(0.5, 3) - v[:,:2] += rng.uniform(-1, 1, size=(1,2)) - + # set up background board + verts, faces, normals, colors = [], [], [], [] + v, f, n = co.geometry.xyplane(z=0, interleaved=True) + v[:, 2] += -v[:, 2].min() + rng.uniform(2, 7) + v[:, :2] *= 5e2 + v[:, 2] = np.mean(v[:, 2]) + (v[:, 2] - np.mean(v[:, 2])) * 5e2 c = np.empty_like(v) - c[:] = rng.uniform(0,1, size=(3,)).astype(np.float32) - - verts.append(v.astype(np.float32)) + c[:] = rng.uniform(0, 1, size=(3,)).astype(np.float32) + verts.append(v) faces.append(f) normals.append(n) colors.append(c) - verts, faces = co.geometry.stack_mesh(verts, faces) - normals = np.vstack(normals).astype(np.float32) - colors = np.vstack(colors).astype(np.float32) - return verts, faces, colors, normals + # randomly sample 4 foreground objects for each scene + for shape_idx in range(4): + v, f, n = objs[rng.randint(0, len(objs))] + v, f, n = v.copy(), f.copy(), n.copy() + + s = rng.uniform(0.25, 1) + v *= s + R = co.geometry.rotm_from_quat(co.geometry.quat_random(rng=rng)) + v = v @ R.T + n = n @ R.T + v[:, 2] += -v[:, 2].min() + min_z + rng.uniform(0.5, 3) + v[:, :2] += rng.uniform(-1, 1, size=(1, 2)) + + c = np.empty_like(v) + c[:] = rng.uniform(0, 1, size=(3,)).astype(np.float32) + + verts.append(v.astype(np.float32)) + faces.append(f) + normals.append(n) + colors.append(c) + + verts, faces = co.geometry.stack_mesh(verts, faces) + normals = np.vstack(normals).astype(np.float32) + colors = np.vstack(colors).astype(np.float32) + return verts, faces, colors, normals def create_data(out_root, idx, n_samples, imsize, patterns, K, baseline, blend_im, noise, track_length=4): + tic = time.time() + rng = np.random.RandomState() + + rng.seed(idx) + + verts, faces, colors, normals = get_mesh(rng) + data = renderer.PyRenderInput(verts=verts.copy(), colors=colors.copy(), normals=normals.copy(), faces=faces.copy()) + print(f'loading mesh for sample {idx + 1}/{n_samples} took {time.time() - tic}[s]') + + # let the camera point to the center + center = np.array([0, 0, 3], dtype=np.float32) + + basevec = np.array([-baseline, 0, 0], dtype=np.float32) + unit = np.array([0, 0, 1], dtype=np.float32) + + cam_x_ = rng.uniform(-0.2, 0.2) + cam_y_ = rng.uniform(-0.2, 0.2) + cam_z_ = rng.uniform(-0.2, 0.2) + + ret = collections.defaultdict(list) + blend_im_rnd = np.clip(blend_im + rng.uniform(-0.1, 0.1), 0, 1) + + # capture the same static scene from different view points as a track + for ind in range(track_length): + + cam_x = cam_x_ + rng.uniform(-0.1, 0.1) + cam_y = cam_y_ + rng.uniform(-0.1, 0.1) + cam_z = cam_z_ + rng.uniform(-0.1, 0.1) + + tcam = np.array([cam_x, cam_y, cam_z], dtype=np.float32) + + if np.linalg.norm(tcam[0:2]) < 1e-9: + Rcam = np.eye(3, dtype=np.float32) + else: + Rcam = get_rotation_matrix(center, center - tcam) + + tproj = tcam + basevec + Rproj = Rcam + + ret['R'].append(Rcam) + ret['t'].append(tcam) + + cams = [] + projs = [] + + # render the scene at multiple scales + scales = [1, 0.5, 0.25, 0.125] + + for scale in scales: + fx = K[0, 0] * scale + fy = K[1, 1] * scale + px = K[0, 2] * scale + py = K[1, 2] * scale + im_height = imsize[0] * scale + im_width = imsize[1] * scale + cams.append(renderer.PyCamera(fx, fy, px, py, Rcam, tcam, im_width, im_height)) + projs.append(renderer.PyCamera(fx, fy, px, py, Rproj, tproj, im_width, im_height)) + + for s, cam, proj, pattern in zip(itertools.count(), cams, projs, patterns): + fl = K[0, 0] / (2 ** s) - tic = time.time() - rng = np.random.RandomState() + shader = renderer.PyShader(0.5, 1.5, 0.0, 10) + pyrenderer = renderer.PyRenderer(cam, shader, engine='gpu') + pyrenderer.mesh_proj(data, proj, pattern, d_alpha=0, d_beta=0.35) - rng.seed(idx) + # get the reflected laser pattern $R$ + im = pyrenderer.color().copy() + depth = pyrenderer.depth().copy() + disp = baseline * fl / depth + mask = depth > 0 + im = np.mean(im, axis=2) - verts, faces, colors, normals = get_mesh(rng) - data = renderer.PyRenderInput(verts=verts.copy(), colors=colors.copy(), normals=normals.copy(), faces=faces.copy()) - print(f'loading mesh for sample {idx+1}/{n_samples} took {time.time()-tic}[s]') + # get the ambient image $A$ + ambient = pyrenderer.normal().copy() + ambient = np.mean(ambient, axis=2) + # get the noise free IR image $J$ + im = blend_im_rnd * im + (1 - blend_im_rnd) * ambient + ret[f'ambient{s}'].append(ambient[None].astype(np.float32)) - # let the camera point to the center - center = np.array([0,0,3], dtype=np.float32) + # get the gradient magnitude of the ambient image $|\nabla A|$ + ambient = ambient.astype(np.float32) + sobelx = cv2.Sobel(ambient, cv2.CV_32F, 1, 0, ksize=5) + sobely = cv2.Sobel(ambient, cv2.CV_32F, 0, 1, ksize=5) + grad = np.sqrt(sobelx ** 2 + sobely ** 2) + grad = np.maximum(grad - 0.8, 0.0) # parameter - basevec = np.array([-baseline,0,0], dtype=np.float32) - unit = np.array([0,0,1],dtype=np.float32) + # get the local contract normalized grad LCN($|\nabla A|$) + grad_lcn, grad_std = lcn.normalize(grad, 5, 0.1) + grad_lcn = np.clip(grad_lcn, 0.0, 1.0) # parameter + ret[f'grad{s}'].append(grad_lcn[None].astype(np.float32)) - cam_x_ = rng.uniform(-0.2,0.2) - cam_y_ = rng.uniform(-0.2,0.2) - cam_z_ = rng.uniform(-0.2,0.2) + ret[f'im{s}'].append(im[None].astype(np.float32)) + ret[f'mask{s}'].append(mask[None].astype(np.float32)) + ret[f'disp{s}'].append(disp[None].astype(np.float32)) - ret = collections.defaultdict(list) - blend_im_rnd = np.clip(blend_im + rng.uniform(-0.1,0.1), 0,1) + for key in ret.keys(): + ret[key] = np.stack(ret[key], axis=0) - # capture the same static scene from different view points as a track - for ind in range(track_length): + # save to files + out_dir = out_root / f'{idx:08d}' + out_dir.mkdir(exist_ok=True, parents=True) + for k, val in ret.items(): + for tidx in range(track_length): + v = val[tidx] + out_path = out_dir / f'{k}_{tidx}.npy' + np.save(out_path, v) + np.save(str(out_dir / 'blend_im.npy'), blend_im_rnd) - cam_x = cam_x_ + rng.uniform(-0.1,0.1) - cam_y = cam_y_ + rng.uniform(-0.1,0.1) - cam_z = cam_z_ + rng.uniform(-0.1,0.1) + print(f'create sample {idx + 1}/{n_samples} took {time.time() - tic}[s]') - tcam = np.array([cam_x, cam_y, cam_z], dtype=np.float32) - if np.linalg.norm(tcam[0:2])<1e-9: - Rcam = np.eye(3, dtype=np.float32) +if __name__ == '__main__': + + np.random.seed(42) + + # output directory + with open('../config.json') as fp: + config = json.load(fp) + data_root = Path(config['DATA_ROOT']) + shapenet_root = config['SHAPENET_ROOT'] + + data_type = 'syn' + out_root = data_root / f'{data_type}' + out_root.mkdir(parents=True, exist_ok=True) + + start = 0 + if len(sys.argv) >= 2 and isinstance(sys.argv[2], int): + start = sys.argv[2] else: - Rcam = get_rotation_matrix(center, center-tcam) - - tproj = tcam + basevec - Rproj = Rcam - - ret['R'].append(Rcam) - ret['t'].append(tcam) - - cams = [] - projs = [] - - # render the scene at multiple scales - scales = [1, 0.5, 0.25, 0.125] - - for scale in scales: - fx = K[0,0] * scale - fy = K[1,1] * scale - px = K[0,2] * scale - py = K[1,2] * scale - im_height = imsize[0] * scale - im_width = imsize[1] * scale - cams.append( renderer.PyCamera(fx,fy,px,py, Rcam, tcam, im_width, im_height) ) - projs.append( renderer.PyCamera(fx,fy,px,py, Rproj, tproj, im_width, im_height) ) - - - for s, cam, proj, pattern in zip(itertools.count(), cams, projs, patterns): - fl = K[0,0] / (2**s) - - shader = renderer.PyShader(0.5,1.5,0.0,10) - pyrenderer = renderer.PyRenderer(cam, shader, engine='gpu') - pyrenderer.mesh_proj(data, proj, pattern, d_alpha=0, d_beta=0.35) - - # get the reflected laser pattern $R$ - im = pyrenderer.color().copy() - depth = pyrenderer.depth().copy() - disp = baseline * fl / depth - mask = depth > 0 - im = np.mean(im, axis=2) - - # get the ambient image $A$ - ambient = pyrenderer.normal().copy() - ambient = np.mean(ambient, axis=2) - - # get the noise free IR image $J$ - im = blend_im_rnd * im + (1 - blend_im_rnd) * ambient - ret[f'ambient{s}'].append( ambient[None].astype(np.float32) ) - - # get the gradient magnitude of the ambient image $|\nabla A|$ - ambient = ambient.astype(np.float32) - sobelx = cv2.Sobel(ambient,cv2.CV_32F,1,0,ksize=5) - sobely = cv2.Sobel(ambient,cv2.CV_32F,0,1,ksize=5) - grad = np.sqrt(sobelx**2 + sobely**2) - grad = np.maximum(grad-0.8,0.0) # parameter - - # get the local contract normalized grad LCN($|\nabla A|$) - grad_lcn, grad_std = lcn.normalize(grad,5,0.1) - grad_lcn = np.clip(grad_lcn,0.0,1.0) # parameter - ret[f'grad{s}'].append( grad_lcn[None].astype(np.float32)) - - ret[f'im{s}'].append( im[None].astype(np.float32)) - ret[f'mask{s}'].append(mask[None].astype(np.float32)) - ret[f'disp{s}'].append(disp[None].astype(np.float32)) - - for key in ret.keys(): - ret[key] = np.stack(ret[key], axis=0) - - # save to files - out_dir = out_root / f'{idx:08d}' - out_dir.mkdir(exist_ok=True, parents=True) - for k,val in ret.items(): - for tidx in range(track_length): - v = val[tidx] - out_path = out_dir / f'{k}_{tidx}.npy' - np.save(out_path, v) - np.save( str(out_dir /'blend_im.npy'), blend_im_rnd) - - print(f'create sample {idx+1}/{n_samples} took {time.time()-tic}[s]') - - - -if __name__=='__main__': - - np.random.seed(42) - - # output directory - with open('../config.json') as fp: - config = json.load(fp) - data_root = Path(config['DATA_ROOT']) - shapenet_root = config['SHAPENET_ROOT'] - - data_type = 'syn' - out_root = data_root / f'{data_type}' - out_root.mkdir(parents=True, exist_ok=True) - - start = 0 - if len(sys.argv) >= 2 and isinstance(sys.argv[2], int): - start = sys.argv[2] - else: - if sys.argv[2] == '--resume': - try: - start = max([int(dir) for dir in os.listdir(out_root) if str.isdigit(dir)]) or 0 - except: - pass - - # load shapenet models - obj_classes = ['chair'] - objs = get_objs(shapenet_root, obj_classes) - - # camera parameters - imsize = (488, 648) - imsizes = [(imsize[0]//(2**s), imsize[1]//(2**s)) for s in range(4)] - # K = np.array([[567.6, 0, 324.7], [0, 570.2, 250.1], [0 ,0, 1]], dtype=np.float32) - K = np.array([[1929.5936336276382, 0, 113.66561071478046], [0, 1911.2517985448746, 473.70108079885887], [0 ,0, 1]], dtype=np.float32) - focal_lengths = [K[0,0]/(2**s) for s in range(4)] - baseline=0.075 - blend_im = 0.6 - noise = 0 - - # capture the same static scene from different view points as a track - track_length = 4 - - # load pattern image - pattern_path = './kinect_pattern.png' - pattern_crop = True - patterns = get_patterns(pattern_path, imsizes, pattern_crop) - - # write settings to file - settings = { - 'imsizes': imsizes, - 'patterns': patterns, - 'focal_lengths': focal_lengths, - 'baseline': baseline, - 'K': K, - } - out_path = out_root / f'settings.pkl' - print(f'write settings to {out_path}') - with open(str(out_path), 'wb') as f: - pickle.dump(settings, f, pickle.HIGHEST_PROTOCOL) - - # start the job - n_samples = 2**10 + 2**13 - for idx in range(start, n_samples): - args = (out_root, idx, n_samples, imsize, patterns, K, baseline, blend_im, noise, track_length) - create_data(*args) + if sys.argv[2] == '--resume': + try: + start = max([int(dir) for dir in os.listdir(out_root) if str.isdigit(dir)]) or 0 + except: + pass + + # load shapenet models + obj_classes = ['chair'] + objs = get_objs(shapenet_root, obj_classes) + + # camera parameters + imsize = (488, 648) + imsizes = [(imsize[0] // (2 ** s), imsize[1] // (2 ** s)) for s in range(4)] + # K = np.array([[567.6, 0, 324.7], [0, 570.2, 250.1], [0 ,0, 1]], dtype=np.float32) + K = np.array([[1929.5936336276382, 0, 113.66561071478046], [0, 1911.2517985448746, 473.70108079885887], [0, 0, 1]], + dtype=np.float32) + focal_lengths = [K[0, 0] / (2 ** s) for s in range(4)] + baseline = 0.075 + blend_im = 0.6 + noise = 0 + + # capture the same static scene from different view points as a track + track_length = 4 + + # load pattern image + pattern_path = './kinect_pattern.png' + pattern_crop = True + patterns = get_patterns(pattern_path, imsizes, pattern_crop) + + # write settings to file + settings = { + 'imsizes': imsizes, + 'patterns': patterns, + 'focal_lengths': focal_lengths, + 'baseline': baseline, + 'K': K, + } + out_path = out_root / f'settings.pkl' + print(f'write settings to {out_path}') + with open(str(out_path), 'wb') as f: + pickle.dump(settings, f, pickle.HIGHEST_PROTOCOL) + + # start the job + n_samples = 2 ** 10 + 2 ** 13 + for idx in range(start, n_samples): + args = (out_root, idx, n_samples, imsize, patterns, K, baseline, blend_im, noise, track_length) + create_data(*args) diff --git a/data/dataset.py b/data/dataset.py index ee57bd6..5600cd0 100644 --- a/data/dataset.py +++ b/data/dataset.py @@ -21,128 +21,128 @@ from .commons import get_patterns, augment_image from mpl_toolkits.mplot3d import Axes3D + class TrackSynDataset(torchext.BaseDataset): - ''' - Load locally saved synthetic dataset - Please run ./create_syn_data.sh to generate the dataset - ''' - def __init__(self, settings_path, sample_paths, track_length=2, train=True, data_aug=False): - super().__init__(train=train) - - self.settings_path = settings_path - self.sample_paths = sample_paths - self.data_aug = data_aug - self.train = train - self.track_length=track_length - assert(track_length<=4) - - with open(str(settings_path), 'rb') as f: - settings = pickle.load(f) - self.imsizes = settings['imsizes'] - self.patterns = settings['patterns'] - self.focal_lengths = settings['focal_lengths'] - self.baseline = settings['baseline'] - self.K = settings['K'] - - self.scale = len(self.imsizes) - - self.max_shift=0 - self.max_blur=0.5 - self.max_noise=3.0 - self.max_sp_noise=0.0005 - - def __len__(self): - return len(self.sample_paths) - - def __getitem__(self, idx): - if not self.train: - rng = self.get_rng(idx) - else: - rng = np.random.RandomState() - sample_path = self.sample_paths[idx] - - if self.train: - track_ind = np.random.permutation(4)[0:self.track_length] - else: - track_ind = [0] - - ret = {} - ret['id'] = idx - - # load imgs, at all scales - for sidx in range(len(self.imsizes)): - imgs = [] - ambs = [] - grads = [] - for tidx in track_ind: - imgs.append(np.load(os.path.join(sample_path,f'im{sidx}_{tidx}.npy'))) - ambs.append(np.load(os.path.join(sample_path,f'ambient{sidx}_{tidx}.npy'))) - grads.append(np.load(os.path.join(sample_path,f'grad{sidx}_{tidx}.npy'))) - ret[f'im{sidx}'] = np.stack(imgs, axis=0) - ret[f'ambient{sidx}'] = np.stack(ambs, axis=0) - ret[f'grad{sidx}'] = np.stack(grads, axis=0) - - # load disp and grad only at full resolution - disps = [] - R = [] - t = [] - for tidx in track_ind: - disps.append(np.load(os.path.join(sample_path,f'disp0_{tidx}.npy'))) - R.append(np.load(os.path.join(sample_path,f'R_{tidx}.npy'))) - t.append(np.load(os.path.join(sample_path,f't_{tidx}.npy'))) - ret[f'disp0'] = np.stack(disps, axis=0) - ret['R'] = np.stack(R, axis=0) - ret['t'] = np.stack(t, axis=0) - - blend_im = np.load(os.path.join(sample_path,'blend_im.npy')) - ret['blend_im'] = blend_im.astype(np.float32) - - #### apply data augmentation at different scales seperately, only work for max_shift=0 - if self.data_aug: - for sidx in range(len(self.imsizes)): - if sidx==0: - img = ret[f'im{sidx}'] - disp = ret[f'disp{sidx}'] - grad = ret[f'grad{sidx}'] - img_aug = np.zeros_like(img) - disp_aug = np.zeros_like(img) - grad_aug = np.zeros_like(img) - for i in range(img.shape[0]): - img_aug_, disp_aug_, grad_aug_ = augment_image(img[i,0],rng, - disp=disp[i,0],grad=grad[i,0], - max_shift=self.max_shift, max_blur=self.max_blur, - max_noise=self.max_noise, max_sp_noise=self.max_sp_noise) - img_aug[i] = img_aug_[None].astype(np.float32) - disp_aug[i] = disp_aug_[None].astype(np.float32) - grad_aug[i] = grad_aug_[None].astype(np.float32) - ret[f'im{sidx}'] = img_aug - ret[f'disp{sidx}'] = disp_aug - ret[f'grad{sidx}'] = grad_aug + ''' + Load locally saved synthetic dataset + Please run ./create_syn_data.sh to generate the dataset + ''' + + def __init__(self, settings_path, sample_paths, track_length=2, train=True, data_aug=False): + super().__init__(train=train) + + self.settings_path = settings_path + self.sample_paths = sample_paths + self.data_aug = data_aug + self.train = train + self.track_length = track_length + assert (track_length <= 4) + + with open(str(settings_path), 'rb') as f: + settings = pickle.load(f) + self.imsizes = settings['imsizes'] + self.patterns = settings['patterns'] + self.focal_lengths = settings['focal_lengths'] + self.baseline = settings['baseline'] + self.K = settings['K'] + + self.scale = len(self.imsizes) + + self.max_shift = 0 + self.max_blur = 0.5 + self.max_noise = 3.0 + self.max_sp_noise = 0.0005 + + def __len__(self): + return len(self.sample_paths) + + def __getitem__(self, idx): + if not self.train: + rng = self.get_rng(idx) else: - img = ret[f'im{sidx}'] - img_aug = np.zeros_like(img) - for i in range(img.shape[0]): - img_aug_, _, _ = augment_image(img[i,0],rng, - max_shift=self.max_shift, max_blur=self.max_blur, - max_noise=self.max_noise, max_sp_noise=self.max_sp_noise) - img_aug[i] = img_aug_[None].astype(np.float32) - ret[f'im{sidx}'] = img_aug - - if len(track_ind)==1: - for key, val in ret.items(): - if key!='blend_im' and key!='id': - ret[key] = val[0] - + rng = np.random.RandomState() + sample_path = self.sample_paths[idx] - return ret - - def getK(self, sidx=0): - K = self.K.copy() / (2**sidx) - K[2,2] = 1 - return K + if self.train: + track_ind = np.random.permutation(4)[0:self.track_length] + else: + track_ind = [0] + + ret = {} + ret['id'] = idx + + # load imgs, at all scales + for sidx in range(len(self.imsizes)): + imgs = [] + ambs = [] + grads = [] + for tidx in track_ind: + imgs.append(np.load(os.path.join(sample_path, f'im{sidx}_{tidx}.npy'))) + ambs.append(np.load(os.path.join(sample_path, f'ambient{sidx}_{tidx}.npy'))) + grads.append(np.load(os.path.join(sample_path, f'grad{sidx}_{tidx}.npy'))) + ret[f'im{sidx}'] = np.stack(imgs, axis=0) + ret[f'ambient{sidx}'] = np.stack(ambs, axis=0) + ret[f'grad{sidx}'] = np.stack(grads, axis=0) + + # load disp and grad only at full resolution + disps = [] + R = [] + t = [] + for tidx in track_ind: + disps.append(np.load(os.path.join(sample_path, f'disp0_{tidx}.npy'))) + R.append(np.load(os.path.join(sample_path, f'R_{tidx}.npy'))) + t.append(np.load(os.path.join(sample_path, f't_{tidx}.npy'))) + ret[f'disp0'] = np.stack(disps, axis=0) + ret['R'] = np.stack(R, axis=0) + ret['t'] = np.stack(t, axis=0) + + blend_im = np.load(os.path.join(sample_path, 'blend_im.npy')) + ret['blend_im'] = blend_im.astype(np.float32) + + #### apply data augmentation at different scales seperately, only work for max_shift=0 + if self.data_aug: + for sidx in range(len(self.imsizes)): + if sidx == 0: + img = ret[f'im{sidx}'] + disp = ret[f'disp{sidx}'] + grad = ret[f'grad{sidx}'] + img_aug = np.zeros_like(img) + disp_aug = np.zeros_like(img) + grad_aug = np.zeros_like(img) + for i in range(img.shape[0]): + img_aug_, disp_aug_, grad_aug_ = augment_image(img[i, 0], rng, + disp=disp[i, 0], grad=grad[i, 0], + max_shift=self.max_shift, max_blur=self.max_blur, + max_noise=self.max_noise, + max_sp_noise=self.max_sp_noise) + img_aug[i] = img_aug_[None].astype(np.float32) + disp_aug[i] = disp_aug_[None].astype(np.float32) + grad_aug[i] = grad_aug_[None].astype(np.float32) + ret[f'im{sidx}'] = img_aug + ret[f'disp{sidx}'] = disp_aug + ret[f'grad{sidx}'] = grad_aug + else: + img = ret[f'im{sidx}'] + img_aug = np.zeros_like(img) + for i in range(img.shape[0]): + img_aug_, _, _ = augment_image(img[i, 0], rng, + max_shift=self.max_shift, max_blur=self.max_blur, + max_noise=self.max_noise, max_sp_noise=self.max_sp_noise) + img_aug[i] = img_aug_[None].astype(np.float32) + ret[f'im{sidx}'] = img_aug + + if len(track_ind) == 1: + for key, val in ret.items(): + if key != 'blend_im' and key != 'id': + ret[key] = val[0] + + return ret + + def getK(self, sidx=0): + K = self.K.copy() / (2 ** sidx) + K[2, 2] = 1 + return K - if __name__ == '__main__': - pass - + pass diff --git a/data/lcn/lcn.html b/data/lcn/lcn.html index 805940b..3e83f9b 100644 --- a/data/lcn/lcn.html +++ b/data/lcn/lcn.html @@ -2,7 +2,7 @@ - + Cython: lcn.pyx

Generated by Cython 0.29

- Yellow lines hint at Python interaction.
+ Yellow lines hint at Python interaction.
Click on a line that starts with a "+" to see the C code that Cython generated for it.

Raw output: lcn.c

-
+01: import numpy as np
-
  __pyx_t_1 = __Pyx_Import(__pyx_n_s_numpy, 0, -1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1, __pyx_L1_error)
+
+
+01: import numpy as np
+
  __pyx_t_1 = __Pyx_Import(__pyx_n_s_numpy, 0, -1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_1);
   if (PyDict_SetItem(__pyx_d, __pyx_n_s_np, __pyx_t_1) < 0) __PYX_ERR(0, 1, __pyx_L1_error)
   __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
@@ -374,22 +380,39 @@ body.cython { font-family: courier; font-size: 12; }
   __Pyx_GOTREF(__pyx_t_1);
   if (PyDict_SetItem(__pyx_d, __pyx_n_s_test, __pyx_t_1) < 0) __PYX_ERR(0, 1, __pyx_L1_error)
   __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
-
 02: cimport cython
-
 03: 
-
 04: # use c square root function
-
 05: cdef extern from "math.h":
-
 06:     float sqrt(float x)
-
 07: 
-
 08: @cython.boundscheck(False)
-
 09: @cython.wraparound(False)
-
 10: @cython.cdivision(True)
-
 11: 
-
 12: # 3 parameters:
-
 13: # - float image
-
 14: # - kernel size (actually this is the radius, kernel is 2*k+1)
-
 15: # - small constant epsilon that is used to avoid division by zero
-
+16: def normalize(float[:, :] img, int kernel_size = 4, float epsilon = 0.01):
-
/* Python wrapper */
+
+
 02: cimport cython
+
 03: 
+
 04: # use c square root function
+
 05: cdef extern from "math.h":
+
 06:     float sqrt(float x)
+
 07: 
+
 08: @cython.boundscheck(False)
+
 09: @cython.wraparound(False)
+
 10: @cython.cdivision(True)
+
 11: 
+
 12: # 3 parameters:
+
 13: # - float image
+
 14: # - kernel size (actually this is the radius, kernel is 2*k+1)
+
 15: # - small constant epsilon that is used to avoid division by zero
+
+16: def normalize(float[:, :] img, int kernel_size = 4, float epsilon = 0.01):
+
/* Python wrapper */
 static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/
 static PyMethodDef __pyx_mdef_3lcn_1normalize = {"normalize", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_3lcn_1normalize, METH_VARARGS|METH_KEYWORDS, 0};
 static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) {
@@ -434,7 +457,8 @@ static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_
         }
       }
       if (unlikely(kw_args > 0)) {
-        if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "normalize") < 0)) __PYX_ERR(0, 16, __pyx_L3_error)
+        if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "normalize") < 0)) __PYX_ERR(0, 16, __pyx_L3_error)
       }
     } else {
       switch (PyTuple_GET_SIZE(__pyx_args)) {
@@ -447,21 +471,27 @@ static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_
         default: goto __pyx_L5_argtuple_error;
       }
     }
-    __pyx_v_img = __Pyx_PyObject_to_MemoryviewSlice_dsds_float(values[0], PyBUF_WRITABLE); if (unlikely(!__pyx_v_img.memview)) __PYX_ERR(0, 16, __pyx_L3_error)
+    __pyx_v_img = __Pyx_PyObject_to_MemoryviewSlice_dsds_float(values[0], PyBUF_WRITABLE); if (unlikely(!__pyx_v_img.memview)) __PYX_ERR(0, 16, __pyx_L3_error)
     if (values[1]) {
-      __pyx_v_kernel_size = __Pyx_PyInt_As_int(values[1]); if (unlikely((__pyx_v_kernel_size == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 16, __pyx_L3_error)
+      __pyx_v_kernel_size = __Pyx_PyInt_As_int(values[1]); if (unlikely((__pyx_v_kernel_size == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 16, __pyx_L3_error)
     } else {
       __pyx_v_kernel_size = ((int)4);
     }
     if (values[2]) {
-      __pyx_v_epsilon = __pyx_PyFloat_AsFloat(values[2]); if (unlikely((__pyx_v_epsilon == (float)-1) && PyErr_Occurred())) __PYX_ERR(0, 16, __pyx_L3_error)
+      __pyx_v_epsilon = __pyx_PyFloat_AsFloat(values[2]); if (unlikely((__pyx_v_epsilon == (float)-1) && PyErr_Occurred())) __PYX_ERR(0, 16, __pyx_L3_error)
     } else {
       __pyx_v_epsilon = ((float)0.01);
     }
   }
   goto __pyx_L4_argument_unpacking_done;
   __pyx_L5_argtuple_error:;
-  __Pyx_RaiseArgtupleInvalid("normalize", 0, 1, 3, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 16, __pyx_L3_error)
+  __Pyx_RaiseArgtupleInvalid("normalize", 0, 1, 3, PyTuple_GET_SIZE(__pyx_args)); __PYX_ERR(0, 16, __pyx_L3_error)
   __pyx_L3_error:;
   __Pyx_AddTraceback("lcn.normalize", __pyx_clineno, __pyx_lineno, __pyx_filename);
   __Pyx_RefNannyFinishContext();
@@ -515,27 +545,49 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
   return __pyx_r;
 }
 /* … */
-  __pyx_tuple__19 = PyTuple_Pack(19, __pyx_n_s_img, __pyx_n_s_kernel_size, __pyx_n_s_epsilon, __pyx_n_s_M, __pyx_n_s_N, __pyx_n_s_img_lcn, __pyx_n_s_img_std, __pyx_n_s_img_lcn_view, __pyx_n_s_img_std_view, __pyx_n_s_tmp, __pyx_n_s_mean, __pyx_n_s_stddev, __pyx_n_s_m, __pyx_n_s_n, __pyx_n_s_i, __pyx_n_s_j, __pyx_n_s_ks, __pyx_n_s_eps, __pyx_n_s_num); if (unlikely(!__pyx_tuple__19)) __PYX_ERR(0, 16, __pyx_L1_error)
+  __pyx_tuple__19 = PyTuple_Pack(19, __pyx_n_s_img, __pyx_n_s_kernel_size, __pyx_n_s_epsilon, __pyx_n_s_M, __pyx_n_s_N, __pyx_n_s_img_lcn, __pyx_n_s_img_std, __pyx_n_s_img_lcn_view, __pyx_n_s_img_std_view, __pyx_n_s_tmp, __pyx_n_s_mean, __pyx_n_s_stddev, __pyx_n_s_m, __pyx_n_s_n, __pyx_n_s_i, __pyx_n_s_j, __pyx_n_s_ks, __pyx_n_s_eps, __pyx_n_s_num); if (unlikely(!__pyx_tuple__19)) __PYX_ERR(0, 16, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_tuple__19);
   __Pyx_GIVEREF(__pyx_tuple__19);
 /* … */
   __pyx_t_1 = PyCFunction_NewEx(&__pyx_mdef_3lcn_1normalize, NULL, __pyx_n_s_lcn); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 16, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_1);
-  if (PyDict_SetItem(__pyx_d, __pyx_n_s_normalize, __pyx_t_1) < 0) __PYX_ERR(0, 16, __pyx_L1_error)
+  if (PyDict_SetItem(__pyx_d, __pyx_n_s_normalize, __pyx_t_1) < 0) __PYX_ERR(0, 16, __pyx_L1_error)
   __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
-  __pyx_codeobj__20 = (PyObject*)__Pyx_PyCode_New(3, 0, 19, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__19, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_lcn_pyx, __pyx_n_s_normalize, 16, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__20)) __PYX_ERR(0, 16, __pyx_L1_error)
-
 17: 
-
 18:     # image dimensions
-
+19:     cdef Py_ssize_t M = img.shape[0]
-
  __pyx_v_M = (__pyx_v_img.shape[0]);
-
+20:     cdef Py_ssize_t N = img.shape[1]
-
  __pyx_v_N = (__pyx_v_img.shape[1]);
-
 21: 
-
 22:     # create outputs and output views
-
+23:     img_lcn = np.zeros((M, N), dtype=np.float32)
-
  __Pyx_GetModuleGlobalName(__pyx_t_1, __pyx_n_s_np); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)
+  __pyx_codeobj__20 = (PyObject*)__Pyx_PyCode_New(3, 0, 19, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__19, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_lcn_pyx, __pyx_n_s_normalize, 16, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__20)) __PYX_ERR(0, 16, __pyx_L1_error)
+
+
 17: 
+
 18:     # image dimensions
+
+19:     cdef Py_ssize_t M = img.shape[0]
+
  __pyx_v_M = (__pyx_v_img.shape[0]);
+
+
+20:     cdef Py_ssize_t N = img.shape[1]
+
  __pyx_v_N = (__pyx_v_img.shape[1]);
+
+
 21: 
+
 22:     # create outputs and output views
+
+23:     img_lcn = np.zeros((M, N), dtype=np.float32)
+
  __Pyx_GetModuleGlobalName(__pyx_t_1, __pyx_n_s_np); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_1);
-  __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_t_1, __pyx_n_s_zeros); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 23, __pyx_L1_error)
+  __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_t_1, __pyx_n_s_zeros); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 23, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_2);
   __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
   __pyx_t_1 = PyInt_FromSsize_t(__pyx_v_M); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)
@@ -559,22 +611,34 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
   __Pyx_GOTREF(__pyx_t_4);
   __Pyx_GetModuleGlobalName(__pyx_t_1, __pyx_n_s_np); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_1);
-  __pyx_t_5 = __Pyx_PyObject_GetAttrStr(__pyx_t_1, __pyx_n_s_float32); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)
+  __pyx_t_5 = __Pyx_PyObject_GetAttrStr(__pyx_t_1, __pyx_n_s_float32); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_5);
   __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
-  if (PyDict_SetItem(__pyx_t_4, __pyx_n_s_dtype, __pyx_t_5) < 0) __PYX_ERR(0, 23, __pyx_L1_error)
+  if (PyDict_SetItem(__pyx_t_4, __pyx_n_s_dtype, __pyx_t_5) < 0) __PYX_ERR(0, 23, __pyx_L1_error)
   __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0;
-  __pyx_t_5 = __Pyx_PyObject_Call(__pyx_t_2, __pyx_t_3, __pyx_t_4); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)
+  __pyx_t_5 = __Pyx_PyObject_Call(__pyx_t_2, __pyx_t_3, __pyx_t_4); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_5);
   __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
   __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
   __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;
   __pyx_v_img_lcn = __pyx_t_5;
   __pyx_t_5 = 0;
-
+24:     img_std = np.zeros((M, N), dtype=np.float32)
-
  __Pyx_GetModuleGlobalName(__pyx_t_5, __pyx_n_s_np); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)
+
+
+24:     img_std = np.zeros((M, N), dtype=np.float32)
+
  __Pyx_GetModuleGlobalName(__pyx_t_5, __pyx_n_s_np); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_5);
-  __pyx_t_4 = __Pyx_PyObject_GetAttrStr(__pyx_t_5, __pyx_n_s_zeros); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 24, __pyx_L1_error)
+  __pyx_t_4 = __Pyx_PyObject_GetAttrStr(__pyx_t_5, __pyx_n_s_zeros); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 24, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_4);
   __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0;
   __pyx_t_5 = PyInt_FromSsize_t(__pyx_v_M); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)
@@ -598,114 +662,236 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
   __Pyx_GOTREF(__pyx_t_2);
   __Pyx_GetModuleGlobalName(__pyx_t_5, __pyx_n_s_np); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_5);
-  __pyx_t_1 = __Pyx_PyObject_GetAttrStr(__pyx_t_5, __pyx_n_s_float32); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)
+  __pyx_t_1 = __Pyx_PyObject_GetAttrStr(__pyx_t_5, __pyx_n_s_float32); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_1);
   __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0;
-  if (PyDict_SetItem(__pyx_t_2, __pyx_n_s_dtype, __pyx_t_1) < 0) __PYX_ERR(0, 24, __pyx_L1_error)
+  if (PyDict_SetItem(__pyx_t_2, __pyx_n_s_dtype, __pyx_t_1) < 0) __PYX_ERR(0, 24, __pyx_L1_error)
   __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
-  __pyx_t_1 = __Pyx_PyObject_Call(__pyx_t_4, __pyx_t_3, __pyx_t_2); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)
+  __pyx_t_1 = __Pyx_PyObject_Call(__pyx_t_4, __pyx_t_3, __pyx_t_2); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_1);
   __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;
   __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
   __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
   __pyx_v_img_std = __pyx_t_1;
   __pyx_t_1 = 0;
-
+25:     cdef float[:, :] img_lcn_view = img_lcn
-
  __pyx_t_6 = __Pyx_PyObject_to_MemoryviewSlice_dsds_float(__pyx_v_img_lcn, PyBUF_WRITABLE); if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 25, __pyx_L1_error)
+
+
+25:     cdef float[:, :] img_lcn_view = img_lcn
+
  __pyx_t_6 = __Pyx_PyObject_to_MemoryviewSlice_dsds_float(__pyx_v_img_lcn, PyBUF_WRITABLE); if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 25, __pyx_L1_error)
   __pyx_v_img_lcn_view = __pyx_t_6;
   __pyx_t_6.memview = NULL;
   __pyx_t_6.data = NULL;
-
+26:     cdef float[:, :] img_std_view = img_std
-
  __pyx_t_6 = __Pyx_PyObject_to_MemoryviewSlice_dsds_float(__pyx_v_img_std, PyBUF_WRITABLE); if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 26, __pyx_L1_error)
+
+
+26:     cdef float[:, :] img_std_view = img_std
+
  __pyx_t_6 = __Pyx_PyObject_to_MemoryviewSlice_dsds_float(__pyx_v_img_std, PyBUF_WRITABLE); if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 26, __pyx_L1_error)
   __pyx_v_img_std_view = __pyx_t_6;
   __pyx_t_6.memview = NULL;
   __pyx_t_6.data = NULL;
-
 27: 
-
 28:     # temporary c variables
-
 29:     cdef float tmp, mean, stddev
-
 30:     cdef Py_ssize_t m, n, i, j
-
+31:     cdef Py_ssize_t ks = kernel_size
-
  __pyx_v_ks = __pyx_v_kernel_size;
-
+32:     cdef float eps = epsilon
-
  __pyx_v_eps = __pyx_v_epsilon;
-
+33:     cdef float num = (ks*2+1)**2
-
  __pyx_v_num = __Pyx_pow_Py_ssize_t(((__pyx_v_ks * 2) + 1), 2);
-
 34: 
-
 35:     # for all pixels do
-
+36:     for m in range(ks,M-ks):
-
  __pyx_t_7 = (__pyx_v_M - __pyx_v_ks);
+
+
 27: 
+
 28:     # temporary c variables
+
 29:     cdef float tmp, mean, stddev
+
 30:     cdef Py_ssize_t m, n, i, j
+
+31:     cdef Py_ssize_t ks = kernel_size
+
  __pyx_v_ks = __pyx_v_kernel_size;
+
+
+32:     cdef float eps = epsilon
+
  __pyx_v_eps = __pyx_v_epsilon;
+
+
+33:     cdef float num = (ks*2+1)**2
+
  __pyx_v_num = __Pyx_pow_Py_ssize_t(((__pyx_v_ks * 2) + 1), 2);
+
+
 34: 
+
 35:     # for all pixels do
+
+36:     for m in range(ks,M-ks):
+
  __pyx_t_7 = (__pyx_v_M - __pyx_v_ks);
   __pyx_t_8 = __pyx_t_7;
   for (__pyx_t_9 = __pyx_v_ks; __pyx_t_9 < __pyx_t_8; __pyx_t_9+=1) {
     __pyx_v_m = __pyx_t_9;
-
+37:         for n in range(ks,N-ks):
-
    __pyx_t_10 = (__pyx_v_N - __pyx_v_ks);
+
+
+37:         for n in range(ks,N-ks):
+
    __pyx_t_10 = (__pyx_v_N - __pyx_v_ks);
     __pyx_t_11 = __pyx_t_10;
     for (__pyx_t_12 = __pyx_v_ks; __pyx_t_12 < __pyx_t_11; __pyx_t_12+=1) {
       __pyx_v_n = __pyx_t_12;
-
 38: 
-
 39:             # calculate mean
-
+40:             mean = 0;
-
      __pyx_v_mean = 0.0;
-
+41:             for i in range(-ks,ks+1):
-
      __pyx_t_13 = (__pyx_v_ks + 1);
+
+
 38: 
+
 39:             # calculate mean
+
+40:             mean = 0;
+
      __pyx_v_mean = 0.0;
+
+
+41:             for i in range(-ks,ks+1):
+
      __pyx_t_13 = (__pyx_v_ks + 1);
       __pyx_t_14 = __pyx_t_13;
       for (__pyx_t_15 = (-__pyx_v_ks); __pyx_t_15 < __pyx_t_14; __pyx_t_15+=1) {
         __pyx_v_i = __pyx_t_15;
-
+42:                 for j in range(-ks,ks+1):
-
        __pyx_t_16 = (__pyx_v_ks + 1);
+
+
+42:                 for j in range(-ks,ks+1):
+
        __pyx_t_16 = (__pyx_v_ks + 1);
         __pyx_t_17 = __pyx_t_16;
         for (__pyx_t_18 = (-__pyx_v_ks); __pyx_t_18 < __pyx_t_17; __pyx_t_18+=1) {
           __pyx_v_j = __pyx_t_18;
-
+43:                     mean += img[m+i, n+j]
-
          __pyx_t_19 = (__pyx_v_m + __pyx_v_i);
+
+
+43:                     mean += img[m+i, n+j]
+
          __pyx_t_19 = (__pyx_v_m + __pyx_v_i);
           __pyx_t_20 = (__pyx_v_n + __pyx_v_j);
           __pyx_v_mean = (__pyx_v_mean + (*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_19 * __pyx_v_img.strides[0]) ) + __pyx_t_20 * __pyx_v_img.strides[1]) ))));
         }
       }
-
+44:             mean = mean/num
-
      __pyx_v_mean = (__pyx_v_mean / __pyx_v_num);
-
 45: 
-
 46:             # calculate std dev
-
+47:             stddev = 0;
-
      __pyx_v_stddev = 0.0;
-
+48:             for i in range(-ks,ks+1):
-
      __pyx_t_13 = (__pyx_v_ks + 1);
+
+
+44:             mean = mean/num
+
      __pyx_v_mean = (__pyx_v_mean / __pyx_v_num);
+
+
 45: 
+
 46:             # calculate std dev
+
+47:             stddev = 0;
+
      __pyx_v_stddev = 0.0;
+
+
+48:             for i in range(-ks,ks+1):
+
      __pyx_t_13 = (__pyx_v_ks + 1);
       __pyx_t_14 = __pyx_t_13;
       for (__pyx_t_15 = (-__pyx_v_ks); __pyx_t_15 < __pyx_t_14; __pyx_t_15+=1) {
         __pyx_v_i = __pyx_t_15;
-
+49:                 for j in range(-ks,ks+1):
-
        __pyx_t_16 = (__pyx_v_ks + 1);
+
+
+49:                 for j in range(-ks,ks+1):
+
        __pyx_t_16 = (__pyx_v_ks + 1);
         __pyx_t_17 = __pyx_t_16;
         for (__pyx_t_18 = (-__pyx_v_ks); __pyx_t_18 < __pyx_t_17; __pyx_t_18+=1) {
           __pyx_v_j = __pyx_t_18;
-
+50:                     stddev = stddev + (img[m+i, n+j]-mean)*(img[m+i, n+j]-mean)
-
          __pyx_t_21 = (__pyx_v_m + __pyx_v_i);
+
+
+50:                     stddev = stddev + (img[m+i, n+j]-mean)*(img[m+i, n+j]-mean)
+
          __pyx_t_21 = (__pyx_v_m + __pyx_v_i);
           __pyx_t_22 = (__pyx_v_n + __pyx_v_j);
           __pyx_t_23 = (__pyx_v_m + __pyx_v_i);
           __pyx_t_24 = (__pyx_v_n + __pyx_v_j);
           __pyx_v_stddev = (__pyx_v_stddev + (((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_21 * __pyx_v_img.strides[0]) ) + __pyx_t_22 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean) * ((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_23 * __pyx_v_img.strides[0]) ) + __pyx_t_24 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean)));
         }
       }
-
+51:             stddev = sqrt(stddev/num)
-
      __pyx_v_stddev = sqrt((__pyx_v_stddev / __pyx_v_num));
-
 52: 
-
 53:             # compute normalized image (add epsilon) and std dev image
-
+54:             img_lcn_view[m, n] = (img[m, n]-mean)/(stddev+eps)
-
      __pyx_t_25 = __pyx_v_m;
+
+
+51:             stddev = sqrt(stddev/num)
+
      __pyx_v_stddev = sqrt((__pyx_v_stddev / __pyx_v_num));
+
+
 52: 
+
 53:             # compute normalized image (add epsilon) and std dev image
+
+54:             img_lcn_view[m, n] = (img[m, n]-mean)/(stddev+eps)
+
      __pyx_t_25 = __pyx_v_m;
       __pyx_t_26 = __pyx_v_n;
       __pyx_t_27 = __pyx_v_m;
       __pyx_t_28 = __pyx_v_n;
       *((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img_lcn_view.data + __pyx_t_27 * __pyx_v_img_lcn_view.strides[0]) ) + __pyx_t_28 * __pyx_v_img_lcn_view.strides[1]) )) = (((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_25 * __pyx_v_img.strides[0]) ) + __pyx_t_26 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean) / (__pyx_v_stddev + __pyx_v_eps));
-
+55:             img_std_view[m, n] = stddev
-
      __pyx_t_29 = __pyx_v_m;
+
+
+55:             img_std_view[m, n] = stddev
+
      __pyx_t_29 = __pyx_v_m;
       __pyx_t_30 = __pyx_v_n;
       *((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img_std_view.data + __pyx_t_29 * __pyx_v_img_std_view.strides[0]) ) + __pyx_t_30 * __pyx_v_img_std_view.strides[1]) )) = __pyx_v_stddev;
     }
   }
-
 56: 
-
 57:     # return both
-
+58:     return img_lcn, img_std
-
  __Pyx_XDECREF(__pyx_r);
+
+
 56: 
+
 57:     # return both
+
+58:     return img_lcn, img_std
+
  __Pyx_XDECREF(__pyx_r);
   __pyx_t_1 = PyTuple_New(2); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 58, __pyx_L1_error)
   __Pyx_GOTREF(__pyx_t_1);
   __Pyx_INCREF(__pyx_v_img_lcn);
@@ -717,4 +903,7 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
   __pyx_r = __pyx_t_1;
   __pyx_t_1 = 0;
   goto __pyx_L0;
-
+
+
+ + diff --git a/data/lcn/setup.py b/data/lcn/setup.py index e1d4d0e..5c404c4 100644 --- a/data/lcn/setup.py +++ b/data/lcn/setup.py @@ -2,5 +2,5 @@ from distutils.core import setup from Cython.Build import cythonize setup( - ext_modules = cythonize("lcn.pyx",annotate=True) + ext_modules=cythonize("lcn.pyx", annotate=True) ) diff --git a/data/lcn/test_lcn.py b/data/lcn/test_lcn.py index eff0289..b9985c6 100644 --- a/data/lcn/test_lcn.py +++ b/data/lcn/test_lcn.py @@ -5,43 +5,43 @@ from scipy import misc # load and convert to float img = misc.imread('img.png') -img = img.astype(np.float32)/255.0 +img = img.astype(np.float32) / 255.0 # normalize -img_lcn, img_std = lcn.normalize(img,5,0.05) +img_lcn, img_std = lcn.normalize(img, 5, 0.05) # normalize to reasonable range between 0 and 1 -#img_lcn = img_lcn/3.0 -#img_lcn = np.maximum(img_lcn,0.0) -#img_lcn = np.minimum(img_lcn,1.0) +# img_lcn = img_lcn/3.0 +# img_lcn = np.maximum(img_lcn,0.0) +# img_lcn = np.minimum(img_lcn,1.0) # save to file -#misc.imsave('lcn2.png',img_lcn) +# misc.imsave('lcn2.png',img_lcn) -print ("Orig Image: %d x %d (%s), Min: %f, Max: %f" % \ -(img.shape[0], img.shape[1], img.dtype, img.min(), img.max())) -print ("Norm Image: %d x %d (%s), Min: %f, Max: %f" % \ -(img_lcn.shape[0], img_lcn.shape[1], img_lcn.dtype, img_lcn.min(), img_lcn.max())) +print("Orig Image: %d x %d (%s), Min: %f, Max: %f" % \ + (img.shape[0], img.shape[1], img.dtype, img.min(), img.max())) +print("Norm Image: %d x %d (%s), Min: %f, Max: %f" % \ + (img_lcn.shape[0], img_lcn.shape[1], img_lcn.dtype, img_lcn.min(), img_lcn.max())) # plot original image plt.figure(1) img_plot = plt.imshow(img) img_plot.set_cmap('gray') -plt.clim(0, 1) # fix range +plt.clim(0, 1) # fix range plt.tight_layout() # plot normalized image plt.figure(2) img_lcn_plot = plt.imshow(img_lcn) img_lcn_plot.set_cmap('gray') -#plt.clim(0, 1) # fix range +# plt.clim(0, 1) # fix range plt.tight_layout() # plot stddev image plt.figure(3) img_std_plot = plt.imshow(img_std) img_std_plot.set_cmap('gray') -#plt.clim(0, 0.1) # fix range +# plt.clim(0, 0.1) # fix range plt.tight_layout() plt.show() diff --git a/hyperdepth/hyperparam_search.py b/hyperdepth/hyperparam_search.py index 54e9475..cbef34e 100644 --- a/hyperdepth/hyperparam_search.py +++ b/hyperdepth/hyperparam_search.py @@ -11,28 +11,27 @@ import dataset def get_data(n, row_from, row_to, train): - imsizes = [(256,384)] - focal_lengths = [160] - dset = dataset.SynDataset(n, imsizes=imsizes, focal_lengths=focal_lengths, train=train) - ims = np.empty((n, row_to-row_from, imsizes[0][1]), dtype=np.uint8) - disps = np.empty((n, row_to-row_from, imsizes[0][1]), dtype=np.float32) - for idx in range(n): - print(f'load sample {idx} train={train}') - sample = dset[idx] - ims[idx] = (sample['im0'][0,row_from:row_to] * 255).astype(np.uint8) - disps[idx] = sample['disp0'][0,row_from:row_to] - return ims, disps - + imsizes = [(256, 384)] + focal_lengths = [160] + dset = dataset.SynDataset(n, imsizes=imsizes, focal_lengths=focal_lengths, train=train) + ims = np.empty((n, row_to - row_from, imsizes[0][1]), dtype=np.uint8) + disps = np.empty((n, row_to - row_from, imsizes[0][1]), dtype=np.float32) + for idx in range(n): + print(f'load sample {idx} train={train}') + sample = dset[idx] + ims[idx] = (sample['im0'][0, row_from:row_to] * 255).astype(np.uint8) + disps[idx] = sample['disp0'][0, row_from:row_to] + return ims, disps params = hd.TrainParams( - n_trees=4, - max_tree_depth=, - n_test_split_functions=50, - n_test_thresholds=10, - n_test_samples=4096, - min_samples_to_split=16, - min_samples_for_leaf=8) + n_trees=4, + max_tree_depth=, + n_test_split_functions=50, + n_test_thresholds=10, + n_test_samples=4096, + min_samples_to_split=16, + min_samples_for_leaf=8) n_disp_bins = 20 depth_switch = 0 @@ -45,21 +44,23 @@ n_test_samples = 32 train_ims, train_disps = get_data(n_train_samples, row_from, row_to, True) test_ims, test_disps = get_data(n_test_samples, row_from, row_to, False) -for tree_depth in [8,10,12,14,16]: - depth_switch = tree_depth - 4 +for tree_depth in [8, 10, 12, 14, 16]: + depth_switch = tree_depth - 4 - prefix = f'td{tree_depth}_ds{depth_switch}' - prefix = Path(f'./forests/{prefix}/') - prefix.mkdir(parents=True, exist_ok=True) + prefix = f'td{tree_depth}_ds{depth_switch}' + prefix = Path(f'./forests/{prefix}/') + prefix.mkdir(parents=True, exist_ok=True) - hd.train_forest(params, train_ims, train_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch, forest_prefix=str(prefix / 'fr')) + hd.train_forest(params, train_ims, train_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch, + forest_prefix=str(prefix / 'fr')) - es = hd.eval_forest(test_ims, test_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch, forest_prefix=str(prefix / 'fr')) + es = hd.eval_forest(test_ims, test_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch, + forest_prefix=str(prefix / 'fr')) - np.save(str(prefix / 'ta.npy'), test_disps) - np.save(str(prefix / 'es.npy'), es) + np.save(str(prefix / 'ta.npy'), test_disps) + np.save(str(prefix / 'es.npy'), es) - # plt.figure(); - # plt.subplot(2,1,1); plt.imshow(test_disps[0], vmin=0, vmax=4); - # plt.subplot(2,1,2); plt.imshow(es[0], vmin=0, vmax=4); - # plt.show() + # plt.figure(); + # plt.subplot(2,1,1); plt.imshow(test_disps[0], vmin=0, vmax=4); + # plt.subplot(2,1,2); plt.imshow(es[0], vmin=0, vmax=4); + # plt.show() diff --git a/hyperdepth/setup.py b/hyperdepth/setup.py index 00933cc..c40e2bb 100644 --- a/hyperdepth/setup.py +++ b/hyperdepth/setup.py @@ -8,7 +8,6 @@ import os this_dir = os.path.dirname(__file__) - extra_compile_args = ['-O3', '-std=c++11'] extra_link_args = [] @@ -22,24 +21,20 @@ library_dirs = [] libraries = ['m'] setup( - name="hyperdepth", - cmdclass= {'build_ext': build_ext}, - ext_modules=[ - Extension('hyperdepth', - sources, - extra_objects=extra_objects, - language='c++', - library_dirs=library_dirs, - libraries=libraries, - include_dirs=[ - np.get_include(), - ], - extra_compile_args=extra_compile_args, - extra_link_args=extra_link_args - ) - ] + name="hyperdepth", + cmdclass={'build_ext': build_ext}, + ext_modules=[ + Extension('hyperdepth', + sources, + extra_objects=extra_objects, + language='c++', + library_dirs=library_dirs, + libraries=libraries, + include_dirs=[ + np.get_include(), + ], + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args + ) + ] ) - - - - diff --git a/hyperdepth/vis_eval.py b/hyperdepth/vis_eval.py index d6d9a92..75cd09d 100644 --- a/hyperdepth/vis_eval.py +++ b/hyperdepth/vis_eval.py @@ -6,10 +6,13 @@ orig = cv2.imread('disp_orig.png', cv2.IMREAD_ANYDEPTH).astype(np.float32) ta = cv2.imread('disp_ta.png', cv2.IMREAD_ANYDEPTH).astype(np.float32) es = cv2.imread('disp_es.png', cv2.IMREAD_ANYDEPTH).astype(np.float32) - plt.figure() -plt.subplot(2,2,1); plt.imshow(orig / 16, vmin=0, vmax=4, cmap='magma') -plt.subplot(2,2,2); plt.imshow(ta / 16, vmin=0, vmax=4, cmap='magma') -plt.subplot(2,2,3); plt.imshow(es / 16, vmin=0, vmax=4, cmap='magma') -plt.subplot(2,2,4); plt.imshow(np.abs(es - ta) / 16, vmin=0, vmax=1, cmap='magma') +plt.subplot(2, 2, 1); +plt.imshow(orig / 16, vmin=0, vmax=4, cmap='magma') +plt.subplot(2, 2, 2); +plt.imshow(ta / 16, vmin=0, vmax=4, cmap='magma') +plt.subplot(2, 2, 3); +plt.imshow(es / 16, vmin=0, vmax=4, cmap='magma') +plt.subplot(2, 2, 4); +plt.imshow(np.abs(es - ta) / 16, vmin=0, vmax=1, cmap='magma') plt.show() diff --git a/model/exp_synph.py b/model/exp_synph.py index 2e714c7..1574ef6 100644 --- a/model/exp_synph.py +++ b/model/exp_synph.py @@ -12,226 +12,263 @@ import torchext from model import networks from data import dataset -class Worker(torchext.Worker): - def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs): - super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, train_batch_size=train_batch_size, test_batch_size=test_batch_size, save_frequency=save_frequency, **kwargs) - - self.ms = args.ms - self.pattern_path = args.pattern_path - self.lcn_radius = args.lcn_radius - self.dp_weight = args.dp_weight - self.data_type = args.data_type - - self.imsizes = [(480,640)] - for iter in range(3): - self.imsizes.append((int(self.imsizes[-1][0]/2), int(self.imsizes[-1][1]/2))) - - with open('config.json') as fp: - config = json.load(fp) - data_root = Path(config['DATA_ROOT']) - self.settings_path = data_root / self.data_type / 'settings.pkl' - sample_paths = sorted((data_root / self.data_type).glob('0*/')) - - self.train_paths = sample_paths[2**10:] - self.test_paths = sample_paths[:2**8] - - # supervise the edge encoder with only 2**8 samples - self.train_edge = len(self.train_paths) - 2**8 - - self.lcn_in = networks.LCN(self.lcn_radius, 0.05) - self.disparity_loss = networks.DisparityLoss() - self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device)) - - # evaluate in the region where opencv Block Matching has valid values - self.eval_mask = np.zeros(self.imsizes[0]) - self.eval_mask[13:self.imsizes[0][0]-13, 140:self.imsizes[0][1]-13]=1 - self.eval_mask = self.eval_mask.astype(np.bool) - self.eval_h = self.imsizes[0][0]-2*13 - self.eval_w = self.imsizes[0][1]-13-140 - - def get_train_set(self): - train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True, track_length=1) - - return train_set - - def get_test_sets(self): - test_sets = torchext.TestSets() - test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True, track_length=1) - test_sets.append('simple', test_set, test_frequency=1) - - # initialize photometric loss modules according to image sizes - self.losses = [] - for imsize, pat in zip(test_set.imsizes, test_set.patterns): - pat = pat.mean(axis=2) - pat = torch.from_numpy(pat[None][None].astype(np.float32)) - pat = pat.to(self.train_device) - self.lcn_in = self.lcn_in.to(self.train_device) - pat,_ = self.lcn_in(pat) - pat = torch.cat([pat for idx in range(3)], dim=1) - self.losses.append( networks.RectifiedPatternSimilarityLoss(imsize[0],imsize[1], pattern=pat) ) - - return test_sets - - def copy_data(self, data, device, requires_grad, train): - self.lcn_in = self.lcn_in.to(device) - - self.data = {} - for key, val in data.items(): - grad = 'im' in key and requires_grad - self.data[key] = val.to(device).requires_grad_(requires_grad=grad) - - # apply lcn to IR input - # concatenate the normalized IR input and the original IR image - if 'im' in key and 'blend' not in key: - im = self.data[key] - im_lcn,im_std = self.lcn_in(im) - im_cat = torch.cat((im_lcn, im), dim=1) - key_std = key.replace('im','std') - self.data[key]=im_cat - self.data[key_std] = im_std.to(device).detach() - - def net_forward(self, net, train): - out = net(self.data['im0']) - return out - - def loss_forward(self, out, train): - out, edge = out - if not(isinstance(out, tuple) or isinstance(out, list)): - out = [out] - if not(isinstance(edge, tuple) or isinstance(edge, list)): - edge = [edge] - - vals = [] - - # apply photometric loss - for s,l,o in zip(itertools.count(), self.losses, out): - val, pattern_proj = l(o, self.data[f'im{s}'][:,0:1,...], self.data[f'std{s}']) - if s == 0: - self.pattern_proj = pattern_proj.detach() - vals.append(val) - - # apply disparity loss - # 1-edge as ground truth edge if inversed - edge0 = 1-torch.sigmoid(edge[0]) - val = self.disparity_loss(out[0], edge0) - if self.dp_weight>0: - vals.append(val * self.dp_weight) - - # apply edge loss on a subset of training samples - for s,e in zip(itertools.count(), edge): - # inversed ground truth edge where 0 means edge - grad = self.data[f'grad{s}']<0.2 - grad = grad.to(torch.float32) - ids = self.data['id'] - mask = ids>self.train_edge - if mask.sum()>0: - val = self.edge_loss(e[mask], grad[mask]) - else: - val = torch.zeros_like(vals[0]) - if s == 0: - self.edge = e.detach() - self.edge = torch.sigmoid(self.edge) - self.edge_gt = grad.detach() - vals.append(val) - - return vals - - def numpy_in_out(self, output): - output, edge = output - if not(isinstance(output, tuple) or isinstance(output, list)): - output = [output] - es = output[0].detach().to('cpu').numpy() - gt = self.data['disp0'].to('cpu').numpy().astype(np.float32) - im = self.data['im0'][:,0:1,...].detach().to('cpu').numpy() - - ma = gt>0 - return es, gt, im, ma - - def write_img(self, out_path, es, gt, im, ma): - logging.info(f'write img {out_path}') - u_pos, _ = np.meshgrid(range(es.shape[1]), range(es.shape[0])) - - diff = np.abs(es - gt) - - vmin, vmax = np.nanmin(gt), np.nanmax(gt) - vmin = vmin - 0.2*(vmax-vmin) - vmax = vmax + 0.2*(vmax-vmin) - - pattern_proj = self.pattern_proj.to('cpu').numpy()[0,0] - im_orig = self.data['im0'].detach().to('cpu').numpy()[0,0] - pattern_diff = np.abs(im_orig - pattern_proj) - - - fig = plt.figure(figsize=(16,16)) - es_ = co.cmap.color_depth_map(es, scale=vmax) - gt_ = co.cmap.color_depth_map(gt, scale=vmax) - diff_ = co.cmap.color_error_image(diff, BGR=True) - - # plot disparities, ground truth disparity is shown only for reference - ax = plt.subplot(3,3,1); plt.imshow(es_[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'Disparity Est. {es.min():.4f}/{es.max():.4f}') - ax = plt.subplot(3,3,2); plt.imshow(gt_[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'Disparity GT {np.nanmin(gt):.4f}/{np.nanmax(gt):.4f}') - ax = plt.subplot(3,3,3); plt.imshow(diff_[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'Disparity Err. {diff.mean():.5f}') - - # plot edges - edge = self.edge.to('cpu').numpy()[0,0] - edge_gt = self.edge_gt.to('cpu').numpy()[0,0] - edge_err = np.abs(edge - edge_gt) - ax = plt.subplot(3,3,4); plt.imshow(edge, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Edge Est. {edge.min():.5f}/{edge.max():.5f}') - ax = plt.subplot(3,3,5); plt.imshow(edge_gt, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Edge GT {edge_gt.min():.5f}/{edge_gt.max():.5f}') - ax = plt.subplot(3,3,6); plt.imshow(edge_err, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Edge Err. {edge_err.mean():.5f}') - - # plot normalized IR input and warped pattern - ax = plt.subplot(3,3,7); plt.imshow(im, vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'IR input {im.mean():.5f}/{im.std():.5f}') - ax = plt.subplot(3,3,8); plt.imshow(pattern_proj, vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Warped Pattern {pattern_proj.mean():.5f}/{pattern_proj.std():.5f}') - im_std = self.data['std0'].to('cpu').numpy()[0,0] - ax = plt.subplot(3,3,9); plt.imshow(im_std, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'IR std {im_std.min():.5f}/{im_std.max():.5f}') - - plt.tight_layout() - plt.savefig(str(out_path)) - plt.close(fig) - - - def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks=[]): - if batch_idx % 512 == 0: - out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png' - es, gt, im, ma = self.numpy_in_out(output) - self.write_img(out_path, es[0,0], gt[0,0], im[0,0], ma[0,0]) - - - def callback_test_start(self, epoch, set_idx): - self.metric = co.metric.MultipleMetric( - co.metric.DistanceMetric(vec_length=1), - co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5]) - ) - - def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks=[]): - es, gt, im, ma = self.numpy_in_out(output) - - if batch_idx % 8 == 0: - out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png' - self.write_img(out_path, es[0,0], gt[0,0], im[0,0], ma[0,0]) - - es, gt, im, ma = self.crop_output(es, gt, im, ma) - - es = es.reshape(-1,1) - gt = gt.reshape(-1,1) - ma = ma.ravel() - self.metric.add(es, gt, ma) - - def callback_test_stop(self, epoch, set_idx, loss): - logging.info(f'{self.metric}') - for k, v in self.metric.items(): - self.metric_add_test(epoch, set_idx, k, v) - - def crop_output(self, es, gt, im, ma): - bs = es.shape[0] - es = np.reshape(es[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w]) - gt = np.reshape(gt[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w]) - im = np.reshape(im[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w]) - ma = np.reshape(ma[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w]) - return es, gt, im, ma +class Worker(torchext.Worker): + def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs): + super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, + train_batch_size=train_batch_size, test_batch_size=test_batch_size, + save_frequency=save_frequency, **kwargs) + + self.ms = args.ms + self.pattern_path = args.pattern_path + self.lcn_radius = args.lcn_radius + self.dp_weight = args.dp_weight + self.data_type = args.data_type + + self.imsizes = [(488, 648)] + for iter in range(3): + self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2))) + + with open('config.json') as fp: + config = json.load(fp) + data_root = Path(config['DATA_ROOT']) + self.settings_path = data_root / self.data_type / 'settings.pkl' + sample_paths = sorted((data_root / self.data_type).glob('0*/')) + + self.train_paths = sample_paths[2 ** 10:] + self.test_paths = sample_paths[:2 ** 8] + + # supervise the edge encoder with only 2**8 samples + self.train_edge = len(self.train_paths) - 2 ** 8 + + self.lcn_in = networks.LCN(self.lcn_radius, 0.05) + self.disparity_loss = networks.DisparityLoss() + self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device)) + + # evaluate in the region where opencv Block Matching has valid values + self.eval_mask = np.zeros(self.imsizes[0]) + self.eval_mask[13:self.imsizes[0][0] - 13, 140:self.imsizes[0][1] - 13] = 1 + self.eval_mask = self.eval_mask.astype(np.bool) + self.eval_h = self.imsizes[0][0] - 2 * 13 + self.eval_w = self.imsizes[0][1] - 13 - 140 + + def get_train_set(self): + train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True, + track_length=1) + + return train_set + + def get_test_sets(self): + test_sets = torchext.TestSets() + test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True, + track_length=1) + test_sets.append('simple', test_set, test_frequency=1) + + # initialize photometric loss modules according to image sizes + self.losses = [] + for imsize, pat in zip(test_set.imsizes, test_set.patterns): + pat = pat.mean(axis=2) + pat = torch.from_numpy(pat[None][None].astype(np.float32)) + pat = pat.to(self.train_device) + self.lcn_in = self.lcn_in.to(self.train_device) + pat, _ = self.lcn_in(pat) + pat = torch.cat([pat for idx in range(3)], dim=1) + self.losses.append(networks.RectifiedPatternSimilarityLoss(imsize[0], imsize[1], pattern=pat)) + + return test_sets + + def copy_data(self, data, device, requires_grad, train): + self.lcn_in = self.lcn_in.to(device) + + self.data = {} + for key, val in data.items(): + grad = 'im' in key and requires_grad + self.data[key] = val.to(device).requires_grad_(requires_grad=grad) + + # apply lcn to IR input + # concatenate the normalized IR input and the original IR image + if 'im' in key and 'blend' not in key: + im = self.data[key] + im_lcn, im_std = self.lcn_in(im) + im_cat = torch.cat((im_lcn, im), dim=1) + key_std = key.replace('im', 'std') + self.data[key] = im_cat + self.data[key_std] = im_std.to(device).detach() + + def net_forward(self, net, train): + out = net(self.data['im0']) + return out + + def loss_forward(self, out, train): + out, edge = out + if not (isinstance(out, tuple) or isinstance(out, list)): + out = [out] + if not (isinstance(edge, tuple) or isinstance(edge, list)): + edge = [edge] + + vals = [] + + # apply photometric loss + for s, l, o in zip(itertools.count(), self.losses, out): + val, pattern_proj = l(o, self.data[f'im{s}'][:, 0:1, ...], self.data[f'std{s}']) + if s == 0: + self.pattern_proj = pattern_proj.detach() + vals.append(val) + + # apply disparity loss + # 1-edge as ground truth edge if inversed + edge0 = 1 - torch.sigmoid(edge[0]) + val = self.disparity_loss(out[0], edge0) + if self.dp_weight > 0: + vals.append(val * self.dp_weight) + + # apply edge loss on a subset of training samples + for s, e in zip(itertools.count(), edge): + # inversed ground truth edge where 0 means edge + grad = self.data[f'grad{s}'] < 0.2 + grad = grad.to(torch.float32) + ids = self.data['id'] + mask = ids > self.train_edge + if mask.sum() > 0: + val = self.edge_loss(e[mask], grad[mask]) + else: + val = torch.zeros_like(vals[0]) + if s == 0: + self.edge = e.detach() + self.edge = torch.sigmoid(self.edge) + self.edge_gt = grad.detach() + vals.append(val) + + return vals + + def numpy_in_out(self, output): + output, edge = output + if not (isinstance(output, tuple) or isinstance(output, list)): + output = [output] + es = output[0].detach().to('cpu').numpy() + gt = self.data['disp0'].to('cpu').numpy().astype(np.float32) + im = self.data['im0'][:, 0:1, ...].detach().to('cpu').numpy() + + ma = gt > 0 + return es, gt, im, ma + + def write_img(self, out_path, es, gt, im, ma): + logging.info(f'write img {out_path}') + u_pos, _ = np.meshgrid(range(es.shape[1]), range(es.shape[0])) + + diff = np.abs(es - gt) + + vmin, vmax = np.nanmin(gt), np.nanmax(gt) + vmin = vmin - 0.2 * (vmax - vmin) + vmax = vmax + 0.2 * (vmax - vmin) + + pattern_proj = self.pattern_proj.to('cpu').numpy()[0, 0] + im_orig = self.data['im0'].detach().to('cpu').numpy()[0, 0] + pattern_diff = np.abs(im_orig - pattern_proj) + + fig = plt.figure(figsize=(16, 16)) + es_ = co.cmap.color_depth_map(es, scale=vmax) + gt_ = co.cmap.color_depth_map(gt, scale=vmax) + diff_ = co.cmap.color_error_image(diff, BGR=True) + + # plot disparities, ground truth disparity is shown only for reference + ax = plt.subplot(3, 3, 1) + plt.imshow(es_[..., [2, 1, 0]]) + plt.xticks([]) + plt.yticks([]) + ax.set_title(f'Disparity Est. {es.min():.4f}/{es.max():.4f}') + ax = plt.subplot(3, 3, 2) + plt.imshow(gt_[..., [2, 1, 0]]) + plt.xticks([]) + plt.yticks([]) + ax.set_title(f'Disparity GT {np.nanmin(gt):.4f}/{np.nanmax(gt):.4f}') + ax = plt.subplot(3, 3, 3) + plt.imshow(diff_[..., [2, 1, 0]]) + plt.xticks([]) + plt.yticks([]) + ax.set_title(f'Disparity Err. {diff.mean():.5f}') + + # plot edges + edge = self.edge.to('cpu').numpy()[0, 0] + edge_gt = self.edge_gt.to('cpu').numpy()[0, 0] + edge_err = np.abs(edge - edge_gt) + ax = plt.subplot(3, 3, 4); + plt.imshow(edge, cmap='gray'); + plt.xticks([]); + plt.yticks([]); + ax.set_title(f'Edge Est. {edge.min():.5f}/{edge.max():.5f}') + ax = plt.subplot(3, 3, 5); + plt.imshow(edge_gt, cmap='gray'); + plt.xticks([]); + plt.yticks([]); + ax.set_title(f'Edge GT {edge_gt.min():.5f}/{edge_gt.max():.5f}') + ax = plt.subplot(3, 3, 6); + plt.imshow(edge_err, cmap='gray'); + plt.xticks([]); + plt.yticks([]); + ax.set_title(f'Edge Err. {edge_err.mean():.5f}') + + # plot normalized IR input and warped pattern + ax = plt.subplot(3, 3, 7); + plt.imshow(im, vmin=im.min(), vmax=im.max(), cmap='gray'); + plt.xticks([]); + plt.yticks([]); + ax.set_title(f'IR input {im.mean():.5f}/{im.std():.5f}') + ax = plt.subplot(3, 3, 8); + plt.imshow(pattern_proj, vmin=im.min(), vmax=im.max(), cmap='gray'); + plt.xticks([]); + plt.yticks([]); + ax.set_title(f'Warped Pattern {pattern_proj.mean():.5f}/{pattern_proj.std():.5f}') + im_std = self.data['std0'].to('cpu').numpy()[0, 0] + ax = plt.subplot(3, 3, 9); + plt.imshow(im_std, cmap='gray'); + plt.xticks([]); + plt.yticks([]); + ax.set_title(f'IR std {im_std.min():.5f}/{im_std.max():.5f}') + + plt.tight_layout() + plt.savefig(str(out_path)) + plt.close(fig) + + def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks=[]): + if batch_idx % 512 == 0: + out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png' + es, gt, im, ma = self.numpy_in_out(output) + self.write_img(out_path, es[0, 0], gt[0, 0], im[0, 0], ma[0, 0]) + + def callback_test_start(self, epoch, set_idx): + self.metric = co.metric.MultipleMetric( + co.metric.DistanceMetric(vec_length=1), + co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5]) + ) + + def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks=[]): + es, gt, im, ma = self.numpy_in_out(output) + + if batch_idx % 8 == 0: + out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png' + self.write_img(out_path, es[0, 0], gt[0, 0], im[0, 0], ma[0, 0]) + + es, gt, im, ma = self.crop_output(es, gt, im, ma) + + es = es.reshape(-1, 1) + gt = gt.reshape(-1, 1) + ma = ma.ravel() + self.metric.add(es, gt, ma) + + def callback_test_stop(self, epoch, set_idx, loss): + logging.info(f'{self.metric}') + for k, v in self.metric.items(): + self.metric_add_test(epoch, set_idx, k, v) + + def crop_output(self, es, gt, im, ma): + bs = es.shape[0] + es = np.reshape(es[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w]) + gt = np.reshape(gt[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w]) + im = np.reshape(im[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w]) + ma = np.reshape(ma[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w]) + return es, gt, im, ma if __name__ == '__main__': - pass + pass diff --git a/model/exp_synphge.py b/model/exp_synphge.py index 9794450..eec320a 100644 --- a/model/exp_synphge.py +++ b/model/exp_synphge.py @@ -12,287 +12,324 @@ import torchext from model import networks from data import dataset + class Worker(torchext.Worker): - def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs): - super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, train_batch_size=train_batch_size, test_batch_size=test_batch_size, save_frequency=save_frequency, **kwargs) - - self.ms = args.ms - self.pattern_path = args.pattern_path - self.lcn_radius = args.lcn_radius - self.dp_weight = args.dp_weight - self.ge_weight = args.ge_weight - self.track_length = args.track_length - self.data_type = args.data_type - assert(self.track_length>1) - - self.imsizes = [(480,640)] - for iter in range(3): - self.imsizes.append((int(self.imsizes[-1][0]/2), int(self.imsizes[-1][1]/2))) - - with open('config.json') as fp: - config = json.load(fp) - data_root = Path(config['DATA_ROOT']) - self.settings_path = data_root / self.data_type / 'settings.pkl' - sample_paths = sorted((data_root / self.data_type).glob('0*/')) - - self.train_paths = sample_paths[2**10:] - self.test_paths = sample_paths[:2**8] - - # supervise the edge encoder with only 2**8 samples - self.train_edge = len(self.train_paths) - 2**8 - - self.lcn_in = networks.LCN(self.lcn_radius, 0.05) - self.disparity_loss = networks.DisparityLoss() - self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device)) - - # evaluate in the region where opencv Block Matching has valid values - self.eval_mask = np.zeros(self.imsizes[0]) - self.eval_mask[13:self.imsizes[0][0]-13, 140:self.imsizes[0][1]-13]=1 - self.eval_mask = self.eval_mask.astype(np.bool) - self.eval_h = self.imsizes[0][0]-2*13 - self.eval_w = self.imsizes[0][1]-13-140 - - - def get_train_set(self): - train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True, track_length=self.track_length) - return train_set - - def get_test_sets(self): - test_sets = torchext.TestSets() - test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True, track_length=1) - test_sets.append('simple', test_set, test_frequency=1) - - self.ph_losses = [] - self.ge_losses = [] - self.d2ds = [] - - self.lcn_in = self.lcn_in.to('cuda') - for sidx in range(len(test_set.imsizes)): - imsize = test_set.imsizes[sidx] - pat = test_set.patterns[sidx] - pat = pat.mean(axis=2) - pat = torch.from_numpy(pat[None][None].astype(np.float32)).to('cuda') - pat,_ = self.lcn_in(pat) - pat = torch.cat([pat for idx in range(3)], dim=1) - ph_loss = networks.RectifiedPatternSimilarityLoss(imsize[0],imsize[1], pattern=pat) - - K = test_set.getK(sidx) - Ki = np.linalg.inv(K) - K = torch.from_numpy(K) - Ki = torch.from_numpy(Ki) - ge_loss = networks.ProjectionDepthSimilarityLoss(K, Ki, imsize[0], imsize[1], clamp=0.1) - - self.ph_losses.append( ph_loss ) - self.ge_losses.append( ge_loss ) - - d2d = networks.DispToDepth(float(test_set.focal_lengths[sidx]), float(test_set.baseline)) - self.d2ds.append( d2d ) - - return test_sets - - def copy_data(self, data, device, requires_grad, train): - self.data = {} - - self.lcn_in = self.lcn_in.to(device) - for key, val in data.items(): - # from - # batch_size x track_length x ... - # to - # track_length x batch_size x ... - if len(val.shape)>2: - if train: - val = val.transpose(0,1) + def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs): + super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, + train_batch_size=train_batch_size, test_batch_size=test_batch_size, + save_frequency=save_frequency, **kwargs) + + self.ms = args.ms + self.pattern_path = args.pattern_path + self.lcn_radius = args.lcn_radius + self.dp_weight = args.dp_weight + self.ge_weight = args.ge_weight + self.track_length = args.track_length + self.data_type = args.data_type + assert (self.track_length > 1) + + self.imsizes = [(480, 640)] + for iter in range(3): + self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2))) + + with open('config.json') as fp: + config = json.load(fp) + data_root = Path(config['DATA_ROOT']) + self.settings_path = data_root / self.data_type / 'settings.pkl' + sample_paths = sorted((data_root / self.data_type).glob('0*/')) + + self.train_paths = sample_paths[2 ** 10:] + self.test_paths = sample_paths[:2 ** 8] + + # supervise the edge encoder with only 2**8 samples + self.train_edge = len(self.train_paths) - 2 ** 8 + + self.lcn_in = networks.LCN(self.lcn_radius, 0.05) + self.disparity_loss = networks.DisparityLoss() + self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device)) + + # evaluate in the region where opencv Block Matching has valid values + self.eval_mask = np.zeros(self.imsizes[0]) + self.eval_mask[13:self.imsizes[0][0] - 13, 140:self.imsizes[0][1] - 13] = 1 + self.eval_mask = self.eval_mask.astype(np.bool) + self.eval_h = self.imsizes[0][0] - 2 * 13 + self.eval_w = self.imsizes[0][1] - 13 - 140 + + def get_train_set(self): + train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True, + track_length=self.track_length) + return train_set + + def get_test_sets(self): + test_sets = torchext.TestSets() + test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True, + track_length=1) + test_sets.append('simple', test_set, test_frequency=1) + + self.ph_losses = [] + self.ge_losses = [] + self.d2ds = [] + + self.lcn_in = self.lcn_in.to('cuda') + for sidx in range(len(test_set.imsizes)): + imsize = test_set.imsizes[sidx] + pat = test_set.patterns[sidx] + pat = pat.mean(axis=2) + pat = torch.from_numpy(pat[None][None].astype(np.float32)).to('cuda') + pat, _ = self.lcn_in(pat) + pat = torch.cat([pat for idx in range(3)], dim=1) + ph_loss = networks.RectifiedPatternSimilarityLoss(imsize[0], imsize[1], pattern=pat) + + K = test_set.getK(sidx) + Ki = np.linalg.inv(K) + K = torch.from_numpy(K) + Ki = torch.from_numpy(Ki) + ge_loss = networks.ProjectionDepthSimilarityLoss(K, Ki, imsize[0], imsize[1], clamp=0.1) + + self.ph_losses.append(ph_loss) + self.ge_losses.append(ge_loss) + + d2d = networks.DispToDepth(float(test_set.focal_lengths[sidx]), float(test_set.baseline)) + self.d2ds.append(d2d) + + return test_sets + + def copy_data(self, data, device, requires_grad, train): + self.data = {} + + self.lcn_in = self.lcn_in.to(device) + for key, val in data.items(): + # from + # batch_size x track_length x ... + # to + # track_length x batch_size x ... + if len(val.shape) > 2: + if train: + val = val.transpose(0, 1) + else: + val = val.unsqueeze(0) + grad = 'im' in key and requires_grad + self.data[key] = val.to(device).requires_grad_(requires_grad=grad) + if 'im' in key and 'blend' not in key: + im = self.data[key] + tl = im.shape[0] + bs = im.shape[1] + im_lcn, im_std = self.lcn_in(im.contiguous().view(-1, *im.shape[2:])) + key_std = key.replace('im', 'std') + self.data[key_std] = im_std.view(tl, bs, *im.shape[2:]).to(device) + im_cat = torch.cat((im_lcn.view(tl, bs, *im.shape[2:]), im), dim=2) + self.data[key] = im_cat + + def net_forward(self, net, train): + im0 = self.data['im0'] + tl = im0.shape[0] + bs = im0.shape[1] + im0 = im0.view(-1, *im0.shape[2:]) + out, edge = net(im0) + if not (isinstance(out, tuple) or isinstance(out, list)): + out = out.view(tl, bs, *out.shape[1:]) + edge = edge.view(tl, bs, *out.shape[1:]) else: - val = val.unsqueeze(0) - grad = 'im' in key and requires_grad - self.data[key] = val.to(device).requires_grad_(requires_grad=grad) - if 'im' in key and 'blend' not in key: - im = self.data[key] - tl = im.shape[0] - bs = im.shape[1] - im_lcn,im_std = self.lcn_in(im.contiguous().view(-1, *im.shape[2:])) - key_std = key.replace('im','std') - self.data[key_std] = im_std.view(tl, bs, *im.shape[2:]).to(device) - im_cat = torch.cat((im_lcn.view(tl, bs, *im.shape[2:]), im), dim=2) - self.data[key] = im_cat - - def net_forward(self, net, train): - im0 = self.data['im0'] - tl = im0.shape[0] - bs = im0.shape[1] - im0 = im0.view(-1, *im0.shape[2:]) - out, edge = net(im0) - if not(isinstance(out, tuple) or isinstance(out, list)): - out = out.view(tl, bs, *out.shape[1:]) - edge = edge.view(tl, bs, *out.shape[1:]) - else: - out = [o.view(tl, bs, *o.shape[1:]) for o in out] - edge = [e.view(tl, bs, *e.shape[1:]) for e in edge] - return out, edge - - def loss_forward(self, out, train): - out, edge = out - if not(isinstance(out, tuple) or isinstance(out, list)): - out = [out] - vals = [] - diffs = [] - - # apply photometric loss - for s,l,o in zip(itertools.count(), self.ph_losses, out): - im = self.data[f'im{s}'] - im = im.view(-1, *im.shape[2:]) - o = o.view(-1, *o.shape[2:]) - std = self.data[f'std{s}'] - std = std.view(-1, *std.shape[2:]) - val, pattern_proj = l(o, im[:,0:1,...], std) - vals.append(val) - if s == 0: - self.pattern_proj = pattern_proj.detach() - - # apply disparity loss - # 1-edge as ground truth edge if inversed - edge0 = 1-torch.sigmoid(edge[0]) - edge0 = edge0.view(-1, *edge0.shape[2:]) - out0 = out[0].view(-1, *out[0].shape[2:]) - val = self.disparity_loss(out0, edge0) - if self.dp_weight>0: - vals.append(val * self.dp_weight) - - # apply edge loss on a subset of training samples - for s,e in zip(itertools.count(), edge): - # inversed ground truth edge where 0 means edge - grad = self.data[f'grad{s}']<0.2 - grad = grad.to(torch.float32) - ids = self.data['id'] - mask = ids>self.train_edge - if mask.sum()>0: - e = e[:,mask,:] - grad = grad[:,mask,:] - e = e.view(-1, *e.shape[2:]) - grad = grad.view(-1, *grad.shape[2:]) - val = self.edge_loss(e, grad) - else: - val = torch.zeros_like(vals[0]) - vals.append(val) - - if train is False: - return vals - - # apply geometric loss - R = self.data['R'] - t = self.data['t'] - ge_num = self.track_length * (self.track_length-1) / 2 - for sidx in range(len(out)): - d2d = self.d2ds[sidx] - depth = d2d(out[sidx]) - ge_loss = self.ge_losses[sidx] - imsize = self.imsizes[sidx] - for tidx0 in range(depth.shape[0]): - for tidx1 in range(tidx0+1, depth.shape[0]): - depth0 = depth[tidx0] - R0 = R[tidx0] - t0 = t[tidx0] - depth1 = depth[tidx1] - R1 = R[tidx1] - t1 = t[tidx1] - - val = ge_loss(depth0, depth1, R0, t0, R1, t1) - vals.append(val * self.ge_weight / ge_num) - - return vals - - def numpy_in_out(self, output): - output, edge = output - if not(isinstance(output, tuple) or isinstance(output, list)): - output = [output] - es = output[0].detach().to('cpu').numpy() - gt = self.data['disp0'].to('cpu').numpy().astype(np.float32) - im = self.data['im0'][:,:,0:1,...].detach().to('cpu').numpy() - ma = gt>0 - return es, gt, im, ma - - def write_img(self, out_path, es, gt, im, ma): - logging.info(f'write img {out_path}') - u_pos, _ = np.meshgrid(range(es.shape[1]), range(es.shape[0])) - - diff = np.abs(es - gt) - - vmin, vmax = np.nanmin(gt), np.nanmax(gt) - vmin = vmin - 0.2*(vmax-vmin) - vmax = vmax + 0.2*(vmax-vmin) - - pattern_proj = self.pattern_proj.to('cpu').numpy()[0,0] - im_orig = self.data['im0'].detach().to('cpu').numpy()[0,0,0] - pattern_diff = np.abs(im_orig - pattern_proj) - - fig = plt.figure(figsize=(16,16)) - es0 = co.cmap.color_depth_map(es[0], scale=vmax) - gt0 = co.cmap.color_depth_map(gt[0], scale=vmax) - diff0 = co.cmap.color_error_image(diff[0], BGR=True) - - # plot disparities, ground truth disparity is shown only for reference - ax = plt.subplot(3,3,1); plt.imshow(es0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Est. {es0.min():.4f}/{es0.max():.4f}') - ax = plt.subplot(3,3,2); plt.imshow(gt0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity GT {np.nanmin(gt0):.4f}/{np.nanmax(gt0):.4f}') - ax = plt.subplot(3,3,3); plt.imshow(diff0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Err. {diff0.mean():.5f}') - - # plot disparities of the second frame in the track if exists - if es.shape[0]>=2: - es1 = co.cmap.color_depth_map(es[1], scale=vmax) - gt1 = co.cmap.color_depth_map(gt[1], scale=vmax) - diff1 = co.cmap.color_error_image(diff[1], BGR=True) - ax = plt.subplot(3,3,4); plt.imshow(es1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity Est. {es1.min():.4f}/{es1.max():.4f}') - ax = plt.subplot(3,3,5); plt.imshow(gt1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity GT {np.nanmin(gt1):.4f}/{np.nanmax(gt1):.4f}') - ax = plt.subplot(3,3,6); plt.imshow(diff1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity Err. {diff1.mean():.5f}') - - # plot normalized IR inputs - ax = plt.subplot(3,3,7); plt.imshow(im[0], vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 IR input {im[0].mean():.5f}/{im[0].std():.5f}') - if es.shape[0]>=2: - ax = plt.subplot(3,3,8); plt.imshow(im[1], vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 IR input {im[1].mean():.5f}/{im[1].std():.5f}') - - plt.tight_layout() - plt.savefig(str(out_path)) - plt.close(fig) - - def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks): - if batch_idx % 512 == 0: - out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png' - es, gt, im, ma = self.numpy_in_out(output) - masks = [ m.detach().to('cpu').numpy() for m in masks ] - self.write_img(out_path, es[:,0,0], gt[:,0,0], im[:,0,0], ma[:,0,0]) - - def callback_test_start(self, epoch, set_idx): - self.metric = co.metric.MultipleMetric( - co.metric.DistanceMetric(vec_length=1), - co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5]) - ) - - def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks): - es, gt, im, ma = self.numpy_in_out(output) - - if batch_idx % 8 == 0: - out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png' - self.write_img(out_path, es[:,0,0], gt[:,0,0], im[:,0,0], ma[:,0,0]) - - es, gt, im, ma = self.crop_output(es, gt, im, ma) - - es = es.reshape(-1,1) - gt = gt.reshape(-1,1) - ma = ma.ravel() - self.metric.add(es, gt, ma) - - def callback_test_stop(self, epoch, set_idx, loss): - logging.info(f'{self.metric}') - for k, v in self.metric.items(): - self.metric_add_test(epoch, set_idx, k, v) - - def crop_output(self, es, gt, im, ma): - tl = es.shape[0] - bs = es.shape[1] - es = np.reshape(es[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w]) - gt = np.reshape(gt[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w]) - im = np.reshape(im[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w]) - ma = np.reshape(ma[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w]) - return es, gt, im, ma + out = [o.view(tl, bs, *o.shape[1:]) for o in out] + edge = [e.view(tl, bs, *e.shape[1:]) for e in edge] + return out, edge + + def loss_forward(self, out, train): + out, edge = out + if not (isinstance(out, tuple) or isinstance(out, list)): + out = [out] + vals = [] + diffs = [] + + # apply photometric loss + for s, l, o in zip(itertools.count(), self.ph_losses, out): + im = self.data[f'im{s}'] + im = im.view(-1, *im.shape[2:]) + o = o.view(-1, *o.shape[2:]) + std = self.data[f'std{s}'] + std = std.view(-1, *std.shape[2:]) + val, pattern_proj = l(o, im[:, 0:1, ...], std) + vals.append(val) + if s == 0: + self.pattern_proj = pattern_proj.detach() + + # apply disparity loss + # 1-edge as ground truth edge if inversed + edge0 = 1 - torch.sigmoid(edge[0]) + edge0 = edge0.view(-1, *edge0.shape[2:]) + out0 = out[0].view(-1, *out[0].shape[2:]) + val = self.disparity_loss(out0, edge0) + if self.dp_weight > 0: + vals.append(val * self.dp_weight) + + # apply edge loss on a subset of training samples + for s, e in zip(itertools.count(), edge): + # inversed ground truth edge where 0 means edge + grad = self.data[f'grad{s}'] < 0.2 + grad = grad.to(torch.float32) + ids = self.data['id'] + mask = ids > self.train_edge + if mask.sum() > 0: + e = e[:, mask, :] + grad = grad[:, mask, :] + e = e.view(-1, *e.shape[2:]) + grad = grad.view(-1, *grad.shape[2:]) + val = self.edge_loss(e, grad) + else: + val = torch.zeros_like(vals[0]) + vals.append(val) + + if train is False: + return vals + + # apply geometric loss + R = self.data['R'] + t = self.data['t'] + ge_num = self.track_length * (self.track_length - 1) / 2 + for sidx in range(len(out)): + d2d = self.d2ds[sidx] + depth = d2d(out[sidx]) + ge_loss = self.ge_losses[sidx] + imsize = self.imsizes[sidx] + for tidx0 in range(depth.shape[0]): + for tidx1 in range(tidx0 + 1, depth.shape[0]): + depth0 = depth[tidx0] + R0 = R[tidx0] + t0 = t[tidx0] + depth1 = depth[tidx1] + R1 = R[tidx1] + t1 = t[tidx1] + + val = ge_loss(depth0, depth1, R0, t0, R1, t1) + vals.append(val * self.ge_weight / ge_num) + + return vals + + def numpy_in_out(self, output): + output, edge = output + if not (isinstance(output, tuple) or isinstance(output, list)): + output = [output] + es = output[0].detach().to('cpu').numpy() + gt = self.data['disp0'].to('cpu').numpy().astype(np.float32) + im = self.data['im0'][:, :, 0:1, ...].detach().to('cpu').numpy() + ma = gt > 0 + return es, gt, im, ma + + def write_img(self, out_path, es, gt, im, ma): + logging.info(f'write img {out_path}') + u_pos, _ = np.meshgrid(range(es.shape[1]), range(es.shape[0])) + + diff = np.abs(es - gt) + + vmin, vmax = np.nanmin(gt), np.nanmax(gt) + vmin = vmin - 0.2 * (vmax - vmin) + vmax = vmax + 0.2 * (vmax - vmin) + + pattern_proj = self.pattern_proj.to('cpu').numpy()[0, 0] + im_orig = self.data['im0'].detach().to('cpu').numpy()[0, 0, 0] + pattern_diff = np.abs(im_orig - pattern_proj) + + fig = plt.figure(figsize=(16, 16)) + es0 = co.cmap.color_depth_map(es[0], scale=vmax) + gt0 = co.cmap.color_depth_map(gt[0], scale=vmax) + diff0 = co.cmap.color_error_image(diff[0], BGR=True) + + # plot disparities, ground truth disparity is shown only for reference + ax = plt.subplot(3, 3, 1); + plt.imshow(es0[..., [2, 1, 0]]); + plt.xticks([]); + plt.yticks([]); + ax.set_title(f'F0 Disparity Est. {es0.min():.4f}/{es0.max():.4f}') + ax = plt.subplot(3, 3, 2); + plt.imshow(gt0[..., [2, 1, 0]]); + plt.xticks([]); + plt.yticks([]); + ax.set_title(f'F0 Disparity GT {np.nanmin(gt0):.4f}/{np.nanmax(gt0):.4f}') + ax = plt.subplot(3, 3, 3); + plt.imshow(diff0[..., [2, 1, 0]]); + plt.xticks([]); + plt.yticks([]); + ax.set_title(f'F0 Disparity Err. {diff0.mean():.5f}') + + # plot disparities of the second frame in the track if exists + if es.shape[0] >= 2: + es1 = co.cmap.color_depth_map(es[1], scale=vmax) + gt1 = co.cmap.color_depth_map(gt[1], scale=vmax) + diff1 = co.cmap.color_error_image(diff[1], BGR=True) + ax = plt.subplot(3, 3, 4); + plt.imshow(es1[..., [2, 1, 0]]); + plt.xticks([]); + plt.yticks([]); + ax.set_title(f'F1 Disparity Est. {es1.min():.4f}/{es1.max():.4f}') + ax = plt.subplot(3, 3, 5); + plt.imshow(gt1[..., [2, 1, 0]]); + plt.xticks([]); + plt.yticks([]); + ax.set_title(f'F1 Disparity GT {np.nanmin(gt1):.4f}/{np.nanmax(gt1):.4f}') + ax = plt.subplot(3, 3, 6); + plt.imshow(diff1[..., [2, 1, 0]]); + plt.xticks([]); + plt.yticks([]); + ax.set_title(f'F1 Disparity Err. {diff1.mean():.5f}') + + # plot normalized IR inputs + ax = plt.subplot(3, 3, 7); + plt.imshow(im[0], vmin=im.min(), vmax=im.max(), cmap='gray'); + plt.xticks([]); + plt.yticks([]); + ax.set_title(f'F0 IR input {im[0].mean():.5f}/{im[0].std():.5f}') + if es.shape[0] >= 2: + ax = plt.subplot(3, 3, 8); + plt.imshow(im[1], vmin=im.min(), vmax=im.max(), cmap='gray'); + plt.xticks([]); + plt.yticks([]); + ax.set_title(f'F1 IR input {im[1].mean():.5f}/{im[1].std():.5f}') + + plt.tight_layout() + plt.savefig(str(out_path)) + plt.close(fig) + + def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks): + if batch_idx % 512 == 0: + out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png' + es, gt, im, ma = self.numpy_in_out(output) + masks = [m.detach().to('cpu').numpy() for m in masks] + self.write_img(out_path, es[:, 0, 0], gt[:, 0, 0], im[:, 0, 0], ma[:, 0, 0]) + + def callback_test_start(self, epoch, set_idx): + self.metric = co.metric.MultipleMetric( + co.metric.DistanceMetric(vec_length=1), + co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5]) + ) + + def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks): + es, gt, im, ma = self.numpy_in_out(output) + + if batch_idx % 8 == 0: + out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png' + self.write_img(out_path, es[:, 0, 0], gt[:, 0, 0], im[:, 0, 0], ma[:, 0, 0]) + + es, gt, im, ma = self.crop_output(es, gt, im, ma) + + es = es.reshape(-1, 1) + gt = gt.reshape(-1, 1) + ma = ma.ravel() + self.metric.add(es, gt, ma) + + def callback_test_stop(self, epoch, set_idx, loss): + logging.info(f'{self.metric}') + for k, v in self.metric.items(): + self.metric_add_test(epoch, set_idx, k, v) + + def crop_output(self, es, gt, im, ma): + tl = es.shape[0] + bs = es.shape[1] + es = np.reshape(es[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w]) + gt = np.reshape(gt[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w]) + im = np.reshape(im[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w]) + ma = np.reshape(ma[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w]) + return es, gt, im, ma + if __name__ == '__main__': - pass + pass diff --git a/model/networks.py b/model/networks.py index 4781706..cd21bae 100644 --- a/model/networks.py +++ b/model/networks.py @@ -8,559 +8,572 @@ import co class TimedModule(torch.nn.Module): - def __init__(self, mod_name): - super().__init__() - self.mod_name = mod_name + def __init__(self, mod_name): + super().__init__() + self.mod_name = mod_name - def tforward(self, *args, **kwargs): - raise Exception('not implemented') + def tforward(self, *args, **kwargs): + raise Exception('not implemented') - def forward(self, *args, **kwargs): - torch.cuda.synchronize() - with co.gtimer.Ctx(self.mod_name): - x = self.tforward(*args, **kwargs) - torch.cuda.synchronize() - return x + def forward(self, *args, **kwargs): + torch.cuda.synchronize() + with co.gtimer.Ctx(self.mod_name): + x = self.tforward(*args, **kwargs) + torch.cuda.synchronize() + return x class PosOutput(TimedModule): - def __init__(self, channels_in, type, im_height, im_width, alpha=1, beta=0, gamma=1, offset=0): - super().__init__(mod_name='PosOutput') - self.im_width = im_width - self.im_width = im_width - - if type == 'pos': - self.layer = torch.nn.Sequential( - torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1), - SigmoidAffine(alpha=alpha, beta=beta, gamma=gamma, offset=offset) - ) - elif type == 'pos_row': - self.layer = torch.nn.Sequential( - MultiLinear(im_height, channels_in, 1), - SigmoidAffine(alpha=alpha, beta=beta, gamma=gamma, offset=offset) - ) + def __init__(self, channels_in, type, im_height, im_width, alpha=1, beta=0, gamma=1, offset=0): + super().__init__(mod_name='PosOutput') + self.im_width = im_width + self.im_width = im_width + + if type == 'pos': + self.layer = torch.nn.Sequential( + torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1), + SigmoidAffine(alpha=alpha, beta=beta, gamma=gamma, offset=offset) + ) + elif type == 'pos_row': + self.layer = torch.nn.Sequential( + MultiLinear(im_height, channels_in, 1), + SigmoidAffine(alpha=alpha, beta=beta, gamma=gamma, offset=offset) + ) + + self.u_pos = None + + def tforward(self, x): + if self.u_pos is None: + self.u_pos = torch.arange(x.shape[3], dtype=torch.float32).view(1, 1, 1, -1) + self.u_pos = self.u_pos.to(x.device) + pos = self.layer(x) + disp = self.u_pos - pos + return disp - self.u_pos = None - def tforward(self, x): - if self.u_pos is None: - self.u_pos = torch.arange(x.shape[3], dtype=torch.float32).view(1,1,1,-1) - self.u_pos = self.u_pos.to(x.device) - pos = self.layer(x) - disp = self.u_pos - pos - return disp +class OutputLayerFactory(object): + ''' + Define type of output + type options: + linear: apply only conv channel, used for the edge decoder + disp: estimate the disparity + disp_row: independently estimate the disparity per row + pos: estimate the absolute location + pos_row: independently estimate the absolute location per row + ''' + def __init__(self, type='disp', params={}): + self.type = type + self.params = params -class OutputLayerFactory(object): - ''' - Define type of output - type options: - linear: apply only conv channel, used for the edge decoder - disp: estimate the disparity - disp_row: independently estimate the disparity per row - pos: estimate the absolute location - pos_row: independently estimate the absolute location per row - ''' - def __init__(self, type='disp', params={}): - self.type = type - self.params = params - - def __call__(self, channels_in, imsize): - - if self.type == 'linear': - return torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1) - - elif self.type == 'disp': - return torch.nn.Sequential( - torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1), - SigmoidAffine(**self.params) - ) + def __call__(self, channels_in, imsize): - elif self.type == 'disp_row': - return torch.nn.Sequential( - MultiLinear(imsize[0], channels_in, 1), - SigmoidAffine(**self.params) - ) + if self.type == 'linear': + return torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1) - elif self.type == 'pos' or self.type == 'pos_row': - return PosOutput(channels_in, **self.params) + elif self.type == 'disp': + return torch.nn.Sequential( + torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1), + SigmoidAffine(**self.params) + ) - else: - raise Exception('unknown output layer type') + elif self.type == 'disp_row': + return torch.nn.Sequential( + MultiLinear(imsize[0], channels_in, 1), + SigmoidAffine(**self.params) + ) + + elif self.type == 'pos' or self.type == 'pos_row': + return PosOutput(channels_in, **self.params) + + else: + raise Exception('unknown output layer type') class SigmoidAffine(TimedModule): - def __init__(self, alpha=1, beta=0, gamma=1, offset=0): - super().__init__(mod_name='SigmoidAffine') - self.alpha = alpha - self.beta = beta - self.gamma = gamma - self.offset = offset + def __init__(self, alpha=1, beta=0, gamma=1, offset=0): + super().__init__(mod_name='SigmoidAffine') + self.alpha = alpha + self.beta = beta + self.gamma = gamma + self.offset = offset - def tforward(self, x): - return torch.sigmoid(x/self.gamma - self.offset) * self.alpha + self.beta + def tforward(self, x): + return torch.sigmoid(x / self.gamma - self.offset) * self.alpha + self.beta class MultiLinear(TimedModule): - def __init__(self, n, channels_in, channels_out): - super().__init__(mod_name='MultiLinear') - self.channels_out = channels_out - self.mods = torch.nn.ModuleList() - for idx in range(n): - self.mods.append(torch.nn.Linear(channels_in, channels_out)) + def __init__(self, n, channels_in, channels_out): + super().__init__(mod_name='MultiLinear') + self.channels_out = channels_out + self.mods = torch.nn.ModuleList() + for idx in range(n): + self.mods.append(torch.nn.Linear(channels_in, channels_out)) + + def tforward(self, x): + x = x.permute(2, 0, 3, 1) # BxCxHxW => HxBxWxC + y = x.new_empty(*x.shape[:-1], self.channels_out) + for hidx in range(x.shape[0]): + y[hidx] = self.mods[hidx](x[hidx]) + y = y.permute(1, 3, 0, 2) # HxBxWxC => BxCxHxW + return y + - def tforward(self, x): - x = x.permute(2,0,3,1) # BxCxHxW => HxBxWxC - y = x.new_empty(*x.shape[:-1], self.channels_out) - for hidx in range(x.shape[0]): - y[hidx] = self.mods[hidx](x[hidx]) - y = y.permute(1,3,0,2) # HxBxWxC => BxCxHxW - return y +class DispNetS(TimedModule): + ''' + Disparity Decoder based on DispNetS + ''' + + def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False, + channel_multiplier=1): + super(DispNetS, self).__init__(mod_name='DispNetS') + + self.output_ms = output_ms + self.coordconv = coordconv + + conv_planes = channel_multiplier * np.array([32, 64, 128, 256, 512, 512, 512]) + self.conv1 = self.downsample_conv(channels_in, conv_planes[0], kernel_size=7) + self.conv2 = self.downsample_conv(conv_planes[0], conv_planes[1], kernel_size=5) + self.conv3 = self.downsample_conv(conv_planes[1], conv_planes[2]) + self.conv4 = self.downsample_conv(conv_planes[2], conv_planes[3]) + self.conv5 = self.downsample_conv(conv_planes[3], conv_planes[4]) + self.conv6 = self.downsample_conv(conv_planes[4], conv_planes[5]) + self.conv7 = self.downsample_conv(conv_planes[5], conv_planes[6]) + + upconv_planes = channel_multiplier * np.array([512, 512, 256, 128, 64, 32, 16]) + self.upconv7 = self.upconv(conv_planes[6], upconv_planes[0]) + self.upconv6 = self.upconv(upconv_planes[0], upconv_planes[1]) + self.upconv5 = self.upconv(upconv_planes[1], upconv_planes[2]) + self.upconv4 = self.upconv(upconv_planes[2], upconv_planes[3]) + self.upconv3 = self.upconv(upconv_planes[3], upconv_planes[4]) + self.upconv2 = self.upconv(upconv_planes[4], upconv_planes[5]) + self.upconv1 = self.upconv(upconv_planes[5], upconv_planes[6]) + + self.iconv7 = self.conv(upconv_planes[0] + conv_planes[5], upconv_planes[0]) + self.iconv6 = self.conv(upconv_planes[1] + conv_planes[4], upconv_planes[1]) + self.iconv5 = self.conv(upconv_planes[2] + conv_planes[3], upconv_planes[2]) + self.iconv4 = self.conv(upconv_planes[3] + conv_planes[2], upconv_planes[3]) + self.iconv3 = self.conv(1 + upconv_planes[4] + conv_planes[1], upconv_planes[4]) + self.iconv2 = self.conv(1 + upconv_planes[5] + conv_planes[0], upconv_planes[5]) + self.iconv1 = self.conv(1 + upconv_planes[6], upconv_planes[6]) + + if isinstance(output_facs, list): + self.predict_disp4 = output_facs[3](upconv_planes[3], imsizes[3]) + self.predict_disp3 = output_facs[2](upconv_planes[4], imsizes[2]) + self.predict_disp2 = output_facs[1](upconv_planes[5], imsizes[1]) + self.predict_disp1 = output_facs[0](upconv_planes[6], imsizes[0]) + else: + self.predict_disp4 = output_facs(upconv_planes[3], imsizes[3]) + self.predict_disp3 = output_facs(upconv_planes[4], imsizes[2]) + self.predict_disp2 = output_facs(upconv_planes[5], imsizes[1]) + self.predict_disp1 = output_facs(upconv_planes[6], imsizes[0]) + + def init_weights(self): + for m in self.modules(): + if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.ConvTranspose2d): + torch.nn.init.xavier_uniform_(m.weight, gain=0.1) + if m.bias is not None: + torch.nn.init.zeros_(m.bias) + + def downsample_conv(self, in_planes, out_planes, kernel_size=3): + if self.coordconv: + conv = torchext.CoordConv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, + padding=(kernel_size - 1) // 2) + else: + conv = torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, + padding=(kernel_size - 1) // 2) + return torch.nn.Sequential( + conv, + torch.nn.ReLU(inplace=True), + torch.nn.Conv2d(out_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size - 1) // 2), + torch.nn.ReLU(inplace=True) + ) + def conv(self, in_planes, out_planes): + return torch.nn.Sequential( + torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1), + torch.nn.ReLU(inplace=True) + ) + def upconv(self, in_planes, out_planes): + return torch.nn.Sequential( + torch.nn.ConvTranspose2d(in_planes, out_planes, kernel_size=3, stride=2, padding=1, output_padding=1), + torch.nn.ReLU(inplace=True) + ) -class DispNetS(TimedModule): - ''' - Disparity Decoder based on DispNetS - ''' - def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False, channel_multiplier=1): - super(DispNetS, self).__init__(mod_name='DispNetS') - - self.output_ms = output_ms - self.coordconv = coordconv - - conv_planes = channel_multiplier * np.array( [32, 64, 128, 256, 512, 512, 512] ) - self.conv1 = self.downsample_conv(channels_in, conv_planes[0], kernel_size=7) - self.conv2 = self.downsample_conv(conv_planes[0], conv_planes[1], kernel_size=5) - self.conv3 = self.downsample_conv(conv_planes[1], conv_planes[2]) - self.conv4 = self.downsample_conv(conv_planes[2], conv_planes[3]) - self.conv5 = self.downsample_conv(conv_planes[3], conv_planes[4]) - self.conv6 = self.downsample_conv(conv_planes[4], conv_planes[5]) - self.conv7 = self.downsample_conv(conv_planes[5], conv_planes[6]) - - upconv_planes = channel_multiplier * np.array( [512, 512, 256, 128, 64, 32, 16] ) - self.upconv7 = self.upconv(conv_planes[6], upconv_planes[0]) - self.upconv6 = self.upconv(upconv_planes[0], upconv_planes[1]) - self.upconv5 = self.upconv(upconv_planes[1], upconv_planes[2]) - self.upconv4 = self.upconv(upconv_planes[2], upconv_planes[3]) - self.upconv3 = self.upconv(upconv_planes[3], upconv_planes[4]) - self.upconv2 = self.upconv(upconv_planes[4], upconv_planes[5]) - self.upconv1 = self.upconv(upconv_planes[5], upconv_planes[6]) - - self.iconv7 = self.conv(upconv_planes[0] + conv_planes[5], upconv_planes[0]) - self.iconv6 = self.conv(upconv_planes[1] + conv_planes[4], upconv_planes[1]) - self.iconv5 = self.conv(upconv_planes[2] + conv_planes[3], upconv_planes[2]) - self.iconv4 = self.conv(upconv_planes[3] + conv_planes[2], upconv_planes[3]) - self.iconv3 = self.conv(1 + upconv_planes[4] + conv_planes[1], upconv_planes[4]) - self.iconv2 = self.conv(1 + upconv_planes[5] + conv_planes[0], upconv_planes[5]) - self.iconv1 = self.conv(1 + upconv_planes[6], upconv_planes[6]) - - if isinstance(output_facs, list): - self.predict_disp4 = output_facs[3](upconv_planes[3], imsizes[3]) - self.predict_disp3 = output_facs[2](upconv_planes[4], imsizes[2]) - self.predict_disp2 = output_facs[1](upconv_planes[5], imsizes[1]) - self.predict_disp1 = output_facs[0](upconv_planes[6], imsizes[0]) - else: - self.predict_disp4 = output_facs(upconv_planes[3], imsizes[3]) - self.predict_disp3 = output_facs(upconv_planes[4], imsizes[2]) - self.predict_disp2 = output_facs(upconv_planes[5], imsizes[1]) - self.predict_disp1 = output_facs(upconv_planes[6], imsizes[0]) - - - def init_weights(self): - for m in self.modules(): - if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.ConvTranspose2d): - torch.nn.init.xavier_uniform_(m.weight, gain=0.1) - if m.bias is not None: - torch.nn.init.zeros_(m.bias) - - def downsample_conv(self, in_planes, out_planes, kernel_size=3): - if self.coordconv: - conv = torchext.CoordConv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=(kernel_size-1)//2) - else: - conv = torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=(kernel_size-1)//2) - return torch.nn.Sequential( - conv, - torch.nn.ReLU(inplace=True), - torch.nn.Conv2d(out_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size-1)//2), - torch.nn.ReLU(inplace=True) - ) - - def conv(self, in_planes, out_planes): - return torch.nn.Sequential( - torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1), - torch.nn.ReLU(inplace=True) - ) - - def upconv(self, in_planes, out_planes): - return torch.nn.Sequential( - torch.nn.ConvTranspose2d(in_planes, out_planes, kernel_size=3, stride=2, padding=1, output_padding=1), - torch.nn.ReLU(inplace=True) - ) - - def crop_like(self, input, ref): - assert(input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3)) - return input[:, :, :ref.size(2), :ref.size(3)] - - def tforward(self, x): - out_conv1 = self.conv1(x) - out_conv2 = self.conv2(out_conv1) - out_conv3 = self.conv3(out_conv2) - out_conv4 = self.conv4(out_conv3) - out_conv5 = self.conv5(out_conv4) - out_conv6 = self.conv6(out_conv5) - out_conv7 = self.conv7(out_conv6) - - out_upconv7 = self.crop_like(self.upconv7(out_conv7), out_conv6) - concat7 = torch.cat((out_upconv7, out_conv6), 1) - out_iconv7 = self.iconv7(concat7) - - out_upconv6 = self.crop_like(self.upconv6(out_iconv7), out_conv5) - concat6 = torch.cat((out_upconv6, out_conv5), 1) - out_iconv6 = self.iconv6(concat6) - - out_upconv5 = self.crop_like(self.upconv5(out_iconv6), out_conv4) - concat5 = torch.cat((out_upconv5, out_conv4), 1) - out_iconv5 = self.iconv5(concat5) - - out_upconv4 = self.crop_like(self.upconv4(out_iconv5), out_conv3) - concat4 = torch.cat((out_upconv4, out_conv3), 1) - out_iconv4 = self.iconv4(concat4) - disp4 = self.predict_disp4(out_iconv4) - - out_upconv3 = self.crop_like(self.upconv3(out_iconv4), out_conv2) - disp4_up = self.crop_like(torch.nn.functional.interpolate(disp4, scale_factor=2, mode='bilinear', align_corners=False), out_conv2) - concat3 = torch.cat((out_upconv3, out_conv2, disp4_up), 1) - out_iconv3 = self.iconv3(concat3) - disp3 = self.predict_disp3(out_iconv3) - - out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1) - disp3_up = self.crop_like(torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1) - concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1) - out_iconv2 = self.iconv2(concat2) - disp2 = self.predict_disp2(out_iconv2) - - out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x) - disp2_up = self.crop_like(torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x) - concat1 = torch.cat((out_upconv1, disp2_up), 1) - out_iconv1 = self.iconv1(concat1) - disp1 = self.predict_disp1(out_iconv1) - - if self.output_ms: - return disp1, disp2, disp3, disp4 - else: - return disp1 + def crop_like(self, input, ref): + assert (input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3)) + return input[:, :, :ref.size(2), :ref.size(3)] + + def tforward(self, x): + out_conv1 = self.conv1(x) + out_conv2 = self.conv2(out_conv1) + out_conv3 = self.conv3(out_conv2) + out_conv4 = self.conv4(out_conv3) + out_conv5 = self.conv5(out_conv4) + out_conv6 = self.conv6(out_conv5) + out_conv7 = self.conv7(out_conv6) + + out_upconv7 = self.crop_like(self.upconv7(out_conv7), out_conv6) + concat7 = torch.cat((out_upconv7, out_conv6), 1) + out_iconv7 = self.iconv7(concat7) + + out_upconv6 = self.crop_like(self.upconv6(out_iconv7), out_conv5) + concat6 = torch.cat((out_upconv6, out_conv5), 1) + out_iconv6 = self.iconv6(concat6) + + out_upconv5 = self.crop_like(self.upconv5(out_iconv6), out_conv4) + concat5 = torch.cat((out_upconv5, out_conv4), 1) + out_iconv5 = self.iconv5(concat5) + + out_upconv4 = self.crop_like(self.upconv4(out_iconv5), out_conv3) + concat4 = torch.cat((out_upconv4, out_conv3), 1) + out_iconv4 = self.iconv4(concat4) + disp4 = self.predict_disp4(out_iconv4) + + out_upconv3 = self.crop_like(self.upconv3(out_iconv4), out_conv2) + disp4_up = self.crop_like( + torch.nn.functional.interpolate(disp4, scale_factor=2, mode='bilinear', align_corners=False), out_conv2) + concat3 = torch.cat((out_upconv3, out_conv2, disp4_up), 1) + out_iconv3 = self.iconv3(concat3) + disp3 = self.predict_disp3(out_iconv3) + + out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1) + disp3_up = self.crop_like( + torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1) + concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1) + out_iconv2 = self.iconv2(concat2) + disp2 = self.predict_disp2(out_iconv2) + + out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x) + disp2_up = self.crop_like( + torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x) + concat1 = torch.cat((out_upconv1, disp2_up), 1) + out_iconv1 = self.iconv1(concat1) + disp1 = self.predict_disp1(out_iconv1) + + if self.output_ms: + return disp1, disp2, disp3, disp4 + else: + return disp1 class DispNetShallow(DispNetS): - ''' - Edge Decoder based on DispNetS with fewer layers - ''' - def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False): - super(DispNetShallow, self).__init__(channels_in, imsizes, output_facs, output_ms, coordconv, weight_init) - self.mod_name = 'DispNetShallow' - conv_planes = [32, 64, 128, 256, 512, 512, 512] - upconv_planes = [512, 512, 256, 128, 64, 32, 16] - self.iconv3 = self.conv(upconv_planes[4] + conv_planes[1], upconv_planes[4]) - - def tforward(self, x): - out_conv1 = self.conv1(x) - out_conv2 = self.conv2(out_conv1) - out_conv3 = self.conv3(out_conv2) - - out_upconv3 = self.crop_like(self.upconv3(out_conv3), out_conv2) - concat3 = torch.cat((out_upconv3, out_conv2), 1) - out_iconv3 = self.iconv3(concat3) - disp3 = self.predict_disp3(out_iconv3) - - out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1) - disp3_up = self.crop_like(torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1) - concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1) - out_iconv2 = self.iconv2(concat2) - disp2 = self.predict_disp2(out_iconv2) - - out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x) - disp2_up = self.crop_like(torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x) - concat1 = torch.cat((out_upconv1, disp2_up), 1) - out_iconv1 = self.iconv1(concat1) - disp1 = self.predict_disp1(out_iconv1) - - if self.output_ms: - return disp1, disp2, disp3 - else: - return disp1 + ''' + Edge Decoder based on DispNetS with fewer layers + ''' + + def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False): + super(DispNetShallow, self).__init__(channels_in, imsizes, output_facs, output_ms, coordconv, weight_init) + self.mod_name = 'DispNetShallow' + conv_planes = [32, 64, 128, 256, 512, 512, 512] + upconv_planes = [512, 512, 256, 128, 64, 32, 16] + self.iconv3 = self.conv(upconv_planes[4] + conv_planes[1], upconv_planes[4]) + + def tforward(self, x): + out_conv1 = self.conv1(x) + out_conv2 = self.conv2(out_conv1) + out_conv3 = self.conv3(out_conv2) + + out_upconv3 = self.crop_like(self.upconv3(out_conv3), out_conv2) + concat3 = torch.cat((out_upconv3, out_conv2), 1) + out_iconv3 = self.iconv3(concat3) + disp3 = self.predict_disp3(out_iconv3) + + out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1) + disp3_up = self.crop_like( + torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1) + concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1) + out_iconv2 = self.iconv2(concat2) + disp2 = self.predict_disp2(out_iconv2) + + out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x) + disp2_up = self.crop_like( + torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x) + concat1 = torch.cat((out_upconv1, disp2_up), 1) + out_iconv1 = self.iconv1(concat1) + disp1 = self.predict_disp1(out_iconv1) + + if self.output_ms: + return disp1, disp2, disp3 + else: + return disp1 class DispEdgeDecoders(TimedModule): - ''' - Disparity Decoder and Edge Decoder - ''' - def __init__(self, *args, max_disp=128, **kwargs): - super(DispEdgeDecoders, self).__init__(mod_name='DispEdgeDecoders') + ''' + Disparity Decoder and Edge Decoder + ''' - output_facs = [OutputLayerFactory( type='disp', params={ 'alpha': max_disp/(2**s), 'beta': 0, 'gamma': 1, 'offset': 3}) for s in range(4)] - self.disp_decoder = DispNetS(*args, output_facs=output_facs, **kwargs) + def __init__(self, *args, max_disp=128, **kwargs): + super(DispEdgeDecoders, self).__init__(mod_name='DispEdgeDecoders') - output_facs = [OutputLayerFactory( type='linear' ) for s in range(4)] - self.edge_decoder = DispNetShallow(*args, output_facs=output_facs, **kwargs) + output_facs = [ + OutputLayerFactory(type='disp', params={'alpha': max_disp / (2 ** s), 'beta': 0, 'gamma': 1, 'offset': 3}) + for s in range(4)] + self.disp_decoder = DispNetS(*args, output_facs=output_facs, **kwargs) - def tforward(self, x): - disp = self.disp_decoder(x) - edge = self.edge_decoder(x) - return disp, edge + output_facs = [OutputLayerFactory(type='linear') for s in range(4)] + self.edge_decoder = DispNetShallow(*args, output_facs=output_facs, **kwargs) + + def tforward(self, x): + disp = self.disp_decoder(x) + edge = self.edge_decoder(x) + return disp, edge class DispToDepth(TimedModule): - def __init__(self, focal_length, baseline): - super().__init__(mod_name='DispToDepth') - self.baseline_focal_length = baseline * focal_length + def __init__(self, focal_length, baseline): + super().__init__(mod_name='DispToDepth') + self.baseline_focal_length = baseline * focal_length - def tforward(self, disp): - disp = torch.nn.functional.relu(disp) + 1e-12 - depth = self.baseline_focal_length / disp - return depth + def tforward(self, disp): + disp = torch.nn.functional.relu(disp) + 1e-12 + depth = self.baseline_focal_length / disp + return depth class PosToDepth(DispToDepth): - def __init__(self, focal_length, baseline, im_height, im_width): - super().__init__(focal_length, baseline) - self.mod_name = 'PosToDepth' - - self.im_height = im_height - self.im_width = im_width - self.u_pos = torch.arange(im_width, dtype=torch.float32).view(1,1,1,-1) + def __init__(self, focal_length, baseline, im_height, im_width): + super().__init__(focal_length, baseline) + self.mod_name = 'PosToDepth' - def tforward(self, pos): - self.u_pos = self.u_pos.to(pos.device) - disp = self.u_pos - pos - return super().forward(disp) + self.im_height = im_height + self.im_width = im_width + self.u_pos = torch.arange(im_width, dtype=torch.float32).view(1, 1, 1, -1) + def tforward(self, pos): + self.u_pos = self.u_pos.to(pos.device) + disp = self.u_pos - pos + return super().forward(disp) class RectifiedPatternSimilarityLoss(TimedModule): - ''' - Photometric Loss - ''' - def __init__(self, im_height, im_width, pattern, loss_type='census_sad', loss_eps=0.5): - super().__init__(mod_name='RectifiedPatternSimilarityLoss') - self.im_height = im_height - self.im_width = im_width - self.pattern = pattern.mean(dim=1, keepdim=True).contiguous() - - u, v = np.meshgrid(range(im_width), range(im_height)) - uv0 = np.stack((u,v), axis=2).reshape(-1,1) - uv0 = uv0.astype(np.float32).reshape(1,-1,2) - self.uv0 = torch.from_numpy(uv0) - - self.loss_type = loss_type - self.loss_eps = loss_eps - - def tforward(self, disp0, im, std=None): - self.pattern = self.pattern.to(disp0.device) - self.uv0 = self.uv0.to(disp0.device) - - uv0 = self.uv0.expand(disp0.shape[0], *self.uv0.shape[1:]) - uv1 = torch.empty_like(uv0) - uv1[...,0] = uv0[...,0] - disp0.contiguous().view(disp0.shape[0],-1) - uv1[...,1] = uv0[...,1] - - uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width-1) - 0.5) - uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height-1) - 0.5) - uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone() - pattern = self.pattern.expand(disp0.shape[0], *self.pattern.shape[1:]) - pattern_proj = torch.nn.functional.grid_sample(pattern, uv1, padding_mode='border') - mask = torch.ones_like(im) - if std is not None: - mask = mask*std - - diff = torchext.photometric_loss(pattern_proj.contiguous(), im.contiguous(), 9, self.loss_type, self.loss_eps) - val = (mask*diff).sum() / mask.sum() - return val, pattern_proj + ''' + Photometric Loss + ''' + + def __init__(self, im_height, im_width, pattern, loss_type='census_sad', loss_eps=0.5): + super().__init__(mod_name='RectifiedPatternSimilarityLoss') + self.im_height = im_height + self.im_width = im_width + self.pattern = pattern.mean(dim=1, keepdim=True).contiguous() + + u, v = np.meshgrid(range(im_width), range(im_height)) + uv0 = np.stack((u, v), axis=2).reshape(-1, 1) + uv0 = uv0.astype(np.float32).reshape(1, -1, 2) + self.uv0 = torch.from_numpy(uv0) + + self.loss_type = loss_type + self.loss_eps = loss_eps + + def tforward(self, disp0, im, std=None): + self.pattern = self.pattern.to(disp0.device) + self.uv0 = self.uv0.to(disp0.device) + + uv0 = self.uv0.expand(disp0.shape[0], *self.uv0.shape[1:]) + uv1 = torch.empty_like(uv0) + uv1[..., 0] = uv0[..., 0] - disp0.contiguous().view(disp0.shape[0], -1) + uv1[..., 1] = uv0[..., 1] + + uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width - 1) - 0.5) + uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height - 1) - 0.5) + uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone() + pattern = self.pattern.expand(disp0.shape[0], *self.pattern.shape[1:]) + pattern_proj = torch.nn.functional.grid_sample(pattern, uv1, padding_mode='border') + mask = torch.ones_like(im) + if std is not None: + mask = mask * std + + diff = torchext.photometric_loss(pattern_proj.contiguous(), im.contiguous(), 9, self.loss_type, self.loss_eps) + val = (mask * diff).sum() / mask.sum() + return val, pattern_proj -class DisparityLoss(TimedModule): - ''' - Disparity Loss - ''' - def __init__(self): - super().__init__(mod_name='DisparityLoss') - self.sobel = SobelFilter(norm=False) - - #if not edge_gt: - self.b0=0.0503428816795 - self.b1=1.07274045944 - #else: - # self.b0=0.0587115108967 - # self.b1=1.51931190491 - - def tforward(self, disp, edge=None): - self.sobel=self.sobel.to(disp.device) - - if edge is not None: - grad = self.sobel(disp) - grad = torch.sqrt(grad[:,0:1,...]**2 + grad[:,1:2,...]**2 + 1e-8) - pdf = (1-edge)/self.b0 * torch.exp(-torch.abs(grad)/self.b0) + \ - edge/self.b1 * torch.exp(-torch.abs(grad)/self.b1) - val = torch.mean(-torch.log(pdf.clamp(min=1e-4))) - else: - # on qifeng's data we don't have ambient info - # therefore we supress edge everywhere - grad = self.sobel(disp) - grad = torch.sqrt(grad[:,0:1,...]**2 + grad[:,1:2,...]**2 + 1e-8) - grad= torch.clamp(grad, 0, 1.0) - val = torch.mean(grad) - - return val +class DisparityLoss(TimedModule): + ''' + Disparity Loss + ''' + + def __init__(self): + super().__init__(mod_name='DisparityLoss') + self.sobel = SobelFilter(norm=False) + + # if not edge_gt: + self.b0 = 0.0503428816795 + self.b1 = 1.07274045944 + # else: + # self.b0=0.0587115108967 + # self.b1=1.51931190491 + + def tforward(self, disp, edge=None): + self.sobel = self.sobel.to(disp.device) + + if edge is not None: + grad = self.sobel(disp) + grad = torch.sqrt(grad[:, 0:1, ...] ** 2 + grad[:, 1:2, ...] ** 2 + 1e-8) + pdf = (1 - edge) / self.b0 * torch.exp(-torch.abs(grad) / self.b0) + \ + edge / self.b1 * torch.exp(-torch.abs(grad) / self.b1) + val = torch.mean(-torch.log(pdf.clamp(min=1e-4))) + else: + # on qifeng's data we don't have ambient info + # therefore we supress edge everywhere + grad = self.sobel(disp) + grad = torch.sqrt(grad[:, 0:1, ...] ** 2 + grad[:, 1:2, ...] ** 2 + 1e-8) + grad = torch.clamp(grad, 0, 1.0) + val = torch.mean(grad) + + return val class ProjectionBaseLoss(TimedModule): - ''' - Base module of the Geometric Loss - ''' - def __init__(self, K, Ki, im_height, im_width): - super().__init__(mod_name='ProjectionBaseLoss') + ''' + Base module of the Geometric Loss + ''' - self.K = K.view(-1,3,3) + def __init__(self, K, Ki, im_height, im_width): + super().__init__(mod_name='ProjectionBaseLoss') - self.im_height = im_height - self.im_width = im_width + self.K = K.view(-1, 3, 3) - u, v = np.meshgrid(range(im_width), range(im_height)) - uv = np.stack((u,v,np.ones_like(u)), axis=2).reshape(-1,3) + self.im_height = im_height + self.im_width = im_width - ray = uv @ Ki.numpy().T + u, v = np.meshgrid(range(im_width), range(im_height)) + uv = np.stack((u, v, np.ones_like(u)), axis=2).reshape(-1, 3) - ray = ray.reshape(1,-1,3).astype(np.float32) - self.ray = torch.from_numpy(ray) + ray = uv @ Ki.numpy().T - def transform(self, xyz, R=None, t=None): - if t is not None: - bs = xyz.shape[0] - xyz = xyz - t.reshape(bs,1,3) - if R is not None: - xyz = torch.bmm(xyz, R) - return xyz + ray = ray.reshape(1, -1, 3).astype(np.float32) + self.ray = torch.from_numpy(ray) - def unproject(self, depth, R=None, t=None): - self.ray = self.ray.to(depth.device) - bs = depth.shape[0] + def transform(self, xyz, R=None, t=None): + if t is not None: + bs = xyz.shape[0] + xyz = xyz - t.reshape(bs, 1, 3) + if R is not None: + xyz = torch.bmm(xyz, R) + return xyz - xyz = depth.reshape(bs,-1,1) * self.ray - xyz = self.transform(xyz, R, t) - return xyz + def unproject(self, depth, R=None, t=None): + self.ray = self.ray.to(depth.device) + bs = depth.shape[0] - def project(self, xyz, R, t): - self.K = self.K.to(xyz.device) - bs = xyz.shape[0] + xyz = depth.reshape(bs, -1, 1) * self.ray + xyz = self.transform(xyz, R, t) + return xyz - xyz = torch.bmm(xyz, R.transpose(1,2)) - xyz = xyz + t.reshape(bs,1,3) + def project(self, xyz, R, t): + self.K = self.K.to(xyz.device) + bs = xyz.shape[0] - Kt = self.K.transpose(1,2).expand(bs,-1,-1) - uv = torch.bmm(xyz, Kt) + xyz = torch.bmm(xyz, R.transpose(1, 2)) + xyz = xyz + t.reshape(bs, 1, 3) - d = uv[:,:,2:3] + Kt = self.K.transpose(1, 2).expand(bs, -1, -1) + uv = torch.bmm(xyz, Kt) - # avoid division by zero - uv = uv[:,:,:2] / (torch.nn.functional.relu(d) + 1e-12) - return uv, d + d = uv[:, :, 2:3] + # avoid division by zero + uv = uv[:, :, :2] / (torch.nn.functional.relu(d) + 1e-12) + return uv, d - def tforward(self, depth0, R0, t0, R1, t1): - xyz = self.unproject(depth0, R0, t0) - return self.project(xyz, R1, t1) + def tforward(self, depth0, R0, t0, R1, t1): + xyz = self.unproject(depth0, R0, t0) + return self.project(xyz, R1, t1) class ProjectionDepthSimilarityLoss(ProjectionBaseLoss): - ''' - Geometric Loss - ''' - def __init__(self, *args, clamp=-1): - super().__init__(*args) - self.mod_name = 'ProjectionDepthSimilarityLoss' - self.clamp = clamp + ''' + Geometric Loss + ''' - def fwd(self, depth0, depth1, R0, t0, R1, t1): - uv1, d1 = super().tforward(depth0, R0, t0, R1, t1) + def __init__(self, *args, clamp=-1): + super().__init__(*args) + self.mod_name = 'ProjectionDepthSimilarityLoss' + self.clamp = clamp - uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width-1) - 0.5) - uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height-1) - 0.5) - uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone() + def fwd(self, depth0, depth1, R0, t0, R1, t1): + uv1, d1 = super().tforward(depth0, R0, t0, R1, t1) - depth10 = torch.nn.functional.grid_sample(depth1, uv1, padding_mode='border') + uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width - 1) - 0.5) + uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height - 1) - 0.5) + uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone() - diff = torch.abs(d1.view(-1) - depth10.view(-1)) + depth10 = torch.nn.functional.grid_sample(depth1, uv1, padding_mode='border') - if self.clamp > 0: - diff = torch.clamp(diff, 0, self.clamp) + diff = torch.abs(d1.view(-1) - depth10.view(-1)) - # return diff without clamping for debugging - return diff.mean() + if self.clamp > 0: + diff = torch.clamp(diff, 0, self.clamp) - def tforward(self, depth0, depth1, R0, t0, R1, t1): - l0 = self.fwd(depth0, depth1, R0, t0, R1, t1) - l1 = self.fwd(depth1, depth0, R1, t1, R0, t0) - return l0+l1 + # return diff without clamping for debugging + return diff.mean() + def tforward(self, depth0, depth1, R0, t0, R1, t1): + l0 = self.fwd(depth0, depth1, R0, t0, R1, t1) + l1 = self.fwd(depth1, depth0, R1, t1, R0, t0) + return l0 + l1 class LCN(TimedModule): - ''' - Local Contract Normalization - ''' - def __init__(self, radius, epsilon): - super().__init__(mod_name='LCN') - self.box_conv = torch.nn.Sequential( - torch.nn.ReflectionPad2d(radius), - torch.nn.Conv2d(1, 1, kernel_size=2*radius+1, bias=False) - ) - self.box_conv[1].weight.requires_grad=False - self.box_conv[1].weight.fill_(1.) - - self.epsilon = epsilon - self.radius = radius + ''' + Local Contract Normalization + ''' + + def __init__(self, radius, epsilon): + super().__init__(mod_name='LCN') + self.box_conv = torch.nn.Sequential( + torch.nn.ReflectionPad2d(radius), + torch.nn.Conv2d(1, 1, kernel_size=2 * radius + 1, bias=False) + ) + self.box_conv[1].weight.requires_grad = False + self.box_conv[1].weight.fill_(1.) - def tforward(self, data): - boxs = self.box_conv(data) + self.epsilon = epsilon + self.radius = radius - avgs = boxs / (2*self.radius+1)**2 - boxs_n2 = boxs**2 - boxs_2n = self.box_conv(data**2) + def tforward(self, data): + boxs = self.box_conv(data) - stds = torch.sqrt(boxs_2n / (2*self.radius+1)**2 - avgs**2 + 1e-6) - stds = stds + self.epsilon + avgs = boxs / (2 * self.radius + 1) ** 2 + boxs_n2 = boxs ** 2 + boxs_2n = self.box_conv(data ** 2) - return (data - avgs) / stds, stds + stds = torch.sqrt(boxs_2n / (2 * self.radius + 1) ** 2 - avgs ** 2 + 1e-6) + stds = stds + self.epsilon + return (data - avgs) / stds, stds class SobelFilter(TimedModule): - ''' - Sobel Filter - ''' - def __init__(self, norm=False): - super(SobelFilter, self).__init__(mod_name='SobelFilter') - kx = np.array([[-5, -4, 0, 4, 5], - [-8, -10, 0, 10, 8], - [-10, -20, 0, 20, 10], - [-8, -10, 0, 10, 8], - [-5, -4, 0, 4, 5]])/240.0 - ky = kx.copy().transpose(1,0) - - self.conv_x=torch.nn.Conv2d(1, 1, kernel_size=5, stride=1, padding=0, bias=False) - self.conv_x.weight=torch.nn.Parameter(torch.from_numpy(kx).float().unsqueeze(0).unsqueeze(0)) - - self.conv_y=torch.nn.Conv2d(1, 1, kernel_size=5, stride=1, padding=0, bias=False) - self.conv_y.weight=torch.nn.Parameter(torch.from_numpy(ky).float().unsqueeze(0).unsqueeze(0)) - - self.norm=norm - - def tforward(self,x): - x = F.pad(x, (2,2,2,2), "replicate") - gx = self.conv_x(x) - gy = self.conv_y(x) - if self.norm: - return torch.sqrt(gx**2 + gy**2 + 1e-8) - else: - return torch.cat((gx, gy), dim=1) - + ''' + Sobel Filter + ''' + + def __init__(self, norm=False): + super(SobelFilter, self).__init__(mod_name='SobelFilter') + kx = np.array([[-5, -4, 0, 4, 5], + [-8, -10, 0, 10, 8], + [-10, -20, 0, 20, 10], + [-8, -10, 0, 10, 8], + [-5, -4, 0, 4, 5]]) / 240.0 + ky = kx.copy().transpose(1, 0) + + self.conv_x = torch.nn.Conv2d(1, 1, kernel_size=5, stride=1, padding=0, bias=False) + self.conv_x.weight = torch.nn.Parameter(torch.from_numpy(kx).float().unsqueeze(0).unsqueeze(0)) + + self.conv_y = torch.nn.Conv2d(1, 1, kernel_size=5, stride=1, padding=0, bias=False) + self.conv_y.weight = torch.nn.Parameter(torch.from_numpy(ky).float().unsqueeze(0).unsqueeze(0)) + + self.norm = norm + + def tforward(self, x): + x = F.pad(x, (2, 2, 2, 2), "replicate") + gx = self.conv_x(x) + gy = self.conv_y(x) + if self.norm: + return torch.sqrt(gx ** 2 + gy ** 2 + 1e-8) + else: + return torch.cat((gx, gy), dim=1) diff --git a/readme.md b/readme.md index f3ec6af..50f5e5a 100644 --- a/readme.md +++ b/readme.md @@ -6,7 +6,9 @@ This repository contains the code for the paper **[Connecting the Dots: Learning Representations for Active Monocular Depth Estimation](http://www.cvlibs.net/publications/Riegler2019CVPR.pdf)**
-[Gernot Riegler](https://griegler.github.io/), [Yiyi Liao](https://yiyiliao.github.io/), [Simon Donne](https://avg.is.tuebingen.mpg.de/person/sdonne), [Vladlen Koltun](http://vladlen.info/), and [Andreas Geiger](http://www.cvlibs.net/) +[Gernot Riegler](https://griegler.github.io/), [Yiyi Liao](https://yiyiliao.github.io/) +, [Simon Donne](https://avg.is.tuebingen.mpg.de/person/sdonne), [Vladlen Koltun](http://vladlen.info/), +and [Andreas Geiger](http://www.cvlibs.net/)
[CVPR 2019](http://cvpr2019.thecvf.com/) @@ -24,40 +26,45 @@ If you find this code useful for your research, please cite } ``` - ## Dependencies The network training/evaluation code is based on `Pytorch`. + ``` PyTorch>=1.1 Cuda>=10.0 ``` + Updated on 07.06.2021: The code is now compatible with the latest Pytorch version (1.8). The other python packages can be installed with `anaconda`: + ``` conda install --file requirements.txt ``` ### Structured Light Renderer -To train and evaluate our method in a controlled setting, we implemented an structured light renderer. -It can be used to render a virtual scene (arbitrary triangle mesh) with the structured light pattern projected from a customizable projector location. -To build it, first make sure the correct `CUDA_LIBRARY_PATH` is set in `config.json`. -Afterwards, the renderer can be build by running `make` within the `renderer` directory. + +To train and evaluate our method in a controlled setting, we implemented an structured light renderer. It can be used to +render a virtual scene (arbitrary triangle mesh) with the structured light pattern projected from a customizable +projector location. To build it, first make sure the correct `CUDA_LIBRARY_PATH` is set in `config.json`. Afterwards, +the renderer can be build by running `make` within the `renderer` directory. ### PyTorch Extensions -The network training/evaluation code is based on `PyTorch`. -We implemented some custom layers that need to be built in the `torchext` directory. -Simply change into this directory and run + +The network training/evaluation code is based on `PyTorch`. We implemented some custom layers that need to be built in +the `torchext` directory. Simply change into this directory and run ``` python setup.py build_ext --inplace ``` ### Baseline HyperDepth -As baseline we partially re-implemented the random forest based method [HyperDepth](http://openaccess.thecvf.com/content_cvpr_2016/papers/Fanello_HyperDepth_Learning_Depth_CVPR_2016_paper.pdf). -The code resided in the `hyperdepth` directory and is implemented in `C++11` with a Python wrapper written in `Cython`. -To build it change into the directory and run + +As baseline we partially re-implemented the random forest based +method [HyperDepth](http://openaccess.thecvf.com/content_cvpr_2016/papers/Fanello_HyperDepth_Learning_Depth_CVPR_2016_paper.pdf) +. The code resided in the `hyperdepth` directory and is implemented in `C++11` with a Python wrapper written in `Cython` +. To build it change into the directory and run ``` python setup.py build_ext --inplace @@ -65,42 +72,59 @@ python setup.py build_ext --inplace ## Running - ### Creating Synthetic Data -To create synthetic data and save it locally, download [ShapeNet V2](https://www.shapenet.org/) and correct `SHAPENET_ROOT` in `config.json`. Then the data can be generated and saved to `DATA_ROOT` in `config.json` by running + +To create synthetic data and save it locally, download [ShapeNet V2](https://www.shapenet.org/) and +correct `SHAPENET_ROOT` in `config.json`. Then the data can be generated and saved to `DATA_ROOT` in `config.json` by +running + ``` ./create_syn_data.sh ``` -If you are only interested in evaluating our pre-trained model, [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip) is a validation set that contains a small amount of images. + +If you are only interested in evaluating our pre-trained +model, [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip) is a +validation set that contains a small amount of images. ### Training Network -As a first stage, it is recommended to train the disparity decoder and edge decoder without the geometric loss. To train the network on synthetic data for the first stage run +As a first stage, it is recommended to train the disparity decoder and edge decoder without the geometric loss. To train +the network on synthetic data for the first stage run + ``` python train_val.py ``` -After the model is pretrained without the geometric loss, the full model can be trained from the initialized weights by running +After the model is pretrained without the geometric loss, the full model can be trained from the initialized weights by +running + ``` python train_val.py --loss phge ``` - ### Evaluating Network -To evaluate a specific checkpoint, e.g. the 50th epoch, one can run + +To evaluate a specific checkpoint, e.g. the 50th epoch, one can run + ``` python train_val.py --cmd retest --epoch 50 ``` ### Evaluating a Pre-trained Model -We provide a model pre-trained using the photometric loss. Once you have prepared the synthetic dataset and changed `DATA_ROOT` in `config.json`, the pre-trained model can be evaluated on the validation set by running: + +We provide a model pre-trained using the photometric loss. Once you have prepared the synthetic dataset and +changed `DATA_ROOT` in `config.json`, the pre-trained model can be evaluated on the validation set by running: + ``` mkdir -p output mkdir -p output/exp_syn wget -O output/exp_syn/net_0099.params https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/net_0099.params python train_val.py --cmd retest --epoch 99 ``` -You can also download our validation set from [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip). -## Acknowledgement +You can also download our validation set +from [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip). + +## Acknowledgement + This work was supported by the Intel Network on Intelligent Systems. diff --git a/renderer/setup.py b/renderer/setup.py index 393f160..a696d39 100644 --- a/renderer/setup.py +++ b/renderer/setup.py @@ -10,7 +10,7 @@ import json this_dir = os.path.dirname(__file__) with open('../config.json') as fp: - config = json.load(fp) + config = json.load(fp) extra_compile_args = ['-O3', '-std=c++11'] @@ -20,7 +20,7 @@ cuda_lib = 'cudart' sources = ['cyrender.pyx'] extra_objects = [ - os.path.join(this_dir, 'render/render_cpu.cpp.o'), + os.path.join(this_dir, 'render/render_cpu.cpp.o'), ] library_dirs = [] libraries = ['m'] @@ -30,20 +30,20 @@ library_dirs.append(cuda_lib_dir) libraries.append(cuda_lib) setup( - name="cyrender", - cmdclass= {'build_ext': build_ext}, - ext_modules=[ - Extension('cyrender', - sources, - extra_objects=extra_objects, - language='c++', - library_dirs=library_dirs, - libraries=libraries, - include_dirs=[ - np.get_include(), - ], - extra_compile_args=extra_compile_args, - # extra_link_args=extra_link_args - ) - ] + name="cyrender", + cmdclass={'build_ext': build_ext}, + ext_modules=[ + Extension('cyrender', + sources, + extra_objects=extra_objects, + language='c++', + library_dirs=library_dirs, + libraries=libraries, + include_dirs=[ + np.get_include(), + ], + extra_compile_args=extra_compile_args, + # extra_link_args=extra_link_args + ) + ] ) diff --git a/torchext/dataset.py b/torchext/dataset.py index 2c606df..b4c2e1a 100644 --- a/torchext/dataset.py +++ b/torchext/dataset.py @@ -2,65 +2,65 @@ import torch import torch.utils.data import numpy as np + class TestSet(object): - def __init__(self, name, dset, test_frequency=1): - self.name = name - self.dset = dset - self.test_frequency = test_frequency + def __init__(self, name, dset, test_frequency=1): + self.name = name + self.dset = dset + self.test_frequency = test_frequency -class TestSets(list): - def append(self, name, dset, test_frequency=1): - super().append(TestSet(name, dset, test_frequency)) +class TestSets(list): + def append(self, name, dset, test_frequency=1): + super().append(TestSet(name, dset, test_frequency)) class MultiDataset(torch.utils.data.Dataset): - def __init__(self, *datasets): - self.current_epoch = 0 - - self.datasets = [] - self.cum_n_samples = [0] + def __init__(self, *datasets): + self.current_epoch = 0 - for dataset in datasets: - self.append(dataset) + self.datasets = [] + self.cum_n_samples = [0] - def append(self, dataset): - self.datasets.append(dataset) - self.__update_cum_n_samples(dataset) + for dataset in datasets: + self.append(dataset) - def __update_cum_n_samples(self, dataset): - n_samples = self.cum_n_samples[-1] + len(dataset) - self.cum_n_samples.append(n_samples) + def append(self, dataset): + self.datasets.append(dataset) + self.__update_cum_n_samples(dataset) - def dataset_updated(self): - self.cum_n_samples = [0] - for dset in self.datasets: - self.__update_cum_n_samples(dset) + def __update_cum_n_samples(self, dataset): + n_samples = self.cum_n_samples[-1] + len(dataset) + self.cum_n_samples.append(n_samples) - def __len__(self): - return self.cum_n_samples[-1] + def dataset_updated(self): + self.cum_n_samples = [0] + for dset in self.datasets: + self.__update_cum_n_samples(dset) - def __getitem__(self, idx): - didx = np.searchsorted(self.cum_n_samples, idx, side='right') - 1 - sidx = idx - self.cum_n_samples[didx] - return self.datasets[didx][sidx] + def __len__(self): + return self.cum_n_samples[-1] + def __getitem__(self, idx): + didx = np.searchsorted(self.cum_n_samples, idx, side='right') - 1 + sidx = idx - self.cum_n_samples[didx] + return self.datasets[didx][sidx] class BaseDataset(torch.utils.data.Dataset): - def __init__(self, train=True, fix_seed_per_epoch=False): - self.current_epoch = 0 - self.train = train - self.fix_seed_per_epoch = fix_seed_per_epoch - - def get_rng(self, idx): - rng = np.random.RandomState() - if self.train: - if self.fix_seed_per_epoch: - seed = 1 * len(self) + idx - else: - seed = (self.current_epoch + 1) * len(self) + idx - rng.seed(seed) - else: - rng.seed(idx) - return rng + def __init__(self, train=True, fix_seed_per_epoch=False): + self.current_epoch = 0 + self.train = train + self.fix_seed_per_epoch = fix_seed_per_epoch + + def get_rng(self, idx): + rng = np.random.RandomState() + if self.train: + if self.fix_seed_per_epoch: + seed = 1 * len(self) + idx + else: + seed = (self.current_epoch + 1) * len(self) + idx + rng.seed(seed) + else: + rng.seed(idx) + return rng diff --git a/torchext/functions.py b/torchext/functions.py index 73dc7ae..2885ae9 100644 --- a/torchext/functions.py +++ b/torchext/functions.py @@ -2,146 +2,151 @@ import torch from . import ext_cpu from . import ext_cuda + class NNFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, in0, in1): - args = (in0, in1) - if in0.is_cuda: - out = ext_cuda.nn_cuda(*args) - else: - out = ext_cpu.nn_cpu(*args) - return out + @staticmethod + def forward(ctx, in0, in1): + args = (in0, in1) + if in0.is_cuda: + out = ext_cuda.nn_cuda(*args) + else: + out = ext_cpu.nn_cpu(*args) + return out + + @staticmethod + def backward(ctx, grad_out): + return None, None - @staticmethod - def backward(ctx, grad_out): - return None, None def nn(in0, in1): - return NNFunction.apply(in0, in1) + return NNFunction.apply(in0, in1) class CrossCheckFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, in0, in1): - args = (in0, in1) - if in0.is_cuda: - out = ext_cuda.crosscheck_cuda(*args) - else: - out = ext_cpu.crosscheck_cpu(*args) - return out + @staticmethod + def forward(ctx, in0, in1): + args = (in0, in1) + if in0.is_cuda: + out = ext_cuda.crosscheck_cuda(*args) + else: + out = ext_cpu.crosscheck_cpu(*args) + return out + + @staticmethod + def backward(ctx, grad_out): + return None, None - @staticmethod - def backward(ctx, grad_out): - return None, None def crosscheck(in0, in1): - return CrossCheckFunction.apply(in0, in1) + return CrossCheckFunction.apply(in0, in1) + class ProjNNFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, xyz0, xyz1, K, patch_size): - args = (xyz0, xyz1, K, patch_size) - if xyz0.is_cuda: - out = ext_cuda.proj_nn_cuda(*args) - else: - out = ext_cpu.proj_nn_cpu(*args) - return out + @staticmethod + def forward(ctx, xyz0, xyz1, K, patch_size): + args = (xyz0, xyz1, K, patch_size) + if xyz0.is_cuda: + out = ext_cuda.proj_nn_cuda(*args) + else: + out = ext_cpu.proj_nn_cpu(*args) + return out - @staticmethod - def backward(ctx, grad_out): - return None, None, None, None + @staticmethod + def backward(ctx, grad_out): + return None, None, None, None -def proj_nn(xyz0, xyz1, K, patch_size): - return ProjNNFunction.apply(xyz0, xyz1, K, patch_size) +def proj_nn(xyz0, xyz1, K, patch_size): + return ProjNNFunction.apply(xyz0, xyz1, K, patch_size) class XCorrVolFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, in0, in1, n_disps, block_size): - args = (in0, in1, n_disps, block_size) - if in0.is_cuda: - out = ext_cuda.xcorrvol_cuda(*args) - else: - out = ext_cpu.xcorrvol_cpu(*args) - return out + @staticmethod + def forward(ctx, in0, in1, n_disps, block_size): + args = (in0, in1, n_disps, block_size) + if in0.is_cuda: + out = ext_cuda.xcorrvol_cuda(*args) + else: + out = ext_cpu.xcorrvol_cpu(*args) + return out + + @staticmethod + def backward(ctx, grad_out): + return None, None, None, None - @staticmethod - def backward(ctx, grad_out): - return None, None, None, None def xcorrvol(in0, in1, n_disps, block_size): - return XCorrVolFunction.apply(in0, in1, n_disps, block_size) + return XCorrVolFunction.apply(in0, in1, n_disps, block_size) +class PhotometricLossFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, es, ta, block_size, type, eps): + args = (es, ta, block_size, type, eps) + ctx.save_for_backward(es, ta) + ctx.block_size = block_size + ctx.type = type + ctx.eps = eps + if es.is_cuda: + out = ext_cuda.photometric_loss_forward(*args) + else: + out = ext_cpu.photometric_loss_forward(*args) + return out + + @staticmethod + def backward(ctx, grad_out): + es, ta = ctx.saved_tensors + block_size = ctx.block_size + type = ctx.type + eps = ctx.eps + args = (es, ta, grad_out.contiguous(), block_size, type, eps) + if grad_out.is_cuda: + grad_es = ext_cuda.photometric_loss_backward(*args) + else: + grad_es = ext_cpu.photometric_loss_backward(*args) + return grad_es, None, None, None, None -class PhotometricLossFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, es, ta, block_size, type, eps): - args = (es, ta, block_size, type, eps) - ctx.save_for_backward(es, ta) - ctx.block_size = block_size - ctx.type = type - ctx.eps = eps - if es.is_cuda: - out = ext_cuda.photometric_loss_forward(*args) - else: - out = ext_cpu.photometric_loss_forward(*args) - return out - - @staticmethod - def backward(ctx, grad_out): - es, ta = ctx.saved_tensors - block_size = ctx.block_size - type = ctx.type - eps = ctx.eps - args = (es, ta, grad_out.contiguous(), block_size, type, eps) - if grad_out.is_cuda: - grad_es = ext_cuda.photometric_loss_backward(*args) +def photometric_loss(es, ta, block_size, type='mse', eps=0.1): + type = type.lower() + if type == 'mse': + type = 0 + elif type == 'sad': + type = 1 + elif type == 'census_mse': + type = 2 + elif type == 'census_sad': + type = 3 else: - grad_es = ext_cpu.photometric_loss_backward(*args) - return grad_es, None, None, None, None + raise Exception('invalid loss type') + return PhotometricLossFunction.apply(es, ta, block_size, type, eps) -def photometric_loss(es, ta, block_size, type='mse', eps=0.1): - type = type.lower() - if type == 'mse': - type = 0 - elif type == 'sad': - type = 1 - elif type == 'census_mse': - type = 2 - elif type == 'census_sad': - type = 3 - else: - raise Exception('invalid loss type') - return PhotometricLossFunction.apply(es, ta, block_size, type, eps) def photometric_loss_pytorch(es, ta, block_size, type='mse', eps=0.1): - type = type.lower() - p = block_size // 2 - es_pad = torch.nn.functional.pad(es, (p,p,p,p), mode='replicate') - ta_pad = torch.nn.functional.pad(ta, (p,p,p,p), mode='replicate') - es_uf = torch.nn.functional.unfold(es_pad, kernel_size=block_size) - ta_uf = torch.nn.functional.unfold(ta_pad, kernel_size=block_size) - es_uf = es_uf.view(es.shape[0], es.shape[1], -1, es.shape[2], es.shape[3]) - ta_uf = ta_uf.view(ta.shape[0], ta.shape[1], -1, ta.shape[2], ta.shape[3]) - if type == 'mse': - ref = (es_uf - ta_uf)**2 - elif type == 'sad': - ref = torch.abs(es_uf - ta_uf) - elif type == 'census_mse' or type == 'census_sad': - des = es_uf - es.unsqueeze(2) - dta = ta_uf - ta.unsqueeze(2) - h_des = 0.5 * (1 + des / torch.sqrt(des * des + eps)) - h_dta = 0.5 * (1 + dta / torch.sqrt(dta * dta + eps)) - diff = h_des - h_dta - if type == 'census_mse': - ref = diff * diff - elif type == 'census_sad': - ref = torch.abs(diff) - else: - raise Exception('invalid loss type') - ref = ref.view(es.shape[0], -1, es.shape[2], es.shape[3]) - ref = torch.sum(ref, dim=1, keepdim=True) / block_size**2 - return ref + type = type.lower() + p = block_size // 2 + es_pad = torch.nn.functional.pad(es, (p, p, p, p), mode='replicate') + ta_pad = torch.nn.functional.pad(ta, (p, p, p, p), mode='replicate') + es_uf = torch.nn.functional.unfold(es_pad, kernel_size=block_size) + ta_uf = torch.nn.functional.unfold(ta_pad, kernel_size=block_size) + es_uf = es_uf.view(es.shape[0], es.shape[1], -1, es.shape[2], es.shape[3]) + ta_uf = ta_uf.view(ta.shape[0], ta.shape[1], -1, ta.shape[2], ta.shape[3]) + if type == 'mse': + ref = (es_uf - ta_uf) ** 2 + elif type == 'sad': + ref = torch.abs(es_uf - ta_uf) + elif type == 'census_mse' or type == 'census_sad': + des = es_uf - es.unsqueeze(2) + dta = ta_uf - ta.unsqueeze(2) + h_des = 0.5 * (1 + des / torch.sqrt(des * des + eps)) + h_dta = 0.5 * (1 + dta / torch.sqrt(dta * dta + eps)) + diff = h_des - h_dta + if type == 'census_mse': + ref = diff * diff + elif type == 'census_sad': + ref = torch.abs(diff) + else: + raise Exception('invalid loss type') + ref = ref.view(es.shape[0], -1, es.shape[2], es.shape[3]) + ref = torch.sum(ref, dim=1, keepdim=True) / block_size ** 2 + return ref diff --git a/torchext/modules.py b/torchext/modules.py index c7c1bfd..d740fcc 100644 --- a/torchext/modules.py +++ b/torchext/modules.py @@ -4,24 +4,26 @@ import numpy as np from .functions import * + class CoordConv2d(torch.nn.Module): - def __init__(self, channels_in, channels_out, kernel_size, stride, padding): - super().__init__() + def __init__(self, channels_in, channels_out, kernel_size, stride, padding): + super().__init__() - self.conv = torch.nn.Conv2d(channels_in+2, channels_out, kernel_size=kernel_size, padding=padding, stride=stride) + self.conv = torch.nn.Conv2d(channels_in + 2, channels_out, kernel_size=kernel_size, padding=padding, + stride=stride) - self.uv = None + self.uv = None - def forward(self, x): - if self.uv is None: - height, width = x.shape[2], x.shape[3] - u, v = np.meshgrid(range(width), range(height)) - u = 2 * u / (width - 1) - 1 - v = 2 * v / (height - 1) - 1 - uv = np.stack((u, v)).reshape(1, 2, height, width) - self.uv = torch.from_numpy( uv.astype(np.float32) ) - self.uv = self.uv.to(x.device) - uv = self.uv.expand(x.shape[0], *self.uv.shape[1:]) - xuv = torch.cat((x, uv), dim=1) - y = self.conv(xuv) - return y + def forward(self, x): + if self.uv is None: + height, width = x.shape[2], x.shape[3] + u, v = np.meshgrid(range(width), range(height)) + u = 2 * u / (width - 1) - 1 + v = 2 * v / (height - 1) - 1 + uv = np.stack((u, v)).reshape(1, 2, height, width) + self.uv = torch.from_numpy(uv.astype(np.float32)) + self.uv = self.uv.to(x.device) + uv = self.uv.expand(x.shape[0], *self.uv.shape[1:]) + xuv = torch.cat((x, uv), dim=1) + y = self.conv(xuv) + return y diff --git a/torchext/setup.py b/torchext/setup.py index 4d89756..43d8bc9 100644 --- a/torchext/setup.py +++ b/torchext/setup.py @@ -11,11 +11,12 @@ nvcc_args = [ ] setup( - name='ext', - ext_modules=[ - CppExtension('ext_cpu', ['ext/ext_cpu.cpp']), - CUDAExtension('ext_cuda', ['ext/ext_cuda.cpp', 'ext/ext_kernel.cu'], extra_compile_args={'cxx': [], 'nvcc': nvcc_args}), - ], - cmdclass={'build_ext': BuildExtension}, - include_dirs=include_dirs + name='ext', + ext_modules=[ + CppExtension('ext_cpu', ['ext/ext_cpu.cpp']), + CUDAExtension('ext_cuda', ['ext/ext_cuda.cpp', 'ext/ext_kernel.cu'], + extra_compile_args={'cxx': [], 'nvcc': nvcc_args}), + ], + cmdclass={'build_ext': BuildExtension}, + include_dirs=include_dirs ) diff --git a/torchext/worker.py b/torchext/worker.py index a668f93..4617a0a 100644 --- a/torchext/worker.py +++ b/torchext/worker.py @@ -17,512 +17,516 @@ from collections import OrderedDict class StopWatch(object): - def __init__(self): - self.timings = OrderedDict() - self.starts = {} + def __init__(self): + self.timings = OrderedDict() + self.starts = {} - def start(self, name): - self.starts[name] = time.time() + def start(self, name): + self.starts[name] = time.time() - def stop(self, name): - if name not in self.timings: - self.timings[name] = [] - self.timings[name].append(time.time() - self.starts[name]) + def stop(self, name): + if name not in self.timings: + self.timings[name] = [] + self.timings[name].append(time.time() - self.starts[name]) - def get(self, name=None, reduce=np.sum): - if name is not None: - return reduce(self.timings[name]) - else: - ret = {} - for k in self.timings: - ret[k] = reduce(self.timings[k]) - return ret + def get(self, name=None, reduce=np.sum): + if name is not None: + return reduce(self.timings[name]) + else: + ret = {} + for k in self.timings: + ret[k] = reduce(self.timings[k]) + return ret + + def __repr__(self): + return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()]) - def __repr__(self): - return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()]) - def __str__(self): - return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()]) + def __str__(self): + return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()]) class ETA(object): - def __init__(self, length): - self.length = length - self.start_time = time.time() - self.current_idx = 0 - self.current_time = time.time() + def __init__(self, length): + self.length = length + self.start_time = time.time() + self.current_idx = 0 + self.current_time = time.time() - def update(self, idx): - self.current_idx = idx - self.current_time = time.time() + def update(self, idx): + self.current_idx = idx + self.current_time = time.time() - def get_elapsed_time(self): - return self.current_time - self.start_time + def get_elapsed_time(self): + return self.current_time - self.start_time - def get_item_time(self): - return self.get_elapsed_time() / (self.current_idx + 1) + def get_item_time(self): + return self.get_elapsed_time() / (self.current_idx + 1) - def get_remaining_time(self): - return self.get_item_time() * (self.length - self.current_idx + 1) + def get_remaining_time(self): + return self.get_item_time() * (self.length - self.current_idx + 1) - def format_time(self, seconds): - minutes, seconds = divmod(seconds, 60) - hours, minutes = divmod(minutes, 60) - hours = int(hours) - minutes = int(minutes) - return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}' + def format_time(self, seconds): + minutes, seconds = divmod(seconds, 60) + hours, minutes = divmod(minutes, 60) + hours = int(hours) + minutes = int(minutes) + return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}' - def get_elapsed_time_str(self): - return self.format_time(self.get_elapsed_time()) + def get_elapsed_time_str(self): + return self.format_time(self.get_elapsed_time()) + + def get_remaining_time_str(self): + return self.format_time(self.get_remaining_time()) - def get_remaining_time_str(self): - return self.format_time(self.get_remaining_time()) class Worker(object): - def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16, num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1): - self.out_root = Path(out_root) - self.experiment_name = experiment_name - self.epochs = epochs - self.seed = seed - self.train_batch_size = train_batch_size - self.test_batch_size = test_batch_size - self.num_workers = num_workers - self.save_frequency = save_frequency - self.train_device = train_device - self.test_device = test_device - self.max_train_iter = max_train_iter - - self.errs_list=[] - - self.setup_experiment() - - def setup_experiment(self): - self.exp_out_root = self.out_root / self.experiment_name - self.exp_out_root.mkdir(parents=True, exist_ok=True) - - if logging.root: del logging.root.handlers[:] - logging.basicConfig( - level=logging.INFO, - handlers=[ - logging.FileHandler( str(self.exp_out_root / 'train.log') ), - logging.StreamHandler() - ], - format='%(relativeCreated)d:%(levelname)s:%(process)d-%(processName)s: %(message)s' - ) - - logging.info('='*80) - logging.info(f'Start of experiment: {self.experiment_name}') - logging.info(socket.gethostname()) - self.log_datetime() - logging.info('='*80) - - self.metric_path = self.exp_out_root / 'metrics.json' - if self.metric_path.exists(): - with open(str(self.metric_path), 'r') as fp: - self.metric_data = json.load(fp) - else: - self.metric_data = {} - - self.init_seed() - - def metric_add_train(self, epoch, key, val): - epoch = str(epoch) - key = str(key) - if epoch not in self.metric_data: - self.metric_data[epoch] = {} - if 'train' not in self.metric_data[epoch]: - self.metric_data[epoch]['train'] = {} - self.metric_data[epoch]['train'][key] = val - - def metric_add_test(self, epoch, set_idx, key, val): - epoch = str(epoch) - set_idx = str(set_idx) - key = str(key) - if epoch not in self.metric_data: - self.metric_data[epoch] = {} - if 'test' not in self.metric_data[epoch]: - self.metric_data[epoch]['test'] = {} - if set_idx not in self.metric_data[epoch]['test']: - self.metric_data[epoch]['test'][set_idx] = {} - self.metric_data[epoch]['test'][set_idx][key] = val - - def metric_save(self): - with open(str(self.metric_path), 'w') as fp: - json.dump(self.metric_data, fp, indent=2) - - def init_seed(self, seed=None): - if seed is not None: - self.seed = seed - logging.info(f'Set seed to {self.seed}') - np.random.seed(self.seed) - random.seed(self.seed) - torch.manual_seed(self.seed) - torch.cuda.manual_seed(self.seed) - - def log_datetime(self): - logging.info(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')) - - def mem_report(self): - for obj in gc.get_objects(): - if torch.is_tensor(obj): - print(type(obj), obj.shape) - - def get_net_path(self, epoch, root=None): - if root is None: - root = self.exp_out_root - return root / f'net_{epoch:04d}.params' - - def get_do_parser_cmds(self): - return ['retrain', 'resume', 'retest', 'test_init'] - - def get_do_parser(self): - parser = argparse.ArgumentParser() - parser.add_argument('--cmd', type=str, default='resume', choices=self.get_do_parser_cmds()) - parser.add_argument('--epoch', type=int, default=-1) - return parser - - def do_cmd(self, args, net, optimizer, scheduler=None): - if args.cmd == 'retrain': - self.train(net, optimizer, resume=False, scheduler=scheduler) - elif args.cmd == 'resume': - self.train(net, optimizer, resume=True, scheduler=scheduler) - elif args.cmd == 'retest': - self.retest(net, epoch=args.epoch) - elif args.cmd == 'test_init': - test_sets = self.get_test_sets() - self.test(-1, net, test_sets) - else: - raise Exception('invalid cmd') - - def do(self, net, optimizer, load_net_optimizer=None, scheduler=None): - parser = self.get_do_parser() - args, _ = parser.parse_known_args() - - if load_net_optimizer is not None and args.cmd not in ['schedule']: - net, optimizer = load_net_optimizer() - - self.do_cmd(args, net, optimizer, scheduler=scheduler) - - def retest(self, net, epoch=-1): - if epoch < 0: - epochs = range(self.epochs) - else: - epochs = [epoch] - - test_sets = self.get_test_sets() - - for epoch in epochs: - net_path = self.get_net_path(epoch) - if net_path.exists(): - state_dict = torch.load(str(net_path)) - net.load_state_dict(state_dict) - self.test(epoch, net, test_sets) - - def format_err_str(self, errs, div=1): - err = sum(errs) - if len(errs) > 1: - err_str = f'{err/div:0.4f}=' + '+'.join([f'{e/div:0.4f}' for e in errs]) - else: - err_str = f'{err/div:0.4f}' - return err_str - - def write_err_img(self): - err_img_path = self.exp_out_root / 'errs.png' - fig = plt.figure(figsize=(16,16)) - lines=[] - for idx,errs in enumerate(self.errs_list): - line,=plt.plot(range(len(errs)), errs, label=f'error{idx}') - lines.append(line) - plt.tight_layout() - plt.legend(handles=lines) - plt.savefig(str(err_img_path)) - plt.close(fig) - - - def callback_train_new_epoch(self, epoch, net, optimizer): - pass - - def train(self, net, optimizer, resume=False, scheduler=None): - logging.info('='*80) - logging.info('Start training') - self.log_datetime() - logging.info('='*80) - - train_set = self.get_train_set() - test_sets = self.get_test_sets() - - net = net.to(self.train_device) - - epoch = 0 - min_err = {ts.name: 1e9 for ts in test_sets} - - state_path = self.exp_out_root / 'state.dict' - if resume and state_path.exists(): - logging.info('='*80) - logging.info(f'Loading state from {state_path}') - logging.info('='*80) - state = torch.load(str(state_path)) - epoch = state['epoch'] + 1 - if 'min_err' in state: - min_err = state['min_err'] - - curr_state = net.state_dict() - curr_state.update(state['state_dict']) - net.load_state_dict(curr_state) - - - try: - optimizer.load_state_dict(state['optimizer']) - except: - logging.info('Warning: cannot load optimizer from state_dict') - pass - if 'cpu_rng_state' in state: - torch.set_rng_state(state['cpu_rng_state']) - if 'gpu_rng_state' in state: - torch.cuda.set_rng_state(state['gpu_rng_state']) + def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16, + num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1): + self.out_root = Path(out_root) + self.experiment_name = experiment_name + self.epochs = epochs + self.seed = seed + self.train_batch_size = train_batch_size + self.test_batch_size = test_batch_size + self.num_workers = num_workers + self.save_frequency = save_frequency + self.train_device = train_device + self.test_device = test_device + self.max_train_iter = max_train_iter + + self.errs_list = [] + + self.setup_experiment() + + def setup_experiment(self): + self.exp_out_root = self.out_root / self.experiment_name + self.exp_out_root.mkdir(parents=True, exist_ok=True) + + if logging.root: del logging.root.handlers[:] + logging.basicConfig( + level=logging.INFO, + handlers=[ + logging.FileHandler(str(self.exp_out_root / 'train.log')), + logging.StreamHandler() + ], + format='%(relativeCreated)d:%(levelname)s:%(process)d-%(processName)s: %(message)s' + ) + + logging.info('=' * 80) + logging.info(f'Start of experiment: {self.experiment_name}') + logging.info(socket.gethostname()) + self.log_datetime() + logging.info('=' * 80) + + self.metric_path = self.exp_out_root / 'metrics.json' + if self.metric_path.exists(): + with open(str(self.metric_path), 'r') as fp: + self.metric_data = json.load(fp) + else: + self.metric_data = {} + + self.init_seed() + + def metric_add_train(self, epoch, key, val): + epoch = str(epoch) + key = str(key) + if epoch not in self.metric_data: + self.metric_data[epoch] = {} + if 'train' not in self.metric_data[epoch]: + self.metric_data[epoch]['train'] = {} + self.metric_data[epoch]['train'][key] = val + + def metric_add_test(self, epoch, set_idx, key, val): + epoch = str(epoch) + set_idx = str(set_idx) + key = str(key) + if epoch not in self.metric_data: + self.metric_data[epoch] = {} + if 'test' not in self.metric_data[epoch]: + self.metric_data[epoch]['test'] = {} + if set_idx not in self.metric_data[epoch]['test']: + self.metric_data[epoch]['test'][set_idx] = {} + self.metric_data[epoch]['test'][set_idx][key] = val + + def metric_save(self): + with open(str(self.metric_path), 'w') as fp: + json.dump(self.metric_data, fp, indent=2) + + def init_seed(self, seed=None): + if seed is not None: + self.seed = seed + logging.info(f'Set seed to {self.seed}') + np.random.seed(self.seed) + random.seed(self.seed) + torch.manual_seed(self.seed) + torch.cuda.manual_seed(self.seed) + + def log_datetime(self): + logging.info(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')) + + def mem_report(self): + for obj in gc.get_objects(): + if torch.is_tensor(obj): + print(type(obj), obj.shape) + + def get_net_path(self, epoch, root=None): + if root is None: + root = self.exp_out_root + return root / f'net_{epoch:04d}.params' + + def get_do_parser_cmds(self): + return ['retrain', 'resume', 'retest', 'test_init'] + + def get_do_parser(self): + parser = argparse.ArgumentParser() + parser.add_argument('--cmd', type=str, default='resume', choices=self.get_do_parser_cmds()) + parser.add_argument('--epoch', type=int, default=-1) + return parser + + def do_cmd(self, args, net, optimizer, scheduler=None): + if args.cmd == 'retrain': + self.train(net, optimizer, resume=False, scheduler=scheduler) + elif args.cmd == 'resume': + self.train(net, optimizer, resume=True, scheduler=scheduler) + elif args.cmd == 'retest': + self.retest(net, epoch=args.epoch) + elif args.cmd == 'test_init': + test_sets = self.get_test_sets() + self.test(-1, net, test_sets) + else: + raise Exception('invalid cmd') + + def do(self, net, optimizer, load_net_optimizer=None, scheduler=None): + parser = self.get_do_parser() + args, _ = parser.parse_known_args() + + if load_net_optimizer is not None and args.cmd not in ['schedule']: + net, optimizer = load_net_optimizer() - for epoch in range(epoch, self.epochs): - self.callback_train_new_epoch(epoch, net, optimizer) + self.do_cmd(args, net, optimizer, scheduler=scheduler) + + def retest(self, net, epoch=-1): + if epoch < 0: + epochs = range(self.epochs) + else: + epochs = [epoch] + + test_sets = self.get_test_sets() + + for epoch in epochs: + net_path = self.get_net_path(epoch) + if net_path.exists(): + state_dict = torch.load(str(net_path)) + net.load_state_dict(state_dict) + self.test(epoch, net, test_sets) + + def format_err_str(self, errs, div=1): + err = sum(errs) + if len(errs) > 1: + err_str = f'{err / div:0.4f}=' + '+'.join([f'{e / div:0.4f}' for e in errs]) + else: + err_str = f'{err / div:0.4f}' + return err_str + + def write_err_img(self): + err_img_path = self.exp_out_root / 'errs.png' + fig = plt.figure(figsize=(16, 16)) + lines = [] + for idx, errs in enumerate(self.errs_list): + line, = plt.plot(range(len(errs)), errs, label=f'error{idx}') + lines.append(line) + plt.tight_layout() + plt.legend(handles=lines) + plt.savefig(str(err_img_path)) + plt.close(fig) + + def callback_train_new_epoch(self, epoch, net, optimizer): + pass - # train epoch - self.train_epoch(epoch, net, optimizer, train_set) + def train(self, net, optimizer, resume=False, scheduler=None): + logging.info('=' * 80) + logging.info('Start training') + self.log_datetime() + logging.info('=' * 80) - # test epoch - errs = self.test(epoch, net, test_sets) + train_set = self.get_train_set() + test_sets = self.get_test_sets() - if (epoch + 1) % self.save_frequency == 0: net = net.to(self.train_device) - # store state - state_dict = { - 'epoch': epoch, - 'min_err': min_err, - 'state_dict': net.state_dict(), - 'optimizer': optimizer.state_dict(), - 'cpu_rng_state': torch.get_rng_state(), - 'gpu_rng_state': torch.cuda.get_rng_state(), - } - logging.info(f'save state to {state_path}') + epoch = 0 + min_err = {ts.name: 1e9 for ts in test_sets} + state_path = self.exp_out_root / 'state.dict' - torch.save(state_dict, str(state_path)) - - for test_set_name in errs: - err = sum(errs[test_set_name]) - if err < min_err[test_set_name]: - min_err[test_set_name] = err - state_path = self.exp_out_root / f'state_set{test_set_name}_best.dict' - logging.info(f'save state to {state_path}') - torch.save(state_dict, str(state_path)) - - # store network - net_path = self.get_net_path(epoch) - logging.info(f'save network to {net_path}') - torch.save(net.state_dict(), str(net_path)) - - if scheduler is not None: - scheduler.step() - - logging.info('='*80) - logging.info('Finished training') - self.log_datetime() - logging.info('='*80) - - def get_train_set(self): - # returns train_set - raise NotImplementedError() - - def get_test_sets(self): - # returns test_sets - raise NotImplementedError() - - def copy_data(self, data, device, requires_grad, train): - raise NotImplementedError() - - def net_forward(self, net, train): - raise NotImplementedError() - - def loss_forward(self, output, train): - raise NotImplementedError() - - def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks): - # err = False - # for name, param in net.named_parameters(): - # if not torch.isfinite(param.grad).all(): - # print(name) - # err = True - # if err: - # import ipdb; ipdb.set_trace() - pass - - def callback_train_start(self, epoch): - pass - - def callback_train_stop(self, epoch, loss): - pass - - def train_epoch(self, epoch, net, optimizer, dset): - self.callback_train_start(epoch) - stopwatch = StopWatch() - - logging.info('='*80) - logging.info('Train epoch %d' % epoch) - - dset.current_epoch = epoch - train_loader = torch.utils.data.DataLoader(dset, batch_size=self.train_batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True, pin_memory=False) - - net = net.to(self.train_device) - net.train() - - mean_loss = None - - n_batches = self.max_train_iter if self.max_train_iter > 0 else len(train_loader) - bar = ETA(length=n_batches) - - stopwatch.start('total') - stopwatch.start('data') - for batch_idx, data in enumerate(train_loader): - if self.max_train_iter > 0 and batch_idx > self.max_train_iter: break - self.copy_data(data, device=self.train_device, requires_grad=True, train=True) - stopwatch.stop('data') - - optimizer.zero_grad() - - stopwatch.start('forward') - output = self.net_forward(net, train=True) - if 'cuda' in self.train_device: torch.cuda.synchronize() - stopwatch.stop('forward') - - stopwatch.start('loss') - errs = self.loss_forward(output, train=True) - if isinstance(errs, dict): - masks = errs['masks'] - errs = errs['errs'] - else: - masks = [] - if not isinstance(errs, list) and not isinstance(errs, tuple): - errs = [errs] - err = sum(errs) - if 'cuda' in self.train_device: torch.cuda.synchronize() - stopwatch.stop('loss') - - stopwatch.start('backward') - err.backward() - self.callback_train_post_backward(net, errs, output, epoch, batch_idx, masks) - if 'cuda' in self.train_device: torch.cuda.synchronize() - stopwatch.stop('backward') - - stopwatch.start('optimizer') - optimizer.step() - if 'cuda' in self.train_device: torch.cuda.synchronize() - stopwatch.stop('optimizer') - - bar.update(batch_idx) - if (epoch <= 1 and batch_idx < 128) or batch_idx % 16 == 0: - err_str = self.format_err_str(errs) - logging.info(f'train e{epoch}: {batch_idx+1}/{len(train_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}') - #self.write_err_img() - - - if mean_loss is None: - mean_loss = [0 for e in errs] - for erridx, err in enumerate(errs): - mean_loss[erridx] += err.item() - - stopwatch.start('data') - stopwatch.stop('total') - logging.info('timings: %s' % stopwatch) - - mean_loss = [l / len(train_loader) for l in mean_loss] - self.callback_train_stop(epoch, mean_loss) - self.metric_add_train(epoch, 'loss', mean_loss) - - # save metrics - self.metric_save() - - err_str = self.format_err_str(mean_loss) - logging.info(f'avg train_loss={err_str}') - return mean_loss - - def callback_test_start(self, epoch, set_idx): - pass - - def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks): - pass - - def callback_test_stop(self, epoch, set_idx, loss): - pass - - def test(self, epoch, net, test_sets): - errs = {} - for test_set_idx, test_set in enumerate(test_sets): - if (epoch + 1) % test_set.test_frequency == 0: - logging.info('='*80) - logging.info(f'testing set {test_set.name}') - err = self.test_epoch(epoch, test_set_idx, net, test_set.dset) - errs[test_set.name] = err - return errs - - def test_epoch(self, epoch, set_idx, net, dset): - logging.info('-'*80) - logging.info('Test epoch %d' % epoch) - dset.current_epoch = epoch - test_loader = torch.utils.data.DataLoader(dset, batch_size=self.test_batch_size, shuffle=False, num_workers=self.num_workers, drop_last=False, pin_memory=False) - - net = net.to(self.test_device) - net.eval() - - with torch.no_grad(): - mean_loss = None - - self.callback_test_start(epoch, set_idx) - - bar = ETA(length=len(test_loader)) - stopwatch = StopWatch() - stopwatch.start('total') - stopwatch.start('data') - for batch_idx, data in enumerate(test_loader): - # if batch_idx == 10: break - self.copy_data(data, device=self.test_device, requires_grad=False, train=False) - stopwatch.stop('data') - - stopwatch.start('forward') - output = self.net_forward(net, train=False) - if 'cuda' in self.test_device: torch.cuda.synchronize() - stopwatch.stop('forward') - - stopwatch.start('loss') - errs = self.loss_forward(output, train=False) - if isinstance(errs, dict): - masks = errs['masks'] - errs = errs['errs'] - else: - masks = [] - if not isinstance(errs, list) and not isinstance(errs, tuple): - errs = [errs] + if resume and state_path.exists(): + logging.info('=' * 80) + logging.info(f'Loading state from {state_path}') + logging.info('=' * 80) + state = torch.load(str(state_path)) + epoch = state['epoch'] + 1 + if 'min_err' in state: + min_err = state['min_err'] + + curr_state = net.state_dict() + curr_state.update(state['state_dict']) + net.load_state_dict(curr_state) + + try: + optimizer.load_state_dict(state['optimizer']) + except: + logging.info('Warning: cannot load optimizer from state_dict') + pass + if 'cpu_rng_state' in state: + torch.set_rng_state(state['cpu_rng_state']) + if 'gpu_rng_state' in state: + torch.cuda.set_rng_state(state['gpu_rng_state']) + + for epoch in range(epoch, self.epochs): + self.callback_train_new_epoch(epoch, net, optimizer) + + # train epoch + self.train_epoch(epoch, net, optimizer, train_set) + + # test epoch + errs = self.test(epoch, net, test_sets) + + if (epoch + 1) % self.save_frequency == 0: + net = net.to(self.train_device) + + # store state + state_dict = { + 'epoch': epoch, + 'min_err': min_err, + 'state_dict': net.state_dict(), + 'optimizer': optimizer.state_dict(), + 'cpu_rng_state': torch.get_rng_state(), + 'gpu_rng_state': torch.cuda.get_rng_state(), + } + logging.info(f'save state to {state_path}') + state_path = self.exp_out_root / 'state.dict' + torch.save(state_dict, str(state_path)) + + for test_set_name in errs: + err = sum(errs[test_set_name]) + if err < min_err[test_set_name]: + min_err[test_set_name] = err + state_path = self.exp_out_root / f'state_set{test_set_name}_best.dict' + logging.info(f'save state to {state_path}') + torch.save(state_dict, str(state_path)) + + # store network + net_path = self.get_net_path(epoch) + logging.info(f'save network to {net_path}') + torch.save(net.state_dict(), str(net_path)) + + if scheduler is not None: + scheduler.step() + + logging.info('=' * 80) + logging.info('Finished training') + self.log_datetime() + logging.info('=' * 80) + + def get_train_set(self): + # returns train_set + raise NotImplementedError() + + def get_test_sets(self): + # returns test_sets + raise NotImplementedError() + + def copy_data(self, data, device, requires_grad, train): + raise NotImplementedError() + + def net_forward(self, net, train): + raise NotImplementedError() + + def loss_forward(self, output, train): + raise NotImplementedError() + + def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks): + # err = False + # for name, param in net.named_parameters(): + # if not torch.isfinite(param.grad).all(): + # print(name) + # err = True + # if err: + # import ipdb; ipdb.set_trace() + pass + + def callback_train_start(self, epoch): + pass + + def callback_train_stop(self, epoch, loss): + pass - bar.update(batch_idx) - if batch_idx % 25 == 0: - err_str = self.format_err_str(errs) - logging.info(f'test e{epoch}: {batch_idx+1}/{len(test_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}') + def train_epoch(self, epoch, net, optimizer, dset): + self.callback_train_start(epoch) + stopwatch = StopWatch() - if mean_loss is None: - mean_loss = [0 for e in errs] - for erridx, err in enumerate(errs): - mean_loss[erridx] += err.item() - stopwatch.stop('loss') + logging.info('=' * 80) + logging.info('Train epoch %d' % epoch) - self.callback_test_add(epoch, set_idx, batch_idx, len(test_loader), output, masks) + dset.current_epoch = epoch + train_loader = torch.utils.data.DataLoader(dset, batch_size=self.train_batch_size, shuffle=True, + num_workers=self.num_workers, drop_last=True, pin_memory=False) + net = net.to(self.train_device) + net.train() + + mean_loss = None + + n_batches = self.max_train_iter if self.max_train_iter > 0 else len(train_loader) + bar = ETA(length=n_batches) + + stopwatch.start('total') stopwatch.start('data') - stopwatch.stop('total') - logging.info('timings: %s' % stopwatch) + for batch_idx, data in enumerate(train_loader): + if self.max_train_iter > 0 and batch_idx > self.max_train_iter: break + self.copy_data(data, device=self.train_device, requires_grad=True, train=True) + stopwatch.stop('data') + + optimizer.zero_grad() + + stopwatch.start('forward') + output = self.net_forward(net, train=True) + if 'cuda' in self.train_device: torch.cuda.synchronize() + stopwatch.stop('forward') + + stopwatch.start('loss') + errs = self.loss_forward(output, train=True) + if isinstance(errs, dict): + masks = errs['masks'] + errs = errs['errs'] + else: + masks = [] + if not isinstance(errs, list) and not isinstance(errs, tuple): + errs = [errs] + err = sum(errs) + if 'cuda' in self.train_device: torch.cuda.synchronize() + stopwatch.stop('loss') + + stopwatch.start('backward') + err.backward() + self.callback_train_post_backward(net, errs, output, epoch, batch_idx, masks) + if 'cuda' in self.train_device: torch.cuda.synchronize() + stopwatch.stop('backward') + + stopwatch.start('optimizer') + optimizer.step() + if 'cuda' in self.train_device: torch.cuda.synchronize() + stopwatch.stop('optimizer') + + bar.update(batch_idx) + if (epoch <= 1 and batch_idx < 128) or batch_idx % 16 == 0: + err_str = self.format_err_str(errs) + logging.info( + f'train e{epoch}: {batch_idx + 1}/{len(train_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}') + # self.write_err_img() + + if mean_loss is None: + mean_loss = [0 for e in errs] + for erridx, err in enumerate(errs): + mean_loss[erridx] += err.item() + + stopwatch.start('data') + stopwatch.stop('total') + logging.info('timings: %s' % stopwatch) + + mean_loss = [l / len(train_loader) for l in mean_loss] + self.callback_train_stop(epoch, mean_loss) + self.metric_add_train(epoch, 'loss', mean_loss) + + # save metrics + self.metric_save() + + err_str = self.format_err_str(mean_loss) + logging.info(f'avg train_loss={err_str}') + return mean_loss + + def callback_test_start(self, epoch, set_idx): + pass - mean_loss = [l / len(test_loader) for l in mean_loss] - self.callback_test_stop(epoch, set_idx, mean_loss) - self.metric_add_test(epoch, set_idx, 'loss', mean_loss) + def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks): + pass - # save metrics - self.metric_save() + def callback_test_stop(self, epoch, set_idx, loss): + pass - err_str = self.format_err_str(mean_loss) - logging.info(f'test epoch {epoch}: avg test_loss={err_str}') - return mean_loss + def test(self, epoch, net, test_sets): + errs = {} + for test_set_idx, test_set in enumerate(test_sets): + if (epoch + 1) % test_set.test_frequency == 0: + logging.info('=' * 80) + logging.info(f'testing set {test_set.name}') + err = self.test_epoch(epoch, test_set_idx, net, test_set.dset) + errs[test_set.name] = err + return errs + + def test_epoch(self, epoch, set_idx, net, dset): + logging.info('-' * 80) + logging.info('Test epoch %d' % epoch) + dset.current_epoch = epoch + test_loader = torch.utils.data.DataLoader(dset, batch_size=self.test_batch_size, shuffle=False, + num_workers=self.num_workers, drop_last=False, pin_memory=False) + + net = net.to(self.test_device) + net.eval() + + with torch.no_grad(): + mean_loss = None + + self.callback_test_start(epoch, set_idx) + + bar = ETA(length=len(test_loader)) + stopwatch = StopWatch() + stopwatch.start('total') + stopwatch.start('data') + for batch_idx, data in enumerate(test_loader): + # if batch_idx == 10: break + self.copy_data(data, device=self.test_device, requires_grad=False, train=False) + stopwatch.stop('data') + + stopwatch.start('forward') + output = self.net_forward(net, train=False) + if 'cuda' in self.test_device: torch.cuda.synchronize() + stopwatch.stop('forward') + + stopwatch.start('loss') + errs = self.loss_forward(output, train=False) + if isinstance(errs, dict): + masks = errs['masks'] + errs = errs['errs'] + else: + masks = [] + if not isinstance(errs, list) and not isinstance(errs, tuple): + errs = [errs] + + bar.update(batch_idx) + if batch_idx % 25 == 0: + err_str = self.format_err_str(errs) + logging.info( + f'test e{epoch}: {batch_idx + 1}/{len(test_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}') + + if mean_loss is None: + mean_loss = [0 for e in errs] + for erridx, err in enumerate(errs): + mean_loss[erridx] += err.item() + stopwatch.stop('loss') + + self.callback_test_add(epoch, set_idx, batch_idx, len(test_loader), output, masks) + + stopwatch.start('data') + stopwatch.stop('total') + logging.info('timings: %s' % stopwatch) + + mean_loss = [l / len(test_loader) for l in mean_loss] + self.callback_test_stop(epoch, set_idx, mean_loss) + self.metric_add_test(epoch, set_idx, 'loss', mean_loss) + + # save metrics + self.metric_save() + + err_str = self.format_err_str(mean_loss) + logging.info(f'test epoch {epoch}: avg test_loss={err_str}') + return mean_loss diff --git a/train_val.py b/train_val.py index 9c92610..17f14f2 100644 --- a/train_val.py +++ b/train_val.py @@ -5,25 +5,24 @@ from model import exp_synphge from model import networks from co.args import parse_args - # parse args args = parse_args() # loss types -if args.loss=='ph': - worker = exp_synph.Worker(args) -elif args.loss=='phge': - worker = exp_synphge.Worker(args) +if args.loss == 'ph': + worker = exp_synph.Worker(args) +elif args.loss == 'phge': + worker = exp_synphge.Worker(args) # concatenation of original image and lcn image -channels_in=2 +channels_in = 2 # set up network -net = networks.DispEdgeDecoders(channels_in=channels_in, max_disp=args.max_disp, imsizes=worker.imsizes, output_ms=worker.ms) +net = networks.DispEdgeDecoders(channels_in=channels_in, max_disp=args.max_disp, imsizes=worker.imsizes, + output_ms=worker.ms) # optimizer optimizer = torch.optim.Adam(net.parameters(), lr=1e-4) # start the work worker.do(net, optimizer) -