Reformat $EVERYTHING

master
CptCaptain 3 years ago
parent 56f2aa7d5d
commit 43df77fb9b
  1. 1
      co/__init__.py
  2. 5
      co/args.py
  3. 39
      co/cmap.py
  4. 584
      co/geometry.py
  5. 16
      co/gtimer.py
  6. 53
      co/io3d.py
  7. 30
      co/metric.py
  8. 11
      co/plt.py
  9. 23
      co/plt2d.py
  10. 34
      co/plt3d.py
  11. 66
      co/table.py
  12. 10
      co/utils.py
  13. 66
      data/commons.py
  14. 106
      data/create_syn_data.py
  15. 50
      data/dataset.py
  16. 395
      data/lcn/lcn.html
  17. 2
      data/lcn/setup.py
  18. 24
      data/lcn/test_lcn.py
  19. 19
      hyperdepth/hyperparam_search.py
  20. 7
      hyperdepth/setup.py
  21. 13
      hyperdepth/vis_eval.py
  22. 153
      model/exp_synph.py
  23. 163
      model/exp_synphge.py
  24. 161
      model/networks.py
  25. 64
      readme.md
  26. 2
      renderer/setup.py
  27. 4
      torchext/dataset.py
  28. 19
      torchext/functions.py
  29. 6
      torchext/modules.py
  30. 3
      torchext/setup.py
  31. 64
      torchext/worker.py
  32. 11
      train_val.py

@ -7,6 +7,7 @@
# set matplotlib backend depending on env
import os
import matplotlib
if os.name == 'posix' and "DISPLAY" not in os.environ:
matplotlib.use('Agg')

@ -12,7 +12,7 @@ def parse_args():
parser.add_argument('--loss',
help='Train with \'ph\' for the first stage without geometric loss, \
train with \'phge\' for the second stage with geometric loss',
default='ph', choices=['ph','phge'], type=str)
default='ph', choices=['ph', 'phge'], type=str)
parser.add_argument('--data_type',
default='syn', choices=['syn'], type=str)
#
@ -66,6 +66,3 @@ def parse_args():
def get_exp_name(args):
name = f"exp_{args.data_type}"
return name

@ -1,19 +1,20 @@
import numpy as np
_color_map_errors = np.array([
[149, 54, 49], #0: log2(x) = -infinity
[180, 117, 69], #0.0625: log2(x) = -4
[209, 173, 116], #0.125: log2(x) = -3
[233, 217, 171], #0.25: log2(x) = -2
[248, 243, 224], #0.5: log2(x) = -1
[144, 224, 254], #1.0: log2(x) = 0
[97, 174, 253], #2.0: log2(x) = 1
[67, 109, 244], #4.0: log2(x) = 2
[39, 48, 215], #8.0: log2(x) = 3
[38, 0, 165], #16.0: log2(x) = 4
[38, 0, 165] #inf: log2(x) = inf
[149, 54, 49], # 0: log2(x) = -infinity
[180, 117, 69], # 0.0625: log2(x) = -4
[209, 173, 116], # 0.125: log2(x) = -3
[233, 217, 171], # 0.25: log2(x) = -2
[248, 243, 224], # 0.5: log2(x) = -1
[144, 224, 254], # 1.0: log2(x) = 0
[97, 174, 253], # 2.0: log2(x) = 1
[67, 109, 244], # 4.0: log2(x) = 2
[39, 48, 215], # 8.0: log2(x) = 3
[38, 0, 165], # 16.0: log2(x) = 4
[38, 0, 165] # inf: log2(x) = inf
]).astype(float)
def color_error_image(errors, scale=1, mask=None, BGR=True):
"""
Color an input error map.
@ -32,16 +33,18 @@ def color_error_image(errors, scale=1, mask=None, BGR=True):
errors_color_indices = np.clip(np.log2(errors_flat / scale + 1e-5) + 5, 0, 9)
i0 = np.floor(errors_color_indices).astype(int)
f1 = errors_color_indices - i0.astype(float)
colored_errors_flat = _color_map_errors[i0, :] * (1-f1).reshape(-1,1) + _color_map_errors[i0+1, :] * f1.reshape(-1,1)
colored_errors_flat = _color_map_errors[i0, :] * (1 - f1).reshape(-1, 1) + _color_map_errors[i0 + 1,
:] * f1.reshape(-1, 1)
if mask is not None:
colored_errors_flat[mask.flatten() == 0] = 255
if not BGR:
colored_errors_flat = colored_errors_flat[:,[2,1,0]]
colored_errors_flat = colored_errors_flat[:, [2, 1, 0]]
return colored_errors_flat.reshape(errors.shape[0], errors.shape[1], 3).astype(np.int)
_color_map_depths = np.array([
[0, 0, 0], # 0.000
[0, 0, 255], # 0.114
@ -65,6 +68,7 @@ _color_map_bincenters = np.array([
2.000, # doesn't make a difference, just strictly higher than 1
])
def color_depth_map(depths, scale=None):
"""
Color an input depth map.
@ -82,12 +86,13 @@ def color_depth_map(depths, scale=None):
values = np.clip(depths.flatten() / scale, 0, 1)
# for each value, figure out where they fit in in the bincenters: what is the last bincenter smaller than this value?
lower_bin = ((values.reshape(-1, 1) >= _color_map_bincenters.reshape(1,-1)) * np.arange(0,9)).max(axis=1)
lower_bin = ((values.reshape(-1, 1) >= _color_map_bincenters.reshape(1, -1)) * np.arange(0, 9)).max(axis=1)
lower_bin_value = _color_map_bincenters[lower_bin]
higher_bin_value = _color_map_bincenters[lower_bin + 1]
alphas = (values - lower_bin_value) / (higher_bin_value - lower_bin_value)
colors = _color_map_depths[lower_bin] * (1-alphas).reshape(-1,1) + _color_map_depths[lower_bin + 1] * alphas.reshape(-1,1)
colors = _color_map_depths[lower_bin] * (1 - alphas).reshape(-1, 1) + _color_map_depths[
lower_bin + 1] * alphas.reshape(-1, 1)
return colors.reshape(depths.shape[0], depths.shape[1], 3).astype(np.uint8)
#from utils.debug import save_color_numpy
#save_color_numpy(color_depth_map(np.matmul(np.ones((100,1)), np.arange(0,1200).reshape(1,1200)), scale=1000))
# from utils.debug import save_color_numpy
# save_color_numpy(color_depth_map(np.matmul(np.ones((100,1)), np.arange(0,1200).reshape(1,1200)), scale=1000))

File diff suppressed because it is too large Load Diff

@ -2,25 +2,31 @@ import numpy as np
from . import utils
class StopWatch(utils.StopWatch):
def __del__(self):
print('='*80)
print('=' * 80)
print('gtimer:')
total = ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get(reduce=np.sum).items()])
total = ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get(reduce=np.sum).items()])
print(f' [total] {total}')
mean = ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get(reduce=np.mean).items()])
mean = ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get(reduce=np.mean).items()])
print(f' [mean] {mean}')
median = ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get(reduce=np.median).items()])
median = ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get(reduce=np.median).items()])
print(f' [median] {median}')
print('='*80)
print('=' * 80)
GTIMER = StopWatch()
def start(name):
GTIMER.start(name)
def stop(name):
GTIMER.stop(name)
class Ctx(object):
def __init__(self, name):
self.name = name

@ -2,12 +2,13 @@ import struct
import numpy as np
import collections
def _write_ply_point(fp, x,y,z, color=None, normal=None, binary=False):
args = [x,y,z]
def _write_ply_point(fp, x, y, z, color=None, normal=None, binary=False):
args = [x, y, z]
if color is not None:
args += [int(color[0]), int(color[1]), int(color[2])]
if normal is not None:
args += [normal[0],normal[1],normal[2]]
args += [normal[0], normal[1], normal[2]]
if binary:
fmt = '<fff'
if color is not None:
@ -24,11 +25,13 @@ def _write_ply_point(fp, x,y,z, color=None, normal=None, binary=False):
fmt += '\n'
fp.write(fmt % tuple(args))
def _write_ply_triangle(fp, i0,i1,i2, binary):
def _write_ply_triangle(fp, i0, i1, i2, binary):
if binary:
fp.write(struct.pack('<Biii', 3,i0,i1,i2))
fp.write(struct.pack('<Biii', 3, i0, i1, i2))
else:
fp.write('3 %d %d %d\n' % (i0,i1,i2))
fp.write('3 %d %d %d\n' % (i0, i1, i2))
def _write_ply_header_line(fp, str, binary):
if binary:
@ -36,6 +39,7 @@ def _write_ply_header_line(fp, str, binary):
else:
fp.write(str)
def write_ply(path, verts, trias=None, color=None, normals=None, binary=False):
if verts.shape[1] != 3:
raise Exception('verts has to be of shape Nx3')
@ -82,11 +86,12 @@ def write_ply(path, verts, trias=None, color=None, normals=None, binary=False):
n = None
else:
n = normals[vidx]
_write_ply_point(fp, v[0],v[1],v[2], c, n, binary)
_write_ply_point(fp, v[0], v[1], v[2], c, n, binary)
if trias is not None:
for t in trias:
_write_ply_triangle(fp, t[0],t[1],t[2], binary)
_write_ply_triangle(fp, t[0], t[1], t[2], binary)
def faces_to_triangles(faces):
new_faces = []
@ -100,6 +105,7 @@ def faces_to_triangles(faces):
raise Exception('unknown face count %d', f[0])
return new_faces
def read_ply(path):
with open(path, 'rb') as f:
# parse header
@ -152,7 +158,7 @@ def read_ply(path):
sz = n_verts * vert_bin_len
fmt = ','.join(vert_bin_format)
verts = np.ndarray(shape=(1, n_verts), dtype=np.dtype(fmt), buffer=f.read(sz))
verts = verts[0].astype(vert_bin_cols*'f4,').view(dtype='f4').reshape((n_verts,-1))
verts = verts[0].astype(vert_bin_cols * 'f4,').view(dtype='f4').reshape((n_verts, -1))
faces = []
for idx in range(n_faces):
fmt = '<Biii'
@ -172,21 +178,21 @@ def read_ply(path):
for idx in range(n_faces):
splits = f.readline().decode().strip().split(' ')
n_face_verts = int(splits[0])
vals = [int(v) for v in splits[0:n_face_verts+1]]
vals = [int(v) for v in splits[0:n_face_verts + 1]]
faces.append(vals)
faces = faces_to_triangles(faces)
faces = np.array(faces, dtype=np.int32)
xyz = None
if 'x' in vert_types and 'y' in vert_types and 'z' in vert_types:
xyz = verts[:,[vert_types['x'], vert_types['y'], vert_types['z']]]
xyz = verts[:, [vert_types['x'], vert_types['y'], vert_types['z']]]
colors = None
if 'red' in vert_types and 'green' in vert_types and 'blue' in vert_types:
colors = verts[:,[vert_types['red'], vert_types['green'], vert_types['blue']]]
colors = verts[:, [vert_types['red'], vert_types['green'], vert_types['blue']]]
colors /= 255
normals = None
if 'nx' in vert_types and 'ny' in vert_types and 'nz' in vert_types:
normals = verts[:,[vert_types['nx'], vert_types['ny'], vert_types['nz']]]
normals = verts[:, [vert_types['nx'], vert_types['ny'], vert_types['nz']]]
return xyz, faces, colors, normals
@ -204,6 +210,7 @@ def _read_obj_split_f(s):
nidx = -1
return vidx, tidx, nidx
def read_obj(path):
with open(path, 'r') as fp:
lines = fp.readlines()
@ -221,19 +228,19 @@ def read_obj(path):
parts = line.split()
if line.startswith('v '):
parts = parts[1:]
x,y,z = float(parts[0]), float(parts[1]), float(parts[2])
x, y, z = float(parts[0]), float(parts[1]), float(parts[2])
if len(parts) == 4 or len(parts) == 7:
w = float(parts[3])
x,y,z = x/w, y/w, z/w
verts.append((x,y,z))
x, y, z = x / w, y / w, z / w
verts.append((x, y, z))
if len(parts) >= 6:
r,g,b = float(parts[-3]), float(parts[-2]), float(parts[-1])
rgb.append((r,g,b))
r, g, b = float(parts[-3]), float(parts[-2]), float(parts[-1])
rgb.append((r, g, b))
elif line.startswith('vn '):
parts = parts[1:]
x,y,z = float(parts[0]), float(parts[1]), float(parts[2])
fnorms.append((x,y,z))
x, y, z = float(parts[0]), float(parts[1]), float(parts[2])
fnorms.append((x, y, z))
elif line.startswith('f '):
parts = parts[1:]
@ -245,11 +252,11 @@ def read_obj(path):
faces.append((vidx0, vidx1, vidx2))
if nidx0 >= 0:
fnorm_map[vidx0].append( nidx0 )
fnorm_map[vidx0].append(nidx0)
if nidx1 >= 0:
fnorm_map[vidx1].append( nidx1 )
fnorm_map[vidx1].append(nidx1)
if nidx2 >= 0:
fnorm_map[vidx2].append( nidx2 )
fnorm_map[vidx2].append(nidx2)
verts = np.array(verts)
colors = np.array(colors)

@ -1,6 +1,7 @@
import numpy as np
from . import geometry
def _process_inputs(estimate, target, mask):
if estimate.shape != target.shape:
raise Exception('estimate and target have to be same shape')
@ -12,19 +13,23 @@ def _process_inputs(estimate, target, mask):
raise Exception('estimate and mask have to be same shape')
return estimate, target, mask
def mse(estimate, target, mask=None):
estimate, target, mask = _process_inputs(estimate, target, mask)
m = np.sum((estimate[mask] - target[mask])**2) / mask.sum()
m = np.sum((estimate[mask] - target[mask]) ** 2) / mask.sum()
return m
def rmse(estimate, target, mask=None):
return np.sqrt(mse(estimate, target, mask))
def mae(estimate, target, mask=None):
estimate, target, mask = _process_inputs(estimate, target, mask)
m = np.abs(estimate[mask] - target[mask]).sum() / mask.sum()
return m
def outlier_fraction(estimate, target, mask=None, threshold=0):
estimate, target, mask = _process_inputs(estimate, target, mask)
diff = np.abs(estimate[mask] - target[mask])
@ -52,6 +57,7 @@ class Metric(object):
def __str__(self):
return ', '.join([f'{self.str_prefix}{key}={value:.5f}' for key, value in self.get().items()])
class MultipleMetric(Metric):
def __init__(self, *metrics, **kwargs):
self.metrics = [*metrics]
@ -76,6 +82,7 @@ class MultipleMetric(Metric):
def __str__(self):
return '\n'.join([str(m) for m in self.metrics])
class BaseDistanceMetric(Metric):
def __init__(self, name='', **kwargs):
super().__init__(**kwargs)
@ -99,6 +106,7 @@ class BaseDistanceMetric(Metric):
f'dist{self.name}_max': float(np.max(dists)),
}
class DistanceMetric(BaseDistanceMetric):
def __init__(self, vec_length, p=2, **kwargs):
super().__init__(name=f'{p}', **kwargs)
@ -113,7 +121,8 @@ class DistanceMetric(BaseDistanceMetric):
es = es[ma != 0]
ta = ta[ma != 0]
dist = np.linalg.norm(es - ta, ord=self.p, axis=1)
self.dists.append( dist )
self.dists.append(dist)
class OutlierFractionMetric(DistanceMetric):
def __init__(self, thresholds, *args, **kwargs):
@ -128,6 +137,7 @@ class OutlierFractionMetric(DistanceMetric):
ret[f'of{t}'] = float(ma.sum() / ma.size)
return ret
class RelativeDistanceMetric(BaseDistanceMetric):
def __init__(self, vec_length, p=2, **kwargs):
super().__init__(name=f'rel{p}', **kwargs)
@ -142,7 +152,8 @@ class RelativeDistanceMetric(BaseDistanceMetric):
dist /= denom
if ma is not None:
dist = dist[ma != 0]
self.dists.append( dist )
self.dists.append(dist)
class RotmDistanceMetric(BaseDistanceMetric):
def __init__(self, type='identity', **kwargs):
@ -156,12 +167,13 @@ class RotmDistanceMetric(BaseDistanceMetric):
if ma is not None:
raise Exception('mask is not implemented')
if self.type == 'identity':
self.dists.append( geometry.rotm_distance_identity(es, ta) )
self.dists.append(geometry.rotm_distance_identity(es, ta))
elif self.type == 'geodesic':
self.dists.append( geometry.rotm_distance_geodesic_unit_sphere(es, ta) )
self.dists.append(geometry.rotm_distance_geodesic_unit_sphere(es, ta))
else:
raise Exception('invalid distance type')
class QuaternionDistanceMetric(BaseDistanceMetric):
def __init__(self, type='angle', **kwargs):
super().__init__(name=type, **kwargs)
@ -174,11 +186,11 @@ class QuaternionDistanceMetric(BaseDistanceMetric):
if ma is not None:
raise Exception('mask is not implemented')
if self.type == 'angle':
self.dists.append( geometry.quat_distance_angle(es, ta) )
self.dists.append(geometry.quat_distance_angle(es, ta))
elif self.type == 'mineucl':
self.dists.append( geometry.quat_distance_mineucl(es, ta) )
self.dists.append(geometry.quat_distance_mineucl(es, ta))
elif self.type == 'normdiff':
self.dists.append( geometry.quat_distance_normdiff(es, ta) )
self.dists.append(geometry.quat_distance_normdiff(es, ta))
else:
raise Exception('invalid distance type')
@ -241,7 +253,7 @@ class BinaryAccuracyMetric(Metric):
accuracies = np.divide(tps + tns, tps + tns + fps + fns)
aacc = np.mean(accuracies)
for t in np.linspace(0,1,num=11)[1:-1]:
for t in np.linspace(0, 1, num=11)[1:-1]:
idx = np.argmin(np.abs(t - wp))
ret[f'acc{wp[idx]:.2f}'] = float(accuracies[idx])

@ -6,6 +6,7 @@ import matplotlib.pyplot as plt
import os
import time
def save(path, remove_axis=False, dpi=300, fig=None):
if fig is None:
fig = plt.gcf()
@ -15,13 +16,14 @@ def save(path, remove_axis=False, dpi=300, fig=None):
if remove_axis:
for ax in fig.axes:
ax.axis('off')
ax.margins(0,0)
ax.margins(0, 0)
fig.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
for ax in fig.axes:
ax.xaxis.set_major_locator(plt.NullLocator())
ax.yaxis.set_major_locator(plt.NullLocator())
fig.savefig(path, dpi=dpi, bbox_inches='tight', pad_inches=0)
def color_map(im_, cmap='viridis', vmin=None, vmax=None):
cm = plt.get_cmap(cmap)
im = im_.copy()
@ -33,11 +35,12 @@ def color_map(im_, cmap='viridis', vmin=None, vmax=None):
im[mask] = vmin
im = (im.clip(vmin, vmax) - vmin) / (vmax - vmin)
im = cm(im)
im = im[...,:3]
im = im[..., :3]
for c in range(3):
im[mask, c] = 1
return im
def interactive_legend(leg=None, fig=None, all_axes=True):
if leg is None:
leg = plt.legend()
@ -60,7 +63,7 @@ def interactive_legend(leg=None, fig=None, all_axes=True):
def onpick(event):
if event.mouseevent.dblclick:
tmp = [(k,v) for k,v in lined.items()]
tmp = [(k, v) for k, v in lined.items()]
else:
tmp = [(event.artist, lined[event.artist])]
@ -76,6 +79,7 @@ def interactive_legend(leg=None, fig=None, all_axes=True):
fig.canvas.mpl_connect('pick_event', onpick)
def non_annoying_pause(interval, focus_figure=False):
# https://github.com/matplotlib/matplotlib/issues/11131
backend = mpl.rcParams['backend']
@ -91,6 +95,7 @@ def non_annoying_pause(interval, focus_figure=False):
return
time.sleep(interval)
def remove_all_ticks(fig=None):
if fig is None:
fig = plt.gcf()

