Reformat $EVERYTHING

This commit is contained in:
CptCaptain 2021-11-15 16:53:30 +01:00
parent 56f2aa7d5d
commit 43df77fb9b
32 changed files with 4171 additions and 3749 deletions

View File

@ -7,8 +7,9 @@
# set matplotlib backend depending on env # set matplotlib backend depending on env
import os import os
import matplotlib import matplotlib
if os.name == 'posix' and "DISPLAY" not in os.environ: if os.name == 'posix' and "DISPLAY" not in os.environ:
matplotlib.use('Agg') matplotlib.use('Agg')
from . import geometry from . import geometry
from . import plt from . import plt

View File

@ -12,14 +12,14 @@ def parse_args():
parser.add_argument('--loss', parser.add_argument('--loss',
help='Train with \'ph\' for the first stage without geometric loss, \ help='Train with \'ph\' for the first stage without geometric loss, \
train with \'phge\' for the second stage with geometric loss', train with \'phge\' for the second stage with geometric loss',
default='ph', choices=['ph','phge'], type=str) default='ph', choices=['ph', 'phge'], type=str)
parser.add_argument('--data_type', parser.add_argument('--data_type',
default='syn', choices=['syn'], type=str) default='syn', choices=['syn'], type=str)
# #
parser.add_argument('--cmd', parser.add_argument('--cmd',
help='Start training or test', help='Start training or test',
default='resume', choices=['retrain', 'resume', 'retest', 'test_init'], type=str) default='resume', choices=['retrain', 'resume', 'retest', 'test_init'], type=str)
parser.add_argument('--epoch', parser.add_argument('--epoch',
help='If larger than -1, retest on the specified epoch', help='If larger than -1, retest on the specified epoch',
default=-1, type=int) default=-1, type=int)
parser.add_argument('--epochs', parser.add_argument('--epochs',
@ -55,7 +55,7 @@ def parse_args():
parser.add_argument('--blend_im', parser.add_argument('--blend_im',
help='Parameter for adding texture', help='Parameter for adding texture',
default=0.6, type=float) default=0.6, type=float)
args = parser.parse_args() args = parser.parse_args()
args.exp_name = get_exp_name(args) args.exp_name = get_exp_name(args)
@ -66,6 +66,3 @@ def parse_args():
def get_exp_name(args): def get_exp_name(args):
name = f"exp_{args.data_type}" name = f"exp_{args.data_type}"
return name return name

View File

@ -1,19 +1,20 @@
import numpy as np import numpy as np
_color_map_errors = np.array([ _color_map_errors = np.array([
[149, 54, 49], #0: log2(x) = -infinity [149, 54, 49], # 0: log2(x) = -infinity
[180, 117, 69], #0.0625: log2(x) = -4 [180, 117, 69], # 0.0625: log2(x) = -4
[209, 173, 116], #0.125: log2(x) = -3 [209, 173, 116], # 0.125: log2(x) = -3
[233, 217, 171], #0.25: log2(x) = -2 [233, 217, 171], # 0.25: log2(x) = -2
[248, 243, 224], #0.5: log2(x) = -1 [248, 243, 224], # 0.5: log2(x) = -1
[144, 224, 254], #1.0: log2(x) = 0 [144, 224, 254], # 1.0: log2(x) = 0
[97, 174, 253], #2.0: log2(x) = 1 [97, 174, 253], # 2.0: log2(x) = 1
[67, 109, 244], #4.0: log2(x) = 2 [67, 109, 244], # 4.0: log2(x) = 2
[39, 48, 215], #8.0: log2(x) = 3 [39, 48, 215], # 8.0: log2(x) = 3
[38, 0, 165], #16.0: log2(x) = 4 [38, 0, 165], # 16.0: log2(x) = 4
[38, 0, 165] #inf: log2(x) = inf [38, 0, 165] # inf: log2(x) = inf
]).astype(float) ]).astype(float)
def color_error_image(errors, scale=1, mask=None, BGR=True): def color_error_image(errors, scale=1, mask=None, BGR=True):
""" """
Color an input error map. Color an input error map.
@ -27,31 +28,33 @@ def color_error_image(errors, scale=1, mask=None, BGR=True):
Returns: Returns:
colored_errors -- HxWx3 numpy array visualizing the errors colored_errors -- HxWx3 numpy array visualizing the errors
""" """
errors_flat = errors.flatten() errors_flat = errors.flatten()
errors_color_indices = np.clip(np.log2(errors_flat / scale + 1e-5) + 5, 0, 9) errors_color_indices = np.clip(np.log2(errors_flat / scale + 1e-5) + 5, 0, 9)
i0 = np.floor(errors_color_indices).astype(int) i0 = np.floor(errors_color_indices).astype(int)
f1 = errors_color_indices - i0.astype(float) f1 = errors_color_indices - i0.astype(float)
colored_errors_flat = _color_map_errors[i0, :] * (1-f1).reshape(-1,1) + _color_map_errors[i0+1, :] * f1.reshape(-1,1) colored_errors_flat = _color_map_errors[i0, :] * (1 - f1).reshape(-1, 1) + _color_map_errors[i0 + 1,
:] * f1.reshape(-1, 1)
if mask is not None: if mask is not None:
colored_errors_flat[mask.flatten() == 0] = 255 colored_errors_flat[mask.flatten() == 0] = 255
if not BGR: if not BGR:
colored_errors_flat = colored_errors_flat[:,[2,1,0]] colored_errors_flat = colored_errors_flat[:, [2, 1, 0]]
return colored_errors_flat.reshape(errors.shape[0], errors.shape[1], 3).astype(np.int) return colored_errors_flat.reshape(errors.shape[0], errors.shape[1], 3).astype(np.int)
_color_map_depths = np.array([ _color_map_depths = np.array([
[0, 0, 0], # 0.000 [0, 0, 0], # 0.000
[0, 0, 255], # 0.114 [0, 0, 255], # 0.114
[255, 0, 0], # 0.299 [255, 0, 0], # 0.299
[255, 0, 255], # 0.413 [255, 0, 255], # 0.413
[0, 255, 0], # 0.587 [0, 255, 0], # 0.587
[0, 255, 255], # 0.701 [0, 255, 255], # 0.701
[255, 255, 0], # 0.886 [255, 255, 0], # 0.886
[255, 255, 255], # 1.000 [255, 255, 255], # 1.000
[255, 255, 255], # 1.000 [255, 255, 255], # 1.000
]).astype(float) ]).astype(float)
_color_map_bincenters = np.array([ _color_map_bincenters = np.array([
0.0, 0.0,
@ -62,9 +65,10 @@ _color_map_bincenters = np.array([
0.701, 0.701,
0.886, 0.886,
1.000, 1.000,
2.000, # doesn't make a difference, just strictly higher than 1 2.000, # doesn't make a difference, just strictly higher than 1
]) ])
def color_depth_map(depths, scale=None): def color_depth_map(depths, scale=None):
""" """
Color an input depth map. Color an input depth map.
@ -82,12 +86,13 @@ def color_depth_map(depths, scale=None):
values = np.clip(depths.flatten() / scale, 0, 1) values = np.clip(depths.flatten() / scale, 0, 1)
# for each value, figure out where they fit in in the bincenters: what is the last bincenter smaller than this value? # for each value, figure out where they fit in in the bincenters: what is the last bincenter smaller than this value?
lower_bin = ((values.reshape(-1, 1) >= _color_map_bincenters.reshape(1,-1)) * np.arange(0,9)).max(axis=1) lower_bin = ((values.reshape(-1, 1) >= _color_map_bincenters.reshape(1, -1)) * np.arange(0, 9)).max(axis=1)
lower_bin_value = _color_map_bincenters[lower_bin] lower_bin_value = _color_map_bincenters[lower_bin]
higher_bin_value = _color_map_bincenters[lower_bin + 1] higher_bin_value = _color_map_bincenters[lower_bin + 1]
alphas = (values - lower_bin_value) / (higher_bin_value - lower_bin_value) alphas = (values - lower_bin_value) / (higher_bin_value - lower_bin_value)
colors = _color_map_depths[lower_bin] * (1-alphas).reshape(-1,1) + _color_map_depths[lower_bin + 1] * alphas.reshape(-1,1) colors = _color_map_depths[lower_bin] * (1 - alphas).reshape(-1, 1) + _color_map_depths[
lower_bin + 1] * alphas.reshape(-1, 1)
return colors.reshape(depths.shape[0], depths.shape[1], 3).astype(np.uint8) return colors.reshape(depths.shape[0], depths.shape[1], 3).astype(np.uint8)
#from utils.debug import save_color_numpy # from utils.debug import save_color_numpy
#save_color_numpy(color_depth_map(np.matmul(np.ones((100,1)), np.arange(0,1200).reshape(1,1200)), scale=1000)) # save_color_numpy(color_depth_map(np.matmul(np.ones((100,1)), np.arange(0,1200).reshape(1,1200)), scale=1000))

File diff suppressed because it is too large Load Diff

View File

@ -2,31 +2,37 @@ import numpy as np
from . import utils from . import utils
class StopWatch(utils.StopWatch): class StopWatch(utils.StopWatch):
def __del__(self): def __del__(self):
print('='*80) print('=' * 80)
print('gtimer:') print('gtimer:')
total = ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get(reduce=np.sum).items()]) total = ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get(reduce=np.sum).items()])
print(f' [total] {total}') print(f' [total] {total}')
mean = ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get(reduce=np.mean).items()]) mean = ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get(reduce=np.mean).items()])
print(f' [mean] {mean}') print(f' [mean] {mean}')
median = ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get(reduce=np.median).items()]) median = ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get(reduce=np.median).items()])
print(f' [median] {median}') print(f' [median] {median}')
print('='*80) print('=' * 80)
GTIMER = StopWatch() GTIMER = StopWatch()
def start(name): def start(name):
GTIMER.start(name) GTIMER.start(name)
def stop(name): def stop(name):
GTIMER.stop(name) GTIMER.stop(name)
class Ctx(object): class Ctx(object):
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
def __enter__(self): def __enter__(self):
start(self.name) start(self.name)
def __exit__(self, *args): def __exit__(self, *args):
stop(self.name) stop(self.name)

View File

@ -2,266 +2,273 @@ import struct
import numpy as np import numpy as np
import collections import collections
def _write_ply_point(fp, x,y,z, color=None, normal=None, binary=False):
args = [x,y,z]
if color is not None:
args += [int(color[0]), int(color[1]), int(color[2])]
if normal is not None:
args += [normal[0],normal[1],normal[2]]
if binary:
fmt = '<fff'
if color is not None:
fmt = fmt + 'BBB'
if normal is not None:
fmt = fmt + 'fff'
fp.write(struct.pack(fmt, *args))
else:
fmt = '%f %f %f'
if color is not None:
fmt = fmt + ' %d %d %d'
if normal is not None:
fmt = fmt + ' %f %f %f'
fmt += '\n'
fp.write(fmt % tuple(args))
def _write_ply_triangle(fp, i0,i1,i2, binary): def _write_ply_point(fp, x, y, z, color=None, normal=None, binary=False):
if binary: args = [x, y, z]
fp.write(struct.pack('<Biii', 3,i0,i1,i2)) if color is not None:
else: args += [int(color[0]), int(color[1]), int(color[2])]
fp.write('3 %d %d %d\n' % (i0,i1,i2)) if normal is not None:
args += [normal[0], normal[1], normal[2]]
if binary:
fmt = '<fff'
if color is not None:
fmt = fmt + 'BBB'
if normal is not None:
fmt = fmt + 'fff'
fp.write(struct.pack(fmt, *args))
else:
fmt = '%f %f %f'
if color is not None:
fmt = fmt + ' %d %d %d'
if normal is not None:
fmt = fmt + ' %f %f %f'
fmt += '\n'
fp.write(fmt % tuple(args))
def _write_ply_triangle(fp, i0, i1, i2, binary):
if binary:
fp.write(struct.pack('<Biii', 3, i0, i1, i2))
else:
fp.write('3 %d %d %d\n' % (i0, i1, i2))
def _write_ply_header_line(fp, str, binary): def _write_ply_header_line(fp, str, binary):
if binary: if binary:
fp.write(str.encode()) fp.write(str.encode())
else: else:
fp.write(str) fp.write(str)
def write_ply(path, verts, trias=None, color=None, normals=None, binary=False): def write_ply(path, verts, trias=None, color=None, normals=None, binary=False):
if verts.shape[1] != 3: if verts.shape[1] != 3:
raise Exception('verts has to be of shape Nx3') raise Exception('verts has to be of shape Nx3')
if trias is not None and trias.shape[1] != 3: if trias is not None and trias.shape[1] != 3:
raise Exception('trias has to be of shape Nx3') raise Exception('trias has to be of shape Nx3')
if color is not None and not callable(color) and not isinstance(color, np.ndarray) and color.shape[1] != 3: if color is not None and not callable(color) and not isinstance(color, np.ndarray) and color.shape[1] != 3:
raise Exception('color has to be of shape Nx3 or a callable') raise Exception('color has to be of shape Nx3 or a callable')
mode = 'wb' if binary else 'w' mode = 'wb' if binary else 'w'
with open(path, mode) as fp: with open(path, mode) as fp:
_write_ply_header_line(fp, "ply\n", binary) _write_ply_header_line(fp, "ply\n", binary)
if binary: if binary:
_write_ply_header_line(fp, "format binary_little_endian 1.0\n", binary) _write_ply_header_line(fp, "format binary_little_endian 1.0\n", binary)
else:
_write_ply_header_line(fp, "format ascii 1.0\n", binary)
_write_ply_header_line(fp, "element vertex %d\n" % (verts.shape[0]), binary)
_write_ply_header_line(fp, "property float32 x\n", binary)
_write_ply_header_line(fp, "property float32 y\n", binary)
_write_ply_header_line(fp, "property float32 z\n", binary)
if color is not None:
_write_ply_header_line(fp, "property uchar red\n", binary)
_write_ply_header_line(fp, "property uchar green\n", binary)
_write_ply_header_line(fp, "property uchar blue\n", binary)
if normals is not None:
_write_ply_header_line(fp, "property float32 nx\n", binary)
_write_ply_header_line(fp, "property float32 ny\n", binary)
_write_ply_header_line(fp, "property float32 nz\n", binary)
if trias is not None:
_write_ply_header_line(fp, "element face %d\n" % (trias.shape[0]), binary)
_write_ply_header_line(fp, "property list uchar int32 vertex_indices\n", binary)
_write_ply_header_line(fp, "end_header\n", binary)
for vidx, v in enumerate(verts):
if color is not None:
if callable(color):
c = color(vidx)
elif color.shape[0] > 1:
c = color[vidx]
else: else:
c = color[0] _write_ply_header_line(fp, "format ascii 1.0\n", binary)
else: _write_ply_header_line(fp, "element vertex %d\n" % (verts.shape[0]), binary)
c = None _write_ply_header_line(fp, "property float32 x\n", binary)
if normals is None: _write_ply_header_line(fp, "property float32 y\n", binary)
n = None _write_ply_header_line(fp, "property float32 z\n", binary)
else: if color is not None:
n = normals[vidx] _write_ply_header_line(fp, "property uchar red\n", binary)
_write_ply_point(fp, v[0],v[1],v[2], c, n, binary) _write_ply_header_line(fp, "property uchar green\n", binary)
_write_ply_header_line(fp, "property uchar blue\n", binary)
if normals is not None:
_write_ply_header_line(fp, "property float32 nx\n", binary)
_write_ply_header_line(fp, "property float32 ny\n", binary)
_write_ply_header_line(fp, "property float32 nz\n", binary)
if trias is not None:
_write_ply_header_line(fp, "element face %d\n" % (trias.shape[0]), binary)
_write_ply_header_line(fp, "property list uchar int32 vertex_indices\n", binary)
_write_ply_header_line(fp, "end_header\n", binary)
for vidx, v in enumerate(verts):
if color is not None:
if callable(color):
c = color(vidx)
elif color.shape[0] > 1:
c = color[vidx]
else:
c = color[0]
else:
c = None
if normals is None:
n = None
else:
n = normals[vidx]
_write_ply_point(fp, v[0], v[1], v[2], c, n, binary)
if trias is not None:
for t in trias:
_write_ply_triangle(fp, t[0], t[1], t[2], binary)
if trias is not None:
for t in trias:
_write_ply_triangle(fp, t[0],t[1],t[2], binary)
def faces_to_triangles(faces): def faces_to_triangles(faces):
new_faces = [] new_faces = []
for f in faces: for f in faces:
if f[0] == 3: if f[0] == 3:
new_faces.append([f[1], f[2], f[3]]) new_faces.append([f[1], f[2], f[3]])
elif f[0] == 4: elif f[0] == 4:
new_faces.append([f[1], f[2], f[3]]) new_faces.append([f[1], f[2], f[3]])
new_faces.append([f[3], f[4], f[1]]) new_faces.append([f[3], f[4], f[1]])
else: else:
raise Exception('unknown face count %d', f[0]) raise Exception('unknown face count %d', f[0])
return new_faces return new_faces
def read_ply(path): def read_ply(path):
with open(path, 'rb') as f: with open(path, 'rb') as f:
# parse header # parse header
line = f.readline().decode().strip() line = f.readline().decode().strip()
if line != 'ply': if line != 'ply':
raise Exception('Header error') raise Exception('Header error')
n_verts = 0 n_verts = 0
n_faces = 0 n_faces = 0
vert_types = {} vert_types = {}
vert_bin_format = [] vert_bin_format = []
vert_bin_len = 0 vert_bin_len = 0
vert_bin_cols = 0 vert_bin_cols = 0
line = f.readline().decode() line = f.readline().decode()
parse_vertex_prop = False
while line.strip() != 'end_header':
if 'format' in line:
if 'ascii' in line:
binary = False
elif 'binary_little_endian' in line:
binary = True
else:
raise Exception('invalid ply format')
if 'element face' in line:
splits = line.strip().split(' ')
n_faces = int(splits[-1])
parse_vertex_prop = False parse_vertex_prop = False
if 'element camera' in line: while line.strip() != 'end_header':
parse_vertex_prop = False if 'format' in line:
if 'element vertex' in line: if 'ascii' in line:
splits = line.strip().split(' ') binary = False
n_verts = int(splits[-1]) elif 'binary_little_endian' in line:
parse_vertex_prop = True binary = True
if parse_vertex_prop and 'property' in line: else:
prop = line.strip().split() raise Exception('invalid ply format')
if prop[1] == 'float': if 'element face' in line:
vert_bin_format.append('f4') splits = line.strip().split(' ')
vert_bin_len += 4 n_faces = int(splits[-1])
vert_bin_cols += 1 parse_vertex_prop = False
elif prop[1] == 'uchar': if 'element camera' in line:
vert_bin_format.append('B') parse_vertex_prop = False
vert_bin_len += 1 if 'element vertex' in line:
vert_bin_cols += 1 splits = line.strip().split(' ')
n_verts = int(splits[-1])
parse_vertex_prop = True
if parse_vertex_prop and 'property' in line:
prop = line.strip().split()
if prop[1] == 'float':
vert_bin_format.append('f4')
vert_bin_len += 4
vert_bin_cols += 1
elif prop[1] == 'uchar':
vert_bin_format.append('B')
vert_bin_len += 1
vert_bin_cols += 1
else:
raise Exception('invalid property')
vert_types[prop[2]] = len(vert_types)
line = f.readline().decode()
# parse content
if binary:
sz = n_verts * vert_bin_len
fmt = ','.join(vert_bin_format)
verts = np.ndarray(shape=(1, n_verts), dtype=np.dtype(fmt), buffer=f.read(sz))
verts = verts[0].astype(vert_bin_cols * 'f4,').view(dtype='f4').reshape((n_verts, -1))
faces = []
for idx in range(n_faces):
fmt = '<Biii'
length = struct.calcsize(fmt)
dat = f.read(length)
vals = struct.unpack(fmt, dat)
faces.append(vals)
faces = faces_to_triangles(faces)
faces = np.array(faces, dtype=np.int32)
else: else:
raise Exception('invalid property') verts = []
vert_types[prop[2]] = len(vert_types) for idx in range(n_verts):
line = f.readline().decode() vals = [float(v) for v in f.readline().decode().strip().split(' ')]
verts.append(vals)
verts = np.array(verts, dtype=np.float32)
faces = []
for idx in range(n_faces):
splits = f.readline().decode().strip().split(' ')
n_face_verts = int(splits[0])
vals = [int(v) for v in splits[0:n_face_verts + 1]]
faces.append(vals)
faces = faces_to_triangles(faces)
faces = np.array(faces, dtype=np.int32)
# parse content xyz = None
if binary: if 'x' in vert_types and 'y' in vert_types and 'z' in vert_types:
sz = n_verts * vert_bin_len xyz = verts[:, [vert_types['x'], vert_types['y'], vert_types['z']]]
fmt = ','.join(vert_bin_format) colors = None
verts = np.ndarray(shape=(1, n_verts), dtype=np.dtype(fmt), buffer=f.read(sz)) if 'red' in vert_types and 'green' in vert_types and 'blue' in vert_types:
verts = verts[0].astype(vert_bin_cols*'f4,').view(dtype='f4').reshape((n_verts,-1)) colors = verts[:, [vert_types['red'], vert_types['green'], vert_types['blue']]]
faces = [] colors /= 255
for idx in range(n_faces): normals = None
fmt = '<Biii' if 'nx' in vert_types and 'ny' in vert_types and 'nz' in vert_types:
length = struct.calcsize(fmt) normals = verts[:, [vert_types['nx'], vert_types['ny'], vert_types['nz']]]
dat = f.read(length)
vals = struct.unpack(fmt, dat)
faces.append(vals)
faces = faces_to_triangles(faces)
faces = np.array(faces, dtype=np.int32)
else:
verts = []
for idx in range(n_verts):
vals = [float(v) for v in f.readline().decode().strip().split(' ')]
verts.append(vals)
verts = np.array(verts, dtype=np.float32)
faces = []
for idx in range(n_faces):
splits = f.readline().decode().strip().split(' ')
n_face_verts = int(splits[0])
vals = [int(v) for v in splits[0:n_face_verts+1]]
faces.append(vals)
faces = faces_to_triangles(faces)
faces = np.array(faces, dtype=np.int32)
xyz = None return xyz, faces, colors, normals
if 'x' in vert_types and 'y' in vert_types and 'z' in vert_types:
xyz = verts[:,[vert_types['x'], vert_types['y'], vert_types['z']]]
colors = None
if 'red' in vert_types and 'green' in vert_types and 'blue' in vert_types:
colors = verts[:,[vert_types['red'], vert_types['green'], vert_types['blue']]]
colors /= 255
normals = None
if 'nx' in vert_types and 'ny' in vert_types and 'nz' in vert_types:
normals = verts[:,[vert_types['nx'], vert_types['ny'], vert_types['nz']]]
return xyz, faces, colors, normals
def _read_obj_split_f(s): def _read_obj_split_f(s):
parts = s.split('/') parts = s.split('/')
vidx = int(parts[0]) - 1 vidx = int(parts[0]) - 1
if len(parts) >= 2 and len(parts[1]) > 0: if len(parts) >= 2 and len(parts[1]) > 0:
tidx = int(parts[1]) - 1 tidx = int(parts[1]) - 1
else: else:
tidx = -1 tidx = -1
if len(parts) >= 3 and len(parts[2]) > 0: if len(parts) >= 3 and len(parts[2]) > 0:
nidx = int(parts[2]) - 1 nidx = int(parts[2]) - 1
else: else:
nidx = -1 nidx = -1
return vidx, tidx, nidx return vidx, tidx, nidx
def read_obj(path): def read_obj(path):
with open(path, 'r') as fp: with open(path, 'r') as fp:
lines = fp.readlines() lines = fp.readlines()
verts = [] verts = []
colors = [] colors = []
fnorms = [] fnorms = []
fnorm_map = collections.defaultdict(list) fnorm_map = collections.defaultdict(list)
faces = [] faces = []
for line in lines: for line in lines:
line = line.strip() line = line.strip()
if line.startswith('#') or len(line) == 0: if line.startswith('#') or len(line) == 0:
continue continue
parts = line.split() parts = line.split()
if line.startswith('v '): if line.startswith('v '):
parts = parts[1:] parts = parts[1:]
x,y,z = float(parts[0]), float(parts[1]), float(parts[2]) x, y, z = float(parts[0]), float(parts[1]), float(parts[2])
if len(parts) == 4 or len(parts) == 7: if len(parts) == 4 or len(parts) == 7:
w = float(parts[3]) w = float(parts[3])
x,y,z = x/w, y/w, z/w x, y, z = x / w, y / w, z / w
verts.append((x,y,z)) verts.append((x, y, z))
if len(parts) >= 6: if len(parts) >= 6:
r,g,b = float(parts[-3]), float(parts[-2]), float(parts[-1]) r, g, b = float(parts[-3]), float(parts[-2]), float(parts[-1])
rgb.append((r,g,b)) rgb.append((r, g, b))
elif line.startswith('vn '): elif line.startswith('vn '):
parts = parts[1:] parts = parts[1:]
x,y,z = float(parts[0]), float(parts[1]), float(parts[2]) x, y, z = float(parts[0]), float(parts[1]), float(parts[2])
fnorms.append((x,y,z)) fnorms.append((x, y, z))
elif line.startswith('f '): elif line.startswith('f '):
parts = parts[1:] parts = parts[1:]
if len(parts) != 3: if len(parts) != 3:
raise Exception('only triangle meshes supported atm') raise Exception('only triangle meshes supported atm')
vidx0, tidx0, nidx0 = _read_obj_split_f(parts[0]) vidx0, tidx0, nidx0 = _read_obj_split_f(parts[0])
vidx1, tidx1, nidx1 = _read_obj_split_f(parts[1]) vidx1, tidx1, nidx1 = _read_obj_split_f(parts[1])
vidx2, tidx2, nidx2 = _read_obj_split_f(parts[2]) vidx2, tidx2, nidx2 = _read_obj_split_f(parts[2])
faces.append((vidx0, vidx1, vidx2)) faces.append((vidx0, vidx1, vidx2))
if nidx0 >= 0: if nidx0 >= 0:
fnorm_map[vidx0].append( nidx0 ) fnorm_map[vidx0].append(nidx0)
if nidx1 >= 0: if nidx1 >= 0:
fnorm_map[vidx1].append( nidx1 ) fnorm_map[vidx1].append(nidx1)
if nidx2 >= 0: if nidx2 >= 0:
fnorm_map[vidx2].append( nidx2 ) fnorm_map[vidx2].append(nidx2)
verts = np.array(verts) verts = np.array(verts)
colors = np.array(colors) colors = np.array(colors)
fnorms = np.array(fnorms) fnorms = np.array(fnorms)
faces = np.array(faces) faces = np.array(faces)
# face normals to vertex normals
norms = np.zeros_like(verts)
for vidx in fnorm_map.keys():
ind = fnorm_map[vidx]
norms[vidx] = fnorms[ind].sum(axis=0)
N = np.linalg.norm(norms, axis=1, keepdims=True)
np.divide(norms, N, out=norms, where=N != 0)
return verts, faces, colors, norms # face normals to vertex normals
norms = np.zeros_like(verts)
for vidx in fnorm_map.keys():
ind = fnorm_map[vidx]
norms[vidx] = fnorms[ind].sum(axis=0)
N = np.linalg.norm(norms, axis=1, keepdims=True)
np.divide(norms, N, out=norms, where=N != 0)
return verts, faces, colors, norms

View File

@ -1,248 +1,260 @@
import numpy as np import numpy as np
from . import geometry from . import geometry
def _process_inputs(estimate, target, mask): def _process_inputs(estimate, target, mask):
if estimate.shape != target.shape: if estimate.shape != target.shape:
raise Exception('estimate and target have to be same shape') raise Exception('estimate and target have to be same shape')
if mask is None: if mask is None:
mask = np.ones(estimate.shape, dtype=np.bool) mask = np.ones(estimate.shape, dtype=np.bool)
else: else:
mask = mask != 0 mask = mask != 0
if estimate.shape != mask.shape: if estimate.shape != mask.shape:
raise Exception('estimate and mask have to be same shape') raise Exception('estimate and mask have to be same shape')
return estimate, target, mask return estimate, target, mask
def mse(estimate, target, mask=None): def mse(estimate, target, mask=None):
estimate, target, mask = _process_inputs(estimate, target, mask) estimate, target, mask = _process_inputs(estimate, target, mask)
m = np.sum((estimate[mask] - target[mask])**2) / mask.sum() m = np.sum((estimate[mask] - target[mask]) ** 2) / mask.sum()
return m return m
def rmse(estimate, target, mask=None): def rmse(estimate, target, mask=None):
return np.sqrt(mse(estimate, target, mask)) return np.sqrt(mse(estimate, target, mask))
def mae(estimate, target, mask=None): def mae(estimate, target, mask=None):
estimate, target, mask = _process_inputs(estimate, target, mask) estimate, target, mask = _process_inputs(estimate, target, mask)
m = np.abs(estimate[mask] - target[mask]).sum() / mask.sum() m = np.abs(estimate[mask] - target[mask]).sum() / mask.sum()
return m return m
def outlier_fraction(estimate, target, mask=None, threshold=0): def outlier_fraction(estimate, target, mask=None, threshold=0):
estimate, target, mask = _process_inputs(estimate, target, mask) estimate, target, mask = _process_inputs(estimate, target, mask)
diff = np.abs(estimate[mask] - target[mask]) diff = np.abs(estimate[mask] - target[mask])
m = (diff > threshold).sum() / mask.sum() m = (diff > threshold).sum() / mask.sum()
return m return m
class Metric(object): class Metric(object):
def __init__(self, str_prefix=''): def __init__(self, str_prefix=''):
self.str_prefix = str_prefix self.str_prefix = str_prefix
self.reset() self.reset()
def reset(self): def reset(self):
pass pass
def add(self, es, ta, ma=None): def add(self, es, ta, ma=None):
pass pass
def get(self): def get(self):
return {} return {}
def items(self): def items(self):
return self.get().items() return self.get().items()
def __str__(self):
return ', '.join([f'{self.str_prefix}{key}={value:.5f}' for key, value in self.get().items()])
def __str__(self):
return ', '.join([f'{self.str_prefix}{key}={value:.5f}' for key, value in self.get().items()])
class MultipleMetric(Metric): class MultipleMetric(Metric):
def __init__(self, *metrics, **kwargs): def __init__(self, *metrics, **kwargs):
self.metrics = [*metrics] self.metrics = [*metrics]
super().__init__(**kwargs) super().__init__(**kwargs)
def reset(self): def reset(self):
for m in self.metrics: for m in self.metrics:
m.reset() m.reset()
def add(self, es, ta, ma=None): def add(self, es, ta, ma=None):
for m in self.metrics: for m in self.metrics:
m.add(es, ta, ma) m.add(es, ta, ma)
def get(self): def get(self):
ret = {} ret = {}
for m in self.metrics: for m in self.metrics:
vals = m.get() vals = m.get()
for k in vals: for k in vals:
ret[k] = vals[k] ret[k] = vals[k]
return ret return ret
def __str__(self):
return '\n'.join([str(m) for m in self.metrics])
def __str__(self):
return '\n'.join([str(m) for m in self.metrics])
class BaseDistanceMetric(Metric): class BaseDistanceMetric(Metric):
def __init__(self, name='', **kwargs): def __init__(self, name='', **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.name = name self.name = name
def reset(self): def reset(self):
self.dists = [] self.dists = []
def add(self, es, ta, ma=None): def add(self, es, ta, ma=None):
pass pass
def get(self):
dists = np.hstack(self.dists)
return {
f'dist{self.name}_mean': float(np.mean(dists)),
f'dist{self.name}_std': float(np.std(dists)),
f'dist{self.name}_median': float(np.median(dists)),
f'dist{self.name}_q10': float(np.percentile(dists, 10)),
f'dist{self.name}_q90': float(np.percentile(dists, 90)),
f'dist{self.name}_min': float(np.min(dists)),
f'dist{self.name}_max': float(np.max(dists)),
}
def get(self):
dists = np.hstack(self.dists)
return {
f'dist{self.name}_mean': float(np.mean(dists)),
f'dist{self.name}_std': float(np.std(dists)),
f'dist{self.name}_median': float(np.median(dists)),
f'dist{self.name}_q10': float(np.percentile(dists, 10)),
f'dist{self.name}_q90': float(np.percentile(dists, 90)),
f'dist{self.name}_min': float(np.min(dists)),
f'dist{self.name}_max': float(np.max(dists)),
}
class DistanceMetric(BaseDistanceMetric): class DistanceMetric(BaseDistanceMetric):
def __init__(self, vec_length, p=2, **kwargs): def __init__(self, vec_length, p=2, **kwargs):
super().__init__(name=f'{p}', **kwargs) super().__init__(name=f'{p}', **kwargs)
self.vec_length = vec_length self.vec_length = vec_length
self.p = p self.p = p
def add(self, es, ta, ma=None):
if es.shape != ta.shape or es.shape[1] != self.vec_length or es.ndim != 2:
print(es.shape, ta.shape)
raise Exception('es and ta have to be of shape Nxdim')
if ma is not None:
es = es[ma != 0]
ta = ta[ma != 0]
dist = np.linalg.norm(es - ta, ord=self.p, axis=1)
self.dists.append(dist)
def add(self, es, ta, ma=None):
if es.shape != ta.shape or es.shape[1] != self.vec_length or es.ndim != 2:
print(es.shape, ta.shape)
raise Exception('es and ta have to be of shape Nxdim')
if ma is not None:
es = es[ma != 0]
ta = ta[ma != 0]
dist = np.linalg.norm(es - ta, ord=self.p, axis=1)
self.dists.append( dist )
class OutlierFractionMetric(DistanceMetric): class OutlierFractionMetric(DistanceMetric):
def __init__(self, thresholds, *args, **kwargs): def __init__(self, thresholds, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.thresholds = thresholds self.thresholds = thresholds
def get(self):
dists = np.hstack(self.dists)
ret = {}
for t in self.thresholds:
ma = dists > t
ret[f'of{t}'] = float(ma.sum() / ma.size)
return ret
def get(self):
dists = np.hstack(self.dists)
ret = {}
for t in self.thresholds:
ma = dists > t
ret[f'of{t}'] = float(ma.sum() / ma.size)
return ret
class RelativeDistanceMetric(BaseDistanceMetric): class RelativeDistanceMetric(BaseDistanceMetric):
def __init__(self, vec_length, p=2, **kwargs): def __init__(self, vec_length, p=2, **kwargs):
super().__init__(name=f'rel{p}', **kwargs) super().__init__(name=f'rel{p}', **kwargs)
self.vec_length = vec_length self.vec_length = vec_length
self.p = p self.p = p
def add(self, es, ta, ma=None):
if es.shape != ta.shape or es.shape[1] != self.vec_length or es.ndim != 2:
raise Exception('es and ta have to be of shape Nxdim')
dist = np.linalg.norm(es - ta, ord=self.p, axis=1)
denom = np.linalg.norm(ta, ord=self.p, axis=1)
dist /= denom
if ma is not None:
dist = dist[ma != 0]
self.dists.append(dist)
def add(self, es, ta, ma=None):
if es.shape != ta.shape or es.shape[1] != self.vec_length or es.ndim != 2:
raise Exception('es and ta have to be of shape Nxdim')
dist = np.linalg.norm(es - ta, ord=self.p, axis=1)
denom = np.linalg.norm(ta, ord=self.p, axis=1)
dist /= denom
if ma is not None:
dist = dist[ma != 0]
self.dists.append( dist )
class RotmDistanceMetric(BaseDistanceMetric): class RotmDistanceMetric(BaseDistanceMetric):
def __init__(self, type='identity', **kwargs): def __init__(self, type='identity', **kwargs):
super().__init__(name=type, **kwargs) super().__init__(name=type, **kwargs)
self.type = type self.type = type
def add(self, es, ta, ma=None):
if es.shape != ta.shape or es.shape[1] != 3 or es.shape[2] != 3 or es.ndim != 3:
print(es.shape, ta.shape)
raise Exception('es and ta have to be of shape Nx3x3')
if ma is not None:
raise Exception('mask is not implemented')
if self.type == 'identity':
self.dists.append(geometry.rotm_distance_identity(es, ta))
elif self.type == 'geodesic':
self.dists.append(geometry.rotm_distance_geodesic_unit_sphere(es, ta))
else:
raise Exception('invalid distance type')
def add(self, es, ta, ma=None):
if es.shape != ta.shape or es.shape[1] != 3 or es.shape[2] != 3 or es.ndim != 3:
print(es.shape, ta.shape)
raise Exception('es and ta have to be of shape Nx3x3')
if ma is not None:
raise Exception('mask is not implemented')
if self.type == 'identity':
self.dists.append( geometry.rotm_distance_identity(es, ta) )
elif self.type == 'geodesic':
self.dists.append( geometry.rotm_distance_geodesic_unit_sphere(es, ta) )
else:
raise Exception('invalid distance type')
class QuaternionDistanceMetric(BaseDistanceMetric): class QuaternionDistanceMetric(BaseDistanceMetric):
def __init__(self, type='angle', **kwargs): def __init__(self, type='angle', **kwargs):
super().__init__(name=type, **kwargs) super().__init__(name=type, **kwargs)
self.type = type self.type = type
def add(self, es, ta, ma=None): def add(self, es, ta, ma=None):
if es.shape != ta.shape or es.shape[1] != 4 or es.ndim != 2: if es.shape != ta.shape or es.shape[1] != 4 or es.ndim != 2:
print(es.shape, ta.shape) print(es.shape, ta.shape)
raise Exception('es and ta have to be of shape Nx4') raise Exception('es and ta have to be of shape Nx4')
if ma is not None: if ma is not None:
raise Exception('mask is not implemented') raise Exception('mask is not implemented')
if self.type == 'angle': if self.type == 'angle':
self.dists.append( geometry.quat_distance_angle(es, ta) ) self.dists.append(geometry.quat_distance_angle(es, ta))
elif self.type == 'mineucl': elif self.type == 'mineucl':
self.dists.append( geometry.quat_distance_mineucl(es, ta) ) self.dists.append(geometry.quat_distance_mineucl(es, ta))
elif self.type == 'normdiff': elif self.type == 'normdiff':
self.dists.append( geometry.quat_distance_normdiff(es, ta) ) self.dists.append(geometry.quat_distance_normdiff(es, ta))
else: else:
raise Exception('invalid distance type') raise Exception('invalid distance type')
class BinaryAccuracyMetric(Metric): class BinaryAccuracyMetric(Metric):
def __init__(self, thresholds=np.linspace(0.0, 1.0, num=101, dtype=np.float64)[:-1], **kwargs): def __init__(self, thresholds=np.linspace(0.0, 1.0, num=101, dtype=np.float64)[:-1], **kwargs):
self.thresholds = thresholds self.thresholds = thresholds
super().__init__(**kwargs) super().__init__(**kwargs)
def reset(self): def reset(self):
self.tps = [0 for wp in self.thresholds] self.tps = [0 for wp in self.thresholds]
self.fps = [0 for wp in self.thresholds] self.fps = [0 for wp in self.thresholds]
self.fns = [0 for wp in self.thresholds] self.fns = [0 for wp in self.thresholds]
self.tns = [0 for wp in self.thresholds] self.tns = [0 for wp in self.thresholds]
self.n_pos = 0 self.n_pos = 0
self.n_neg = 0 self.n_neg = 0
def add(self, es, ta, ma=None): def add(self, es, ta, ma=None):
if ma is not None: if ma is not None:
raise Exception('mask is not implemented') raise Exception('mask is not implemented')
es = es.ravel() es = es.ravel()
ta = ta.ravel() ta = ta.ravel()
if es.shape[0] != ta.shape[0]: if es.shape[0] != ta.shape[0]:
raise Exception('invalid shape of es, or ta') raise Exception('invalid shape of es, or ta')
if es.min() < 0 or es.max() > 1: if es.min() < 0 or es.max() > 1:
raise Exception('estimate has wrong value range') raise Exception('estimate has wrong value range')
ta_p = (ta == 1) ta_p = (ta == 1)
ta_n = (ta == 0) ta_n = (ta == 0)
es_p = es[ta_p] es_p = es[ta_p]
es_n = es[ta_n] es_n = es[ta_n]
for idx, wp in enumerate(self.thresholds): for idx, wp in enumerate(self.thresholds):
wp = np.asscalar(wp) wp = np.asscalar(wp)
self.tps[idx] += (es_p > wp).sum() self.tps[idx] += (es_p > wp).sum()
self.fps[idx] += (es_n > wp).sum() self.fps[idx] += (es_n > wp).sum()
self.fns[idx] += (es_p <= wp).sum() self.fns[idx] += (es_p <= wp).sum()
self.tns[idx] += (es_n <= wp).sum() self.tns[idx] += (es_n <= wp).sum()
self.n_pos += ta_p.sum() self.n_pos += ta_p.sum()
self.n_neg += ta_n.sum() self.n_neg += ta_n.sum()
def get(self): def get(self):
tps = np.array(self.tps).astype(np.float32) tps = np.array(self.tps).astype(np.float32)
fps = np.array(self.fps).astype(np.float32) fps = np.array(self.fps).astype(np.float32)
fns = np.array(self.fns).astype(np.float32) fns = np.array(self.fns).astype(np.float32)
tns = np.array(self.tns).astype(np.float32) tns = np.array(self.tns).astype(np.float32)
wp = self.thresholds wp = self.thresholds
ret = {} ret = {}
precisions = np.divide(tps, tps + fps, out=np.zeros_like(tps), where=tps + fps != 0) precisions = np.divide(tps, tps + fps, out=np.zeros_like(tps), where=tps + fps != 0)
recalls = np.divide(tps, tps + fns, out=np.zeros_like(tps), where=tps + fns != 0) # tprs recalls = np.divide(tps, tps + fns, out=np.zeros_like(tps), where=tps + fns != 0) # tprs
fprs = np.divide(fps, fps + tns, out=np.zeros_like(tps), where=fps + tns != 0) fprs = np.divide(fps, fps + tns, out=np.zeros_like(tps), where=fps + tns != 0)
precisions = np.r_[0, precisions, 1] precisions = np.r_[0, precisions, 1]
recalls = np.r_[1, recalls, 0] recalls = np.r_[1, recalls, 0]
fprs = np.r_[1, fprs, 0] fprs = np.r_[1, fprs, 0]
ret['auc'] = float(-np.trapz(recalls, fprs)) ret['auc'] = float(-np.trapz(recalls, fprs))
ret['prauc'] = float(-np.trapz(precisions, recalls)) ret['prauc'] = float(-np.trapz(precisions, recalls))
ret['ap'] = float(-(np.diff(recalls) * precisions[:-1]).sum()) ret['ap'] = float(-(np.diff(recalls) * precisions[:-1]).sum())
accuracies = np.divide(tps + tns, tps + tns + fps + fns) accuracies = np.divide(tps + tns, tps + tns + fps + fns)
aacc = np.mean(accuracies) aacc = np.mean(accuracies)
for t in np.linspace(0,1,num=11)[1:-1]: for t in np.linspace(0, 1, num=11)[1:-1]:
idx = np.argmin(np.abs(t - wp)) idx = np.argmin(np.abs(t - wp))
ret[f'acc{wp[idx]:.2f}'] = float(accuracies[idx]) ret[f'acc{wp[idx]:.2f}'] = float(accuracies[idx])
return ret return ret

163
co/plt.py
View File

@ -6,94 +6,99 @@ import matplotlib.pyplot as plt
import os import os
import time import time
def save(path, remove_axis=False, dpi=300, fig=None): def save(path, remove_axis=False, dpi=300, fig=None):
if fig is None: if fig is None:
fig = plt.gcf() fig = plt.gcf()
dirname = os.path.dirname(path) dirname = os.path.dirname(path)
if dirname != '' and not os.path.exists(dirname): if dirname != '' and not os.path.exists(dirname):
os.makedirs(dirname) os.makedirs(dirname)
if remove_axis: if remove_axis:
for ax in fig.axes: for ax in fig.axes:
ax.axis('off') ax.axis('off')
ax.margins(0,0) ax.margins(0, 0)
fig.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) fig.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
for ax in fig.axes: for ax in fig.axes:
ax.xaxis.set_major_locator(plt.NullLocator()) ax.xaxis.set_major_locator(plt.NullLocator())
ax.yaxis.set_major_locator(plt.NullLocator()) ax.yaxis.set_major_locator(plt.NullLocator())
fig.savefig(path, dpi=dpi, bbox_inches='tight', pad_inches=0) fig.savefig(path, dpi=dpi, bbox_inches='tight', pad_inches=0)
def color_map(im_, cmap='viridis', vmin=None, vmax=None): def color_map(im_, cmap='viridis', vmin=None, vmax=None):
cm = plt.get_cmap(cmap) cm = plt.get_cmap(cmap)
im = im_.copy() im = im_.copy()
if vmin is None: if vmin is None:
vmin = np.nanmin(im) vmin = np.nanmin(im)
if vmax is None: if vmax is None:
vmax = np.nanmax(im) vmax = np.nanmax(im)
mask = np.logical_not(np.isfinite(im)) mask = np.logical_not(np.isfinite(im))
im[mask] = vmin im[mask] = vmin
im = (im.clip(vmin, vmax) - vmin) / (vmax - vmin) im = (im.clip(vmin, vmax) - vmin) / (vmax - vmin)
im = cm(im) im = cm(im)
im = im[...,:3] im = im[..., :3]
for c in range(3): for c in range(3):
im[mask, c] = 1 im[mask, c] = 1
return im return im
def interactive_legend(leg=None, fig=None, all_axes=True): def interactive_legend(leg=None, fig=None, all_axes=True):
if leg is None: if leg is None:
leg = plt.legend() leg = plt.legend()
if fig is None: if fig is None:
fig = plt.gcf() fig = plt.gcf()
if all_axes: if all_axes:
axs = fig.get_axes() axs = fig.get_axes()
else:
axs = [fig.gca()]
# lined = dict()
# lines = ax.lines
# for legline, origline in zip(leg.get_lines(), ax.lines):
# legline.set_picker(5)
# lined[legline] = origline
lined = dict()
for lidx, legline in enumerate(leg.get_lines()):
legline.set_picker(5)
lined[legline] = [ax.lines[lidx] for ax in axs]
def onpick(event):
if event.mouseevent.dblclick:
tmp = [(k,v) for k,v in lined.items()]
else: else:
tmp = [(event.artist, lined[event.artist])] axs = [fig.gca()]
for legline, origline in tmp: # lined = dict()
for ol in origline: # lines = ax.lines
vis = not ol.get_visible() # for legline, origline in zip(leg.get_lines(), ax.lines):
ol.set_visible(vis) # legline.set_picker(5)
if vis: # lined[legline] = origline
legline.set_alpha(1.0) lined = dict()
else: for lidx, legline in enumerate(leg.get_lines()):
legline.set_alpha(0.2) legline.set_picker(5)
fig.canvas.draw() lined[legline] = [ax.lines[lidx] for ax in axs]
def onpick(event):
if event.mouseevent.dblclick:
tmp = [(k, v) for k, v in lined.items()]
else:
tmp = [(event.artist, lined[event.artist])]
for legline, origline in tmp:
for ol in origline:
vis = not ol.get_visible()
ol.set_visible(vis)
if vis:
legline.set_alpha(1.0)
else:
legline.set_alpha(0.2)
fig.canvas.draw()
fig.canvas.mpl_connect('pick_event', onpick)
fig.canvas.mpl_connect('pick_event', onpick)
def non_annoying_pause(interval, focus_figure=False): def non_annoying_pause(interval, focus_figure=False):
# https://github.com/matplotlib/matplotlib/issues/11131 # https://github.com/matplotlib/matplotlib/issues/11131
backend = mpl.rcParams['backend'] backend = mpl.rcParams['backend']
if backend in _interactive_bk: if backend in _interactive_bk:
figManager = _pylab_helpers.Gcf.get_active() figManager = _pylab_helpers.Gcf.get_active()
if figManager is not None: if figManager is not None:
canvas = figManager.canvas canvas = figManager.canvas
if canvas.figure.stale: if canvas.figure.stale:
canvas.draw() canvas.draw()
if focus_figure: if focus_figure:
plt.show(block=False) plt.show(block=False)
canvas.start_event_loop(interval) canvas.start_event_loop(interval)
return return
time.sleep(interval) time.sleep(interval)
def remove_all_ticks(fig=None): def remove_all_ticks(fig=None):
if fig is None: if fig is None:
fig = plt.gcf() fig = plt.gcf()
for ax in fig.axes: for ax in fig.axes:
ax.axes.get_xaxis().set_visible(False) ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False) ax.axes.get_yaxis().set_visible(False)

