Reformat $EVERYTHING

master
CptCaptain 3 years ago
parent 56f2aa7d5d
commit 43df77fb9b
  1. 1
      co/__init__.py
  2. 5
      co/args.py
  3. 39
      co/cmap.py
  4. 584
      co/geometry.py
  5. 16
      co/gtimer.py
  6. 53
      co/io3d.py
  7. 30
      co/metric.py
  8. 11
      co/plt.py
  9. 23
      co/plt2d.py
  10. 34
      co/plt3d.py
  11. 66
      co/table.py
  12. 10
      co/utils.py
  13. 66
      data/commons.py
  14. 106
      data/create_syn_data.py
  15. 50
      data/dataset.py
  16. 395
      data/lcn/lcn.html
  17. 2
      data/lcn/setup.py
  18. 24
      data/lcn/test_lcn.py
  19. 19
      hyperdepth/hyperparam_search.py
  20. 7
      hyperdepth/setup.py
  21. 13
      hyperdepth/vis_eval.py
  22. 153
      model/exp_synph.py
  23. 163
      model/exp_synphge.py
  24. 161
      model/networks.py
  25. 64
      readme.md
  26. 2
      renderer/setup.py
  27. 4
      torchext/dataset.py
  28. 19
      torchext/functions.py
  29. 6
      torchext/modules.py
  30. 3
      torchext/setup.py
  31. 64
      torchext/worker.py
  32. 11
      train_val.py

@ -7,6 +7,7 @@
# set matplotlib backend depending on env # set matplotlib backend depending on env
import os import os
import matplotlib import matplotlib
if os.name == 'posix' and "DISPLAY" not in os.environ: if os.name == 'posix' and "DISPLAY" not in os.environ:
matplotlib.use('Agg') matplotlib.use('Agg')

@ -12,7 +12,7 @@ def parse_args():
parser.add_argument('--loss', parser.add_argument('--loss',
help='Train with \'ph\' for the first stage without geometric loss, \ help='Train with \'ph\' for the first stage without geometric loss, \
train with \'phge\' for the second stage with geometric loss', train with \'phge\' for the second stage with geometric loss',
default='ph', choices=['ph','phge'], type=str) default='ph', choices=['ph', 'phge'], type=str)
parser.add_argument('--data_type', parser.add_argument('--data_type',
default='syn', choices=['syn'], type=str) default='syn', choices=['syn'], type=str)
# #
@ -66,6 +66,3 @@ def parse_args():
def get_exp_name(args): def get_exp_name(args):
name = f"exp_{args.data_type}" name = f"exp_{args.data_type}"
return name return name

@ -1,19 +1,20 @@
import numpy as np import numpy as np
_color_map_errors = np.array([ _color_map_errors = np.array([
[149, 54, 49], #0: log2(x) = -infinity [149, 54, 49], # 0: log2(x) = -infinity
[180, 117, 69], #0.0625: log2(x) = -4 [180, 117, 69], # 0.0625: log2(x) = -4
[209, 173, 116], #0.125: log2(x) = -3 [209, 173, 116], # 0.125: log2(x) = -3
[233, 217, 171], #0.25: log2(x) = -2 [233, 217, 171], # 0.25: log2(x) = -2
[248, 243, 224], #0.5: log2(x) = -1 [248, 243, 224], # 0.5: log2(x) = -1
[144, 224, 254], #1.0: log2(x) = 0 [144, 224, 254], # 1.0: log2(x) = 0
[97, 174, 253], #2.0: log2(x) = 1 [97, 174, 253], # 2.0: log2(x) = 1
[67, 109, 244], #4.0: log2(x) = 2 [67, 109, 244], # 4.0: log2(x) = 2
[39, 48, 215], #8.0: log2(x) = 3 [39, 48, 215], # 8.0: log2(x) = 3
[38, 0, 165], #16.0: log2(x) = 4 [38, 0, 165], # 16.0: log2(x) = 4
[38, 0, 165] #inf: log2(x) = inf [38, 0, 165] # inf: log2(x) = inf
]).astype(float) ]).astype(float)
def color_error_image(errors, scale=1, mask=None, BGR=True): def color_error_image(errors, scale=1, mask=None, BGR=True):
""" """
Color an input error map. Color an input error map.
@ -32,16 +33,18 @@ def color_error_image(errors, scale=1, mask=None, BGR=True):
errors_color_indices = np.clip(np.log2(errors_flat / scale + 1e-5) + 5, 0, 9) errors_color_indices = np.clip(np.log2(errors_flat / scale + 1e-5) + 5, 0, 9)
i0 = np.floor(errors_color_indices).astype(int) i0 = np.floor(errors_color_indices).astype(int)
f1 = errors_color_indices - i0.astype(float) f1 = errors_color_indices - i0.astype(float)
colored_errors_flat = _color_map_errors[i0, :] * (1-f1).reshape(-1,1) + _color_map_errors[i0+1, :] * f1.reshape(-1,1) colored_errors_flat = _color_map_errors[i0, :] * (1 - f1).reshape(-1, 1) + _color_map_errors[i0 + 1,
:] * f1.reshape(-1, 1)
if mask is not None: if mask is not None:
colored_errors_flat[mask.flatten() == 0] = 255 colored_errors_flat[mask.flatten() == 0] = 255
if not BGR: if not BGR:
colored_errors_flat = colored_errors_flat[:,[2,1,0]] colored_errors_flat = colored_errors_flat[:, [2, 1, 0]]
return colored_errors_flat.reshape(errors.shape[0], errors.shape[1], 3).astype(np.int) return colored_errors_flat.reshape(errors.shape[0], errors.shape[1], 3).astype(np.int)
_color_map_depths = np.array([ _color_map_depths = np.array([
[0, 0, 0], # 0.000 [0, 0, 0], # 0.000
[0, 0, 255], # 0.114 [0, 0, 255], # 0.114
@ -65,6 +68,7 @@ _color_map_bincenters = np.array([
2.000, # doesn't make a difference, just strictly higher than 1 2.000, # doesn't make a difference, just strictly higher than 1
]) ])
def color_depth_map(depths, scale=None): def color_depth_map(depths, scale=None):
""" """
Color an input depth map. Color an input depth map.
@ -82,12 +86,13 @@ def color_depth_map(depths, scale=None):
values = np.clip(depths.flatten() / scale, 0, 1) values = np.clip(depths.flatten() / scale, 0, 1)
# for each value, figure out where they fit in in the bincenters: what is the last bincenter smaller than this value? # for each value, figure out where they fit in in the bincenters: what is the last bincenter smaller than this value?
lower_bin = ((values.reshape(-1, 1) >= _color_map_bincenters.reshape(1,-1)) * np.arange(0,9)).max(axis=1) lower_bin = ((values.reshape(-1, 1) >= _color_map_bincenters.reshape(1, -1)) * np.arange(0, 9)).max(axis=1)
lower_bin_value = _color_map_bincenters[lower_bin] lower_bin_value = _color_map_bincenters[lower_bin]
higher_bin_value = _color_map_bincenters[lower_bin + 1] higher_bin_value = _color_map_bincenters[lower_bin + 1]
alphas = (values - lower_bin_value) / (higher_bin_value - lower_bin_value) alphas = (values - lower_bin_value) / (higher_bin_value - lower_bin_value)
colors = _color_map_depths[lower_bin] * (1-alphas).reshape(-1,1) + _color_map_depths[lower_bin + 1] * alphas.reshape(-1,1) colors = _color_map_depths[lower_bin] * (1 - alphas).reshape(-1, 1) + _color_map_depths[
lower_bin + 1] * alphas.reshape(-1, 1)
return colors.reshape(depths.shape[0], depths.shape[1], 3).astype(np.uint8) return colors.reshape(depths.shape[0], depths.shape[1], 3).astype(np.uint8)
#from utils.debug import save_color_numpy # from utils.debug import save_color_numpy
#save_color_numpy(color_depth_map(np.matmul(np.ones((100,1)), np.arange(0,1200).reshape(1,1200)), scale=1000)) # save_color_numpy(color_depth_map(np.matmul(np.ones((100,1)), np.arange(0,1200).reshape(1,1200)), scale=1000))

File diff suppressed because it is too large Load Diff

@ -2,25 +2,31 @@ import numpy as np
from . import utils from . import utils
class StopWatch(utils.StopWatch): class StopWatch(utils.StopWatch):
def __del__(self): def __del__(self):
print('='*80) print('=' * 80)
print('gtimer:') print('gtimer:')
total = ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get(reduce=np.sum).items()]) total = ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get(reduce=np.sum).items()])
print(f' [total] {total}') print(f' [total] {total}')
mean = ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get(reduce=np.mean).items()]) mean = ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get(reduce=np.mean).items()])
print(f' [mean] {mean}') print(f' [mean] {mean}')
median = ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get(reduce=np.median).items()]) median = ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get(reduce=np.median).items()])
print(f' [median] {median}') print(f' [median] {median}')
print('='*80) print('=' * 80)
GTIMER = StopWatch() GTIMER = StopWatch()
def start(name): def start(name):
GTIMER.start(name) GTIMER.start(name)
def stop(name): def stop(name):
GTIMER.stop(name) GTIMER.stop(name)
class Ctx(object): class Ctx(object):
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name

@ -2,12 +2,13 @@ import struct
import numpy as np import numpy as np
import collections import collections
def _write_ply_point(fp, x,y,z, color=None, normal=None, binary=False):
args = [x,y,z] def _write_ply_point(fp, x, y, z, color=None, normal=None, binary=False):
args = [x, y, z]
if color is not None: if color is not None:
args += [int(color[0]), int(color[1]), int(color[2])] args += [int(color[0]), int(color[1]), int(color[2])]
if normal is not None: if normal is not None:
args += [normal[0],normal[1],normal[2]] args += [normal[0], normal[1], normal[2]]
if binary: if binary:
fmt = '<fff' fmt = '<fff'
if color is not None: if color is not None:
@ -24,11 +25,13 @@ def _write_ply_point(fp, x,y,z, color=None, normal=None, binary=False):
fmt += '\n' fmt += '\n'
fp.write(fmt % tuple(args)) fp.write(fmt % tuple(args))
def _write_ply_triangle(fp, i0,i1,i2, binary):
def _write_ply_triangle(fp, i0, i1, i2, binary):
if binary: if binary:
fp.write(struct.pack('<Biii', 3,i0,i1,i2)) fp.write(struct.pack('<Biii', 3, i0, i1, i2))
else: else:
fp.write('3 %d %d %d\n' % (i0,i1,i2)) fp.write('3 %d %d %d\n' % (i0, i1, i2))
def _write_ply_header_line(fp, str, binary): def _write_ply_header_line(fp, str, binary):
if binary: if binary:
@ -36,6 +39,7 @@ def _write_ply_header_line(fp, str, binary):
else: else:
fp.write(str) fp.write(str)
def write_ply(path, verts, trias=None, color=None, normals=None, binary=False): def write_ply(path, verts, trias=None, color=None, normals=None, binary=False):
if verts.shape[1] != 3: if verts.shape[1] != 3:
raise Exception('verts has to be of shape Nx3') raise Exception('verts has to be of shape Nx3')
@ -82,11 +86,12 @@ def write_ply(path, verts, trias=None, color=None, normals=None, binary=False):
n = None n = None
else: else:
n = normals[vidx] n = normals[vidx]
_write_ply_point(fp, v[0],v[1],v[2], c, n, binary) _write_ply_point(fp, v[0], v[1], v[2], c, n, binary)
if trias is not None: if trias is not None:
for t in trias: for t in trias:
_write_ply_triangle(fp, t[0],t[1],t[2], binary) _write_ply_triangle(fp, t[0], t[1], t[2], binary)
def faces_to_triangles(faces): def faces_to_triangles(faces):
new_faces = [] new_faces = []
@ -100,6 +105,7 @@ def faces_to_triangles(faces):
raise Exception('unknown face count %d', f[0]) raise Exception('unknown face count %d', f[0])
return new_faces return new_faces
def read_ply(path): def read_ply(path):
with open(path, 'rb') as f: with open(path, 'rb') as f:
# parse header # parse header
@ -152,7 +158,7 @@ def read_ply(path):
sz = n_verts * vert_bin_len sz = n_verts * vert_bin_len
fmt = ','.join(vert_bin_format) fmt = ','.join(vert_bin_format)
verts = np.ndarray(shape=(1, n_verts), dtype=np.dtype(fmt), buffer=f.read(sz)) verts = np.ndarray(shape=(1, n_verts), dtype=np.dtype(fmt), buffer=f.read(sz))
verts = verts[0].astype(vert_bin_cols*'f4,').view(dtype='f4').reshape((n_verts,-1)) verts = verts[0].astype(vert_bin_cols * 'f4,').view(dtype='f4').reshape((n_verts, -1))
faces = [] faces = []
for idx in range(n_faces): for idx in range(n_faces):
fmt = '<Biii' fmt = '<Biii'
@ -172,21 +178,21 @@ def read_ply(path):
for idx in range(n_faces): for idx in range(n_faces):
splits = f.readline().decode().strip().split(' ') splits = f.readline().decode().strip().split(' ')
n_face_verts = int(splits[0]) n_face_verts = int(splits[0])
vals = [int(v) for v in splits[0:n_face_verts+1]] vals = [int(v) for v in splits[0:n_face_verts + 1]]
faces.append(vals) faces.append(vals)
faces = faces_to_triangles(faces) faces = faces_to_triangles(faces)
faces = np.array(faces, dtype=np.int32) faces = np.array(faces, dtype=np.int32)
xyz = None xyz = None
if 'x' in vert_types and 'y' in vert_types and 'z' in vert_types: if 'x' in vert_types and 'y' in vert_types and 'z' in vert_types:
xyz = verts[:,[vert_types['x'], vert_types['y'], vert_types['z']]] xyz = verts[:, [vert_types['x'], vert_types['y'], vert_types['z']]]
colors = None colors = None
if 'red' in vert_types and 'green' in vert_types and 'blue' in vert_types: if 'red' in vert_types and 'green' in vert_types and 'blue' in vert_types:
colors = verts[:,[vert_types['red'], vert_types['green'], vert_types['blue']]] colors = verts[:, [vert_types['red'], vert_types['green'], vert_types['blue']]]
colors /= 255 colors /= 255
normals = None normals = None
if 'nx' in vert_types and 'ny' in vert_types and 'nz' in vert_types: if 'nx' in vert_types and 'ny' in vert_types and 'nz' in vert_types:
normals = verts[:,[vert_types['nx'], vert_types['ny'], vert_types['nz']]] normals = verts[:, [vert_types['nx'], vert_types['ny'], vert_types['nz']]]
return xyz, faces, colors, normals return xyz, faces, colors, normals
@ -204,6 +210,7 @@ def _read_obj_split_f(s):
nidx = -1 nidx = -1
return vidx, tidx, nidx return vidx, tidx, nidx
def read_obj(path): def read_obj(path):
with open(path, 'r') as fp: with open(path, 'r') as fp:
lines = fp.readlines() lines = fp.readlines()
@ -221,19 +228,19 @@ def read_obj(path):
parts = line.split() parts = line.split()
if line.startswith('v '): if line.startswith('v '):
parts = parts[1:] parts = parts[1:]
x,y,z = float(parts[0]), float(parts[1]), float(parts[2]) x, y, z = float(parts[0]), float(parts[1]), float(parts[2])
if len(parts) == 4 or len(parts) == 7: if len(parts) == 4 or len(parts) == 7:
w = float(parts[3]) w = float(parts[3])
x,y,z = x/w, y/w, z/w x, y, z = x / w, y / w, z / w
verts.append((x,y,z)) verts.append((x, y, z))
if len(parts) >= 6: if len(parts) >= 6:
r,g,b = float(parts[-3]), float(parts[-2]), float(parts[-1]) r, g, b = float(parts[-3]), float(parts[-2]), float(parts[-1])
rgb.append((r,g,b)) rgb.append((r, g, b))
elif line.startswith('vn '): elif line.startswith('vn '):
parts = parts[1:] parts = parts[1:]
x,y,z = float(parts[0]), float(parts[1]), float(parts[2]) x, y, z = float(parts[0]), float(parts[1]), float(parts[2])
fnorms.append((x,y,z)) fnorms.append((x, y, z))
elif line.startswith('f '): elif line.startswith('f '):
parts = parts[1:] parts = parts[1:]
@ -245,11 +252,11 @@ def read_obj(path):
faces.append((vidx0, vidx1, vidx2)) faces.append((vidx0, vidx1, vidx2))
if nidx0 >= 0: if nidx0 >= 0:
fnorm_map[vidx0].append( nidx0 ) fnorm_map[vidx0].append(nidx0)
if nidx1 >= 0: if nidx1 >= 0:
fnorm_map[vidx1].append( nidx1 ) fnorm_map[vidx1].append(nidx1)
if nidx2 >= 0: if nidx2 >= 0:
fnorm_map[vidx2].append( nidx2 ) fnorm_map[vidx2].append(nidx2)
verts = np.array(verts) verts = np.array(verts)
colors = np.array(colors) colors = np.array(colors)

@ -1,6 +1,7 @@
import numpy as np import numpy as np
from . import geometry from . import geometry
def _process_inputs(estimate, target, mask): def _process_inputs(estimate, target, mask):
if estimate.shape != target.shape: if estimate.shape != target.shape:
raise Exception('estimate and target have to be same shape') raise Exception('estimate and target have to be same shape')
@ -12,19 +13,23 @@ def _process_inputs(estimate, target, mask):
raise Exception('estimate and mask have to be same shape') raise Exception('estimate and mask have to be same shape')
return estimate, target, mask return estimate, target, mask
def mse(estimate, target, mask=None): def mse(estimate, target, mask=None):
estimate, target, mask = _process_inputs(estimate, target, mask) estimate, target, mask = _process_inputs(estimate, target, mask)
m = np.sum((estimate[mask] - target[mask])**2) / mask.sum() m = np.sum((estimate[mask] - target[mask]) ** 2) / mask.sum()
return m return m
def rmse(estimate, target, mask=None): def rmse(estimate, target, mask=None):
return np.sqrt(mse(estimate, target, mask)) return np.sqrt(mse(estimate, target, mask))
def mae(estimate, target, mask=None): def mae(estimate, target, mask=None):
estimate, target, mask = _process_inputs(estimate, target, mask) estimate, target, mask = _process_inputs(estimate, target, mask)
m = np.abs(estimate[mask] - target[mask]).sum() / mask.sum() m = np.abs(estimate[mask] - target[mask]).sum() / mask.sum()
return m return m
def outlier_fraction(estimate, target, mask=None, threshold=0): def outlier_fraction(estimate, target, mask=None, threshold=0):
estimate, target, mask = _process_inputs(estimate, target, mask) estimate, target, mask = _process_inputs(estimate, target, mask)
diff = np.abs(estimate[mask] - target[mask]) diff = np.abs(estimate[mask] - target[mask])
@ -52,6 +57,7 @@ class Metric(object):
def __str__(self): def __str__(self):
return ', '.join([f'{self.str_prefix}{key}={value:.5f}' for key, value in self.get().items()]) return ', '.join([f'{self.str_prefix}{key}={value:.5f}' for key, value in self.get().items()])
class MultipleMetric(Metric): class MultipleMetric(Metric):
def __init__(self, *metrics, **kwargs): def __init__(self, *metrics, **kwargs):
self.metrics = [*metrics] self.metrics = [*metrics]
@ -76,6 +82,7 @@ class MultipleMetric(Metric):
def __str__(self): def __str__(self):
return '\n'.join([str(m) for m in self.metrics]) return '\n'.join([str(m) for m in self.metrics])
class BaseDistanceMetric(Metric): class BaseDistanceMetric(Metric):
def __init__(self, name='', **kwargs): def __init__(self, name='', **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -99,6 +106,7 @@ class BaseDistanceMetric(Metric):
f'dist{self.name}_max': float(np.max(dists)), f'dist{self.name}_max': float(np.max(dists)),
} }
class DistanceMetric(BaseDistanceMetric): class DistanceMetric(BaseDistanceMetric):
def __init__(self, vec_length, p=2, **kwargs): def __init__(self, vec_length, p=2, **kwargs):
super().__init__(name=f'{p}', **kwargs) super().__init__(name=f'{p}', **kwargs)
@ -113,7 +121,8 @@ class DistanceMetric(BaseDistanceMetric):
es = es[ma != 0] es = es[ma != 0]
ta = ta[ma != 0] ta = ta[ma != 0]
dist = np.linalg.norm(es - ta, ord=self.p, axis=1) dist = np.linalg.norm(es - ta, ord=self.p, axis=1)
self.dists.append( dist ) self.dists.append(dist)
class OutlierFractionMetric(DistanceMetric): class OutlierFractionMetric(DistanceMetric):
def __init__(self, thresholds, *args, **kwargs): def __init__(self, thresholds, *args, **kwargs):
@ -128,6 +137,7 @@ class OutlierFractionMetric(DistanceMetric):
ret[f'of{t}'] = float(ma.sum() / ma.size) ret[f'of{t}'] = float(ma.sum() / ma.size)
return ret return ret
class RelativeDistanceMetric(BaseDistanceMetric): class RelativeDistanceMetric(BaseDistanceMetric):
def __init__(self, vec_length, p=2, **kwargs): def __init__(self, vec_length, p=2, **kwargs):
super().__init__(name=f'rel{p}', **kwargs) super().__init__(name=f'rel{p}', **kwargs)
@ -142,7 +152,8 @@ class RelativeDistanceMetric(BaseDistanceMetric):
dist /= denom dist /= denom
if ma is not None: if ma is not None:
dist = dist[ma != 0] dist = dist[ma != 0]
self.dists.append( dist ) self.dists.append(dist)
class RotmDistanceMetric(BaseDistanceMetric): class RotmDistanceMetric(BaseDistanceMetric):
def __init__(self, type='identity', **kwargs): def __init__(self, type='identity', **kwargs):
@ -156,12 +167,13 @@ class RotmDistanceMetric(BaseDistanceMetric):
if ma is not None: if ma is not None:
raise Exception('mask is not implemented') raise Exception('mask is not implemented')
if self.type == 'identity': if self.type == 'identity':
self.dists.append( geometry.rotm_distance_identity(es, ta) ) self.dists.append(geometry.rotm_distance_identity(es, ta))
elif self.type == 'geodesic': elif self.type == 'geodesic':
self.dists.append( geometry.rotm_distance_geodesic_unit_sphere(es, ta) ) self.dists.append(geometry.rotm_distance_geodesic_unit_sphere(es, ta))
else: else:
raise Exception('invalid distance type') raise Exception('invalid distance type')
class QuaternionDistanceMetric(BaseDistanceMetric): class QuaternionDistanceMetric(BaseDistanceMetric):
def __init__(self, type='angle', **kwargs): def __init__(self, type='angle', **kwargs):
super().__init__(name=type, **kwargs) super().__init__(name=type, **kwargs)
@ -174,11 +186,11 @@ class QuaternionDistanceMetric(BaseDistanceMetric):
if ma is not None: if ma is not None:
raise Exception('mask is not implemented') raise Exception('mask is not implemented')
if self.type == 'angle': if self.type == 'angle':
self.dists.append( geometry.quat_distance_angle(es, ta) ) self.dists.append(geometry.quat_distance_angle(es, ta))
elif self.type == 'mineucl': elif self.type == 'mineucl':
self.dists.append( geometry.quat_distance_mineucl(es, ta) ) self.dists.append(geometry.quat_distance_mineucl(es, ta))
elif self.type == 'normdiff': elif self.type == 'normdiff':
self.dists.append( geometry.quat_distance_normdiff(es, ta) ) self.dists.append(geometry.quat_distance_normdiff(es, ta))
else: else:
raise Exception('invalid distance type') raise Exception('invalid distance type')
@ -241,7 +253,7 @@ class BinaryAccuracyMetric(Metric):
accuracies = np.divide(tps + tns, tps + tns + fps + fns) accuracies = np.divide(tps + tns, tps + tns + fps + fns)
aacc = np.mean(accuracies) aacc = np.mean(accuracies)
for t in np.linspace(0,1,num=11)[1:-1]: for t in np.linspace(0, 1, num=11)[1:-1]:
idx = np.argmin(np.abs(t - wp)) idx = np.argmin(np.abs(t - wp))
ret[f'acc{wp[idx]:.2f}'] = float(accuracies[idx]) ret[f'acc{wp[idx]:.2f}'] = float(accuracies[idx])

@ -6,6 +6,7 @@ import matplotlib.pyplot as plt
import os import os
import time import time
def save(path, remove_axis=False, dpi=300, fig=None): def save(path, remove_axis=False, dpi=300, fig=None):
if fig is None: if fig is None:
fig = plt.gcf() fig = plt.gcf()
@ -15,13 +16,14 @@ def save(path, remove_axis=False, dpi=300, fig=None):
if remove_axis: if remove_axis:
for ax in fig.axes: for ax in fig.axes:
ax.axis('off') ax.axis('off')
ax.margins(0,0) ax.margins(0, 0)
fig.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) fig.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
for ax in fig.axes: for ax in fig.axes:
ax.xaxis.set_major_locator(plt.NullLocator()) ax.xaxis.set_major_locator(plt.NullLocator())
ax.yaxis.set_major_locator(plt.NullLocator()) ax.yaxis.set_major_locator(plt.NullLocator())
fig.savefig(path, dpi=dpi, bbox_inches='tight', pad_inches=0) fig.savefig(path, dpi=dpi, bbox_inches='tight', pad_inches=0)
def color_map(im_, cmap='viridis', vmin=None, vmax=None): def color_map(im_, cmap='viridis', vmin=None, vmax=None):
cm = plt.get_cmap(cmap) cm = plt.get_cmap(cmap)
im = im_.copy() im = im_.copy()
@ -33,11 +35,12 @@ def color_map(im_, cmap='viridis', vmin=None, vmax=None):
im[mask] = vmin im[mask] = vmin
im = (im.clip(vmin, vmax) - vmin) / (vmax - vmin) im = (im.clip(vmin, vmax) - vmin) / (vmax - vmin)
im = cm(im) im = cm(im)
im = im[...,:3] im = im[..., :3]
for c in range(3): for c in range(3):
im[mask, c] = 1 im[mask, c] = 1
return im return im
def interactive_legend(leg=None, fig=None, all_axes=True): def interactive_legend(leg=None, fig=None, all_axes=True):
if leg is None: if leg is None:
leg = plt.legend() leg = plt.legend()
@ -60,7 +63,7 @@ def interactive_legend(leg=None, fig=None, all_axes=True):
def onpick(event): def onpick(event):
if event.mouseevent.dblclick: if event.mouseevent.dblclick:
tmp = [(k,v) for k,v in lined.items()] tmp = [(k, v) for k, v in lined.items()]
else: else:
tmp = [(event.artist, lined[event.artist])] tmp = [(event.artist, lined[event.artist])]
@ -76,6 +79,7 @@ def interactive_legend(leg=None, fig=None, all_axes=True):
fig.canvas.mpl_connect('pick_event', onpick) fig.canvas.mpl_connect('pick_event', onpick)
def non_annoying_pause(interval, focus_figure=False): def non_annoying_pause(interval, focus_figure=False):
# https://github.com/matplotlib/matplotlib/issues/11131 # https://github.com/matplotlib/matplotlib/issues/11131
backend = mpl.rcParams['backend'] backend = mpl.rcParams['backend']
@ -91,6 +95,7 @@ def non_annoying_pause(interval, focus_figure=False):
return return
time.sleep(interval) time.sleep(interval)
def remove_all_ticks(fig=None): def remove_all_ticks(fig=None):
if fig is None: if fig is None:
fig = plt.gcf() fig = plt.gcf()