@ -3,21 +3,23 @@ import matplotlib.pyplot as plt
from . import geometry
def image_matrix(ims, bgval=0):
n = ims.shape[0]
m = int( np.ceil(np.sqrt(n)) )
m = int(np.ceil(np.sqrt(n)))
h = ims.shape[1]
w = ims.shape[2]
mat = np.empty((m*h, m*w), dtype=ims.dtype)
mat = np.empty((m * h, m * w), dtype=ims.dtype)
mat.fill(bgval)
idx = 0
for r in range(m):
for c in range(m):
if idx < n:
mat[r*h:(r+1)*h, c*w:(c+1)*w] = ims[idx]
mat[r * h:(r + 1) * h, c * w:(c + 1) * w] = ims[idx]
idx += 1
return mat
def image_cat(ims, vertical=False):
offx = [0]
offy = [0]
@ -34,20 +36,23 @@ def image_cat(ims, vertical=False):
offx = np.cumsum(offx)
offy = np.cumsum(offy)
im = np.zeros((height,width,*ims[0].shape[2:]), dtype=ims[0].dtype)
im = np.zeros((height, width, *ims[0].shape[2:]), dtype=ims[0].dtype)
for im0, ox, oy in zip(ims, offx, offy):
im[oy:oy + im0.shape[0], ox:ox + im0.shape[1]] = im0
return im, offx, offy
def line(li, h, w, ax=None, *args, **kwargs):
if ax is None:
ax = plt.gca()
xs = (-li[2] - li[1] * np.array((0, h-1))) / li[0]
ys = (-li[2] - li[0] * np.array((0, w-1))) / li[1]
pts = np.array([(0,ys[0]), (w-1, ys[1]), (xs[0], 0), (xs[1], h-1)])
pts = pts[np.logical_and(np.logical_and(pts[:,0] >= 0, pts[:,0] < w), np.logical_and(pts[:,1] >= 0, pts[:,1] < h))]
ax.plot(pts[:,0], pts[:,1], *args, **kwargs)
xs = (-li[2] - li[1] * np.array((0, h - 1))) / li[0]
ys = (-li[2] - li[0] * np.array((0, w - 1))) / li[1]
pts = np.array([(0, ys[0]), (w - 1, ys[1]), (xs[0], 0), (xs[1], h - 1)])
pts = pts[
np.logical_and(np.logical_and(pts[:, 0] >= 0, pts[:, 0] < w), np.logical_and(pts[:, 1] >= 0, pts[:, 1] < h))]
ax.plot(pts[:, 0], pts[:, 1], *args, **kwargs)
def depthshow(depth, *args, ax=None, **kwargs):
if ax is None:

@ -4,35 +4,45 @@ from mpl_toolkits.mplot3d import Axes3D
from . import geometry
def ax3d(fig=None):
if fig is None:
fig = plt.gcf()
return fig.add_subplot(111, projection='3d')
def plot_camera(ax=None, R=np.eye(3), t=np.zeros((3,)), size=25, marker_C='.', color='b', linestyle='-', linewidth=0.1, label=None, **kwargs):
def plot_camera(ax=None, R=np.eye(3), t=np.zeros((3,)), size=25, marker_C='.', color='b', linestyle='-', linewidth=0.1,
label=None, **kwargs):
if ax is None:
ax = plt.gca()
C0 = geometry.translation_to_cameracenter(R, t).ravel()
C1 = C0 + R.T.dot( np.array([[-size],[-size],[3*size]], dtype=np.float32) ).ravel()
C2 = C0 + R.T.dot( np.array([[-size],[+size],[3*size]], dtype=np.float32) ).ravel()
C3 = C0 + R.T.dot( np.array([[+size],[+size],[3*size]], dtype=np.float32) ).ravel()
C4 = C0 + R.T.dot( np.array([[+size],[-size],[3*size]], dtype=np.float32) ).ravel()
C1 = C0 + R.T.dot(np.array([[-size], [-size], [3 * size]], dtype=np.float32)).ravel()
C2 = C0 + R.T.dot(np.array([[-size], [+size], [3 * size]], dtype=np.float32)).ravel()
C3 = C0 + R.T.dot(np.array([[+size], [+size], [3 * size]], dtype=np.float32)).ravel()
C4 = C0 + R.T.dot(np.array([[+size], [-size], [3 * size]], dtype=np.float32)).ravel()
if marker_C != '':
ax.plot([C0[0]], [C0[1]], [C0[2]], marker=marker_C, color=color, label=label, **kwargs)
ax.plot([C0[0], C1[0]], [C0[1], C1[1]], [C0[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs)
ax.plot([C0[0], C2[0]], [C0[1], C2[1]], [C0[2], C2[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs)
ax.plot([C0[0], C3[0]], [C0[1], C3[1]], [C0[2], C3[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs)
ax.plot([C0[0], C4[0]], [C0[1], C4[1]], [C0[2], C4[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs)
ax.plot([C1[0], C2[0], C3[0], C4[0], C1[0]], [C1[1], C2[1], C3[1], C4[1], C1[1]], [C1[2], C2[2], C3[2], C4[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs)
ax.plot([C0[0], C1[0]], [C0[1], C1[1]], [C0[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle,
linewidth=linewidth, **kwargs)
ax.plot([C0[0], C2[0]], [C0[1], C2[1]], [C0[2], C2[2]], color=color, label='_nolegend_', linestyle=linestyle,
linewidth=linewidth, **kwargs)
ax.plot([C0[0], C3[0]], [C0[1], C3[1]], [C0[2], C3[2]], color=color, label='_nolegend_', linestyle=linestyle,
linewidth=linewidth, **kwargs)
ax.plot([C0[0], C4[0]], [C0[1], C4[1]], [C0[2], C4[2]], color=color, label='_nolegend_', linestyle=linestyle,
linewidth=linewidth, **kwargs)
ax.plot([C1[0], C2[0], C3[0], C4[0], C1[0]], [C1[1], C2[1], C3[1], C4[1], C1[1]],
[C1[2], C2[2], C3[2], C4[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle,
linewidth=linewidth, **kwargs)
def axis_equal(ax=None):
if ax is None:
ax = plt.gca()
extents = np.array([getattr(ax, 'get_{}lim'.format(dim))() for dim in 'xyz'])
sz = extents[:,1] - extents[:,0]
sz = extents[:, 1] - extents[:, 0]
centers = np.mean(extents, axis=1)
maxsize = max(abs(sz))
r = maxsize/2
r = maxsize / 2
for ctr, dim in zip(centers, 'xyz'):
getattr(ax, 'set_{}lim'.format(dim))(ctr - r, ctr + r)

@ -3,6 +3,7 @@ import pandas as pd
import enum
import itertools
class Table(object):
def __init__(self, n_cols):
self.n_cols = n_cols
@ -42,7 +43,8 @@ class Table(object):
else:
raise Exception('number of cols does not fit in table')
for c in range(len(cols)):
self.rows[row+r].cells[col+c] = Cell(data[r][c], fmt)
self.rows[row + r].cells[col + c] = Cell(data[r][c], fmt)
class Row(object):
def __init__(self, cells, pre_separator=None, post_separator=None):
@ -61,9 +63,8 @@ class Row(object):
return sum([c.span for c in self.cells])
class Color(object):
def __init__(self, color=(0,0,0), fmt='rgb'):
def __init__(self, color=(0, 0, 0), fmt='rgb'):
if fmt == 'rgb':
self.color = color
elif fmt == 'RGB':
@ -79,11 +80,11 @@ class Color(object):
@classmethod
def rgb(cls, r, g, b):
return Color(color=(r,g,b), fmt='rgb')
return Color(color=(r, g, b), fmt='rgb')
@classmethod
def RGB(cls, r, g, b):
return Color(color=(r,g,b), fmt='RGB')
return Color(color=(r, g, b), fmt='RGB')
class CellFormat(object):
@ -93,6 +94,7 @@ class CellFormat(object):
self.bgcolor = bgcolor
self.bold = bold
class Cell(object):
def __init__(self, data=None, fmt=None, span=1, align=None):
self.data = data
@ -105,6 +107,7 @@ class Cell(object):
def __str__(self):
return self.fmt.fmt % self.data
class Separator(enum.Enum):
HEAD = 1
BOTTOM = 2
@ -143,6 +146,7 @@ class Renderer(object):
with open(path, 'w') as fp:
fp.write(txt)
class TerminalRenderer(Renderer):
def __init__(self, col_sep=' '):
super().__init__()
@ -152,7 +156,7 @@ class TerminalRenderer(Renderer):
cell = table.rows[row].cells[col]
str = cell.fmt.fmt % cell.data
str_width = len(str)
cell_width = sum([widths[idx] for idx in range(col, col+cell.span)])
cell_width = sum([widths[idx] for idx in range(col, col + cell.span)])
cell_width += len(self.col_sep) * (cell.span - 1)
if len(str) > cell_width:
str = str[:cell_width]
@ -167,13 +171,13 @@ class TerminalRenderer(Renderer):
if str_width < cell_width:
n_ws = (cell_width - str_width)
if table.get_cell_align(row, col) == 'r':
str = ' '*n_ws + str
str = ' ' * n_ws + str
elif table.get_cell_align(row, col) == 'l':
str = str + ' '*n_ws
str = str + ' ' * n_ws
elif table.get_cell_align(row, col) == 'c':
n_ws1 = n_ws // 2
n_ws0 = n_ws - n_ws1
str = ' '*n_ws0 + str + ' '*n_ws1
str = ' ' * n_ws0 + str + ' ' * n_ws1
if cell.fmt.bgcolor is not None:
# color = cell.fmt.bgcolor.as_RGB()
# str = sty.bg(*color) + str + sty.rs.bg
@ -182,11 +186,11 @@ class TerminalRenderer(Renderer):
def render_separator(self, separator, tab, col_widths, total_width):
if separator == Separator.HEAD:
return '='*total_width
return '=' * total_width
elif separator == Separator.INNER:
return '-'*total_width
return '-' * total_width
elif separator == Separator.BOTTOM:
return '='*total_width
return '=' * total_width
def render(self, table):
widths = self.col_widths(table)
@ -207,6 +211,7 @@ class TerminalRenderer(Renderer):
lines.append(sepline)
return '\n'.join(lines)
class MarkdownRenderer(TerminalRenderer):
def __init__(self):
super().__init__(col_sep='|')
@ -231,20 +236,20 @@ class MarkdownRenderer(TerminalRenderer):
str = f'**{str}**'
str_width = len(str)
cell_width = sum([widths[idx] for idx in range(col, col+cell.span)])
cell_width = sum([widths[idx] for idx in range(col, col + cell.span)])
cell_width += len(self.col_sep) * (cell.span - 1)
if len(str) > cell_width:
str = str[:cell_width]
else:
n_ws = (cell_width - str_width)
if table.get_cell_align(row, col) == 'r':
str = ' '*n_ws + str
str = ' ' * n_ws + str
elif table.get_cell_align(row, col) == 'l':
str = str + ' '*n_ws
str = str + ' ' * n_ws
elif table.get_cell_align(row, col) == 'c':
n_ws1 = n_ws // 2
n_ws0 = n_ws - n_ws1
str = ' '*n_ws0 + str + ' '*n_ws1
str = ' ' * n_ws0 + str + ' ' * n_ws1
if col == 0: str = self.col_sep + str
if col == table.n_cols - 1: str += self.col_sep
@ -279,7 +284,7 @@ class LatexRenderer(Renderer):
cell = table.rows[row].cells[col]
str = cell.fmt.fmt % cell.data
if cell.fmt.bold:
str = '{\\bf '+ str + '}'
str = '{\\bf ' + str + '}'
if cell.fmt.fgcolor is not None:
color = cell.fmt.fgcolor.as_rgb()
str = f'{{\\color[rgb]{{{color[0]},{color[1]},{color[2]}}} ' + str + '}'
@ -313,6 +318,7 @@ class LatexRenderer(Renderer):
lines.append('\\end{tabular}')
return '\n'.join(lines)
class HtmlRenderer(Renderer):
def __init__(self, html_class='result_table'):
super().__init__()
@ -331,10 +337,14 @@ class HtmlRenderer(Renderer):
color = cell.fmt.bgcolor.as_RGB()
styles.append(f'background-color: rgb({color[0]},{color[1]},{color[2]});')
align = table.get_cell_align(row, col)
if align == 'l': align = 'left'
elif align == 'r': align = 'right'
elif align == 'c': align = 'center'
else: raise Exception('invalid align')
if align == 'l':
align = 'left'
elif align == 'r':
align = 'right'
elif align == 'c':
align = 'center'
else:
raise Exception('invalid align')
styles.append(f'text-align: {align};')
row = table.rows[row]
if row.pre_separator is not None:
@ -365,10 +375,11 @@ class HtmlRenderer(Renderer):
return '\n'.join(lines)
def pandas_to_table(rowname, colname, valname, data, val_cell_fmt=CellFormat(fmt='%.4f'), best_val_cell_fmt=CellFormat(fmt='%.4f', bold=True), best_is_max=[]):
def pandas_to_table(rowname, colname, valname, data, val_cell_fmt=CellFormat(fmt='%.4f'),
best_val_cell_fmt=CellFormat(fmt='%.4f', bold=True), best_is_max=[]):
rnames = data[rowname].unique()
cnames = data[colname].unique()
tab = Table(1+len(cnames))
tab = Table(1 + len(cnames))
header = [Cell('', align='r')]
header.extend([Cell(h, align='r') for h in cnames])
@ -395,7 +406,6 @@ def pandas_to_table(rowname, colname, valname, data, val_cell_fmt=CellFormat(fmt
return tab
if __name__ == '__main__':
# df = pd.read_pickle('full.df')
# best_is_max = ['movF0.5', 'movF1.0']
@ -411,11 +421,11 @@ if __name__ == '__main__':
# tab.add_row(header2)
tab.add_row(Row([Cell(f'c{c}') for c in range(7)]))
tab.rows[-1].post_separator = Separator.INNER
tab.add_block(np.arange(15*7).reshape(15,7))
tab.add_block(np.arange(15 * 7).reshape(15, 7))
tab.rows[4].cells[2].fmt = CellFormat(bold=True)
tab.rows[2].cells[1].fmt = CellFormat(fgcolor=Color.rgb(0.2,0.6,0.1))
tab.rows[2].cells[2].fmt = CellFormat(bgcolor=Color.rgb(0.7,0.1,0.5))
tab.rows[5].cells[3].fmt = CellFormat(bold=True,bgcolor=Color.rgb(0.7,0.1,0.5),fgcolor=Color.rgb(0.1,0.1,0.1))
tab.rows[2].cells[1].fmt = CellFormat(fgcolor=Color.rgb(0.2, 0.6, 0.1))
tab.rows[2].cells[2].fmt = CellFormat(bgcolor=Color.rgb(0.7, 0.1, 0.5))
tab.rows[5].cells[3].fmt = CellFormat(bold=True, bgcolor=Color.rgb(0.7, 0.1, 0.5), fgcolor=Color.rgb(0.1, 0.1, 0.1))
tab.rows[-1].post_separator = Separator.BOTTOM
renderer = TerminalRenderer()

@ -8,6 +8,7 @@ import re
import pickle
import subprocess
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
@ -16,6 +17,7 @@ def str2bool(v):
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
class StopWatch(object):
def __init__(self):
self.timings = OrderedDict()
@ -39,9 +41,11 @@ class StopWatch(object):
return ret
def __repr__(self):
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
def __str__(self):
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
class ETA(object):
def __init__(self, length):
@ -76,6 +80,7 @@ class ETA(object):
def get_remaining_time_str(self):
return self.format_time(self.get_remaining_time())
def git_hash(cwd=None):
ret = subprocess.run(['git', 'describe', '--always'], cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
hash = ret.stdout
@ -83,4 +88,3 @@ def git_hash(cwd=None):
return hash.decode().strip()
else:
return None

@ -7,7 +7,7 @@ def get_patterns(path='syn', imsizes=[], crop=True):
pattern_size = imsizes[0]
if path == 'syn':
np.random.seed(42)
pattern = np.random.uniform(0,1, size=pattern_size)
pattern = np.random.uniform(0, 1, size=pattern_size)
pattern = (pattern < 0.1).astype(np.float32)
pattern.reshape(*imsizes[0])
else:
@ -21,30 +21,30 @@ def get_patterns(path='syn', imsizes=[], crop=True):
if crop and pattern.shape[0] > pattern_size[0] and pattern.shape[1] > pattern_size[1]:
r0 = (pattern.shape[0] - pattern_size[0]) // 2
c0 = (pattern.shape[1] - pattern_size[1]) // 2
pattern = pattern[r0:r0+imsizes[0][0], c0:c0+imsizes[0][1]]
pattern = pattern[r0:r0 + imsizes[0][0], c0:c0 + imsizes[0][1]]
patterns = []
for imsize in imsizes:
pat = cv2.resize(pattern, (imsize[1],imsize[0]), interpolation=cv2.INTER_LINEAR)
pat = cv2.resize(pattern, (imsize[1], imsize[0]), interpolation=cv2.INTER_LINEAR)
patterns.append(pat)
return patterns
def get_rotation_matrix(v0, v1):
v0 = v0/np.linalg.norm(v0)
v1 = v1/np.linalg.norm(v1)
v = np.cross(v0,v1)
c = np.dot(v0,v1)
v0 = v0 / np.linalg.norm(v0)
v1 = v1 / np.linalg.norm(v1)
v = np.cross(v0, v1)
c = np.dot(v0, v1)
s = np.linalg.norm(v)
I = np.eye(3)
vXStr = '{} {} {}; {} {} {}; {} {} {}'.format(0, -v[2], v[1], v[2], 0, -v[0], -v[1], v[0], 0)
k = np.matrix(vXStr)
r = I + k + k @ k * ((1 -c)/(s**2))
r = I + k + k @ k * ((1 - c) / (s ** 2))
return np.asarray(r.astype(np.float32))
def augment_image(img,rng,disp=None,grad=None,max_shift=64,max_blur=1.5,max_noise=10.0,max_sp_noise=0.001):
def augment_image(img, rng, disp=None, grad=None, max_shift=64, max_blur=1.5, max_noise=10.0, max_sp_noise=0.001):
# get min/max values of image
min_val = np.min(img)
max_val = np.max(img)
@ -57,54 +57,56 @@ def augment_image(img,rng,disp=None,grad=None,max_shift=64,max_blur=1.5,max_nois
grad_aug = grad
# apply affine transformation
if max_shift>1:
if max_shift > 1:
# affine parameters
rows,cols = img.shape
rows, cols = img.shape
shear = 0
shift = 0
shear_correction = 0
if rng.uniform(0,1)<0.75: shear = rng.uniform(-max_shift,max_shift) # shear with 75% probability
else: shift = rng.uniform(0,max_shift) # shift with 25% probability
if shear<0: shear_correction = -shear
if rng.uniform(0, 1) < 0.75:
shear = rng.uniform(-max_shift, max_shift) # shear with 75% probability
else:
shift = rng.uniform(0, max_shift) # shift with 25% probability
if shear < 0: shear_correction = -shear
# affine transformation
a = shear/float(rows)
b = shift+shear_correction
a = shear / float(rows)
b = shift + shear_correction
# warp image
T = np.float32([[1,a,b],[0,1,0]])
img_aug = cv2.warpAffine(img_aug,T,(cols,rows))
T = np.float32([[1, a, b], [0, 1, 0]])
img_aug = cv2.warpAffine(img_aug, T, (cols, rows))
if grad is not None:
grad_aug = cv2.warpAffine(grad,T,(cols,rows))
grad_aug = cv2.warpAffine(grad, T, (cols, rows))
# disparity correction map
col = a*np.array(range(rows))+b
disp_delta = np.tile(col,(cols,1)).transpose()
col = a * np.array(range(rows)) + b
disp_delta = np.tile(col, (cols, 1)).transpose()
if disp is not None:
disp_aug = cv2.warpAffine(disp+disp_delta,T,(cols,rows))
disp_aug = cv2.warpAffine(disp + disp_delta, T, (cols, rows))
# gaussian smoothing
if rng.uniform(0,1)<0.5:
img_aug = cv2.GaussianBlur(img_aug,(5,5),rng.uniform(0.2,max_blur))
if rng.uniform(0, 1) < 0.5:
img_aug = cv2.GaussianBlur(img_aug, (5, 5), rng.uniform(0.2, max_blur))
# per-pixel gaussian noise
img_aug = img_aug + rng.randn(*img_aug.shape)*rng.uniform(0.0,max_noise)/255.0
img_aug = img_aug + rng.randn(*img_aug.shape) * rng.uniform(0.0, max_noise) / 255.0
# salt-and-pepper noise
if rng.uniform(0,1)<0.5:
ratio=rng.uniform(0.0,max_sp_noise)
if rng.uniform(0, 1) < 0.5:
ratio = rng.uniform(0.0, max_sp_noise)
img_shape = img_aug.shape
img_aug = img_aug.flatten()
coord = rng.choice(np.size(img_aug), int(np.size(img_aug)*ratio))
coord = rng.choice(np.size(img_aug), int(np.size(img_aug) * ratio))
img_aug[coord] = max_val
coord = rng.choice(np.size(img_aug), int(np.size(img_aug)*ratio))
coord = rng.choice(np.size(img_aug), int(np.size(img_aug) * ratio))
img_aug[coord] = min_val
img_aug = np.reshape(img_aug, img_shape)
# clip intensities back to [0,1]
img_aug = np.maximum(img_aug,0.0)
img_aug = np.minimum(img_aug,1.0)
img_aug = np.maximum(img_aug, 0.0)
img_aug = np.minimum(img_aug, 1.0)
# return image
return img_aug, disp_aug, grad_aug

@ -10,14 +10,15 @@ import cv2
import os
import collections
import sys
sys.path.append('../')
import renderer
import co
from commons import get_patterns,get_rotation_matrix
from commons import get_patterns, get_rotation_matrix
from lcn import lcn
def get_objs(shapenet_dir, obj_classes, num_perclass=100):
def get_objs(shapenet_dir, obj_classes, num_perclass=100):
shapenet = {'chair': '03001627',
'airplane': '02691156',
'car': '02958343',
@ -40,7 +41,7 @@ def get_objs(shapenet_dir, obj_classes, num_perclass=100):
v /= (0.5 * diffs.max())
v -= (v.min(axis=0) + 1)
f = f.astype(np.int32)
objs.append((v,f,n))
objs.append((v, f, n))
print(f'loaded {len(objs)} objects')
return objs
@ -50,11 +51,11 @@ def get_mesh(rng, min_z=0):
# set up background board
verts, faces, normals, colors = [], [], [], []
v, f, n = co.geometry.xyplane(z=0, interleaved=True)
v[:,2] += -v[:,2].min() + rng.uniform(2,7)
v[:,:2] *= 5e2
v[:,2] = np.mean(v[:,2]) + (v[:,2] - np.mean(v[:,2])) * 5e2
v[:, 2] += -v[:, 2].min() + rng.uniform(2, 7)
v[:, :2] *= 5e2
v[:, 2] = np.mean(v[:, 2]) + (v[:, 2] - np.mean(v[:, 2])) * 5e2
c = np.empty_like(v)
c[:] = rng.uniform(0,1, size=(3,)).astype(np.float32)
c[:] = rng.uniform(0, 1, size=(3,)).astype(np.float32)
verts.append(v)
faces.append(f)
normals.append(n)
@ -62,7 +63,7 @@ def get_mesh(rng, min_z=0):
# randomly sample 4 foreground objects for each scene
for shape_idx in range(4):
v, f, n = objs[rng.randint(0,len(objs))]
v, f, n = objs[rng.randint(0, len(objs))]
v, f, n = v.copy(), f.copy(), n.copy()
s = rng.uniform(0.25, 1)
@ -70,11 +71,11 @@ def get_mesh(rng, min_z=0):
R = co.geometry.rotm_from_quat(co.geometry.quat_random(rng=rng))
v = v @ R.T
n = n @ R.T
v[:,2] += -v[:,2].min() + min_z + rng.uniform(0.5, 3)
v[:,:2] += rng.uniform(-1, 1, size=(1,2))
v[:, 2] += -v[:, 2].min() + min_z + rng.uniform(0.5, 3)
v[:, :2] += rng.uniform(-1, 1, size=(1, 2))
c = np.empty_like(v)
c[:] = rng.uniform(0,1, size=(3,)).astype(np.float32)
c[:] = rng.uniform(0, 1, size=(3,)).astype(np.float32)
verts.append(v.astype(np.float32))
faces.append(f)
@ -88,7 +89,6 @@ def get_mesh(rng, min_z=0):
def create_data(out_root, idx, n_samples, imsize, patterns, K, baseline, blend_im, noise, track_length=4):
tic = time.time()
rng = np.random.RandomState()
@ -96,35 +96,34 @@ def create_data(out_root, idx, n_samples, imsize, patterns, K, baseline, blend_i
verts, faces, colors, normals = get_mesh(rng)
data = renderer.PyRenderInput(verts=verts.copy(), colors=colors.copy(), normals=normals.copy(), faces=faces.copy())
print(f'loading mesh for sample {idx+1}/{n_samples} took {time.time()-tic}[s]')
print(f'loading mesh for sample {idx + 1}/{n_samples} took {time.time() - tic}[s]')
# let the camera point to the center
center = np.array([0,0,3], dtype=np.float32)
center = np.array([0, 0, 3], dtype=np.float32)
basevec = np.array([-baseline,0,0], dtype=np.float32)
unit = np.array([0,0,1],dtype=np.float32)
basevec = np.array([-baseline, 0, 0], dtype=np.float32)
unit = np.array([0, 0, 1], dtype=np.float32)
cam_x_ = rng.uniform(-0.2,0.2)
cam_y_ = rng.uniform(-0.2,0.2)
cam_z_ = rng.uniform(-0.2,0.2)
cam_x_ = rng.uniform(-0.2, 0.2)
cam_y_ = rng.uniform(-0.2, 0.2)
cam_z_ = rng.uniform(-0.2, 0.2)
ret = collections.defaultdict(list)
blend_im_rnd = np.clip(blend_im + rng.uniform(-0.1,0.1), 0,1)
blend_im_rnd = np.clip(blend_im + rng.uniform(-0.1, 0.1), 0, 1)
# capture the same static scene from different view points as a track
for ind in range(track_length):
cam_x = cam_x_ + rng.uniform(-0.1,0.1)
cam_y = cam_y_ + rng.uniform(-0.1,0.1)
cam_z = cam_z_ + rng.uniform(-0.1,0.1)
cam_x = cam_x_ + rng.uniform(-0.1, 0.1)
cam_y = cam_y_ + rng.uniform(-0.1, 0.1)
cam_z = cam_z_ + rng.uniform(-0.1, 0.1)
tcam = np.array([cam_x, cam_y, cam_z], dtype=np.float32)
if np.linalg.norm(tcam[0:2])<1e-9:
if np.linalg.norm(tcam[0:2]) < 1e-9:
Rcam = np.eye(3, dtype=np.float32)
else:
Rcam = get_rotation_matrix(center, center-tcam)
Rcam = get_rotation_matrix(center, center - tcam)
tproj = tcam + basevec
Rproj = Rcam
@ -139,20 +138,19 @@ def create_data(out_root, idx, n_samples, imsize, patterns, K, baseline, blend_i
scales = [1, 0.5, 0.25, 0.125]
for scale in scales:
fx = K[0,0] * scale
fy = K[1,1] * scale
px = K[0,2] * scale
py = K[1,2] * scale
fx = K[0, 0] * scale
fy = K[1, 1] * scale
px = K[0, 2] * scale
py = K[1, 2] * scale
im_height = imsize[0] * scale
im_width = imsize[1] * scale
cams.append( renderer.PyCamera(fx,fy,px,py, Rcam, tcam, im_width, im_height) )
projs.append( renderer.PyCamera(fx,fy,px,py, Rproj, tproj, im_width, im_height) )
cams.append(renderer.PyCamera(fx, fy, px, py, Rcam, tcam, im_width, im_height))
projs.append(renderer.PyCamera(fx, fy, px, py, Rproj, tproj, im_width, im_height))
for s, cam, proj, pattern in zip(itertools.count(), cams, projs, patterns):
fl = K[0,0] / (2**s)
fl = K[0, 0] / (2 ** s)
shader = renderer.PyShader(0.5,1.5,0.0,10)
shader = renderer.PyShader(0.5, 1.5, 0.0, 10)
pyrenderer = renderer.PyRenderer(cam, shader, engine='gpu')
pyrenderer.mesh_proj(data, proj, pattern, d_alpha=0, d_beta=0.35)
@ -169,21 +167,21 @@ def create_data(out_root, idx, n_samples, imsize, patterns, K, baseline, blend_i
# get the noise free IR image $J$
im = blend_im_rnd * im + (1 - blend_im_rnd) * ambient
ret[f'ambient{s}'].append( ambient[None].astype(np.float32) )
ret[f'ambient{s}'].append(ambient[None].astype(np.float32))
# get the gradient magnitude of the ambient image $|\nabla A|$
ambient = ambient.astype(np.float32)
sobelx = cv2.Sobel(ambient,cv2.CV_32F,1,0,ksize=5)
sobely = cv2.Sobel(ambient,cv2.CV_32F,0,1,ksize=5)
grad = np.sqrt(sobelx**2 + sobely**2)
grad = np.maximum(grad-0.8,0.0) # parameter
sobelx = cv2.Sobel(ambient, cv2.CV_32F, 1, 0, ksize=5)
sobely = cv2.Sobel(ambient, cv2.CV_32F, 0, 1, ksize=5)
grad = np.sqrt(sobelx ** 2 + sobely ** 2)
grad = np.maximum(grad - 0.8, 0.0) # parameter
# get the local contract normalized grad LCN($|\nabla A|$)
grad_lcn, grad_std = lcn.normalize(grad,5,0.1)
grad_lcn = np.clip(grad_lcn,0.0,1.0) # parameter
ret[f'grad{s}'].append( grad_lcn[None].astype(np.float32))
grad_lcn, grad_std = lcn.normalize(grad, 5, 0.1)
grad_lcn = np.clip(grad_lcn, 0.0, 1.0) # parameter
ret[f'grad{s}'].append(grad_lcn[None].astype(np.float32))
ret[f'im{s}'].append( im[None].astype(np.float32))
ret[f'im{s}'].append(im[None].astype(np.float32))
ret[f'mask{s}'].append(mask[None].astype(np.float32))
ret[f'disp{s}'].append(disp[None].astype(np.float32))
@ -193,18 +191,17 @@ def create_data(out_root, idx, n_samples, imsize, patterns, K, baseline, blend_i
# save to files
out_dir = out_root / f'{idx:08d}'
out_dir.mkdir(exist_ok=True, parents=True)
for k,val in ret.items():
for k, val in ret.items():
for tidx in range(track_length):
v = val[tidx]
out_path = out_dir / f'{k}_{tidx}.npy'
np.save(out_path, v)
np.save( str(out_dir /'blend_im.npy'), blend_im_rnd)
print(f'create sample {idx+1}/{n_samples} took {time.time()-tic}[s]')
np.save(str(out_dir / 'blend_im.npy'), blend_im_rnd)
print(f'create sample {idx + 1}/{n_samples} took {time.time() - tic}[s]')
if __name__=='__main__':
if __name__ == '__main__':
np.random.seed(42)
@ -234,11 +231,12 @@ if __name__=='__main__':
# camera parameters
imsize = (488, 648)
imsizes = [(imsize[0]//(2**s), imsize[1]//(2**s)) for s in range(4)]
imsizes = [(imsize[0] // (2 ** s), imsize[1] // (2 ** s)) for s in range(4)]
# K = np.array([[567.6, 0, 324.7], [0, 570.2, 250.1], [0 ,0, 1]], dtype=np.float32)
K = np.array([[1929.5936336276382, 0, 113.66561071478046], [0, 1911.2517985448746, 473.70108079885887], [0 ,0, 1]], dtype=np.float32)
focal_lengths = [K[0,0]/(2**s) for s in range(4)]
baseline=0.075
K = np.array([[1929.5936336276382, 0, 113.66561071478046], [0, 1911.2517985448746, 473.70108079885887], [0, 0, 1]],
dtype=np.float32)
focal_lengths = [K[0, 0] / (2 ** s) for s in range(4)]
baseline = 0.075
blend_im = 0.6
noise = 0
@ -264,7 +262,7 @@ if __name__=='__main__':
pickle.dump(settings, f, pickle.HIGHEST_PROTOCOL)
# start the job
n_samples = 2**10 + 2**13
n_samples = 2 ** 10 + 2 ** 13
for idx in range(start, n_samples):
args = (out_root, idx, n_samples, imsize, patterns, K, baseline, blend_im, noise, track_length)
create_data(*args)

@ -21,11 +21,13 @@ from .commons import get_patterns, augment_image
from mpl_toolkits.mplot3d import Axes3D
class TrackSynDataset(torchext.BaseDataset):
'''
Load locally saved synthetic dataset
Please run ./create_syn_data.sh to generate the dataset
'''
def __init__(self, settings_path, sample_paths, track_length=2, train=True, data_aug=False):
super().__init__(train=train)
@ -33,8 +35,8 @@ class TrackSynDataset(torchext.BaseDataset):
self.sample_paths = sample_paths
self.data_aug = data_aug
self.train = train
self.track_length=track_length
assert(track_length<=4)
self.track_length = track_length
assert (track_length <= 4)
with open(str(settings_path), 'rb') as f:
settings = pickle.load(f)
@ -46,10 +48,10 @@ class TrackSynDataset(torchext.BaseDataset):
self.scale = len(self.imsizes)
self.max_shift=0
self.max_blur=0.5
self.max_noise=3.0
self.max_sp_noise=0.0005
self.max_shift = 0
self.max_blur = 0.5
self.max_noise = 3.0
self.max_sp_noise = 0.0005
def __len__(self):
return len(self.sample_paths)
@ -75,9 +77,9 @@ class TrackSynDataset(torchext.BaseDataset):
ambs = []
grads = []
for tidx in track_ind:
imgs.append(np.load(os.path.join(sample_path,f'im{sidx}_{tidx}.npy')))
ambs.append(np.load(os.path.join(sample_path,f'ambient{sidx}_{tidx}.npy')))
grads.append(np.load(os.path.join(sample_path,f'grad{sidx}_{tidx}.npy')))
imgs.append(np.load(os.path.join(sample_path, f'im{sidx}_{tidx}.npy')))
ambs.append(np.load(os.path.join(sample_path, f'ambient{sidx}_{tidx}.npy')))
grads.append(np.load(os.path.join(sample_path, f'grad{sidx}_{tidx}.npy')))
ret[f'im{sidx}'] = np.stack(imgs, axis=0)
ret[f'ambient{sidx}'] = np.stack(ambs, axis=0)
ret[f'grad{sidx}'] = np.stack(grads, axis=0)
@ -87,20 +89,20 @@ class TrackSynDataset(torchext.BaseDataset):
R = []
t = []
for tidx in track_ind:
disps.append(np.load(os.path.join(sample_path,f'disp0_{tidx}.npy')))
R.append(np.load(os.path.join(sample_path,f'R_{tidx}.npy')))
t.append(np.load(os.path.join(sample_path,f't_{tidx}.npy')))
disps.append(np.load(os.path.join(sample_path, f'disp0_{tidx}.npy')))
R.append(np.load(os.path.join(sample_path, f'R_{tidx}.npy')))
t.append(np.load(os.path.join(sample_path, f't_{tidx}.npy')))
ret[f'disp0'] = np.stack(disps, axis=0)
ret['R'] = np.stack(R, axis=0)
ret['t'] = np.stack(t, axis=0)
blend_im = np.load(os.path.join(sample_path,'blend_im.npy'))
blend_im = np.load(os.path.join(sample_path, 'blend_im.npy'))
ret['blend_im'] = blend_im.astype(np.float32)
#### apply data augmentation at different scales seperately, only work for max_shift=0
if self.data_aug:
for sidx in range(len(self.imsizes)):
if sidx==0:
if sidx == 0:
img = ret[f'im{sidx}']
disp = ret[f'disp{sidx}']
grad = ret[f'grad{sidx}']
@ -108,10 +110,11 @@ class TrackSynDataset(torchext.BaseDataset):
disp_aug = np.zeros_like(img)
grad_aug = np.zeros_like(img)
for i in range(img.shape[0]):
img_aug_, disp_aug_, grad_aug_ = augment_image(img[i,0],rng,
disp=disp[i,0],grad=grad[i,0],
img_aug_, disp_aug_, grad_aug_ = augment_image(img[i, 0], rng,
disp=disp[i, 0], grad=grad[i, 0],
max_shift=self.max_shift, max_blur=self.max_blur,
max_noise=self.max_noise, max_sp_noise=self.max_sp_noise)
max_noise=self.max_noise,
max_sp_noise=self.max_sp_noise)
img_aug[i] = img_aug_[None].astype(np.float32)
disp_aug[i] = disp_aug_[None].astype(np.float32)
grad_aug[i] = grad_aug_[None].astype(np.float32)
@ -122,27 +125,24 @@ class TrackSynDataset(torchext.BaseDataset):
img = ret[f'im{sidx}']
img_aug = np.zeros_like(img)
for i in range(img.shape[0]):
img_aug_, _, _ = augment_image(img[i,0],rng,
img_aug_, _, _ = augment_image(img[i, 0], rng,
max_shift=self.max_shift, max_blur=self.max_blur,
max_noise=self.max_noise, max_sp_noise=self.max_sp_noise)
img_aug[i] = img_aug_[None].astype(np.float32)
ret[f'im{sidx}'] = img_aug
if len(track_ind)==1:
if len(track_ind) == 1:
for key, val in ret.items():
if key!='blend_im' and key!='id':
if key != 'blend_im' and key != 'id':
ret[key] = val[0]
return ret
def getK(self, sidx=0):
K = self.K.copy() / (2**sidx)
K[2,2] = 1
K = self.K.copy() / (2 ** sidx)
K[2, 2] = 1
return K
if __name__ == '__main__':
pass

@ -2,7 +2,7 @@
<!-- Generated by Cython 0.29 -->
<html>
<head>
<meta http-equiv="Content-Type" content="text/html; charset=utf-8" />
<meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
<title>Cython: lcn.pyx</title>
<style type="text/css">
@ -355,17 +355,23 @@ body.cython { font-family: courier; font-size: 12; }
.cython .vi { color: #19177C } /* Name.Variable.Instance */
.cython .vm { color: #19177C } /* Name.Variable.Magic */
.cython .il { color: #666666 } /* Literal.Number.Integer.Long */
</style>
</head>
<body class="cython">
<p><span style="border-bottom: solid 1px grey;">Generated by Cython 0.29</span></p>
<p>
<span style="background-color: #FFFF00">Yellow lines</span> hint at Python interaction.<br />
<span style="background-color: #FFFF00">Yellow lines</span> hint at Python interaction.<br/>
Click on a line that starts with a "<code>+</code>" to see the C code that Cython generated for it.
</p>
<p>Raw output: <a href="lcn.c">lcn.c</a></p>
<div class="cython"><pre class="cython line score-16" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">01</span>: <span class="k">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span></pre>
<pre class='cython code score-16 '> __pyx_t_1 = <span class='pyx_c_api'>__Pyx_Import</span>(__pyx_n_s_numpy, 0, -1);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1, __pyx_L1_error)</span>
<div class="cython">
<pre class="cython line score-16"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">01</span>: <span class="k">import</span> <span class="nn">numpy</span> <span
class="k">as</span> <span class="nn">np</span></pre>
<pre class='cython code score-16 '> __pyx_t_1 = <span class='pyx_c_api'>__Pyx_Import</span>(__pyx_n_s_numpy, 0, -1);<span
class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_np, __pyx_t_1) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 1, __pyx_L1_error)</span>
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
@ -374,22 +380,39 @@ body.cython { font-family: courier; font-size: 12; }
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_test, __pyx_t_1) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 1, __pyx_L1_error)</span>
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
</pre><pre class="cython line score-0">&#xA0;<span class="">02</span>: <span class="k">cimport</span> <span class="nn">cython</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">03</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">04</span>: <span class="c"># use c square root function</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">05</span>: <span class="k">cdef</span> <span class="kr">extern</span> <span class="k">from</span> <span class="s">&quot;math.h&quot;</span><span class="p">:</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">06</span>: <span class="nb">float</span> <span class="n">sqrt</span><span class="p">(</span><span class="nb">float</span> <span class="n">x</span><span class="p">)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">07</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">08</span>: <span class="nd">@cython</span><span class="o">.</span><span class="n">boundscheck</span><span class="p">(</span><span class="bp">False</span><span class="p">)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">09</span>: <span class="nd">@cython</span><span class="o">.</span><span class="n">wraparound</span><span class="p">(</span><span class="bp">False</span><span class="p">)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">10</span>: <span class="nd">@cython</span><span class="o">.</span><span class="n">cdivision</span><span class="p">(</span><span class="bp">True</span><span class="p">)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">11</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">12</span>: <span class="c"># 3 parameters:</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">13</span>: <span class="c"># - float image</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">14</span>: <span class="c"># - kernel size (actually this is the radius, kernel is 2*k+1)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">15</span>: <span class="c"># - small constant epsilon that is used to avoid division by zero</span></pre>
<pre class="cython line score-67" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">16</span>: <span class="k">def</span> <span class="nf">normalize</span><span class="p">(</span><span class="nb">float</span><span class="p">[:,</span> <span class="p">:]</span> <span class="n">img</span><span class="p">,</span> <span class="nb">int</span> <span class="n">kernel_size</span> <span class="o">=</span> <span class="mf">4</span><span class="p">,</span> <span class="nb">float</span> <span class="n">epsilon</span> <span class="o">=</span> <span class="mf">0.01</span><span class="p">):</span></pre>
<pre class='cython code score-67 '>/* Python wrapper */
</pre>
<pre class="cython line score-0">&#xA0;<span class="">02</span>: <span class="k">cimport</span> <span class="nn">cython</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">03</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">04</span>: <span class="c"># use c square root function</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">05</span>: <span class="k">cdef</span> <span
class="kr">extern</span> <span class="k">from</span> <span class="s">&quot;math.h&quot;</span><span
class="p">:</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">06</span>: <span class="nb">float</span> <span class="n">sqrt</span><span
class="p">(</span><span class="nb">float</span> <span class="n">x</span><span class="p">)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">07</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">08</span>: <span class="nd">@cython</span><span
class="o">.</span><span class="n">boundscheck</span><span class="p">(</span><span
class="bp">False</span><span class="p">)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">09</span>: <span class="nd">@cython</span><span
class="o">.</span><span class="n">wraparound</span><span class="p">(</span><span
class="bp">False</span><span class="p">)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">10</span>: <span class="nd">@cython</span><span
class="o">.</span><span class="n">cdivision</span><span class="p">(</span><span class="bp">True</span><span
class="p">)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">11</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">12</span>: <span class="c"># 3 parameters:</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">13</span>: <span class="c"># - float image</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">14</span>: <span class="c"># - kernel size (actually this is the radius, kernel is 2*k+1)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">15</span>: <span class="c"># - small constant epsilon that is used to avoid division by zero</span></pre>
<pre class="cython line score-67"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">16</span>: <span class="k">def</span> <span class="nf">normalize</span><span
class="p">(</span><span class="nb">float</span><span class="p">[:,</span> <span class="p">:]</span> <span
class="n">img</span><span class="p">,</span> <span class="nb">int</span> <span class="n">kernel_size</span> <span
class="o">=</span> <span class="mf">4</span><span class="p">,</span> <span class="nb">float</span> <span
class="n">epsilon</span> <span class="o">=</span> <span class="mf">0.01</span><span
class="p">):</span></pre>
<pre class='cython code score-67 '>/* Python wrapper */
static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/
static PyMethodDef __pyx_mdef_3lcn_1normalize = {"normalize", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_3lcn_1normalize, METH_VARARGS|METH_KEYWORDS, 0};
static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) {
@ -434,7 +457,8 @@ static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_
}
}
if (unlikely(kw_args &gt; 0)) {
if (unlikely(<span class='pyx_c_api'>__Pyx_ParseOptionalKeywords</span>(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "normalize") &lt; 0)) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
if (unlikely(<span class='pyx_c_api'>__Pyx_ParseOptionalKeywords</span>(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "normalize") &lt; 0)) <span
class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
}
} else {
switch (<span class='py_macro_api'>PyTuple_GET_SIZE</span>(__pyx_args)) {
@ -447,21 +471,27 @@ static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_
default: goto __pyx_L5_argtuple_error;
}
}
__pyx_v_img = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(values[0], PyBUF_WRITABLE);<span class='error_goto'> if (unlikely(!__pyx_v_img.memview)) __PYX_ERR(0, 16, __pyx_L3_error)</span>
__pyx_v_img = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(values[0], PyBUF_WRITABLE);<span
class='error_goto'> if (unlikely(!__pyx_v_img.memview)) __PYX_ERR(0, 16, __pyx_L3_error)</span>
if (values[1]) {
__pyx_v_kernel_size = <span class='pyx_c_api'>__Pyx_PyInt_As_int</span>(values[1]); if (unlikely((__pyx_v_kernel_size == (int)-1) &amp;&amp; <span class='py_c_api'>PyErr_Occurred</span>())) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
__pyx_v_kernel_size = <span class='pyx_c_api'>__Pyx_PyInt_As_int</span>(values[1]); if (unlikely((__pyx_v_kernel_size == (int)-1) &amp;&amp; <span
class='py_c_api'>PyErr_Occurred</span>())) <span
class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
} else {
__pyx_v_kernel_size = ((int)4);
}
if (values[2]) {
__pyx_v_epsilon = __pyx_<span class='py_c_api'>PyFloat_AsFloat</span>(values[2]); if (unlikely((__pyx_v_epsilon == (float)-1) &amp;&amp; <span class='py_c_api'>PyErr_Occurred</span>())) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
__pyx_v_epsilon = __pyx_<span class='py_c_api'>PyFloat_AsFloat</span>(values[2]); if (unlikely((__pyx_v_epsilon == (float)-1) &amp;&amp; <span
class='py_c_api'>PyErr_Occurred</span>())) <span
class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
} else {
__pyx_v_epsilon = ((float)0.01);
}
}
goto __pyx_L4_argument_unpacking_done;
__pyx_L5_argtuple_error:;
<span class='pyx_c_api'>__Pyx_RaiseArgtupleInvalid</span>("normalize", 0, 1, 3, <span class='py_macro_api'>PyTuple_GET_SIZE</span>(__pyx_args)); <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
<span class='pyx_c_api'>__Pyx_RaiseArgtupleInvalid</span>("normalize", 0, 1, 3, <span class='py_macro_api'>PyTuple_GET_SIZE</span>(__pyx_args)); <span
class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
__pyx_L3_error:;
<span class='pyx_c_api'>__Pyx_AddTraceback</span>("lcn.normalize", __pyx_clineno, __pyx_lineno, __pyx_filename);
<span class='refnanny'>__Pyx_RefNannyFinishContext</span>();
@ -515,27 +545,49 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
return __pyx_r;
}
/* … */
__pyx_tuple__19 = <span class='py_c_api'>PyTuple_Pack</span>(19, __pyx_n_s_img, __pyx_n_s_kernel_size, __pyx_n_s_epsilon, __pyx_n_s_M, __pyx_n_s_N, __pyx_n_s_img_lcn, __pyx_n_s_img_std, __pyx_n_s_img_lcn_view, __pyx_n_s_img_std_view, __pyx_n_s_tmp, __pyx_n_s_mean, __pyx_n_s_stddev, __pyx_n_s_m, __pyx_n_s_n, __pyx_n_s_i, __pyx_n_s_j, __pyx_n_s_ks, __pyx_n_s_eps, __pyx_n_s_num);<span class='error_goto'> if (unlikely(!__pyx_tuple__19)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
__pyx_tuple__19 = <span class='py_c_api'>PyTuple_Pack</span>(19, __pyx_n_s_img, __pyx_n_s_kernel_size, __pyx_n_s_epsilon, __pyx_n_s_M, __pyx_n_s_N, __pyx_n_s_img_lcn, __pyx_n_s_img_std, __pyx_n_s_img_lcn_view, __pyx_n_s_img_std_view, __pyx_n_s_tmp, __pyx_n_s_mean, __pyx_n_s_stddev, __pyx_n_s_m, __pyx_n_s_n, __pyx_n_s_i, __pyx_n_s_j, __pyx_n_s_ks, __pyx_n_s_eps, __pyx_n_s_num);<span
class='error_goto'> if (unlikely(!__pyx_tuple__19)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_tuple__19);
<span class='refnanny'>__Pyx_GIVEREF</span>(__pyx_tuple__19);
/* … */
__pyx_t_1 = PyCFunction_NewEx(&amp;__pyx_mdef_3lcn_1normalize, NULL, __pyx_n_s_lcn);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_normalize, __pyx_t_1) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L1_error)</span>
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_normalize, __pyx_t_1) &lt; 0) <span
class='error_goto'>__PYX_ERR(0, 16, __pyx_L1_error)</span>
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
__pyx_codeobj__20 = (PyObject*)<span class='pyx_c_api'>__Pyx_PyCode_New</span>(3, 0, 19, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__19, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_lcn_pyx, __pyx_n_s_normalize, 16, __pyx_empty_bytes);<span class='error_goto'> if (unlikely(!__pyx_codeobj__20)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
</pre><pre class="cython line score-0">&#xA0;<span class="">17</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">18</span>: <span class="c"># image dimensions</span></pre>
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">19</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">M</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mf">0</span><span class="p">]</span></pre>
<pre class='cython code score-0 '> __pyx_v_M = (__pyx_v_img.shape[0]);
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">20</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">N</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mf">1</span><span class="p">]</span></pre>
<pre class='cython code score-0 '> __pyx_v_N = (__pyx_v_img.shape[1]);
</pre><pre class="cython line score-0">&#xA0;<span class="">21</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">22</span>: <span class="c"># create outputs and output views</span></pre>
<pre class="cython line score-46" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">23</span>: <span class="n">img_lcn</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span></pre>
<pre class='cython code score-46 '> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_1, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
__pyx_codeobj__20 = (PyObject*)<span class='pyx_c_api'>__Pyx_PyCode_New</span>(3, 0, 19, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__19, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_lcn_pyx, __pyx_n_s_normalize, 16, __pyx_empty_bytes);<span
class='error_goto'> if (unlikely(!__pyx_codeobj__20)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
</pre>
<pre class="cython line score-0">&#xA0;<span class="">17</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">18</span>: <span class="c"># image dimensions</span></pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">19</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span
class="nf">M</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span
class="n">shape</span><span class="p">[</span><span class="mf">0</span><span class="p">]</span></pre>
<pre class='cython code score-0 '> __pyx_v_M = (__pyx_v_img.shape[0]);
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">20</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span
class="nf">N</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span
class="n">shape</span><span class="p">[</span><span class="mf">1</span><span class="p">]</span></pre>
<pre class='cython code score-0 '> __pyx_v_N = (__pyx_v_img.shape[1]);
</pre>
<pre class="cython line score-0">&#xA0;<span class="">21</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">22</span>: <span class="c"># create outputs and output views</span></pre>
<pre class="cython line score-46"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">23</span>: <span class="n">img_lcn</span> <span class="o">=</span> <span
class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span
class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span
class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span
class="n">float32</span><span class="p">)</span></pre>
<pre class='cython code score-46 '> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_1, __pyx_n_s_np);<span
class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
__pyx_t_2 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_1, __pyx_n_s_zeros);<span class='error_goto'> if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
__pyx_t_2 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_1, __pyx_n_s_zeros);<span
class='error_goto'> if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_2);
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
__pyx_t_1 = <span class='py_c_api'>PyInt_FromSsize_t</span>(__pyx_v_M);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
@ -559,22 +611,34 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_4);
<span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_1, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
__pyx_t_5 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_1, __pyx_n_s_float32);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
__pyx_t_5 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_1, __pyx_n_s_float32);<span
class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_t_4, __pyx_n_s_dtype, __pyx_t_5) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 23, __pyx_L1_error)</span>
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_t_4, __pyx_n_s_dtype, __pyx_t_5) &lt; 0) <span
class='error_goto'>__PYX_ERR(0, 23, __pyx_L1_error)</span>
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0;
__pyx_t_5 = <span class='pyx_c_api'>__Pyx_PyObject_Call</span>(__pyx_t_2, __pyx_t_3, __pyx_t_4);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
__pyx_t_5 = <span class='pyx_c_api'>__Pyx_PyObject_Call</span>(__pyx_t_2, __pyx_t_3, __pyx_t_4);<span
class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_2); __pyx_t_2 = 0;
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_3); __pyx_t_3 = 0;
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_4); __pyx_t_4 = 0;
__pyx_v_img_lcn = __pyx_t_5;
__pyx_t_5 = 0;
</pre><pre class="cython line score-46" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">24</span>: <span class="n">img_std</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span></pre>
<pre class='cython code score-46 '> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_5, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
</pre>
<pre class="cython line score-46"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">24</span>: <span class="n">img_std</span> <span class="o">=</span> <span
class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span
class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span
class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span
class="n">float32</span><span class="p">)</span></pre>
<pre class='cython code score-46 '> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_5, __pyx_n_s_np);<span
class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
__pyx_t_4 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_5, __pyx_n_s_zeros);<span class='error_goto'> if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
__pyx_t_4 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_5, __pyx_n_s_zeros);<span
class='error_goto'> if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_4);
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0;
__pyx_t_5 = <span class='py_c_api'>PyInt_FromSsize_t</span>(__pyx_v_M);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
@ -598,114 +662,236 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_2);
<span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_5, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
__pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_5, __pyx_n_s_float32);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
__pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_5, __pyx_n_s_float32);<span
class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0;
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_t_2, __pyx_n_s_dtype, __pyx_t_1) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 24, __pyx_L1_error)</span>
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_t_2, __pyx_n_s_dtype, __pyx_t_1) &lt; 0) <span
class='error_goto'>__PYX_ERR(0, 24, __pyx_L1_error)</span>
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
__pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyObject_Call</span>(__pyx_t_4, __pyx_t_3, __pyx_t_2);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
__pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyObject_Call</span>(__pyx_t_4, __pyx_t_3, __pyx_t_2);<span
class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_4); __pyx_t_4 = 0;
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_3); __pyx_t_3 = 0;
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_2); __pyx_t_2 = 0;
__pyx_v_img_std = __pyx_t_1;
__pyx_t_1 = 0;
</pre><pre class="cython line score-2" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">25</span>: <span class="k">cdef</span> <span class="kt">float</span>[<span class="p">:,</span> <span class="p">:]</span> <span class="n">img_lcn_view</span> <span class="o">=</span> <span class="n">img_lcn</span></pre>
<pre class='cython code score-2 '> __pyx_t_6 = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(__pyx_v_img_lcn, PyBUF_WRITABLE);<span class='error_goto'> if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 25, __pyx_L1_error)</span>
</pre>
<pre class="cython line score-2"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">25</span>: <span class="k">cdef</span> <span class="kt">float</span>[<span class="p">:,</span> <span
class="p">:]</span> <span class="n">img_lcn_view</span> <span class="o">=</span> <span
class="n">img_lcn</span></pre>
<pre class='cython code score-2 '> __pyx_t_6 = <span
class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(__pyx_v_img_lcn, PyBUF_WRITABLE);<span
class='error_goto'> if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 25, __pyx_L1_error)</span>
__pyx_v_img_lcn_view = __pyx_t_6;
__pyx_t_6.memview = NULL;
__pyx_t_6.data = NULL;
</pre><pre class="cython line score-2" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">26</span>: <span class="k">cdef</span> <span class="kt">float</span>[<span class="p">:,</span> <span class="p">:]</span> <span class="n">img_std_view</span> <span class="o">=</span> <span class="n">img_std</span></pre>
<pre class='cython code score-2 '> __pyx_t_6 = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(__pyx_v_img_std, PyBUF_WRITABLE);<span class='error_goto'> if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 26, __pyx_L1_error)</span>
</pre>
<pre class="cython line score-2"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">26</span>: <span class="k">cdef</span> <span class="kt">float</span>[<span class="p">:,</span> <span
class="p">:]</span> <span class="n">img_std_view</span> <span class="o">=</span> <span
class="n">img_std</span></pre>
<pre class='cython code score-2 '> __pyx_t_6 = <span
class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(__pyx_v_img_std, PyBUF_WRITABLE);<span
class='error_goto'> if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 26, __pyx_L1_error)</span>
__pyx_v_img_std_view = __pyx_t_6;
__pyx_t_6.memview = NULL;
__pyx_t_6.data = NULL;
</pre><pre class="cython line score-0">&#xA0;<span class="">27</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">28</span>: <span class="c"># temporary c variables</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">29</span>: <span class="k">cdef</span> <span class="kt">float</span> <span class="nf">tmp</span><span class="p">,</span> <span class="nf">mean</span><span class="p">,</span> <span class="nf">stddev</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">30</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">m</span><span class="p">,</span> <span class="nf">n</span><span class="p">,</span> <span class="nf">i</span><span class="p">,</span> <span class="nf">j</span></pre>
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">31</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">ks</span> <span class="o">=</span> <span class="n">kernel_size</span></pre>
<pre class='cython code score-0 '> __pyx_v_ks = __pyx_v_kernel_size;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">32</span>: <span class="k">cdef</span> <span class="kt">float</span> <span class="nf">eps</span> <span class="o">=</span> <span class="n">epsilon</span></pre>
<pre class='cython code score-0 '> __pyx_v_eps = __pyx_v_epsilon;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">33</span>: <span class="k">cdef</span> <span class="kt">float</span> <span class="nf">num</span> <span class="o">=</span> <span class="p">(</span><span class="n">ks</span><span class="o">*</span><span class="mf">2</span><span class="o">+</span><span class="mf">1</span><span class="p">)</span><span class="o">**</span><span class="mf">2</span></pre>
<pre class='cython code score-0 '> __pyx_v_num = __Pyx_pow_Py_ssize_t(((__pyx_v_ks * 2) + 1), 2);
</pre><pre class="cython line score-0">&#xA0;<span class="">34</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">35</span>: <span class="c"># for all pixels do</span></pre>
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">36</span>: <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ks</span><span class="p">,</span><span class="n">M</span><span class="o">-</span><span class="n">ks</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_7 = (__pyx_v_M - __pyx_v_ks);
</pre>
<pre class="cython line score-0">&#xA0;<span class="">27</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">28</span>: <span class="c"># temporary c variables</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">29</span>: <span class="k">cdef</span> <span class="kt">float</span> <span
class="nf">tmp</span><span class="p">,</span> <span class="nf">mean</span><span class="p">,</span> <span
class="nf">stddev</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">30</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span
class="nf">m</span><span class="p">,</span> <span class="nf">n</span><span class="p">,</span> <span
class="nf">i</span><span class="p">,</span> <span class="nf">j</span></pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">31</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span
class="nf">ks</span> <span class="o">=</span> <span class="n">kernel_size</span></pre>
<pre class='cython code score-0 '> __pyx_v_ks = __pyx_v_kernel_size;
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">32</span>: <span class="k">cdef</span> <span class="kt">float</span> <span
class="nf">eps</span> <span class="o">=</span> <span class="n">epsilon</span></pre>
<pre class='cython code score-0 '> __pyx_v_eps = __pyx_v_epsilon;
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">33</span>: <span class="k">cdef</span> <span class="kt">float</span> <span
class="nf">num</span> <span class="o">=</span> <span class="p">(</span><span class="n">ks</span><span
class="o">*</span><span class="mf">2</span><span class="o">+</span><span class="mf">1</span><span class="p">)</span><span
class="o">**</span><span class="mf">2</span></pre>
<pre class='cython code score-0 '> __pyx_v_num = __Pyx_pow_Py_ssize_t(((__pyx_v_ks * 2) + 1), 2);
</pre>
<pre class="cython line score-0">&#xA0;<span class="">34</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">35</span>: <span
class="c"># for all pixels do</span></pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">36</span>: <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span
class="nb">range</span><span class="p">(</span><span class="n">ks</span><span class="p">,</span><span
class="n">M</span><span class="o">-</span><span class="n">ks</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_7 = (__pyx_v_M - __pyx_v_ks);
__pyx_t_8 = __pyx_t_7;
for (__pyx_t_9 = __pyx_v_ks; __pyx_t_9 &lt; __pyx_t_8; __pyx_t_9+=1) {
__pyx_v_m = __pyx_t_9;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">37</span>: <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ks</span><span class="p">,</span><span class="n">N</span><span class="o">-</span><span class="n">ks</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_10 = (__pyx_v_N - __pyx_v_ks);
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">37</span>: <span class="k">for</span> <span class="n">n</span> <span
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ks</span><span
class="p">,</span><span class="n">N</span><span class="o">-</span><span class="n">ks</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_10 = (__pyx_v_N - __pyx_v_ks);
__pyx_t_11 = __pyx_t_10;
for (__pyx_t_12 = __pyx_v_ks; __pyx_t_12 &lt; __pyx_t_11; __pyx_t_12+=1) {
__pyx_v_n = __pyx_t_12;
</pre><pre class="cython line score-0">&#xA0;<span class="">38</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">39</span>: <span class="c"># calculate mean</span></pre>
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">40</span>: <span class="n">mean</span> <span class="o">=</span> <span class="mf">0</span><span class="p">;</span></pre>
<pre class='cython code score-0 '> __pyx_v_mean = 0.0;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">41</span>: <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_13 = (__pyx_v_ks + 1);
</pre>
<pre class="cython line score-0">&#xA0;<span class="">38</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">39</span>: <span class="c"># calculate mean</span></pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">40</span>: <span class="n">mean</span> <span class="o">=</span> <span
class="mf">0</span><span class="p">;</span></pre>
<pre class='cython code score-0 '> __pyx_v_mean = 0.0;
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">41</span>: <span class="k">for</span> <span class="n">i</span> <span
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span
class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span
class="mf">1</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_13 = (__pyx_v_ks + 1);
__pyx_t_14 = __pyx_t_13;
for (__pyx_t_15 = (-__pyx_v_ks); __pyx_t_15 &lt; __pyx_t_14; __pyx_t_15+=1) {
__pyx_v_i = __pyx_t_15;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">42</span>: <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_16 = (__pyx_v_ks + 1);
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">42</span>: <span class="k">for</span> <span class="n">j</span> <span
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span
class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span
class="mf">1</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_16 = (__pyx_v_ks + 1);
__pyx_t_17 = __pyx_t_16;
for (__pyx_t_18 = (-__pyx_v_ks); __pyx_t_18 &lt; __pyx_t_17; __pyx_t_18+=1) {
__pyx_v_j = __pyx_t_18;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">43</span>: <span class="n">mean</span> <span class="o">+=</span> <span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span class="p">]</span></pre>
<pre class='cython code score-0 '> __pyx_t_19 = (__pyx_v_m + __pyx_v_i);
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">43</span>: <span class="n">mean</span> <span class="o">+=</span> <span
class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span
class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span
class="p">]</span></pre>
<pre class='cython code score-0 '> __pyx_t_19 = (__pyx_v_m + __pyx_v_i);
__pyx_t_20 = (__pyx_v_n + __pyx_v_j);
__pyx_v_mean = (__pyx_v_mean + (*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_19 * __pyx_v_img.strides[0]) ) + __pyx_t_20 * __pyx_v_img.strides[1]) ))));
}
}
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">44</span>: <span class="n">mean</span> <span class="o">=</span> <span class="n">mean</span><span class="o">/</span><span class="n">num</span></pre>
<pre class='cython code score-0 '> __pyx_v_mean = (__pyx_v_mean / __pyx_v_num);
</pre><pre class="cython line score-0">&#xA0;<span class="">45</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">46</span>: <span class="c"># calculate std dev</span></pre>
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">47</span>: <span class="n">stddev</span> <span class="o">=</span> <span class="mf">0</span><span class="p">;</span></pre>
<pre class='cython code score-0 '> __pyx_v_stddev = 0.0;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">48</span>: <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_13 = (__pyx_v_ks + 1);
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">44</span>: <span class="n">mean</span> <span class="o">=</span> <span
class="n">mean</span><span class="o">/</span><span class="n">num</span></pre>
<pre class='cython code score-0 '> __pyx_v_mean = (__pyx_v_mean / __pyx_v_num);
</pre>
<pre class="cython line score-0">&#xA0;<span class="">45</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">46</span>: <span
class="c"># calculate std dev</span></pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">47</span>: <span class="n">stddev</span> <span class="o">=</span> <span
class="mf">0</span><span class="p">;</span></pre>
<pre class='cython code score-0 '> __pyx_v_stddev = 0.0;
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">48</span>: <span class="k">for</span> <span class="n">i</span> <span
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span
class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span
class="mf">1</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_13 = (__pyx_v_ks + 1);
__pyx_t_14 = __pyx_t_13;
for (__pyx_t_15 = (-__pyx_v_ks); __pyx_t_15 &lt; __pyx_t_14; __pyx_t_15+=1) {
__pyx_v_i = __pyx_t_15;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">49</span>: <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_16 = (__pyx_v_ks + 1);
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">49</span>: <span class="k">for</span> <span class="n">j</span> <span
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span
class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span
class="mf">1</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_16 = (__pyx_v_ks + 1);
__pyx_t_17 = __pyx_t_16;
for (__pyx_t_18 = (-__pyx_v_ks); __pyx_t_18 &lt; __pyx_t_17; __pyx_t_18+=1) {
__pyx_v_j = __pyx_t_18;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">50</span>: <span class="n">stddev</span> <span class="o">=</span> <span class="n">stddev</span> <span class="o">+</span> <span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span><span class="o">*</span><span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span></pre>
<pre class='cython code score-0 '> __pyx_t_21 = (__pyx_v_m + __pyx_v_i);
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">50</span>: <span class="n">stddev</span> <span class="o">=</span> <span
class="n">stddev</span> <span class="o">+</span> <span class="p">(</span><span class="n">img</span><span
class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span
class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span
class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span><span
class="o">*</span><span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span
class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span
class="o">+</span><span class="n">j</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span
class="p">)</span></pre>
<pre class='cython code score-0 '> __pyx_t_21 = (__pyx_v_m + __pyx_v_i);
__pyx_t_22 = (__pyx_v_n + __pyx_v_j);
__pyx_t_23 = (__pyx_v_m + __pyx_v_i);
__pyx_t_24 = (__pyx_v_n + __pyx_v_j);
__pyx_v_stddev = (__pyx_v_stddev + (((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_21 * __pyx_v_img.strides[0]) ) + __pyx_t_22 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean) * ((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_23 * __pyx_v_img.strides[0]) ) + __pyx_t_24 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean)));
}
}
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">51</span>: <span class="n">stddev</span> <span class="o">=</span> <span class="n">sqrt</span><span class="p">(</span><span class="n">stddev</span><span class="o">/</span><span class="n">num</span><span class="p">)</span></pre>
<pre class='cython code score-0 '> __pyx_v_stddev = sqrt((__pyx_v_stddev / __pyx_v_num));
</pre><pre class="cython line score-0">&#xA0;<span class="">52</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">53</span>: <span class="c"># compute normalized image (add epsilon) and std dev image</span></pre>
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">54</span>: <span class="n">img_lcn_view</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span><span class="o">/</span><span class="p">(</span><span class="n">stddev</span><span class="o">+</span><span class="n">eps</span><span class="p">)</span></pre>
<pre class='cython code score-0 '> __pyx_t_25 = __pyx_v_m;
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">51</span>: <span class="n">stddev</span> <span class="o">=</span> <span
class="n">sqrt</span><span class="p">(</span><span class="n">stddev</span><span class="o">/</span><span
class="n">num</span><span class="p">)</span></pre>
<pre class='cython code score-0 '> __pyx_v_stddev = sqrt((__pyx_v_stddev / __pyx_v_num));
</pre>
<pre class="cython line score-0">&#xA0;<span class="">52</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">53</span>: <span class="c"># compute normalized image (add epsilon) and std dev image</span></pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">54</span>: <span class="n">img_lcn_view</span><span class="p">[</span><span
class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span
class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span
class="n">n</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span
class="p">)</span><span class="o">/</span><span class="p">(</span><span class="n">stddev</span><span
class="o">+</span><span class="n">eps</span><span class="p">)</span></pre>
<pre class='cython code score-0 '> __pyx_t_25 = __pyx_v_m;
__pyx_t_26 = __pyx_v_n;
__pyx_t_27 = __pyx_v_m;
__pyx_t_28 = __pyx_v_n;
*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img_lcn_view.data + __pyx_t_27 * __pyx_v_img_lcn_view.strides[0]) ) + __pyx_t_28 * __pyx_v_img_lcn_view.strides[1]) )) = (((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_25 * __pyx_v_img.strides[0]) ) + __pyx_t_26 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean) / (__pyx_v_stddev + __pyx_v_eps));
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">55</span>: <span class="n">img_std_view</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span class="n">stddev</span></pre>
<pre class='cython code score-0 '> __pyx_t_29 = __pyx_v_m;
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">55</span>: <span class="n">img_std_view</span><span class="p">[</span><span
class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span
class="n">stddev</span></pre>
<pre class='cython code score-0 '> __pyx_t_29 = __pyx_v_m;
__pyx_t_30 = __pyx_v_n;
*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img_std_view.data + __pyx_t_29 * __pyx_v_img_std_view.strides[0]) ) + __pyx_t_30 * __pyx_v_img_std_view.strides[1]) )) = __pyx_v_stddev;
}
}
</pre><pre class="cython line score-0">&#xA0;<span class="">56</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">57</span>: <span class="c"># return both</span></pre>
<pre class="cython line score-10" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">58</span>: <span class="k">return</span> <span class="n">img_lcn</span><span class="p">,</span> <span class="n">img_std</span></pre>
<pre class='cython code score-10 '> <span class='pyx_macro_api'>__Pyx_XDECREF</span>(__pyx_r);
</pre>
<pre class="cython line score-0">&#xA0;<span class="">56</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">57</span>: <span class="c"># return both</span></pre>
<pre class="cython line score-10"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">58</span>: <span class="k">return</span> <span class="n">img_lcn</span><span class="p">,</span> <span
class="n">img_std</span></pre>
<pre class='cython code score-10 '> <span class='pyx_macro_api'>__Pyx_XDECREF</span>(__pyx_r);
__pyx_t_1 = <span class='py_c_api'>PyTuple_New</span>(2);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 58, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
<span class='pyx_macro_api'>__Pyx_INCREF</span>(__pyx_v_img_lcn);
@ -717,4 +903,7 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
__pyx_r = __pyx_t_1;
__pyx_t_1 = 0;
goto __pyx_L0;
</pre></div></body></html>
</pre>
</div>
</body>
</html>