View File

@ -3,55 +3,60 @@ import matplotlib.pyplot as plt
from . import geometry from . import geometry
def image_matrix(ims, bgval=0): def image_matrix(ims, bgval=0):
n = ims.shape[0] n = ims.shape[0]
m = int( np.ceil(np.sqrt(n)) ) m = int(np.ceil(np.sqrt(n)))
h = ims.shape[1] h = ims.shape[1]
w = ims.shape[2] w = ims.shape[2]
mat = np.empty((m*h, m*w), dtype=ims.dtype) mat = np.empty((m * h, m * w), dtype=ims.dtype)
mat.fill(bgval) mat.fill(bgval)
idx = 0 idx = 0
for r in range(m): for r in range(m):
for c in range(m): for c in range(m):
if idx < n: if idx < n:
mat[r*h:(r+1)*h, c*w:(c+1)*w] = ims[idx] mat[r * h:(r + 1) * h, c * w:(c + 1) * w] = ims[idx]
idx += 1 idx += 1
return mat return mat
def image_cat(ims, vertical=False): def image_cat(ims, vertical=False):
offx = [0] offx = [0]
offy = [0] offy = [0]
if vertical: if vertical:
width = max([im.shape[1] for im in ims]) width = max([im.shape[1] for im in ims])
offx += [0 for im in ims[:-1]] offx += [0 for im in ims[:-1]]
offy += [im.shape[0] for im in ims[:-1]] offy += [im.shape[0] for im in ims[:-1]]
height = sum([im.shape[0] for im in ims]) height = sum([im.shape[0] for im in ims])
else: else:
height = max([im.shape[0] for im in ims]) height = max([im.shape[0] for im in ims])
offx += [im.shape[1] for im in ims[:-1]] offx += [im.shape[1] for im in ims[:-1]]
offy += [0 for im in ims[:-1]] offy += [0 for im in ims[:-1]]
width = sum([im.shape[1] for im in ims]) width = sum([im.shape[1] for im in ims])
offx = np.cumsum(offx) offx = np.cumsum(offx)
offy = np.cumsum(offy) offy = np.cumsum(offy)
im = np.zeros((height,width,*ims[0].shape[2:]), dtype=ims[0].dtype) im = np.zeros((height, width, *ims[0].shape[2:]), dtype=ims[0].dtype)
for im0, ox, oy in zip(ims, offx, offy): for im0, ox, oy in zip(ims, offx, offy):
im[oy:oy + im0.shape[0], ox:ox + im0.shape[1]] = im0 im[oy:oy + im0.shape[0], ox:ox + im0.shape[1]] = im0
return im, offx, offy
return im, offx, offy
def line(li, h, w, ax=None, *args, **kwargs): def line(li, h, w, ax=None, *args, **kwargs):
if ax is None: if ax is None:
ax = plt.gca() ax = plt.gca()
xs = (-li[2] - li[1] * np.array((0, h-1))) / li[0] xs = (-li[2] - li[1] * np.array((0, h - 1))) / li[0]
ys = (-li[2] - li[0] * np.array((0, w-1))) / li[1] ys = (-li[2] - li[0] * np.array((0, w - 1))) / li[1]
pts = np.array([(0,ys[0]), (w-1, ys[1]), (xs[0], 0), (xs[1], h-1)]) pts = np.array([(0, ys[0]), (w - 1, ys[1]), (xs[0], 0), (xs[1], h - 1)])
pts = pts[np.logical_and(np.logical_and(pts[:,0] >= 0, pts[:,0] < w), np.logical_and(pts[:,1] >= 0, pts[:,1] < h))] pts = pts[
ax.plot(pts[:,0], pts[:,1], *args, **kwargs) np.logical_and(np.logical_and(pts[:, 0] >= 0, pts[:, 0] < w), np.logical_and(pts[:, 1] >= 0, pts[:, 1] < h))]
ax.plot(pts[:, 0], pts[:, 1], *args, **kwargs)
def depthshow(depth, *args, ax=None, **kwargs): def depthshow(depth, *args, ax=None, **kwargs):
if ax is None: if ax is None:
ax = plt.gca() ax = plt.gca()
d = depth.copy() d = depth.copy()
d[d < 0] = np.NaN d[d < 0] = np.NaN
ax.imshow(d, *args, **kwargs) ax.imshow(d, *args, **kwargs)

View File

