Reformat $EVERYTHING
This commit is contained in:
parent
56f2aa7d5d
commit
43df77fb9b
@ -7,6 +7,7 @@
|
|||||||
# set matplotlib backend depending on env
|
# set matplotlib backend depending on env
|
||||||
import os
|
import os
|
||||||
import matplotlib
|
import matplotlib
|
||||||
|
|
||||||
if os.name == 'posix' and "DISPLAY" not in os.environ:
|
if os.name == 'posix' and "DISPLAY" not in os.environ:
|
||||||
matplotlib.use('Agg')
|
matplotlib.use('Agg')
|
||||||
|
|
||||||
|
@ -66,6 +66,3 @@ def parse_args():
|
|||||||
def get_exp_name(args):
|
def get_exp_name(args):
|
||||||
name = f"exp_{args.data_type}"
|
name = f"exp_{args.data_type}"
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@ _color_map_errors = np.array([
|
|||||||
[38, 0, 165] # inf: log2(x) = inf
|
[38, 0, 165] # inf: log2(x) = inf
|
||||||
]).astype(float)
|
]).astype(float)
|
||||||
|
|
||||||
|
|
||||||
def color_error_image(errors, scale=1, mask=None, BGR=True):
|
def color_error_image(errors, scale=1, mask=None, BGR=True):
|
||||||
"""
|
"""
|
||||||
Color an input error map.
|
Color an input error map.
|
||||||
@ -32,7 +33,8 @@ def color_error_image(errors, scale=1, mask=None, BGR=True):
|
|||||||
errors_color_indices = np.clip(np.log2(errors_flat / scale + 1e-5) + 5, 0, 9)
|
errors_color_indices = np.clip(np.log2(errors_flat / scale + 1e-5) + 5, 0, 9)
|
||||||
i0 = np.floor(errors_color_indices).astype(int)
|
i0 = np.floor(errors_color_indices).astype(int)
|
||||||
f1 = errors_color_indices - i0.astype(float)
|
f1 = errors_color_indices - i0.astype(float)
|
||||||
colored_errors_flat = _color_map_errors[i0, :] * (1-f1).reshape(-1,1) + _color_map_errors[i0+1, :] * f1.reshape(-1,1)
|
colored_errors_flat = _color_map_errors[i0, :] * (1 - f1).reshape(-1, 1) + _color_map_errors[i0 + 1,
|
||||||
|
:] * f1.reshape(-1, 1)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
colored_errors_flat[mask.flatten() == 0] = 255
|
colored_errors_flat[mask.flatten() == 0] = 255
|
||||||
@ -42,6 +44,7 @@ def color_error_image(errors, scale=1, mask=None, BGR=True):
|
|||||||
|
|
||||||
return colored_errors_flat.reshape(errors.shape[0], errors.shape[1], 3).astype(np.int)
|
return colored_errors_flat.reshape(errors.shape[0], errors.shape[1], 3).astype(np.int)
|
||||||
|
|
||||||
|
|
||||||
_color_map_depths = np.array([
|
_color_map_depths = np.array([
|
||||||
[0, 0, 0], # 0.000
|
[0, 0, 0], # 0.000
|
||||||
[0, 0, 255], # 0.114
|
[0, 0, 255], # 0.114
|
||||||
@ -65,6 +68,7 @@ _color_map_bincenters = np.array([
|
|||||||
2.000, # doesn't make a difference, just strictly higher than 1
|
2.000, # doesn't make a difference, just strictly higher than 1
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
def color_depth_map(depths, scale=None):
|
def color_depth_map(depths, scale=None):
|
||||||
"""
|
"""
|
||||||
Color an input depth map.
|
Color an input depth map.
|
||||||
@ -86,7 +90,8 @@ def color_depth_map(depths, scale=None):
|
|||||||
lower_bin_value = _color_map_bincenters[lower_bin]
|
lower_bin_value = _color_map_bincenters[lower_bin]
|
||||||
higher_bin_value = _color_map_bincenters[lower_bin + 1]
|
higher_bin_value = _color_map_bincenters[lower_bin + 1]
|
||||||
alphas = (values - lower_bin_value) / (higher_bin_value - lower_bin_value)
|
alphas = (values - lower_bin_value) / (higher_bin_value - lower_bin_value)
|
||||||
colors = _color_map_depths[lower_bin] * (1-alphas).reshape(-1,1) + _color_map_depths[lower_bin + 1] * alphas.reshape(-1,1)
|
colors = _color_map_depths[lower_bin] * (1 - alphas).reshape(-1, 1) + _color_map_depths[
|
||||||
|
lower_bin + 1] * alphas.reshape(-1, 1)
|
||||||
return colors.reshape(depths.shape[0], depths.shape[1], 3).astype(np.uint8)
|
return colors.reshape(depths.shape[0], depths.shape[1], 3).astype(np.uint8)
|
||||||
|
|
||||||
# from utils.debug import save_color_numpy
|
# from utils.debug import save_color_numpy
|
||||||
|
@ -2,6 +2,7 @@ import numpy as np
|
|||||||
import scipy.spatial
|
import scipy.spatial
|
||||||
import scipy.linalg
|
import scipy.linalg
|
||||||
|
|
||||||
|
|
||||||
def nullspace(A, atol=1e-13, rtol=0):
|
def nullspace(A, atol=1e-13, rtol=0):
|
||||||
u, s, vh = np.linalg.svd(A)
|
u, s, vh = np.linalg.svd(A)
|
||||||
tol = max(atol, rtol * s[0])
|
tol = max(atol, rtol * s[0])
|
||||||
@ -9,10 +10,12 @@ def nullspace(A, atol=1e-13, rtol=0):
|
|||||||
ns = vh[nnz:].conj().T
|
ns = vh[nnz:].conj().T
|
||||||
return ns
|
return ns
|
||||||
|
|
||||||
|
|
||||||
def nearest_orthogonal_matrix(R):
|
def nearest_orthogonal_matrix(R):
|
||||||
U, S, Vt = np.linalg.svd(R)
|
U, S, Vt = np.linalg.svd(R)
|
||||||
return U @ np.eye(3, dtype=R.dtype) @ Vt
|
return U @ np.eye(3, dtype=R.dtype) @ Vt
|
||||||
|
|
||||||
|
|
||||||
def power_iters(A, n_iters=10):
|
def power_iters(A, n_iters=10):
|
||||||
b = np.random.uniform(-1, 1, size=(A.shape[0], A.shape[1], 1))
|
b = np.random.uniform(-1, 1, size=(A.shape[0], A.shape[1], 1))
|
||||||
for iter in range(n_iters):
|
for iter in range(n_iters):
|
||||||
@ -20,6 +23,7 @@ def power_iters(A, n_iters=10):
|
|||||||
b = b / np.linalg.norm(b, axis=1, keepdims=True)
|
b = b / np.linalg.norm(b, axis=1, keepdims=True)
|
||||||
return b
|
return b
|
||||||
|
|
||||||
|
|
||||||
def rayleigh_quotient(A, b):
|
def rayleigh_quotient(A, b):
|
||||||
return (b.transpose(0, 2, 1) @ A @ b) / (b.transpose(0, 2, 1) @ b)
|
return (b.transpose(0, 2, 1) @ A @ b) / (b.transpose(0, 2, 1) @ b)
|
||||||
|
|
||||||
@ -38,9 +42,11 @@ def cross_prod_mat(x):
|
|||||||
X[:, 2, 2] = 0
|
X[:, 2, 2] = 0
|
||||||
return X.squeeze()
|
return X.squeeze()
|
||||||
|
|
||||||
|
|
||||||
def hat_operator(x):
|
def hat_operator(x):
|
||||||
return cross_prod_mat(x)
|
return cross_prod_mat(x)
|
||||||
|
|
||||||
|
|
||||||
def vee_operator(X):
|
def vee_operator(X):
|
||||||
X = X.reshape(-1, 3, 3)
|
X = X.reshape(-1, 3, 3)
|
||||||
x = np.empty((X.shape[0], 3), dtype=X.dtype)
|
x = np.empty((X.shape[0], 3), dtype=X.dtype)
|
||||||
@ -61,6 +67,7 @@ def rot_x(x, dtype=np.float32):
|
|||||||
R[:, 2, 2] = np.cos(x).ravel()
|
R[:, 2, 2] = np.cos(x).ravel()
|
||||||
return R.squeeze()
|
return R.squeeze()
|
||||||
|
|
||||||
|
|
||||||
def rot_y(y, dtype=np.float32):
|
def rot_y(y, dtype=np.float32):
|
||||||
y = np.array(y, copy=False)
|
y = np.array(y, copy=False)
|
||||||
y = y.reshape(-1, 1)
|
y = y.reshape(-1, 1)
|
||||||
@ -72,6 +79,7 @@ def rot_y(y, dtype=np.float32):
|
|||||||
R[:, 2, 2] = np.cos(y).ravel()
|
R[:, 2, 2] = np.cos(y).ravel()
|
||||||
return R.squeeze()
|
return R.squeeze()
|
||||||
|
|
||||||
|
|
||||||
def rot_z(z, dtype=np.float32):
|
def rot_z(z, dtype=np.float32):
|
||||||
z = np.array(z, copy=False)
|
z = np.array(z, copy=False)
|
||||||
z = z.reshape(-1, 1)
|
z = z.reshape(-1, 1)
|
||||||
@ -83,6 +91,7 @@ def rot_z(z, dtype=np.float32):
|
|||||||
R[:, 2, 2] = 1
|
R[:, 2, 2] = 1
|
||||||
return R.squeeze()
|
return R.squeeze()
|
||||||
|
|
||||||
|
|
||||||
def xyz_from_rotm(R):
|
def xyz_from_rotm(R):
|
||||||
R = R.reshape(-1, 3, 3)
|
R = R.reshape(-1, 3, 3)
|
||||||
xyz = np.empty((R.shape[0], 3), dtype=R.dtype)
|
xyz = np.empty((R.shape[0], 3), dtype=R.dtype)
|
||||||
@ -102,6 +111,7 @@ def xyz_from_rotm(R):
|
|||||||
xyz[bidx, 2] = 0
|
xyz[bidx, 2] = 0
|
||||||
return xyz.squeeze()
|
return xyz.squeeze()
|
||||||
|
|
||||||
|
|
||||||
def zyx_from_rotm(R):
|
def zyx_from_rotm(R):
|
||||||
R = R.reshape(-1, 3, 3)
|
R = R.reshape(-1, 3, 3)
|
||||||
zyx = np.empty((R.shape[0], 3), dtype=R.dtype)
|
zyx = np.empty((R.shape[0], 3), dtype=R.dtype)
|
||||||
@ -121,14 +131,17 @@ def zyx_from_rotm(R):
|
|||||||
zyx[bidx, 2] = 0
|
zyx[bidx, 2] = 0
|
||||||
return zyx.squeeze()
|
return zyx.squeeze()
|
||||||
|
|
||||||
|
|
||||||
def rotm_from_xyz(xyz):
|
def rotm_from_xyz(xyz):
|
||||||
xyz = np.array(xyz, copy=False).reshape(-1, 3)
|
xyz = np.array(xyz, copy=False).reshape(-1, 3)
|
||||||
return (rot_x(xyz[:, 0]) @ rot_y(xyz[:, 1]) @ rot_z(xyz[:, 2])).squeeze()
|
return (rot_x(xyz[:, 0]) @ rot_y(xyz[:, 1]) @ rot_z(xyz[:, 2])).squeeze()
|
||||||
|
|
||||||
|
|
||||||
def rotm_from_zyx(zyx):
|
def rotm_from_zyx(zyx):
|
||||||
zyx = np.array(zyx, copy=False).reshape(-1, 3)
|
zyx = np.array(zyx, copy=False).reshape(-1, 3)
|
||||||
return (rot_z(zyx[:, 0]) @ rot_y(zyx[:, 1]) @ rot_x(zyx[:, 2])).squeeze()
|
return (rot_z(zyx[:, 0]) @ rot_y(zyx[:, 1]) @ rot_x(zyx[:, 2])).squeeze()
|
||||||
|
|
||||||
|
|
||||||
def rotm_from_quat(q):
|
def rotm_from_quat(q):
|
||||||
q = q.reshape(-1, 4)
|
q = q.reshape(-1, 4)
|
||||||
w, x, y, z = q[:, 0], q[:, 1], q[:, 2], q[:, 3]
|
w, x, y, z = q[:, 0], q[:, 1], q[:, 2], q[:, 3]
|
||||||
@ -140,6 +153,7 @@ def rotm_from_quat(q):
|
|||||||
R = R.transpose((2, 0, 1))
|
R = R.transpose((2, 0, 1))
|
||||||
return R.squeeze()
|
return R.squeeze()
|
||||||
|
|
||||||
|
|
||||||
def rotm_from_axisangle(a):
|
def rotm_from_axisangle(a):
|
||||||
# exponential
|
# exponential
|
||||||
a = a.reshape(-1, 3)
|
a = a.reshape(-1, 3)
|
||||||
@ -150,6 +164,7 @@ def rotm_from_axisangle(a):
|
|||||||
R = np.eye(3, dtype=a.dtype) + np.sin(phi) * A + (1 - np.cos(phi)) * A @ A
|
R = np.eye(3, dtype=a.dtype) + np.sin(phi) * A + (1 - np.cos(phi)) * A @ A
|
||||||
return R.squeeze()
|
return R.squeeze()
|
||||||
|
|
||||||
|
|
||||||
def rotm_from_lookat(dir, up=None):
|
def rotm_from_lookat(dir, up=None):
|
||||||
dir = dir.reshape(-1, 3)
|
dir = dir.reshape(-1, 3)
|
||||||
if up is None:
|
if up is None:
|
||||||
@ -175,6 +190,7 @@ def rotm_from_lookat(dir, up=None):
|
|||||||
R[:, 2, 2] = dir[:, 2]
|
R[:, 2, 2] = dir[:, 2]
|
||||||
return R.transpose(0, 2, 1).squeeze()
|
return R.transpose(0, 2, 1).squeeze()
|
||||||
|
|
||||||
|
|
||||||
def rotm_distance_identity(R0, R1):
|
def rotm_distance_identity(R0, R1):
|
||||||
# https://link.springer.com/article/10.1007%2Fs10851-009-0161-2
|
# https://link.springer.com/article/10.1007%2Fs10851-009-0161-2
|
||||||
# in [0, 2*sqrt(2)]
|
# in [0, 2*sqrt(2)]
|
||||||
@ -183,6 +199,7 @@ def rotm_distance_identity(R0, R1):
|
|||||||
dists = np.linalg.norm(np.eye(3, dtype=R0.dtype) - R0 @ R1.transpose(0, 2, 1), axis=(1, 2))
|
dists = np.linalg.norm(np.eye(3, dtype=R0.dtype) - R0 @ R1.transpose(0, 2, 1), axis=(1, 2))
|
||||||
return dists.squeeze()
|
return dists.squeeze()
|
||||||
|
|
||||||
|
|
||||||
def rotm_distance_geodesic(R0, R1):
|
def rotm_distance_geodesic(R0, R1):
|
||||||
# https://link.springer.com/article/10.1007%2Fs10851-009-0161-2
|
# https://link.springer.com/article/10.1007%2Fs10851-009-0161-2
|
||||||
# in [0, pi)
|
# in [0, pi)
|
||||||
@ -195,7 +212,6 @@ def rotm_distance_geodesic(R0, R1):
|
|||||||
return dists.squeeze()
|
return dists.squeeze()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def axisangle_from_rotm(R):
|
def axisangle_from_rotm(R):
|
||||||
# logarithm of rotation matrix
|
# logarithm of rotation matrix
|
||||||
# R = R.reshape(-1,3,3)
|
# R = R.reshape(-1,3,3)
|
||||||
@ -219,6 +235,7 @@ def axisangle_from_rotm(R):
|
|||||||
np.divide(omega, r, out=aa, where=r != 0)
|
np.divide(omega, r, out=aa, where=r != 0)
|
||||||
return aa.squeeze()
|
return aa.squeeze()
|
||||||
|
|
||||||
|
|
||||||
def axisangle_from_quat(q):
|
def axisangle_from_quat(q):
|
||||||
q = q.reshape(-1, 4)
|
q = q.reshape(-1, 4)
|
||||||
phi = 2 * np.arccos(q[:, 0])
|
phi = 2 * np.arccos(q[:, 0])
|
||||||
@ -231,6 +248,7 @@ def axisangle_from_quat(q):
|
|||||||
aa = a.astype(q.dtype)
|
aa = a.astype(q.dtype)
|
||||||
return aa.squeeze()
|
return aa.squeeze()
|
||||||
|
|
||||||
|
|
||||||
def axisangle_apply(aa, x):
|
def axisangle_apply(aa, x):
|
||||||
# working only with single aa and single x at the moment
|
# working only with single aa and single x at the moment
|
||||||
xshape = x.shape
|
xshape = x.shape
|
||||||
@ -247,10 +265,12 @@ def exp_so3(R):
|
|||||||
w = axisangle_from_rotm(R)
|
w = axisangle_from_rotm(R)
|
||||||
return w
|
return w
|
||||||
|
|
||||||
|
|
||||||
def log_so3(w):
|
def log_so3(w):
|
||||||
R = rotm_from_axisangle(w)
|
R = rotm_from_axisangle(w)
|
||||||
return R
|
return R
|
||||||
|
|
||||||
|
|
||||||
def exp_se3(R, t):
|
def exp_se3(R, t):
|
||||||
R = R.reshape(-1, 3, 3)
|
R = R.reshape(-1, 3, 3)
|
||||||
t = t.reshape(-1, 3)
|
t = t.reshape(-1, 3)
|
||||||
@ -269,6 +289,7 @@ def exp_se3(R, t):
|
|||||||
|
|
||||||
return v.squeeze()
|
return v.squeeze()
|
||||||
|
|
||||||
|
|
||||||
def log_se3(v):
|
def log_se3(v):
|
||||||
# v = (u, w)
|
# v = (u, w)
|
||||||
v = v.reshape(-1, 6)
|
v = v.reshape(-1, 6)
|
||||||
@ -298,6 +319,7 @@ def quat_from_rotm(R):
|
|||||||
q /= np.linalg.norm(q, axis=1, keepdims=True)
|
q /= np.linalg.norm(q, axis=1, keepdims=True)
|
||||||
return q.squeeze()
|
return q.squeeze()
|
||||||
|
|
||||||
|
|
||||||
def quat_from_axisangle(a):
|
def quat_from_axisangle(a):
|
||||||
a = a.reshape(-1, 3)
|
a = a.reshape(-1, 3)
|
||||||
phi = np.linalg.norm(a, axis=1)
|
phi = np.linalg.norm(a, axis=1)
|
||||||
@ -311,17 +333,20 @@ def quat_from_axisangle(a):
|
|||||||
q /= np.linalg.norm(q, axis=1).reshape(-1, 1)
|
q /= np.linalg.norm(q, axis=1).reshape(-1, 1)
|
||||||
return q.squeeze()
|
return q.squeeze()
|
||||||
|
|
||||||
|
|
||||||
def quat_identity(n=1, dtype=np.float32):
|
def quat_identity(n=1, dtype=np.float32):
|
||||||
q = np.zeros((n, 4), dtype=dtype)
|
q = np.zeros((n, 4), dtype=dtype)
|
||||||
q[:, 0] = 1
|
q[:, 0] = 1
|
||||||
return q.squeeze()
|
return q.squeeze()
|
||||||
|
|
||||||
|
|
||||||
def quat_conjugate(q):
|
def quat_conjugate(q):
|
||||||
shape = q.shape
|
shape = q.shape
|
||||||
q = q.reshape(-1, 4).copy()
|
q = q.reshape(-1, 4).copy()
|
||||||
q[:, 1:] *= -1
|
q[:, 1:] *= -1
|
||||||
return q.reshape(shape)
|
return q.reshape(shape)
|
||||||
|
|
||||||
|
|
||||||
def quat_product(q1, q2):
|
def quat_product(q1, q2):
|
||||||
# q1 . q2 is equivalent to R(q1) @ R(q2)
|
# q1 . q2 is equivalent to R(q1) @ R(q2)
|
||||||
shape = q1.shape
|
shape = q1.shape
|
||||||
@ -335,6 +360,7 @@ def quat_product(q1, q2):
|
|||||||
q[:, 3] = a1 * d2 + b1 * c2 - c1 * b2 + d1 * a2
|
q[:, 3] = a1 * d2 + b1 * c2 - c1 * b2 + d1 * a2
|
||||||
return q.squeeze()
|
return q.squeeze()
|
||||||
|
|
||||||
|
|
||||||
def quat_apply(q, x):
|
def quat_apply(q, x):
|
||||||
xshape = x.shape
|
xshape = x.shape
|
||||||
x = x.reshape(-1, 3)
|
x = x.reshape(-1, 3)
|
||||||
@ -367,6 +393,7 @@ def quat_random(rng=None, n=1):
|
|||||||
q /= np.linalg.norm(q, axis=1, keepdims=True)
|
q /= np.linalg.norm(q, axis=1, keepdims=True)
|
||||||
return q.squeeze()
|
return q.squeeze()
|
||||||
|
|
||||||
|
|
||||||
def quat_distance_angle(q0, q1):
|
def quat_distance_angle(q0, q1):
|
||||||
# https://math.stackexchange.com/questions/90081/quaternion-distance
|
# https://math.stackexchange.com/questions/90081/quaternion-distance
|
||||||
# https://link.springer.com/article/10.1007%2Fs10851-009-0161-2
|
# https://link.springer.com/article/10.1007%2Fs10851-009-0161-2
|
||||||
@ -375,6 +402,7 @@ def quat_distance_angle(q0, q1):
|
|||||||
dists = np.arccos(np.clip(2 * np.sum(q0 * q1, axis=1) ** 2 - 1, -1, 1))
|
dists = np.arccos(np.clip(2 * np.sum(q0 * q1, axis=1) ** 2 - 1, -1, 1))
|
||||||
return dists
|
return dists
|
||||||
|
|
||||||
|
|
||||||
def quat_distance_normdiff(q0, q1):
|
def quat_distance_normdiff(q0, q1):
|
||||||
# https://link.springer.com/article/10.1007%2Fs10851-009-0161-2
|
# https://link.springer.com/article/10.1007%2Fs10851-009-0161-2
|
||||||
# \phi_4
|
# \phi_4
|
||||||
@ -383,6 +411,7 @@ def quat_distance_normdiff(q0, q1):
|
|||||||
q1 = q1.reshape(-1, 4)
|
q1 = q1.reshape(-1, 4)
|
||||||
return 1 - np.sum(q0 * q1, axis=1) ** 2
|
return 1 - np.sum(q0 * q1, axis=1) ** 2
|
||||||
|
|
||||||
|
|
||||||
def quat_distance_mineucl(q0, q1):
|
def quat_distance_mineucl(q0, q1):
|
||||||
# https://link.springer.com/article/10.1007%2Fs10851-009-0161-2
|
# https://link.springer.com/article/10.1007%2Fs10851-009-0161-2
|
||||||
# http://users.cecs.anu.edu.au/~trumpf/pubs/Hartley_Trumpf_Dai_Li.pdf
|
# http://users.cecs.anu.edu.au/~trumpf/pubs/Hartley_Trumpf_Dai_Li.pdf
|
||||||
@ -392,6 +421,7 @@ def quat_distance_mineucl(q0, q1):
|
|||||||
diff1 = ((q0 + q1) ** 2).sum(axis=1)
|
diff1 = ((q0 + q1) ** 2).sum(axis=1)
|
||||||
return np.minimum(diff0, diff1)
|
return np.minimum(diff0, diff1)
|
||||||
|
|
||||||
|
|
||||||
def quat_slerp_space(q0, q1, num=100, endpoint=True):
|
def quat_slerp_space(q0, q1, num=100, endpoint=True):
|
||||||
q0 = q0.ravel()
|
q0 = q0.ravel()
|
||||||
q1 = q1.ravel()
|
q1 = q1.ravel()
|
||||||
@ -411,6 +441,7 @@ def quat_slerp_space(q0, q1, num=100, endpoint=True):
|
|||||||
s1 = np.sin(theta) / np.sin(theta0)
|
s1 = np.sin(theta) / np.sin(theta0)
|
||||||
return (s0 * q0) + (s1 * q1)
|
return (s0 * q0) + (s1 * q1)
|
||||||
|
|
||||||
|
|
||||||
def cart_to_spherical(x):
|
def cart_to_spherical(x):
|
||||||
shape = x.shape
|
shape = x.shape
|
||||||
x = x.reshape(-1, 3)
|
x = x.reshape(-1, 3)
|
||||||
@ -420,6 +451,7 @@ def cart_to_spherical(x):
|
|||||||
y[:, 2] = np.arctan2(x[:, 1], x[:, 0]) # phi
|
y[:, 2] = np.arctan2(x[:, 1], x[:, 0]) # phi
|
||||||
return y.reshape(shape)
|
return y.reshape(shape)
|
||||||
|
|
||||||
|
|
||||||
def spherical_to_cart(x):
|
def spherical_to_cart(x):
|
||||||
shape = x.shape
|
shape = x.shape
|
||||||
x = x.reshape(-1, 3)
|
x = x.reshape(-1, 3)
|
||||||
@ -429,6 +461,7 @@ def spherical_to_cart(x):
|
|||||||
y[:, 2] = x[:, 0] * np.cos(x[:, 1])
|
y[:, 2] = x[:, 0] * np.cos(x[:, 1])
|
||||||
return y.reshape(shape)
|
return y.reshape(shape)
|
||||||
|
|
||||||
|
|
||||||
def spherical_random(r=1, n=1):
|
def spherical_random(r=1, n=1):
|
||||||
# http://mathworld.wolfram.com/SpherePointPicking.html
|
# http://mathworld.wolfram.com/SpherePointPicking.html
|
||||||
# https://math.stackexchange.com/questions/1585975/how-to-generate-random-points-on-a-sphere
|
# https://math.stackexchange.com/questions/1585975/how-to-generate-random-points-on-a-sphere
|
||||||
@ -438,6 +471,7 @@ def spherical_random(r=1, n=1):
|
|||||||
x[:, 2] = np.arccos(2 * np.random.uniform(0, 1, size=(n,)) - 1)
|
x[:, 2] = np.arccos(2 * np.random.uniform(0, 1, size=(n,)) - 1)
|
||||||
return x.squeeze()
|
return x.squeeze()
|
||||||
|
|
||||||
|
|
||||||
def color_pcl(pcl, K, im, color_axis=0, as_int=True, invalid_color=[0, 0, 0]):
|
def color_pcl(pcl, K, im, color_axis=0, as_int=True, invalid_color=[0, 0, 0]):
|
||||||
uvd = K @ pcl.T
|
uvd = K @ pcl.T
|
||||||
uvd /= uvd[2]
|
uvd /= uvd[2]
|
||||||
@ -461,6 +495,7 @@ def color_pcl(pcl, K, im, color_axis=0, as_int=True, invalid_color=[0,0,0]):
|
|||||||
color = (255.0 * color).astype(np.int32)
|
color = (255.0 * color).astype(np.int32)
|
||||||
return color
|
return color
|
||||||
|
|
||||||
|
|
||||||
def center_pcl(pcl, robust=False, copy=False, axis=1):
|
def center_pcl(pcl, robust=False, copy=False, axis=1):
|
||||||
if copy:
|
if copy:
|
||||||
pcl = pcl.copy()
|
pcl = pcl.copy()
|
||||||
@ -470,13 +505,16 @@ def center_pcl(pcl, robust=False, copy=False, axis=1):
|
|||||||
mu = np.mean(pcl, axis=axis, keepdims=True)
|
mu = np.mean(pcl, axis=axis, keepdims=True)
|
||||||
return pcl - mu
|
return pcl - mu
|
||||||
|
|
||||||
|
|
||||||
def to_homogeneous(x):
|
def to_homogeneous(x):
|
||||||
# return np.hstack((x, np.ones((x.shape[0],1),dtype=x.dtype)))
|
# return np.hstack((x, np.ones((x.shape[0],1),dtype=x.dtype)))
|
||||||
return np.concatenate((x, np.ones((*x.shape[:-1], 1), dtype=x.dtype)), axis=-1)
|
return np.concatenate((x, np.ones((*x.shape[:-1], 1), dtype=x.dtype)), axis=-1)
|
||||||
|
|
||||||
|
|
||||||
def from_homogeneous(x):
|
def from_homogeneous(x):
|
||||||
return x[:, :-1] / x[:, -1]
|
return x[:, :-1] / x[:, -1]
|
||||||
|
|
||||||
|
|
||||||
def project_uvn(uv, Ki=None):
|
def project_uvn(uv, Ki=None):
|
||||||
if uv.shape[1] == 2:
|
if uv.shape[1] == 2:
|
||||||
uvn = to_homogeneous(uv)
|
uvn = to_homogeneous(uv)
|
||||||
@ -489,6 +527,7 @@ def project_uvn(uv, Ki=None):
|
|||||||
else:
|
else:
|
||||||
return uvn @ Ki.T
|
return uvn @ Ki.T
|
||||||
|
|
||||||
|
|
||||||
def project_uvd(uv, depth, K=np.eye(3), R=np.eye(3), t=np.zeros((3, 1)), ignore_negative_depth=True, return_uvn=False):
|
def project_uvd(uv, depth, K=np.eye(3), R=np.eye(3), t=np.zeros((3, 1)), ignore_negative_depth=True, return_uvn=False):
|
||||||
Ki = np.linalg.inv(K)
|
Ki = np.linalg.inv(K)
|
||||||
|
|
||||||
@ -510,6 +549,7 @@ def project_uvd(uv, depth, K=np.eye(3), R=np.eye(3), t=np.zeros((3,1)), ignore_n
|
|||||||
else:
|
else:
|
||||||
return xyz
|
return xyz
|
||||||
|
|
||||||
|
|
||||||
def project_depth(depth, K, R=np.eye(3, 3), t=np.zeros((3, 1)), ignore_negative_depth=True, return_uvn=False):
|
def project_depth(depth, K, R=np.eye(3, 3), t=np.zeros((3, 1)), ignore_negative_depth=True, return_uvn=False):
|
||||||
u, v = np.meshgrid(range(depth.shape[1]), range(depth.shape[0]))
|
u, v = np.meshgrid(range(depth.shape[1]), range(depth.shape[0]))
|
||||||
uv = np.hstack((u.reshape(-1, 1), v.reshape(-1, 1)))
|
uv = np.hstack((u.reshape(-1, 1), v.reshape(-1, 1)))
|
||||||
@ -540,12 +580,14 @@ def translation_to_cameracenter(R, t):
|
|||||||
C = -R.transpose(0, 2, 1) @ t
|
C = -R.transpose(0, 2, 1) @ t
|
||||||
return C.squeeze()
|
return C.squeeze()
|
||||||
|
|
||||||
|
|
||||||
def cameracenter_to_translation(R, C):
|
def cameracenter_to_translation(R, C):
|
||||||
C = C.reshape(-1, 3, 1)
|
C = C.reshape(-1, 3, 1)
|
||||||
R = R.reshape(-1, 3, 3)
|
R = R.reshape(-1, 3, 3)
|
||||||
t = -R @ C
|
t = -R @ C
|
||||||
return t.squeeze()
|
return t.squeeze()
|
||||||
|
|
||||||
|
|
||||||
def decompose_projection_matrix(P, return_t=True):
|
def decompose_projection_matrix(P, return_t=True):
|
||||||
if P.shape[0] != 3 or P.shape[1] != 4:
|
if P.shape[0] != 3 or P.shape[1] != 4:
|
||||||
raise Exception('P has to be 3x4')
|
raise Exception('P has to be 3x4')
|
||||||
@ -575,11 +617,11 @@ def compose_projection_matrix(K=np.eye(3), R=np.eye(3,3), t=np.zeros((3,1))):
|
|||||||
return K @ np.hstack((R, t.reshape((3, 1))))
|
return K @ np.hstack((R, t.reshape((3, 1))))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def point_plane_distance(pts, plane):
|
def point_plane_distance(pts, plane):
|
||||||
pts = pts.reshape(-1, 3)
|
pts = pts.reshape(-1, 3)
|
||||||
return np.abs(np.sum(plane[:3] * pts, axis=1) + plane[3]) / np.linalg.norm(plane[:3])
|
return np.abs(np.sum(plane[:3] * pts, axis=1) + plane[3]) / np.linalg.norm(plane[:3])
|
||||||
|
|
||||||
|
|
||||||
def fit_plane(pts):
|
def fit_plane(pts):
|
||||||
pts = pts.reshape(-1, 3)
|
pts = pts.reshape(-1, 3)
|
||||||
center = np.mean(pts, axis=0)
|
center = np.mean(pts, axis=0)
|
||||||
@ -590,6 +632,7 @@ def fit_plane(pts):
|
|||||||
plane = np.array([*vh[2], -vh[2].dot(center)])
|
plane = np.array([*vh[2], -vh[2].dot(center)])
|
||||||
return plane
|
return plane
|
||||||
|
|
||||||
|
|
||||||
def tetrahedron(dtype=np.float32):
|
def tetrahedron(dtype=np.float32):
|
||||||
verts = np.array([
|
verts = np.array([
|
||||||
(np.sqrt(8 / 9), 0, -1 / 3), (-np.sqrt(2 / 9), np.sqrt(2 / 3), -1 / 3),
|
(np.sqrt(8 / 9), 0, -1 / 3), (-np.sqrt(2 / 9), np.sqrt(2 / 3), -1 / 3),
|
||||||
@ -599,6 +642,7 @@ def tetrahedron(dtype=np.float32):
|
|||||||
normals /= np.linalg.norm(normals, axis=1).reshape(-1, 1)
|
normals /= np.linalg.norm(normals, axis=1).reshape(-1, 1)
|
||||||
return verts, faces, normals
|
return verts, faces, normals
|
||||||
|
|
||||||
|
|
||||||
def cube(dtype=np.float32):
|
def cube(dtype=np.float32):
|
||||||
verts = np.array([
|
verts = np.array([
|
||||||
[-0.5, -0.5, -0.5], [-0.5, 0.5, -0.5], [0.5, 0.5, -0.5], [0.5, -0.5, -0.5],
|
[-0.5, -0.5, -0.5], [-0.5, 0.5, -0.5], [0.5, 0.5, -0.5], [0.5, -0.5, -0.5],
|
||||||
@ -611,6 +655,7 @@ def cube(dtype=np.float32):
|
|||||||
normals /= np.linalg.norm(normals, axis=1).reshape(-1, 1)
|
normals /= np.linalg.norm(normals, axis=1).reshape(-1, 1)
|
||||||
return verts, faces, normals
|
return verts, faces, normals
|
||||||
|
|
||||||
|
|
||||||
def octahedron(dtype=np.float32):
|
def octahedron(dtype=np.float32):
|
||||||
verts = np.array([
|
verts = np.array([
|
||||||
(+1, 0, 0), (0, +1, 0), (0, 0, +1),
|
(+1, 0, 0), (0, +1, 0), (0, 0, +1),
|
||||||
@ -622,6 +667,7 @@ def octahedron(dtype=np.float32):
|
|||||||
normals /= np.linalg.norm(normals, axis=1).reshape(-1, 1)
|
normals /= np.linalg.norm(normals, axis=1).reshape(-1, 1)
|
||||||
return verts, faces, normals
|
return verts, faces, normals
|
||||||
|
|
||||||
|
|
||||||
def icosahedron(dtype=np.float32):
|
def icosahedron(dtype=np.float32):
|
||||||
p = (1 + np.sqrt(5)) / 2
|
p = (1 + np.sqrt(5)) / 2
|
||||||
verts = np.array([
|
verts = np.array([
|
||||||
@ -638,6 +684,7 @@ def icosahedron(dtype=np.float32):
|
|||||||
normals /= np.linalg.norm(normals, axis=1).reshape(-1, 1)
|
normals /= np.linalg.norm(normals, axis=1).reshape(-1, 1)
|
||||||
return verts, faces, normals
|
return verts, faces, normals
|
||||||
|
|
||||||
|
|
||||||
def xyplane(dtype=np.float32, z=0, interleaved=False):
|
def xyplane(dtype=np.float32, z=0, interleaved=False):
|
||||||
if interleaved:
|
if interleaved:
|
||||||
eps = 1e-6
|
eps = 1e-6
|
||||||
@ -652,6 +699,7 @@ def xyplane(dtype=np.float32, z=0, interleaved=False):
|
|||||||
normals[:, 2] = -1
|
normals[:, 2] = -1
|
||||||
return verts, faces, normals
|
return verts, faces, normals
|
||||||
|
|
||||||
|
|
||||||
def mesh_independent_verts(verts, faces, normals=None):
|
def mesh_independent_verts(verts, faces, normals=None):
|
||||||
new_verts = []
|
new_verts = []
|
||||||
new_normals = []
|
new_normals = []
|
||||||
@ -682,6 +730,7 @@ def stack_mesh(verts, faces):
|
|||||||
faces = np.vstack(mfaces)
|
faces = np.vstack(mfaces)
|
||||||
return verts, faces
|
return verts, faces
|
||||||
|
|
||||||
|
|
||||||
def normalize_mesh(verts):
|
def normalize_mesh(verts):
|
||||||
# all the verts have unit distance to the center (0,0,0)
|
# all the verts have unit distance to the center (0,0,0)
|
||||||
return verts / np.linalg.norm(verts, axis=1, keepdims=True)
|
return verts / np.linalg.norm(verts, axis=1, keepdims=True)
|
||||||
@ -700,6 +749,7 @@ def mesh_triangle_areas(verts, faces):
|
|||||||
t[:, 2] = (x[:, 0] * y[:, 1] - x[:, 1] * y[:, 0]);
|
t[:, 2] = (x[:, 0] * y[:, 1] - x[:, 1] * y[:, 0]);
|
||||||
return np.linalg.norm(t, axis=1) / 2
|
return np.linalg.norm(t, axis=1) / 2
|
||||||
|
|
||||||
|
|
||||||
def subdivde_mesh(verts_in, faces_in, n=1):
|
def subdivde_mesh(verts_in, faces_in, n=1):
|
||||||
for iter in range(n):
|
for iter in range(n):
|
||||||
verts = []
|
verts = []
|
||||||
|
@ -2,6 +2,7 @@ import numpy as np
|
|||||||
|
|
||||||
from . import utils
|
from . import utils
|
||||||
|
|
||||||
|
|
||||||
class StopWatch(utils.StopWatch):
|
class StopWatch(utils.StopWatch):
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
print('=' * 80)
|
print('=' * 80)
|
||||||
@ -14,13 +15,18 @@ class StopWatch(utils.StopWatch):
|
|||||||
print(f' [median] {median}')
|
print(f' [median] {median}')
|
||||||
print('=' * 80)
|
print('=' * 80)
|
||||||
|
|
||||||
|
|
||||||
GTIMER = StopWatch()
|
GTIMER = StopWatch()
|
||||||
|
|
||||||
|
|
||||||
def start(name):
|
def start(name):
|
||||||
GTIMER.start(name)
|
GTIMER.start(name)
|
||||||
|
|
||||||
|
|
||||||
def stop(name):
|
def stop(name):
|
||||||
GTIMER.stop(name)
|
GTIMER.stop(name)
|
||||||
|
|
||||||
|
|
||||||
class Ctx(object):
|
class Ctx(object):
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
self.name = name
|
self.name = name
|
||||||
|
@ -2,6 +2,7 @@ import struct
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import collections
|
import collections
|
||||||
|
|
||||||
|
|
||||||
def _write_ply_point(fp, x, y, z, color=None, normal=None, binary=False):
|
def _write_ply_point(fp, x, y, z, color=None, normal=None, binary=False):
|
||||||
args = [x, y, z]
|
args = [x, y, z]
|
||||||
if color is not None:
|
if color is not None:
|
||||||
@ -24,18 +25,21 @@ def _write_ply_point(fp, x,y,z, color=None, normal=None, binary=False):
|
|||||||
fmt += '\n'
|
fmt += '\n'
|
||||||
fp.write(fmt % tuple(args))
|
fp.write(fmt % tuple(args))
|
||||||
|
|
||||||
|
|
||||||
def _write_ply_triangle(fp, i0, i1, i2, binary):
|
def _write_ply_triangle(fp, i0, i1, i2, binary):
|
||||||
if binary:
|
if binary:
|
||||||
fp.write(struct.pack('<Biii', 3, i0, i1, i2))
|
fp.write(struct.pack('<Biii', 3, i0, i1, i2))
|
||||||
else:
|
else:
|
||||||
fp.write('3 %d %d %d\n' % (i0, i1, i2))
|
fp.write('3 %d %d %d\n' % (i0, i1, i2))
|
||||||
|
|
||||||
|
|
||||||
def _write_ply_header_line(fp, str, binary):
|
def _write_ply_header_line(fp, str, binary):
|
||||||
if binary:
|
if binary:
|
||||||
fp.write(str.encode())
|
fp.write(str.encode())
|
||||||
else:
|
else:
|
||||||
fp.write(str)
|
fp.write(str)
|
||||||
|
|
||||||
|
|
||||||
def write_ply(path, verts, trias=None, color=None, normals=None, binary=False):
|
def write_ply(path, verts, trias=None, color=None, normals=None, binary=False):
|
||||||
if verts.shape[1] != 3:
|
if verts.shape[1] != 3:
|
||||||
raise Exception('verts has to be of shape Nx3')
|
raise Exception('verts has to be of shape Nx3')
|
||||||
@ -88,6 +92,7 @@ def write_ply(path, verts, trias=None, color=None, normals=None, binary=False):
|
|||||||
for t in trias:
|
for t in trias:
|
||||||
_write_ply_triangle(fp, t[0], t[1], t[2], binary)
|
_write_ply_triangle(fp, t[0], t[1], t[2], binary)
|
||||||
|
|
||||||
|
|
||||||
def faces_to_triangles(faces):
|
def faces_to_triangles(faces):
|
||||||
new_faces = []
|
new_faces = []
|
||||||
for f in faces:
|
for f in faces:
|
||||||
@ -100,6 +105,7 @@ def faces_to_triangles(faces):
|
|||||||
raise Exception('unknown face count %d', f[0])
|
raise Exception('unknown face count %d', f[0])
|
||||||
return new_faces
|
return new_faces
|
||||||
|
|
||||||
|
|
||||||
def read_ply(path):
|
def read_ply(path):
|
||||||
with open(path, 'rb') as f:
|
with open(path, 'rb') as f:
|
||||||
# parse header
|
# parse header
|
||||||
@ -204,6 +210,7 @@ def _read_obj_split_f(s):
|
|||||||
nidx = -1
|
nidx = -1
|
||||||
return vidx, tidx, nidx
|
return vidx, tidx, nidx
|
||||||
|
|
||||||
|
|
||||||
def read_obj(path):
|
def read_obj(path):
|
||||||
with open(path, 'r') as fp:
|
with open(path, 'r') as fp:
|
||||||
lines = fp.readlines()
|
lines = fp.readlines()
|
||||||
|
12
co/metric.py
12
co/metric.py
@ -1,6 +1,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from . import geometry
|
from . import geometry
|
||||||
|
|
||||||
|
|
||||||
def _process_inputs(estimate, target, mask):
|
def _process_inputs(estimate, target, mask):
|
||||||
if estimate.shape != target.shape:
|
if estimate.shape != target.shape:
|
||||||
raise Exception('estimate and target have to be same shape')
|
raise Exception('estimate and target have to be same shape')
|
||||||
@ -12,19 +13,23 @@ def _process_inputs(estimate, target, mask):
|
|||||||
raise Exception('estimate and mask have to be same shape')
|
raise Exception('estimate and mask have to be same shape')
|
||||||
return estimate, target, mask
|
return estimate, target, mask
|
||||||
|
|
||||||
|
|
||||||
def mse(estimate, target, mask=None):
|
def mse(estimate, target, mask=None):
|
||||||
estimate, target, mask = _process_inputs(estimate, target, mask)
|
estimate, target, mask = _process_inputs(estimate, target, mask)
|
||||||
m = np.sum((estimate[mask] - target[mask]) ** 2) / mask.sum()
|
m = np.sum((estimate[mask] - target[mask]) ** 2) / mask.sum()
|
||||||
return m
|
return m
|
||||||
|
|
||||||
|
|
||||||
def rmse(estimate, target, mask=None):
|
def rmse(estimate, target, mask=None):
|
||||||
return np.sqrt(mse(estimate, target, mask))
|
return np.sqrt(mse(estimate, target, mask))
|
||||||
|
|
||||||
|
|
||||||
def mae(estimate, target, mask=None):
|
def mae(estimate, target, mask=None):
|
||||||
estimate, target, mask = _process_inputs(estimate, target, mask)
|
estimate, target, mask = _process_inputs(estimate, target, mask)
|
||||||
m = np.abs(estimate[mask] - target[mask]).sum() / mask.sum()
|
m = np.abs(estimate[mask] - target[mask]).sum() / mask.sum()
|
||||||
return m
|
return m
|
||||||
|
|
||||||
|
|
||||||
def outlier_fraction(estimate, target, mask=None, threshold=0):
|
def outlier_fraction(estimate, target, mask=None, threshold=0):
|
||||||
estimate, target, mask = _process_inputs(estimate, target, mask)
|
estimate, target, mask = _process_inputs(estimate, target, mask)
|
||||||
diff = np.abs(estimate[mask] - target[mask])
|
diff = np.abs(estimate[mask] - target[mask])
|
||||||
@ -52,6 +57,7 @@ class Metric(object):
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
return ', '.join([f'{self.str_prefix}{key}={value:.5f}' for key, value in self.get().items()])
|
return ', '.join([f'{self.str_prefix}{key}={value:.5f}' for key, value in self.get().items()])
|
||||||
|
|
||||||
|
|
||||||
class MultipleMetric(Metric):
|
class MultipleMetric(Metric):
|
||||||
def __init__(self, *metrics, **kwargs):
|
def __init__(self, *metrics, **kwargs):
|
||||||
self.metrics = [*metrics]
|
self.metrics = [*metrics]
|
||||||
@ -76,6 +82,7 @@ class MultipleMetric(Metric):
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
return '\n'.join([str(m) for m in self.metrics])
|
return '\n'.join([str(m) for m in self.metrics])
|
||||||
|
|
||||||
|
|
||||||
class BaseDistanceMetric(Metric):
|
class BaseDistanceMetric(Metric):
|
||||||
def __init__(self, name='', **kwargs):
|
def __init__(self, name='', **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@ -99,6 +106,7 @@ class BaseDistanceMetric(Metric):
|
|||||||
f'dist{self.name}_max': float(np.max(dists)),
|
f'dist{self.name}_max': float(np.max(dists)),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class DistanceMetric(BaseDistanceMetric):
|
class DistanceMetric(BaseDistanceMetric):
|
||||||
def __init__(self, vec_length, p=2, **kwargs):
|
def __init__(self, vec_length, p=2, **kwargs):
|
||||||
super().__init__(name=f'{p}', **kwargs)
|
super().__init__(name=f'{p}', **kwargs)
|
||||||
@ -115,6 +123,7 @@ class DistanceMetric(BaseDistanceMetric):
|
|||||||
dist = np.linalg.norm(es - ta, ord=self.p, axis=1)
|
dist = np.linalg.norm(es - ta, ord=self.p, axis=1)
|
||||||
self.dists.append(dist)
|
self.dists.append(dist)
|
||||||
|
|
||||||
|
|
||||||
class OutlierFractionMetric(DistanceMetric):
|
class OutlierFractionMetric(DistanceMetric):
|
||||||
def __init__(self, thresholds, *args, **kwargs):
|
def __init__(self, thresholds, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@ -128,6 +137,7 @@ class OutlierFractionMetric(DistanceMetric):
|
|||||||
ret[f'of{t}'] = float(ma.sum() / ma.size)
|
ret[f'of{t}'] = float(ma.sum() / ma.size)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
class RelativeDistanceMetric(BaseDistanceMetric):
|
class RelativeDistanceMetric(BaseDistanceMetric):
|
||||||
def __init__(self, vec_length, p=2, **kwargs):
|
def __init__(self, vec_length, p=2, **kwargs):
|
||||||
super().__init__(name=f'rel{p}', **kwargs)
|
super().__init__(name=f'rel{p}', **kwargs)
|
||||||
@ -144,6 +154,7 @@ class RelativeDistanceMetric(BaseDistanceMetric):
|
|||||||
dist = dist[ma != 0]
|
dist = dist[ma != 0]
|
||||||
self.dists.append(dist)
|
self.dists.append(dist)
|
||||||
|
|
||||||
|
|
||||||
class RotmDistanceMetric(BaseDistanceMetric):
|
class RotmDistanceMetric(BaseDistanceMetric):
|
||||||
def __init__(self, type='identity', **kwargs):
|
def __init__(self, type='identity', **kwargs):
|
||||||
super().__init__(name=type, **kwargs)
|
super().__init__(name=type, **kwargs)
|
||||||
@ -162,6 +173,7 @@ class RotmDistanceMetric(BaseDistanceMetric):
|
|||||||
else:
|
else:
|
||||||
raise Exception('invalid distance type')
|
raise Exception('invalid distance type')
|
||||||
|
|
||||||
|
|
||||||
class QuaternionDistanceMetric(BaseDistanceMetric):
|
class QuaternionDistanceMetric(BaseDistanceMetric):
|
||||||
def __init__(self, type='angle', **kwargs):
|
def __init__(self, type='angle', **kwargs):
|
||||||
super().__init__(name=type, **kwargs)
|
super().__init__(name=type, **kwargs)
|
||||||
|
@ -6,6 +6,7 @@ import matplotlib.pyplot as plt
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
||||||
def save(path, remove_axis=False, dpi=300, fig=None):
|
def save(path, remove_axis=False, dpi=300, fig=None):
|
||||||
if fig is None:
|
if fig is None:
|
||||||
fig = plt.gcf()
|
fig = plt.gcf()
|
||||||
@ -22,6 +23,7 @@ def save(path, remove_axis=False, dpi=300, fig=None):
|
|||||||
ax.yaxis.set_major_locator(plt.NullLocator())
|
ax.yaxis.set_major_locator(plt.NullLocator())
|
||||||
fig.savefig(path, dpi=dpi, bbox_inches='tight', pad_inches=0)
|
fig.savefig(path, dpi=dpi, bbox_inches='tight', pad_inches=0)
|
||||||
|
|
||||||
|
|
||||||
def color_map(im_, cmap='viridis', vmin=None, vmax=None):
|
def color_map(im_, cmap='viridis', vmin=None, vmax=None):
|
||||||
cm = plt.get_cmap(cmap)
|
cm = plt.get_cmap(cmap)
|
||||||
im = im_.copy()
|
im = im_.copy()
|
||||||
@ -38,6 +40,7 @@ def color_map(im_, cmap='viridis', vmin=None, vmax=None):
|
|||||||
im[mask, c] = 1
|
im[mask, c] = 1
|
||||||
return im
|
return im
|
||||||
|
|
||||||
|
|
||||||
def interactive_legend(leg=None, fig=None, all_axes=True):
|
def interactive_legend(leg=None, fig=None, all_axes=True):
|
||||||
if leg is None:
|
if leg is None:
|
||||||
leg = plt.legend()
|
leg = plt.legend()
|
||||||
@ -76,6 +79,7 @@ def interactive_legend(leg=None, fig=None, all_axes=True):
|
|||||||
|
|
||||||
fig.canvas.mpl_connect('pick_event', onpick)
|
fig.canvas.mpl_connect('pick_event', onpick)
|
||||||
|
|
||||||
|
|
||||||
def non_annoying_pause(interval, focus_figure=False):
|
def non_annoying_pause(interval, focus_figure=False):
|
||||||
# https://github.com/matplotlib/matplotlib/issues/11131
|
# https://github.com/matplotlib/matplotlib/issues/11131
|
||||||
backend = mpl.rcParams['backend']
|
backend = mpl.rcParams['backend']
|
||||||
@ -91,6 +95,7 @@ def non_annoying_pause(interval, focus_figure=False):
|
|||||||
return
|
return
|
||||||
time.sleep(interval)
|
time.sleep(interval)
|
||||||
|
|
||||||
|
|
||||||
def remove_all_ticks(fig=None):
|
def remove_all_ticks(fig=None):
|
||||||
if fig is None:
|
if fig is None:
|
||||||
fig = plt.gcf()
|
fig = plt.gcf()
|
||||||
|
@ -3,6 +3,7 @@ import matplotlib.pyplot as plt
|
|||||||
|
|
||||||
from . import geometry
|
from . import geometry
|
||||||
|
|
||||||
|
|
||||||
def image_matrix(ims, bgval=0):
|
def image_matrix(ims, bgval=0):
|
||||||
n = ims.shape[0]
|
n = ims.shape[0]
|
||||||
m = int(np.ceil(np.sqrt(n)))
|
m = int(np.ceil(np.sqrt(n)))
|
||||||
@ -18,6 +19,7 @@ def image_matrix(ims, bgval=0):
|
|||||||
idx += 1
|
idx += 1
|
||||||
return mat
|
return mat
|
||||||
|
|
||||||
|
|
||||||
def image_cat(ims, vertical=False):
|
def image_cat(ims, vertical=False):
|
||||||
offx = [0]
|
offx = [0]
|
||||||
offy = [0]
|
offy = [0]
|
||||||
@ -40,15 +42,18 @@ def image_cat(ims, vertical=False):
|
|||||||
|
|
||||||
return im, offx, offy
|
return im, offx, offy
|
||||||
|
|
||||||
|
|
||||||
def line(li, h, w, ax=None, *args, **kwargs):
|
def line(li, h, w, ax=None, *args, **kwargs):
|
||||||
if ax is None:
|
if ax is None:
|
||||||
ax = plt.gca()
|
ax = plt.gca()
|
||||||
xs = (-li[2] - li[1] * np.array((0, h - 1))) / li[0]
|
xs = (-li[2] - li[1] * np.array((0, h - 1))) / li[0]
|
||||||
ys = (-li[2] - li[0] * np.array((0, w - 1))) / li[1]
|
ys = (-li[2] - li[0] * np.array((0, w - 1))) / li[1]
|
||||||
pts = np.array([(0, ys[0]), (w - 1, ys[1]), (xs[0], 0), (xs[1], h - 1)])
|
pts = np.array([(0, ys[0]), (w - 1, ys[1]), (xs[0], 0), (xs[1], h - 1)])
|
||||||
pts = pts[np.logical_and(np.logical_and(pts[:,0] >= 0, pts[:,0] < w), np.logical_and(pts[:,1] >= 0, pts[:,1] < h))]
|
pts = pts[
|
||||||
|
np.logical_and(np.logical_and(pts[:, 0] >= 0, pts[:, 0] < w), np.logical_and(pts[:, 1] >= 0, pts[:, 1] < h))]
|
||||||
ax.plot(pts[:, 0], pts[:, 1], *args, **kwargs)
|
ax.plot(pts[:, 0], pts[:, 1], *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def depthshow(depth, *args, ax=None, **kwargs):
|
def depthshow(depth, *args, ax=None, **kwargs):
|
||||||
if ax is None:
|
if ax is None:
|
||||||
ax = plt.gca()
|
ax = plt.gca()
|
||||||
|
22
co/plt3d.py
22
co/plt3d.py
@ -4,12 +4,15 @@ from mpl_toolkits.mplot3d import Axes3D
|
|||||||
|
|
||||||
from . import geometry
|
from . import geometry
|
||||||
|
|
||||||
|
|
||||||
def ax3d(fig=None):
|
def ax3d(fig=None):
|
||||||
if fig is None:
|
if fig is None:
|
||||||
fig = plt.gcf()
|
fig = plt.gcf()
|
||||||
return fig.add_subplot(111, projection='3d')
|
return fig.add_subplot(111, projection='3d')
|
||||||
|
|
||||||
def plot_camera(ax=None, R=np.eye(3), t=np.zeros((3,)), size=25, marker_C='.', color='b', linestyle='-', linewidth=0.1, label=None, **kwargs):
|
|
||||||
|
def plot_camera(ax=None, R=np.eye(3), t=np.zeros((3,)), size=25, marker_C='.', color='b', linestyle='-', linewidth=0.1,
|
||||||
|
label=None, **kwargs):
|
||||||
if ax is None:
|
if ax is None:
|
||||||
ax = plt.gca()
|
ax = plt.gca()
|
||||||
C0 = geometry.translation_to_cameracenter(R, t).ravel()
|
C0 = geometry.translation_to_cameracenter(R, t).ravel()
|
||||||
@ -20,11 +23,18 @@ def plot_camera(ax=None, R=np.eye(3), t=np.zeros((3,)), size=25, marker_C='.', c
|
|||||||
|
|
||||||
if marker_C != '':
|
if marker_C != '':
|
||||||
ax.plot([C0[0]], [C0[1]], [C0[2]], marker=marker_C, color=color, label=label, **kwargs)
|
ax.plot([C0[0]], [C0[1]], [C0[2]], marker=marker_C, color=color, label=label, **kwargs)
|
||||||
ax.plot([C0[0], C1[0]], [C0[1], C1[1]], [C0[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs)
|
ax.plot([C0[0], C1[0]], [C0[1], C1[1]], [C0[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle,
|
||||||
ax.plot([C0[0], C2[0]], [C0[1], C2[1]], [C0[2], C2[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs)
|
linewidth=linewidth, **kwargs)
|
||||||
ax.plot([C0[0], C3[0]], [C0[1], C3[1]], [C0[2], C3[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs)
|
ax.plot([C0[0], C2[0]], [C0[1], C2[1]], [C0[2], C2[2]], color=color, label='_nolegend_', linestyle=linestyle,
|
||||||
ax.plot([C0[0], C4[0]], [C0[1], C4[1]], [C0[2], C4[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs)
|
linewidth=linewidth, **kwargs)
|
||||||
ax.plot([C1[0], C2[0], C3[0], C4[0], C1[0]], [C1[1], C2[1], C3[1], C4[1], C1[1]], [C1[2], C2[2], C3[2], C4[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs)
|
ax.plot([C0[0], C3[0]], [C0[1], C3[1]], [C0[2], C3[2]], color=color, label='_nolegend_', linestyle=linestyle,
|
||||||
|
linewidth=linewidth, **kwargs)
|
||||||
|
ax.plot([C0[0], C4[0]], [C0[1], C4[1]], [C0[2], C4[2]], color=color, label='_nolegend_', linestyle=linestyle,
|
||||||
|
linewidth=linewidth, **kwargs)
|
||||||
|
ax.plot([C1[0], C2[0], C3[0], C4[0], C1[0]], [C1[1], C2[1], C3[1], C4[1], C1[1]],
|
||||||
|
[C1[2], C2[2], C3[2], C4[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle,
|
||||||
|
linewidth=linewidth, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def axis_equal(ax=None):
|
def axis_equal(ax=None):
|
||||||
if ax is None:
|
if ax is None:
|
||||||
|
24
co/table.py
24
co/table.py
@ -3,6 +3,7 @@ import pandas as pd
|
|||||||
import enum
|
import enum
|
||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
|
|
||||||
class Table(object):
|
class Table(object):
|
||||||
def __init__(self, n_cols):
|
def __init__(self, n_cols):
|
||||||
self.n_cols = n_cols
|
self.n_cols = n_cols
|
||||||
@ -44,6 +45,7 @@ class Table(object):
|
|||||||
for c in range(len(cols)):
|
for c in range(len(cols)):
|
||||||
self.rows[row + r].cells[col + c] = Cell(data[r][c], fmt)
|
self.rows[row + r].cells[col + c] = Cell(data[r][c], fmt)
|
||||||
|
|
||||||
|
|
||||||
class Row(object):
|
class Row(object):
|
||||||
def __init__(self, cells, pre_separator=None, post_separator=None):
|
def __init__(self, cells, pre_separator=None, post_separator=None):
|
||||||
self.cells = cells
|
self.cells = cells
|
||||||
@ -61,7 +63,6 @@ class Row(object):
|
|||||||
return sum([c.span for c in self.cells])
|
return sum([c.span for c in self.cells])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Color(object):
|
class Color(object):
|
||||||
def __init__(self, color=(0, 0, 0), fmt='rgb'):
|
def __init__(self, color=(0, 0, 0), fmt='rgb'):
|
||||||
if fmt == 'rgb':
|
if fmt == 'rgb':
|
||||||
@ -93,6 +94,7 @@ class CellFormat(object):
|
|||||||
self.bgcolor = bgcolor
|
self.bgcolor = bgcolor
|
||||||
self.bold = bold
|
self.bold = bold
|
||||||
|
|
||||||
|
|
||||||
class Cell(object):
|
class Cell(object):
|
||||||
def __init__(self, data=None, fmt=None, span=1, align=None):
|
def __init__(self, data=None, fmt=None, span=1, align=None):
|
||||||
self.data = data
|
self.data = data
|
||||||
@ -105,6 +107,7 @@ class Cell(object):
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.fmt.fmt % self.data
|
return self.fmt.fmt % self.data
|
||||||
|
|
||||||
|
|
||||||
class Separator(enum.Enum):
|
class Separator(enum.Enum):
|
||||||
HEAD = 1
|
HEAD = 1
|
||||||
BOTTOM = 2
|
BOTTOM = 2
|
||||||
@ -143,6 +146,7 @@ class Renderer(object):
|
|||||||
with open(path, 'w') as fp:
|
with open(path, 'w') as fp:
|
||||||
fp.write(txt)
|
fp.write(txt)
|
||||||
|
|
||||||
|
|
||||||
class TerminalRenderer(Renderer):
|
class TerminalRenderer(Renderer):
|
||||||
def __init__(self, col_sep=' '):
|
def __init__(self, col_sep=' '):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -207,6 +211,7 @@ class TerminalRenderer(Renderer):
|
|||||||
lines.append(sepline)
|
lines.append(sepline)
|
||||||
return '\n'.join(lines)
|
return '\n'.join(lines)
|
||||||
|
|
||||||
|
|
||||||
class MarkdownRenderer(TerminalRenderer):
|
class MarkdownRenderer(TerminalRenderer):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(col_sep='|')
|
super().__init__(col_sep='|')
|
||||||
@ -313,6 +318,7 @@ class LatexRenderer(Renderer):
|
|||||||
lines.append('\\end{tabular}')
|
lines.append('\\end{tabular}')
|
||||||
return '\n'.join(lines)
|
return '\n'.join(lines)
|
||||||
|
|
||||||
|
|
||||||
class HtmlRenderer(Renderer):
|
class HtmlRenderer(Renderer):
|
||||||
def __init__(self, html_class='result_table'):
|
def __init__(self, html_class='result_table'):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -331,10 +337,14 @@ class HtmlRenderer(Renderer):
|
|||||||
color = cell.fmt.bgcolor.as_RGB()
|
color = cell.fmt.bgcolor.as_RGB()
|
||||||
styles.append(f'background-color: rgb({color[0]},{color[1]},{color[2]});')
|
styles.append(f'background-color: rgb({color[0]},{color[1]},{color[2]});')
|
||||||
align = table.get_cell_align(row, col)
|
align = table.get_cell_align(row, col)
|
||||||
if align == 'l': align = 'left'
|
if align == 'l':
|
||||||
elif align == 'r': align = 'right'
|
align = 'left'
|
||||||
elif align == 'c': align = 'center'
|
elif align == 'r':
|
||||||
else: raise Exception('invalid align')
|
align = 'right'
|
||||||
|
elif align == 'c':
|
||||||
|
align = 'center'
|
||||||
|
else:
|
||||||
|
raise Exception('invalid align')
|
||||||
styles.append(f'text-align: {align};')
|
styles.append(f'text-align: {align};')
|
||||||
row = table.rows[row]
|
row = table.rows[row]
|
||||||
if row.pre_separator is not None:
|
if row.pre_separator is not None:
|
||||||
@ -365,7 +375,8 @@ class HtmlRenderer(Renderer):
|
|||||||
return '\n'.join(lines)
|
return '\n'.join(lines)
|
||||||
|
|
||||||
|
|
||||||
def pandas_to_table(rowname, colname, valname, data, val_cell_fmt=CellFormat(fmt='%.4f'), best_val_cell_fmt=CellFormat(fmt='%.4f', bold=True), best_is_max=[]):
|
def pandas_to_table(rowname, colname, valname, data, val_cell_fmt=CellFormat(fmt='%.4f'),
|
||||||
|
best_val_cell_fmt=CellFormat(fmt='%.4f', bold=True), best_is_max=[]):
|
||||||
rnames = data[rowname].unique()
|
rnames = data[rowname].unique()
|
||||||
cnames = data[colname].unique()
|
cnames = data[colname].unique()
|
||||||
tab = Table(1 + len(cnames))
|
tab = Table(1 + len(cnames))
|
||||||
@ -395,7 +406,6 @@ def pandas_to_table(rowname, colname, valname, data, val_cell_fmt=CellFormat(fmt
|
|||||||
return tab
|
return tab
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# df = pd.read_pickle('full.df')
|
# df = pd.read_pickle('full.df')
|
||||||
# best_is_max = ['movF0.5', 'movF1.0']
|
# best_is_max = ['movF0.5', 'movF1.0']
|
||||||
|
@ -8,6 +8,7 @@ import re
|
|||||||
import pickle
|
import pickle
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
|
|
||||||
def str2bool(v):
|
def str2bool(v):
|
||||||
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
||||||
return True
|
return True
|
||||||
@ -16,6 +17,7 @@ def str2bool(v):
|
|||||||
else:
|
else:
|
||||||
raise argparse.ArgumentTypeError('Boolean value expected.')
|
raise argparse.ArgumentTypeError('Boolean value expected.')
|
||||||
|
|
||||||
|
|
||||||
class StopWatch(object):
|
class StopWatch(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.timings = OrderedDict()
|
self.timings = OrderedDict()
|
||||||
@ -40,9 +42,11 @@ class StopWatch(object):
|
|||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
|
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
|
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
|
||||||
|
|
||||||
|
|
||||||
class ETA(object):
|
class ETA(object):
|
||||||
def __init__(self, length):
|
def __init__(self, length):
|
||||||
self.length = length
|
self.length = length
|
||||||
@ -76,6 +80,7 @@ class ETA(object):
|
|||||||
def get_remaining_time_str(self):
|
def get_remaining_time_str(self):
|
||||||
return self.format_time(self.get_remaining_time())
|
return self.format_time(self.get_remaining_time())
|
||||||
|
|
||||||
|
|
||||||
def git_hash(cwd=None):
|
def git_hash(cwd=None):
|
||||||
ret = subprocess.run(['git', 'describe', '--always'], cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
ret = subprocess.run(['git', 'describe', '--always'], cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||||
hash = ret.stdout
|
hash = ret.stdout
|
||||||
@ -83,4 +88,3 @@ def git_hash(cwd=None):
|
|||||||
return hash.decode().strip()
|
return hash.decode().strip()
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -30,6 +30,7 @@ def get_patterns(path='syn', imsizes=[], crop=True):
|
|||||||
|
|
||||||
return patterns
|
return patterns
|
||||||
|
|
||||||
|
|
||||||
def get_rotation_matrix(v0, v1):
|
def get_rotation_matrix(v0, v1):
|
||||||
v0 = v0 / np.linalg.norm(v0)
|
v0 = v0 / np.linalg.norm(v0)
|
||||||
v1 = v1 / np.linalg.norm(v1)
|
v1 = v1 / np.linalg.norm(v1)
|
||||||
@ -44,7 +45,6 @@ def get_rotation_matrix(v0, v1):
|
|||||||
|
|
||||||
|
|
||||||
def augment_image(img, rng, disp=None, grad=None, max_shift=64, max_blur=1.5, max_noise=10.0, max_sp_noise=0.001):
|
def augment_image(img, rng, disp=None, grad=None, max_shift=64, max_blur=1.5, max_noise=10.0, max_sp_noise=0.001):
|
||||||
|
|
||||||
# get min/max values of image
|
# get min/max values of image
|
||||||
min_val = np.min(img)
|
min_val = np.min(img)
|
||||||
max_val = np.max(img)
|
max_val = np.max(img)
|
||||||
@ -64,8 +64,10 @@ def augment_image(img,rng,disp=None,grad=None,max_shift=64,max_blur=1.5,max_nois
|
|||||||
shear = 0
|
shear = 0
|
||||||
shift = 0
|
shift = 0
|
||||||
shear_correction = 0
|
shear_correction = 0
|
||||||
if rng.uniform(0,1)<0.75: shear = rng.uniform(-max_shift,max_shift) # shear with 75% probability
|
if rng.uniform(0, 1) < 0.75:
|
||||||
else: shift = rng.uniform(0,max_shift) # shift with 25% probability
|
shear = rng.uniform(-max_shift, max_shift) # shear with 75% probability
|
||||||
|
else:
|
||||||
|
shift = rng.uniform(0, max_shift) # shift with 25% probability
|
||||||
if shear < 0: shear_correction = -shear
|
if shear < 0: shear_correction = -shear
|
||||||
|
|
||||||
# affine transformation
|
# affine transformation
|
||||||
|
@ -10,14 +10,15 @@ import cv2
|
|||||||
import os
|
import os
|
||||||
import collections
|
import collections
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.path.append('../')
|
sys.path.append('../')
|
||||||
import renderer
|
import renderer
|
||||||
import co
|
import co
|
||||||
from commons import get_patterns, get_rotation_matrix
|
from commons import get_patterns, get_rotation_matrix
|
||||||
from lcn import lcn
|
from lcn import lcn
|
||||||
|
|
||||||
def get_objs(shapenet_dir, obj_classes, num_perclass=100):
|
|
||||||
|
|
||||||
|
def get_objs(shapenet_dir, obj_classes, num_perclass=100):
|
||||||
shapenet = {'chair': '03001627',
|
shapenet = {'chair': '03001627',
|
||||||
'airplane': '02691156',
|
'airplane': '02691156',
|
||||||
'car': '02958343',
|
'car': '02958343',
|
||||||
@ -88,7 +89,6 @@ def get_mesh(rng, min_z=0):
|
|||||||
|
|
||||||
|
|
||||||
def create_data(out_root, idx, n_samples, imsize, patterns, K, baseline, blend_im, noise, track_length=4):
|
def create_data(out_root, idx, n_samples, imsize, patterns, K, baseline, blend_im, noise, track_length=4):
|
||||||
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
rng = np.random.RandomState()
|
rng = np.random.RandomState()
|
||||||
|
|
||||||
@ -98,7 +98,6 @@ def create_data(out_root, idx, n_samples, imsize, patterns, K, baseline, blend_i
|
|||||||
data = renderer.PyRenderInput(verts=verts.copy(), colors=colors.copy(), normals=normals.copy(), faces=faces.copy())
|
data = renderer.PyRenderInput(verts=verts.copy(), colors=colors.copy(), normals=normals.copy(), faces=faces.copy())
|
||||||
print(f'loading mesh for sample {idx + 1}/{n_samples} took {time.time() - tic}[s]')
|
print(f'loading mesh for sample {idx + 1}/{n_samples} took {time.time() - tic}[s]')
|
||||||
|
|
||||||
|
|
||||||
# let the camera point to the center
|
# let the camera point to the center
|
||||||
center = np.array([0, 0, 3], dtype=np.float32)
|
center = np.array([0, 0, 3], dtype=np.float32)
|
||||||
|
|
||||||
@ -148,7 +147,6 @@ def create_data(out_root, idx, n_samples, imsize, patterns, K, baseline, blend_i
|
|||||||
cams.append(renderer.PyCamera(fx, fy, px, py, Rcam, tcam, im_width, im_height))
|
cams.append(renderer.PyCamera(fx, fy, px, py, Rcam, tcam, im_width, im_height))
|
||||||
projs.append(renderer.PyCamera(fx, fy, px, py, Rproj, tproj, im_width, im_height))
|
projs.append(renderer.PyCamera(fx, fy, px, py, Rproj, tproj, im_width, im_height))
|
||||||
|
|
||||||
|
|
||||||
for s, cam, proj, pattern in zip(itertools.count(), cams, projs, patterns):
|
for s, cam, proj, pattern in zip(itertools.count(), cams, projs, patterns):
|
||||||
fl = K[0, 0] / (2 ** s)
|
fl = K[0, 0] / (2 ** s)
|
||||||
|
|
||||||
@ -203,7 +201,6 @@ def create_data(out_root, idx, n_samples, imsize, patterns, K, baseline, blend_i
|
|||||||
print(f'create sample {idx + 1}/{n_samples} took {time.time() - tic}[s]')
|
print(f'create sample {idx + 1}/{n_samples} took {time.time() - tic}[s]')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
np.random.seed(42)
|
np.random.seed(42)
|
||||||
@ -236,7 +233,8 @@ if __name__=='__main__':
|
|||||||
imsize = (488, 648)
|
imsize = (488, 648)
|
||||||
imsizes = [(imsize[0] // (2 ** s), imsize[1] // (2 ** s)) for s in range(4)]
|
imsizes = [(imsize[0] // (2 ** s), imsize[1] // (2 ** s)) for s in range(4)]
|
||||||
# K = np.array([[567.6, 0, 324.7], [0, 570.2, 250.1], [0 ,0, 1]], dtype=np.float32)
|
# K = np.array([[567.6, 0, 324.7], [0, 570.2, 250.1], [0 ,0, 1]], dtype=np.float32)
|
||||||
K = np.array([[1929.5936336276382, 0, 113.66561071478046], [0, 1911.2517985448746, 473.70108079885887], [0 ,0, 1]], dtype=np.float32)
|
K = np.array([[1929.5936336276382, 0, 113.66561071478046], [0, 1911.2517985448746, 473.70108079885887], [0, 0, 1]],
|
||||||
|
dtype=np.float32)
|
||||||
focal_lengths = [K[0, 0] / (2 ** s) for s in range(4)]
|
focal_lengths = [K[0, 0] / (2 ** s) for s in range(4)]
|
||||||
baseline = 0.075
|
baseline = 0.075
|
||||||
blend_im = 0.6
|
blend_im = 0.6
|
||||||
|
@ -21,11 +21,13 @@ from .commons import get_patterns, augment_image
|
|||||||
|
|
||||||
from mpl_toolkits.mplot3d import Axes3D
|
from mpl_toolkits.mplot3d import Axes3D
|
||||||
|
|
||||||
|
|
||||||
class TrackSynDataset(torchext.BaseDataset):
|
class TrackSynDataset(torchext.BaseDataset):
|
||||||
'''
|
'''
|
||||||
Load locally saved synthetic dataset
|
Load locally saved synthetic dataset
|
||||||
Please run ./create_syn_data.sh to generate the dataset
|
Please run ./create_syn_data.sh to generate the dataset
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, settings_path, sample_paths, track_length=2, train=True, data_aug=False):
|
def __init__(self, settings_path, sample_paths, track_length=2, train=True, data_aug=False):
|
||||||
super().__init__(train=train)
|
super().__init__(train=train)
|
||||||
|
|
||||||
@ -111,7 +113,8 @@ class TrackSynDataset(torchext.BaseDataset):
|
|||||||
img_aug_, disp_aug_, grad_aug_ = augment_image(img[i, 0], rng,
|
img_aug_, disp_aug_, grad_aug_ = augment_image(img[i, 0], rng,
|
||||||
disp=disp[i, 0], grad=grad[i, 0],
|
disp=disp[i, 0], grad=grad[i, 0],
|
||||||
max_shift=self.max_shift, max_blur=self.max_blur,
|
max_shift=self.max_shift, max_blur=self.max_blur,
|
||||||
max_noise=self.max_noise, max_sp_noise=self.max_sp_noise)
|
max_noise=self.max_noise,
|
||||||
|
max_sp_noise=self.max_sp_noise)
|
||||||
img_aug[i] = img_aug_[None].astype(np.float32)
|
img_aug[i] = img_aug_[None].astype(np.float32)
|
||||||
disp_aug[i] = disp_aug_[None].astype(np.float32)
|
disp_aug[i] = disp_aug_[None].astype(np.float32)
|
||||||
grad_aug[i] = grad_aug_[None].astype(np.float32)
|
grad_aug[i] = grad_aug_[None].astype(np.float32)
|
||||||
@ -133,7 +136,6 @@ class TrackSynDataset(torchext.BaseDataset):
|
|||||||
if key != 'blend_im' and key != 'id':
|
if key != 'blend_im' and key != 'id':
|
||||||
ret[key] = val[0]
|
ret[key] = val[0]
|
||||||
|
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def getK(self, sidx=0):
|
def getK(self, sidx=0):
|
||||||
@ -142,7 +144,5 @@ class TrackSynDataset(torchext.BaseDataset):
|
|||||||
return K
|
return K
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -355,6 +355,7 @@ body.cython { font-family: courier; font-size: 12; }
|
|||||||
.cython .vi { color: #19177C } /* Name.Variable.Instance */
|
.cython .vi { color: #19177C } /* Name.Variable.Instance */
|
||||||
.cython .vm { color: #19177C } /* Name.Variable.Magic */
|
.cython .vm { color: #19177C } /* Name.Variable.Magic */
|
||||||
.cython .il { color: #666666 } /* Literal.Number.Integer.Long */
|
.cython .il { color: #666666 } /* Literal.Number.Integer.Long */
|
||||||
|
|
||||||
</style>
|
</style>
|
||||||
</head>
|
</head>
|
||||||
<body class="cython">
|
<body class="cython">
|
||||||
@ -364,8 +365,13 @@ body.cython { font-family: courier; font-size: 12; }
|
|||||||
Click on a line that starts with a "<code>+</code>" to see the C code that Cython generated for it.
|
Click on a line that starts with a "<code>+</code>" to see the C code that Cython generated for it.
|
||||||
</p>
|
</p>
|
||||||
<p>Raw output: <a href="lcn.c">lcn.c</a></p>
|
<p>Raw output: <a href="lcn.c">lcn.c</a></p>
|
||||||
<div class="cython"><pre class="cython line score-16" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">01</span>: <span class="k">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span></pre>
|
<div class="cython">
|
||||||
<pre class='cython code score-16 '> __pyx_t_1 = <span class='pyx_c_api'>__Pyx_Import</span>(__pyx_n_s_numpy, 0, -1);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1, __pyx_L1_error)</span>
|
<pre class="cython line score-16"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">01</span>: <span class="k">import</span> <span class="nn">numpy</span> <span
|
||||||
|
class="k">as</span> <span class="nn">np</span></pre>
|
||||||
|
<pre class='cython code score-16 '> __pyx_t_1 = <span class='pyx_c_api'>__Pyx_Import</span>(__pyx_n_s_numpy, 0, -1);<span
|
||||||
|
class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1, __pyx_L1_error)</span>
|
||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
||||||
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_np, __pyx_t_1) < 0) <span class='error_goto'>__PYX_ERR(0, 1, __pyx_L1_error)</span>
|
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_np, __pyx_t_1) < 0) <span class='error_goto'>__PYX_ERR(0, 1, __pyx_L1_error)</span>
|
||||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
|
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
|
||||||
@ -374,21 +380,38 @@ body.cython { font-family: courier; font-size: 12; }
|
|||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
||||||
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_test, __pyx_t_1) < 0) <span class='error_goto'>__PYX_ERR(0, 1, __pyx_L1_error)</span>
|
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_test, __pyx_t_1) < 0) <span class='error_goto'>__PYX_ERR(0, 1, __pyx_L1_error)</span>
|
||||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
|
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
|
||||||
</pre><pre class="cython line score-0"> <span class="">02</span>: <span class="k">cimport</span> <span class="nn">cython</span></pre>
|
</pre>
|
||||||
|
<pre class="cython line score-0"> <span class="">02</span>: <span class="k">cimport</span> <span class="nn">cython</span></pre>
|
||||||
<pre class="cython line score-0"> <span class="">03</span>: </pre>
|
<pre class="cython line score-0"> <span class="">03</span>: </pre>
|
||||||
<pre class="cython line score-0"> <span class="">04</span>: <span class="c"># use c square root function</span></pre>
|
<pre class="cython line score-0"> <span class="">04</span>: <span class="c"># use c square root function</span></pre>
|
||||||
<pre class="cython line score-0"> <span class="">05</span>: <span class="k">cdef</span> <span class="kr">extern</span> <span class="k">from</span> <span class="s">"math.h"</span><span class="p">:</span></pre>
|
<pre class="cython line score-0"> <span class="">05</span>: <span class="k">cdef</span> <span
|
||||||
<pre class="cython line score-0"> <span class="">06</span>: <span class="nb">float</span> <span class="n">sqrt</span><span class="p">(</span><span class="nb">float</span> <span class="n">x</span><span class="p">)</span></pre>
|
class="kr">extern</span> <span class="k">from</span> <span class="s">"math.h"</span><span
|
||||||
|
class="p">:</span></pre>
|
||||||
|
<pre class="cython line score-0"> <span class="">06</span>: <span class="nb">float</span> <span class="n">sqrt</span><span
|
||||||
|
class="p">(</span><span class="nb">float</span> <span class="n">x</span><span class="p">)</span></pre>
|
||||||
<pre class="cython line score-0"> <span class="">07</span>: </pre>
|
<pre class="cython line score-0"> <span class="">07</span>: </pre>
|
||||||
<pre class="cython line score-0"> <span class="">08</span>: <span class="nd">@cython</span><span class="o">.</span><span class="n">boundscheck</span><span class="p">(</span><span class="bp">False</span><span class="p">)</span></pre>
|
<pre class="cython line score-0"> <span class="">08</span>: <span class="nd">@cython</span><span
|
||||||
<pre class="cython line score-0"> <span class="">09</span>: <span class="nd">@cython</span><span class="o">.</span><span class="n">wraparound</span><span class="p">(</span><span class="bp">False</span><span class="p">)</span></pre>
|
class="o">.</span><span class="n">boundscheck</span><span class="p">(</span><span
|
||||||
<pre class="cython line score-0"> <span class="">10</span>: <span class="nd">@cython</span><span class="o">.</span><span class="n">cdivision</span><span class="p">(</span><span class="bp">True</span><span class="p">)</span></pre>
|
class="bp">False</span><span class="p">)</span></pre>
|
||||||
|
<pre class="cython line score-0"> <span class="">09</span>: <span class="nd">@cython</span><span
|
||||||
|
class="o">.</span><span class="n">wraparound</span><span class="p">(</span><span
|
||||||
|
class="bp">False</span><span class="p">)</span></pre>
|
||||||
|
<pre class="cython line score-0"> <span class="">10</span>: <span class="nd">@cython</span><span
|
||||||
|
class="o">.</span><span class="n">cdivision</span><span class="p">(</span><span class="bp">True</span><span
|
||||||
|
class="p">)</span></pre>
|
||||||
<pre class="cython line score-0"> <span class="">11</span>: </pre>
|
<pre class="cython line score-0"> <span class="">11</span>: </pre>
|
||||||
<pre class="cython line score-0"> <span class="">12</span>: <span class="c"># 3 parameters:</span></pre>
|
<pre class="cython line score-0"> <span class="">12</span>: <span class="c"># 3 parameters:</span></pre>
|
||||||
<pre class="cython line score-0"> <span class="">13</span>: <span class="c"># - float image</span></pre>
|
<pre class="cython line score-0"> <span class="">13</span>: <span class="c"># - float image</span></pre>
|
||||||
<pre class="cython line score-0"> <span class="">14</span>: <span class="c"># - kernel size (actually this is the radius, kernel is 2*k+1)</span></pre>
|
<pre class="cython line score-0"> <span class="">14</span>: <span class="c"># - kernel size (actually this is the radius, kernel is 2*k+1)</span></pre>
|
||||||
<pre class="cython line score-0"> <span class="">15</span>: <span class="c"># - small constant epsilon that is used to avoid division by zero</span></pre>
|
<pre class="cython line score-0"> <span class="">15</span>: <span class="c"># - small constant epsilon that is used to avoid division by zero</span></pre>
|
||||||
<pre class="cython line score-67" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">16</span>: <span class="k">def</span> <span class="nf">normalize</span><span class="p">(</span><span class="nb">float</span><span class="p">[:,</span> <span class="p">:]</span> <span class="n">img</span><span class="p">,</span> <span class="nb">int</span> <span class="n">kernel_size</span> <span class="o">=</span> <span class="mf">4</span><span class="p">,</span> <span class="nb">float</span> <span class="n">epsilon</span> <span class="o">=</span> <span class="mf">0.01</span><span class="p">):</span></pre>
|
<pre class="cython line score-67"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">16</span>: <span class="k">def</span> <span class="nf">normalize</span><span
|
||||||
|
class="p">(</span><span class="nb">float</span><span class="p">[:,</span> <span class="p">:]</span> <span
|
||||||
|
class="n">img</span><span class="p">,</span> <span class="nb">int</span> <span class="n">kernel_size</span> <span
|
||||||
|
class="o">=</span> <span class="mf">4</span><span class="p">,</span> <span class="nb">float</span> <span
|
||||||
|
class="n">epsilon</span> <span class="o">=</span> <span class="mf">0.01</span><span
|
||||||
|
class="p">):</span></pre>
|
||||||
<pre class='cython code score-67 '>/* Python wrapper */
|
<pre class='cython code score-67 '>/* Python wrapper */
|
||||||
static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/
|
static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/
|
||||||
static PyMethodDef __pyx_mdef_3lcn_1normalize = {"normalize", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_3lcn_1normalize, METH_VARARGS|METH_KEYWORDS, 0};
|
static PyMethodDef __pyx_mdef_3lcn_1normalize = {"normalize", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_3lcn_1normalize, METH_VARARGS|METH_KEYWORDS, 0};
|
||||||
@ -434,7 +457,8 @@ static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (unlikely(kw_args > 0)) {
|
if (unlikely(kw_args > 0)) {
|
||||||
if (unlikely(<span class='pyx_c_api'>__Pyx_ParseOptionalKeywords</span>(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "normalize") < 0)) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
|
if (unlikely(<span class='pyx_c_api'>__Pyx_ParseOptionalKeywords</span>(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "normalize") < 0)) <span
|
||||||
|
class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
switch (<span class='py_macro_api'>PyTuple_GET_SIZE</span>(__pyx_args)) {
|
switch (<span class='py_macro_api'>PyTuple_GET_SIZE</span>(__pyx_args)) {
|
||||||
@ -447,21 +471,27 @@ static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_
|
|||||||
default: goto __pyx_L5_argtuple_error;
|
default: goto __pyx_L5_argtuple_error;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
__pyx_v_img = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(values[0], PyBUF_WRITABLE);<span class='error_goto'> if (unlikely(!__pyx_v_img.memview)) __PYX_ERR(0, 16, __pyx_L3_error)</span>
|
__pyx_v_img = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(values[0], PyBUF_WRITABLE);<span
|
||||||
|
class='error_goto'> if (unlikely(!__pyx_v_img.memview)) __PYX_ERR(0, 16, __pyx_L3_error)</span>
|
||||||
if (values[1]) {
|
if (values[1]) {
|
||||||
__pyx_v_kernel_size = <span class='pyx_c_api'>__Pyx_PyInt_As_int</span>(values[1]); if (unlikely((__pyx_v_kernel_size == (int)-1) && <span class='py_c_api'>PyErr_Occurred</span>())) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
|
__pyx_v_kernel_size = <span class='pyx_c_api'>__Pyx_PyInt_As_int</span>(values[1]); if (unlikely((__pyx_v_kernel_size == (int)-1) && <span
|
||||||
|
class='py_c_api'>PyErr_Occurred</span>())) <span
|
||||||
|
class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
|
||||||
} else {
|
} else {
|
||||||
__pyx_v_kernel_size = ((int)4);
|
__pyx_v_kernel_size = ((int)4);
|
||||||
}
|
}
|
||||||
if (values[2]) {
|
if (values[2]) {
|
||||||
__pyx_v_epsilon = __pyx_<span class='py_c_api'>PyFloat_AsFloat</span>(values[2]); if (unlikely((__pyx_v_epsilon == (float)-1) && <span class='py_c_api'>PyErr_Occurred</span>())) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
|
__pyx_v_epsilon = __pyx_<span class='py_c_api'>PyFloat_AsFloat</span>(values[2]); if (unlikely((__pyx_v_epsilon == (float)-1) && <span
|
||||||
|
class='py_c_api'>PyErr_Occurred</span>())) <span
|
||||||
|
class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
|
||||||
} else {
|
} else {
|
||||||
__pyx_v_epsilon = ((float)0.01);
|
__pyx_v_epsilon = ((float)0.01);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
goto __pyx_L4_argument_unpacking_done;
|
goto __pyx_L4_argument_unpacking_done;
|
||||||
__pyx_L5_argtuple_error:;
|
__pyx_L5_argtuple_error:;
|
||||||
<span class='pyx_c_api'>__Pyx_RaiseArgtupleInvalid</span>("normalize", 0, 1, 3, <span class='py_macro_api'>PyTuple_GET_SIZE</span>(__pyx_args)); <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
|
<span class='pyx_c_api'>__Pyx_RaiseArgtupleInvalid</span>("normalize", 0, 1, 3, <span class='py_macro_api'>PyTuple_GET_SIZE</span>(__pyx_args)); <span
|
||||||
|
class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
|
||||||
__pyx_L3_error:;
|
__pyx_L3_error:;
|
||||||
<span class='pyx_c_api'>__Pyx_AddTraceback</span>("lcn.normalize", __pyx_clineno, __pyx_lineno, __pyx_filename);
|
<span class='pyx_c_api'>__Pyx_AddTraceback</span>("lcn.normalize", __pyx_clineno, __pyx_lineno, __pyx_filename);
|
||||||
<span class='refnanny'>__Pyx_RefNannyFinishContext</span>();
|
<span class='refnanny'>__Pyx_RefNannyFinishContext</span>();
|
||||||
@ -515,27 +545,49 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
|
|||||||
return __pyx_r;
|
return __pyx_r;
|
||||||
}
|
}
|
||||||
/* … */
|
/* … */
|
||||||
__pyx_tuple__19 = <span class='py_c_api'>PyTuple_Pack</span>(19, __pyx_n_s_img, __pyx_n_s_kernel_size, __pyx_n_s_epsilon, __pyx_n_s_M, __pyx_n_s_N, __pyx_n_s_img_lcn, __pyx_n_s_img_std, __pyx_n_s_img_lcn_view, __pyx_n_s_img_std_view, __pyx_n_s_tmp, __pyx_n_s_mean, __pyx_n_s_stddev, __pyx_n_s_m, __pyx_n_s_n, __pyx_n_s_i, __pyx_n_s_j, __pyx_n_s_ks, __pyx_n_s_eps, __pyx_n_s_num);<span class='error_goto'> if (unlikely(!__pyx_tuple__19)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
|
__pyx_tuple__19 = <span class='py_c_api'>PyTuple_Pack</span>(19, __pyx_n_s_img, __pyx_n_s_kernel_size, __pyx_n_s_epsilon, __pyx_n_s_M, __pyx_n_s_N, __pyx_n_s_img_lcn, __pyx_n_s_img_std, __pyx_n_s_img_lcn_view, __pyx_n_s_img_std_view, __pyx_n_s_tmp, __pyx_n_s_mean, __pyx_n_s_stddev, __pyx_n_s_m, __pyx_n_s_n, __pyx_n_s_i, __pyx_n_s_j, __pyx_n_s_ks, __pyx_n_s_eps, __pyx_n_s_num);<span
|
||||||
|
class='error_goto'> if (unlikely(!__pyx_tuple__19)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
|
||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_tuple__19);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_tuple__19);
|
||||||
<span class='refnanny'>__Pyx_GIVEREF</span>(__pyx_tuple__19);
|
<span class='refnanny'>__Pyx_GIVEREF</span>(__pyx_tuple__19);
|
||||||
/* … */
|
/* … */
|
||||||
__pyx_t_1 = PyCFunction_NewEx(&__pyx_mdef_3lcn_1normalize, NULL, __pyx_n_s_lcn);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
|
__pyx_t_1 = PyCFunction_NewEx(&__pyx_mdef_3lcn_1normalize, NULL, __pyx_n_s_lcn);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
|
||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
||||||
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_normalize, __pyx_t_1) < 0) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L1_error)</span>
|
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_normalize, __pyx_t_1) < 0) <span
|
||||||
|
class='error_goto'>__PYX_ERR(0, 16, __pyx_L1_error)</span>
|
||||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
|
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
|
||||||
__pyx_codeobj__20 = (PyObject*)<span class='pyx_c_api'>__Pyx_PyCode_New</span>(3, 0, 19, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__19, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_lcn_pyx, __pyx_n_s_normalize, 16, __pyx_empty_bytes);<span class='error_goto'> if (unlikely(!__pyx_codeobj__20)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
|
__pyx_codeobj__20 = (PyObject*)<span class='pyx_c_api'>__Pyx_PyCode_New</span>(3, 0, 19, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__19, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_lcn_pyx, __pyx_n_s_normalize, 16, __pyx_empty_bytes);<span
|
||||||
</pre><pre class="cython line score-0"> <span class="">17</span>: </pre>
|
class='error_goto'> if (unlikely(!__pyx_codeobj__20)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
|
||||||
|
</pre>
|
||||||
|
<pre class="cython line score-0"> <span class="">17</span>: </pre>
|
||||||
<pre class="cython line score-0"> <span class="">18</span>: <span class="c"># image dimensions</span></pre>
|
<pre class="cython line score-0"> <span class="">18</span>: <span class="c"># image dimensions</span></pre>
|
||||||
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">19</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">M</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mf">0</span><span class="p">]</span></pre>
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">19</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span
|
||||||
|
class="nf">M</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span
|
||||||
|
class="n">shape</span><span class="p">[</span><span class="mf">0</span><span class="p">]</span></pre>
|
||||||
<pre class='cython code score-0 '> __pyx_v_M = (__pyx_v_img.shape[0]);
|
<pre class='cython code score-0 '> __pyx_v_M = (__pyx_v_img.shape[0]);
|
||||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">20</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">N</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mf">1</span><span class="p">]</span></pre>
|
</pre>
|
||||||
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">20</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span
|
||||||
|
class="nf">N</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span
|
||||||
|
class="n">shape</span><span class="p">[</span><span class="mf">1</span><span class="p">]</span></pre>
|
||||||
<pre class='cython code score-0 '> __pyx_v_N = (__pyx_v_img.shape[1]);
|
<pre class='cython code score-0 '> __pyx_v_N = (__pyx_v_img.shape[1]);
|
||||||
</pre><pre class="cython line score-0"> <span class="">21</span>: </pre>
|
</pre>
|
||||||
|
<pre class="cython line score-0"> <span class="">21</span>: </pre>
|
||||||
<pre class="cython line score-0"> <span class="">22</span>: <span class="c"># create outputs and output views</span></pre>
|
<pre class="cython line score-0"> <span class="">22</span>: <span class="c"># create outputs and output views</span></pre>
|
||||||
<pre class="cython line score-46" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">23</span>: <span class="n">img_lcn</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span></pre>
|
<pre class="cython line score-46"
|
||||||
<pre class='cython code score-46 '> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_1, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">23</span>: <span class="n">img_lcn</span> <span class="o">=</span> <span
|
||||||
|
class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span
|
||||||
|
class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span
|
||||||
|
class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span
|
||||||
|
class="n">float32</span><span class="p">)</span></pre>
|
||||||
|
<pre class='cython code score-46 '> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_1, __pyx_n_s_np);<span
|
||||||
|
class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
||||||
__pyx_t_2 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_1, __pyx_n_s_zeros);<span class='error_goto'> if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
__pyx_t_2 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_1, __pyx_n_s_zeros);<span
|
||||||
|
class='error_goto'> if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_2);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_2);
|
||||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
|
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
|
||||||
__pyx_t_1 = <span class='py_c_api'>PyInt_FromSsize_t</span>(__pyx_v_M);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
__pyx_t_1 = <span class='py_c_api'>PyInt_FromSsize_t</span>(__pyx_v_M);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
||||||
@ -559,22 +611,34 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
|
|||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_4);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_4);
|
||||||
<span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_1, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
<span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_1, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
||||||
__pyx_t_5 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_1, __pyx_n_s_float32);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
__pyx_t_5 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_1, __pyx_n_s_float32);<span
|
||||||
|
class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
|
||||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
|
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
|
||||||
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_t_4, __pyx_n_s_dtype, __pyx_t_5) < 0) <span class='error_goto'>__PYX_ERR(0, 23, __pyx_L1_error)</span>
|
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_t_4, __pyx_n_s_dtype, __pyx_t_5) < 0) <span
|
||||||
|
class='error_goto'>__PYX_ERR(0, 23, __pyx_L1_error)</span>
|
||||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0;
|
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0;
|
||||||
__pyx_t_5 = <span class='pyx_c_api'>__Pyx_PyObject_Call</span>(__pyx_t_2, __pyx_t_3, __pyx_t_4);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
__pyx_t_5 = <span class='pyx_c_api'>__Pyx_PyObject_Call</span>(__pyx_t_2, __pyx_t_3, __pyx_t_4);<span
|
||||||
|
class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
|
||||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_2); __pyx_t_2 = 0;
|
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_2); __pyx_t_2 = 0;
|
||||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_3); __pyx_t_3 = 0;
|
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_3); __pyx_t_3 = 0;
|
||||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_4); __pyx_t_4 = 0;
|
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_4); __pyx_t_4 = 0;
|
||||||
__pyx_v_img_lcn = __pyx_t_5;
|
__pyx_v_img_lcn = __pyx_t_5;
|
||||||
__pyx_t_5 = 0;
|
__pyx_t_5 = 0;
|
||||||
</pre><pre class="cython line score-46" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">24</span>: <span class="n">img_std</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span></pre>
|
</pre>
|
||||||
<pre class='cython code score-46 '> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_5, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
<pre class="cython line score-46"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">24</span>: <span class="n">img_std</span> <span class="o">=</span> <span
|
||||||
|
class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span
|
||||||
|
class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span
|
||||||
|
class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span
|
||||||
|
class="n">float32</span><span class="p">)</span></pre>
|
||||||
|
<pre class='cython code score-46 '> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_5, __pyx_n_s_np);<span
|
||||||
|
class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
|
||||||
__pyx_t_4 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_5, __pyx_n_s_zeros);<span class='error_goto'> if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
__pyx_t_4 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_5, __pyx_n_s_zeros);<span
|
||||||
|
class='error_goto'> if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_4);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_4);
|
||||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0;
|
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0;
|
||||||
__pyx_t_5 = <span class='py_c_api'>PyInt_FromSsize_t</span>(__pyx_v_M);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
__pyx_t_5 = <span class='py_c_api'>PyInt_FromSsize_t</span>(__pyx_v_M);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
||||||
@ -598,87 +662,187 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
|
|||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_2);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_2);
|
||||||
<span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_5, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
<span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_5, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
|
||||||
__pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_5, __pyx_n_s_float32);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
__pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_5, __pyx_n_s_float32);<span
|
||||||
|
class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
||||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0;
|
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0;
|
||||||
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_t_2, __pyx_n_s_dtype, __pyx_t_1) < 0) <span class='error_goto'>__PYX_ERR(0, 24, __pyx_L1_error)</span>
|
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_t_2, __pyx_n_s_dtype, __pyx_t_1) < 0) <span
|
||||||
|
class='error_goto'>__PYX_ERR(0, 24, __pyx_L1_error)</span>
|
||||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
|
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
|
||||||
__pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyObject_Call</span>(__pyx_t_4, __pyx_t_3, __pyx_t_2);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
__pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyObject_Call</span>(__pyx_t_4, __pyx_t_3, __pyx_t_2);<span
|
||||||
|
class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
||||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_4); __pyx_t_4 = 0;
|
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_4); __pyx_t_4 = 0;
|
||||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_3); __pyx_t_3 = 0;
|
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_3); __pyx_t_3 = 0;
|
||||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_2); __pyx_t_2 = 0;
|
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_2); __pyx_t_2 = 0;
|
||||||
__pyx_v_img_std = __pyx_t_1;
|
__pyx_v_img_std = __pyx_t_1;
|
||||||
__pyx_t_1 = 0;
|
__pyx_t_1 = 0;
|
||||||
</pre><pre class="cython line score-2" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">25</span>: <span class="k">cdef</span> <span class="kt">float</span>[<span class="p">:,</span> <span class="p">:]</span> <span class="n">img_lcn_view</span> <span class="o">=</span> <span class="n">img_lcn</span></pre>
|
</pre>
|
||||||
<pre class='cython code score-2 '> __pyx_t_6 = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(__pyx_v_img_lcn, PyBUF_WRITABLE);<span class='error_goto'> if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 25, __pyx_L1_error)</span>
|
<pre class="cython line score-2"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">25</span>: <span class="k">cdef</span> <span class="kt">float</span>[<span class="p">:,</span> <span
|
||||||
|
class="p">:]</span> <span class="n">img_lcn_view</span> <span class="o">=</span> <span
|
||||||
|
class="n">img_lcn</span></pre>
|
||||||
|
<pre class='cython code score-2 '> __pyx_t_6 = <span
|
||||||
|
class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(__pyx_v_img_lcn, PyBUF_WRITABLE);<span
|
||||||
|
class='error_goto'> if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 25, __pyx_L1_error)</span>
|
||||||
__pyx_v_img_lcn_view = __pyx_t_6;
|
__pyx_v_img_lcn_view = __pyx_t_6;
|
||||||
__pyx_t_6.memview = NULL;
|
__pyx_t_6.memview = NULL;
|
||||||
__pyx_t_6.data = NULL;
|
__pyx_t_6.data = NULL;
|
||||||
</pre><pre class="cython line score-2" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">26</span>: <span class="k">cdef</span> <span class="kt">float</span>[<span class="p">:,</span> <span class="p">:]</span> <span class="n">img_std_view</span> <span class="o">=</span> <span class="n">img_std</span></pre>
|
</pre>
|
||||||
<pre class='cython code score-2 '> __pyx_t_6 = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(__pyx_v_img_std, PyBUF_WRITABLE);<span class='error_goto'> if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 26, __pyx_L1_error)</span>
|
<pre class="cython line score-2"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">26</span>: <span class="k">cdef</span> <span class="kt">float</span>[<span class="p">:,</span> <span
|
||||||
|
class="p">:]</span> <span class="n">img_std_view</span> <span class="o">=</span> <span
|
||||||
|
class="n">img_std</span></pre>
|
||||||
|
<pre class='cython code score-2 '> __pyx_t_6 = <span
|
||||||
|
class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(__pyx_v_img_std, PyBUF_WRITABLE);<span
|
||||||
|
class='error_goto'> if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 26, __pyx_L1_error)</span>
|
||||||
__pyx_v_img_std_view = __pyx_t_6;
|
__pyx_v_img_std_view = __pyx_t_6;
|
||||||
__pyx_t_6.memview = NULL;
|
__pyx_t_6.memview = NULL;
|
||||||
__pyx_t_6.data = NULL;
|
__pyx_t_6.data = NULL;
|
||||||
</pre><pre class="cython line score-0"> <span class="">27</span>: </pre>
|
</pre>
|
||||||
|
<pre class="cython line score-0"> <span class="">27</span>: </pre>
|
||||||
<pre class="cython line score-0"> <span class="">28</span>: <span class="c"># temporary c variables</span></pre>
|
<pre class="cython line score-0"> <span class="">28</span>: <span class="c"># temporary c variables</span></pre>
|
||||||
<pre class="cython line score-0"> <span class="">29</span>: <span class="k">cdef</span> <span class="kt">float</span> <span class="nf">tmp</span><span class="p">,</span> <span class="nf">mean</span><span class="p">,</span> <span class="nf">stddev</span></pre>
|
<pre class="cython line score-0"> <span class="">29</span>: <span class="k">cdef</span> <span class="kt">float</span> <span
|
||||||
<pre class="cython line score-0"> <span class="">30</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">m</span><span class="p">,</span> <span class="nf">n</span><span class="p">,</span> <span class="nf">i</span><span class="p">,</span> <span class="nf">j</span></pre>
|
class="nf">tmp</span><span class="p">,</span> <span class="nf">mean</span><span class="p">,</span> <span
|
||||||
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">31</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">ks</span> <span class="o">=</span> <span class="n">kernel_size</span></pre>
|
class="nf">stddev</span></pre>
|
||||||
|
<pre class="cython line score-0"> <span class="">30</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span
|
||||||
|
class="nf">m</span><span class="p">,</span> <span class="nf">n</span><span class="p">,</span> <span
|
||||||
|
class="nf">i</span><span class="p">,</span> <span class="nf">j</span></pre>
|
||||||
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">31</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span
|
||||||
|
class="nf">ks</span> <span class="o">=</span> <span class="n">kernel_size</span></pre>
|
||||||
<pre class='cython code score-0 '> __pyx_v_ks = __pyx_v_kernel_size;
|
<pre class='cython code score-0 '> __pyx_v_ks = __pyx_v_kernel_size;
|
||||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">32</span>: <span class="k">cdef</span> <span class="kt">float</span> <span class="nf">eps</span> <span class="o">=</span> <span class="n">epsilon</span></pre>
|
</pre>
|
||||||
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">32</span>: <span class="k">cdef</span> <span class="kt">float</span> <span
|
||||||
|
class="nf">eps</span> <span class="o">=</span> <span class="n">epsilon</span></pre>
|
||||||
<pre class='cython code score-0 '> __pyx_v_eps = __pyx_v_epsilon;
|
<pre class='cython code score-0 '> __pyx_v_eps = __pyx_v_epsilon;
|
||||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">33</span>: <span class="k">cdef</span> <span class="kt">float</span> <span class="nf">num</span> <span class="o">=</span> <span class="p">(</span><span class="n">ks</span><span class="o">*</span><span class="mf">2</span><span class="o">+</span><span class="mf">1</span><span class="p">)</span><span class="o">**</span><span class="mf">2</span></pre>
|
</pre>
|
||||||
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">33</span>: <span class="k">cdef</span> <span class="kt">float</span> <span
|
||||||
|
class="nf">num</span> <span class="o">=</span> <span class="p">(</span><span class="n">ks</span><span
|
||||||
|
class="o">*</span><span class="mf">2</span><span class="o">+</span><span class="mf">1</span><span class="p">)</span><span
|
||||||
|
class="o">**</span><span class="mf">2</span></pre>
|
||||||
<pre class='cython code score-0 '> __pyx_v_num = __Pyx_pow_Py_ssize_t(((__pyx_v_ks * 2) + 1), 2);
|
<pre class='cython code score-0 '> __pyx_v_num = __Pyx_pow_Py_ssize_t(((__pyx_v_ks * 2) + 1), 2);
|
||||||
</pre><pre class="cython line score-0"> <span class="">34</span>: </pre>
|
</pre>
|
||||||
<pre class="cython line score-0"> <span class="">35</span>: <span class="c"># for all pixels do</span></pre>
|
<pre class="cython line score-0"> <span class="">34</span>: </pre>
|
||||||
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">36</span>: <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ks</span><span class="p">,</span><span class="n">M</span><span class="o">-</span><span class="n">ks</span><span class="p">):</span></pre>
|
<pre class="cython line score-0"> <span class="">35</span>: <span
|
||||||
|
class="c"># for all pixels do</span></pre>
|
||||||
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">36</span>: <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span
|
||||||
|
class="nb">range</span><span class="p">(</span><span class="n">ks</span><span class="p">,</span><span
|
||||||
|
class="n">M</span><span class="o">-</span><span class="n">ks</span><span class="p">):</span></pre>
|
||||||
<pre class='cython code score-0 '> __pyx_t_7 = (__pyx_v_M - __pyx_v_ks);
|
<pre class='cython code score-0 '> __pyx_t_7 = (__pyx_v_M - __pyx_v_ks);
|
||||||
__pyx_t_8 = __pyx_t_7;
|
__pyx_t_8 = __pyx_t_7;
|
||||||
for (__pyx_t_9 = __pyx_v_ks; __pyx_t_9 < __pyx_t_8; __pyx_t_9+=1) {
|
for (__pyx_t_9 = __pyx_v_ks; __pyx_t_9 < __pyx_t_8; __pyx_t_9+=1) {
|
||||||
__pyx_v_m = __pyx_t_9;
|
__pyx_v_m = __pyx_t_9;
|
||||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">37</span>: <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ks</span><span class="p">,</span><span class="n">N</span><span class="o">-</span><span class="n">ks</span><span class="p">):</span></pre>
|
</pre>
|
||||||
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">37</span>: <span class="k">for</span> <span class="n">n</span> <span
|
||||||
|
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ks</span><span
|
||||||
|
class="p">,</span><span class="n">N</span><span class="o">-</span><span class="n">ks</span><span class="p">):</span></pre>
|
||||||
<pre class='cython code score-0 '> __pyx_t_10 = (__pyx_v_N - __pyx_v_ks);
|
<pre class='cython code score-0 '> __pyx_t_10 = (__pyx_v_N - __pyx_v_ks);
|
||||||
__pyx_t_11 = __pyx_t_10;
|
__pyx_t_11 = __pyx_t_10;
|
||||||
for (__pyx_t_12 = __pyx_v_ks; __pyx_t_12 < __pyx_t_11; __pyx_t_12+=1) {
|
for (__pyx_t_12 = __pyx_v_ks; __pyx_t_12 < __pyx_t_11; __pyx_t_12+=1) {
|
||||||
__pyx_v_n = __pyx_t_12;
|
__pyx_v_n = __pyx_t_12;
|
||||||
</pre><pre class="cython line score-0"> <span class="">38</span>: </pre>
|
</pre>
|
||||||
|
<pre class="cython line score-0"> <span class="">38</span>: </pre>
|
||||||
<pre class="cython line score-0"> <span class="">39</span>: <span class="c"># calculate mean</span></pre>
|
<pre class="cython line score-0"> <span class="">39</span>: <span class="c"># calculate mean</span></pre>
|
||||||
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">40</span>: <span class="n">mean</span> <span class="o">=</span> <span class="mf">0</span><span class="p">;</span></pre>
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">40</span>: <span class="n">mean</span> <span class="o">=</span> <span
|
||||||
|
class="mf">0</span><span class="p">;</span></pre>
|
||||||
<pre class='cython code score-0 '> __pyx_v_mean = 0.0;
|
<pre class='cython code score-0 '> __pyx_v_mean = 0.0;
|
||||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">41</span>: <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre>
|
</pre>
|
||||||
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">41</span>: <span class="k">for</span> <span class="n">i</span> <span
|
||||||
|
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span
|
||||||
|
class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span
|
||||||
|
class="mf">1</span><span class="p">):</span></pre>
|
||||||
<pre class='cython code score-0 '> __pyx_t_13 = (__pyx_v_ks + 1);
|
<pre class='cython code score-0 '> __pyx_t_13 = (__pyx_v_ks + 1);
|
||||||
__pyx_t_14 = __pyx_t_13;
|
__pyx_t_14 = __pyx_t_13;
|
||||||
for (__pyx_t_15 = (-__pyx_v_ks); __pyx_t_15 < __pyx_t_14; __pyx_t_15+=1) {
|
for (__pyx_t_15 = (-__pyx_v_ks); __pyx_t_15 < __pyx_t_14; __pyx_t_15+=1) {
|
||||||
__pyx_v_i = __pyx_t_15;
|
__pyx_v_i = __pyx_t_15;
|
||||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">42</span>: <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre>
|
</pre>
|
||||||
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">42</span>: <span class="k">for</span> <span class="n">j</span> <span
|
||||||
|
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span
|
||||||
|
class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span
|
||||||
|
class="mf">1</span><span class="p">):</span></pre>
|
||||||
<pre class='cython code score-0 '> __pyx_t_16 = (__pyx_v_ks + 1);
|
<pre class='cython code score-0 '> __pyx_t_16 = (__pyx_v_ks + 1);
|
||||||
__pyx_t_17 = __pyx_t_16;
|
__pyx_t_17 = __pyx_t_16;
|
||||||
for (__pyx_t_18 = (-__pyx_v_ks); __pyx_t_18 < __pyx_t_17; __pyx_t_18+=1) {
|
for (__pyx_t_18 = (-__pyx_v_ks); __pyx_t_18 < __pyx_t_17; __pyx_t_18+=1) {
|
||||||
__pyx_v_j = __pyx_t_18;
|
__pyx_v_j = __pyx_t_18;
|
||||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">43</span>: <span class="n">mean</span> <span class="o">+=</span> <span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span class="p">]</span></pre>
|
</pre>
|
||||||
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">43</span>: <span class="n">mean</span> <span class="o">+=</span> <span
|
||||||
|
class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span
|
||||||
|
class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span
|
||||||
|
class="p">]</span></pre>
|
||||||
<pre class='cython code score-0 '> __pyx_t_19 = (__pyx_v_m + __pyx_v_i);
|
<pre class='cython code score-0 '> __pyx_t_19 = (__pyx_v_m + __pyx_v_i);
|
||||||
__pyx_t_20 = (__pyx_v_n + __pyx_v_j);
|
__pyx_t_20 = (__pyx_v_n + __pyx_v_j);
|
||||||
__pyx_v_mean = (__pyx_v_mean + (*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_19 * __pyx_v_img.strides[0]) ) + __pyx_t_20 * __pyx_v_img.strides[1]) ))));
|
__pyx_v_mean = (__pyx_v_mean + (*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_19 * __pyx_v_img.strides[0]) ) + __pyx_t_20 * __pyx_v_img.strides[1]) ))));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">44</span>: <span class="n">mean</span> <span class="o">=</span> <span class="n">mean</span><span class="o">/</span><span class="n">num</span></pre>
|
</pre>
|
||||||
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">44</span>: <span class="n">mean</span> <span class="o">=</span> <span
|
||||||
|
class="n">mean</span><span class="o">/</span><span class="n">num</span></pre>
|
||||||
<pre class='cython code score-0 '> __pyx_v_mean = (__pyx_v_mean / __pyx_v_num);
|
<pre class='cython code score-0 '> __pyx_v_mean = (__pyx_v_mean / __pyx_v_num);
|
||||||
</pre><pre class="cython line score-0"> <span class="">45</span>: </pre>
|
</pre>
|
||||||
<pre class="cython line score-0"> <span class="">46</span>: <span class="c"># calculate std dev</span></pre>
|
<pre class="cython line score-0"> <span class="">45</span>: </pre>
|
||||||
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">47</span>: <span class="n">stddev</span> <span class="o">=</span> <span class="mf">0</span><span class="p">;</span></pre>
|
<pre class="cython line score-0"> <span class="">46</span>: <span
|
||||||
|
class="c"># calculate std dev</span></pre>
|
||||||
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">47</span>: <span class="n">stddev</span> <span class="o">=</span> <span
|
||||||
|
class="mf">0</span><span class="p">;</span></pre>
|
||||||
<pre class='cython code score-0 '> __pyx_v_stddev = 0.0;
|
<pre class='cython code score-0 '> __pyx_v_stddev = 0.0;
|
||||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">48</span>: <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre>
|
</pre>
|
||||||
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">48</span>: <span class="k">for</span> <span class="n">i</span> <span
|
||||||
|
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span
|
||||||
|
class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span
|
||||||
|
class="mf">1</span><span class="p">):</span></pre>
|
||||||
<pre class='cython code score-0 '> __pyx_t_13 = (__pyx_v_ks + 1);
|
<pre class='cython code score-0 '> __pyx_t_13 = (__pyx_v_ks + 1);
|
||||||
__pyx_t_14 = __pyx_t_13;
|
__pyx_t_14 = __pyx_t_13;
|
||||||
for (__pyx_t_15 = (-__pyx_v_ks); __pyx_t_15 < __pyx_t_14; __pyx_t_15+=1) {
|
for (__pyx_t_15 = (-__pyx_v_ks); __pyx_t_15 < __pyx_t_14; __pyx_t_15+=1) {
|
||||||
__pyx_v_i = __pyx_t_15;
|
__pyx_v_i = __pyx_t_15;
|
||||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">49</span>: <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre>
|
</pre>
|
||||||
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">49</span>: <span class="k">for</span> <span class="n">j</span> <span
|
||||||
|
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span
|
||||||
|
class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span
|
||||||
|
class="mf">1</span><span class="p">):</span></pre>
|
||||||
<pre class='cython code score-0 '> __pyx_t_16 = (__pyx_v_ks + 1);
|
<pre class='cython code score-0 '> __pyx_t_16 = (__pyx_v_ks + 1);
|
||||||
__pyx_t_17 = __pyx_t_16;
|
__pyx_t_17 = __pyx_t_16;
|
||||||
for (__pyx_t_18 = (-__pyx_v_ks); __pyx_t_18 < __pyx_t_17; __pyx_t_18+=1) {
|
for (__pyx_t_18 = (-__pyx_v_ks); __pyx_t_18 < __pyx_t_17; __pyx_t_18+=1) {
|
||||||
__pyx_v_j = __pyx_t_18;
|
__pyx_v_j = __pyx_t_18;
|
||||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">50</span>: <span class="n">stddev</span> <span class="o">=</span> <span class="n">stddev</span> <span class="o">+</span> <span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span><span class="o">*</span><span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span></pre>
|
</pre>
|
||||||
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">50</span>: <span class="n">stddev</span> <span class="o">=</span> <span
|
||||||
|
class="n">stddev</span> <span class="o">+</span> <span class="p">(</span><span class="n">img</span><span
|
||||||
|
class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span
|
||||||
|
class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span
|
||||||
|
class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span><span
|
||||||
|
class="o">*</span><span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span
|
||||||
|
class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span
|
||||||
|
class="o">+</span><span class="n">j</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span
|
||||||
|
class="p">)</span></pre>
|
||||||
<pre class='cython code score-0 '> __pyx_t_21 = (__pyx_v_m + __pyx_v_i);
|
<pre class='cython code score-0 '> __pyx_t_21 = (__pyx_v_m + __pyx_v_i);
|
||||||
__pyx_t_22 = (__pyx_v_n + __pyx_v_j);
|
__pyx_t_22 = (__pyx_v_n + __pyx_v_j);
|
||||||
__pyx_t_23 = (__pyx_v_m + __pyx_v_i);
|
__pyx_t_23 = (__pyx_v_m + __pyx_v_i);
|
||||||
@ -686,25 +850,47 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
|
|||||||
__pyx_v_stddev = (__pyx_v_stddev + (((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_21 * __pyx_v_img.strides[0]) ) + __pyx_t_22 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean) * ((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_23 * __pyx_v_img.strides[0]) ) + __pyx_t_24 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean)));
|
__pyx_v_stddev = (__pyx_v_stddev + (((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_21 * __pyx_v_img.strides[0]) ) + __pyx_t_22 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean) * ((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_23 * __pyx_v_img.strides[0]) ) + __pyx_t_24 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">51</span>: <span class="n">stddev</span> <span class="o">=</span> <span class="n">sqrt</span><span class="p">(</span><span class="n">stddev</span><span class="o">/</span><span class="n">num</span><span class="p">)</span></pre>
|
</pre>
|
||||||
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">51</span>: <span class="n">stddev</span> <span class="o">=</span> <span
|
||||||
|
class="n">sqrt</span><span class="p">(</span><span class="n">stddev</span><span class="o">/</span><span
|
||||||
|
class="n">num</span><span class="p">)</span></pre>
|
||||||
<pre class='cython code score-0 '> __pyx_v_stddev = sqrt((__pyx_v_stddev / __pyx_v_num));
|
<pre class='cython code score-0 '> __pyx_v_stddev = sqrt((__pyx_v_stddev / __pyx_v_num));
|
||||||
</pre><pre class="cython line score-0"> <span class="">52</span>: </pre>
|
</pre>
|
||||||
|
<pre class="cython line score-0"> <span class="">52</span>: </pre>
|
||||||
<pre class="cython line score-0"> <span class="">53</span>: <span class="c"># compute normalized image (add epsilon) and std dev image</span></pre>
|
<pre class="cython line score-0"> <span class="">53</span>: <span class="c"># compute normalized image (add epsilon) and std dev image</span></pre>
|
||||||
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">54</span>: <span class="n">img_lcn_view</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span><span class="o">/</span><span class="p">(</span><span class="n">stddev</span><span class="o">+</span><span class="n">eps</span><span class="p">)</span></pre>
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">54</span>: <span class="n">img_lcn_view</span><span class="p">[</span><span
|
||||||
|
class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span
|
||||||
|
class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span
|
||||||
|
class="n">n</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span
|
||||||
|
class="p">)</span><span class="o">/</span><span class="p">(</span><span class="n">stddev</span><span
|
||||||
|
class="o">+</span><span class="n">eps</span><span class="p">)</span></pre>
|
||||||
<pre class='cython code score-0 '> __pyx_t_25 = __pyx_v_m;
|
<pre class='cython code score-0 '> __pyx_t_25 = __pyx_v_m;
|
||||||
__pyx_t_26 = __pyx_v_n;
|
__pyx_t_26 = __pyx_v_n;
|
||||||
__pyx_t_27 = __pyx_v_m;
|
__pyx_t_27 = __pyx_v_m;
|
||||||
__pyx_t_28 = __pyx_v_n;
|
__pyx_t_28 = __pyx_v_n;
|
||||||
*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img_lcn_view.data + __pyx_t_27 * __pyx_v_img_lcn_view.strides[0]) ) + __pyx_t_28 * __pyx_v_img_lcn_view.strides[1]) )) = (((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_25 * __pyx_v_img.strides[0]) ) + __pyx_t_26 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean) / (__pyx_v_stddev + __pyx_v_eps));
|
*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img_lcn_view.data + __pyx_t_27 * __pyx_v_img_lcn_view.strides[0]) ) + __pyx_t_28 * __pyx_v_img_lcn_view.strides[1]) )) = (((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_25 * __pyx_v_img.strides[0]) ) + __pyx_t_26 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean) / (__pyx_v_stddev + __pyx_v_eps));
|
||||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">55</span>: <span class="n">img_std_view</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span class="n">stddev</span></pre>
|
</pre>
|
||||||
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">55</span>: <span class="n">img_std_view</span><span class="p">[</span><span
|
||||||
|
class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span
|
||||||
|
class="n">stddev</span></pre>
|
||||||
<pre class='cython code score-0 '> __pyx_t_29 = __pyx_v_m;
|
<pre class='cython code score-0 '> __pyx_t_29 = __pyx_v_m;
|
||||||
__pyx_t_30 = __pyx_v_n;
|
__pyx_t_30 = __pyx_v_n;
|
||||||
*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img_std_view.data + __pyx_t_29 * __pyx_v_img_std_view.strides[0]) ) + __pyx_t_30 * __pyx_v_img_std_view.strides[1]) )) = __pyx_v_stddev;
|
*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img_std_view.data + __pyx_t_29 * __pyx_v_img_std_view.strides[0]) ) + __pyx_t_30 * __pyx_v_img_std_view.strides[1]) )) = __pyx_v_stddev;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
</pre><pre class="cython line score-0"> <span class="">56</span>: </pre>
|
</pre>
|
||||||
|
<pre class="cython line score-0"> <span class="">56</span>: </pre>
|
||||||
<pre class="cython line score-0"> <span class="">57</span>: <span class="c"># return both</span></pre>
|
<pre class="cython line score-0"> <span class="">57</span>: <span class="c"># return both</span></pre>
|
||||||
<pre class="cython line score-10" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">58</span>: <span class="k">return</span> <span class="n">img_lcn</span><span class="p">,</span> <span class="n">img_std</span></pre>
|
<pre class="cython line score-10"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">58</span>: <span class="k">return</span> <span class="n">img_lcn</span><span class="p">,</span> <span
|
||||||
|
class="n">img_std</span></pre>
|
||||||
<pre class='cython code score-10 '> <span class='pyx_macro_api'>__Pyx_XDECREF</span>(__pyx_r);
|
<pre class='cython code score-10 '> <span class='pyx_macro_api'>__Pyx_XDECREF</span>(__pyx_r);
|
||||||
__pyx_t_1 = <span class='py_c_api'>PyTuple_New</span>(2);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 58, __pyx_L1_error)</span>
|
__pyx_t_1 = <span class='py_c_api'>PyTuple_New</span>(2);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 58, __pyx_L1_error)</span>
|
||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
||||||
@ -717,4 +903,7 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
|
|||||||
__pyx_r = __pyx_t_1;
|
__pyx_r = __pyx_t_1;
|
||||||
__pyx_t_1 = 0;
|
__pyx_t_1 = 0;
|
||||||
goto __pyx_L0;
|
goto __pyx_L0;
|
||||||
</pre></div></body></html>
|
</pre>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
@ -24,7 +24,6 @@ def get_data(n, row_from, row_to, train):
|
|||||||
return ims, disps
|
return ims, disps
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
params = hd.TrainParams(
|
params = hd.TrainParams(
|
||||||
n_trees=4,
|
n_trees=4,
|
||||||
max_tree_depth=,
|
max_tree_depth=,
|
||||||
@ -52,9 +51,11 @@ for tree_depth in [8,10,12,14,16]:
|
|||||||
prefix = Path(f'./forests/{prefix}/')
|
prefix = Path(f'./forests/{prefix}/')
|
||||||
prefix.mkdir(parents=True, exist_ok=True)
|
prefix.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
hd.train_forest(params, train_ims, train_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch, forest_prefix=str(prefix / 'fr'))
|
hd.train_forest(params, train_ims, train_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch,
|
||||||
|
forest_prefix=str(prefix / 'fr'))
|
||||||
|
|
||||||
es = hd.eval_forest(test_ims, test_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch, forest_prefix=str(prefix / 'fr'))
|
es = hd.eval_forest(test_ims, test_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch,
|
||||||
|
forest_prefix=str(prefix / 'fr'))
|
||||||
|
|
||||||
np.save(str(prefix / 'ta.npy'), test_disps)
|
np.save(str(prefix / 'ta.npy'), test_disps)
|
||||||
np.save(str(prefix / 'es.npy'), es)
|
np.save(str(prefix / 'es.npy'), es)
|
||||||
|
@ -8,7 +8,6 @@ import os
|
|||||||
|
|
||||||
this_dir = os.path.dirname(__file__)
|
this_dir = os.path.dirname(__file__)
|
||||||
|
|
||||||
|
|
||||||
extra_compile_args = ['-O3', '-std=c++11']
|
extra_compile_args = ['-O3', '-std=c++11']
|
||||||
extra_link_args = []
|
extra_link_args = []
|
||||||
|
|
||||||
@ -39,7 +38,3 @@ setup(
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,10 +6,13 @@ orig = cv2.imread('disp_orig.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
|
|||||||
ta = cv2.imread('disp_ta.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
|
ta = cv2.imread('disp_ta.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
|
||||||
es = cv2.imread('disp_es.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
|
es = cv2.imread('disp_es.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
plt.figure()
|
plt.figure()
|
||||||
plt.subplot(2,2,1); plt.imshow(orig / 16, vmin=0, vmax=4, cmap='magma')
|
plt.subplot(2, 2, 1);
|
||||||
plt.subplot(2,2,2); plt.imshow(ta / 16, vmin=0, vmax=4, cmap='magma')
|
plt.imshow(orig / 16, vmin=0, vmax=4, cmap='magma')
|
||||||
plt.subplot(2,2,3); plt.imshow(es / 16, vmin=0, vmax=4, cmap='magma')
|
plt.subplot(2, 2, 2);
|
||||||
plt.subplot(2,2,4); plt.imshow(np.abs(es - ta) / 16, vmin=0, vmax=1, cmap='magma')
|
plt.imshow(ta / 16, vmin=0, vmax=4, cmap='magma')
|
||||||
|
plt.subplot(2, 2, 3);
|
||||||
|
plt.imshow(es / 16, vmin=0, vmax=4, cmap='magma')
|
||||||
|
plt.subplot(2, 2, 4);
|
||||||
|
plt.imshow(np.abs(es - ta) / 16, vmin=0, vmax=1, cmap='magma')
|
||||||
plt.show()
|
plt.show()
|
||||||
|
@ -12,9 +12,12 @@ import torchext
|
|||||||
from model import networks
|
from model import networks
|
||||||
from data import dataset
|
from data import dataset
|
||||||
|
|
||||||
|
|
||||||
class Worker(torchext.Worker):
|
class Worker(torchext.Worker):
|
||||||
def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs):
|
def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs):
|
||||||
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, train_batch_size=train_batch_size, test_batch_size=test_batch_size, save_frequency=save_frequency, **kwargs)
|
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers,
|
||||||
|
train_batch_size=train_batch_size, test_batch_size=test_batch_size,
|
||||||
|
save_frequency=save_frequency, **kwargs)
|
||||||
|
|
||||||
self.ms = args.ms
|
self.ms = args.ms
|
||||||
self.pattern_path = args.pattern_path
|
self.pattern_path = args.pattern_path
|
||||||
@ -22,7 +25,7 @@ class Worker(torchext.Worker):
|
|||||||
self.dp_weight = args.dp_weight
|
self.dp_weight = args.dp_weight
|
||||||
self.data_type = args.data_type
|
self.data_type = args.data_type
|
||||||
|
|
||||||
self.imsizes = [(480,640)]
|
self.imsizes = [(488, 648)]
|
||||||
for iter in range(3):
|
for iter in range(3):
|
||||||
self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2)))
|
self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2)))
|
||||||
|
|
||||||
@ -50,13 +53,15 @@ class Worker(torchext.Worker):
|
|||||||
self.eval_w = self.imsizes[0][1] - 13 - 140
|
self.eval_w = self.imsizes[0][1] - 13 - 140
|
||||||
|
|
||||||
def get_train_set(self):
|
def get_train_set(self):
|
||||||
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True, track_length=1)
|
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True,
|
||||||
|
track_length=1)
|
||||||
|
|
||||||
return train_set
|
return train_set
|
||||||
|
|
||||||
def get_test_sets(self):
|
def get_test_sets(self):
|
||||||
test_sets = torchext.TestSets()
|
test_sets = torchext.TestSets()
|
||||||
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True, track_length=1)
|
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True,
|
||||||
|
track_length=1)
|
||||||
test_sets.append('simple', test_set, test_frequency=1)
|
test_sets.append('simple', test_set, test_frequency=1)
|
||||||
|
|
||||||
# initialize photometric loss modules according to image sizes
|
# initialize photometric loss modules according to image sizes
|
||||||
@ -161,43 +166,76 @@ class Worker(torchext.Worker):
|
|||||||
im_orig = self.data['im0'].detach().to('cpu').numpy()[0, 0]
|
im_orig = self.data['im0'].detach().to('cpu').numpy()[0, 0]
|
||||||
pattern_diff = np.abs(im_orig - pattern_proj)
|
pattern_diff = np.abs(im_orig - pattern_proj)
|
||||||
|
|
||||||
|
|
||||||
fig = plt.figure(figsize=(16, 16))
|
fig = plt.figure(figsize=(16, 16))
|
||||||
es_ = co.cmap.color_depth_map(es, scale=vmax)
|
es_ = co.cmap.color_depth_map(es, scale=vmax)
|
||||||
gt_ = co.cmap.color_depth_map(gt, scale=vmax)
|
gt_ = co.cmap.color_depth_map(gt, scale=vmax)
|
||||||
diff_ = co.cmap.color_error_image(diff, BGR=True)
|
diff_ = co.cmap.color_error_image(diff, BGR=True)
|
||||||
|
|
||||||
# plot disparities, ground truth disparity is shown only for reference
|
# plot disparities, ground truth disparity is shown only for reference
|
||||||
ax = plt.subplot(3,3,1); plt.imshow(es_[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'Disparity Est. {es.min():.4f}/{es.max():.4f}')
|
ax = plt.subplot(3, 3, 1)
|
||||||
ax = plt.subplot(3,3,2); plt.imshow(gt_[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'Disparity GT {np.nanmin(gt):.4f}/{np.nanmax(gt):.4f}')
|
plt.imshow(es_[..., [2, 1, 0]])
|
||||||
ax = plt.subplot(3,3,3); plt.imshow(diff_[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'Disparity Err. {diff.mean():.5f}')
|
plt.xticks([])
|
||||||
|
plt.yticks([])
|
||||||
|
ax.set_title(f'Disparity Est. {es.min():.4f}/{es.max():.4f}')
|
||||||
|
ax = plt.subplot(3, 3, 2)
|
||||||
|
plt.imshow(gt_[..., [2, 1, 0]])
|
||||||
|
plt.xticks([])
|
||||||
|
plt.yticks([])
|
||||||
|
ax.set_title(f'Disparity GT {np.nanmin(gt):.4f}/{np.nanmax(gt):.4f}')
|
||||||
|
ax = plt.subplot(3, 3, 3)
|
||||||
|
plt.imshow(diff_[..., [2, 1, 0]])
|
||||||
|
plt.xticks([])
|
||||||
|
plt.yticks([])
|
||||||
|
ax.set_title(f'Disparity Err. {diff.mean():.5f}')
|
||||||
|
|
||||||
# plot edges
|
# plot edges
|
||||||
edge = self.edge.to('cpu').numpy()[0, 0]
|
edge = self.edge.to('cpu').numpy()[0, 0]
|
||||||
edge_gt = self.edge_gt.to('cpu').numpy()[0, 0]
|
edge_gt = self.edge_gt.to('cpu').numpy()[0, 0]
|
||||||
edge_err = np.abs(edge - edge_gt)
|
edge_err = np.abs(edge - edge_gt)
|
||||||
ax = plt.subplot(3,3,4); plt.imshow(edge, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Edge Est. {edge.min():.5f}/{edge.max():.5f}')
|
ax = plt.subplot(3, 3, 4);
|
||||||
ax = plt.subplot(3,3,5); plt.imshow(edge_gt, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Edge GT {edge_gt.min():.5f}/{edge_gt.max():.5f}')
|
plt.imshow(edge, cmap='gray');
|
||||||
ax = plt.subplot(3,3,6); plt.imshow(edge_err, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Edge Err. {edge_err.mean():.5f}')
|
plt.xticks([]);
|
||||||
|
plt.yticks([]);
|
||||||
|
ax.set_title(f'Edge Est. {edge.min():.5f}/{edge.max():.5f}')
|
||||||
|
ax = plt.subplot(3, 3, 5);
|
||||||
|
plt.imshow(edge_gt, cmap='gray');
|
||||||
|
plt.xticks([]);
|
||||||
|
plt.yticks([]);
|
||||||
|
ax.set_title(f'Edge GT {edge_gt.min():.5f}/{edge_gt.max():.5f}')
|
||||||
|
ax = plt.subplot(3, 3, 6);
|
||||||
|
plt.imshow(edge_err, cmap='gray');
|
||||||
|
plt.xticks([]);
|
||||||
|
plt.yticks([]);
|
||||||
|
ax.set_title(f'Edge Err. {edge_err.mean():.5f}')
|
||||||
|
|
||||||
# plot normalized IR input and warped pattern
|
# plot normalized IR input and warped pattern
|
||||||
ax = plt.subplot(3,3,7); plt.imshow(im, vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'IR input {im.mean():.5f}/{im.std():.5f}')
|
ax = plt.subplot(3, 3, 7);
|
||||||
ax = plt.subplot(3,3,8); plt.imshow(pattern_proj, vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Warped Pattern {pattern_proj.mean():.5f}/{pattern_proj.std():.5f}')
|
plt.imshow(im, vmin=im.min(), vmax=im.max(), cmap='gray');
|
||||||
|
plt.xticks([]);
|
||||||
|
plt.yticks([]);
|
||||||
|
ax.set_title(f'IR input {im.mean():.5f}/{im.std():.5f}')
|
||||||
|
ax = plt.subplot(3, 3, 8);
|
||||||
|
plt.imshow(pattern_proj, vmin=im.min(), vmax=im.max(), cmap='gray');
|
||||||
|
plt.xticks([]);
|
||||||
|
plt.yticks([]);
|
||||||
|
ax.set_title(f'Warped Pattern {pattern_proj.mean():.5f}/{pattern_proj.std():.5f}')
|
||||||
im_std = self.data['std0'].to('cpu').numpy()[0, 0]
|
im_std = self.data['std0'].to('cpu').numpy()[0, 0]
|
||||||
ax = plt.subplot(3,3,9); plt.imshow(im_std, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'IR std {im_std.min():.5f}/{im_std.max():.5f}')
|
ax = plt.subplot(3, 3, 9);
|
||||||
|
plt.imshow(im_std, cmap='gray');
|
||||||
|
plt.xticks([]);
|
||||||
|
plt.yticks([]);
|
||||||
|
ax.set_title(f'IR std {im_std.min():.5f}/{im_std.max():.5f}')
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.savefig(str(out_path))
|
plt.savefig(str(out_path))
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks=[]):
|
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks=[]):
|
||||||
if batch_idx % 512 == 0:
|
if batch_idx % 512 == 0:
|
||||||
out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png'
|
out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png'
|
||||||
es, gt, im, ma = self.numpy_in_out(output)
|
es, gt, im, ma = self.numpy_in_out(output)
|
||||||
self.write_img(out_path, es[0, 0], gt[0, 0], im[0, 0], ma[0, 0])
|
self.write_img(out_path, es[0, 0], gt[0, 0], im[0, 0], ma[0, 0])
|
||||||
|
|
||||||
|
|
||||||
def callback_test_start(self, epoch, set_idx):
|
def callback_test_start(self, epoch, set_idx):
|
||||||
self.metric = co.metric.MultipleMetric(
|
self.metric = co.metric.MultipleMetric(
|
||||||
co.metric.DistanceMetric(vec_length=1),
|
co.metric.DistanceMetric(vec_length=1),
|
||||||
@ -232,6 +270,5 @@ class Worker(torchext.Worker):
|
|||||||
return es, gt, im, ma
|
return es, gt, im, ma
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pass
|
pass
|
||||||
|
@ -12,9 +12,12 @@ import torchext
|
|||||||
from model import networks
|
from model import networks
|
||||||
from data import dataset
|
from data import dataset
|
||||||
|
|
||||||
|
|
||||||
class Worker(torchext.Worker):
|
class Worker(torchext.Worker):
|
||||||
def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs):
|
def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs):
|
||||||
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, train_batch_size=train_batch_size, test_batch_size=test_batch_size, save_frequency=save_frequency, **kwargs)
|
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers,
|
||||||
|
train_batch_size=train_batch_size, test_batch_size=test_batch_size,
|
||||||
|
save_frequency=save_frequency, **kwargs)
|
||||||
|
|
||||||
self.ms = args.ms
|
self.ms = args.ms
|
||||||
self.pattern_path = args.pattern_path
|
self.pattern_path = args.pattern_path
|
||||||
@ -52,14 +55,15 @@ class Worker(torchext.Worker):
|
|||||||
self.eval_h = self.imsizes[0][0] - 2 * 13
|
self.eval_h = self.imsizes[0][0] - 2 * 13
|
||||||
self.eval_w = self.imsizes[0][1] - 13 - 140
|
self.eval_w = self.imsizes[0][1] - 13 - 140
|
||||||
|
|
||||||
|
|
||||||
def get_train_set(self):
|
def get_train_set(self):
|
||||||
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True, track_length=self.track_length)
|
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True,
|
||||||
|
track_length=self.track_length)
|
||||||
return train_set
|
return train_set
|
||||||
|
|
||||||
def get_test_sets(self):
|
def get_test_sets(self):
|
||||||
test_sets = torchext.TestSets()
|
test_sets = torchext.TestSets()
|
||||||
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True, track_length=1)
|
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True,
|
||||||
|
track_length=1)
|
||||||
test_sets.append('simple', test_set, test_frequency=1)
|
test_sets.append('simple', test_set, test_frequency=1)
|
||||||
|
|
||||||
self.ph_losses = []
|
self.ph_losses = []
|
||||||
@ -231,23 +235,55 @@ class Worker(torchext.Worker):
|
|||||||
diff0 = co.cmap.color_error_image(diff[0], BGR=True)
|
diff0 = co.cmap.color_error_image(diff[0], BGR=True)
|
||||||
|
|
||||||
# plot disparities, ground truth disparity is shown only for reference
|
# plot disparities, ground truth disparity is shown only for reference
|
||||||
ax = plt.subplot(3,3,1); plt.imshow(es0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Est. {es0.min():.4f}/{es0.max():.4f}')
|
ax = plt.subplot(3, 3, 1);
|
||||||
ax = plt.subplot(3,3,2); plt.imshow(gt0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity GT {np.nanmin(gt0):.4f}/{np.nanmax(gt0):.4f}')
|
plt.imshow(es0[..., [2, 1, 0]]);
|
||||||
ax = plt.subplot(3,3,3); plt.imshow(diff0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Err. {diff0.mean():.5f}')
|
plt.xticks([]);
|
||||||
|
plt.yticks([]);
|
||||||
|
ax.set_title(f'F0 Disparity Est. {es0.min():.4f}/{es0.max():.4f}')
|
||||||
|
ax = plt.subplot(3, 3, 2);
|
||||||
|
plt.imshow(gt0[..., [2, 1, 0]]);
|
||||||
|
plt.xticks([]);
|
||||||
|
plt.yticks([]);
|
||||||
|
ax.set_title(f'F0 Disparity GT {np.nanmin(gt0):.4f}/{np.nanmax(gt0):.4f}')
|
||||||
|
ax = plt.subplot(3, 3, 3);
|
||||||
|
plt.imshow(diff0[..., [2, 1, 0]]);
|
||||||
|
plt.xticks([]);
|
||||||
|
plt.yticks([]);
|
||||||
|
ax.set_title(f'F0 Disparity Err. {diff0.mean():.5f}')
|
||||||
|
|
||||||
# plot disparities of the second frame in the track if exists
|
# plot disparities of the second frame in the track if exists
|
||||||
if es.shape[0] >= 2:
|
if es.shape[0] >= 2:
|
||||||
es1 = co.cmap.color_depth_map(es[1], scale=vmax)
|
es1 = co.cmap.color_depth_map(es[1], scale=vmax)
|
||||||
gt1 = co.cmap.color_depth_map(gt[1], scale=vmax)
|
gt1 = co.cmap.color_depth_map(gt[1], scale=vmax)
|
||||||
diff1 = co.cmap.color_error_image(diff[1], BGR=True)
|
diff1 = co.cmap.color_error_image(diff[1], BGR=True)
|
||||||
ax = plt.subplot(3,3,4); plt.imshow(es1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity Est. {es1.min():.4f}/{es1.max():.4f}')
|
ax = plt.subplot(3, 3, 4);
|
||||||
ax = plt.subplot(3,3,5); plt.imshow(gt1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity GT {np.nanmin(gt1):.4f}/{np.nanmax(gt1):.4f}')
|
plt.imshow(es1[..., [2, 1, 0]]);
|
||||||
ax = plt.subplot(3,3,6); plt.imshow(diff1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity Err. {diff1.mean():.5f}')
|
plt.xticks([]);
|
||||||
|
plt.yticks([]);
|
||||||
|
ax.set_title(f'F1 Disparity Est. {es1.min():.4f}/{es1.max():.4f}')
|
||||||
|
ax = plt.subplot(3, 3, 5);
|
||||||
|
plt.imshow(gt1[..., [2, 1, 0]]);
|
||||||
|
plt.xticks([]);
|
||||||
|
plt.yticks([]);
|
||||||
|
ax.set_title(f'F1 Disparity GT {np.nanmin(gt1):.4f}/{np.nanmax(gt1):.4f}')
|
||||||
|
ax = plt.subplot(3, 3, 6);
|
||||||
|
plt.imshow(diff1[..., [2, 1, 0]]);
|
||||||
|
plt.xticks([]);
|
||||||
|
plt.yticks([]);
|
||||||
|
ax.set_title(f'F1 Disparity Err. {diff1.mean():.5f}')
|
||||||
|
|
||||||
# plot normalized IR inputs
|
# plot normalized IR inputs
|
||||||
ax = plt.subplot(3,3,7); plt.imshow(im[0], vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 IR input {im[0].mean():.5f}/{im[0].std():.5f}')
|
ax = plt.subplot(3, 3, 7);
|
||||||
|
plt.imshow(im[0], vmin=im.min(), vmax=im.max(), cmap='gray');
|
||||||
|
plt.xticks([]);
|
||||||
|
plt.yticks([]);
|
||||||
|
ax.set_title(f'F0 IR input {im[0].mean():.5f}/{im[0].std():.5f}')
|
||||||
if es.shape[0] >= 2:
|
if es.shape[0] >= 2:
|
||||||
ax = plt.subplot(3,3,8); plt.imshow(im[1], vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 IR input {im[1].mean():.5f}/{im[1].std():.5f}')
|
ax = plt.subplot(3, 3, 8);
|
||||||
|
plt.imshow(im[1], vmin=im.min(), vmax=im.max(), cmap='gray');
|
||||||
|
plt.xticks([]);
|
||||||
|
plt.yticks([]);
|
||||||
|
ax.set_title(f'F1 IR input {im[1].mean():.5f}/{im[1].std():.5f}')
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.savefig(str(out_path))
|
plt.savefig(str(out_path))
|
||||||
@ -294,5 +330,6 @@ class Worker(torchext.Worker):
|
|||||||
ma = np.reshape(ma[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w])
|
ma = np.reshape(ma[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w])
|
||||||
return es, gt, im, ma
|
return es, gt, im, ma
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pass
|
pass
|
||||||
|
@ -61,6 +61,7 @@ class OutputLayerFactory(object):
|
|||||||
pos: estimate the absolute location
|
pos: estimate the absolute location
|
||||||
pos_row: independently estimate the absolute location per row
|
pos_row: independently estimate the absolute location per row
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, type='disp', params={}):
|
def __init__(self, type='disp', params={}):
|
||||||
self.type = type
|
self.type = type
|
||||||
self.params = params
|
self.params = params
|
||||||
@ -118,12 +119,13 @@ class MultiLinear(TimedModule):
|
|||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class DispNetS(TimedModule):
|
class DispNetS(TimedModule):
|
||||||
'''
|
'''
|
||||||
Disparity Decoder based on DispNetS
|
Disparity Decoder based on DispNetS
|
||||||
'''
|
'''
|
||||||
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False, channel_multiplier=1):
|
|
||||||
|
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False,
|
||||||
|
channel_multiplier=1):
|
||||||
super(DispNetS, self).__init__(mod_name='DispNetS')
|
super(DispNetS, self).__init__(mod_name='DispNetS')
|
||||||
|
|
||||||
self.output_ms = output_ms
|
self.output_ms = output_ms
|
||||||
@ -166,7 +168,6 @@ class DispNetS(TimedModule):
|
|||||||
self.predict_disp2 = output_facs(upconv_planes[5], imsizes[1])
|
self.predict_disp2 = output_facs(upconv_planes[5], imsizes[1])
|
||||||
self.predict_disp1 = output_facs(upconv_planes[6], imsizes[0])
|
self.predict_disp1 = output_facs(upconv_planes[6], imsizes[0])
|
||||||
|
|
||||||
|
|
||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.ConvTranspose2d):
|
if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.ConvTranspose2d):
|
||||||
@ -176,9 +177,11 @@ class DispNetS(TimedModule):
|
|||||||
|
|
||||||
def downsample_conv(self, in_planes, out_planes, kernel_size=3):
|
def downsample_conv(self, in_planes, out_planes, kernel_size=3):
|
||||||
if self.coordconv:
|
if self.coordconv:
|
||||||
conv = torchext.CoordConv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=(kernel_size-1)//2)
|
conv = torchext.CoordConv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2,
|
||||||
|
padding=(kernel_size - 1) // 2)
|
||||||
else:
|
else:
|
||||||
conv = torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=(kernel_size-1)//2)
|
conv = torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2,
|
||||||
|
padding=(kernel_size - 1) // 2)
|
||||||
return torch.nn.Sequential(
|
return torch.nn.Sequential(
|
||||||
conv,
|
conv,
|
||||||
torch.nn.ReLU(inplace=True),
|
torch.nn.ReLU(inplace=True),
|
||||||
@ -229,19 +232,22 @@ class DispNetS(TimedModule):
|
|||||||
disp4 = self.predict_disp4(out_iconv4)
|
disp4 = self.predict_disp4(out_iconv4)
|
||||||
|
|
||||||
out_upconv3 = self.crop_like(self.upconv3(out_iconv4), out_conv2)
|
out_upconv3 = self.crop_like(self.upconv3(out_iconv4), out_conv2)
|
||||||
disp4_up = self.crop_like(torch.nn.functional.interpolate(disp4, scale_factor=2, mode='bilinear', align_corners=False), out_conv2)
|
disp4_up = self.crop_like(
|
||||||
|
torch.nn.functional.interpolate(disp4, scale_factor=2, mode='bilinear', align_corners=False), out_conv2)
|
||||||
concat3 = torch.cat((out_upconv3, out_conv2, disp4_up), 1)
|
concat3 = torch.cat((out_upconv3, out_conv2, disp4_up), 1)
|
||||||
out_iconv3 = self.iconv3(concat3)
|
out_iconv3 = self.iconv3(concat3)
|
||||||
disp3 = self.predict_disp3(out_iconv3)
|
disp3 = self.predict_disp3(out_iconv3)
|
||||||
|
|
||||||
out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1)
|
out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1)
|
||||||
disp3_up = self.crop_like(torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
|
disp3_up = self.crop_like(
|
||||||
|
torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
|
||||||
concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1)
|
concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1)
|
||||||
out_iconv2 = self.iconv2(concat2)
|
out_iconv2 = self.iconv2(concat2)
|
||||||
disp2 = self.predict_disp2(out_iconv2)
|
disp2 = self.predict_disp2(out_iconv2)
|
||||||
|
|
||||||
out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x)
|
out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x)
|
||||||
disp2_up = self.crop_like(torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x)
|
disp2_up = self.crop_like(
|
||||||
|
torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x)
|
||||||
concat1 = torch.cat((out_upconv1, disp2_up), 1)
|
concat1 = torch.cat((out_upconv1, disp2_up), 1)
|
||||||
out_iconv1 = self.iconv1(concat1)
|
out_iconv1 = self.iconv1(concat1)
|
||||||
disp1 = self.predict_disp1(out_iconv1)
|
disp1 = self.predict_disp1(out_iconv1)
|
||||||
@ -256,6 +262,7 @@ class DispNetShallow(DispNetS):
|
|||||||
'''
|
'''
|
||||||
Edge Decoder based on DispNetS with fewer layers
|
Edge Decoder based on DispNetS with fewer layers
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False):
|
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False):
|
||||||
super(DispNetShallow, self).__init__(channels_in, imsizes, output_facs, output_ms, coordconv, weight_init)
|
super(DispNetShallow, self).__init__(channels_in, imsizes, output_facs, output_ms, coordconv, weight_init)
|
||||||
self.mod_name = 'DispNetShallow'
|
self.mod_name = 'DispNetShallow'
|
||||||
@ -274,13 +281,15 @@ class DispNetShallow(DispNetS):
|
|||||||
disp3 = self.predict_disp3(out_iconv3)
|
disp3 = self.predict_disp3(out_iconv3)
|
||||||
|
|
||||||
out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1)
|
out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1)
|
||||||
disp3_up = self.crop_like(torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
|
disp3_up = self.crop_like(
|
||||||
|
torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
|
||||||
concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1)
|
concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1)
|
||||||
out_iconv2 = self.iconv2(concat2)
|
out_iconv2 = self.iconv2(concat2)
|
||||||
disp2 = self.predict_disp2(out_iconv2)
|
disp2 = self.predict_disp2(out_iconv2)
|
||||||
|
|
||||||
out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x)
|
out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x)
|
||||||
disp2_up = self.crop_like(torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x)
|
disp2_up = self.crop_like(
|
||||||
|
torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x)
|
||||||
concat1 = torch.cat((out_upconv1, disp2_up), 1)
|
concat1 = torch.cat((out_upconv1, disp2_up), 1)
|
||||||
out_iconv1 = self.iconv1(concat1)
|
out_iconv1 = self.iconv1(concat1)
|
||||||
disp1 = self.predict_disp1(out_iconv1)
|
disp1 = self.predict_disp1(out_iconv1)
|
||||||
@ -295,10 +304,13 @@ class DispEdgeDecoders(TimedModule):
|
|||||||
'''
|
'''
|
||||||
Disparity Decoder and Edge Decoder
|
Disparity Decoder and Edge Decoder
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, *args, max_disp=128, **kwargs):
|
def __init__(self, *args, max_disp=128, **kwargs):
|
||||||
super(DispEdgeDecoders, self).__init__(mod_name='DispEdgeDecoders')
|
super(DispEdgeDecoders, self).__init__(mod_name='DispEdgeDecoders')
|
||||||
|
|
||||||
output_facs = [OutputLayerFactory( type='disp', params={ 'alpha': max_disp/(2**s), 'beta': 0, 'gamma': 1, 'offset': 3}) for s in range(4)]
|
output_facs = [
|
||||||
|
OutputLayerFactory(type='disp', params={'alpha': max_disp / (2 ** s), 'beta': 0, 'gamma': 1, 'offset': 3})
|
||||||
|
for s in range(4)]
|
||||||
self.disp_decoder = DispNetS(*args, output_facs=output_facs, **kwargs)
|
self.disp_decoder = DispNetS(*args, output_facs=output_facs, **kwargs)
|
||||||
|
|
||||||
output_facs = [OutputLayerFactory(type='linear') for s in range(4)]
|
output_facs = [OutputLayerFactory(type='linear') for s in range(4)]
|
||||||
@ -336,11 +348,11 @@ class PosToDepth(DispToDepth):
|
|||||||
return super().forward(disp)
|
return super().forward(disp)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class RectifiedPatternSimilarityLoss(TimedModule):
|
class RectifiedPatternSimilarityLoss(TimedModule):
|
||||||
'''
|
'''
|
||||||
Photometric Loss
|
Photometric Loss
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, im_height, im_width, pattern, loss_type='census_sad', loss_eps=0.5):
|
def __init__(self, im_height, im_width, pattern, loss_type='census_sad', loss_eps=0.5):
|
||||||
super().__init__(mod_name='RectifiedPatternSimilarityLoss')
|
super().__init__(mod_name='RectifiedPatternSimilarityLoss')
|
||||||
self.im_height = im_height
|
self.im_height = im_height
|
||||||
@ -377,10 +389,12 @@ class RectifiedPatternSimilarityLoss(TimedModule):
|
|||||||
val = (mask * diff).sum() / mask.sum()
|
val = (mask * diff).sum() / mask.sum()
|
||||||
return val, pattern_proj
|
return val, pattern_proj
|
||||||
|
|
||||||
|
|
||||||
class DisparityLoss(TimedModule):
|
class DisparityLoss(TimedModule):
|
||||||
'''
|
'''
|
||||||
Disparity Loss
|
Disparity Loss
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(mod_name='DisparityLoss')
|
super().__init__(mod_name='DisparityLoss')
|
||||||
self.sobel = SobelFilter(norm=False)
|
self.sobel = SobelFilter(norm=False)
|
||||||
@ -412,11 +426,11 @@ class DisparityLoss(TimedModule):
|
|||||||
return val
|
return val
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ProjectionBaseLoss(TimedModule):
|
class ProjectionBaseLoss(TimedModule):
|
||||||
'''
|
'''
|
||||||
Base module of the Geometric Loss
|
Base module of the Geometric Loss
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, K, Ki, im_height, im_width):
|
def __init__(self, K, Ki, im_height, im_width):
|
||||||
super().__init__(mod_name='ProjectionBaseLoss')
|
super().__init__(mod_name='ProjectionBaseLoss')
|
||||||
|
|
||||||
@ -465,7 +479,6 @@ class ProjectionBaseLoss(TimedModule):
|
|||||||
uv = uv[:, :, :2] / (torch.nn.functional.relu(d) + 1e-12)
|
uv = uv[:, :, :2] / (torch.nn.functional.relu(d) + 1e-12)
|
||||||
return uv, d
|
return uv, d
|
||||||
|
|
||||||
|
|
||||||
def tforward(self, depth0, R0, t0, R1, t1):
|
def tforward(self, depth0, R0, t0, R1, t1):
|
||||||
xyz = self.unproject(depth0, R0, t0)
|
xyz = self.unproject(depth0, R0, t0)
|
||||||
return self.project(xyz, R1, t1)
|
return self.project(xyz, R1, t1)
|
||||||
@ -475,6 +488,7 @@ class ProjectionDepthSimilarityLoss(ProjectionBaseLoss):
|
|||||||
'''
|
'''
|
||||||
Geometric Loss
|
Geometric Loss
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, *args, clamp=-1):
|
def __init__(self, *args, clamp=-1):
|
||||||
super().__init__(*args)
|
super().__init__(*args)
|
||||||
self.mod_name = 'ProjectionDepthSimilarityLoss'
|
self.mod_name = 'ProjectionDepthSimilarityLoss'
|
||||||
@ -503,11 +517,11 @@ class ProjectionDepthSimilarityLoss(ProjectionBaseLoss):
|
|||||||
return l0 + l1
|
return l0 + l1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class LCN(TimedModule):
|
class LCN(TimedModule):
|
||||||
'''
|
'''
|
||||||
Local Contract Normalization
|
Local Contract Normalization
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, radius, epsilon):
|
def __init__(self, radius, epsilon):
|
||||||
super().__init__(mod_name='LCN')
|
super().__init__(mod_name='LCN')
|
||||||
self.box_conv = torch.nn.Sequential(
|
self.box_conv = torch.nn.Sequential(
|
||||||
@ -533,11 +547,11 @@ class LCN(TimedModule):
|
|||||||
return (data - avgs) / stds, stds
|
return (data - avgs) / stds, stds
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SobelFilter(TimedModule):
|
class SobelFilter(TimedModule):
|
||||||
'''
|
'''
|
||||||
Sobel Filter
|
Sobel Filter
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, norm=False):
|
def __init__(self, norm=False):
|
||||||
super(SobelFilter, self).__init__(mod_name='SobelFilter')
|
super(SobelFilter, self).__init__(mod_name='SobelFilter')
|
||||||
kx = np.array([[-5, -4, 0, 4, 5],
|
kx = np.array([[-5, -4, 0, 4, 5],
|
||||||
@ -563,4 +577,3 @@ class SobelFilter(TimedModule):
|
|||||||
return torch.sqrt(gx ** 2 + gy ** 2 + 1e-8)
|
return torch.sqrt(gx ** 2 + gy ** 2 + 1e-8)
|
||||||
else:
|
else:
|
||||||
return torch.cat((gx, gy), dim=1)
|
return torch.cat((gx, gy), dim=1)
|
||||||
|
|
||||||
|
64
readme.md
64
readme.md
@ -6,7 +6,9 @@ This repository contains the code for the paper
|
|||||||
|
|
||||||
**[Connecting the Dots: Learning Representations for Active Monocular Depth Estimation](http://www.cvlibs.net/publications/Riegler2019CVPR.pdf)**
|
**[Connecting the Dots: Learning Representations for Active Monocular Depth Estimation](http://www.cvlibs.net/publications/Riegler2019CVPR.pdf)**
|
||||||
<br>
|
<br>
|
||||||
[Gernot Riegler](https://griegler.github.io/), [Yiyi Liao](https://yiyiliao.github.io/), [Simon Donne](https://avg.is.tuebingen.mpg.de/person/sdonne), [Vladlen Koltun](http://vladlen.info/), and [Andreas Geiger](http://www.cvlibs.net/)
|
[Gernot Riegler](https://griegler.github.io/), [Yiyi Liao](https://yiyiliao.github.io/)
|
||||||
|
, [Simon Donne](https://avg.is.tuebingen.mpg.de/person/sdonne), [Vladlen Koltun](http://vladlen.info/),
|
||||||
|
and [Andreas Geiger](http://www.cvlibs.net/)
|
||||||
<br>
|
<br>
|
||||||
[CVPR 2019](http://cvpr2019.thecvf.com/)
|
[CVPR 2019](http://cvpr2019.thecvf.com/)
|
||||||
|
|
||||||
@ -24,40 +26,45 @@ If you find this code useful for your research, please cite
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## Dependencies
|
## Dependencies
|
||||||
|
|
||||||
The network training/evaluation code is based on `Pytorch`.
|
The network training/evaluation code is based on `Pytorch`.
|
||||||
|
|
||||||
```
|
```
|
||||||
PyTorch>=1.1
|
PyTorch>=1.1
|
||||||
Cuda>=10.0
|
Cuda>=10.0
|
||||||
```
|
```
|
||||||
|
|
||||||
Updated on 07.06.2021: The code is now compatible with the latest Pytorch version (1.8).
|
Updated on 07.06.2021: The code is now compatible with the latest Pytorch version (1.8).
|
||||||
|
|
||||||
The other python packages can be installed with `anaconda`:
|
The other python packages can be installed with `anaconda`:
|
||||||
|
|
||||||
```
|
```
|
||||||
conda install --file requirements.txt
|
conda install --file requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
### Structured Light Renderer
|
### Structured Light Renderer
|
||||||
To train and evaluate our method in a controlled setting, we implemented an structured light renderer.
|
|
||||||
It can be used to render a virtual scene (arbitrary triangle mesh) with the structured light pattern projected from a customizable projector location.
|
To train and evaluate our method in a controlled setting, we implemented an structured light renderer. It can be used to
|
||||||
To build it, first make sure the correct `CUDA_LIBRARY_PATH` is set in `config.json`.
|
render a virtual scene (arbitrary triangle mesh) with the structured light pattern projected from a customizable
|
||||||
Afterwards, the renderer can be build by running `make` within the `renderer` directory.
|
projector location. To build it, first make sure the correct `CUDA_LIBRARY_PATH` is set in `config.json`. Afterwards,
|
||||||
|
the renderer can be build by running `make` within the `renderer` directory.
|
||||||
|
|
||||||
### PyTorch Extensions
|
### PyTorch Extensions
|
||||||
The network training/evaluation code is based on `PyTorch`.
|
|
||||||
We implemented some custom layers that need to be built in the `torchext` directory.
|
The network training/evaluation code is based on `PyTorch`. We implemented some custom layers that need to be built in
|
||||||
Simply change into this directory and run
|
the `torchext` directory. Simply change into this directory and run
|
||||||
|
|
||||||
```
|
```
|
||||||
python setup.py build_ext --inplace
|
python setup.py build_ext --inplace
|
||||||
```
|
```
|
||||||
|
|
||||||
### Baseline HyperDepth
|
### Baseline HyperDepth
|
||||||
As baseline we partially re-implemented the random forest based method [HyperDepth](http://openaccess.thecvf.com/content_cvpr_2016/papers/Fanello_HyperDepth_Learning_Depth_CVPR_2016_paper.pdf).
|
|
||||||
The code resided in the `hyperdepth` directory and is implemented in `C++11` with a Python wrapper written in `Cython`.
|
As baseline we partially re-implemented the random forest based
|
||||||
To build it change into the directory and run
|
method [HyperDepth](http://openaccess.thecvf.com/content_cvpr_2016/papers/Fanello_HyperDepth_Learning_Depth_CVPR_2016_paper.pdf)
|
||||||
|
. The code resided in the `hyperdepth` directory and is implemented in `C++11` with a Python wrapper written in `Cython`
|
||||||
|
. To build it change into the directory and run
|
||||||
|
|
||||||
```
|
```
|
||||||
python setup.py build_ext --inplace
|
python setup.py build_ext --inplace
|
||||||
@ -65,42 +72,59 @@ python setup.py build_ext --inplace
|
|||||||
|
|
||||||
## Running
|
## Running
|
||||||
|
|
||||||
|
|
||||||
### Creating Synthetic Data
|
### Creating Synthetic Data
|
||||||
To create synthetic data and save it locally, download [ShapeNet V2](https://www.shapenet.org/) and correct `SHAPENET_ROOT` in `config.json`. Then the data can be generated and saved to `DATA_ROOT` in `config.json` by running
|
|
||||||
|
To create synthetic data and save it locally, download [ShapeNet V2](https://www.shapenet.org/) and
|
||||||
|
correct `SHAPENET_ROOT` in `config.json`. Then the data can be generated and saved to `DATA_ROOT` in `config.json` by
|
||||||
|
running
|
||||||
|
|
||||||
```
|
```
|
||||||
./create_syn_data.sh
|
./create_syn_data.sh
|
||||||
```
|
```
|
||||||
If you are only interested in evaluating our pre-trained model, [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip) is a validation set that contains a small amount of images.
|
|
||||||
|
If you are only interested in evaluating our pre-trained
|
||||||
|
model, [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip) is a
|
||||||
|
validation set that contains a small amount of images.
|
||||||
|
|
||||||
### Training Network
|
### Training Network
|
||||||
|
|
||||||
As a first stage, it is recommended to train the disparity decoder and edge decoder without the geometric loss. To train the network on synthetic data for the first stage run
|
As a first stage, it is recommended to train the disparity decoder and edge decoder without the geometric loss. To train
|
||||||
|
the network on synthetic data for the first stage run
|
||||||
|
|
||||||
```
|
```
|
||||||
python train_val.py
|
python train_val.py
|
||||||
```
|
```
|
||||||
|
|
||||||
After the model is pretrained without the geometric loss, the full model can be trained from the initialized weights by running
|
After the model is pretrained without the geometric loss, the full model can be trained from the initialized weights by
|
||||||
|
running
|
||||||
|
|
||||||
```
|
```
|
||||||
python train_val.py --loss phge
|
python train_val.py --loss phge
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
### Evaluating Network
|
### Evaluating Network
|
||||||
|
|
||||||
To evaluate a specific checkpoint, e.g. the 50th epoch, one can run
|
To evaluate a specific checkpoint, e.g. the 50th epoch, one can run
|
||||||
|
|
||||||
```
|
```
|
||||||
python train_val.py --cmd retest --epoch 50
|
python train_val.py --cmd retest --epoch 50
|
||||||
```
|
```
|
||||||
|
|
||||||
### Evaluating a Pre-trained Model
|
### Evaluating a Pre-trained Model
|
||||||
We provide a model pre-trained using the photometric loss. Once you have prepared the synthetic dataset and changed `DATA_ROOT` in `config.json`, the pre-trained model can be evaluated on the validation set by running:
|
|
||||||
|
We provide a model pre-trained using the photometric loss. Once you have prepared the synthetic dataset and
|
||||||
|
changed `DATA_ROOT` in `config.json`, the pre-trained model can be evaluated on the validation set by running:
|
||||||
|
|
||||||
```
|
```
|
||||||
mkdir -p output
|
mkdir -p output
|
||||||
mkdir -p output/exp_syn
|
mkdir -p output/exp_syn
|
||||||
wget -O output/exp_syn/net_0099.params https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/net_0099.params
|
wget -O output/exp_syn/net_0099.params https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/net_0099.params
|
||||||
python train_val.py --cmd retest --epoch 99
|
python train_val.py --cmd retest --epoch 99
|
||||||
```
|
```
|
||||||
You can also download our validation set from [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip).
|
|
||||||
|
You can also download our validation set
|
||||||
|
from [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip).
|
||||||
|
|
||||||
## Acknowledgement
|
## Acknowledgement
|
||||||
|
|
||||||
This work was supported by the Intel Network on Intelligent Systems.
|
This work was supported by the Intel Network on Intelligent Systems.
|
||||||
|
@ -2,18 +2,19 @@ import torch
|
|||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class TestSet(object):
|
class TestSet(object):
|
||||||
def __init__(self, name, dset, test_frequency=1):
|
def __init__(self, name, dset, test_frequency=1):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.dset = dset
|
self.dset = dset
|
||||||
self.test_frequency = test_frequency
|
self.test_frequency = test_frequency
|
||||||
|
|
||||||
|
|
||||||
class TestSets(list):
|
class TestSets(list):
|
||||||
def append(self, name, dset, test_frequency=1):
|
def append(self, name, dset, test_frequency=1):
|
||||||
super().append(TestSet(name, dset, test_frequency))
|
super().append(TestSet(name, dset, test_frequency))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MultiDataset(torch.utils.data.Dataset):
|
class MultiDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, *datasets):
|
def __init__(self, *datasets):
|
||||||
self.current_epoch = 0
|
self.current_epoch = 0
|
||||||
@ -46,7 +47,6 @@ class MultiDataset(torch.utils.data.Dataset):
|
|||||||
return self.datasets[didx][sidx]
|
return self.datasets[didx][sidx]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BaseDataset(torch.utils.data.Dataset):
|
class BaseDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, train=True, fix_seed_per_epoch=False):
|
def __init__(self, train=True, fix_seed_per_epoch=False):
|
||||||
self.current_epoch = 0
|
self.current_epoch = 0
|
||||||
|
@ -2,6 +2,7 @@ import torch
|
|||||||
from . import ext_cpu
|
from . import ext_cpu
|
||||||
from . import ext_cuda
|
from . import ext_cuda
|
||||||
|
|
||||||
|
|
||||||
class NNFunction(torch.autograd.Function):
|
class NNFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, in0, in1):
|
def forward(ctx, in0, in1):
|
||||||
@ -16,6 +17,7 @@ class NNFunction(torch.autograd.Function):
|
|||||||
def backward(ctx, grad_out):
|
def backward(ctx, grad_out):
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
def nn(in0, in1):
|
def nn(in0, in1):
|
||||||
return NNFunction.apply(in0, in1)
|
return NNFunction.apply(in0, in1)
|
||||||
|
|
||||||
@ -34,9 +36,11 @@ class CrossCheckFunction(torch.autograd.Function):
|
|||||||
def backward(ctx, grad_out):
|
def backward(ctx, grad_out):
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
def crosscheck(in0, in1):
|
def crosscheck(in0, in1):
|
||||||
return CrossCheckFunction.apply(in0, in1)
|
return CrossCheckFunction.apply(in0, in1)
|
||||||
|
|
||||||
|
|
||||||
class ProjNNFunction(torch.autograd.Function):
|
class ProjNNFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, xyz0, xyz1, K, patch_size):
|
def forward(ctx, xyz0, xyz1, K, patch_size):
|
||||||
@ -51,11 +55,11 @@ class ProjNNFunction(torch.autograd.Function):
|
|||||||
def backward(ctx, grad_out):
|
def backward(ctx, grad_out):
|
||||||
return None, None, None, None
|
return None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
def proj_nn(xyz0, xyz1, K, patch_size):
|
def proj_nn(xyz0, xyz1, K, patch_size):
|
||||||
return ProjNNFunction.apply(xyz0, xyz1, K, patch_size)
|
return ProjNNFunction.apply(xyz0, xyz1, K, patch_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class XCorrVolFunction(torch.autograd.Function):
|
class XCorrVolFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, in0, in1, n_disps, block_size):
|
def forward(ctx, in0, in1, n_disps, block_size):
|
||||||
@ -70,12 +74,11 @@ class XCorrVolFunction(torch.autograd.Function):
|
|||||||
def backward(ctx, grad_out):
|
def backward(ctx, grad_out):
|
||||||
return None, None, None, None
|
return None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
def xcorrvol(in0, in1, n_disps, block_size):
|
def xcorrvol(in0, in1, n_disps, block_size):
|
||||||
return XCorrVolFunction.apply(in0, in1, n_disps, block_size)
|
return XCorrVolFunction.apply(in0, in1, n_disps, block_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class PhotometricLossFunction(torch.autograd.Function):
|
class PhotometricLossFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, es, ta, block_size, type, eps):
|
def forward(ctx, es, ta, block_size, type, eps):
|
||||||
@ -103,6 +106,7 @@ class PhotometricLossFunction(torch.autograd.Function):
|
|||||||
grad_es = ext_cpu.photometric_loss_backward(*args)
|
grad_es = ext_cpu.photometric_loss_backward(*args)
|
||||||
return grad_es, None, None, None, None
|
return grad_es, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
def photometric_loss(es, ta, block_size, type='mse', eps=0.1):
|
def photometric_loss(es, ta, block_size, type='mse', eps=0.1):
|
||||||
type = type.lower()
|
type = type.lower()
|
||||||
if type == 'mse':
|
if type == 'mse':
|
||||||
@ -117,6 +121,7 @@ def photometric_loss(es, ta, block_size, type='mse', eps=0.1):
|
|||||||
raise Exception('invalid loss type')
|
raise Exception('invalid loss type')
|
||||||
return PhotometricLossFunction.apply(es, ta, block_size, type, eps)
|
return PhotometricLossFunction.apply(es, ta, block_size, type, eps)
|
||||||
|
|
||||||
|
|
||||||
def photometric_loss_pytorch(es, ta, block_size, type='mse', eps=0.1):
|
def photometric_loss_pytorch(es, ta, block_size, type='mse', eps=0.1):
|
||||||
type = type.lower()
|
type = type.lower()
|
||||||
p = block_size // 2
|
p = block_size // 2
|
||||||
|
@ -4,11 +4,13 @@ import numpy as np
|
|||||||
|
|
||||||
from .functions import *
|
from .functions import *
|
||||||
|
|
||||||
|
|
||||||
class CoordConv2d(torch.nn.Module):
|
class CoordConv2d(torch.nn.Module):
|
||||||
def __init__(self, channels_in, channels_out, kernel_size, stride, padding):
|
def __init__(self, channels_in, channels_out, kernel_size, stride, padding):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.conv = torch.nn.Conv2d(channels_in+2, channels_out, kernel_size=kernel_size, padding=padding, stride=stride)
|
self.conv = torch.nn.Conv2d(channels_in + 2, channels_out, kernel_size=kernel_size, padding=padding,
|
||||||
|
stride=stride)
|
||||||
|
|
||||||
self.uv = None
|
self.uv = None
|
||||||
|
|
||||||
|
@ -14,7 +14,8 @@ setup(
|
|||||||
name='ext',
|
name='ext',
|
||||||
ext_modules=[
|
ext_modules=[
|
||||||
CppExtension('ext_cpu', ['ext/ext_cpu.cpp']),
|
CppExtension('ext_cpu', ['ext/ext_cpu.cpp']),
|
||||||
CUDAExtension('ext_cuda', ['ext/ext_cuda.cpp', 'ext/ext_kernel.cu'], extra_compile_args={'cxx': [], 'nvcc': nvcc_args}),
|
CUDAExtension('ext_cuda', ['ext/ext_cuda.cpp', 'ext/ext_kernel.cu'],
|
||||||
|
extra_compile_args={'cxx': [], 'nvcc': nvcc_args}),
|
||||||
],
|
],
|
||||||
cmdclass={'build_ext': BuildExtension},
|
cmdclass={'build_ext': BuildExtension},
|
||||||
include_dirs=include_dirs
|
include_dirs=include_dirs
|
||||||
|
@ -40,6 +40,7 @@ class StopWatch(object):
|
|||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
|
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
|
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
|
||||||
|
|
||||||
@ -77,8 +78,10 @@ class ETA(object):
|
|||||||
def get_remaining_time_str(self):
|
def get_remaining_time_str(self):
|
||||||
return self.format_time(self.get_remaining_time())
|
return self.format_time(self.get_remaining_time())
|
||||||
|
|
||||||
|
|
||||||
class Worker(object):
|
class Worker(object):
|
||||||
def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16, num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1):
|
def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16,
|
||||||
|
num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1):
|
||||||
self.out_root = Path(out_root)
|
self.out_root = Path(out_root)
|
||||||
self.experiment_name = experiment_name
|
self.experiment_name = experiment_name
|
||||||
self.epochs = epochs
|
self.epochs = epochs
|
||||||
@ -237,7 +240,6 @@ class Worker(object):
|
|||||||
plt.savefig(str(err_img_path))
|
plt.savefig(str(err_img_path))
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
def callback_train_new_epoch(self, epoch, net, optimizer):
|
def callback_train_new_epoch(self, epoch, net, optimizer):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -269,7 +271,6 @@ class Worker(object):
|
|||||||
curr_state.update(state['state_dict'])
|
curr_state.update(state['state_dict'])
|
||||||
net.load_state_dict(curr_state)
|
net.load_state_dict(curr_state)
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
optimizer.load_state_dict(state['optimizer'])
|
optimizer.load_state_dict(state['optimizer'])
|
||||||
except:
|
except:
|
||||||
@ -367,7 +368,8 @@ class Worker(object):
|
|||||||
logging.info('Train epoch %d' % epoch)
|
logging.info('Train epoch %d' % epoch)
|
||||||
|
|
||||||
dset.current_epoch = epoch
|
dset.current_epoch = epoch
|
||||||
train_loader = torch.utils.data.DataLoader(dset, batch_size=self.train_batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True, pin_memory=False)
|
train_loader = torch.utils.data.DataLoader(dset, batch_size=self.train_batch_size, shuffle=True,
|
||||||
|
num_workers=self.num_workers, drop_last=True, pin_memory=False)
|
||||||
|
|
||||||
net = net.to(self.train_device)
|
net = net.to(self.train_device)
|
||||||
net.train()
|
net.train()
|
||||||
@ -418,10 +420,10 @@ class Worker(object):
|
|||||||
bar.update(batch_idx)
|
bar.update(batch_idx)
|
||||||
if (epoch <= 1 and batch_idx < 128) or batch_idx % 16 == 0:
|
if (epoch <= 1 and batch_idx < 128) or batch_idx % 16 == 0:
|
||||||
err_str = self.format_err_str(errs)
|
err_str = self.format_err_str(errs)
|
||||||
logging.info(f'train e{epoch}: {batch_idx+1}/{len(train_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
|
logging.info(
|
||||||
|
f'train e{epoch}: {batch_idx + 1}/{len(train_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
|
||||||
# self.write_err_img()
|
# self.write_err_img()
|
||||||
|
|
||||||
|
|
||||||
if mean_loss is None:
|
if mean_loss is None:
|
||||||
mean_loss = [0 for e in errs]
|
mean_loss = [0 for e in errs]
|
||||||
for erridx, err in enumerate(errs):
|
for erridx, err in enumerate(errs):
|
||||||
@ -465,7 +467,8 @@ class Worker(object):
|
|||||||
logging.info('-' * 80)
|
logging.info('-' * 80)
|
||||||
logging.info('Test epoch %d' % epoch)
|
logging.info('Test epoch %d' % epoch)
|
||||||
dset.current_epoch = epoch
|
dset.current_epoch = epoch
|
||||||
test_loader = torch.utils.data.DataLoader(dset, batch_size=self.test_batch_size, shuffle=False, num_workers=self.num_workers, drop_last=False, pin_memory=False)
|
test_loader = torch.utils.data.DataLoader(dset, batch_size=self.test_batch_size, shuffle=False,
|
||||||
|
num_workers=self.num_workers, drop_last=False, pin_memory=False)
|
||||||
|
|
||||||
net = net.to(self.test_device)
|
net = net.to(self.test_device)
|
||||||
net.eval()
|
net.eval()
|
||||||
@ -502,7 +505,8 @@ class Worker(object):
|
|||||||
bar.update(batch_idx)
|
bar.update(batch_idx)
|
||||||
if batch_idx % 25 == 0:
|
if batch_idx % 25 == 0:
|
||||||
err_str = self.format_err_str(errs)
|
err_str = self.format_err_str(errs)
|
||||||
logging.info(f'test e{epoch}: {batch_idx+1}/{len(test_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
|
logging.info(
|
||||||
|
f'test e{epoch}: {batch_idx + 1}/{len(test_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
|
||||||
|
|
||||||
if mean_loss is None:
|
if mean_loss is None:
|
||||||
mean_loss = [0 for e in errs]
|
mean_loss = [0 for e in errs]
|
||||||
|
@ -5,7 +5,6 @@ from model import exp_synphge
|
|||||||
from model import networks
|
from model import networks
|
||||||
from co.args import parse_args
|
from co.args import parse_args
|
||||||
|
|
||||||
|
|
||||||
# parse args
|
# parse args
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
@ -19,11 +18,11 @@ elif args.loss=='phge':
|
|||||||
channels_in = 2
|
channels_in = 2
|
||||||
|
|
||||||
# set up network
|
# set up network
|
||||||
net = networks.DispEdgeDecoders(channels_in=channels_in, max_disp=args.max_disp, imsizes=worker.imsizes, output_ms=worker.ms)
|
net = networks.DispEdgeDecoders(channels_in=channels_in, max_disp=args.max_disp, imsizes=worker.imsizes,
|
||||||
|
output_ms=worker.ms)
|
||||||
|
|
||||||
# optimizer
|
# optimizer
|
||||||
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
|
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
|
||||||
|
|
||||||
# start the work
|
# start the work
|
||||||
worker.do(net, optimizer)
|
worker.do(net, optimizer)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user