Reformat $EVERYTHING

master
CptCaptain 3 years ago
parent 56f2aa7d5d
commit 43df77fb9b
  1. 3
      co/__init__.py
  2. 5
      co/args.py
  3. 59
      co/cmap.py
  4. 1336
      co/geometry.py
  5. 42
      co/gtimer.py
  6. 469
      co/io3d.py
  7. 426
      co/metric.py
  8. 161
      co/plt.py
  9. 95
      co/plt2d.py
  10. 68
      co/plt3d.py
  11. 804
      co/table.py
  12. 108
      co/utils.py
  13. 120
      data/commons.py
  14. 454
      data/create_syn_data.py
  15. 238
      data/dataset.py
  16. 395
      data/lcn/lcn.html
  17. 2
      data/lcn/setup.py
  18. 26
      data/lcn/test_lcn.py
  19. 65
      hyperdepth/hyperparam_search.py
  20. 37
      hyperdepth/setup.py
  21. 13
      hyperdepth/vis_eval.py
  22. 475
      model/exp_synph.py
  23. 597
      model/exp_synphge.py
  24. 953
      model/networks.py
  25. 64
      readme.md
  26. 36
      renderer/setup.py
  27. 92
      torchext/dataset.py
  28. 233
      torchext/functions.py
  29. 36
      torchext/modules.py
  30. 15
      torchext/setup.py
  31. 960
      torchext/worker.py
  32. 15
      train_val.py

@ -7,8 +7,9 @@
# set matplotlib backend depending on env # set matplotlib backend depending on env
import os import os
import matplotlib import matplotlib
if os.name == 'posix' and "DISPLAY" not in os.environ: if os.name == 'posix' and "DISPLAY" not in os.environ:
matplotlib.use('Agg') matplotlib.use('Agg')
from . import geometry from . import geometry
from . import plt from . import plt

@ -12,7 +12,7 @@ def parse_args():
parser.add_argument('--loss', parser.add_argument('--loss',
help='Train with \'ph\' for the first stage without geometric loss, \ help='Train with \'ph\' for the first stage without geometric loss, \
train with \'phge\' for the second stage with geometric loss', train with \'phge\' for the second stage with geometric loss',
default='ph', choices=['ph','phge'], type=str) default='ph', choices=['ph', 'phge'], type=str)
parser.add_argument('--data_type', parser.add_argument('--data_type',
default='syn', choices=['syn'], type=str) default='syn', choices=['syn'], type=str)
# #
@ -66,6 +66,3 @@ def parse_args():
def get_exp_name(args): def get_exp_name(args):
name = f"exp_{args.data_type}" name = f"exp_{args.data_type}"
return name return name

@ -1,19 +1,20 @@
import numpy as np import numpy as np
_color_map_errors = np.array([ _color_map_errors = np.array([
[149, 54, 49], #0: log2(x) = -infinity [149, 54, 49], # 0: log2(x) = -infinity
[180, 117, 69], #0.0625: log2(x) = -4 [180, 117, 69], # 0.0625: log2(x) = -4
[209, 173, 116], #0.125: log2(x) = -3 [209, 173, 116], # 0.125: log2(x) = -3
[233, 217, 171], #0.25: log2(x) = -2 [233, 217, 171], # 0.25: log2(x) = -2
[248, 243, 224], #0.5: log2(x) = -1 [248, 243, 224], # 0.5: log2(x) = -1
[144, 224, 254], #1.0: log2(x) = 0 [144, 224, 254], # 1.0: log2(x) = 0
[97, 174, 253], #2.0: log2(x) = 1 [97, 174, 253], # 2.0: log2(x) = 1
[67, 109, 244], #4.0: log2(x) = 2 [67, 109, 244], # 4.0: log2(x) = 2
[39, 48, 215], #8.0: log2(x) = 3 [39, 48, 215], # 8.0: log2(x) = 3
[38, 0, 165], #16.0: log2(x) = 4 [38, 0, 165], # 16.0: log2(x) = 4
[38, 0, 165] #inf: log2(x) = inf [38, 0, 165] # inf: log2(x) = inf
]).astype(float) ]).astype(float)
def color_error_image(errors, scale=1, mask=None, BGR=True): def color_error_image(errors, scale=1, mask=None, BGR=True):
""" """
Color an input error map. Color an input error map.
@ -32,26 +33,28 @@ def color_error_image(errors, scale=1, mask=None, BGR=True):
errors_color_indices = np.clip(np.log2(errors_flat / scale + 1e-5) + 5, 0, 9) errors_color_indices = np.clip(np.log2(errors_flat / scale + 1e-5) + 5, 0, 9)
i0 = np.floor(errors_color_indices).astype(int) i0 = np.floor(errors_color_indices).astype(int)
f1 = errors_color_indices - i0.astype(float) f1 = errors_color_indices - i0.astype(float)
colored_errors_flat = _color_map_errors[i0, :] * (1-f1).reshape(-1,1) + _color_map_errors[i0+1, :] * f1.reshape(-1,1) colored_errors_flat = _color_map_errors[i0, :] * (1 - f1).reshape(-1, 1) + _color_map_errors[i0 + 1,
:] * f1.reshape(-1, 1)
if mask is not None: if mask is not None:
colored_errors_flat[mask.flatten() == 0] = 255 colored_errors_flat[mask.flatten() == 0] = 255
if not BGR: if not BGR:
colored_errors_flat = colored_errors_flat[:,[2,1,0]] colored_errors_flat = colored_errors_flat[:, [2, 1, 0]]
return colored_errors_flat.reshape(errors.shape[0], errors.shape[1], 3).astype(np.int) return colored_errors_flat.reshape(errors.shape[0], errors.shape[1], 3).astype(np.int)
_color_map_depths = np.array([ _color_map_depths = np.array([
[0, 0, 0], # 0.000 [0, 0, 0], # 0.000
[0, 0, 255], # 0.114 [0, 0, 255], # 0.114
[255, 0, 0], # 0.299 [255, 0, 0], # 0.299
[255, 0, 255], # 0.413 [255, 0, 255], # 0.413
[0, 255, 0], # 0.587 [0, 255, 0], # 0.587
[0, 255, 255], # 0.701 [0, 255, 255], # 0.701
[255, 255, 0], # 0.886 [255, 255, 0], # 0.886
[255, 255, 255], # 1.000 [255, 255, 255], # 1.000
[255, 255, 255], # 1.000 [255, 255, 255], # 1.000
]).astype(float) ]).astype(float)
_color_map_bincenters = np.array([ _color_map_bincenters = np.array([
0.0, 0.0,
@ -62,9 +65,10 @@ _color_map_bincenters = np.array([
0.701, 0.701,
0.886, 0.886,
1.000, 1.000,
2.000, # doesn't make a difference, just strictly higher than 1 2.000, # doesn't make a difference, just strictly higher than 1
]) ])
def color_depth_map(depths, scale=None): def color_depth_map(depths, scale=None):
""" """
Color an input depth map. Color an input depth map.
@ -82,12 +86,13 @@ def color_depth_map(depths, scale=None):
values = np.clip(depths.flatten() / scale, 0, 1) values = np.clip(depths.flatten() / scale, 0, 1)
# for each value, figure out where they fit in in the bincenters: what is the last bincenter smaller than this value? # for each value, figure out where they fit in in the bincenters: what is the last bincenter smaller than this value?
lower_bin = ((values.reshape(-1, 1) >= _color_map_bincenters.reshape(1,-1)) * np.arange(0,9)).max(axis=1) lower_bin = ((values.reshape(-1, 1) >= _color_map_bincenters.reshape(1, -1)) * np.arange(0, 9)).max(axis=1)
lower_bin_value = _color_map_bincenters[lower_bin] lower_bin_value = _color_map_bincenters[lower_bin]
higher_bin_value = _color_map_bincenters[lower_bin + 1] higher_bin_value = _color_map_bincenters[lower_bin + 1]
alphas = (values - lower_bin_value) / (higher_bin_value - lower_bin_value) alphas = (values - lower_bin_value) / (higher_bin_value - lower_bin_value)
colors = _color_map_depths[lower_bin] * (1-alphas).reshape(-1,1) + _color_map_depths[lower_bin + 1] * alphas.reshape(-1,1) colors = _color_map_depths[lower_bin] * (1 - alphas).reshape(-1, 1) + _color_map_depths[
lower_bin + 1] * alphas.reshape(-1, 1)
return colors.reshape(depths.shape[0], depths.shape[1], 3).astype(np.uint8) return colors.reshape(depths.shape[0], depths.shape[1], 3).astype(np.uint8)
#from utils.debug import save_color_numpy # from utils.debug import save_color_numpy
#save_color_numpy(color_depth_map(np.matmul(np.ones((100,1)), np.arange(0,1200).reshape(1,1200)), scale=1000)) # save_color_numpy(color_depth_map(np.matmul(np.ones((100,1)), np.arange(0,1200).reshape(1,1200)), scale=1000))

File diff suppressed because it is too large Load Diff

@ -2,31 +2,37 @@ import numpy as np
from . import utils from . import utils
class StopWatch(utils.StopWatch): class StopWatch(utils.StopWatch):
def __del__(self): def __del__(self):
print('='*80) print('=' * 80)
print('gtimer:') print('gtimer:')
total = ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get(reduce=np.sum).items()]) total = ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get(reduce=np.sum).items()])
print(f' [total] {total}') print(f' [total] {total}')
mean = ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get(reduce=np.mean).items()]) mean = ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get(reduce=np.mean).items()])
print(f' [mean] {mean}') print(f' [mean] {mean}')
median = ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get(reduce=np.median).items()]) median = ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get(reduce=np.median).items()])
print(f' [median] {median}') print(f' [median] {median}')
print('='*80) print('=' * 80)
GTIMER = StopWatch() GTIMER = StopWatch()
def start(name): def start(name):
GTIMER.start(name) GTIMER.start(name)
def stop(name): def stop(name):
GTIMER.stop(name) GTIMER.stop(name)
class Ctx(object): class Ctx(object):
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
def __enter__(self): def __enter__(self):
start(self.name) start(self.name)
def __exit__(self, *args): def __exit__(self, *args):
stop(self.name) stop(self.name)

@ -2,266 +2,273 @@ import struct
import numpy as np import numpy as np
import collections import collections
def _write_ply_point(fp, x,y,z, color=None, normal=None, binary=False):
args = [x,y,z] def _write_ply_point(fp, x, y, z, color=None, normal=None, binary=False):
if color is not None: args = [x, y, z]
args += [int(color[0]), int(color[1]), int(color[2])]
if normal is not None:
args += [normal[0],normal[1],normal[2]]
if binary:
fmt = '<fff'
if color is not None:
fmt = fmt + 'BBB'
if normal is not None:
fmt = fmt + 'fff'
fp.write(struct.pack(fmt, *args))
else:
fmt = '%f %f %f'
if color is not None: if color is not None:
fmt = fmt + ' %d %d %d' args += [int(color[0]), int(color[1]), int(color[2])]
if normal is not None: if normal is not None:
fmt = fmt + ' %f %f %f' args += [normal[0], normal[1], normal[2]]
fmt += '\n' if binary:
fp.write(fmt % tuple(args)) fmt = '<fff'
if color is not None:
fmt = fmt + 'BBB'
if normal is not None:
fmt = fmt + 'fff'
fp.write(struct.pack(fmt, *args))
else:
fmt = '%f %f %f'
if color is not None:
fmt = fmt + ' %d %d %d'
if normal is not None:
fmt = fmt + ' %f %f %f'
fmt += '\n'
fp.write(fmt % tuple(args))
def _write_ply_triangle(fp, i0,i1,i2, binary):
if binary:
fp.write(struct.pack('<Biii', 3,i0,i1,i2))
else:
fp.write('3 %d %d %d\n' % (i0,i1,i2))
def _write_ply_header_line(fp, str, binary): def _write_ply_triangle(fp, i0, i1, i2, binary):
if binary: if binary:
fp.write(str.encode()) fp.write(struct.pack('<Biii', 3, i0, i1, i2))
else: else:
fp.write(str) fp.write('3 %d %d %d\n' % (i0, i1, i2))
def write_ply(path, verts, trias=None, color=None, normals=None, binary=False):
if verts.shape[1] != 3:
raise Exception('verts has to be of shape Nx3')
if trias is not None and trias.shape[1] != 3:
raise Exception('trias has to be of shape Nx3')
if color is not None and not callable(color) and not isinstance(color, np.ndarray) and color.shape[1] != 3:
raise Exception('color has to be of shape Nx3 or a callable')
mode = 'wb' if binary else 'w' def _write_ply_header_line(fp, str, binary):
with open(path, mode) as fp:
_write_ply_header_line(fp, "ply\n", binary)
if binary: if binary:
_write_ply_header_line(fp, "format binary_little_endian 1.0\n", binary) fp.write(str.encode())
else: else:
_write_ply_header_line(fp, "format ascii 1.0\n", binary) fp.write(str)
_write_ply_header_line(fp, "element vertex %d\n" % (verts.shape[0]), binary)
_write_ply_header_line(fp, "property float32 x\n", binary)
_write_ply_header_line(fp, "property float32 y\n", binary) def write_ply(path, verts, trias=None, color=None, normals=None, binary=False):
_write_ply_header_line(fp, "property float32 z\n", binary) if verts.shape[1] != 3:
if color is not None: raise Exception('verts has to be of shape Nx3')
_write_ply_header_line(fp, "property uchar red\n", binary) if trias is not None and trias.shape[1] != 3:
_write_ply_header_line(fp, "property uchar green\n", binary) raise Exception('trias has to be of shape Nx3')
_write_ply_header_line(fp, "property uchar blue\n", binary) if color is not None and not callable(color) and not isinstance(color, np.ndarray) and color.shape[1] != 3:
if normals is not None: raise Exception('color has to be of shape Nx3 or a callable')
_write_ply_header_line(fp, "property float32 nx\n", binary)
_write_ply_header_line(fp, "property float32 ny\n", binary)
_write_ply_header_line(fp, "property float32 nz\n", binary)
if trias is not None:
_write_ply_header_line(fp, "element face %d\n" % (trias.shape[0]), binary)
_write_ply_header_line(fp, "property list uchar int32 vertex_indices\n", binary)
_write_ply_header_line(fp, "end_header\n", binary)
for vidx, v in enumerate(verts): mode = 'wb' if binary else 'w'
if color is not None: with open(path, mode) as fp:
if callable(color): _write_ply_header_line(fp, "ply\n", binary)
c = color(vidx) if binary:
elif color.shape[0] > 1: _write_ply_header_line(fp, "format binary_little_endian 1.0\n", binary)
c = color[vidx]
else: else:
c = color[0] _write_ply_header_line(fp, "format ascii 1.0\n", binary)
else: _write_ply_header_line(fp, "element vertex %d\n" % (verts.shape[0]), binary)
c = None _write_ply_header_line(fp, "property float32 x\n", binary)
if normals is None: _write_ply_header_line(fp, "property float32 y\n", binary)
n = None _write_ply_header_line(fp, "property float32 z\n", binary)
else: if color is not None:
n = normals[vidx] _write_ply_header_line(fp, "property uchar red\n", binary)
_write_ply_point(fp, v[0],v[1],v[2], c, n, binary) _write_ply_header_line(fp, "property uchar green\n", binary)
_write_ply_header_line(fp, "property uchar blue\n", binary)
if normals is not None:
_write_ply_header_line(fp, "property float32 nx\n", binary)
_write_ply_header_line(fp, "property float32 ny\n", binary)
_write_ply_header_line(fp, "property float32 nz\n", binary)
if trias is not None:
_write_ply_header_line(fp, "element face %d\n" % (trias.shape[0]), binary)
_write_ply_header_line(fp, "property list uchar int32 vertex_indices\n", binary)
_write_ply_header_line(fp, "end_header\n", binary)
for vidx, v in enumerate(verts):
if color is not None:
if callable(color):
c = color(vidx)
elif color.shape[0] > 1:
c = color[vidx]
else:
c = color[0]
else:
c = None
if normals is None:
n = None
else:
n = normals[vidx]
_write_ply_point(fp, v[0], v[1], v[2], c, n, binary)
if trias is not None:
for t in trias:
_write_ply_triangle(fp, t[0], t[1], t[2], binary)
if trias is not None:
for t in trias:
_write_ply_triangle(fp, t[0],t[1],t[2], binary)
def faces_to_triangles(faces): def faces_to_triangles(faces):
new_faces = [] new_faces = []
for f in faces: for f in faces:
if f[0] == 3: if f[0] == 3:
new_faces.append([f[1], f[2], f[3]]) new_faces.append([f[1], f[2], f[3]])
elif f[0] == 4: elif f[0] == 4:
new_faces.append([f[1], f[2], f[3]]) new_faces.append([f[1], f[2], f[3]])
new_faces.append([f[3], f[4], f[1]]) new_faces.append([f[3], f[4], f[1]])
else: else:
raise Exception('unknown face count %d', f[0]) raise Exception('unknown face count %d', f[0])
return new_faces return new_faces
def read_ply(path): def read_ply(path):
with open(path, 'rb') as f: with open(path, 'rb') as f:
# parse header # parse header
line = f.readline().decode().strip() line = f.readline().decode().strip()
if line != 'ply': if line != 'ply':
raise Exception('Header error') raise Exception('Header error')
n_verts = 0 n_verts = 0
n_faces = 0 n_faces = 0
vert_types = {} vert_types = {}
vert_bin_format = [] vert_bin_format = []
vert_bin_len = 0 vert_bin_len = 0
vert_bin_cols = 0 vert_bin_cols = 0
line = f.readline().decode() line = f.readline().decode()
parse_vertex_prop = False
while line.strip() != 'end_header':
if 'format' in line:
if 'ascii' in line:
binary = False
elif 'binary_little_endian' in line:
binary = True
else:
raise Exception('invalid ply format')
if 'element face' in line:
splits = line.strip().split(' ')
n_faces = int(splits[-1])
parse_vertex_prop = False
if 'element camera' in line:
parse_vertex_prop = False parse_vertex_prop = False
if 'element vertex' in line: while line.strip() != 'end_header':
splits = line.strip().split(' ') if 'format' in line:
n_verts = int(splits[-1]) if 'ascii' in line:
parse_vertex_prop = True binary = False
if parse_vertex_prop and 'property' in line: elif 'binary_little_endian' in line:
prop = line.strip().split() binary = True
if prop[1] == 'float': else:
vert_bin_format.append('f4') raise Exception('invalid ply format')
vert_bin_len += 4 if 'element face' in line:
vert_bin_cols += 1 splits = line.strip().split(' ')
elif prop[1] == 'uchar': n_faces = int(splits[-1])
vert_bin_format.append('B') parse_vertex_prop = False
vert_bin_len += 1 if 'element camera' in line:
vert_bin_cols += 1 parse_vertex_prop = False
else: if 'element vertex' in line:
raise Exception('invalid property') splits = line.strip().split(' ')
vert_types[prop[2]] = len(vert_types) n_verts = int(splits[-1])
line = f.readline().decode() parse_vertex_prop = True
if parse_vertex_prop and 'property' in line:
prop = line.strip().split()
if prop[1] == 'float':
vert_bin_format.append('f4')
vert_bin_len += 4
vert_bin_cols += 1
elif prop[1] == 'uchar':
vert_bin_format.append('B')
vert_bin_len += 1
vert_bin_cols += 1
else:
raise Exception('invalid property')
vert_types[prop[2]] = len(vert_types)
line = f.readline().decode()
# parse content # parse content
if binary: if binary:
sz = n_verts * vert_bin_len sz = n_verts * vert_bin_len
fmt = ','.join(vert_bin_format) fmt = ','.join(vert_bin_format)
verts = np.ndarray(shape=(1, n_verts), dtype=np.dtype(fmt), buffer=f.read(sz)) verts = np.ndarray(shape=(1, n_verts), dtype=np.dtype(fmt), buffer=f.read(sz))
verts = verts[0].astype(vert_bin_cols*'f4,').view(dtype='f4').reshape((n_verts,-1)) verts = verts[0].astype(vert_bin_cols * 'f4,').view(dtype='f4').reshape((n_verts, -1))
faces = [] faces = []
for idx in range(n_faces): for idx in range(n_faces):
fmt = '<Biii' fmt = '<Biii'
length = struct.calcsize(fmt) length = struct.calcsize(fmt)
dat = f.read(length) dat = f.read(length)
vals = struct.unpack(fmt, dat) vals = struct.unpack(fmt, dat)
faces.append(vals) faces.append(vals)
faces = faces_to_triangles(faces) faces = faces_to_triangles(faces)
faces = np.array(faces, dtype=np.int32) faces = np.array(faces, dtype=np.int32)
else: else:
verts = [] verts = []
for idx in range(n_verts): for idx in range(n_verts):
vals = [float(v) for v in f.readline().decode().strip().split(' ')] vals = [float(v) for v in f.readline().decode().strip().split(' ')]
verts.append(vals) verts.append(vals)
verts = np.array(verts, dtype=np.float32) verts = np.array(verts, dtype=np.float32)
faces = [] faces = []
for idx in range(n_faces): for idx in range(n_faces):
splits = f.readline().decode().strip().split(' ') splits = f.readline().decode().strip().split(' ')
n_face_verts = int(splits[0]) n_face_verts = int(splits[0])
vals = [int(v) for v in splits[0:n_face_verts+1]] vals = [int(v) for v in splits[0:n_face_verts + 1]]
faces.append(vals) faces.append(vals)
faces = faces_to_triangles(faces) faces = faces_to_triangles(faces)
faces = np.array(faces, dtype=np.int32) faces = np.array(faces, dtype=np.int32)
xyz = None xyz = None
if 'x' in vert_types and 'y' in vert_types and 'z' in vert_types: if 'x' in vert_types and 'y' in vert_types and 'z' in vert_types:
xyz = verts[:,[vert_types['x'], vert_types['y'], vert_types['z']]] xyz = verts[:, [vert_types['x'], vert_types['y'], vert_types['z']]]
colors = None colors = None
if 'red' in vert_types and 'green' in vert_types and 'blue' in vert_types: if 'red' in vert_types and 'green' in vert_types and 'blue' in vert_types:
colors = verts[:,[vert_types['red'], vert_types['green'], vert_types['blue']]] colors = verts[:, [vert_types['red'], vert_types['green'], vert_types['blue']]]
colors /= 255 colors /= 255
normals = None normals = None
if 'nx' in vert_types and 'ny' in vert_types and 'nz' in vert_types: if 'nx' in vert_types and 'ny' in vert_types and 'nz' in vert_types:
normals = verts[:,[vert_types['nx'], vert_types['ny'], vert_types['nz']]] normals = verts[:, [vert_types['nx'], vert_types['ny'], vert_types['nz']]]
return xyz, faces, colors, normals return xyz, faces, colors, normals
def _read_obj_split_f(s): def _read_obj_split_f(s):
parts = s.split('/') parts = s.split('/')
vidx = int(parts[0]) - 1 vidx = int(parts[0]) - 1
if len(parts) >= 2 and len(parts[1]) > 0: if len(parts) >= 2 and len(parts[1]) > 0:
tidx = int(parts[1]) - 1 tidx = int(parts[1]) - 1
else: else:
tidx = -1 tidx = -1
if len(parts) >= 3 and len(parts[2]) > 0: if len(parts) >= 3 and len(parts[2]) > 0:
nidx = int(parts[2]) - 1 nidx = int(parts[2]) - 1
else: else:
nidx = -1 nidx = -1
return vidx, tidx, nidx return vidx, tidx, nidx
def read_obj(path): def read_obj(path):
with open(path, 'r') as fp: with open(path, 'r') as fp:
lines = fp.readlines() lines = fp.readlines()
verts = [] verts = []
colors = [] colors = []
fnorms = [] fnorms = []
fnorm_map = collections.defaultdict(list) fnorm_map = collections.defaultdict(list)
faces = [] faces = []
for line in lines: for line in lines:
line = line.strip() line = line.strip()
if line.startswith('#') or len(line) == 0: if line.startswith('#') or len(line) == 0:
continue continue
parts = line.split() parts = line.split()
if line.startswith('v '): if line.startswith('v '):
parts = parts[1:] parts = parts[1:]
x,y,z = float(parts[0]), float(parts[1]), float(parts[2]) x, y, z = float(parts[0]), float(parts[1]), float(parts[2])
if len(parts) == 4 or len(parts) == 7: if len(parts) == 4 or len(parts) == 7:
w = float(parts[3]) w = float(parts[3])
x,y,z = x/w, y/w, z/w x, y, z = x / w, y / w, z / w
verts.append((x,y,z)) verts.append((x, y, z))
if len(parts) >= 6: if len(parts) >= 6:
r,g,b = float(parts[-3]), float(parts[-2]), float(parts[-1]) r, g, b = float(parts[-3]), float(parts[-2]), float(parts[-1])
rgb.append((r,g,b)) rgb.append((r, g, b))
elif line.startswith('vn '): elif line.startswith('vn '):
parts = parts[1:] parts = parts[1:]
x,y,z = float(parts[0]), float(parts[1]), float(parts[2]) x, y, z = float(parts[0]), float(parts[1]), float(parts[2])
fnorms.append((x,y,z)) fnorms.append((x, y, z))
elif line.startswith('f '): elif line.startswith('f '):
parts = parts[1:] parts = parts[1:]
if len(parts) != 3: if len(parts) != 3:
raise Exception('only triangle meshes supported atm') raise Exception('only triangle meshes supported atm')
vidx0, tidx0, nidx0 = _read_obj_split_f(parts[0]) vidx0, tidx0, nidx0 = _read_obj_split_f(parts[0])
vidx1, tidx1, nidx1 = _read_obj_split_f(parts[1]) vidx1, tidx1, nidx1 = _read_obj_split_f(parts[1])
vidx2, tidx2, nidx2 = _read_obj_split_f(parts[2]) vidx2, tidx2, nidx2 = _read_obj_split_f(parts[2])
faces.append((vidx0, vidx1, vidx2)) faces.append((vidx0, vidx1, vidx2))
if nidx0 >= 0: if nidx0 >= 0:
fnorm_map[vidx0].append( nidx0 ) fnorm_map[vidx0].append(nidx0)
if nidx1 >= 0: if nidx1 >= 0:
fnorm_map[vidx1].append( nidx1 ) fnorm_map[vidx1].append(nidx1)
if nidx2 >= 0: if nidx2 >= 0:
fnorm_map[vidx2].append( nidx2 ) fnorm_map[vidx2].append(nidx2)
verts = np.array(verts) verts = np.array(verts)
colors = np.array(colors) colors = np.array(colors)
fnorms = np.array(fnorms) fnorms = np.array(fnorms)
faces = np.array(faces) faces = np.array(faces)
# face normals to vertex normals # face normals to vertex normals
norms = np.zeros_like(verts) norms = np.zeros_like(verts)
for vidx in fnorm_map.keys(): for vidx in fnorm_map.keys():
ind = fnorm_map[vidx] ind = fnorm_map[vidx]
norms[vidx] = fnorms[ind].sum(axis=0) norms[vidx] = fnorms[ind].sum(axis=0)
N = np.linalg.norm(norms, axis=1, keepdims=True) N = np.linalg.norm(norms, axis=1, keepdims=True)
np.divide(norms, N, out=norms, where=N != 0) np.divide(norms, N, out=norms, where=N != 0)
return verts, faces, colors, norms return verts, faces, colors, norms