@ -4,35 +4,45 @@ from mpl_toolkits.mplot3d import Axes3D
from . import geometry from . import geometry
def ax3d(fig=None): def ax3d(fig=None):
if fig is None: if fig is None:
fig = plt.gcf() fig = plt.gcf()
return fig.add_subplot(111, projection='3d') return fig.add_subplot(111, projection='3d')
def plot_camera(ax=None, R=np.eye(3), t=np.zeros((3,)), size=25, marker_C='.', color='b', linestyle='-', linewidth=0.1, label=None, **kwargs):
if ax is None:
ax = plt.gca()
C0 = geometry.translation_to_cameracenter(R, t).ravel()
C1 = C0 + R.T.dot( np.array([[-size],[-size],[3*size]], dtype=np.float32) ).ravel()
C2 = C0 + R.T.dot( np.array([[-size],[+size],[3*size]], dtype=np.float32) ).ravel()
C3 = C0 + R.T.dot( np.array([[+size],[+size],[3*size]], dtype=np.float32) ).ravel()
C4 = C0 + R.T.dot( np.array([[+size],[-size],[3*size]], dtype=np.float32) ).ravel()
if marker_C != '': def plot_camera(ax=None, R=np.eye(3), t=np.zeros((3,)), size=25, marker_C='.', color='b', linestyle='-', linewidth=0.1,
ax.plot([C0[0]], [C0[1]], [C0[2]], marker=marker_C, color=color, label=label, **kwargs) label=None, **kwargs):
ax.plot([C0[0], C1[0]], [C0[1], C1[1]], [C0[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs) if ax is None:
ax.plot([C0[0], C2[0]], [C0[1], C2[1]], [C0[2], C2[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs) ax = plt.gca()
ax.plot([C0[0], C3[0]], [C0[1], C3[1]], [C0[2], C3[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs) C0 = geometry.translation_to_cameracenter(R, t).ravel()
ax.plot([C0[0], C4[0]], [C0[1], C4[1]], [C0[2], C4[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs) C1 = C0 + R.T.dot(np.array([[-size], [-size], [3 * size]], dtype=np.float32)).ravel()
ax.plot([C1[0], C2[0], C3[0], C4[0], C1[0]], [C1[1], C2[1], C3[1], C4[1], C1[1]], [C1[2], C2[2], C3[2], C4[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs) C2 = C0 + R.T.dot(np.array([[-size], [+size], [3 * size]], dtype=np.float32)).ravel()
C3 = C0 + R.T.dot(np.array([[+size], [+size], [3 * size]], dtype=np.float32)).ravel()
C4 = C0 + R.T.dot(np.array([[+size], [-size], [3 * size]], dtype=np.float32)).ravel()
if marker_C != '':
ax.plot([C0[0]], [C0[1]], [C0[2]], marker=marker_C, color=color, label=label, **kwargs)
ax.plot([C0[0], C1[0]], [C0[1], C1[1]], [C0[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle,
linewidth=linewidth, **kwargs)
ax.plot([C0[0], C2[0]], [C0[1], C2[1]], [C0[2], C2[2]], color=color, label='_nolegend_', linestyle=linestyle,
linewidth=linewidth, **kwargs)
ax.plot([C0[0], C3[0]], [C0[1], C3[1]], [C0[2], C3[2]], color=color, label='_nolegend_', linestyle=linestyle,
linewidth=linewidth, **kwargs)
ax.plot([C0[0], C4[0]], [C0[1], C4[1]], [C0[2], C4[2]], color=color, label='_nolegend_', linestyle=linestyle,
linewidth=linewidth, **kwargs)
ax.plot([C1[0], C2[0], C3[0], C4[0], C1[0]], [C1[1], C2[1], C3[1], C4[1], C1[1]],
[C1[2], C2[2], C3[2], C4[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle,
linewidth=linewidth, **kwargs)
def axis_equal(ax=None): def axis_equal(ax=None):
if ax is None: if ax is None:
ax = plt.gca() ax = plt.gca()
extents = np.array([getattr(ax, 'get_{}lim'.format(dim))() for dim in 'xyz']) extents = np.array([getattr(ax, 'get_{}lim'.format(dim))() for dim in 'xyz'])
sz = extents[:,1] - extents[:,0] sz = extents[:, 1] - extents[:, 0]
centers = np.mean(extents, axis=1) centers = np.mean(extents, axis=1)
maxsize = max(abs(sz)) maxsize = max(abs(sz))
r = maxsize/2 r = maxsize / 2
for ctr, dim in zip(centers, 'xyz'): for ctr, dim in zip(centers, 'xyz'):
getattr(ax, 'set_{}lim'.format(dim))(ctr - r, ctr + r) getattr(ax, 'set_{}lim'.format(dim))(ctr - r, ctr + r)

View File

@ -3,443 +3,453 @@ import pandas as pd
import enum import enum
import itertools import itertools
class Table(object): class Table(object):
def __init__(self, n_cols): def __init__(self, n_cols):
self.n_cols = n_cols self.n_cols = n_cols
self.rows = [] self.rows = []
self.aligns = ['r' for c in range(n_cols)] self.aligns = ['r' for c in range(n_cols)]
def get_cell_align(self, r, c): def get_cell_align(self, r, c):
align = self.rows[r].cells[c].align align = self.rows[r].cells[c].align
if align is None: if align is None:
return self.aligns[c] return self.aligns[c]
else:
return align
def add_row(self, row):
if row.ncols() != self.n_cols:
raise Exception(f'row has invalid number of cols, {row.ncols()} vs. {self.n_cols}')
self.rows.append(row)
def empty_row(self):
return Row.Empty(self.n_cols)
def expand_rows(self, n_add_cols=1):
if n_add_cols < 0: raise Exception('n_add_cols has to be positive')
self.n_cols += n_add_cols
for row in self.rows:
row.cells.extend([Cell() for cidx in range(n_add_cols)])
def add_block(self, data, row=-1, col=0, fmt=None, expand=False):
if row < 0: row = len(self.rows)
while len(self.rows) < row + len(data):
self.add_row(self.empty_row())
for r in range(len(data)):
cols = data[r]
if col + len(cols) > self.n_cols:
if expand:
self.expand_rows(col + len(cols) - self.n_cols)
else: else:
raise Exception('number of cols does not fit in table') return align
for c in range(len(cols)):
self.rows[row+r].cells[col+c] = Cell(data[r][c], fmt) def add_row(self, row):
if row.ncols() != self.n_cols:
raise Exception(f'row has invalid number of cols, {row.ncols()} vs. {self.n_cols}')
self.rows.append(row)
def empty_row(self):
return Row.Empty(self.n_cols)
def expand_rows(self, n_add_cols=1):
if n_add_cols < 0: raise Exception('n_add_cols has to be positive')
self.n_cols += n_add_cols
for row in self.rows:
row.cells.extend([Cell() for cidx in range(n_add_cols)])
def add_block(self, data, row=-1, col=0, fmt=None, expand=False):
if row < 0: row = len(self.rows)
while len(self.rows) < row + len(data):
self.add_row(self.empty_row())
for r in range(len(data)):
cols = data[r]
if col + len(cols) > self.n_cols:
if expand:
self.expand_rows(col + len(cols) - self.n_cols)
else:
raise Exception('number of cols does not fit in table')
for c in range(len(cols)):
self.rows[row + r].cells[col + c] = Cell(data[r][c], fmt)
class Row(object): class Row(object):
def __init__(self, cells, pre_separator=None, post_separator=None): def __init__(self, cells, pre_separator=None, post_separator=None):
self.cells = cells self.cells = cells
self.pre_separator = pre_separator self.pre_separator = pre_separator
self.post_separator = post_separator self.post_separator = post_separator
@classmethod @classmethod
def Empty(cls, n_cols): def Empty(cls, n_cols):
return Row([Cell() for c in range(n_cols)]) return Row([Cell() for c in range(n_cols)])
def add_cell(self, cell): def add_cell(self, cell):
self.cells.append(cell) self.cells.append(cell)
def ncols(self):
return sum([c.span for c in self.cells])
def ncols(self):
return sum([c.span for c in self.cells])
class Color(object): class Color(object):
def __init__(self, color=(0,0,0), fmt='rgb'): def __init__(self, color=(0, 0, 0), fmt='rgb'):
if fmt == 'rgb': if fmt == 'rgb':
self.color = color self.color = color
elif fmt == 'RGB': elif fmt == 'RGB':
self.color = tuple(c / 255 for c in color) self.color = tuple(c / 255 for c in color)
else: else:
return Exception('invalid color format') return Exception('invalid color format')
def as_rgb(self): def as_rgb(self):
return self.color return self.color
def as_RGB(self): def as_RGB(self):
return tuple(int(c * 255) for c in self.color) return tuple(int(c * 255) for c in self.color)
@classmethod @classmethod
def rgb(cls, r, g, b): def rgb(cls, r, g, b):
return Color(color=(r,g,b), fmt='rgb') return Color(color=(r, g, b), fmt='rgb')
@classmethod @classmethod
def RGB(cls, r, g, b): def RGB(cls, r, g, b):
return Color(color=(r,g,b), fmt='RGB') return Color(color=(r, g, b), fmt='RGB')
class CellFormat(object): class CellFormat(object):
def __init__(self, fmt='%s', fgcolor=None, bgcolor=None, bold=False): def __init__(self, fmt='%s', fgcolor=None, bgcolor=None, bold=False):
self.fmt = fmt self.fmt = fmt
self.fgcolor = fgcolor self.fgcolor = fgcolor
self.bgcolor = bgcolor self.bgcolor = bgcolor
self.bold = bold self.bold = bold
class Cell(object): class Cell(object):
def __init__(self, data=None, fmt=None, span=1, align=None): def __init__(self, data=None, fmt=None, span=1, align=None):
self.data = data self.data = data
if fmt is None: if fmt is None:
fmt = CellFormat() fmt = CellFormat()
self.fmt = fmt self.fmt = fmt
self.span = span self.span = span
self.align = align self.align = align
def __str__(self):
return self.fmt.fmt % self.data
def __str__(self):
return self.fmt.fmt % self.data
class Separator(enum.Enum): class Separator(enum.Enum):
HEAD = 1 HEAD = 1
BOTTOM = 2 BOTTOM = 2
INNER = 3 INNER = 3
class Renderer(object): class Renderer(object):
def __init__(self): def __init__(self):
pass pass
def cell_str_len(self, cell): def cell_str_len(self, cell):
return len(str(cell)) return len(str(cell))
def col_widths(self, table): def col_widths(self, table):
widths = [0 for c in range(table.n_cols)] widths = [0 for c in range(table.n_cols)]
for row in table.rows: for row in table.rows:
cidx = 0 cidx = 0
for cell in row.cells: for cell in row.cells:
if cell.span == 1: if cell.span == 1:
strlen = self.cell_str_len(cell) strlen = self.cell_str_len(cell)
widths[cidx] = max(widths[cidx], strlen) widths[cidx] = max(widths[cidx], strlen)
cidx += cell.span cidx += cell.span
return widths return widths
def render(self, table): def render(self, table):
raise NotImplementedError('not implemented') raise NotImplementedError('not implemented')
def __call__(self, table): def __call__(self, table):
return self.render(table) return self.render(table)
def render_to_file_comment(self): def render_to_file_comment(self):
return '' return ''
def render_to_file(self, path, table):
txt = self.render(table)
with open(path, 'w') as fp:
fp.write(txt)
def render_to_file(self, path, table):
txt = self.render(table)
with open(path, 'w') as fp:
fp.write(txt)
class TerminalRenderer(Renderer): class TerminalRenderer(Renderer):
def __init__(self, col_sep=' '): def __init__(self, col_sep=' '):
super().__init__() super().__init__()
self.col_sep = col_sep self.col_sep = col_sep
def render_cell(self, table, row, col, widths): def render_cell(self, table, row, col, widths):
cell = table.rows[row].cells[col] cell = table.rows[row].cells[col]
str = cell.fmt.fmt % cell.data str = cell.fmt.fmt % cell.data
str_width = len(str) str_width = len(str)
cell_width = sum([widths[idx] for idx in range(col, col+cell.span)]) cell_width = sum([widths[idx] for idx in range(col, col + cell.span)])
cell_width += len(self.col_sep) * (cell.span - 1) cell_width += len(self.col_sep) * (cell.span - 1)
if len(str) > cell_width: if len(str) > cell_width:
str = str[:cell_width] str = str[:cell_width]
if cell.fmt.bold: if cell.fmt.bold:
# str = sty.ef.bold + str + sty.rs.bold_dim # str = sty.ef.bold + str + sty.rs.bold_dim
# str = sty.ef.bold + str + sty.rs.bold # str = sty.ef.bold + str + sty.rs.bold
pass pass
if cell.fmt.fgcolor is not None: if cell.fmt.fgcolor is not None:
# color = cell.fmt.fgcolor.as_RGB() # color = cell.fmt.fgcolor.as_RGB()
# str = sty.fg(*color) + str + sty.rs.fg # str = sty.fg(*color) + str + sty.rs.fg
pass pass
if str_width < cell_width: if str_width < cell_width:
n_ws = (cell_width - str_width) n_ws = (cell_width - str_width)
if table.get_cell_align(row, col) == 'r': if table.get_cell_align(row, col) == 'r':
str = ' '*n_ws + str str = ' ' * n_ws + str
elif table.get_cell_align(row, col) == 'l': elif table.get_cell_align(row, col) == 'l':
str = str + ' '*n_ws str = str + ' ' * n_ws
elif table.get_cell_align(row, col) == 'c': elif table.get_cell_align(row, col) == 'c':
n_ws1 = n_ws // 2 n_ws1 = n_ws // 2
n_ws0 = n_ws - n_ws1 n_ws0 = n_ws - n_ws1
str = ' '*n_ws0 + str + ' '*n_ws1 str = ' ' * n_ws0 + str + ' ' * n_ws1
if cell.fmt.bgcolor is not None: if cell.fmt.bgcolor is not None:
# color = cell.fmt.bgcolor.as_RGB() # color = cell.fmt.bgcolor.as_RGB()
# str = sty.bg(*color) + str + sty.rs.bg # str = sty.bg(*color) + str + sty.rs.bg
pass pass
return str return str
def render_separator(self, separator, tab, col_widths, total_width): def render_separator(self, separator, tab, col_widths, total_width):
if separator == Separator.HEAD: if separator == Separator.HEAD:
return '='*total_width return '=' * total_width
elif separator == Separator.INNER: elif separator == Separator.INNER:
return '-'*total_width return '-' * total_width
elif separator == Separator.BOTTOM: elif separator == Separator.BOTTOM:
return '='*total_width return '=' * total_width
def render(self, table):
widths = self.col_widths(table)
total_width = sum(widths) + len(self.col_sep) * (table.n_cols - 1)
lines = []
for ridx, row in enumerate(table.rows):
if row.pre_separator is not None:
sepline = self.render_separator(row.pre_separator, table, widths, total_width)
if len(sepline) > 0:
lines.append(sepline)
line = []
for cidx, cell in enumerate(row.cells):
line.append(self.render_cell(table, ridx, cidx, widths))
lines.append(self.col_sep.join(line))
if row.post_separator is not None:
sepline = self.render_separator(row.post_separator, table, widths, total_width)
if len(sepline) > 0:
lines.append(sepline)
return '\n'.join(lines)
def render(self, table):
widths = self.col_widths(table)
total_width = sum(widths) + len(self.col_sep) * (table.n_cols - 1)
lines = []
for ridx, row in enumerate(table.rows):
if row.pre_separator is not None:
sepline = self.render_separator(row.pre_separator, table, widths, total_width)
if len(sepline) > 0:
lines.append(sepline)
line = []
for cidx, cell in enumerate(row.cells):
line.append(self.render_cell(table, ridx, cidx, widths))
lines.append(self.col_sep.join(line))
if row.post_separator is not None:
sepline = self.render_separator(row.post_separator, table, widths, total_width)
if len(sepline) > 0:
lines.append(sepline)
return '\n'.join(lines)
class MarkdownRenderer(TerminalRenderer): class MarkdownRenderer(TerminalRenderer):
def __init__(self): def __init__(self):
super().__init__(col_sep='|') super().__init__(col_sep='|')
self.printed_color_warning = False self.printed_color_warning = False
def print_color_warning(self): def print_color_warning(self):
if not self.printed_color_warning: if not self.printed_color_warning:
print('[WARNING] MarkdownRenderer does not support color yet') print('[WARNING] MarkdownRenderer does not support color yet')
self.printed_color_warning = True self.printed_color_warning = True
def cell_str_len(self, cell): def cell_str_len(self, cell):
strlen = len(str(cell)) strlen = len(str(cell))
if cell.fmt.bold: if cell.fmt.bold:
strlen += 4 strlen += 4
strlen = max(5, strlen) strlen = max(5, strlen)
return strlen return strlen
def render_cell(self, table, row, col, widths): def render_cell(self, table, row, col, widths):
cell = table.rows[row].cells[col] cell = table.rows[row].cells[col]
str = cell.fmt.fmt % cell.data str = cell.fmt.fmt % cell.data
if cell.fmt.bold: if cell.fmt.bold:
str = f'**{str}**' str = f'**{str}**'
str_width = len(str) str_width = len(str)
cell_width = sum([widths[idx] for idx in range(col, col+cell.span)]) cell_width = sum([widths[idx] for idx in range(col, col + cell.span)])
cell_width += len(self.col_sep) * (cell.span - 1) cell_width += len(self.col_sep) * (cell.span - 1)
if len(str) > cell_width: if len(str) > cell_width:
str = str[:cell_width] str = str[:cell_width]
else: else:
n_ws = (cell_width - str_width) n_ws = (cell_width - str_width)
if table.get_cell_align(row, col) == 'r': if table.get_cell_align(row, col) == 'r':
str = ' '*n_ws + str str = ' ' * n_ws + str
elif table.get_cell_align(row, col) == 'l': elif table.get_cell_align(row, col) == 'l':
str = str + ' '*n_ws str = str + ' ' * n_ws
elif table.get_cell_align(row, col) == 'c': elif table.get_cell_align(row, col) == 'c':
n_ws1 = n_ws // 2 n_ws1 = n_ws // 2
n_ws0 = n_ws - n_ws1 n_ws0 = n_ws - n_ws1
str = ' '*n_ws0 + str + ' '*n_ws1 str = ' ' * n_ws0 + str + ' ' * n_ws1
if col == 0: str = self.col_sep + str if col == 0: str = self.col_sep + str
if col == table.n_cols - 1: str += self.col_sep if col == table.n_cols - 1: str += self.col_sep
if cell.fmt.fgcolor is not None: if cell.fmt.fgcolor is not None:
self.print_color_warning() self.print_color_warning()
if cell.fmt.bgcolor is not None: if cell.fmt.bgcolor is not None:
self.print_color_warning() self.print_color_warning()
return str return str
def render_separator(self, separator, tab, widths, total_width): def render_separator(self, separator, tab, widths, total_width):
sep = '' sep = ''
if separator == Separator.INNER: if separator == Separator.INNER:
sep = self.col_sep sep = self.col_sep
for idx, width in enumerate(widths): for idx, width in enumerate(widths):
csep = '-' * (width - 2) csep = '-' * (width - 2)
if tab.get_cell_align(1, idx) == 'r': if tab.get_cell_align(1, idx) == 'r':
csep = '-' + csep + ':' csep = '-' + csep + ':'
elif tab.get_cell_align(1, idx) == 'l': elif tab.get_cell_align(1, idx) == 'l':
csep = ':' + csep + '-' csep = ':' + csep + '-'
elif tab.get_cell_align(1, idx) == 'c': elif tab.get_cell_align(1, idx) == 'c':
csep = ':' + csep + ':' csep = ':' + csep + ':'
sep += csep + self.col_sep sep += csep + self.col_sep
return sep return sep
class LatexRenderer(Renderer): class LatexRenderer(Renderer):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def render_cell(self, table, row, col): def render_cell(self, table, row, col):
cell = table.rows[row].cells[col] cell = table.rows[row].cells[col]
str = cell.fmt.fmt % cell.data str = cell.fmt.fmt % cell.data
if cell.fmt.bold: if cell.fmt.bold:
str = '{\\bf '+ str + '}' str = '{\\bf ' + str + '}'
if cell.fmt.fgcolor is not None: if cell.fmt.fgcolor is not None:
color = cell.fmt.fgcolor.as_rgb() color = cell.fmt.fgcolor.as_rgb()
str = f'{{\\color[rgb]{{{color[0]},{color[1]},{color[2]}}} ' + str + '}' str = f'{{\\color[rgb]{{{color[0]},{color[1]},{color[2]}}} ' + str + '}'
if cell.fmt.bgcolor is not None: if cell.fmt.bgcolor is not None:
color = cell.fmt.bgcolor.as_rgb() color = cell.fmt.bgcolor.as_rgb()
str = f'\\cellcolor[rgb]{{{color[0]},{color[1]},{color[2]}}} ' + str str = f'\\cellcolor[rgb]{{{color[0]},{color[1]},{color[2]}}} ' + str
align = table.get_cell_align(row, col) align = table.get_cell_align(row, col)
if cell.span != 1 or align != table.aligns[col]: if cell.span != 1 or align != table.aligns[col]:
str = f'\\multicolumn{{{cell.span}}}{{{align}}}{{{str}}}' str = f'\\multicolumn{{{cell.span}}}{{{align}}}{{{str}}}'
return str return str
def render_separator(self, separator): def render_separator(self, separator):
if separator == Separator.HEAD: if separator == Separator.HEAD:
return '\\toprule' return '\\toprule'
elif separator == Separator.INNER: elif separator == Separator.INNER:
return '\\midrule' return '\\midrule'
elif separator == Separator.BOTTOM: elif separator == Separator.BOTTOM:
return '\\bottomrule' return '\\bottomrule'
def render(self, table):
lines = ['\\begin{tabular}{' + ''.join(table.aligns) + '}']
for ridx, row in enumerate(table.rows):
if row.pre_separator is not None:
lines.append(self.render_separator(row.pre_separator))
line = []
for cidx, cell in enumerate(row.cells):
line.append(self.render_cell(table, ridx, cidx))
lines.append(' & '.join(line) + ' \\\\')
if row.post_separator is not None:
lines.append(self.render_separator(row.post_separator))
lines.append('\\end{tabular}')
return '\n'.join(lines)
def render(self, table):
lines = ['\\begin{tabular}{' + ''.join(table.aligns) + '}']
for ridx, row in enumerate(table.rows):
if row.pre_separator is not None:
lines.append(self.render_separator(row.pre_separator))
line = []
for cidx, cell in enumerate(row.cells):
line.append(self.render_cell(table, ridx, cidx))
lines.append(' & '.join(line) + ' \\\\')
if row.post_separator is not None:
lines.append(self.render_separator(row.post_separator))
lines.append('\\end{tabular}')
return '\n'.join(lines)
class HtmlRenderer(Renderer): class HtmlRenderer(Renderer):
def __init__(self, html_class='result_table'): def __init__(self, html_class='result_table'):
super().__init__() super().__init__()
self.html_class = html_class self.html_class = html_class
def render_cell(self, table, row, col): def render_cell(self, table, row, col):
cell = table.rows[row].cells[col] cell = table.rows[row].cells[col]
str = cell.fmt.fmt % cell.data str = cell.fmt.fmt % cell.data
styles = [] styles = []
if cell.fmt.bold: if cell.fmt.bold:
styles.append('font-weight: bold;') styles.append('font-weight: bold;')
if cell.fmt.fgcolor is not None: if cell.fmt.fgcolor is not None:
color = cell.fmt.fgcolor.as_RGB() color = cell.fmt.fgcolor.as_RGB()
styles.append(f'color: rgb({color[0]},{color[1]},{color[2]});') styles.append(f'color: rgb({color[0]},{color[1]},{color[2]});')
if cell.fmt.bgcolor is not None: if cell.fmt.bgcolor is not None:
color = cell.fmt.bgcolor.as_RGB() color = cell.fmt.bgcolor.as_RGB()
styles.append(f'background-color: rgb({color[0]},{color[1]},{color[2]});') styles.append(f'background-color: rgb({color[0]},{color[1]},{color[2]});')
align = table.get_cell_align(row, col) align = table.get_cell_align(row, col)
if align == 'l': align = 'left' if align == 'l':
elif align == 'r': align = 'right' align = 'left'
elif align == 'c': align = 'center' elif align == 'r':
else: raise Exception('invalid align') align = 'right'
styles.append(f'text-align: {align};') elif align == 'c':
row = table.rows[row] align = 'center'
if row.pre_separator is not None: else:
styles.append(f'border-top: {self.render_separator(row.pre_separator)};') raise Exception('invalid align')
if row.post_separator is not None: styles.append(f'text-align: {align};')
styles.append(f'border-bottom: {self.render_separator(row.post_separator)};') row = table.rows[row]
style = ' '.join(styles) if row.pre_separator is not None:
str = f' <td style="{style}" colspan="{cell.span}">{str}</td>\n' styles.append(f'border-top: {self.render_separator(row.pre_separator)};')
return str if row.post_separator is not None:
styles.append(f'border-bottom: {self.render_separator(row.post_separator)};')
style = ' '.join(styles)
str = f' <td style="{style}" colspan="{cell.span}">{str}</td>\n'
return str
def render_separator(self, separator): def render_separator(self, separator):
if separator == Separator.HEAD: if separator == Separator.HEAD:
return '1.5pt solid black' return '1.5pt solid black'
elif separator == Separator.INNER: elif separator == Separator.INNER:
return '0.75pt solid black' return '0.75pt solid black'
elif separator == Separator.BOTTOM: elif separator == Separator.BOTTOM:
return '1.5pt solid black' return '1.5pt solid black'
def render(self, table): def render(self, table):
lines = [f'<table width="100%" style="border-collapse: collapse" class={self.html_class}>'] lines = [f'<table width="100%" style="border-collapse: collapse" class={self.html_class}>']
for ridx, row in enumerate(table.rows): for ridx, row in enumerate(table.rows):
line = [f' <tr>\n'] line = [f' <tr>\n']
for cidx, cell in enumerate(row.cells): for cidx, cell in enumerate(row.cells):
line.append(self.render_cell(table, ridx, cidx)) line.append(self.render_cell(table, ridx, cidx))
line.append(' </tr>\n') line.append(' </tr>\n')
lines.append(' '.join(line)) lines.append(' '.join(line))
lines.append('</table>') lines.append('</table>')
return '\n'.join(lines) return '\n'.join(lines)
def pandas_to_table(rowname, colname, valname, data, val_cell_fmt=CellFormat(fmt='%.4f'), best_val_cell_fmt=CellFormat(fmt='%.4f', bold=True), best_is_max=[]): def pandas_to_table(rowname, colname, valname, data, val_cell_fmt=CellFormat(fmt='%.4f'),
rnames = data[rowname].unique() best_val_cell_fmt=CellFormat(fmt='%.4f', bold=True), best_is_max=[]):
cnames = data[colname].unique() rnames = data[rowname].unique()
tab = Table(1+len(cnames)) cnames = data[colname].unique()
tab = Table(1 + len(cnames))
header = [Cell('', align='r')] header = [Cell('', align='r')]
header.extend([Cell(h, align='r') for h in cnames]) header.extend([Cell(h, align='r') for h in cnames])
header = Row(header, pre_separator=Separator.HEAD, post_separator=Separator.INNER) header = Row(header, pre_separator=Separator.HEAD, post_separator=Separator.INNER)
tab.add_row(header) tab.add_row(header)
for rname in rnames:
cells = [Cell(rname, align='l')]
for cname in cnames:
cdata = data[data[colname] == cname]
if cname in best_is_max:
bestval = cdata[valname].max()
val = cdata[cdata[rowname] == rname][valname].max()
else:
bestval = cdata[valname].min()
val = cdata[cdata[rowname] == rname][valname].min()
if val == bestval:
fmt = best_val_cell_fmt
else:
fmt = val_cell_fmt
cells.append(Cell(val, align='r', fmt=fmt))
tab.add_row(Row(cells))
tab.rows[-1].post_separator = Separator.BOTTOM
return tab
for rname in rnames:
cells = [Cell(rname, align='l')]
for cname in cnames:
cdata = data[data[colname] == cname]
if cname in best_is_max:
bestval = cdata[valname].max()
val = cdata[cdata[rowname] == rname][valname].max()
else:
bestval = cdata[valname].min()
val = cdata[cdata[rowname] == rname][valname].min()
if val == bestval:
fmt = best_val_cell_fmt
else:
fmt = val_cell_fmt
cells.append(Cell(val, align='r', fmt=fmt))
tab.add_row(Row(cells))
tab.rows[-1].post_separator = Separator.BOTTOM
return tab
if __name__ == '__main__': if __name__ == '__main__':
# df = pd.read_pickle('full.df') # df = pd.read_pickle('full.df')
# best_is_max = ['movF0.5', 'movF1.0'] # best_is_max = ['movF0.5', 'movF1.0']
# tab = pandas_to_table(rowname='method', colname='metric', valname='val', data=df, best_is_max=best_is_max) # tab = pandas_to_table(rowname='method', colname='metric', valname='val', data=df, best_is_max=best_is_max)
# renderer = TerminalRenderer() # renderer = TerminalRenderer()
# print(renderer(tab)) # print(renderer(tab))
tab = Table(7) tab = Table(7)
# header = Row([Cell('header', span=7, align='c')], pre_separator=Separator.HEAD, post_separator=Separator.INNER) # header = Row([Cell('header', span=7, align='c')], pre_separator=Separator.HEAD, post_separator=Separator.INNER)
# tab.add_row(header) # tab.add_row(header)
# header2 = Row([Cell('thisisaverylongheader', span=4, align='c'), Cell('vals2', span=3, align='c')], post_separator=Separator.INNER) # header2 = Row([Cell('thisisaverylongheader', span=4, align='c'), Cell('vals2', span=3, align='c')], post_separator=Separator.INNER)
# tab.add_row(header2) # tab.add_row(header2)
tab.add_row(Row([Cell(f'c{c}') for c in range(7)])) tab.add_row(Row([Cell(f'c{c}') for c in range(7)]))
tab.rows[-1].post_separator = Separator.INNER tab.rows[-1].post_separator = Separator.INNER
tab.add_block(np.arange(15*7).reshape(15,7)) tab.add_block(np.arange(15 * 7).reshape(15, 7))
tab.rows[4].cells[2].fmt = CellFormat(bold=True) tab.rows[4].cells[2].fmt = CellFormat(bold=True)
tab.rows[2].cells[1].fmt = CellFormat(fgcolor=Color.rgb(0.2,0.6,0.1)) tab.rows[2].cells[1].fmt = CellFormat(fgcolor=Color.rgb(0.2, 0.6, 0.1))
tab.rows[2].cells[2].fmt = CellFormat(bgcolor=Color.rgb(0.7,0.1,0.5)) tab.rows[2].cells[2].fmt = CellFormat(bgcolor=Color.rgb(0.7, 0.1, 0.5))
tab.rows[5].cells[3].fmt = CellFormat(bold=True,bgcolor=Color.rgb(0.7,0.1,0.5),fgcolor=Color.rgb(0.1,0.1,0.1)) tab.rows[5].cells[3].fmt = CellFormat(bold=True, bgcolor=Color.rgb(0.7, 0.1, 0.5), fgcolor=Color.rgb(0.1, 0.1, 0.1))
tab.rows[-1].post_separator = Separator.BOTTOM tab.rows[-1].post_separator = Separator.BOTTOM
renderer = TerminalRenderer() renderer = TerminalRenderer()
print(renderer(tab)) print(renderer(tab))
renderer = MarkdownRenderer() renderer = MarkdownRenderer()
print(renderer(tab)) print(renderer(tab))
# renderer = HtmlRenderer() # renderer = HtmlRenderer()
# html_tab = renderer(tab) # html_tab = renderer(tab)
# print(html_tab) # print(html_tab)
# with open('test.html', 'w') as fp: # with open('test.html', 'w') as fp:
# fp.write(html_tab) # fp.write(html_tab)
# import latex # import latex
# renderer = LatexRenderer() # renderer = LatexRenderer()
# ltx_tab = renderer(tab) # ltx_tab = renderer(tab)
# print(ltx_tab) # print(ltx_tab)
# with open('test.tex', 'w') as fp: # with open('test.tex', 'w') as fp:
# latex.write_doc_prefix(fp, document_class='article') # latex.write_doc_prefix(fp, document_class='article')
# fp.write('this is text that should appear before the table and should be long enough to wrap around.\n'*40) # fp.write('this is text that should appear before the table and should be long enough to wrap around.\n'*40)
# fp.write('\\begin{table}') # fp.write('\\begin{table}')
# fp.write(ltx_tab) # fp.write(ltx_tab)
# fp.write('\\end{table}') # fp.write('\\end{table}')
# fp.write('this is text that should appear after the table and should be long enough to wrap around.\n'*40) # fp.write('this is text that should appear after the table and should be long enough to wrap around.\n'*40)
# latex.write_doc_suffix(fp) # latex.write_doc_suffix(fp)

View File

@ -8,6 +8,7 @@ import re
import pickle import pickle
import subprocess import subprocess
def str2bool(v): def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'): if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True return True
@ -16,71 +17,74 @@ def str2bool(v):
else: else:
raise argparse.ArgumentTypeError('Boolean value expected.') raise argparse.ArgumentTypeError('Boolean value expected.')
class StopWatch(object): class StopWatch(object):
def __init__(self): def __init__(self):
self.timings = OrderedDict() self.timings = OrderedDict()
self.starts = {} self.starts = {}
def start(self, name): def start(self, name):
self.starts[name] = time.time() self.starts[name] = time.time()
def stop(self, name): def stop(self, name):
if name not in self.timings: if name not in self.timings:
self.timings[name] = [] self.timings[name] = []
self.timings[name].append(time.time() - self.starts[name]) self.timings[name].append(time.time() - self.starts[name])
def get(self, name=None, reduce=np.sum): def get(self, name=None, reduce=np.sum):
if name is not None: if name is not None:
return reduce(self.timings[name]) return reduce(self.timings[name])
else: else:
ret = {} ret = {}
for k in self.timings: for k in self.timings:
ret[k] = reduce(self.timings[k]) ret[k] = reduce(self.timings[k])
return ret return ret
def __repr__(self):
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
def __str__(self):
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
def __repr__(self):
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
def __str__(self):
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
class ETA(object): class ETA(object):
def __init__(self, length): def __init__(self, length):
self.length = length self.length = length
self.start_time = time.time() self.start_time = time.time()
self.current_idx = 0 self.current_idx = 0
self.current_time = time.time() self.current_time = time.time()
def update(self, idx): def update(self, idx):
self.current_idx = idx self.current_idx = idx
self.current_time = time.time() self.current_time = time.time()
def get_elapsed_time(self): def get_elapsed_time(self):
return self.current_time - self.start_time return self.current_time - self.start_time
def get_item_time(self): def get_item_time(self):
return self.get_elapsed_time() / (self.current_idx + 1) return self.get_elapsed_time() / (self.current_idx + 1)
def get_remaining_time(self): def get_remaining_time(self):
return self.get_item_time() * (self.length - self.current_idx + 1) return self.get_item_time() * (self.length - self.current_idx + 1)
def format_time(self, seconds): def format_time(self, seconds):
minutes, seconds = divmod(seconds, 60) minutes, seconds = divmod(seconds, 60)
hours, minutes = divmod(minutes, 60) hours, minutes = divmod(minutes, 60)
hours = int(hours) hours = int(hours)
minutes = int(minutes) minutes = int(minutes)
return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}' return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}'
def get_elapsed_time_str(self): def get_elapsed_time_str(self):
return self.format_time(self.get_elapsed_time()) return self.format_time(self.get_elapsed_time())
def get_remaining_time_str(self):
return self.format_time(self.get_remaining_time())
def get_remaining_time_str(self):
return self.format_time(self.get_remaining_time())
def git_hash(cwd=None): def git_hash(cwd=None):
ret = subprocess.run(['git', 'describe', '--always'], cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) ret = subprocess.run(['git', 'describe', '--always'], cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
hash = ret.stdout hash = ret.stdout
if hash is not None and 'fatal' not in hash.decode(): if hash is not None and 'fatal' not in hash.decode():
return hash.decode().strip() return hash.decode().strip()
else: else:
return None return None

View File

@ -4,107 +4,109 @@ import cv2
def get_patterns(path='syn', imsizes=[], crop=True): def get_patterns(path='syn', imsizes=[], crop=True):
pattern_size = imsizes[0] pattern_size = imsizes[0]
if path == 'syn': if path == 'syn':
np.random.seed(42) np.random.seed(42)
pattern = np.random.uniform(0,1, size=pattern_size) pattern = np.random.uniform(0, 1, size=pattern_size)
pattern = (pattern < 0.1).astype(np.float32) pattern = (pattern < 0.1).astype(np.float32)
pattern.reshape(*imsizes[0]) pattern.reshape(*imsizes[0])
else: else:
pattern = cv2.imread(path) pattern = cv2.imread(path)
pattern = pattern.astype(np.float32) pattern = pattern.astype(np.float32)
pattern /= 255 pattern /= 255
if pattern.ndim == 2: if pattern.ndim == 2:
pattern = np.stack([pattern for idx in range(3)], axis=2) pattern = np.stack([pattern for idx in range(3)], axis=2)
if crop and pattern.shape[0] > pattern_size[0] and pattern.shape[1] > pattern_size[1]: if crop and pattern.shape[0] > pattern_size[0] and pattern.shape[1] > pattern_size[1]:
r0 = (pattern.shape[0] - pattern_size[0]) // 2 r0 = (pattern.shape[0] - pattern_size[0]) // 2
c0 = (pattern.shape[1] - pattern_size[1]) // 2 c0 = (pattern.shape[1] - pattern_size[1]) // 2
pattern = pattern[r0:r0+imsizes[0][0], c0:c0+imsizes[0][1]] pattern = pattern[r0:r0 + imsizes[0][0], c0:c0 + imsizes[0][1]]
patterns = [] patterns = []
for imsize in imsizes: for imsize in imsizes:
pat = cv2.resize(pattern, (imsize[1],imsize[0]), interpolation=cv2.INTER_LINEAR) pat = cv2.resize(pattern, (imsize[1], imsize[0]), interpolation=cv2.INTER_LINEAR)
patterns.append(pat) patterns.append(pat)
return patterns
return patterns
def get_rotation_matrix(v0, v1): def get_rotation_matrix(v0, v1):
v0 = v0/np.linalg.norm(v0) v0 = v0 / np.linalg.norm(v0)
v1 = v1/np.linalg.norm(v1) v1 = v1 / np.linalg.norm(v1)
v = np.cross(v0,v1) v = np.cross(v0, v1)
c = np.dot(v0,v1) c = np.dot(v0, v1)
s = np.linalg.norm(v) s = np.linalg.norm(v)
I = np.eye(3) I = np.eye(3)
vXStr = '{} {} {}; {} {} {}; {} {} {}'.format(0, -v[2], v[1], v[2], 0, -v[0], -v[1], v[0], 0) vXStr = '{} {} {}; {} {} {}; {} {} {}'.format(0, -v[2], v[1], v[2], 0, -v[0], -v[1], v[0], 0)
k = np.matrix(vXStr) k = np.matrix(vXStr)
r = I + k + k @ k * ((1 -c)/(s**2)) r = I + k + k @ k * ((1 - c) / (s ** 2))
return np.asarray(r.astype(np.float32)) return np.asarray(r.astype(np.float32))
def augment_image(img,rng,disp=None,grad=None,max_shift=64,max_blur=1.5,max_noise=10.0,max_sp_noise=0.001): def augment_image(img, rng, disp=None, grad=None, max_shift=64, max_blur=1.5, max_noise=10.0, max_sp_noise=0.001):
# get min/max values of image # get min/max values of image
min_val = np.min(img) min_val = np.min(img)
max_val = np.max(img) max_val = np.max(img)
# init augmented image # init augmented image
img_aug = img img_aug = img
# init disparity correction map # init disparity correction map
disp_aug = disp disp_aug = disp
grad_aug = grad grad_aug = grad
# apply affine transformation # apply affine transformation
if max_shift>1: if max_shift > 1:
# affine parameters # affine parameters
rows,cols = img.shape rows, cols = img.shape
shear = 0 shear = 0
shift = 0 shift = 0
shear_correction = 0 shear_correction = 0
if rng.uniform(0,1)<0.75: shear = rng.uniform(-max_shift,max_shift) # shear with 75% probability if rng.uniform(0, 1) < 0.75:
else: shift = rng.uniform(0,max_shift) # shift with 25% probability shear = rng.uniform(-max_shift, max_shift) # shear with 75% probability
if shear<0: shear_correction = -shear else:
shift = rng.uniform(0, max_shift) # shift with 25% probability
if shear < 0: shear_correction = -shear
# affine transformation # affine transformation
a = shear/float(rows) a = shear / float(rows)
b = shift+shear_correction b = shift + shear_correction
# warp image # warp image
T = np.float32([[1,a,b],[0,1,0]]) T = np.float32([[1, a, b], [0, 1, 0]])
img_aug = cv2.warpAffine(img_aug,T,(cols,rows)) img_aug = cv2.warpAffine(img_aug, T, (cols, rows))
if grad is not None: if grad is not None:
grad_aug = cv2.warpAffine(grad,T,(cols,rows)) grad_aug = cv2.warpAffine(grad, T, (cols, rows))
# disparity correction map # disparity correction map
col = a*np.array(range(rows))+b col = a * np.array(range(rows)) + b
disp_delta = np.tile(col,(cols,1)).transpose() disp_delta = np.tile(col, (cols, 1)).transpose()
if disp is not None: if disp is not None:
disp_aug = cv2.warpAffine(disp+disp_delta,T,(cols,rows)) disp_aug = cv2.warpAffine(disp + disp_delta, T, (cols, rows))
# gaussian smoothing # gaussian smoothing
if rng.uniform(0,1)<0.5: if rng.uniform(0, 1) < 0.5:
img_aug = cv2.GaussianBlur(img_aug,(5,5),rng.uniform(0.2,max_blur)) img_aug = cv2.GaussianBlur(img_aug, (5, 5), rng.uniform(0.2, max_blur))
# per-pixel gaussian noise # per-pixel gaussian noise
img_aug = img_aug + rng.randn(*img_aug.shape)*rng.uniform(0.0,max_noise)/255.0 img_aug = img_aug + rng.randn(*img_aug.shape) * rng.uniform(0.0, max_noise) / 255.0
# salt-and-pepper noise # salt-and-pepper noise
if rng.uniform(0,1)<0.5: if rng.uniform(0, 1) < 0.5:
ratio=rng.uniform(0.0,max_sp_noise) ratio = rng.uniform(0.0, max_sp_noise)
img_shape = img_aug.shape img_shape = img_aug.shape
img_aug = img_aug.flatten() img_aug = img_aug.flatten()
coord = rng.choice(np.size(img_aug), int(np.size(img_aug)*ratio)) coord = rng.choice(np.size(img_aug), int(np.size(img_aug) * ratio))
img_aug[coord] = max_val img_aug[coord] = max_val
coord = rng.choice(np.size(img_aug), int(np.size(img_aug)*ratio)) coord = rng.choice(np.size(img_aug), int(np.size(img_aug) * ratio))
img_aug[coord] = min_val img_aug[coord] = min_val
img_aug = np.reshape(img_aug, img_shape) img_aug = np.reshape(img_aug, img_shape)
# clip intensities back to [0,1] # clip intensities back to [0,1]
img_aug = np.maximum(img_aug,0.0) img_aug = np.maximum(img_aug, 0.0)
img_aug = np.minimum(img_aug,1.0) img_aug = np.minimum(img_aug, 1.0)
# return image # return image
return img_aug, disp_aug, grad_aug return img_aug, disp_aug, grad_aug

View File

@ -10,261 +10,259 @@ import cv2
import os import os
import collections import collections
import sys import sys
sys.path.append('../') sys.path.append('../')
import renderer import renderer
import co import co
from commons import get_patterns,get_rotation_matrix from commons import get_patterns, get_rotation_matrix
from lcn import lcn from lcn import lcn
def get_objs(shapenet_dir, obj_classes, num_perclass=100): def get_objs(shapenet_dir, obj_classes, num_perclass=100):
shapenet = {'chair': '03001627',
'airplane': '02691156',
'car': '02958343',
'watercraft': '04530566'}
shapenet = {'chair': '03001627', obj_paths = []
'airplane': '02691156', for cls in obj_classes:
'car': '02958343', if cls not in shapenet.keys():
'watercraft': '04530566'} raise Exception('unknown class name')
ids = shapenet[cls]
obj_path = sorted(Path(f'{shapenet_dir}/{ids}').glob('**/models/*.obj'))
obj_paths += obj_path[:num_perclass]
print(f'found {len(obj_paths)} object paths')
obj_paths = [] objs = []
for cls in obj_classes: for obj_path in obj_paths:
if cls not in shapenet.keys(): print(f'load {obj_path}')
raise Exception('unknown class name') v, f, _, n = co.io3d.read_obj(obj_path)
ids = shapenet[cls] diffs = v.max(axis=0) - v.min(axis=0)
obj_path = sorted(Path(f'{shapenet_dir}/{ids}').glob('**/models/*.obj')) v /= (0.5 * diffs.max())
obj_paths += obj_path[:num_perclass] v -= (v.min(axis=0) + 1)
print(f'found {len(obj_paths)} object paths') f = f.astype(np.int32)
objs.append((v, f, n))
print(f'loaded {len(objs)} objects')
objs = [] return objs
for obj_path in obj_paths:
print(f'load {obj_path}')
v, f, _, n = co.io3d.read_obj(obj_path)
diffs = v.max(axis=0) - v.min(axis=0)
v /= (0.5 * diffs.max())
v -= (v.min(axis=0) + 1)
f = f.astype(np.int32)
objs.append((v,f,n))
print(f'loaded {len(objs)} objects')
return objs
def get_mesh(rng, min_z=0): def get_mesh(rng, min_z=0):
# set up background board # set up background board
verts, faces, normals, colors = [], [], [], [] verts, faces, normals, colors = [], [], [], []
v, f, n = co.geometry.xyplane(z=0, interleaved=True) v, f, n = co.geometry.xyplane(z=0, interleaved=True)
v[:,2] += -v[:,2].min() + rng.uniform(2,7) v[:, 2] += -v[:, 2].min() + rng.uniform(2, 7)
v[:,:2] *= 5e2 v[:, :2] *= 5e2
v[:,2] = np.mean(v[:,2]) + (v[:,2] - np.mean(v[:,2])) * 5e2 v[:, 2] = np.mean(v[:, 2]) + (v[:, 2] - np.mean(v[:, 2])) * 5e2
c = np.empty_like(v)
c[:] = rng.uniform(0,1, size=(3,)).astype(np.float32)
verts.append(v)
faces.append(f)
normals.append(n)
colors.append(c)
# randomly sample 4 foreground objects for each scene
for shape_idx in range(4):
v, f, n = objs[rng.randint(0,len(objs))]
v, f, n = v.copy(), f.copy(), n.copy()
s = rng.uniform(0.25, 1)
v *= s
R = co.geometry.rotm_from_quat(co.geometry.quat_random(rng=rng))
v = v @ R.T
n = n @ R.T
v[:,2] += -v[:,2].min() + min_z + rng.uniform(0.5, 3)
v[:,:2] += rng.uniform(-1, 1, size=(1,2))
c = np.empty_like(v) c = np.empty_like(v)
c[:] = rng.uniform(0,1, size=(3,)).astype(np.float32) c[:] = rng.uniform(0, 1, size=(3,)).astype(np.float32)
verts.append(v)
verts.append(v.astype(np.float32))
faces.append(f) faces.append(f)
normals.append(n) normals.append(n)
colors.append(c) colors.append(c)
verts, faces = co.geometry.stack_mesh(verts, faces) # randomly sample 4 foreground objects for each scene
normals = np.vstack(normals).astype(np.float32) for shape_idx in range(4):
colors = np.vstack(colors).astype(np.float32) v, f, n = objs[rng.randint(0, len(objs))]
return verts, faces, colors, normals v, f, n = v.copy(), f.copy(), n.copy()
s = rng.uniform(0.25, 1)
v *= s
R = co.geometry.rotm_from_quat(co.geometry.quat_random(rng=rng))
v = v @ R.T
n = n @ R.T
v[:, 2] += -v[:, 2].min() + min_z + rng.uniform(0.5, 3)
v[:, :2] += rng.uniform(-1, 1, size=(1, 2))
c = np.empty_like(v)
c[:] = rng.uniform(0, 1, size=(3,)).astype(np.float32)
verts.append(v.astype(np.float32))
faces.append(f)
normals.append(n)
colors.append(c)
verts, faces = co.geometry.stack_mesh(verts, faces)
normals = np.vstack(normals).astype(np.float32)
colors = np.vstack(colors).astype(np.float32)
return verts, faces, colors, normals
def create_data(out_root, idx, n_samples, imsize, patterns, K, baseline, blend_im, noise, track_length=4): def create_data(out_root, idx, n_samples, imsize, patterns, K, baseline, blend_im, noise, track_length=4):
tic = time.time()
rng = np.random.RandomState()
tic = time.time() rng.seed(idx)
rng = np.random.RandomState()
rng.seed(idx) verts, faces, colors, normals = get_mesh(rng)
data = renderer.PyRenderInput(verts=verts.copy(), colors=colors.copy(), normals=normals.copy(), faces=faces.copy())
print(f'loading mesh for sample {idx + 1}/{n_samples} took {time.time() - tic}[s]')
verts, faces, colors, normals = get_mesh(rng) # let the camera point to the center
data = renderer.PyRenderInput(verts=verts.copy(), colors=colors.copy(), normals=normals.copy(), faces=faces.copy()) center = np.array([0, 0, 3], dtype=np.float32)
print(f'loading mesh for sample {idx+1}/{n_samples} took {time.time()-tic}[s]')
basevec = np.array([-baseline, 0, 0], dtype=np.float32)
unit = np.array([0, 0, 1], dtype=np.float32)
cam_x_ = rng.uniform(-0.2, 0.2)
cam_y_ = rng.uniform(-0.2, 0.2)
cam_z_ = rng.uniform(-0.2, 0.2)
ret = collections.defaultdict(list)
blend_im_rnd = np.clip(blend_im + rng.uniform(-0.1, 0.1), 0, 1)
# capture the same static scene from different view points as a track
for ind in range(track_length):
cam_x = cam_x_ + rng.uniform(-0.1, 0.1)
cam_y = cam_y_ + rng.uniform(-0.1, 0.1)
cam_z = cam_z_ + rng.uniform(-0.1, 0.1)
tcam = np.array([cam_x, cam_y, cam_z], dtype=np.float32)
if np.linalg.norm(tcam[0:2]) < 1e-9:
Rcam = np.eye(3, dtype=np.float32)
else:
Rcam = get_rotation_matrix(center, center - tcam)
tproj = tcam + basevec
Rproj = Rcam
ret['R'].append(Rcam)
ret['t'].append(tcam)
cams = []
projs = []
# render the scene at multiple scales
scales = [1, 0.5, 0.25, 0.125]
for scale in scales:
fx = K[0, 0] * scale
fy = K[1, 1] * scale
px = K[0, 2] * scale
py = K[1, 2] * scale
im_height = imsize[0] * scale
im_width = imsize[1] * scale
cams.append(renderer.PyCamera(fx, fy, px, py, Rcam, tcam, im_width, im_height))
projs.append(renderer.PyCamera(fx, fy, px, py, Rproj, tproj, im_width, im_height))
for s, cam, proj, pattern in zip(itertools.count(), cams, projs, patterns):
fl = K[0, 0] / (2 ** s)
shader = renderer.PyShader(0.5, 1.5, 0.0, 10)
pyrenderer = renderer.PyRenderer(cam, shader, engine='gpu')
pyrenderer.mesh_proj(data, proj, pattern, d_alpha=0, d_beta=0.35)
# get the reflected laser pattern $R$
im = pyrenderer.color().copy()
depth = pyrenderer.depth().copy()
disp = baseline * fl / depth
mask = depth > 0
im = np.mean(im, axis=2)
# get the ambient image $A$
ambient = pyrenderer.normal().copy()
ambient = np.mean(ambient, axis=2)
# get the noise free IR image $J$
im = blend_im_rnd * im + (1 - blend_im_rnd) * ambient
ret[f'ambient{s}'].append(ambient[None].astype(np.float32))
# get the gradient magnitude of the ambient image $|\nabla A|$
ambient = ambient.astype(np.float32)
sobelx = cv2.Sobel(ambient, cv2.CV_32F, 1, 0, ksize=5)
sobely = cv2.Sobel(ambient, cv2.CV_32F, 0, 1, ksize=5)
grad = np.sqrt(sobelx ** 2 + sobely ** 2)
grad = np.maximum(grad - 0.8, 0.0) # parameter
# get the local contract normalized grad LCN($|\nabla A|$)
grad_lcn, grad_std = lcn.normalize(grad, 5, 0.1)
grad_lcn = np.clip(grad_lcn, 0.0, 1.0) # parameter
ret[f'grad{s}'].append(grad_lcn[None].astype(np.float32))
ret[f'im{s}'].append(im[None].astype(np.float32))
ret[f'mask{s}'].append(mask[None].astype(np.float32))
ret[f'disp{s}'].append(disp[None].astype(np.float32))
for key in ret.keys():
ret[key] = np.stack(ret[key], axis=0)
# save to files
out_dir = out_root / f'{idx:08d}'
out_dir.mkdir(exist_ok=True, parents=True)
for k, val in ret.items():
for tidx in range(track_length):
v = val[tidx]
out_path = out_dir / f'{k}_{tidx}.npy'
np.save(out_path, v)
np.save(str(out_dir / 'blend_im.npy'), blend_im_rnd)
print(f'create sample {idx + 1}/{n_samples} took {time.time() - tic}[s]')
# let the camera point to the center if __name__ == '__main__':
center = np.array([0,0,3], dtype=np.float32)
basevec = np.array([-baseline,0,0], dtype=np.float32) np.random.seed(42)
unit = np.array([0,0,1],dtype=np.float32)
cam_x_ = rng.uniform(-0.2,0.2) # output directory
cam_y_ = rng.uniform(-0.2,0.2) with open('../config.json') as fp:
cam_z_ = rng.uniform(-0.2,0.2) config = json.load(fp)
data_root = Path(config['DATA_ROOT'])
shapenet_root = config['SHAPENET_ROOT']
ret = collections.defaultdict(list) data_type = 'syn'
blend_im_rnd = np.clip(blend_im + rng.uniform(-0.1,0.1), 0,1) out_root = data_root / f'{data_type}'
out_root.mkdir(parents=True, exist_ok=True)
# capture the same static scene from different view points as a track start = 0
for ind in range(track_length): if len(sys.argv) >= 2 and isinstance(sys.argv[2], int):
start = sys.argv[2]
cam_x = cam_x_ + rng.uniform(-0.1,0.1)
cam_y = cam_y_ + rng.uniform(-0.1,0.1)
cam_z = cam_z_ + rng.uniform(-0.1,0.1)
tcam = np.array([cam_x, cam_y, cam_z], dtype=np.float32)
if np.linalg.norm(tcam[0:2])<1e-9:
Rcam = np.eye(3, dtype=np.float32)
else: else:
Rcam = get_rotation_matrix(center, center-tcam) if sys.argv[2] == '--resume':
try:
start = max([int(dir) for dir in os.listdir(out_root) if str.isdigit(dir)]) or 0
except:
pass
tproj = tcam + basevec # load shapenet models
Rproj = Rcam obj_classes = ['chair']
objs = get_objs(shapenet_root, obj_classes)
ret['R'].append(Rcam) # camera parameters
ret['t'].append(tcam) imsize = (488, 648)
imsizes = [(imsize[0] // (2 ** s), imsize[1] // (2 ** s)) for s in range(4)]
# K = np.array([[567.6, 0, 324.7], [0, 570.2, 250.1], [0 ,0, 1]], dtype=np.float32)
K = np.array([[1929.5936336276382, 0, 113.66561071478046], [0, 1911.2517985448746, 473.70108079885887], [0, 0, 1]],
dtype=np.float32)
focal_lengths = [K[0, 0] / (2 ** s) for s in range(4)]
baseline = 0.075
blend_im = 0.6
noise = 0
cams = [] # capture the same static scene from different view points as a track
projs = [] track_length = 4
# render the scene at multiple scales # load pattern image
scales = [1, 0.5, 0.25, 0.125] pattern_path = './kinect_pattern.png'
pattern_crop = True
patterns = get_patterns(pattern_path, imsizes, pattern_crop)
for scale in scales: # write settings to file
fx = K[0,0] * scale settings = {
fy = K[1,1] * scale 'imsizes': imsizes,
px = K[0,2] * scale 'patterns': patterns,
py = K[1,2] * scale 'focal_lengths': focal_lengths,
im_height = imsize[0] * scale 'baseline': baseline,
im_width = imsize[1] * scale 'K': K,
cams.append( renderer.PyCamera(fx,fy,px,py, Rcam, tcam, im_width, im_height) ) }
projs.append( renderer.PyCamera(fx,fy,px,py, Rproj, tproj, im_width, im_height) ) out_path = out_root / f'settings.pkl'
print(f'write settings to {out_path}')
with open(str(out_path), 'wb') as f:
pickle.dump(settings, f, pickle.HIGHEST_PROTOCOL)
# start the job
for s, cam, proj, pattern in zip(itertools.count(), cams, projs, patterns): n_samples = 2 ** 10 + 2 ** 13
fl = K[0,0] / (2**s) for idx in range(start, n_samples):
args = (out_root, idx, n_samples, imsize, patterns, K, baseline, blend_im, noise, track_length)
shader = renderer.PyShader(0.5,1.5,0.0,10) create_data(*args)
pyrenderer = renderer.PyRenderer(cam, shader, engine='gpu')
pyrenderer.mesh_proj(data, proj, pattern, d_alpha=0, d_beta=0.35)
# get the reflected laser pattern $R$
im = pyrenderer.color().copy()
depth = pyrenderer.depth().copy()
disp = baseline * fl / depth
mask = depth > 0
im = np.mean(im, axis=2)
# get the ambient image $A$
ambient = pyrenderer.normal().copy()
ambient = np.mean(ambient, axis=2)
# get the noise free IR image $J$
im = blend_im_rnd * im + (1 - blend_im_rnd) * ambient
ret[f'ambient{s}'].append( ambient[None].astype(np.float32) )
# get the gradient magnitude of the ambient image $|\nabla A|$
ambient = ambient.astype(np.float32)
sobelx = cv2.Sobel(ambient,cv2.CV_32F,1,0,ksize=5)
sobely = cv2.Sobel(ambient,cv2.CV_32F,0,1,ksize=5)
grad = np.sqrt(sobelx**2 + sobely**2)
grad = np.maximum(grad-0.8,0.0) # parameter
# get the local contract normalized grad LCN($|\nabla A|$)
grad_lcn, grad_std = lcn.normalize(grad,5,0.1)
grad_lcn = np.clip(grad_lcn,0.0,1.0) # parameter
ret[f'grad{s}'].append( grad_lcn[None].astype(np.float32))
ret[f'im{s}'].append( im[None].astype(np.float32))
ret[f'mask{s}'].append(mask[None].astype(np.float32))
ret[f'disp{s}'].append(disp[None].astype(np.float32))
for key in ret.keys():
ret[key] = np.stack(ret[key], axis=0)
# save to files
out_dir = out_root / f'{idx:08d}'
out_dir.mkdir(exist_ok=True, parents=True)
for k,val in ret.items():
for tidx in range(track_length):
v = val[tidx]
out_path = out_dir / f'{k}_{tidx}.npy'
np.save(out_path, v)
np.save( str(out_dir /'blend_im.npy'), blend_im_rnd)
print(f'create sample {idx+1}/{n_samples} took {time.time()-tic}[s]')
if __name__=='__main__':
np.random.seed(42)
# output directory
with open('../config.json') as fp:
config = json.load(fp)
data_root = Path(config['DATA_ROOT'])
shapenet_root = config['SHAPENET_ROOT']
data_type = 'syn'
out_root = data_root / f'{data_type}'
out_root.mkdir(parents=True, exist_ok=True)
start = 0
if len(sys.argv) >= 2 and isinstance(sys.argv[2], int):
start = sys.argv[2]
else:
if sys.argv[2] == '--resume':
try:
start = max([int(dir) for dir in os.listdir(out_root) if str.isdigit(dir)]) or 0
except:
pass
# load shapenet models
obj_classes = ['chair']
objs = get_objs(shapenet_root, obj_classes)
# camera parameters
imsize = (488, 648)
imsizes = [(imsize[0]//(2**s), imsize[1]//(2**s)) for s in range(4)]
# K = np.array([[567.6, 0, 324.7], [0, 570.2, 250.1], [0 ,0, 1]], dtype=np.float32)
K = np.array([[1929.5936336276382, 0, 113.66561071478046], [0, 1911.2517985448746, 473.70108079885887], [0 ,0, 1]], dtype=np.float32)
focal_lengths = [K[0,0]/(2**s) for s in range(4)]
baseline=0.075
blend_im = 0.6
noise = 0
# capture the same static scene from different view points as a track
track_length = 4
# load pattern image
pattern_path = './kinect_pattern.png'
pattern_crop = True
patterns = get_patterns(pattern_path, imsizes, pattern_crop)
# write settings to file
settings = {
'imsizes': imsizes,
'patterns': patterns,
'focal_lengths': focal_lengths,
'baseline': baseline,
'K': K,
}
out_path = out_root / f'settings.pkl'
print(f'write settings to {out_path}')
with open(str(out_path), 'wb') as f:
pickle.dump(settings, f, pickle.HIGHEST_PROTOCOL)
# start the job
n_samples = 2**10 + 2**13
for idx in range(start, n_samples):
args = (out_root, idx, n_samples, imsize, patterns, K, baseline, blend_im, noise, track_length)
create_data(*args)

View File

@ -21,128 +21,128 @@ from .commons import get_patterns, augment_image
from mpl_toolkits.mplot3d import Axes3D from mpl_toolkits.mplot3d import Axes3D
class TrackSynDataset(torchext.BaseDataset): class TrackSynDataset(torchext.BaseDataset):
''' '''
Load locally saved synthetic dataset Load locally saved synthetic dataset
Please run ./create_syn_data.sh to generate the dataset Please run ./create_syn_data.sh to generate the dataset
''' '''
def __init__(self, settings_path, sample_paths, track_length=2, train=True, data_aug=False):
super().__init__(train=train)
self.settings_path = settings_path def __init__(self, settings_path, sample_paths, track_length=2, train=True, data_aug=False):
self.sample_paths = sample_paths super().__init__(train=train)
self.data_aug = data_aug
self.train = train
self.track_length=track_length
assert(track_length<=4)
with open(str(settings_path), 'rb') as f: self.settings_path = settings_path
settings = pickle.load(f) self.sample_paths = sample_paths
self.imsizes = settings['imsizes'] self.data_aug = data_aug
self.patterns = settings['patterns'] self.train = train
self.focal_lengths = settings['focal_lengths'] self.track_length = track_length
self.baseline = settings['baseline'] assert (track_length <= 4)
self.K = settings['K']
self.scale = len(self.imsizes) with open(str(settings_path), 'rb') as f:
settings = pickle.load(f)
self.imsizes = settings['imsizes']
self.patterns = settings['patterns']
self.focal_lengths = settings['focal_lengths']
self.baseline = settings['baseline']
self.K = settings['K']
self.max_shift=0 self.scale = len(self.imsizes)
self.max_blur=0.5
self.max_noise=3.0
self.max_sp_noise=0.0005
def __len__(self): self.max_shift = 0
return len(self.sample_paths) self.max_blur = 0.5
self.max_noise = 3.0
self.max_sp_noise = 0.0005
def __getitem__(self, idx): def __len__(self):
if not self.train: return len(self.sample_paths)
rng = self.get_rng(idx)
else:
rng = np.random.RandomState()
sample_path = self.sample_paths[idx]
if self.train: def __getitem__(self, idx):
track_ind = np.random.permutation(4)[0:self.track_length] if not self.train:
else: rng = self.get_rng(idx)
track_ind = [0]
ret = {}
ret['id'] = idx
# load imgs, at all scales
for sidx in range(len(self.imsizes)):
imgs = []
ambs = []
grads = []
for tidx in track_ind:
imgs.append(np.load(os.path.join(sample_path,f'im{sidx}_{tidx}.npy')))
ambs.append(np.load(os.path.join(sample_path,f'ambient{sidx}_{tidx}.npy')))
grads.append(np.load(os.path.join(sample_path,f'grad{sidx}_{tidx}.npy')))
ret[f'im{sidx}'] = np.stack(imgs, axis=0)
ret[f'ambient{sidx}'] = np.stack(ambs, axis=0)
ret[f'grad{sidx}'] = np.stack(grads, axis=0)
# load disp and grad only at full resolution
disps = []
R = []
t = []
for tidx in track_ind:
disps.append(np.load(os.path.join(sample_path,f'disp0_{tidx}.npy')))
R.append(np.load(os.path.join(sample_path,f'R_{tidx}.npy')))
t.append(np.load(os.path.join(sample_path,f't_{tidx}.npy')))
ret[f'disp0'] = np.stack(disps, axis=0)
ret['R'] = np.stack(R, axis=0)
ret['t'] = np.stack(t, axis=0)
blend_im = np.load(os.path.join(sample_path,'blend_im.npy'))
ret['blend_im'] = blend_im.astype(np.float32)
#### apply data augmentation at different scales seperately, only work for max_shift=0
if self.data_aug:
for sidx in range(len(self.imsizes)):
if sidx==0:
img = ret[f'im{sidx}']
disp = ret[f'disp{sidx}']
grad = ret[f'grad{sidx}']
img_aug = np.zeros_like(img)
disp_aug = np.zeros_like(img)
grad_aug = np.zeros_like(img)
for i in range(img.shape[0]):
img_aug_, disp_aug_, grad_aug_ = augment_image(img[i,0],rng,
disp=disp[i,0],grad=grad[i,0],
max_shift=self.max_shift, max_blur=self.max_blur,
max_noise=self.max_noise, max_sp_noise=self.max_sp_noise)
img_aug[i] = img_aug_[None].astype(np.float32)
disp_aug[i] = disp_aug_[None].astype(np.float32)
grad_aug[i] = grad_aug_[None].astype(np.float32)
ret[f'im{sidx}'] = img_aug
ret[f'disp{sidx}'] = disp_aug
ret[f'grad{sidx}'] = grad_aug
else: else:
img = ret[f'im{sidx}'] rng = np.random.RandomState()
img_aug = np.zeros_like(img) sample_path = self.sample_paths[idx]
for i in range(img.shape[0]):
img_aug_, _, _ = augment_image(img[i,0],rng,
max_shift=self.max_shift, max_blur=self.max_blur,
max_noise=self.max_noise, max_sp_noise=self.max_sp_noise)
img_aug[i] = img_aug_[None].astype(np.float32)
ret[f'im{sidx}'] = img_aug
if len(track_ind)==1: if self.train:
for key, val in ret.items(): track_ind = np.random.permutation(4)[0:self.track_length]
if key!='blend_im' and key!='id': else:
ret[key] = val[0] track_ind = [0]
ret = {}
ret['id'] = idx
return ret # load imgs, at all scales
for sidx in range(len(self.imsizes)):
imgs = []
ambs = []
grads = []
for tidx in track_ind:
imgs.append(np.load(os.path.join(sample_path, f'im{sidx}_{tidx}.npy')))
ambs.append(np.load(os.path.join(sample_path, f'ambient{sidx}_{tidx}.npy')))
grads.append(np.load(os.path.join(sample_path, f'grad{sidx}_{tidx}.npy')))
ret[f'im{sidx}'] = np.stack(imgs, axis=0)
ret[f'ambient{sidx}'] = np.stack(ambs, axis=0)
ret[f'grad{sidx}'] = np.stack(grads, axis=0)
def getK(self, sidx=0): # load disp and grad only at full resolution
K = self.K.copy() / (2**sidx) disps = []
K[2,2] = 1 R = []
return K t = []
for tidx in track_ind:
disps.append(np.load(os.path.join(sample_path, f'disp0_{tidx}.npy')))
R.append(np.load(os.path.join(sample_path, f'R_{tidx}.npy')))
t.append(np.load(os.path.join(sample_path, f't_{tidx}.npy')))
ret[f'disp0'] = np.stack(disps, axis=0)
ret['R'] = np.stack(R, axis=0)
ret['t'] = np.stack(t, axis=0)
blend_im = np.load(os.path.join(sample_path, 'blend_im.npy'))
ret['blend_im'] = blend_im.astype(np.float32)
#### apply data augmentation at different scales seperately, only work for max_shift=0
if self.data_aug:
for sidx in range(len(self.imsizes)):
if sidx == 0:
img = ret[f'im{sidx}']
disp = ret[f'disp{sidx}']
grad = ret[f'grad{sidx}']
img_aug = np.zeros_like(img)
disp_aug = np.zeros_like(img)
grad_aug = np.zeros_like(img)
for i in range(img.shape[0]):
img_aug_, disp_aug_, grad_aug_ = augment_image(img[i, 0], rng,
disp=disp[i, 0], grad=grad[i, 0],
max_shift=self.max_shift, max_blur=self.max_blur,
max_noise=self.max_noise,
max_sp_noise=self.max_sp_noise)
img_aug[i] = img_aug_[None].astype(np.float32)
disp_aug[i] = disp_aug_[None].astype(np.float32)
grad_aug[i] = grad_aug_[None].astype(np.float32)
ret[f'im{sidx}'] = img_aug
ret[f'disp{sidx}'] = disp_aug
ret[f'grad{sidx}'] = grad_aug
else:
img = ret[f'im{sidx}']
img_aug = np.zeros_like(img)
for i in range(img.shape[0]):
img_aug_, _, _ = augment_image(img[i, 0], rng,
max_shift=self.max_shift, max_blur=self.max_blur,
max_noise=self.max_noise, max_sp_noise=self.max_sp_noise)
img_aug[i] = img_aug_[None].astype(np.float32)
ret[f'im{sidx}'] = img_aug
if len(track_ind) == 1:
for key, val in ret.items():
if key != 'blend_im' and key != 'id':
ret[key] = val[0]
return ret
def getK(self, sidx=0):
K = self.K.copy() / (2 ** sidx)
K[2, 2] = 1
return K
if __name__ == '__main__': if __name__ == '__main__':
pass pass

View File

@ -2,7 +2,7 @@
<!-- Generated by Cython 0.29 --> <!-- Generated by Cython 0.29 -->
<html> <html>
<head> <head>
<meta http-equiv="Content-Type" content="text/html; charset=utf-8" /> <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
<title>Cython: lcn.pyx</title> <title>Cython: lcn.pyx</title>
<style type="text/css"> <style type="text/css">
@ -355,17 +355,23 @@ body.cython { font-family: courier; font-size: 12; }
.cython .vi { color: #19177C } /* Name.Variable.Instance */ .cython .vi { color: #19177C } /* Name.Variable.Instance */
.cython .vm { color: #19177C } /* Name.Variable.Magic */ .cython .vm { color: #19177C } /* Name.Variable.Magic */
.cython .il { color: #666666 } /* Literal.Number.Integer.Long */ .cython .il { color: #666666 } /* Literal.Number.Integer.Long */
</style> </style>
</head> </head>
<body class="cython"> <body class="cython">
<p><span style="border-bottom: solid 1px grey;">Generated by Cython 0.29</span></p> <p><span style="border-bottom: solid 1px grey;">Generated by Cython 0.29</span></p>
<p> <p>
<span style="background-color: #FFFF00">Yellow lines</span> hint at Python interaction.<br /> <span style="background-color: #FFFF00">Yellow lines</span> hint at Python interaction.<br/>
Click on a line that starts with a "<code>+</code>" to see the C code that Cython generated for it. Click on a line that starts with a "<code>+</code>" to see the C code that Cython generated for it.
</p> </p>
<p>Raw output: <a href="lcn.c">lcn.c</a></p> <p>Raw output: <a href="lcn.c">lcn.c</a></p>
<div class="cython"><pre class="cython line score-16" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">01</span>: <span class="k">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span></pre> <div class="cython">
<pre class='cython code score-16 '> __pyx_t_1 = <span class='pyx_c_api'>__Pyx_Import</span>(__pyx_n_s_numpy, 0, -1);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1, __pyx_L1_error)</span> <pre class="cython line score-16"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">01</span>: <span class="k">import</span> <span class="nn">numpy</span> <span
class="k">as</span> <span class="nn">np</span></pre>
<pre class='cython code score-16 '> __pyx_t_1 = <span class='pyx_c_api'>__Pyx_Import</span>(__pyx_n_s_numpy, 0, -1);<span
class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_np, __pyx_t_1) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 1, __pyx_L1_error)</span> if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_np, __pyx_t_1) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 1, __pyx_L1_error)</span>
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
@ -374,22 +380,39 @@ body.cython { font-family: courier; font-size: 12; }
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_test, __pyx_t_1) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 1, __pyx_L1_error)</span> if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_test, __pyx_t_1) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 1, __pyx_L1_error)</span>
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
</pre><pre class="cython line score-0">&#xA0;<span class="">02</span>: <span class="k">cimport</span> <span class="nn">cython</span></pre> </pre>
<pre class="cython line score-0">&#xA0;<span class="">03</span>: </pre> <pre class="cython line score-0">&#xA0;<span class="">02</span>: <span class="k">cimport</span> <span class="nn">cython</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">04</span>: <span class="c"># use c square root function</span></pre> <pre class="cython line score-0">&#xA0;<span class="">03</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">05</span>: <span class="k">cdef</span> <span class="kr">extern</span> <span class="k">from</span> <span class="s">&quot;math.h&quot;</span><span class="p">:</span></pre> <pre class="cython line score-0">&#xA0;<span class="">04</span>: <span class="c"># use c square root function</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">06</span>: <span class="nb">float</span> <span class="n">sqrt</span><span class="p">(</span><span class="nb">float</span> <span class="n">x</span><span class="p">)</span></pre> <pre class="cython line score-0">&#xA0;<span class="">05</span>: <span class="k">cdef</span> <span
<pre class="cython line score-0">&#xA0;<span class="">07</span>: </pre> class="kr">extern</span> <span class="k">from</span> <span class="s">&quot;math.h&quot;</span><span
<pre class="cython line score-0">&#xA0;<span class="">08</span>: <span class="nd">@cython</span><span class="o">.</span><span class="n">boundscheck</span><span class="p">(</span><span class="bp">False</span><span class="p">)</span></pre> class="p">:</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">09</span>: <span class="nd">@cython</span><span class="o">.</span><span class="n">wraparound</span><span class="p">(</span><span class="bp">False</span><span class="p">)</span></pre> <pre class="cython line score-0">&#xA0;<span class="">06</span>: <span class="nb">float</span> <span class="n">sqrt</span><span
<pre class="cython line score-0">&#xA0;<span class="">10</span>: <span class="nd">@cython</span><span class="o">.</span><span class="n">cdivision</span><span class="p">(</span><span class="bp">True</span><span class="p">)</span></pre> class="p">(</span><span class="nb">float</span> <span class="n">x</span><span class="p">)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">11</span>: </pre> <pre class="cython line score-0">&#xA0;<span class="">07</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">12</span>: <span class="c"># 3 parameters:</span></pre> <pre class="cython line score-0">&#xA0;<span class="">08</span>: <span class="nd">@cython</span><span
<pre class="cython line score-0">&#xA0;<span class="">13</span>: <span class="c"># - float image</span></pre> class="o">.</span><span class="n">boundscheck</span><span class="p">(</span><span
<pre class="cython line score-0">&#xA0;<span class="">14</span>: <span class="c"># - kernel size (actually this is the radius, kernel is 2*k+1)</span></pre> class="bp">False</span><span class="p">)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">15</span>: <span class="c"># - small constant epsilon that is used to avoid division by zero</span></pre> <pre class="cython line score-0">&#xA0;<span class="">09</span>: <span class="nd">@cython</span><span
<pre class="cython line score-67" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">16</span>: <span class="k">def</span> <span class="nf">normalize</span><span class="p">(</span><span class="nb">float</span><span class="p">[:,</span> <span class="p">:]</span> <span class="n">img</span><span class="p">,</span> <span class="nb">int</span> <span class="n">kernel_size</span> <span class="o">=</span> <span class="mf">4</span><span class="p">,</span> <span class="nb">float</span> <span class="n">epsilon</span> <span class="o">=</span> <span class="mf">0.01</span><span class="p">):</span></pre> class="o">.</span><span class="n">wraparound</span><span class="p">(</span><span
<pre class='cython code score-67 '>/* Python wrapper */ class="bp">False</span><span class="p">)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">10</span>: <span class="nd">@cython</span><span
class="o">.</span><span class="n">cdivision</span><span class="p">(</span><span class="bp">True</span><span
class="p">)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">11</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">12</span>: <span class="c"># 3 parameters:</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">13</span>: <span class="c"># - float image</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">14</span>: <span class="c"># - kernel size (actually this is the radius, kernel is 2*k+1)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">15</span>: <span class="c"># - small constant epsilon that is used to avoid division by zero</span></pre>
<pre class="cython line score-67"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">16</span>: <span class="k">def</span> <span class="nf">normalize</span><span
class="p">(</span><span class="nb">float</span><span class="p">[:,</span> <span class="p">:]</span> <span
class="n">img</span><span class="p">,</span> <span class="nb">int</span> <span class="n">kernel_size</span> <span
class="o">=</span> <span class="mf">4</span><span class="p">,</span> <span class="nb">float</span> <span
class="n">epsilon</span> <span class="o">=</span> <span class="mf">0.01</span><span
class="p">):</span></pre>
<pre class='cython code score-67 '>/* Python wrapper */
static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/
static PyMethodDef __pyx_mdef_3lcn_1normalize = {"normalize", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_3lcn_1normalize, METH_VARARGS|METH_KEYWORDS, 0}; static PyMethodDef __pyx_mdef_3lcn_1normalize = {"normalize", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_3lcn_1normalize, METH_VARARGS|METH_KEYWORDS, 0};
static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) {
@ -434,7 +457,8 @@ static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_
} }
} }
if (unlikely(kw_args &gt; 0)) { if (unlikely(kw_args &gt; 0)) {
if (unlikely(<span class='pyx_c_api'>__Pyx_ParseOptionalKeywords</span>(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "normalize") &lt; 0)) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span> if (unlikely(<span class='pyx_c_api'>__Pyx_ParseOptionalKeywords</span>(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "normalize") &lt; 0)) <span
class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
} }
} else { } else {
switch (<span class='py_macro_api'>PyTuple_GET_SIZE</span>(__pyx_args)) { switch (<span class='py_macro_api'>PyTuple_GET_SIZE</span>(__pyx_args)) {
@ -447,21 +471,27 @@ static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_
default: goto __pyx_L5_argtuple_error; default: goto __pyx_L5_argtuple_error;
} }
} }
__pyx_v_img = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(values[0], PyBUF_WRITABLE);<span class='error_goto'> if (unlikely(!__pyx_v_img.memview)) __PYX_ERR(0, 16, __pyx_L3_error)</span> __pyx_v_img = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(values[0], PyBUF_WRITABLE);<span
class='error_goto'> if (unlikely(!__pyx_v_img.memview)) __PYX_ERR(0, 16, __pyx_L3_error)</span>
if (values[1]) { if (values[1]) {
__pyx_v_kernel_size = <span class='pyx_c_api'>__Pyx_PyInt_As_int</span>(values[1]); if (unlikely((__pyx_v_kernel_size == (int)-1) &amp;&amp; <span class='py_c_api'>PyErr_Occurred</span>())) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span> __pyx_v_kernel_size = <span class='pyx_c_api'>__Pyx_PyInt_As_int</span>(values[1]); if (unlikely((__pyx_v_kernel_size == (int)-1) &amp;&amp; <span
class='py_c_api'>PyErr_Occurred</span>())) <span
class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
} else { } else {
__pyx_v_kernel_size = ((int)4); __pyx_v_kernel_size = ((int)4);
} }
if (values[2]) { if (values[2]) {
__pyx_v_epsilon = __pyx_<span class='py_c_api'>PyFloat_AsFloat</span>(values[2]); if (unlikely((__pyx_v_epsilon == (float)-1) &amp;&amp; <span class='py_c_api'>PyErr_Occurred</span>())) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span> __pyx_v_epsilon = __pyx_<span class='py_c_api'>PyFloat_AsFloat</span>(values[2]); if (unlikely((__pyx_v_epsilon == (float)-1) &amp;&amp; <span
class='py_c_api'>PyErr_Occurred</span>())) <span
class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
} else { } else {
__pyx_v_epsilon = ((float)0.01); __pyx_v_epsilon = ((float)0.01);
} }
} }
goto __pyx_L4_argument_unpacking_done; goto __pyx_L4_argument_unpacking_done;
__pyx_L5_argtuple_error:; __pyx_L5_argtuple_error:;
<span class='pyx_c_api'>__Pyx_RaiseArgtupleInvalid</span>("normalize", 0, 1, 3, <span class='py_macro_api'>PyTuple_GET_SIZE</span>(__pyx_args)); <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span> <span class='pyx_c_api'>__Pyx_RaiseArgtupleInvalid</span>("normalize", 0, 1, 3, <span class='py_macro_api'>PyTuple_GET_SIZE</span>(__pyx_args)); <span
class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
__pyx_L3_error:; __pyx_L3_error:;
<span class='pyx_c_api'>__Pyx_AddTraceback</span>("lcn.normalize", __pyx_clineno, __pyx_lineno, __pyx_filename); <span class='pyx_c_api'>__Pyx_AddTraceback</span>("lcn.normalize", __pyx_clineno, __pyx_lineno, __pyx_filename);
<span class='refnanny'>__Pyx_RefNannyFinishContext</span>(); <span class='refnanny'>__Pyx_RefNannyFinishContext</span>();
@ -515,27 +545,49 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
return __pyx_r; return __pyx_r;
} }
/* … */ /* … */
__pyx_tuple__19 = <span class='py_c_api'>PyTuple_Pack</span>(19, __pyx_n_s_img, __pyx_n_s_kernel_size, __pyx_n_s_epsilon, __pyx_n_s_M, __pyx_n_s_N, __pyx_n_s_img_lcn, __pyx_n_s_img_std, __pyx_n_s_img_lcn_view, __pyx_n_s_img_std_view, __pyx_n_s_tmp, __pyx_n_s_mean, __pyx_n_s_stddev, __pyx_n_s_m, __pyx_n_s_n, __pyx_n_s_i, __pyx_n_s_j, __pyx_n_s_ks, __pyx_n_s_eps, __pyx_n_s_num);<span class='error_goto'> if (unlikely(!__pyx_tuple__19)) __PYX_ERR(0, 16, __pyx_L1_error)</span> __pyx_tuple__19 = <span class='py_c_api'>PyTuple_Pack</span>(19, __pyx_n_s_img, __pyx_n_s_kernel_size, __pyx_n_s_epsilon, __pyx_n_s_M, __pyx_n_s_N, __pyx_n_s_img_lcn, __pyx_n_s_img_std, __pyx_n_s_img_lcn_view, __pyx_n_s_img_std_view, __pyx_n_s_tmp, __pyx_n_s_mean, __pyx_n_s_stddev, __pyx_n_s_m, __pyx_n_s_n, __pyx_n_s_i, __pyx_n_s_j, __pyx_n_s_ks, __pyx_n_s_eps, __pyx_n_s_num);<span
class='error_goto'> if (unlikely(!__pyx_tuple__19)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_tuple__19); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_tuple__19);
<span class='refnanny'>__Pyx_GIVEREF</span>(__pyx_tuple__19); <span class='refnanny'>__Pyx_GIVEREF</span>(__pyx_tuple__19);
/* … */ /* … */
__pyx_t_1 = PyCFunction_NewEx(&amp;__pyx_mdef_3lcn_1normalize, NULL, __pyx_n_s_lcn);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 16, __pyx_L1_error)</span> __pyx_t_1 = PyCFunction_NewEx(&amp;__pyx_mdef_3lcn_1normalize, NULL, __pyx_n_s_lcn);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_normalize, __pyx_t_1) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L1_error)</span> if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_normalize, __pyx_t_1) &lt; 0) <span
class='error_goto'>__PYX_ERR(0, 16, __pyx_L1_error)</span>
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
__pyx_codeobj__20 = (PyObject*)<span class='pyx_c_api'>__Pyx_PyCode_New</span>(3, 0, 19, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__19, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_lcn_pyx, __pyx_n_s_normalize, 16, __pyx_empty_bytes);<span class='error_goto'> if (unlikely(!__pyx_codeobj__20)) __PYX_ERR(0, 16, __pyx_L1_error)</span> __pyx_codeobj__20 = (PyObject*)<span class='pyx_c_api'>__Pyx_PyCode_New</span>(3, 0, 19, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__19, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_lcn_pyx, __pyx_n_s_normalize, 16, __pyx_empty_bytes);<span
</pre><pre class="cython line score-0">&#xA0;<span class="">17</span>: </pre> class='error_goto'> if (unlikely(!__pyx_codeobj__20)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
<pre class="cython line score-0">&#xA0;<span class="">18</span>: <span class="c"># image dimensions</span></pre> </pre>
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">19</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">M</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mf">0</span><span class="p">]</span></pre> <pre class="cython line score-0">&#xA0;<span class="">17</span>: </pre>
<pre class='cython code score-0 '> __pyx_v_M = (__pyx_v_img.shape[0]); <pre class="cython line score-0">&#xA0;<span class="">18</span>: <span class="c"># image dimensions</span></pre>
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">20</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">N</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mf">1</span><span class="p">]</span></pre> <pre class="cython line score-0"
<pre class='cython code score-0 '> __pyx_v_N = (__pyx_v_img.shape[1]); onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
</pre><pre class="cython line score-0">&#xA0;<span class="">21</span>: </pre> class="">19</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span
<pre class="cython line score-0">&#xA0;<span class="">22</span>: <span class="c"># create outputs and output views</span></pre> class="nf">M</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span
<pre class="cython line score-46" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">23</span>: <span class="n">img_lcn</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span></pre> class="n">shape</span><span class="p">[</span><span class="mf">0</span><span class="p">]</span></pre>
<pre class='cython code score-46 '> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_1, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span> <pre class='cython code score-0 '> __pyx_v_M = (__pyx_v_img.shape[0]);
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">20</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span
class="nf">N</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span
class="n">shape</span><span class="p">[</span><span class="mf">1</span><span class="p">]</span></pre>
<pre class='cython code score-0 '> __pyx_v_N = (__pyx_v_img.shape[1]);
</pre>
<pre class="cython line score-0">&#xA0;<span class="">21</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">22</span>: <span class="c"># create outputs and output views</span></pre>
<pre class="cython line score-46"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">23</span>: <span class="n">img_lcn</span> <span class="o">=</span> <span
class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span
class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span
class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span
class="n">float32</span><span class="p">)</span></pre>
<pre class='cython code score-46 '> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_1, __pyx_n_s_np);<span
class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
__pyx_t_2 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_1, __pyx_n_s_zeros);<span class='error_goto'> if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 23, __pyx_L1_error)</span> __pyx_t_2 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_1, __pyx_n_s_zeros);<span
class='error_goto'> if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_2); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_2);
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
__pyx_t_1 = <span class='py_c_api'>PyInt_FromSsize_t</span>(__pyx_v_M);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span> __pyx_t_1 = <span class='py_c_api'>PyInt_FromSsize_t</span>(__pyx_v_M);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
@ -559,22 +611,34 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_4); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_4);
<span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_1, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_1, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
__pyx_t_5 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_1, __pyx_n_s_float32);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)</span> __pyx_t_5 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_1, __pyx_n_s_float32);<span
class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_t_4, __pyx_n_s_dtype, __pyx_t_5) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 23, __pyx_L1_error)</span> if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_t_4, __pyx_n_s_dtype, __pyx_t_5) &lt; 0) <span
class='error_goto'>__PYX_ERR(0, 23, __pyx_L1_error)</span>
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0;
__pyx_t_5 = <span class='pyx_c_api'>__Pyx_PyObject_Call</span>(__pyx_t_2, __pyx_t_3, __pyx_t_4);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)</span> __pyx_t_5 = <span class='pyx_c_api'>__Pyx_PyObject_Call</span>(__pyx_t_2, __pyx_t_3, __pyx_t_4);<span
class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_2); __pyx_t_2 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_2); __pyx_t_2 = 0;
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_3); __pyx_t_3 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_3); __pyx_t_3 = 0;
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_4); __pyx_t_4 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_4); __pyx_t_4 = 0;
__pyx_v_img_lcn = __pyx_t_5; __pyx_v_img_lcn = __pyx_t_5;
__pyx_t_5 = 0; __pyx_t_5 = 0;
</pre><pre class="cython line score-46" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">24</span>: <span class="n">img_std</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span></pre> </pre>
<pre class='cython code score-46 '> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_5, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span> <pre class="cython line score-46"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">24</span>: <span class="n">img_std</span> <span class="o">=</span> <span
class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span
class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span
class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span
class="n">float32</span><span class="p">)</span></pre>
<pre class='cython code score-46 '> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_5, __pyx_n_s_np);<span
class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
__pyx_t_4 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_5, __pyx_n_s_zeros);<span class='error_goto'> if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 24, __pyx_L1_error)</span> __pyx_t_4 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_5, __pyx_n_s_zeros);<span
class='error_goto'> if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_4); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_4);
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0;
__pyx_t_5 = <span class='py_c_api'>PyInt_FromSsize_t</span>(__pyx_v_M);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span> __pyx_t_5 = <span class='py_c_api'>PyInt_FromSsize_t</span>(__pyx_v_M);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
@ -598,114 +662,236 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_2); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_2);
<span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_5, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_5, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
__pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_5, __pyx_n_s_float32);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)</span> __pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_5, __pyx_n_s_float32);<span
class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0;
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_t_2, __pyx_n_s_dtype, __pyx_t_1) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 24, __pyx_L1_error)</span> if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_t_2, __pyx_n_s_dtype, __pyx_t_1) &lt; 0) <span
class='error_goto'>__PYX_ERR(0, 24, __pyx_L1_error)</span>
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
__pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyObject_Call</span>(__pyx_t_4, __pyx_t_3, __pyx_t_2);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)</span> __pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyObject_Call</span>(__pyx_t_4, __pyx_t_3, __pyx_t_2);<span
class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_4); __pyx_t_4 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_4); __pyx_t_4 = 0;
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_3); __pyx_t_3 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_3); __pyx_t_3 = 0;
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_2); __pyx_t_2 = 0; <span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_2); __pyx_t_2 = 0;
__pyx_v_img_std = __pyx_t_1; __pyx_v_img_std = __pyx_t_1;
__pyx_t_1 = 0; __pyx_t_1 = 0;
</pre><pre class="cython line score-2" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">25</span>: <span class="k">cdef</span> <span class="kt">float</span>[<span class="p">:,</span> <span class="p">:]</span> <span class="n">img_lcn_view</span> <span class="o">=</span> <span class="n">img_lcn</span></pre> </pre>
<pre class='cython code score-2 '> __pyx_t_6 = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(__pyx_v_img_lcn, PyBUF_WRITABLE);<span class='error_goto'> if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 25, __pyx_L1_error)</span> <pre class="cython line score-2"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">25</span>: <span class="k">cdef</span> <span class="kt">float</span>[<span class="p">:,</span> <span
class="p">:]</span> <span class="n">img_lcn_view</span> <span class="o">=</span> <span
class="n">img_lcn</span></pre>
<pre class='cython code score-2 '> __pyx_t_6 = <span
class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(__pyx_v_img_lcn, PyBUF_WRITABLE);<span
class='error_goto'> if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 25, __pyx_L1_error)</span>
__pyx_v_img_lcn_view = __pyx_t_6; __pyx_v_img_lcn_view = __pyx_t_6;
__pyx_t_6.memview = NULL; __pyx_t_6.memview = NULL;
__pyx_t_6.data = NULL; __pyx_t_6.data = NULL;
</pre><pre class="cython line score-2" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">26</span>: <span class="k">cdef</span> <span class="kt">float</span>[<span class="p">:,</span> <span class="p">:]</span> <span class="n">img_std_view</span> <span class="o">=</span> <span class="n">img_std</span></pre> </pre>
<pre class='cython code score-2 '> __pyx_t_6 = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(__pyx_v_img_std, PyBUF_WRITABLE);<span class='error_goto'> if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 26, __pyx_L1_error)</span> <pre class="cython line score-2"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">26</span>: <span class="k">cdef</span> <span class="kt">float</span>[<span class="p">:,</span> <span
class="p">:]</span> <span class="n">img_std_view</span> <span class="o">=</span> <span
class="n">img_std</span></pre>
<pre class='cython code score-2 '> __pyx_t_6 = <span
class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(__pyx_v_img_std, PyBUF_WRITABLE);<span
class='error_goto'> if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 26, __pyx_L1_error)</span>
__pyx_v_img_std_view = __pyx_t_6; __pyx_v_img_std_view = __pyx_t_6;
__pyx_t_6.memview = NULL; __pyx_t_6.memview = NULL;
__pyx_t_6.data = NULL; __pyx_t_6.data = NULL;
</pre><pre class="cython line score-0">&#xA0;<span class="">27</span>: </pre> </pre>
<pre class="cython line score-0">&#xA0;<span class="">28</span>: <span class="c"># temporary c variables</span></pre> <pre class="cython line score-0">&#xA0;<span class="">27</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">29</span>: <span class="k">cdef</span> <span class="kt">float</span> <span class="nf">tmp</span><span class="p">,</span> <span class="nf">mean</span><span class="p">,</span> <span class="nf">stddev</span></pre> <pre class="cython line score-0">&#xA0;<span class="">28</span>: <span class="c"># temporary c variables</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">30</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">m</span><span class="p">,</span> <span class="nf">n</span><span class="p">,</span> <span class="nf">i</span><span class="p">,</span> <span class="nf">j</span></pre> <pre class="cython line score-0">&#xA0;<span class="">29</span>: <span class="k">cdef</span> <span class="kt">float</span> <span
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">31</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">ks</span> <span class="o">=</span> <span class="n">kernel_size</span></pre> class="nf">tmp</span><span class="p">,</span> <span class="nf">mean</span><span class="p">,</span> <span
<pre class='cython code score-0 '> __pyx_v_ks = __pyx_v_kernel_size; class="nf">stddev</span></pre>
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">32</span>: <span class="k">cdef</span> <span class="kt">float</span> <span class="nf">eps</span> <span class="o">=</span> <span class="n">epsilon</span></pre> <pre class="cython line score-0">&#xA0;<span class="">30</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span
<pre class='cython code score-0 '> __pyx_v_eps = __pyx_v_epsilon; class="nf">m</span><span class="p">,</span> <span class="nf">n</span><span class="p">,</span> <span
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">33</span>: <span class="k">cdef</span> <span class="kt">float</span> <span class="nf">num</span> <span class="o">=</span> <span class="p">(</span><span class="n">ks</span><span class="o">*</span><span class="mf">2</span><span class="o">+</span><span class="mf">1</span><span class="p">)</span><span class="o">**</span><span class="mf">2</span></pre> class="nf">i</span><span class="p">,</span> <span class="nf">j</span></pre>
<pre class='cython code score-0 '> __pyx_v_num = __Pyx_pow_Py_ssize_t(((__pyx_v_ks * 2) + 1), 2); <pre class="cython line score-0"
</pre><pre class="cython line score-0">&#xA0;<span class="">34</span>: </pre> onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
<pre class="cython line score-0">&#xA0;<span class="">35</span>: <span class="c"># for all pixels do</span></pre> class="">31</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">36</span>: <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ks</span><span class="p">,</span><span class="n">M</span><span class="o">-</span><span class="n">ks</span><span class="p">):</span></pre> class="nf">ks</span> <span class="o">=</span> <span class="n">kernel_size</span></pre>
<pre class='cython code score-0 '> __pyx_t_7 = (__pyx_v_M - __pyx_v_ks); <pre class='cython code score-0 '> __pyx_v_ks = __pyx_v_kernel_size;
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">32</span>: <span class="k">cdef</span> <span class="kt">float</span> <span
class="nf">eps</span> <span class="o">=</span> <span class="n">epsilon</span></pre>
<pre class='cython code score-0 '> __pyx_v_eps = __pyx_v_epsilon;
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">33</span>: <span class="k">cdef</span> <span class="kt">float</span> <span
class="nf">num</span> <span class="o">=</span> <span class="p">(</span><span class="n">ks</span><span
class="o">*</span><span class="mf">2</span><span class="o">+</span><span class="mf">1</span><span class="p">)</span><span
class="o">**</span><span class="mf">2</span></pre>
<pre class='cython code score-0 '> __pyx_v_num = __Pyx_pow_Py_ssize_t(((__pyx_v_ks * 2) + 1), 2);
</pre>
<pre class="cython line score-0">&#xA0;<span class="">34</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">35</span>: <span
class="c"># for all pixels do</span></pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">36</span>: <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span
class="nb">range</span><span class="p">(</span><span class="n">ks</span><span class="p">,</span><span
class="n">M</span><span class="o">-</span><span class="n">ks</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_7 = (__pyx_v_M - __pyx_v_ks);
__pyx_t_8 = __pyx_t_7; __pyx_t_8 = __pyx_t_7;
for (__pyx_t_9 = __pyx_v_ks; __pyx_t_9 &lt; __pyx_t_8; __pyx_t_9+=1) { for (__pyx_t_9 = __pyx_v_ks; __pyx_t_9 &lt; __pyx_t_8; __pyx_t_9+=1) {
__pyx_v_m = __pyx_t_9; __pyx_v_m = __pyx_t_9;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">37</span>: <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ks</span><span class="p">,</span><span class="n">N</span><span class="o">-</span><span class="n">ks</span><span class="p">):</span></pre> </pre>
<pre class='cython code score-0 '> __pyx_t_10 = (__pyx_v_N - __pyx_v_ks); <pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">37</span>: <span class="k">for</span> <span class="n">n</span> <span
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ks</span><span
class="p">,</span><span class="n">N</span><span class="o">-</span><span class="n">ks</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_10 = (__pyx_v_N - __pyx_v_ks);
__pyx_t_11 = __pyx_t_10; __pyx_t_11 = __pyx_t_10;
for (__pyx_t_12 = __pyx_v_ks; __pyx_t_12 &lt; __pyx_t_11; __pyx_t_12+=1) { for (__pyx_t_12 = __pyx_v_ks; __pyx_t_12 &lt; __pyx_t_11; __pyx_t_12+=1) {
__pyx_v_n = __pyx_t_12; __pyx_v_n = __pyx_t_12;
</pre><pre class="cython line score-0">&#xA0;<span class="">38</span>: </pre> </pre>
<pre class="cython line score-0">&#xA0;<span class="">39</span>: <span class="c"># calculate mean</span></pre> <pre class="cython line score-0">&#xA0;<span class="">38</span>: </pre>
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">40</span>: <span class="n">mean</span> <span class="o">=</span> <span class="mf">0</span><span class="p">;</span></pre> <pre class="cython line score-0">&#xA0;<span class="">39</span>: <span class="c"># calculate mean</span></pre>
<pre class='cython code score-0 '> __pyx_v_mean = 0.0; <pre class="cython line score-0"
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">41</span>: <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre> onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
<pre class='cython code score-0 '> __pyx_t_13 = (__pyx_v_ks + 1); class="">40</span>: <span class="n">mean</span> <span class="o">=</span> <span
class="mf">0</span><span class="p">;</span></pre>
<pre class='cython code score-0 '> __pyx_v_mean = 0.0;
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">41</span>: <span class="k">for</span> <span class="n">i</span> <span
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span
class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span
class="mf">1</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_13 = (__pyx_v_ks + 1);
__pyx_t_14 = __pyx_t_13; __pyx_t_14 = __pyx_t_13;
for (__pyx_t_15 = (-__pyx_v_ks); __pyx_t_15 &lt; __pyx_t_14; __pyx_t_15+=1) { for (__pyx_t_15 = (-__pyx_v_ks); __pyx_t_15 &lt; __pyx_t_14; __pyx_t_15+=1) {
__pyx_v_i = __pyx_t_15; __pyx_v_i = __pyx_t_15;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">42</span>: <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre> </pre>
<pre class='cython code score-0 '> __pyx_t_16 = (__pyx_v_ks + 1); <pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">42</span>: <span class="k">for</span> <span class="n">j</span> <span
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span
class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span
class="mf">1</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_16 = (__pyx_v_ks + 1);
__pyx_t_17 = __pyx_t_16; __pyx_t_17 = __pyx_t_16;
for (__pyx_t_18 = (-__pyx_v_ks); __pyx_t_18 &lt; __pyx_t_17; __pyx_t_18+=1) { for (__pyx_t_18 = (-__pyx_v_ks); __pyx_t_18 &lt; __pyx_t_17; __pyx_t_18+=1) {
__pyx_v_j = __pyx_t_18; __pyx_v_j = __pyx_t_18;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">43</span>: <span class="n">mean</span> <span class="o">+=</span> <span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span class="p">]</span></pre> </pre>
<pre class='cython code score-0 '> __pyx_t_19 = (__pyx_v_m + __pyx_v_i); <pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">43</span>: <span class="n">mean</span> <span class="o">+=</span> <span
class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span
class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span
class="p">]</span></pre>
<pre class='cython code score-0 '> __pyx_t_19 = (__pyx_v_m + __pyx_v_i);
__pyx_t_20 = (__pyx_v_n + __pyx_v_j); __pyx_t_20 = (__pyx_v_n + __pyx_v_j);
__pyx_v_mean = (__pyx_v_mean + (*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_19 * __pyx_v_img.strides[0]) ) + __pyx_t_20 * __pyx_v_img.strides[1]) )))); __pyx_v_mean = (__pyx_v_mean + (*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_19 * __pyx_v_img.strides[0]) ) + __pyx_t_20 * __pyx_v_img.strides[1]) ))));
} }
} }
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">44</span>: <span class="n">mean</span> <span class="o">=</span> <span class="n">mean</span><span class="o">/</span><span class="n">num</span></pre> </pre>
<pre class='cython code score-0 '> __pyx_v_mean = (__pyx_v_mean / __pyx_v_num); <pre class="cython line score-0"
</pre><pre class="cython line score-0">&#xA0;<span class="">45</span>: </pre> onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
<pre class="cython line score-0">&#xA0;<span class="">46</span>: <span class="c"># calculate std dev</span></pre> class="">44</span>: <span class="n">mean</span> <span class="o">=</span> <span
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">47</span>: <span class="n">stddev</span> <span class="o">=</span> <span class="mf">0</span><span class="p">;</span></pre> class="n">mean</span><span class="o">/</span><span class="n">num</span></pre>
<pre class='cython code score-0 '> __pyx_v_stddev = 0.0; <pre class='cython code score-0 '> __pyx_v_mean = (__pyx_v_mean / __pyx_v_num);
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">48</span>: <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre> </pre>
<pre class='cython code score-0 '> __pyx_t_13 = (__pyx_v_ks + 1); <pre class="cython line score-0">&#xA0;<span class="">45</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">46</span>: <span
class="c"># calculate std dev</span></pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">47</span>: <span class="n">stddev</span> <span class="o">=</span> <span
class="mf">0</span><span class="p">;</span></pre>
<pre class='cython code score-0 '> __pyx_v_stddev = 0.0;
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">48</span>: <span class="k">for</span> <span class="n">i</span> <span
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span
class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span
class="mf">1</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_13 = (__pyx_v_ks + 1);
__pyx_t_14 = __pyx_t_13; __pyx_t_14 = __pyx_t_13;
for (__pyx_t_15 = (-__pyx_v_ks); __pyx_t_15 &lt; __pyx_t_14; __pyx_t_15+=1) { for (__pyx_t_15 = (-__pyx_v_ks); __pyx_t_15 &lt; __pyx_t_14; __pyx_t_15+=1) {
__pyx_v_i = __pyx_t_15; __pyx_v_i = __pyx_t_15;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">49</span>: <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre> </pre>
<pre class='cython code score-0 '> __pyx_t_16 = (__pyx_v_ks + 1); <pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">49</span>: <span class="k">for</span> <span class="n">j</span> <span
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span
class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span
class="mf">1</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_16 = (__pyx_v_ks + 1);
__pyx_t_17 = __pyx_t_16; __pyx_t_17 = __pyx_t_16;
for (__pyx_t_18 = (-__pyx_v_ks); __pyx_t_18 &lt; __pyx_t_17; __pyx_t_18+=1) { for (__pyx_t_18 = (-__pyx_v_ks); __pyx_t_18 &lt; __pyx_t_17; __pyx_t_18+=1) {
__pyx_v_j = __pyx_t_18; __pyx_v_j = __pyx_t_18;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">50</span>: <span class="n">stddev</span> <span class="o">=</span> <span class="n">stddev</span> <span class="o">+</span> <span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span><span class="o">*</span><span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span></pre> </pre>
<pre class='cython code score-0 '> __pyx_t_21 = (__pyx_v_m + __pyx_v_i); <pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">50</span>: <span class="n">stddev</span> <span class="o">=</span> <span
class="n">stddev</span> <span class="o">+</span> <span class="p">(</span><span class="n">img</span><span
class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span
class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span
class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span><span
class="o">*</span><span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span
class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span
class="o">+</span><span class="n">j</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span
class="p">)</span></pre>
<pre class='cython code score-0 '> __pyx_t_21 = (__pyx_v_m + __pyx_v_i);
__pyx_t_22 = (__pyx_v_n + __pyx_v_j); __pyx_t_22 = (__pyx_v_n + __pyx_v_j);
__pyx_t_23 = (__pyx_v_m + __pyx_v_i); __pyx_t_23 = (__pyx_v_m + __pyx_v_i);
__pyx_t_24 = (__pyx_v_n + __pyx_v_j); __pyx_t_24 = (__pyx_v_n + __pyx_v_j);
__pyx_v_stddev = (__pyx_v_stddev + (((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_21 * __pyx_v_img.strides[0]) ) + __pyx_t_22 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean) * ((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_23 * __pyx_v_img.strides[0]) ) + __pyx_t_24 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean))); __pyx_v_stddev = (__pyx_v_stddev + (((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_21 * __pyx_v_img.strides[0]) ) + __pyx_t_22 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean) * ((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_23 * __pyx_v_img.strides[0]) ) + __pyx_t_24 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean)));
} }
} }
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">51</span>: <span class="n">stddev</span> <span class="o">=</span> <span class="n">sqrt</span><span class="p">(</span><span class="n">stddev</span><span class="o">/</span><span class="n">num</span><span class="p">)</span></pre> </pre>
<pre class='cython code score-0 '> __pyx_v_stddev = sqrt((__pyx_v_stddev / __pyx_v_num)); <pre class="cython line score-0"
</pre><pre class="cython line score-0">&#xA0;<span class="">52</span>: </pre> onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
<pre class="cython line score-0">&#xA0;<span class="">53</span>: <span class="c"># compute normalized image (add epsilon) and std dev image</span></pre> class="">51</span>: <span class="n">stddev</span> <span class="o">=</span> <span
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">54</span>: <span class="n">img_lcn_view</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span><span class="o">/</span><span class="p">(</span><span class="n">stddev</span><span class="o">+</span><span class="n">eps</span><span class="p">)</span></pre> class="n">sqrt</span><span class="p">(</span><span class="n">stddev</span><span class="o">/</span><span
<pre class='cython code score-0 '> __pyx_t_25 = __pyx_v_m; class="n">num</span><span class="p">)</span></pre>
<pre class='cython code score-0 '> __pyx_v_stddev = sqrt((__pyx_v_stddev / __pyx_v_num));
</pre>
<pre class="cython line score-0">&#xA0;<span class="">52</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">53</span>: <span class="c"># compute normalized image (add epsilon) and std dev image</span></pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">54</span>: <span class="n">img_lcn_view</span><span class="p">[</span><span
class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span
class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span
class="n">n</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span
class="p">)</span><span class="o">/</span><span class="p">(</span><span class="n">stddev</span><span
class="o">+</span><span class="n">eps</span><span class="p">)</span></pre>
<pre class='cython code score-0 '> __pyx_t_25 = __pyx_v_m;
__pyx_t_26 = __pyx_v_n; __pyx_t_26 = __pyx_v_n;
__pyx_t_27 = __pyx_v_m; __pyx_t_27 = __pyx_v_m;
__pyx_t_28 = __pyx_v_n; __pyx_t_28 = __pyx_v_n;
*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img_lcn_view.data + __pyx_t_27 * __pyx_v_img_lcn_view.strides[0]) ) + __pyx_t_28 * __pyx_v_img_lcn_view.strides[1]) )) = (((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_25 * __pyx_v_img.strides[0]) ) + __pyx_t_26 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean) / (__pyx_v_stddev + __pyx_v_eps)); *((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img_lcn_view.data + __pyx_t_27 * __pyx_v_img_lcn_view.strides[0]) ) + __pyx_t_28 * __pyx_v_img_lcn_view.strides[1]) )) = (((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_25 * __pyx_v_img.strides[0]) ) + __pyx_t_26 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean) / (__pyx_v_stddev + __pyx_v_eps));
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">55</span>: <span class="n">img_std_view</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span class="n">stddev</span></pre> </pre>
<pre class='cython code score-0 '> __pyx_t_29 = __pyx_v_m; <pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">55</span>: <span class="n">img_std_view</span><span class="p">[</span><span
class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span
class="n">stddev</span></pre>
<pre class='cython code score-0 '> __pyx_t_29 = __pyx_v_m;
__pyx_t_30 = __pyx_v_n; __pyx_t_30 = __pyx_v_n;
*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img_std_view.data + __pyx_t_29 * __pyx_v_img_std_view.strides[0]) ) + __pyx_t_30 * __pyx_v_img_std_view.strides[1]) )) = __pyx_v_stddev; *((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img_std_view.data + __pyx_t_29 * __pyx_v_img_std_view.strides[0]) ) + __pyx_t_30 * __pyx_v_img_std_view.strides[1]) )) = __pyx_v_stddev;
} }
} }
</pre><pre class="cython line score-0">&#xA0;<span class="">56</span>: </pre> </pre>
<pre class="cython line score-0">&#xA0;<span class="">57</span>: <span class="c"># return both</span></pre> <pre class="cython line score-0">&#xA0;<span class="">56</span>: </pre>
<pre class="cython line score-10" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">58</span>: <span class="k">return</span> <span class="n">img_lcn</span><span class="p">,</span> <span class="n">img_std</span></pre> <pre class="cython line score-0">&#xA0;<span class="">57</span>: <span class="c"># return both</span></pre>
<pre class='cython code score-10 '> <span class='pyx_macro_api'>__Pyx_XDECREF</span>(__pyx_r); <pre class="cython line score-10"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">58</span>: <span class="k">return</span> <span class="n">img_lcn</span><span class="p">,</span> <span
class="n">img_std</span></pre>
<pre class='cython code score-10 '> <span class='pyx_macro_api'>__Pyx_XDECREF</span>(__pyx_r);
__pyx_t_1 = <span class='py_c_api'>PyTuple_New</span>(2);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 58, __pyx_L1_error)</span> __pyx_t_1 = <span class='py_c_api'>PyTuple_New</span>(2);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 58, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1); <span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
<span class='pyx_macro_api'>__Pyx_INCREF</span>(__pyx_v_img_lcn); <span class='pyx_macro_api'>__Pyx_INCREF</span>(__pyx_v_img_lcn);
@ -717,4 +903,7 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
__pyx_r = __pyx_t_1; __pyx_r = __pyx_t_1;
__pyx_t_1 = 0; __pyx_t_1 = 0;
goto __pyx_L0; goto __pyx_L0;
</pre></div></body></html> </pre>
</div>
</body>
</html>