@ -2,5 +2,5 @@ from distutils.core import setup
from Cython.Build import cythonize
setup(
ext_modules = cythonize("lcn.pyx",annotate=True)
ext_modules=cythonize("lcn.pyx", annotate=True)
)

@ -5,23 +5,23 @@ from scipy import misc
# load and convert to float
img = misc.imread('img.png')
img = img.astype(np.float32)/255.0
img = img.astype(np.float32) / 255.0
# normalize
img_lcn, img_std = lcn.normalize(img,5,0.05)
img_lcn, img_std = lcn.normalize(img, 5, 0.05)
# normalize to reasonable range between 0 and 1
#img_lcn = img_lcn/3.0
#img_lcn = np.maximum(img_lcn,0.0)
#img_lcn = np.minimum(img_lcn,1.0)
# img_lcn = img_lcn/3.0
# img_lcn = np.maximum(img_lcn,0.0)
# img_lcn = np.minimum(img_lcn,1.0)
# save to file
#misc.imsave('lcn2.png',img_lcn)
# misc.imsave('lcn2.png',img_lcn)
print ("Orig Image: %d x %d (%s), Min: %f, Max: %f" % \
(img.shape[0], img.shape[1], img.dtype, img.min(), img.max()))
print ("Norm Image: %d x %d (%s), Min: %f, Max: %f" % \
(img_lcn.shape[0], img_lcn.shape[1], img_lcn.dtype, img_lcn.min(), img_lcn.max()))
print("Orig Image: %d x %d (%s), Min: %f, Max: %f" % \
(img.shape[0], img.shape[1], img.dtype, img.min(), img.max()))
print("Norm Image: %d x %d (%s), Min: %f, Max: %f" % \
(img_lcn.shape[0], img_lcn.shape[1], img_lcn.dtype, img_lcn.min(), img_lcn.max()))
# plot original image
plt.figure(1)
@ -34,14 +34,14 @@ plt.tight_layout()
plt.figure(2)
img_lcn_plot = plt.imshow(img_lcn)
img_lcn_plot.set_cmap('gray')
#plt.clim(0, 1) # fix range
# plt.clim(0, 1) # fix range
plt.tight_layout()
# plot stddev image
plt.figure(3)
img_std_plot = plt.imshow(img_std)
img_std_plot.set_cmap('gray')
#plt.clim(0, 0.1) # fix range
# plt.clim(0, 0.1) # fix range
plt.tight_layout()
plt.show()