@ -3,21 +3,23 @@ import matplotlib.pyplot as plt
from . import geometry from . import geometry
def image_matrix(ims, bgval=0): def image_matrix(ims, bgval=0):
n = ims.shape[0] n = ims.shape[0]
m = int( np.ceil(np.sqrt(n)) ) m = int(np.ceil(np.sqrt(n)))
h = ims.shape[1] h = ims.shape[1]
w = ims.shape[2] w = ims.shape[2]
mat = np.empty((m*h, m*w), dtype=ims.dtype) mat = np.empty((m * h, m * w), dtype=ims.dtype)
mat.fill(bgval) mat.fill(bgval)
idx = 0 idx = 0
for r in range(m): for r in range(m):
for c in range(m): for c in range(m):
if idx < n: if idx < n:
mat[r*h:(r+1)*h, c*w:(c+1)*w] = ims[idx] mat[r * h:(r + 1) * h, c * w:(c + 1) * w] = ims[idx]
idx += 1 idx += 1
return mat return mat
def image_cat(ims, vertical=False): def image_cat(ims, vertical=False):
offx = [0] offx = [0]
offy = [0] offy = [0]
@ -34,20 +36,23 @@ def image_cat(ims, vertical=False):
offx = np.cumsum(offx) offx = np.cumsum(offx)
offy = np.cumsum(offy) offy = np.cumsum(offy)
im = np.zeros((height,width,*ims[0].shape[2:]), dtype=ims[0].dtype) im = np.zeros((height, width, *ims[0].shape[2:]), dtype=ims[0].dtype)
for im0, ox, oy in zip(ims, offx, offy): for im0, ox, oy in zip(ims, offx, offy):
im[oy:oy + im0.shape[0], ox:ox + im0.shape[1]] = im0 im[oy:oy + im0.shape[0], ox:ox + im0.shape[1]] = im0
return im, offx, offy return im, offx, offy
def line(li, h, w, ax=None, *args, **kwargs): def line(li, h, w, ax=None, *args, **kwargs):
if ax is None: if ax is None:
ax = plt.gca() ax = plt.gca()
xs = (-li[2] - li[1] * np.array((0, h-1))) / li[0] xs = (-li[2] - li[1] * np.array((0, h - 1))) / li[0]
ys = (-li[2] - li[0] * np.array((0, w-1))) / li[1] ys = (-li[2] - li[0] * np.array((0, w - 1))) / li[1]
pts = np.array([(0,ys[0]), (w-1, ys[1]), (xs[0], 0), (xs[1], h-1)]) pts = np.array([(0, ys[0]), (w - 1, ys[1]), (xs[0], 0), (xs[1], h - 1)])
pts = pts[np.logical_and(np.logical_and(pts[:,0] >= 0, pts[:,0] < w), np.logical_and(pts[:,1] >= 0, pts[:,1] < h))] pts = pts[
ax.plot(pts[:,0], pts[:,1], *args, **kwargs) np.logical_and(np.logical_and(pts[:, 0] >= 0, pts[:, 0] < w), np.logical_and(pts[:, 1] >= 0, pts[:, 1] < h))]
ax.plot(pts[:, 0], pts[:, 1], *args, **kwargs)
def depthshow(depth, *args, ax=None, **kwargs): def depthshow(depth, *args, ax=None, **kwargs):
if ax is None: if ax is None:

@ -4,35 +4,45 @@ from mpl_toolkits.mplot3d import Axes3D
from . import geometry from . import geometry
def ax3d(fig=None): def ax3d(fig=None):
if fig is None: if fig is None:
fig = plt.gcf() fig = plt.gcf()
return fig.add_subplot(111, projection='3d') return fig.add_subplot(111, projection='3d')
def plot_camera(ax=None, R=np.eye(3), t=np.zeros((3,)), size=25, marker_C='.', color='b', linestyle='-', linewidth=0.1, label=None, **kwargs):
def plot_camera(ax=None, R=np.eye(3), t=np.zeros((3,)), size=25, marker_C='.', color='b', linestyle='-', linewidth=0.1,
label=None, **kwargs):
if ax is None: if ax is None:
ax = plt.gca() ax = plt.gca()
C0 = geometry.translation_to_cameracenter(R, t).ravel() C0 = geometry.translation_to_cameracenter(R, t).ravel()
C1 = C0 + R.T.dot( np.array([[-size],[-size],[3*size]], dtype=np.float32) ).ravel() C1 = C0 + R.T.dot(np.array([[-size], [-size], [3 * size]], dtype=np.float32)).ravel()
C2 = C0 + R.T.dot( np.array([[-size],[+size],[3*size]], dtype=np.float32) ).ravel() C2 = C0 + R.T.dot(np.array([[-size], [+size], [3 * size]], dtype=np.float32)).ravel()
C3 = C0 + R.T.dot( np.array([[+size],[+size],[3*size]], dtype=np.float32) ).ravel() C3 = C0 + R.T.dot(np.array([[+size], [+size], [3 * size]], dtype=np.float32)).ravel()
C4 = C0 + R.T.dot( np.array([[+size],[-size],[3*size]], dtype=np.float32) ).ravel() C4 = C0 + R.T.dot(np.array([[+size], [-size], [3 * size]], dtype=np.float32)).ravel()
if marker_C != '': if marker_C != '':
ax.plot([C0[0]], [C0[1]], [C0[2]], marker=marker_C, color=color, label=label, **kwargs) ax.plot([C0[0]], [C0[1]], [C0[2]], marker=marker_C, color=color, label=label, **kwargs)
ax.plot([C0[0], C1[0]], [C0[1], C1[1]], [C0[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs) ax.plot([C0[0], C1[0]], [C0[1], C1[1]], [C0[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle,
ax.plot([C0[0], C2[0]], [C0[1], C2[1]], [C0[2], C2[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs) linewidth=linewidth, **kwargs)
ax.plot([C0[0], C3[0]], [C0[1], C3[1]], [C0[2], C3[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs) ax.plot([C0[0], C2[0]], [C0[1], C2[1]], [C0[2], C2[2]], color=color, label='_nolegend_', linestyle=linestyle,
ax.plot([C0[0], C4[0]], [C0[1], C4[1]], [C0[2], C4[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs) linewidth=linewidth, **kwargs)
ax.plot([C1[0], C2[0], C3[0], C4[0], C1[0]], [C1[1], C2[1], C3[1], C4[1], C1[1]], [C1[2], C2[2], C3[2], C4[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs) ax.plot([C0[0], C3[0]], [C0[1], C3[1]], [C0[2], C3[2]], color=color, label='_nolegend_', linestyle=linestyle,
linewidth=linewidth, **kwargs)
ax.plot([C0[0], C4[0]], [C0[1], C4[1]], [C0[2], C4[2]], color=color, label='_nolegend_', linestyle=linestyle,
linewidth=linewidth, **kwargs)
ax.plot([C1[0], C2[0], C3[0], C4[0], C1[0]], [C1[1], C2[1], C3[1], C4[1], C1[1]],
[C1[2], C2[2], C3[2], C4[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle,
linewidth=linewidth, **kwargs)
def axis_equal(ax=None): def axis_equal(ax=None):
if ax is None: if ax is None:
ax = plt.gca() ax = plt.gca()
extents = np.array([getattr(ax, 'get_{}lim'.format(dim))() for dim in 'xyz']) extents = np.array([getattr(ax, 'get_{}lim'.format(dim))() for dim in 'xyz'])
sz = extents[:,1] - extents[:,0] sz = extents[:, 1] - extents[:, 0]
centers = np.mean(extents, axis=1) centers = np.mean(extents, axis=1)
maxsize = max(abs(sz)) maxsize = max(abs(sz))
r = maxsize/2 r = maxsize / 2
for ctr, dim in zip(centers, 'xyz'): for ctr, dim in zip(centers, 'xyz'):
getattr(ax, 'set_{}lim'.format(dim))(ctr - r, ctr + r) getattr(ax, 'set_{}lim'.format(dim))(ctr - r, ctr + r)

@ -3,6 +3,7 @@ import pandas as pd
import enum import enum
import itertools import itertools
class Table(object): class Table(object):
def __init__(self, n_cols): def __init__(self, n_cols):
self.n_cols = n_cols self.n_cols = n_cols
@ -42,7 +43,8 @@ class Table(object):
else: else:
raise Exception('number of cols does not fit in table') raise Exception('number of cols does not fit in table')
for c in range(len(cols)): for c in range(len(cols)):
self.rows[row+r].cells[col+c] = Cell(data[r][c], fmt) self.rows[row + r].cells[col + c] = Cell(data[r][c], fmt)
class Row(object): class Row(object):
def __init__(self, cells, pre_separator=None, post_separator=None): def __init__(self, cells, pre_separator=None, post_separator=None):
@ -61,9 +63,8 @@ class Row(object):
return sum([c.span for c in self.cells]) return sum([c.span for c in self.cells])
class Color(object): class Color(object):
def __init__(self, color=(0,0,0), fmt='rgb'): def __init__(self, color=(0, 0, 0), fmt='rgb'):
if fmt == 'rgb': if fmt == 'rgb':
self.color = color self.color = color
elif fmt == 'RGB': elif fmt == 'RGB':
@ -79,11 +80,11 @@ class Color(object):
@classmethod @classmethod
def rgb(cls, r, g, b): def rgb(cls, r, g, b):
return Color(color=(r,g,b), fmt='rgb') return Color(color=(r, g, b), fmt='rgb')
@classmethod @classmethod
def RGB(cls, r, g, b): def RGB(cls, r, g, b):
return Color(color=(r,g,b), fmt='RGB') return Color(color=(r, g, b), fmt='RGB')
class CellFormat(object): class CellFormat(object):
@ -93,6 +94,7 @@ class CellFormat(object):
self.bgcolor = bgcolor self.bgcolor = bgcolor
self.bold = bold self.bold = bold
class Cell(object): class Cell(object):
def __init__(self, data=None, fmt=None, span=1, align=None): def __init__(self, data=None, fmt=None, span=1, align=None):
self.data = data self.data = data
@ -105,6 +107,7 @@ class Cell(object):
def __str__(self): def __str__(self):
return self.fmt.fmt % self.data return self.fmt.fmt % self.data
class Separator(enum.Enum): class Separator(enum.Enum):
HEAD = 1 HEAD = 1
BOTTOM = 2 BOTTOM = 2
@ -143,6 +146,7 @@ class Renderer(object):
with open(path, 'w') as fp: with open(path, 'w') as fp:
fp.write(txt) fp.write(txt)
class TerminalRenderer(Renderer): class TerminalRenderer(Renderer):
def __init__(self, col_sep=' '): def __init__(self, col_sep=' '):
super().__init__() super().__init__()
@ -152,7 +156,7 @@ class TerminalRenderer(Renderer):
cell = table.rows[row].cells[col] cell = table.rows[row].cells[col]
str = cell.fmt.fmt % cell.data str = cell.fmt.fmt % cell.data
str_width = len(str) str_width = len(str)
cell_width = sum([widths[idx] for idx in range(col, col+cell.span)]) cell_width = sum([widths[idx] for idx in range(col, col + cell.span)])
cell_width += len(self.col_sep) * (cell.span - 1) cell_width += len(self.col_sep) * (cell.span - 1)
if len(str) > cell_width: if len(str) > cell_width:
str = str[:cell_width] str = str[:cell_width]
@ -167,13 +171,13 @@ class TerminalRenderer(Renderer):
if str_width < cell_width: if str_width < cell_width:
n_ws = (cell_width - str_width) n_ws = (cell_width - str_width)
if table.get_cell_align(row, col) == 'r': if table.get_cell_align(row, col) == 'r':
str = ' '*n_ws + str str = ' ' * n_ws + str
elif table.get_cell_align(row, col) == 'l': elif table.get_cell_align(row, col) == 'l':
str = str + ' '*n_ws str = str + ' ' * n_ws
elif table.get_cell_align(row, col) == 'c': elif table.get_cell_align(row, col) == 'c':
n_ws1 = n_ws // 2 n_ws1 = n_ws // 2
n_ws0 = n_ws - n_ws1 n_ws0 = n_ws - n_ws1
str = ' '*n_ws0 + str + ' '*n_ws1 str = ' ' * n_ws0 + str + ' ' * n_ws1
if cell.fmt.bgcolor is not None: if cell.fmt.bgcolor is not None:
# color = cell.fmt.bgcolor.as_RGB() # color = cell.fmt.bgcolor.as_RGB()
# str = sty.bg(*color) + str + sty.rs.bg # str = sty.bg(*color) + str + sty.rs.bg
@ -182,11 +186,11 @@ class TerminalRenderer(Renderer):
def render_separator(self, separator, tab, col_widths, total_width): def render_separator(self, separator, tab, col_widths, total_width):
if separator == Separator.HEAD: if separator == Separator.HEAD:
return '='*total_width return '=' * total_width
elif separator == Separator.INNER: elif separator == Separator.INNER:
return '-'*total_width return '-' * total_width
elif separator == Separator.BOTTOM: elif separator == Separator.BOTTOM:
return '='*total_width return '=' * total_width
def render(self, table): def render(self, table):
widths = self.col_widths(table) widths = self.col_widths(table)
@ -207,6 +211,7 @@ class TerminalRenderer(Renderer):
lines.append(sepline) lines.append(sepline)
return '\n'.join(lines) return '\n'.join(lines)
class MarkdownRenderer(TerminalRenderer): class MarkdownRenderer(TerminalRenderer):
def __init__(self): def __init__(self):
super().__init__(col_sep='|') super().__init__(col_sep='|')
@ -231,20 +236,20 @@ class MarkdownRenderer(TerminalRenderer):
str = f'**{str}**' str = f'**{str}**'
str_width = len(str) str_width = len(str)
cell_width = sum([widths[idx] for idx in range(col, col+cell.span)]) cell_width = sum([widths[idx] for idx in range(col, col + cell.span)])
cell_width += len(self.col_sep) * (cell.span - 1) cell_width += len(self.col_sep) * (cell.span - 1)
if len(str) > cell_width: if len(str) > cell_width:
str = str[:cell_width] str = str[:cell_width]
else: else:
n_ws = (cell_width - str_width) n_ws = (cell_width - str_width)
if table.get_cell_align(row, col) == 'r': if table.get_cell_align(row, col) == 'r':
str = ' '*n_ws + str str = ' ' * n_ws + str
elif table.get_cell_align(row, col) == 'l': elif table.get_cell_align(row, col) == 'l':
str = str + ' '*n_ws str = str + ' ' * n_ws
elif table.get_cell_align(row, col) == 'c': elif table.get_cell_align(row, col) == 'c':
n_ws1 = n_ws // 2 n_ws1 = n_ws // 2
n_ws0 = n_ws - n_ws1 n_ws0 = n_ws - n_ws1
str = ' '*n_ws0 + str + ' '*n_ws1 str = ' ' * n_ws0 + str + ' ' * n_ws1
if col == 0: str = self.col_sep + str if col == 0: str = self.col_sep + str
if col == table.n_cols - 1: str += self.col_sep if col == table.n_cols - 1: str += self.col_sep
@ -279,7 +284,7 @@ class LatexRenderer(Renderer):
cell = table.rows[row].cells[col] cell = table.rows[row].cells[col]
str = cell.fmt.fmt % cell.data str = cell.fmt.fmt % cell.data
if cell.fmt.bold: if cell.fmt.bold:
str = '{\\bf '+ str + '}' str = '{\\bf ' + str + '}'
if cell.fmt.fgcolor is not None: if cell.fmt.fgcolor is not None:
color = cell.fmt.fgcolor.as_rgb() color = cell.fmt.fgcolor.as_rgb()
str = f'{{\\color[rgb]{{{color[0]},{color[1]},{color[2]}}} ' + str + '}' str = f'{{\\color[rgb]{{{color[0]},{color[1]},{color[2]}}} ' + str + '}'
@ -313,6 +318,7 @@ class LatexRenderer(Renderer):
lines.append('\\end{tabular}') lines.append('\\end{tabular}')
return '\n'.join(lines) return '\n'.join(lines)
class HtmlRenderer(Renderer): class HtmlRenderer(Renderer):
def __init__(self, html_class='result_table'): def __init__(self, html_class='result_table'):
super().__init__() super().__init__()
@ -331,10 +337,14 @@ class HtmlRenderer(Renderer):
color = cell.fmt.bgcolor.as_RGB() color = cell.fmt.bgcolor.as_RGB()
styles.append(f'background-color: rgb({color[0]},{color[1]},{color[2]});') styles.append(f'background-color: rgb({color[0]},{color[1]},{color[2]});')
align = table.get_cell_align(row, col) align = table.get_cell_align(row, col)
if align == 'l': align = 'left' if align == 'l':
elif align == 'r': align = 'right' align = 'left'
elif align == 'c': align = 'center' elif align == 'r':
else: raise Exception('invalid align') align = 'right'
elif align == 'c':
align = 'center'
else:
raise Exception('invalid align')
styles.append(f'text-align: {align};') styles.append(f'text-align: {align};')
row = table.rows[row] row = table.rows[row]
if row.pre_separator is not None: if row.pre_separator is not None:
@ -365,10 +375,11 @@ class HtmlRenderer(Renderer):
return '\n'.join(lines) return '\n'.join(lines)
def pandas_to_table(rowname, colname, valname, data, val_cell_fmt=CellFormat(fmt='%.4f'), best_val_cell_fmt=CellFormat(fmt='%.4f', bold=True), best_is_max=[]): def pandas_to_table(rowname, colname, valname, data, val_cell_fmt=CellFormat(fmt='%.4f'),
best_val_cell_fmt=CellFormat(fmt='%.4f', bold=True), best_is_max=[]):
rnames = data[rowname].unique() rnames = data[rowname].unique()
cnames = data[colname].unique() cnames = data[colname].unique()
tab = Table(1+len(cnames)) tab = Table(1 + len(cnames))
header = [Cell('', align='r')] header = [Cell('', align='r')]
header.extend([Cell(h, align='r') for h in cnames]) header.extend([Cell(h, align='r') for h in cnames])
@ -395,7 +406,6 @@ def pandas_to_table(rowname, colname, valname, data, val_cell_fmt=CellFormat(fmt
return tab return tab
if __name__ == '__main__': if __name__ == '__main__':
# df = pd.read_pickle('full.df') # df = pd.read_pickle('full.df')
# best_is_max = ['movF0.5', 'movF1.0'] # best_is_max = ['movF0.5', 'movF1.0']
@ -411,11 +421,11 @@ if __name__ == '__main__':
# tab.add_row(header2) # tab.add_row(header2)
tab.add_row(Row([Cell(f'c{c}') for c in range(7)])) tab.add_row(Row([Cell(f'c{c}') for c in range(7)]))
tab.rows[-1].post_separator = Separator.INNER tab.rows[-1].post_separator = Separator.INNER
tab.add_block(np.arange(15*7).reshape(15,7)) tab.add_block(np.arange(15 * 7).reshape(15, 7))
tab.rows[4].cells[2].fmt = CellFormat(bold=True) tab.rows[4].cells[2].fmt = CellFormat(bold=True)
tab.rows[2].cells[1].fmt = CellFormat(fgcolor=Color.rgb(0.2,0.6,0.1)) tab.rows[2].cells[1].fmt = CellFormat(fgcolor=Color.rgb(0.2, 0.6, 0.1))
tab.rows[2].cells[2].fmt = CellFormat(bgcolor=Color.rgb(0.7,0.1,0.5)) tab.rows[2].cells[2].fmt = CellFormat(bgcolor=Color.rgb(0.7, 0.1, 0.5))
tab.rows[5].cells[3].fmt = CellFormat(bold=True,bgcolor=Color.rgb(0.7,0.1,0.5),fgcolor=Color.rgb(0.1,0.1,0.1)) tab.rows[5].cells[3].fmt = CellFormat(bold=True, bgcolor=Color.rgb(0.7, 0.1, 0.5), fgcolor=Color.rgb(0.1, 0.1, 0.1))
tab.rows[-1].post_separator = Separator.BOTTOM tab.rows[-1].post_separator = Separator.BOTTOM
renderer = TerminalRenderer() renderer = TerminalRenderer()

@ -8,6 +8,7 @@ import re
import pickle import pickle
import subprocess import subprocess
def str2bool(v): def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'): if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True return True
@ -16,6 +17,7 @@ def str2bool(v):
else: else:
raise argparse.ArgumentTypeError('Boolean value expected.') raise argparse.ArgumentTypeError('Boolean value expected.')
class StopWatch(object): class StopWatch(object):
def __init__(self): def __init__(self):
self.timings = OrderedDict() self.timings = OrderedDict()
@ -39,9 +41,11 @@ class StopWatch(object):
return ret return ret
def __repr__(self): def __repr__(self):
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()]) return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
def __str__(self): def __str__(self):
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()]) return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
class ETA(object): class ETA(object):
def __init__(self, length): def __init__(self, length):
@ -76,6 +80,7 @@ class ETA(object):
def get_remaining_time_str(self): def get_remaining_time_str(self):
return self.format_time(self.get_remaining_time()) return self.format_time(self.get_remaining_time())
def git_hash(cwd=None): def git_hash(cwd=None):
ret = subprocess.run(['git', 'describe', '--always'], cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) ret = subprocess.run(['git', 'describe', '--always'], cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
hash = ret.stdout hash = ret.stdout
@ -83,4 +88,3 @@ def git_hash(cwd=None):
return hash.decode().strip() return hash.decode().strip()
else: else:
return None return None

@ -7,7 +7,7 @@ def get_patterns(path='syn', imsizes=[], crop=True):
pattern_size = imsizes[0] pattern_size = imsizes[0]
if path == 'syn': if path == 'syn':
np.random.seed(42) np.random.seed(42)
pattern = np.random.uniform(0,1, size=pattern_size) pattern = np.random.uniform(0, 1, size=pattern_size)
pattern = (pattern < 0.1).astype(np.float32) pattern = (pattern < 0.1).astype(np.float32)
pattern.reshape(*imsizes[0]) pattern.reshape(*imsizes[0])
else: else:
@ -21,30 +21,30 @@ def get_patterns(path='syn', imsizes=[], crop=True):
if crop and pattern.shape[0] > pattern_size[0] and pattern.shape[1] > pattern_size[1]: if crop and pattern.shape[0] > pattern_size[0] and pattern.shape[1] > pattern_size[1]:
r0 = (pattern.shape[0] - pattern_size[0]) // 2 r0 = (pattern.shape[0] - pattern_size[0]) // 2
c0 = (pattern.shape[1] - pattern_size[1]) // 2 c0 = (pattern.shape[1] - pattern_size[1]) // 2
pattern = pattern[r0:r0+imsizes[0][0], c0:c0+imsizes[0][1]] pattern = pattern[r0:r0 + imsizes[0][0], c0:c0 + imsizes[0][1]]
patterns = [] patterns = []
for imsize in imsizes: for imsize in imsizes:
pat = cv2.resize(pattern, (imsize[1],imsize[0]), interpolation=cv2.INTER_LINEAR) pat = cv2.resize(pattern, (imsize[1], imsize[0]), interpolation=cv2.INTER_LINEAR)
patterns.append(pat) patterns.append(pat)
return patterns return patterns
def get_rotation_matrix(v0, v1): def get_rotation_matrix(v0, v1):
v0 = v0/np.linalg.norm(v0) v0 = v0 / np.linalg.norm(v0)
v1 = v1/np.linalg.norm(v1) v1 = v1 / np.linalg.norm(v1)
v = np.cross(v0,v1) v = np.cross(v0, v1)
c = np.dot(v0,v1) c = np.dot(v0, v1)
s = np.linalg.norm(v) s = np.linalg.norm(v)
I = np.eye(3) I = np.eye(3)
vXStr = '{} {} {}; {} {} {}; {} {} {}'.format(0, -v[2], v[1], v[2], 0, -v[0], -v[1], v[0], 0) vXStr = '{} {} {}; {} {} {}; {} {} {}'.format(0, -v[2], v[1], v[2], 0, -v[0], -v[1], v[0], 0)
k = np.matrix(vXStr) k = np.matrix(vXStr)
r = I + k + k @ k * ((1 -c)/(s**2)) r = I + k + k @ k * ((1 - c) / (s ** 2))
return np.asarray(r.astype(np.float32)) return np.asarray(r.astype(np.float32))
def augment_image(img,rng,disp=None,grad=None,max_shift=64,max_blur=1.5,max_noise=10.0,max_sp_noise=0.001): def augment_image(img, rng, disp=None, grad=None, max_shift=64, max_blur=1.5, max_noise=10.0, max_sp_noise=0.001):
# get min/max values of image # get min/max values of image
min_val = np.min(img) min_val = np.min(img)
max_val = np.max(img) max_val = np.max(img)
@ -57,54 +57,56 @@ def augment_image(img,rng,disp=None,grad=None,max_shift=64,max_blur=1.5,max_nois
grad_aug = grad grad_aug = grad
# apply affine transformation # apply affine transformation
if max_shift>1: if max_shift > 1:
# affine parameters # affine parameters
rows,cols = img.shape rows, cols = img.shape
shear = 0 shear = 0
shift = 0 shift = 0
shear_correction = 0 shear_correction = 0
if rng.uniform(0,1)<0.75: shear = rng.uniform(-max_shift,max_shift) # shear with 75% probability if rng.uniform(0, 1) < 0.75:
else: shift = rng.uniform(0,max_shift) # shift with 25% probability shear = rng.uniform(-max_shift, max_shift) # shear with 75% probability
if shear<0: shear_correction = -shear else:
shift = rng.uniform(0, max_shift) # shift with 25% probability
if shear < 0: shear_correction = -shear
# affine transformation # affine transformation
a = shear/float(rows) a = shear / float(rows)
b = shift+shear_correction b = shift + shear_correction
# warp image # warp image
T = np.float32([[1,a,b],[0,1,0]]) T = np.float32([[1, a, b], [0, 1, 0]])
img_aug = cv2.warpAffine(img_aug,T,(cols,rows)) img_aug = cv2.warpAffine(img_aug, T, (cols, rows))
if grad is not None: if grad is not None:
grad_aug = cv2.warpAffine(grad,T,(cols,rows)) grad_aug = cv2.warpAffine(grad, T, (cols, rows))
# disparity correction map # disparity correction map
col = a*np.array(range(rows))+b col = a * np.array(range(rows)) + b
disp_delta = np.tile(col,(cols,1)).transpose() disp_delta = np.tile(col, (cols, 1)).transpose()
if disp is not None: if disp is not None:
disp_aug = cv2.warpAffine(disp+disp_delta,T,(cols,rows)) disp_aug = cv2.warpAffine(disp + disp_delta, T, (cols, rows))
# gaussian smoothing # gaussian smoothing
if rng.uniform(0,1)<0.5: if rng.uniform(0, 1) < 0.5:
img_aug = cv2.GaussianBlur(img_aug,(5,5),rng.uniform(0.2,max_blur)) img_aug = cv2.GaussianBlur(img_aug, (5, 5), rng.uniform(0.2, max_blur))
# per-pixel gaussian noise # per-pixel gaussian noise
img_aug = img_aug + rng.randn(*img_aug.shape)*rng.uniform(0.0,max_noise)/255.0 img_aug = img_aug + rng.randn(*img_aug.shape) * rng.uniform(0.0, max_noise) / 255.0
# salt-and-pepper noise # salt-and-pepper noise
if rng.uniform(0,1)<0.5: if rng.uniform(0, 1) < 0.5:
ratio=rng.uniform(0.0,max_sp_noise) ratio = rng.uniform(0.0, max_sp_noise)
img_shape = img_aug.shape img_shape = img_aug.shape
img_aug = img_aug.flatten() img_aug = img_aug.flatten()
coord = rng.choice(np.size(img_aug), int(np.size(img_aug)*ratio)) coord = rng.choice(np.size(img_aug), int(np.size(img_aug) * ratio))
img_aug[coord] = max_val img_aug[coord] = max_val
coord = rng.choice(np.size(img_aug), int(np.size(img_aug)*ratio)) coord = rng.choice(np.size(img_aug), int(np.size(img_aug) * ratio))
img_aug[coord] = min_val img_aug[coord] = min_val
img_aug = np.reshape(img_aug, img_shape) img_aug = np.reshape(img_aug, img_shape)
# clip intensities back to [0,1] # clip intensities back to [0,1]
img_aug = np.maximum(img_aug,0.0) img_aug = np.maximum(img_aug, 0.0)
img_aug = np.minimum(img_aug,1.0) img_aug = np.minimum(img_aug, 1.0)
# return image # return image
return img_aug, disp_aug, grad_aug return img_aug, disp_aug, grad_aug

@ -10,14 +10,15 @@ import cv2
import os import os
import collections import collections
import sys import sys
sys.path.append('../') sys.path.append('../')
import renderer import renderer
import co import co
from commons import get_patterns,get_rotation_matrix from commons import get_patterns, get_rotation_matrix
from lcn import lcn from lcn import lcn
def get_objs(shapenet_dir, obj_classes, num_perclass=100):
def get_objs(shapenet_dir, obj_classes, num_perclass=100):
shapenet = {'chair': '03001627', shapenet = {'chair': '03001627',
'airplane': '02691156', 'airplane': '02691156',
'car': '02958343', 'car': '02958343',
@ -40,7 +41,7 @@ def get_objs(shapenet_dir, obj_classes, num_perclass=100):
v /= (0.5 * diffs.max()) v /= (0.5 * diffs.max())
v -= (v.min(axis=0) + 1) v -= (v.min(axis=0) + 1)
f = f.astype(np.int32) f = f.astype(np.int32)
objs.append((v,f,n)) objs.append((v, f, n))
print(f'loaded {len(objs)} objects') print(f'loaded {len(objs)} objects')
return objs return objs
@ -50,11 +51,11 @@ def get_mesh(rng, min_z=0):
# set up background board # set up background board
verts, faces, normals, colors = [], [], [], [] verts, faces, normals, colors = [], [], [], []
v, f, n = co.geometry.xyplane(z=0, interleaved=True) v, f, n = co.geometry.xyplane(z=0, interleaved=True)
v[:,2] += -v[:,2].min() + rng.uniform(2,7) v[:, 2] += -v[:, 2].min() + rng.uniform(2, 7)
v[:,:2] *= 5e2 v[:, :2] *= 5e2
v[:,2] = np.mean(v[:,2]) + (v[:,2] - np.mean(v[:,2])) * 5e2 v[:, 2] = np.mean(v[:, 2]) + (v[:, 2] - np.mean(v[:, 2])) * 5e2
c = np.empty_like(v) c = np.empty_like(v)
c[:] = rng.uniform(0,1, size=(3,)).astype(np.float32) c[:] = rng.uniform(0, 1, size=(3,)).astype(np.float32)
verts.append(v) verts.append(v)
faces.append(f) faces.append(f)
normals.append(n) normals.append(n)
@ -62,7 +63,7 @@ def get_mesh(rng, min_z=0):
# randomly sample 4 foreground objects for each scene # randomly sample 4 foreground objects for each scene
for shape_idx in range(4): for shape_idx in range(4):
v, f, n = objs[rng.randint(0,len(objs))] v, f, n = objs[rng.randint(0, len(objs))]
v, f, n = v.copy(), f.copy(), n.copy() v, f, n = v.copy(), f.copy(), n.copy()
s = rng.uniform(0.25, 1) s = rng.uniform(0.25, 1)
@ -70,11 +71,11 @@ def get_mesh(rng, min_z=0):
R = co.geometry.rotm_from_quat(co.geometry.quat_random(rng=rng)) R = co.geometry.rotm_from_quat(co.geometry.quat_random(rng=rng))
v = v @ R.T v = v @ R.T
n = n @ R.T n = n @ R.T
v[:,2] += -v[:,2].min() + min_z + rng.uniform(0.5, 3) v[:, 2] += -v[:, 2].min() + min_z + rng.uniform(0.5, 3)
v[:,:2] += rng.uniform(-1, 1, size=(1,2)) v[:, :2] += rng.uniform(-1, 1, size=(1, 2))
c = np.empty_like(v) c = np.empty_like(v)
c[:] = rng.uniform(0,1, size=(3,)).astype(np.float32) c[:] = rng.uniform(0, 1, size=(3,)).astype(np.float32)
verts.append(v.astype(np.float32)) verts.append(v.astype(np.float32))
faces.append(f) faces.append(f)
@ -88,7 +89,6 @@ def get_mesh(rng, min_z=0):
def create_data(out_root, idx, n_samples, imsize, patterns, K, baseline, blend_im, noise, track_length=4): def create_data(out_root, idx, n_samples, imsize, patterns, K, baseline, blend_im, noise, track_length=4):
tic = time.time() tic = time.time()
rng = np.random.RandomState() rng = np.random.RandomState()
@ -96,35 +96,34 @@ def create_data(out_root, idx, n_samples, imsize, patterns, K, baseline, blend_i
verts, faces, colors, normals = get_mesh(rng) verts, faces, colors, normals = get_mesh(rng)
data = renderer.PyRenderInput(verts=verts.copy(), colors=colors.copy(), normals=normals.copy(), faces=faces.copy()) data = renderer.PyRenderInput(verts=verts.copy(), colors=colors.copy(), normals=normals.copy(), faces=faces.copy())
print(f'loading mesh for sample {idx+1}/{n_samples} took {time.time()-tic}[s]') print(f'loading mesh for sample {idx + 1}/{n_samples} took {time.time() - tic}[s]')
# let the camera point to the center # let the camera point to the center
center = np.array([0,0,3], dtype=np.float32) center = np.array([0, 0, 3], dtype=np.float32)
basevec = np.array([-baseline,0,0], dtype=np.float32) basevec = np.array([-baseline, 0, 0], dtype=np.float32)
unit = np.array([0,0,1],dtype=np.float32) unit = np.array([0, 0, 1], dtype=np.float32)
cam_x_ = rng.uniform(-0.2,0.2) cam_x_ = rng.uniform(-0.2, 0.2)
cam_y_ = rng.uniform(-0.2,0.2) cam_y_ = rng.uniform(-0.2, 0.2)
cam_z_ = rng.uniform(-0.2,0.2) cam_z_ = rng.uniform(-0.2, 0.2)
ret = collections.defaultdict(list) ret = collections.defaultdict(list)
blend_im_rnd = np.clip(blend_im + rng.uniform(-0.1,0.1), 0,1) blend_im_rnd = np.clip(blend_im + rng.uniform(-0.1, 0.1), 0, 1)
# capture the same static scene from different view points as a track # capture the same static scene from different view points as a track
for ind in range(track_length): for ind in range(track_length):
cam_x = cam_x_ + rng.uniform(-0.1,0.1) cam_x = cam_x_ + rng.uniform(-0.1, 0.1)
cam_y = cam_y_ + rng.uniform(-0.1,0.1) cam_y = cam_y_ + rng.uniform(-0.1, 0.1)
cam_z = cam_z_ + rng.uniform(-0.1,0.1) cam_z = cam_z_ + rng.uniform(-0.1, 0.1)
tcam = np.array([cam_x, cam_y, cam_z], dtype=np.float32) tcam = np.array([cam_x, cam_y, cam_z], dtype=np.float32)
if np.linalg.norm(tcam[0:2])<1e-9: if np.linalg.norm(tcam[0:2]) < 1e-9:
Rcam = np.eye(3, dtype=np.float32) Rcam = np.eye(3, dtype=np.float32)
else: else:
Rcam = get_rotation_matrix(center, center-tcam) Rcam = get_rotation_matrix(center, center - tcam)
tproj = tcam + basevec tproj = tcam + basevec
Rproj = Rcam Rproj = Rcam
@ -139,20 +138,19 @@ def create_data(out_root, idx, n_samples, imsize, patterns, K, baseline, blend_i
scales = [1, 0.5, 0.25, 0.125] scales = [1, 0.5, 0.25, 0.125]
for scale in scales: for scale in scales:
fx = K[0,0] * scale fx = K[0, 0] * scale
fy = K[1,1] * scale fy = K[1, 1] * scale
px = K[0,2] * scale px = K[0, 2] * scale
py = K[1,2] * scale py = K[1, 2] * scale
im_height = imsize[0] * scale im_height = imsize[0] * scale
im_width = imsize[1] * scale im_width = imsize[1] * scale
cams.append( renderer.PyCamera(fx,fy,px,py, Rcam, tcam, im_width, im_height) ) cams.append(renderer.PyCamera(fx, fy, px, py, Rcam, tcam, im_width, im_height))
projs.append( renderer.PyCamera(fx,fy,px,py, Rproj, tproj, im_width, im_height) ) projs.append(renderer.PyCamera(fx, fy, px, py, Rproj, tproj, im_width, im_height))
for s, cam, proj, pattern in zip(itertools.count(), cams, projs, patterns): for s, cam, proj, pattern in zip(itertools.count(), cams, projs, patterns):
fl = K[0,0] / (2**s) fl = K[0, 0] / (2 ** s)
shader = renderer.PyShader(0.5,1.5,0.0,10) shader = renderer.PyShader(0.5, 1.5, 0.0, 10)
pyrenderer = renderer.PyRenderer(cam, shader, engine='gpu') pyrenderer = renderer.PyRenderer(cam, shader, engine='gpu')
pyrenderer.mesh_proj(data, proj, pattern, d_alpha=0, d_beta=0.35) pyrenderer.mesh_proj(data, proj, pattern, d_alpha=0, d_beta=0.35)
@ -169,21 +167,21 @@ def create_data(out_root, idx, n_samples, imsize, patterns, K, baseline, blend_i
# get the noise free IR image $J$ # get the noise free IR image $J$
im = blend_im_rnd * im + (1 - blend_im_rnd) * ambient im = blend_im_rnd * im + (1 - blend_im_rnd) * ambient
ret[f'ambient{s}'].append( ambient[None].astype(np.float32) ) ret[f'ambient{s}'].append(ambient[None].astype(np.float32))
# get the gradient magnitude of the ambient image $|\nabla A|$ # get the gradient magnitude of the ambient image $|\nabla A|$
ambient = ambient.astype(np.float32) ambient = ambient.astype(np.float32)
sobelx = cv2.Sobel(ambient,cv2.CV_32F,1,0,ksize=5) sobelx = cv2.Sobel(ambient, cv2.CV_32F, 1, 0, ksize=5)
sobely = cv2.Sobel(ambient,cv2.CV_32F,0,1,ksize=5) sobely = cv2.Sobel(ambient, cv2.CV_32F, 0, 1, ksize=5)
grad = np.sqrt(sobelx**2 + sobely**2) grad = np.sqrt(sobelx ** 2 + sobely ** 2)
grad = np.maximum(grad-0.8,0.0) # parameter grad = np.maximum(grad - 0.8, 0.0) # parameter
# get the local contract normalized grad LCN($|\nabla A|$) # get the local contract normalized grad LCN($|\nabla A|$)
grad_lcn, grad_std = lcn.normalize(grad,5,0.1) grad_lcn, grad_std = lcn.normalize(grad, 5, 0.1)
grad_lcn = np.clip(grad_lcn,0.0,1.0) # parameter grad_lcn = np.clip(grad_lcn, 0.0, 1.0) # parameter
ret[f'grad{s}'].append( grad_lcn[None].astype(np.float32)) ret[f'grad{s}'].append(grad_lcn[None].astype(np.float32))
ret[f'im{s}'].append( im[None].astype(np.float32)) ret[f'im{s}'].append(im[None].astype(np.float32))
ret[f'mask{s}'].append(mask[None].astype(np.float32)) ret[f'mask{s}'].append(mask[None].astype(np.float32))
ret[f'disp{s}'].append(disp[None].astype(np.float32)) ret[f'disp{s}'].append(disp[None].astype(np.float32))
@ -193,18 +191,17 @@ def create_data(out_root, idx, n_samples, imsize, patterns, K, baseline, blend_i
# save to files # save to files
out_dir = out_root / f'{idx:08d}' out_dir = out_root / f'{idx:08d}'
out_dir.mkdir(exist_ok=True, parents=True) out_dir.mkdir(exist_ok=True, parents=True)
for k,val in ret.items(): for k, val in ret.items():
for tidx in range(track_length): for tidx in range(track_length):
v = val[tidx] v = val[tidx]
out_path = out_dir / f'{k}_{tidx}.npy' out_path = out_dir / f'{k}_{tidx}.npy'
np.save(out_path, v) np.save(out_path, v)
np.save( str(out_dir /'blend_im.npy'), blend_im_rnd) np.save(str(out_dir / 'blend_im.npy'), blend_im_rnd)
print(f'create sample {idx+1}/{n_samples} took {time.time()-tic}[s]')
print(f'create sample {idx + 1}/{n_samples} took {time.time() - tic}[s]')
if __name__=='__main__': if __name__ == '__main__':
np.random.seed(42) np.random.seed(42)
@ -234,11 +231,12 @@ if __name__=='__main__':
# camera parameters # camera parameters
imsize = (488, 648) imsize = (488, 648)
imsizes = [(imsize[0]//(2**s), imsize[1]//(2**s)) for s in range(4)] imsizes = [(imsize[0] // (2 ** s), imsize[1] // (2 ** s)) for s in range(4)]
# K = np.array([[567.6, 0, 324.7], [0, 570.2, 250.1], [0 ,0, 1]], dtype=np.float32) # K = np.array([[567.6, 0, 324.7], [0, 570.2, 250.1], [0 ,0, 1]], dtype=np.float32)
K = np.array([[1929.5936336276382, 0, 113.66561071478046], [0, 1911.2517985448746, 473.70108079885887], [0 ,0, 1]], dtype=np.float32) K = np.array([[1929.5936336276382, 0, 113.66561071478046], [0, 1911.2517985448746, 473.70108079885887], [0, 0, 1]],
focal_lengths = [K[0,0]/(2**s) for s in range(4)] dtype=np.float32)
baseline=0.075 focal_lengths = [K[0, 0] / (2 ** s) for s in range(4)]
baseline = 0.075
blend_im = 0.6 blend_im = 0.6
noise = 0 noise = 0
@ -264,7 +262,7 @@ if __name__=='__main__':
pickle.dump(settings, f, pickle.HIGHEST_PROTOCOL) pickle.dump(settings, f, pickle.HIGHEST_PROTOCOL)
# start the job # start the job
n_samples = 2**10 + 2**13 n_samples = 2 ** 10 + 2 ** 13
for idx in range(start, n_samples): for idx in range(start, n_samples):
args = (out_root, idx, n_samples, imsize, patterns, K, baseline, blend_im, noise, track_length) args = (out_root, idx, n_samples, imsize, patterns, K, baseline, blend_im, noise, track_length)
create_data(*args) create_data(*args)

@ -21,11 +21,13 @@ from .commons import get_patterns, augment_image
from mpl_toolkits.mplot3d import Axes3D from mpl_toolkits.mplot3d import Axes3D
class TrackSynDataset(torchext.BaseDataset): class TrackSynDataset(torchext.BaseDataset):
''' '''
Load locally saved synthetic dataset Load locally saved synthetic dataset
Please run ./create_syn_data.sh to generate the dataset Please run ./create_syn_data.sh to generate the dataset
''' '''
def __init__(self, settings_path, sample_paths, track_length=2, train=True, data_aug=False): def __init__(self, settings_path, sample_paths, track_length=2, train=True, data_aug=False):
super().__init__(train=train) super().__init__(train=train)
@ -33,8 +35,8 @@ class TrackSynDataset(torchext.BaseDataset):
self.sample_paths = sample_paths self.sample_paths = sample_paths
self.data_aug = data_aug self.data_aug = data_aug
self.train = train self.train = train
self.track_length=track_length self.track_length = track_length
assert(track_length<=4) assert (track_length <= 4)
with open(str(settings_path), 'rb') as f: with open(str(settings_path), 'rb') as f:
settings = pickle.load(f) settings = pickle.load(f)
@ -46,10 +48,10 @@ class TrackSynDataset(torchext.BaseDataset):
self.scale = len(self.imsizes) self.scale = len(self.imsizes)
self.max_shift=0 self.max_shift = 0
self.max_blur=0.5 self.max_blur = 0.5
self.max_noise=3.0 self.max_noise = 3.0
self.max_sp_noise=0.0005 self.max_sp_noise = 0.0005
def __len__(self): def __len__(self):
return len(self.sample_paths) return len(self.sample_paths)
@ -75,9 +77,9 @@ class TrackSynDataset(torchext.BaseDataset):
ambs = [] ambs = []
grads = [] grads = []
for tidx in track_ind: for tidx in track_ind:
imgs.append(np.load(os.path.join(sample_path,f'im{sidx}_{tidx}.npy'))) imgs.append(np.load(os.path.join(sample_path, f'im{sidx}_{tidx}.npy')))
ambs.append(np.load(os.path.join(sample_path,f'ambient{sidx}_{tidx}.npy'))) ambs.append(np.load(os.path.join(sample_path, f'ambient{sidx}_{tidx}.npy')))
grads.append(np.load(os.path.join(sample_path,f'grad{sidx}_{tidx}.npy'))) grads.append(np.load(os.path.join(sample_path, f'grad{sidx}_{tidx}.npy')))
ret[f'im{sidx}'] = np.stack(imgs, axis=0) ret[f'im{sidx}'] = np.stack(imgs, axis=0)
ret[f'ambient{sidx}'] = np.stack(ambs, axis=0) ret[f'ambient{sidx}'] = np.stack(ambs, axis=0)
ret[f'grad{sidx}'] = np.stack(grads, axis=0) ret[f'grad{sidx}'] = np.stack(grads, axis=0)
@ -87,20 +89,20 @@ class TrackSynDataset(torchext.BaseDataset):
R = [] R = []
t = [] t = []
for tidx in track_ind: for tidx in track_ind:
disps.append(np.load(os.path.join(sample_path,f'disp0_{tidx}.npy'))) disps.append(np.load(os.path.join(sample_path, f'disp0_{tidx}.npy')))
R.append(np.load(os.path.join(sample_path,f'R_{tidx}.npy'))) R.append(np.load(os.path.join(sample_path, f'R_{tidx}.npy')))
t.append(np.load(os.path.join(sample_path,f't_{tidx}.npy'))) t.append(np.load(os.path.join(sample_path, f't_{tidx}.npy')))
ret[f'disp0'] = np.stack(disps, axis=0) ret[f'disp0'] = np.stack(disps, axis=0)
ret['R'] = np.stack(R, axis=0) ret['R'] = np.stack(R, axis=0)
ret['t'] = np.stack(t, axis=0) ret['t'] = np.stack(t, axis=0)
blend_im = np.load(os.path.join(sample_path,'blend_im.npy')) blend_im = np.load(os.path.join(sample_path, 'blend_im.npy'))
ret['blend_im'] = blend_im.astype(np.float32) ret['blend_im'] = blend_im.astype(np.float32)
#### apply data augmentation at different scales seperately, only work for max_shift=0 #### apply data augmentation at different scales seperately, only work for max_shift=0
if self.data_aug: if self.data_aug:
for sidx in range(len(self.imsizes)): for sidx in range(len(self.imsizes)):
if sidx==0: if sidx == 0:
img = ret[f'im{sidx}'] img = ret[f'im{sidx}']
disp = ret[f'disp{sidx}'] disp = ret[f'disp{sidx}']
grad = ret[f'grad{sidx}'] grad = ret[f'grad{sidx}']
@ -108,10 +110,11 @@ class TrackSynDataset(torchext.BaseDataset):
disp_aug = np.zeros_like(img) disp_aug = np.zeros_like(img)
grad_aug = np.zeros_like(img) grad_aug = np.zeros_like(img)
for i in range(img.shape[0]): for i in range(img.shape[0]):
img_aug_, disp_aug_, grad_aug_ = augment_image(img[i,0],rng, img_aug_, disp_aug_, grad_aug_ = augment_image(img[i, 0], rng,
disp=disp[i,0],grad=grad[i,0], disp=disp[i, 0], grad=grad[i, 0],
max_shift=self.max_shift, max_blur=self.max_blur, max_shift=self.max_shift, max_blur=self.max_blur,
max_noise=self.max_noise, max_sp_noise=self.max_sp_noise) max_noise=self.max_noise,
max_sp_noise=self.max_sp_noise)
img_aug[i] = img_aug_[None].astype(np.float32) img_aug[i] = img_aug_[None].astype(np.float32)
disp_aug[i] = disp_aug_[None].astype(np.float32) disp_aug[i] = disp_aug_[None].astype(np.float32)
grad_aug[i] = grad_aug_[None].astype(np.float32) grad_aug[i] = grad_aug_[None].astype(np.float32)
@ -122,27 +125,24 @@ class TrackSynDataset(torchext.BaseDataset):
img = ret[f'im{sidx}'] img = ret[f'im{sidx}']
img_aug = np.zeros_like(img) img_aug = np.zeros_like(img)
for i in range(img.shape[0]): for i in range(img.shape[0]):
img_aug_, _, _ = augment_image(img[i,0],rng, img_aug_, _, _ = augment_image(img[i, 0], rng,
max_shift=self.max_shift, max_blur=self.max_blur, max_shift=self.max_shift, max_blur=self.max_blur,
max_noise=self.max_noise, max_sp_noise=self.max_sp_noise) max_noise=self.max_noise, max_sp_noise=self.max_sp_noise)
img_aug[i] = img_aug_[None].astype(np.float32) img_aug[i] = img_aug_[None].astype(np.float32)
ret[f'im{sidx}'] = img_aug ret[f'im{sidx}'] = img_aug
if len(track_ind)==1: if len(track_ind) == 1:
for key, val in ret.items(): for key, val in ret.items():
if key!='blend_im' and key!='id': if key != 'blend_im' and key != 'id':
ret[key] = val[0] ret[key] = val[0]
return ret return ret
def getK(self, sidx=0): def getK(self, sidx=0):
K = self.K.copy() / (2**sidx) K = self.K.copy() / (2 ** sidx)
K[2,2] = 1 K[2, 2] = 1
return K return K
if __name__ == '__main__': if __name__ == '__main__':
pass pass

@ -2,7 +2,7 @@
<!-- Generated by Cython 0.29 --> <!-- Generated by Cython 0.29 -->
<html> <html>
<head> <head>
<meta http-equiv="Content-Type" content="text/html; charset=utf-8" /> <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
<title>Cython: lcn.pyx</title> <title>Cython: lcn.pyx</title>
<style type="text/css"> <style type="text/css">
@ -355,17 +355,23 @@ body.cython { font-family: courier; font-size: 12; }
.cython .vi { color: #19177C } /* Name.Variable.Instance */ .cython .vi { color: #19177C } /* Name.Variable.Instance */
.cython .vm { color: #19177C } /* Name.Variable.Magic */ .cython .vm { color: #19177C } /* Name.Variable.Magic */
.cython .il { color: #666666 } /* Literal.Number.Integer.Long */ .cython .il { color: #666666 } /* Literal.Number.Integer.Long */
</style> </style>
</head> </head>
<body class="cython"> <body class="cython">
<p><span style="border-bottom: solid 1px grey;">Generated by Cython 0.29</span></p> <p><span style="border-bottom: solid 1px grey;">Generated by Cython 0.29</span></p>
<p> <p>
<span style="background-color: #FFFF00">Yellow lines</span> hint at Python interaction.<br /> <span style="background-color: #FFFF00">Yellow lines</span> hint at Python interaction.<br/>
Click on a line that starts with a "<code>+</code>" to see the C code that Cython generated for it. Click on a line that starts with a "<code>+</code>" to see the C code that Cython generated for it.
</p> </p>
<p>Raw output: <a href="lcn.c">lcn.c</a></p> <p>Raw output: <a href="lcn.c">lcn.c</a></p>
<div class="cython"><pre class="cython line score-16" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">01</span>: <span class="k">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span></pre> <div class="cython">
<pre class='cython code score-16 '> __pyx_t_1 = <span class='pyx_c_api'>__Pyx_Import</span>(__pyx_n_s_numpy, 0, -1);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1, __pyx_L1_error)</span> <pre class="cython line score-16"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">01</span>: <span class="k">import</span> <span class="nn">numpy</span> <span
class="k">as</span> <span class="nn">np</span></pre>
<pre class='cython code score-16 '> __pyx_t_1 = <span class='pyx_c_api'>__Pyx_Import</span>(__pyx_n_s_numpy, 0, -1);<span
class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_np, __pyx_t_1) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 1, __pyx_L1_error)</span> if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_np, __pyx_t_1) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 1, __pyx_L1_error)</span>
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
@ -374,22 +380,39 @@ body.cython { font-family: courier; font-size: 12; }
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_test, __pyx_t_1) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 1, __pyx_L1_error)</span> if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_test, __pyx_t_1) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 1, __pyx_L1_error)</span>
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
</pre><pre class="cython line score-0">&#xA0;<span class="">02</span>: <span class="k">cimport</span> <span class="nn">cython</span></pre> </pre>
<pre class="cython line score-0">&#xA0;<span class="">03</span>: </pre> <pre class="cython line score-0">&#xA0;<span class="">02</span>: <span class="k">cimport</span> <span class="nn">cython</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">04</span>: <span class="c"># use c square root function</span></pre> <pre class="cython line score-0">&#xA0;<span class="">03</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">05</span>: <span class="k">cdef</span> <span class="kr">extern</span> <span class="k">from</span> <span class="s">&quot;math.h&quot;</span><span class="p">:</span></pre> <pre class="cython line score-0">&#xA0;<span class="">04</span>: <span class="c"># use c square root function</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">06</span>: <span class="nb">float</span> <span class="n">sqrt</span><span class="p">(</span><span class="nb">float</span> <span class="n">x</span><span class="p">)</span></pre> <pre class="cython line score-0">&#xA0;<span class="">05</span>: <span class="k">cdef</span> <span
<pre class="cython line score-0">&#xA0;<span class="">07</span>: </pre> class="kr">extern</span> <span class="k">from</span> <span class="s">&quot;math.h&quot;</span><span
<pre class="cython line score-0">&#xA0;<span class="">08</span>: <span class="nd">@cython</span><span class="o">.</span><span class="n">boundscheck</span><span class="p">(</span><span class="bp">False</span><span class="p">)</span></pre> class="p">:</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">09</span>: <span class="nd">@cython</span><span class="o">.</span><span class="n">wraparound</span><span class="p">(</span><span class="bp">False</span><span class="p">)</span></pre> <pre class="cython line score-0">&#xA0;<span class="">06</span>: <span class="nb">float</span> <span class="n">sqrt</span><span
<pre class="cython line score-0">&#xA0;<span class="">10</span>: <span class="nd">@cython</span><span class="o">.</span><span class="n">cdivision</span><span class="p">(</span><span class="bp">True</span><span class="p">)</span></pre> class="p">(</span><span class="nb">float</span> <span class="n">x</span><span class="p">)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">11</span>: </pre> <pre class="cython line score-0">&#xA0;<span class="">07</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">12</span>: <span class="c"># 3 parameters:</span></pre> <pre class="cython line score-0">&#xA0;<span class="">08</span>: <span class="nd">@cython</span><span
<pre class="cython line score-0">&#xA0;<span class="">13</span>: <span class="c"># - float image</span></pre> class="o">.</span><span class="n">boundscheck</span><span class="p">(</span><span
<pre class="cython line score-0">&#xA0;<span class="">14</span>: <span class="c"># - kernel size (actually this is the radius, kernel is 2*k+1)</span></pre> class="bp">False</span><span class="p">)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">15</span>: <span class="c"># - small constant epsilon that is used to avoid division by zero</span></pre> <pre class="cython line score-0">&#xA0;<span class="">09</span>: <span class="nd">@cython</span><span
<pre class="cython line score-67" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">16</span>: <span class="k">def</span> <span class="nf">normalize</span><span class="p">(</span><span class="nb">float</span><span class="p">[:,</span> <span class="p">:]</span> <span class="n">img</span><span class="p">,</span> <span class="nb">int</span> <span class="n">kernel_size</span> <span class="o">=</span> <span class="mf">4</span><span class="p">,</span> <span class="nb">float</span> <span class="n">epsilon</span> <span class="o">=</span> <span class="mf">0.01</span><span class="p">):</span></pre> class="o">.</span><span class="n">wraparound</span><span class="p">(</span><span
<pre class='cython code score-67 '>/* Python wrapper */ class="bp">False</span><span class="p">)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">10</span>: <span class="nd">@cython</span><span
class="o">.</span><span class="n">cdivision</span><span class="p">(</span><span class="bp">True</span><span
class="p">)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">11</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">12</span>: <span class="c"># 3 parameters:</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">13</span>: <span class="c"># - float image</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">14</span>: <span class="c"># - kernel size (actually this is the radius, kernel is 2*k+1)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">15</span>: <span class="c"># - small constant epsilon that is used to avoid division by zero</span></pre>
<pre class="cython line score-67"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">16</span>: <span class="k">def</span> <span class="nf">normalize</span><span
class="p">(</span><span class="nb">float</span><span class="p">[:,</span> <span class="p">:]</span> <span
class="n">img</span><span class="p">,</span> <span class="nb">int</span> <span class="n">kernel_size</span> <span
class="o">=</span> <span class="mf">4</span><span class="p">,</span> <span class="nb">float</span> <span
class="n">epsilon</span> <span class="o">=</span> <span class="mf">0.01</span><span
class="p">):</span></pre>
<pre class='cython code score-67 '>/* Python wrapper */
static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/
static PyMethodDef __pyx_mdef_3lcn_1normalize = {"normalize", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_3lcn_1normalize, METH_VARARGS|METH_KEYWORDS, 0}; static PyMethodDef __pyx_mdef_3lcn_1normalize = {"normalize", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_3lcn_1normalize, METH_VARARGS|METH_KEYWORDS, 0};
static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) {
@ -434,7 +457,8 @@ static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_
} }
} }
if (unlikely(kw_args &gt; 0)) { if (unlikely(kw_args &gt; 0)) {
if (unlikely(<span class='pyx_c_api'>__Pyx_ParseOptionalKeywords</span>(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "normalize") &lt; 0)) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span> if (unlikely(<span class='pyx_c_api'>__Pyx_ParseOptionalKeywords</span>(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "normalize") &lt; 0)) <span
class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
} }
} else { } else {
switch (<span class='py_macro_api'>PyTuple_GET_SIZE</span>(__pyx_args)) { switch (<span class='py_macro_api'>PyTuple_GET_SIZE</span>(__pyx_args)) {
@ -447,21 +471,27 @@ static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_
default: goto __pyx_L5_argtuple_error; default: goto __pyx_L5_argtuple_error;
} }
} }
__pyx_v_img = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(values[0], PyBUF_WRITABLE);<span class='error_goto'> if (unlikely(!__pyx_v_img.memview)) __PYX_ERR(0, 16, __pyx_L3_error)</span> __pyx_v_img = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(values[0], PyBUF_WRITABLE);<span
class='error_goto'> if (unlikely(!__pyx_v_img.memview)) __PYX_ERR(0, 16, __pyx_L3_error)</span>
if (values[1]) { if (values[1]) {
__pyx_v_kernel_size = <span class='pyx_c_api'>__Pyx_PyInt_As_int</span>(values[1]); if (unlikely((__pyx_v_kernel_size == (int)-1) &amp;&amp; <span class='py_c_api'>PyErr_Occurred</span>())) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span> __pyx_v_kernel_size = <span class='pyx_c_api'>__Pyx_PyInt_As_int</span>(values[1]); if (unlikely((__pyx_v_kernel_size == (int)-1) &amp;&amp; <span
class='py_c_api'>PyErr_Occurred</span>())) <span
class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
} else { } else {
__pyx_v_kernel_size = ((int)4); __pyx_v_kernel_size = ((int)4);
} }
if (values[2]) { if (values[2]) {
__pyx_v_epsilon = __pyx_<span class='py_c_api'>PyFloat_AsFloat</span>(values[2]); if (unlikely((__pyx_v_epsilon == (float)-1) &amp;&amp; <span class='py_c_api'>PyErr_Occurred</span>())) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span> __pyx_v_epsilon = __pyx_<span class='py_c_api'>PyFloat_AsFloat</span>(values[2]); if (unlikely((__pyx_v_epsilon == (float)-1) &amp;&amp; <span
class='py_c_api'>PyErr_Occurred</span>())) <span
class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
} else { } else {
__pyx_v_epsilon = ((float)0.01); __pyx_v_epsilon = ((float)0.01);
} }
} }
goto __pyx_L4_argument_unpacking_done; goto __pyx_L4_argument_unpacking_done;
__pyx_L5_argtuple_error:; __pyx_L5_argtuple_error:;
<span class='pyx_c_api'>__Pyx_RaiseArgtupleInvalid</span>("normalize", 0, 1, 3, <span class='py_macro_api'>PyTuple_GET_SIZE</span>(__pyx_args)); <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span> <span class='pyx_c_api'>__Pyx_RaiseArgtupleInvalid</span>("normalize", 0, 1, 3, <span class='py_macro_api'>PyTuple_GET_SIZE</span>(__pyx_args)); <span
class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
__pyx_L3_error:; __pyx_L3_error:;
<span class='pyx_c_api'>__Pyx_AddTraceback</span>("lcn.normalize", __pyx_clineno, __pyx_lineno, __pyx_filename); <span class='pyx_c_api'>__Pyx_AddTraceback</span>("lcn.normalize", __pyx_clineno, __pyx_lineno, __pyx_filename);
<span class='refnanny'>__Pyx_RefNannyFinishContext</span>(); <span class='refnanny'>__Pyx_RefNannyFinishContext</span>();
@ -515,27 +545,49 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
return __pyx_r; return __pyx_r;
} }
/* … */ /* … */
__pyx_tuple__19 = <span class='py_c_api'>PyTuple_Pack</span>(19, __pyx_n_s_img, __pyx_n_s_kernel_size, __pyx_n_s_epsilon, __pyx_n_s_M, __pyx_n_s_N, __pyx_n_s_img_lcn, __pyx_n_s_img_std, __pyx_n_s_img_lcn_view, __pyx_n_s_img_std_view, __pyx_n_s_tmp, __pyx_n_s_mean, __pyx_n_s_stddev, __pyx_n_s_m, __pyx_n_s_n, __pyx_n_s_i, __pyx_n_s_j, __pyx_n_s_ks, __pyx_n_s_eps, __pyx_n_s_num);<span class='error_goto'> if (unlikely(!__pyx_tuple__19)) __PYX_ERR(0, 16, __pyx_L1_error)</span> __pyx_tuple__19 = <span class='py_c_api'>PyTuple_Pack</span>(19, __pyx_n_s_img, __pyx_n_s_kernel_size, __pyx_n_s_epsilon, __pyx_n_s_M, __pyx_n_s_N, __pyx_n_s_img_lcn, __pyx_n_s_img_std, __pyx_n_s_img_lcn_view, __pyx_n_s_img_std_view, __pyx_n_s_tmp, __pyx_n_s_mean, __pyx_n_s_stddev, __pyx_n_s_m, __pyx_n_s_n, __pyx_n_s_i, __pyx_n_s_j, __pyx_n_s_ks, __pyx_n_s_eps, __pyx_n_s_num);<span
class='error_goto'> if (unlikely(!__pyx_tuple__19)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_tuple__19); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_tuple__19);
<span class='refnanny'>__Pyx_GIVEREF</span>(__pyx_tuple__19); <span class='refnanny'>__Pyx_GIVEREF</span>(__pyx_tuple__19);
/* … */ /* … */
__pyx_t_1 = PyCFunction_NewEx(&amp;__pyx_mdef_3lcn_1normalize, NULL, __pyx_n_s_lcn);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 16, __pyx_L1_error)</span> __pyx_t_1 = PyCFunction_NewEx(&amp;__pyx_mdef_3lcn_1normalize, NULL, __pyx_n_s_lcn);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_normalize, __pyx_t_1) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L1_error)</span> if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_normalize, __pyx_t_1) &lt; 0) <span
class='error_goto'>__PYX_ERR(0, 16, __pyx_L1_error)</span>
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
__pyx_codeobj__20 = (PyObject*)<span class='pyx_c_api'>__Pyx_PyCode_New</span>(3, 0, 19, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__19, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_lcn_pyx, __pyx_n_s_normalize, 16, __pyx_empty_bytes);<span class='error_goto'> if (unlikely(!__pyx_codeobj__20)) __PYX_ERR(0, 16, __pyx_L1_error)</span> __pyx_codeobj__20 = (PyObject*)<span class='pyx_c_api'>__Pyx_PyCode_New</span>(3, 0, 19, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__19, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_lcn_pyx, __pyx_n_s_normalize, 16, __pyx_empty_bytes);<span
</pre><pre class="cython line score-0">&#xA0;<span class="">17</span>: </pre> class='error_goto'> if (unlikely(!__pyx_codeobj__20)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
<pre class="cython line score-0">&#xA0;<span class="">18</span>: <span class="c"># image dimensions</span></pre> </pre>
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">19</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">M</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mf">0</span><span class="p">]</span></pre> <pre class="cython line score-0">&#xA0;<span class="">17</span>: </pre>
<pre class='cython code score-0 '> __pyx_v_M = (__pyx_v_img.shape[0]); <pre class="cython line score-0">&#xA0;<span class="">18</span>: <span class="c"># image dimensions</span></pre>
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">20</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">N</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mf">1</span><span class="p">]</span></pre> <pre class="cython line score-0"
<pre class='cython code score-0 '> __pyx_v_N = (__pyx_v_img.shape[1]); onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
</pre><pre class="cython line score-0">&#xA0;<span class="">21</span>: </pre> class="">19</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span
<pre class="cython line score-0">&#xA0;<span class="">22</span>: <span class="c"># create outputs and output views</span></pre> class="nf">M</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span
<pre class="cython line score-46" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">23</span>: <span class="n">img_lcn</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span></pre> class="n">shape</span><span class="p">[</span><span class="mf">0</span><span class="p">]</span></pre>
<pre class='cython code score-46 '> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_1, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span> <pre class='cython code score-0 '> __pyx_v_M = (__pyx_v_img.shape[0]);
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">20</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span
class="nf">N</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span
class="n">shape</span><span class="p">[</span><span class="mf">1</span><span class="p">]</span></pre>
<pre class='cython code score-0 '> __pyx_v_N = (__pyx_v_img.shape[1]);
</pre>
<pre class="cython line score-0">&#xA0;<span class="">21</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">22</span>: <span class="c"># create outputs and output views</span></pre>
<pre class="cython line score-46"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">23</span>: <span class="n">img_lcn</span> <span class="o">=</span> <span
class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span
class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span
class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span
class="n">float32</span><span class="p">)</span></pre>
<pre class='cython code score-46 '> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_1, __pyx_n_s_np);<span
class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
__pyx_t_2 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_1, __pyx_n_s_zeros);<span class='error_goto'> if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 23, __pyx_L1_error)</span> __pyx_t_2 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_1, __pyx_n_s_zeros);<span
class='error_goto'> if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_2); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_2);
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
__pyx_t_1 = <span class='py_c_api'>PyInt_FromSsize_t</span>(__pyx_v_M);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span> __pyx_t_1 = <span class='py_c_api'>PyInt_FromSsize_t</span>(__pyx_v_M);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
@ -559,22 +611,34 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_4); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_4);
<span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_1, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_1, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
__pyx_t_5 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_1, __pyx_n_s_float32);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)</span> __pyx_t_5 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_1, __pyx_n_s_float32);<span
class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_t_4, __pyx_n_s_dtype, __pyx_t_5) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 23, __pyx_L1_error)</span> if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_t_4, __pyx_n_s_dtype, __pyx_t_5) &lt; 0) <span
class='error_goto'>__PYX_ERR(0, 23, __pyx_L1_error)</span>
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0;
__pyx_t_5 = <span class='pyx_c_api'>__Pyx_PyObject_Call</span>(__pyx_t_2, __pyx_t_3, __pyx_t_4);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)</span> __pyx_t_5 = <span class='pyx_c_api'>__Pyx_PyObject_Call</span>(__pyx_t_2, __pyx_t_3, __pyx_t_4);<span
class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_2); __pyx_t_2 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_2); __pyx_t_2 = 0;
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_3); __pyx_t_3 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_3); __pyx_t_3 = 0;
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_4); __pyx_t_4 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_4); __pyx_t_4 = 0;
__pyx_v_img_lcn = __pyx_t_5; __pyx_v_img_lcn = __pyx_t_5;
__pyx_t_5 = 0; __pyx_t_5 = 0;
</pre><pre class="cython line score-46" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">24</span>: <span class="n">img_std</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span></pre> </pre>
<pre class='cython code score-46 '> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_5, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span> <pre class="cython line score-46"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">24</span>: <span class="n">img_std</span> <span class="o">=</span> <span
class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span
class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span
class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span
class="n">float32</span><span class="p">)</span></pre>
<pre class='cython code score-46 '> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_5, __pyx_n_s_np);<span
class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
__pyx_t_4 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_5, __pyx_n_s_zeros);<span class='error_goto'> if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 24, __pyx_L1_error)</span> __pyx_t_4 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_5, __pyx_n_s_zeros);<span
class='error_goto'> if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_4); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_4);
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0;
__pyx_t_5 = <span class='py_c_api'>PyInt_FromSsize_t</span>(__pyx_v_M);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span> __pyx_t_5 = <span class='py_c_api'>PyInt_FromSsize_t</span>(__pyx_v_M);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
@ -598,114 +662,236 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_2); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_2);
<span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_5, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_5, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
__pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_5, __pyx_n_s_float32);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)</span> __pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_5, __pyx_n_s_float32);<span
class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0;
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_t_2, __pyx_n_s_dtype, __pyx_t_1) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 24, __pyx_L1_error)</span> if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_t_2, __pyx_n_s_dtype, __pyx_t_1) &lt; 0) <span
class='error_goto'>__PYX_ERR(0, 24, __pyx_L1_error)</span>
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
__pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyObject_Call</span>(__pyx_t_4, __pyx_t_3, __pyx_t_2);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)</span> __pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyObject_Call</span>(__pyx_t_4, __pyx_t_3, __pyx_t_2);<span
class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_4); __pyx_t_4 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_4); __pyx_t_4 = 0;
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_3); __pyx_t_3 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_3); __pyx_t_3 = 0;
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_2); __pyx_t_2 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_2); __pyx_t_2 = 0;
__pyx_v_img_std = __pyx_t_1; __pyx_v_img_std = __pyx_t_1;
__pyx_t_1 = 0; __pyx_t_1 = 0;
</pre><pre class="cython line score-2" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">25</span>: <span class="k">cdef</span> <span class="kt">float</span>[<span class="p">:,</span> <span class="p">:]</span> <span class="n">img_lcn_view</span> <span class="o">=</span> <span class="n">img_lcn</span></pre> </pre>
<pre class='cython code score-2 '> __pyx_t_6 = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(__pyx_v_img_lcn, PyBUF_WRITABLE);<span class='error_goto'> if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 25, __pyx_L1_error)</span> <pre class="cython line score-2"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">25</span>: <span class="k">cdef</span> <span class="kt">float</span>[<span class="p">:,</span> <span
class="p">:]</span> <span class="n">img_lcn_view</span> <span class="o">=</span> <span
class="n">img_lcn</span></pre>
<pre class='cython code score-2 '> __pyx_t_6 = <span
class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(__pyx_v_img_lcn, PyBUF_WRITABLE);<span
class='error_goto'> if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 25, __pyx_L1_error)</span>
__pyx_v_img_lcn_view = __pyx_t_6; __pyx_v_img_lcn_view = __pyx_t_6;
__pyx_t_6.memview = NULL; __pyx_t_6.memview = NULL;
__pyx_t_6.data = NULL; __pyx_t_6.data = NULL;
</pre><pre class="cython line score-2" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">26</span>: <span class="k">cdef</span> <span class="kt">float</span>[<span class="p">:,</span> <span class="p">:]</span> <span class="n">img_std_view</span> <span class="o">=</span> <span class="n">img_std</span></pre> </pre>
<pre class='cython code score-2 '> __pyx_t_6 = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(__pyx_v_img_std, PyBUF_WRITABLE);<span class='error_goto'> if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 26, __pyx_L1_error)</span> <pre class="cython line score-2"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">26</span>: <span class="k">cdef</span> <span class="kt">float</span>[<span class="p">:,</span> <span
class="p">:]</span> <span class="n">img_std_view</span> <span class="o">=</span> <span
class="n">img_std</span></pre>
<pre class='cython code score-2 '> __pyx_t_6 = <span
class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(__pyx_v_img_std, PyBUF_WRITABLE);<span
class='error_goto'> if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 26, __pyx_L1_error)</span>
__pyx_v_img_std_view = __pyx_t_6; __pyx_v_img_std_view = __pyx_t_6;
__pyx_t_6.memview = NULL; __pyx_t_6.memview = NULL;
__pyx_t_6.data = NULL; __pyx_t_6.data = NULL;
</pre><pre class="cython line score-0">&#xA0;<span class="">27</span>: </pre> </pre>
<pre class="cython line score-0">&#xA0;<span class="">28</span>: <span class="c"># temporary c variables</span></pre> <pre class="cython line score-0">&#xA0;<span class="">27</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">29</span>: <span class="k">cdef</span> <span class="kt">float</span> <span class="nf">tmp</span><span class="p">,</span> <span class="nf">mean</span><span class="p">,</span> <span class="nf">stddev</span></pre> <pre class="cython line score-0">&#xA0;<span class="">28</span>: <span class="c"># temporary c variables</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">30</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">m</span><span class="p">,</span> <span class="nf">n</span><span class="p">,</span> <span class="nf">i</span><span class="p">,</span> <span class="nf">j</span></pre> <pre class="cython line score-0">&#xA0;<span class="">29</span>: <span class="k">cdef</span> <span class="kt">float</span> <span
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">31</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">ks</span> <span class="o">=</span> <span class="n">kernel_size</span></pre> class="nf">tmp</span><span class="p">,</span> <span class="nf">mean</span><span class="p">,</span> <span
<pre class='cython code score-0 '> __pyx_v_ks = __pyx_v_kernel_size; class="nf">stddev</span></pre>
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">32</span>: <span class="k">cdef</span> <span class="kt">float</span> <span class="nf">eps</span> <span class="o">=</span> <span class="n">epsilon</span></pre> <pre class="cython line score-0">&#xA0;<span class="">30</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span
<pre class='cython code score-0 '> __pyx_v_eps = __pyx_v_epsilon; class="nf">m</span><span class="p">,</span> <span class="nf">n</span><span class="p">,</span> <span
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">33</span>: <span class="k">cdef</span> <span class="kt">float</span> <span class="nf">num</span> <span class="o">=</span> <span class="p">(</span><span class="n">ks</span><span class="o">*</span><span class="mf">2</span><span class="o">+</span><span class="mf">1</span><span class="p">)</span><span class="o">**</span><span class="mf">2</span></pre> class="nf">i</span><span class="p">,</span> <span class="nf">j</span></pre>
<pre class='cython code score-0 '> __pyx_v_num = __Pyx_pow_Py_ssize_t(((__pyx_v_ks * 2) + 1), 2); <pre class="cython line score-0"
</pre><pre class="cython line score-0">&#xA0;<span class="">34</span>: </pre> onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
<pre class="cython line score-0">&#xA0;<span class="">35</span>: <span class="c"># for all pixels do</span></pre> class="">31</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">36</span>: <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ks</span><span class="p">,</span><span class="n">M</span><span class="o">-</span><span class="n">ks</span><span class="p">):</span></pre> class="nf">ks</span> <span class="o">=</span> <span class="n">kernel_size</span></pre>
<pre class='cython code score-0 '> __pyx_t_7 = (__pyx_v_M - __pyx_v_ks); <pre class='cython code score-0 '> __pyx_v_ks = __pyx_v_kernel_size;
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">32</span>: <span class="k">cdef</span> <span class="kt">float</span> <span
class="nf">eps</span> <span class="o">=</span> <span class="n">epsilon</span></pre>
<pre class='cython code score-0 '> __pyx_v_eps = __pyx_v_epsilon;
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">33</span>: <span class="k">cdef</span> <span class="kt">float</span> <span
class="nf">num</span> <span class="o">=</span> <span class="p">(</span><span class="n">ks</span><span
class="o">*</span><span class="mf">2</span><span class="o">+</span><span class="mf">1</span><span class="p">)</span><span
class="o">**</span><span class="mf">2</span></pre>
<pre class='cython code score-0 '> __pyx_v_num = __Pyx_pow_Py_ssize_t(((__pyx_v_ks * 2) + 1), 2);
</pre>
<pre class="cython line score-0">&#xA0;<span class="">34</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">35</span>: <span
class="c"># for all pixels do</span></pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">36</span>: <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span
class="nb">range</span><span class="p">(</span><span class="n">ks</span><span class="p">,</span><span
class="n">M</span><span class="o">-</span><span class="n">ks</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_7 = (__pyx_v_M - __pyx_v_ks);
__pyx_t_8 = __pyx_t_7; __pyx_t_8 = __pyx_t_7;
for (__pyx_t_9 = __pyx_v_ks; __pyx_t_9 &lt; __pyx_t_8; __pyx_t_9+=1) { for (__pyx_t_9 = __pyx_v_ks; __pyx_t_9 &lt; __pyx_t_8; __pyx_t_9+=1) {
__pyx_v_m = __pyx_t_9; __pyx_v_m = __pyx_t_9;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">37</span>: <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ks</span><span class="p">,</span><span class="n">N</span><span class="o">-</span><span class="n">ks</span><span class="p">):</span></pre> </pre>
<pre class='cython code score-0 '> __pyx_t_10 = (__pyx_v_N - __pyx_v_ks); <pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">37</span>: <span class="k">for</span> <span class="n">n</span> <span
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ks</span><span
class="p">,</span><span class="n">N</span><span class="o">-</span><span class="n">ks</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_10 = (__pyx_v_N - __pyx_v_ks);
__pyx_t_11 = __pyx_t_10; __pyx_t_11 = __pyx_t_10;
for (__pyx_t_12 = __pyx_v_ks; __pyx_t_12 &lt; __pyx_t_11; __pyx_t_12+=1) { for (__pyx_t_12 = __pyx_v_ks; __pyx_t_12 &lt; __pyx_t_11; __pyx_t_12+=1) {
__pyx_v_n = __pyx_t_12; __pyx_v_n = __pyx_t_12;
</pre><pre class="cython line score-0">&#xA0;<span class="">38</span>: </pre> </pre>
<pre class="cython line score-0">&#xA0;<span class="">39</span>: <span class="c"># calculate mean</span></pre> <pre class="cython line score-0">&#xA0;<span class="">38</span>: </pre>
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">40</span>: <span class="n">mean</span> <span class="o">=</span> <span class="mf">0</span><span class="p">;</span></pre> <pre class="cython line score-0">&#xA0;<span class="">39</span>: <span class="c"># calculate mean</span></pre>
<pre class='cython code score-0 '> __pyx_v_mean = 0.0; <pre class="cython line score-0"
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">41</span>: <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre> onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
<pre class='cython code score-0 '> __pyx_t_13 = (__pyx_v_ks + 1); class="">40</span>: <span class="n">mean</span> <span class="o">=</span> <span
class="mf">0</span><span class="p">;</span></pre>
<pre class='cython code score-0 '> __pyx_v_mean = 0.0;
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">41</span>: <span class="k">for</span> <span class="n">i</span> <span
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span
class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span
class="mf">1</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_13 = (__pyx_v_ks + 1);
__pyx_t_14 = __pyx_t_13; __pyx_t_14 = __pyx_t_13;
for (__pyx_t_15 = (-__pyx_v_ks); __pyx_t_15 &lt; __pyx_t_14; __pyx_t_15+=1) { for (__pyx_t_15 = (-__pyx_v_ks); __pyx_t_15 &lt; __pyx_t_14; __pyx_t_15+=1) {
__pyx_v_i = __pyx_t_15; __pyx_v_i = __pyx_t_15;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">42</span>: <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre> </pre>
<pre class='cython code score-0 '> __pyx_t_16 = (__pyx_v_ks + 1); <pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">42</span>: <span class="k">for</span> <span class="n">j</span> <span
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span
class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span
class="mf">1</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_16 = (__pyx_v_ks + 1);
__pyx_t_17 = __pyx_t_16; __pyx_t_17 = __pyx_t_16;
for (__pyx_t_18 = (-__pyx_v_ks); __pyx_t_18 &lt; __pyx_t_17; __pyx_t_18+=1) { for (__pyx_t_18 = (-__pyx_v_ks); __pyx_t_18 &lt; __pyx_t_17; __pyx_t_18+=1) {
__pyx_v_j = __pyx_t_18; __pyx_v_j = __pyx_t_18;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">43</span>: <span class="n">mean</span> <span class="o">+=</span> <span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span class="p">]</span></pre> </pre>
<pre class='cython code score-0 '> __pyx_t_19 = (__pyx_v_m + __pyx_v_i); <pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">43</span>: <span class="n">mean</span> <span class="o">+=</span> <span
class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span
class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span
class="p">]</span></pre>
<pre class='cython code score-0 '> __pyx_t_19 = (__pyx_v_m + __pyx_v_i);
__pyx_t_20 = (__pyx_v_n + __pyx_v_j); __pyx_t_20 = (__pyx_v_n + __pyx_v_j);
__pyx_v_mean = (__pyx_v_mean + (*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_19 * __pyx_v_img.strides[0]) ) + __pyx_t_20 * __pyx_v_img.strides[1]) )))); __pyx_v_mean = (__pyx_v_mean + (*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_19 * __pyx_v_img.strides[0]) ) + __pyx_t_20 * __pyx_v_img.strides[1]) ))));
} }
} }
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">44</span>: <span class="n">mean</span> <span class="o">=</span> <span class="n">mean</span><span class="o">/</span><span class="n">num</span></pre> </pre>
<pre class='cython code score-0 '> __pyx_v_mean = (__pyx_v_mean / __pyx_v_num); <pre class="cython line score-0"
</pre><pre class="cython line score-0">&#xA0;<span class="">45</span>: </pre> onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
<pre class="cython line score-0">&#xA0;<span class="">46</span>: <span class="c"># calculate std dev</span></pre> class="">44</span>: <span class="n">mean</span> <span class="o">=</span> <span
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">47</span>: <span class="n">stddev</span> <span class="o">=</span> <span class="mf">0</span><span class="p">;</span></pre> class="n">mean</span><span class="o">/</span><span class="n">num</span></pre>
<pre class='cython code score-0 '> __pyx_v_stddev = 0.0; <pre class='cython code score-0 '> __pyx_v_mean = (__pyx_v_mean / __pyx_v_num);
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">48</span>: <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre> </pre>
<pre class='cython code score-0 '> __pyx_t_13 = (__pyx_v_ks + 1); <pre class="cython line score-0">&#xA0;<span class="">45</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">46</span>: <span
class="c"># calculate std dev</span></pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">47</span>: <span class="n">stddev</span> <span class="o">=</span> <span
class="mf">0</span><span class="p">;</span></pre>
<pre class='cython code score-0 '> __pyx_v_stddev = 0.0;
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">48</span>: <span class="k">for</span> <span class="n">i</span> <span
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span
class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span
class="mf">1</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_13 = (__pyx_v_ks + 1);
__pyx_t_14 = __pyx_t_13; __pyx_t_14 = __pyx_t_13;
for (__pyx_t_15 = (-__pyx_v_ks); __pyx_t_15 &lt; __pyx_t_14; __pyx_t_15+=1) { for (__pyx_t_15 = (-__pyx_v_ks); __pyx_t_15 &lt; __pyx_t_14; __pyx_t_15+=1) {
__pyx_v_i = __pyx_t_15; __pyx_v_i = __pyx_t_15;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">49</span>: <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre> </pre>
<pre class='cython code score-0 '> __pyx_t_16 = (__pyx_v_ks + 1); <pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">49</span>: <span class="k">for</span> <span class="n">j</span> <span
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span
class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span
class="mf">1</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_16 = (__pyx_v_ks + 1);
__pyx_t_17 = __pyx_t_16; __pyx_t_17 = __pyx_t_16;
for (__pyx_t_18 = (-__pyx_v_ks); __pyx_t_18 &lt; __pyx_t_17; __pyx_t_18+=1) { for (__pyx_t_18 = (-__pyx_v_ks); __pyx_t_18 &lt; __pyx_t_17; __pyx_t_18+=1) {
__pyx_v_j = __pyx_t_18; __pyx_v_j = __pyx_t_18;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">50</span>: <span class="n">stddev</span> <span class="o">=</span> <span class="n">stddev</span> <span class="o">+</span> <span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span><span class="o">*</span><span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span></pre> </pre>
<pre class='cython code score-0 '> __pyx_t_21 = (__pyx_v_m + __pyx_v_i); <pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">50</span>: <span class="n">stddev</span> <span class="o">=</span> <span
class="n">stddev</span> <span class="o">+</span> <span class="p">(</span><span class="n">img</span><span
class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span
class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span
class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span><span
class="o">*</span><span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span
class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span
class="o">+</span><span class="n">j</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span
class="p">)</span></pre>
<pre class='cython code score-0 '> __pyx_t_21 = (__pyx_v_m + __pyx_v_i);
__pyx_t_22 = (__pyx_v_n + __pyx_v_j); __pyx_t_22 = (__pyx_v_n + __pyx_v_j);
__pyx_t_23 = (__pyx_v_m + __pyx_v_i); __pyx_t_23 = (__pyx_v_m + __pyx_v_i);
__pyx_t_24 = (__pyx_v_n + __pyx_v_j); __pyx_t_24 = (__pyx_v_n + __pyx_v_j);
__pyx_v_stddev = (__pyx_v_stddev + (((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_21 * __pyx_v_img.strides[0]) ) + __pyx_t_22 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean) * ((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_23 * __pyx_v_img.strides[0]) ) + __pyx_t_24 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean))); __pyx_v_stddev = (__pyx_v_stddev + (((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_21 * __pyx_v_img.strides[0]) ) + __pyx_t_22 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean) * ((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_23 * __pyx_v_img.strides[0]) ) + __pyx_t_24 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean)));
} }
} }
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">51</span>: <span class="n">stddev</span> <span class="o">=</span> <span class="n">sqrt</span><span class="p">(</span><span class="n">stddev</span><span class="o">/</span><span class="n">num</span><span class="p">)</span></pre> </pre>
<pre class='cython code score-0 '> __pyx_v_stddev = sqrt((__pyx_v_stddev / __pyx_v_num)); <pre class="cython line score-0"
</pre><pre class="cython line score-0">&#xA0;<span class="">52</span>: </pre> onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
<pre class="cython line score-0">&#xA0;<span class="">53</span>: <span class="c"># compute normalized image (add epsilon) and std dev image</span></pre> class="">51</span>: <span class="n">stddev</span> <span class="o">=</span> <span
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">54</span>: <span class="n">img_lcn_view</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span><span class="o">/</span><span class="p">(</span><span class="n">stddev</span><span class="o">+</span><span class="n">eps</span><span class="p">)</span></pre> class="n">sqrt</span><span class="p">(</span><span class="n">stddev</span><span class="o">/</span><span
<pre class='cython code score-0 '> __pyx_t_25 = __pyx_v_m; class="n">num</span><span class="p">)</span></pre>
<pre class='cython code score-0 '> __pyx_v_stddev = sqrt((__pyx_v_stddev / __pyx_v_num));
</pre>
<pre class="cython line score-0">&#xA0;<span class="">52</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">53</span>: <span class="c"># compute normalized image (add epsilon) and std dev image</span></pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">54</span>: <span class="n">img_lcn_view</span><span class="p">[</span><span
class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span
class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span
class="n">n</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span
class="p">)</span><span class="o">/</span><span class="p">(</span><span class="n">stddev</span><span
class="o">+</span><span class="n">eps</span><span class="p">)</span></pre>
<pre class='cython code score-0 '> __pyx_t_25 = __pyx_v_m;
__pyx_t_26 = __pyx_v_n; __pyx_t_26 = __pyx_v_n;
__pyx_t_27 = __pyx_v_m; __pyx_t_27 = __pyx_v_m;
__pyx_t_28 = __pyx_v_n; __pyx_t_28 = __pyx_v_n;
*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img_lcn_view.data + __pyx_t_27 * __pyx_v_img_lcn_view.strides[0]) ) + __pyx_t_28 * __pyx_v_img_lcn_view.strides[1]) )) = (((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_25 * __pyx_v_img.strides[0]) ) + __pyx_t_26 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean) / (__pyx_v_stddev + __pyx_v_eps)); *((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img_lcn_view.data + __pyx_t_27 * __pyx_v_img_lcn_view.strides[0]) ) + __pyx_t_28 * __pyx_v_img_lcn_view.strides[1]) )) = (((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_25 * __pyx_v_img.strides[0]) ) + __pyx_t_26 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean) / (__pyx_v_stddev + __pyx_v_eps));
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">55</span>: <span class="n">img_std_view</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span class="n">stddev</span></pre> </pre>
<pre class='cython code score-0 '> __pyx_t_29 = __pyx_v_m; <pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">55</span>: <span class="n">img_std_view</span><span class="p">[</span><span
class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span
class="n">stddev</span></pre>
<pre class='cython code score-0 '> __pyx_t_29 = __pyx_v_m;
__pyx_t_30 = __pyx_v_n; __pyx_t_30 = __pyx_v_n;
*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img_std_view.data + __pyx_t_29 * __pyx_v_img_std_view.strides[0]) ) + __pyx_t_30 * __pyx_v_img_std_view.strides[1]) )) = __pyx_v_stddev; *((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img_std_view.data + __pyx_t_29 * __pyx_v_img_std_view.strides[0]) ) + __pyx_t_30 * __pyx_v_img_std_view.strides[1]) )) = __pyx_v_stddev;
} }
} }
</pre><pre class="cython line score-0">&#xA0;<span class="">56</span>: </pre> </pre>
<pre class="cython line score-0">&#xA0;<span class="">57</span>: <span class="c"># return both</span></pre> <pre class="cython line score-0">&#xA0;<span class="">56</span>: </pre>
<pre class="cython line score-10" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">58</span>: <span class="k">return</span> <span class="n">img_lcn</span><span class="p">,</span> <span class="n">img_std</span></pre> <pre class="cython line score-0">&#xA0;<span class="">57</span>: <span class="c"># return both</span></pre>
<pre class='cython code score-10 '> <span class='pyx_macro_api'>__Pyx_XDECREF</span>(__pyx_r); <pre class="cython line score-10"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">58</span>: <span class="k">return</span> <span class="n">img_lcn</span><span class="p">,</span> <span
class="n">img_std</span></pre>
<pre class='cython code score-10 '> <span class='pyx_macro_api'>__Pyx_XDECREF</span>(__pyx_r);
__pyx_t_1 = <span class='py_c_api'>PyTuple_New</span>(2);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 58, __pyx_L1_error)</span> __pyx_t_1 = <span class='py_c_api'>PyTuple_New</span>(2);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 58, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
<span class='pyx_macro_api'>__Pyx_INCREF</span>(__pyx_v_img_lcn); <span class='pyx_macro_api'>__Pyx_INCREF</span>(__pyx_v_img_lcn);
@ -717,4 +903,7 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
__pyx_r = __pyx_t_1; __pyx_r = __pyx_t_1;
__pyx_t_1 = 0; __pyx_t_1 = 0;
goto __pyx_L0; goto __pyx_L0;
</pre></div></body></html> </pre>
</div>
</body>
</html>