View File

@ -2,5 +2,5 @@ from distutils.core import setup
from Cython.Build import cythonize from Cython.Build import cythonize
setup( setup(
ext_modules = cythonize("lcn.pyx",annotate=True) ext_modules=cythonize("lcn.pyx", annotate=True)
) )

View File

@ -5,43 +5,43 @@ from scipy import misc
# load and convert to float # load and convert to float
img = misc.imread('img.png') img = misc.imread('img.png')
img = img.astype(np.float32)/255.0 img = img.astype(np.float32) / 255.0
# normalize # normalize
img_lcn, img_std = lcn.normalize(img,5,0.05) img_lcn, img_std = lcn.normalize(img, 5, 0.05)
# normalize to reasonable range between 0 and 1 # normalize to reasonable range between 0 and 1
#img_lcn = img_lcn/3.0 # img_lcn = img_lcn/3.0
#img_lcn = np.maximum(img_lcn,0.0) # img_lcn = np.maximum(img_lcn,0.0)
#img_lcn = np.minimum(img_lcn,1.0) # img_lcn = np.minimum(img_lcn,1.0)
# save to file # save to file
#misc.imsave('lcn2.png',img_lcn) # misc.imsave('lcn2.png',img_lcn)
print ("Orig Image: %d x %d (%s), Min: %f, Max: %f" % \ print("Orig Image: %d x %d (%s), Min: %f, Max: %f" % \
(img.shape[0], img.shape[1], img.dtype, img.min(), img.max())) (img.shape[0], img.shape[1], img.dtype, img.min(), img.max()))
print ("Norm Image: %d x %d (%s), Min: %f, Max: %f" % \ print("Norm Image: %d x %d (%s), Min: %f, Max: %f" % \
(img_lcn.shape[0], img_lcn.shape[1], img_lcn.dtype, img_lcn.min(), img_lcn.max())) (img_lcn.shape[0], img_lcn.shape[1], img_lcn.dtype, img_lcn.min(), img_lcn.max()))
# plot original image # plot original image
plt.figure(1) plt.figure(1)
img_plot = plt.imshow(img) img_plot = plt.imshow(img)
img_plot.set_cmap('gray') img_plot.set_cmap('gray')
plt.clim(0, 1) # fix range plt.clim(0, 1) # fix range
plt.tight_layout() plt.tight_layout()
# plot normalized image # plot normalized image
plt.figure(2) plt.figure(2)
img_lcn_plot = plt.imshow(img_lcn) img_lcn_plot = plt.imshow(img_lcn)
img_lcn_plot.set_cmap('gray') img_lcn_plot.set_cmap('gray')
#plt.clim(0, 1) # fix range # plt.clim(0, 1) # fix range
plt.tight_layout() plt.tight_layout()
# plot stddev image # plot stddev image
plt.figure(3) plt.figure(3)
img_std_plot = plt.imshow(img_std) img_std_plot = plt.imshow(img_std)
img_std_plot.set_cmap('gray') img_std_plot.set_cmap('gray')
#plt.clim(0, 0.1) # fix range # plt.clim(0, 0.1) # fix range
plt.tight_layout() plt.tight_layout()
plt.show() plt.show()