@ -1,248 +1,260 @@
import numpy as np import numpy as np
from . import geometry from . import geometry
def _process_inputs(estimate, target, mask): def _process_inputs(estimate, target, mask):
if estimate.shape != target.shape: if estimate.shape != target.shape:
raise Exception('estimate and target have to be same shape') raise Exception('estimate and target have to be same shape')
if mask is None: if mask is None:
mask = np.ones(estimate.shape, dtype=np.bool) mask = np.ones(estimate.shape, dtype=np.bool)
else: else:
mask = mask != 0 mask = mask != 0
if estimate.shape != mask.shape: if estimate.shape != mask.shape:
raise Exception('estimate and mask have to be same shape') raise Exception('estimate and mask have to be same shape')
return estimate, target, mask return estimate, target, mask
def mse(estimate, target, mask=None): def mse(estimate, target, mask=None):
estimate, target, mask = _process_inputs(estimate, target, mask) estimate, target, mask = _process_inputs(estimate, target, mask)
m = np.sum((estimate[mask] - target[mask])**2) / mask.sum() m = np.sum((estimate[mask] - target[mask]) ** 2) / mask.sum()
return m return m
def rmse(estimate, target, mask=None): def rmse(estimate, target, mask=None):
return np.sqrt(mse(estimate, target, mask)) return np.sqrt(mse(estimate, target, mask))
def mae(estimate, target, mask=None): def mae(estimate, target, mask=None):
estimate, target, mask = _process_inputs(estimate, target, mask) estimate, target, mask = _process_inputs(estimate, target, mask)
m = np.abs(estimate[mask] - target[mask]).sum() / mask.sum() m = np.abs(estimate[mask] - target[mask]).sum() / mask.sum()
return m return m
def outlier_fraction(estimate, target, mask=None, threshold=0): def outlier_fraction(estimate, target, mask=None, threshold=0):
estimate, target, mask = _process_inputs(estimate, target, mask) estimate, target, mask = _process_inputs(estimate, target, mask)
diff = np.abs(estimate[mask] - target[mask]) diff = np.abs(estimate[mask] - target[mask])
m = (diff > threshold).sum() / mask.sum() m = (diff > threshold).sum() / mask.sum()
return m return m
class Metric(object): class Metric(object):
def __init__(self, str_prefix=''): def __init__(self, str_prefix=''):
self.str_prefix = str_prefix self.str_prefix = str_prefix
self.reset() self.reset()
def reset(self):
pass
def reset(self): def add(self, es, ta, ma=None):
pass pass
def add(self, es, ta, ma=None): def get(self):
pass return {}
def get(self): def items(self):
return {} return self.get().items()
def items(self): def __str__(self):
return self.get().items() return ', '.join([f'{self.str_prefix}{key}={value:.5f}' for key, value in self.get().items()])
def __str__(self):
return ', '.join([f'{self.str_prefix}{key}={value:.5f}' for key, value in self.get().items()])
class MultipleMetric(Metric): class MultipleMetric(Metric):
def __init__(self, *metrics, **kwargs): def __init__(self, *metrics, **kwargs):
self.metrics = [*metrics] self.metrics = [*metrics]
super().__init__(**kwargs) super().__init__(**kwargs)
def reset(self):
for m in self.metrics:
m.reset()
def reset(self): def add(self, es, ta, ma=None):
for m in self.metrics: for m in self.metrics:
m.reset() m.add(es, ta, ma)
def add(self, es, ta, ma=None): def get(self):
for m in self.metrics: ret = {}
m.add(es, ta, ma) for m in self.metrics:
vals = m.get()
for k in vals:
ret[k] = vals[k]
return ret
def get(self): def __str__(self):
ret = {} return '\n'.join([str(m) for m in self.metrics])
for m in self.metrics:
vals = m.get()
for k in vals:
ret[k] = vals[k]
return ret
def __str__(self):
return '\n'.join([str(m) for m in self.metrics])
class BaseDistanceMetric(Metric): class BaseDistanceMetric(Metric):
def __init__(self, name='', **kwargs): def __init__(self, name='', **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.name = name self.name = name
def reset(self): def reset(self):
self.dists = [] self.dists = []
def add(self, es, ta, ma=None): def add(self, es, ta, ma=None):
pass pass
def get(self): def get(self):
dists = np.hstack(self.dists) dists = np.hstack(self.dists)
return { return {
f'dist{self.name}_mean': float(np.mean(dists)), f'dist{self.name}_mean': float(np.mean(dists)),
f'dist{self.name}_std': float(np.std(dists)), f'dist{self.name}_std': float(np.std(dists)),
f'dist{self.name}_median': float(np.median(dists)), f'dist{self.name}_median': float(np.median(dists)),
f'dist{self.name}_q10': float(np.percentile(dists, 10)), f'dist{self.name}_q10': float(np.percentile(dists, 10)),
f'dist{self.name}_q90': float(np.percentile(dists, 90)), f'dist{self.name}_q90': float(np.percentile(dists, 90)),
f'dist{self.name}_min': float(np.min(dists)), f'dist{self.name}_min': float(np.min(dists)),
f'dist{self.name}_max': float(np.max(dists)), f'dist{self.name}_max': float(np.max(dists)),
} }
class DistanceMetric(BaseDistanceMetric): class DistanceMetric(BaseDistanceMetric):
def __init__(self, vec_length, p=2, **kwargs): def __init__(self, vec_length, p=2, **kwargs):
super().__init__(name=f'{p}', **kwargs) super().__init__(name=f'{p}', **kwargs)
self.vec_length = vec_length self.vec_length = vec_length
self.p = p self.p = p
def add(self, es, ta, ma=None): def add(self, es, ta, ma=None):
if es.shape != ta.shape or es.shape[1] != self.vec_length or es.ndim != 2: if es.shape != ta.shape or es.shape[1] != self.vec_length or es.ndim != 2:
print(es.shape, ta.shape) print(es.shape, ta.shape)
raise Exception('es and ta have to be of shape Nxdim') raise Exception('es and ta have to be of shape Nxdim')
if ma is not None: if ma is not None:
es = es[ma != 0] es = es[ma != 0]
ta = ta[ma != 0] ta = ta[ma != 0]
dist = np.linalg.norm(es - ta, ord=self.p, axis=1) dist = np.linalg.norm(es - ta, ord=self.p, axis=1)
self.dists.append( dist ) self.dists.append(dist)
class OutlierFractionMetric(DistanceMetric): class OutlierFractionMetric(DistanceMetric):
def __init__(self, thresholds, *args, **kwargs): def __init__(self, thresholds, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.thresholds = thresholds self.thresholds = thresholds
def get(self): def get(self):
dists = np.hstack(self.dists) dists = np.hstack(self.dists)
ret = {} ret = {}
for t in self.thresholds: for t in self.thresholds:
ma = dists > t ma = dists > t
ret[f'of{t}'] = float(ma.sum() / ma.size) ret[f'of{t}'] = float(ma.sum() / ma.size)
return ret return ret
class RelativeDistanceMetric(BaseDistanceMetric): class RelativeDistanceMetric(BaseDistanceMetric):
def __init__(self, vec_length, p=2, **kwargs): def __init__(self, vec_length, p=2, **kwargs):
super().__init__(name=f'rel{p}', **kwargs) super().__init__(name=f'rel{p}', **kwargs)
self.vec_length = vec_length self.vec_length = vec_length
self.p = p self.p = p
def add(self, es, ta, ma=None): def add(self, es, ta, ma=None):
if es.shape != ta.shape or es.shape[1] != self.vec_length or es.ndim != 2: if es.shape != ta.shape or es.shape[1] != self.vec_length or es.ndim != 2:
raise Exception('es and ta have to be of shape Nxdim') raise Exception('es and ta have to be of shape Nxdim')
dist = np.linalg.norm(es - ta, ord=self.p, axis=1) dist = np.linalg.norm(es - ta, ord=self.p, axis=1)
denom = np.linalg.norm(ta, ord=self.p, axis=1) denom = np.linalg.norm(ta, ord=self.p, axis=1)
dist /= denom dist /= denom
if ma is not None: if ma is not None:
dist = dist[ma != 0] dist = dist[ma != 0]
self.dists.append( dist ) self.dists.append(dist)
class RotmDistanceMetric(BaseDistanceMetric): class RotmDistanceMetric(BaseDistanceMetric):
def __init__(self, type='identity', **kwargs): def __init__(self, type='identity', **kwargs):
super().__init__(name=type, **kwargs) super().__init__(name=type, **kwargs)
self.type = type self.type = type
def add(self, es, ta, ma=None): def add(self, es, ta, ma=None):
if es.shape != ta.shape or es.shape[1] != 3 or es.shape[2] != 3 or es.ndim != 3: if es.shape != ta.shape or es.shape[1] != 3 or es.shape[2] != 3 or es.ndim != 3:
print(es.shape, ta.shape) print(es.shape, ta.shape)
raise Exception('es and ta have to be of shape Nx3x3') raise Exception('es and ta have to be of shape Nx3x3')
if ma is not None: if ma is not None:
raise Exception('mask is not implemented') raise Exception('mask is not implemented')
if self.type == 'identity': if self.type == 'identity':
self.dists.append( geometry.rotm_distance_identity(es, ta) ) self.dists.append(geometry.rotm_distance_identity(es, ta))
elif self.type == 'geodesic': elif self.type == 'geodesic':
self.dists.append( geometry.rotm_distance_geodesic_unit_sphere(es, ta) ) self.dists.append(geometry.rotm_distance_geodesic_unit_sphere(es, ta))
else: else:
raise Exception('invalid distance type') raise Exception('invalid distance type')
class QuaternionDistanceMetric(BaseDistanceMetric): class QuaternionDistanceMetric(BaseDistanceMetric):
def __init__(self, type='angle', **kwargs): def __init__(self, type='angle', **kwargs):
super().__init__(name=type, **kwargs) super().__init__(name=type, **kwargs)
self.type = type self.type = type
def add(self, es, ta, ma=None): def add(self, es, ta, ma=None):
if es.shape != ta.shape or es.shape[1] != 4 or es.ndim != 2: if es.shape != ta.shape or es.shape[1] != 4 or es.ndim != 2:
print(es.shape, ta.shape) print(es.shape, ta.shape)
raise Exception('es and ta have to be of shape Nx4') raise Exception('es and ta have to be of shape Nx4')
if ma is not None: if ma is not None:
raise Exception('mask is not implemented') raise Exception('mask is not implemented')
if self.type == 'angle': if self.type == 'angle':
self.dists.append( geometry.quat_distance_angle(es, ta) ) self.dists.append(geometry.quat_distance_angle(es, ta))
elif self.type == 'mineucl': elif self.type == 'mineucl':
self.dists.append( geometry.quat_distance_mineucl(es, ta) ) self.dists.append(geometry.quat_distance_mineucl(es, ta))
elif self.type == 'normdiff': elif self.type == 'normdiff':
self.dists.append( geometry.quat_distance_normdiff(es, ta) ) self.dists.append(geometry.quat_distance_normdiff(es, ta))
else: else:
raise Exception('invalid distance type') raise Exception('invalid distance type')
class BinaryAccuracyMetric(Metric): class BinaryAccuracyMetric(Metric):
def __init__(self, thresholds=np.linspace(0.0, 1.0, num=101, dtype=np.float64)[:-1], **kwargs): def __init__(self, thresholds=np.linspace(0.0, 1.0, num=101, dtype=np.float64)[:-1], **kwargs):
self.thresholds = thresholds self.thresholds = thresholds
super().__init__(**kwargs) super().__init__(**kwargs)
def reset(self): def reset(self):
self.tps = [0 for wp in self.thresholds] self.tps = [0 for wp in self.thresholds]
self.fps = [0 for wp in self.thresholds] self.fps = [0 for wp in self.thresholds]
self.fns = [0 for wp in self.thresholds] self.fns = [0 for wp in self.thresholds]
self.tns = [0 for wp in self.thresholds] self.tns = [0 for wp in self.thresholds]
self.n_pos = 0 self.n_pos = 0
self.n_neg = 0 self.n_neg = 0
def add(self, es, ta, ma=None): def add(self, es, ta, ma=None):
if ma is not None: if ma is not None:
raise Exception('mask is not implemented') raise Exception('mask is not implemented')
es = es.ravel() es = es.ravel()
ta = ta.ravel() ta = ta.ravel()
if es.shape[0] != ta.shape[0]: if es.shape[0] != ta.shape[0]:
raise Exception('invalid shape of es, or ta') raise Exception('invalid shape of es, or ta')
if es.min() < 0 or es.max() > 1: if es.min() < 0 or es.max() > 1:
raise Exception('estimate has wrong value range') raise Exception('estimate has wrong value range')
ta_p = (ta == 1) ta_p = (ta == 1)
ta_n = (ta == 0) ta_n = (ta == 0)
es_p = es[ta_p] es_p = es[ta_p]
es_n = es[ta_n] es_n = es[ta_n]
for idx, wp in enumerate(self.thresholds): for idx, wp in enumerate(self.thresholds):
wp = np.asscalar(wp) wp = np.asscalar(wp)
self.tps[idx] += (es_p > wp).sum() self.tps[idx] += (es_p > wp).sum()
self.fps[idx] += (es_n > wp).sum() self.fps[idx] += (es_n > wp).sum()
self.fns[idx] += (es_p <= wp).sum() self.fns[idx] += (es_p <= wp).sum()
self.tns[idx] += (es_n <= wp).sum() self.tns[idx] += (es_n <= wp).sum()
self.n_pos += ta_p.sum() self.n_pos += ta_p.sum()
self.n_neg += ta_n.sum() self.n_neg += ta_n.sum()
def get(self): def get(self):
tps = np.array(self.tps).astype(np.float32) tps = np.array(self.tps).astype(np.float32)
fps = np.array(self.fps).astype(np.float32) fps = np.array(self.fps).astype(np.float32)
fns = np.array(self.fns).astype(np.float32) fns = np.array(self.fns).astype(np.float32)
tns = np.array(self.tns).astype(np.float32) tns = np.array(self.tns).astype(np.float32)
wp = self.thresholds wp = self.thresholds
ret = {} ret = {}
precisions = np.divide(tps, tps + fps, out=np.zeros_like(tps), where=tps + fps != 0) precisions = np.divide(tps, tps + fps, out=np.zeros_like(tps), where=tps + fps != 0)
recalls = np.divide(tps, tps + fns, out=np.zeros_like(tps), where=tps + fns != 0) # tprs recalls = np.divide(tps, tps + fns, out=np.zeros_like(tps), where=tps + fns != 0) # tprs
fprs = np.divide(fps, fps + tns, out=np.zeros_like(tps), where=fps + tns != 0) fprs = np.divide(fps, fps + tns, out=np.zeros_like(tps), where=fps + tns != 0)
precisions = np.r_[0, precisions, 1] precisions = np.r_[0, precisions, 1]
recalls = np.r_[1, recalls, 0] recalls = np.r_[1, recalls, 0]
fprs = np.r_[1, fprs, 0] fprs = np.r_[1, fprs, 0]
ret['auc'] = float(-np.trapz(recalls, fprs)) ret['auc'] = float(-np.trapz(recalls, fprs))
ret['prauc'] = float(-np.trapz(precisions, recalls)) ret['prauc'] = float(-np.trapz(precisions, recalls))
ret['ap'] = float(-(np.diff(recalls) * precisions[:-1]).sum()) ret['ap'] = float(-(np.diff(recalls) * precisions[:-1]).sum())
accuracies = np.divide(tps + tns, tps + tns + fps + fns) accuracies = np.divide(tps + tns, tps + tns + fps + fns)
aacc = np.mean(accuracies) aacc = np.mean(accuracies)
for t in np.linspace(0,1,num=11)[1:-1]: for t in np.linspace(0, 1, num=11)[1:-1]:
idx = np.argmin(np.abs(t - wp)) idx = np.argmin(np.abs(t - wp))
ret[f'acc{wp[idx]:.2f}'] = float(accuracies[idx]) ret[f'acc{wp[idx]:.2f}'] = float(accuracies[idx])
return ret return ret

@ -6,94 +6,99 @@ import matplotlib.pyplot as plt
import os import os
import time import time
def save(path, remove_axis=False, dpi=300, fig=None): def save(path, remove_axis=False, dpi=300, fig=None):
if fig is None: if fig is None:
fig = plt.gcf() fig = plt.gcf()
dirname = os.path.dirname(path) dirname = os.path.dirname(path)
if dirname != '' and not os.path.exists(dirname): if dirname != '' and not os.path.exists(dirname):
os.makedirs(dirname) os.makedirs(dirname)
if remove_axis: if remove_axis:
for ax in fig.axes: for ax in fig.axes:
ax.axis('off') ax.axis('off')
ax.margins(0,0) ax.margins(0, 0)
fig.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) fig.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
for ax in fig.axes: for ax in fig.axes:
ax.xaxis.set_major_locator(plt.NullLocator()) ax.xaxis.set_major_locator(plt.NullLocator())
ax.yaxis.set_major_locator(plt.NullLocator()) ax.yaxis.set_major_locator(plt.NullLocator())
fig.savefig(path, dpi=dpi, bbox_inches='tight', pad_inches=0) fig.savefig(path, dpi=dpi, bbox_inches='tight', pad_inches=0)
def color_map(im_, cmap='viridis', vmin=None, vmax=None): def color_map(im_, cmap='viridis', vmin=None, vmax=None):
cm = plt.get_cmap(cmap) cm = plt.get_cmap(cmap)
im = im_.copy() im = im_.copy()
if vmin is None: if vmin is None:
vmin = np.nanmin(im) vmin = np.nanmin(im)
if vmax is None: if vmax is None:
vmax = np.nanmax(im) vmax = np.nanmax(im)
mask = np.logical_not(np.isfinite(im)) mask = np.logical_not(np.isfinite(im))
im[mask] = vmin im[mask] = vmin
im = (im.clip(vmin, vmax) - vmin) / (vmax - vmin) im = (im.clip(vmin, vmax) - vmin) / (vmax - vmin)
im = cm(im) im = cm(im)
im = im[...,:3] im = im[..., :3]
for c in range(3): for c in range(3):
im[mask, c] = 1 im[mask, c] = 1
return im return im
def interactive_legend(leg=None, fig=None, all_axes=True): def interactive_legend(leg=None, fig=None, all_axes=True):
if leg is None: if leg is None:
leg = plt.legend() leg = plt.legend()
if fig is None: if fig is None:
fig = plt.gcf() fig = plt.gcf()
if all_axes: if all_axes:
axs = fig.get_axes() axs = fig.get_axes()
else: else:
axs = [fig.gca()] axs = [fig.gca()]
# lined = dict() # lined = dict()
# lines = ax.lines # lines = ax.lines
# for legline, origline in zip(leg.get_lines(), ax.lines): # for legline, origline in zip(leg.get_lines(), ax.lines):
# legline.set_picker(5) # legline.set_picker(5)
# lined[legline] = origline # lined[legline] = origline
lined = dict() lined = dict()
for lidx, legline in enumerate(leg.get_lines()): for lidx, legline in enumerate(leg.get_lines()):
legline.set_picker(5) legline.set_picker(5)
lined[legline] = [ax.lines[lidx] for ax in axs] lined[legline] = [ax.lines[lidx] for ax in axs]
def onpick(event): def onpick(event):
if event.mouseevent.dblclick: if event.mouseevent.dblclick:
tmp = [(k,v) for k,v in lined.items()] tmp = [(k, v) for k, v in lined.items()]
else: else:
tmp = [(event.artist, lined[event.artist])] tmp = [(event.artist, lined[event.artist])]
for legline, origline in tmp:
for ol in origline:
vis = not ol.get_visible()
ol.set_visible(vis)
if vis:
legline.set_alpha(1.0)
else:
legline.set_alpha(0.2)
fig.canvas.draw()
for legline, origline in tmp: fig.canvas.mpl_connect('pick_event', onpick)
for ol in origline:
vis = not ol.get_visible()
ol.set_visible(vis)
if vis:
legline.set_alpha(1.0)
else:
legline.set_alpha(0.2)
fig.canvas.draw()
fig.canvas.mpl_connect('pick_event', onpick)
def non_annoying_pause(interval, focus_figure=False): def non_annoying_pause(interval, focus_figure=False):
# https://github.com/matplotlib/matplotlib/issues/11131 # https://github.com/matplotlib/matplotlib/issues/11131
backend = mpl.rcParams['backend'] backend = mpl.rcParams['backend']
if backend in _interactive_bk: if backend in _interactive_bk:
figManager = _pylab_helpers.Gcf.get_active() figManager = _pylab_helpers.Gcf.get_active()
if figManager is not None: if figManager is not None:
canvas = figManager.canvas canvas = figManager.canvas
if canvas.figure.stale: if canvas.figure.stale:
canvas.draw() canvas.draw()
if focus_figure: if focus_figure:
plt.show(block=False) plt.show(block=False)
canvas.start_event_loop(interval) canvas.start_event_loop(interval)
return return
time.sleep(interval) time.sleep(interval)
def remove_all_ticks(fig=None): def remove_all_ticks(fig=None):
if fig is None: if fig is None:
fig = plt.gcf() fig = plt.gcf()
for ax in fig.axes: for ax in fig.axes:
ax.axes.get_xaxis().set_visible(False) ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False) ax.axes.get_yaxis().set_visible(False)

@ -3,55 +3,60 @@ import matplotlib.pyplot as plt
from . import geometry from . import geometry
def image_matrix(ims, bgval=0): def image_matrix(ims, bgval=0):
n = ims.shape[0] n = ims.shape[0]
m = int( np.ceil(np.sqrt(n)) ) m = int(np.ceil(np.sqrt(n)))
h = ims.shape[1] h = ims.shape[1]
w = ims.shape[2] w = ims.shape[2]
mat = np.empty((m*h, m*w), dtype=ims.dtype) mat = np.empty((m * h, m * w), dtype=ims.dtype)
mat.fill(bgval) mat.fill(bgval)
idx = 0 idx = 0
for r in range(m): for r in range(m):
for c in range(m): for c in range(m):
if idx < n: if idx < n:
mat[r*h:(r+1)*h, c*w:(c+1)*w] = ims[idx] mat[r * h:(r + 1) * h, c * w:(c + 1) * w] = ims[idx]
idx += 1 idx += 1
return mat return mat
def image_cat(ims, vertical=False): def image_cat(ims, vertical=False):
offx = [0] offx = [0]
offy = [0] offy = [0]
if vertical: if vertical:
width = max([im.shape[1] for im in ims]) width = max([im.shape[1] for im in ims])
offx += [0 for im in ims[:-1]] offx += [0 for im in ims[:-1]]
offy += [im.shape[0] for im in ims[:-1]] offy += [im.shape[0] for im in ims[:-1]]
height = sum([im.shape[0] for im in ims]) height = sum([im.shape[0] for im in ims])
else: else:
height = max([im.shape[0] for im in ims]) height = max([im.shape[0] for im in ims])
offx += [im.shape[1] for im in ims[:-1]] offx += [im.shape[1] for im in ims[:-1]]
offy += [0 for im in ims[:-1]] offy += [0 for im in ims[:-1]]
width = sum([im.shape[1] for im in ims]) width = sum([im.shape[1] for im in ims])
offx = np.cumsum(offx) offx = np.cumsum(offx)
offy = np.cumsum(offy) offy = np.cumsum(offy)
im = np.zeros((height,width,*ims[0].shape[2:]), dtype=ims[0].dtype) im = np.zeros((height, width, *ims[0].shape[2:]), dtype=ims[0].dtype)
for im0, ox, oy in zip(ims, offx, offy): for im0, ox, oy in zip(ims, offx, offy):
im[oy:oy + im0.shape[0], ox:ox + im0.shape[1]] = im0 im[oy:oy + im0.shape[0], ox:ox + im0.shape[1]] = im0
return im, offx, offy return im, offx, offy
def line(li, h, w, ax=None, *args, **kwargs): def line(li, h, w, ax=None, *args, **kwargs):
if ax is None: if ax is None:
ax = plt.gca() ax = plt.gca()
xs = (-li[2] - li[1] * np.array((0, h-1))) / li[0] xs = (-li[2] - li[1] * np.array((0, h - 1))) / li[0]
ys = (-li[2] - li[0] * np.array((0, w-1))) / li[1] ys = (-li[2] - li[0] * np.array((0, w - 1))) / li[1]
pts = np.array([(0,ys[0]), (w-1, ys[1]), (xs[0], 0), (xs[1], h-1)]) pts = np.array([(0, ys[0]), (w - 1, ys[1]), (xs[0], 0), (xs[1], h - 1)])
pts = pts[np.logical_and(np.logical_and(pts[:,0] >= 0, pts[:,0] < w), np.logical_and(pts[:,1] >= 0, pts[:,1] < h))] pts = pts[
ax.plot(pts[:,0], pts[:,1], *args, **kwargs) np.logical_and(np.logical_and(pts[:, 0] >= 0, pts[:, 0] < w), np.logical_and(pts[:, 1] >= 0, pts[:, 1] < h))]
ax.plot(pts[:, 0], pts[:, 1], *args, **kwargs)
def depthshow(depth, *args, ax=None, **kwargs): def depthshow(depth, *args, ax=None, **kwargs):
if ax is None: if ax is None:
ax = plt.gca() ax = plt.gca()
d = depth.copy() d = depth.copy()
d[d < 0] = np.NaN d[d < 0] = np.NaN
ax.imshow(d, *args, **kwargs) ax.imshow(d, *args, **kwargs)