@ -2,5 +2,5 @@ from distutils.core import setup
from Cython.Build import cythonize from Cython.Build import cythonize
setup( setup(
ext_modules = cythonize("lcn.pyx",annotate=True) ext_modules=cythonize("lcn.pyx", annotate=True)
) )

@ -5,23 +5,23 @@ from scipy import misc
# load and convert to float # load and convert to float
img = misc.imread('img.png') img = misc.imread('img.png')
img = img.astype(np.float32)/255.0 img = img.astype(np.float32) / 255.0
# normalize # normalize
img_lcn, img_std = lcn.normalize(img,5,0.05) img_lcn, img_std = lcn.normalize(img, 5, 0.05)
# normalize to reasonable range between 0 and 1 # normalize to reasonable range between 0 and 1
#img_lcn = img_lcn/3.0 # img_lcn = img_lcn/3.0
#img_lcn = np.maximum(img_lcn,0.0) # img_lcn = np.maximum(img_lcn,0.0)
#img_lcn = np.minimum(img_lcn,1.0) # img_lcn = np.minimum(img_lcn,1.0)
# save to file # save to file
#misc.imsave('lcn2.png',img_lcn) # misc.imsave('lcn2.png',img_lcn)
print ("Orig Image: %d x %d (%s), Min: %f, Max: %f" % \ print("Orig Image: %d x %d (%s), Min: %f, Max: %f" % \
(img.shape[0], img.shape[1], img.dtype, img.min(), img.max())) (img.shape[0], img.shape[1], img.dtype, img.min(), img.max()))
print ("Norm Image: %d x %d (%s), Min: %f, Max: %f" % \ print("Norm Image: %d x %d (%s), Min: %f, Max: %f" % \
(img_lcn.shape[0], img_lcn.shape[1], img_lcn.dtype, img_lcn.min(), img_lcn.max())) (img_lcn.shape[0], img_lcn.shape[1], img_lcn.dtype, img_lcn.min(), img_lcn.max()))
# plot original image # plot original image
plt.figure(1) plt.figure(1)
@ -34,14 +34,14 @@ plt.tight_layout()
plt.figure(2) plt.figure(2)
img_lcn_plot = plt.imshow(img_lcn) img_lcn_plot = plt.imshow(img_lcn)
img_lcn_plot.set_cmap('gray') img_lcn_plot.set_cmap('gray')
#plt.clim(0, 1) # fix range # plt.clim(0, 1) # fix range
plt.tight_layout() plt.tight_layout()
# plot stddev image # plot stddev image
plt.figure(3) plt.figure(3)
img_std_plot = plt.imshow(img_std) img_std_plot = plt.imshow(img_std)
img_std_plot.set_cmap('gray') img_std_plot.set_cmap('gray')
#plt.clim(0, 0.1) # fix range # plt.clim(0, 0.1) # fix range
plt.tight_layout() plt.tight_layout()
plt.show() plt.show()