View File

@ -11,28 +11,27 @@ import dataset
def get_data(n, row_from, row_to, train): def get_data(n, row_from, row_to, train):
imsizes = [(256,384)] imsizes = [(256, 384)]
focal_lengths = [160] focal_lengths = [160]
dset = dataset.SynDataset(n, imsizes=imsizes, focal_lengths=focal_lengths, train=train) dset = dataset.SynDataset(n, imsizes=imsizes, focal_lengths=focal_lengths, train=train)
ims = np.empty((n, row_to-row_from, imsizes[0][1]), dtype=np.uint8) ims = np.empty((n, row_to - row_from, imsizes[0][1]), dtype=np.uint8)
disps = np.empty((n, row_to-row_from, imsizes[0][1]), dtype=np.float32) disps = np.empty((n, row_to - row_from, imsizes[0][1]), dtype=np.float32)
for idx in range(n): for idx in range(n):
print(f'load sample {idx} train={train}') print(f'load sample {idx} train={train}')
sample = dset[idx] sample = dset[idx]
ims[idx] = (sample['im0'][0,row_from:row_to] * 255).astype(np.uint8) ims[idx] = (sample['im0'][0, row_from:row_to] * 255).astype(np.uint8)
disps[idx] = sample['disp0'][0,row_from:row_to] disps[idx] = sample['disp0'][0, row_from:row_to]
return ims, disps return ims, disps
params = hd.TrainParams( params = hd.TrainParams(
n_trees=4, n_trees=4,
max_tree_depth=, max_tree_depth=,
n_test_split_functions=50, n_test_split_functions=50,
n_test_thresholds=10, n_test_thresholds=10,
n_test_samples=4096, n_test_samples=4096,
min_samples_to_split=16, min_samples_to_split=16,
min_samples_for_leaf=8) min_samples_for_leaf=8)
n_disp_bins = 20 n_disp_bins = 20
depth_switch = 0 depth_switch = 0
@ -45,21 +44,23 @@ n_test_samples = 32
train_ims, train_disps = get_data(n_train_samples, row_from, row_to, True) train_ims, train_disps = get_data(n_train_samples, row_from, row_to, True)
test_ims, test_disps = get_data(n_test_samples, row_from, row_to, False) test_ims, test_disps = get_data(n_test_samples, row_from, row_to, False)
for tree_depth in [8,10,12,14,16]: for tree_depth in [8, 10, 12, 14, 16]:
depth_switch = tree_depth - 4 depth_switch = tree_depth - 4
prefix = f'td{tree_depth}_ds{depth_switch}' prefix = f'td{tree_depth}_ds{depth_switch}'
prefix = Path(f'./forests/{prefix}/') prefix = Path(f'./forests/{prefix}/')
prefix.mkdir(parents=True, exist_ok=True) prefix.mkdir(parents=True, exist_ok=True)
hd.train_forest(params, train_ims, train_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch, forest_prefix=str(prefix / 'fr')) hd.train_forest(params, train_ims, train_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch,
forest_prefix=str(prefix / 'fr'))
es = hd.eval_forest(test_ims, test_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch, forest_prefix=str(prefix / 'fr')) es = hd.eval_forest(test_ims, test_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch,
forest_prefix=str(prefix / 'fr'))
np.save(str(prefix / 'ta.npy'), test_disps) np.save(str(prefix / 'ta.npy'), test_disps)
np.save(str(prefix / 'es.npy'), es) np.save(str(prefix / 'es.npy'), es)
# plt.figure(); # plt.figure();
# plt.subplot(2,1,1); plt.imshow(test_disps[0], vmin=0, vmax=4); # plt.subplot(2,1,1); plt.imshow(test_disps[0], vmin=0, vmax=4);
# plt.subplot(2,1,2); plt.imshow(es[0], vmin=0, vmax=4); # plt.subplot(2,1,2); plt.imshow(es[0], vmin=0, vmax=4);
# plt.show() # plt.show()

View File

@ -8,7 +8,6 @@ import os
this_dir = os.path.dirname(__file__) this_dir = os.path.dirname(__file__)
extra_compile_args = ['-O3', '-std=c++11'] extra_compile_args = ['-O3', '-std=c++11']
extra_link_args = [] extra_link_args = []
@ -22,24 +21,20 @@ library_dirs = []
libraries = ['m'] libraries = ['m']
setup( setup(
name="hyperdepth", name="hyperdepth",
cmdclass= {'build_ext': build_ext}, cmdclass={'build_ext': build_ext},
ext_modules=[ ext_modules=[
Extension('hyperdepth', Extension('hyperdepth',
sources, sources,
extra_objects=extra_objects, extra_objects=extra_objects,
language='c++', language='c++',
library_dirs=library_dirs, library_dirs=library_dirs,
libraries=libraries, libraries=libraries,
include_dirs=[ include_dirs=[
np.get_include(), np.get_include(),
], ],
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args extra_link_args=extra_link_args
) )
] ]
) )

View File