@ -4,35 +4,45 @@ from mpl_toolkits.mplot3d import Axes3D
from . import geometry from . import geometry
def ax3d(fig=None): def ax3d(fig=None):
if fig is None: if fig is None:
fig = plt.gcf() fig = plt.gcf()
return fig.add_subplot(111, projection='3d') return fig.add_subplot(111, projection='3d')
def plot_camera(ax=None, R=np.eye(3), t=np.zeros((3,)), size=25, marker_C='.', color='b', linestyle='-', linewidth=0.1, label=None, **kwargs):
if ax is None: def plot_camera(ax=None, R=np.eye(3), t=np.zeros((3,)), size=25, marker_C='.', color='b', linestyle='-', linewidth=0.1,
ax = plt.gca() label=None, **kwargs):
C0 = geometry.translation_to_cameracenter(R, t).ravel() if ax is None:
C1 = C0 + R.T.dot( np.array([[-size],[-size],[3*size]], dtype=np.float32) ).ravel() ax = plt.gca()
C2 = C0 + R.T.dot( np.array([[-size],[+size],[3*size]], dtype=np.float32) ).ravel() C0 = geometry.translation_to_cameracenter(R, t).ravel()
C3 = C0 + R.T.dot( np.array([[+size],[+size],[3*size]], dtype=np.float32) ).ravel() C1 = C0 + R.T.dot(np.array([[-size], [-size], [3 * size]], dtype=np.float32)).ravel()
C4 = C0 + R.T.dot( np.array([[+size],[-size],[3*size]], dtype=np.float32) ).ravel() C2 = C0 + R.T.dot(np.array([[-size], [+size], [3 * size]], dtype=np.float32)).ravel()
C3 = C0 + R.T.dot(np.array([[+size], [+size], [3 * size]], dtype=np.float32)).ravel()
if marker_C != '': C4 = C0 + R.T.dot(np.array([[+size], [-size], [3 * size]], dtype=np.float32)).ravel()
ax.plot([C0[0]], [C0[1]], [C0[2]], marker=marker_C, color=color, label=label, **kwargs)
ax.plot([C0[0], C1[0]], [C0[1], C1[1]], [C0[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs) if marker_C != '':
ax.plot([C0[0], C2[0]], [C0[1], C2[1]], [C0[2], C2[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs) ax.plot([C0[0]], [C0[1]], [C0[2]], marker=marker_C, color=color, label=label, **kwargs)
ax.plot([C0[0], C3[0]], [C0[1], C3[1]], [C0[2], C3[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs) ax.plot([C0[0], C1[0]], [C0[1], C1[1]], [C0[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle,
ax.plot([C0[0], C4[0]], [C0[1], C4[1]], [C0[2], C4[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs) linewidth=linewidth, **kwargs)
ax.plot([C1[0], C2[0], C3[0], C4[0], C1[0]], [C1[1], C2[1], C3[1], C4[1], C1[1]], [C1[2], C2[2], C3[2], C4[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs) ax.plot([C0[0], C2[0]], [C0[1], C2[1]], [C0[2], C2[2]], color=color, label='_nolegend_', linestyle=linestyle,
linewidth=linewidth, **kwargs)
ax.plot([C0[0], C3[0]], [C0[1], C3[1]], [C0[2], C3[2]], color=color, label='_nolegend_', linestyle=linestyle,
linewidth=linewidth, **kwargs)
ax.plot([C0[0], C4[0]], [C0[1], C4[1]], [C0[2], C4[2]], color=color, label='_nolegend_', linestyle=linestyle,
linewidth=linewidth, **kwargs)
ax.plot([C1[0], C2[0], C3[0], C4[0], C1[0]], [C1[1], C2[1], C3[1], C4[1], C1[1]],
[C1[2], C2[2], C3[2], C4[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle,
linewidth=linewidth, **kwargs)
def axis_equal(ax=None): def axis_equal(ax=None):
if ax is None: if ax is None:
ax = plt.gca() ax = plt.gca()
extents = np.array([getattr(ax, 'get_{}lim'.format(dim))() for dim in 'xyz']) extents = np.array([getattr(ax, 'get_{}lim'.format(dim))() for dim in 'xyz'])
sz = extents[:,1] - extents[:,0] sz = extents[:, 1] - extents[:, 0]
centers = np.mean(extents, axis=1) centers = np.mean(extents, axis=1)
maxsize = max(abs(sz)) maxsize = max(abs(sz))
r = maxsize/2 r = maxsize / 2
for ctr, dim in zip(centers, 'xyz'): for ctr, dim in zip(centers, 'xyz'):
getattr(ax, 'set_{}lim'.format(dim))(ctr - r, ctr + r) getattr(ax, 'set_{}lim'.format(dim))(ctr - r, ctr + r)

@ -3,443 +3,453 @@ import pandas as pd
import enum import enum
import itertools import itertools
class Table(object): class Table(object):
def __init__(self, n_cols): def __init__(self, n_cols):
self.n_cols = n_cols self.n_cols = n_cols
self.rows = [] self.rows = []
self.aligns = ['r' for c in range(n_cols)] self.aligns = ['r' for c in range(n_cols)]
def get_cell_align(self, r, c): def get_cell_align(self, r, c):
align = self.rows[r].cells[c].align align = self.rows[r].cells[c].align
if align is None: if align is None:
return self.aligns[c] return self.aligns[c]
else:
return align
def add_row(self, row):
if row.ncols() != self.n_cols:
raise Exception(f'row has invalid number of cols, {row.ncols()} vs. {self.n_cols}')
self.rows.append(row)
def empty_row(self):
return Row.Empty(self.n_cols)
def expand_rows(self, n_add_cols=1):
if n_add_cols < 0: raise Exception('n_add_cols has to be positive')
self.n_cols += n_add_cols
for row in self.rows:
row.cells.extend([Cell() for cidx in range(n_add_cols)])
def add_block(self, data, row=-1, col=0, fmt=None, expand=False):
if row < 0: row = len(self.rows)
while len(self.rows) < row + len(data):
self.add_row(self.empty_row())
for r in range(len(data)):
cols = data[r]
if col + len(cols) > self.n_cols:
if expand:
self.expand_rows(col + len(cols) - self.n_cols)
else: else:
raise Exception('number of cols does not fit in table') return align
for c in range(len(cols)):
self.rows[row+r].cells[col+c] = Cell(data[r][c], fmt) def add_row(self, row):
if row.ncols() != self.n_cols:
raise Exception(f'row has invalid number of cols, {row.ncols()} vs. {self.n_cols}')
self.rows.append(row)
def empty_row(self):
return Row.Empty(self.n_cols)
def expand_rows(self, n_add_cols=1):
if n_add_cols < 0: raise Exception('n_add_cols has to be positive')
self.n_cols += n_add_cols
for row in self.rows:
row.cells.extend([Cell() for cidx in range(n_add_cols)])
def add_block(self, data, row=-1, col=0, fmt=None, expand=False):
if row < 0: row = len(self.rows)
while len(self.rows) < row + len(data):
self.add_row(self.empty_row())
for r in range(len(data)):
cols = data[r]
if col + len(cols) > self.n_cols:
if expand:
self.expand_rows(col + len(cols) - self.n_cols)
else:
raise Exception('number of cols does not fit in table')
for c in range(len(cols)):
self.rows[row + r].cells[col + c] = Cell(data[r][c], fmt)
class Row(object):
def __init__(self, cells, pre_separator=None, post_separator=None):
self.cells = cells
self.pre_separator = pre_separator
self.post_separator = post_separator
@classmethod class Row(object):
def Empty(cls, n_cols): def __init__(self, cells, pre_separator=None, post_separator=None):
return Row([Cell() for c in range(n_cols)]) self.cells = cells
self.pre_separator = pre_separator
self.post_separator = post_separator
def add_cell(self, cell): @classmethod
self.cells.append(cell) def Empty(cls, n_cols):
return Row([Cell() for c in range(n_cols)])
def ncols(self): def add_cell(self, cell):
return sum([c.span for c in self.cells]) self.cells.append(cell)
def ncols(self):
return sum([c.span for c in self.cells])
class Color(object): class Color(object):
def __init__(self, color=(0,0,0), fmt='rgb'): def __init__(self, color=(0, 0, 0), fmt='rgb'):
if fmt == 'rgb': if fmt == 'rgb':
self.color = color self.color = color
elif fmt == 'RGB': elif fmt == 'RGB':
self.color = tuple(c / 255 for c in color) self.color = tuple(c / 255 for c in color)
else: else:
return Exception('invalid color format') return Exception('invalid color format')
def as_rgb(self): def as_rgb(self):
return self.color return self.color
def as_RGB(self): def as_RGB(self):
return tuple(int(c * 255) for c in self.color) return tuple(int(c * 255) for c in self.color)
@classmethod @classmethod
def rgb(cls, r, g, b): def rgb(cls, r, g, b):
return Color(color=(r,g,b), fmt='rgb') return Color(color=(r, g, b), fmt='rgb')
@classmethod @classmethod
def RGB(cls, r, g, b): def RGB(cls, r, g, b):
return Color(color=(r,g,b), fmt='RGB') return Color(color=(r, g, b), fmt='RGB')
class CellFormat(object): class CellFormat(object):
def __init__(self, fmt='%s', fgcolor=None, bgcolor=None, bold=False): def __init__(self, fmt='%s', fgcolor=None, bgcolor=None, bold=False):
self.fmt = fmt self.fmt = fmt
self.fgcolor = fgcolor self.fgcolor = fgcolor
self.bgcolor = bgcolor self.bgcolor = bgcolor
self.bold = bold self.bold = bold
class Cell(object): class Cell(object):
def __init__(self, data=None, fmt=None, span=1, align=None): def __init__(self, data=None, fmt=None, span=1, align=None):
self.data = data self.data = data
if fmt is None: if fmt is None:
fmt = CellFormat() fmt = CellFormat()
self.fmt = fmt self.fmt = fmt
self.span = span self.span = span
self.align = align self.align = align
def __str__(self):
return self.fmt.fmt % self.data
def __str__(self):
return self.fmt.fmt % self.data
class Separator(enum.Enum): class Separator(enum.Enum):
HEAD = 1 HEAD = 1
BOTTOM = 2 BOTTOM = 2
INNER = 3 INNER = 3
class Renderer(object): class Renderer(object):
def __init__(self): def __init__(self):
pass pass
def cell_str_len(self, cell): def cell_str_len(self, cell):
return len(str(cell)) return len(str(cell))
def col_widths(self, table): def col_widths(self, table):
widths = [0 for c in range(table.n_cols)] widths = [0 for c in range(table.n_cols)]
for row in table.rows: for row in table.rows:
cidx = 0 cidx = 0
for cell in row.cells: for cell in row.cells:
if cell.span == 1: if cell.span == 1:
strlen = self.cell_str_len(cell) strlen = self.cell_str_len(cell)
widths[cidx] = max(widths[cidx], strlen) widths[cidx] = max(widths[cidx], strlen)
cidx += cell.span cidx += cell.span
return widths return widths
def render(self, table): def render(self, table):
raise NotImplementedError('not implemented') raise NotImplementedError('not implemented')
def __call__(self, table): def __call__(self, table):
return self.render(table) return self.render(table)
def render_to_file_comment(self): def render_to_file_comment(self):
return '' return ''
def render_to_file(self, path, table):
txt = self.render(table)
with open(path, 'w') as fp:
fp.write(txt)
def render_to_file(self, path, table):
txt = self.render(table)
with open(path, 'w') as fp:
fp.write(txt)
class TerminalRenderer(Renderer): class TerminalRenderer(Renderer):
def __init__(self, col_sep=' '): def __init__(self, col_sep=' '):
super().__init__() super().__init__()
self.col_sep = col_sep self.col_sep = col_sep
def render_cell(self, table, row, col, widths): def render_cell(self, table, row, col, widths):
cell = table.rows[row].cells[col] cell = table.rows[row].cells[col]
str = cell.fmt.fmt % cell.data str = cell.fmt.fmt % cell.data
str_width = len(str) str_width = len(str)
cell_width = sum([widths[idx] for idx in range(col, col+cell.span)]) cell_width = sum([widths[idx] for idx in range(col, col + cell.span)])
cell_width += len(self.col_sep) * (cell.span - 1) cell_width += len(self.col_sep) * (cell.span - 1)
if len(str) > cell_width: if len(str) > cell_width:
str = str[:cell_width] str = str[:cell_width]
if cell.fmt.bold: if cell.fmt.bold:
# str = sty.ef.bold + str + sty.rs.bold_dim # str = sty.ef.bold + str + sty.rs.bold_dim
# str = sty.ef.bold + str + sty.rs.bold # str = sty.ef.bold + str + sty.rs.bold
pass pass
if cell.fmt.fgcolor is not None: if cell.fmt.fgcolor is not None:
# color = cell.fmt.fgcolor.as_RGB() # color = cell.fmt.fgcolor.as_RGB()
# str = sty.fg(*color) + str + sty.rs.fg # str = sty.fg(*color) + str + sty.rs.fg
pass pass
if str_width < cell_width: if str_width < cell_width:
n_ws = (cell_width - str_width) n_ws = (cell_width - str_width)
if table.get_cell_align(row, col) == 'r': if table.get_cell_align(row, col) == 'r':
str = ' '*n_ws + str str = ' ' * n_ws + str
elif table.get_cell_align(row, col) == 'l': elif table.get_cell_align(row, col) == 'l':
str = str + ' '*n_ws str = str + ' ' * n_ws
elif table.get_cell_align(row, col) == 'c': elif table.get_cell_align(row, col) == 'c':
n_ws1 = n_ws // 2 n_ws1 = n_ws // 2
n_ws0 = n_ws - n_ws1 n_ws0 = n_ws - n_ws1
str = ' '*n_ws0 + str + ' '*n_ws1 str = ' ' * n_ws0 + str + ' ' * n_ws1
if cell.fmt.bgcolor is not None: if cell.fmt.bgcolor is not None:
# color = cell.fmt.bgcolor.as_RGB() # color = cell.fmt.bgcolor.as_RGB()
# str = sty.bg(*color) + str + sty.rs.bg # str = sty.bg(*color) + str + sty.rs.bg
pass pass
return str return str
def render_separator(self, separator, tab, col_widths, total_width): def render_separator(self, separator, tab, col_widths, total_width):
if separator == Separator.HEAD: if separator == Separator.HEAD:
return '='*total_width return '=' * total_width
elif separator == Separator.INNER: elif separator == Separator.INNER:
return '-'*total_width return '-' * total_width
elif separator == Separator.BOTTOM: elif separator == Separator.BOTTOM:
return '='*total_width return '=' * total_width
def render(self, table): def render(self, table):
widths = self.col_widths(table) widths = self.col_widths(table)
total_width = sum(widths) + len(self.col_sep) * (table.n_cols - 1) total_width = sum(widths) + len(self.col_sep) * (table.n_cols - 1)
lines = [] lines = []
for ridx, row in enumerate(table.rows): for ridx, row in enumerate(table.rows):
if row.pre_separator is not None: if row.pre_separator is not None:
sepline = self.render_separator(row.pre_separator, table, widths, total_width) sepline = self.render_separator(row.pre_separator, table, widths, total_width)
if len(sepline) > 0: if len(sepline) > 0:
lines.append(sepline) lines.append(sepline)
line = [] line = []
for cidx, cell in enumerate(row.cells): for cidx, cell in enumerate(row.cells):
line.append(self.render_cell(table, ridx, cidx, widths)) line.append(self.render_cell(table, ridx, cidx, widths))
lines.append(self.col_sep.join(line)) lines.append(self.col_sep.join(line))
if row.post_separator is not None: if row.post_separator is not None:
sepline = self.render_separator(row.post_separator, table, widths, total_width) sepline = self.render_separator(row.post_separator, table, widths, total_width)
if len(sepline) > 0: if len(sepline) > 0:
lines.append(sepline) lines.append(sepline)
return '\n'.join(lines) return '\n'.join(lines)
class MarkdownRenderer(TerminalRenderer): class MarkdownRenderer(TerminalRenderer):
def __init__(self): def __init__(self):
super().__init__(col_sep='|') super().__init__(col_sep='|')
self.printed_color_warning = False self.printed_color_warning = False
def print_color_warning(self): def print_color_warning(self):
if not self.printed_color_warning: if not self.printed_color_warning:
print('[WARNING] MarkdownRenderer does not support color yet') print('[WARNING] MarkdownRenderer does not support color yet')
self.printed_color_warning = True self.printed_color_warning = True
def cell_str_len(self, cell): def cell_str_len(self, cell):
strlen = len(str(cell)) strlen = len(str(cell))
if cell.fmt.bold: if cell.fmt.bold:
strlen += 4 strlen += 4
strlen = max(5, strlen) strlen = max(5, strlen)
return strlen return strlen
def render_cell(self, table, row, col, widths): def render_cell(self, table, row, col, widths):
cell = table.rows[row].cells[col] cell = table.rows[row].cells[col]
str = cell.fmt.fmt % cell.data str = cell.fmt.fmt % cell.data
if cell.fmt.bold: if cell.fmt.bold:
str = f'**{str}**' str = f'**{str}**'
str_width = len(str) str_width = len(str)
cell_width = sum([widths[idx] for idx in range(col, col+cell.span)]) cell_width = sum([widths[idx] for idx in range(col, col + cell.span)])
cell_width += len(self.col_sep) * (cell.span - 1) cell_width += len(self.col_sep) * (cell.span - 1)
if len(str) > cell_width: if len(str) > cell_width:
str = str[:cell_width] str = str[:cell_width]
else: else:
n_ws = (cell_width - str_width) n_ws = (cell_width - str_width)
if table.get_cell_align(row, col) == 'r': if table.get_cell_align(row, col) == 'r':
str = ' '*n_ws + str str = ' ' * n_ws + str
elif table.get_cell_align(row, col) == 'l': elif table.get_cell_align(row, col) == 'l':
str = str + ' '*n_ws str = str + ' ' * n_ws
elif table.get_cell_align(row, col) == 'c': elif table.get_cell_align(row, col) == 'c':
n_ws1 = n_ws // 2 n_ws1 = n_ws // 2
n_ws0 = n_ws - n_ws1 n_ws0 = n_ws - n_ws1
str = ' '*n_ws0 + str + ' '*n_ws1 str = ' ' * n_ws0 + str + ' ' * n_ws1
if col == 0: str = self.col_sep + str if col == 0: str = self.col_sep + str
if col == table.n_cols - 1: str += self.col_sep if col == table.n_cols - 1: str += self.col_sep
if cell.fmt.fgcolor is not None: if cell.fmt.fgcolor is not None:
self.print_color_warning() self.print_color_warning()
if cell.fmt.bgcolor is not None: if cell.fmt.bgcolor is not None:
self.print_color_warning() self.print_color_warning()
return str return str
def render_separator(self, separator, tab, widths, total_width): def render_separator(self, separator, tab, widths, total_width):
sep = '' sep = ''
if separator == Separator.INNER: if separator == Separator.INNER:
sep = self.col_sep sep = self.col_sep
for idx, width in enumerate(widths): for idx, width in enumerate(widths):
csep = '-' * (width - 2) csep = '-' * (width - 2)
if tab.get_cell_align(1, idx) == 'r': if tab.get_cell_align(1, idx) == 'r':
csep = '-' + csep + ':' csep = '-' + csep + ':'
elif tab.get_cell_align(1, idx) == 'l': elif tab.get_cell_align(1, idx) == 'l':
csep = ':' + csep + '-' csep = ':' + csep + '-'
elif tab.get_cell_align(1, idx) == 'c': elif tab.get_cell_align(1, idx) == 'c':
csep = ':' + csep + ':' csep = ':' + csep + ':'
sep += csep + self.col_sep sep += csep + self.col_sep
return sep return sep
class LatexRenderer(Renderer): class LatexRenderer(Renderer):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def render_cell(self, table, row, col): def render_cell(self, table, row, col):
cell = table.rows[row].cells[col] cell = table.rows[row].cells[col]
str = cell.fmt.fmt % cell.data str = cell.fmt.fmt % cell.data
if cell.fmt.bold: if cell.fmt.bold:
str = '{\\bf '+ str + '}' str = '{\\bf ' + str + '}'
if cell.fmt.fgcolor is not None: if cell.fmt.fgcolor is not None:
color = cell.fmt.fgcolor.as_rgb() color = cell.fmt.fgcolor.as_rgb()
str = f'{{\\color[rgb]{{{color[0]},{color[1]},{color[2]}}} ' + str + '}' str = f'{{\\color[rgb]{{{color[0]},{color[1]},{color[2]}}} ' + str + '}'
if cell.fmt.bgcolor is not None: if cell.fmt.bgcolor is not None:
color = cell.fmt.bgcolor.as_rgb() color = cell.fmt.bgcolor.as_rgb()
str = f'\\cellcolor[rgb]{{{color[0]},{color[1]},{color[2]}}} ' + str str = f'\\cellcolor[rgb]{{{color[0]},{color[1]},{color[2]}}} ' + str
align = table.get_cell_align(row, col) align = table.get_cell_align(row, col)
if cell.span != 1 or align != table.aligns[col]: if cell.span != 1 or align != table.aligns[col]:
str = f'\\multicolumn{{{cell.span}}}{{{align}}}{{{str}}}' str = f'\\multicolumn{{{cell.span}}}{{{align}}}{{{str}}}'
return str return str
def render_separator(self, separator): def render_separator(self, separator):
if separator == Separator.HEAD: if separator == Separator.HEAD:
return '\\toprule' return '\\toprule'
elif separator == Separator.INNER: elif separator == Separator.INNER:
return '\\midrule' return '\\midrule'
elif separator == Separator.BOTTOM: elif separator == Separator.BOTTOM:
return '\\bottomrule' return '\\bottomrule'
def render(self, table): def render(self, table):
lines = ['\\begin{tabular}{' + ''.join(table.aligns) + '}'] lines = ['\\begin{tabular}{' + ''.join(table.aligns) + '}']
for ridx, row in enumerate(table.rows): for ridx, row in enumerate(table.rows):
if row.pre_separator is not None: if row.pre_separator is not None:
lines.append(self.render_separator(row.pre_separator)) lines.append(self.render_separator(row.pre_separator))
line = [] line = []
for cidx, cell in enumerate(row.cells): for cidx, cell in enumerate(row.cells):
line.append(self.render_cell(table, ridx, cidx)) line.append(self.render_cell(table, ridx, cidx))
lines.append(' & '.join(line) + ' \\\\') lines.append(' & '.join(line) + ' \\\\')
if row.post_separator is not None: if row.post_separator is not None:
lines.append(self.render_separator(row.post_separator)) lines.append(self.render_separator(row.post_separator))
lines.append('\\end{tabular}') lines.append('\\end{tabular}')
return '\n'.join(lines) return '\n'.join(lines)
class HtmlRenderer(Renderer):
def __init__(self, html_class='result_table'):
super().__init__()
self.html_class = html_class
def render_cell(self, table, row, col):
cell = table.rows[row].cells[col]
str = cell.fmt.fmt % cell.data
styles = []
if cell.fmt.bold:
styles.append('font-weight: bold;')
if cell.fmt.fgcolor is not None:
color = cell.fmt.fgcolor.as_RGB()
styles.append(f'color: rgb({color[0]},{color[1]},{color[2]});')
if cell.fmt.bgcolor is not None:
color = cell.fmt.bgcolor.as_RGB()
styles.append(f'background-color: rgb({color[0]},{color[1]},{color[2]});')
align = table.get_cell_align(row, col)
if align == 'l': align = 'left'
elif align == 'r': align = 'right'
elif align == 'c': align = 'center'
else: raise Exception('invalid align')
styles.append(f'text-align: {align};')
row = table.rows[row]
if row.pre_separator is not None:
styles.append(f'border-top: {self.render_separator(row.pre_separator)};')
if row.post_separator is not None:
styles.append(f'border-bottom: {self.render_separator(row.post_separator)};')
style = ' '.join(styles)
str = f' <td style="{style}" colspan="{cell.span}">{str}</td>\n'
return str
def render_separator(self, separator):
if separator == Separator.HEAD:
return '1.5pt solid black'
elif separator == Separator.INNER:
return '0.75pt solid black'
elif separator == Separator.BOTTOM:
return '1.5pt solid black'
def render(self, table):
lines = [f'<table width="100%" style="border-collapse: collapse" class={self.html_class}>']
for ridx, row in enumerate(table.rows):
line = [f' <tr>\n']
for cidx, cell in enumerate(row.cells):
line.append(self.render_cell(table, ridx, cidx))
line.append(' </tr>\n')
lines.append(' '.join(line))
lines.append('</table>')
return '\n'.join(lines)
def pandas_to_table(rowname, colname, valname, data, val_cell_fmt=CellFormat(fmt='%.4f'), best_val_cell_fmt=CellFormat(fmt='%.4f', bold=True), best_is_max=[]):
rnames = data[rowname].unique()
cnames = data[colname].unique()
tab = Table(1+len(cnames))
header = [Cell('', align='r')]
header.extend([Cell(h, align='r') for h in cnames])
header = Row(header, pre_separator=Separator.HEAD, post_separator=Separator.INNER)
tab.add_row(header)
for rname in rnames:
cells = [Cell(rname, align='l')]
for cname in cnames:
cdata = data[data[colname] == cname]
if cname in best_is_max:
bestval = cdata[valname].max()
val = cdata[cdata[rowname] == rname][valname].max()
else:
bestval = cdata[valname].min()
val = cdata[cdata[rowname] == rname][valname].min()
if val == bestval:
fmt = best_val_cell_fmt
else:
fmt = val_cell_fmt
cells.append(Cell(val, align='r', fmt=fmt))
tab.add_row(Row(cells))
tab.rows[-1].post_separator = Separator.BOTTOM
return tab
class HtmlRenderer(Renderer):
def __init__(self, html_class='result_table'):
super().__init__()
self.html_class = html_class
def render_cell(self, table, row, col):
cell = table.rows[row].cells[col]
str = cell.fmt.fmt % cell.data
styles = []
if cell.fmt.bold:
styles.append('font-weight: bold;')
if cell.fmt.fgcolor is not None:
color = cell.fmt.fgcolor.as_RGB()
styles.append(f'color: rgb({color[0]},{color[1]},{color[2]});')
if cell.fmt.bgcolor is not None:
color = cell.fmt.bgcolor.as_RGB()
styles.append(f'background-color: rgb({color[0]},{color[1]},{color[2]});')
align = table.get_cell_align(row, col)
if align == 'l':
align = 'left'
elif align == 'r':
align = 'right'
elif align == 'c':
align = 'center'
else:
raise Exception('invalid align')
styles.append(f'text-align: {align};')
row = table.rows[row]
if row.pre_separator is not None:
styles.append(f'border-top: {self.render_separator(row.pre_separator)};')
if row.post_separator is not None:
styles.append(f'border-bottom: {self.render_separator(row.post_separator)};')
style = ' '.join(styles)
str = f' <td style="{style}" colspan="{cell.span}">{str}</td>\n'
return str
def render_separator(self, separator):
if separator == Separator.HEAD:
return '1.5pt solid black'
elif separator == Separator.INNER:
return '0.75pt solid black'
elif separator == Separator.BOTTOM:
return '1.5pt solid black'
def render(self, table):
lines = [f'<table width="100%" style="border-collapse: collapse" class={self.html_class}>']
for ridx, row in enumerate(table.rows):
line = [f' <tr>\n']
for cidx, cell in enumerate(row.cells):
line.append(self.render_cell(table, ridx, cidx))
line.append(' </tr>\n')
lines.append(' '.join(line))
lines.append('</table>')
return '\n'.join(lines)
def pandas_to_table(rowname, colname, valname, data, val_cell_fmt=CellFormat(fmt='%.4f'),
best_val_cell_fmt=CellFormat(fmt='%.4f', bold=True), best_is_max=[]):
rnames = data[rowname].unique()
cnames = data[colname].unique()
tab = Table(1 + len(cnames))
header = [Cell('', align='r')]
header.extend([Cell(h, align='r') for h in cnames])
header = Row(header, pre_separator=Separator.HEAD, post_separator=Separator.INNER)
tab.add_row(header)
for rname in rnames:
cells = [Cell(rname, align='l')]
for cname in cnames:
cdata = data[data[colname] == cname]
if cname in best_is_max:
bestval = cdata[valname].max()
val = cdata[cdata[rowname] == rname][valname].max()
else:
bestval = cdata[valname].min()
val = cdata[cdata[rowname] == rname][valname].min()
if val == bestval:
fmt = best_val_cell_fmt
else:
fmt = val_cell_fmt
cells.append(Cell(val, align='r', fmt=fmt))
tab.add_row(Row(cells))
tab.rows[-1].post_separator = Separator.BOTTOM
return tab
if __name__ == '__main__': if __name__ == '__main__':
# df = pd.read_pickle('full.df') # df = pd.read_pickle('full.df')
# best_is_max = ['movF0.5', 'movF1.0'] # best_is_max = ['movF0.5', 'movF1.0']
# tab = pandas_to_table(rowname='method', colname='metric', valname='val', data=df, best_is_max=best_is_max) # tab = pandas_to_table(rowname='method', colname='metric', valname='val', data=df, best_is_max=best_is_max)
# renderer = TerminalRenderer() # renderer = TerminalRenderer()
# print(renderer(tab)) # print(renderer(tab))
tab = Table(7) tab = Table(7)
# header = Row([Cell('header', span=7, align='c')], pre_separator=Separator.HEAD, post_separator=Separator.INNER) # header = Row([Cell('header', span=7, align='c')], pre_separator=Separator.HEAD, post_separator=Separator.INNER)
# tab.add_row(header) # tab.add_row(header)
# header2 = Row([Cell('thisisaverylongheader', span=4, align='c'), Cell('vals2', span=3, align='c')], post_separator=Separator.INNER) # header2 = Row([Cell('thisisaverylongheader', span=4, align='c'), Cell('vals2', span=3, align='c')], post_separator=Separator.INNER)
# tab.add_row(header2) # tab.add_row(header2)
tab.add_row(Row([Cell(f'c{c}') for c in range(7)])) tab.add_row(Row([Cell(f'c{c}') for c in range(7)]))
tab.rows[-1].post_separator = Separator.INNER tab.rows[-1].post_separator = Separator.INNER
tab.add_block(np.arange(15*7).reshape(15,7)) tab.add_block(np.arange(15 * 7).reshape(15, 7))
tab.rows[4].cells[2].fmt = CellFormat(bold=True) tab.rows[4].cells[2].fmt = CellFormat(bold=True)
tab.rows[2].cells[1].fmt = CellFormat(fgcolor=Color.rgb(0.2,0.6,0.1)) tab.rows[2].cells[1].fmt = CellFormat(fgcolor=Color.rgb(0.2, 0.6, 0.1))
tab.rows[2].cells[2].fmt = CellFormat(bgcolor=Color.rgb(0.7,0.1,0.5)) tab.rows[2].cells[2].fmt = CellFormat(bgcolor=Color.rgb(0.7, 0.1, 0.5))
tab.rows[5].cells[3].fmt = CellFormat(bold=True,bgcolor=Color.rgb(0.7,0.1,0.5),fgcolor=Color.rgb(0.1,0.1,0.1)) tab.rows[5].cells[3].fmt = CellFormat(bold=True, bgcolor=Color.rgb(0.7, 0.1, 0.5), fgcolor=Color.rgb(0.1, 0.1, 0.1))
tab.rows[-1].post_separator = Separator.BOTTOM tab.rows[-1].post_separator = Separator.BOTTOM
renderer = TerminalRenderer() renderer = TerminalRenderer()
print(renderer(tab)) print(renderer(tab))
renderer = MarkdownRenderer() renderer = MarkdownRenderer()
print(renderer(tab)) print(renderer(tab))
# renderer = HtmlRenderer() # renderer = HtmlRenderer()
# html_tab = renderer(tab) # html_tab = renderer(tab)
# print(html_tab) # print(html_tab)
# with open('test.html', 'w') as fp: # with open('test.html', 'w') as fp:
# fp.write(html_tab) # fp.write(html_tab)
# import latex # import latex
# renderer = LatexRenderer() # renderer = LatexRenderer()
# ltx_tab = renderer(tab) # ltx_tab = renderer(tab)
# print(ltx_tab) # print(ltx_tab)
# with open('test.tex', 'w') as fp: # with open('test.tex', 'w') as fp:
# latex.write_doc_prefix(fp, document_class='article') # latex.write_doc_prefix(fp, document_class='article')
# fp.write('this is text that should appear before the table and should be long enough to wrap around.\n'*40) # fp.write('this is text that should appear before the table and should be long enough to wrap around.\n'*40)
# fp.write('\\begin{table}') # fp.write('\\begin{table}')
# fp.write(ltx_tab) # fp.write(ltx_tab)
# fp.write('\\end{table}') # fp.write('\\end{table}')
# fp.write('this is text that should appear after the table and should be long enough to wrap around.\n'*40) # fp.write('this is text that should appear after the table and should be long enough to wrap around.\n'*40)
# latex.write_doc_suffix(fp) # latex.write_doc_suffix(fp)

@ -8,6 +8,7 @@ import re
import pickle import pickle
import subprocess import subprocess
def str2bool(v): def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'): if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True return True
@ -16,71 +17,74 @@ def str2bool(v):
else: else:
raise argparse.ArgumentTypeError('Boolean value expected.') raise argparse.ArgumentTypeError('Boolean value expected.')
class StopWatch(object): class StopWatch(object):
def __init__(self): def __init__(self):
self.timings = OrderedDict() self.timings = OrderedDict()
self.starts = {} self.starts = {}
def start(self, name): def start(self, name):
self.starts[name] = time.time() self.starts[name] = time.time()
def stop(self, name): def stop(self, name):
if name not in self.timings: if name not in self.timings:
self.timings[name] = [] self.timings[name] = []
self.timings[name].append(time.time() - self.starts[name]) self.timings[name].append(time.time() - self.starts[name])
def get(self, name=None, reduce=np.sum): def get(self, name=None, reduce=np.sum):
if name is not None: if name is not None:
return reduce(self.timings[name]) return reduce(self.timings[name])
else: else:
ret = {} ret = {}
for k in self.timings: for k in self.timings:
ret[k] = reduce(self.timings[k]) ret[k] = reduce(self.timings[k])
return ret return ret
def __repr__(self):
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
def __str__(self):
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
def __repr__(self):
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
def __str__(self):
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
class ETA(object): class ETA(object):
def __init__(self, length): def __init__(self, length):
self.length = length self.length = length
self.start_time = time.time() self.start_time = time.time()
self.current_idx = 0 self.current_idx = 0
self.current_time = time.time() self.current_time = time.time()
def update(self, idx): def update(self, idx):
self.current_idx = idx self.current_idx = idx
self.current_time = time.time() self.current_time = time.time()
def get_elapsed_time(self): def get_elapsed_time(self):
return self.current_time - self.start_time return self.current_time - self.start_time
def get_item_time(self): def get_item_time(self):
return self.get_elapsed_time() / (self.current_idx + 1) return self.get_elapsed_time() / (self.current_idx + 1)
def get_remaining_time(self): def get_remaining_time(self):
return self.get_item_time() * (self.length - self.current_idx + 1) return self.get_item_time() * (self.length - self.current_idx + 1)
def format_time(self, seconds): def format_time(self, seconds):
minutes, seconds = divmod(seconds, 60) minutes, seconds = divmod(seconds, 60)
hours, minutes = divmod(minutes, 60) hours, minutes = divmod(minutes, 60)
hours = int(hours) hours = int(hours)
minutes = int(minutes) minutes = int(minutes)
return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}' return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}'
def get_elapsed_time_str(self): def get_elapsed_time_str(self):
return self.format_time(self.get_elapsed_time()) return self.format_time(self.get_elapsed_time())
def get_remaining_time_str(self): def get_remaining_time_str(self):
return self.format_time(self.get_remaining_time()) return self.format_time(self.get_remaining_time())
def git_hash(cwd=None):
ret = subprocess.run(['git', 'describe', '--always'], cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
hash = ret.stdout
if hash is not None and 'fatal' not in hash.decode():
return hash.decode().strip()
else:
return None
def git_hash(cwd=None):
ret = subprocess.run(['git', 'describe', '--always'], cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
hash = ret.stdout
if hash is not None and 'fatal' not in hash.decode():
return hash.decode().strip()
else:
return None

@ -4,47 +4,47 @@ import cv2
def get_patterns(path='syn', imsizes=[], crop=True): def get_patterns(path='syn', imsizes=[], crop=True):
pattern_size = imsizes[0] pattern_size = imsizes[0]
if path == 'syn': if path == 'syn':
np.random.seed(42) np.random.seed(42)
pattern = np.random.uniform(0,1, size=pattern_size) pattern = np.random.uniform(0, 1, size=pattern_size)
pattern = (pattern < 0.1).astype(np.float32) pattern = (pattern < 0.1).astype(np.float32)
pattern.reshape(*imsizes[0]) pattern.reshape(*imsizes[0])
else: else:
pattern = cv2.imread(path) pattern = cv2.imread(path)
pattern = pattern.astype(np.float32) pattern = pattern.astype(np.float32)
pattern /= 255 pattern /= 255
if pattern.ndim == 2:
pattern = np.stack([pattern for idx in range(3)], axis=2)
if crop and pattern.shape[0] > pattern_size[0] and pattern.shape[1] > pattern_size[1]:
r0 = (pattern.shape[0] - pattern_size[0]) // 2
c0 = (pattern.shape[1] - pattern_size[1]) // 2
pattern = pattern[r0:r0+imsizes[0][0], c0:c0+imsizes[0][1]]
patterns = []
for imsize in imsizes:
pat = cv2.resize(pattern, (imsize[1],imsize[0]), interpolation=cv2.INTER_LINEAR)
patterns.append(pat)
return patterns
def get_rotation_matrix(v0, v1): if pattern.ndim == 2:
v0 = v0/np.linalg.norm(v0) pattern = np.stack([pattern for idx in range(3)], axis=2)
v1 = v1/np.linalg.norm(v1)
v = np.cross(v0,v1) if crop and pattern.shape[0] > pattern_size[0] and pattern.shape[1] > pattern_size[1]:
c = np.dot(v0,v1) r0 = (pattern.shape[0] - pattern_size[0]) // 2
s = np.linalg.norm(v) c0 = (pattern.shape[1] - pattern_size[1]) // 2
I = np.eye(3) pattern = pattern[r0:r0 + imsizes[0][0], c0:c0 + imsizes[0][1]]
vXStr = '{} {} {}; {} {} {}; {} {} {}'.format(0, -v[2], v[1], v[2], 0, -v[0], -v[1], v[0], 0)
k = np.matrix(vXStr)
r = I + k + k @ k * ((1 -c)/(s**2))
return np.asarray(r.astype(np.float32))
patterns = []
for imsize in imsizes:
pat = cv2.resize(pattern, (imsize[1], imsize[0]), interpolation=cv2.INTER_LINEAR)
patterns.append(pat)
def augment_image(img,rng,disp=None,grad=None,max_shift=64,max_blur=1.5,max_noise=10.0,max_sp_noise=0.001): return patterns
def get_rotation_matrix(v0, v1):
v0 = v0 / np.linalg.norm(v0)
v1 = v1 / np.linalg.norm(v1)
v = np.cross(v0, v1)
c = np.dot(v0, v1)
s = np.linalg.norm(v)
I = np.eye(3)
vXStr = '{} {} {}; {} {} {}; {} {} {}'.format(0, -v[2], v[1], v[2], 0, -v[0], -v[1], v[0], 0)
k = np.matrix(vXStr)
r = I + k + k @ k * ((1 - c) / (s ** 2))
return np.asarray(r.astype(np.float32))
def augment_image(img, rng, disp=None, grad=None, max_shift=64, max_blur=1.5, max_noise=10.0, max_sp_noise=0.001):
# get min/max values of image # get min/max values of image
min_val = np.min(img) min_val = np.min(img)
max_val = np.max(img) max_val = np.max(img)
@ -57,54 +57,56 @@ def augment_image(img,rng,disp=None,grad=None,max_shift=64,max_blur=1.5,max_nois
grad_aug = grad grad_aug = grad
# apply affine transformation # apply affine transformation
if max_shift>1: if max_shift > 1:
# affine parameters # affine parameters
rows,cols = img.shape rows, cols = img.shape
shear = 0 shear = 0
shift = 0 shift = 0
shear_correction = 0 shear_correction = 0
if rng.uniform(0,1)<0.75: shear = rng.uniform(-max_shift,max_shift) # shear with 75% probability if rng.uniform(0, 1) < 0.75:
else: shift = rng.uniform(0,max_shift) # shift with 25% probability shear = rng.uniform(-max_shift, max_shift) # shear with 75% probability
if shear<0: shear_correction = -shear else:
shift = rng.uniform(0, max_shift) # shift with 25% probability
if shear < 0: shear_correction = -shear
# affine transformation # affine transformation
a = shear/float(rows) a = shear / float(rows)
b = shift+shear_correction b = shift + shear_correction
# warp image # warp image
T = np.float32([[1,a,b],[0,1,0]]) T = np.float32([[1, a, b], [0, 1, 0]])
img_aug = cv2.warpAffine(img_aug,T,(cols,rows)) img_aug = cv2.warpAffine(img_aug, T, (cols, rows))
if grad is not None: if grad is not None:
grad_aug = cv2.warpAffine(grad,T,(cols,rows)) grad_aug = cv2.warpAffine(grad, T, (cols, rows))
# disparity correction map # disparity correction map
col = a*np.array(range(rows))+b col = a * np.array(range(rows)) + b
disp_delta = np.tile(col,(cols,1)).transpose() disp_delta = np.tile(col, (cols, 1)).transpose()
if disp is not None: if disp is not None:
disp_aug = cv2.warpAffine(disp+disp_delta,T,(cols,rows)) disp_aug = cv2.warpAffine(disp + disp_delta, T, (cols, rows))
# gaussian smoothing # gaussian smoothing
if rng.uniform(0,1)<0.5: if rng.uniform(0, 1) < 0.5:
img_aug = cv2.GaussianBlur(img_aug,(5,5),rng.uniform(0.2,max_blur)) img_aug = cv2.GaussianBlur(img_aug, (5, 5), rng.uniform(0.2, max_blur))
# per-pixel gaussian noise # per-pixel gaussian noise
img_aug = img_aug + rng.randn(*img_aug.shape)*rng.uniform(0.0,max_noise)/255.0 img_aug = img_aug + rng.randn(*img_aug.shape) * rng.uniform(0.0, max_noise) / 255.0
# salt-and-pepper noise # salt-and-pepper noise
if rng.uniform(0,1)<0.5: if rng.uniform(0, 1) < 0.5:
ratio=rng.uniform(0.0,max_sp_noise) ratio = rng.uniform(0.0, max_sp_noise)
img_shape = img_aug.shape img_shape = img_aug.shape
img_aug = img_aug.flatten() img_aug = img_aug.flatten()
coord = rng.choice(np.size(img_aug), int(np.size(img_aug)*ratio)) coord = rng.choice(np.size(img_aug), int(np.size(img_aug) * ratio))
img_aug[coord] = max_val img_aug[coord] = max_val
coord = rng.choice(np.size(img_aug), int(np.size(img_aug)*ratio)) coord = rng.choice(np.size(img_aug), int(np.size(img_aug) * ratio))
img_aug[coord] = min_val img_aug[coord] = min_val
img_aug = np.reshape(img_aug, img_shape) img_aug = np.reshape(img_aug, img_shape)
# clip intensities back to [0,1] # clip intensities back to [0,1]
img_aug = np.maximum(img_aug,0.0) img_aug = np.maximum(img_aug, 0.0)
img_aug = np.minimum(img_aug,1.0) img_aug = np.minimum(img_aug, 1.0)
# return image # return image
return img_aug, disp_aug, grad_aug return img_aug, disp_aug, grad_aug

@ -10,261 +10,259 @@ import cv2
import os import os
import collections import collections
import sys import sys
sys.path.append('../') sys.path.append('../')
import renderer import renderer
import co import co
from commons import get_patterns,get_rotation_matrix from commons import get_patterns, get_rotation_matrix
from lcn import lcn from lcn import lcn
def get_objs(shapenet_dir, obj_classes, num_perclass=100):
shapenet = {'chair': '03001627', def get_objs(shapenet_dir, obj_classes, num_perclass=100):
'airplane': '02691156', shapenet = {'chair': '03001627',
'car': '02958343', 'airplane': '02691156',
'watercraft': '04530566'} 'car': '02958343',
'watercraft': '04530566'}
obj_paths = []
for cls in obj_classes: obj_paths = []
if cls not in shapenet.keys(): for cls in obj_classes:
raise Exception('unknown class name') if cls not in shapenet.keys():
ids = shapenet[cls] raise Exception('unknown class name')
obj_path = sorted(Path(f'{shapenet_dir}/{ids}').glob('**/models/*.obj')) ids = shapenet[cls]
obj_paths += obj_path[:num_perclass] obj_path = sorted(Path(f'{shapenet_dir}/{ids}').glob('**/models/*.obj'))
print(f'found {len(obj_paths)} object paths') obj_paths += obj_path[:num_perclass]
print(f'found {len(obj_paths)} object paths')
objs = []
for obj_path in obj_paths: objs = []
print(f'load {obj_path}') for obj_path in obj_paths:
v, f, _, n = co.io3d.read_obj(obj_path) print(f'load {obj_path}')
diffs = v.max(axis=0) - v.min(axis=0) v, f, _, n = co.io3d.read_obj(obj_path)
v /= (0.5 * diffs.max()) diffs = v.max(axis=0) - v.min(axis=0)
v -= (v.min(axis=0) + 1) v /= (0.5 * diffs.max())
f = f.astype(np.int32) v -= (v.min(axis=0) + 1)
objs.append((v,f,n)) f = f.astype(np.int32)
print(f'loaded {len(objs)} objects') objs.append((v, f, n))
print(f'loaded {len(objs)} objects')
return objs
return objs
def get_mesh(rng, min_z=0): def get_mesh(rng, min_z=0):
# set up background board # set up background board
verts, faces, normals, colors = [], [], [], [] verts, faces, normals, colors = [], [], [], []
v, f, n = co.geometry.xyplane(z=0, interleaved=True) v, f, n = co.geometry.xyplane(z=0, interleaved=True)
v[:,2] += -v[:,2].min() + rng.uniform(2,7) v[:, 2] += -v[:, 2].min() + rng.uniform(2, 7)
v[:,:2] *= 5e2 v[:, :2] *= 5e2
v[:,2] = np.mean(v[:,2]) + (v[:,2] - np.mean(v[:,2])) * 5e2 v[:, 2] = np.mean(v[:, 2]) + (v[:, 2] - np.mean(v[:, 2])) * 5e2
c = np.empty_like(v)
c[:] = rng.uniform(0,1, size=(3,)).astype(np.float32)
verts.append(v)
faces.append(f)
normals.append(n)
colors.append(c)
# randomly sample 4 foreground objects for each scene
for shape_idx in range(4):
v, f, n = objs[rng.randint(0,len(objs))]
v, f, n = v.copy(), f.copy(), n.copy()
s = rng.uniform(0.25, 1)
v *= s
R = co.geometry.rotm_from_quat(co.geometry.quat_random(rng=rng))
v = v @ R.T
n = n @ R.T
v[:,2] += -v[:,2].min() + min_z + rng.uniform(0.5, 3)
v[:,:2] += rng.uniform(-1, 1, size=(1,2))
c = np.empty_like(v) c = np.empty_like(v)
c[:] = rng.uniform(0,1, size=(3,)).astype(np.float32) c[:] = rng.uniform(0, 1, size=(3,)).astype(np.float32)
verts.append(v)
verts.append(v.astype(np.float32))
faces.append(f) faces.append(f)
normals.append(n) normals.append(n)
colors.append(c) colors.append(c)
verts, faces = co.geometry.stack_mesh(verts, faces) # randomly sample 4 foreground objects for each scene
normals = np.vstack(normals).astype(np.float32) for shape_idx in range(4):
colors = np.vstack(colors).astype(np.float32) v, f, n = objs[rng.randint(0, len(objs))]
return verts, faces, colors, normals v, f, n = v.copy(), f.copy(), n.copy()
s = rng.uniform(0.25, 1)
v *= s
R = co.geometry.rotm_from_quat(co.geometry.quat_random(rng=rng))
v = v @ R.T
n = n @ R.T
v[:, 2] += -v[:, 2].min() + min_z + rng.uniform(0.5, 3)
v[:, :2] += rng.uniform(-1, 1, size=(1, 2))
c = np.empty_like(v)
c[:] = rng.uniform(0, 1, size=(3,)).astype(np.float32)
verts.append(v.astype(np.float32))
faces.append(f)
normals.append(n)
colors.append(c)
verts, faces = co.geometry.stack_mesh(verts, faces)
normals = np.vstack(normals).astype(np.float32)
colors = np.vstack(colors).astype(np.float32)
return verts, faces, colors, normals
def create_data(out_root, idx, n_samples, imsize, patterns, K, baseline, blend_im, noise, track_length=4): def create_data(out_root, idx, n_samples, imsize, patterns, K, baseline, blend_im, noise, track_length=4):
tic = time.time()
rng = np.random.RandomState()
rng.seed(idx)
verts, faces, colors, normals = get_mesh(rng)
data = renderer.PyRenderInput(verts=verts.copy(), colors=colors.copy(), normals=normals.copy(), faces=faces.copy())
print(f'loading mesh for sample {idx + 1}/{n_samples} took {time.time() - tic}[s]')
# let the camera point to the center
center = np.array([0, 0, 3], dtype=np.float32)
basevec = np.array([-baseline, 0, 0], dtype=np.float32)
unit = np.array([0, 0, 1], dtype=np.float32)
cam_x_ = rng.uniform(-0.2, 0.2)
cam_y_ = rng.uniform(-0.2, 0.2)
cam_z_ = rng.uniform(-0.2, 0.2)
ret = collections.defaultdict(list)
blend_im_rnd = np.clip(blend_im + rng.uniform(-0.1, 0.1), 0, 1)
# capture the same static scene from different view points as a track
for ind in range(track_length):
cam_x = cam_x_ + rng.uniform(-0.1, 0.1)
cam_y = cam_y_ + rng.uniform(-0.1, 0.1)
cam_z = cam_z_ + rng.uniform(-0.1, 0.1)
tcam = np.array([cam_x, cam_y, cam_z], dtype=np.float32)
if np.linalg.norm(tcam[0:2]) < 1e-9:
Rcam = np.eye(3, dtype=np.float32)
else:
Rcam = get_rotation_matrix(center, center - tcam)
tproj = tcam + basevec
Rproj = Rcam
ret['R'].append(Rcam)
ret['t'].append(tcam)
cams = []
projs = []
# render the scene at multiple scales
scales = [1, 0.5, 0.25, 0.125]
for scale in scales:
fx = K[0, 0] * scale
fy = K[1, 1] * scale
px = K[0, 2] * scale
py = K[1, 2] * scale
im_height = imsize[0] * scale
im_width = imsize[1] * scale
cams.append(renderer.PyCamera(fx, fy, px, py, Rcam, tcam, im_width, im_height))
projs.append(renderer.PyCamera(fx, fy, px, py, Rproj, tproj, im_width, im_height))
for s, cam, proj, pattern in zip(itertools.count(), cams, projs, patterns):
fl = K[0, 0] / (2 ** s)
tic = time.time() shader = renderer.PyShader(0.5, 1.5, 0.0, 10)
rng = np.random.RandomState() pyrenderer = renderer.PyRenderer(cam, shader, engine='gpu')
pyrenderer.mesh_proj(data, proj, pattern, d_alpha=0, d_beta=0.35)
rng.seed(idx) # get the reflected laser pattern $R$
im = pyrenderer.color().copy()
depth = pyrenderer.depth().copy()
disp = baseline * fl / depth
mask = depth > 0
im = np.mean(im, axis=2)
verts, faces, colors, normals = get_mesh(rng) # get the ambient image $A$
data = renderer.PyRenderInput(verts=verts.copy(), colors=colors.copy(), normals=normals.copy(), faces=faces.copy()) ambient = pyrenderer.normal().copy()
print(f'loading mesh for sample {idx+1}/{n_samples} took {time.time()-tic}[s]') ambient = np.mean(ambient, axis=2)
# get the noise free IR image $J$
im = blend_im_rnd * im + (1 - blend_im_rnd) * ambient
ret[f'ambient{s}'].append(ambient[None].astype(np.float32))
# let the camera point to the center # get the gradient magnitude of the ambient image $|\nabla A|$
center = np.array([0,0,3], dtype=np.float32) ambient = ambient.astype(np.float32)
sobelx = cv2.Sobel(ambient, cv2.CV_32F, 1, 0, ksize=5)
sobely = cv2.Sobel(ambient, cv2.CV_32F, 0, 1, ksize=5)
grad = np.sqrt(sobelx ** 2 + sobely ** 2)
grad = np.maximum(grad - 0.8, 0.0) # parameter
basevec = np.array([-baseline,0,0], dtype=np.float32) # get the local contract normalized grad LCN($|\nabla A|$)
unit = np.array([0,0,1],dtype=np.float32) grad_lcn, grad_std = lcn.normalize(grad, 5, 0.1)
grad_lcn = np.clip(grad_lcn, 0.0, 1.0) # parameter
ret[f'grad{s}'].append(grad_lcn[None].astype(np.float32))
cam_x_ = rng.uniform(-0.2,0.2) ret[f'im{s}'].append(im[None].astype(np.float32))
cam_y_ = rng.uniform(-0.2,0.2) ret[f'mask{s}'].append(mask[None].astype(np.float32))
cam_z_ = rng.uniform(-0.2,0.2) ret[f'disp{s}'].append(disp[None].astype(np.float32))
ret = collections.defaultdict(list) for key in ret.keys():
blend_im_rnd = np.clip(blend_im + rng.uniform(-0.1,0.1), 0,1) ret[key] = np.stack(ret[key], axis=0)
# capture the same static scene from different view points as a track # save to files
for ind in range(track_length): out_dir = out_root / f'{idx:08d}'
out_dir.mkdir(exist_ok=True, parents=True)
for k, val in ret.items():
for tidx in range(track_length):
v = val[tidx]
out_path = out_dir / f'{k}_{tidx}.npy'
np.save(out_path, v)
np.save(str(out_dir / 'blend_im.npy'), blend_im_rnd)
cam_x = cam_x_ + rng.uniform(-0.1,0.1) print(f'create sample {idx + 1}/{n_samples} took {time.time() - tic}[s]')
cam_y = cam_y_ + rng.uniform(-0.1,0.1)
cam_z = cam_z_ + rng.uniform(-0.1,0.1)
tcam = np.array([cam_x, cam_y, cam_z], dtype=np.float32)
if np.linalg.norm(tcam[0:2])<1e-9: if __name__ == '__main__':
Rcam = np.eye(3, dtype=np.float32)
np.random.seed(42)
# output directory
with open('../config.json') as fp:
config = json.load(fp)
data_root = Path(config['DATA_ROOT'])
shapenet_root = config['SHAPENET_ROOT']
data_type = 'syn'
out_root = data_root / f'{data_type}'
out_root.mkdir(parents=True, exist_ok=True)
start = 0
if len(sys.argv) >= 2 and isinstance(sys.argv[2], int):
start = sys.argv[2]
else: else:
Rcam = get_rotation_matrix(center, center-tcam) if sys.argv[2] == '--resume':
try:
tproj = tcam + basevec start = max([int(dir) for dir in os.listdir(out_root) if str.isdigit(dir)]) or 0
Rproj = Rcam except:
pass
ret['R'].append(Rcam)
ret['t'].append(tcam) # load shapenet models
obj_classes = ['chair']
cams = [] objs = get_objs(shapenet_root, obj_classes)
projs = []
# camera parameters
# render the scene at multiple scales imsize = (488, 648)
scales = [1, 0.5, 0.25, 0.125] imsizes = [(imsize[0] // (2 ** s), imsize[1] // (2 ** s)) for s in range(4)]
# K = np.array([[567.6, 0, 324.7], [0, 570.2, 250.1], [0 ,0, 1]], dtype=np.float32)
for scale in scales: K = np.array([[1929.5936336276382, 0, 113.66561071478046], [0, 1911.2517985448746, 473.70108079885887], [0, 0, 1]],
fx = K[0,0] * scale dtype=np.float32)
fy = K[1,1] * scale focal_lengths = [K[0, 0] / (2 ** s) for s in range(4)]
px = K[0,2] * scale baseline = 0.075
py = K[1,2] * scale blend_im = 0.6
im_height = imsize[0] * scale noise = 0
im_width = imsize[1] * scale
cams.append( renderer.PyCamera(fx,fy,px,py, Rcam, tcam, im_width, im_height) ) # capture the same static scene from different view points as a track
projs.append( renderer.PyCamera(fx,fy,px,py, Rproj, tproj, im_width, im_height) ) track_length = 4
# load pattern image
for s, cam, proj, pattern in zip(itertools.count(), cams, projs, patterns): pattern_path = './kinect_pattern.png'
fl = K[0,0] / (2**s) pattern_crop = True
patterns = get_patterns(pattern_path, imsizes, pattern_crop)
shader = renderer.PyShader(0.5,1.5,0.0,10)
pyrenderer = renderer.PyRenderer(cam, shader, engine='gpu') # write settings to file
pyrenderer.mesh_proj(data, proj, pattern, d_alpha=0, d_beta=0.35) settings = {
'imsizes': imsizes,
# get the reflected laser pattern $R$ 'patterns': patterns,
im = pyrenderer.color().copy() 'focal_lengths': focal_lengths,
depth = pyrenderer.depth().copy() 'baseline': baseline,
disp = baseline * fl / depth 'K': K,
mask = depth > 0 }
im = np.mean(im, axis=2) out_path = out_root / f'settings.pkl'
print(f'write settings to {out_path}')
# get the ambient image $A$ with open(str(out_path), 'wb') as f:
ambient = pyrenderer.normal().copy() pickle.dump(settings, f, pickle.HIGHEST_PROTOCOL)
ambient = np.mean(ambient, axis=2)
# start the job
# get the noise free IR image $J$ n_samples = 2 ** 10 + 2 ** 13
im = blend_im_rnd * im + (1 - blend_im_rnd) * ambient for idx in range(start, n_samples):
ret[f'ambient{s}'].append( ambient[None].astype(np.float32) ) args = (out_root, idx, n_samples, imsize, patterns, K, baseline, blend_im, noise, track_length)
create_data(*args)
# get the gradient magnitude of the ambient image $|\nabla A|$
ambient = ambient.astype(np.float32)
sobelx = cv2.Sobel(ambient,cv2.CV_32F,1,0,ksize=5)
sobely = cv2.Sobel(ambient,cv2.CV_32F,0,1,ksize=5)
grad = np.sqrt(sobelx**2 + sobely**2)
grad = np.maximum(grad-0.8,0.0) # parameter
# get the local contract normalized grad LCN($|\nabla A|$)
grad_lcn, grad_std = lcn.normalize(grad,5,0.1)
grad_lcn = np.clip(grad_lcn,0.0,1.0) # parameter
ret[f'grad{s}'].append( grad_lcn[None].astype(np.float32))
ret[f'im{s}'].append( im[None].astype(np.float32))
ret[f'mask{s}'].append(mask[None].astype(np.float32))
ret[f'disp{s}'].append(disp[None].astype(np.float32))
for key in ret.keys():
ret[key] = np.stack(ret[key], axis=0)
# save to files
out_dir = out_root / f'{idx:08d}'
out_dir.mkdir(exist_ok=True, parents=True)
for k,val in ret.items():
for tidx in range(track_length):
v = val[tidx]
out_path = out_dir / f'{k}_{tidx}.npy'
np.save(out_path, v)
np.save( str(out_dir /'blend_im.npy'), blend_im_rnd)
print(f'create sample {idx+1}/{n_samples} took {time.time()-tic}[s]')
if __name__=='__main__':
np.random.seed(42)
# output directory
with open('../config.json') as fp:
config = json.load(fp)
data_root = Path(config['DATA_ROOT'])
shapenet_root = config['SHAPENET_ROOT']
data_type = 'syn'
out_root = data_root / f'{data_type}'
out_root.mkdir(parents=True, exist_ok=True)
start = 0
if len(sys.argv) >= 2 and isinstance(sys.argv[2], int):
start = sys.argv[2]
else:
if sys.argv[2] == '--resume':
try:
start = max([int(dir) for dir in os.listdir(out_root) if str.isdigit(dir)]) or 0
except:
pass
# load shapenet models
obj_classes = ['chair']
objs = get_objs(shapenet_root, obj_classes)
# camera parameters
imsize = (488, 648)
imsizes = [(imsize[0]//(2**s), imsize[1]//(2**s)) for s in range(4)]
# K = np.array([[567.6, 0, 324.7], [0, 570.2, 250.1], [0 ,0, 1]], dtype=np.float32)
K = np.array([[1929.5936336276382, 0, 113.66561071478046], [0, 1911.2517985448746, 473.70108079885887], [0 ,0, 1]], dtype=np.float32)
focal_lengths = [K[0,0]/(2**s) for s in range(4)]
baseline=0.075
blend_im = 0.6
noise = 0
# capture the same static scene from different view points as a track
track_length = 4
# load pattern image
pattern_path = './kinect_pattern.png'
pattern_crop = True
patterns = get_patterns(pattern_path, imsizes, pattern_crop)
# write settings to file
settings = {
'imsizes': imsizes,
'patterns': patterns,
'focal_lengths': focal_lengths,
'baseline': baseline,
'K': K,
}
out_path = out_root / f'settings.pkl'
print(f'write settings to {out_path}')
with open(str(out_path), 'wb') as f:
pickle.dump(settings, f, pickle.HIGHEST_PROTOCOL)
# start the job
n_samples = 2**10 + 2**13
for idx in range(start, n_samples):
args = (out_root, idx, n_samples, imsize, patterns, K, baseline, blend_im, noise, track_length)
create_data(*args)

@ -21,128 +21,128 @@ from .commons import get_patterns, augment_image
from mpl_toolkits.mplot3d import Axes3D from mpl_toolkits.mplot3d import Axes3D
class TrackSynDataset(torchext.BaseDataset): class TrackSynDataset(torchext.BaseDataset):
''' '''
Load locally saved synthetic dataset Load locally saved synthetic dataset
Please run ./create_syn_data.sh to generate the dataset Please run ./create_syn_data.sh to generate the dataset
''' '''
def __init__(self, settings_path, sample_paths, track_length=2, train=True, data_aug=False):
super().__init__(train=train) def __init__(self, settings_path, sample_paths, track_length=2, train=True, data_aug=False):
super().__init__(train=train)
self.settings_path = settings_path
self.sample_paths = sample_paths self.settings_path = settings_path
self.data_aug = data_aug self.sample_paths = sample_paths
self.train = train self.data_aug = data_aug
self.track_length=track_length self.train = train
assert(track_length<=4) self.track_length = track_length
assert (track_length <= 4)
with open(str(settings_path), 'rb') as f:
settings = pickle.load(f) with open(str(settings_path), 'rb') as f:
self.imsizes = settings['imsizes'] settings = pickle.load(f)
self.patterns = settings['patterns'] self.imsizes = settings['imsizes']
self.focal_lengths = settings['focal_lengths'] self.patterns = settings['patterns']
self.baseline = settings['baseline'] self.focal_lengths = settings['focal_lengths']
self.K = settings['K'] self.baseline = settings['baseline']
self.K = settings['K']
self.scale = len(self.imsizes)
self.scale = len(self.imsizes)
self.max_shift=0
self.max_blur=0.5 self.max_shift = 0
self.max_noise=3.0 self.max_blur = 0.5
self.max_sp_noise=0.0005 self.max_noise = 3.0
self.max_sp_noise = 0.0005
def __len__(self):
return len(self.sample_paths) def __len__(self):
return len(self.sample_paths)
def __getitem__(self, idx):
if not self.train: def __getitem__(self, idx):
rng = self.get_rng(idx) if not self.train:
else: rng = self.get_rng(idx)
rng = np.random.RandomState()
sample_path = self.sample_paths[idx]
if self.train:
track_ind = np.random.permutation(4)[0:self.track_length]
else:
track_ind = [0]
ret = {}
ret['id'] = idx
# load imgs, at all scales
for sidx in range(len(self.imsizes)):
imgs = []
ambs = []
grads = []
for tidx in track_ind:
imgs.append(np.load(os.path.join(sample_path,f'im{sidx}_{tidx}.npy')))
ambs.append(np.load(os.path.join(sample_path,f'ambient{sidx}_{tidx}.npy')))
grads.append(np.load(os.path.join(sample_path,f'grad{sidx}_{tidx}.npy')))
ret[f'im{sidx}'] = np.stack(imgs, axis=0)
ret[f'ambient{sidx}'] = np.stack(ambs, axis=0)
ret[f'grad{sidx}'] = np.stack(grads, axis=0)
# load disp and grad only at full resolution
disps = []
R = []
t = []
for tidx in track_ind:
disps.append(np.load(os.path.join(sample_path,f'disp0_{tidx}.npy')))
R.append(np.load(os.path.join(sample_path,f'R_{tidx}.npy')))
t.append(np.load(os.path.join(sample_path,f't_{tidx}.npy')))
ret[f'disp0'] = np.stack(disps, axis=0)
ret['R'] = np.stack(R, axis=0)
ret['t'] = np.stack(t, axis=0)
blend_im = np.load(os.path.join(sample_path,'blend_im.npy'))
ret['blend_im'] = blend_im.astype(np.float32)
#### apply data augmentation at different scales seperately, only work for max_shift=0
if self.data_aug:
for sidx in range(len(self.imsizes)):
if sidx==0:
img = ret[f'im{sidx}']
disp = ret[f'disp{sidx}']
grad = ret[f'grad{sidx}']
img_aug = np.zeros_like(img)
disp_aug = np.zeros_like(img)
grad_aug = np.zeros_like(img)
for i in range(img.shape[0]):
img_aug_, disp_aug_, grad_aug_ = augment_image(img[i,0],rng,
disp=disp[i,0],grad=grad[i,0],
max_shift=self.max_shift, max_blur=self.max_blur,
max_noise=self.max_noise, max_sp_noise=self.max_sp_noise)
img_aug[i] = img_aug_[None].astype(np.float32)
disp_aug[i] = disp_aug_[None].astype(np.float32)
grad_aug[i] = grad_aug_[None].astype(np.float32)
ret[f'im{sidx}'] = img_aug
ret[f'disp{sidx}'] = disp_aug
ret[f'grad{sidx}'] = grad_aug
else: else:
img = ret[f'im{sidx}'] rng = np.random.RandomState()
img_aug = np.zeros_like(img) sample_path = self.sample_paths[idx]
for i in range(img.shape[0]):
img_aug_, _, _ = augment_image(img[i,0],rng,
max_shift=self.max_shift, max_blur=self.max_blur,
max_noise=self.max_noise, max_sp_noise=self.max_sp_noise)
img_aug[i] = img_aug_[None].astype(np.float32)
ret[f'im{sidx}'] = img_aug
if len(track_ind)==1:
for key, val in ret.items():
if key!='blend_im' and key!='id':
ret[key] = val[0]
return ret
def getK(self, sidx=0):
K = self.K.copy() / (2**sidx)
K[2,2] = 1
return K
if self.train:
track_ind = np.random.permutation(4)[0:self.track_length]
else:
track_ind = [0]
ret = {}
ret['id'] = idx
# load imgs, at all scales
for sidx in range(len(self.imsizes)):
imgs = []
ambs = []
grads = []
for tidx in track_ind:
imgs.append(np.load(os.path.join(sample_path, f'im{sidx}_{tidx}.npy')))
ambs.append(np.load(os.path.join(sample_path, f'ambient{sidx}_{tidx}.npy')))
grads.append(np.load(os.path.join(sample_path, f'grad{sidx}_{tidx}.npy')))
ret[f'im{sidx}'] = np.stack(imgs, axis=0)
ret[f'ambient{sidx}'] = np.stack(ambs, axis=0)
ret[f'grad{sidx}'] = np.stack(grads, axis=0)
# load disp and grad only at full resolution
disps = []
R = []
t = []
for tidx in track_ind:
disps.append(np.load(os.path.join(sample_path, f'disp0_{tidx}.npy')))
R.append(np.load(os.path.join(sample_path, f'R_{tidx}.npy')))
t.append(np.load(os.path.join(sample_path, f't_{tidx}.npy')))
ret[f'disp0'] = np.stack(disps, axis=0)
ret['R'] = np.stack(R, axis=0)
ret['t'] = np.stack(t, axis=0)
blend_im = np.load(os.path.join(sample_path, 'blend_im.npy'))
ret['blend_im'] = blend_im.astype(np.float32)
#### apply data augmentation at different scales seperately, only work for max_shift=0
if self.data_aug:
for sidx in range(len(self.imsizes)):
if sidx == 0:
img = ret[f'im{sidx}']
disp = ret[f'disp{sidx}']
grad = ret[f'grad{sidx}']
img_aug = np.zeros_like(img)
disp_aug = np.zeros_like(img)
grad_aug = np.zeros_like(img)
for i in range(img.shape[0]):
img_aug_, disp_aug_, grad_aug_ = augment_image(img[i, 0], rng,
disp=disp[i, 0], grad=grad[i, 0],
max_shift=self.max_shift, max_blur=self.max_blur,
max_noise=self.max_noise,
max_sp_noise=self.max_sp_noise)
img_aug[i] = img_aug_[None].astype(np.float32)
disp_aug[i] = disp_aug_[None].astype(np.float32)
grad_aug[i] = grad_aug_[None].astype(np.float32)
ret[f'im{sidx}'] = img_aug
ret[f'disp{sidx}'] = disp_aug
ret[f'grad{sidx}'] = grad_aug
else:
img = ret[f'im{sidx}']
img_aug = np.zeros_like(img)
for i in range(img.shape[0]):
img_aug_, _, _ = augment_image(img[i, 0], rng,
max_shift=self.max_shift, max_blur=self.max_blur,
max_noise=self.max_noise, max_sp_noise=self.max_sp_noise)
img_aug[i] = img_aug_[None].astype(np.float32)
ret[f'im{sidx}'] = img_aug
if len(track_ind) == 1:
for key, val in ret.items():
if key != 'blend_im' and key != 'id':
ret[key] = val[0]
return ret
def getK(self, sidx=0):
K = self.K.copy() / (2 ** sidx)
K[2, 2] = 1
return K
if __name__ == '__main__': if __name__ == '__main__':
pass pass

@ -2,7 +2,7 @@
<!-- Generated by Cython 0.29 --> <!-- Generated by Cython 0.29 -->
<html> <html>
<head> <head>
<meta http-equiv="Content-Type" content="text/html; charset=utf-8" /> <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
<title>Cython: lcn.pyx</title> <title>Cython: lcn.pyx</title>
<style type="text/css"> <style type="text/css">
@ -355,17 +355,23 @@ body.cython { font-family: courier; font-size: 12; }
.cython .vi { color: #19177C } /* Name.Variable.Instance */ .cython .vi { color: #19177C } /* Name.Variable.Instance */
.cython .vm { color: #19177C } /* Name.Variable.Magic */ .cython .vm { color: #19177C } /* Name.Variable.Magic */
.cython .il { color: #666666 } /* Literal.Number.Integer.Long */ .cython .il { color: #666666 } /* Literal.Number.Integer.Long */
</style> </style>
</head> </head>
<body class="cython"> <body class="cython">
<p><span style="border-bottom: solid 1px grey;">Generated by Cython 0.29</span></p> <p><span style="border-bottom: solid 1px grey;">Generated by Cython 0.29</span></p>
<p> <p>
<span style="background-color: #FFFF00">Yellow lines</span> hint at Python interaction.<br /> <span style="background-color: #FFFF00">Yellow lines</span> hint at Python interaction.<br/>
Click on a line that starts with a "<code>+</code>" to see the C code that Cython generated for it. Click on a line that starts with a "<code>+</code>" to see the C code that Cython generated for it.
</p> </p>
<p>Raw output: <a href="lcn.c">lcn.c</a></p> <p>Raw output: <a href="lcn.c">lcn.c</a></p>
<div class="cython"><pre class="cython line score-16" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">01</span>: <span class="k">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span></pre> <div class="cython">
<pre class='cython code score-16 '> __pyx_t_1 = <span class='pyx_c_api'>__Pyx_Import</span>(__pyx_n_s_numpy, 0, -1);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1, __pyx_L1_error)</span> <pre class="cython line score-16"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">01</span>: <span class="k">import</span> <span class="nn">numpy</span> <span
class="k">as</span> <span class="nn">np</span></pre>
<pre class='cython code score-16 '> __pyx_t_1 = <span class='pyx_c_api'>__Pyx_Import</span>(__pyx_n_s_numpy, 0, -1);<span
class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_np, __pyx_t_1) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 1, __pyx_L1_error)</span> if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_np, __pyx_t_1) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 1, __pyx_L1_error)</span>
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
@ -374,22 +380,39 @@ body.cython { font-family: courier; font-size: 12; }
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_test, __pyx_t_1) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 1, __pyx_L1_error)</span> if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_test, __pyx_t_1) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 1, __pyx_L1_error)</span>
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
</pre><pre class="cython line score-0">&#xA0;<span class="">02</span>: <span class="k">cimport</span> <span class="nn">cython</span></pre> </pre>
<pre class="cython line score-0">&#xA0;<span class="">03</span>: </pre> <pre class="cython line score-0">&#xA0;<span class="">02</span>: <span class="k">cimport</span> <span class="nn">cython</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">04</span>: <span class="c"># use c square root function</span></pre> <pre class="cython line score-0">&#xA0;<span class="">03</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">05</span>: <span class="k">cdef</span> <span class="kr">extern</span> <span class="k">from</span> <span class="s">&quot;math.h&quot;</span><span class="p">:</span></pre> <pre class="cython line score-0">&#xA0;<span class="">04</span>: <span class="c"># use c square root function</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">06</span>: <span class="nb">float</span> <span class="n">sqrt</span><span class="p">(</span><span class="nb">float</span> <span class="n">x</span><span class="p">)</span></pre> <pre class="cython line score-0">&#xA0;<span class="">05</span>: <span class="k">cdef</span> <span
<pre class="cython line score-0">&#xA0;<span class="">07</span>: </pre> class="kr">extern</span> <span class="k">from</span> <span class="s">&quot;math.h&quot;</span><span
<pre class="cython line score-0">&#xA0;<span class="">08</span>: <span class="nd">@cython</span><span class="o">.</span><span class="n">boundscheck</span><span class="p">(</span><span class="bp">False</span><span class="p">)</span></pre> class="p">:</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">09</span>: <span class="nd">@cython</span><span class="o">.</span><span class="n">wraparound</span><span class="p">(</span><span class="bp">False</span><span class="p">)</span></pre> <pre class="cython line score-0">&#xA0;<span class="">06</span>: <span class="nb">float</span> <span class="n">sqrt</span><span
<pre class="cython line score-0">&#xA0;<span class="">10</span>: <span class="nd">@cython</span><span class="o">.</span><span class="n">cdivision</span><span class="p">(</span><span class="bp">True</span><span class="p">)</span></pre> class="p">(</span><span class="nb">float</span> <span class="n">x</span><span class="p">)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">11</span>: </pre> <pre class="cython line score-0">&#xA0;<span class="">07</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">12</span>: <span class="c"># 3 parameters:</span></pre> <pre class="cython line score-0">&#xA0;<span class="">08</span>: <span class="nd">@cython</span><span
<pre class="cython line score-0">&#xA0;<span class="">13</span>: <span class="c"># - float image</span></pre> class="o">.</span><span class="n">boundscheck</span><span class="p">(</span><span
<pre class="cython line score-0">&#xA0;<span class="">14</span>: <span class="c"># - kernel size (actually this is the radius, kernel is 2*k+1)</span></pre> class="bp">False</span><span class="p">)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">15</span>: <span class="c"># - small constant epsilon that is used to avoid division by zero</span></pre> <pre class="cython line score-0">&#xA0;<span class="">09</span>: <span class="nd">@cython</span><span
<pre class="cython line score-67" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">16</span>: <span class="k">def</span> <span class="nf">normalize</span><span class="p">(</span><span class="nb">float</span><span class="p">[:,</span> <span class="p">:]</span> <span class="n">img</span><span class="p">,</span> <span class="nb">int</span> <span class="n">kernel_size</span> <span class="o">=</span> <span class="mf">4</span><span class="p">,</span> <span class="nb">float</span> <span class="n">epsilon</span> <span class="o">=</span> <span class="mf">0.01</span><span class="p">):</span></pre> class="o">.</span><span class="n">wraparound</span><span class="p">(</span><span
<pre class='cython code score-67 '>/* Python wrapper */ class="bp">False</span><span class="p">)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">10</span>: <span class="nd">@cython</span><span
class="o">.</span><span class="n">cdivision</span><span class="p">(</span><span class="bp">True</span><span
class="p">)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">11</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">12</span>: <span class="c"># 3 parameters:</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">13</span>: <span class="c"># - float image</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">14</span>: <span class="c"># - kernel size (actually this is the radius, kernel is 2*k+1)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">15</span>: <span class="c"># - small constant epsilon that is used to avoid division by zero</span></pre>
<pre class="cython line score-67"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">16</span>: <span class="k">def</span> <span class="nf">normalize</span><span
class="p">(</span><span class="nb">float</span><span class="p">[:,</span> <span class="p">:]</span> <span
class="n">img</span><span class="p">,</span> <span class="nb">int</span> <span class="n">kernel_size</span> <span
class="o">=</span> <span class="mf">4</span><span class="p">,</span> <span class="nb">float</span> <span
class="n">epsilon</span> <span class="o">=</span> <span class="mf">0.01</span><span
class="p">):</span></pre>
<pre class='cython code score-67 '>/* Python wrapper */
static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/
static PyMethodDef __pyx_mdef_3lcn_1normalize = {"normalize", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_3lcn_1normalize, METH_VARARGS|METH_KEYWORDS, 0}; static PyMethodDef __pyx_mdef_3lcn_1normalize = {"normalize", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_3lcn_1normalize, METH_VARARGS|METH_KEYWORDS, 0};
static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) {
@ -434,7 +457,8 @@ static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_
} }
} }
if (unlikely(kw_args &gt; 0)) { if (unlikely(kw_args &gt; 0)) {
if (unlikely(<span class='pyx_c_api'>__Pyx_ParseOptionalKeywords</span>(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "normalize") &lt; 0)) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span> if (unlikely(<span class='pyx_c_api'>__Pyx_ParseOptionalKeywords</span>(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "normalize") &lt; 0)) <span
class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
} }
} else { } else {
switch (<span class='py_macro_api'>PyTuple_GET_SIZE</span>(__pyx_args)) { switch (<span class='py_macro_api'>PyTuple_GET_SIZE</span>(__pyx_args)) {
@ -447,21 +471,27 @@ static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_
default: goto __pyx_L5_argtuple_error; default: goto __pyx_L5_argtuple_error;
} }
} }
__pyx_v_img = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(values[0], PyBUF_WRITABLE);<span class='error_goto'> if (unlikely(!__pyx_v_img.memview)) __PYX_ERR(0, 16, __pyx_L3_error)</span> __pyx_v_img = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(values[0], PyBUF_WRITABLE);<span
class='error_goto'> if (unlikely(!__pyx_v_img.memview)) __PYX_ERR(0, 16, __pyx_L3_error)</span>
if (values[1]) { if (values[1]) {
__pyx_v_kernel_size = <span class='pyx_c_api'>__Pyx_PyInt_As_int</span>(values[1]); if (unlikely((__pyx_v_kernel_size == (int)-1) &amp;&amp; <span class='py_c_api'>PyErr_Occurred</span>())) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span> __pyx_v_kernel_size = <span class='pyx_c_api'>__Pyx_PyInt_As_int</span>(values[1]); if (unlikely((__pyx_v_kernel_size == (int)-1) &amp;&amp; <span
class='py_c_api'>PyErr_Occurred</span>())) <span
class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
} else { } else {
__pyx_v_kernel_size = ((int)4); __pyx_v_kernel_size = ((int)4);
} }
if (values[2]) { if (values[2]) {
__pyx_v_epsilon = __pyx_<span class='py_c_api'>PyFloat_AsFloat</span>(values[2]); if (unlikely((__pyx_v_epsilon == (float)-1) &amp;&amp; <span class='py_c_api'>PyErr_Occurred</span>())) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span> __pyx_v_epsilon = __pyx_<span class='py_c_api'>PyFloat_AsFloat</span>(values[2]); if (unlikely((__pyx_v_epsilon == (float)-1) &amp;&amp; <span
class='py_c_api'>PyErr_Occurred</span>())) <span
class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
} else { } else {
__pyx_v_epsilon = ((float)0.01); __pyx_v_epsilon = ((float)0.01);
} }
} }
goto __pyx_L4_argument_unpacking_done; goto __pyx_L4_argument_unpacking_done;
__pyx_L5_argtuple_error:; __pyx_L5_argtuple_error:;
<span class='pyx_c_api'>__Pyx_RaiseArgtupleInvalid</span>("normalize", 0, 1, 3, <span class='py_macro_api'>PyTuple_GET_SIZE</span>(__pyx_args)); <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span> <span class='pyx_c_api'>__Pyx_RaiseArgtupleInvalid</span>("normalize", 0, 1, 3, <span class='py_macro_api'>PyTuple_GET_SIZE</span>(__pyx_args)); <span
class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
__pyx_L3_error:; __pyx_L3_error:;
<span class='pyx_c_api'>__Pyx_AddTraceback</span>("lcn.normalize", __pyx_clineno, __pyx_lineno, __pyx_filename); <span class='pyx_c_api'>__Pyx_AddTraceback</span>("lcn.normalize", __pyx_clineno, __pyx_lineno, __pyx_filename);
<span class='refnanny'>__Pyx_RefNannyFinishContext</span>(); <span class='refnanny'>__Pyx_RefNannyFinishContext</span>();
@ -515,27 +545,49 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
return __pyx_r; return __pyx_r;
} }
/* … */ /* … */
__pyx_tuple__19 = <span class='py_c_api'>PyTuple_Pack</span>(19, __pyx_n_s_img, __pyx_n_s_kernel_size, __pyx_n_s_epsilon, __pyx_n_s_M, __pyx_n_s_N, __pyx_n_s_img_lcn, __pyx_n_s_img_std, __pyx_n_s_img_lcn_view, __pyx_n_s_img_std_view, __pyx_n_s_tmp, __pyx_n_s_mean, __pyx_n_s_stddev, __pyx_n_s_m, __pyx_n_s_n, __pyx_n_s_i, __pyx_n_s_j, __pyx_n_s_ks, __pyx_n_s_eps, __pyx_n_s_num);<span class='error_goto'> if (unlikely(!__pyx_tuple__19)) __PYX_ERR(0, 16, __pyx_L1_error)</span> __pyx_tuple__19 = <span class='py_c_api'>PyTuple_Pack</span>(19, __pyx_n_s_img, __pyx_n_s_kernel_size, __pyx_n_s_epsilon, __pyx_n_s_M, __pyx_n_s_N, __pyx_n_s_img_lcn, __pyx_n_s_img_std, __pyx_n_s_img_lcn_view, __pyx_n_s_img_std_view, __pyx_n_s_tmp, __pyx_n_s_mean, __pyx_n_s_stddev, __pyx_n_s_m, __pyx_n_s_n, __pyx_n_s_i, __pyx_n_s_j, __pyx_n_s_ks, __pyx_n_s_eps, __pyx_n_s_num);<span
class='error_goto'> if (unlikely(!__pyx_tuple__19)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_tuple__19); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_tuple__19);
<span class='refnanny'>__Pyx_GIVEREF</span>(__pyx_tuple__19); <span class='refnanny'>__Pyx_GIVEREF</span>(__pyx_tuple__19);
/* … */ /* … */
__pyx_t_1 = PyCFunction_NewEx(&amp;__pyx_mdef_3lcn_1normalize, NULL, __pyx_n_s_lcn);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 16, __pyx_L1_error)</span> __pyx_t_1 = PyCFunction_NewEx(&amp;__pyx_mdef_3lcn_1normalize, NULL, __pyx_n_s_lcn);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_normalize, __pyx_t_1) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L1_error)</span> if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_normalize, __pyx_t_1) &lt; 0) <span
class='error_goto'>__PYX_ERR(0, 16, __pyx_L1_error)</span>
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
__pyx_codeobj__20 = (PyObject*)<span class='pyx_c_api'>__Pyx_PyCode_New</span>(3, 0, 19, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__19, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_lcn_pyx, __pyx_n_s_normalize, 16, __pyx_empty_bytes);<span class='error_goto'> if (unlikely(!__pyx_codeobj__20)) __PYX_ERR(0, 16, __pyx_L1_error)</span> __pyx_codeobj__20 = (PyObject*)<span class='pyx_c_api'>__Pyx_PyCode_New</span>(3, 0, 19, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__19, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_lcn_pyx, __pyx_n_s_normalize, 16, __pyx_empty_bytes);<span
</pre><pre class="cython line score-0">&#xA0;<span class="">17</span>: </pre> class='error_goto'> if (unlikely(!__pyx_codeobj__20)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
<pre class="cython line score-0">&#xA0;<span class="">18</span>: <span class="c"># image dimensions</span></pre> </pre>
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">19</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">M</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mf">0</span><span class="p">]</span></pre> <pre class="cython line score-0">&#xA0;<span class="">17</span>: </pre>
<pre class='cython code score-0 '> __pyx_v_M = (__pyx_v_img.shape[0]); <pre class="cython line score-0">&#xA0;<span class="">18</span>: <span class="c"># image dimensions</span></pre>
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">20</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">N</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mf">1</span><span class="p">]</span></pre> <pre class="cython line score-0"
<pre class='cython code score-0 '> __pyx_v_N = (__pyx_v_img.shape[1]); onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
</pre><pre class="cython line score-0">&#xA0;<span class="">21</span>: </pre> class="">19</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span
<pre class="cython line score-0">&#xA0;<span class="">22</span>: <span class="c"># create outputs and output views</span></pre> class="nf">M</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span
<pre class="cython line score-46" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">23</span>: <span class="n">img_lcn</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span></pre> class="n">shape</span><span class="p">[</span><span class="mf">0</span><span class="p">]</span></pre>
<pre class='cython code score-46 '> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_1, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span> <pre class='cython code score-0 '> __pyx_v_M = (__pyx_v_img.shape[0]);
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">20</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span
class="nf">N</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span
class="n">shape</span><span class="p">[</span><span class="mf">1</span><span class="p">]</span></pre>
<pre class='cython code score-0 '> __pyx_v_N = (__pyx_v_img.shape[1]);
</pre>
<pre class="cython line score-0">&#xA0;<span class="">21</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">22</span>: <span class="c"># create outputs and output views</span></pre>
<pre class="cython line score-46"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">23</span>: <span class="n">img_lcn</span> <span class="o">=</span> <span
class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span
class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span
class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span
class="n">float32</span><span class="p">)</span></pre>
<pre class='cython code score-46 '> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_1, __pyx_n_s_np);<span
class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
__pyx_t_2 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_1, __pyx_n_s_zeros);<span class='error_goto'> if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 23, __pyx_L1_error)</span> __pyx_t_2 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_1, __pyx_n_s_zeros);<span
class='error_goto'> if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_2); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_2);
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
__pyx_t_1 = <span class='py_c_api'>PyInt_FromSsize_t</span>(__pyx_v_M);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span> __pyx_t_1 = <span class='py_c_api'>PyInt_FromSsize_t</span>(__pyx_v_M);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
@ -559,22 +611,34 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_4); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_4);
<span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_1, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_1, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
__pyx_t_5 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_1, __pyx_n_s_float32);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)</span> __pyx_t_5 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_1, __pyx_n_s_float32);<span
class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_t_4, __pyx_n_s_dtype, __pyx_t_5) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 23, __pyx_L1_error)</span> if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_t_4, __pyx_n_s_dtype, __pyx_t_5) &lt; 0) <span
class='error_goto'>__PYX_ERR(0, 23, __pyx_L1_error)</span>
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0;
__pyx_t_5 = <span class='pyx_c_api'>__Pyx_PyObject_Call</span>(__pyx_t_2, __pyx_t_3, __pyx_t_4);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)</span> __pyx_t_5 = <span class='pyx_c_api'>__Pyx_PyObject_Call</span>(__pyx_t_2, __pyx_t_3, __pyx_t_4);<span
class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_2); __pyx_t_2 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_2); __pyx_t_2 = 0;
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_3); __pyx_t_3 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_3); __pyx_t_3 = 0;
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_4); __pyx_t_4 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_4); __pyx_t_4 = 0;
__pyx_v_img_lcn = __pyx_t_5; __pyx_v_img_lcn = __pyx_t_5;
__pyx_t_5 = 0; __pyx_t_5 = 0;
</pre><pre class="cython line score-46" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">24</span>: <span class="n">img_std</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span></pre> </pre>
<pre class='cython code score-46 '> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_5, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span> <pre class="cython line score-46"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">24</span>: <span class="n">img_std</span> <span class="o">=</span> <span
class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span
class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span
class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span
class="n">float32</span><span class="p">)</span></pre>
<pre class='cython code score-46 '> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_5, __pyx_n_s_np);<span
class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
__pyx_t_4 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_5, __pyx_n_s_zeros);<span class='error_goto'> if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 24, __pyx_L1_error)</span> __pyx_t_4 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_5, __pyx_n_s_zeros);<span
class='error_goto'> if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_4); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_4);
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0;
__pyx_t_5 = <span class='py_c_api'>PyInt_FromSsize_t</span>(__pyx_v_M);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span> __pyx_t_5 = <span class='py_c_api'>PyInt_FromSsize_t</span>(__pyx_v_M);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
@ -598,114 +662,236 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_2); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_2);
<span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_5, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_5, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
__pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_5, __pyx_n_s_float32);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)</span> __pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_5, __pyx_n_s_float32);<span
class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0;
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_t_2, __pyx_n_s_dtype, __pyx_t_1) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 24, __pyx_L1_error)</span> if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_t_2, __pyx_n_s_dtype, __pyx_t_1) &lt; 0) <span
class='error_goto'>__PYX_ERR(0, 24, __pyx_L1_error)</span>
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
__pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyObject_Call</span>(__pyx_t_4, __pyx_t_3, __pyx_t_2);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)</span> __pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyObject_Call</span>(__pyx_t_4, __pyx_t_3, __pyx_t_2);<span
class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_4); __pyx_t_4 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_4); __pyx_t_4 = 0;
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_3); __pyx_t_3 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_3); __pyx_t_3 = 0;
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_2); __pyx_t_2 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_2); __pyx_t_2 = 0;
__pyx_v_img_std = __pyx_t_1; __pyx_v_img_std = __pyx_t_1;
__pyx_t_1 = 0; __pyx_t_1 = 0;
</pre><pre class="cython line score-2" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">25</span>: <span class="k">cdef</span> <span class="kt">float</span>[<span class="p">:,</span> <span class="p">:]</span> <span class="n">img_lcn_view</span> <span class="o">=</span> <span class="n">img_lcn</span></pre> </pre>
<pre class='cython code score-2 '> __pyx_t_6 = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(__pyx_v_img_lcn, PyBUF_WRITABLE);<span class='error_goto'> if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 25, __pyx_L1_error)</span> <pre class="cython line score-2"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">25</span>: <span class="k">cdef</span> <span class="kt">float</span>[<span class="p">:,</span> <span
class="p">:]</span> <span class="n">img_lcn_view</span> <span class="o">=</span> <span
class="n">img_lcn</span></pre>
<pre class='cython code score-2 '> __pyx_t_6 = <span
class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(__pyx_v_img_lcn, PyBUF_WRITABLE);<span
class='error_goto'> if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 25, __pyx_L1_error)</span>
__pyx_v_img_lcn_view = __pyx_t_6; __pyx_v_img_lcn_view = __pyx_t_6;
__pyx_t_6.memview = NULL; __pyx_t_6.memview = NULL;
__pyx_t_6.data = NULL; __pyx_t_6.data = NULL;
</pre><pre class="cython line score-2" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">26</span>: <span class="k">cdef</span> <span class="kt">float</span>[<span class="p">:,</span> <span class="p">:]</span> <span class="n">img_std_view</span> <span class="o">=</span> <span class="n">img_std</span></pre> </pre>
<pre class='cython code score-2 '> __pyx_t_6 = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(__pyx_v_img_std, PyBUF_WRITABLE);<span class='error_goto'> if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 26, __pyx_L1_error)</span> <pre class="cython line score-2"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">26</span>: <span class="k">cdef</span> <span class="kt">float</span>[<span class="p">:,</span> <span
class="p">:]</span> <span class="n">img_std_view</span> <span class="o">=</span> <span
class="n">img_std</span></pre>
<pre class='cython code score-2 '> __pyx_t_6 = <span
class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(__pyx_v_img_std, PyBUF_WRITABLE);<span
class='error_goto'> if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 26, __pyx_L1_error)</span>
__pyx_v_img_std_view = __pyx_t_6; __pyx_v_img_std_view = __pyx_t_6;
__pyx_t_6.memview = NULL; __pyx_t_6.memview = NULL;
__pyx_t_6.data = NULL; __pyx_t_6.data = NULL;
</pre><pre class="cython line score-0">&#xA0;<span class="">27</span>: </pre> </pre>
<pre class="cython line score-0">&#xA0;<span class="">28</span>: <span class="c"># temporary c variables</span></pre> <pre class="cython line score-0">&#xA0;<span class="">27</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">29</span>: <span class="k">cdef</span> <span class="kt">float</span> <span class="nf">tmp</span><span class="p">,</span> <span class="nf">mean</span><span class="p">,</span> <span class="nf">stddev</span></pre> <pre class="cython line score-0">&#xA0;<span class="">28</span>: <span class="c"># temporary c variables</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">30</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">m</span><span class="p">,</span> <span class="nf">n</span><span class="p">,</span> <span class="nf">i</span><span class="p">,</span> <span class="nf">j</span></pre> <pre class="cython line score-0">&#xA0;<span class="">29</span>: <span class="k">cdef</span> <span class="kt">float</span> <span
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">31</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">ks</span> <span class="o">=</span> <span class="n">kernel_size</span></pre> class="nf">tmp</span><span class="p">,</span> <span class="nf">mean</span><span class="p">,</span> <span
<pre class='cython code score-0 '> __pyx_v_ks = __pyx_v_kernel_size; class="nf">stddev</span></pre>
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">32</span>: <span class="k">cdef</span> <span class="kt">float</span> <span class="nf">eps</span> <span class="o">=</span> <span class="n">epsilon</span></pre> <pre class="cython line score-0">&#xA0;<span class="">30</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span
<pre class='cython code score-0 '> __pyx_v_eps = __pyx_v_epsilon; class="nf">m</span><span class="p">,</span> <span class="nf">n</span><span class="p">,</span> <span
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">33</span>: <span class="k">cdef</span> <span class="kt">float</span> <span class="nf">num</span> <span class="o">=</span> <span class="p">(</span><span class="n">ks</span><span class="o">*</span><span class="mf">2</span><span class="o">+</span><span class="mf">1</span><span class="p">)</span><span class="o">**</span><span class="mf">2</span></pre> class="nf">i</span><span class="p">,</span> <span class="nf">j</span></pre>
<pre class='cython code score-0 '> __pyx_v_num = __Pyx_pow_Py_ssize_t(((__pyx_v_ks * 2) + 1), 2); <pre class="cython line score-0"
</pre><pre class="cython line score-0">&#xA0;<span class="">34</span>: </pre> onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
<pre class="cython line score-0">&#xA0;<span class="">35</span>: <span class="c"># for all pixels do</span></pre> class="">31</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">36</span>: <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ks</span><span class="p">,</span><span class="n">M</span><span class="o">-</span><span class="n">ks</span><span class="p">):</span></pre> class="nf">ks</span> <span class="o">=</span> <span class="n">kernel_size</span></pre>
<pre class='cython code score-0 '> __pyx_t_7 = (__pyx_v_M - __pyx_v_ks); <pre class='cython code score-0 '> __pyx_v_ks = __pyx_v_kernel_size;
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">32</span>: <span class="k">cdef</span> <span class="kt">float</span> <span
class="nf">eps</span> <span class="o">=</span> <span class="n">epsilon</span></pre>
<pre class='cython code score-0 '> __pyx_v_eps = __pyx_v_epsilon;
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">33</span>: <span class="k">cdef</span> <span class="kt">float</span> <span
class="nf">num</span> <span class="o">=</span> <span class="p">(</span><span class="n">ks</span><span
class="o">*</span><span class="mf">2</span><span class="o">+</span><span class="mf">1</span><span class="p">)</span><span
class="o">**</span><span class="mf">2</span></pre>
<pre class='cython code score-0 '> __pyx_v_num = __Pyx_pow_Py_ssize_t(((__pyx_v_ks * 2) + 1), 2);
</pre>
<pre class="cython line score-0">&#xA0;<span class="">34</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">35</span>: <span
class="c"># for all pixels do</span></pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">36</span>: <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span
class="nb">range</span><span class="p">(</span><span class="n">ks</span><span class="p">,</span><span
class="n">M</span><span class="o">-</span><span class="n">ks</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_7 = (__pyx_v_M - __pyx_v_ks);
__pyx_t_8 = __pyx_t_7; __pyx_t_8 = __pyx_t_7;
for (__pyx_t_9 = __pyx_v_ks; __pyx_t_9 &lt; __pyx_t_8; __pyx_t_9+=1) { for (__pyx_t_9 = __pyx_v_ks; __pyx_t_9 &lt; __pyx_t_8; __pyx_t_9+=1) {
__pyx_v_m = __pyx_t_9; __pyx_v_m = __pyx_t_9;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">37</span>: <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ks</span><span class="p">,</span><span class="n">N</span><span class="o">-</span><span class="n">ks</span><span class="p">):</span></pre> </pre>
<pre class='cython code score-0 '> __pyx_t_10 = (__pyx_v_N - __pyx_v_ks); <pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">37</span>: <span class="k">for</span> <span class="n">n</span> <span
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ks</span><span
class="p">,</span><span class="n">N</span><span class="o">-</span><span class="n">ks</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_10 = (__pyx_v_N - __pyx_v_ks);
__pyx_t_11 = __pyx_t_10; __pyx_t_11 = __pyx_t_10;
for (__pyx_t_12 = __pyx_v_ks; __pyx_t_12 &lt; __pyx_t_11; __pyx_t_12+=1) { for (__pyx_t_12 = __pyx_v_ks; __pyx_t_12 &lt; __pyx_t_11; __pyx_t_12+=1) {
__pyx_v_n = __pyx_t_12; __pyx_v_n = __pyx_t_12;
</pre><pre class="cython line score-0">&#xA0;<span class="">38</span>: </pre> </pre>
<pre class="cython line score-0">&#xA0;<span class="">39</span>: <span class="c"># calculate mean</span></pre> <pre class="cython line score-0">&#xA0;<span class="">38</span>: </pre>
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">40</span>: <span class="n">mean</span> <span class="o">=</span> <span class="mf">0</span><span class="p">;</span></pre> <pre class="cython line score-0">&#xA0;<span class="">39</span>: <span class="c"># calculate mean</span></pre>
<pre class='cython code score-0 '> __pyx_v_mean = 0.0; <pre class="cython line score-0"
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">41</span>: <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre> onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
<pre class='cython code score-0 '> __pyx_t_13 = (__pyx_v_ks + 1); class="">40</span>: <span class="n">mean</span> <span class="o">=</span> <span
class="mf">0</span><span class="p">;</span></pre>
<pre class='cython code score-0 '> __pyx_v_mean = 0.0;
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">41</span>: <span class="k">for</span> <span class="n">i</span> <span
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span
class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span
class="mf">1</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_13 = (__pyx_v_ks + 1);
__pyx_t_14 = __pyx_t_13; __pyx_t_14 = __pyx_t_13;
for (__pyx_t_15 = (-__pyx_v_ks); __pyx_t_15 &lt; __pyx_t_14; __pyx_t_15+=1) { for (__pyx_t_15 = (-__pyx_v_ks); __pyx_t_15 &lt; __pyx_t_14; __pyx_t_15+=1) {
__pyx_v_i = __pyx_t_15; __pyx_v_i = __pyx_t_15;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">42</span>: <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre> </pre>
<pre class='cython code score-0 '> __pyx_t_16 = (__pyx_v_ks + 1); <pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">42</span>: <span class="k">for</span> <span class="n">j</span> <span
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span
class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span
class="mf">1</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_16 = (__pyx_v_ks + 1);
__pyx_t_17 = __pyx_t_16; __pyx_t_17 = __pyx_t_16;
for (__pyx_t_18 = (-__pyx_v_ks); __pyx_t_18 &lt; __pyx_t_17; __pyx_t_18+=1) { for (__pyx_t_18 = (-__pyx_v_ks); __pyx_t_18 &lt; __pyx_t_17; __pyx_t_18+=1) {
__pyx_v_j = __pyx_t_18; __pyx_v_j = __pyx_t_18;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">43</span>: <span class="n">mean</span> <span class="o">+=</span> <span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span class="p">]</span></pre> </pre>
<pre class='cython code score-0 '> __pyx_t_19 = (__pyx_v_m + __pyx_v_i); <pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">43</span>: <span class="n">mean</span> <span class="o">+=</span> <span
class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span
class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span
class="p">]</span></pre>
<pre class='cython code score-0 '> __pyx_t_19 = (__pyx_v_m + __pyx_v_i);
__pyx_t_20 = (__pyx_v_n + __pyx_v_j); __pyx_t_20 = (__pyx_v_n + __pyx_v_j);
__pyx_v_mean = (__pyx_v_mean + (*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_19 * __pyx_v_img.strides[0]) ) + __pyx_t_20 * __pyx_v_img.strides[1]) )))); __pyx_v_mean = (__pyx_v_mean + (*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_19 * __pyx_v_img.strides[0]) ) + __pyx_t_20 * __pyx_v_img.strides[1]) ))));
} }
} }
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">44</span>: <span class="n">mean</span> <span class="o">=</span> <span class="n">mean</span><span class="o">/</span><span class="n">num</span></pre> </pre>
<pre class='cython code score-0 '> __pyx_v_mean = (__pyx_v_mean / __pyx_v_num); <pre class="cython line score-0"
</pre><pre class="cython line score-0">&#xA0;<span class="">45</span>: </pre> onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
<pre class="cython line score-0">&#xA0;<span class="">46</span>: <span class="c"># calculate std dev</span></pre> class="">44</span>: <span class="n">mean</span> <span class="o">=</span> <span
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">47</span>: <span class="n">stddev</span> <span class="o">=</span> <span class="mf">0</span><span class="p">;</span></pre> class="n">mean</span><span class="o">/</span><span class="n">num</span></pre>
<pre class='cython code score-0 '> __pyx_v_stddev = 0.0; <pre class='cython code score-0 '> __pyx_v_mean = (__pyx_v_mean / __pyx_v_num);
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">48</span>: <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre> </pre>
<pre class='cython code score-0 '> __pyx_t_13 = (__pyx_v_ks + 1); <pre class="cython line score-0">&#xA0;<span class="">45</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">46</span>: <span
class="c"># calculate std dev</span></pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">47</span>: <span class="n">stddev</span> <span class="o">=</span> <span
class="mf">0</span><span class="p">;</span></pre>
<pre class='cython code score-0 '> __pyx_v_stddev = 0.0;
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">48</span>: <span class="k">for</span> <span class="n">i</span> <span
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span
class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span
class="mf">1</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_13 = (__pyx_v_ks + 1);
__pyx_t_14 = __pyx_t_13; __pyx_t_14 = __pyx_t_13;
for (__pyx_t_15 = (-__pyx_v_ks); __pyx_t_15 &lt; __pyx_t_14; __pyx_t_15+=1) { for (__pyx_t_15 = (-__pyx_v_ks); __pyx_t_15 &lt; __pyx_t_14; __pyx_t_15+=1) {
__pyx_v_i = __pyx_t_15; __pyx_v_i = __pyx_t_15;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">49</span>: <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre> </pre>
<pre class='cython code score-0 '> __pyx_t_16 = (__pyx_v_ks + 1); <pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">49</span>: <span class="k">for</span> <span class="n">j</span> <span
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span
class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span
class="mf">1</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_16 = (__pyx_v_ks + 1);
__pyx_t_17 = __pyx_t_16; __pyx_t_17 = __pyx_t_16;
for (__pyx_t_18 = (-__pyx_v_ks); __pyx_t_18 &lt; __pyx_t_17; __pyx_t_18+=1) { for (__pyx_t_18 = (-__pyx_v_ks); __pyx_t_18 &lt; __pyx_t_17; __pyx_t_18+=1) {
__pyx_v_j = __pyx_t_18; __pyx_v_j = __pyx_t_18;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">50</span>: <span class="n">stddev</span> <span class="o">=</span> <span class="n">stddev</span> <span class="o">+</span> <span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span><span class="o">*</span><span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span></pre> </pre>
<pre class='cython code score-0 '> __pyx_t_21 = (__pyx_v_m + __pyx_v_i); <pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">50</span>: <span class="n">stddev</span> <span class="o">=</span> <span
class="n">stddev</span> <span class="o">+</span> <span class="p">(</span><span class="n">img</span><span
class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span
class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span
class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span><span
class="o">*</span><span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span
class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span
class="o">+</span><span class="n">j</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span
class="p">)</span></pre>
<pre class='cython code score-0 '> __pyx_t_21 = (__pyx_v_m + __pyx_v_i);
__pyx_t_22 = (__pyx_v_n + __pyx_v_j); __pyx_t_22 = (__pyx_v_n + __pyx_v_j);
__pyx_t_23 = (__pyx_v_m + __pyx_v_i); __pyx_t_23 = (__pyx_v_m + __pyx_v_i);
__pyx_t_24 = (__pyx_v_n + __pyx_v_j); __pyx_t_24 = (__pyx_v_n + __pyx_v_j);
__pyx_v_stddev = (__pyx_v_stddev + (((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_21 * __pyx_v_img.strides[0]) ) + __pyx_t_22 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean) * ((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_23 * __pyx_v_img.strides[0]) ) + __pyx_t_24 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean))); __pyx_v_stddev = (__pyx_v_stddev + (((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_21 * __pyx_v_img.strides[0]) ) + __pyx_t_22 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean) * ((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_23 * __pyx_v_img.strides[0]) ) + __pyx_t_24 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean)));
} }
} }
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">51</span>: <span class="n">stddev</span> <span class="o">=</span> <span class="n">sqrt</span><span class="p">(</span><span class="n">stddev</span><span class="o">/</span><span class="n">num</span><span class="p">)</span></pre> </pre>
<pre class='cython code score-0 '> __pyx_v_stddev = sqrt((__pyx_v_stddev / __pyx_v_num)); <pre class="cython line score-0"
</pre><pre class="cython line score-0">&#xA0;<span class="">52</span>: </pre> onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
<pre class="cython line score-0">&#xA0;<span class="">53</span>: <span class="c"># compute normalized image (add epsilon) and std dev image</span></pre> class="">51</span>: <span class="n">stddev</span> <span class="o">=</span> <span
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">54</span>: <span class="n">img_lcn_view</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span><span class="o">/</span><span class="p">(</span><span class="n">stddev</span><span class="o">+</span><span class="n">eps</span><span class="p">)</span></pre> class="n">sqrt</span><span class="p">(</span><span class="n">stddev</span><span class="o">/</span><span
<pre class='cython code score-0 '> __pyx_t_25 = __pyx_v_m; class="n">num</span><span class="p">)</span></pre>
<pre class='cython code score-0 '> __pyx_v_stddev = sqrt((__pyx_v_stddev / __pyx_v_num));
</pre>
<pre class="cython line score-0">&#xA0;<span class="">52</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">53</span>: <span class="c"># compute normalized image (add epsilon) and std dev image</span></pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">54</span>: <span class="n">img_lcn_view</span><span class="p">[</span><span
class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span
class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span
class="n">n</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span
class="p">)</span><span class="o">/</span><span class="p">(</span><span class="n">stddev</span><span
class="o">+</span><span class="n">eps</span><span class="p">)</span></pre>
<pre class='cython code score-0 '> __pyx_t_25 = __pyx_v_m;
__pyx_t_26 = __pyx_v_n; __pyx_t_26 = __pyx_v_n;
__pyx_t_27 = __pyx_v_m; __pyx_t_27 = __pyx_v_m;
__pyx_t_28 = __pyx_v_n; __pyx_t_28 = __pyx_v_n;
*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img_lcn_view.data + __pyx_t_27 * __pyx_v_img_lcn_view.strides[0]) ) + __pyx_t_28 * __pyx_v_img_lcn_view.strides[1]) )) = (((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_25 * __pyx_v_img.strides[0]) ) + __pyx_t_26 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean) / (__pyx_v_stddev + __pyx_v_eps)); *((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img_lcn_view.data + __pyx_t_27 * __pyx_v_img_lcn_view.strides[0]) ) + __pyx_t_28 * __pyx_v_img_lcn_view.strides[1]) )) = (((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_25 * __pyx_v_img.strides[0]) ) + __pyx_t_26 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean) / (__pyx_v_stddev + __pyx_v_eps));
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">55</span>: <span class="n">img_std_view</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span class="n">stddev</span></pre> </pre>
<pre class='cython code score-0 '> __pyx_t_29 = __pyx_v_m; <pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">55</span>: <span class="n">img_std_view</span><span class="p">[</span><span
class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span
class="n">stddev</span></pre>
<pre class='cython code score-0 '> __pyx_t_29 = __pyx_v_m;
__pyx_t_30 = __pyx_v_n; __pyx_t_30 = __pyx_v_n;
*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img_std_view.data + __pyx_t_29 * __pyx_v_img_std_view.strides[0]) ) + __pyx_t_30 * __pyx_v_img_std_view.strides[1]) )) = __pyx_v_stddev; *((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img_std_view.data + __pyx_t_29 * __pyx_v_img_std_view.strides[0]) ) + __pyx_t_30 * __pyx_v_img_std_view.strides[1]) )) = __pyx_v_stddev;
} }
} }
</pre><pre class="cython line score-0">&#xA0;<span class="">56</span>: </pre> </pre>
<pre class="cython line score-0">&#xA0;<span class="">57</span>: <span class="c"># return both</span></pre> <pre class="cython line score-0">&#xA0;<span class="">56</span>: </pre>
<pre class="cython line score-10" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">58</span>: <span class="k">return</span> <span class="n">img_lcn</span><span class="p">,</span> <span class="n">img_std</span></pre> <pre class="cython line score-0">&#xA0;<span class="">57</span>: <span class="c"># return both</span></pre>
<pre class='cython code score-10 '> <span class='pyx_macro_api'>__Pyx_XDECREF</span>(__pyx_r); <pre class="cython line score-10"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">58</span>: <span class="k">return</span> <span class="n">img_lcn</span><span class="p">,</span> <span
class="n">img_std</span></pre>
<pre class='cython code score-10 '> <span class='pyx_macro_api'>__Pyx_XDECREF</span>(__pyx_r);
__pyx_t_1 = <span class='py_c_api'>PyTuple_New</span>(2);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 58, __pyx_L1_error)</span> __pyx_t_1 = <span class='py_c_api'>PyTuple_New</span>(2);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 58, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
<span class='pyx_macro_api'>__Pyx_INCREF</span>(__pyx_v_img_lcn); <span class='pyx_macro_api'>__Pyx_INCREF</span>(__pyx_v_img_lcn);
@ -717,4 +903,7 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
__pyx_r = __pyx_t_1; __pyx_r = __pyx_t_1;
__pyx_t_1 = 0; __pyx_t_1 = 0;
goto __pyx_L0; goto __pyx_L0;
</pre></div></body></html> </pre>
</div>
</body>
</html>