@ -11,20 +11,19 @@ import dataset
def get_data(n, row_from, row_to, train): def get_data(n, row_from, row_to, train):
imsizes = [(256,384)] imsizes = [(256, 384)]
focal_lengths = [160] focal_lengths = [160]
dset = dataset.SynDataset(n, imsizes=imsizes, focal_lengths=focal_lengths, train=train) dset = dataset.SynDataset(n, imsizes=imsizes, focal_lengths=focal_lengths, train=train)
ims = np.empty((n, row_to-row_from, imsizes[0][1]), dtype=np.uint8) ims = np.empty((n, row_to - row_from, imsizes[0][1]), dtype=np.uint8)
disps = np.empty((n, row_to-row_from, imsizes[0][1]), dtype=np.float32) disps = np.empty((n, row_to - row_from, imsizes[0][1]), dtype=np.float32)
for idx in range(n): for idx in range(n):
print(f'load sample {idx} train={train}') print(f'load sample {idx} train={train}')
sample = dset[idx] sample = dset[idx]
ims[idx] = (sample['im0'][0,row_from:row_to] * 255).astype(np.uint8) ims[idx] = (sample['im0'][0, row_from:row_to] * 255).astype(np.uint8)
disps[idx] = sample['disp0'][0,row_from:row_to] disps[idx] = sample['disp0'][0, row_from:row_to]
return ims, disps return ims, disps
params = hd.TrainParams( params = hd.TrainParams(
n_trees=4, n_trees=4,
max_tree_depth=, max_tree_depth=,
@ -45,16 +44,18 @@ n_test_samples = 32
train_ims, train_disps = get_data(n_train_samples, row_from, row_to, True) train_ims, train_disps = get_data(n_train_samples, row_from, row_to, True)
test_ims, test_disps = get_data(n_test_samples, row_from, row_to, False) test_ims, test_disps = get_data(n_test_samples, row_from, row_to, False)
for tree_depth in [8,10,12,14,16]: for tree_depth in [8, 10, 12, 14, 16]:
depth_switch = tree_depth - 4 depth_switch = tree_depth - 4
prefix = f'td{tree_depth}_ds{depth_switch}' prefix = f'td{tree_depth}_ds{depth_switch}'
prefix = Path(f'./forests/{prefix}/') prefix = Path(f'./forests/{prefix}/')
prefix.mkdir(parents=True, exist_ok=True) prefix.mkdir(parents=True, exist_ok=True)
hd.train_forest(params, train_ims, train_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch, forest_prefix=str(prefix / 'fr')) hd.train_forest(params, train_ims, train_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch,
forest_prefix=str(prefix / 'fr'))
es = hd.eval_forest(test_ims, test_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch, forest_prefix=str(prefix / 'fr')) es = hd.eval_forest(test_ims, test_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch,
forest_prefix=str(prefix / 'fr'))
np.save(str(prefix / 'ta.npy'), test_disps) np.save(str(prefix / 'ta.npy'), test_disps)
np.save(str(prefix / 'es.npy'), es) np.save(str(prefix / 'es.npy'), es)

@ -8,7 +8,6 @@ import os
this_dir = os.path.dirname(__file__) this_dir = os.path.dirname(__file__)
extra_compile_args = ['-O3', '-std=c++11'] extra_compile_args = ['-O3', '-std=c++11']
extra_link_args = [] extra_link_args = []
@ -23,7 +22,7 @@ libraries = ['m']
setup( setup(
name="hyperdepth", name="hyperdepth",
cmdclass= {'build_ext': build_ext}, cmdclass={'build_ext': build_ext},
ext_modules=[ ext_modules=[
Extension('hyperdepth', Extension('hyperdepth',
sources, sources,
@ -39,7 +38,3 @@ setup(
) )
] ]
) )

@ -6,10 +6,13 @@ orig = cv2.imread('disp_orig.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
ta = cv2.imread('disp_ta.png', cv2.IMREAD_ANYDEPTH).astype(np.float32) ta = cv2.imread('disp_ta.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
es = cv2.imread('disp_es.png', cv2.IMREAD_ANYDEPTH).astype(np.float32) es = cv2.imread('disp_es.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
plt.figure() plt.figure()
plt.subplot(2,2,1); plt.imshow(orig / 16, vmin=0, vmax=4, cmap='magma') plt.subplot(2, 2, 1);
plt.subplot(2,2,2); plt.imshow(ta / 16, vmin=0, vmax=4, cmap='magma') plt.imshow(orig / 16, vmin=0, vmax=4, cmap='magma')
plt.subplot(2,2,3); plt.imshow(es / 16, vmin=0, vmax=4, cmap='magma') plt.subplot(2, 2, 2);
plt.subplot(2,2,4); plt.imshow(np.abs(es - ta) / 16, vmin=0, vmax=1, cmap='magma') plt.imshow(ta / 16, vmin=0, vmax=4, cmap='magma')
plt.subplot(2, 2, 3);
plt.imshow(es / 16, vmin=0, vmax=4, cmap='magma')
plt.subplot(2, 2, 4);
plt.imshow(np.abs(es - ta) / 16, vmin=0, vmax=1, cmap='magma')
plt.show() plt.show()

@ -12,9 +12,12 @@ import torchext
from model import networks from model import networks
from data import dataset from data import dataset
class Worker(torchext.Worker): class Worker(torchext.Worker):
def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs): def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs):
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, train_batch_size=train_batch_size, test_batch_size=test_batch_size, save_frequency=save_frequency, **kwargs) super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers,
train_batch_size=train_batch_size, test_batch_size=test_batch_size,
save_frequency=save_frequency, **kwargs)
self.ms = args.ms self.ms = args.ms
self.pattern_path = args.pattern_path self.pattern_path = args.pattern_path
@ -22,9 +25,9 @@ class Worker(torchext.Worker):
self.dp_weight = args.dp_weight self.dp_weight = args.dp_weight
self.data_type = args.data_type self.data_type = args.data_type
self.imsizes = [(480,640)] self.imsizes = [(488, 648)]
for iter in range(3): for iter in range(3):
self.imsizes.append((int(self.imsizes[-1][0]/2), int(self.imsizes[-1][1]/2))) self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2)))
with open('config.json') as fp: with open('config.json') as fp:
config = json.load(fp) config = json.load(fp)
@ -32,11 +35,11 @@ class Worker(torchext.Worker):
self.settings_path = data_root / self.data_type / 'settings.pkl' self.settings_path = data_root / self.data_type / 'settings.pkl'
sample_paths = sorted((data_root / self.data_type).glob('0*/')) sample_paths = sorted((data_root / self.data_type).glob('0*/'))
self.train_paths = sample_paths[2**10:] self.train_paths = sample_paths[2 ** 10:]
self.test_paths = sample_paths[:2**8] self.test_paths = sample_paths[:2 ** 8]
# supervise the edge encoder with only 2**8 samples # supervise the edge encoder with only 2**8 samples
self.train_edge = len(self.train_paths) - 2**8 self.train_edge = len(self.train_paths) - 2 ** 8
self.lcn_in = networks.LCN(self.lcn_radius, 0.05) self.lcn_in = networks.LCN(self.lcn_radius, 0.05)
self.disparity_loss = networks.DisparityLoss() self.disparity_loss = networks.DisparityLoss()
@ -44,19 +47,21 @@ class Worker(torchext.Worker):
# evaluate in the region where opencv Block Matching has valid values # evaluate in the region where opencv Block Matching has valid values
self.eval_mask = np.zeros(self.imsizes[0]) self.eval_mask = np.zeros(self.imsizes[0])
self.eval_mask[13:self.imsizes[0][0]-13, 140:self.imsizes[0][1]-13]=1 self.eval_mask[13:self.imsizes[0][0] - 13, 140:self.imsizes[0][1] - 13] = 1
self.eval_mask = self.eval_mask.astype(np.bool) self.eval_mask = self.eval_mask.astype(np.bool)
self.eval_h = self.imsizes[0][0]-2*13 self.eval_h = self.imsizes[0][0] - 2 * 13
self.eval_w = self.imsizes[0][1]-13-140 self.eval_w = self.imsizes[0][1] - 13 - 140
def get_train_set(self): def get_train_set(self):
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True, track_length=1) train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True,
track_length=1)
return train_set return train_set
def get_test_sets(self): def get_test_sets(self):
test_sets = torchext.TestSets() test_sets = torchext.TestSets()
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True, track_length=1) test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True,
track_length=1)
test_sets.append('simple', test_set, test_frequency=1) test_sets.append('simple', test_set, test_frequency=1)
# initialize photometric loss modules according to image sizes # initialize photometric loss modules according to image sizes
@ -66,9 +71,9 @@ class Worker(torchext.Worker):
pat = torch.from_numpy(pat[None][None].astype(np.float32)) pat = torch.from_numpy(pat[None][None].astype(np.float32))
pat = pat.to(self.train_device) pat = pat.to(self.train_device)
self.lcn_in = self.lcn_in.to(self.train_device) self.lcn_in = self.lcn_in.to(self.train_device)
pat,_ = self.lcn_in(pat) pat, _ = self.lcn_in(pat)
pat = torch.cat([pat for idx in range(3)], dim=1) pat = torch.cat([pat for idx in range(3)], dim=1)
self.losses.append( networks.RectifiedPatternSimilarityLoss(imsize[0],imsize[1], pattern=pat) ) self.losses.append(networks.RectifiedPatternSimilarityLoss(imsize[0], imsize[1], pattern=pat))
return test_sets return test_sets
@ -84,10 +89,10 @@ class Worker(torchext.Worker):
# concatenate the normalized IR input and the original IR image # concatenate the normalized IR input and the original IR image
if 'im' in key and 'blend' not in key: if 'im' in key and 'blend' not in key:
im = self.data[key] im = self.data[key]
im_lcn,im_std = self.lcn_in(im) im_lcn, im_std = self.lcn_in(im)
im_cat = torch.cat((im_lcn, im), dim=1) im_cat = torch.cat((im_lcn, im), dim=1)
key_std = key.replace('im','std') key_std = key.replace('im', 'std')
self.data[key]=im_cat self.data[key] = im_cat
self.data[key_std] = im_std.to(device).detach() self.data[key_std] = im_std.to(device).detach()
def net_forward(self, net, train): def net_forward(self, net, train):
@ -96,35 +101,35 @@ class Worker(torchext.Worker):
def loss_forward(self, out, train): def loss_forward(self, out, train):
out, edge = out out, edge = out
if not(isinstance(out, tuple) or isinstance(out, list)): if not (isinstance(out, tuple) or isinstance(out, list)):
out = [out] out = [out]
if not(isinstance(edge, tuple) or isinstance(edge, list)): if not (isinstance(edge, tuple) or isinstance(edge, list)):
edge = [edge] edge = [edge]
vals = [] vals = []
# apply photometric loss # apply photometric loss
for s,l,o in zip(itertools.count(), self.losses, out): for s, l, o in zip(itertools.count(), self.losses, out):
val, pattern_proj = l(o, self.data[f'im{s}'][:,0:1,...], self.data[f'std{s}']) val, pattern_proj = l(o, self.data[f'im{s}'][:, 0:1, ...], self.data[f'std{s}'])
if s == 0: if s == 0:
self.pattern_proj = pattern_proj.detach() self.pattern_proj = pattern_proj.detach()
vals.append(val) vals.append(val)
# apply disparity loss # apply disparity loss
# 1-edge as ground truth edge if inversed # 1-edge as ground truth edge if inversed
edge0 = 1-torch.sigmoid(edge[0]) edge0 = 1 - torch.sigmoid(edge[0])
val = self.disparity_loss(out[0], edge0) val = self.disparity_loss(out[0], edge0)
if self.dp_weight>0: if self.dp_weight > 0:
vals.append(val * self.dp_weight) vals.append(val * self.dp_weight)
# apply edge loss on a subset of training samples # apply edge loss on a subset of training samples
for s,e in zip(itertools.count(), edge): for s, e in zip(itertools.count(), edge):
# inversed ground truth edge where 0 means edge # inversed ground truth edge where 0 means edge
grad = self.data[f'grad{s}']<0.2 grad = self.data[f'grad{s}'] < 0.2
grad = grad.to(torch.float32) grad = grad.to(torch.float32)
ids = self.data['id'] ids = self.data['id']
mask = ids>self.train_edge mask = ids > self.train_edge
if mask.sum()>0: if mask.sum() > 0:
val = self.edge_loss(e[mask], grad[mask]) val = self.edge_loss(e[mask], grad[mask])
else: else:
val = torch.zeros_like(vals[0]) val = torch.zeros_like(vals[0])
@ -138,13 +143,13 @@ class Worker(torchext.Worker):
def numpy_in_out(self, output): def numpy_in_out(self, output):
output, edge = output output, edge = output
if not(isinstance(output, tuple) or isinstance(output, list)): if not (isinstance(output, tuple) or isinstance(output, list)):
output = [output] output = [output]
es = output[0].detach().to('cpu').numpy() es = output[0].detach().to('cpu').numpy()
gt = self.data['disp0'].to('cpu').numpy().astype(np.float32) gt = self.data['disp0'].to('cpu').numpy().astype(np.float32)
im = self.data['im0'][:,0:1,...].detach().to('cpu').numpy() im = self.data['im0'][:, 0:1, ...].detach().to('cpu').numpy()
ma = gt>0 ma = gt > 0
return es, gt, im, ma return es, gt, im, ma
def write_img(self, out_path, es, gt, im, ma): def write_img(self, out_path, es, gt, im, ma):
@ -154,49 +159,82 @@ class Worker(torchext.Worker):
diff = np.abs(es - gt) diff = np.abs(es - gt)
vmin, vmax = np.nanmin(gt), np.nanmax(gt) vmin, vmax = np.nanmin(gt), np.nanmax(gt)
vmin = vmin - 0.2*(vmax-vmin) vmin = vmin - 0.2 * (vmax - vmin)
vmax = vmax + 0.2*(vmax-vmin) vmax = vmax + 0.2 * (vmax - vmin)
pattern_proj = self.pattern_proj.to('cpu').numpy()[0,0] pattern_proj = self.pattern_proj.to('cpu').numpy()[0, 0]
im_orig = self.data['im0'].detach().to('cpu').numpy()[0,0] im_orig = self.data['im0'].detach().to('cpu').numpy()[0, 0]
pattern_diff = np.abs(im_orig - pattern_proj) pattern_diff = np.abs(im_orig - pattern_proj)
fig = plt.figure(figsize=(16, 16))
fig = plt.figure(figsize=(16,16))
es_ = co.cmap.color_depth_map(es, scale=vmax) es_ = co.cmap.color_depth_map(es, scale=vmax)
gt_ = co.cmap.color_depth_map(gt, scale=vmax) gt_ = co.cmap.color_depth_map(gt, scale=vmax)
diff_ = co.cmap.color_error_image(diff, BGR=True) diff_ = co.cmap.color_error_image(diff, BGR=True)
# plot disparities, ground truth disparity is shown only for reference # plot disparities, ground truth disparity is shown only for reference
ax = plt.subplot(3,3,1); plt.imshow(es_[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'Disparity Est. {es.min():.4f}/{es.max():.4f}') ax = plt.subplot(3, 3, 1)
ax = plt.subplot(3,3,2); plt.imshow(gt_[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'Disparity GT {np.nanmin(gt):.4f}/{np.nanmax(gt):.4f}') plt.imshow(es_[..., [2, 1, 0]])
ax = plt.subplot(3,3,3); plt.imshow(diff_[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'Disparity Err. {diff.mean():.5f}') plt.xticks([])
plt.yticks([])
ax.set_title(f'Disparity Est. {es.min():.4f}/{es.max():.4f}')
ax = plt.subplot(3, 3, 2)
plt.imshow(gt_[..., [2, 1, 0]])
plt.xticks([])
plt.yticks([])
ax.set_title(f'Disparity GT {np.nanmin(gt):.4f}/{np.nanmax(gt):.4f}')
ax = plt.subplot(3, 3, 3)
plt.imshow(diff_[..., [2, 1, 0]])
plt.xticks([])
plt.yticks([])
ax.set_title(f'Disparity Err. {diff.mean():.5f}')
# plot edges # plot edges
edge = self.edge.to('cpu').numpy()[0,0] edge = self.edge.to('cpu').numpy()[0, 0]
edge_gt = self.edge_gt.to('cpu').numpy()[0,0] edge_gt = self.edge_gt.to('cpu').numpy()[0, 0]
edge_err = np.abs(edge - edge_gt) edge_err = np.abs(edge - edge_gt)
ax = plt.subplot(3,3,4); plt.imshow(edge, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Edge Est. {edge.min():.5f}/{edge.max():.5f}') ax = plt.subplot(3, 3, 4);
ax = plt.subplot(3,3,5); plt.imshow(edge_gt, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Edge GT {edge_gt.min():.5f}/{edge_gt.max():.5f}') plt.imshow(edge, cmap='gray');
ax = plt.subplot(3,3,6); plt.imshow(edge_err, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Edge Err. {edge_err.mean():.5f}') plt.xticks([]);
plt.yticks([]);
ax.set_title(f'Edge Est. {edge.min():.5f}/{edge.max():.5f}')
ax = plt.subplot(3, 3, 5);
plt.imshow(edge_gt, cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'Edge GT {edge_gt.min():.5f}/{edge_gt.max():.5f}')
ax = plt.subplot(3, 3, 6);
plt.imshow(edge_err, cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'Edge Err. {edge_err.mean():.5f}')
# plot normalized IR input and warped pattern # plot normalized IR input and warped pattern
ax = plt.subplot(3,3,7); plt.imshow(im, vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'IR input {im.mean():.5f}/{im.std():.5f}') ax = plt.subplot(3, 3, 7);
ax = plt.subplot(3,3,8); plt.imshow(pattern_proj, vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Warped Pattern {pattern_proj.mean():.5f}/{pattern_proj.std():.5f}') plt.imshow(im, vmin=im.min(), vmax=im.max(), cmap='gray');
im_std = self.data['std0'].to('cpu').numpy()[0,0] plt.xticks([]);
ax = plt.subplot(3,3,9); plt.imshow(im_std, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'IR std {im_std.min():.5f}/{im_std.max():.5f}') plt.yticks([]);
ax.set_title(f'IR input {im.mean():.5f}/{im.std():.5f}')
ax = plt.subplot(3, 3, 8);
plt.imshow(pattern_proj, vmin=im.min(), vmax=im.max(), cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'Warped Pattern {pattern_proj.mean():.5f}/{pattern_proj.std():.5f}')
im_std = self.data['std0'].to('cpu').numpy()[0, 0]
ax = plt.subplot(3, 3, 9);
plt.imshow(im_std, cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'IR std {im_std.min():.5f}/{im_std.max():.5f}')
plt.tight_layout() plt.tight_layout()
plt.savefig(str(out_path)) plt.savefig(str(out_path))
plt.close(fig) plt.close(fig)
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks=[]): def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks=[]):
if batch_idx % 512 == 0: if batch_idx % 512 == 0:
out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png' out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png'
es, gt, im, ma = self.numpy_in_out(output) es, gt, im, ma = self.numpy_in_out(output)
self.write_img(out_path, es[0,0], gt[0,0], im[0,0], ma[0,0]) self.write_img(out_path, es[0, 0], gt[0, 0], im[0, 0], ma[0, 0])
def callback_test_start(self, epoch, set_idx): def callback_test_start(self, epoch, set_idx):
self.metric = co.metric.MultipleMetric( self.metric = co.metric.MultipleMetric(
@ -209,12 +247,12 @@ class Worker(torchext.Worker):
if batch_idx % 8 == 0: if batch_idx % 8 == 0:
out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png' out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png'
self.write_img(out_path, es[0,0], gt[0,0], im[0,0], ma[0,0]) self.write_img(out_path, es[0, 0], gt[0, 0], im[0, 0], ma[0, 0])
es, gt, im, ma = self.crop_output(es, gt, im, ma) es, gt, im, ma = self.crop_output(es, gt, im, ma)
es = es.reshape(-1,1) es = es.reshape(-1, 1)
gt = gt.reshape(-1,1) gt = gt.reshape(-1, 1)
ma = ma.ravel() ma = ma.ravel()
self.metric.add(es, gt, ma) self.metric.add(es, gt, ma)
@ -225,13 +263,12 @@ class Worker(torchext.Worker):
def crop_output(self, es, gt, im, ma): def crop_output(self, es, gt, im, ma):
bs = es.shape[0] bs = es.shape[0]
es = np.reshape(es[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w]) es = np.reshape(es[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
gt = np.reshape(gt[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w]) gt = np.reshape(gt[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
im = np.reshape(im[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w]) im = np.reshape(im[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
ma = np.reshape(ma[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w]) ma = np.reshape(ma[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
return es, gt, im, ma return es, gt, im, ma
if __name__ == '__main__': if __name__ == '__main__':
pass pass

@ -12,9 +12,12 @@ import torchext
from model import networks from model import networks
from data import dataset from data import dataset
class Worker(torchext.Worker): class Worker(torchext.Worker):
def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs): def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs):
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, train_batch_size=train_batch_size, test_batch_size=test_batch_size, save_frequency=save_frequency, **kwargs) super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers,
train_batch_size=train_batch_size, test_batch_size=test_batch_size,
save_frequency=save_frequency, **kwargs)
self.ms = args.ms self.ms = args.ms
self.pattern_path = args.pattern_path self.pattern_path = args.pattern_path
@ -23,11 +26,11 @@ class Worker(torchext.Worker):
self.ge_weight = args.ge_weight self.ge_weight = args.ge_weight
self.track_length = args.track_length self.track_length = args.track_length
self.data_type = args.data_type self.data_type = args.data_type
assert(self.track_length>1) assert (self.track_length > 1)
self.imsizes = [(480,640)] self.imsizes = [(480, 640)]
for iter in range(3): for iter in range(3):
self.imsizes.append((int(self.imsizes[-1][0]/2), int(self.imsizes[-1][1]/2))) self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2)))
with open('config.json') as fp: with open('config.json') as fp:
config = json.load(fp) config = json.load(fp)
@ -35,11 +38,11 @@ class Worker(torchext.Worker):
self.settings_path = data_root / self.data_type / 'settings.pkl' self.settings_path = data_root / self.data_type / 'settings.pkl'
sample_paths = sorted((data_root / self.data_type).glob('0*/')) sample_paths = sorted((data_root / self.data_type).glob('0*/'))
self.train_paths = sample_paths[2**10:] self.train_paths = sample_paths[2 ** 10:]
self.test_paths = sample_paths[:2**8] self.test_paths = sample_paths[:2 ** 8]
# supervise the edge encoder with only 2**8 samples # supervise the edge encoder with only 2**8 samples
self.train_edge = len(self.train_paths) - 2**8 self.train_edge = len(self.train_paths) - 2 ** 8
self.lcn_in = networks.LCN(self.lcn_radius, 0.05) self.lcn_in = networks.LCN(self.lcn_radius, 0.05)
self.disparity_loss = networks.DisparityLoss() self.disparity_loss = networks.DisparityLoss()
@ -47,19 +50,20 @@ class Worker(torchext.Worker):
# evaluate in the region where opencv Block Matching has valid values # evaluate in the region where opencv Block Matching has valid values
self.eval_mask = np.zeros(self.imsizes[0]) self.eval_mask = np.zeros(self.imsizes[0])
self.eval_mask[13:self.imsizes[0][0]-13, 140:self.imsizes[0][1]-13]=1 self.eval_mask[13:self.imsizes[0][0] - 13, 140:self.imsizes[0][1] - 13] = 1
self.eval_mask = self.eval_mask.astype(np.bool) self.eval_mask = self.eval_mask.astype(np.bool)
self.eval_h = self.imsizes[0][0]-2*13 self.eval_h = self.imsizes[0][0] - 2 * 13
self.eval_w = self.imsizes[0][1]-13-140 self.eval_w = self.imsizes[0][1] - 13 - 140
def get_train_set(self): def get_train_set(self):
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True, track_length=self.track_length) train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True,
track_length=self.track_length)
return train_set return train_set
def get_test_sets(self): def get_test_sets(self):
test_sets = torchext.TestSets() test_sets = torchext.TestSets()
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True, track_length=1) test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True,
track_length=1)
test_sets.append('simple', test_set, test_frequency=1) test_sets.append('simple', test_set, test_frequency=1)
self.ph_losses = [] self.ph_losses = []
@ -72,9 +76,9 @@ class Worker(torchext.Worker):
pat = test_set.patterns[sidx] pat = test_set.patterns[sidx]
pat = pat.mean(axis=2) pat = pat.mean(axis=2)
pat = torch.from_numpy(pat[None][None].astype(np.float32)).to('cuda') pat = torch.from_numpy(pat[None][None].astype(np.float32)).to('cuda')
pat,_ = self.lcn_in(pat) pat, _ = self.lcn_in(pat)
pat = torch.cat([pat for idx in range(3)], dim=1) pat = torch.cat([pat for idx in range(3)], dim=1)
ph_loss = networks.RectifiedPatternSimilarityLoss(imsize[0],imsize[1], pattern=pat) ph_loss = networks.RectifiedPatternSimilarityLoss(imsize[0], imsize[1], pattern=pat)
K = test_set.getK(sidx) K = test_set.getK(sidx)
Ki = np.linalg.inv(K) Ki = np.linalg.inv(K)
@ -82,11 +86,11 @@ class Worker(torchext.Worker):
Ki = torch.from_numpy(Ki) Ki = torch.from_numpy(Ki)
ge_loss = networks.ProjectionDepthSimilarityLoss(K, Ki, imsize[0], imsize[1], clamp=0.1) ge_loss = networks.ProjectionDepthSimilarityLoss(K, Ki, imsize[0], imsize[1], clamp=0.1)
self.ph_losses.append( ph_loss ) self.ph_losses.append(ph_loss)
self.ge_losses.append( ge_loss ) self.ge_losses.append(ge_loss)
d2d = networks.DispToDepth(float(test_set.focal_lengths[sidx]), float(test_set.baseline)) d2d = networks.DispToDepth(float(test_set.focal_lengths[sidx]), float(test_set.baseline))
self.d2ds.append( d2d ) self.d2ds.append(d2d)
return test_sets return test_sets
@ -99,9 +103,9 @@ class Worker(torchext.Worker):
# batch_size x track_length x ... # batch_size x track_length x ...
# to # to
# track_length x batch_size x ... # track_length x batch_size x ...
if len(val.shape)>2: if len(val.shape) > 2:
if train: if train:
val = val.transpose(0,1) val = val.transpose(0, 1)
else: else:
val = val.unsqueeze(0) val = val.unsqueeze(0)
grad = 'im' in key and requires_grad grad = 'im' in key and requires_grad
@ -110,8 +114,8 @@ class Worker(torchext.Worker):
im = self.data[key] im = self.data[key]
tl = im.shape[0] tl = im.shape[0]
bs = im.shape[1] bs = im.shape[1]
im_lcn,im_std = self.lcn_in(im.contiguous().view(-1, *im.shape[2:])) im_lcn, im_std = self.lcn_in(im.contiguous().view(-1, *im.shape[2:]))
key_std = key.replace('im','std') key_std = key.replace('im', 'std')
self.data[key_std] = im_std.view(tl, bs, *im.shape[2:]).to(device) self.data[key_std] = im_std.view(tl, bs, *im.shape[2:]).to(device)
im_cat = torch.cat((im_lcn.view(tl, bs, *im.shape[2:]), im), dim=2) im_cat = torch.cat((im_lcn.view(tl, bs, *im.shape[2:]), im), dim=2)
self.data[key] = im_cat self.data[key] = im_cat
@ -122,7 +126,7 @@ class Worker(torchext.Worker):
bs = im0.shape[1] bs = im0.shape[1]
im0 = im0.view(-1, *im0.shape[2:]) im0 = im0.view(-1, *im0.shape[2:])
out, edge = net(im0) out, edge = net(im0)
if not(isinstance(out, tuple) or isinstance(out, list)): if not (isinstance(out, tuple) or isinstance(out, list)):
out = out.view(tl, bs, *out.shape[1:]) out = out.view(tl, bs, *out.shape[1:])
edge = edge.view(tl, bs, *out.shape[1:]) edge = edge.view(tl, bs, *out.shape[1:])
else: else:
@ -132,42 +136,42 @@ class Worker(torchext.Worker):
def loss_forward(self, out, train): def loss_forward(self, out, train):
out, edge = out out, edge = out
if not(isinstance(out, tuple) or isinstance(out, list)): if not (isinstance(out, tuple) or isinstance(out, list)):
out = [out] out = [out]
vals = [] vals = []
diffs = [] diffs = []
# apply photometric loss # apply photometric loss
for s,l,o in zip(itertools.count(), self.ph_losses, out): for s, l, o in zip(itertools.count(), self.ph_losses, out):
im = self.data[f'im{s}'] im = self.data[f'im{s}']
im = im.view(-1, *im.shape[2:]) im = im.view(-1, *im.shape[2:])
o = o.view(-1, *o.shape[2:]) o = o.view(-1, *o.shape[2:])
std = self.data[f'std{s}'] std = self.data[f'std{s}']
std = std.view(-1, *std.shape[2:]) std = std.view(-1, *std.shape[2:])
val, pattern_proj = l(o, im[:,0:1,...], std) val, pattern_proj = l(o, im[:, 0:1, ...], std)
vals.append(val) vals.append(val)
if s == 0: if s == 0:
self.pattern_proj = pattern_proj.detach() self.pattern_proj = pattern_proj.detach()
# apply disparity loss # apply disparity loss
# 1-edge as ground truth edge if inversed # 1-edge as ground truth edge if inversed
edge0 = 1-torch.sigmoid(edge[0]) edge0 = 1 - torch.sigmoid(edge[0])
edge0 = edge0.view(-1, *edge0.shape[2:]) edge0 = edge0.view(-1, *edge0.shape[2:])
out0 = out[0].view(-1, *out[0].shape[2:]) out0 = out[0].view(-1, *out[0].shape[2:])
val = self.disparity_loss(out0, edge0) val = self.disparity_loss(out0, edge0)
if self.dp_weight>0: if self.dp_weight > 0:
vals.append(val * self.dp_weight) vals.append(val * self.dp_weight)
# apply edge loss on a subset of training samples # apply edge loss on a subset of training samples
for s,e in zip(itertools.count(), edge): for s, e in zip(itertools.count(), edge):
# inversed ground truth edge where 0 means edge # inversed ground truth edge where 0 means edge
grad = self.data[f'grad{s}']<0.2 grad = self.data[f'grad{s}'] < 0.2
grad = grad.to(torch.float32) grad = grad.to(torch.float32)
ids = self.data['id'] ids = self.data['id']
mask = ids>self.train_edge mask = ids > self.train_edge
if mask.sum()>0: if mask.sum() > 0:
e = e[:,mask,:] e = e[:, mask, :]
grad = grad[:,mask,:] grad = grad[:, mask, :]
e = e.view(-1, *e.shape[2:]) e = e.view(-1, *e.shape[2:])
grad = grad.view(-1, *grad.shape[2:]) grad = grad.view(-1, *grad.shape[2:])
val = self.edge_loss(e, grad) val = self.edge_loss(e, grad)
@ -181,14 +185,14 @@ class Worker(torchext.Worker):
# apply geometric loss # apply geometric loss
R = self.data['R'] R = self.data['R']
t = self.data['t'] t = self.data['t']
ge_num = self.track_length * (self.track_length-1) / 2 ge_num = self.track_length * (self.track_length - 1) / 2
for sidx in range(len(out)): for sidx in range(len(out)):
d2d = self.d2ds[sidx] d2d = self.d2ds[sidx]
depth = d2d(out[sidx]) depth = d2d(out[sidx])
ge_loss = self.ge_losses[sidx] ge_loss = self.ge_losses[sidx]
imsize = self.imsizes[sidx] imsize = self.imsizes[sidx]
for tidx0 in range(depth.shape[0]): for tidx0 in range(depth.shape[0]):
for tidx1 in range(tidx0+1, depth.shape[0]): for tidx1 in range(tidx0 + 1, depth.shape[0]):
depth0 = depth[tidx0] depth0 = depth[tidx0]
R0 = R[tidx0] R0 = R[tidx0]
t0 = t[tidx0] t0 = t[tidx0]
@ -203,12 +207,12 @@ class Worker(torchext.Worker):
def numpy_in_out(self, output): def numpy_in_out(self, output):
output, edge = output output, edge = output
if not(isinstance(output, tuple) or isinstance(output, list)): if not (isinstance(output, tuple) or isinstance(output, list)):
output = [output] output = [output]
es = output[0].detach().to('cpu').numpy() es = output[0].detach().to('cpu').numpy()
gt = self.data['disp0'].to('cpu').numpy().astype(np.float32) gt = self.data['disp0'].to('cpu').numpy().astype(np.float32)
im = self.data['im0'][:,:,0:1,...].detach().to('cpu').numpy() im = self.data['im0'][:, :, 0:1, ...].detach().to('cpu').numpy()
ma = gt>0 ma = gt > 0
return es, gt, im, ma return es, gt, im, ma
def write_img(self, out_path, es, gt, im, ma): def write_img(self, out_path, es, gt, im, ma):
@ -218,36 +222,68 @@ class Worker(torchext.Worker):
diff = np.abs(es - gt) diff = np.abs(es - gt)
vmin, vmax = np.nanmin(gt), np.nanmax(gt) vmin, vmax = np.nanmin(gt), np.nanmax(gt)
vmin = vmin - 0.2*(vmax-vmin) vmin = vmin - 0.2 * (vmax - vmin)
vmax = vmax + 0.2*(vmax-vmin) vmax = vmax + 0.2 * (vmax - vmin)
pattern_proj = self.pattern_proj.to('cpu').numpy()[0,0] pattern_proj = self.pattern_proj.to('cpu').numpy()[0, 0]
im_orig = self.data['im0'].detach().to('cpu').numpy()[0,0,0] im_orig = self.data['im0'].detach().to('cpu').numpy()[0, 0, 0]
pattern_diff = np.abs(im_orig - pattern_proj) pattern_diff = np.abs(im_orig - pattern_proj)
fig = plt.figure(figsize=(16,16)) fig = plt.figure(figsize=(16, 16))
es0 = co.cmap.color_depth_map(es[0], scale=vmax) es0 = co.cmap.color_depth_map(es[0], scale=vmax)
gt0 = co.cmap.color_depth_map(gt[0], scale=vmax) gt0 = co.cmap.color_depth_map(gt[0], scale=vmax)
diff0 = co.cmap.color_error_image(diff[0], BGR=True) diff0 = co.cmap.color_error_image(diff[0], BGR=True)
# plot disparities, ground truth disparity is shown only for reference # plot disparities, ground truth disparity is shown only for reference
ax = plt.subplot(3,3,1); plt.imshow(es0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Est. {es0.min():.4f}/{es0.max():.4f}') ax = plt.subplot(3, 3, 1);
ax = plt.subplot(3,3,2); plt.imshow(gt0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity GT {np.nanmin(gt0):.4f}/{np.nanmax(gt0):.4f}') plt.imshow(es0[..., [2, 1, 0]]);
ax = plt.subplot(3,3,3); plt.imshow(diff0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Err. {diff0.mean():.5f}') plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F0 Disparity Est. {es0.min():.4f}/{es0.max():.4f}')
ax = plt.subplot(3, 3, 2);
plt.imshow(gt0[..., [2, 1, 0]]);
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F0 Disparity GT {np.nanmin(gt0):.4f}/{np.nanmax(gt0):.4f}')
ax = plt.subplot(3, 3, 3);
plt.imshow(diff0[..., [2, 1, 0]]);
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F0 Disparity Err. {diff0.mean():.5f}')
# plot disparities of the second frame in the track if exists # plot disparities of the second frame in the track if exists
if es.shape[0]>=2: if es.shape[0] >= 2:
es1 = co.cmap.color_depth_map(es[1], scale=vmax) es1 = co.cmap.color_depth_map(es[1], scale=vmax)
gt1 = co.cmap.color_depth_map(gt[1], scale=vmax) gt1 = co.cmap.color_depth_map(gt[1], scale=vmax)
diff1 = co.cmap.color_error_image(diff[1], BGR=True) diff1 = co.cmap.color_error_image(diff[1], BGR=True)
ax = plt.subplot(3,3,4); plt.imshow(es1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity Est. {es1.min():.4f}/{es1.max():.4f}') ax = plt.subplot(3, 3, 4);
ax = plt.subplot(3,3,5); plt.imshow(gt1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity GT {np.nanmin(gt1):.4f}/{np.nanmax(gt1):.4f}') plt.imshow(es1[..., [2, 1, 0]]);
ax = plt.subplot(3,3,6); plt.imshow(diff1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity Err. {diff1.mean():.5f}') plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F1 Disparity Est. {es1.min():.4f}/{es1.max():.4f}')
ax = plt.subplot(3, 3, 5);
plt.imshow(gt1[..., [2, 1, 0]]);
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F1 Disparity GT {np.nanmin(gt1):.4f}/{np.nanmax(gt1):.4f}')
ax = plt.subplot(3, 3, 6);
plt.imshow(diff1[..., [2, 1, 0]]);
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F1 Disparity Err. {diff1.mean():.5f}')
# plot normalized IR inputs # plot normalized IR inputs
ax = plt.subplot(3,3,7); plt.imshow(im[0], vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 IR input {im[0].mean():.5f}/{im[0].std():.5f}') ax = plt.subplot(3, 3, 7);
if es.shape[0]>=2: plt.imshow(im[0], vmin=im.min(), vmax=im.max(), cmap='gray');
ax = plt.subplot(3,3,8); plt.imshow(im[1], vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 IR input {im[1].mean():.5f}/{im[1].std():.5f}') plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F0 IR input {im[0].mean():.5f}/{im[0].std():.5f}')
if es.shape[0] >= 2:
ax = plt.subplot(3, 3, 8);
plt.imshow(im[1], vmin=im.min(), vmax=im.max(), cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F1 IR input {im[1].mean():.5f}/{im[1].std():.5f}')
plt.tight_layout() plt.tight_layout()
plt.savefig(str(out_path)) plt.savefig(str(out_path))
@ -257,8 +293,8 @@ class Worker(torchext.Worker):
if batch_idx % 512 == 0: if batch_idx % 512 == 0:
out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png' out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png'
es, gt, im, ma = self.numpy_in_out(output) es, gt, im, ma = self.numpy_in_out(output)
masks = [ m.detach().to('cpu').numpy() for m in masks ] masks = [m.detach().to('cpu').numpy() for m in masks]
self.write_img(out_path, es[:,0,0], gt[:,0,0], im[:,0,0], ma[:,0,0]) self.write_img(out_path, es[:, 0, 0], gt[:, 0, 0], im[:, 0, 0], ma[:, 0, 0])
def callback_test_start(self, epoch, set_idx): def callback_test_start(self, epoch, set_idx):
self.metric = co.metric.MultipleMetric( self.metric = co.metric.MultipleMetric(
@ -271,12 +307,12 @@ class Worker(torchext.Worker):
if batch_idx % 8 == 0: if batch_idx % 8 == 0:
out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png' out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png'
self.write_img(out_path, es[:,0,0], gt[:,0,0], im[:,0,0], ma[:,0,0]) self.write_img(out_path, es[:, 0, 0], gt[:, 0, 0], im[:, 0, 0], ma[:, 0, 0])
es, gt, im, ma = self.crop_output(es, gt, im, ma) es, gt, im, ma = self.crop_output(es, gt, im, ma)
es = es.reshape(-1,1) es = es.reshape(-1, 1)
gt = gt.reshape(-1,1) gt = gt.reshape(-1, 1)
ma = ma.ravel() ma = ma.ravel()
self.metric.add(es, gt, ma) self.metric.add(es, gt, ma)
@ -288,11 +324,12 @@ class Worker(torchext.Worker):
def crop_output(self, es, gt, im, ma): def crop_output(self, es, gt, im, ma):
tl = es.shape[0] tl = es.shape[0]
bs = es.shape[1] bs = es.shape[1]
es = np.reshape(es[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w]) es = np.reshape(es[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w])
gt = np.reshape(gt[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w]) gt = np.reshape(gt[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w])
im = np.reshape(im[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w]) im = np.reshape(im[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w])
ma = np.reshape(ma[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w]) ma = np.reshape(ma[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w])
return es, gt, im, ma return es, gt, im, ma
if __name__ == '__main__': if __name__ == '__main__':
pass pass

@ -44,7 +44,7 @@ class PosOutput(TimedModule):
def tforward(self, x): def tforward(self, x):
if self.u_pos is None: if self.u_pos is None:
self.u_pos = torch.arange(x.shape[3], dtype=torch.float32).view(1,1,1,-1) self.u_pos = torch.arange(x.shape[3], dtype=torch.float32).view(1, 1, 1, -1)
self.u_pos = self.u_pos.to(x.device) self.u_pos = self.u_pos.to(x.device)
pos = self.layer(x) pos = self.layer(x)
disp = self.u_pos - pos disp = self.u_pos - pos
@ -61,6 +61,7 @@ class OutputLayerFactory(object):
pos: estimate the absolute location pos: estimate the absolute location
pos_row: independently estimate the absolute location per row pos_row: independently estimate the absolute location per row
''' '''
def __init__(self, type='disp', params={}): def __init__(self, type='disp', params={}):
self.type = type self.type = type
self.params = params self.params = params
@ -98,7 +99,7 @@ class SigmoidAffine(TimedModule):
self.offset = offset self.offset = offset
def tforward(self, x): def tforward(self, x):
return torch.sigmoid(x/self.gamma - self.offset) * self.alpha + self.beta return torch.sigmoid(x / self.gamma - self.offset) * self.alpha + self.beta
class MultiLinear(TimedModule): class MultiLinear(TimedModule):
@ -110,26 +111,27 @@ class MultiLinear(TimedModule):
self.mods.append(torch.nn.Linear(channels_in, channels_out)) self.mods.append(torch.nn.Linear(channels_in, channels_out))
def tforward(self, x): def tforward(self, x):
x = x.permute(2,0,3,1) # BxCxHxW => HxBxWxC x = x.permute(2, 0, 3, 1) # BxCxHxW => HxBxWxC
y = x.new_empty(*x.shape[:-1], self.channels_out) y = x.new_empty(*x.shape[:-1], self.channels_out)
for hidx in range(x.shape[0]): for hidx in range(x.shape[0]):
y[hidx] = self.mods[hidx](x[hidx]) y[hidx] = self.mods[hidx](x[hidx])
y = y.permute(1,3,0,2) # HxBxWxC => BxCxHxW y = y.permute(1, 3, 0, 2) # HxBxWxC => BxCxHxW
return y return y
class DispNetS(TimedModule): class DispNetS(TimedModule):
''' '''
Disparity Decoder based on DispNetS Disparity Decoder based on DispNetS
''' '''
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False, channel_multiplier=1):
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False,
channel_multiplier=1):
super(DispNetS, self).__init__(mod_name='DispNetS') super(DispNetS, self).__init__(mod_name='DispNetS')
self.output_ms = output_ms self.output_ms = output_ms
self.coordconv = coordconv self.coordconv = coordconv
conv_planes = channel_multiplier * np.array( [32, 64, 128, 256, 512, 512, 512] ) conv_planes = channel_multiplier * np.array([32, 64, 128, 256, 512, 512, 512])
self.conv1 = self.downsample_conv(channels_in, conv_planes[0], kernel_size=7) self.conv1 = self.downsample_conv(channels_in, conv_planes[0], kernel_size=7)
self.conv2 = self.downsample_conv(conv_planes[0], conv_planes[1], kernel_size=5) self.conv2 = self.downsample_conv(conv_planes[0], conv_planes[1], kernel_size=5)
self.conv3 = self.downsample_conv(conv_planes[1], conv_planes[2]) self.conv3 = self.downsample_conv(conv_planes[1], conv_planes[2])
@ -138,7 +140,7 @@ class DispNetS(TimedModule):
self.conv6 = self.downsample_conv(conv_planes[4], conv_planes[5]) self.conv6 = self.downsample_conv(conv_planes[4], conv_planes[5])
self.conv7 = self.downsample_conv(conv_planes[5], conv_planes[6]) self.conv7 = self.downsample_conv(conv_planes[5], conv_planes[6])
upconv_planes = channel_multiplier * np.array( [512, 512, 256, 128, 64, 32, 16] ) upconv_planes = channel_multiplier * np.array([512, 512, 256, 128, 64, 32, 16])
self.upconv7 = self.upconv(conv_planes[6], upconv_planes[0]) self.upconv7 = self.upconv(conv_planes[6], upconv_planes[0])
self.upconv6 = self.upconv(upconv_planes[0], upconv_planes[1]) self.upconv6 = self.upconv(upconv_planes[0], upconv_planes[1])
self.upconv5 = self.upconv(upconv_planes[1], upconv_planes[2]) self.upconv5 = self.upconv(upconv_planes[1], upconv_planes[2])
@ -166,7 +168,6 @@ class DispNetS(TimedModule):
self.predict_disp2 = output_facs(upconv_planes[5], imsizes[1]) self.predict_disp2 = output_facs(upconv_planes[5], imsizes[1])
self.predict_disp1 = output_facs(upconv_planes[6], imsizes[0]) self.predict_disp1 = output_facs(upconv_planes[6], imsizes[0])
def init_weights(self): def init_weights(self):
for m in self.modules(): for m in self.modules():
if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.ConvTranspose2d): if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.ConvTranspose2d):
@ -176,13 +177,15 @@ class DispNetS(TimedModule):
def downsample_conv(self, in_planes, out_planes, kernel_size=3): def downsample_conv(self, in_planes, out_planes, kernel_size=3):
if self.coordconv: if self.coordconv:
conv = torchext.CoordConv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=(kernel_size-1)//2) conv = torchext.CoordConv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2,
padding=(kernel_size - 1) // 2)
else: else:
conv = torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=(kernel_size-1)//2) conv = torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2,
padding=(kernel_size - 1) // 2)
return torch.nn.Sequential( return torch.nn.Sequential(
conv, conv,
torch.nn.ReLU(inplace=True), torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(out_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size-1)//2), torch.nn.Conv2d(out_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size - 1) // 2),
torch.nn.ReLU(inplace=True) torch.nn.ReLU(inplace=True)
) )
@ -199,7 +202,7 @@ class DispNetS(TimedModule):
) )
def crop_like(self, input, ref): def crop_like(self, input, ref):
assert(input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3)) assert (input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3))
return input[:, :, :ref.size(2), :ref.size(3)] return input[:, :, :ref.size(2), :ref.size(3)]
def tforward(self, x): def tforward(self, x):
@ -229,19 +232,22 @@ class DispNetS(TimedModule):
disp4 = self.predict_disp4(out_iconv4) disp4 = self.predict_disp4(out_iconv4)
out_upconv3 = self.crop_like(self.upconv3(out_iconv4), out_conv2) out_upconv3 = self.crop_like(self.upconv3(out_iconv4), out_conv2)
disp4_up = self.crop_like(torch.nn.functional.interpolate(disp4, scale_factor=2, mode='bilinear', align_corners=False), out_conv2) disp4_up = self.crop_like(
torch.nn.functional.interpolate(disp4, scale_factor=2, mode='bilinear', align_corners=False), out_conv2)
concat3 = torch.cat((out_upconv3, out_conv2, disp4_up), 1) concat3 = torch.cat((out_upconv3, out_conv2, disp4_up), 1)
out_iconv3 = self.iconv3(concat3) out_iconv3 = self.iconv3(concat3)
disp3 = self.predict_disp3(out_iconv3) disp3 = self.predict_disp3(out_iconv3)
out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1) out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1)
disp3_up = self.crop_like(torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1) disp3_up = self.crop_like(
torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1) concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1)
out_iconv2 = self.iconv2(concat2) out_iconv2 = self.iconv2(concat2)
disp2 = self.predict_disp2(out_iconv2) disp2 = self.predict_disp2(out_iconv2)
out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x) out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x)
disp2_up = self.crop_like(torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x) disp2_up = self.crop_like(
torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x)
concat1 = torch.cat((out_upconv1, disp2_up), 1) concat1 = torch.cat((out_upconv1, disp2_up), 1)
out_iconv1 = self.iconv1(concat1) out_iconv1 = self.iconv1(concat1)
disp1 = self.predict_disp1(out_iconv1) disp1 = self.predict_disp1(out_iconv1)
@ -256,6 +262,7 @@ class DispNetShallow(DispNetS):
''' '''
Edge Decoder based on DispNetS with fewer layers Edge Decoder based on DispNetS with fewer layers
''' '''
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False): def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False):
super(DispNetShallow, self).__init__(channels_in, imsizes, output_facs, output_ms, coordconv, weight_init) super(DispNetShallow, self).__init__(channels_in, imsizes, output_facs, output_ms, coordconv, weight_init)
self.mod_name = 'DispNetShallow' self.mod_name = 'DispNetShallow'
@ -274,13 +281,15 @@ class DispNetShallow(DispNetS):
disp3 = self.predict_disp3(out_iconv3) disp3 = self.predict_disp3(out_iconv3)
out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1) out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1)
disp3_up = self.crop_like(torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1) disp3_up = self.crop_like(
torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1) concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1)
out_iconv2 = self.iconv2(concat2) out_iconv2 = self.iconv2(concat2)
disp2 = self.predict_disp2(out_iconv2) disp2 = self.predict_disp2(out_iconv2)
out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x) out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x)
disp2_up = self.crop_like(torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x) disp2_up = self.crop_like(
torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x)
concat1 = torch.cat((out_upconv1, disp2_up), 1) concat1 = torch.cat((out_upconv1, disp2_up), 1)
out_iconv1 = self.iconv1(concat1) out_iconv1 = self.iconv1(concat1)
disp1 = self.predict_disp1(out_iconv1) disp1 = self.predict_disp1(out_iconv1)
@ -295,13 +304,16 @@ class DispEdgeDecoders(TimedModule):
''' '''
Disparity Decoder and Edge Decoder Disparity Decoder and Edge Decoder
''' '''
def __init__(self, *args, max_disp=128, **kwargs): def __init__(self, *args, max_disp=128, **kwargs):
super(DispEdgeDecoders, self).__init__(mod_name='DispEdgeDecoders') super(DispEdgeDecoders, self).__init__(mod_name='DispEdgeDecoders')
output_facs = [OutputLayerFactory( type='disp', params={ 'alpha': max_disp/(2**s), 'beta': 0, 'gamma': 1, 'offset': 3}) for s in range(4)] output_facs = [
OutputLayerFactory(type='disp', params={'alpha': max_disp / (2 ** s), 'beta': 0, 'gamma': 1, 'offset': 3})
for s in range(4)]
self.disp_decoder = DispNetS(*args, output_facs=output_facs, **kwargs) self.disp_decoder = DispNetS(*args, output_facs=output_facs, **kwargs)
output_facs = [OutputLayerFactory( type='linear' ) for s in range(4)] output_facs = [OutputLayerFactory(type='linear') for s in range(4)]
self.edge_decoder = DispNetShallow(*args, output_facs=output_facs, **kwargs) self.edge_decoder = DispNetShallow(*args, output_facs=output_facs, **kwargs)
def tforward(self, x): def tforward(self, x):
@ -328,7 +340,7 @@ class PosToDepth(DispToDepth):
self.im_height = im_height self.im_height = im_height
self.im_width = im_width self.im_width = im_width
self.u_pos = torch.arange(im_width, dtype=torch.float32).view(1,1,1,-1) self.u_pos = torch.arange(im_width, dtype=torch.float32).view(1, 1, 1, -1)
def tforward(self, pos): def tforward(self, pos):
self.u_pos = self.u_pos.to(pos.device) self.u_pos = self.u_pos.to(pos.device)
@ -336,11 +348,11 @@ class PosToDepth(DispToDepth):
return super().forward(disp) return super().forward(disp)
class RectifiedPatternSimilarityLoss(TimedModule): class RectifiedPatternSimilarityLoss(TimedModule):
''' '''
Photometric Loss Photometric Loss
''' '''
def __init__(self, im_height, im_width, pattern, loss_type='census_sad', loss_eps=0.5): def __init__(self, im_height, im_width, pattern, loss_type='census_sad', loss_eps=0.5):
super().__init__(mod_name='RectifiedPatternSimilarityLoss') super().__init__(mod_name='RectifiedPatternSimilarityLoss')
self.im_height = im_height self.im_height = im_height
@ -348,8 +360,8 @@ class RectifiedPatternSimilarityLoss(TimedModule):
self.pattern = pattern.mean(dim=1, keepdim=True).contiguous() self.pattern = pattern.mean(dim=1, keepdim=True).contiguous()
u, v = np.meshgrid(range(im_width), range(im_height)) u, v = np.meshgrid(range(im_width), range(im_height))
uv0 = np.stack((u,v), axis=2).reshape(-1,1) uv0 = np.stack((u, v), axis=2).reshape(-1, 1)
uv0 = uv0.astype(np.float32).reshape(1,-1,2) uv0 = uv0.astype(np.float32).reshape(1, -1, 2)
self.uv0 = torch.from_numpy(uv0) self.uv0 = torch.from_numpy(uv0)
self.loss_type = loss_type self.loss_type = loss_type
@ -361,82 +373,84 @@ class RectifiedPatternSimilarityLoss(TimedModule):
uv0 = self.uv0.expand(disp0.shape[0], *self.uv0.shape[1:]) uv0 = self.uv0.expand(disp0.shape[0], *self.uv0.shape[1:])
uv1 = torch.empty_like(uv0) uv1 = torch.empty_like(uv0)
uv1[...,0] = uv0[...,0] - disp0.contiguous().view(disp0.shape[0],-1) uv1[..., 0] = uv0[..., 0] - disp0.contiguous().view(disp0.shape[0], -1)
uv1[...,1] = uv0[...,1] uv1[..., 1] = uv0[..., 1]
uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width-1) - 0.5) uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width - 1) - 0.5)
uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height-1) - 0.5) uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height - 1) - 0.5)
uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone() uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone()
pattern = self.pattern.expand(disp0.shape[0], *self.pattern.shape[1:]) pattern = self.pattern.expand(disp0.shape[0], *self.pattern.shape[1:])
pattern_proj = torch.nn.functional.grid_sample(pattern, uv1, padding_mode='border') pattern_proj = torch.nn.functional.grid_sample(pattern, uv1, padding_mode='border')
mask = torch.ones_like(im) mask = torch.ones_like(im)
if std is not None: if std is not None:
mask = mask*std mask = mask * std
diff = torchext.photometric_loss(pattern_proj.contiguous(), im.contiguous(), 9, self.loss_type, self.loss_eps) diff = torchext.photometric_loss(pattern_proj.contiguous(), im.contiguous(), 9, self.loss_type, self.loss_eps)
val = (mask*diff).sum() / mask.sum() val = (mask * diff).sum() / mask.sum()
return val, pattern_proj return val, pattern_proj
class DisparityLoss(TimedModule): class DisparityLoss(TimedModule):
''' '''
Disparity Loss Disparity Loss
''' '''
def __init__(self): def __init__(self):
super().__init__(mod_name='DisparityLoss') super().__init__(mod_name='DisparityLoss')
self.sobel = SobelFilter(norm=False) self.sobel = SobelFilter(norm=False)
#if not edge_gt: # if not edge_gt:
self.b0=0.0503428816795 self.b0 = 0.0503428816795
self.b1=1.07274045944 self.b1 = 1.07274045944
#else: # else:
# self.b0=0.0587115108967 # self.b0=0.0587115108967
# self.b1=1.51931190491 # self.b1=1.51931190491
def tforward(self, disp, edge=None): def tforward(self, disp, edge=None):
self.sobel=self.sobel.to(disp.device) self.sobel = self.sobel.to(disp.device)
if edge is not None: if edge is not None:
grad = self.sobel(disp) grad = self.sobel(disp)
grad = torch.sqrt(grad[:,0:1,...]**2 + grad[:,1:2,...]**2 + 1e-8) grad = torch.sqrt(grad[:, 0:1, ...] ** 2 + grad[:, 1:2, ...] ** 2 + 1e-8)
pdf = (1-edge)/self.b0 * torch.exp(-torch.abs(grad)/self.b0) + \ pdf = (1 - edge) / self.b0 * torch.exp(-torch.abs(grad) / self.b0) + \
edge/self.b1 * torch.exp(-torch.abs(grad)/self.b1) edge / self.b1 * torch.exp(-torch.abs(grad) / self.b1)
val = torch.mean(-torch.log(pdf.clamp(min=1e-4))) val = torch.mean(-torch.log(pdf.clamp(min=1e-4)))
else: else:
# on qifeng's data we don't have ambient info # on qifeng's data we don't have ambient info
# therefore we supress edge everywhere # therefore we supress edge everywhere
grad = self.sobel(disp) grad = self.sobel(disp)
grad = torch.sqrt(grad[:,0:1,...]**2 + grad[:,1:2,...]**2 + 1e-8) grad = torch.sqrt(grad[:, 0:1, ...] ** 2 + grad[:, 1:2, ...] ** 2 + 1e-8)
grad= torch.clamp(grad, 0, 1.0) grad = torch.clamp(grad, 0, 1.0)
val = torch.mean(grad) val = torch.mean(grad)
return val return val
class ProjectionBaseLoss(TimedModule): class ProjectionBaseLoss(TimedModule):
''' '''
Base module of the Geometric Loss Base module of the Geometric Loss
''' '''
def __init__(self, K, Ki, im_height, im_width): def __init__(self, K, Ki, im_height, im_width):
super().__init__(mod_name='ProjectionBaseLoss') super().__init__(mod_name='ProjectionBaseLoss')
self.K = K.view(-1,3,3) self.K = K.view(-1, 3, 3)
self.im_height = im_height self.im_height = im_height
self.im_width = im_width self.im_width = im_width
u, v = np.meshgrid(range(im_width), range(im_height)) u, v = np.meshgrid(range(im_width), range(im_height))
uv = np.stack((u,v,np.ones_like(u)), axis=2).reshape(-1,3) uv = np.stack((u, v, np.ones_like(u)), axis=2).reshape(-1, 3)
ray = uv @ Ki.numpy().T ray = uv @ Ki.numpy().T
ray = ray.reshape(1,-1,3).astype(np.float32) ray = ray.reshape(1, -1, 3).astype(np.float32)
self.ray = torch.from_numpy(ray) self.ray = torch.from_numpy(ray)
def transform(self, xyz, R=None, t=None): def transform(self, xyz, R=None, t=None):
if t is not None: if t is not None:
bs = xyz.shape[0] bs = xyz.shape[0]
xyz = xyz - t.reshape(bs,1,3) xyz = xyz - t.reshape(bs, 1, 3)
if R is not None: if R is not None:
xyz = torch.bmm(xyz, R) xyz = torch.bmm(xyz, R)
return xyz return xyz
@ -445,7 +459,7 @@ class ProjectionBaseLoss(TimedModule):
self.ray = self.ray.to(depth.device) self.ray = self.ray.to(depth.device)
bs = depth.shape[0] bs = depth.shape[0]
xyz = depth.reshape(bs,-1,1) * self.ray xyz = depth.reshape(bs, -1, 1) * self.ray
xyz = self.transform(xyz, R, t) xyz = self.transform(xyz, R, t)
return xyz return xyz
@ -453,19 +467,18 @@ class ProjectionBaseLoss(TimedModule):
self.K = self.K.to(xyz.device) self.K = self.K.to(xyz.device)
bs = xyz.shape[0] bs = xyz.shape[0]
xyz = torch.bmm(xyz, R.transpose(1,2)) xyz = torch.bmm(xyz, R.transpose(1, 2))
xyz = xyz + t.reshape(bs,1,3) xyz = xyz + t.reshape(bs, 1, 3)
Kt = self.K.transpose(1,2).expand(bs,-1,-1) Kt = self.K.transpose(1, 2).expand(bs, -1, -1)
uv = torch.bmm(xyz, Kt) uv = torch.bmm(xyz, Kt)
d = uv[:,:,2:3] d = uv[:, :, 2:3]
# avoid division by zero # avoid division by zero
uv = uv[:,:,:2] / (torch.nn.functional.relu(d) + 1e-12) uv = uv[:, :, :2] / (torch.nn.functional.relu(d) + 1e-12)
return uv, d return uv, d
def tforward(self, depth0, R0, t0, R1, t1): def tforward(self, depth0, R0, t0, R1, t1):
xyz = self.unproject(depth0, R0, t0) xyz = self.unproject(depth0, R0, t0)
return self.project(xyz, R1, t1) return self.project(xyz, R1, t1)
@ -475,6 +488,7 @@ class ProjectionDepthSimilarityLoss(ProjectionBaseLoss):
''' '''
Geometric Loss Geometric Loss
''' '''
def __init__(self, *args, clamp=-1): def __init__(self, *args, clamp=-1):
super().__init__(*args) super().__init__(*args)
self.mod_name = 'ProjectionDepthSimilarityLoss' self.mod_name = 'ProjectionDepthSimilarityLoss'
@ -483,8 +497,8 @@ class ProjectionDepthSimilarityLoss(ProjectionBaseLoss):
def fwd(self, depth0, depth1, R0, t0, R1, t1): def fwd(self, depth0, depth1, R0, t0, R1, t1):
uv1, d1 = super().tforward(depth0, R0, t0, R1, t1) uv1, d1 = super().tforward(depth0, R0, t0, R1, t1)
uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width-1) - 0.5) uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width - 1) - 0.5)
uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height-1) - 0.5) uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height - 1) - 0.5)
uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone() uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone()
depth10 = torch.nn.functional.grid_sample(depth1, uv1, padding_mode='border') depth10 = torch.nn.functional.grid_sample(depth1, uv1, padding_mode='border')
@ -500,21 +514,21 @@ class ProjectionDepthSimilarityLoss(ProjectionBaseLoss):
def tforward(self, depth0, depth1, R0, t0, R1, t1): def tforward(self, depth0, depth1, R0, t0, R1, t1):
l0 = self.fwd(depth0, depth1, R0, t0, R1, t1) l0 = self.fwd(depth0, depth1, R0, t0, R1, t1)
l1 = self.fwd(depth1, depth0, R1, t1, R0, t0) l1 = self.fwd(depth1, depth0, R1, t1, R0, t0)
return l0+l1 return l0 + l1
class LCN(TimedModule): class LCN(TimedModule):
''' '''
Local Contract Normalization Local Contract Normalization
''' '''
def __init__(self, radius, epsilon): def __init__(self, radius, epsilon):
super().__init__(mod_name='LCN') super().__init__(mod_name='LCN')
self.box_conv = torch.nn.Sequential( self.box_conv = torch.nn.Sequential(
torch.nn.ReflectionPad2d(radius), torch.nn.ReflectionPad2d(radius),
torch.nn.Conv2d(1, 1, kernel_size=2*radius+1, bias=False) torch.nn.Conv2d(1, 1, kernel_size=2 * radius + 1, bias=False)
) )
self.box_conv[1].weight.requires_grad=False self.box_conv[1].weight.requires_grad = False
self.box_conv[1].weight.fill_(1.) self.box_conv[1].weight.fill_(1.)
self.epsilon = epsilon self.epsilon = epsilon
@ -523,44 +537,43 @@ class LCN(TimedModule):
def tforward(self, data): def tforward(self, data):
boxs = self.box_conv(data) boxs = self.box_conv(data)
avgs = boxs / (2*self.radius+1)**2 avgs = boxs / (2 * self.radius + 1) ** 2
boxs_n2 = boxs**2 boxs_n2 = boxs ** 2
boxs_2n = self.box_conv(data**2) boxs_2n = self.box_conv(data ** 2)
stds = torch.sqrt(boxs_2n / (2*self.radius+1)**2 - avgs**2 + 1e-6) stds = torch.sqrt(boxs_2n / (2 * self.radius + 1) ** 2 - avgs ** 2 + 1e-6)
stds = stds + self.epsilon stds = stds + self.epsilon
return (data - avgs) / stds, stds return (data - avgs) / stds, stds
class SobelFilter(TimedModule): class SobelFilter(TimedModule):
''' '''
Sobel Filter Sobel Filter
''' '''
def __init__(self, norm=False): def __init__(self, norm=False):
super(SobelFilter, self).__init__(mod_name='SobelFilter') super(SobelFilter, self).__init__(mod_name='SobelFilter')
kx = np.array([[-5, -4, 0, 4, 5], kx = np.array([[-5, -4, 0, 4, 5],
[-8, -10, 0, 10, 8], [-8, -10, 0, 10, 8],
[-10, -20, 0, 20, 10], [-10, -20, 0, 20, 10],
[-8, -10, 0, 10, 8], [-8, -10, 0, 10, 8],
[-5, -4, 0, 4, 5]])/240.0 [-5, -4, 0, 4, 5]]) / 240.0
ky = kx.copy().transpose(1,0) ky = kx.copy().transpose(1, 0)
self.conv_x=torch.nn.Conv2d(1, 1, kernel_size=5, stride=1, padding=0, bias=False) self.conv_x = torch.nn.Conv2d(1, 1, kernel_size=5, stride=1, padding=0, bias=False)
self.conv_x.weight=torch.nn.Parameter(torch.from_numpy(kx).float().unsqueeze(0).unsqueeze(0)) self.conv_x.weight = torch.nn.Parameter(torch.from_numpy(kx).float().unsqueeze(0).unsqueeze(0))
self.conv_y=torch.nn.Conv2d(1, 1, kernel_size=5, stride=1, padding=0, bias=False) self.conv_y = torch.nn.Conv2d(1, 1, kernel_size=5, stride=1, padding=0, bias=False)
self.conv_y.weight=torch.nn.Parameter(torch.from_numpy(ky).float().unsqueeze(0).unsqueeze(0)) self.conv_y.weight = torch.nn.Parameter(torch.from_numpy(ky).float().unsqueeze(0).unsqueeze(0))
self.norm=norm self.norm = norm
def tforward(self,x): def tforward(self, x):
x = F.pad(x, (2,2,2,2), "replicate") x = F.pad(x, (2, 2, 2, 2), "replicate")
gx = self.conv_x(x) gx = self.conv_x(x)
gy = self.conv_y(x) gy = self.conv_y(x)
if self.norm: if self.norm:
return torch.sqrt(gx**2 + gy**2 + 1e-8) return torch.sqrt(gx ** 2 + gy ** 2 + 1e-8)
else: else:
return torch.cat((gx, gy), dim=1) return torch.cat((gx, gy), dim=1)