@ -6,10 +6,13 @@ orig = cv2.imread('disp_orig.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
ta = cv2.imread('disp_ta.png', cv2.IMREAD_ANYDEPTH).astype(np.float32) ta = cv2.imread('disp_ta.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
es = cv2.imread('disp_es.png', cv2.IMREAD_ANYDEPTH).astype(np.float32) es = cv2.imread('disp_es.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
plt.figure() plt.figure()
plt.subplot(2,2,1); plt.imshow(orig / 16, vmin=0, vmax=4, cmap='magma') plt.subplot(2, 2, 1);
plt.subplot(2,2,2); plt.imshow(ta / 16, vmin=0, vmax=4, cmap='magma') plt.imshow(orig / 16, vmin=0, vmax=4, cmap='magma')
plt.subplot(2,2,3); plt.imshow(es / 16, vmin=0, vmax=4, cmap='magma') plt.subplot(2, 2, 2);
plt.subplot(2,2,4); plt.imshow(np.abs(es - ta) / 16, vmin=0, vmax=1, cmap='magma') plt.imshow(ta / 16, vmin=0, vmax=4, cmap='magma')
plt.subplot(2, 2, 3);
plt.imshow(es / 16, vmin=0, vmax=4, cmap='magma')
plt.subplot(2, 2, 4);
plt.imshow(np.abs(es - ta) / 16, vmin=0, vmax=1, cmap='magma')
plt.show() plt.show()

View File

@ -12,226 +12,263 @@ import torchext
from model import networks from model import networks
from data import dataset from data import dataset
class Worker(torchext.Worker): class Worker(torchext.Worker):
def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs): def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs):
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, train_batch_size=train_batch_size, test_batch_size=test_batch_size, save_frequency=save_frequency, **kwargs) super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers,
train_batch_size=train_batch_size, test_batch_size=test_batch_size,
save_frequency=save_frequency, **kwargs)
self.ms = args.ms self.ms = args.ms
self.pattern_path = args.pattern_path self.pattern_path = args.pattern_path
self.lcn_radius = args.lcn_radius self.lcn_radius = args.lcn_radius
self.dp_weight = args.dp_weight self.dp_weight = args.dp_weight
self.data_type = args.data_type self.data_type = args.data_type
self.imsizes = [(480,640)] self.imsizes = [(488, 648)]
for iter in range(3): for iter in range(3):
self.imsizes.append((int(self.imsizes[-1][0]/2), int(self.imsizes[-1][1]/2))) self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2)))
with open('config.json') as fp: with open('config.json') as fp:
config = json.load(fp) config = json.load(fp)
data_root = Path(config['DATA_ROOT']) data_root = Path(config['DATA_ROOT'])
self.settings_path = data_root / self.data_type / 'settings.pkl' self.settings_path = data_root / self.data_type / 'settings.pkl'
sample_paths = sorted((data_root / self.data_type).glob('0*/')) sample_paths = sorted((data_root / self.data_type).glob('0*/'))
self.train_paths = sample_paths[2**10:] self.train_paths = sample_paths[2 ** 10:]
self.test_paths = sample_paths[:2**8] self.test_paths = sample_paths[:2 ** 8]
# supervise the edge encoder with only 2**8 samples # supervise the edge encoder with only 2**8 samples
self.train_edge = len(self.train_paths) - 2**8 self.train_edge = len(self.train_paths) - 2 ** 8
self.lcn_in = networks.LCN(self.lcn_radius, 0.05) self.lcn_in = networks.LCN(self.lcn_radius, 0.05)
self.disparity_loss = networks.DisparityLoss() self.disparity_loss = networks.DisparityLoss()
self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device)) self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device))
# evaluate in the region where opencv Block Matching has valid values # evaluate in the region where opencv Block Matching has valid values
self.eval_mask = np.zeros(self.imsizes[0]) self.eval_mask = np.zeros(self.imsizes[0])
self.eval_mask[13:self.imsizes[0][0]-13, 140:self.imsizes[0][1]-13]=1 self.eval_mask[13:self.imsizes[0][0] - 13, 140:self.imsizes[0][1] - 13] = 1
self.eval_mask = self.eval_mask.astype(np.bool) self.eval_mask = self.eval_mask.astype(np.bool)
self.eval_h = self.imsizes[0][0]-2*13 self.eval_h = self.imsizes[0][0] - 2 * 13
self.eval_w = self.imsizes[0][1]-13-140 self.eval_w = self.imsizes[0][1] - 13 - 140
def get_train_set(self): def get_train_set(self):
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True, track_length=1) train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True,
track_length=1)
return train_set return train_set
def get_test_sets(self): def get_test_sets(self):
test_sets = torchext.TestSets() test_sets = torchext.TestSets()
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True, track_length=1) test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True,
test_sets.append('simple', test_set, test_frequency=1) track_length=1)
test_sets.append('simple', test_set, test_frequency=1)
# initialize photometric loss modules according to image sizes # initialize photometric loss modules according to image sizes
self.losses = [] self.losses = []
for imsize, pat in zip(test_set.imsizes, test_set.patterns): for imsize, pat in zip(test_set.imsizes, test_set.patterns):
pat = pat.mean(axis=2) pat = pat.mean(axis=2)
pat = torch.from_numpy(pat[None][None].astype(np.float32)) pat = torch.from_numpy(pat[None][None].astype(np.float32))
pat = pat.to(self.train_device) pat = pat.to(self.train_device)
self.lcn_in = self.lcn_in.to(self.train_device) self.lcn_in = self.lcn_in.to(self.train_device)
pat,_ = self.lcn_in(pat) pat, _ = self.lcn_in(pat)
pat = torch.cat([pat for idx in range(3)], dim=1) pat = torch.cat([pat for idx in range(3)], dim=1)
self.losses.append( networks.RectifiedPatternSimilarityLoss(imsize[0],imsize[1], pattern=pat) ) self.losses.append(networks.RectifiedPatternSimilarityLoss(imsize[0], imsize[1], pattern=pat))
return test_sets return test_sets
def copy_data(self, data, device, requires_grad, train): def copy_data(self, data, device, requires_grad, train):
self.lcn_in = self.lcn_in.to(device) self.lcn_in = self.lcn_in.to(device)
self.data = {} self.data = {}
for key, val in data.items(): for key, val in data.items():
grad = 'im' in key and requires_grad grad = 'im' in key and requires_grad
self.data[key] = val.to(device).requires_grad_(requires_grad=grad) self.data[key] = val.to(device).requires_grad_(requires_grad=grad)
# apply lcn to IR input # apply lcn to IR input
# concatenate the normalized IR input and the original IR image # concatenate the normalized IR input and the original IR image
if 'im' in key and 'blend' not in key: if 'im' in key and 'blend' not in key:
im = self.data[key] im = self.data[key]
im_lcn,im_std = self.lcn_in(im) im_lcn, im_std = self.lcn_in(im)
im_cat = torch.cat((im_lcn, im), dim=1) im_cat = torch.cat((im_lcn, im), dim=1)
key_std = key.replace('im','std') key_std = key.replace('im', 'std')
self.data[key]=im_cat self.data[key] = im_cat
self.data[key_std] = im_std.to(device).detach() self.data[key_std] = im_std.to(device).detach()
def net_forward(self, net, train): def net_forward(self, net, train):
out = net(self.data['im0']) out = net(self.data['im0'])
return out return out
def loss_forward(self, out, train): def loss_forward(self, out, train):
out, edge = out out, edge = out
if not(isinstance(out, tuple) or isinstance(out, list)): if not (isinstance(out, tuple) or isinstance(out, list)):
out = [out] out = [out]
if not(isinstance(edge, tuple) or isinstance(edge, list)): if not (isinstance(edge, tuple) or isinstance(edge, list)):
edge = [edge] edge = [edge]
vals = [] vals = []
# apply photometric loss # apply photometric loss
for s,l,o in zip(itertools.count(), self.losses, out): for s, l, o in zip(itertools.count(), self.losses, out):
val, pattern_proj = l(o, self.data[f'im{s}'][:,0:1,...], self.data[f'std{s}']) val, pattern_proj = l(o, self.data[f'im{s}'][:, 0:1, ...], self.data[f'std{s}'])
if s == 0: if s == 0:
self.pattern_proj = pattern_proj.detach() self.pattern_proj = pattern_proj.detach()
vals.append(val) vals.append(val)
# apply disparity loss # apply disparity loss
# 1-edge as ground truth edge if inversed # 1-edge as ground truth edge if inversed
edge0 = 1-torch.sigmoid(edge[0]) edge0 = 1 - torch.sigmoid(edge[0])
val = self.disparity_loss(out[0], edge0) val = self.disparity_loss(out[0], edge0)
if self.dp_weight>0: if self.dp_weight > 0:
vals.append(val * self.dp_weight) vals.append(val * self.dp_weight)
# apply edge loss on a subset of training samples # apply edge loss on a subset of training samples
for s,e in zip(itertools.count(), edge): for s, e in zip(itertools.count(), edge):
# inversed ground truth edge where 0 means edge # inversed ground truth edge where 0 means edge
grad = self.data[f'grad{s}']<0.2 grad = self.data[f'grad{s}'] < 0.2
grad = grad.to(torch.float32) grad = grad.to(torch.float32)
ids = self.data['id'] ids = self.data['id']
mask = ids>self.train_edge mask = ids > self.train_edge
if mask.sum()>0: if mask.sum() > 0:
val = self.edge_loss(e[mask], grad[mask]) val = self.edge_loss(e[mask], grad[mask])
else: else:
val = torch.zeros_like(vals[0]) val = torch.zeros_like(vals[0])
if s == 0: if s == 0:
self.edge = e.detach() self.edge = e.detach()
self.edge = torch.sigmoid(self.edge) self.edge = torch.sigmoid(self.edge)
self.edge_gt = grad.detach() self.edge_gt = grad.detach()
vals.append(val) vals.append(val)
return vals return vals
def numpy_in_out(self, output): def numpy_in_out(self, output):
output, edge = output output, edge = output
if not(isinstance(output, tuple) or isinstance(output, list)): if not (isinstance(output, tuple) or isinstance(output, list)):
output = [output] output = [output]
es = output[0].detach().to('cpu').numpy() es = output[0].detach().to('cpu').numpy()
gt = self.data['disp0'].to('cpu').numpy().astype(np.float32) gt = self.data['disp0'].to('cpu').numpy().astype(np.float32)
im = self.data['im0'][:,0:1,...].detach().to('cpu').numpy() im = self.data['im0'][:, 0:1, ...].detach().to('cpu').numpy()
ma = gt>0 ma = gt > 0
return es, gt, im, ma return es, gt, im, ma
def write_img(self, out_path, es, gt, im, ma): def write_img(self, out_path, es, gt, im, ma):
logging.info(f'write img {out_path}') logging.info(f'write img {out_path}')
u_pos, _ = np.meshgrid(range(es.shape[1]), range(es.shape[0])) u_pos, _ = np.meshgrid(range(es.shape[1]), range(es.shape[0]))
diff = np.abs(es - gt) diff = np.abs(es - gt)
vmin, vmax = np.nanmin(gt), np.nanmax(gt) vmin, vmax = np.nanmin(gt), np.nanmax(gt)
vmin = vmin - 0.2*(vmax-vmin) vmin = vmin - 0.2 * (vmax - vmin)
vmax = vmax + 0.2*(vmax-vmin) vmax = vmax + 0.2 * (vmax - vmin)
pattern_proj = self.pattern_proj.to('cpu').numpy()[0,0] pattern_proj = self.pattern_proj.to('cpu').numpy()[0, 0]
im_orig = self.data['im0'].detach().to('cpu').numpy()[0,0] im_orig = self.data['im0'].detach().to('cpu').numpy()[0, 0]
pattern_diff = np.abs(im_orig - pattern_proj) pattern_diff = np.abs(im_orig - pattern_proj)
fig = plt.figure(figsize=(16, 16))
es_ = co.cmap.color_depth_map(es, scale=vmax)
gt_ = co.cmap.color_depth_map(gt, scale=vmax)
diff_ = co.cmap.color_error_image(diff, BGR=True)
fig = plt.figure(figsize=(16,16)) # plot disparities, ground truth disparity is shown only for reference
es_ = co.cmap.color_depth_map(es, scale=vmax) ax = plt.subplot(3, 3, 1)
gt_ = co.cmap.color_depth_map(gt, scale=vmax) plt.imshow(es_[..., [2, 1, 0]])
diff_ = co.cmap.color_error_image(diff, BGR=True) plt.xticks([])
plt.yticks([])
ax.set_title(f'Disparity Est. {es.min():.4f}/{es.max():.4f}')
ax = plt.subplot(3, 3, 2)
plt.imshow(gt_[..., [2, 1, 0]])
plt.xticks([])
plt.yticks([])
ax.set_title(f'Disparity GT {np.nanmin(gt):.4f}/{np.nanmax(gt):.4f}')
ax = plt.subplot(3, 3, 3)
plt.imshow(diff_[..., [2, 1, 0]])
plt.xticks([])
plt.yticks([])
ax.set_title(f'Disparity Err. {diff.mean():.5f}')
# plot disparities, ground truth disparity is shown only for reference # plot edges
ax = plt.subplot(3,3,1); plt.imshow(es_[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'Disparity Est. {es.min():.4f}/{es.max():.4f}') edge = self.edge.to('cpu').numpy()[0, 0]
ax = plt.subplot(3,3,2); plt.imshow(gt_[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'Disparity GT {np.nanmin(gt):.4f}/{np.nanmax(gt):.4f}') edge_gt = self.edge_gt.to('cpu').numpy()[0, 0]
ax = plt.subplot(3,3,3); plt.imshow(diff_[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'Disparity Err. {diff.mean():.5f}') edge_err = np.abs(edge - edge_gt)
ax = plt.subplot(3, 3, 4);
plt.imshow(edge, cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'Edge Est. {edge.min():.5f}/{edge.max():.5f}')
ax = plt.subplot(3, 3, 5);
plt.imshow(edge_gt, cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'Edge GT {edge_gt.min():.5f}/{edge_gt.max():.5f}')
ax = plt.subplot(3, 3, 6);
plt.imshow(edge_err, cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'Edge Err. {edge_err.mean():.5f}')
# plot edges # plot normalized IR input and warped pattern
edge = self.edge.to('cpu').numpy()[0,0] ax = plt.subplot(3, 3, 7);
edge_gt = self.edge_gt.to('cpu').numpy()[0,0] plt.imshow(im, vmin=im.min(), vmax=im.max(), cmap='gray');
edge_err = np.abs(edge - edge_gt) plt.xticks([]);
ax = plt.subplot(3,3,4); plt.imshow(edge, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Edge Est. {edge.min():.5f}/{edge.max():.5f}') plt.yticks([]);
ax = plt.subplot(3,3,5); plt.imshow(edge_gt, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Edge GT {edge_gt.min():.5f}/{edge_gt.max():.5f}') ax.set_title(f'IR input {im.mean():.5f}/{im.std():.5f}')
ax = plt.subplot(3,3,6); plt.imshow(edge_err, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Edge Err. {edge_err.mean():.5f}') ax = plt.subplot(3, 3, 8);
plt.imshow(pattern_proj, vmin=im.min(), vmax=im.max(), cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'Warped Pattern {pattern_proj.mean():.5f}/{pattern_proj.std():.5f}')
im_std = self.data['std0'].to('cpu').numpy()[0, 0]
ax = plt.subplot(3, 3, 9);
plt.imshow(im_std, cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'IR std {im_std.min():.5f}/{im_std.max():.5f}')
# plot normalized IR input and warped pattern plt.tight_layout()
ax = plt.subplot(3,3,7); plt.imshow(im, vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'IR input {im.mean():.5f}/{im.std():.5f}') plt.savefig(str(out_path))
ax = plt.subplot(3,3,8); plt.imshow(pattern_proj, vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Warped Pattern {pattern_proj.mean():.5f}/{pattern_proj.std():.5f}') plt.close(fig)
im_std = self.data['std0'].to('cpu').numpy()[0,0]
ax = plt.subplot(3,3,9); plt.imshow(im_std, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'IR std {im_std.min():.5f}/{im_std.max():.5f}')
plt.tight_layout() def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks=[]):
plt.savefig(str(out_path)) if batch_idx % 512 == 0:
plt.close(fig) out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png'
es, gt, im, ma = self.numpy_in_out(output)
self.write_img(out_path, es[0, 0], gt[0, 0], im[0, 0], ma[0, 0])
def callback_test_start(self, epoch, set_idx):
self.metric = co.metric.MultipleMetric(
co.metric.DistanceMetric(vec_length=1),
co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5])
)
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks=[]): def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks=[]):
if batch_idx % 512 == 0: es, gt, im, ma = self.numpy_in_out(output)
out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png'
es, gt, im, ma = self.numpy_in_out(output)
self.write_img(out_path, es[0,0], gt[0,0], im[0,0], ma[0,0])
if batch_idx % 8 == 0:
out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png'
self.write_img(out_path, es[0, 0], gt[0, 0], im[0, 0], ma[0, 0])
def callback_test_start(self, epoch, set_idx): es, gt, im, ma = self.crop_output(es, gt, im, ma)
self.metric = co.metric.MultipleMetric(
co.metric.DistanceMetric(vec_length=1),
co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5])
)
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks=[]): es = es.reshape(-1, 1)
es, gt, im, ma = self.numpy_in_out(output) gt = gt.reshape(-1, 1)
ma = ma.ravel()
self.metric.add(es, gt, ma)
if batch_idx % 8 == 0: def callback_test_stop(self, epoch, set_idx, loss):
out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png' logging.info(f'{self.metric}')
self.write_img(out_path, es[0,0], gt[0,0], im[0,0], ma[0,0]) for k, v in self.metric.items():
self.metric_add_test(epoch, set_idx, k, v)
es, gt, im, ma = self.crop_output(es, gt, im, ma)
es = es.reshape(-1,1)
gt = gt.reshape(-1,1)
ma = ma.ravel()
self.metric.add(es, gt, ma)
def callback_test_stop(self, epoch, set_idx, loss):
logging.info(f'{self.metric}')
for k, v in self.metric.items():
self.metric_add_test(epoch, set_idx, k, v)
def crop_output(self, es, gt, im, ma):
bs = es.shape[0]
es = np.reshape(es[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
gt = np.reshape(gt[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
im = np.reshape(im[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
ma = np.reshape(ma[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
return es, gt, im, ma
def crop_output(self, es, gt, im, ma):
bs = es.shape[0]
es = np.reshape(es[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
gt = np.reshape(gt[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
im = np.reshape(im[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
ma = np.reshape(ma[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
return es, gt, im, ma
if __name__ == '__main__': if __name__ == '__main__':
pass pass

View File

@ -12,287 +12,324 @@ import torchext
from model import networks from model import networks
from data import dataset from data import dataset
class Worker(torchext.Worker): class Worker(torchext.Worker):
def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs): def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs):
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, train_batch_size=train_batch_size, test_batch_size=test_batch_size, save_frequency=save_frequency, **kwargs) super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers,
train_batch_size=train_batch_size, test_batch_size=test_batch_size,
save_frequency=save_frequency, **kwargs)
self.ms = args.ms self.ms = args.ms
self.pattern_path = args.pattern_path self.pattern_path = args.pattern_path
self.lcn_radius = args.lcn_radius self.lcn_radius = args.lcn_radius
self.dp_weight = args.dp_weight self.dp_weight = args.dp_weight
self.ge_weight = args.ge_weight self.ge_weight = args.ge_weight
self.track_length = args.track_length self.track_length = args.track_length
self.data_type = args.data_type self.data_type = args.data_type
assert(self.track_length>1) assert (self.track_length > 1)
self.imsizes = [(480,640)] self.imsizes = [(480, 640)]
for iter in range(3): for iter in range(3):
self.imsizes.append((int(self.imsizes[-1][0]/2), int(self.imsizes[-1][1]/2))) self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2)))
with open('config.json') as fp: with open('config.json') as fp:
config = json.load(fp) config = json.load(fp)
data_root = Path(config['DATA_ROOT']) data_root = Path(config['DATA_ROOT'])
self.settings_path = data_root / self.data_type / 'settings.pkl' self.settings_path = data_root / self.data_type / 'settings.pkl'
sample_paths = sorted((data_root / self.data_type).glob('0*/')) sample_paths = sorted((data_root / self.data_type).glob('0*/'))
self.train_paths = sample_paths[2**10:] self.train_paths = sample_paths[2 ** 10:]
self.test_paths = sample_paths[:2**8] self.test_paths = sample_paths[:2 ** 8]
# supervise the edge encoder with only 2**8 samples # supervise the edge encoder with only 2**8 samples
self.train_edge = len(self.train_paths) - 2**8 self.train_edge = len(self.train_paths) - 2 ** 8
self.lcn_in = networks.LCN(self.lcn_radius, 0.05) self.lcn_in = networks.LCN(self.lcn_radius, 0.05)
self.disparity_loss = networks.DisparityLoss() self.disparity_loss = networks.DisparityLoss()
self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device)) self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device))
# evaluate in the region where opencv Block Matching has valid values # evaluate in the region where opencv Block Matching has valid values
self.eval_mask = np.zeros(self.imsizes[0]) self.eval_mask = np.zeros(self.imsizes[0])
self.eval_mask[13:self.imsizes[0][0]-13, 140:self.imsizes[0][1]-13]=1 self.eval_mask[13:self.imsizes[0][0] - 13, 140:self.imsizes[0][1] - 13] = 1
self.eval_mask = self.eval_mask.astype(np.bool) self.eval_mask = self.eval_mask.astype(np.bool)
self.eval_h = self.imsizes[0][0]-2*13 self.eval_h = self.imsizes[0][0] - 2 * 13
self.eval_w = self.imsizes[0][1]-13-140 self.eval_w = self.imsizes[0][1] - 13 - 140
def get_train_set(self):
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True,
track_length=self.track_length)
return train_set
def get_train_set(self): def get_test_sets(self):
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True, track_length=self.track_length) test_sets = torchext.TestSets()
return train_set test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True,
track_length=1)
test_sets.append('simple', test_set, test_frequency=1)
def get_test_sets(self): self.ph_losses = []
test_sets = torchext.TestSets() self.ge_losses = []
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True, track_length=1) self.d2ds = []
test_sets.append('simple', test_set, test_frequency=1)
self.ph_losses = [] self.lcn_in = self.lcn_in.to('cuda')
self.ge_losses = [] for sidx in range(len(test_set.imsizes)):
self.d2ds = [] imsize = test_set.imsizes[sidx]
pat = test_set.patterns[sidx]
pat = pat.mean(axis=2)
pat = torch.from_numpy(pat[None][None].astype(np.float32)).to('cuda')
pat, _ = self.lcn_in(pat)
pat = torch.cat([pat for idx in range(3)], dim=1)
ph_loss = networks.RectifiedPatternSimilarityLoss(imsize[0], imsize[1], pattern=pat)
self.lcn_in = self.lcn_in.to('cuda') K = test_set.getK(sidx)
for sidx in range(len(test_set.imsizes)): Ki = np.linalg.inv(K)
imsize = test_set.imsizes[sidx] K = torch.from_numpy(K)
pat = test_set.patterns[sidx] Ki = torch.from_numpy(Ki)
pat = pat.mean(axis=2) ge_loss = networks.ProjectionDepthSimilarityLoss(K, Ki, imsize[0], imsize[1], clamp=0.1)
pat = torch.from_numpy(pat[None][None].astype(np.float32)).to('cuda')
pat,_ = self.lcn_in(pat)
pat = torch.cat([pat for idx in range(3)], dim=1)
ph_loss = networks.RectifiedPatternSimilarityLoss(imsize[0],imsize[1], pattern=pat)
K = test_set.getK(sidx) self.ph_losses.append(ph_loss)
Ki = np.linalg.inv(K) self.ge_losses.append(ge_loss)
K = torch.from_numpy(K)
Ki = torch.from_numpy(Ki)
ge_loss = networks.ProjectionDepthSimilarityLoss(K, Ki, imsize[0], imsize[1], clamp=0.1)
self.ph_losses.append( ph_loss ) d2d = networks.DispToDepth(float(test_set.focal_lengths[sidx]), float(test_set.baseline))
self.ge_losses.append( ge_loss ) self.d2ds.append(d2d)
d2d = networks.DispToDepth(float(test_set.focal_lengths[sidx]), float(test_set.baseline))
self.d2ds.append( d2d )
return test_sets return test_sets
def copy_data(self, data, device, requires_grad, train): def copy_data(self, data, device, requires_grad, train):
self.data = {} self.data = {}
self.lcn_in = self.lcn_in.to(device) self.lcn_in = self.lcn_in.to(device)
for key, val in data.items(): for key, val in data.items():
# from # from
# batch_size x track_length x ... # batch_size x track_length x ...
# to # to
# track_length x batch_size x ... # track_length x batch_size x ...
if len(val.shape)>2: if len(val.shape) > 2:
if train: if train:
val = val.transpose(0,1) val = val.transpose(0, 1)
else:
val = val.unsqueeze(0)
grad = 'im' in key and requires_grad
self.data[key] = val.to(device).requires_grad_(requires_grad=grad)
if 'im' in key and 'blend' not in key:
im = self.data[key]
tl = im.shape[0]
bs = im.shape[1]
im_lcn, im_std = self.lcn_in(im.contiguous().view(-1, *im.shape[2:]))
key_std = key.replace('im', 'std')
self.data[key_std] = im_std.view(tl, bs, *im.shape[2:]).to(device)
im_cat = torch.cat((im_lcn.view(tl, bs, *im.shape[2:]), im), dim=2)
self.data[key] = im_cat
def net_forward(self, net, train):
im0 = self.data['im0']
tl = im0.shape[0]
bs = im0.shape[1]
im0 = im0.view(-1, *im0.shape[2:])
out, edge = net(im0)
if not (isinstance(out, tuple) or isinstance(out, list)):
out = out.view(tl, bs, *out.shape[1:])
edge = edge.view(tl, bs, *out.shape[1:])
else: else:
val = val.unsqueeze(0) out = [o.view(tl, bs, *o.shape[1:]) for o in out]
grad = 'im' in key and requires_grad edge = [e.view(tl, bs, *e.shape[1:]) for e in edge]
self.data[key] = val.to(device).requires_grad_(requires_grad=grad) return out, edge
if 'im' in key and 'blend' not in key:
im = self.data[key]
tl = im.shape[0]
bs = im.shape[1]
im_lcn,im_std = self.lcn_in(im.contiguous().view(-1, *im.shape[2:]))
key_std = key.replace('im','std')
self.data[key_std] = im_std.view(tl, bs, *im.shape[2:]).to(device)
im_cat = torch.cat((im_lcn.view(tl, bs, *im.shape[2:]), im), dim=2)
self.data[key] = im_cat
def net_forward(self, net, train): def loss_forward(self, out, train):
im0 = self.data['im0'] out, edge = out
tl = im0.shape[0] if not (isinstance(out, tuple) or isinstance(out, list)):
bs = im0.shape[1] out = [out]
im0 = im0.view(-1, *im0.shape[2:]) vals = []
out, edge = net(im0) diffs = []
if not(isinstance(out, tuple) or isinstance(out, list)):
out = out.view(tl, bs, *out.shape[1:])
edge = edge.view(tl, bs, *out.shape[1:])
else:
out = [o.view(tl, bs, *o.shape[1:]) for o in out]
edge = [e.view(tl, bs, *e.shape[1:]) for e in edge]
return out, edge
def loss_forward(self, out, train): # apply photometric loss
out, edge = out for s, l, o in zip(itertools.count(), self.ph_losses, out):
if not(isinstance(out, tuple) or isinstance(out, list)): im = self.data[f'im{s}']
out = [out] im = im.view(-1, *im.shape[2:])
vals = [] o = o.view(-1, *o.shape[2:])
diffs = [] std = self.data[f'std{s}']
std = std.view(-1, *std.shape[2:])
val, pattern_proj = l(o, im[:, 0:1, ...], std)
vals.append(val)
if s == 0:
self.pattern_proj = pattern_proj.detach()
# apply photometric loss # apply disparity loss
for s,l,o in zip(itertools.count(), self.ph_losses, out): # 1-edge as ground truth edge if inversed
im = self.data[f'im{s}'] edge0 = 1 - torch.sigmoid(edge[0])
im = im.view(-1, *im.shape[2:]) edge0 = edge0.view(-1, *edge0.shape[2:])
o = o.view(-1, *o.shape[2:]) out0 = out[0].view(-1, *out[0].shape[2:])
std = self.data[f'std{s}'] val = self.disparity_loss(out0, edge0)
std = std.view(-1, *std.shape[2:]) if self.dp_weight > 0:
val, pattern_proj = l(o, im[:,0:1,...], std) vals.append(val * self.dp_weight)
vals.append(val)
if s == 0:
self.pattern_proj = pattern_proj.detach()
# apply disparity loss # apply edge loss on a subset of training samples
# 1-edge as ground truth edge if inversed for s, e in zip(itertools.count(), edge):
edge0 = 1-torch.sigmoid(edge[0]) # inversed ground truth edge where 0 means edge
edge0 = edge0.view(-1, *edge0.shape[2:]) grad = self.data[f'grad{s}'] < 0.2
out0 = out[0].view(-1, *out[0].shape[2:]) grad = grad.to(torch.float32)
val = self.disparity_loss(out0, edge0) ids = self.data['id']
if self.dp_weight>0: mask = ids > self.train_edge
vals.append(val * self.dp_weight) if mask.sum() > 0:
e = e[:, mask, :]
grad = grad[:, mask, :]
e = e.view(-1, *e.shape[2:])
grad = grad.view(-1, *grad.shape[2:])
val = self.edge_loss(e, grad)
else:
val = torch.zeros_like(vals[0])
vals.append(val)
# apply edge loss on a subset of training samples if train is False:
for s,e in zip(itertools.count(), edge): return vals
# inversed ground truth edge where 0 means edge
grad = self.data[f'grad{s}']<0.2
grad = grad.to(torch.float32)
ids = self.data['id']
mask = ids>self.train_edge
if mask.sum()>0:
e = e[:,mask,:]
grad = grad[:,mask,:]
e = e.view(-1, *e.shape[2:])
grad = grad.view(-1, *grad.shape[2:])
val = self.edge_loss(e, grad)
else:
val = torch.zeros_like(vals[0])
vals.append(val)
if train is False: # apply geometric loss
return vals R = self.data['R']
t = self.data['t']
ge_num = self.track_length * (self.track_length - 1) / 2
for sidx in range(len(out)):
d2d = self.d2ds[sidx]
depth = d2d(out[sidx])
ge_loss = self.ge_losses[sidx]
imsize = self.imsizes[sidx]
for tidx0 in range(depth.shape[0]):
for tidx1 in range(tidx0 + 1, depth.shape[0]):
depth0 = depth[tidx0]
R0 = R[tidx0]
t0 = t[tidx0]
depth1 = depth[tidx1]
R1 = R[tidx1]
t1 = t[tidx1]
# apply geometric loss val = ge_loss(depth0, depth1, R0, t0, R1, t1)
R = self.data['R'] vals.append(val * self.ge_weight / ge_num)
t = self.data['t']
ge_num = self.track_length * (self.track_length-1) / 2
for sidx in range(len(out)):
d2d = self.d2ds[sidx]
depth = d2d(out[sidx])
ge_loss = self.ge_losses[sidx]
imsize = self.imsizes[sidx]
for tidx0 in range(depth.shape[0]):
for tidx1 in range(tidx0+1, depth.shape[0]):
depth0 = depth[tidx0]
R0 = R[tidx0]
t0 = t[tidx0]
depth1 = depth[tidx1]
R1 = R[tidx1]
t1 = t[tidx1]
val = ge_loss(depth0, depth1, R0, t0, R1, t1) return vals
vals.append(val * self.ge_weight / ge_num)
return vals def numpy_in_out(self, output):
output, edge = output
if not (isinstance(output, tuple) or isinstance(output, list)):
output = [output]
es = output[0].detach().to('cpu').numpy()
gt = self.data['disp0'].to('cpu').numpy().astype(np.float32)
im = self.data['im0'][:, :, 0:1, ...].detach().to('cpu').numpy()
ma = gt > 0
return es, gt, im, ma
def numpy_in_out(self, output): def write_img(self, out_path, es, gt, im, ma):
output, edge = output logging.info(f'write img {out_path}')
if not(isinstance(output, tuple) or isinstance(output, list)): u_pos, _ = np.meshgrid(range(es.shape[1]), range(es.shape[0]))
output = [output]
es = output[0].detach().to('cpu').numpy()
gt = self.data['disp0'].to('cpu').numpy().astype(np.float32)
im = self.data['im0'][:,:,0:1,...].detach().to('cpu').numpy()
ma = gt>0
return es, gt, im, ma
def write_img(self, out_path, es, gt, im, ma): diff = np.abs(es - gt)
logging.info(f'write img {out_path}')
u_pos, _ = np.meshgrid(range(es.shape[1]), range(es.shape[0]))
diff = np.abs(es - gt) vmin, vmax = np.nanmin(gt), np.nanmax(gt)
vmin = vmin - 0.2 * (vmax - vmin)
vmax = vmax + 0.2 * (vmax - vmin)
vmin, vmax = np.nanmin(gt), np.nanmax(gt) pattern_proj = self.pattern_proj.to('cpu').numpy()[0, 0]
vmin = vmin - 0.2*(vmax-vmin) im_orig = self.data['im0'].detach().to('cpu').numpy()[0, 0, 0]
vmax = vmax + 0.2*(vmax-vmin) pattern_diff = np.abs(im_orig - pattern_proj)
pattern_proj = self.pattern_proj.to('cpu').numpy()[0,0] fig = plt.figure(figsize=(16, 16))
im_orig = self.data['im0'].detach().to('cpu').numpy()[0,0,0] es0 = co.cmap.color_depth_map(es[0], scale=vmax)
pattern_diff = np.abs(im_orig - pattern_proj) gt0 = co.cmap.color_depth_map(gt[0], scale=vmax)
diff0 = co.cmap.color_error_image(diff[0], BGR=True)
fig = plt.figure(figsize=(16,16)) # plot disparities, ground truth disparity is shown only for reference
es0 = co.cmap.color_depth_map(es[0], scale=vmax) ax = plt.subplot(3, 3, 1);
gt0 = co.cmap.color_depth_map(gt[0], scale=vmax) plt.imshow(es0[..., [2, 1, 0]]);
diff0 = co.cmap.color_error_image(diff[0], BGR=True) plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F0 Disparity Est. {es0.min():.4f}/{es0.max():.4f}')
ax = plt.subplot(3, 3, 2);
plt.imshow(gt0[..., [2, 1, 0]]);
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F0 Disparity GT {np.nanmin(gt0):.4f}/{np.nanmax(gt0):.4f}')
ax = plt.subplot(3, 3, 3);
plt.imshow(diff0[..., [2, 1, 0]]);
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F0 Disparity Err. {diff0.mean():.5f}')
# plot disparities, ground truth disparity is shown only for reference # plot disparities of the second frame in the track if exists
ax = plt.subplot(3,3,1); plt.imshow(es0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Est. {es0.min():.4f}/{es0.max():.4f}') if es.shape[0] >= 2:
ax = plt.subplot(3,3,2); plt.imshow(gt0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity GT {np.nanmin(gt0):.4f}/{np.nanmax(gt0):.4f}') es1 = co.cmap.color_depth_map(es[1], scale=vmax)
ax = plt.subplot(3,3,3); plt.imshow(diff0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Err. {diff0.mean():.5f}') gt1 = co.cmap.color_depth_map(gt[1], scale=vmax)
diff1 = co.cmap.color_error_image(diff[1], BGR=True)
ax = plt.subplot(3, 3, 4);
plt.imshow(es1[..., [2, 1, 0]]);
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F1 Disparity Est. {es1.min():.4f}/{es1.max():.4f}')
ax = plt.subplot(3, 3, 5);
plt.imshow(gt1[..., [2, 1, 0]]);
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F1 Disparity GT {np.nanmin(gt1):.4f}/{np.nanmax(gt1):.4f}')
ax = plt.subplot(3, 3, 6);
plt.imshow(diff1[..., [2, 1, 0]]);
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F1 Disparity Err. {diff1.mean():.5f}')
# plot disparities of the second frame in the track if exists # plot normalized IR inputs
if es.shape[0]>=2: ax = plt.subplot(3, 3, 7);
es1 = co.cmap.color_depth_map(es[1], scale=vmax) plt.imshow(im[0], vmin=im.min(), vmax=im.max(), cmap='gray');
gt1 = co.cmap.color_depth_map(gt[1], scale=vmax) plt.xticks([]);
diff1 = co.cmap.color_error_image(diff[1], BGR=True) plt.yticks([]);
ax = plt.subplot(3,3,4); plt.imshow(es1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity Est. {es1.min():.4f}/{es1.max():.4f}') ax.set_title(f'F0 IR input {im[0].mean():.5f}/{im[0].std():.5f}')
ax = plt.subplot(3,3,5); plt.imshow(gt1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity GT {np.nanmin(gt1):.4f}/{np.nanmax(gt1):.4f}') if es.shape[0] >= 2:
ax = plt.subplot(3,3,6); plt.imshow(diff1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity Err. {diff1.mean():.5f}') ax = plt.subplot(3, 3, 8);
plt.imshow(im[1], vmin=im.min(), vmax=im.max(), cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F1 IR input {im[1].mean():.5f}/{im[1].std():.5f}')
# plot normalized IR inputs plt.tight_layout()
ax = plt.subplot(3,3,7); plt.imshow(im[0], vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 IR input {im[0].mean():.5f}/{im[0].std():.5f}') plt.savefig(str(out_path))
if es.shape[0]>=2: plt.close(fig)
ax = plt.subplot(3,3,8); plt.imshow(im[1], vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 IR input {im[1].mean():.5f}/{im[1].std():.5f}')
plt.tight_layout()
plt.savefig(str(out_path))
plt.close(fig)
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks): def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks):
if batch_idx % 512 == 0: if batch_idx % 512 == 0:
out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png' out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png'
es, gt, im, ma = self.numpy_in_out(output) es, gt, im, ma = self.numpy_in_out(output)
masks = [ m.detach().to('cpu').numpy() for m in masks ] masks = [m.detach().to('cpu').numpy() for m in masks]
self.write_img(out_path, es[:,0,0], gt[:,0,0], im[:,0,0], ma[:,0,0]) self.write_img(out_path, es[:, 0, 0], gt[:, 0, 0], im[:, 0, 0], ma[:, 0, 0])
def callback_test_start(self, epoch, set_idx): def callback_test_start(self, epoch, set_idx):
self.metric = co.metric.MultipleMetric( self.metric = co.metric.MultipleMetric(
co.metric.DistanceMetric(vec_length=1), co.metric.DistanceMetric(vec_length=1),
co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5]) co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5])
) )
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks): def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks):
es, gt, im, ma = self.numpy_in_out(output) es, gt, im, ma = self.numpy_in_out(output)
if batch_idx % 8 == 0: if batch_idx % 8 == 0:
out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png' out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png'
self.write_img(out_path, es[:,0,0], gt[:,0,0], im[:,0,0], ma[:,0,0]) self.write_img(out_path, es[:, 0, 0], gt[:, 0, 0], im[:, 0, 0], ma[:, 0, 0])
es, gt, im, ma = self.crop_output(es, gt, im, ma) es, gt, im, ma = self.crop_output(es, gt, im, ma)
es = es.reshape(-1,1) es = es.reshape(-1, 1)
gt = gt.reshape(-1,1) gt = gt.reshape(-1, 1)
ma = ma.ravel() ma = ma.ravel()
self.metric.add(es, gt, ma) self.metric.add(es, gt, ma)
def callback_test_stop(self, epoch, set_idx, loss): def callback_test_stop(self, epoch, set_idx, loss):
logging.info(f'{self.metric}') logging.info(f'{self.metric}')
for k, v in self.metric.items(): for k, v in self.metric.items():
self.metric_add_test(epoch, set_idx, k, v) self.metric_add_test(epoch, set_idx, k, v)
def crop_output(self, es, gt, im, ma):
tl = es.shape[0]
bs = es.shape[1]
es = np.reshape(es[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w])
gt = np.reshape(gt[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w])
im = np.reshape(im[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w])
ma = np.reshape(ma[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w])
return es, gt, im, ma
def crop_output(self, es, gt, im, ma):
tl = es.shape[0]
bs = es.shape[1]
es = np.reshape(es[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w])
gt = np.reshape(gt[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w])
im = np.reshape(im[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w])
ma = np.reshape(ma[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w])
return es, gt, im, ma
if __name__ == '__main__': if __name__ == '__main__':
pass pass

View File

@ -8,559 +8,572 @@ import co
class TimedModule(torch.nn.Module): class TimedModule(torch.nn.Module):
def __init__(self, mod_name): def __init__(self, mod_name):
super().__init__() super().__init__()
self.mod_name = mod_name self.mod_name = mod_name
def tforward(self, *args, **kwargs): def tforward(self, *args, **kwargs):
raise Exception('not implemented') raise Exception('not implemented')
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
torch.cuda.synchronize() torch.cuda.synchronize()
with co.gtimer.Ctx(self.mod_name): with co.gtimer.Ctx(self.mod_name):
x = self.tforward(*args, **kwargs) x = self.tforward(*args, **kwargs)
torch.cuda.synchronize() torch.cuda.synchronize()
return x return x
class PosOutput(TimedModule): class PosOutput(TimedModule):
def __init__(self, channels_in, type, im_height, im_width, alpha=1, beta=0, gamma=1, offset=0): def __init__(self, channels_in, type, im_height, im_width, alpha=1, beta=0, gamma=1, offset=0):
super().__init__(mod_name='PosOutput') super().__init__(mod_name='PosOutput')
self.im_width = im_width self.im_width = im_width
self.im_width = im_width self.im_width = im_width
if type == 'pos': if type == 'pos':
self.layer = torch.nn.Sequential( self.layer = torch.nn.Sequential(
torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1), torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1),
SigmoidAffine(alpha=alpha, beta=beta, gamma=gamma, offset=offset) SigmoidAffine(alpha=alpha, beta=beta, gamma=gamma, offset=offset)
) )
elif type == 'pos_row': elif type == 'pos_row':
self.layer = torch.nn.Sequential( self.layer = torch.nn.Sequential(
MultiLinear(im_height, channels_in, 1), MultiLinear(im_height, channels_in, 1),
SigmoidAffine(alpha=alpha, beta=beta, gamma=gamma, offset=offset) SigmoidAffine(alpha=alpha, beta=beta, gamma=gamma, offset=offset)
) )
self.u_pos = None self.u_pos = None
def tforward(self, x): def tforward(self, x):
if self.u_pos is None: if self.u_pos is None:
self.u_pos = torch.arange(x.shape[3], dtype=torch.float32).view(1,1,1,-1) self.u_pos = torch.arange(x.shape[3], dtype=torch.float32).view(1, 1, 1, -1)
self.u_pos = self.u_pos.to(x.device) self.u_pos = self.u_pos.to(x.device)
pos = self.layer(x) pos = self.layer(x)
disp = self.u_pos - pos disp = self.u_pos - pos
return disp return disp
class OutputLayerFactory(object): class OutputLayerFactory(object):
''' '''
Define type of output Define type of output
type options: type options:
linear: apply only conv channel, used for the edge decoder linear: apply only conv channel, used for the edge decoder
disp: estimate the disparity disp: estimate the disparity
disp_row: independently estimate the disparity per row disp_row: independently estimate the disparity per row
pos: estimate the absolute location pos: estimate the absolute location
pos_row: independently estimate the absolute location per row pos_row: independently estimate the absolute location per row
''' '''
def __init__(self, type='disp', params={}):
self.type = type
self.params = params
def __call__(self, channels_in, imsize): def __init__(self, type='disp', params={}):
self.type = type
self.params = params
if self.type == 'linear': def __call__(self, channels_in, imsize):
return torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1)
elif self.type == 'disp': if self.type == 'linear':
return torch.nn.Sequential( return torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1)
torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1),
SigmoidAffine(**self.params)
)
elif self.type == 'disp_row': elif self.type == 'disp':
return torch.nn.Sequential( return torch.nn.Sequential(
MultiLinear(imsize[0], channels_in, 1), torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1),
SigmoidAffine(**self.params) SigmoidAffine(**self.params)
) )
elif self.type == 'pos' or self.type == 'pos_row': elif self.type == 'disp_row':
return PosOutput(channels_in, **self.params) return torch.nn.Sequential(
MultiLinear(imsize[0], channels_in, 1),
SigmoidAffine(**self.params)
)
else: elif self.type == 'pos' or self.type == 'pos_row':
raise Exception('unknown output layer type') return PosOutput(channels_in, **self.params)
else:
raise Exception('unknown output layer type')
class SigmoidAffine(TimedModule): class SigmoidAffine(TimedModule):
def __init__(self, alpha=1, beta=0, gamma=1, offset=0): def __init__(self, alpha=1, beta=0, gamma=1, offset=0):
super().__init__(mod_name='SigmoidAffine') super().__init__(mod_name='SigmoidAffine')
self.alpha = alpha self.alpha = alpha
self.beta = beta self.beta = beta
self.gamma = gamma self.gamma = gamma
self.offset = offset self.offset = offset
def tforward(self, x): def tforward(self, x):
return torch.sigmoid(x/self.gamma - self.offset) * self.alpha + self.beta return torch.sigmoid(x / self.gamma - self.offset) * self.alpha + self.beta
class MultiLinear(TimedModule): class MultiLinear(TimedModule):
def __init__(self, n, channels_in, channels_out): def __init__(self, n, channels_in, channels_out):
super().__init__(mod_name='MultiLinear') super().__init__(mod_name='MultiLinear')
self.channels_out = channels_out self.channels_out = channels_out
self.mods = torch.nn.ModuleList() self.mods = torch.nn.ModuleList()
for idx in range(n): for idx in range(n):
self.mods.append(torch.nn.Linear(channels_in, channels_out)) self.mods.append(torch.nn.Linear(channels_in, channels_out))
def tforward(self, x):
x = x.permute(2,0,3,1) # BxCxHxW => HxBxWxC
y = x.new_empty(*x.shape[:-1], self.channels_out)
for hidx in range(x.shape[0]):
y[hidx] = self.mods[hidx](x[hidx])
y = y.permute(1,3,0,2) # HxBxWxC => BxCxHxW
return y
def tforward(self, x):
x = x.permute(2, 0, 3, 1) # BxCxHxW => HxBxWxC
y = x.new_empty(*x.shape[:-1], self.channels_out)
for hidx in range(x.shape[0]):
y[hidx] = self.mods[hidx](x[hidx])
y = y.permute(1, 3, 0, 2) # HxBxWxC => BxCxHxW
return y
class DispNetS(TimedModule): class DispNetS(TimedModule):
''' '''
Disparity Decoder based on DispNetS Disparity Decoder based on DispNetS
''' '''
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False, channel_multiplier=1):
super(DispNetS, self).__init__(mod_name='DispNetS')
self.output_ms = output_ms def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False,
self.coordconv = coordconv channel_multiplier=1):
super(DispNetS, self).__init__(mod_name='DispNetS')
conv_planes = channel_multiplier * np.array( [32, 64, 128, 256, 512, 512, 512] ) self.output_ms = output_ms
self.conv1 = self.downsample_conv(channels_in, conv_planes[0], kernel_size=7) self.coordconv = coordconv
self.conv2 = self.downsample_conv(conv_planes[0], conv_planes[1], kernel_size=5)
self.conv3 = self.downsample_conv(conv_planes[1], conv_planes[2])
self.conv4 = self.downsample_conv(conv_planes[2], conv_planes[3])
self.conv5 = self.downsample_conv(conv_planes[3], conv_planes[4])
self.conv6 = self.downsample_conv(conv_planes[4], conv_planes[5])
self.conv7 = self.downsample_conv(conv_planes[5], conv_planes[6])
upconv_planes = channel_multiplier * np.array( [512, 512, 256, 128, 64, 32, 16] ) conv_planes = channel_multiplier * np.array([32, 64, 128, 256, 512, 512, 512])
self.upconv7 = self.upconv(conv_planes[6], upconv_planes[0]) self.conv1 = self.downsample_conv(channels_in, conv_planes[0], kernel_size=7)
self.upconv6 = self.upconv(upconv_planes[0], upconv_planes[1]) self.conv2 = self.downsample_conv(conv_planes[0], conv_planes[1], kernel_size=5)
self.upconv5 = self.upconv(upconv_planes[1], upconv_planes[2]) self.conv3 = self.downsample_conv(conv_planes[1], conv_planes[2])
self.upconv4 = self.upconv(upconv_planes[2], upconv_planes[3]) self.conv4 = self.downsample_conv(conv_planes[2], conv_planes[3])
self.upconv3 = self.upconv(upconv_planes[3], upconv_planes[4]) self.conv5 = self.downsample_conv(conv_planes[3], conv_planes[4])
self.upconv2 = self.upconv(upconv_planes[4], upconv_planes[5]) self.conv6 = self.downsample_conv(conv_planes[4], conv_planes[5])
self.upconv1 = self.upconv(upconv_planes[5], upconv_planes[6]) self.conv7 = self.downsample_conv(conv_planes[5], conv_planes[6])
self.iconv7 = self.conv(upconv_planes[0] + conv_planes[5], upconv_planes[0]) upconv_planes = channel_multiplier * np.array([512, 512, 256, 128, 64, 32, 16])
self.iconv6 = self.conv(upconv_planes[1] + conv_planes[4], upconv_planes[1]) self.upconv7 = self.upconv(conv_planes[6], upconv_planes[0])
self.iconv5 = self.conv(upconv_planes[2] + conv_planes[3], upconv_planes[2]) self.upconv6 = self.upconv(upconv_planes[0], upconv_planes[1])
self.iconv4 = self.conv(upconv_planes[3] + conv_planes[2], upconv_planes[3]) self.upconv5 = self.upconv(upconv_planes[1], upconv_planes[2])
self.iconv3 = self.conv(1 + upconv_planes[4] + conv_planes[1], upconv_planes[4]) self.upconv4 = self.upconv(upconv_planes[2], upconv_planes[3])
self.iconv2 = self.conv(1 + upconv_planes[5] + conv_planes[0], upconv_planes[5]) self.upconv3 = self.upconv(upconv_planes[3], upconv_planes[4])
self.iconv1 = self.conv(1 + upconv_planes[6], upconv_planes[6]) self.upconv2 = self.upconv(upconv_planes[4], upconv_planes[5])
self.upconv1 = self.upconv(upconv_planes[5], upconv_planes[6])
if isinstance(output_facs, list): self.iconv7 = self.conv(upconv_planes[0] + conv_planes[5], upconv_planes[0])
self.predict_disp4 = output_facs[3](upconv_planes[3], imsizes[3]) self.iconv6 = self.conv(upconv_planes[1] + conv_planes[4], upconv_planes[1])
self.predict_disp3 = output_facs[2](upconv_planes[4], imsizes[2]) self.iconv5 = self.conv(upconv_planes[2] + conv_planes[3], upconv_planes[2])
self.predict_disp2 = output_facs[1](upconv_planes[5], imsizes[1]) self.iconv4 = self.conv(upconv_planes[3] + conv_planes[2], upconv_planes[3])
self.predict_disp1 = output_facs[0](upconv_planes[6], imsizes[0]) self.iconv3 = self.conv(1 + upconv_planes[4] + conv_planes[1], upconv_planes[4])
else: self.iconv2 = self.conv(1 + upconv_planes[5] + conv_planes[0], upconv_planes[5])
self.predict_disp4 = output_facs(upconv_planes[3], imsizes[3]) self.iconv1 = self.conv(1 + upconv_planes[6], upconv_planes[6])
self.predict_disp3 = output_facs(upconv_planes[4], imsizes[2])
self.predict_disp2 = output_facs(upconv_planes[5], imsizes[1])
self.predict_disp1 = output_facs(upconv_planes[6], imsizes[0])
if isinstance(output_facs, list):
self.predict_disp4 = output_facs[3](upconv_planes[3], imsizes[3])
self.predict_disp3 = output_facs[2](upconv_planes[4], imsizes[2])
self.predict_disp2 = output_facs[1](upconv_planes[5], imsizes[1])
self.predict_disp1 = output_facs[0](upconv_planes[6], imsizes[0])
else:
self.predict_disp4 = output_facs(upconv_planes[3], imsizes[3])
self.predict_disp3 = output_facs(upconv_planes[4], imsizes[2])
self.predict_disp2 = output_facs(upconv_planes[5], imsizes[1])
self.predict_disp1 = output_facs(upconv_planes[6], imsizes[0])
def init_weights(self): def init_weights(self):
for m in self.modules(): for m in self.modules():
if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.ConvTranspose2d): if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.ConvTranspose2d):
torch.nn.init.xavier_uniform_(m.weight, gain=0.1) torch.nn.init.xavier_uniform_(m.weight, gain=0.1)
if m.bias is not None: if m.bias is not None:
torch.nn.init.zeros_(m.bias) torch.nn.init.zeros_(m.bias)
def downsample_conv(self, in_planes, out_planes, kernel_size=3): def downsample_conv(self, in_planes, out_planes, kernel_size=3):
if self.coordconv: if self.coordconv:
conv = torchext.CoordConv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=(kernel_size-1)//2) conv = torchext.CoordConv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2,
else: padding=(kernel_size - 1) // 2)
conv = torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=(kernel_size-1)//2) else:
return torch.nn.Sequential( conv = torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2,
conv, padding=(kernel_size - 1) // 2)
torch.nn.ReLU(inplace=True), return torch.nn.Sequential(
torch.nn.Conv2d(out_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size-1)//2), conv,
torch.nn.ReLU(inplace=True) torch.nn.ReLU(inplace=True),
) torch.nn.Conv2d(out_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size - 1) // 2),
torch.nn.ReLU(inplace=True)
)
def conv(self, in_planes, out_planes): def conv(self, in_planes, out_planes):
return torch.nn.Sequential( return torch.nn.Sequential(
torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1), torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1),
torch.nn.ReLU(inplace=True) torch.nn.ReLU(inplace=True)
) )
def upconv(self, in_planes, out_planes): def upconv(self, in_planes, out_planes):
return torch.nn.Sequential( return torch.nn.Sequential(
torch.nn.ConvTranspose2d(in_planes, out_planes, kernel_size=3, stride=2, padding=1, output_padding=1), torch.nn.ConvTranspose2d(in_planes, out_planes, kernel_size=3, stride=2, padding=1, output_padding=1),
torch.nn.ReLU(inplace=True) torch.nn.ReLU(inplace=True)
) )
def crop_like(self, input, ref): def crop_like(self, input, ref):
assert(input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3)) assert (input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3))
return input[:, :, :ref.size(2), :ref.size(3)] return input[:, :, :ref.size(2), :ref.size(3)]
def tforward(self, x): def tforward(self, x):
out_conv1 = self.conv1(x) out_conv1 = self.conv1(x)
out_conv2 = self.conv2(out_conv1) out_conv2 = self.conv2(out_conv1)
out_conv3 = self.conv3(out_conv2) out_conv3 = self.conv3(out_conv2)
out_conv4 = self.conv4(out_conv3) out_conv4 = self.conv4(out_conv3)
out_conv5 = self.conv5(out_conv4) out_conv5 = self.conv5(out_conv4)
out_conv6 = self.conv6(out_conv5) out_conv6 = self.conv6(out_conv5)
out_conv7 = self.conv7(out_conv6) out_conv7 = self.conv7(out_conv6)
out_upconv7 = self.crop_like(self.upconv7(out_conv7), out_conv6) out_upconv7 = self.crop_like(self.upconv7(out_conv7), out_conv6)
concat7 = torch.cat((out_upconv7, out_conv6), 1) concat7 = torch.cat((out_upconv7, out_conv6), 1)
out_iconv7 = self.iconv7(concat7) out_iconv7 = self.iconv7(concat7)
out_upconv6 = self.crop_like(self.upconv6(out_iconv7), out_conv5) out_upconv6 = self.crop_like(self.upconv6(out_iconv7), out_conv5)
concat6 = torch.cat((out_upconv6, out_conv5), 1) concat6 = torch.cat((out_upconv6, out_conv5), 1)
out_iconv6 = self.iconv6(concat6) out_iconv6 = self.iconv6(concat6)
out_upconv5 = self.crop_like(self.upconv5(out_iconv6), out_conv4) out_upconv5 = self.crop_like(self.upconv5(out_iconv6), out_conv4)
concat5 = torch.cat((out_upconv5, out_conv4), 1) concat5 = torch.cat((out_upconv5, out_conv4), 1)
out_iconv5 = self.iconv5(concat5) out_iconv5 = self.iconv5(concat5)
out_upconv4 = self.crop_like(self.upconv4(out_iconv5), out_conv3) out_upconv4 = self.crop_like(self.upconv4(out_iconv5), out_conv3)
concat4 = torch.cat((out_upconv4, out_conv3), 1) concat4 = torch.cat((out_upconv4, out_conv3), 1)
out_iconv4 = self.iconv4(concat4) out_iconv4 = self.iconv4(concat4)
disp4 = self.predict_disp4(out_iconv4) disp4 = self.predict_disp4(out_iconv4)
out_upconv3 = self.crop_like(self.upconv3(out_iconv4), out_conv2) out_upconv3 = self.crop_like(self.upconv3(out_iconv4), out_conv2)
disp4_up = self.crop_like(torch.nn.functional.interpolate(disp4, scale_factor=2, mode='bilinear', align_corners=False), out_conv2) disp4_up = self.crop_like(
concat3 = torch.cat((out_upconv3, out_conv2, disp4_up), 1) torch.nn.functional.interpolate(disp4, scale_factor=2, mode='bilinear', align_corners=False), out_conv2)
out_iconv3 = self.iconv3(concat3) concat3 = torch.cat((out_upconv3, out_conv2, disp4_up), 1)
disp3 = self.predict_disp3(out_iconv3) out_iconv3 = self.iconv3(concat3)
disp3 = self.predict_disp3(out_iconv3)
out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1) out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1)
disp3_up = self.crop_like(torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1) disp3_up = self.crop_like(
concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1) torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
out_iconv2 = self.iconv2(concat2) concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1)
disp2 = self.predict_disp2(out_iconv2) out_iconv2 = self.iconv2(concat2)
disp2 = self.predict_disp2(out_iconv2)
out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x) out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x)
disp2_up = self.crop_like(torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x) disp2_up = self.crop_like(
concat1 = torch.cat((out_upconv1, disp2_up), 1) torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x)
out_iconv1 = self.iconv1(concat1) concat1 = torch.cat((out_upconv1, disp2_up), 1)
disp1 = self.predict_disp1(out_iconv1) out_iconv1 = self.iconv1(concat1)
disp1 = self.predict_disp1(out_iconv1)
if self.output_ms: if self.output_ms:
return disp1, disp2, disp3, disp4 return disp1, disp2, disp3, disp4
else: else:
return disp1 return disp1
class DispNetShallow(DispNetS): class DispNetShallow(DispNetS):
''' '''
Edge Decoder based on DispNetS with fewer layers Edge Decoder based on DispNetS with fewer layers
''' '''
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False):
super(DispNetShallow, self).__init__(channels_in, imsizes, output_facs, output_ms, coordconv, weight_init)
self.mod_name = 'DispNetShallow'
conv_planes = [32, 64, 128, 256, 512, 512, 512]
upconv_planes = [512, 512, 256, 128, 64, 32, 16]
self.iconv3 = self.conv(upconv_planes[4] + conv_planes[1], upconv_planes[4])
def tforward(self, x): def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False):
out_conv1 = self.conv1(x) super(DispNetShallow, self).__init__(channels_in, imsizes, output_facs, output_ms, coordconv, weight_init)
out_conv2 = self.conv2(out_conv1) self.mod_name = 'DispNetShallow'
out_conv3 = self.conv3(out_conv2) conv_planes = [32, 64, 128, 256, 512, 512, 512]
upconv_planes = [512, 512, 256, 128, 64, 32, 16]
self.iconv3 = self.conv(upconv_planes[4] + conv_planes[1], upconv_planes[4])
out_upconv3 = self.crop_like(self.upconv3(out_conv3), out_conv2) def tforward(self, x):
concat3 = torch.cat((out_upconv3, out_conv2), 1) out_conv1 = self.conv1(x)
out_iconv3 = self.iconv3(concat3) out_conv2 = self.conv2(out_conv1)
disp3 = self.predict_disp3(out_iconv3) out_conv3 = self.conv3(out_conv2)
out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1) out_upconv3 = self.crop_like(self.upconv3(out_conv3), out_conv2)
disp3_up = self.crop_like(torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1) concat3 = torch.cat((out_upconv3, out_conv2), 1)
concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1) out_iconv3 = self.iconv3(concat3)
out_iconv2 = self.iconv2(concat2) disp3 = self.predict_disp3(out_iconv3)
disp2 = self.predict_disp2(out_iconv2)
out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x) out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1)
disp2_up = self.crop_like(torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x) disp3_up = self.crop_like(
concat1 = torch.cat((out_upconv1, disp2_up), 1) torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
out_iconv1 = self.iconv1(concat1) concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1)
disp1 = self.predict_disp1(out_iconv1) out_iconv2 = self.iconv2(concat2)
disp2 = self.predict_disp2(out_iconv2)
if self.output_ms: out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x)
return disp1, disp2, disp3 disp2_up = self.crop_like(
else: torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x)
return disp1 concat1 = torch.cat((out_upconv1, disp2_up), 1)
out_iconv1 = self.iconv1(concat1)
disp1 = self.predict_disp1(out_iconv1)
if self.output_ms:
return disp1, disp2, disp3
else:
return disp1
class DispEdgeDecoders(TimedModule): class DispEdgeDecoders(TimedModule):
''' '''
Disparity Decoder and Edge Decoder Disparity Decoder and Edge Decoder
''' '''
def __init__(self, *args, max_disp=128, **kwargs):
super(DispEdgeDecoders, self).__init__(mod_name='DispEdgeDecoders')
output_facs = [OutputLayerFactory( type='disp', params={ 'alpha': max_disp/(2**s), 'beta': 0, 'gamma': 1, 'offset': 3}) for s in range(4)] def __init__(self, *args, max_disp=128, **kwargs):
self.disp_decoder = DispNetS(*args, output_facs=output_facs, **kwargs) super(DispEdgeDecoders, self).__init__(mod_name='DispEdgeDecoders')
output_facs = [OutputLayerFactory( type='linear' ) for s in range(4)] output_facs = [
self.edge_decoder = DispNetShallow(*args, output_facs=output_facs, **kwargs) OutputLayerFactory(type='disp', params={'alpha': max_disp / (2 ** s), 'beta': 0, 'gamma': 1, 'offset': 3})
for s in range(4)]
self.disp_decoder = DispNetS(*args, output_facs=output_facs, **kwargs)
def tforward(self, x): output_facs = [OutputLayerFactory(type='linear') for s in range(4)]
disp = self.disp_decoder(x) self.edge_decoder = DispNetShallow(*args, output_facs=output_facs, **kwargs)
edge = self.edge_decoder(x)
return disp, edge def tforward(self, x):
disp = self.disp_decoder(x)
edge = self.edge_decoder(x)
return disp, edge
class DispToDepth(TimedModule): class DispToDepth(TimedModule):
def __init__(self, focal_length, baseline): def __init__(self, focal_length, baseline):
super().__init__(mod_name='DispToDepth') super().__init__(mod_name='DispToDepth')
self.baseline_focal_length = baseline * focal_length self.baseline_focal_length = baseline * focal_length
def tforward(self, disp): def tforward(self, disp):
disp = torch.nn.functional.relu(disp) + 1e-12 disp = torch.nn.functional.relu(disp) + 1e-12
depth = self.baseline_focal_length / disp depth = self.baseline_focal_length / disp
return depth return depth
class PosToDepth(DispToDepth): class PosToDepth(DispToDepth):
def __init__(self, focal_length, baseline, im_height, im_width): def __init__(self, focal_length, baseline, im_height, im_width):
super().__init__(focal_length, baseline) super().__init__(focal_length, baseline)
self.mod_name = 'PosToDepth' self.mod_name = 'PosToDepth'
self.im_height = im_height self.im_height = im_height
self.im_width = im_width self.im_width = im_width
self.u_pos = torch.arange(im_width, dtype=torch.float32).view(1,1,1,-1) self.u_pos = torch.arange(im_width, dtype=torch.float32).view(1, 1, 1, -1)
def tforward(self, pos):
self.u_pos = self.u_pos.to(pos.device)
disp = self.u_pos - pos
return super().forward(disp)
def tforward(self, pos):
self.u_pos = self.u_pos.to(pos.device)
disp = self.u_pos - pos
return super().forward(disp)
class RectifiedPatternSimilarityLoss(TimedModule): class RectifiedPatternSimilarityLoss(TimedModule):
''' '''
Photometric Loss Photometric Loss
''' '''
def __init__(self, im_height, im_width, pattern, loss_type='census_sad', loss_eps=0.5):
super().__init__(mod_name='RectifiedPatternSimilarityLoss')
self.im_height = im_height
self.im_width = im_width
self.pattern = pattern.mean(dim=1, keepdim=True).contiguous()
u, v = np.meshgrid(range(im_width), range(im_height)) def __init__(self, im_height, im_width, pattern, loss_type='census_sad', loss_eps=0.5):
uv0 = np.stack((u,v), axis=2).reshape(-1,1) super().__init__(mod_name='RectifiedPatternSimilarityLoss')
uv0 = uv0.astype(np.float32).reshape(1,-1,2) self.im_height = im_height
self.uv0 = torch.from_numpy(uv0) self.im_width = im_width
self.pattern = pattern.mean(dim=1, keepdim=True).contiguous()
self.loss_type = loss_type u, v = np.meshgrid(range(im_width), range(im_height))
self.loss_eps = loss_eps uv0 = np.stack((u, v), axis=2).reshape(-1, 1)
uv0 = uv0.astype(np.float32).reshape(1, -1, 2)
self.uv0 = torch.from_numpy(uv0)
def tforward(self, disp0, im, std=None): self.loss_type = loss_type
self.pattern = self.pattern.to(disp0.device) self.loss_eps = loss_eps
self.uv0 = self.uv0.to(disp0.device)
uv0 = self.uv0.expand(disp0.shape[0], *self.uv0.shape[1:]) def tforward(self, disp0, im, std=None):
uv1 = torch.empty_like(uv0) self.pattern = self.pattern.to(disp0.device)
uv1[...,0] = uv0[...,0] - disp0.contiguous().view(disp0.shape[0],-1) self.uv0 = self.uv0.to(disp0.device)
uv1[...,1] = uv0[...,1]
uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width-1) - 0.5) uv0 = self.uv0.expand(disp0.shape[0], *self.uv0.shape[1:])
uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height-1) - 0.5) uv1 = torch.empty_like(uv0)
uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone() uv1[..., 0] = uv0[..., 0] - disp0.contiguous().view(disp0.shape[0], -1)
pattern = self.pattern.expand(disp0.shape[0], *self.pattern.shape[1:]) uv1[..., 1] = uv0[..., 1]
pattern_proj = torch.nn.functional.grid_sample(pattern, uv1, padding_mode='border')
mask = torch.ones_like(im) uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width - 1) - 0.5)
if std is not None: uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height - 1) - 0.5)
mask = mask*std uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone()
pattern = self.pattern.expand(disp0.shape[0], *self.pattern.shape[1:])
pattern_proj = torch.nn.functional.grid_sample(pattern, uv1, padding_mode='border')
mask = torch.ones_like(im)
if std is not None:
mask = mask * std
diff = torchext.photometric_loss(pattern_proj.contiguous(), im.contiguous(), 9, self.loss_type, self.loss_eps)
val = (mask * diff).sum() / mask.sum()
return val, pattern_proj
diff = torchext.photometric_loss(pattern_proj.contiguous(), im.contiguous(), 9, self.loss_type, self.loss_eps)
val = (mask*diff).sum() / mask.sum()
return val, pattern_proj
class DisparityLoss(TimedModule): class DisparityLoss(TimedModule):
''' '''
Disparity Loss Disparity Loss
''' '''
def __init__(self):
super().__init__(mod_name='DisparityLoss')
self.sobel = SobelFilter(norm=False)
#if not edge_gt: def __init__(self):
self.b0=0.0503428816795 super().__init__(mod_name='DisparityLoss')
self.b1=1.07274045944 self.sobel = SobelFilter(norm=False)
#else:
# self.b0=0.0587115108967
# self.b1=1.51931190491
def tforward(self, disp, edge=None): # if not edge_gt:
self.sobel=self.sobel.to(disp.device) self.b0 = 0.0503428816795
self.b1 = 1.07274045944
# else:
# self.b0=0.0587115108967
# self.b1=1.51931190491
if edge is not None: def tforward(self, disp, edge=None):
grad = self.sobel(disp) self.sobel = self.sobel.to(disp.device)
grad = torch.sqrt(grad[:,0:1,...]**2 + grad[:,1:2,...]**2 + 1e-8)
pdf = (1-edge)/self.b0 * torch.exp(-torch.abs(grad)/self.b0) + \
edge/self.b1 * torch.exp(-torch.abs(grad)/self.b1)
val = torch.mean(-torch.log(pdf.clamp(min=1e-4)))
else:
# on qifeng's data we don't have ambient info
# therefore we supress edge everywhere
grad = self.sobel(disp)
grad = torch.sqrt(grad[:,0:1,...]**2 + grad[:,1:2,...]**2 + 1e-8)
grad= torch.clamp(grad, 0, 1.0)
val = torch.mean(grad)
return val if edge is not None:
grad = self.sobel(disp)
grad = torch.sqrt(grad[:, 0:1, ...] ** 2 + grad[:, 1:2, ...] ** 2 + 1e-8)
pdf = (1 - edge) / self.b0 * torch.exp(-torch.abs(grad) / self.b0) + \
edge / self.b1 * torch.exp(-torch.abs(grad) / self.b1)
val = torch.mean(-torch.log(pdf.clamp(min=1e-4)))
else:
# on qifeng's data we don't have ambient info
# therefore we supress edge everywhere
grad = self.sobel(disp)
grad = torch.sqrt(grad[:, 0:1, ...] ** 2 + grad[:, 1:2, ...] ** 2 + 1e-8)
grad = torch.clamp(grad, 0, 1.0)
val = torch.mean(grad)
return val
class ProjectionBaseLoss(TimedModule): class ProjectionBaseLoss(TimedModule):
''' '''
Base module of the Geometric Loss Base module of the Geometric Loss
''' '''
def __init__(self, K, Ki, im_height, im_width):
super().__init__(mod_name='ProjectionBaseLoss')
self.K = K.view(-1,3,3) def __init__(self, K, Ki, im_height, im_width):
super().__init__(mod_name='ProjectionBaseLoss')
self.im_height = im_height self.K = K.view(-1, 3, 3)
self.im_width = im_width
u, v = np.meshgrid(range(im_width), range(im_height)) self.im_height = im_height
uv = np.stack((u,v,np.ones_like(u)), axis=2).reshape(-1,3) self.im_width = im_width
ray = uv @ Ki.numpy().T u, v = np.meshgrid(range(im_width), range(im_height))
uv = np.stack((u, v, np.ones_like(u)), axis=2).reshape(-1, 3)
ray = ray.reshape(1,-1,3).astype(np.float32) ray = uv @ Ki.numpy().T
self.ray = torch.from_numpy(ray)
def transform(self, xyz, R=None, t=None): ray = ray.reshape(1, -1, 3).astype(np.float32)
if t is not None: self.ray = torch.from_numpy(ray)
bs = xyz.shape[0]
xyz = xyz - t.reshape(bs,1,3)
if R is not None:
xyz = torch.bmm(xyz, R)
return xyz
def unproject(self, depth, R=None, t=None): def transform(self, xyz, R=None, t=None):
self.ray = self.ray.to(depth.device) if t is not None:
bs = depth.shape[0] bs = xyz.shape[0]
xyz = xyz - t.reshape(bs, 1, 3)
if R is not None:
xyz = torch.bmm(xyz, R)
return xyz
xyz = depth.reshape(bs,-1,1) * self.ray def unproject(self, depth, R=None, t=None):
xyz = self.transform(xyz, R, t) self.ray = self.ray.to(depth.device)
return xyz bs = depth.shape[0]
def project(self, xyz, R, t): xyz = depth.reshape(bs, -1, 1) * self.ray
self.K = self.K.to(xyz.device) xyz = self.transform(xyz, R, t)
bs = xyz.shape[0] return xyz
xyz = torch.bmm(xyz, R.transpose(1,2)) def project(self, xyz, R, t):
xyz = xyz + t.reshape(bs,1,3) self.K = self.K.to(xyz.device)
bs = xyz.shape[0]
Kt = self.K.transpose(1,2).expand(bs,-1,-1) xyz = torch.bmm(xyz, R.transpose(1, 2))
uv = torch.bmm(xyz, Kt) xyz = xyz + t.reshape(bs, 1, 3)
d = uv[:,:,2:3] Kt = self.K.transpose(1, 2).expand(bs, -1, -1)
uv = torch.bmm(xyz, Kt)
# avoid division by zero d = uv[:, :, 2:3]
uv = uv[:,:,:2] / (torch.nn.functional.relu(d) + 1e-12)
return uv, d
# avoid division by zero
uv = uv[:, :, :2] / (torch.nn.functional.relu(d) + 1e-12)
return uv, d
def tforward(self, depth0, R0, t0, R1, t1): def tforward(self, depth0, R0, t0, R1, t1):
xyz = self.unproject(depth0, R0, t0) xyz = self.unproject(depth0, R0, t0)
return self.project(xyz, R1, t1) return self.project(xyz, R1, t1)
class ProjectionDepthSimilarityLoss(ProjectionBaseLoss): class ProjectionDepthSimilarityLoss(ProjectionBaseLoss):
''' '''
Geometric Loss Geometric Loss
''' '''
def __init__(self, *args, clamp=-1):
super().__init__(*args)
self.mod_name = 'ProjectionDepthSimilarityLoss'
self.clamp = clamp
def fwd(self, depth0, depth1, R0, t0, R1, t1): def __init__(self, *args, clamp=-1):
uv1, d1 = super().tforward(depth0, R0, t0, R1, t1) super().__init__(*args)
self.mod_name = 'ProjectionDepthSimilarityLoss'
self.clamp = clamp
uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width-1) - 0.5) def fwd(self, depth0, depth1, R0, t0, R1, t1):
uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height-1) - 0.5) uv1, d1 = super().tforward(depth0, R0, t0, R1, t1)
uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone()
depth10 = torch.nn.functional.grid_sample(depth1, uv1, padding_mode='border') uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width - 1) - 0.5)
uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height - 1) - 0.5)
uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone()
diff = torch.abs(d1.view(-1) - depth10.view(-1)) depth10 = torch.nn.functional.grid_sample(depth1, uv1, padding_mode='border')
if self.clamp > 0: diff = torch.abs(d1.view(-1) - depth10.view(-1))
diff = torch.clamp(diff, 0, self.clamp)
# return diff without clamping for debugging if self.clamp > 0:
return diff.mean() diff = torch.clamp(diff, 0, self.clamp)
def tforward(self, depth0, depth1, R0, t0, R1, t1): # return diff without clamping for debugging
l0 = self.fwd(depth0, depth1, R0, t0, R1, t1) return diff.mean()
l1 = self.fwd(depth1, depth0, R1, t1, R0, t0)
return l0+l1
def tforward(self, depth0, depth1, R0, t0, R1, t1):
l0 = self.fwd(depth0, depth1, R0, t0, R1, t1)
l1 = self.fwd(depth1, depth0, R1, t1, R0, t0)
return l0 + l1
class LCN(TimedModule): class LCN(TimedModule):
''' '''
Local Contract Normalization Local Contract Normalization
''' '''
def __init__(self, radius, epsilon):
super().__init__(mod_name='LCN')
self.box_conv = torch.nn.Sequential(
torch.nn.ReflectionPad2d(radius),
torch.nn.Conv2d(1, 1, kernel_size=2*radius+1, bias=False)
)
self.box_conv[1].weight.requires_grad=False
self.box_conv[1].weight.fill_(1.)
self.epsilon = epsilon def __init__(self, radius, epsilon):
self.radius = radius super().__init__(mod_name='LCN')
self.box_conv = torch.nn.Sequential(
torch.nn.ReflectionPad2d(radius),
torch.nn.Conv2d(1, 1, kernel_size=2 * radius + 1, bias=False)
)
self.box_conv[1].weight.requires_grad = False
self.box_conv[1].weight.fill_(1.)
def tforward(self, data): self.epsilon = epsilon
boxs = self.box_conv(data) self.radius = radius
avgs = boxs / (2*self.radius+1)**2 def tforward(self, data):
boxs_n2 = boxs**2 boxs = self.box_conv(data)
boxs_2n = self.box_conv(data**2)
stds = torch.sqrt(boxs_2n / (2*self.radius+1)**2 - avgs**2 + 1e-6) avgs = boxs / (2 * self.radius + 1) ** 2
stds = stds + self.epsilon boxs_n2 = boxs ** 2
boxs_2n = self.box_conv(data ** 2)
return (data - avgs) / stds, stds stds = torch.sqrt(boxs_2n / (2 * self.radius + 1) ** 2 - avgs ** 2 + 1e-6)
stds = stds + self.epsilon
return (data - avgs) / stds, stds
class SobelFilter(TimedModule): class SobelFilter(TimedModule):
''' '''
Sobel Filter Sobel Filter
''' '''
def __init__(self, norm=False):
super(SobelFilter, self).__init__(mod_name='SobelFilter')
kx = np.array([[-5, -4, 0, 4, 5],
[-8, -10, 0, 10, 8],
[-10, -20, 0, 20, 10],
[-8, -10, 0, 10, 8],
[-5, -4, 0, 4, 5]])/240.0
ky = kx.copy().transpose(1,0)
self.conv_x=torch.nn.Conv2d(1, 1, kernel_size=5, stride=1, padding=0, bias=False) def __init__(self, norm=False):
self.conv_x.weight=torch.nn.Parameter(torch.from_numpy(kx).float().unsqueeze(0).unsqueeze(0)) super(SobelFilter, self).__init__(mod_name='SobelFilter')
kx = np.array([[-5, -4, 0, 4, 5],
[-8, -10, 0, 10, 8],
[-10, -20, 0, 20, 10],
[-8, -10, 0, 10, 8],
[-5, -4, 0, 4, 5]]) / 240.0
ky = kx.copy().transpose(1, 0)
self.conv_y=torch.nn.Conv2d(1, 1, kernel_size=5, stride=1, padding=0, bias=False) self.conv_x = torch.nn.Conv2d(1, 1, kernel_size=5, stride=1, padding=0, bias=False)
self.conv_y.weight=torch.nn.Parameter(torch.from_numpy(ky).float().unsqueeze(0).unsqueeze(0)) self.conv_x.weight = torch.nn.Parameter(torch.from_numpy(kx).float().unsqueeze(0).unsqueeze(0))
self.norm=norm self.conv_y = torch.nn.Conv2d(1, 1, kernel_size=5, stride=1, padding=0, bias=False)
self.conv_y.weight = torch.nn.Parameter(torch.from_numpy(ky).float().unsqueeze(0).unsqueeze(0))
def tforward(self,x): self.norm = norm
x = F.pad(x, (2,2,2,2), "replicate")
gx = self.conv_x(x)
gy = self.conv_y(x)
if self.norm:
return torch.sqrt(gx**2 + gy**2 + 1e-8)
else:
return torch.cat((gx, gy), dim=1)
def tforward(self, x):
x = F.pad(x, (2, 2, 2, 2), "replicate")
gx = self.conv_x(x)
gy = self.conv_y(x)
if self.norm:
return torch.sqrt(gx ** 2 + gy ** 2 + 1e-8)
else:
return torch.cat((gx, gy), dim=1)