@ -2,5 +2,5 @@ from distutils.core import setup
from Cython.Build import cythonize from Cython.Build import cythonize
setup( setup(
ext_modules = cythonize("lcn.pyx",annotate=True) ext_modules=cythonize("lcn.pyx", annotate=True)
) )

@ -5,43 +5,43 @@ from scipy import misc
# load and convert to float # load and convert to float
img = misc.imread('img.png') img = misc.imread('img.png')
img = img.astype(np.float32)/255.0 img = img.astype(np.float32) / 255.0
# normalize # normalize
img_lcn, img_std = lcn.normalize(img,5,0.05) img_lcn, img_std = lcn.normalize(img, 5, 0.05)
# normalize to reasonable range between 0 and 1 # normalize to reasonable range between 0 and 1
#img_lcn = img_lcn/3.0 # img_lcn = img_lcn/3.0
#img_lcn = np.maximum(img_lcn,0.0) # img_lcn = np.maximum(img_lcn,0.0)
#img_lcn = np.minimum(img_lcn,1.0) # img_lcn = np.minimum(img_lcn,1.0)
# save to file # save to file
#misc.imsave('lcn2.png',img_lcn) # misc.imsave('lcn2.png',img_lcn)
print ("Orig Image: %d x %d (%s), Min: %f, Max: %f" % \ print("Orig Image: %d x %d (%s), Min: %f, Max: %f" % \
(img.shape[0], img.shape[1], img.dtype, img.min(), img.max())) (img.shape[0], img.shape[1], img.dtype, img.min(), img.max()))
print ("Norm Image: %d x %d (%s), Min: %f, Max: %f" % \ print("Norm Image: %d x %d (%s), Min: %f, Max: %f" % \
(img_lcn.shape[0], img_lcn.shape[1], img_lcn.dtype, img_lcn.min(), img_lcn.max())) (img_lcn.shape[0], img_lcn.shape[1], img_lcn.dtype, img_lcn.min(), img_lcn.max()))
# plot original image # plot original image
plt.figure(1) plt.figure(1)
img_plot = plt.imshow(img) img_plot = plt.imshow(img)
img_plot.set_cmap('gray') img_plot.set_cmap('gray')
plt.clim(0, 1) # fix range plt.clim(0, 1) # fix range
plt.tight_layout() plt.tight_layout()
# plot normalized image # plot normalized image
plt.figure(2) plt.figure(2)
img_lcn_plot = plt.imshow(img_lcn) img_lcn_plot = plt.imshow(img_lcn)
img_lcn_plot.set_cmap('gray') img_lcn_plot.set_cmap('gray')
#plt.clim(0, 1) # fix range # plt.clim(0, 1) # fix range
plt.tight_layout() plt.tight_layout()
# plot stddev image # plot stddev image
plt.figure(3) plt.figure(3)
img_std_plot = plt.imshow(img_std) img_std_plot = plt.imshow(img_std)
img_std_plot.set_cmap('gray') img_std_plot.set_cmap('gray')
#plt.clim(0, 0.1) # fix range # plt.clim(0, 0.1) # fix range
plt.tight_layout() plt.tight_layout()
plt.show() plt.show()