@ -11,20 +11,19 @@ import dataset
def get_data(n, row_from, row_to, train):
imsizes = [(256,384)]
imsizes = [(256, 384)]
focal_lengths = [160]
dset = dataset.SynDataset(n, imsizes=imsizes, focal_lengths=focal_lengths, train=train)
ims = np.empty((n, row_to-row_from, imsizes[0][1]), dtype=np.uint8)
disps = np.empty((n, row_to-row_from, imsizes[0][1]), dtype=np.float32)
ims = np.empty((n, row_to - row_from, imsizes[0][1]), dtype=np.uint8)
disps = np.empty((n, row_to - row_from, imsizes[0][1]), dtype=np.float32)
for idx in range(n):
print(f'load sample {idx} train={train}')
sample = dset[idx]
ims[idx] = (sample['im0'][0,row_from:row_to] * 255).astype(np.uint8)
disps[idx] = sample['disp0'][0,row_from:row_to]
ims[idx] = (sample['im0'][0, row_from:row_to] * 255).astype(np.uint8)
disps[idx] = sample['disp0'][0, row_from:row_to]
return ims, disps
params = hd.TrainParams(
n_trees=4,
max_tree_depth=,
@ -45,16 +44,18 @@ n_test_samples = 32
train_ims, train_disps = get_data(n_train_samples, row_from, row_to, True)
test_ims, test_disps = get_data(n_test_samples, row_from, row_to, False)
for tree_depth in [8,10,12,14,16]:
for tree_depth in [8, 10, 12, 14, 16]:
depth_switch = tree_depth - 4
prefix = f'td{tree_depth}_ds{depth_switch}'
prefix = Path(f'./forests/{prefix}/')
prefix.mkdir(parents=True, exist_ok=True)
hd.train_forest(params, train_ims, train_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch, forest_prefix=str(prefix / 'fr'))
hd.train_forest(params, train_ims, train_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch,
forest_prefix=str(prefix / 'fr'))
es = hd.eval_forest(test_ims, test_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch, forest_prefix=str(prefix / 'fr'))
es = hd.eval_forest(test_ims, test_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch,
forest_prefix=str(prefix / 'fr'))
np.save(str(prefix / 'ta.npy'), test_disps)
np.save(str(prefix / 'es.npy'), es)

@ -8,7 +8,6 @@ import os
this_dir = os.path.dirname(__file__)
extra_compile_args = ['-O3', '-std=c++11']
extra_link_args = []
@ -23,7 +22,7 @@ libraries = ['m']
setup(
name="hyperdepth",
cmdclass= {'build_ext': build_ext},
cmdclass={'build_ext': build_ext},
ext_modules=[
Extension('hyperdepth',
sources,
@ -39,7 +38,3 @@ setup(
)
]
)

@ -6,10 +6,13 @@ orig = cv2.imread('disp_orig.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
ta = cv2.imread('disp_ta.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
es = cv2.imread('disp_es.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
plt.figure()
plt.subplot(2,2,1); plt.imshow(orig / 16, vmin=0, vmax=4, cmap='magma')
plt.subplot(2,2,2); plt.imshow(ta / 16, vmin=0, vmax=4, cmap='magma')
plt.subplot(2,2,3); plt.imshow(es / 16, vmin=0, vmax=4, cmap='magma')
plt.subplot(2,2,4); plt.imshow(np.abs(es - ta) / 16, vmin=0, vmax=1, cmap='magma')
plt.subplot(2, 2, 1);
plt.imshow(orig / 16, vmin=0, vmax=4, cmap='magma')
plt.subplot(2, 2, 2);
plt.imshow(ta / 16, vmin=0, vmax=4, cmap='magma')
plt.subplot(2, 2, 3);
plt.imshow(es / 16, vmin=0, vmax=4, cmap='magma')
plt.subplot(2, 2, 4);
plt.imshow(np.abs(es - ta) / 16, vmin=0, vmax=1, cmap='magma')
plt.show()

@ -12,9 +12,12 @@ import torchext
from model import networks
from data import dataset
class Worker(torchext.Worker):
def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs):
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, train_batch_size=train_batch_size, test_batch_size=test_batch_size, save_frequency=save_frequency, **kwargs)
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers,
train_batch_size=train_batch_size, test_batch_size=test_batch_size,
save_frequency=save_frequency, **kwargs)
self.ms = args.ms
self.pattern_path = args.pattern_path
@ -22,9 +25,9 @@ class Worker(torchext.Worker):
self.dp_weight = args.dp_weight
self.data_type = args.data_type
self.imsizes = [(480,640)]
self.imsizes = [(488, 648)]
for iter in range(3):
self.imsizes.append((int(self.imsizes[-1][0]/2), int(self.imsizes[-1][1]/2)))
self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2)))
with open('config.json') as fp:
config = json.load(fp)
@ -32,11 +35,11 @@ class Worker(torchext.Worker):
self.settings_path = data_root / self.data_type / 'settings.pkl'
sample_paths = sorted((data_root / self.data_type).glob('0*/'))
self.train_paths = sample_paths[2**10:]
self.test_paths = sample_paths[:2**8]
self.train_paths = sample_paths[2 ** 10:]
self.test_paths = sample_paths[:2 ** 8]
# supervise the edge encoder with only 2**8 samples
self.train_edge = len(self.train_paths) - 2**8
self.train_edge = len(self.train_paths) - 2 ** 8
self.lcn_in = networks.LCN(self.lcn_radius, 0.05)
self.disparity_loss = networks.DisparityLoss()
@ -44,19 +47,21 @@ class Worker(torchext.Worker):
# evaluate in the region where opencv Block Matching has valid values
self.eval_mask = np.zeros(self.imsizes[0])
self.eval_mask[13:self.imsizes[0][0]-13, 140:self.imsizes[0][1]-13]=1
self.eval_mask[13:self.imsizes[0][0] - 13, 140:self.imsizes[0][1] - 13] = 1
self.eval_mask = self.eval_mask.astype(np.bool)
self.eval_h = self.imsizes[0][0]-2*13
self.eval_w = self.imsizes[0][1]-13-140
self.eval_h = self.imsizes[0][0] - 2 * 13
self.eval_w = self.imsizes[0][1] - 13 - 140
def get_train_set(self):
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True, track_length=1)
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True,
track_length=1)
return train_set
def get_test_sets(self):
test_sets = torchext.TestSets()
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True, track_length=1)
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True,
track_length=1)
test_sets.append('simple', test_set, test_frequency=1)
# initialize photometric loss modules according to image sizes
@ -66,9 +71,9 @@ class Worker(torchext.Worker):
pat = torch.from_numpy(pat[None][None].astype(np.float32))
pat = pat.to(self.train_device)
self.lcn_in = self.lcn_in.to(self.train_device)
pat,_ = self.lcn_in(pat)
pat, _ = self.lcn_in(pat)
pat = torch.cat([pat for idx in range(3)], dim=1)
self.losses.append( networks.RectifiedPatternSimilarityLoss(imsize[0],imsize[1], pattern=pat) )
self.losses.append(networks.RectifiedPatternSimilarityLoss(imsize[0], imsize[1], pattern=pat))
return test_sets
@ -84,10 +89,10 @@ class Worker(torchext.Worker):
# concatenate the normalized IR input and the original IR image
if 'im' in key and 'blend' not in key:
im = self.data[key]
im_lcn,im_std = self.lcn_in(im)
im_lcn, im_std = self.lcn_in(im)
im_cat = torch.cat((im_lcn, im), dim=1)
key_std = key.replace('im','std')
self.data[key]=im_cat
key_std = key.replace('im', 'std')
self.data[key] = im_cat
self.data[key_std] = im_std.to(device).detach()
def net_forward(self, net, train):
@ -96,35 +101,35 @@ class Worker(torchext.Worker):
def loss_forward(self, out, train):
out, edge = out
if not(isinstance(out, tuple) or isinstance(out, list)):
if not (isinstance(out, tuple) or isinstance(out, list)):
out = [out]
if not(isinstance(edge, tuple) or isinstance(edge, list)):
if not (isinstance(edge, tuple) or isinstance(edge, list)):
edge = [edge]
vals = []
# apply photometric loss
for s,l,o in zip(itertools.count(), self.losses, out):
val, pattern_proj = l(o, self.data[f'im{s}'][:,0:1,...], self.data[f'std{s}'])
for s, l, o in zip(itertools.count(), self.losses, out):
val, pattern_proj = l(o, self.data[f'im{s}'][:, 0:1, ...], self.data[f'std{s}'])
if s == 0:
self.pattern_proj = pattern_proj.detach()
vals.append(val)
# apply disparity loss
# 1-edge as ground truth edge if inversed
edge0 = 1-torch.sigmoid(edge[0])
edge0 = 1 - torch.sigmoid(edge[0])
val = self.disparity_loss(out[0], edge0)
if self.dp_weight>0:
if self.dp_weight > 0:
vals.append(val * self.dp_weight)
# apply edge loss on a subset of training samples
for s,e in zip(itertools.count(), edge):
for s, e in zip(itertools.count(), edge):
# inversed ground truth edge where 0 means edge
grad = self.data[f'grad{s}']<0.2
grad = self.data[f'grad{s}'] < 0.2
grad = grad.to(torch.float32)
ids = self.data['id']
mask = ids>self.train_edge
if mask.sum()>0:
mask = ids > self.train_edge
if mask.sum() > 0:
val = self.edge_loss(e[mask], grad[mask])
else:
val = torch.zeros_like(vals[0])
@ -138,13 +143,13 @@ class Worker(torchext.Worker):
def numpy_in_out(self, output):
output, edge = output
if not(isinstance(output, tuple) or isinstance(output, list)):
if not (isinstance(output, tuple) or isinstance(output, list)):
output = [output]
es = output[0].detach().to('cpu').numpy()
gt = self.data['disp0'].to('cpu').numpy().astype(np.float32)
im = self.data['im0'][:,0:1,...].detach().to('cpu').numpy()
im = self.data['im0'][:, 0:1, ...].detach().to('cpu').numpy()
ma = gt>0
ma = gt > 0
return es, gt, im, ma
def write_img(self, out_path, es, gt, im, ma):
@ -154,49 +159,82 @@ class Worker(torchext.Worker):
diff = np.abs(es - gt)
vmin, vmax = np.nanmin(gt), np.nanmax(gt)
vmin = vmin - 0.2*(vmax-vmin)
vmax = vmax + 0.2*(vmax-vmin)
vmin = vmin - 0.2 * (vmax - vmin)
vmax = vmax + 0.2 * (vmax - vmin)
pattern_proj = self.pattern_proj.to('cpu').numpy()[0,0]
im_orig = self.data['im0'].detach().to('cpu').numpy()[0,0]
pattern_proj = self.pattern_proj.to('cpu').numpy()[0, 0]
im_orig = self.data['im0'].detach().to('cpu').numpy()[0, 0]
pattern_diff = np.abs(im_orig - pattern_proj)
fig = plt.figure(figsize=(16,16))
fig = plt.figure(figsize=(16, 16))
es_ = co.cmap.color_depth_map(es, scale=vmax)
gt_ = co.cmap.color_depth_map(gt, scale=vmax)
diff_ = co.cmap.color_error_image(diff, BGR=True)
# plot disparities, ground truth disparity is shown only for reference
ax = plt.subplot(3,3,1); plt.imshow(es_[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'Disparity Est. {es.min():.4f}/{es.max():.4f}')
ax = plt.subplot(3,3,2); plt.imshow(gt_[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'Disparity GT {np.nanmin(gt):.4f}/{np.nanmax(gt):.4f}')
ax = plt.subplot(3,3,3); plt.imshow(diff_[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'Disparity Err. {diff.mean():.5f}')
ax = plt.subplot(3, 3, 1)
plt.imshow(es_[..., [2, 1, 0]])
plt.xticks([])
plt.yticks([])
ax.set_title(f'Disparity Est. {es.min():.4f}/{es.max():.4f}')
ax = plt.subplot(3, 3, 2)
plt.imshow(gt_[..., [2, 1, 0]])
plt.xticks([])
plt.yticks([])
ax.set_title(f'Disparity GT {np.nanmin(gt):.4f}/{np.nanmax(gt):.4f}')
ax = plt.subplot(3, 3, 3)
plt.imshow(diff_[..., [2, 1, 0]])
plt.xticks([])
plt.yticks([])
ax.set_title(f'Disparity Err. {diff.mean():.5f}')
# plot edges
edge = self.edge.to('cpu').numpy()[0,0]
edge_gt = self.edge_gt.to('cpu').numpy()[0,0]
edge = self.edge.to('cpu').numpy()[0, 0]
edge_gt = self.edge_gt.to('cpu').numpy()[0, 0]
edge_err = np.abs(edge - edge_gt)
ax = plt.subplot(3,3,4); plt.imshow(edge, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Edge Est. {edge.min():.5f}/{edge.max():.5f}')
ax = plt.subplot(3,3,5); plt.imshow(edge_gt, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Edge GT {edge_gt.min():.5f}/{edge_gt.max():.5f}')
ax = plt.subplot(3,3,6); plt.imshow(edge_err, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Edge Err. {edge_err.mean():.5f}')
ax = plt.subplot(3, 3, 4);
plt.imshow(edge, cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'Edge Est. {edge.min():.5f}/{edge.max():.5f}')
ax = plt.subplot(3, 3, 5);
plt.imshow(edge_gt, cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'Edge GT {edge_gt.min():.5f}/{edge_gt.max():.5f}')
ax = plt.subplot(3, 3, 6);
plt.imshow(edge_err, cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'Edge Err. {edge_err.mean():.5f}')
# plot normalized IR input and warped pattern
ax = plt.subplot(3,3,7); plt.imshow(im, vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'IR input {im.mean():.5f}/{im.std():.5f}')
ax = plt.subplot(3,3,8); plt.imshow(pattern_proj, vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Warped Pattern {pattern_proj.mean():.5f}/{pattern_proj.std():.5f}')
im_std = self.data['std0'].to('cpu').numpy()[0,0]
ax = plt.subplot(3,3,9); plt.imshow(im_std, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'IR std {im_std.min():.5f}/{im_std.max():.5f}')
ax = plt.subplot(3, 3, 7);
plt.imshow(im, vmin=im.min(), vmax=im.max(), cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'IR input {im.mean():.5f}/{im.std():.5f}')
ax = plt.subplot(3, 3, 8);
plt.imshow(pattern_proj, vmin=im.min(), vmax=im.max(), cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'Warped Pattern {pattern_proj.mean():.5f}/{pattern_proj.std():.5f}')
im_std = self.data['std0'].to('cpu').numpy()[0, 0]
ax = plt.subplot(3, 3, 9);
plt.imshow(im_std, cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'IR std {im_std.min():.5f}/{im_std.max():.5f}')
plt.tight_layout()
plt.savefig(str(out_path))
plt.close(fig)
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks=[]):
if batch_idx % 512 == 0:
out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png'
es, gt, im, ma = self.numpy_in_out(output)
self.write_img(out_path, es[0,0], gt[0,0], im[0,0], ma[0,0])
self.write_img(out_path, es[0, 0], gt[0, 0], im[0, 0], ma[0, 0])
def callback_test_start(self, epoch, set_idx):
self.metric = co.metric.MultipleMetric(
@ -209,12 +247,12 @@ class Worker(torchext.Worker):
if batch_idx % 8 == 0:
out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png'
self.write_img(out_path, es[0,0], gt[0,0], im[0,0], ma[0,0])
self.write_img(out_path, es[0, 0], gt[0, 0], im[0, 0], ma[0, 0])
es, gt, im, ma = self.crop_output(es, gt, im, ma)
es = es.reshape(-1,1)
gt = gt.reshape(-1,1)
es = es.reshape(-1, 1)
gt = gt.reshape(-1, 1)
ma = ma.ravel()
self.metric.add(es, gt, ma)
@ -225,13 +263,12 @@ class Worker(torchext.Worker):
def crop_output(self, es, gt, im, ma):
bs = es.shape[0]
es = np.reshape(es[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
gt = np.reshape(gt[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
im = np.reshape(im[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
ma = np.reshape(ma[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
es = np.reshape(es[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
gt = np.reshape(gt[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
im = np.reshape(im[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
ma = np.reshape(ma[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
return es, gt, im, ma
if __name__ == '__main__':
pass

@ -12,9 +12,12 @@ import torchext
from model import networks
from data import dataset
class Worker(torchext.Worker):
def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs):
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, train_batch_size=train_batch_size, test_batch_size=test_batch_size, save_frequency=save_frequency, **kwargs)
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers,
train_batch_size=train_batch_size, test_batch_size=test_batch_size,
save_frequency=save_frequency, **kwargs)
self.ms = args.ms
self.pattern_path = args.pattern_path
@ -23,11 +26,11 @@ class Worker(torchext.Worker):
self.ge_weight = args.ge_weight
self.track_length = args.track_length
self.data_type = args.data_type
assert(self.track_length>1)
assert (self.track_length > 1)
self.imsizes = [(480,640)]
self.imsizes = [(480, 640)]
for iter in range(3):
self.imsizes.append((int(self.imsizes[-1][0]/2), int(self.imsizes[-1][1]/2)))
self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2)))
with open('config.json') as fp:
config = json.load(fp)
@ -35,11 +38,11 @@ class Worker(torchext.Worker):
self.settings_path = data_root / self.data_type / 'settings.pkl'
sample_paths = sorted((data_root / self.data_type).glob('0*/'))
self.train_paths = sample_paths[2**10:]
self.test_paths = sample_paths[:2**8]
self.train_paths = sample_paths[2 ** 10:]
self.test_paths = sample_paths[:2 ** 8]
# supervise the edge encoder with only 2**8 samples
self.train_edge = len(self.train_paths) - 2**8
self.train_edge = len(self.train_paths) - 2 ** 8
self.lcn_in = networks.LCN(self.lcn_radius, 0.05)
self.disparity_loss = networks.DisparityLoss()
@ -47,19 +50,20 @@ class Worker(torchext.Worker):
# evaluate in the region where opencv Block Matching has valid values
self.eval_mask = np.zeros(self.imsizes[0])
self.eval_mask[13:self.imsizes[0][0]-13, 140:self.imsizes[0][1]-13]=1
self.eval_mask[13:self.imsizes[0][0] - 13, 140:self.imsizes[0][1] - 13] = 1
self.eval_mask = self.eval_mask.astype(np.bool)
self.eval_h = self.imsizes[0][0]-2*13
self.eval_w = self.imsizes[0][1]-13-140
self.eval_h = self.imsizes[0][0] - 2 * 13
self.eval_w = self.imsizes[0][1] - 13 - 140
def get_train_set(self):
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True, track_length=self.track_length)
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True,
track_length=self.track_length)
return train_set
def get_test_sets(self):
test_sets = torchext.TestSets()
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True, track_length=1)
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True,
track_length=1)
test_sets.append('simple', test_set, test_frequency=1)
self.ph_losses = []
@ -72,9 +76,9 @@ class Worker(torchext.Worker):
pat = test_set.patterns[sidx]
pat = pat.mean(axis=2)
pat = torch.from_numpy(pat[None][None].astype(np.float32)).to('cuda')
pat,_ = self.lcn_in(pat)
pat, _ = self.lcn_in(pat)
pat = torch.cat([pat for idx in range(3)], dim=1)
ph_loss = networks.RectifiedPatternSimilarityLoss(imsize[0],imsize[1], pattern=pat)
ph_loss = networks.RectifiedPatternSimilarityLoss(imsize[0], imsize[1], pattern=pat)
K = test_set.getK(sidx)
Ki = np.linalg.inv(K)
@ -82,11 +86,11 @@ class Worker(torchext.Worker):
Ki = torch.from_numpy(Ki)
ge_loss = networks.ProjectionDepthSimilarityLoss(K, Ki, imsize[0], imsize[1], clamp=0.1)
self.ph_losses.append( ph_loss )
self.ge_losses.append( ge_loss )
self.ph_losses.append(ph_loss)
self.ge_losses.append(ge_loss)
d2d = networks.DispToDepth(float(test_set.focal_lengths[sidx]), float(test_set.baseline))
self.d2ds.append( d2d )
self.d2ds.append(d2d)
return test_sets
@ -99,9 +103,9 @@ class Worker(torchext.Worker):
# batch_size x track_length x ...
# to
# track_length x batch_size x ...
if len(val.shape)>2:
if len(val.shape) > 2:
if train:
val = val.transpose(0,1)
val = val.transpose(0, 1)
else:
val = val.unsqueeze(0)
grad = 'im' in key and requires_grad
@ -110,8 +114,8 @@ class Worker(torchext.Worker):
im = self.data[key]
tl = im.shape[0]
bs = im.shape[1]
im_lcn,im_std = self.lcn_in(im.contiguous().view(-1, *im.shape[2:]))
key_std = key.replace('im','std')
im_lcn, im_std = self.lcn_in(im.contiguous().view(-1, *im.shape[2:]))
key_std = key.replace('im', 'std')
self.data[key_std] = im_std.view(tl, bs, *im.shape[2:]).to(device)
im_cat = torch.cat((im_lcn.view(tl, bs, *im.shape[2:]), im), dim=2)
self.data[key] = im_cat
@ -122,7 +126,7 @@ class Worker(torchext.Worker):
bs = im0.shape[1]
im0 = im0.view(-1, *im0.shape[2:])
out, edge = net(im0)
if not(isinstance(out, tuple) or isinstance(out, list)):
if not (isinstance(out, tuple) or isinstance(out, list)):
out = out.view(tl, bs, *out.shape[1:])
edge = edge.view(tl, bs, *out.shape[1:])
else:
@ -132,42 +136,42 @@ class Worker(torchext.Worker):
def loss_forward(self, out, train):
out, edge = out
if not(isinstance(out, tuple) or isinstance(out, list)):
if not (isinstance(out, tuple) or isinstance(out, list)):
out = [out]
vals = []
diffs = []
# apply photometric loss
for s,l,o in zip(itertools.count(), self.ph_losses, out):
for s, l, o in zip(itertools.count(), self.ph_losses, out):
im = self.data[f'im{s}']
im = im.view(-1, *im.shape[2:])
o = o.view(-1, *o.shape[2:])
std = self.data[f'std{s}']
std = std.view(-1, *std.shape[2:])
val, pattern_proj = l(o, im[:,0:1,...], std)
val, pattern_proj = l(o, im[:, 0:1, ...], std)
vals.append(val)
if s == 0:
self.pattern_proj = pattern_proj.detach()
# apply disparity loss
# 1-edge as ground truth edge if inversed
edge0 = 1-torch.sigmoid(edge[0])
edge0 = 1 - torch.sigmoid(edge[0])
edge0 = edge0.view(-1, *edge0.shape[2:])
out0 = out[0].view(-1, *out[0].shape[2:])
val = self.disparity_loss(out0, edge0)
if self.dp_weight>0:
if self.dp_weight > 0:
vals.append(val * self.dp_weight)
# apply edge loss on a subset of training samples
for s,e in zip(itertools.count(), edge):
for s, e in zip(itertools.count(), edge):
# inversed ground truth edge where 0 means edge
grad = self.data[f'grad{s}']<0.2
grad = self.data[f'grad{s}'] < 0.2
grad = grad.to(torch.float32)
ids = self.data['id']
mask = ids>self.train_edge
if mask.sum()>0:
e = e[:,mask,:]
grad = grad[:,mask,:]
mask = ids > self.train_edge
if mask.sum() > 0:
e = e[:, mask, :]
grad = grad[:, mask, :]
e = e.view(-1, *e.shape[2:])
grad = grad.view(-1, *grad.shape[2:])
val = self.edge_loss(e, grad)
@ -181,14 +185,14 @@ class Worker(torchext.Worker):
# apply geometric loss
R = self.data['R']
t = self.data['t']
ge_num = self.track_length * (self.track_length-1) / 2
ge_num = self.track_length * (self.track_length - 1) / 2
for sidx in range(len(out)):
d2d = self.d2ds[sidx]
depth = d2d(out[sidx])
ge_loss = self.ge_losses[sidx]
imsize = self.imsizes[sidx]
for tidx0 in range(depth.shape[0]):
for tidx1 in range(tidx0+1, depth.shape[0]):
for tidx1 in range(tidx0 + 1, depth.shape[0]):
depth0 = depth[tidx0]
R0 = R[tidx0]
t0 = t[tidx0]
@ -203,12 +207,12 @@ class Worker(torchext.Worker):
def numpy_in_out(self, output):
output, edge = output
if not(isinstance(output, tuple) or isinstance(output, list)):
if not (isinstance(output, tuple) or isinstance(output, list)):
output = [output]
es = output[0].detach().to('cpu').numpy()
gt = self.data['disp0'].to('cpu').numpy().astype(np.float32)
im = self.data['im0'][:,:,0:1,...].detach().to('cpu').numpy()
ma = gt>0
im = self.data['im0'][:, :, 0:1, ...].detach().to('cpu').numpy()
ma = gt > 0
return es, gt, im, ma
def write_img(self, out_path, es, gt, im, ma):
@ -218,36 +222,68 @@ class Worker(torchext.Worker):
diff = np.abs(es - gt)
vmin, vmax = np.nanmin(gt), np.nanmax(gt)
vmin = vmin - 0.2*(vmax-vmin)
vmax = vmax + 0.2*(vmax-vmin)
vmin = vmin - 0.2 * (vmax - vmin)
vmax = vmax + 0.2 * (vmax - vmin)
pattern_proj = self.pattern_proj.to('cpu').numpy()[0,0]
im_orig = self.data['im0'].detach().to('cpu').numpy()[0,0,0]
pattern_proj = self.pattern_proj.to('cpu').numpy()[0, 0]
im_orig = self.data['im0'].detach().to('cpu').numpy()[0, 0, 0]
pattern_diff = np.abs(im_orig - pattern_proj)
fig = plt.figure(figsize=(16,16))
fig = plt.figure(figsize=(16, 16))
es0 = co.cmap.color_depth_map(es[0], scale=vmax)
gt0 = co.cmap.color_depth_map(gt[0], scale=vmax)
diff0 = co.cmap.color_error_image(diff[0], BGR=True)
# plot disparities, ground truth disparity is shown only for reference
ax = plt.subplot(3,3,1); plt.imshow(es0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Est. {es0.min():.4f}/{es0.max():.4f}')
ax = plt.subplot(3,3,2); plt.imshow(gt0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity GT {np.nanmin(gt0):.4f}/{np.nanmax(gt0):.4f}')
ax = plt.subplot(3,3,3); plt.imshow(diff0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Err. {diff0.mean():.5f}')
ax = plt.subplot(3, 3, 1);
plt.imshow(es0[..., [2, 1, 0]]);
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F0 Disparity Est. {es0.min():.4f}/{es0.max():.4f}')
ax = plt.subplot(3, 3, 2);
plt.imshow(gt0[..., [2, 1, 0]]);
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F0 Disparity GT {np.nanmin(gt0):.4f}/{np.nanmax(gt0):.4f}')
ax = plt.subplot(3, 3, 3);
plt.imshow(diff0[..., [2, 1, 0]]);
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F0 Disparity Err. {diff0.mean():.5f}')
# plot disparities of the second frame in the track if exists
if es.shape[0]>=2:
if es.shape[0] >= 2:
es1 = co.cmap.color_depth_map(es[1], scale=vmax)
gt1 = co.cmap.color_depth_map(gt[1], scale=vmax)
diff1 = co.cmap.color_error_image(diff[1], BGR=True)
ax = plt.subplot(3,3,4); plt.imshow(es1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity Est. {es1.min():.4f}/{es1.max():.4f}')
ax = plt.subplot(3,3,5); plt.imshow(gt1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity GT {np.nanmin(gt1):.4f}/{np.nanmax(gt1):.4f}')
ax = plt.subplot(3,3,6); plt.imshow(diff1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity Err. {diff1.mean():.5f}')
ax = plt.subplot(3, 3, 4);
plt.imshow(es1[..., [2, 1, 0]]);
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F1 Disparity Est. {es1.min():.4f}/{es1.max():.4f}')
ax = plt.subplot(3, 3, 5);
plt.imshow(gt1[..., [2, 1, 0]]);
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F1 Disparity GT {np.nanmin(gt1):.4f}/{np.nanmax(gt1):.4f}')
ax = plt.subplot(3, 3, 6);
plt.imshow(diff1[..., [2, 1, 0]]);
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F1 Disparity Err. {diff1.mean():.5f}')
# plot normalized IR inputs
ax = plt.subplot(3,3,7); plt.imshow(im[0], vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 IR input {im[0].mean():.5f}/{im[0].std():.5f}')
if es.shape[0]>=2:
ax = plt.subplot(3,3,8); plt.imshow(im[1], vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 IR input {im[1].mean():.5f}/{im[1].std():.5f}')
ax = plt.subplot(3, 3, 7);
plt.imshow(im[0], vmin=im.min(), vmax=im.max(), cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F0 IR input {im[0].mean():.5f}/{im[0].std():.5f}')
if es.shape[0] >= 2:
ax = plt.subplot(3, 3, 8);
plt.imshow(im[1], vmin=im.min(), vmax=im.max(), cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F1 IR input {im[1].mean():.5f}/{im[1].std():.5f}')
plt.tight_layout()
plt.savefig(str(out_path))
@ -257,8 +293,8 @@ class Worker(torchext.Worker):
if batch_idx % 512 == 0:
out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png'
es, gt, im, ma = self.numpy_in_out(output)
masks = [ m.detach().to('cpu').numpy() for m in masks ]
self.write_img(out_path, es[:,0,0], gt[:,0,0], im[:,0,0], ma[:,0,0])
masks = [m.detach().to('cpu').numpy() for m in masks]
self.write_img(out_path, es[:, 0, 0], gt[:, 0, 0], im[:, 0, 0], ma[:, 0, 0])
def callback_test_start(self, epoch, set_idx):
self.metric = co.metric.MultipleMetric(
@ -271,12 +307,12 @@ class Worker(torchext.Worker):
if batch_idx % 8 == 0:
out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png'
self.write_img(out_path, es[:,0,0], gt[:,0,0], im[:,0,0], ma[:,0,0])
self.write_img(out_path, es[:, 0, 0], gt[:, 0, 0], im[:, 0, 0], ma[:, 0, 0])
es, gt, im, ma = self.crop_output(es, gt, im, ma)
es = es.reshape(-1,1)
gt = gt.reshape(-1,1)
es = es.reshape(-1, 1)
gt = gt.reshape(-1, 1)
ma = ma.ravel()
self.metric.add(es, gt, ma)
@ -288,11 +324,12 @@ class Worker(torchext.Worker):
def crop_output(self, es, gt, im, ma):
tl = es.shape[0]
bs = es.shape[1]
es = np.reshape(es[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w])
gt = np.reshape(gt[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w])
im = np.reshape(im[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w])
ma = np.reshape(ma[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w])
es = np.reshape(es[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w])
gt = np.reshape(gt[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w])
im = np.reshape(im[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w])
ma = np.reshape(ma[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w])
return es, gt, im, ma
if __name__ == '__main__':
pass

@ -44,7 +44,7 @@ class PosOutput(TimedModule):
def tforward(self, x):
if self.u_pos is None:
self.u_pos = torch.arange(x.shape[3], dtype=torch.float32).view(1,1,1,-1)
self.u_pos = torch.arange(x.shape[3], dtype=torch.float32).view(1, 1, 1, -1)
self.u_pos = self.u_pos.to(x.device)
pos = self.layer(x)
disp = self.u_pos - pos
@ -61,6 +61,7 @@ class OutputLayerFactory(object):
pos: estimate the absolute location
pos_row: independently estimate the absolute location per row
'''
def __init__(self, type='disp', params={}):
self.type = type
self.params = params
@ -98,7 +99,7 @@ class SigmoidAffine(TimedModule):
self.offset = offset
def tforward(self, x):
return torch.sigmoid(x/self.gamma - self.offset) * self.alpha + self.beta
return torch.sigmoid(x / self.gamma - self.offset) * self.alpha + self.beta
class MultiLinear(TimedModule):
@ -110,26 +111,27 @@ class MultiLinear(TimedModule):
self.mods.append(torch.nn.Linear(channels_in, channels_out))
def tforward(self, x):
x = x.permute(2,0,3,1) # BxCxHxW => HxBxWxC
x = x.permute(2, 0, 3, 1) # BxCxHxW => HxBxWxC
y = x.new_empty(*x.shape[:-1], self.channels_out)
for hidx in range(x.shape[0]):
y[hidx] = self.mods[hidx](x[hidx])
y = y.permute(1,3,0,2) # HxBxWxC => BxCxHxW
y = y.permute(1, 3, 0, 2) # HxBxWxC => BxCxHxW
return y
class DispNetS(TimedModule):
'''
Disparity Decoder based on DispNetS
'''
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False, channel_multiplier=1):
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False,
channel_multiplier=1):
super(DispNetS, self).__init__(mod_name='DispNetS')
self.output_ms = output_ms
self.coordconv = coordconv
conv_planes = channel_multiplier * np.array( [32, 64, 128, 256, 512, 512, 512] )
conv_planes = channel_multiplier * np.array([32, 64, 128, 256, 512, 512, 512])
self.conv1 = self.downsample_conv(channels_in, conv_planes[0], kernel_size=7)
self.conv2 = self.downsample_conv(conv_planes[0], conv_planes[1], kernel_size=5)
self.conv3 = self.downsample_conv(conv_planes[1], conv_planes[2])
@ -138,7 +140,7 @@ class DispNetS(TimedModule):
self.conv6 = self.downsample_conv(conv_planes[4], conv_planes[5])
self.conv7 = self.downsample_conv(conv_planes[5], conv_planes[6])
upconv_planes = channel_multiplier * np.array( [512, 512, 256, 128, 64, 32, 16] )
upconv_planes = channel_multiplier * np.array([512, 512, 256, 128, 64, 32, 16])
self.upconv7 = self.upconv(conv_planes[6], upconv_planes[0])
self.upconv6 = self.upconv(upconv_planes[0], upconv_planes[1])
self.upconv5 = self.upconv(upconv_planes[1], upconv_planes[2])
@ -166,7 +168,6 @@ class DispNetS(TimedModule):
self.predict_disp2 = output_facs(upconv_planes[5], imsizes[1])
self.predict_disp1 = output_facs(upconv_planes[6], imsizes[0])
def init_weights(self):
for m in self.modules():
if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.ConvTranspose2d):
@ -176,13 +177,15 @@ class DispNetS(TimedModule):
def downsample_conv(self, in_planes, out_planes, kernel_size=3):
if self.coordconv:
conv = torchext.CoordConv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=(kernel_size-1)//2)
conv = torchext.CoordConv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2,
padding=(kernel_size - 1) // 2)
else:
conv = torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=(kernel_size-1)//2)
conv = torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2,
padding=(kernel_size - 1) // 2)
return torch.nn.Sequential(
conv,
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(out_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size-1)//2),
torch.nn.Conv2d(out_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size - 1) // 2),
torch.nn.ReLU(inplace=True)
)
@ -199,7 +202,7 @@ class DispNetS(TimedModule):
)
def crop_like(self, input, ref):
assert(input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3))
assert (input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3))
return input[:, :, :ref.size(2), :ref.size(3)]
def tforward(self, x):
@ -229,19 +232,22 @@ class DispNetS(TimedModule):
disp4 = self.predict_disp4(out_iconv4)
out_upconv3 = self.crop_like(self.upconv3(out_iconv4), out_conv2)
disp4_up = self.crop_like(torch.nn.functional.interpolate(disp4, scale_factor=2, mode='bilinear', align_corners=False), out_conv2)
disp4_up = self.crop_like(
torch.nn.functional.interpolate(disp4, scale_factor=2, mode='bilinear', align_corners=False), out_conv2)
concat3 = torch.cat((out_upconv3, out_conv2, disp4_up), 1)
out_iconv3 = self.iconv3(concat3)
disp3 = self.predict_disp3(out_iconv3)
out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1)
disp3_up = self.crop_like(torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
disp3_up = self.crop_like(
torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1)
out_iconv2 = self.iconv2(concat2)
disp2 = self.predict_disp2(out_iconv2)
out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x)
disp2_up = self.crop_like(torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x)
disp2_up = self.crop_like(
torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x)
concat1 = torch.cat((out_upconv1, disp2_up), 1)
out_iconv1 = self.iconv1(concat1)
disp1 = self.predict_disp1(out_iconv1)
@ -256,6 +262,7 @@ class DispNetShallow(DispNetS):
'''
Edge Decoder based on DispNetS with fewer layers
'''
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False):
super(DispNetShallow, self).__init__(channels_in, imsizes, output_facs, output_ms, coordconv, weight_init)
self.mod_name = 'DispNetShallow'
@ -274,13 +281,15 @@ class DispNetShallow(DispNetS):
disp3 = self.predict_disp3(out_iconv3)
out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1)
disp3_up = self.crop_like(torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
disp3_up = self.crop_like(
torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1)
out_iconv2 = self.iconv2(concat2)
disp2 = self.predict_disp2(out_iconv2)
out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x)
disp2_up = self.crop_like(torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x)
disp2_up = self.crop_like(
torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x)
concat1 = torch.cat((out_upconv1, disp2_up), 1)
out_iconv1 = self.iconv1(concat1)
disp1 = self.predict_disp1(out_iconv1)
@ -295,13 +304,16 @@ class DispEdgeDecoders(TimedModule):
'''
Disparity Decoder and Edge Decoder
'''
def __init__(self, *args, max_disp=128, **kwargs):
super(DispEdgeDecoders, self).__init__(mod_name='DispEdgeDecoders')
output_facs = [OutputLayerFactory( type='disp', params={ 'alpha': max_disp/(2**s), 'beta': 0, 'gamma': 1, 'offset': 3}) for s in range(4)]
output_facs = [
OutputLayerFactory(type='disp', params={'alpha': max_disp / (2 ** s), 'beta': 0, 'gamma': 1, 'offset': 3})
for s in range(4)]
self.disp_decoder = DispNetS(*args, output_facs=output_facs, **kwargs)
output_facs = [OutputLayerFactory( type='linear' ) for s in range(4)]
output_facs = [OutputLayerFactory(type='linear') for s in range(4)]
self.edge_decoder = DispNetShallow(*args, output_facs=output_facs, **kwargs)
def tforward(self, x):
@ -328,7 +340,7 @@ class PosToDepth(DispToDepth):
self.im_height = im_height
self.im_width = im_width
self.u_pos = torch.arange(im_width, dtype=torch.float32).view(1,1,1,-1)
self.u_pos = torch.arange(im_width, dtype=torch.float32).view(1, 1, 1, -1)
def tforward(self, pos):
self.u_pos = self.u_pos.to(pos.device)
@ -336,11 +348,11 @@ class PosToDepth(DispToDepth):
return super().forward(disp)
class RectifiedPatternSimilarityLoss(TimedModule):
'''
Photometric Loss
'''
def __init__(self, im_height, im_width, pattern, loss_type='census_sad', loss_eps=0.5):
super().__init__(mod_name='RectifiedPatternSimilarityLoss')
self.im_height = im_height
@ -348,8 +360,8 @@ class RectifiedPatternSimilarityLoss(TimedModule):
self.pattern = pattern.mean(dim=1, keepdim=True).contiguous()
u, v = np.meshgrid(range(im_width), range(im_height))
uv0 = np.stack((u,v), axis=2).reshape(-1,1)
uv0 = uv0.astype(np.float32).reshape(1,-1,2)
uv0 = np.stack((u, v), axis=2).reshape(-1, 1)
uv0 = uv0.astype(np.float32).reshape(1, -1, 2)
self.uv0 = torch.from_numpy(uv0)
self.loss_type = loss_type
@ -361,82 +373,84 @@ class RectifiedPatternSimilarityLoss(TimedModule):
uv0 = self.uv0.expand(disp0.shape[0], *self.uv0.shape[1:])
uv1 = torch.empty_like(uv0)
uv1[...,0] = uv0[...,0] - disp0.contiguous().view(disp0.shape[0],-1)
uv1[...,1] = uv0[...,1]
uv1[..., 0] = uv0[..., 0] - disp0.contiguous().view(disp0.shape[0], -1)
uv1[..., 1] = uv0[..., 1]
uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width-1) - 0.5)
uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height-1) - 0.5)
uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width - 1) - 0.5)
uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height - 1) - 0.5)
uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone()
pattern = self.pattern.expand(disp0.shape[0], *self.pattern.shape[1:])
pattern_proj = torch.nn.functional.grid_sample(pattern, uv1, padding_mode='border')
mask = torch.ones_like(im)
if std is not None:
mask = mask*std
mask = mask * std
diff = torchext.photometric_loss(pattern_proj.contiguous(), im.contiguous(), 9, self.loss_type, self.loss_eps)
val = (mask*diff).sum() / mask.sum()
val = (mask * diff).sum() / mask.sum()
return val, pattern_proj
class DisparityLoss(TimedModule):
'''
Disparity Loss
'''
def __init__(self):
super().__init__(mod_name='DisparityLoss')
self.sobel = SobelFilter(norm=False)
#if not edge_gt:
self.b0=0.0503428816795
self.b1=1.07274045944
#else:
# if not edge_gt:
self.b0 = 0.0503428816795
self.b1 = 1.07274045944
# else:
# self.b0=0.0587115108967
# self.b1=1.51931190491
def tforward(self, disp, edge=None):
self.sobel=self.sobel.to(disp.device)
self.sobel = self.sobel.to(disp.device)
if edge is not None:
grad = self.sobel(disp)
grad = torch.sqrt(grad[:,0:1,...]**2 + grad[:,1:2,...]**2 + 1e-8)
pdf = (1-edge)/self.b0 * torch.exp(-torch.abs(grad)/self.b0) + \
edge/self.b1 * torch.exp(-torch.abs(grad)/self.b1)
grad = torch.sqrt(grad[:, 0:1, ...] ** 2 + grad[:, 1:2, ...] ** 2 + 1e-8)
pdf = (1 - edge) / self.b0 * torch.exp(-torch.abs(grad) / self.b0) + \
edge / self.b1 * torch.exp(-torch.abs(grad) / self.b1)
val = torch.mean(-torch.log(pdf.clamp(min=1e-4)))
else:
# on qifeng's data we don't have ambient info
# therefore we supress edge everywhere
grad = self.sobel(disp)
grad = torch.sqrt(grad[:,0:1,...]**2 + grad[:,1:2,...]**2 + 1e-8)
grad= torch.clamp(grad, 0, 1.0)
grad = torch.sqrt(grad[:, 0:1, ...] ** 2 + grad[:, 1:2, ...] ** 2 + 1e-8)
grad = torch.clamp(grad, 0, 1.0)
val = torch.mean(grad)
return val
class ProjectionBaseLoss(TimedModule):
'''
Base module of the Geometric Loss
'''
def __init__(self, K, Ki, im_height, im_width):
super().__init__(mod_name='ProjectionBaseLoss')
self.K = K.view(-1,3,3)
self.K = K.view(-1, 3, 3)
self.im_height = im_height
self.im_width = im_width
u, v = np.meshgrid(range(im_width), range(im_height))
uv = np.stack((u,v,np.ones_like(u)), axis=2).reshape(-1,3)
uv = np.stack((u, v, np.ones_like(u)), axis=2).reshape(-1, 3)
ray = uv @ Ki.numpy().T
ray = ray.reshape(1,-1,3).astype(np.float32)
ray = ray.reshape(1, -1, 3).astype(np.float32)
self.ray = torch.from_numpy(ray)
def transform(self, xyz, R=None, t=None):
if t is not None:
bs = xyz.shape[0]
xyz = xyz - t.reshape(bs,1,3)
xyz = xyz - t.reshape(bs, 1, 3)
if R is not None:
xyz = torch.bmm(xyz, R)
return xyz
@ -445,7 +459,7 @@ class ProjectionBaseLoss(TimedModule):
self.ray = self.ray.to(depth.device)
bs = depth.shape[0]
xyz = depth.reshape(bs,-1,1) * self.ray
xyz = depth.reshape(bs, -1, 1) * self.ray
xyz = self.transform(xyz, R, t)
return xyz
@ -453,19 +467,18 @@ class ProjectionBaseLoss(TimedModule):
self.K = self.K.to(xyz.device)
bs = xyz.shape[0]
xyz = torch.bmm(xyz, R.transpose(1,2))
xyz = xyz + t.reshape(bs,1,3)
xyz = torch.bmm(xyz, R.transpose(1, 2))
xyz = xyz + t.reshape(bs, 1, 3)
Kt = self.K.transpose(1,2).expand(bs,-1,-1)
Kt = self.K.transpose(1, 2).expand(bs, -1, -1)
uv = torch.bmm(xyz, Kt)
d = uv[:,:,2:3]
d = uv[:, :, 2:3]
# avoid division by zero
uv = uv[:,:,:2] / (torch.nn.functional.relu(d) + 1e-12)
uv = uv[:, :, :2] / (torch.nn.functional.relu(d) + 1e-12)
return uv, d
def tforward(self, depth0, R0, t0, R1, t1):
xyz = self.unproject(depth0, R0, t0)
return self.project(xyz, R1, t1)
@ -475,6 +488,7 @@ class ProjectionDepthSimilarityLoss(ProjectionBaseLoss):
'''
Geometric Loss
'''
def __init__(self, *args, clamp=-1):
super().__init__(*args)
self.mod_name = 'ProjectionDepthSimilarityLoss'
@ -483,8 +497,8 @@ class ProjectionDepthSimilarityLoss(ProjectionBaseLoss):
def fwd(self, depth0, depth1, R0, t0, R1, t1):
uv1, d1 = super().tforward(depth0, R0, t0, R1, t1)
uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width-1) - 0.5)
uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height-1) - 0.5)
uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width - 1) - 0.5)
uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height - 1) - 0.5)
uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone()
depth10 = torch.nn.functional.grid_sample(depth1, uv1, padding_mode='border')
@ -500,21 +514,21 @@ class ProjectionDepthSimilarityLoss(ProjectionBaseLoss):
def tforward(self, depth0, depth1, R0, t0, R1, t1):
l0 = self.fwd(depth0, depth1, R0, t0, R1, t1)
l1 = self.fwd(depth1, depth0, R1, t1, R0, t0)
return l0+l1
return l0 + l1
class LCN(TimedModule):
'''
Local Contract Normalization
'''
def __init__(self, radius, epsilon):
super().__init__(mod_name='LCN')
self.box_conv = torch.nn.Sequential(
torch.nn.ReflectionPad2d(radius),
torch.nn.Conv2d(1, 1, kernel_size=2*radius+1, bias=False)
torch.nn.Conv2d(1, 1, kernel_size=2 * radius + 1, bias=False)
)
self.box_conv[1].weight.requires_grad=False
self.box_conv[1].weight.requires_grad = False
self.box_conv[1].weight.fill_(1.)
self.epsilon = epsilon
@ -523,44 +537,43 @@ class LCN(TimedModule):
def tforward(self, data):
boxs = self.box_conv(data)
avgs = boxs / (2*self.radius+1)**2
boxs_n2 = boxs**2
boxs_2n = self.box_conv(data**2)
avgs = boxs / (2 * self.radius + 1) ** 2
boxs_n2 = boxs ** 2
boxs_2n = self.box_conv(data ** 2)
stds = torch.sqrt(boxs_2n / (2*self.radius+1)**2 - avgs**2 + 1e-6)
stds = torch.sqrt(boxs_2n / (2 * self.radius + 1) ** 2 - avgs ** 2 + 1e-6)
stds = stds + self.epsilon
return (data - avgs) / stds, stds
class SobelFilter(TimedModule):
'''
Sobel Filter
'''
def __init__(self, norm=False):
super(SobelFilter, self).__init__(mod_name='SobelFilter')
kx = np.array([[-5, -4, 0, 4, 5],
[-8, -10, 0, 10, 8],
[-10, -20, 0, 20, 10],
[-8, -10, 0, 10, 8],
[-5, -4, 0, 4, 5]])/240.0
ky = kx.copy().transpose(1,0)
[-5, -4, 0, 4, 5]]) / 240.0
ky = kx.copy().transpose(1, 0)
self.conv_x=torch.nn.Conv2d(1, 1, kernel_size=5, stride=1, padding=0, bias=False)
self.conv_x.weight=torch.nn.Parameter(torch.from_numpy(kx).float().unsqueeze(0).unsqueeze(0))
self.conv_x = torch.nn.Conv2d(1, 1, kernel_size=5, stride=1, padding=0, bias=False)
self.conv_x.weight = torch.nn.Parameter(torch.from_numpy(kx).float().unsqueeze(0).unsqueeze(0))
self.conv_y=torch.nn.Conv2d(1, 1, kernel_size=5, stride=1, padding=0, bias=False)
self.conv_y.weight=torch.nn.Parameter(torch.from_numpy(ky).float().unsqueeze(0).unsqueeze(0))
self.conv_y = torch.nn.Conv2d(1, 1, kernel_size=5, stride=1, padding=0, bias=False)
self.conv_y.weight = torch.nn.Parameter(torch.from_numpy(ky).float().unsqueeze(0).unsqueeze(0))
self.norm=norm
self.norm = norm
def tforward(self,x):
x = F.pad(x, (2,2,2,2), "replicate")
def tforward(self, x):
x = F.pad(x, (2, 2, 2, 2), "replicate")
gx = self.conv_x(x)
gy = self.conv_y(x)
if self.norm:
return torch.sqrt(gx**2 + gy**2 + 1e-8)
return torch.sqrt(gx ** 2 + gy ** 2 + 1e-8)
else:
return torch.cat((gx, gy), dim=1)

@ -6,7 +6,9 @@ This repository contains the code for the paper
**[Connecting the Dots: Learning Representations for Active Monocular Depth Estimation](http://www.cvlibs.net/publications/Riegler2019CVPR.pdf)**
<br>
[Gernot Riegler](https://griegler.github.io/), [Yiyi Liao](https://yiyiliao.github.io/), [Simon Donne](https://avg.is.tuebingen.mpg.de/person/sdonne), [Vladlen Koltun](http://vladlen.info/), and [Andreas Geiger](http://www.cvlibs.net/)
[Gernot Riegler](https://griegler.github.io/), [Yiyi Liao](https://yiyiliao.github.io/)
, [Simon Donne](https://avg.is.tuebingen.mpg.de/person/sdonne), [Vladlen Koltun](http://vladlen.info/),
and [Andreas Geiger](http://www.cvlibs.net/)
<br>
[CVPR 2019](http://cvpr2019.thecvf.com/)
@ -24,40 +26,45 @@ If you find this code useful for your research, please cite
}
```
## Dependencies
The network training/evaluation code is based on `Pytorch`.
```
PyTorch>=1.1
Cuda>=10.0
```
Updated on 07.06.2021: The code is now compatible with the latest Pytorch version (1.8).
The other python packages can be installed with `anaconda`:
```
conda install --file requirements.txt
```
### Structured Light Renderer
To train and evaluate our method in a controlled setting, we implemented an structured light renderer.
It can be used to render a virtual scene (arbitrary triangle mesh) with the structured light pattern projected from a customizable projector location.
To build it, first make sure the correct `CUDA_LIBRARY_PATH` is set in `config.json`.
Afterwards, the renderer can be build by running `make` within the `renderer` directory.
To train and evaluate our method in a controlled setting, we implemented an structured light renderer. It can be used to
render a virtual scene (arbitrary triangle mesh) with the structured light pattern projected from a customizable
projector location. To build it, first make sure the correct `CUDA_LIBRARY_PATH` is set in `config.json`. Afterwards,
the renderer can be build by running `make` within the `renderer` directory.
### PyTorch Extensions
The network training/evaluation code is based on `PyTorch`.
We implemented some custom layers that need to be built in the `torchext` directory.
Simply change into this directory and run
The network training/evaluation code is based on `PyTorch`. We implemented some custom layers that need to be built in
the `torchext` directory. Simply change into this directory and run
```
python setup.py build_ext --inplace
```
### Baseline HyperDepth
As baseline we partially re-implemented the random forest based method [HyperDepth](http://openaccess.thecvf.com/content_cvpr_2016/papers/Fanello_HyperDepth_Learning_Depth_CVPR_2016_paper.pdf).
The code resided in the `hyperdepth` directory and is implemented in `C++11` with a Python wrapper written in `Cython`.
To build it change into the directory and run
As baseline we partially re-implemented the random forest based
method [HyperDepth](http://openaccess.thecvf.com/content_cvpr_2016/papers/Fanello_HyperDepth_Learning_Depth_CVPR_2016_paper.pdf)
. The code resided in the `hyperdepth` directory and is implemented in `C++11` with a Python wrapper written in `Cython`
. To build it change into the directory and run
```
python setup.py build_ext --inplace
@ -65,42 +72,59 @@ python setup.py build_ext --inplace
## Running
### Creating Synthetic Data
To create synthetic data and save it locally, download [ShapeNet V2](https://www.shapenet.org/) and correct `SHAPENET_ROOT` in `config.json`. Then the data can be generated and saved to `DATA_ROOT` in `config.json` by running
To create synthetic data and save it locally, download [ShapeNet V2](https://www.shapenet.org/) and
correct `SHAPENET_ROOT` in `config.json`. Then the data can be generated and saved to `DATA_ROOT` in `config.json` by
running
```
./create_syn_data.sh
```
If you are only interested in evaluating our pre-trained model, [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip) is a validation set that contains a small amount of images.
If you are only interested in evaluating our pre-trained
model, [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip) is a
validation set that contains a small amount of images.
### Training Network
As a first stage, it is recommended to train the disparity decoder and edge decoder without the geometric loss. To train the network on synthetic data for the first stage run
As a first stage, it is recommended to train the disparity decoder and edge decoder without the geometric loss. To train
the network on synthetic data for the first stage run
```
python train_val.py
```
After the model is pretrained without the geometric loss, the full model can be trained from the initialized weights by running
After the model is pretrained without the geometric loss, the full model can be trained from the initialized weights by
running
```
python train_val.py --loss phge
```
### Evaluating Network
To evaluate a specific checkpoint, e.g. the 50th epoch, one can run
```
python train_val.py --cmd retest --epoch 50
```
### Evaluating a Pre-trained Model
We provide a model pre-trained using the photometric loss. Once you have prepared the synthetic dataset and changed `DATA_ROOT` in `config.json`, the pre-trained model can be evaluated on the validation set by running:
We provide a model pre-trained using the photometric loss. Once you have prepared the synthetic dataset and
changed `DATA_ROOT` in `config.json`, the pre-trained model can be evaluated on the validation set by running:
```
mkdir -p output
mkdir -p output/exp_syn
wget -O output/exp_syn/net_0099.params https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/net_0099.params
python train_val.py --cmd retest --epoch 99
```
You can also download our validation set from [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip).
You can also download our validation set
from [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip).
## Acknowledgement
This work was supported by the Intel Network on Intelligent Systems.

@ -31,7 +31,7 @@ libraries.append(cuda_lib)
setup(
name="cyrender",
cmdclass= {'build_ext': build_ext},
cmdclass={'build_ext': build_ext},
ext_modules=[
Extension('cyrender',
sources,

@ -2,18 +2,19 @@ import torch
import torch.utils.data
import numpy as np
class TestSet(object):
def __init__(self, name, dset, test_frequency=1):
self.name = name
self.dset = dset
self.test_frequency = test_frequency
class TestSets(list):
def append(self, name, dset, test_frequency=1):
super().append(TestSet(name, dset, test_frequency))
class MultiDataset(torch.utils.data.Dataset):
def __init__(self, *datasets):
self.current_epoch = 0
@ -46,7 +47,6 @@ class MultiDataset(torch.utils.data.Dataset):
return self.datasets[didx][sidx]
class BaseDataset(torch.utils.data.Dataset):
def __init__(self, train=True, fix_seed_per_epoch=False):
self.current_epoch = 0

@ -2,6 +2,7 @@ import torch
from . import ext_cpu
from . import ext_cuda
class NNFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, in0, in1):
@ -16,6 +17,7 @@ class NNFunction(torch.autograd.Function):
def backward(ctx, grad_out):
return None, None
def nn(in0, in1):
return NNFunction.apply(in0, in1)
@ -34,9 +36,11 @@ class CrossCheckFunction(torch.autograd.Function):
def backward(ctx, grad_out):
return None, None
def crosscheck(in0, in1):
return CrossCheckFunction.apply(in0, in1)
class ProjNNFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, xyz0, xyz1, K, patch_size):
@ -51,11 +55,11 @@ class ProjNNFunction(torch.autograd.Function):
def backward(ctx, grad_out):
return None, None, None, None
def proj_nn(xyz0, xyz1, K, patch_size):
return ProjNNFunction.apply(xyz0, xyz1, K, patch_size)
class XCorrVolFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, in0, in1, n_disps, block_size):
@ -70,12 +74,11 @@ class XCorrVolFunction(torch.autograd.Function):
def backward(ctx, grad_out):
return None, None, None, None
def xcorrvol(in0, in1, n_disps, block_size):
return XCorrVolFunction.apply(in0, in1, n_disps, block_size)
class PhotometricLossFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, es, ta, block_size, type, eps):
@ -103,6 +106,7 @@ class PhotometricLossFunction(torch.autograd.Function):
grad_es = ext_cpu.photometric_loss_backward(*args)
return grad_es, None, None, None, None
def photometric_loss(es, ta, block_size, type='mse', eps=0.1):
type = type.lower()
if type == 'mse':
@ -117,17 +121,18 @@ def photometric_loss(es, ta, block_size, type='mse', eps=0.1):
raise Exception('invalid loss type')
return PhotometricLossFunction.apply(es, ta, block_size, type, eps)
def photometric_loss_pytorch(es, ta, block_size, type='mse', eps=0.1):
type = type.lower()
p = block_size // 2
es_pad = torch.nn.functional.pad(es, (p,p,p,p), mode='replicate')
ta_pad = torch.nn.functional.pad(ta, (p,p,p,p), mode='replicate')
es_pad = torch.nn.functional.pad(es, (p, p, p, p), mode='replicate')
ta_pad = torch.nn.functional.pad(ta, (p, p, p, p), mode='replicate')
es_uf = torch.nn.functional.unfold(es_pad, kernel_size=block_size)
ta_uf = torch.nn.functional.unfold(ta_pad, kernel_size=block_size)
es_uf = es_uf.view(es.shape[0], es.shape[1], -1, es.shape[2], es.shape[3])
ta_uf = ta_uf.view(ta.shape[0], ta.shape[1], -1, ta.shape[2], ta.shape[3])
if type == 'mse':
ref = (es_uf - ta_uf)**2
ref = (es_uf - ta_uf) ** 2
elif type == 'sad':
ref = torch.abs(es_uf - ta_uf)
elif type == 'census_mse' or type == 'census_sad':
@ -143,5 +148,5 @@ def photometric_loss_pytorch(es, ta, block_size, type='mse', eps=0.1):
else:
raise Exception('invalid loss type')
ref = ref.view(es.shape[0], -1, es.shape[2], es.shape[3])
ref = torch.sum(ref, dim=1, keepdim=True) / block_size**2
ref = torch.sum(ref, dim=1, keepdim=True) / block_size ** 2
return ref

@ -4,11 +4,13 @@ import numpy as np
from .functions import *
class CoordConv2d(torch.nn.Module):
def __init__(self, channels_in, channels_out, kernel_size, stride, padding):
super().__init__()
self.conv = torch.nn.Conv2d(channels_in+2, channels_out, kernel_size=kernel_size, padding=padding, stride=stride)
self.conv = torch.nn.Conv2d(channels_in + 2, channels_out, kernel_size=kernel_size, padding=padding,
stride=stride)
self.uv = None
@ -19,7 +21,7 @@ class CoordConv2d(torch.nn.Module):
u = 2 * u / (width - 1) - 1
v = 2 * v / (height - 1) - 1
uv = np.stack((u, v)).reshape(1, 2, height, width)
self.uv = torch.from_numpy( uv.astype(np.float32) )
self.uv = torch.from_numpy(uv.astype(np.float32))
self.uv = self.uv.to(x.device)
uv = self.uv.expand(x.shape[0], *self.uv.shape[1:])
xuv = torch.cat((x, uv), dim=1)

@ -14,7 +14,8 @@ setup(
name='ext',
ext_modules=[
CppExtension('ext_cpu', ['ext/ext_cpu.cpp']),
CUDAExtension('ext_cuda', ['ext/ext_cuda.cpp', 'ext/ext_kernel.cu'], extra_compile_args={'cxx': [], 'nvcc': nvcc_args}),
CUDAExtension('ext_cuda', ['ext/ext_cuda.cpp', 'ext/ext_kernel.cu'],
extra_compile_args={'cxx': [], 'nvcc': nvcc_args}),
],
cmdclass={'build_ext': BuildExtension},
include_dirs=include_dirs

@ -39,9 +39,10 @@ class StopWatch(object):
return ret
def __repr__(self):
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
def __str__(self):
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
class ETA(object):
@ -77,8 +78,10 @@ class ETA(object):
def get_remaining_time_str(self):
return self.format_time(self.get_remaining_time())
class Worker(object):
def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16, num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1):
def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16,
num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1):
self.out_root = Path(out_root)
self.experiment_name = experiment_name
self.epochs = epochs
@ -91,7 +94,7 @@ class Worker(object):
self.test_device = test_device
self.max_train_iter = max_train_iter
self.errs_list=[]
self.errs_list = []
self.setup_experiment()
@ -103,17 +106,17 @@ class Worker(object):
logging.basicConfig(
level=logging.INFO,
handlers=[
logging.FileHandler( str(self.exp_out_root / 'train.log') ),
logging.FileHandler(str(self.exp_out_root / 'train.log')),
logging.StreamHandler()
],
format='%(relativeCreated)d:%(levelname)s:%(process)d-%(processName)s: %(message)s'
)
logging.info('='*80)
logging.info('=' * 80)
logging.info(f'Start of experiment: {self.experiment_name}')
logging.info(socket.gethostname())
self.log_datetime()
logging.info('='*80)
logging.info('=' * 80)
self.metric_path = self.exp_out_root / 'metrics.json'
if self.metric_path.exists():
@ -220,32 +223,31 @@ class Worker(object):
def format_err_str(self, errs, div=1):
err = sum(errs)
if len(errs) > 1:
err_str = f'{err/div:0.4f}=' + '+'.join([f'{e/div:0.4f}' for e in errs])
err_str = f'{err / div:0.4f}=' + '+'.join([f'{e / div:0.4f}' for e in errs])
else:
err_str = f'{err/div:0.4f}'
err_str = f'{err / div:0.4f}'
return err_str
def write_err_img(self):
err_img_path = self.exp_out_root / 'errs.png'
fig = plt.figure(figsize=(16,16))
lines=[]
for idx,errs in enumerate(self.errs_list):
line,=plt.plot(range(len(errs)), errs, label=f'error{idx}')
fig = plt.figure(figsize=(16, 16))
lines = []
for idx, errs in enumerate(self.errs_list):
line, = plt.plot(range(len(errs)), errs, label=f'error{idx}')
lines.append(line)
plt.tight_layout()
plt.legend(handles=lines)
plt.savefig(str(err_img_path))
plt.close(fig)
def callback_train_new_epoch(self, epoch, net, optimizer):
pass
def train(self, net, optimizer, resume=False, scheduler=None):
logging.info('='*80)
logging.info('=' * 80)
logging.info('Start training')
self.log_datetime()
logging.info('='*80)
logging.info('=' * 80)
train_set = self.get_train_set()
test_sets = self.get_test_sets()
@ -257,9 +259,9 @@ class Worker(object):
state_path = self.exp_out_root / 'state.dict'
if resume and state_path.exists():
logging.info('='*80)
logging.info('=' * 80)
logging.info(f'Loading state from {state_path}')
logging.info('='*80)
logging.info('=' * 80)
state = torch.load(str(state_path))
epoch = state['epoch'] + 1
if 'min_err' in state:
@ -269,7 +271,6 @@ class Worker(object):
curr_state.update(state['state_dict'])
net.load_state_dict(curr_state)
try:
optimizer.load_state_dict(state['optimizer'])
except:
@ -321,10 +322,10 @@ class Worker(object):
if scheduler is not None:
scheduler.step()
logging.info('='*80)
logging.info('=' * 80)
logging.info('Finished training')
self.log_datetime()
logging.info('='*80)
logging.info('=' * 80)
def get_train_set(self):
# returns train_set
@ -363,11 +364,12 @@ class Worker(object):
self.callback_train_start(epoch)
stopwatch = StopWatch()
logging.info('='*80)
logging.info('=' * 80)
logging.info('Train epoch %d' % epoch)
dset.current_epoch = epoch
train_loader = torch.utils.data.DataLoader(dset, batch_size=self.train_batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True, pin_memory=False)
train_loader = torch.utils.data.DataLoader(dset, batch_size=self.train_batch_size, shuffle=True,
num_workers=self.num_workers, drop_last=True, pin_memory=False)
net = net.to(self.train_device)
net.train()
@ -418,9 +420,9 @@ class Worker(object):
bar.update(batch_idx)
if (epoch <= 1 and batch_idx < 128) or batch_idx % 16 == 0:
err_str = self.format_err_str(errs)
logging.info(f'train e{epoch}: {batch_idx+1}/{len(train_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
#self.write_err_img()
logging.info(
f'train e{epoch}: {batch_idx + 1}/{len(train_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
# self.write_err_img()
if mean_loss is None:
mean_loss = [0 for e in errs]
@ -455,17 +457,18 @@ class Worker(object):
errs = {}
for test_set_idx, test_set in enumerate(test_sets):
if (epoch + 1) % test_set.test_frequency == 0:
logging.info('='*80)
logging.info('=' * 80)
logging.info(f'testing set {test_set.name}')
err = self.test_epoch(epoch, test_set_idx, net, test_set.dset)
errs[test_set.name] = err
return errs
def test_epoch(self, epoch, set_idx, net, dset):
logging.info('-'*80)
logging.info('-' * 80)
logging.info('Test epoch %d' % epoch)
dset.current_epoch = epoch
test_loader = torch.utils.data.DataLoader(dset, batch_size=self.test_batch_size, shuffle=False, num_workers=self.num_workers, drop_last=False, pin_memory=False)
test_loader = torch.utils.data.DataLoader(dset, batch_size=self.test_batch_size, shuffle=False,
num_workers=self.num_workers, drop_last=False, pin_memory=False)
net = net.to(self.test_device)
net.eval()
@ -502,7 +505,8 @@ class Worker(object):
bar.update(batch_idx)
if batch_idx % 25 == 0:
err_str = self.format_err_str(errs)
logging.info(f'test e{epoch}: {batch_idx+1}/{len(test_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
logging.info(
f'test e{epoch}: {batch_idx + 1}/{len(test_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
if mean_loss is None:
mean_loss = [0 for e in errs]

@ -5,25 +5,24 @@ from model import exp_synphge
from model import networks
from co.args import parse_args
# parse args
args = parse_args()
# loss types
if args.loss=='ph':
if args.loss == 'ph':
worker = exp_synph.Worker(args)
elif args.loss=='phge':
elif args.loss == 'phge':
worker = exp_synphge.Worker(args)
# concatenation of original image and lcn image
channels_in=2
channels_in = 2
# set up network
net = networks.DispEdgeDecoders(channels_in=channels_in, max_disp=args.max_disp, imsizes=worker.imsizes, output_ms=worker.ms)
net = networks.DispEdgeDecoders(channels_in=channels_in, max_disp=args.max_disp, imsizes=worker.imsizes,
output_ms=worker.ms)
# optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
# start the work
worker.do(net, optimizer)

Loading…
Cancel
Save