View File

@ -6,7 +6,9 @@ This repository contains the code for the paper
**[Connecting the Dots: Learning Representations for Active Monocular Depth Estimation](http://www.cvlibs.net/publications/Riegler2019CVPR.pdf)** **[Connecting the Dots: Learning Representations for Active Monocular Depth Estimation](http://www.cvlibs.net/publications/Riegler2019CVPR.pdf)**
<br> <br>
[Gernot Riegler](https://griegler.github.io/), [Yiyi Liao](https://yiyiliao.github.io/), [Simon Donne](https://avg.is.tuebingen.mpg.de/person/sdonne), [Vladlen Koltun](http://vladlen.info/), and [Andreas Geiger](http://www.cvlibs.net/) [Gernot Riegler](https://griegler.github.io/), [Yiyi Liao](https://yiyiliao.github.io/)
, [Simon Donne](https://avg.is.tuebingen.mpg.de/person/sdonne), [Vladlen Koltun](http://vladlen.info/),
and [Andreas Geiger](http://www.cvlibs.net/)
<br> <br>
[CVPR 2019](http://cvpr2019.thecvf.com/) [CVPR 2019](http://cvpr2019.thecvf.com/)
@ -24,40 +26,45 @@ If you find this code useful for your research, please cite
} }
``` ```
## Dependencies ## Dependencies
The network training/evaluation code is based on `Pytorch`. The network training/evaluation code is based on `Pytorch`.
``` ```
PyTorch>=1.1 PyTorch>=1.1
Cuda>=10.0 Cuda>=10.0
``` ```
Updated on 07.06.2021: The code is now compatible with the latest Pytorch version (1.8). Updated on 07.06.2021: The code is now compatible with the latest Pytorch version (1.8).
The other python packages can be installed with `anaconda`: The other python packages can be installed with `anaconda`:
``` ```
conda install --file requirements.txt conda install --file requirements.txt
``` ```
### Structured Light Renderer ### Structured Light Renderer
To train and evaluate our method in a controlled setting, we implemented an structured light renderer.
It can be used to render a virtual scene (arbitrary triangle mesh) with the structured light pattern projected from a customizable projector location. To train and evaluate our method in a controlled setting, we implemented an structured light renderer. It can be used to
To build it, first make sure the correct `CUDA_LIBRARY_PATH` is set in `config.json`. render a virtual scene (arbitrary triangle mesh) with the structured light pattern projected from a customizable
Afterwards, the renderer can be build by running `make` within the `renderer` directory. projector location. To build it, first make sure the correct `CUDA_LIBRARY_PATH` is set in `config.json`. Afterwards,
the renderer can be build by running `make` within the `renderer` directory.
### PyTorch Extensions ### PyTorch Extensions
The network training/evaluation code is based on `PyTorch`.
We implemented some custom layers that need to be built in the `torchext` directory. The network training/evaluation code is based on `PyTorch`. We implemented some custom layers that need to be built in
Simply change into this directory and run the `torchext` directory. Simply change into this directory and run
``` ```
python setup.py build_ext --inplace python setup.py build_ext --inplace
``` ```
### Baseline HyperDepth ### Baseline HyperDepth
As baseline we partially re-implemented the random forest based method [HyperDepth](http://openaccess.thecvf.com/content_cvpr_2016/papers/Fanello_HyperDepth_Learning_Depth_CVPR_2016_paper.pdf).
The code resided in the `hyperdepth` directory and is implemented in `C++11` with a Python wrapper written in `Cython`. As baseline we partially re-implemented the random forest based
To build it change into the directory and run method [HyperDepth](http://openaccess.thecvf.com/content_cvpr_2016/papers/Fanello_HyperDepth_Learning_Depth_CVPR_2016_paper.pdf)
. The code resided in the `hyperdepth` directory and is implemented in `C++11` with a Python wrapper written in `Cython`
. To build it change into the directory and run
``` ```
python setup.py build_ext --inplace python setup.py build_ext --inplace
@ -65,42 +72,59 @@ python setup.py build_ext --inplace
## Running ## Running
### Creating Synthetic Data ### Creating Synthetic Data
To create synthetic data and save it locally, download [ShapeNet V2](https://www.shapenet.org/) and correct `SHAPENET_ROOT` in `config.json`. Then the data can be generated and saved to `DATA_ROOT` in `config.json` by running
To create synthetic data and save it locally, download [ShapeNet V2](https://www.shapenet.org/) and
correct `SHAPENET_ROOT` in `config.json`. Then the data can be generated and saved to `DATA_ROOT` in `config.json` by
running
``` ```
./create_syn_data.sh ./create_syn_data.sh
``` ```
If you are only interested in evaluating our pre-trained model, [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip) is a validation set that contains a small amount of images.
If you are only interested in evaluating our pre-trained
model, [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip) is a
validation set that contains a small amount of images.
### Training Network ### Training Network
As a first stage, it is recommended to train the disparity decoder and edge decoder without the geometric loss. To train the network on synthetic data for the first stage run As a first stage, it is recommended to train the disparity decoder and edge decoder without the geometric loss. To train
the network on synthetic data for the first stage run
``` ```
python train_val.py python train_val.py
``` ```
After the model is pretrained without the geometric loss, the full model can be trained from the initialized weights by running After the model is pretrained without the geometric loss, the full model can be trained from the initialized weights by
running
``` ```
python train_val.py --loss phge python train_val.py --loss phge
``` ```
### Evaluating Network ### Evaluating Network
To evaluate a specific checkpoint, e.g. the 50th epoch, one can run
To evaluate a specific checkpoint, e.g. the 50th epoch, one can run
``` ```
python train_val.py --cmd retest --epoch 50 python train_val.py --cmd retest --epoch 50
``` ```
### Evaluating a Pre-trained Model ### Evaluating a Pre-trained Model
We provide a model pre-trained using the photometric loss. Once you have prepared the synthetic dataset and changed `DATA_ROOT` in `config.json`, the pre-trained model can be evaluated on the validation set by running:
We provide a model pre-trained using the photometric loss. Once you have prepared the synthetic dataset and
changed `DATA_ROOT` in `config.json`, the pre-trained model can be evaluated on the validation set by running:
``` ```
mkdir -p output mkdir -p output
mkdir -p output/exp_syn mkdir -p output/exp_syn
wget -O output/exp_syn/net_0099.params https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/net_0099.params wget -O output/exp_syn/net_0099.params https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/net_0099.params
python train_val.py --cmd retest --epoch 99 python train_val.py --cmd retest --epoch 99
``` ```
You can also download our validation set from [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip).
## Acknowledgement You can also download our validation set
from [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip).
## Acknowledgement
This work was supported by the Intel Network on Intelligent Systems. This work was supported by the Intel Network on Intelligent Systems.

View File

@ -10,7 +10,7 @@ import json
this_dir = os.path.dirname(__file__) this_dir = os.path.dirname(__file__)
with open('../config.json') as fp: with open('../config.json') as fp:
config = json.load(fp) config = json.load(fp)
extra_compile_args = ['-O3', '-std=c++11'] extra_compile_args = ['-O3', '-std=c++11']
@ -20,7 +20,7 @@ cuda_lib = 'cudart'
sources = ['cyrender.pyx'] sources = ['cyrender.pyx']
extra_objects = [ extra_objects = [
os.path.join(this_dir, 'render/render_cpu.cpp.o'), os.path.join(this_dir, 'render/render_cpu.cpp.o'),
] ]
library_dirs = [] library_dirs = []
libraries = ['m'] libraries = ['m']
@ -30,20 +30,20 @@ library_dirs.append(cuda_lib_dir)
libraries.append(cuda_lib) libraries.append(cuda_lib)
setup( setup(
name="cyrender", name="cyrender",
cmdclass= {'build_ext': build_ext}, cmdclass={'build_ext': build_ext},
ext_modules=[ ext_modules=[
Extension('cyrender', Extension('cyrender',
sources, sources,
extra_objects=extra_objects, extra_objects=extra_objects,
language='c++', language='c++',
library_dirs=library_dirs, library_dirs=library_dirs,
libraries=libraries, libraries=libraries,
include_dirs=[ include_dirs=[
np.get_include(), np.get_include(),
], ],
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
# extra_link_args=extra_link_args # extra_link_args=extra_link_args
) )
] ]
) )

View File

@ -2,65 +2,65 @@ import torch
import torch.utils.data import torch.utils.data
import numpy as np import numpy as np
class TestSet(object): class TestSet(object):
def __init__(self, name, dset, test_frequency=1): def __init__(self, name, dset, test_frequency=1):
self.name = name self.name = name
self.dset = dset self.dset = dset
self.test_frequency = test_frequency self.test_frequency = test_frequency
class TestSets(list): class TestSets(list):
def append(self, name, dset, test_frequency=1): def append(self, name, dset, test_frequency=1):
super().append(TestSet(name, dset, test_frequency)) super().append(TestSet(name, dset, test_frequency))
class MultiDataset(torch.utils.data.Dataset): class MultiDataset(torch.utils.data.Dataset):
def __init__(self, *datasets): def __init__(self, *datasets):
self.current_epoch = 0 self.current_epoch = 0
self.datasets = [] self.datasets = []
self.cum_n_samples = [0] self.cum_n_samples = [0]
for dataset in datasets: for dataset in datasets:
self.append(dataset) self.append(dataset)
def append(self, dataset): def append(self, dataset):
self.datasets.append(dataset) self.datasets.append(dataset)
self.__update_cum_n_samples(dataset) self.__update_cum_n_samples(dataset)
def __update_cum_n_samples(self, dataset): def __update_cum_n_samples(self, dataset):
n_samples = self.cum_n_samples[-1] + len(dataset) n_samples = self.cum_n_samples[-1] + len(dataset)
self.cum_n_samples.append(n_samples) self.cum_n_samples.append(n_samples)
def dataset_updated(self): def dataset_updated(self):
self.cum_n_samples = [0] self.cum_n_samples = [0]
for dset in self.datasets: for dset in self.datasets:
self.__update_cum_n_samples(dset) self.__update_cum_n_samples(dset)
def __len__(self): def __len__(self):
return self.cum_n_samples[-1] return self.cum_n_samples[-1]
def __getitem__(self, idx):
didx = np.searchsorted(self.cum_n_samples, idx, side='right') - 1
sidx = idx - self.cum_n_samples[didx]
return self.datasets[didx][sidx]
def __getitem__(self, idx):
didx = np.searchsorted(self.cum_n_samples, idx, side='right') - 1
sidx = idx - self.cum_n_samples[didx]
return self.datasets[didx][sidx]
class BaseDataset(torch.utils.data.Dataset): class BaseDataset(torch.utils.data.Dataset):
def __init__(self, train=True, fix_seed_per_epoch=False): def __init__(self, train=True, fix_seed_per_epoch=False):
self.current_epoch = 0 self.current_epoch = 0
self.train = train self.train = train
self.fix_seed_per_epoch = fix_seed_per_epoch self.fix_seed_per_epoch = fix_seed_per_epoch
def get_rng(self, idx): def get_rng(self, idx):
rng = np.random.RandomState() rng = np.random.RandomState()
if self.train: if self.train:
if self.fix_seed_per_epoch: if self.fix_seed_per_epoch:
seed = 1 * len(self) + idx seed = 1 * len(self) + idx
else: else:
seed = (self.current_epoch + 1) * len(self) + idx seed = (self.current_epoch + 1) * len(self) + idx
rng.seed(seed) rng.seed(seed)
else: else:
rng.seed(idx) rng.seed(idx)
return rng return rng

View File

@ -2,146 +2,151 @@ import torch
from . import ext_cpu from . import ext_cpu
from . import ext_cuda from . import ext_cuda
class NNFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, in0, in1):
args = (in0, in1)
if in0.is_cuda:
out = ext_cuda.nn_cuda(*args)
else:
out = ext_cpu.nn_cpu(*args)
return out
@staticmethod class NNFunction(torch.autograd.Function):
def backward(ctx, grad_out): @staticmethod
return None, None def forward(ctx, in0, in1):
args = (in0, in1)
if in0.is_cuda:
out = ext_cuda.nn_cuda(*args)
else:
out = ext_cpu.nn_cpu(*args)
return out
@staticmethod
def backward(ctx, grad_out):
return None, None
def nn(in0, in1): def nn(in0, in1):
return NNFunction.apply(in0, in1) return NNFunction.apply(in0, in1)
class CrossCheckFunction(torch.autograd.Function): class CrossCheckFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, in0, in1): def forward(ctx, in0, in1):
args = (in0, in1) args = (in0, in1)
if in0.is_cuda: if in0.is_cuda:
out = ext_cuda.crosscheck_cuda(*args) out = ext_cuda.crosscheck_cuda(*args)
else: else:
out = ext_cpu.crosscheck_cpu(*args) out = ext_cpu.crosscheck_cpu(*args)
return out return out
@staticmethod
def backward(ctx, grad_out):
return None, None
@staticmethod
def backward(ctx, grad_out):
return None, None
def crosscheck(in0, in1): def crosscheck(in0, in1):
return CrossCheckFunction.apply(in0, in1) return CrossCheckFunction.apply(in0, in1)
class ProjNNFunction(torch.autograd.Function): class ProjNNFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, xyz0, xyz1, K, patch_size): def forward(ctx, xyz0, xyz1, K, patch_size):
args = (xyz0, xyz1, K, patch_size) args = (xyz0, xyz1, K, patch_size)
if xyz0.is_cuda: if xyz0.is_cuda:
out = ext_cuda.proj_nn_cuda(*args) out = ext_cuda.proj_nn_cuda(*args)
else: else:
out = ext_cpu.proj_nn_cpu(*args) out = ext_cpu.proj_nn_cpu(*args)
return out return out
@staticmethod
def backward(ctx, grad_out):
return None, None, None, None
@staticmethod
def backward(ctx, grad_out):
return None, None, None, None
def proj_nn(xyz0, xyz1, K, patch_size): def proj_nn(xyz0, xyz1, K, patch_size):
return ProjNNFunction.apply(xyz0, xyz1, K, patch_size) return ProjNNFunction.apply(xyz0, xyz1, K, patch_size)
class XCorrVolFunction(torch.autograd.Function): class XCorrVolFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, in0, in1, n_disps, block_size): def forward(ctx, in0, in1, n_disps, block_size):
args = (in0, in1, n_disps, block_size) args = (in0, in1, n_disps, block_size)
if in0.is_cuda: if in0.is_cuda:
out = ext_cuda.xcorrvol_cuda(*args) out = ext_cuda.xcorrvol_cuda(*args)
else: else:
out = ext_cpu.xcorrvol_cpu(*args) out = ext_cpu.xcorrvol_cpu(*args)
return out return out
@staticmethod
def backward(ctx, grad_out):
return None, None, None, None
@staticmethod
def backward(ctx, grad_out):
return None, None, None, None
def xcorrvol(in0, in1, n_disps, block_size): def xcorrvol(in0, in1, n_disps, block_size):
return XCorrVolFunction.apply(in0, in1, n_disps, block_size) return XCorrVolFunction.apply(in0, in1, n_disps, block_size)
class PhotometricLossFunction(torch.autograd.Function): class PhotometricLossFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, es, ta, block_size, type, eps): def forward(ctx, es, ta, block_size, type, eps):
args = (es, ta, block_size, type, eps) args = (es, ta, block_size, type, eps)
ctx.save_for_backward(es, ta) ctx.save_for_backward(es, ta)
ctx.block_size = block_size ctx.block_size = block_size
ctx.type = type ctx.type = type
ctx.eps = eps ctx.eps = eps
if es.is_cuda: if es.is_cuda:
out = ext_cuda.photometric_loss_forward(*args) out = ext_cuda.photometric_loss_forward(*args)
else: else:
out = ext_cpu.photometric_loss_forward(*args) out = ext_cpu.photometric_loss_forward(*args)
return out return out
@staticmethod
def backward(ctx, grad_out):
es, ta = ctx.saved_tensors
block_size = ctx.block_size
type = ctx.type
eps = ctx.eps
args = (es, ta, grad_out.contiguous(), block_size, type, eps)
if grad_out.is_cuda:
grad_es = ext_cuda.photometric_loss_backward(*args)
else:
grad_es = ext_cpu.photometric_loss_backward(*args)
return grad_es, None, None, None, None
@staticmethod
def backward(ctx, grad_out):
es, ta = ctx.saved_tensors
block_size = ctx.block_size
type = ctx.type
eps = ctx.eps
args = (es, ta, grad_out.contiguous(), block_size, type, eps)
if grad_out.is_cuda:
grad_es = ext_cuda.photometric_loss_backward(*args)
else:
grad_es = ext_cpu.photometric_loss_backward(*args)
return grad_es, None, None, None, None
def photometric_loss(es, ta, block_size, type='mse', eps=0.1): def photometric_loss(es, ta, block_size, type='mse', eps=0.1):
type = type.lower() type = type.lower()
if type == 'mse': if type == 'mse':
type = 0 type = 0
elif type == 'sad': elif type == 'sad':
type = 1 type = 1
elif type == 'census_mse': elif type == 'census_mse':
type = 2 type = 2
elif type == 'census_sad': elif type == 'census_sad':
type = 3 type = 3
else: else:
raise Exception('invalid loss type') raise Exception('invalid loss type')
return PhotometricLossFunction.apply(es, ta, block_size, type, eps) return PhotometricLossFunction.apply(es, ta, block_size, type, eps)
def photometric_loss_pytorch(es, ta, block_size, type='mse', eps=0.1): def photometric_loss_pytorch(es, ta, block_size, type='mse', eps=0.1):
type = type.lower() type = type.lower()
p = block_size // 2 p = block_size // 2
es_pad = torch.nn.functional.pad(es, (p,p,p,p), mode='replicate') es_pad = torch.nn.functional.pad(es, (p, p, p, p), mode='replicate')
ta_pad = torch.nn.functional.pad(ta, (p,p,p,p), mode='replicate') ta_pad = torch.nn.functional.pad(ta, (p, p, p, p), mode='replicate')
es_uf = torch.nn.functional.unfold(es_pad, kernel_size=block_size) es_uf = torch.nn.functional.unfold(es_pad, kernel_size=block_size)
ta_uf = torch.nn.functional.unfold(ta_pad, kernel_size=block_size) ta_uf = torch.nn.functional.unfold(ta_pad, kernel_size=block_size)
es_uf = es_uf.view(es.shape[0], es.shape[1], -1, es.shape[2], es.shape[3]) es_uf = es_uf.view(es.shape[0], es.shape[1], -1, es.shape[2], es.shape[3])
ta_uf = ta_uf.view(ta.shape[0], ta.shape[1], -1, ta.shape[2], ta.shape[3]) ta_uf = ta_uf.view(ta.shape[0], ta.shape[1], -1, ta.shape[2], ta.shape[3])
if type == 'mse': if type == 'mse':
ref = (es_uf - ta_uf)**2 ref = (es_uf - ta_uf) ** 2
elif type == 'sad': elif type == 'sad':
ref = torch.abs(es_uf - ta_uf) ref = torch.abs(es_uf - ta_uf)
elif type == 'census_mse' or type == 'census_sad': elif type == 'census_mse' or type == 'census_sad':
des = es_uf - es.unsqueeze(2) des = es_uf - es.unsqueeze(2)
dta = ta_uf - ta.unsqueeze(2) dta = ta_uf - ta.unsqueeze(2)
h_des = 0.5 * (1 + des / torch.sqrt(des * des + eps)) h_des = 0.5 * (1 + des / torch.sqrt(des * des + eps))
h_dta = 0.5 * (1 + dta / torch.sqrt(dta * dta + eps)) h_dta = 0.5 * (1 + dta / torch.sqrt(dta * dta + eps))
diff = h_des - h_dta diff = h_des - h_dta
if type == 'census_mse': if type == 'census_mse':
ref = diff * diff ref = diff * diff
elif type == 'census_sad': elif type == 'census_sad':
ref = torch.abs(diff) ref = torch.abs(diff)
else: else:
raise Exception('invalid loss type') raise Exception('invalid loss type')
ref = ref.view(es.shape[0], -1, es.shape[2], es.shape[3]) ref = ref.view(es.shape[0], -1, es.shape[2], es.shape[3])
ref = torch.sum(ref, dim=1, keepdim=True) / block_size**2 ref = torch.sum(ref, dim=1, keepdim=True) / block_size ** 2
return ref return ref

View File

@ -4,24 +4,26 @@ import numpy as np
from .functions import * from .functions import *
class CoordConv2d(torch.nn.Module): class CoordConv2d(torch.nn.Module):
def __init__(self, channels_in, channels_out, kernel_size, stride, padding): def __init__(self, channels_in, channels_out, kernel_size, stride, padding):
super().__init__() super().__init__()
self.conv = torch.nn.Conv2d(channels_in+2, channels_out, kernel_size=kernel_size, padding=padding, stride=stride) self.conv = torch.nn.Conv2d(channels_in + 2, channels_out, kernel_size=kernel_size, padding=padding,
stride=stride)
self.uv = None self.uv = None
def forward(self, x): def forward(self, x):
if self.uv is None: if self.uv is None:
height, width = x.shape[2], x.shape[3] height, width = x.shape[2], x.shape[3]
u, v = np.meshgrid(range(width), range(height)) u, v = np.meshgrid(range(width), range(height))
u = 2 * u / (width - 1) - 1 u = 2 * u / (width - 1) - 1
v = 2 * v / (height - 1) - 1 v = 2 * v / (height - 1) - 1
uv = np.stack((u, v)).reshape(1, 2, height, width) uv = np.stack((u, v)).reshape(1, 2, height, width)
self.uv = torch.from_numpy( uv.astype(np.float32) ) self.uv = torch.from_numpy(uv.astype(np.float32))
self.uv = self.uv.to(x.device) self.uv = self.uv.to(x.device)
uv = self.uv.expand(x.shape[0], *self.uv.shape[1:]) uv = self.uv.expand(x.shape[0], *self.uv.shape[1:])
xuv = torch.cat((x, uv), dim=1) xuv = torch.cat((x, uv), dim=1)
y = self.conv(xuv) y = self.conv(xuv)
return y return y

View File

@ -11,11 +11,12 @@ nvcc_args = [
] ]
setup( setup(
name='ext', name='ext',
ext_modules=[ ext_modules=[
CppExtension('ext_cpu', ['ext/ext_cpu.cpp']), CppExtension('ext_cpu', ['ext/ext_cpu.cpp']),
CUDAExtension('ext_cuda', ['ext/ext_cuda.cpp', 'ext/ext_kernel.cu'], extra_compile_args={'cxx': [], 'nvcc': nvcc_args}), CUDAExtension('ext_cuda', ['ext/ext_cuda.cpp', 'ext/ext_kernel.cu'],
], extra_compile_args={'cxx': [], 'nvcc': nvcc_args}),
cmdclass={'build_ext': BuildExtension}, ],
include_dirs=include_dirs cmdclass={'build_ext': BuildExtension},
include_dirs=include_dirs
) )

View File

@ -17,512 +17,516 @@ from collections import OrderedDict
class StopWatch(object): class StopWatch(object):
def __init__(self): def __init__(self):
self.timings = OrderedDict() self.timings = OrderedDict()
self.starts = {} self.starts = {}
def start(self, name): def start(self, name):
self.starts[name] = time.time() self.starts[name] = time.time()
def stop(self, name): def stop(self, name):
if name not in self.timings: if name not in self.timings:
self.timings[name] = [] self.timings[name] = []
self.timings[name].append(time.time() - self.starts[name]) self.timings[name].append(time.time() - self.starts[name])
def get(self, name=None, reduce=np.sum): def get(self, name=None, reduce=np.sum):
if name is not None: if name is not None:
return reduce(self.timings[name]) return reduce(self.timings[name])
else: else:
ret = {} ret = {}
for k in self.timings: for k in self.timings:
ret[k] = reduce(self.timings[k]) ret[k] = reduce(self.timings[k])
return ret return ret
def __repr__(self): def __repr__(self):
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()]) return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
def __str__(self):
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()]) def __str__(self):
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
class ETA(object): class ETA(object):
def __init__(self, length): def __init__(self, length):
self.length = length self.length = length
self.start_time = time.time() self.start_time = time.time()
self.current_idx = 0 self.current_idx = 0
self.current_time = time.time() self.current_time = time.time()
def update(self, idx): def update(self, idx):
self.current_idx = idx self.current_idx = idx
self.current_time = time.time() self.current_time = time.time()
def get_elapsed_time(self): def get_elapsed_time(self):
return self.current_time - self.start_time return self.current_time - self.start_time
def get_item_time(self): def get_item_time(self):
return self.get_elapsed_time() / (self.current_idx + 1) return self.get_elapsed_time() / (self.current_idx + 1)
def get_remaining_time(self): def get_remaining_time(self):
return self.get_item_time() * (self.length - self.current_idx + 1) return self.get_item_time() * (self.length - self.current_idx + 1)
def format_time(self, seconds): def format_time(self, seconds):
minutes, seconds = divmod(seconds, 60) minutes, seconds = divmod(seconds, 60)
hours, minutes = divmod(minutes, 60) hours, minutes = divmod(minutes, 60)
hours = int(hours) hours = int(hours)
minutes = int(minutes) minutes = int(minutes)
return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}' return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}'
def get_elapsed_time_str(self): def get_elapsed_time_str(self):
return self.format_time(self.get_elapsed_time()) return self.format_time(self.get_elapsed_time())
def get_remaining_time_str(self):
return self.format_time(self.get_remaining_time())
def get_remaining_time_str(self):
return self.format_time(self.get_remaining_time())
class Worker(object): class Worker(object):
def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16, num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1): def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16,
self.out_root = Path(out_root) num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1):
self.experiment_name = experiment_name self.out_root = Path(out_root)
self.epochs = epochs self.experiment_name = experiment_name
self.seed = seed self.epochs = epochs
self.train_batch_size = train_batch_size self.seed = seed
self.test_batch_size = test_batch_size self.train_batch_size = train_batch_size
self.num_workers = num_workers self.test_batch_size = test_batch_size
self.save_frequency = save_frequency self.num_workers = num_workers
self.train_device = train_device self.save_frequency = save_frequency
self.test_device = test_device self.train_device = train_device
self.max_train_iter = max_train_iter self.test_device = test_device
self.max_train_iter = max_train_iter
self.errs_list=[] self.errs_list = []
self.setup_experiment() self.setup_experiment()
def setup_experiment(self): def setup_experiment(self):
self.exp_out_root = self.out_root / self.experiment_name self.exp_out_root = self.out_root / self.experiment_name
self.exp_out_root.mkdir(parents=True, exist_ok=True) self.exp_out_root.mkdir(parents=True, exist_ok=True)
if logging.root: del logging.root.handlers[:] if logging.root: del logging.root.handlers[:]
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
handlers=[ handlers=[
logging.FileHandler( str(self.exp_out_root / 'train.log') ), logging.FileHandler(str(self.exp_out_root / 'train.log')),
logging.StreamHandler() logging.StreamHandler()
], ],
format='%(relativeCreated)d:%(levelname)s:%(process)d-%(processName)s: %(message)s' format='%(relativeCreated)d:%(levelname)s:%(process)d-%(processName)s: %(message)s'
) )
logging.info('='*80) logging.info('=' * 80)
logging.info(f'Start of experiment: {self.experiment_name}') logging.info(f'Start of experiment: {self.experiment_name}')
logging.info(socket.gethostname()) logging.info(socket.gethostname())
self.log_datetime() self.log_datetime()
logging.info('='*80) logging.info('=' * 80)
self.metric_path = self.exp_out_root / 'metrics.json' self.metric_path = self.exp_out_root / 'metrics.json'
if self.metric_path.exists(): if self.metric_path.exists():
with open(str(self.metric_path), 'r') as fp: with open(str(self.metric_path), 'r') as fp:
self.metric_data = json.load(fp) self.metric_data = json.load(fp)
else: else:
self.metric_data = {} self.metric_data = {}
self.init_seed() self.init_seed()
def metric_add_train(self, epoch, key, val): def metric_add_train(self, epoch, key, val):
epoch = str(epoch) epoch = str(epoch)
key = str(key) key = str(key)
if epoch not in self.metric_data: if epoch not in self.metric_data:
self.metric_data[epoch] = {} self.metric_data[epoch] = {}
if 'train' not in self.metric_data[epoch]: if 'train' not in self.metric_data[epoch]:
self.metric_data[epoch]['train'] = {} self.metric_data[epoch]['train'] = {}
self.metric_data[epoch]['train'][key] = val self.metric_data[epoch]['train'][key] = val
def metric_add_test(self, epoch, set_idx, key, val): def metric_add_test(self, epoch, set_idx, key, val):
epoch = str(epoch) epoch = str(epoch)
set_idx = str(set_idx) set_idx = str(set_idx)
key = str(key) key = str(key)
if epoch not in self.metric_data: if epoch not in self.metric_data:
self.metric_data[epoch] = {} self.metric_data[epoch] = {}
if 'test' not in self.metric_data[epoch]: if 'test' not in self.metric_data[epoch]:
self.metric_data[epoch]['test'] = {} self.metric_data[epoch]['test'] = {}
if set_idx not in self.metric_data[epoch]['test']: if set_idx not in self.metric_data[epoch]['test']:
self.metric_data[epoch]['test'][set_idx] = {} self.metric_data[epoch]['test'][set_idx] = {}
self.metric_data[epoch]['test'][set_idx][key] = val self.metric_data[epoch]['test'][set_idx][key] = val
def metric_save(self): def metric_save(self):
with open(str(self.metric_path), 'w') as fp: with open(str(self.metric_path), 'w') as fp:
json.dump(self.metric_data, fp, indent=2) json.dump(self.metric_data, fp, indent=2)
def init_seed(self, seed=None): def init_seed(self, seed=None):
if seed is not None: if seed is not None:
self.seed = seed self.seed = seed
logging.info(f'Set seed to {self.seed}') logging.info(f'Set seed to {self.seed}')
np.random.seed(self.seed) np.random.seed(self.seed)
random.seed(self.seed) random.seed(self.seed)
torch.manual_seed(self.seed) torch.manual_seed(self.seed)
torch.cuda.manual_seed(self.seed) torch.cuda.manual_seed(self.seed)
def log_datetime(self): def log_datetime(self):
logging.info(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')) logging.info(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
def mem_report(self): def mem_report(self):
for obj in gc.get_objects(): for obj in gc.get_objects():
if torch.is_tensor(obj): if torch.is_tensor(obj):
print(type(obj), obj.shape) print(type(obj), obj.shape)
def get_net_path(self, epoch, root=None): def get_net_path(self, epoch, root=None):
if root is None: if root is None:
root = self.exp_out_root root = self.exp_out_root
return root / f'net_{epoch:04d}.params' return root / f'net_{epoch:04d}.params'
def get_do_parser_cmds(self): def get_do_parser_cmds(self):
return ['retrain', 'resume', 'retest', 'test_init'] return ['retrain', 'resume', 'retest', 'test_init']
def get_do_parser(self): def get_do_parser(self):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--cmd', type=str, default='resume', choices=self.get_do_parser_cmds()) parser.add_argument('--cmd', type=str, default='resume', choices=self.get_do_parser_cmds())
parser.add_argument('--epoch', type=int, default=-1) parser.add_argument('--epoch', type=int, default=-1)
return parser return parser
def do_cmd(self, args, net, optimizer, scheduler=None): def do_cmd(self, args, net, optimizer, scheduler=None):
if args.cmd == 'retrain': if args.cmd == 'retrain':
self.train(net, optimizer, resume=False, scheduler=scheduler) self.train(net, optimizer, resume=False, scheduler=scheduler)
elif args.cmd == 'resume': elif args.cmd == 'resume':
self.train(net, optimizer, resume=True, scheduler=scheduler) self.train(net, optimizer, resume=True, scheduler=scheduler)
elif args.cmd == 'retest': elif args.cmd == 'retest':
self.retest(net, epoch=args.epoch) self.retest(net, epoch=args.epoch)
elif args.cmd == 'test_init': elif args.cmd == 'test_init':
test_sets = self.get_test_sets() test_sets = self.get_test_sets()
self.test(-1, net, test_sets) self.test(-1, net, test_sets)
else: else:
raise Exception('invalid cmd') raise Exception('invalid cmd')
def do(self, net, optimizer, load_net_optimizer=None, scheduler=None): def do(self, net, optimizer, load_net_optimizer=None, scheduler=None):
parser = self.get_do_parser() parser = self.get_do_parser()
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
if load_net_optimizer is not None and args.cmd not in ['schedule']: if load_net_optimizer is not None and args.cmd not in ['schedule']:
net, optimizer = load_net_optimizer() net, optimizer = load_net_optimizer()
self.do_cmd(args, net, optimizer, scheduler=scheduler) self.do_cmd(args, net, optimizer, scheduler=scheduler)
def retest(self, net, epoch=-1): def retest(self, net, epoch=-1):
if epoch < 0: if epoch < 0:
epochs = range(self.epochs) epochs = range(self.epochs)
else: else:
epochs = [epoch] epochs = [epoch]
test_sets = self.get_test_sets() test_sets = self.get_test_sets()
for epoch in epochs: for epoch in epochs:
net_path = self.get_net_path(epoch) net_path = self.get_net_path(epoch)
if net_path.exists(): if net_path.exists():
state_dict = torch.load(str(net_path)) state_dict = torch.load(str(net_path))
net.load_state_dict(state_dict) net.load_state_dict(state_dict)
self.test(epoch, net, test_sets) self.test(epoch, net, test_sets)
def format_err_str(self, errs, div=1): def format_err_str(self, errs, div=1):
err = sum(errs) err = sum(errs)
if len(errs) > 1: if len(errs) > 1:
err_str = f'{err/div:0.4f}=' + '+'.join([f'{e/div:0.4f}' for e in errs]) err_str = f'{err / div:0.4f}=' + '+'.join([f'{e / div:0.4f}' for e in errs])
else: else:
err_str = f'{err/div:0.4f}' err_str = f'{err / div:0.4f}'
return err_str return err_str
def write_err_img(self): def write_err_img(self):
err_img_path = self.exp_out_root / 'errs.png' err_img_path = self.exp_out_root / 'errs.png'
fig = plt.figure(figsize=(16,16)) fig = plt.figure(figsize=(16, 16))
lines=[] lines = []
for idx,errs in enumerate(self.errs_list): for idx, errs in enumerate(self.errs_list):
line,=plt.plot(range(len(errs)), errs, label=f'error{idx}') line, = plt.plot(range(len(errs)), errs, label=f'error{idx}')
lines.append(line) lines.append(line)
plt.tight_layout() plt.tight_layout()
plt.legend(handles=lines) plt.legend(handles=lines)
plt.savefig(str(err_img_path)) plt.savefig(str(err_img_path))
plt.close(fig) plt.close(fig)
def callback_train_new_epoch(self, epoch, net, optimizer):
def callback_train_new_epoch(self, epoch, net, optimizer):
pass
def train(self, net, optimizer, resume=False, scheduler=None):
logging.info('='*80)
logging.info('Start training')
self.log_datetime()
logging.info('='*80)
train_set = self.get_train_set()
test_sets = self.get_test_sets()
net = net.to(self.train_device)
epoch = 0
min_err = {ts.name: 1e9 for ts in test_sets}
state_path = self.exp_out_root / 'state.dict'
if resume and state_path.exists():
logging.info('='*80)
logging.info(f'Loading state from {state_path}')
logging.info('='*80)
state = torch.load(str(state_path))
epoch = state['epoch'] + 1
if 'min_err' in state:
min_err = state['min_err']
curr_state = net.state_dict()
curr_state.update(state['state_dict'])
net.load_state_dict(curr_state)
try:
optimizer.load_state_dict(state['optimizer'])
except:
logging.info('Warning: cannot load optimizer from state_dict')
pass pass
if 'cpu_rng_state' in state:
torch.set_rng_state(state['cpu_rng_state'])
if 'gpu_rng_state' in state:
torch.cuda.set_rng_state(state['gpu_rng_state'])
for epoch in range(epoch, self.epochs): def train(self, net, optimizer, resume=False, scheduler=None):
self.callback_train_new_epoch(epoch, net, optimizer) logging.info('=' * 80)
logging.info('Start training')
self.log_datetime()
logging.info('=' * 80)
# train epoch train_set = self.get_train_set()
self.train_epoch(epoch, net, optimizer, train_set) test_sets = self.get_test_sets()
# test epoch
errs = self.test(epoch, net, test_sets)
if (epoch + 1) % self.save_frequency == 0:
net = net.to(self.train_device) net = net.to(self.train_device)
# store state epoch = 0
state_dict = { min_err = {ts.name: 1e9 for ts in test_sets}
'epoch': epoch,
'min_err': min_err,
'state_dict': net.state_dict(),
'optimizer': optimizer.state_dict(),
'cpu_rng_state': torch.get_rng_state(),
'gpu_rng_state': torch.cuda.get_rng_state(),
}
logging.info(f'save state to {state_path}')
state_path = self.exp_out_root / 'state.dict' state_path = self.exp_out_root / 'state.dict'
torch.save(state_dict, str(state_path)) if resume and state_path.exists():
logging.info('=' * 80)
logging.info(f'Loading state from {state_path}')
logging.info('=' * 80)
state = torch.load(str(state_path))
epoch = state['epoch'] + 1
if 'min_err' in state:
min_err = state['min_err']
for test_set_name in errs: curr_state = net.state_dict()
err = sum(errs[test_set_name]) curr_state.update(state['state_dict'])
if err < min_err[test_set_name]: net.load_state_dict(curr_state)
min_err[test_set_name] = err
state_path = self.exp_out_root / f'state_set{test_set_name}_best.dict'
logging.info(f'save state to {state_path}')
torch.save(state_dict, str(state_path))
# store network try:
net_path = self.get_net_path(epoch) optimizer.load_state_dict(state['optimizer'])
logging.info(f'save network to {net_path}') except:
torch.save(net.state_dict(), str(net_path)) logging.info('Warning: cannot load optimizer from state_dict')
pass
if 'cpu_rng_state' in state:
torch.set_rng_state(state['cpu_rng_state'])
if 'gpu_rng_state' in state:
torch.cuda.set_rng_state(state['gpu_rng_state'])
if scheduler is not None: for epoch in range(epoch, self.epochs):
scheduler.step() self.callback_train_new_epoch(epoch, net, optimizer)
logging.info('='*80) # train epoch
logging.info('Finished training') self.train_epoch(epoch, net, optimizer, train_set)
self.log_datetime()
logging.info('='*80)
def get_train_set(self): # test epoch
# returns train_set errs = self.test(epoch, net, test_sets)
raise NotImplementedError()
def get_test_sets(self): if (epoch + 1) % self.save_frequency == 0:
# returns test_sets net = net.to(self.train_device)
raise NotImplementedError()
def copy_data(self, data, device, requires_grad, train): # store state
raise NotImplementedError() state_dict = {
'epoch': epoch,
'min_err': min_err,
'state_dict': net.state_dict(),
'optimizer': optimizer.state_dict(),
'cpu_rng_state': torch.get_rng_state(),
'gpu_rng_state': torch.cuda.get_rng_state(),
}
logging.info(f'save state to {state_path}')
state_path = self.exp_out_root / 'state.dict'
torch.save(state_dict, str(state_path))
def net_forward(self, net, train): for test_set_name in errs:
raise NotImplementedError() err = sum(errs[test_set_name])
if err < min_err[test_set_name]:
min_err[test_set_name] = err
state_path = self.exp_out_root / f'state_set{test_set_name}_best.dict'
logging.info(f'save state to {state_path}')
torch.save(state_dict, str(state_path))
def loss_forward(self, output, train): # store network
raise NotImplementedError() net_path = self.get_net_path(epoch)
logging.info(f'save network to {net_path}')
torch.save(net.state_dict(), str(net_path))
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks): if scheduler is not None:
# err = False scheduler.step()
# for name, param in net.named_parameters():
# if not torch.isfinite(param.grad).all():
# print(name)
# err = True
# if err:
# import ipdb; ipdb.set_trace()
pass
def callback_train_start(self, epoch): logging.info('=' * 80)
pass logging.info('Finished training')
self.log_datetime()
logging.info('=' * 80)
def callback_train_stop(self, epoch, loss): def get_train_set(self):
pass # returns train_set
raise NotImplementedError()
def train_epoch(self, epoch, net, optimizer, dset): def get_test_sets(self):
self.callback_train_start(epoch) # returns test_sets
stopwatch = StopWatch() raise NotImplementedError()
logging.info('='*80) def copy_data(self, data, device, requires_grad, train):
logging.info('Train epoch %d' % epoch) raise NotImplementedError()
dset.current_epoch = epoch def net_forward(self, net, train):
train_loader = torch.utils.data.DataLoader(dset, batch_size=self.train_batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True, pin_memory=False) raise NotImplementedError()
net = net.to(self.train_device) def loss_forward(self, output, train):
net.train() raise NotImplementedError()
mean_loss = None def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks):
# err = False
# for name, param in net.named_parameters():
# if not torch.isfinite(param.grad).all():
# print(name)
# err = True
# if err:
# import ipdb; ipdb.set_trace()
pass
n_batches = self.max_train_iter if self.max_train_iter > 0 else len(train_loader) def callback_train_start(self, epoch):
bar = ETA(length=n_batches) pass
stopwatch.start('total') def callback_train_stop(self, epoch, loss):
stopwatch.start('data') pass
for batch_idx, data in enumerate(train_loader):
if self.max_train_iter > 0 and batch_idx > self.max_train_iter: break
self.copy_data(data, device=self.train_device, requires_grad=True, train=True)
stopwatch.stop('data')
optimizer.zero_grad() def train_epoch(self, epoch, net, optimizer, dset):
self.callback_train_start(epoch)
stopwatch = StopWatch()
stopwatch.start('forward') logging.info('=' * 80)
output = self.net_forward(net, train=True) logging.info('Train epoch %d' % epoch)
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('forward')
stopwatch.start('loss') dset.current_epoch = epoch
errs = self.loss_forward(output, train=True) train_loader = torch.utils.data.DataLoader(dset, batch_size=self.train_batch_size, shuffle=True,
if isinstance(errs, dict): num_workers=self.num_workers, drop_last=True, pin_memory=False)
masks = errs['masks']
errs = errs['errs']
else:
masks = []
if not isinstance(errs, list) and not isinstance(errs, tuple):
errs = [errs]
err = sum(errs)
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('loss')
stopwatch.start('backward') net = net.to(self.train_device)
err.backward() net.train()
self.callback_train_post_backward(net, errs, output, epoch, batch_idx, masks)
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('backward')
stopwatch.start('optimizer') mean_loss = None
optimizer.step()
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('optimizer')
bar.update(batch_idx) n_batches = self.max_train_iter if self.max_train_iter > 0 else len(train_loader)
if (epoch <= 1 and batch_idx < 128) or batch_idx % 16 == 0: bar = ETA(length=n_batches)
err_str = self.format_err_str(errs)
logging.info(f'train e{epoch}: {batch_idx+1}/{len(train_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
#self.write_err_img()
if mean_loss is None:
mean_loss = [0 for e in errs]
for erridx, err in enumerate(errs):
mean_loss[erridx] += err.item()
stopwatch.start('data')
stopwatch.stop('total')
logging.info('timings: %s' % stopwatch)
mean_loss = [l / len(train_loader) for l in mean_loss]
self.callback_train_stop(epoch, mean_loss)
self.metric_add_train(epoch, 'loss', mean_loss)
# save metrics
self.metric_save()
err_str = self.format_err_str(mean_loss)
logging.info(f'avg train_loss={err_str}')
return mean_loss
def callback_test_start(self, epoch, set_idx):
pass
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks):
pass
def callback_test_stop(self, epoch, set_idx, loss):
pass
def test(self, epoch, net, test_sets):
errs = {}
for test_set_idx, test_set in enumerate(test_sets):
if (epoch + 1) % test_set.test_frequency == 0:
logging.info('='*80)
logging.info(f'testing set {test_set.name}')
err = self.test_epoch(epoch, test_set_idx, net, test_set.dset)
errs[test_set.name] = err
return errs
def test_epoch(self, epoch, set_idx, net, dset):
logging.info('-'*80)
logging.info('Test epoch %d' % epoch)
dset.current_epoch = epoch
test_loader = torch.utils.data.DataLoader(dset, batch_size=self.test_batch_size, shuffle=False, num_workers=self.num_workers, drop_last=False, pin_memory=False)
net = net.to(self.test_device)
net.eval()
with torch.no_grad():
mean_loss = None
self.callback_test_start(epoch, set_idx)
bar = ETA(length=len(test_loader))
stopwatch = StopWatch()
stopwatch.start('total')
stopwatch.start('data')
for batch_idx, data in enumerate(test_loader):
# if batch_idx == 10: break
self.copy_data(data, device=self.test_device, requires_grad=False, train=False)
stopwatch.stop('data')
stopwatch.start('forward')
output = self.net_forward(net, train=False)
if 'cuda' in self.test_device: torch.cuda.synchronize()
stopwatch.stop('forward')
stopwatch.start('loss')
errs = self.loss_forward(output, train=False)
if isinstance(errs, dict):
masks = errs['masks']
errs = errs['errs']
else:
masks = []
if not isinstance(errs, list) and not isinstance(errs, tuple):
errs = [errs]
bar.update(batch_idx)
if batch_idx % 25 == 0:
err_str = self.format_err_str(errs)
logging.info(f'test e{epoch}: {batch_idx+1}/{len(test_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
if mean_loss is None:
mean_loss = [0 for e in errs]
for erridx, err in enumerate(errs):
mean_loss[erridx] += err.item()
stopwatch.stop('loss')
self.callback_test_add(epoch, set_idx, batch_idx, len(test_loader), output, masks)
stopwatch.start('total')
stopwatch.start('data') stopwatch.start('data')
stopwatch.stop('total') for batch_idx, data in enumerate(train_loader):
logging.info('timings: %s' % stopwatch) if self.max_train_iter > 0 and batch_idx > self.max_train_iter: break
self.copy_data(data, device=self.train_device, requires_grad=True, train=True)
stopwatch.stop('data')
mean_loss = [l / len(test_loader) for l in mean_loss] optimizer.zero_grad()
self.callback_test_stop(epoch, set_idx, mean_loss)
self.metric_add_test(epoch, set_idx, 'loss', mean_loss)
# save metrics stopwatch.start('forward')
self.metric_save() output = self.net_forward(net, train=True)
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('forward')
err_str = self.format_err_str(mean_loss) stopwatch.start('loss')
logging.info(f'test epoch {epoch}: avg test_loss={err_str}') errs = self.loss_forward(output, train=True)
return mean_loss if isinstance(errs, dict):
masks = errs['masks']
errs = errs['errs']
else:
masks = []
if not isinstance(errs, list) and not isinstance(errs, tuple):
errs = [errs]
err = sum(errs)
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('loss')
stopwatch.start('backward')
err.backward()
self.callback_train_post_backward(net, errs, output, epoch, batch_idx, masks)
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('backward')
stopwatch.start('optimizer')
optimizer.step()
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('optimizer')
bar.update(batch_idx)
if (epoch <= 1 and batch_idx < 128) or batch_idx % 16 == 0:
err_str = self.format_err_str(errs)
logging.info(
f'train e{epoch}: {batch_idx + 1}/{len(train_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
# self.write_err_img()
if mean_loss is None:
mean_loss = [0 for e in errs]
for erridx, err in enumerate(errs):
mean_loss[erridx] += err.item()
stopwatch.start('data')
stopwatch.stop('total')
logging.info('timings: %s' % stopwatch)
mean_loss = [l / len(train_loader) for l in mean_loss]
self.callback_train_stop(epoch, mean_loss)
self.metric_add_train(epoch, 'loss', mean_loss)
# save metrics
self.metric_save()
err_str = self.format_err_str(mean_loss)
logging.info(f'avg train_loss={err_str}')
return mean_loss
def callback_test_start(self, epoch, set_idx):
pass
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks):
pass
def callback_test_stop(self, epoch, set_idx, loss):
pass
def test(self, epoch, net, test_sets):
errs = {}
for test_set_idx, test_set in enumerate(test_sets):
if (epoch + 1) % test_set.test_frequency == 0:
logging.info('=' * 80)
logging.info(f'testing set {test_set.name}')
err = self.test_epoch(epoch, test_set_idx, net, test_set.dset)
errs[test_set.name] = err
return errs
def test_epoch(self, epoch, set_idx, net, dset):
logging.info('-' * 80)
logging.info('Test epoch %d' % epoch)
dset.current_epoch = epoch
test_loader = torch.utils.data.DataLoader(dset, batch_size=self.test_batch_size, shuffle=False,
num_workers=self.num_workers, drop_last=False, pin_memory=False)
net = net.to(self.test_device)
net.eval()
with torch.no_grad():
mean_loss = None
self.callback_test_start(epoch, set_idx)
bar = ETA(length=len(test_loader))
stopwatch = StopWatch()
stopwatch.start('total')
stopwatch.start('data')
for batch_idx, data in enumerate(test_loader):
# if batch_idx == 10: break
self.copy_data(data, device=self.test_device, requires_grad=False, train=False)
stopwatch.stop('data')
stopwatch.start('forward')
output = self.net_forward(net, train=False)
if 'cuda' in self.test_device: torch.cuda.synchronize()
stopwatch.stop('forward')
stopwatch.start('loss')
errs = self.loss_forward(output, train=False)
if isinstance(errs, dict):
masks = errs['masks']
errs = errs['errs']
else:
masks = []
if not isinstance(errs, list) and not isinstance(errs, tuple):
errs = [errs]
bar.update(batch_idx)
if batch_idx % 25 == 0:
err_str = self.format_err_str(errs)
logging.info(
f'test e{epoch}: {batch_idx + 1}/{len(test_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
if mean_loss is None:
mean_loss = [0 for e in errs]
for erridx, err in enumerate(errs):
mean_loss[erridx] += err.item()
stopwatch.stop('loss')
self.callback_test_add(epoch, set_idx, batch_idx, len(test_loader), output, masks)
stopwatch.start('data')
stopwatch.stop('total')
logging.info('timings: %s' % stopwatch)
mean_loss = [l / len(test_loader) for l in mean_loss]
self.callback_test_stop(epoch, set_idx, mean_loss)
self.metric_add_test(epoch, set_idx, 'loss', mean_loss)
# save metrics
self.metric_save()
err_str = self.format_err_str(mean_loss)
logging.info(f'test epoch {epoch}: avg test_loss={err_str}')
return mean_loss

View File

@ -5,25 +5,24 @@ from model import exp_synphge
from model import networks from model import networks
from co.args import parse_args from co.args import parse_args
# parse args # parse args
args = parse_args() args = parse_args()
# loss types # loss types
if args.loss=='ph': if args.loss == 'ph':
worker = exp_synph.Worker(args) worker = exp_synph.Worker(args)
elif args.loss=='phge': elif args.loss == 'phge':
worker = exp_synphge.Worker(args) worker = exp_synphge.Worker(args)
# concatenation of original image and lcn image # concatenation of original image and lcn image
channels_in=2 channels_in = 2
# set up network # set up network
net = networks.DispEdgeDecoders(channels_in=channels_in, max_disp=args.max_disp, imsizes=worker.imsizes, output_ms=worker.ms) net = networks.DispEdgeDecoders(channels_in=channels_in, max_disp=args.max_disp, imsizes=worker.imsizes,
output_ms=worker.ms)
# optimizer # optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4) optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
# start the work # start the work
worker.do(net, optimizer) worker.do(net, optimizer)