@ -11,28 +11,27 @@ import dataset
def get_data(n, row_from, row_to, train): def get_data(n, row_from, row_to, train):
imsizes = [(256,384)] imsizes = [(256, 384)]
focal_lengths = [160] focal_lengths = [160]
dset = dataset.SynDataset(n, imsizes=imsizes, focal_lengths=focal_lengths, train=train) dset = dataset.SynDataset(n, imsizes=imsizes, focal_lengths=focal_lengths, train=train)
ims = np.empty((n, row_to-row_from, imsizes[0][1]), dtype=np.uint8) ims = np.empty((n, row_to - row_from, imsizes[0][1]), dtype=np.uint8)
disps = np.empty((n, row_to-row_from, imsizes[0][1]), dtype=np.float32) disps = np.empty((n, row_to - row_from, imsizes[0][1]), dtype=np.float32)
for idx in range(n): for idx in range(n):
print(f'load sample {idx} train={train}') print(f'load sample {idx} train={train}')
sample = dset[idx] sample = dset[idx]
ims[idx] = (sample['im0'][0,row_from:row_to] * 255).astype(np.uint8) ims[idx] = (sample['im0'][0, row_from:row_to] * 255).astype(np.uint8)
disps[idx] = sample['disp0'][0,row_from:row_to] disps[idx] = sample['disp0'][0, row_from:row_to]
return ims, disps return ims, disps
params = hd.TrainParams( params = hd.TrainParams(
n_trees=4, n_trees=4,
max_tree_depth=, max_tree_depth=,
n_test_split_functions=50, n_test_split_functions=50,
n_test_thresholds=10, n_test_thresholds=10,
n_test_samples=4096, n_test_samples=4096,
min_samples_to_split=16, min_samples_to_split=16,
min_samples_for_leaf=8) min_samples_for_leaf=8)
n_disp_bins = 20 n_disp_bins = 20
depth_switch = 0 depth_switch = 0
@ -45,21 +44,23 @@ n_test_samples = 32
train_ims, train_disps = get_data(n_train_samples, row_from, row_to, True) train_ims, train_disps = get_data(n_train_samples, row_from, row_to, True)
test_ims, test_disps = get_data(n_test_samples, row_from, row_to, False) test_ims, test_disps = get_data(n_test_samples, row_from, row_to, False)
for tree_depth in [8,10,12,14,16]: for tree_depth in [8, 10, 12, 14, 16]:
depth_switch = tree_depth - 4 depth_switch = tree_depth - 4
prefix = f'td{tree_depth}_ds{depth_switch}' prefix = f'td{tree_depth}_ds{depth_switch}'
prefix = Path(f'./forests/{prefix}/') prefix = Path(f'./forests/{prefix}/')
prefix.mkdir(parents=True, exist_ok=True) prefix.mkdir(parents=True, exist_ok=True)
hd.train_forest(params, train_ims, train_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch, forest_prefix=str(prefix / 'fr')) hd.train_forest(params, train_ims, train_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch,
forest_prefix=str(prefix / 'fr'))
es = hd.eval_forest(test_ims, test_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch, forest_prefix=str(prefix / 'fr')) es = hd.eval_forest(test_ims, test_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch,
forest_prefix=str(prefix / 'fr'))
np.save(str(prefix / 'ta.npy'), test_disps) np.save(str(prefix / 'ta.npy'), test_disps)
np.save(str(prefix / 'es.npy'), es) np.save(str(prefix / 'es.npy'), es)
# plt.figure(); # plt.figure();
# plt.subplot(2,1,1); plt.imshow(test_disps[0], vmin=0, vmax=4); # plt.subplot(2,1,1); plt.imshow(test_disps[0], vmin=0, vmax=4);
# plt.subplot(2,1,2); plt.imshow(es[0], vmin=0, vmax=4); # plt.subplot(2,1,2); plt.imshow(es[0], vmin=0, vmax=4);
# plt.show() # plt.show()