@ -6,7 +6,9 @@ This repository contains the code for the paper
**[Connecting the Dots: Learning Representations for Active Monocular Depth Estimation](http://www.cvlibs.net/publications/Riegler2019CVPR.pdf)** **[Connecting the Dots: Learning Representations for Active Monocular Depth Estimation](http://www.cvlibs.net/publications/Riegler2019CVPR.pdf)**
<br> <br>
[Gernot Riegler](https://griegler.github.io/), [Yiyi Liao](https://yiyiliao.github.io/), [Simon Donne](https://avg.is.tuebingen.mpg.de/person/sdonne), [Vladlen Koltun](http://vladlen.info/), and [Andreas Geiger](http://www.cvlibs.net/) [Gernot Riegler](https://griegler.github.io/), [Yiyi Liao](https://yiyiliao.github.io/)
, [Simon Donne](https://avg.is.tuebingen.mpg.de/person/sdonne), [Vladlen Koltun](http://vladlen.info/),
and [Andreas Geiger](http://www.cvlibs.net/)
<br> <br>
[CVPR 2019](http://cvpr2019.thecvf.com/) [CVPR 2019](http://cvpr2019.thecvf.com/)
@ -24,40 +26,45 @@ If you find this code useful for your research, please cite
} }
``` ```
## Dependencies ## Dependencies
The network training/evaluation code is based on `Pytorch`. The network training/evaluation code is based on `Pytorch`.
``` ```
PyTorch>=1.1 PyTorch>=1.1
Cuda>=10.0 Cuda>=10.0
``` ```
Updated on 07.06.2021: The code is now compatible with the latest Pytorch version (1.8). Updated on 07.06.2021: The code is now compatible with the latest Pytorch version (1.8).
The other python packages can be installed with `anaconda`: The other python packages can be installed with `anaconda`:
``` ```
conda install --file requirements.txt conda install --file requirements.txt
``` ```
### Structured Light Renderer ### Structured Light Renderer
To train and evaluate our method in a controlled setting, we implemented an structured light renderer.
It can be used to render a virtual scene (arbitrary triangle mesh) with the structured light pattern projected from a customizable projector location. To train and evaluate our method in a controlled setting, we implemented an structured light renderer. It can be used to
To build it, first make sure the correct `CUDA_LIBRARY_PATH` is set in `config.json`. render a virtual scene (arbitrary triangle mesh) with the structured light pattern projected from a customizable
Afterwards, the renderer can be build by running `make` within the `renderer` directory. projector location. To build it, first make sure the correct `CUDA_LIBRARY_PATH` is set in `config.json`. Afterwards,
the renderer can be build by running `make` within the `renderer` directory.
### PyTorch Extensions ### PyTorch Extensions
The network training/evaluation code is based on `PyTorch`.
We implemented some custom layers that need to be built in the `torchext` directory. The network training/evaluation code is based on `PyTorch`. We implemented some custom layers that need to be built in
Simply change into this directory and run the `torchext` directory. Simply change into this directory and run
``` ```
python setup.py build_ext --inplace python setup.py build_ext --inplace
``` ```
### Baseline HyperDepth ### Baseline HyperDepth
As baseline we partially re-implemented the random forest based method [HyperDepth](http://openaccess.thecvf.com/content_cvpr_2016/papers/Fanello_HyperDepth_Learning_Depth_CVPR_2016_paper.pdf).
The code resided in the `hyperdepth` directory and is implemented in `C++11` with a Python wrapper written in `Cython`. As baseline we partially re-implemented the random forest based
To build it change into the directory and run method [HyperDepth](http://openaccess.thecvf.com/content_cvpr_2016/papers/Fanello_HyperDepth_Learning_Depth_CVPR_2016_paper.pdf)
. The code resided in the `hyperdepth` directory and is implemented in `C++11` with a Python wrapper written in `Cython`
. To build it change into the directory and run
``` ```
python setup.py build_ext --inplace python setup.py build_ext --inplace
@ -65,42 +72,59 @@ python setup.py build_ext --inplace
## Running ## Running
### Creating Synthetic Data ### Creating Synthetic Data
To create synthetic data and save it locally, download [ShapeNet V2](https://www.shapenet.org/) and correct `SHAPENET_ROOT` in `config.json`. Then the data can be generated and saved to `DATA_ROOT` in `config.json` by running
To create synthetic data and save it locally, download [ShapeNet V2](https://www.shapenet.org/) and
correct `SHAPENET_ROOT` in `config.json`. Then the data can be generated and saved to `DATA_ROOT` in `config.json` by
running
``` ```
./create_syn_data.sh ./create_syn_data.sh
``` ```
If you are only interested in evaluating our pre-trained model, [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip) is a validation set that contains a small amount of images.
If you are only interested in evaluating our pre-trained
model, [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip) is a
validation set that contains a small amount of images.
### Training Network ### Training Network
As a first stage, it is recommended to train the disparity decoder and edge decoder without the geometric loss. To train the network on synthetic data for the first stage run As a first stage, it is recommended to train the disparity decoder and edge decoder without the geometric loss. To train
the network on synthetic data for the first stage run
``` ```
python train_val.py python train_val.py
``` ```
After the model is pretrained without the geometric loss, the full model can be trained from the initialized weights by running After the model is pretrained without the geometric loss, the full model can be trained from the initialized weights by
running
``` ```
python train_val.py --loss phge python train_val.py --loss phge
``` ```
### Evaluating Network ### Evaluating Network
To evaluate a specific checkpoint, e.g. the 50th epoch, one can run To evaluate a specific checkpoint, e.g. the 50th epoch, one can run
``` ```
python train_val.py --cmd retest --epoch 50 python train_val.py --cmd retest --epoch 50
``` ```
### Evaluating a Pre-trained Model ### Evaluating a Pre-trained Model
We provide a model pre-trained using the photometric loss. Once you have prepared the synthetic dataset and changed `DATA_ROOT` in `config.json`, the pre-trained model can be evaluated on the validation set by running:
We provide a model pre-trained using the photometric loss. Once you have prepared the synthetic dataset and
changed `DATA_ROOT` in `config.json`, the pre-trained model can be evaluated on the validation set by running:
``` ```
mkdir -p output mkdir -p output
mkdir -p output/exp_syn mkdir -p output/exp_syn
wget -O output/exp_syn/net_0099.params https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/net_0099.params wget -O output/exp_syn/net_0099.params https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/net_0099.params
python train_val.py --cmd retest --epoch 99 python train_val.py --cmd retest --epoch 99
``` ```
You can also download our validation set from [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip).
You can also download our validation set
from [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip).
## Acknowledgement ## Acknowledgement
This work was supported by the Intel Network on Intelligent Systems. This work was supported by the Intel Network on Intelligent Systems.

@ -31,7 +31,7 @@ libraries.append(cuda_lib)
setup( setup(
name="cyrender", name="cyrender",
cmdclass= {'build_ext': build_ext}, cmdclass={'build_ext': build_ext},
ext_modules=[ ext_modules=[
Extension('cyrender', Extension('cyrender',
sources, sources,

@ -2,18 +2,19 @@ import torch
import torch.utils.data import torch.utils.data
import numpy as np import numpy as np
class TestSet(object): class TestSet(object):
def __init__(self, name, dset, test_frequency=1): def __init__(self, name, dset, test_frequency=1):
self.name = name self.name = name
self.dset = dset self.dset = dset
self.test_frequency = test_frequency self.test_frequency = test_frequency
class TestSets(list): class TestSets(list):
def append(self, name, dset, test_frequency=1): def append(self, name, dset, test_frequency=1):
super().append(TestSet(name, dset, test_frequency)) super().append(TestSet(name, dset, test_frequency))
class MultiDataset(torch.utils.data.Dataset): class MultiDataset(torch.utils.data.Dataset):
def __init__(self, *datasets): def __init__(self, *datasets):
self.current_epoch = 0 self.current_epoch = 0
@ -46,7 +47,6 @@ class MultiDataset(torch.utils.data.Dataset):
return self.datasets[didx][sidx] return self.datasets[didx][sidx]
class BaseDataset(torch.utils.data.Dataset): class BaseDataset(torch.utils.data.Dataset):
def __init__(self, train=True, fix_seed_per_epoch=False): def __init__(self, train=True, fix_seed_per_epoch=False):
self.current_epoch = 0 self.current_epoch = 0

@ -2,6 +2,7 @@ import torch
from . import ext_cpu from . import ext_cpu
from . import ext_cuda from . import ext_cuda
class NNFunction(torch.autograd.Function): class NNFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, in0, in1): def forward(ctx, in0, in1):
@ -16,6 +17,7 @@ class NNFunction(torch.autograd.Function):
def backward(ctx, grad_out): def backward(ctx, grad_out):
return None, None return None, None
def nn(in0, in1): def nn(in0, in1):
return NNFunction.apply(in0, in1) return NNFunction.apply(in0, in1)
@ -34,9 +36,11 @@ class CrossCheckFunction(torch.autograd.Function):
def backward(ctx, grad_out): def backward(ctx, grad_out):
return None, None return None, None
def crosscheck(in0, in1): def crosscheck(in0, in1):
return CrossCheckFunction.apply(in0, in1) return CrossCheckFunction.apply(in0, in1)
class ProjNNFunction(torch.autograd.Function): class ProjNNFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, xyz0, xyz1, K, patch_size): def forward(ctx, xyz0, xyz1, K, patch_size):
@ -51,11 +55,11 @@ class ProjNNFunction(torch.autograd.Function):
def backward(ctx, grad_out): def backward(ctx, grad_out):
return None, None, None, None return None, None, None, None
def proj_nn(xyz0, xyz1, K, patch_size): def proj_nn(xyz0, xyz1, K, patch_size):
return ProjNNFunction.apply(xyz0, xyz1, K, patch_size) return ProjNNFunction.apply(xyz0, xyz1, K, patch_size)
class XCorrVolFunction(torch.autograd.Function): class XCorrVolFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, in0, in1, n_disps, block_size): def forward(ctx, in0, in1, n_disps, block_size):
@ -70,12 +74,11 @@ class XCorrVolFunction(torch.autograd.Function):
def backward(ctx, grad_out): def backward(ctx, grad_out):
return None, None, None, None return None, None, None, None
def xcorrvol(in0, in1, n_disps, block_size): def xcorrvol(in0, in1, n_disps, block_size):
return XCorrVolFunction.apply(in0, in1, n_disps, block_size) return XCorrVolFunction.apply(in0, in1, n_disps, block_size)
class PhotometricLossFunction(torch.autograd.Function): class PhotometricLossFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, es, ta, block_size, type, eps): def forward(ctx, es, ta, block_size, type, eps):
@ -103,6 +106,7 @@ class PhotometricLossFunction(torch.autograd.Function):
grad_es = ext_cpu.photometric_loss_backward(*args) grad_es = ext_cpu.photometric_loss_backward(*args)
return grad_es, None, None, None, None return grad_es, None, None, None, None
def photometric_loss(es, ta, block_size, type='mse', eps=0.1): def photometric_loss(es, ta, block_size, type='mse', eps=0.1):
type = type.lower() type = type.lower()
if type == 'mse': if type == 'mse':
@ -117,17 +121,18 @@ def photometric_loss(es, ta, block_size, type='mse', eps=0.1):
raise Exception('invalid loss type') raise Exception('invalid loss type')
return PhotometricLossFunction.apply(es, ta, block_size, type, eps) return PhotometricLossFunction.apply(es, ta, block_size, type, eps)
def photometric_loss_pytorch(es, ta, block_size, type='mse', eps=0.1): def photometric_loss_pytorch(es, ta, block_size, type='mse', eps=0.1):
type = type.lower() type = type.lower()
p = block_size // 2 p = block_size // 2
es_pad = torch.nn.functional.pad(es, (p,p,p,p), mode='replicate') es_pad = torch.nn.functional.pad(es, (p, p, p, p), mode='replicate')
ta_pad = torch.nn.functional.pad(ta, (p,p,p,p), mode='replicate') ta_pad = torch.nn.functional.pad(ta, (p, p, p, p), mode='replicate')
es_uf = torch.nn.functional.unfold(es_pad, kernel_size=block_size) es_uf = torch.nn.functional.unfold(es_pad, kernel_size=block_size)
ta_uf = torch.nn.functional.unfold(ta_pad, kernel_size=block_size) ta_uf = torch.nn.functional.unfold(ta_pad, kernel_size=block_size)
es_uf = es_uf.view(es.shape[0], es.shape[1], -1, es.shape[2], es.shape[3]) es_uf = es_uf.view(es.shape[0], es.shape[1], -1, es.shape[2], es.shape[3])
ta_uf = ta_uf.view(ta.shape[0], ta.shape[1], -1, ta.shape[2], ta.shape[3]) ta_uf = ta_uf.view(ta.shape[0], ta.shape[1], -1, ta.shape[2], ta.shape[3])
if type == 'mse': if type == 'mse':
ref = (es_uf - ta_uf)**2 ref = (es_uf - ta_uf) ** 2
elif type == 'sad': elif type == 'sad':
ref = torch.abs(es_uf - ta_uf) ref = torch.abs(es_uf - ta_uf)
elif type == 'census_mse' or type == 'census_sad': elif type == 'census_mse' or type == 'census_sad':
@ -143,5 +148,5 @@ def photometric_loss_pytorch(es, ta, block_size, type='mse', eps=0.1):
else: else:
raise Exception('invalid loss type') raise Exception('invalid loss type')
ref = ref.view(es.shape[0], -1, es.shape[2], es.shape[3]) ref = ref.view(es.shape[0], -1, es.shape[2], es.shape[3])
ref = torch.sum(ref, dim=1, keepdim=True) / block_size**2 ref = torch.sum(ref, dim=1, keepdim=True) / block_size ** 2
return ref return ref

@ -4,11 +4,13 @@ import numpy as np
from .functions import * from .functions import *
class CoordConv2d(torch.nn.Module): class CoordConv2d(torch.nn.Module):
def __init__(self, channels_in, channels_out, kernel_size, stride, padding): def __init__(self, channels_in, channels_out, kernel_size, stride, padding):
super().__init__() super().__init__()
self.conv = torch.nn.Conv2d(channels_in+2, channels_out, kernel_size=kernel_size, padding=padding, stride=stride) self.conv = torch.nn.Conv2d(channels_in + 2, channels_out, kernel_size=kernel_size, padding=padding,
stride=stride)
self.uv = None self.uv = None
@ -19,7 +21,7 @@ class CoordConv2d(torch.nn.Module):
u = 2 * u / (width - 1) - 1 u = 2 * u / (width - 1) - 1
v = 2 * v / (height - 1) - 1 v = 2 * v / (height - 1) - 1
uv = np.stack((u, v)).reshape(1, 2, height, width) uv = np.stack((u, v)).reshape(1, 2, height, width)
self.uv = torch.from_numpy( uv.astype(np.float32) ) self.uv = torch.from_numpy(uv.astype(np.float32))
self.uv = self.uv.to(x.device) self.uv = self.uv.to(x.device)
uv = self.uv.expand(x.shape[0], *self.uv.shape[1:]) uv = self.uv.expand(x.shape[0], *self.uv.shape[1:])
xuv = torch.cat((x, uv), dim=1) xuv = torch.cat((x, uv), dim=1)

@ -14,7 +14,8 @@ setup(
name='ext', name='ext',
ext_modules=[ ext_modules=[
CppExtension('ext_cpu', ['ext/ext_cpu.cpp']), CppExtension('ext_cpu', ['ext/ext_cpu.cpp']),
CUDAExtension('ext_cuda', ['ext/ext_cuda.cpp', 'ext/ext_kernel.cu'], extra_compile_args={'cxx': [], 'nvcc': nvcc_args}), CUDAExtension('ext_cuda', ['ext/ext_cuda.cpp', 'ext/ext_kernel.cu'],
extra_compile_args={'cxx': [], 'nvcc': nvcc_args}),
], ],
cmdclass={'build_ext': BuildExtension}, cmdclass={'build_ext': BuildExtension},
include_dirs=include_dirs include_dirs=include_dirs

@ -39,9 +39,10 @@ class StopWatch(object):
return ret return ret
def __repr__(self): def __repr__(self):
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()]) return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
def __str__(self): def __str__(self):
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()]) return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
class ETA(object): class ETA(object):
@ -77,8 +78,10 @@ class ETA(object):
def get_remaining_time_str(self): def get_remaining_time_str(self):
return self.format_time(self.get_remaining_time()) return self.format_time(self.get_remaining_time())
class Worker(object): class Worker(object):
def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16, num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1): def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16,
num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1):
self.out_root = Path(out_root) self.out_root = Path(out_root)
self.experiment_name = experiment_name self.experiment_name = experiment_name
self.epochs = epochs self.epochs = epochs
@ -91,7 +94,7 @@ class Worker(object):
self.test_device = test_device self.test_device = test_device
self.max_train_iter = max_train_iter self.max_train_iter = max_train_iter
self.errs_list=[] self.errs_list = []
self.setup_experiment() self.setup_experiment()
@ -103,17 +106,17 @@ class Worker(object):
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
handlers=[ handlers=[
logging.FileHandler( str(self.exp_out_root / 'train.log') ), logging.FileHandler(str(self.exp_out_root / 'train.log')),
logging.StreamHandler() logging.StreamHandler()
], ],
format='%(relativeCreated)d:%(levelname)s:%(process)d-%(processName)s: %(message)s' format='%(relativeCreated)d:%(levelname)s:%(process)d-%(processName)s: %(message)s'
) )
logging.info('='*80) logging.info('=' * 80)
logging.info(f'Start of experiment: {self.experiment_name}') logging.info(f'Start of experiment: {self.experiment_name}')
logging.info(socket.gethostname()) logging.info(socket.gethostname())
self.log_datetime() self.log_datetime()
logging.info('='*80) logging.info('=' * 80)
self.metric_path = self.exp_out_root / 'metrics.json' self.metric_path = self.exp_out_root / 'metrics.json'
if self.metric_path.exists(): if self.metric_path.exists():
@ -220,32 +223,31 @@ class Worker(object):
def format_err_str(self, errs, div=1): def format_err_str(self, errs, div=1):
err = sum(errs) err = sum(errs)
if len(errs) > 1: if len(errs) > 1:
err_str = f'{err/div:0.4f}=' + '+'.join([f'{e/div:0.4f}' for e in errs]) err_str = f'{err / div:0.4f}=' + '+'.join([f'{e / div:0.4f}' for e in errs])
else: else:
err_str = f'{err/div:0.4f}' err_str = f'{err / div:0.4f}'
return err_str return err_str
def write_err_img(self): def write_err_img(self):
err_img_path = self.exp_out_root / 'errs.png' err_img_path = self.exp_out_root / 'errs.png'
fig = plt.figure(figsize=(16,16)) fig = plt.figure(figsize=(16, 16))
lines=[] lines = []
for idx,errs in enumerate(self.errs_list): for idx, errs in enumerate(self.errs_list):
line,=plt.plot(range(len(errs)), errs, label=f'error{idx}') line, = plt.plot(range(len(errs)), errs, label=f'error{idx}')
lines.append(line) lines.append(line)
plt.tight_layout() plt.tight_layout()
plt.legend(handles=lines) plt.legend(handles=lines)
plt.savefig(str(err_img_path)) plt.savefig(str(err_img_path))
plt.close(fig) plt.close(fig)
def callback_train_new_epoch(self, epoch, net, optimizer): def callback_train_new_epoch(self, epoch, net, optimizer):
pass pass
def train(self, net, optimizer, resume=False, scheduler=None): def train(self, net, optimizer, resume=False, scheduler=None):
logging.info('='*80) logging.info('=' * 80)
logging.info('Start training') logging.info('Start training')
self.log_datetime() self.log_datetime()
logging.info('='*80) logging.info('=' * 80)
train_set = self.get_train_set() train_set = self.get_train_set()
test_sets = self.get_test_sets() test_sets = self.get_test_sets()
@ -257,9 +259,9 @@ class Worker(object):
state_path = self.exp_out_root / 'state.dict' state_path = self.exp_out_root / 'state.dict'
if resume and state_path.exists(): if resume and state_path.exists():
logging.info('='*80) logging.info('=' * 80)
logging.info(f'Loading state from {state_path}') logging.info(f'Loading state from {state_path}')
logging.info('='*80) logging.info('=' * 80)
state = torch.load(str(state_path)) state = torch.load(str(state_path))
epoch = state['epoch'] + 1 epoch = state['epoch'] + 1
if 'min_err' in state: if 'min_err' in state:
@ -269,7 +271,6 @@ class Worker(object):
curr_state.update(state['state_dict']) curr_state.update(state['state_dict'])
net.load_state_dict(curr_state) net.load_state_dict(curr_state)
try: try:
optimizer.load_state_dict(state['optimizer']) optimizer.load_state_dict(state['optimizer'])
except: except:
@ -321,10 +322,10 @@ class Worker(object):
if scheduler is not None: if scheduler is not None:
scheduler.step() scheduler.step()
logging.info('='*80) logging.info('=' * 80)
logging.info('Finished training') logging.info('Finished training')
self.log_datetime() self.log_datetime()
logging.info('='*80) logging.info('=' * 80)
def get_train_set(self): def get_train_set(self):
# returns train_set # returns train_set
@ -363,11 +364,12 @@ class Worker(object):
self.callback_train_start(epoch) self.callback_train_start(epoch)
stopwatch = StopWatch() stopwatch = StopWatch()
logging.info('='*80) logging.info('=' * 80)
logging.info('Train epoch %d' % epoch) logging.info('Train epoch %d' % epoch)
dset.current_epoch = epoch dset.current_epoch = epoch
train_loader = torch.utils.data.DataLoader(dset, batch_size=self.train_batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True, pin_memory=False) train_loader = torch.utils.data.DataLoader(dset, batch_size=self.train_batch_size, shuffle=True,
num_workers=self.num_workers, drop_last=True, pin_memory=False)
net = net.to(self.train_device) net = net.to(self.train_device)
net.train() net.train()
@ -418,9 +420,9 @@ class Worker(object):
bar.update(batch_idx) bar.update(batch_idx)
if (epoch <= 1 and batch_idx < 128) or batch_idx % 16 == 0: if (epoch <= 1 and batch_idx < 128) or batch_idx % 16 == 0:
err_str = self.format_err_str(errs) err_str = self.format_err_str(errs)
logging.info(f'train e{epoch}: {batch_idx+1}/{len(train_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}') logging.info(
#self.write_err_img() f'train e{epoch}: {batch_idx + 1}/{len(train_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
# self.write_err_img()
if mean_loss is None: if mean_loss is None:
mean_loss = [0 for e in errs] mean_loss = [0 for e in errs]
@ -455,17 +457,18 @@ class Worker(object):
errs = {} errs = {}
for test_set_idx, test_set in enumerate(test_sets): for test_set_idx, test_set in enumerate(test_sets):
if (epoch + 1) % test_set.test_frequency == 0: if (epoch + 1) % test_set.test_frequency == 0:
logging.info('='*80) logging.info('=' * 80)
logging.info(f'testing set {test_set.name}') logging.info(f'testing set {test_set.name}')
err = self.test_epoch(epoch, test_set_idx, net, test_set.dset) err = self.test_epoch(epoch, test_set_idx, net, test_set.dset)
errs[test_set.name] = err errs[test_set.name] = err
return errs return errs
def test_epoch(self, epoch, set_idx, net, dset): def test_epoch(self, epoch, set_idx, net, dset):
logging.info('-'*80) logging.info('-' * 80)
logging.info('Test epoch %d' % epoch) logging.info('Test epoch %d' % epoch)
dset.current_epoch = epoch dset.current_epoch = epoch
test_loader = torch.utils.data.DataLoader(dset, batch_size=self.test_batch_size, shuffle=False, num_workers=self.num_workers, drop_last=False, pin_memory=False) test_loader = torch.utils.data.DataLoader(dset, batch_size=self.test_batch_size, shuffle=False,
num_workers=self.num_workers, drop_last=False, pin_memory=False)
net = net.to(self.test_device) net = net.to(self.test_device)
net.eval() net.eval()
@ -502,7 +505,8 @@ class Worker(object):
bar.update(batch_idx) bar.update(batch_idx)
if batch_idx % 25 == 0: if batch_idx % 25 == 0:
err_str = self.format_err_str(errs) err_str = self.format_err_str(errs)
logging.info(f'test e{epoch}: {batch_idx+1}/{len(test_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}') logging.info(
f'test e{epoch}: {batch_idx + 1}/{len(test_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
if mean_loss is None: if mean_loss is None:
mean_loss = [0 for e in errs] mean_loss = [0 for e in errs]

@ -5,25 +5,24 @@ from model import exp_synphge
from model import networks from model import networks
from co.args import parse_args from co.args import parse_args
# parse args # parse args
args = parse_args() args = parse_args()
# loss types # loss types
if args.loss=='ph': if args.loss == 'ph':
worker = exp_synph.Worker(args) worker = exp_synph.Worker(args)
elif args.loss=='phge': elif args.loss == 'phge':
worker = exp_synphge.Worker(args) worker = exp_synphge.Worker(args)
# concatenation of original image and lcn image # concatenation of original image and lcn image
channels_in=2 channels_in = 2
# set up network # set up network
net = networks.DispEdgeDecoders(channels_in=channels_in, max_disp=args.max_disp, imsizes=worker.imsizes, output_ms=worker.ms) net = networks.DispEdgeDecoders(channels_in=channels_in, max_disp=args.max_disp, imsizes=worker.imsizes,
output_ms=worker.ms)
# optimizer # optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4) optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
# start the work # start the work
worker.do(net, optimizer) worker.do(net, optimizer)

Loading…
Cancel
Save