@ -8,7 +8,6 @@ import os
this_dir = os.path.dirname(__file__) this_dir = os.path.dirname(__file__)
extra_compile_args = ['-O3', '-std=c++11'] extra_compile_args = ['-O3', '-std=c++11']
extra_link_args = [] extra_link_args = []
@ -22,24 +21,20 @@ library_dirs = []
libraries = ['m'] libraries = ['m']
setup( setup(
name="hyperdepth", name="hyperdepth",
cmdclass= {'build_ext': build_ext}, cmdclass={'build_ext': build_ext},
ext_modules=[ ext_modules=[
Extension('hyperdepth', Extension('hyperdepth',
sources, sources,
extra_objects=extra_objects, extra_objects=extra_objects,
language='c++', language='c++',
library_dirs=library_dirs, library_dirs=library_dirs,
libraries=libraries, libraries=libraries,
include_dirs=[ include_dirs=[
np.get_include(), np.get_include(),
], ],
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args extra_link_args=extra_link_args
) )
] ]
) )

@ -6,10 +6,13 @@ orig = cv2.imread('disp_orig.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
ta = cv2.imread('disp_ta.png', cv2.IMREAD_ANYDEPTH).astype(np.float32) ta = cv2.imread('disp_ta.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
es = cv2.imread('disp_es.png', cv2.IMREAD_ANYDEPTH).astype(np.float32) es = cv2.imread('disp_es.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
plt.figure() plt.figure()
plt.subplot(2,2,1); plt.imshow(orig / 16, vmin=0, vmax=4, cmap='magma') plt.subplot(2, 2, 1);
plt.subplot(2,2,2); plt.imshow(ta / 16, vmin=0, vmax=4, cmap='magma') plt.imshow(orig / 16, vmin=0, vmax=4, cmap='magma')
plt.subplot(2,2,3); plt.imshow(es / 16, vmin=0, vmax=4, cmap='magma') plt.subplot(2, 2, 2);
plt.subplot(2,2,4); plt.imshow(np.abs(es - ta) / 16, vmin=0, vmax=1, cmap='magma') plt.imshow(ta / 16, vmin=0, vmax=4, cmap='magma')
plt.subplot(2, 2, 3);
plt.imshow(es / 16, vmin=0, vmax=4, cmap='magma')
plt.subplot(2, 2, 4);
plt.imshow(np.abs(es - ta) / 16, vmin=0, vmax=1, cmap='magma')
plt.show() plt.show()

@ -12,226 +12,263 @@ import torchext
from model import networks from model import networks
from data import dataset from data import dataset
class Worker(torchext.Worker):
def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs):
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, train_batch_size=train_batch_size, test_batch_size=test_batch_size, save_frequency=save_frequency, **kwargs)
self.ms = args.ms
self.pattern_path = args.pattern_path
self.lcn_radius = args.lcn_radius
self.dp_weight = args.dp_weight
self.data_type = args.data_type
self.imsizes = [(480,640)]
for iter in range(3):
self.imsizes.append((int(self.imsizes[-1][0]/2), int(self.imsizes[-1][1]/2)))
with open('config.json') as fp:
config = json.load(fp)
data_root = Path(config['DATA_ROOT'])
self.settings_path = data_root / self.data_type / 'settings.pkl'
sample_paths = sorted((data_root / self.data_type).glob('0*/'))
self.train_paths = sample_paths[2**10:]
self.test_paths = sample_paths[:2**8]
# supervise the edge encoder with only 2**8 samples
self.train_edge = len(self.train_paths) - 2**8
self.lcn_in = networks.LCN(self.lcn_radius, 0.05)
self.disparity_loss = networks.DisparityLoss()
self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device))
# evaluate in the region where opencv Block Matching has valid values
self.eval_mask = np.zeros(self.imsizes[0])
self.eval_mask[13:self.imsizes[0][0]-13, 140:self.imsizes[0][1]-13]=1
self.eval_mask = self.eval_mask.astype(np.bool)
self.eval_h = self.imsizes[0][0]-2*13
self.eval_w = self.imsizes[0][1]-13-140
def get_train_set(self):
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True, track_length=1)
return train_set
def get_test_sets(self):
test_sets = torchext.TestSets()
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True, track_length=1)
test_sets.append('simple', test_set, test_frequency=1)
# initialize photometric loss modules according to image sizes
self.losses = []
for imsize, pat in zip(test_set.imsizes, test_set.patterns):
pat = pat.mean(axis=2)
pat = torch.from_numpy(pat[None][None].astype(np.float32))
pat = pat.to(self.train_device)
self.lcn_in = self.lcn_in.to(self.train_device)
pat,_ = self.lcn_in(pat)
pat = torch.cat([pat for idx in range(3)], dim=1)
self.losses.append( networks.RectifiedPatternSimilarityLoss(imsize[0],imsize[1], pattern=pat) )
return test_sets
def copy_data(self, data, device, requires_grad, train):
self.lcn_in = self.lcn_in.to(device)
self.data = {}
for key, val in data.items():
grad = 'im' in key and requires_grad
self.data[key] = val.to(device).requires_grad_(requires_grad=grad)
# apply lcn to IR input
# concatenate the normalized IR input and the original IR image
if 'im' in key and 'blend' not in key:
im = self.data[key]
im_lcn,im_std = self.lcn_in(im)
im_cat = torch.cat((im_lcn, im), dim=1)
key_std = key.replace('im','std')
self.data[key]=im_cat
self.data[key_std] = im_std.to(device).detach()
def net_forward(self, net, train):
out = net(self.data['im0'])
return out
def loss_forward(self, out, train):
out, edge = out
if not(isinstance(out, tuple) or isinstance(out, list)):
out = [out]
if not(isinstance(edge, tuple) or isinstance(edge, list)):
edge = [edge]
vals = []
# apply photometric loss
for s,l,o in zip(itertools.count(), self.losses, out):
val, pattern_proj = l(o, self.data[f'im{s}'][:,0:1,...], self.data[f'std{s}'])
if s == 0:
self.pattern_proj = pattern_proj.detach()
vals.append(val)
# apply disparity loss
# 1-edge as ground truth edge if inversed
edge0 = 1-torch.sigmoid(edge[0])
val = self.disparity_loss(out[0], edge0)
if self.dp_weight>0:
vals.append(val * self.dp_weight)
# apply edge loss on a subset of training samples
for s,e in zip(itertools.count(), edge):
# inversed ground truth edge where 0 means edge
grad = self.data[f'grad{s}']<0.2
grad = grad.to(torch.float32)
ids = self.data['id']
mask = ids>self.train_edge
if mask.sum()>0:
val = self.edge_loss(e[mask], grad[mask])
else:
val = torch.zeros_like(vals[0])
if s == 0:
self.edge = e.detach()
self.edge = torch.sigmoid(self.edge)
self.edge_gt = grad.detach()
vals.append(val)
return vals
def numpy_in_out(self, output):
output, edge = output
if not(isinstance(output, tuple) or isinstance(output, list)):
output = [output]
es = output[0].detach().to('cpu').numpy()
gt = self.data['disp0'].to('cpu').numpy().astype(np.float32)
im = self.data['im0'][:,0:1,...].detach().to('cpu').numpy()
ma = gt>0
return es, gt, im, ma
def write_img(self, out_path, es, gt, im, ma):
logging.info(f'write img {out_path}')
u_pos, _ = np.meshgrid(range(es.shape[1]), range(es.shape[0]))
diff = np.abs(es - gt)
vmin, vmax = np.nanmin(gt), np.nanmax(gt)
vmin = vmin - 0.2*(vmax-vmin)
vmax = vmax + 0.2*(vmax-vmin)
pattern_proj = self.pattern_proj.to('cpu').numpy()[0,0]
im_orig = self.data['im0'].detach().to('cpu').numpy()[0,0]
pattern_diff = np.abs(im_orig - pattern_proj)
fig = plt.figure(figsize=(16,16))
es_ = co.cmap.color_depth_map(es, scale=vmax)
gt_ = co.cmap.color_depth_map(gt, scale=vmax)
diff_ = co.cmap.color_error_image(diff, BGR=True)
# plot disparities, ground truth disparity is shown only for reference
ax = plt.subplot(3,3,1); plt.imshow(es_[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'Disparity Est. {es.min():.4f}/{es.max():.4f}')
ax = plt.subplot(3,3,2); plt.imshow(gt_[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'Disparity GT {np.nanmin(gt):.4f}/{np.nanmax(gt):.4f}')
ax = plt.subplot(3,3,3); plt.imshow(diff_[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'Disparity Err. {diff.mean():.5f}')
# plot edges
edge = self.edge.to('cpu').numpy()[0,0]
edge_gt = self.edge_gt.to('cpu').numpy()[0,0]
edge_err = np.abs(edge - edge_gt)
ax = plt.subplot(3,3,4); plt.imshow(edge, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Edge Est. {edge.min():.5f}/{edge.max():.5f}')
ax = plt.subplot(3,3,5); plt.imshow(edge_gt, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Edge GT {edge_gt.min():.5f}/{edge_gt.max():.5f}')
ax = plt.subplot(3,3,6); plt.imshow(edge_err, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Edge Err. {edge_err.mean():.5f}')
# plot normalized IR input and warped pattern
ax = plt.subplot(3,3,7); plt.imshow(im, vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'IR input {im.mean():.5f}/{im.std():.5f}')
ax = plt.subplot(3,3,8); plt.imshow(pattern_proj, vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Warped Pattern {pattern_proj.mean():.5f}/{pattern_proj.std():.5f}')
im_std = self.data['std0'].to('cpu').numpy()[0,0]
ax = plt.subplot(3,3,9); plt.imshow(im_std, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'IR std {im_std.min():.5f}/{im_std.max():.5f}')
plt.tight_layout()
plt.savefig(str(out_path))
plt.close(fig)
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks=[]):
if batch_idx % 512 == 0:
out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png'
es, gt, im, ma = self.numpy_in_out(output)
self.write_img(out_path, es[0,0], gt[0,0], im[0,0], ma[0,0])
def callback_test_start(self, epoch, set_idx):
self.metric = co.metric.MultipleMetric(
co.metric.DistanceMetric(vec_length=1),
co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5])
)
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks=[]):
es, gt, im, ma = self.numpy_in_out(output)
if batch_idx % 8 == 0:
out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png'
self.write_img(out_path, es[0,0], gt[0,0], im[0,0], ma[0,0])
es, gt, im, ma = self.crop_output(es, gt, im, ma)
es = es.reshape(-1,1)
gt = gt.reshape(-1,1)
ma = ma.ravel()
self.metric.add(es, gt, ma)
def callback_test_stop(self, epoch, set_idx, loss):
logging.info(f'{self.metric}')
for k, v in self.metric.items():
self.metric_add_test(epoch, set_idx, k, v)
def crop_output(self, es, gt, im, ma):
bs = es.shape[0]
es = np.reshape(es[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
gt = np.reshape(gt[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
im = np.reshape(im[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
ma = np.reshape(ma[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
return es, gt, im, ma
class Worker(torchext.Worker):
def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs):
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers,
train_batch_size=train_batch_size, test_batch_size=test_batch_size,
save_frequency=save_frequency, **kwargs)
self.ms = args.ms
self.pattern_path = args.pattern_path
self.lcn_radius = args.lcn_radius
self.dp_weight = args.dp_weight
self.data_type = args.data_type
self.imsizes = [(488, 648)]
for iter in range(3):
self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2)))
with open('config.json') as fp:
config = json.load(fp)
data_root = Path(config['DATA_ROOT'])
self.settings_path = data_root / self.data_type / 'settings.pkl'
sample_paths = sorted((data_root / self.data_type).glob('0*/'))
self.train_paths = sample_paths[2 ** 10:]
self.test_paths = sample_paths[:2 ** 8]
# supervise the edge encoder with only 2**8 samples
self.train_edge = len(self.train_paths) - 2 ** 8
self.lcn_in = networks.LCN(self.lcn_radius, 0.05)
self.disparity_loss = networks.DisparityLoss()
self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device))
# evaluate in the region where opencv Block Matching has valid values
self.eval_mask = np.zeros(self.imsizes[0])
self.eval_mask[13:self.imsizes[0][0] - 13, 140:self.imsizes[0][1] - 13] = 1
self.eval_mask = self.eval_mask.astype(np.bool)
self.eval_h = self.imsizes[0][0] - 2 * 13
self.eval_w = self.imsizes[0][1] - 13 - 140
def get_train_set(self):
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True,
track_length=1)
return train_set
def get_test_sets(self):
test_sets = torchext.TestSets()
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True,
track_length=1)
test_sets.append('simple', test_set, test_frequency=1)
# initialize photometric loss modules according to image sizes
self.losses = []
for imsize, pat in zip(test_set.imsizes, test_set.patterns):
pat = pat.mean(axis=2)
pat = torch.from_numpy(pat[None][None].astype(np.float32))
pat = pat.to(self.train_device)
self.lcn_in = self.lcn_in.to(self.train_device)
pat, _ = self.lcn_in(pat)
pat = torch.cat([pat for idx in range(3)], dim=1)
self.losses.append(networks.RectifiedPatternSimilarityLoss(imsize[0], imsize[1], pattern=pat))
return test_sets
def copy_data(self, data, device, requires_grad, train):
self.lcn_in = self.lcn_in.to(device)
self.data = {}
for key, val in data.items():
grad = 'im' in key and requires_grad
self.data[key] = val.to(device).requires_grad_(requires_grad=grad)
# apply lcn to IR input
# concatenate the normalized IR input and the original IR image
if 'im' in key and 'blend' not in key:
im = self.data[key]
im_lcn, im_std = self.lcn_in(im)
im_cat = torch.cat((im_lcn, im), dim=1)
key_std = key.replace('im', 'std')
self.data[key] = im_cat
self.data[key_std] = im_std.to(device).detach()
def net_forward(self, net, train):
out = net(self.data['im0'])
return out
def loss_forward(self, out, train):
out, edge = out
if not (isinstance(out, tuple) or isinstance(out, list)):
out = [out]
if not (isinstance(edge, tuple) or isinstance(edge, list)):
edge = [edge]
vals = []
# apply photometric loss
for s, l, o in zip(itertools.count(), self.losses, out):
val, pattern_proj = l(o, self.data[f'im{s}'][:, 0:1, ...], self.data[f'std{s}'])
if s == 0:
self.pattern_proj = pattern_proj.detach()
vals.append(val)
# apply disparity loss
# 1-edge as ground truth edge if inversed
edge0 = 1 - torch.sigmoid(edge[0])
val = self.disparity_loss(out[0], edge0)
if self.dp_weight > 0:
vals.append(val * self.dp_weight)
# apply edge loss on a subset of training samples
for s, e in zip(itertools.count(), edge):
# inversed ground truth edge where 0 means edge
grad = self.data[f'grad{s}'] < 0.2
grad = grad.to(torch.float32)
ids = self.data['id']
mask = ids > self.train_edge
if mask.sum() > 0:
val = self.edge_loss(e[mask], grad[mask])
else:
val = torch.zeros_like(vals[0])
if s == 0:
self.edge = e.detach()
self.edge = torch.sigmoid(self.edge)
self.edge_gt = grad.detach()
vals.append(val)
return vals
def numpy_in_out(self, output):
output, edge = output
if not (isinstance(output, tuple) or isinstance(output, list)):
output = [output]
es = output[0].detach().to('cpu').numpy()
gt = self.data['disp0'].to('cpu').numpy().astype(np.float32)
im = self.data['im0'][:, 0:1, ...].detach().to('cpu').numpy()
ma = gt > 0
return es, gt, im, ma
def write_img(self, out_path, es, gt, im, ma):
logging.info(f'write img {out_path}')
u_pos, _ = np.meshgrid(range(es.shape[1]), range(es.shape[0]))
diff = np.abs(es - gt)
vmin, vmax = np.nanmin(gt), np.nanmax(gt)
vmin = vmin - 0.2 * (vmax - vmin)
vmax = vmax + 0.2 * (vmax - vmin)
pattern_proj = self.pattern_proj.to('cpu').numpy()[0, 0]
im_orig = self.data['im0'].detach().to('cpu').numpy()[0, 0]
pattern_diff = np.abs(im_orig - pattern_proj)
fig = plt.figure(figsize=(16, 16))
es_ = co.cmap.color_depth_map(es, scale=vmax)
gt_ = co.cmap.color_depth_map(gt, scale=vmax)
diff_ = co.cmap.color_error_image(diff, BGR=True)
# plot disparities, ground truth disparity is shown only for reference
ax = plt.subplot(3, 3, 1)
plt.imshow(es_[..., [2, 1, 0]])
plt.xticks([])
plt.yticks([])
ax.set_title(f'Disparity Est. {es.min():.4f}/{es.max():.4f}')
ax = plt.subplot(3, 3, 2)
plt.imshow(gt_[..., [2, 1, 0]])
plt.xticks([])
plt.yticks([])
ax.set_title(f'Disparity GT {np.nanmin(gt):.4f}/{np.nanmax(gt):.4f}')
ax = plt.subplot(3, 3, 3)
plt.imshow(diff_[..., [2, 1, 0]])
plt.xticks([])
plt.yticks([])
ax.set_title(f'Disparity Err. {diff.mean():.5f}')
# plot edges
edge = self.edge.to('cpu').numpy()[0, 0]
edge_gt = self.edge_gt.to('cpu').numpy()[0, 0]
edge_err = np.abs(edge - edge_gt)
ax = plt.subplot(3, 3, 4);
plt.imshow(edge, cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'Edge Est. {edge.min():.5f}/{edge.max():.5f}')
ax = plt.subplot(3, 3, 5);
plt.imshow(edge_gt, cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'Edge GT {edge_gt.min():.5f}/{edge_gt.max():.5f}')
ax = plt.subplot(3, 3, 6);
plt.imshow(edge_err, cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'Edge Err. {edge_err.mean():.5f}')
# plot normalized IR input and warped pattern
ax = plt.subplot(3, 3, 7);
plt.imshow(im, vmin=im.min(), vmax=im.max(), cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'IR input {im.mean():.5f}/{im.std():.5f}')
ax = plt.subplot(3, 3, 8);
plt.imshow(pattern_proj, vmin=im.min(), vmax=im.max(), cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'Warped Pattern {pattern_proj.mean():.5f}/{pattern_proj.std():.5f}')
im_std = self.data['std0'].to('cpu').numpy()[0, 0]
ax = plt.subplot(3, 3, 9);
plt.imshow(im_std, cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'IR std {im_std.min():.5f}/{im_std.max():.5f}')
plt.tight_layout()
plt.savefig(str(out_path))
plt.close(fig)
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks=[]):
if batch_idx % 512 == 0:
out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png'
es, gt, im, ma = self.numpy_in_out(output)
self.write_img(out_path, es[0, 0], gt[0, 0], im[0, 0], ma[0, 0])
def callback_test_start(self, epoch, set_idx):
self.metric = co.metric.MultipleMetric(
co.metric.DistanceMetric(vec_length=1),
co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5])
)
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks=[]):
es, gt, im, ma = self.numpy_in_out(output)
if batch_idx % 8 == 0:
out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png'
self.write_img(out_path, es[0, 0], gt[0, 0], im[0, 0], ma[0, 0])
es, gt, im, ma = self.crop_output(es, gt, im, ma)
es = es.reshape(-1, 1)
gt = gt.reshape(-1, 1)
ma = ma.ravel()
self.metric.add(es, gt, ma)
def callback_test_stop(self, epoch, set_idx, loss):
logging.info(f'{self.metric}')
for k, v in self.metric.items():
self.metric_add_test(epoch, set_idx, k, v)
def crop_output(self, es, gt, im, ma):
bs = es.shape[0]
es = np.reshape(es[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
gt = np.reshape(gt[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
im = np.reshape(im[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
ma = np.reshape(ma[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
return es, gt, im, ma
if __name__ == '__main__': if __name__ == '__main__':
pass pass

@ -12,287 +12,324 @@ import torchext
from model import networks from model import networks
from data import dataset from data import dataset
class Worker(torchext.Worker): class Worker(torchext.Worker):
def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs): def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs):
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, train_batch_size=train_batch_size, test_batch_size=test_batch_size, save_frequency=save_frequency, **kwargs) super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers,
train_batch_size=train_batch_size, test_batch_size=test_batch_size,
self.ms = args.ms save_frequency=save_frequency, **kwargs)
self.pattern_path = args.pattern_path
self.lcn_radius = args.lcn_radius self.ms = args.ms
self.dp_weight = args.dp_weight self.pattern_path = args.pattern_path
self.ge_weight = args.ge_weight self.lcn_radius = args.lcn_radius
self.track_length = args.track_length self.dp_weight = args.dp_weight
self.data_type = args.data_type self.ge_weight = args.ge_weight
assert(self.track_length>1) self.track_length = args.track_length
self.data_type = args.data_type
self.imsizes = [(480,640)] assert (self.track_length > 1)
for iter in range(3):
self.imsizes.append((int(self.imsizes[-1][0]/2), int(self.imsizes[-1][1]/2))) self.imsizes = [(480, 640)]
for iter in range(3):
with open('config.json') as fp: self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2)))
config = json.load(fp)
data_root = Path(config['DATA_ROOT']) with open('config.json') as fp:
self.settings_path = data_root / self.data_type / 'settings.pkl' config = json.load(fp)
sample_paths = sorted((data_root / self.data_type).glob('0*/')) data_root = Path(config['DATA_ROOT'])
self.settings_path = data_root / self.data_type / 'settings.pkl'
self.train_paths = sample_paths[2**10:] sample_paths = sorted((data_root / self.data_type).glob('0*/'))
self.test_paths = sample_paths[:2**8]
self.train_paths = sample_paths[2 ** 10:]
# supervise the edge encoder with only 2**8 samples self.test_paths = sample_paths[:2 ** 8]
self.train_edge = len(self.train_paths) - 2**8
# supervise the edge encoder with only 2**8 samples
self.lcn_in = networks.LCN(self.lcn_radius, 0.05) self.train_edge = len(self.train_paths) - 2 ** 8
self.disparity_loss = networks.DisparityLoss()
self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device)) self.lcn_in = networks.LCN(self.lcn_radius, 0.05)
self.disparity_loss = networks.DisparityLoss()
# evaluate in the region where opencv Block Matching has valid values self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device))
self.eval_mask = np.zeros(self.imsizes[0])
self.eval_mask[13:self.imsizes[0][0]-13, 140:self.imsizes[0][1]-13]=1 # evaluate in the region where opencv Block Matching has valid values
self.eval_mask = self.eval_mask.astype(np.bool) self.eval_mask = np.zeros(self.imsizes[0])
self.eval_h = self.imsizes[0][0]-2*13 self.eval_mask[13:self.imsizes[0][0] - 13, 140:self.imsizes[0][1] - 13] = 1
self.eval_w = self.imsizes[0][1]-13-140 self.eval_mask = self.eval_mask.astype(np.bool)
self.eval_h = self.imsizes[0][0] - 2 * 13
self.eval_w = self.imsizes[0][1] - 13 - 140
def get_train_set(self):
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True, track_length=self.track_length) def get_train_set(self):
return train_set train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True,
track_length=self.track_length)
def get_test_sets(self): return train_set
test_sets = torchext.TestSets()
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True, track_length=1) def get_test_sets(self):
test_sets.append('simple', test_set, test_frequency=1) test_sets = torchext.TestSets()
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True,
self.ph_losses = [] track_length=1)
self.ge_losses = [] test_sets.append('simple', test_set, test_frequency=1)
self.d2ds = []
self.ph_losses = []
self.lcn_in = self.lcn_in.to('cuda') self.ge_losses = []
for sidx in range(len(test_set.imsizes)): self.d2ds = []
imsize = test_set.imsizes[sidx]
pat = test_set.patterns[sidx] self.lcn_in = self.lcn_in.to('cuda')
pat = pat.mean(axis=2) for sidx in range(len(test_set.imsizes)):
pat = torch.from_numpy(pat[None][None].astype(np.float32)).to('cuda') imsize = test_set.imsizes[sidx]
pat,_ = self.lcn_in(pat) pat = test_set.patterns[sidx]
pat = torch.cat([pat for idx in range(3)], dim=1) pat = pat.mean(axis=2)
ph_loss = networks.RectifiedPatternSimilarityLoss(imsize[0],imsize[1], pattern=pat) pat = torch.from_numpy(pat[None][None].astype(np.float32)).to('cuda')
pat, _ = self.lcn_in(pat)
K = test_set.getK(sidx) pat = torch.cat([pat for idx in range(3)], dim=1)
Ki = np.linalg.inv(K) ph_loss = networks.RectifiedPatternSimilarityLoss(imsize[0], imsize[1], pattern=pat)
K = torch.from_numpy(K)
Ki = torch.from_numpy(Ki) K = test_set.getK(sidx)
ge_loss = networks.ProjectionDepthSimilarityLoss(K, Ki, imsize[0], imsize[1], clamp=0.1) Ki = np.linalg.inv(K)
K = torch.from_numpy(K)
self.ph_losses.append( ph_loss ) Ki = torch.from_numpy(Ki)
self.ge_losses.append( ge_loss ) ge_loss = networks.ProjectionDepthSimilarityLoss(K, Ki, imsize[0], imsize[1], clamp=0.1)
d2d = networks.DispToDepth(float(test_set.focal_lengths[sidx]), float(test_set.baseline)) self.ph_losses.append(ph_loss)
self.d2ds.append( d2d ) self.ge_losses.append(ge_loss)
return test_sets d2d = networks.DispToDepth(float(test_set.focal_lengths[sidx]), float(test_set.baseline))
self.d2ds.append(d2d)
def copy_data(self, data, device, requires_grad, train):
self.data = {} return test_sets
self.lcn_in = self.lcn_in.to(device) def copy_data(self, data, device, requires_grad, train):
for key, val in data.items(): self.data = {}
# from
# batch_size x track_length x ... self.lcn_in = self.lcn_in.to(device)
# to for key, val in data.items():
# track_length x batch_size x ... # from
if len(val.shape)>2: # batch_size x track_length x ...
if train: # to
val = val.transpose(0,1) # track_length x batch_size x ...
if len(val.shape) > 2:
if train:
val = val.transpose(0, 1)
else:
val = val.unsqueeze(0)
grad = 'im' in key and requires_grad
self.data[key] = val.to(device).requires_grad_(requires_grad=grad)
if 'im' in key and 'blend' not in key:
im = self.data[key]
tl = im.shape[0]
bs = im.shape[1]
im_lcn, im_std = self.lcn_in(im.contiguous().view(-1, *im.shape[2:]))
key_std = key.replace('im', 'std')
self.data[key_std] = im_std.view(tl, bs, *im.shape[2:]).to(device)
im_cat = torch.cat((im_lcn.view(tl, bs, *im.shape[2:]), im), dim=2)
self.data[key] = im_cat
def net_forward(self, net, train):
im0 = self.data['im0']
tl = im0.shape[0]
bs = im0.shape[1]
im0 = im0.view(-1, *im0.shape[2:])
out, edge = net(im0)
if not (isinstance(out, tuple) or isinstance(out, list)):
out = out.view(tl, bs, *out.shape[1:])
edge = edge.view(tl, bs, *out.shape[1:])
else: else:
val = val.unsqueeze(0) out = [o.view(tl, bs, *o.shape[1:]) for o in out]
grad = 'im' in key and requires_grad edge = [e.view(tl, bs, *e.shape[1:]) for e in edge]
self.data[key] = val.to(device).requires_grad_(requires_grad=grad) return out, edge
if 'im' in key and 'blend' not in key:
im = self.data[key] def loss_forward(self, out, train):
tl = im.shape[0] out, edge = out
bs = im.shape[1] if not (isinstance(out, tuple) or isinstance(out, list)):
im_lcn,im_std = self.lcn_in(im.contiguous().view(-1, *im.shape[2:])) out = [out]
key_std = key.replace('im','std') vals = []
self.data[key_std] = im_std.view(tl, bs, *im.shape[2:]).to(device) diffs = []
im_cat = torch.cat((im_lcn.view(tl, bs, *im.shape[2:]), im), dim=2)
self.data[key] = im_cat # apply photometric loss
for s, l, o in zip(itertools.count(), self.ph_losses, out):
def net_forward(self, net, train): im = self.data[f'im{s}']
im0 = self.data['im0'] im = im.view(-1, *im.shape[2:])
tl = im0.shape[0] o = o.view(-1, *o.shape[2:])
bs = im0.shape[1] std = self.data[f'std{s}']
im0 = im0.view(-1, *im0.shape[2:]) std = std.view(-1, *std.shape[2:])
out, edge = net(im0) val, pattern_proj = l(o, im[:, 0:1, ...], std)
if not(isinstance(out, tuple) or isinstance(out, list)): vals.append(val)
out = out.view(tl, bs, *out.shape[1:]) if s == 0:
edge = edge.view(tl, bs, *out.shape[1:]) self.pattern_proj = pattern_proj.detach()
else:
out = [o.view(tl, bs, *o.shape[1:]) for o in out] # apply disparity loss
edge = [e.view(tl, bs, *e.shape[1:]) for e in edge] # 1-edge as ground truth edge if inversed
return out, edge edge0 = 1 - torch.sigmoid(edge[0])
edge0 = edge0.view(-1, *edge0.shape[2:])
def loss_forward(self, out, train): out0 = out[0].view(-1, *out[0].shape[2:])
out, edge = out val = self.disparity_loss(out0, edge0)
if not(isinstance(out, tuple) or isinstance(out, list)): if self.dp_weight > 0:
out = [out] vals.append(val * self.dp_weight)
vals = []
diffs = [] # apply edge loss on a subset of training samples
for s, e in zip(itertools.count(), edge):
# apply photometric loss # inversed ground truth edge where 0 means edge
for s,l,o in zip(itertools.count(), self.ph_losses, out): grad = self.data[f'grad{s}'] < 0.2
im = self.data[f'im{s}'] grad = grad.to(torch.float32)
im = im.view(-1, *im.shape[2:]) ids = self.data['id']
o = o.view(-1, *o.shape[2:]) mask = ids > self.train_edge
std = self.data[f'std{s}'] if mask.sum() > 0:
std = std.view(-1, *std.shape[2:]) e = e[:, mask, :]
val, pattern_proj = l(o, im[:,0:1,...], std) grad = grad[:, mask, :]
vals.append(val) e = e.view(-1, *e.shape[2:])
if s == 0: grad = grad.view(-1, *grad.shape[2:])
self.pattern_proj = pattern_proj.detach() val = self.edge_loss(e, grad)
else:
# apply disparity loss val = torch.zeros_like(vals[0])
# 1-edge as ground truth edge if inversed vals.append(val)
edge0 = 1-torch.sigmoid(edge[0])
edge0 = edge0.view(-1, *edge0.shape[2:]) if train is False:
out0 = out[0].view(-1, *out[0].shape[2:]) return vals
val = self.disparity_loss(out0, edge0)
if self.dp_weight>0: # apply geometric loss
vals.append(val * self.dp_weight) R = self.data['R']
t = self.data['t']
# apply edge loss on a subset of training samples ge_num = self.track_length * (self.track_length - 1) / 2
for s,e in zip(itertools.count(), edge): for sidx in range(len(out)):
# inversed ground truth edge where 0 means edge d2d = self.d2ds[sidx]
grad = self.data[f'grad{s}']<0.2 depth = d2d(out[sidx])
grad = grad.to(torch.float32) ge_loss = self.ge_losses[sidx]
ids = self.data['id'] imsize = self.imsizes[sidx]
mask = ids>self.train_edge for tidx0 in range(depth.shape[0]):
if mask.sum()>0: for tidx1 in range(tidx0 + 1, depth.shape[0]):
e = e[:,mask,:] depth0 = depth[tidx0]
grad = grad[:,mask,:] R0 = R[tidx0]
e = e.view(-1, *e.shape[2:]) t0 = t[tidx0]
grad = grad.view(-1, *grad.shape[2:]) depth1 = depth[tidx1]
val = self.edge_loss(e, grad) R1 = R[tidx1]
else: t1 = t[tidx1]
val = torch.zeros_like(vals[0])
vals.append(val) val = ge_loss(depth0, depth1, R0, t0, R1, t1)
vals.append(val * self.ge_weight / ge_num)
if train is False:
return vals return vals
# apply geometric loss def numpy_in_out(self, output):
R = self.data['R'] output, edge = output
t = self.data['t'] if not (isinstance(output, tuple) or isinstance(output, list)):
ge_num = self.track_length * (self.track_length-1) / 2 output = [output]
for sidx in range(len(out)): es = output[0].detach().to('cpu').numpy()
d2d = self.d2ds[sidx] gt = self.data['disp0'].to('cpu').numpy().astype(np.float32)
depth = d2d(out[sidx]) im = self.data['im0'][:, :, 0:1, ...].detach().to('cpu').numpy()
ge_loss = self.ge_losses[sidx] ma = gt > 0
imsize = self.imsizes[sidx] return es, gt, im, ma
for tidx0 in range(depth.shape[0]):
for tidx1 in range(tidx0+1, depth.shape[0]): def write_img(self, out_path, es, gt, im, ma):
depth0 = depth[tidx0] logging.info(f'write img {out_path}')
R0 = R[tidx0] u_pos, _ = np.meshgrid(range(es.shape[1]), range(es.shape[0]))
t0 = t[tidx0]
depth1 = depth[tidx1] diff = np.abs(es - gt)
R1 = R[tidx1]
t1 = t[tidx1] vmin, vmax = np.nanmin(gt), np.nanmax(gt)
vmin = vmin - 0.2 * (vmax - vmin)
val = ge_loss(depth0, depth1, R0, t0, R1, t1) vmax = vmax + 0.2 * (vmax - vmin)
vals.append(val * self.ge_weight / ge_num)
pattern_proj = self.pattern_proj.to('cpu').numpy()[0, 0]
return vals im_orig = self.data['im0'].detach().to('cpu').numpy()[0, 0, 0]
pattern_diff = np.abs(im_orig - pattern_proj)
def numpy_in_out(self, output):
output, edge = output fig = plt.figure(figsize=(16, 16))
if not(isinstance(output, tuple) or isinstance(output, list)): es0 = co.cmap.color_depth_map(es[0], scale=vmax)
output = [output] gt0 = co.cmap.color_depth_map(gt[0], scale=vmax)
es = output[0].detach().to('cpu').numpy() diff0 = co.cmap.color_error_image(diff[0], BGR=True)
gt = self.data['disp0'].to('cpu').numpy().astype(np.float32)
im = self.data['im0'][:,:,0:1,...].detach().to('cpu').numpy() # plot disparities, ground truth disparity is shown only for reference
ma = gt>0 ax = plt.subplot(3, 3, 1);
return es, gt, im, ma plt.imshow(es0[..., [2, 1, 0]]);
plt.xticks([]);
def write_img(self, out_path, es, gt, im, ma): plt.yticks([]);
logging.info(f'write img {out_path}') ax.set_title(f'F0 Disparity Est. {es0.min():.4f}/{es0.max():.4f}')
u_pos, _ = np.meshgrid(range(es.shape[1]), range(es.shape[0])) ax = plt.subplot(3, 3, 2);
plt.imshow(gt0[..., [2, 1, 0]]);
diff = np.abs(es - gt) plt.xticks([]);
plt.yticks([]);
vmin, vmax = np.nanmin(gt), np.nanmax(gt) ax.set_title(f'F0 Disparity GT {np.nanmin(gt0):.4f}/{np.nanmax(gt0):.4f}')
vmin = vmin - 0.2*(vmax-vmin) ax = plt.subplot(3, 3, 3);
vmax = vmax + 0.2*(vmax-vmin) plt.imshow(diff0[..., [2, 1, 0]]);
plt.xticks([]);
pattern_proj = self.pattern_proj.to('cpu').numpy()[0,0] plt.yticks([]);
im_orig = self.data['im0'].detach().to('cpu').numpy()[0,0,0] ax.set_title(f'F0 Disparity Err. {diff0.mean():.5f}')
pattern_diff = np.abs(im_orig - pattern_proj)
# plot disparities of the second frame in the track if exists
fig = plt.figure(figsize=(16,16)) if es.shape[0] >= 2:
es0 = co.cmap.color_depth_map(es[0], scale=vmax) es1 = co.cmap.color_depth_map(es[1], scale=vmax)
gt0 = co.cmap.color_depth_map(gt[0], scale=vmax) gt1 = co.cmap.color_depth_map(gt[1], scale=vmax)
diff0 = co.cmap.color_error_image(diff[0], BGR=True) diff1 = co.cmap.color_error_image(diff[1], BGR=True)
ax = plt.subplot(3, 3, 4);
# plot disparities, ground truth disparity is shown only for reference plt.imshow(es1[..., [2, 1, 0]]);
ax = plt.subplot(3,3,1); plt.imshow(es0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Est. {es0.min():.4f}/{es0.max():.4f}') plt.xticks([]);
ax = plt.subplot(3,3,2); plt.imshow(gt0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity GT {np.nanmin(gt0):.4f}/{np.nanmax(gt0):.4f}') plt.yticks([]);
ax = plt.subplot(3,3,3); plt.imshow(diff0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Err. {diff0.mean():.5f}') ax.set_title(f'F1 Disparity Est. {es1.min():.4f}/{es1.max():.4f}')
ax = plt.subplot(3, 3, 5);
# plot disparities of the second frame in the track if exists plt.imshow(gt1[..., [2, 1, 0]]);
if es.shape[0]>=2: plt.xticks([]);
es1 = co.cmap.color_depth_map(es[1], scale=vmax) plt.yticks([]);
gt1 = co.cmap.color_depth_map(gt[1], scale=vmax) ax.set_title(f'F1 Disparity GT {np.nanmin(gt1):.4f}/{np.nanmax(gt1):.4f}')
diff1 = co.cmap.color_error_image(diff[1], BGR=True) ax = plt.subplot(3, 3, 6);
ax = plt.subplot(3,3,4); plt.imshow(es1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity Est. {es1.min():.4f}/{es1.max():.4f}') plt.imshow(diff1[..., [2, 1, 0]]);
ax = plt.subplot(3,3,5); plt.imshow(gt1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity GT {np.nanmin(gt1):.4f}/{np.nanmax(gt1):.4f}') plt.xticks([]);
ax = plt.subplot(3,3,6); plt.imshow(diff1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity Err. {diff1.mean():.5f}') plt.yticks([]);
ax.set_title(f'F1 Disparity Err. {diff1.mean():.5f}')
# plot normalized IR inputs
ax = plt.subplot(3,3,7); plt.imshow(im[0], vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 IR input {im[0].mean():.5f}/{im[0].std():.5f}') # plot normalized IR inputs
if es.shape[0]>=2: ax = plt.subplot(3, 3, 7);
ax = plt.subplot(3,3,8); plt.imshow(im[1], vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 IR input {im[1].mean():.5f}/{im[1].std():.5f}') plt.imshow(im[0], vmin=im.min(), vmax=im.max(), cmap='gray');
plt.xticks([]);
plt.tight_layout() plt.yticks([]);
plt.savefig(str(out_path)) ax.set_title(f'F0 IR input {im[0].mean():.5f}/{im[0].std():.5f}')
plt.close(fig) if es.shape[0] >= 2:
ax = plt.subplot(3, 3, 8);
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks): plt.imshow(im[1], vmin=im.min(), vmax=im.max(), cmap='gray');
if batch_idx % 512 == 0: plt.xticks([]);
out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png' plt.yticks([]);
es, gt, im, ma = self.numpy_in_out(output) ax.set_title(f'F1 IR input {im[1].mean():.5f}/{im[1].std():.5f}')
masks = [ m.detach().to('cpu').numpy() for m in masks ]
self.write_img(out_path, es[:,0,0], gt[:,0,0], im[:,0,0], ma[:,0,0]) plt.tight_layout()
plt.savefig(str(out_path))
def callback_test_start(self, epoch, set_idx): plt.close(fig)
self.metric = co.metric.MultipleMetric(
co.metric.DistanceMetric(vec_length=1), def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks):
co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5]) if batch_idx % 512 == 0:
) out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png'
es, gt, im, ma = self.numpy_in_out(output)
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks): masks = [m.detach().to('cpu').numpy() for m in masks]
es, gt, im, ma = self.numpy_in_out(output) self.write_img(out_path, es[:, 0, 0], gt[:, 0, 0], im[:, 0, 0], ma[:, 0, 0])
if batch_idx % 8 == 0: def callback_test_start(self, epoch, set_idx):
out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png' self.metric = co.metric.MultipleMetric(
self.write_img(out_path, es[:,0,0], gt[:,0,0], im[:,0,0], ma[:,0,0]) co.metric.DistanceMetric(vec_length=1),
co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5])
es, gt, im, ma = self.crop_output(es, gt, im, ma) )
es = es.reshape(-1,1) def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks):
gt = gt.reshape(-1,1) es, gt, im, ma = self.numpy_in_out(output)
ma = ma.ravel()
self.metric.add(es, gt, ma) if batch_idx % 8 == 0:
out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png'
def callback_test_stop(self, epoch, set_idx, loss): self.write_img(out_path, es[:, 0, 0], gt[:, 0, 0], im[:, 0, 0], ma[:, 0, 0])
logging.info(f'{self.metric}')
for k, v in self.metric.items(): es, gt, im, ma = self.crop_output(es, gt, im, ma)
self.metric_add_test(epoch, set_idx, k, v)
es = es.reshape(-1, 1)
def crop_output(self, es, gt, im, ma): gt = gt.reshape(-1, 1)
tl = es.shape[0] ma = ma.ravel()
bs = es.shape[1] self.metric.add(es, gt, ma)
es = np.reshape(es[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w])
gt = np.reshape(gt[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w]) def callback_test_stop(self, epoch, set_idx, loss):
im = np.reshape(im[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w]) logging.info(f'{self.metric}')
ma = np.reshape(ma[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w]) for k, v in self.metric.items():
return es, gt, im, ma self.metric_add_test(epoch, set_idx, k, v)
def crop_output(self, es, gt, im, ma):
tl = es.shape[0]
bs = es.shape[1]
es = np.reshape(es[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w])
gt = np.reshape(gt[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w])
im = np.reshape(im[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w])
ma = np.reshape(ma[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w])
return es, gt, im, ma
if __name__ == '__main__': if __name__ == '__main__':
pass pass

File diff suppressed because it is too large Load Diff

@ -6,7 +6,9 @@ This repository contains the code for the paper
**[Connecting the Dots: Learning Representations for Active Monocular Depth Estimation](http://www.cvlibs.net/publications/Riegler2019CVPR.pdf)** **[Connecting the Dots: Learning Representations for Active Monocular Depth Estimation](http://www.cvlibs.net/publications/Riegler2019CVPR.pdf)**
<br> <br>
[Gernot Riegler](https://griegler.github.io/), [Yiyi Liao](https://yiyiliao.github.io/), [Simon Donne](https://avg.is.tuebingen.mpg.de/person/sdonne), [Vladlen Koltun](http://vladlen.info/), and [Andreas Geiger](http://www.cvlibs.net/) [Gernot Riegler](https://griegler.github.io/), [Yiyi Liao](https://yiyiliao.github.io/)
, [Simon Donne](https://avg.is.tuebingen.mpg.de/person/sdonne), [Vladlen Koltun](http://vladlen.info/),
and [Andreas Geiger](http://www.cvlibs.net/)
<br> <br>
[CVPR 2019](http://cvpr2019.thecvf.com/) [CVPR 2019](http://cvpr2019.thecvf.com/)
@ -24,40 +26,45 @@ If you find this code useful for your research, please cite
} }
``` ```
## Dependencies ## Dependencies
The network training/evaluation code is based on `Pytorch`. The network training/evaluation code is based on `Pytorch`.
``` ```
PyTorch>=1.1 PyTorch>=1.1
Cuda>=10.0 Cuda>=10.0
``` ```
Updated on 07.06.2021: The code is now compatible with the latest Pytorch version (1.8). Updated on 07.06.2021: The code is now compatible with the latest Pytorch version (1.8).
The other python packages can be installed with `anaconda`: The other python packages can be installed with `anaconda`:
``` ```
conda install --file requirements.txt conda install --file requirements.txt
``` ```
### Structured Light Renderer ### Structured Light Renderer
To train and evaluate our method in a controlled setting, we implemented an structured light renderer.
It can be used to render a virtual scene (arbitrary triangle mesh) with the structured light pattern projected from a customizable projector location. To train and evaluate our method in a controlled setting, we implemented an structured light renderer. It can be used to
To build it, first make sure the correct `CUDA_LIBRARY_PATH` is set in `config.json`. render a virtual scene (arbitrary triangle mesh) with the structured light pattern projected from a customizable
Afterwards, the renderer can be build by running `make` within the `renderer` directory. projector location. To build it, first make sure the correct `CUDA_LIBRARY_PATH` is set in `config.json`. Afterwards,
the renderer can be build by running `make` within the `renderer` directory.
### PyTorch Extensions ### PyTorch Extensions
The network training/evaluation code is based on `PyTorch`.
We implemented some custom layers that need to be built in the `torchext` directory. The network training/evaluation code is based on `PyTorch`. We implemented some custom layers that need to be built in
Simply change into this directory and run the `torchext` directory. Simply change into this directory and run
``` ```
python setup.py build_ext --inplace python setup.py build_ext --inplace
``` ```
### Baseline HyperDepth ### Baseline HyperDepth
As baseline we partially re-implemented the random forest based method [HyperDepth](http://openaccess.thecvf.com/content_cvpr_2016/papers/Fanello_HyperDepth_Learning_Depth_CVPR_2016_paper.pdf).
The code resided in the `hyperdepth` directory and is implemented in `C++11` with a Python wrapper written in `Cython`. As baseline we partially re-implemented the random forest based
To build it change into the directory and run method [HyperDepth](http://openaccess.thecvf.com/content_cvpr_2016/papers/Fanello_HyperDepth_Learning_Depth_CVPR_2016_paper.pdf)
. The code resided in the `hyperdepth` directory and is implemented in `C++11` with a Python wrapper written in `Cython`
. To build it change into the directory and run
``` ```
python setup.py build_ext --inplace python setup.py build_ext --inplace
@ -65,42 +72,59 @@ python setup.py build_ext --inplace
## Running ## Running
### Creating Synthetic Data ### Creating Synthetic Data
To create synthetic data and save it locally, download [ShapeNet V2](https://www.shapenet.org/) and correct `SHAPENET_ROOT` in `config.json`. Then the data can be generated and saved to `DATA_ROOT` in `config.json` by running
To create synthetic data and save it locally, download [ShapeNet V2](https://www.shapenet.org/) and
correct `SHAPENET_ROOT` in `config.json`. Then the data can be generated and saved to `DATA_ROOT` in `config.json` by
running
``` ```
./create_syn_data.sh ./create_syn_data.sh
``` ```
If you are only interested in evaluating our pre-trained model, [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip) is a validation set that contains a small amount of images.
If you are only interested in evaluating our pre-trained
model, [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip) is a
validation set that contains a small amount of images.
### Training Network ### Training Network
As a first stage, it is recommended to train the disparity decoder and edge decoder without the geometric loss. To train the network on synthetic data for the first stage run As a first stage, it is recommended to train the disparity decoder and edge decoder without the geometric loss. To train
the network on synthetic data for the first stage run
``` ```
python train_val.py python train_val.py
``` ```
After the model is pretrained without the geometric loss, the full model can be trained from the initialized weights by running After the model is pretrained without the geometric loss, the full model can be trained from the initialized weights by
running
``` ```
python train_val.py --loss phge python train_val.py --loss phge
``` ```
### Evaluating Network ### Evaluating Network
To evaluate a specific checkpoint, e.g. the 50th epoch, one can run To evaluate a specific checkpoint, e.g. the 50th epoch, one can run
``` ```
python train_val.py --cmd retest --epoch 50 python train_val.py --cmd retest --epoch 50
``` ```
### Evaluating a Pre-trained Model ### Evaluating a Pre-trained Model
We provide a model pre-trained using the photometric loss. Once you have prepared the synthetic dataset and changed `DATA_ROOT` in `config.json`, the pre-trained model can be evaluated on the validation set by running:
We provide a model pre-trained using the photometric loss. Once you have prepared the synthetic dataset and
changed `DATA_ROOT` in `config.json`, the pre-trained model can be evaluated on the validation set by running:
``` ```
mkdir -p output mkdir -p output
mkdir -p output/exp_syn mkdir -p output/exp_syn
wget -O output/exp_syn/net_0099.params https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/net_0099.params wget -O output/exp_syn/net_0099.params https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/net_0099.params
python train_val.py --cmd retest --epoch 99 python train_val.py --cmd retest --epoch 99
``` ```
You can also download our validation set from [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip).
You can also download our validation set
from [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip).
## Acknowledgement ## Acknowledgement
This work was supported by the Intel Network on Intelligent Systems. This work was supported by the Intel Network on Intelligent Systems.

@ -10,7 +10,7 @@ import json
this_dir = os.path.dirname(__file__) this_dir = os.path.dirname(__file__)
with open('../config.json') as fp: with open('../config.json') as fp:
config = json.load(fp) config = json.load(fp)
extra_compile_args = ['-O3', '-std=c++11'] extra_compile_args = ['-O3', '-std=c++11']
@ -20,7 +20,7 @@ cuda_lib = 'cudart'
sources = ['cyrender.pyx'] sources = ['cyrender.pyx']
extra_objects = [ extra_objects = [
os.path.join(this_dir, 'render/render_cpu.cpp.o'), os.path.join(this_dir, 'render/render_cpu.cpp.o'),
] ]
library_dirs = [] library_dirs = []
libraries = ['m'] libraries = ['m']
@ -30,20 +30,20 @@ library_dirs.append(cuda_lib_dir)
libraries.append(cuda_lib) libraries.append(cuda_lib)
setup( setup(
name="cyrender", name="cyrender",
cmdclass= {'build_ext': build_ext}, cmdclass={'build_ext': build_ext},
ext_modules=[ ext_modules=[
Extension('cyrender', Extension('cyrender',
sources, sources,
extra_objects=extra_objects, extra_objects=extra_objects,
language='c++', language='c++',
library_dirs=library_dirs, library_dirs=library_dirs,
libraries=libraries, libraries=libraries,
include_dirs=[ include_dirs=[
np.get_include(), np.get_include(),
], ],
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
# extra_link_args=extra_link_args # extra_link_args=extra_link_args
) )
] ]
) )

@ -2,65 +2,65 @@ import torch
import torch.utils.data import torch.utils.data
import numpy as np import numpy as np
class TestSet(object): class TestSet(object):
def __init__(self, name, dset, test_frequency=1): def __init__(self, name, dset, test_frequency=1):
self.name = name self.name = name
self.dset = dset self.dset = dset
self.test_frequency = test_frequency self.test_frequency = test_frequency
class TestSets(list):
def append(self, name, dset, test_frequency=1):
super().append(TestSet(name, dset, test_frequency))
class TestSets(list):
def append(self, name, dset, test_frequency=1):
super().append(TestSet(name, dset, test_frequency))
class MultiDataset(torch.utils.data.Dataset): class MultiDataset(torch.utils.data.Dataset):
def __init__(self, *datasets): def __init__(self, *datasets):
self.current_epoch = 0 self.current_epoch = 0
self.datasets = []
self.cum_n_samples = [0]
for dataset in datasets: self.datasets = []
self.append(dataset) self.cum_n_samples = [0]
def append(self, dataset): for dataset in datasets:
self.datasets.append(dataset) self.append(dataset)
self.__update_cum_n_samples(dataset)
def __update_cum_n_samples(self, dataset): def append(self, dataset):
n_samples = self.cum_n_samples[-1] + len(dataset) self.datasets.append(dataset)
self.cum_n_samples.append(n_samples) self.__update_cum_n_samples(dataset)
def dataset_updated(self): def __update_cum_n_samples(self, dataset):
self.cum_n_samples = [0] n_samples = self.cum_n_samples[-1] + len(dataset)
for dset in self.datasets: self.cum_n_samples.append(n_samples)
self.__update_cum_n_samples(dset)
def __len__(self): def dataset_updated(self):
return self.cum_n_samples[-1] self.cum_n_samples = [0]
for dset in self.datasets:
self.__update_cum_n_samples(dset)
def __getitem__(self, idx): def __len__(self):
didx = np.searchsorted(self.cum_n_samples, idx, side='right') - 1 return self.cum_n_samples[-1]
sidx = idx - self.cum_n_samples[didx]
return self.datasets[didx][sidx]
def __getitem__(self, idx):
didx = np.searchsorted(self.cum_n_samples, idx, side='right') - 1
sidx = idx - self.cum_n_samples[didx]
return self.datasets[didx][sidx]
class BaseDataset(torch.utils.data.Dataset): class BaseDataset(torch.utils.data.Dataset):
def __init__(self, train=True, fix_seed_per_epoch=False): def __init__(self, train=True, fix_seed_per_epoch=False):
self.current_epoch = 0 self.current_epoch = 0
self.train = train self.train = train
self.fix_seed_per_epoch = fix_seed_per_epoch self.fix_seed_per_epoch = fix_seed_per_epoch
def get_rng(self, idx): def get_rng(self, idx):
rng = np.random.RandomState() rng = np.random.RandomState()
if self.train: if self.train:
if self.fix_seed_per_epoch: if self.fix_seed_per_epoch:
seed = 1 * len(self) + idx seed = 1 * len(self) + idx
else: else:
seed = (self.current_epoch + 1) * len(self) + idx seed = (self.current_epoch + 1) * len(self) + idx
rng.seed(seed) rng.seed(seed)
else: else:
rng.seed(idx) rng.seed(idx)
return rng return rng

@ -2,146 +2,151 @@ import torch
from . import ext_cpu from . import ext_cpu
from . import ext_cuda from . import ext_cuda
class NNFunction(torch.autograd.Function): class NNFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, in0, in1): def forward(ctx, in0, in1):
args = (in0, in1) args = (in0, in1)
if in0.is_cuda: if in0.is_cuda:
out = ext_cuda.nn_cuda(*args) out = ext_cuda.nn_cuda(*args)
else: else:
out = ext_cpu.nn_cpu(*args) out = ext_cpu.nn_cpu(*args)
return out return out
@staticmethod
def backward(ctx, grad_out):
return None, None
@staticmethod
def backward(ctx, grad_out):
return None, None
def nn(in0, in1): def nn(in0, in1):
return NNFunction.apply(in0, in1) return NNFunction.apply(in0, in1)
class CrossCheckFunction(torch.autograd.Function): class CrossCheckFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, in0, in1): def forward(ctx, in0, in1):
args = (in0, in1) args = (in0, in1)
if in0.is_cuda: if in0.is_cuda:
out = ext_cuda.crosscheck_cuda(*args) out = ext_cuda.crosscheck_cuda(*args)
else: else:
out = ext_cpu.crosscheck_cpu(*args) out = ext_cpu.crosscheck_cpu(*args)
return out return out
@staticmethod
def backward(ctx, grad_out):
return None, None
@staticmethod
def backward(ctx, grad_out):
return None, None
def crosscheck(in0, in1): def crosscheck(in0, in1):
return CrossCheckFunction.apply(in0, in1) return CrossCheckFunction.apply(in0, in1)
class ProjNNFunction(torch.autograd.Function): class ProjNNFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, xyz0, xyz1, K, patch_size): def forward(ctx, xyz0, xyz1, K, patch_size):
args = (xyz0, xyz1, K, patch_size) args = (xyz0, xyz1, K, patch_size)
if xyz0.is_cuda: if xyz0.is_cuda:
out = ext_cuda.proj_nn_cuda(*args) out = ext_cuda.proj_nn_cuda(*args)
else: else:
out = ext_cpu.proj_nn_cpu(*args) out = ext_cpu.proj_nn_cpu(*args)
return out return out
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
return None, None, None, None return None, None, None, None
def proj_nn(xyz0, xyz1, K, patch_size):
return ProjNNFunction.apply(xyz0, xyz1, K, patch_size)
def proj_nn(xyz0, xyz1, K, patch_size):
return ProjNNFunction.apply(xyz0, xyz1, K, patch_size)
class XCorrVolFunction(torch.autograd.Function): class XCorrVolFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, in0, in1, n_disps, block_size): def forward(ctx, in0, in1, n_disps, block_size):
args = (in0, in1, n_disps, block_size) args = (in0, in1, n_disps, block_size)
if in0.is_cuda: if in0.is_cuda:
out = ext_cuda.xcorrvol_cuda(*args) out = ext_cuda.xcorrvol_cuda(*args)
else: else:
out = ext_cpu.xcorrvol_cpu(*args) out = ext_cpu.xcorrvol_cpu(*args)
return out return out
@staticmethod
def backward(ctx, grad_out):
return None, None, None, None
@staticmethod
def backward(ctx, grad_out):
return None, None, None, None
def xcorrvol(in0, in1, n_disps, block_size): def xcorrvol(in0, in1, n_disps, block_size):
return XCorrVolFunction.apply(in0, in1, n_disps, block_size) return XCorrVolFunction.apply(in0, in1, n_disps, block_size)
class PhotometricLossFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, es, ta, block_size, type, eps):
args = (es, ta, block_size, type, eps)
ctx.save_for_backward(es, ta)
ctx.block_size = block_size
ctx.type = type
ctx.eps = eps
if es.is_cuda:
out = ext_cuda.photometric_loss_forward(*args)
else:
out = ext_cpu.photometric_loss_forward(*args)
return out
@staticmethod
def backward(ctx, grad_out):
es, ta = ctx.saved_tensors
block_size = ctx.block_size
type = ctx.type
eps = ctx.eps
args = (es, ta, grad_out.contiguous(), block_size, type, eps)
if grad_out.is_cuda:
grad_es = ext_cuda.photometric_loss_backward(*args)
else:
grad_es = ext_cpu.photometric_loss_backward(*args)
return grad_es, None, None, None, None
class PhotometricLossFunction(torch.autograd.Function): def photometric_loss(es, ta, block_size, type='mse', eps=0.1):
@staticmethod type = type.lower()
def forward(ctx, es, ta, block_size, type, eps): if type == 'mse':
args = (es, ta, block_size, type, eps) type = 0
ctx.save_for_backward(es, ta) elif type == 'sad':
ctx.block_size = block_size type = 1
ctx.type = type elif type == 'census_mse':
ctx.eps = eps type = 2
if es.is_cuda: elif type == 'census_sad':
out = ext_cuda.photometric_loss_forward(*args) type = 3
else:
out = ext_cpu.photometric_loss_forward(*args)
return out
@staticmethod
def backward(ctx, grad_out):
es, ta = ctx.saved_tensors
block_size = ctx.block_size
type = ctx.type
eps = ctx.eps
args = (es, ta, grad_out.contiguous(), block_size, type, eps)
if grad_out.is_cuda:
grad_es = ext_cuda.photometric_loss_backward(*args)
else: else:
grad_es = ext_cpu.photometric_loss_backward(*args) raise Exception('invalid loss type')
return grad_es, None, None, None, None return PhotometricLossFunction.apply(es, ta, block_size, type, eps)
def photometric_loss(es, ta, block_size, type='mse', eps=0.1):
type = type.lower()
if type == 'mse':
type = 0
elif type == 'sad':
type = 1
elif type == 'census_mse':
type = 2
elif type == 'census_sad':
type = 3
else:
raise Exception('invalid loss type')
return PhotometricLossFunction.apply(es, ta, block_size, type, eps)
def photometric_loss_pytorch(es, ta, block_size, type='mse', eps=0.1): def photometric_loss_pytorch(es, ta, block_size, type='mse', eps=0.1):
type = type.lower() type = type.lower()
p = block_size // 2 p = block_size // 2
es_pad = torch.nn.functional.pad(es, (p,p,p,p), mode='replicate') es_pad = torch.nn.functional.pad(es, (p, p, p, p), mode='replicate')
ta_pad = torch.nn.functional.pad(ta, (p,p,p,p), mode='replicate') ta_pad = torch.nn.functional.pad(ta, (p, p, p, p), mode='replicate')
es_uf = torch.nn.functional.unfold(es_pad, kernel_size=block_size) es_uf = torch.nn.functional.unfold(es_pad, kernel_size=block_size)
ta_uf = torch.nn.functional.unfold(ta_pad, kernel_size=block_size) ta_uf = torch.nn.functional.unfold(ta_pad, kernel_size=block_size)
es_uf = es_uf.view(es.shape[0], es.shape[1], -1, es.shape[2], es.shape[3]) es_uf = es_uf.view(es.shape[0], es.shape[1], -1, es.shape[2], es.shape[3])
ta_uf = ta_uf.view(ta.shape[0], ta.shape[1], -1, ta.shape[2], ta.shape[3]) ta_uf = ta_uf.view(ta.shape[0], ta.shape[1], -1, ta.shape[2], ta.shape[3])
if type == 'mse': if type == 'mse':
ref = (es_uf - ta_uf)**2 ref = (es_uf - ta_uf) ** 2
elif type == 'sad': elif type == 'sad':
ref = torch.abs(es_uf - ta_uf) ref = torch.abs(es_uf - ta_uf)
elif type == 'census_mse' or type == 'census_sad': elif type == 'census_mse' or type == 'census_sad':
des = es_uf - es.unsqueeze(2) des = es_uf - es.unsqueeze(2)
dta = ta_uf - ta.unsqueeze(2) dta = ta_uf - ta.unsqueeze(2)
h_des = 0.5 * (1 + des / torch.sqrt(des * des + eps)) h_des = 0.5 * (1 + des / torch.sqrt(des * des + eps))
h_dta = 0.5 * (1 + dta / torch.sqrt(dta * dta + eps)) h_dta = 0.5 * (1 + dta / torch.sqrt(dta * dta + eps))
diff = h_des - h_dta diff = h_des - h_dta
if type == 'census_mse': if type == 'census_mse':
ref = diff * diff ref = diff * diff
elif type == 'census_sad': elif type == 'census_sad':
ref = torch.abs(diff) ref = torch.abs(diff)
else: else:
raise Exception('invalid loss type') raise Exception('invalid loss type')
ref = ref.view(es.shape[0], -1, es.shape[2], es.shape[3]) ref = ref.view(es.shape[0], -1, es.shape[2], es.shape[3])
ref = torch.sum(ref, dim=1, keepdim=True) / block_size**2 ref = torch.sum(ref, dim=1, keepdim=True) / block_size ** 2
return ref return ref

@ -4,24 +4,26 @@ import numpy as np
from .functions import * from .functions import *
class CoordConv2d(torch.nn.Module): class CoordConv2d(torch.nn.Module):
def __init__(self, channels_in, channels_out, kernel_size, stride, padding): def __init__(self, channels_in, channels_out, kernel_size, stride, padding):
super().__init__() super().__init__()
self.conv = torch.nn.Conv2d(channels_in+2, channels_out, kernel_size=kernel_size, padding=padding, stride=stride) self.conv = torch.nn.Conv2d(channels_in + 2, channels_out, kernel_size=kernel_size, padding=padding,
stride=stride)
self.uv = None self.uv = None
def forward(self, x): def forward(self, x):
if self.uv is None: if self.uv is None:
height, width = x.shape[2], x.shape[3] height, width = x.shape[2], x.shape[3]
u, v = np.meshgrid(range(width), range(height)) u, v = np.meshgrid(range(width), range(height))
u = 2 * u / (width - 1) - 1 u = 2 * u / (width - 1) - 1
v = 2 * v / (height - 1) - 1 v = 2 * v / (height - 1) - 1
uv = np.stack((u, v)).reshape(1, 2, height, width) uv = np.stack((u, v)).reshape(1, 2, height, width)
self.uv = torch.from_numpy( uv.astype(np.float32) ) self.uv = torch.from_numpy(uv.astype(np.float32))
self.uv = self.uv.to(x.device) self.uv = self.uv.to(x.device)
uv = self.uv.expand(x.shape[0], *self.uv.shape[1:]) uv = self.uv.expand(x.shape[0], *self.uv.shape[1:])
xuv = torch.cat((x, uv), dim=1) xuv = torch.cat((x, uv), dim=1)
y = self.conv(xuv) y = self.conv(xuv)
return y return y

@ -11,11 +11,12 @@ nvcc_args = [
] ]
setup( setup(
name='ext', name='ext',
ext_modules=[ ext_modules=[
CppExtension('ext_cpu', ['ext/ext_cpu.cpp']), CppExtension('ext_cpu', ['ext/ext_cpu.cpp']),
CUDAExtension('ext_cuda', ['ext/ext_cuda.cpp', 'ext/ext_kernel.cu'], extra_compile_args={'cxx': [], 'nvcc': nvcc_args}), CUDAExtension('ext_cuda', ['ext/ext_cuda.cpp', 'ext/ext_kernel.cu'],
], extra_compile_args={'cxx': [], 'nvcc': nvcc_args}),
cmdclass={'build_ext': BuildExtension}, ],
include_dirs=include_dirs cmdclass={'build_ext': BuildExtension},
include_dirs=include_dirs
) )

@ -17,512 +17,516 @@ from collections import OrderedDict
class StopWatch(object): class StopWatch(object):
def __init__(self): def __init__(self):
self.timings = OrderedDict() self.timings = OrderedDict()
self.starts = {} self.starts = {}
def start(self, name): def start(self, name):
self.starts[name] = time.time() self.starts[name] = time.time()
def stop(self, name): def stop(self, name):
if name not in self.timings: if name not in self.timings:
self.timings[name] = [] self.timings[name] = []
self.timings[name].append(time.time() - self.starts[name]) self.timings[name].append(time.time() - self.starts[name])
def get(self, name=None, reduce=np.sum): def get(self, name=None, reduce=np.sum):
if name is not None: if name is not None:
return reduce(self.timings[name]) return reduce(self.timings[name])
else: else:
ret = {} ret = {}
for k in self.timings: for k in self.timings:
ret[k] = reduce(self.timings[k]) ret[k] = reduce(self.timings[k])
return ret return ret
def __repr__(self):
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
def __repr__(self): def __str__(self):
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()]) return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
def __str__(self):
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
class ETA(object): class ETA(object):
def __init__(self, length): def __init__(self, length):
self.length = length self.length = length
self.start_time = time.time() self.start_time = time.time()
self.current_idx = 0 self.current_idx = 0
self.current_time = time.time() self.current_time = time.time()
def update(self, idx): def update(self, idx):
self.current_idx = idx self.current_idx = idx
self.current_time = time.time() self.current_time = time.time()
def get_elapsed_time(self): def get_elapsed_time(self):
return self.current_time - self.start_time return self.current_time - self.start_time
def get_item_time(self): def get_item_time(self):
return self.get_elapsed_time() / (self.current_idx + 1) return self.get_elapsed_time() / (self.current_idx + 1)
def get_remaining_time(self): def get_remaining_time(self):
return self.get_item_time() * (self.length - self.current_idx + 1) return self.get_item_time() * (self.length - self.current_idx + 1)
def format_time(self, seconds): def format_time(self, seconds):
minutes, seconds = divmod(seconds, 60) minutes, seconds = divmod(seconds, 60)
hours, minutes = divmod(minutes, 60) hours, minutes = divmod(minutes, 60)
hours = int(hours) hours = int(hours)
minutes = int(minutes) minutes = int(minutes)
return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}' return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}'
def get_elapsed_time_str(self): def get_elapsed_time_str(self):
return self.format_time(self.get_elapsed_time()) return self.format_time(self.get_elapsed_time())
def get_remaining_time_str(self):
return self.format_time(self.get_remaining_time())
def get_remaining_time_str(self):
return self.format_time(self.get_remaining_time())
class Worker(object): class Worker(object):
def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16, num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1): def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16,
self.out_root = Path(out_root) num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1):
self.experiment_name = experiment_name self.out_root = Path(out_root)
self.epochs = epochs self.experiment_name = experiment_name
self.seed = seed self.epochs = epochs
self.train_batch_size = train_batch_size self.seed = seed
self.test_batch_size = test_batch_size self.train_batch_size = train_batch_size
self.num_workers = num_workers self.test_batch_size = test_batch_size
self.save_frequency = save_frequency self.num_workers = num_workers
self.train_device = train_device self.save_frequency = save_frequency
self.test_device = test_device self.train_device = train_device
self.max_train_iter = max_train_iter self.test_device = test_device
self.max_train_iter = max_train_iter
self.errs_list=[]
self.errs_list = []
self.setup_experiment()
self.setup_experiment()
def setup_experiment(self):
self.exp_out_root = self.out_root / self.experiment_name def setup_experiment(self):
self.exp_out_root.mkdir(parents=True, exist_ok=True) self.exp_out_root = self.out_root / self.experiment_name
self.exp_out_root.mkdir(parents=True, exist_ok=True)
if logging.root: del logging.root.handlers[:]
logging.basicConfig( if logging.root: del logging.root.handlers[:]
level=logging.INFO, logging.basicConfig(
handlers=[ level=logging.INFO,
logging.FileHandler( str(self.exp_out_root / 'train.log') ), handlers=[
logging.StreamHandler() logging.FileHandler(str(self.exp_out_root / 'train.log')),
], logging.StreamHandler()
format='%(relativeCreated)d:%(levelname)s:%(process)d-%(processName)s: %(message)s' ],
) format='%(relativeCreated)d:%(levelname)s:%(process)d-%(processName)s: %(message)s'
)
logging.info('='*80)
logging.info(f'Start of experiment: {self.experiment_name}') logging.info('=' * 80)
logging.info(socket.gethostname()) logging.info(f'Start of experiment: {self.experiment_name}')
self.log_datetime() logging.info(socket.gethostname())
logging.info('='*80) self.log_datetime()
logging.info('=' * 80)
self.metric_path = self.exp_out_root / 'metrics.json'
if self.metric_path.exists(): self.metric_path = self.exp_out_root / 'metrics.json'
with open(str(self.metric_path), 'r') as fp: if self.metric_path.exists():
self.metric_data = json.load(fp) with open(str(self.metric_path), 'r') as fp:
else: self.metric_data = json.load(fp)
self.metric_data = {} else:
self.metric_data = {}
self.init_seed()
self.init_seed()
def metric_add_train(self, epoch, key, val):
epoch = str(epoch) def metric_add_train(self, epoch, key, val):
key = str(key) epoch = str(epoch)
if epoch not in self.metric_data: key = str(key)
self.metric_data[epoch] = {} if epoch not in self.metric_data:
if 'train' not in self.metric_data[epoch]: self.metric_data[epoch] = {}
self.metric_data[epoch]['train'] = {} if 'train' not in self.metric_data[epoch]:
self.metric_data[epoch]['train'][key] = val self.metric_data[epoch]['train'] = {}
self.metric_data[epoch]['train'][key] = val
def metric_add_test(self, epoch, set_idx, key, val):
epoch = str(epoch) def metric_add_test(self, epoch, set_idx, key, val):
set_idx = str(set_idx) epoch = str(epoch)
key = str(key) set_idx = str(set_idx)
if epoch not in self.metric_data: key = str(key)
self.metric_data[epoch] = {} if epoch not in self.metric_data:
if 'test' not in self.metric_data[epoch]: self.metric_data[epoch] = {}
self.metric_data[epoch]['test'] = {} if 'test' not in self.metric_data[epoch]:
if set_idx not in self.metric_data[epoch]['test']: self.metric_data[epoch]['test'] = {}
self.metric_data[epoch]['test'][set_idx] = {} if set_idx not in self.metric_data[epoch]['test']:
self.metric_data[epoch]['test'][set_idx][key] = val self.metric_data[epoch]['test'][set_idx] = {}
self.metric_data[epoch]['test'][set_idx][key] = val
def metric_save(self):
with open(str(self.metric_path), 'w') as fp: def metric_save(self):
json.dump(self.metric_data, fp, indent=2) with open(str(self.metric_path), 'w') as fp:
json.dump(self.metric_data, fp, indent=2)
def init_seed(self, seed=None):
if seed is not None: def init_seed(self, seed=None):
self.seed = seed if seed is not None:
logging.info(f'Set seed to {self.seed}') self.seed = seed
np.random.seed(self.seed) logging.info(f'Set seed to {self.seed}')
random.seed(self.seed) np.random.seed(self.seed)
torch.manual_seed(self.seed) random.seed(self.seed)
torch.cuda.manual_seed(self.seed) torch.manual_seed(self.seed)
torch.cuda.manual_seed(self.seed)
def log_datetime(self):
logging.info(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')) def log_datetime(self):
logging.info(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
def mem_report(self):
for obj in gc.get_objects(): def mem_report(self):
if torch.is_tensor(obj): for obj in gc.get_objects():
print(type(obj), obj.shape) if torch.is_tensor(obj):
print(type(obj), obj.shape)
def get_net_path(self, epoch, root=None):
if root is None: def get_net_path(self, epoch, root=None):
root = self.exp_out_root if root is None:
return root / f'net_{epoch:04d}.params' root = self.exp_out_root
return root / f'net_{epoch:04d}.params'
def get_do_parser_cmds(self):
return ['retrain', 'resume', 'retest', 'test_init'] def get_do_parser_cmds(self):
return ['retrain', 'resume', 'retest', 'test_init']
def get_do_parser(self):
parser = argparse.ArgumentParser() def get_do_parser(self):
parser.add_argument('--cmd', type=str, default='resume', choices=self.get_do_parser_cmds()) parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int, default=-1) parser.add_argument('--cmd', type=str, default='resume', choices=self.get_do_parser_cmds())
return parser parser.add_argument('--epoch', type=int, default=-1)
return parser
def do_cmd(self, args, net, optimizer, scheduler=None):
if args.cmd == 'retrain': def do_cmd(self, args, net, optimizer, scheduler=None):
self.train(net, optimizer, resume=False, scheduler=scheduler) if args.cmd == 'retrain':
elif args.cmd == 'resume': self.train(net, optimizer, resume=False, scheduler=scheduler)
self.train(net, optimizer, resume=True, scheduler=scheduler) elif args.cmd == 'resume':
elif args.cmd == 'retest': self.train(net, optimizer, resume=True, scheduler=scheduler)
self.retest(net, epoch=args.epoch) elif args.cmd == 'retest':
elif args.cmd == 'test_init': self.retest(net, epoch=args.epoch)
test_sets = self.get_test_sets() elif args.cmd == 'test_init':
self.test(-1, net, test_sets) test_sets = self.get_test_sets()
else: self.test(-1, net, test_sets)
raise Exception('invalid cmd') else:
raise Exception('invalid cmd')
def do(self, net, optimizer, load_net_optimizer=None, scheduler=None):
parser = self.get_do_parser() def do(self, net, optimizer, load_net_optimizer=None, scheduler=None):
args, _ = parser.parse_known_args() parser = self.get_do_parser()
args, _ = parser.parse_known_args()
if load_net_optimizer is not None and args.cmd not in ['schedule']:
net, optimizer = load_net_optimizer() if load_net_optimizer is not None and args.cmd not in ['schedule']:
net, optimizer = load_net_optimizer()
self.do_cmd(args, net, optimizer, scheduler=scheduler)
def retest(self, net, epoch=-1):
if epoch < 0:
epochs = range(self.epochs)
else:
epochs = [epoch]
test_sets = self.get_test_sets()
for epoch in epochs:
net_path = self.get_net_path(epoch)
if net_path.exists():
state_dict = torch.load(str(net_path))
net.load_state_dict(state_dict)
self.test(epoch, net, test_sets)
def format_err_str(self, errs, div=1):
err = sum(errs)
if len(errs) > 1:
err_str = f'{err/div:0.4f}=' + '+'.join([f'{e/div:0.4f}' for e in errs])
else:
err_str = f'{err/div:0.4f}'
return err_str
def write_err_img(self):
err_img_path = self.exp_out_root / 'errs.png'
fig = plt.figure(figsize=(16,16))
lines=[]
for idx,errs in enumerate(self.errs_list):
line,=plt.plot(range(len(errs)), errs, label=f'error{idx}')
lines.append(line)
plt.tight_layout()
plt.legend(handles=lines)
plt.savefig(str(err_img_path))
plt.close(fig)
def callback_train_new_epoch(self, epoch, net, optimizer):
pass
def train(self, net, optimizer, resume=False, scheduler=None):
logging.info('='*80)
logging.info('Start training')
self.log_datetime()
logging.info('='*80)
train_set = self.get_train_set()
test_sets = self.get_test_sets()
net = net.to(self.train_device)
epoch = 0
min_err = {ts.name: 1e9 for ts in test_sets}
state_path = self.exp_out_root / 'state.dict'
if resume and state_path.exists():
logging.info('='*80)
logging.info(f'Loading state from {state_path}')
logging.info('='*80)
state = torch.load(str(state_path))
epoch = state['epoch'] + 1
if 'min_err' in state:
min_err = state['min_err']
curr_state = net.state_dict()
curr_state.update(state['state_dict'])
net.load_state_dict(curr_state)
try:
optimizer.load_state_dict(state['optimizer'])
except:
logging.info('Warning: cannot load optimizer from state_dict')
pass
if 'cpu_rng_state' in state:
torch.set_rng_state(state['cpu_rng_state'])
if 'gpu_rng_state' in state:
torch.cuda.set_rng_state(state['gpu_rng_state'])
for epoch in range(epoch, self.epochs): self.do_cmd(args, net, optimizer, scheduler=scheduler)
self.callback_train_new_epoch(epoch, net, optimizer)
def retest(self, net, epoch=-1):
if epoch < 0:
epochs = range(self.epochs)
else:
epochs = [epoch]
test_sets = self.get_test_sets()
for epoch in epochs:
net_path = self.get_net_path(epoch)
if net_path.exists():
state_dict = torch.load(str(net_path))
net.load_state_dict(state_dict)
self.test(epoch, net, test_sets)
def format_err_str(self, errs, div=1):
err = sum(errs)
if len(errs) > 1:
err_str = f'{err / div:0.4f}=' + '+'.join([f'{e / div:0.4f}' for e in errs])
else:
err_str = f'{err / div:0.4f}'
return err_str
def write_err_img(self):
err_img_path = self.exp_out_root / 'errs.png'
fig = plt.figure(figsize=(16, 16))
lines = []
for idx, errs in enumerate(self.errs_list):
line, = plt.plot(range(len(errs)), errs, label=f'error{idx}')
lines.append(line)
plt.tight_layout()
plt.legend(handles=lines)
plt.savefig(str(err_img_path))
plt.close(fig)
def callback_train_new_epoch(self, epoch, net, optimizer):
pass
# train epoch def train(self, net, optimizer, resume=False, scheduler=None):
self.train_epoch(epoch, net, optimizer, train_set) logging.info('=' * 80)
logging.info('Start training')
self.log_datetime()
logging.info('=' * 80)
# test epoch train_set = self.get_train_set()
errs = self.test(epoch, net, test_sets) test_sets = self.get_test_sets()
if (epoch + 1) % self.save_frequency == 0:
net = net.to(self.train_device) net = net.to(self.train_device)
# store state epoch = 0
state_dict = { min_err = {ts.name: 1e9 for ts in test_sets}
'epoch': epoch,
'min_err': min_err,
'state_dict': net.state_dict(),
'optimizer': optimizer.state_dict(),
'cpu_rng_state': torch.get_rng_state(),
'gpu_rng_state': torch.cuda.get_rng_state(),
}
logging.info(f'save state to {state_path}')
state_path = self.exp_out_root / 'state.dict' state_path = self.exp_out_root / 'state.dict'
torch.save(state_dict, str(state_path)) if resume and state_path.exists():
logging.info('=' * 80)
for test_set_name in errs: logging.info(f'Loading state from {state_path}')
err = sum(errs[test_set_name]) logging.info('=' * 80)
if err < min_err[test_set_name]: state = torch.load(str(state_path))
min_err[test_set_name] = err epoch = state['epoch'] + 1
state_path = self.exp_out_root / f'state_set{test_set_name}_best.dict' if 'min_err' in state:
logging.info(f'save state to {state_path}') min_err = state['min_err']
torch.save(state_dict, str(state_path))
curr_state = net.state_dict()
# store network curr_state.update(state['state_dict'])
net_path = self.get_net_path(epoch) net.load_state_dict(curr_state)
logging.info(f'save network to {net_path}')
torch.save(net.state_dict(), str(net_path)) try:
optimizer.load_state_dict(state['optimizer'])
if scheduler is not None: except:
scheduler.step() logging.info('Warning: cannot load optimizer from state_dict')
pass
logging.info('='*80) if 'cpu_rng_state' in state:
logging.info('Finished training') torch.set_rng_state(state['cpu_rng_state'])
self.log_datetime() if 'gpu_rng_state' in state:
logging.info('='*80) torch.cuda.set_rng_state(state['gpu_rng_state'])
def get_train_set(self): for epoch in range(epoch, self.epochs):
# returns train_set self.callback_train_new_epoch(epoch, net, optimizer)
raise NotImplementedError()
# train epoch
def get_test_sets(self): self.train_epoch(epoch, net, optimizer, train_set)
# returns test_sets
raise NotImplementedError() # test epoch
errs = self.test(epoch, net, test_sets)
def copy_data(self, data, device, requires_grad, train):
raise NotImplementedError() if (epoch + 1) % self.save_frequency == 0:
net = net.to(self.train_device)
def net_forward(self, net, train):
raise NotImplementedError() # store state
state_dict = {
def loss_forward(self, output, train): 'epoch': epoch,
raise NotImplementedError() 'min_err': min_err,
'state_dict': net.state_dict(),
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks): 'optimizer': optimizer.state_dict(),
# err = False 'cpu_rng_state': torch.get_rng_state(),
# for name, param in net.named_parameters(): 'gpu_rng_state': torch.cuda.get_rng_state(),
# if not torch.isfinite(param.grad).all(): }
# print(name) logging.info(f'save state to {state_path}')
# err = True state_path = self.exp_out_root / 'state.dict'
# if err: torch.save(state_dict, str(state_path))
# import ipdb; ipdb.set_trace()
pass for test_set_name in errs:
err = sum(errs[test_set_name])
def callback_train_start(self, epoch): if err < min_err[test_set_name]:
pass min_err[test_set_name] = err
state_path = self.exp_out_root / f'state_set{test_set_name}_best.dict'
def callback_train_stop(self, epoch, loss): logging.info(f'save state to {state_path}')
pass torch.save(state_dict, str(state_path))
def train_epoch(self, epoch, net, optimizer, dset): # store network
self.callback_train_start(epoch) net_path = self.get_net_path(epoch)
stopwatch = StopWatch() logging.info(f'save network to {net_path}')
torch.save(net.state_dict(), str(net_path))
logging.info('='*80)
logging.info('Train epoch %d' % epoch) if scheduler is not None:
scheduler.step()
dset.current_epoch = epoch
train_loader = torch.utils.data.DataLoader(dset, batch_size=self.train_batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True, pin_memory=False) logging.info('=' * 80)
logging.info('Finished training')
net = net.to(self.train_device) self.log_datetime()
net.train() logging.info('=' * 80)
mean_loss = None def get_train_set(self):
# returns train_set
n_batches = self.max_train_iter if self.max_train_iter > 0 else len(train_loader) raise NotImplementedError()
bar = ETA(length=n_batches)
def get_test_sets(self):
stopwatch.start('total') # returns test_sets
stopwatch.start('data') raise NotImplementedError()
for batch_idx, data in enumerate(train_loader):
if self.max_train_iter > 0 and batch_idx > self.max_train_iter: break def copy_data(self, data, device, requires_grad, train):
self.copy_data(data, device=self.train_device, requires_grad=True, train=True) raise NotImplementedError()
stopwatch.stop('data')
def net_forward(self, net, train):
optimizer.zero_grad() raise NotImplementedError()
stopwatch.start('forward') def loss_forward(self, output, train):
output = self.net_forward(net, train=True) raise NotImplementedError()
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('forward') def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks):
# err = False
stopwatch.start('loss') # for name, param in net.named_parameters():
errs = self.loss_forward(output, train=True) # if not torch.isfinite(param.grad).all():
if isinstance(errs, dict): # print(name)
masks = errs['masks'] # err = True
errs = errs['errs'] # if err:
else: # import ipdb; ipdb.set_trace()
masks = [] pass
if not isinstance(errs, list) and not isinstance(errs, tuple):
errs = [errs] def callback_train_start(self, epoch):
err = sum(errs) pass
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('loss') def callback_train_stop(self, epoch, loss):
pass
stopwatch.start('backward')
err.backward()
self.callback_train_post_backward(net, errs, output, epoch, batch_idx, masks)
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('backward')
stopwatch.start('optimizer')
optimizer.step()
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('optimizer')
bar.update(batch_idx)
if (epoch <= 1 and batch_idx < 128) or batch_idx % 16 == 0:
err_str = self.format_err_str(errs)
logging.info(f'train e{epoch}: {batch_idx+1}/{len(train_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
#self.write_err_img()
if mean_loss is None:
mean_loss = [0 for e in errs]
for erridx, err in enumerate(errs):
mean_loss[erridx] += err.item()
stopwatch.start('data')
stopwatch.stop('total')
logging.info('timings: %s' % stopwatch)
mean_loss = [l / len(train_loader) for l in mean_loss]
self.callback_train_stop(epoch, mean_loss)
self.metric_add_train(epoch, 'loss', mean_loss)
# save metrics
self.metric_save()
err_str = self.format_err_str(mean_loss)
logging.info(f'avg train_loss={err_str}')
return mean_loss
def callback_test_start(self, epoch, set_idx):
pass
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks):
pass
def callback_test_stop(self, epoch, set_idx, loss):
pass
def test(self, epoch, net, test_sets):
errs = {}
for test_set_idx, test_set in enumerate(test_sets):
if (epoch + 1) % test_set.test_frequency == 0:
logging.info('='*80)
logging.info(f'testing set {test_set.name}')
err = self.test_epoch(epoch, test_set_idx, net, test_set.dset)
errs[test_set.name] = err
return errs
def test_epoch(self, epoch, set_idx, net, dset):
logging.info('-'*80)
logging.info('Test epoch %d' % epoch)
dset.current_epoch = epoch
test_loader = torch.utils.data.DataLoader(dset, batch_size=self.test_batch_size, shuffle=False, num_workers=self.num_workers, drop_last=False, pin_memory=False)
net = net.to(self.test_device)
net.eval()
with torch.no_grad():
mean_loss = None
self.callback_test_start(epoch, set_idx)
bar = ETA(length=len(test_loader))
stopwatch = StopWatch()
stopwatch.start('total')
stopwatch.start('data')
for batch_idx, data in enumerate(test_loader):
# if batch_idx == 10: break
self.copy_data(data, device=self.test_device, requires_grad=False, train=False)
stopwatch.stop('data')
stopwatch.start('forward')
output = self.net_forward(net, train=False)
if 'cuda' in self.test_device: torch.cuda.synchronize()
stopwatch.stop('forward')
stopwatch.start('loss')
errs = self.loss_forward(output, train=False)
if isinstance(errs, dict):
masks = errs['masks']
errs = errs['errs']
else:
masks = []
if not isinstance(errs, list) and not isinstance(errs, tuple):
errs = [errs]
bar.update(batch_idx) def train_epoch(self, epoch, net, optimizer, dset):
if batch_idx % 25 == 0: self.callback_train_start(epoch)
err_str = self.format_err_str(errs) stopwatch = StopWatch()
logging.info(f'test e{epoch}: {batch_idx+1}/{len(test_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
if mean_loss is None: logging.info('=' * 80)
mean_loss = [0 for e in errs] logging.info('Train epoch %d' % epoch)
for erridx, err in enumerate(errs):
mean_loss[erridx] += err.item()
stopwatch.stop('loss')
self.callback_test_add(epoch, set_idx, batch_idx, len(test_loader), output, masks) dset.current_epoch = epoch
train_loader = torch.utils.data.DataLoader(dset, batch_size=self.train_batch_size, shuffle=True,
num_workers=self.num_workers, drop_last=True, pin_memory=False)
net = net.to(self.train_device)
net.train()
mean_loss = None
n_batches = self.max_train_iter if self.max_train_iter > 0 else len(train_loader)
bar = ETA(length=n_batches)
stopwatch.start('total')
stopwatch.start('data') stopwatch.start('data')
stopwatch.stop('total') for batch_idx, data in enumerate(train_loader):
logging.info('timings: %s' % stopwatch) if self.max_train_iter > 0 and batch_idx > self.max_train_iter: break
self.copy_data(data, device=self.train_device, requires_grad=True, train=True)
stopwatch.stop('data')
optimizer.zero_grad()
stopwatch.start('forward')
output = self.net_forward(net, train=True)
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('forward')
stopwatch.start('loss')
errs = self.loss_forward(output, train=True)
if isinstance(errs, dict):
masks = errs['masks']
errs = errs['errs']
else:
masks = []
if not isinstance(errs, list) and not isinstance(errs, tuple):
errs = [errs]
err = sum(errs)
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('loss')
stopwatch.start('backward')
err.backward()
self.callback_train_post_backward(net, errs, output, epoch, batch_idx, masks)
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('backward')
stopwatch.start('optimizer')
optimizer.step()
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('optimizer')
bar.update(batch_idx)
if (epoch <= 1 and batch_idx < 128) or batch_idx % 16 == 0:
err_str = self.format_err_str(errs)
logging.info(
f'train e{epoch}: {batch_idx + 1}/{len(train_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
# self.write_err_img()
if mean_loss is None:
mean_loss = [0 for e in errs]
for erridx, err in enumerate(errs):
mean_loss[erridx] += err.item()
stopwatch.start('data')
stopwatch.stop('total')
logging.info('timings: %s' % stopwatch)
mean_loss = [l / len(train_loader) for l in mean_loss]
self.callback_train_stop(epoch, mean_loss)
self.metric_add_train(epoch, 'loss', mean_loss)
# save metrics
self.metric_save()
err_str = self.format_err_str(mean_loss)
logging.info(f'avg train_loss={err_str}')
return mean_loss
def callback_test_start(self, epoch, set_idx):
pass
mean_loss = [l / len(test_loader) for l in mean_loss] def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks):
self.callback_test_stop(epoch, set_idx, mean_loss) pass
self.metric_add_test(epoch, set_idx, 'loss', mean_loss)
# save metrics def callback_test_stop(self, epoch, set_idx, loss):
self.metric_save() pass
err_str = self.format_err_str(mean_loss) def test(self, epoch, net, test_sets):
logging.info(f'test epoch {epoch}: avg test_loss={err_str}') errs = {}
return mean_loss for test_set_idx, test_set in enumerate(test_sets):
if (epoch + 1) % test_set.test_frequency == 0:
logging.info('=' * 80)
logging.info(f'testing set {test_set.name}')
err = self.test_epoch(epoch, test_set_idx, net, test_set.dset)
errs[test_set.name] = err
return errs
def test_epoch(self, epoch, set_idx, net, dset):
logging.info('-' * 80)
logging.info('Test epoch %d' % epoch)
dset.current_epoch = epoch
test_loader = torch.utils.data.DataLoader(dset, batch_size=self.test_batch_size, shuffle=False,
num_workers=self.num_workers, drop_last=False, pin_memory=False)
net = net.to(self.test_device)
net.eval()
with torch.no_grad():
mean_loss = None
self.callback_test_start(epoch, set_idx)
bar = ETA(length=len(test_loader))
stopwatch = StopWatch()
stopwatch.start('total')
stopwatch.start('data')
for batch_idx, data in enumerate(test_loader):
# if batch_idx == 10: break
self.copy_data(data, device=self.test_device, requires_grad=False, train=False)
stopwatch.stop('data')
stopwatch.start('forward')
output = self.net_forward(net, train=False)
if 'cuda' in self.test_device: torch.cuda.synchronize()
stopwatch.stop('forward')
stopwatch.start('loss')
errs = self.loss_forward(output, train=False)
if isinstance(errs, dict):
masks = errs['masks']
errs = errs['errs']
else:
masks = []
if not isinstance(errs, list) and not isinstance(errs, tuple):
errs = [errs]
bar.update(batch_idx)
if batch_idx % 25 == 0:
err_str = self.format_err_str(errs)
logging.info(
f'test e{epoch}: {batch_idx + 1}/{len(test_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
if mean_loss is None:
mean_loss = [0 for e in errs]
for erridx, err in enumerate(errs):
mean_loss[erridx] += err.item()
stopwatch.stop('loss')
self.callback_test_add(epoch, set_idx, batch_idx, len(test_loader), output, masks)
stopwatch.start('data')
stopwatch.stop('total')
logging.info('timings: %s' % stopwatch)
mean_loss = [l / len(test_loader) for l in mean_loss]
self.callback_test_stop(epoch, set_idx, mean_loss)
self.metric_add_test(epoch, set_idx, 'loss', mean_loss)
# save metrics
self.metric_save()
err_str = self.format_err_str(mean_loss)
logging.info(f'test epoch {epoch}: avg test_loss={err_str}')
return mean_loss

@ -5,25 +5,24 @@ from model import exp_synphge
from model import networks from model import networks
from co.args import parse_args from co.args import parse_args
# parse args # parse args
args = parse_args() args = parse_args()
# loss types # loss types
if args.loss=='ph': if args.loss == 'ph':
worker = exp_synph.Worker(args) worker = exp_synph.Worker(args)
elif args.loss=='phge': elif args.loss == 'phge':
worker = exp_synphge.Worker(args) worker = exp_synphge.Worker(args)
# concatenation of original image and lcn image # concatenation of original image and lcn image
channels_in=2 channels_in = 2
# set up network # set up network
net = networks.DispEdgeDecoders(channels_in=channels_in, max_disp=args.max_disp, imsizes=worker.imsizes, output_ms=worker.ms) net = networks.DispEdgeDecoders(channels_in=channels_in, max_disp=args.max_disp, imsizes=worker.imsizes,
output_ms=worker.ms)
# optimizer # optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4) optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
# start the work # start the work
worker.do(net, optimizer) worker.do(net, optimizer)

Loading…
Cancel
Save