Reformat $EVERYTHING
This commit is contained in:
parent
56f2aa7d5d
commit
43df77fb9b
@ -7,8 +7,9 @@
|
|||||||
# set matplotlib backend depending on env
|
# set matplotlib backend depending on env
|
||||||
import os
|
import os
|
||||||
import matplotlib
|
import matplotlib
|
||||||
|
|
||||||
if os.name == 'posix' and "DISPLAY" not in os.environ:
|
if os.name == 'posix' and "DISPLAY" not in os.environ:
|
||||||
matplotlib.use('Agg')
|
matplotlib.use('Agg')
|
||||||
|
|
||||||
from . import geometry
|
from . import geometry
|
||||||
from . import plt
|
from . import plt
|
||||||
|
@ -12,7 +12,7 @@ def parse_args():
|
|||||||
parser.add_argument('--loss',
|
parser.add_argument('--loss',
|
||||||
help='Train with \'ph\' for the first stage without geometric loss, \
|
help='Train with \'ph\' for the first stage without geometric loss, \
|
||||||
train with \'phge\' for the second stage with geometric loss',
|
train with \'phge\' for the second stage with geometric loss',
|
||||||
default='ph', choices=['ph','phge'], type=str)
|
default='ph', choices=['ph', 'phge'], type=str)
|
||||||
parser.add_argument('--data_type',
|
parser.add_argument('--data_type',
|
||||||
default='syn', choices=['syn'], type=str)
|
default='syn', choices=['syn'], type=str)
|
||||||
#
|
#
|
||||||
@ -66,6 +66,3 @@ def parse_args():
|
|||||||
def get_exp_name(args):
|
def get_exp_name(args):
|
||||||
name = f"exp_{args.data_type}"
|
name = f"exp_{args.data_type}"
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
59
co/cmap.py
59
co/cmap.py
@ -1,19 +1,20 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
_color_map_errors = np.array([
|
_color_map_errors = np.array([
|
||||||
[149, 54, 49], #0: log2(x) = -infinity
|
[149, 54, 49], # 0: log2(x) = -infinity
|
||||||
[180, 117, 69], #0.0625: log2(x) = -4
|
[180, 117, 69], # 0.0625: log2(x) = -4
|
||||||
[209, 173, 116], #0.125: log2(x) = -3
|
[209, 173, 116], # 0.125: log2(x) = -3
|
||||||
[233, 217, 171], #0.25: log2(x) = -2
|
[233, 217, 171], # 0.25: log2(x) = -2
|
||||||
[248, 243, 224], #0.5: log2(x) = -1
|
[248, 243, 224], # 0.5: log2(x) = -1
|
||||||
[144, 224, 254], #1.0: log2(x) = 0
|
[144, 224, 254], # 1.0: log2(x) = 0
|
||||||
[97, 174, 253], #2.0: log2(x) = 1
|
[97, 174, 253], # 2.0: log2(x) = 1
|
||||||
[67, 109, 244], #4.0: log2(x) = 2
|
[67, 109, 244], # 4.0: log2(x) = 2
|
||||||
[39, 48, 215], #8.0: log2(x) = 3
|
[39, 48, 215], # 8.0: log2(x) = 3
|
||||||
[38, 0, 165], #16.0: log2(x) = 4
|
[38, 0, 165], # 16.0: log2(x) = 4
|
||||||
[38, 0, 165] #inf: log2(x) = inf
|
[38, 0, 165] # inf: log2(x) = inf
|
||||||
]).astype(float)
|
]).astype(float)
|
||||||
|
|
||||||
|
|
||||||
def color_error_image(errors, scale=1, mask=None, BGR=True):
|
def color_error_image(errors, scale=1, mask=None, BGR=True):
|
||||||
"""
|
"""
|
||||||
Color an input error map.
|
Color an input error map.
|
||||||
@ -32,26 +33,28 @@ def color_error_image(errors, scale=1, mask=None, BGR=True):
|
|||||||
errors_color_indices = np.clip(np.log2(errors_flat / scale + 1e-5) + 5, 0, 9)
|
errors_color_indices = np.clip(np.log2(errors_flat / scale + 1e-5) + 5, 0, 9)
|
||||||
i0 = np.floor(errors_color_indices).astype(int)
|
i0 = np.floor(errors_color_indices).astype(int)
|
||||||
f1 = errors_color_indices - i0.astype(float)
|
f1 = errors_color_indices - i0.astype(float)
|
||||||
colored_errors_flat = _color_map_errors[i0, :] * (1-f1).reshape(-1,1) + _color_map_errors[i0+1, :] * f1.reshape(-1,1)
|
colored_errors_flat = _color_map_errors[i0, :] * (1 - f1).reshape(-1, 1) + _color_map_errors[i0 + 1,
|
||||||
|
:] * f1.reshape(-1, 1)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
colored_errors_flat[mask.flatten() == 0] = 255
|
colored_errors_flat[mask.flatten() == 0] = 255
|
||||||
|
|
||||||
if not BGR:
|
if not BGR:
|
||||||
colored_errors_flat = colored_errors_flat[:,[2,1,0]]
|
colored_errors_flat = colored_errors_flat[:, [2, 1, 0]]
|
||||||
|
|
||||||
return colored_errors_flat.reshape(errors.shape[0], errors.shape[1], 3).astype(np.int)
|
return colored_errors_flat.reshape(errors.shape[0], errors.shape[1], 3).astype(np.int)
|
||||||
|
|
||||||
|
|
||||||
_color_map_depths = np.array([
|
_color_map_depths = np.array([
|
||||||
[0, 0, 0], # 0.000
|
[0, 0, 0], # 0.000
|
||||||
[0, 0, 255], # 0.114
|
[0, 0, 255], # 0.114
|
||||||
[255, 0, 0], # 0.299
|
[255, 0, 0], # 0.299
|
||||||
[255, 0, 255], # 0.413
|
[255, 0, 255], # 0.413
|
||||||
[0, 255, 0], # 0.587
|
[0, 255, 0], # 0.587
|
||||||
[0, 255, 255], # 0.701
|
[0, 255, 255], # 0.701
|
||||||
[255, 255, 0], # 0.886
|
[255, 255, 0], # 0.886
|
||||||
[255, 255, 255], # 1.000
|
[255, 255, 255], # 1.000
|
||||||
[255, 255, 255], # 1.000
|
[255, 255, 255], # 1.000
|
||||||
]).astype(float)
|
]).astype(float)
|
||||||
_color_map_bincenters = np.array([
|
_color_map_bincenters = np.array([
|
||||||
0.0,
|
0.0,
|
||||||
@ -62,9 +65,10 @@ _color_map_bincenters = np.array([
|
|||||||
0.701,
|
0.701,
|
||||||
0.886,
|
0.886,
|
||||||
1.000,
|
1.000,
|
||||||
2.000, # doesn't make a difference, just strictly higher than 1
|
2.000, # doesn't make a difference, just strictly higher than 1
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
def color_depth_map(depths, scale=None):
|
def color_depth_map(depths, scale=None):
|
||||||
"""
|
"""
|
||||||
Color an input depth map.
|
Color an input depth map.
|
||||||
@ -82,12 +86,13 @@ def color_depth_map(depths, scale=None):
|
|||||||
|
|
||||||
values = np.clip(depths.flatten() / scale, 0, 1)
|
values = np.clip(depths.flatten() / scale, 0, 1)
|
||||||
# for each value, figure out where they fit in in the bincenters: what is the last bincenter smaller than this value?
|
# for each value, figure out where they fit in in the bincenters: what is the last bincenter smaller than this value?
|
||||||
lower_bin = ((values.reshape(-1, 1) >= _color_map_bincenters.reshape(1,-1)) * np.arange(0,9)).max(axis=1)
|
lower_bin = ((values.reshape(-1, 1) >= _color_map_bincenters.reshape(1, -1)) * np.arange(0, 9)).max(axis=1)
|
||||||
lower_bin_value = _color_map_bincenters[lower_bin]
|
lower_bin_value = _color_map_bincenters[lower_bin]
|
||||||
higher_bin_value = _color_map_bincenters[lower_bin + 1]
|
higher_bin_value = _color_map_bincenters[lower_bin + 1]
|
||||||
alphas = (values - lower_bin_value) / (higher_bin_value - lower_bin_value)
|
alphas = (values - lower_bin_value) / (higher_bin_value - lower_bin_value)
|
||||||
colors = _color_map_depths[lower_bin] * (1-alphas).reshape(-1,1) + _color_map_depths[lower_bin + 1] * alphas.reshape(-1,1)
|
colors = _color_map_depths[lower_bin] * (1 - alphas).reshape(-1, 1) + _color_map_depths[
|
||||||
|
lower_bin + 1] * alphas.reshape(-1, 1)
|
||||||
return colors.reshape(depths.shape[0], depths.shape[1], 3).astype(np.uint8)
|
return colors.reshape(depths.shape[0], depths.shape[1], 3).astype(np.uint8)
|
||||||
|
|
||||||
#from utils.debug import save_color_numpy
|
# from utils.debug import save_color_numpy
|
||||||
#save_color_numpy(color_depth_map(np.matmul(np.ones((100,1)), np.arange(0,1200).reshape(1,1200)), scale=1000))
|
# save_color_numpy(color_depth_map(np.matmul(np.ones((100,1)), np.arange(0,1200).reshape(1,1200)), scale=1000))
|
||||||
|
1302
co/geometry.py
1302
co/geometry.py
File diff suppressed because it is too large
Load Diff
42
co/gtimer.py
42
co/gtimer.py
@ -2,31 +2,37 @@ import numpy as np
|
|||||||
|
|
||||||
from . import utils
|
from . import utils
|
||||||
|
|
||||||
|
|
||||||
class StopWatch(utils.StopWatch):
|
class StopWatch(utils.StopWatch):
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
print('='*80)
|
print('=' * 80)
|
||||||
print('gtimer:')
|
print('gtimer:')
|
||||||
total = ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get(reduce=np.sum).items()])
|
total = ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get(reduce=np.sum).items()])
|
||||||
print(f' [total] {total}')
|
print(f' [total] {total}')
|
||||||
mean = ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get(reduce=np.mean).items()])
|
mean = ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get(reduce=np.mean).items()])
|
||||||
print(f' [mean] {mean}')
|
print(f' [mean] {mean}')
|
||||||
median = ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get(reduce=np.median).items()])
|
median = ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get(reduce=np.median).items()])
|
||||||
print(f' [median] {median}')
|
print(f' [median] {median}')
|
||||||
print('='*80)
|
print('=' * 80)
|
||||||
|
|
||||||
|
|
||||||
GTIMER = StopWatch()
|
GTIMER = StopWatch()
|
||||||
|
|
||||||
|
|
||||||
def start(name):
|
def start(name):
|
||||||
GTIMER.start(name)
|
GTIMER.start(name)
|
||||||
|
|
||||||
|
|
||||||
def stop(name):
|
def stop(name):
|
||||||
GTIMER.stop(name)
|
GTIMER.stop(name)
|
||||||
|
|
||||||
|
|
||||||
class Ctx(object):
|
class Ctx(object):
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
start(self.name)
|
start(self.name)
|
||||||
|
|
||||||
def __exit__(self, *args):
|
def __exit__(self, *args):
|
||||||
stop(self.name)
|
stop(self.name)
|
||||||
|
475
co/io3d.py
475
co/io3d.py
@ -2,266 +2,273 @@ import struct
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import collections
|
import collections
|
||||||
|
|
||||||
def _write_ply_point(fp, x,y,z, color=None, normal=None, binary=False):
|
|
||||||
args = [x,y,z]
|
|
||||||
if color is not None:
|
|
||||||
args += [int(color[0]), int(color[1]), int(color[2])]
|
|
||||||
if normal is not None:
|
|
||||||
args += [normal[0],normal[1],normal[2]]
|
|
||||||
if binary:
|
|
||||||
fmt = '<fff'
|
|
||||||
if color is not None:
|
|
||||||
fmt = fmt + 'BBB'
|
|
||||||
if normal is not None:
|
|
||||||
fmt = fmt + 'fff'
|
|
||||||
fp.write(struct.pack(fmt, *args))
|
|
||||||
else:
|
|
||||||
fmt = '%f %f %f'
|
|
||||||
if color is not None:
|
|
||||||
fmt = fmt + ' %d %d %d'
|
|
||||||
if normal is not None:
|
|
||||||
fmt = fmt + ' %f %f %f'
|
|
||||||
fmt += '\n'
|
|
||||||
fp.write(fmt % tuple(args))
|
|
||||||
|
|
||||||
def _write_ply_triangle(fp, i0,i1,i2, binary):
|
def _write_ply_point(fp, x, y, z, color=None, normal=None, binary=False):
|
||||||
if binary:
|
args = [x, y, z]
|
||||||
fp.write(struct.pack('<Biii', 3,i0,i1,i2))
|
if color is not None:
|
||||||
else:
|
args += [int(color[0]), int(color[1]), int(color[2])]
|
||||||
fp.write('3 %d %d %d\n' % (i0,i1,i2))
|
if normal is not None:
|
||||||
|
args += [normal[0], normal[1], normal[2]]
|
||||||
|
if binary:
|
||||||
|
fmt = '<fff'
|
||||||
|
if color is not None:
|
||||||
|
fmt = fmt + 'BBB'
|
||||||
|
if normal is not None:
|
||||||
|
fmt = fmt + 'fff'
|
||||||
|
fp.write(struct.pack(fmt, *args))
|
||||||
|
else:
|
||||||
|
fmt = '%f %f %f'
|
||||||
|
if color is not None:
|
||||||
|
fmt = fmt + ' %d %d %d'
|
||||||
|
if normal is not None:
|
||||||
|
fmt = fmt + ' %f %f %f'
|
||||||
|
fmt += '\n'
|
||||||
|
fp.write(fmt % tuple(args))
|
||||||
|
|
||||||
|
|
||||||
|
def _write_ply_triangle(fp, i0, i1, i2, binary):
|
||||||
|
if binary:
|
||||||
|
fp.write(struct.pack('<Biii', 3, i0, i1, i2))
|
||||||
|
else:
|
||||||
|
fp.write('3 %d %d %d\n' % (i0, i1, i2))
|
||||||
|
|
||||||
|
|
||||||
def _write_ply_header_line(fp, str, binary):
|
def _write_ply_header_line(fp, str, binary):
|
||||||
if binary:
|
if binary:
|
||||||
fp.write(str.encode())
|
fp.write(str.encode())
|
||||||
else:
|
else:
|
||||||
fp.write(str)
|
fp.write(str)
|
||||||
|
|
||||||
|
|
||||||
def write_ply(path, verts, trias=None, color=None, normals=None, binary=False):
|
def write_ply(path, verts, trias=None, color=None, normals=None, binary=False):
|
||||||
if verts.shape[1] != 3:
|
if verts.shape[1] != 3:
|
||||||
raise Exception('verts has to be of shape Nx3')
|
raise Exception('verts has to be of shape Nx3')
|
||||||
if trias is not None and trias.shape[1] != 3:
|
if trias is not None and trias.shape[1] != 3:
|
||||||
raise Exception('trias has to be of shape Nx3')
|
raise Exception('trias has to be of shape Nx3')
|
||||||
if color is not None and not callable(color) and not isinstance(color, np.ndarray) and color.shape[1] != 3:
|
if color is not None and not callable(color) and not isinstance(color, np.ndarray) and color.shape[1] != 3:
|
||||||
raise Exception('color has to be of shape Nx3 or a callable')
|
raise Exception('color has to be of shape Nx3 or a callable')
|
||||||
|
|
||||||
mode = 'wb' if binary else 'w'
|
mode = 'wb' if binary else 'w'
|
||||||
with open(path, mode) as fp:
|
with open(path, mode) as fp:
|
||||||
_write_ply_header_line(fp, "ply\n", binary)
|
_write_ply_header_line(fp, "ply\n", binary)
|
||||||
if binary:
|
if binary:
|
||||||
_write_ply_header_line(fp, "format binary_little_endian 1.0\n", binary)
|
_write_ply_header_line(fp, "format binary_little_endian 1.0\n", binary)
|
||||||
else:
|
|
||||||
_write_ply_header_line(fp, "format ascii 1.0\n", binary)
|
|
||||||
_write_ply_header_line(fp, "element vertex %d\n" % (verts.shape[0]), binary)
|
|
||||||
_write_ply_header_line(fp, "property float32 x\n", binary)
|
|
||||||
_write_ply_header_line(fp, "property float32 y\n", binary)
|
|
||||||
_write_ply_header_line(fp, "property float32 z\n", binary)
|
|
||||||
if color is not None:
|
|
||||||
_write_ply_header_line(fp, "property uchar red\n", binary)
|
|
||||||
_write_ply_header_line(fp, "property uchar green\n", binary)
|
|
||||||
_write_ply_header_line(fp, "property uchar blue\n", binary)
|
|
||||||
if normals is not None:
|
|
||||||
_write_ply_header_line(fp, "property float32 nx\n", binary)
|
|
||||||
_write_ply_header_line(fp, "property float32 ny\n", binary)
|
|
||||||
_write_ply_header_line(fp, "property float32 nz\n", binary)
|
|
||||||
if trias is not None:
|
|
||||||
_write_ply_header_line(fp, "element face %d\n" % (trias.shape[0]), binary)
|
|
||||||
_write_ply_header_line(fp, "property list uchar int32 vertex_indices\n", binary)
|
|
||||||
_write_ply_header_line(fp, "end_header\n", binary)
|
|
||||||
|
|
||||||
for vidx, v in enumerate(verts):
|
|
||||||
if color is not None:
|
|
||||||
if callable(color):
|
|
||||||
c = color(vidx)
|
|
||||||
elif color.shape[0] > 1:
|
|
||||||
c = color[vidx]
|
|
||||||
else:
|
else:
|
||||||
c = color[0]
|
_write_ply_header_line(fp, "format ascii 1.0\n", binary)
|
||||||
else:
|
_write_ply_header_line(fp, "element vertex %d\n" % (verts.shape[0]), binary)
|
||||||
c = None
|
_write_ply_header_line(fp, "property float32 x\n", binary)
|
||||||
if normals is None:
|
_write_ply_header_line(fp, "property float32 y\n", binary)
|
||||||
n = None
|
_write_ply_header_line(fp, "property float32 z\n", binary)
|
||||||
else:
|
if color is not None:
|
||||||
n = normals[vidx]
|
_write_ply_header_line(fp, "property uchar red\n", binary)
|
||||||
_write_ply_point(fp, v[0],v[1],v[2], c, n, binary)
|
_write_ply_header_line(fp, "property uchar green\n", binary)
|
||||||
|
_write_ply_header_line(fp, "property uchar blue\n", binary)
|
||||||
|
if normals is not None:
|
||||||
|
_write_ply_header_line(fp, "property float32 nx\n", binary)
|
||||||
|
_write_ply_header_line(fp, "property float32 ny\n", binary)
|
||||||
|
_write_ply_header_line(fp, "property float32 nz\n", binary)
|
||||||
|
if trias is not None:
|
||||||
|
_write_ply_header_line(fp, "element face %d\n" % (trias.shape[0]), binary)
|
||||||
|
_write_ply_header_line(fp, "property list uchar int32 vertex_indices\n", binary)
|
||||||
|
_write_ply_header_line(fp, "end_header\n", binary)
|
||||||
|
|
||||||
|
for vidx, v in enumerate(verts):
|
||||||
|
if color is not None:
|
||||||
|
if callable(color):
|
||||||
|
c = color(vidx)
|
||||||
|
elif color.shape[0] > 1:
|
||||||
|
c = color[vidx]
|
||||||
|
else:
|
||||||
|
c = color[0]
|
||||||
|
else:
|
||||||
|
c = None
|
||||||
|
if normals is None:
|
||||||
|
n = None
|
||||||
|
else:
|
||||||
|
n = normals[vidx]
|
||||||
|
_write_ply_point(fp, v[0], v[1], v[2], c, n, binary)
|
||||||
|
|
||||||
|
if trias is not None:
|
||||||
|
for t in trias:
|
||||||
|
_write_ply_triangle(fp, t[0], t[1], t[2], binary)
|
||||||
|
|
||||||
if trias is not None:
|
|
||||||
for t in trias:
|
|
||||||
_write_ply_triangle(fp, t[0],t[1],t[2], binary)
|
|
||||||
|
|
||||||
def faces_to_triangles(faces):
|
def faces_to_triangles(faces):
|
||||||
new_faces = []
|
new_faces = []
|
||||||
for f in faces:
|
for f in faces:
|
||||||
if f[0] == 3:
|
if f[0] == 3:
|
||||||
new_faces.append([f[1], f[2], f[3]])
|
new_faces.append([f[1], f[2], f[3]])
|
||||||
elif f[0] == 4:
|
elif f[0] == 4:
|
||||||
new_faces.append([f[1], f[2], f[3]])
|
new_faces.append([f[1], f[2], f[3]])
|
||||||
new_faces.append([f[3], f[4], f[1]])
|
new_faces.append([f[3], f[4], f[1]])
|
||||||
else:
|
else:
|
||||||
raise Exception('unknown face count %d', f[0])
|
raise Exception('unknown face count %d', f[0])
|
||||||
return new_faces
|
return new_faces
|
||||||
|
|
||||||
|
|
||||||
def read_ply(path):
|
def read_ply(path):
|
||||||
with open(path, 'rb') as f:
|
with open(path, 'rb') as f:
|
||||||
# parse header
|
# parse header
|
||||||
line = f.readline().decode().strip()
|
line = f.readline().decode().strip()
|
||||||
if line != 'ply':
|
if line != 'ply':
|
||||||
raise Exception('Header error')
|
raise Exception('Header error')
|
||||||
n_verts = 0
|
n_verts = 0
|
||||||
n_faces = 0
|
n_faces = 0
|
||||||
vert_types = {}
|
vert_types = {}
|
||||||
vert_bin_format = []
|
vert_bin_format = []
|
||||||
vert_bin_len = 0
|
vert_bin_len = 0
|
||||||
vert_bin_cols = 0
|
vert_bin_cols = 0
|
||||||
line = f.readline().decode()
|
line = f.readline().decode()
|
||||||
parse_vertex_prop = False
|
|
||||||
while line.strip() != 'end_header':
|
|
||||||
if 'format' in line:
|
|
||||||
if 'ascii' in line:
|
|
||||||
binary = False
|
|
||||||
elif 'binary_little_endian' in line:
|
|
||||||
binary = True
|
|
||||||
else:
|
|
||||||
raise Exception('invalid ply format')
|
|
||||||
if 'element face' in line:
|
|
||||||
splits = line.strip().split(' ')
|
|
||||||
n_faces = int(splits[-1])
|
|
||||||
parse_vertex_prop = False
|
parse_vertex_prop = False
|
||||||
if 'element camera' in line:
|
while line.strip() != 'end_header':
|
||||||
parse_vertex_prop = False
|
if 'format' in line:
|
||||||
if 'element vertex' in line:
|
if 'ascii' in line:
|
||||||
splits = line.strip().split(' ')
|
binary = False
|
||||||
n_verts = int(splits[-1])
|
elif 'binary_little_endian' in line:
|
||||||
parse_vertex_prop = True
|
binary = True
|
||||||
if parse_vertex_prop and 'property' in line:
|
else:
|
||||||
prop = line.strip().split()
|
raise Exception('invalid ply format')
|
||||||
if prop[1] == 'float':
|
if 'element face' in line:
|
||||||
vert_bin_format.append('f4')
|
splits = line.strip().split(' ')
|
||||||
vert_bin_len += 4
|
n_faces = int(splits[-1])
|
||||||
vert_bin_cols += 1
|
parse_vertex_prop = False
|
||||||
elif prop[1] == 'uchar':
|
if 'element camera' in line:
|
||||||
vert_bin_format.append('B')
|
parse_vertex_prop = False
|
||||||
vert_bin_len += 1
|
if 'element vertex' in line:
|
||||||
vert_bin_cols += 1
|
splits = line.strip().split(' ')
|
||||||
|
n_verts = int(splits[-1])
|
||||||
|
parse_vertex_prop = True
|
||||||
|
if parse_vertex_prop and 'property' in line:
|
||||||
|
prop = line.strip().split()
|
||||||
|
if prop[1] == 'float':
|
||||||
|
vert_bin_format.append('f4')
|
||||||
|
vert_bin_len += 4
|
||||||
|
vert_bin_cols += 1
|
||||||
|
elif prop[1] == 'uchar':
|
||||||
|
vert_bin_format.append('B')
|
||||||
|
vert_bin_len += 1
|
||||||
|
vert_bin_cols += 1
|
||||||
|
else:
|
||||||
|
raise Exception('invalid property')
|
||||||
|
vert_types[prop[2]] = len(vert_types)
|
||||||
|
line = f.readline().decode()
|
||||||
|
|
||||||
|
# parse content
|
||||||
|
if binary:
|
||||||
|
sz = n_verts * vert_bin_len
|
||||||
|
fmt = ','.join(vert_bin_format)
|
||||||
|
verts = np.ndarray(shape=(1, n_verts), dtype=np.dtype(fmt), buffer=f.read(sz))
|
||||||
|
verts = verts[0].astype(vert_bin_cols * 'f4,').view(dtype='f4').reshape((n_verts, -1))
|
||||||
|
faces = []
|
||||||
|
for idx in range(n_faces):
|
||||||
|
fmt = '<Biii'
|
||||||
|
length = struct.calcsize(fmt)
|
||||||
|
dat = f.read(length)
|
||||||
|
vals = struct.unpack(fmt, dat)
|
||||||
|
faces.append(vals)
|
||||||
|
faces = faces_to_triangles(faces)
|
||||||
|
faces = np.array(faces, dtype=np.int32)
|
||||||
else:
|
else:
|
||||||
raise Exception('invalid property')
|
verts = []
|
||||||
vert_types[prop[2]] = len(vert_types)
|
for idx in range(n_verts):
|
||||||
line = f.readline().decode()
|
vals = [float(v) for v in f.readline().decode().strip().split(' ')]
|
||||||
|
verts.append(vals)
|
||||||
|
verts = np.array(verts, dtype=np.float32)
|
||||||
|
faces = []
|
||||||
|
for idx in range(n_faces):
|
||||||
|
splits = f.readline().decode().strip().split(' ')
|
||||||
|
n_face_verts = int(splits[0])
|
||||||
|
vals = [int(v) for v in splits[0:n_face_verts + 1]]
|
||||||
|
faces.append(vals)
|
||||||
|
faces = faces_to_triangles(faces)
|
||||||
|
faces = np.array(faces, dtype=np.int32)
|
||||||
|
|
||||||
# parse content
|
xyz = None
|
||||||
if binary:
|
if 'x' in vert_types and 'y' in vert_types and 'z' in vert_types:
|
||||||
sz = n_verts * vert_bin_len
|
xyz = verts[:, [vert_types['x'], vert_types['y'], vert_types['z']]]
|
||||||
fmt = ','.join(vert_bin_format)
|
colors = None
|
||||||
verts = np.ndarray(shape=(1, n_verts), dtype=np.dtype(fmt), buffer=f.read(sz))
|
if 'red' in vert_types and 'green' in vert_types and 'blue' in vert_types:
|
||||||
verts = verts[0].astype(vert_bin_cols*'f4,').view(dtype='f4').reshape((n_verts,-1))
|
colors = verts[:, [vert_types['red'], vert_types['green'], vert_types['blue']]]
|
||||||
faces = []
|
colors /= 255
|
||||||
for idx in range(n_faces):
|
normals = None
|
||||||
fmt = '<Biii'
|
if 'nx' in vert_types and 'ny' in vert_types and 'nz' in vert_types:
|
||||||
length = struct.calcsize(fmt)
|
normals = verts[:, [vert_types['nx'], vert_types['ny'], vert_types['nz']]]
|
||||||
dat = f.read(length)
|
|
||||||
vals = struct.unpack(fmt, dat)
|
|
||||||
faces.append(vals)
|
|
||||||
faces = faces_to_triangles(faces)
|
|
||||||
faces = np.array(faces, dtype=np.int32)
|
|
||||||
else:
|
|
||||||
verts = []
|
|
||||||
for idx in range(n_verts):
|
|
||||||
vals = [float(v) for v in f.readline().decode().strip().split(' ')]
|
|
||||||
verts.append(vals)
|
|
||||||
verts = np.array(verts, dtype=np.float32)
|
|
||||||
faces = []
|
|
||||||
for idx in range(n_faces):
|
|
||||||
splits = f.readline().decode().strip().split(' ')
|
|
||||||
n_face_verts = int(splits[0])
|
|
||||||
vals = [int(v) for v in splits[0:n_face_verts+1]]
|
|
||||||
faces.append(vals)
|
|
||||||
faces = faces_to_triangles(faces)
|
|
||||||
faces = np.array(faces, dtype=np.int32)
|
|
||||||
|
|
||||||
xyz = None
|
return xyz, faces, colors, normals
|
||||||
if 'x' in vert_types and 'y' in vert_types and 'z' in vert_types:
|
|
||||||
xyz = verts[:,[vert_types['x'], vert_types['y'], vert_types['z']]]
|
|
||||||
colors = None
|
|
||||||
if 'red' in vert_types and 'green' in vert_types and 'blue' in vert_types:
|
|
||||||
colors = verts[:,[vert_types['red'], vert_types['green'], vert_types['blue']]]
|
|
||||||
colors /= 255
|
|
||||||
normals = None
|
|
||||||
if 'nx' in vert_types and 'ny' in vert_types and 'nz' in vert_types:
|
|
||||||
normals = verts[:,[vert_types['nx'], vert_types['ny'], vert_types['nz']]]
|
|
||||||
|
|
||||||
return xyz, faces, colors, normals
|
|
||||||
|
|
||||||
|
|
||||||
def _read_obj_split_f(s):
|
def _read_obj_split_f(s):
|
||||||
parts = s.split('/')
|
parts = s.split('/')
|
||||||
vidx = int(parts[0]) - 1
|
vidx = int(parts[0]) - 1
|
||||||
if len(parts) >= 2 and len(parts[1]) > 0:
|
if len(parts) >= 2 and len(parts[1]) > 0:
|
||||||
tidx = int(parts[1]) - 1
|
tidx = int(parts[1]) - 1
|
||||||
else:
|
else:
|
||||||
tidx = -1
|
tidx = -1
|
||||||
if len(parts) >= 3 and len(parts[2]) > 0:
|
if len(parts) >= 3 and len(parts[2]) > 0:
|
||||||
nidx = int(parts[2]) - 1
|
nidx = int(parts[2]) - 1
|
||||||
else:
|
else:
|
||||||
nidx = -1
|
nidx = -1
|
||||||
return vidx, tidx, nidx
|
return vidx, tidx, nidx
|
||||||
|
|
||||||
|
|
||||||
def read_obj(path):
|
def read_obj(path):
|
||||||
with open(path, 'r') as fp:
|
with open(path, 'r') as fp:
|
||||||
lines = fp.readlines()
|
lines = fp.readlines()
|
||||||
|
|
||||||
verts = []
|
verts = []
|
||||||
colors = []
|
colors = []
|
||||||
fnorms = []
|
fnorms = []
|
||||||
fnorm_map = collections.defaultdict(list)
|
fnorm_map = collections.defaultdict(list)
|
||||||
faces = []
|
faces = []
|
||||||
for line in lines:
|
for line in lines:
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if line.startswith('#') or len(line) == 0:
|
if line.startswith('#') or len(line) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
parts = line.split()
|
parts = line.split()
|
||||||
if line.startswith('v '):
|
if line.startswith('v '):
|
||||||
parts = parts[1:]
|
parts = parts[1:]
|
||||||
x,y,z = float(parts[0]), float(parts[1]), float(parts[2])
|
x, y, z = float(parts[0]), float(parts[1]), float(parts[2])
|
||||||
if len(parts) == 4 or len(parts) == 7:
|
if len(parts) == 4 or len(parts) == 7:
|
||||||
w = float(parts[3])
|
w = float(parts[3])
|
||||||
x,y,z = x/w, y/w, z/w
|
x, y, z = x / w, y / w, z / w
|
||||||
verts.append((x,y,z))
|
verts.append((x, y, z))
|
||||||
if len(parts) >= 6:
|
if len(parts) >= 6:
|
||||||
r,g,b = float(parts[-3]), float(parts[-2]), float(parts[-1])
|
r, g, b = float(parts[-3]), float(parts[-2]), float(parts[-1])
|
||||||
rgb.append((r,g,b))
|
rgb.append((r, g, b))
|
||||||
|
|
||||||
elif line.startswith('vn '):
|
elif line.startswith('vn '):
|
||||||
parts = parts[1:]
|
parts = parts[1:]
|
||||||
x,y,z = float(parts[0]), float(parts[1]), float(parts[2])
|
x, y, z = float(parts[0]), float(parts[1]), float(parts[2])
|
||||||
fnorms.append((x,y,z))
|
fnorms.append((x, y, z))
|
||||||
|
|
||||||
elif line.startswith('f '):
|
elif line.startswith('f '):
|
||||||
parts = parts[1:]
|
parts = parts[1:]
|
||||||
if len(parts) != 3:
|
if len(parts) != 3:
|
||||||
raise Exception('only triangle meshes supported atm')
|
raise Exception('only triangle meshes supported atm')
|
||||||
vidx0, tidx0, nidx0 = _read_obj_split_f(parts[0])
|
vidx0, tidx0, nidx0 = _read_obj_split_f(parts[0])
|
||||||
vidx1, tidx1, nidx1 = _read_obj_split_f(parts[1])
|
vidx1, tidx1, nidx1 = _read_obj_split_f(parts[1])
|
||||||
vidx2, tidx2, nidx2 = _read_obj_split_f(parts[2])
|
vidx2, tidx2, nidx2 = _read_obj_split_f(parts[2])
|
||||||
|
|
||||||
faces.append((vidx0, vidx1, vidx2))
|
faces.append((vidx0, vidx1, vidx2))
|
||||||
if nidx0 >= 0:
|
if nidx0 >= 0:
|
||||||
fnorm_map[vidx0].append( nidx0 )
|
fnorm_map[vidx0].append(nidx0)
|
||||||
if nidx1 >= 0:
|
if nidx1 >= 0:
|
||||||
fnorm_map[vidx1].append( nidx1 )
|
fnorm_map[vidx1].append(nidx1)
|
||||||
if nidx2 >= 0:
|
if nidx2 >= 0:
|
||||||
fnorm_map[vidx2].append( nidx2 )
|
fnorm_map[vidx2].append(nidx2)
|
||||||
|
|
||||||
verts = np.array(verts)
|
verts = np.array(verts)
|
||||||
colors = np.array(colors)
|
colors = np.array(colors)
|
||||||
fnorms = np.array(fnorms)
|
fnorms = np.array(fnorms)
|
||||||
faces = np.array(faces)
|
faces = np.array(faces)
|
||||||
|
|
||||||
# face normals to vertex normals
|
# face normals to vertex normals
|
||||||
norms = np.zeros_like(verts)
|
norms = np.zeros_like(verts)
|
||||||
for vidx in fnorm_map.keys():
|
for vidx in fnorm_map.keys():
|
||||||
ind = fnorm_map[vidx]
|
ind = fnorm_map[vidx]
|
||||||
norms[vidx] = fnorms[ind].sum(axis=0)
|
norms[vidx] = fnorms[ind].sum(axis=0)
|
||||||
N = np.linalg.norm(norms, axis=1, keepdims=True)
|
N = np.linalg.norm(norms, axis=1, keepdims=True)
|
||||||
np.divide(norms, N, out=norms, where=N != 0)
|
np.divide(norms, N, out=norms, where=N != 0)
|
||||||
|
|
||||||
return verts, faces, colors, norms
|
return verts, faces, colors, norms
|
||||||
|
392
co/metric.py
392
co/metric.py
@ -1,248 +1,260 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from . import geometry
|
from . import geometry
|
||||||
|
|
||||||
|
|
||||||
def _process_inputs(estimate, target, mask):
|
def _process_inputs(estimate, target, mask):
|
||||||
if estimate.shape != target.shape:
|
if estimate.shape != target.shape:
|
||||||
raise Exception('estimate and target have to be same shape')
|
raise Exception('estimate and target have to be same shape')
|
||||||
if mask is None:
|
if mask is None:
|
||||||
mask = np.ones(estimate.shape, dtype=np.bool)
|
mask = np.ones(estimate.shape, dtype=np.bool)
|
||||||
else:
|
else:
|
||||||
mask = mask != 0
|
mask = mask != 0
|
||||||
if estimate.shape != mask.shape:
|
if estimate.shape != mask.shape:
|
||||||
raise Exception('estimate and mask have to be same shape')
|
raise Exception('estimate and mask have to be same shape')
|
||||||
return estimate, target, mask
|
return estimate, target, mask
|
||||||
|
|
||||||
|
|
||||||
def mse(estimate, target, mask=None):
|
def mse(estimate, target, mask=None):
|
||||||
estimate, target, mask = _process_inputs(estimate, target, mask)
|
estimate, target, mask = _process_inputs(estimate, target, mask)
|
||||||
m = np.sum((estimate[mask] - target[mask])**2) / mask.sum()
|
m = np.sum((estimate[mask] - target[mask]) ** 2) / mask.sum()
|
||||||
return m
|
return m
|
||||||
|
|
||||||
|
|
||||||
def rmse(estimate, target, mask=None):
|
def rmse(estimate, target, mask=None):
|
||||||
return np.sqrt(mse(estimate, target, mask))
|
return np.sqrt(mse(estimate, target, mask))
|
||||||
|
|
||||||
|
|
||||||
def mae(estimate, target, mask=None):
|
def mae(estimate, target, mask=None):
|
||||||
estimate, target, mask = _process_inputs(estimate, target, mask)
|
estimate, target, mask = _process_inputs(estimate, target, mask)
|
||||||
m = np.abs(estimate[mask] - target[mask]).sum() / mask.sum()
|
m = np.abs(estimate[mask] - target[mask]).sum() / mask.sum()
|
||||||
return m
|
return m
|
||||||
|
|
||||||
|
|
||||||
def outlier_fraction(estimate, target, mask=None, threshold=0):
|
def outlier_fraction(estimate, target, mask=None, threshold=0):
|
||||||
estimate, target, mask = _process_inputs(estimate, target, mask)
|
estimate, target, mask = _process_inputs(estimate, target, mask)
|
||||||
diff = np.abs(estimate[mask] - target[mask])
|
diff = np.abs(estimate[mask] - target[mask])
|
||||||
m = (diff > threshold).sum() / mask.sum()
|
m = (diff > threshold).sum() / mask.sum()
|
||||||
return m
|
return m
|
||||||
|
|
||||||
|
|
||||||
class Metric(object):
|
class Metric(object):
|
||||||
def __init__(self, str_prefix=''):
|
def __init__(self, str_prefix=''):
|
||||||
self.str_prefix = str_prefix
|
self.str_prefix = str_prefix
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def add(self, es, ta, ma=None):
|
def add(self, es, ta, ma=None):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get(self):
|
def get(self):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def items(self):
|
def items(self):
|
||||||
return self.get().items()
|
return self.get().items()
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return ', '.join([f'{self.str_prefix}{key}={value:.5f}' for key, value in self.get().items()])
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return ', '.join([f'{self.str_prefix}{key}={value:.5f}' for key, value in self.get().items()])
|
|
||||||
|
|
||||||
class MultipleMetric(Metric):
|
class MultipleMetric(Metric):
|
||||||
def __init__(self, *metrics, **kwargs):
|
def __init__(self, *metrics, **kwargs):
|
||||||
self.metrics = [*metrics]
|
self.metrics = [*metrics]
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
for m in self.metrics:
|
for m in self.metrics:
|
||||||
m.reset()
|
m.reset()
|
||||||
|
|
||||||
def add(self, es, ta, ma=None):
|
def add(self, es, ta, ma=None):
|
||||||
for m in self.metrics:
|
for m in self.metrics:
|
||||||
m.add(es, ta, ma)
|
m.add(es, ta, ma)
|
||||||
|
|
||||||
def get(self):
|
def get(self):
|
||||||
ret = {}
|
ret = {}
|
||||||
for m in self.metrics:
|
for m in self.metrics:
|
||||||
vals = m.get()
|
vals = m.get()
|
||||||
for k in vals:
|
for k in vals:
|
||||||
ret[k] = vals[k]
|
ret[k] = vals[k]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return '\n'.join([str(m) for m in self.metrics])
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return '\n'.join([str(m) for m in self.metrics])
|
|
||||||
|
|
||||||
class BaseDistanceMetric(Metric):
|
class BaseDistanceMetric(Metric):
|
||||||
def __init__(self, name='', **kwargs):
|
def __init__(self, name='', **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.dists = []
|
self.dists = []
|
||||||
|
|
||||||
def add(self, es, ta, ma=None):
|
def add(self, es, ta, ma=None):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
dists = np.hstack(self.dists)
|
||||||
|
return {
|
||||||
|
f'dist{self.name}_mean': float(np.mean(dists)),
|
||||||
|
f'dist{self.name}_std': float(np.std(dists)),
|
||||||
|
f'dist{self.name}_median': float(np.median(dists)),
|
||||||
|
f'dist{self.name}_q10': float(np.percentile(dists, 10)),
|
||||||
|
f'dist{self.name}_q90': float(np.percentile(dists, 90)),
|
||||||
|
f'dist{self.name}_min': float(np.min(dists)),
|
||||||
|
f'dist{self.name}_max': float(np.max(dists)),
|
||||||
|
}
|
||||||
|
|
||||||
def get(self):
|
|
||||||
dists = np.hstack(self.dists)
|
|
||||||
return {
|
|
||||||
f'dist{self.name}_mean': float(np.mean(dists)),
|
|
||||||
f'dist{self.name}_std': float(np.std(dists)),
|
|
||||||
f'dist{self.name}_median': float(np.median(dists)),
|
|
||||||
f'dist{self.name}_q10': float(np.percentile(dists, 10)),
|
|
||||||
f'dist{self.name}_q90': float(np.percentile(dists, 90)),
|
|
||||||
f'dist{self.name}_min': float(np.min(dists)),
|
|
||||||
f'dist{self.name}_max': float(np.max(dists)),
|
|
||||||
}
|
|
||||||
|
|
||||||
class DistanceMetric(BaseDistanceMetric):
|
class DistanceMetric(BaseDistanceMetric):
|
||||||
def __init__(self, vec_length, p=2, **kwargs):
|
def __init__(self, vec_length, p=2, **kwargs):
|
||||||
super().__init__(name=f'{p}', **kwargs)
|
super().__init__(name=f'{p}', **kwargs)
|
||||||
self.vec_length = vec_length
|
self.vec_length = vec_length
|
||||||
self.p = p
|
self.p = p
|
||||||
|
|
||||||
|
def add(self, es, ta, ma=None):
|
||||||
|
if es.shape != ta.shape or es.shape[1] != self.vec_length or es.ndim != 2:
|
||||||
|
print(es.shape, ta.shape)
|
||||||
|
raise Exception('es and ta have to be of shape Nxdim')
|
||||||
|
if ma is not None:
|
||||||
|
es = es[ma != 0]
|
||||||
|
ta = ta[ma != 0]
|
||||||
|
dist = np.linalg.norm(es - ta, ord=self.p, axis=1)
|
||||||
|
self.dists.append(dist)
|
||||||
|
|
||||||
def add(self, es, ta, ma=None):
|
|
||||||
if es.shape != ta.shape or es.shape[1] != self.vec_length or es.ndim != 2:
|
|
||||||
print(es.shape, ta.shape)
|
|
||||||
raise Exception('es and ta have to be of shape Nxdim')
|
|
||||||
if ma is not None:
|
|
||||||
es = es[ma != 0]
|
|
||||||
ta = ta[ma != 0]
|
|
||||||
dist = np.linalg.norm(es - ta, ord=self.p, axis=1)
|
|
||||||
self.dists.append( dist )
|
|
||||||
|
|
||||||
class OutlierFractionMetric(DistanceMetric):
|
class OutlierFractionMetric(DistanceMetric):
|
||||||
def __init__(self, thresholds, *args, **kwargs):
|
def __init__(self, thresholds, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.thresholds = thresholds
|
self.thresholds = thresholds
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
dists = np.hstack(self.dists)
|
||||||
|
ret = {}
|
||||||
|
for t in self.thresholds:
|
||||||
|
ma = dists > t
|
||||||
|
ret[f'of{t}'] = float(ma.sum() / ma.size)
|
||||||
|
return ret
|
||||||
|
|
||||||
def get(self):
|
|
||||||
dists = np.hstack(self.dists)
|
|
||||||
ret = {}
|
|
||||||
for t in self.thresholds:
|
|
||||||
ma = dists > t
|
|
||||||
ret[f'of{t}'] = float(ma.sum() / ma.size)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
class RelativeDistanceMetric(BaseDistanceMetric):
|
class RelativeDistanceMetric(BaseDistanceMetric):
|
||||||
def __init__(self, vec_length, p=2, **kwargs):
|
def __init__(self, vec_length, p=2, **kwargs):
|
||||||
super().__init__(name=f'rel{p}', **kwargs)
|
super().__init__(name=f'rel{p}', **kwargs)
|
||||||
self.vec_length = vec_length
|
self.vec_length = vec_length
|
||||||
self.p = p
|
self.p = p
|
||||||
|
|
||||||
|
def add(self, es, ta, ma=None):
|
||||||
|
if es.shape != ta.shape or es.shape[1] != self.vec_length or es.ndim != 2:
|
||||||
|
raise Exception('es and ta have to be of shape Nxdim')
|
||||||
|
dist = np.linalg.norm(es - ta, ord=self.p, axis=1)
|
||||||
|
denom = np.linalg.norm(ta, ord=self.p, axis=1)
|
||||||
|
dist /= denom
|
||||||
|
if ma is not None:
|
||||||
|
dist = dist[ma != 0]
|
||||||
|
self.dists.append(dist)
|
||||||
|
|
||||||
def add(self, es, ta, ma=None):
|
|
||||||
if es.shape != ta.shape or es.shape[1] != self.vec_length or es.ndim != 2:
|
|
||||||
raise Exception('es and ta have to be of shape Nxdim')
|
|
||||||
dist = np.linalg.norm(es - ta, ord=self.p, axis=1)
|
|
||||||
denom = np.linalg.norm(ta, ord=self.p, axis=1)
|
|
||||||
dist /= denom
|
|
||||||
if ma is not None:
|
|
||||||
dist = dist[ma != 0]
|
|
||||||
self.dists.append( dist )
|
|
||||||
|
|
||||||
class RotmDistanceMetric(BaseDistanceMetric):
|
class RotmDistanceMetric(BaseDistanceMetric):
|
||||||
def __init__(self, type='identity', **kwargs):
|
def __init__(self, type='identity', **kwargs):
|
||||||
super().__init__(name=type, **kwargs)
|
super().__init__(name=type, **kwargs)
|
||||||
self.type = type
|
self.type = type
|
||||||
|
|
||||||
|
def add(self, es, ta, ma=None):
|
||||||
|
if es.shape != ta.shape or es.shape[1] != 3 or es.shape[2] != 3 or es.ndim != 3:
|
||||||
|
print(es.shape, ta.shape)
|
||||||
|
raise Exception('es and ta have to be of shape Nx3x3')
|
||||||
|
if ma is not None:
|
||||||
|
raise Exception('mask is not implemented')
|
||||||
|
if self.type == 'identity':
|
||||||
|
self.dists.append(geometry.rotm_distance_identity(es, ta))
|
||||||
|
elif self.type == 'geodesic':
|
||||||
|
self.dists.append(geometry.rotm_distance_geodesic_unit_sphere(es, ta))
|
||||||
|
else:
|
||||||
|
raise Exception('invalid distance type')
|
||||||
|
|
||||||
def add(self, es, ta, ma=None):
|
|
||||||
if es.shape != ta.shape or es.shape[1] != 3 or es.shape[2] != 3 or es.ndim != 3:
|
|
||||||
print(es.shape, ta.shape)
|
|
||||||
raise Exception('es and ta have to be of shape Nx3x3')
|
|
||||||
if ma is not None:
|
|
||||||
raise Exception('mask is not implemented')
|
|
||||||
if self.type == 'identity':
|
|
||||||
self.dists.append( geometry.rotm_distance_identity(es, ta) )
|
|
||||||
elif self.type == 'geodesic':
|
|
||||||
self.dists.append( geometry.rotm_distance_geodesic_unit_sphere(es, ta) )
|
|
||||||
else:
|
|
||||||
raise Exception('invalid distance type')
|
|
||||||
|
|
||||||
class QuaternionDistanceMetric(BaseDistanceMetric):
|
class QuaternionDistanceMetric(BaseDistanceMetric):
|
||||||
def __init__(self, type='angle', **kwargs):
|
def __init__(self, type='angle', **kwargs):
|
||||||
super().__init__(name=type, **kwargs)
|
super().__init__(name=type, **kwargs)
|
||||||
self.type = type
|
self.type = type
|
||||||
|
|
||||||
def add(self, es, ta, ma=None):
|
def add(self, es, ta, ma=None):
|
||||||
if es.shape != ta.shape or es.shape[1] != 4 or es.ndim != 2:
|
if es.shape != ta.shape or es.shape[1] != 4 or es.ndim != 2:
|
||||||
print(es.shape, ta.shape)
|
print(es.shape, ta.shape)
|
||||||
raise Exception('es and ta have to be of shape Nx4')
|
raise Exception('es and ta have to be of shape Nx4')
|
||||||
if ma is not None:
|
if ma is not None:
|
||||||
raise Exception('mask is not implemented')
|
raise Exception('mask is not implemented')
|
||||||
if self.type == 'angle':
|
if self.type == 'angle':
|
||||||
self.dists.append( geometry.quat_distance_angle(es, ta) )
|
self.dists.append(geometry.quat_distance_angle(es, ta))
|
||||||
elif self.type == 'mineucl':
|
elif self.type == 'mineucl':
|
||||||
self.dists.append( geometry.quat_distance_mineucl(es, ta) )
|
self.dists.append(geometry.quat_distance_mineucl(es, ta))
|
||||||
elif self.type == 'normdiff':
|
elif self.type == 'normdiff':
|
||||||
self.dists.append( geometry.quat_distance_normdiff(es, ta) )
|
self.dists.append(geometry.quat_distance_normdiff(es, ta))
|
||||||
else:
|
else:
|
||||||
raise Exception('invalid distance type')
|
raise Exception('invalid distance type')
|
||||||
|
|
||||||
|
|
||||||
class BinaryAccuracyMetric(Metric):
|
class BinaryAccuracyMetric(Metric):
|
||||||
def __init__(self, thresholds=np.linspace(0.0, 1.0, num=101, dtype=np.float64)[:-1], **kwargs):
|
def __init__(self, thresholds=np.linspace(0.0, 1.0, num=101, dtype=np.float64)[:-1], **kwargs):
|
||||||
self.thresholds = thresholds
|
self.thresholds = thresholds
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.tps = [0 for wp in self.thresholds]
|
self.tps = [0 for wp in self.thresholds]
|
||||||
self.fps = [0 for wp in self.thresholds]
|
self.fps = [0 for wp in self.thresholds]
|
||||||
self.fns = [0 for wp in self.thresholds]
|
self.fns = [0 for wp in self.thresholds]
|
||||||
self.tns = [0 for wp in self.thresholds]
|
self.tns = [0 for wp in self.thresholds]
|
||||||
self.n_pos = 0
|
self.n_pos = 0
|
||||||
self.n_neg = 0
|
self.n_neg = 0
|
||||||
|
|
||||||
def add(self, es, ta, ma=None):
|
def add(self, es, ta, ma=None):
|
||||||
if ma is not None:
|
if ma is not None:
|
||||||
raise Exception('mask is not implemented')
|
raise Exception('mask is not implemented')
|
||||||
es = es.ravel()
|
es = es.ravel()
|
||||||
ta = ta.ravel()
|
ta = ta.ravel()
|
||||||
if es.shape[0] != ta.shape[0]:
|
if es.shape[0] != ta.shape[0]:
|
||||||
raise Exception('invalid shape of es, or ta')
|
raise Exception('invalid shape of es, or ta')
|
||||||
if es.min() < 0 or es.max() > 1:
|
if es.min() < 0 or es.max() > 1:
|
||||||
raise Exception('estimate has wrong value range')
|
raise Exception('estimate has wrong value range')
|
||||||
ta_p = (ta == 1)
|
ta_p = (ta == 1)
|
||||||
ta_n = (ta == 0)
|
ta_n = (ta == 0)
|
||||||
es_p = es[ta_p]
|
es_p = es[ta_p]
|
||||||
es_n = es[ta_n]
|
es_n = es[ta_n]
|
||||||
for idx, wp in enumerate(self.thresholds):
|
for idx, wp in enumerate(self.thresholds):
|
||||||
wp = np.asscalar(wp)
|
wp = np.asscalar(wp)
|
||||||
self.tps[idx] += (es_p > wp).sum()
|
self.tps[idx] += (es_p > wp).sum()
|
||||||
self.fps[idx] += (es_n > wp).sum()
|
self.fps[idx] += (es_n > wp).sum()
|
||||||
self.fns[idx] += (es_p <= wp).sum()
|
self.fns[idx] += (es_p <= wp).sum()
|
||||||
self.tns[idx] += (es_n <= wp).sum()
|
self.tns[idx] += (es_n <= wp).sum()
|
||||||
self.n_pos += ta_p.sum()
|
self.n_pos += ta_p.sum()
|
||||||
self.n_neg += ta_n.sum()
|
self.n_neg += ta_n.sum()
|
||||||
|
|
||||||
def get(self):
|
def get(self):
|
||||||
tps = np.array(self.tps).astype(np.float32)
|
tps = np.array(self.tps).astype(np.float32)
|
||||||
fps = np.array(self.fps).astype(np.float32)
|
fps = np.array(self.fps).astype(np.float32)
|
||||||
fns = np.array(self.fns).astype(np.float32)
|
fns = np.array(self.fns).astype(np.float32)
|
||||||
tns = np.array(self.tns).astype(np.float32)
|
tns = np.array(self.tns).astype(np.float32)
|
||||||
wp = self.thresholds
|
wp = self.thresholds
|
||||||
|
|
||||||
ret = {}
|
ret = {}
|
||||||
|
|
||||||
precisions = np.divide(tps, tps + fps, out=np.zeros_like(tps), where=tps + fps != 0)
|
precisions = np.divide(tps, tps + fps, out=np.zeros_like(tps), where=tps + fps != 0)
|
||||||
recalls = np.divide(tps, tps + fns, out=np.zeros_like(tps), where=tps + fns != 0) # tprs
|
recalls = np.divide(tps, tps + fns, out=np.zeros_like(tps), where=tps + fns != 0) # tprs
|
||||||
fprs = np.divide(fps, fps + tns, out=np.zeros_like(tps), where=fps + tns != 0)
|
fprs = np.divide(fps, fps + tns, out=np.zeros_like(tps), where=fps + tns != 0)
|
||||||
|
|
||||||
precisions = np.r_[0, precisions, 1]
|
precisions = np.r_[0, precisions, 1]
|
||||||
recalls = np.r_[1, recalls, 0]
|
recalls = np.r_[1, recalls, 0]
|
||||||
fprs = np.r_[1, fprs, 0]
|
fprs = np.r_[1, fprs, 0]
|
||||||
|
|
||||||
ret['auc'] = float(-np.trapz(recalls, fprs))
|
ret['auc'] = float(-np.trapz(recalls, fprs))
|
||||||
ret['prauc'] = float(-np.trapz(precisions, recalls))
|
ret['prauc'] = float(-np.trapz(precisions, recalls))
|
||||||
ret['ap'] = float(-(np.diff(recalls) * precisions[:-1]).sum())
|
ret['ap'] = float(-(np.diff(recalls) * precisions[:-1]).sum())
|
||||||
|
|
||||||
accuracies = np.divide(tps + tns, tps + tns + fps + fns)
|
accuracies = np.divide(tps + tns, tps + tns + fps + fns)
|
||||||
aacc = np.mean(accuracies)
|
aacc = np.mean(accuracies)
|
||||||
for t in np.linspace(0,1,num=11)[1:-1]:
|
for t in np.linspace(0, 1, num=11)[1:-1]:
|
||||||
idx = np.argmin(np.abs(t - wp))
|
idx = np.argmin(np.abs(t - wp))
|
||||||
ret[f'acc{wp[idx]:.2f}'] = float(accuracies[idx])
|
ret[f'acc{wp[idx]:.2f}'] = float(accuracies[idx])
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
163
co/plt.py
163
co/plt.py
@ -6,94 +6,99 @@ import matplotlib.pyplot as plt
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
||||||
def save(path, remove_axis=False, dpi=300, fig=None):
|
def save(path, remove_axis=False, dpi=300, fig=None):
|
||||||
if fig is None:
|
if fig is None:
|
||||||
fig = plt.gcf()
|
fig = plt.gcf()
|
||||||
dirname = os.path.dirname(path)
|
dirname = os.path.dirname(path)
|
||||||
if dirname != '' and not os.path.exists(dirname):
|
if dirname != '' and not os.path.exists(dirname):
|
||||||
os.makedirs(dirname)
|
os.makedirs(dirname)
|
||||||
if remove_axis:
|
if remove_axis:
|
||||||
for ax in fig.axes:
|
for ax in fig.axes:
|
||||||
ax.axis('off')
|
ax.axis('off')
|
||||||
ax.margins(0,0)
|
ax.margins(0, 0)
|
||||||
fig.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
fig.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
||||||
for ax in fig.axes:
|
for ax in fig.axes:
|
||||||
ax.xaxis.set_major_locator(plt.NullLocator())
|
ax.xaxis.set_major_locator(plt.NullLocator())
|
||||||
ax.yaxis.set_major_locator(plt.NullLocator())
|
ax.yaxis.set_major_locator(plt.NullLocator())
|
||||||
fig.savefig(path, dpi=dpi, bbox_inches='tight', pad_inches=0)
|
fig.savefig(path, dpi=dpi, bbox_inches='tight', pad_inches=0)
|
||||||
|
|
||||||
|
|
||||||
def color_map(im_, cmap='viridis', vmin=None, vmax=None):
|
def color_map(im_, cmap='viridis', vmin=None, vmax=None):
|
||||||
cm = plt.get_cmap(cmap)
|
cm = plt.get_cmap(cmap)
|
||||||
im = im_.copy()
|
im = im_.copy()
|
||||||
if vmin is None:
|
if vmin is None:
|
||||||
vmin = np.nanmin(im)
|
vmin = np.nanmin(im)
|
||||||
if vmax is None:
|
if vmax is None:
|
||||||
vmax = np.nanmax(im)
|
vmax = np.nanmax(im)
|
||||||
mask = np.logical_not(np.isfinite(im))
|
mask = np.logical_not(np.isfinite(im))
|
||||||
im[mask] = vmin
|
im[mask] = vmin
|
||||||
im = (im.clip(vmin, vmax) - vmin) / (vmax - vmin)
|
im = (im.clip(vmin, vmax) - vmin) / (vmax - vmin)
|
||||||
im = cm(im)
|
im = cm(im)
|
||||||
im = im[...,:3]
|
im = im[..., :3]
|
||||||
for c in range(3):
|
for c in range(3):
|
||||||
im[mask, c] = 1
|
im[mask, c] = 1
|
||||||
return im
|
return im
|
||||||
|
|
||||||
|
|
||||||
def interactive_legend(leg=None, fig=None, all_axes=True):
|
def interactive_legend(leg=None, fig=None, all_axes=True):
|
||||||
if leg is None:
|
if leg is None:
|
||||||
leg = plt.legend()
|
leg = plt.legend()
|
||||||
if fig is None:
|
if fig is None:
|
||||||
fig = plt.gcf()
|
fig = plt.gcf()
|
||||||
if all_axes:
|
if all_axes:
|
||||||
axs = fig.get_axes()
|
axs = fig.get_axes()
|
||||||
else:
|
|
||||||
axs = [fig.gca()]
|
|
||||||
|
|
||||||
# lined = dict()
|
|
||||||
# lines = ax.lines
|
|
||||||
# for legline, origline in zip(leg.get_lines(), ax.lines):
|
|
||||||
# legline.set_picker(5)
|
|
||||||
# lined[legline] = origline
|
|
||||||
lined = dict()
|
|
||||||
for lidx, legline in enumerate(leg.get_lines()):
|
|
||||||
legline.set_picker(5)
|
|
||||||
lined[legline] = [ax.lines[lidx] for ax in axs]
|
|
||||||
|
|
||||||
def onpick(event):
|
|
||||||
if event.mouseevent.dblclick:
|
|
||||||
tmp = [(k,v) for k,v in lined.items()]
|
|
||||||
else:
|
else:
|
||||||
tmp = [(event.artist, lined[event.artist])]
|
axs = [fig.gca()]
|
||||||
|
|
||||||
for legline, origline in tmp:
|
# lined = dict()
|
||||||
for ol in origline:
|
# lines = ax.lines
|
||||||
vis = not ol.get_visible()
|
# for legline, origline in zip(leg.get_lines(), ax.lines):
|
||||||
ol.set_visible(vis)
|
# legline.set_picker(5)
|
||||||
if vis:
|
# lined[legline] = origline
|
||||||
legline.set_alpha(1.0)
|
lined = dict()
|
||||||
else:
|
for lidx, legline in enumerate(leg.get_lines()):
|
||||||
legline.set_alpha(0.2)
|
legline.set_picker(5)
|
||||||
fig.canvas.draw()
|
lined[legline] = [ax.lines[lidx] for ax in axs]
|
||||||
|
|
||||||
|
def onpick(event):
|
||||||
|
if event.mouseevent.dblclick:
|
||||||
|
tmp = [(k, v) for k, v in lined.items()]
|
||||||
|
else:
|
||||||
|
tmp = [(event.artist, lined[event.artist])]
|
||||||
|
|
||||||
|
for legline, origline in tmp:
|
||||||
|
for ol in origline:
|
||||||
|
vis = not ol.get_visible()
|
||||||
|
ol.set_visible(vis)
|
||||||
|
if vis:
|
||||||
|
legline.set_alpha(1.0)
|
||||||
|
else:
|
||||||
|
legline.set_alpha(0.2)
|
||||||
|
fig.canvas.draw()
|
||||||
|
|
||||||
|
fig.canvas.mpl_connect('pick_event', onpick)
|
||||||
|
|
||||||
fig.canvas.mpl_connect('pick_event', onpick)
|
|
||||||
|
|
||||||
def non_annoying_pause(interval, focus_figure=False):
|
def non_annoying_pause(interval, focus_figure=False):
|
||||||
# https://github.com/matplotlib/matplotlib/issues/11131
|
# https://github.com/matplotlib/matplotlib/issues/11131
|
||||||
backend = mpl.rcParams['backend']
|
backend = mpl.rcParams['backend']
|
||||||
if backend in _interactive_bk:
|
if backend in _interactive_bk:
|
||||||
figManager = _pylab_helpers.Gcf.get_active()
|
figManager = _pylab_helpers.Gcf.get_active()
|
||||||
if figManager is not None:
|
if figManager is not None:
|
||||||
canvas = figManager.canvas
|
canvas = figManager.canvas
|
||||||
if canvas.figure.stale:
|
if canvas.figure.stale:
|
||||||
canvas.draw()
|
canvas.draw()
|
||||||
if focus_figure:
|
if focus_figure:
|
||||||
plt.show(block=False)
|
plt.show(block=False)
|
||||||
canvas.start_event_loop(interval)
|
canvas.start_event_loop(interval)
|
||||||
return
|
return
|
||||||
time.sleep(interval)
|
time.sleep(interval)
|
||||||
|
|
||||||
|
|
||||||
def remove_all_ticks(fig=None):
|
def remove_all_ticks(fig=None):
|
||||||
if fig is None:
|
if fig is None:
|
||||||
fig = plt.gcf()
|
fig = plt.gcf()
|
||||||
for ax in fig.axes:
|
for ax in fig.axes:
|
||||||
ax.axes.get_xaxis().set_visible(False)
|
ax.axes.get_xaxis().set_visible(False)
|
||||||
ax.axes.get_yaxis().set_visible(False)
|
ax.axes.get_yaxis().set_visible(False)
|
||||||
|
91
co/plt2d.py
91
co/plt2d.py
@ -3,55 +3,60 @@ import matplotlib.pyplot as plt
|
|||||||
|
|
||||||
from . import geometry
|
from . import geometry
|
||||||
|
|
||||||
|
|
||||||
def image_matrix(ims, bgval=0):
|
def image_matrix(ims, bgval=0):
|
||||||
n = ims.shape[0]
|
n = ims.shape[0]
|
||||||
m = int( np.ceil(np.sqrt(n)) )
|
m = int(np.ceil(np.sqrt(n)))
|
||||||
h = ims.shape[1]
|
h = ims.shape[1]
|
||||||
w = ims.shape[2]
|
w = ims.shape[2]
|
||||||
mat = np.empty((m*h, m*w), dtype=ims.dtype)
|
mat = np.empty((m * h, m * w), dtype=ims.dtype)
|
||||||
mat.fill(bgval)
|
mat.fill(bgval)
|
||||||
idx = 0
|
idx = 0
|
||||||
for r in range(m):
|
for r in range(m):
|
||||||
for c in range(m):
|
for c in range(m):
|
||||||
if idx < n:
|
if idx < n:
|
||||||
mat[r*h:(r+1)*h, c*w:(c+1)*w] = ims[idx]
|
mat[r * h:(r + 1) * h, c * w:(c + 1) * w] = ims[idx]
|
||||||
idx += 1
|
idx += 1
|
||||||
return mat
|
return mat
|
||||||
|
|
||||||
|
|
||||||
def image_cat(ims, vertical=False):
|
def image_cat(ims, vertical=False):
|
||||||
offx = [0]
|
offx = [0]
|
||||||
offy = [0]
|
offy = [0]
|
||||||
if vertical:
|
if vertical:
|
||||||
width = max([im.shape[1] for im in ims])
|
width = max([im.shape[1] for im in ims])
|
||||||
offx += [0 for im in ims[:-1]]
|
offx += [0 for im in ims[:-1]]
|
||||||
offy += [im.shape[0] for im in ims[:-1]]
|
offy += [im.shape[0] for im in ims[:-1]]
|
||||||
height = sum([im.shape[0] for im in ims])
|
height = sum([im.shape[0] for im in ims])
|
||||||
else:
|
else:
|
||||||
height = max([im.shape[0] for im in ims])
|
height = max([im.shape[0] for im in ims])
|
||||||
offx += [im.shape[1] for im in ims[:-1]]
|
offx += [im.shape[1] for im in ims[:-1]]
|
||||||
offy += [0 for im in ims[:-1]]
|
offy += [0 for im in ims[:-1]]
|
||||||
width = sum([im.shape[1] for im in ims])
|
width = sum([im.shape[1] for im in ims])
|
||||||
offx = np.cumsum(offx)
|
offx = np.cumsum(offx)
|
||||||
offy = np.cumsum(offy)
|
offy = np.cumsum(offy)
|
||||||
|
|
||||||
im = np.zeros((height,width,*ims[0].shape[2:]), dtype=ims[0].dtype)
|
im = np.zeros((height, width, *ims[0].shape[2:]), dtype=ims[0].dtype)
|
||||||
for im0, ox, oy in zip(ims, offx, offy):
|
for im0, ox, oy in zip(ims, offx, offy):
|
||||||
im[oy:oy + im0.shape[0], ox:ox + im0.shape[1]] = im0
|
im[oy:oy + im0.shape[0], ox:ox + im0.shape[1]] = im0
|
||||||
|
|
||||||
|
return im, offx, offy
|
||||||
|
|
||||||
return im, offx, offy
|
|
||||||
|
|
||||||
def line(li, h, w, ax=None, *args, **kwargs):
|
def line(li, h, w, ax=None, *args, **kwargs):
|
||||||
if ax is None:
|
if ax is None:
|
||||||
ax = plt.gca()
|
ax = plt.gca()
|
||||||
xs = (-li[2] - li[1] * np.array((0, h-1))) / li[0]
|
xs = (-li[2] - li[1] * np.array((0, h - 1))) / li[0]
|
||||||
ys = (-li[2] - li[0] * np.array((0, w-1))) / li[1]
|
ys = (-li[2] - li[0] * np.array((0, w - 1))) / li[1]
|
||||||
pts = np.array([(0,ys[0]), (w-1, ys[1]), (xs[0], 0), (xs[1], h-1)])
|
pts = np.array([(0, ys[0]), (w - 1, ys[1]), (xs[0], 0), (xs[1], h - 1)])
|
||||||
pts = pts[np.logical_and(np.logical_and(pts[:,0] >= 0, pts[:,0] < w), np.logical_and(pts[:,1] >= 0, pts[:,1] < h))]
|
pts = pts[
|
||||||
ax.plot(pts[:,0], pts[:,1], *args, **kwargs)
|
np.logical_and(np.logical_and(pts[:, 0] >= 0, pts[:, 0] < w), np.logical_and(pts[:, 1] >= 0, pts[:, 1] < h))]
|
||||||
|
ax.plot(pts[:, 0], pts[:, 1], *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def depthshow(depth, *args, ax=None, **kwargs):
|
def depthshow(depth, *args, ax=None, **kwargs):
|
||||||
if ax is None:
|
if ax is None:
|
||||||
ax = plt.gca()
|
ax = plt.gca()
|
||||||
d = depth.copy()
|
d = depth.copy()
|
||||||
d[d < 0] = np.NaN
|
d[d < 0] = np.NaN
|
||||||
ax.imshow(d, *args, **kwargs)
|
ax.imshow(d, *args, **kwargs)
|
||||||
|
64
co/plt3d.py
64
co/plt3d.py
@ -4,35 +4,45 @@ from mpl_toolkits.mplot3d import Axes3D
|
|||||||
|
|
||||||
from . import geometry
|
from . import geometry
|
||||||
|
|
||||||
|
|
||||||
def ax3d(fig=None):
|
def ax3d(fig=None):
|
||||||
if fig is None:
|
if fig is None:
|
||||||
fig = plt.gcf()
|
fig = plt.gcf()
|
||||||
return fig.add_subplot(111, projection='3d')
|
return fig.add_subplot(111, projection='3d')
|
||||||
|
|
||||||
def plot_camera(ax=None, R=np.eye(3), t=np.zeros((3,)), size=25, marker_C='.', color='b', linestyle='-', linewidth=0.1, label=None, **kwargs):
|
|
||||||
if ax is None:
|
|
||||||
ax = plt.gca()
|
|
||||||
C0 = geometry.translation_to_cameracenter(R, t).ravel()
|
|
||||||
C1 = C0 + R.T.dot( np.array([[-size],[-size],[3*size]], dtype=np.float32) ).ravel()
|
|
||||||
C2 = C0 + R.T.dot( np.array([[-size],[+size],[3*size]], dtype=np.float32) ).ravel()
|
|
||||||
C3 = C0 + R.T.dot( np.array([[+size],[+size],[3*size]], dtype=np.float32) ).ravel()
|
|
||||||
C4 = C0 + R.T.dot( np.array([[+size],[-size],[3*size]], dtype=np.float32) ).ravel()
|
|
||||||
|
|
||||||
if marker_C != '':
|
def plot_camera(ax=None, R=np.eye(3), t=np.zeros((3,)), size=25, marker_C='.', color='b', linestyle='-', linewidth=0.1,
|
||||||
ax.plot([C0[0]], [C0[1]], [C0[2]], marker=marker_C, color=color, label=label, **kwargs)
|
label=None, **kwargs):
|
||||||
ax.plot([C0[0], C1[0]], [C0[1], C1[1]], [C0[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs)
|
if ax is None:
|
||||||
ax.plot([C0[0], C2[0]], [C0[1], C2[1]], [C0[2], C2[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs)
|
ax = plt.gca()
|
||||||
ax.plot([C0[0], C3[0]], [C0[1], C3[1]], [C0[2], C3[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs)
|
C0 = geometry.translation_to_cameracenter(R, t).ravel()
|
||||||
ax.plot([C0[0], C4[0]], [C0[1], C4[1]], [C0[2], C4[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs)
|
C1 = C0 + R.T.dot(np.array([[-size], [-size], [3 * size]], dtype=np.float32)).ravel()
|
||||||
ax.plot([C1[0], C2[0], C3[0], C4[0], C1[0]], [C1[1], C2[1], C3[1], C4[1], C1[1]], [C1[2], C2[2], C3[2], C4[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs)
|
C2 = C0 + R.T.dot(np.array([[-size], [+size], [3 * size]], dtype=np.float32)).ravel()
|
||||||
|
C3 = C0 + R.T.dot(np.array([[+size], [+size], [3 * size]], dtype=np.float32)).ravel()
|
||||||
|
C4 = C0 + R.T.dot(np.array([[+size], [-size], [3 * size]], dtype=np.float32)).ravel()
|
||||||
|
|
||||||
|
if marker_C != '':
|
||||||
|
ax.plot([C0[0]], [C0[1]], [C0[2]], marker=marker_C, color=color, label=label, **kwargs)
|
||||||
|
ax.plot([C0[0], C1[0]], [C0[1], C1[1]], [C0[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle,
|
||||||
|
linewidth=linewidth, **kwargs)
|
||||||
|
ax.plot([C0[0], C2[0]], [C0[1], C2[1]], [C0[2], C2[2]], color=color, label='_nolegend_', linestyle=linestyle,
|
||||||
|
linewidth=linewidth, **kwargs)
|
||||||
|
ax.plot([C0[0], C3[0]], [C0[1], C3[1]], [C0[2], C3[2]], color=color, label='_nolegend_', linestyle=linestyle,
|
||||||
|
linewidth=linewidth, **kwargs)
|
||||||
|
ax.plot([C0[0], C4[0]], [C0[1], C4[1]], [C0[2], C4[2]], color=color, label='_nolegend_', linestyle=linestyle,
|
||||||
|
linewidth=linewidth, **kwargs)
|
||||||
|
ax.plot([C1[0], C2[0], C3[0], C4[0], C1[0]], [C1[1], C2[1], C3[1], C4[1], C1[1]],
|
||||||
|
[C1[2], C2[2], C3[2], C4[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle,
|
||||||
|
linewidth=linewidth, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def axis_equal(ax=None):
|
def axis_equal(ax=None):
|
||||||
if ax is None:
|
if ax is None:
|
||||||
ax = plt.gca()
|
ax = plt.gca()
|
||||||
extents = np.array([getattr(ax, 'get_{}lim'.format(dim))() for dim in 'xyz'])
|
extents = np.array([getattr(ax, 'get_{}lim'.format(dim))() for dim in 'xyz'])
|
||||||
sz = extents[:,1] - extents[:,0]
|
sz = extents[:, 1] - extents[:, 0]
|
||||||
centers = np.mean(extents, axis=1)
|
centers = np.mean(extents, axis=1)
|
||||||
maxsize = max(abs(sz))
|
maxsize = max(abs(sz))
|
||||||
r = maxsize/2
|
r = maxsize / 2
|
||||||
for ctr, dim in zip(centers, 'xyz'):
|
for ctr, dim in zip(centers, 'xyz'):
|
||||||
getattr(ax, 'set_{}lim'.format(dim))(ctr - r, ctr + r)
|
getattr(ax, 'set_{}lim'.format(dim))(ctr - r, ctr + r)
|
||||||
|
748
co/table.py
748
co/table.py
@ -3,443 +3,453 @@ import pandas as pd
|
|||||||
import enum
|
import enum
|
||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
|
|
||||||
class Table(object):
|
class Table(object):
|
||||||
def __init__(self, n_cols):
|
def __init__(self, n_cols):
|
||||||
self.n_cols = n_cols
|
self.n_cols = n_cols
|
||||||
self.rows = []
|
self.rows = []
|
||||||
self.aligns = ['r' for c in range(n_cols)]
|
self.aligns = ['r' for c in range(n_cols)]
|
||||||
|
|
||||||
def get_cell_align(self, r, c):
|
def get_cell_align(self, r, c):
|
||||||
align = self.rows[r].cells[c].align
|
align = self.rows[r].cells[c].align
|
||||||
if align is None:
|
if align is None:
|
||||||
return self.aligns[c]
|
return self.aligns[c]
|
||||||
else:
|
|
||||||
return align
|
|
||||||
|
|
||||||
def add_row(self, row):
|
|
||||||
if row.ncols() != self.n_cols:
|
|
||||||
raise Exception(f'row has invalid number of cols, {row.ncols()} vs. {self.n_cols}')
|
|
||||||
self.rows.append(row)
|
|
||||||
|
|
||||||
def empty_row(self):
|
|
||||||
return Row.Empty(self.n_cols)
|
|
||||||
|
|
||||||
def expand_rows(self, n_add_cols=1):
|
|
||||||
if n_add_cols < 0: raise Exception('n_add_cols has to be positive')
|
|
||||||
self.n_cols += n_add_cols
|
|
||||||
for row in self.rows:
|
|
||||||
row.cells.extend([Cell() for cidx in range(n_add_cols)])
|
|
||||||
|
|
||||||
def add_block(self, data, row=-1, col=0, fmt=None, expand=False):
|
|
||||||
if row < 0: row = len(self.rows)
|
|
||||||
while len(self.rows) < row + len(data):
|
|
||||||
self.add_row(self.empty_row())
|
|
||||||
for r in range(len(data)):
|
|
||||||
cols = data[r]
|
|
||||||
if col + len(cols) > self.n_cols:
|
|
||||||
if expand:
|
|
||||||
self.expand_rows(col + len(cols) - self.n_cols)
|
|
||||||
else:
|
else:
|
||||||
raise Exception('number of cols does not fit in table')
|
return align
|
||||||
for c in range(len(cols)):
|
|
||||||
self.rows[row+r].cells[col+c] = Cell(data[r][c], fmt)
|
def add_row(self, row):
|
||||||
|
if row.ncols() != self.n_cols:
|
||||||
|
raise Exception(f'row has invalid number of cols, {row.ncols()} vs. {self.n_cols}')
|
||||||
|
self.rows.append(row)
|
||||||
|
|
||||||
|
def empty_row(self):
|
||||||
|
return Row.Empty(self.n_cols)
|
||||||
|
|
||||||
|
def expand_rows(self, n_add_cols=1):
|
||||||
|
if n_add_cols < 0: raise Exception('n_add_cols has to be positive')
|
||||||
|
self.n_cols += n_add_cols
|
||||||
|
for row in self.rows:
|
||||||
|
row.cells.extend([Cell() for cidx in range(n_add_cols)])
|
||||||
|
|
||||||
|
def add_block(self, data, row=-1, col=0, fmt=None, expand=False):
|
||||||
|
if row < 0: row = len(self.rows)
|
||||||
|
while len(self.rows) < row + len(data):
|
||||||
|
self.add_row(self.empty_row())
|
||||||
|
for r in range(len(data)):
|
||||||
|
cols = data[r]
|
||||||
|
if col + len(cols) > self.n_cols:
|
||||||
|
if expand:
|
||||||
|
self.expand_rows(col + len(cols) - self.n_cols)
|
||||||
|
else:
|
||||||
|
raise Exception('number of cols does not fit in table')
|
||||||
|
for c in range(len(cols)):
|
||||||
|
self.rows[row + r].cells[col + c] = Cell(data[r][c], fmt)
|
||||||
|
|
||||||
|
|
||||||
class Row(object):
|
class Row(object):
|
||||||
def __init__(self, cells, pre_separator=None, post_separator=None):
|
def __init__(self, cells, pre_separator=None, post_separator=None):
|
||||||
self.cells = cells
|
self.cells = cells
|
||||||
self.pre_separator = pre_separator
|
self.pre_separator = pre_separator
|
||||||
self.post_separator = post_separator
|
self.post_separator = post_separator
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def Empty(cls, n_cols):
|
def Empty(cls, n_cols):
|
||||||
return Row([Cell() for c in range(n_cols)])
|
return Row([Cell() for c in range(n_cols)])
|
||||||
|
|
||||||
def add_cell(self, cell):
|
def add_cell(self, cell):
|
||||||
self.cells.append(cell)
|
self.cells.append(cell)
|
||||||
|
|
||||||
def ncols(self):
|
|
||||||
return sum([c.span for c in self.cells])
|
|
||||||
|
|
||||||
|
def ncols(self):
|
||||||
|
return sum([c.span for c in self.cells])
|
||||||
|
|
||||||
|
|
||||||
class Color(object):
|
class Color(object):
|
||||||
def __init__(self, color=(0,0,0), fmt='rgb'):
|
def __init__(self, color=(0, 0, 0), fmt='rgb'):
|
||||||
if fmt == 'rgb':
|
if fmt == 'rgb':
|
||||||
self.color = color
|
self.color = color
|
||||||
elif fmt == 'RGB':
|
elif fmt == 'RGB':
|
||||||
self.color = tuple(c / 255 for c in color)
|
self.color = tuple(c / 255 for c in color)
|
||||||
else:
|
else:
|
||||||
return Exception('invalid color format')
|
return Exception('invalid color format')
|
||||||
|
|
||||||
def as_rgb(self):
|
def as_rgb(self):
|
||||||
return self.color
|
return self.color
|
||||||
|
|
||||||
def as_RGB(self):
|
def as_RGB(self):
|
||||||
return tuple(int(c * 255) for c in self.color)
|
return tuple(int(c * 255) for c in self.color)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def rgb(cls, r, g, b):
|
def rgb(cls, r, g, b):
|
||||||
return Color(color=(r,g,b), fmt='rgb')
|
return Color(color=(r, g, b), fmt='rgb')
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def RGB(cls, r, g, b):
|
def RGB(cls, r, g, b):
|
||||||
return Color(color=(r,g,b), fmt='RGB')
|
return Color(color=(r, g, b), fmt='RGB')
|
||||||
|
|
||||||
|
|
||||||
class CellFormat(object):
|
class CellFormat(object):
|
||||||
def __init__(self, fmt='%s', fgcolor=None, bgcolor=None, bold=False):
|
def __init__(self, fmt='%s', fgcolor=None, bgcolor=None, bold=False):
|
||||||
self.fmt = fmt
|
self.fmt = fmt
|
||||||
self.fgcolor = fgcolor
|
self.fgcolor = fgcolor
|
||||||
self.bgcolor = bgcolor
|
self.bgcolor = bgcolor
|
||||||
self.bold = bold
|
self.bold = bold
|
||||||
|
|
||||||
|
|
||||||
class Cell(object):
|
class Cell(object):
|
||||||
def __init__(self, data=None, fmt=None, span=1, align=None):
|
def __init__(self, data=None, fmt=None, span=1, align=None):
|
||||||
self.data = data
|
self.data = data
|
||||||
if fmt is None:
|
if fmt is None:
|
||||||
fmt = CellFormat()
|
fmt = CellFormat()
|
||||||
self.fmt = fmt
|
self.fmt = fmt
|
||||||
self.span = span
|
self.span = span
|
||||||
self.align = align
|
self.align = align
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.fmt.fmt % self.data
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return self.fmt.fmt % self.data
|
|
||||||
|
|
||||||
class Separator(enum.Enum):
|
class Separator(enum.Enum):
|
||||||
HEAD = 1
|
HEAD = 1
|
||||||
BOTTOM = 2
|
BOTTOM = 2
|
||||||
INNER = 3
|
INNER = 3
|
||||||
|
|
||||||
|
|
||||||
class Renderer(object):
|
class Renderer(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def cell_str_len(self, cell):
|
def cell_str_len(self, cell):
|
||||||
return len(str(cell))
|
return len(str(cell))
|
||||||
|
|
||||||
def col_widths(self, table):
|
def col_widths(self, table):
|
||||||
widths = [0 for c in range(table.n_cols)]
|
widths = [0 for c in range(table.n_cols)]
|
||||||
for row in table.rows:
|
for row in table.rows:
|
||||||
cidx = 0
|
cidx = 0
|
||||||
for cell in row.cells:
|
for cell in row.cells:
|
||||||
if cell.span == 1:
|
if cell.span == 1:
|
||||||
strlen = self.cell_str_len(cell)
|
strlen = self.cell_str_len(cell)
|
||||||
widths[cidx] = max(widths[cidx], strlen)
|
widths[cidx] = max(widths[cidx], strlen)
|
||||||
cidx += cell.span
|
cidx += cell.span
|
||||||
return widths
|
return widths
|
||||||
|
|
||||||
def render(self, table):
|
def render(self, table):
|
||||||
raise NotImplementedError('not implemented')
|
raise NotImplementedError('not implemented')
|
||||||
|
|
||||||
def __call__(self, table):
|
def __call__(self, table):
|
||||||
return self.render(table)
|
return self.render(table)
|
||||||
|
|
||||||
def render_to_file_comment(self):
|
def render_to_file_comment(self):
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
|
def render_to_file(self, path, table):
|
||||||
|
txt = self.render(table)
|
||||||
|
with open(path, 'w') as fp:
|
||||||
|
fp.write(txt)
|
||||||
|
|
||||||
def render_to_file(self, path, table):
|
|
||||||
txt = self.render(table)
|
|
||||||
with open(path, 'w') as fp:
|
|
||||||
fp.write(txt)
|
|
||||||
|
|
||||||
class TerminalRenderer(Renderer):
|
class TerminalRenderer(Renderer):
|
||||||
def __init__(self, col_sep=' '):
|
def __init__(self, col_sep=' '):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.col_sep = col_sep
|
self.col_sep = col_sep
|
||||||
|
|
||||||
def render_cell(self, table, row, col, widths):
|
def render_cell(self, table, row, col, widths):
|
||||||
cell = table.rows[row].cells[col]
|
cell = table.rows[row].cells[col]
|
||||||
str = cell.fmt.fmt % cell.data
|
str = cell.fmt.fmt % cell.data
|
||||||
str_width = len(str)
|
str_width = len(str)
|
||||||
cell_width = sum([widths[idx] for idx in range(col, col+cell.span)])
|
cell_width = sum([widths[idx] for idx in range(col, col + cell.span)])
|
||||||
cell_width += len(self.col_sep) * (cell.span - 1)
|
cell_width += len(self.col_sep) * (cell.span - 1)
|
||||||
if len(str) > cell_width:
|
if len(str) > cell_width:
|
||||||
str = str[:cell_width]
|
str = str[:cell_width]
|
||||||
if cell.fmt.bold:
|
if cell.fmt.bold:
|
||||||
# str = sty.ef.bold + str + sty.rs.bold_dim
|
# str = sty.ef.bold + str + sty.rs.bold_dim
|
||||||
# str = sty.ef.bold + str + sty.rs.bold
|
# str = sty.ef.bold + str + sty.rs.bold
|
||||||
pass
|
pass
|
||||||
if cell.fmt.fgcolor is not None:
|
if cell.fmt.fgcolor is not None:
|
||||||
# color = cell.fmt.fgcolor.as_RGB()
|
# color = cell.fmt.fgcolor.as_RGB()
|
||||||
# str = sty.fg(*color) + str + sty.rs.fg
|
# str = sty.fg(*color) + str + sty.rs.fg
|
||||||
pass
|
pass
|
||||||
if str_width < cell_width:
|
if str_width < cell_width:
|
||||||
n_ws = (cell_width - str_width)
|
n_ws = (cell_width - str_width)
|
||||||
if table.get_cell_align(row, col) == 'r':
|
if table.get_cell_align(row, col) == 'r':
|
||||||
str = ' '*n_ws + str
|
str = ' ' * n_ws + str
|
||||||
elif table.get_cell_align(row, col) == 'l':
|
elif table.get_cell_align(row, col) == 'l':
|
||||||
str = str + ' '*n_ws
|
str = str + ' ' * n_ws
|
||||||
elif table.get_cell_align(row, col) == 'c':
|
elif table.get_cell_align(row, col) == 'c':
|
||||||
n_ws1 = n_ws // 2
|
n_ws1 = n_ws // 2
|
||||||
n_ws0 = n_ws - n_ws1
|
n_ws0 = n_ws - n_ws1
|
||||||
str = ' '*n_ws0 + str + ' '*n_ws1
|
str = ' ' * n_ws0 + str + ' ' * n_ws1
|
||||||
if cell.fmt.bgcolor is not None:
|
if cell.fmt.bgcolor is not None:
|
||||||
# color = cell.fmt.bgcolor.as_RGB()
|
# color = cell.fmt.bgcolor.as_RGB()
|
||||||
# str = sty.bg(*color) + str + sty.rs.bg
|
# str = sty.bg(*color) + str + sty.rs.bg
|
||||||
pass
|
pass
|
||||||
return str
|
return str
|
||||||
|
|
||||||
def render_separator(self, separator, tab, col_widths, total_width):
|
def render_separator(self, separator, tab, col_widths, total_width):
|
||||||
if separator == Separator.HEAD:
|
if separator == Separator.HEAD:
|
||||||
return '='*total_width
|
return '=' * total_width
|
||||||
elif separator == Separator.INNER:
|
elif separator == Separator.INNER:
|
||||||
return '-'*total_width
|
return '-' * total_width
|
||||||
elif separator == Separator.BOTTOM:
|
elif separator == Separator.BOTTOM:
|
||||||
return '='*total_width
|
return '=' * total_width
|
||||||
|
|
||||||
|
def render(self, table):
|
||||||
|
widths = self.col_widths(table)
|
||||||
|
total_width = sum(widths) + len(self.col_sep) * (table.n_cols - 1)
|
||||||
|
lines = []
|
||||||
|
for ridx, row in enumerate(table.rows):
|
||||||
|
if row.pre_separator is not None:
|
||||||
|
sepline = self.render_separator(row.pre_separator, table, widths, total_width)
|
||||||
|
if len(sepline) > 0:
|
||||||
|
lines.append(sepline)
|
||||||
|
line = []
|
||||||
|
for cidx, cell in enumerate(row.cells):
|
||||||
|
line.append(self.render_cell(table, ridx, cidx, widths))
|
||||||
|
lines.append(self.col_sep.join(line))
|
||||||
|
if row.post_separator is not None:
|
||||||
|
sepline = self.render_separator(row.post_separator, table, widths, total_width)
|
||||||
|
if len(sepline) > 0:
|
||||||
|
lines.append(sepline)
|
||||||
|
return '\n'.join(lines)
|
||||||
|
|
||||||
def render(self, table):
|
|
||||||
widths = self.col_widths(table)
|
|
||||||
total_width = sum(widths) + len(self.col_sep) * (table.n_cols - 1)
|
|
||||||
lines = []
|
|
||||||
for ridx, row in enumerate(table.rows):
|
|
||||||
if row.pre_separator is not None:
|
|
||||||
sepline = self.render_separator(row.pre_separator, table, widths, total_width)
|
|
||||||
if len(sepline) > 0:
|
|
||||||
lines.append(sepline)
|
|
||||||
line = []
|
|
||||||
for cidx, cell in enumerate(row.cells):
|
|
||||||
line.append(self.render_cell(table, ridx, cidx, widths))
|
|
||||||
lines.append(self.col_sep.join(line))
|
|
||||||
if row.post_separator is not None:
|
|
||||||
sepline = self.render_separator(row.post_separator, table, widths, total_width)
|
|
||||||
if len(sepline) > 0:
|
|
||||||
lines.append(sepline)
|
|
||||||
return '\n'.join(lines)
|
|
||||||
|
|
||||||
class MarkdownRenderer(TerminalRenderer):
|
class MarkdownRenderer(TerminalRenderer):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(col_sep='|')
|
super().__init__(col_sep='|')
|
||||||
self.printed_color_warning = False
|
self.printed_color_warning = False
|
||||||
|
|
||||||
def print_color_warning(self):
|
def print_color_warning(self):
|
||||||
if not self.printed_color_warning:
|
if not self.printed_color_warning:
|
||||||
print('[WARNING] MarkdownRenderer does not support color yet')
|
print('[WARNING] MarkdownRenderer does not support color yet')
|
||||||
self.printed_color_warning = True
|
self.printed_color_warning = True
|
||||||
|
|
||||||
def cell_str_len(self, cell):
|
def cell_str_len(self, cell):
|
||||||
strlen = len(str(cell))
|
strlen = len(str(cell))
|
||||||
if cell.fmt.bold:
|
if cell.fmt.bold:
|
||||||
strlen += 4
|
strlen += 4
|
||||||
strlen = max(5, strlen)
|
strlen = max(5, strlen)
|
||||||
return strlen
|
return strlen
|
||||||
|
|
||||||
def render_cell(self, table, row, col, widths):
|
def render_cell(self, table, row, col, widths):
|
||||||
cell = table.rows[row].cells[col]
|
cell = table.rows[row].cells[col]
|
||||||
str = cell.fmt.fmt % cell.data
|
str = cell.fmt.fmt % cell.data
|
||||||
if cell.fmt.bold:
|
if cell.fmt.bold:
|
||||||
str = f'**{str}**'
|
str = f'**{str}**'
|
||||||
|
|
||||||
str_width = len(str)
|
str_width = len(str)
|
||||||
cell_width = sum([widths[idx] for idx in range(col, col+cell.span)])
|
cell_width = sum([widths[idx] for idx in range(col, col + cell.span)])
|
||||||
cell_width += len(self.col_sep) * (cell.span - 1)
|
cell_width += len(self.col_sep) * (cell.span - 1)
|
||||||
if len(str) > cell_width:
|
if len(str) > cell_width:
|
||||||
str = str[:cell_width]
|
str = str[:cell_width]
|
||||||
else:
|
else:
|
||||||
n_ws = (cell_width - str_width)
|
n_ws = (cell_width - str_width)
|
||||||
if table.get_cell_align(row, col) == 'r':
|
if table.get_cell_align(row, col) == 'r':
|
||||||
str = ' '*n_ws + str
|
str = ' ' * n_ws + str
|
||||||
elif table.get_cell_align(row, col) == 'l':
|
elif table.get_cell_align(row, col) == 'l':
|
||||||
str = str + ' '*n_ws
|
str = str + ' ' * n_ws
|
||||||
elif table.get_cell_align(row, col) == 'c':
|
elif table.get_cell_align(row, col) == 'c':
|
||||||
n_ws1 = n_ws // 2
|
n_ws1 = n_ws // 2
|
||||||
n_ws0 = n_ws - n_ws1
|
n_ws0 = n_ws - n_ws1
|
||||||
str = ' '*n_ws0 + str + ' '*n_ws1
|
str = ' ' * n_ws0 + str + ' ' * n_ws1
|
||||||
|
|
||||||
if col == 0: str = self.col_sep + str
|
if col == 0: str = self.col_sep + str
|
||||||
if col == table.n_cols - 1: str += self.col_sep
|
if col == table.n_cols - 1: str += self.col_sep
|
||||||
|
|
||||||
if cell.fmt.fgcolor is not None:
|
if cell.fmt.fgcolor is not None:
|
||||||
self.print_color_warning()
|
self.print_color_warning()
|
||||||
if cell.fmt.bgcolor is not None:
|
if cell.fmt.bgcolor is not None:
|
||||||
self.print_color_warning()
|
self.print_color_warning()
|
||||||
return str
|
return str
|
||||||
|
|
||||||
def render_separator(self, separator, tab, widths, total_width):
|
def render_separator(self, separator, tab, widths, total_width):
|
||||||
sep = ''
|
sep = ''
|
||||||
if separator == Separator.INNER:
|
if separator == Separator.INNER:
|
||||||
sep = self.col_sep
|
sep = self.col_sep
|
||||||
for idx, width in enumerate(widths):
|
for idx, width in enumerate(widths):
|
||||||
csep = '-' * (width - 2)
|
csep = '-' * (width - 2)
|
||||||
if tab.get_cell_align(1, idx) == 'r':
|
if tab.get_cell_align(1, idx) == 'r':
|
||||||
csep = '-' + csep + ':'
|
csep = '-' + csep + ':'
|
||||||
elif tab.get_cell_align(1, idx) == 'l':
|
elif tab.get_cell_align(1, idx) == 'l':
|
||||||
csep = ':' + csep + '-'
|
csep = ':' + csep + '-'
|
||||||
elif tab.get_cell_align(1, idx) == 'c':
|
elif tab.get_cell_align(1, idx) == 'c':
|
||||||
csep = ':' + csep + ':'
|
csep = ':' + csep + ':'
|
||||||
sep += csep + self.col_sep
|
sep += csep + self.col_sep
|
||||||
return sep
|
return sep
|
||||||
|
|
||||||
|
|
||||||
class LatexRenderer(Renderer):
|
class LatexRenderer(Renderer):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def render_cell(self, table, row, col):
|
def render_cell(self, table, row, col):
|
||||||
cell = table.rows[row].cells[col]
|
cell = table.rows[row].cells[col]
|
||||||
str = cell.fmt.fmt % cell.data
|
str = cell.fmt.fmt % cell.data
|
||||||
if cell.fmt.bold:
|
if cell.fmt.bold:
|
||||||
str = '{\\bf '+ str + '}'
|
str = '{\\bf ' + str + '}'
|
||||||
if cell.fmt.fgcolor is not None:
|
if cell.fmt.fgcolor is not None:
|
||||||
color = cell.fmt.fgcolor.as_rgb()
|
color = cell.fmt.fgcolor.as_rgb()
|
||||||
str = f'{{\\color[rgb]{{{color[0]},{color[1]},{color[2]}}} ' + str + '}'
|
str = f'{{\\color[rgb]{{{color[0]},{color[1]},{color[2]}}} ' + str + '}'
|
||||||
if cell.fmt.bgcolor is not None:
|
if cell.fmt.bgcolor is not None:
|
||||||
color = cell.fmt.bgcolor.as_rgb()
|
color = cell.fmt.bgcolor.as_rgb()
|
||||||
str = f'\\cellcolor[rgb]{{{color[0]},{color[1]},{color[2]}}} ' + str
|
str = f'\\cellcolor[rgb]{{{color[0]},{color[1]},{color[2]}}} ' + str
|
||||||
align = table.get_cell_align(row, col)
|
align = table.get_cell_align(row, col)
|
||||||
if cell.span != 1 or align != table.aligns[col]:
|
if cell.span != 1 or align != table.aligns[col]:
|
||||||
str = f'\\multicolumn{{{cell.span}}}{{{align}}}{{{str}}}'
|
str = f'\\multicolumn{{{cell.span}}}{{{align}}}{{{str}}}'
|
||||||
return str
|
return str
|
||||||
|
|
||||||
def render_separator(self, separator):
|
def render_separator(self, separator):
|
||||||
if separator == Separator.HEAD:
|
if separator == Separator.HEAD:
|
||||||
return '\\toprule'
|
return '\\toprule'
|
||||||
elif separator == Separator.INNER:
|
elif separator == Separator.INNER:
|
||||||
return '\\midrule'
|
return '\\midrule'
|
||||||
elif separator == Separator.BOTTOM:
|
elif separator == Separator.BOTTOM:
|
||||||
return '\\bottomrule'
|
return '\\bottomrule'
|
||||||
|
|
||||||
|
def render(self, table):
|
||||||
|
lines = ['\\begin{tabular}{' + ''.join(table.aligns) + '}']
|
||||||
|
for ridx, row in enumerate(table.rows):
|
||||||
|
if row.pre_separator is not None:
|
||||||
|
lines.append(self.render_separator(row.pre_separator))
|
||||||
|
line = []
|
||||||
|
for cidx, cell in enumerate(row.cells):
|
||||||
|
line.append(self.render_cell(table, ridx, cidx))
|
||||||
|
lines.append(' & '.join(line) + ' \\\\')
|
||||||
|
if row.post_separator is not None:
|
||||||
|
lines.append(self.render_separator(row.post_separator))
|
||||||
|
lines.append('\\end{tabular}')
|
||||||
|
return '\n'.join(lines)
|
||||||
|
|
||||||
def render(self, table):
|
|
||||||
lines = ['\\begin{tabular}{' + ''.join(table.aligns) + '}']
|
|
||||||
for ridx, row in enumerate(table.rows):
|
|
||||||
if row.pre_separator is not None:
|
|
||||||
lines.append(self.render_separator(row.pre_separator))
|
|
||||||
line = []
|
|
||||||
for cidx, cell in enumerate(row.cells):
|
|
||||||
line.append(self.render_cell(table, ridx, cidx))
|
|
||||||
lines.append(' & '.join(line) + ' \\\\')
|
|
||||||
if row.post_separator is not None:
|
|
||||||
lines.append(self.render_separator(row.post_separator))
|
|
||||||
lines.append('\\end{tabular}')
|
|
||||||
return '\n'.join(lines)
|
|
||||||
|
|
||||||
class HtmlRenderer(Renderer):
|
class HtmlRenderer(Renderer):
|
||||||
def __init__(self, html_class='result_table'):
|
def __init__(self, html_class='result_table'):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.html_class = html_class
|
self.html_class = html_class
|
||||||
|
|
||||||
def render_cell(self, table, row, col):
|
def render_cell(self, table, row, col):
|
||||||
cell = table.rows[row].cells[col]
|
cell = table.rows[row].cells[col]
|
||||||
str = cell.fmt.fmt % cell.data
|
str = cell.fmt.fmt % cell.data
|
||||||
styles = []
|
styles = []
|
||||||
if cell.fmt.bold:
|
if cell.fmt.bold:
|
||||||
styles.append('font-weight: bold;')
|
styles.append('font-weight: bold;')
|
||||||
if cell.fmt.fgcolor is not None:
|
if cell.fmt.fgcolor is not None:
|
||||||
color = cell.fmt.fgcolor.as_RGB()
|
color = cell.fmt.fgcolor.as_RGB()
|
||||||
styles.append(f'color: rgb({color[0]},{color[1]},{color[2]});')
|
styles.append(f'color: rgb({color[0]},{color[1]},{color[2]});')
|
||||||
if cell.fmt.bgcolor is not None:
|
if cell.fmt.bgcolor is not None:
|
||||||
color = cell.fmt.bgcolor.as_RGB()
|
color = cell.fmt.bgcolor.as_RGB()
|
||||||
styles.append(f'background-color: rgb({color[0]},{color[1]},{color[2]});')
|
styles.append(f'background-color: rgb({color[0]},{color[1]},{color[2]});')
|
||||||
align = table.get_cell_align(row, col)
|
align = table.get_cell_align(row, col)
|
||||||
if align == 'l': align = 'left'
|
if align == 'l':
|
||||||
elif align == 'r': align = 'right'
|
align = 'left'
|
||||||
elif align == 'c': align = 'center'
|
elif align == 'r':
|
||||||
else: raise Exception('invalid align')
|
align = 'right'
|
||||||
styles.append(f'text-align: {align};')
|
elif align == 'c':
|
||||||
row = table.rows[row]
|
align = 'center'
|
||||||
if row.pre_separator is not None:
|
else:
|
||||||
styles.append(f'border-top: {self.render_separator(row.pre_separator)};')
|
raise Exception('invalid align')
|
||||||
if row.post_separator is not None:
|
styles.append(f'text-align: {align};')
|
||||||
styles.append(f'border-bottom: {self.render_separator(row.post_separator)};')
|
row = table.rows[row]
|
||||||
style = ' '.join(styles)
|
if row.pre_separator is not None:
|
||||||
str = f' <td style="{style}" colspan="{cell.span}">{str}</td>\n'
|
styles.append(f'border-top: {self.render_separator(row.pre_separator)};')
|
||||||
return str
|
if row.post_separator is not None:
|
||||||
|
styles.append(f'border-bottom: {self.render_separator(row.post_separator)};')
|
||||||
|
style = ' '.join(styles)
|
||||||
|
str = f' <td style="{style}" colspan="{cell.span}">{str}</td>\n'
|
||||||
|
return str
|
||||||
|
|
||||||
def render_separator(self, separator):
|
def render_separator(self, separator):
|
||||||
if separator == Separator.HEAD:
|
if separator == Separator.HEAD:
|
||||||
return '1.5pt solid black'
|
return '1.5pt solid black'
|
||||||
elif separator == Separator.INNER:
|
elif separator == Separator.INNER:
|
||||||
return '0.75pt solid black'
|
return '0.75pt solid black'
|
||||||
elif separator == Separator.BOTTOM:
|
elif separator == Separator.BOTTOM:
|
||||||
return '1.5pt solid black'
|
return '1.5pt solid black'
|
||||||
|
|
||||||
def render(self, table):
|
def render(self, table):
|
||||||
lines = [f'<table width="100%" style="border-collapse: collapse" class={self.html_class}>']
|
lines = [f'<table width="100%" style="border-collapse: collapse" class={self.html_class}>']
|
||||||
for ridx, row in enumerate(table.rows):
|
for ridx, row in enumerate(table.rows):
|
||||||
line = [f' <tr>\n']
|
line = [f' <tr>\n']
|
||||||
for cidx, cell in enumerate(row.cells):
|
for cidx, cell in enumerate(row.cells):
|
||||||
line.append(self.render_cell(table, ridx, cidx))
|
line.append(self.render_cell(table, ridx, cidx))
|
||||||
line.append(' </tr>\n')
|
line.append(' </tr>\n')
|
||||||
lines.append(' '.join(line))
|
lines.append(' '.join(line))
|
||||||
lines.append('</table>')
|
lines.append('</table>')
|
||||||
return '\n'.join(lines)
|
return '\n'.join(lines)
|
||||||
|
|
||||||
|
|
||||||
def pandas_to_table(rowname, colname, valname, data, val_cell_fmt=CellFormat(fmt='%.4f'), best_val_cell_fmt=CellFormat(fmt='%.4f', bold=True), best_is_max=[]):
|
def pandas_to_table(rowname, colname, valname, data, val_cell_fmt=CellFormat(fmt='%.4f'),
|
||||||
rnames = data[rowname].unique()
|
best_val_cell_fmt=CellFormat(fmt='%.4f', bold=True), best_is_max=[]):
|
||||||
cnames = data[colname].unique()
|
rnames = data[rowname].unique()
|
||||||
tab = Table(1+len(cnames))
|
cnames = data[colname].unique()
|
||||||
|
tab = Table(1 + len(cnames))
|
||||||
|
|
||||||
header = [Cell('', align='r')]
|
header = [Cell('', align='r')]
|
||||||
header.extend([Cell(h, align='r') for h in cnames])
|
header.extend([Cell(h, align='r') for h in cnames])
|
||||||
header = Row(header, pre_separator=Separator.HEAD, post_separator=Separator.INNER)
|
header = Row(header, pre_separator=Separator.HEAD, post_separator=Separator.INNER)
|
||||||
tab.add_row(header)
|
tab.add_row(header)
|
||||||
|
|
||||||
for rname in rnames:
|
|
||||||
cells = [Cell(rname, align='l')]
|
|
||||||
for cname in cnames:
|
|
||||||
cdata = data[data[colname] == cname]
|
|
||||||
if cname in best_is_max:
|
|
||||||
bestval = cdata[valname].max()
|
|
||||||
val = cdata[cdata[rowname] == rname][valname].max()
|
|
||||||
else:
|
|
||||||
bestval = cdata[valname].min()
|
|
||||||
val = cdata[cdata[rowname] == rname][valname].min()
|
|
||||||
if val == bestval:
|
|
||||||
fmt = best_val_cell_fmt
|
|
||||||
else:
|
|
||||||
fmt = val_cell_fmt
|
|
||||||
cells.append(Cell(val, align='r', fmt=fmt))
|
|
||||||
tab.add_row(Row(cells))
|
|
||||||
tab.rows[-1].post_separator = Separator.BOTTOM
|
|
||||||
return tab
|
|
||||||
|
|
||||||
|
for rname in rnames:
|
||||||
|
cells = [Cell(rname, align='l')]
|
||||||
|
for cname in cnames:
|
||||||
|
cdata = data[data[colname] == cname]
|
||||||
|
if cname in best_is_max:
|
||||||
|
bestval = cdata[valname].max()
|
||||||
|
val = cdata[cdata[rowname] == rname][valname].max()
|
||||||
|
else:
|
||||||
|
bestval = cdata[valname].min()
|
||||||
|
val = cdata[cdata[rowname] == rname][valname].min()
|
||||||
|
if val == bestval:
|
||||||
|
fmt = best_val_cell_fmt
|
||||||
|
else:
|
||||||
|
fmt = val_cell_fmt
|
||||||
|
cells.append(Cell(val, align='r', fmt=fmt))
|
||||||
|
tab.add_row(Row(cells))
|
||||||
|
tab.rows[-1].post_separator = Separator.BOTTOM
|
||||||
|
return tab
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# df = pd.read_pickle('full.df')
|
# df = pd.read_pickle('full.df')
|
||||||
# best_is_max = ['movF0.5', 'movF1.0']
|
# best_is_max = ['movF0.5', 'movF1.0']
|
||||||
# tab = pandas_to_table(rowname='method', colname='metric', valname='val', data=df, best_is_max=best_is_max)
|
# tab = pandas_to_table(rowname='method', colname='metric', valname='val', data=df, best_is_max=best_is_max)
|
||||||
|
|
||||||
# renderer = TerminalRenderer()
|
# renderer = TerminalRenderer()
|
||||||
# print(renderer(tab))
|
# print(renderer(tab))
|
||||||
|
|
||||||
tab = Table(7)
|
tab = Table(7)
|
||||||
# header = Row([Cell('header', span=7, align='c')], pre_separator=Separator.HEAD, post_separator=Separator.INNER)
|
# header = Row([Cell('header', span=7, align='c')], pre_separator=Separator.HEAD, post_separator=Separator.INNER)
|
||||||
# tab.add_row(header)
|
# tab.add_row(header)
|
||||||
# header2 = Row([Cell('thisisaverylongheader', span=4, align='c'), Cell('vals2', span=3, align='c')], post_separator=Separator.INNER)
|
# header2 = Row([Cell('thisisaverylongheader', span=4, align='c'), Cell('vals2', span=3, align='c')], post_separator=Separator.INNER)
|
||||||
# tab.add_row(header2)
|
# tab.add_row(header2)
|
||||||
tab.add_row(Row([Cell(f'c{c}') for c in range(7)]))
|
tab.add_row(Row([Cell(f'c{c}') for c in range(7)]))
|
||||||
tab.rows[-1].post_separator = Separator.INNER
|
tab.rows[-1].post_separator = Separator.INNER
|
||||||
tab.add_block(np.arange(15*7).reshape(15,7))
|
tab.add_block(np.arange(15 * 7).reshape(15, 7))
|
||||||
tab.rows[4].cells[2].fmt = CellFormat(bold=True)
|
tab.rows[4].cells[2].fmt = CellFormat(bold=True)
|
||||||
tab.rows[2].cells[1].fmt = CellFormat(fgcolor=Color.rgb(0.2,0.6,0.1))
|
tab.rows[2].cells[1].fmt = CellFormat(fgcolor=Color.rgb(0.2, 0.6, 0.1))
|
||||||
tab.rows[2].cells[2].fmt = CellFormat(bgcolor=Color.rgb(0.7,0.1,0.5))
|
tab.rows[2].cells[2].fmt = CellFormat(bgcolor=Color.rgb(0.7, 0.1, 0.5))
|
||||||
tab.rows[5].cells[3].fmt = CellFormat(bold=True,bgcolor=Color.rgb(0.7,0.1,0.5),fgcolor=Color.rgb(0.1,0.1,0.1))
|
tab.rows[5].cells[3].fmt = CellFormat(bold=True, bgcolor=Color.rgb(0.7, 0.1, 0.5), fgcolor=Color.rgb(0.1, 0.1, 0.1))
|
||||||
tab.rows[-1].post_separator = Separator.BOTTOM
|
tab.rows[-1].post_separator = Separator.BOTTOM
|
||||||
|
|
||||||
renderer = TerminalRenderer()
|
renderer = TerminalRenderer()
|
||||||
print(renderer(tab))
|
print(renderer(tab))
|
||||||
renderer = MarkdownRenderer()
|
renderer = MarkdownRenderer()
|
||||||
print(renderer(tab))
|
print(renderer(tab))
|
||||||
|
|
||||||
# renderer = HtmlRenderer()
|
# renderer = HtmlRenderer()
|
||||||
# html_tab = renderer(tab)
|
# html_tab = renderer(tab)
|
||||||
# print(html_tab)
|
# print(html_tab)
|
||||||
# with open('test.html', 'w') as fp:
|
# with open('test.html', 'w') as fp:
|
||||||
# fp.write(html_tab)
|
# fp.write(html_tab)
|
||||||
|
|
||||||
# import latex
|
# import latex
|
||||||
|
|
||||||
# renderer = LatexRenderer()
|
# renderer = LatexRenderer()
|
||||||
# ltx_tab = renderer(tab)
|
# ltx_tab = renderer(tab)
|
||||||
# print(ltx_tab)
|
# print(ltx_tab)
|
||||||
|
|
||||||
# with open('test.tex', 'w') as fp:
|
# with open('test.tex', 'w') as fp:
|
||||||
# latex.write_doc_prefix(fp, document_class='article')
|
# latex.write_doc_prefix(fp, document_class='article')
|
||||||
# fp.write('this is text that should appear before the table and should be long enough to wrap around.\n'*40)
|
# fp.write('this is text that should appear before the table and should be long enough to wrap around.\n'*40)
|
||||||
# fp.write('\\begin{table}')
|
# fp.write('\\begin{table}')
|
||||||
# fp.write(ltx_tab)
|
# fp.write(ltx_tab)
|
||||||
# fp.write('\\end{table}')
|
# fp.write('\\end{table}')
|
||||||
# fp.write('this is text that should appear after the table and should be long enough to wrap around.\n'*40)
|
# fp.write('this is text that should appear after the table and should be long enough to wrap around.\n'*40)
|
||||||
# latex.write_doc_suffix(fp)
|
# latex.write_doc_suffix(fp)
|
||||||
|
108
co/utils.py
108
co/utils.py
@ -8,6 +8,7 @@ import re
|
|||||||
import pickle
|
import pickle
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
|
|
||||||
def str2bool(v):
|
def str2bool(v):
|
||||||
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
||||||
return True
|
return True
|
||||||
@ -16,71 +17,74 @@ def str2bool(v):
|
|||||||
else:
|
else:
|
||||||
raise argparse.ArgumentTypeError('Boolean value expected.')
|
raise argparse.ArgumentTypeError('Boolean value expected.')
|
||||||
|
|
||||||
|
|
||||||
class StopWatch(object):
|
class StopWatch(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.timings = OrderedDict()
|
self.timings = OrderedDict()
|
||||||
self.starts = {}
|
self.starts = {}
|
||||||
|
|
||||||
def start(self, name):
|
def start(self, name):
|
||||||
self.starts[name] = time.time()
|
self.starts[name] = time.time()
|
||||||
|
|
||||||
def stop(self, name):
|
def stop(self, name):
|
||||||
if name not in self.timings:
|
if name not in self.timings:
|
||||||
self.timings[name] = []
|
self.timings[name] = []
|
||||||
self.timings[name].append(time.time() - self.starts[name])
|
self.timings[name].append(time.time() - self.starts[name])
|
||||||
|
|
||||||
def get(self, name=None, reduce=np.sum):
|
def get(self, name=None, reduce=np.sum):
|
||||||
if name is not None:
|
if name is not None:
|
||||||
return reduce(self.timings[name])
|
return reduce(self.timings[name])
|
||||||
else:
|
else:
|
||||||
ret = {}
|
ret = {}
|
||||||
for k in self.timings:
|
for k in self.timings:
|
||||||
ret[k] = reduce(self.timings[k])
|
ret[k] = reduce(self.timings[k])
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
|
|
||||||
def __str__(self):
|
|
||||||
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
|
|
||||||
|
|
||||||
class ETA(object):
|
class ETA(object):
|
||||||
def __init__(self, length):
|
def __init__(self, length):
|
||||||
self.length = length
|
self.length = length
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
self.current_idx = 0
|
self.current_idx = 0
|
||||||
self.current_time = time.time()
|
self.current_time = time.time()
|
||||||
|
|
||||||
def update(self, idx):
|
def update(self, idx):
|
||||||
self.current_idx = idx
|
self.current_idx = idx
|
||||||
self.current_time = time.time()
|
self.current_time = time.time()
|
||||||
|
|
||||||
def get_elapsed_time(self):
|
def get_elapsed_time(self):
|
||||||
return self.current_time - self.start_time
|
return self.current_time - self.start_time
|
||||||
|
|
||||||
def get_item_time(self):
|
def get_item_time(self):
|
||||||
return self.get_elapsed_time() / (self.current_idx + 1)
|
return self.get_elapsed_time() / (self.current_idx + 1)
|
||||||
|
|
||||||
def get_remaining_time(self):
|
def get_remaining_time(self):
|
||||||
return self.get_item_time() * (self.length - self.current_idx + 1)
|
return self.get_item_time() * (self.length - self.current_idx + 1)
|
||||||
|
|
||||||
def format_time(self, seconds):
|
def format_time(self, seconds):
|
||||||
minutes, seconds = divmod(seconds, 60)
|
minutes, seconds = divmod(seconds, 60)
|
||||||
hours, minutes = divmod(minutes, 60)
|
hours, minutes = divmod(minutes, 60)
|
||||||
hours = int(hours)
|
hours = int(hours)
|
||||||
minutes = int(minutes)
|
minutes = int(minutes)
|
||||||
return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}'
|
return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}'
|
||||||
|
|
||||||
def get_elapsed_time_str(self):
|
def get_elapsed_time_str(self):
|
||||||
return self.format_time(self.get_elapsed_time())
|
return self.format_time(self.get_elapsed_time())
|
||||||
|
|
||||||
|
def get_remaining_time_str(self):
|
||||||
|
return self.format_time(self.get_remaining_time())
|
||||||
|
|
||||||
def get_remaining_time_str(self):
|
|
||||||
return self.format_time(self.get_remaining_time())
|
|
||||||
|
|
||||||
def git_hash(cwd=None):
|
def git_hash(cwd=None):
|
||||||
ret = subprocess.run(['git', 'describe', '--always'], cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
ret = subprocess.run(['git', 'describe', '--always'], cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||||
hash = ret.stdout
|
hash = ret.stdout
|
||||||
if hash is not None and 'fatal' not in hash.decode():
|
if hash is not None and 'fatal' not in hash.decode():
|
||||||
return hash.decode().strip()
|
return hash.decode().strip()
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
112
data/commons.py
112
data/commons.py
@ -4,47 +4,47 @@ import cv2
|
|||||||
|
|
||||||
|
|
||||||
def get_patterns(path='syn', imsizes=[], crop=True):
|
def get_patterns(path='syn', imsizes=[], crop=True):
|
||||||
pattern_size = imsizes[0]
|
pattern_size = imsizes[0]
|
||||||
if path == 'syn':
|
if path == 'syn':
|
||||||
np.random.seed(42)
|
np.random.seed(42)
|
||||||
pattern = np.random.uniform(0,1, size=pattern_size)
|
pattern = np.random.uniform(0, 1, size=pattern_size)
|
||||||
pattern = (pattern < 0.1).astype(np.float32)
|
pattern = (pattern < 0.1).astype(np.float32)
|
||||||
pattern.reshape(*imsizes[0])
|
pattern.reshape(*imsizes[0])
|
||||||
else:
|
else:
|
||||||
pattern = cv2.imread(path)
|
pattern = cv2.imread(path)
|
||||||
pattern = pattern.astype(np.float32)
|
pattern = pattern.astype(np.float32)
|
||||||
pattern /= 255
|
pattern /= 255
|
||||||
|
|
||||||
if pattern.ndim == 2:
|
if pattern.ndim == 2:
|
||||||
pattern = np.stack([pattern for idx in range(3)], axis=2)
|
pattern = np.stack([pattern for idx in range(3)], axis=2)
|
||||||
|
|
||||||
if crop and pattern.shape[0] > pattern_size[0] and pattern.shape[1] > pattern_size[1]:
|
if crop and pattern.shape[0] > pattern_size[0] and pattern.shape[1] > pattern_size[1]:
|
||||||
r0 = (pattern.shape[0] - pattern_size[0]) // 2
|
r0 = (pattern.shape[0] - pattern_size[0]) // 2
|
||||||
c0 = (pattern.shape[1] - pattern_size[1]) // 2
|
c0 = (pattern.shape[1] - pattern_size[1]) // 2
|
||||||
pattern = pattern[r0:r0+imsizes[0][0], c0:c0+imsizes[0][1]]
|
pattern = pattern[r0:r0 + imsizes[0][0], c0:c0 + imsizes[0][1]]
|
||||||
|
|
||||||
patterns = []
|
patterns = []
|
||||||
for imsize in imsizes:
|
for imsize in imsizes:
|
||||||
pat = cv2.resize(pattern, (imsize[1],imsize[0]), interpolation=cv2.INTER_LINEAR)
|
pat = cv2.resize(pattern, (imsize[1], imsize[0]), interpolation=cv2.INTER_LINEAR)
|
||||||
patterns.append(pat)
|
patterns.append(pat)
|
||||||
|
|
||||||
|
return patterns
|
||||||
|
|
||||||
return patterns
|
|
||||||
|
|
||||||
def get_rotation_matrix(v0, v1):
|
def get_rotation_matrix(v0, v1):
|
||||||
v0 = v0/np.linalg.norm(v0)
|
v0 = v0 / np.linalg.norm(v0)
|
||||||
v1 = v1/np.linalg.norm(v1)
|
v1 = v1 / np.linalg.norm(v1)
|
||||||
v = np.cross(v0,v1)
|
v = np.cross(v0, v1)
|
||||||
c = np.dot(v0,v1)
|
c = np.dot(v0, v1)
|
||||||
s = np.linalg.norm(v)
|
s = np.linalg.norm(v)
|
||||||
I = np.eye(3)
|
I = np.eye(3)
|
||||||
vXStr = '{} {} {}; {} {} {}; {} {} {}'.format(0, -v[2], v[1], v[2], 0, -v[0], -v[1], v[0], 0)
|
vXStr = '{} {} {}; {} {} {}; {} {} {}'.format(0, -v[2], v[1], v[2], 0, -v[0], -v[1], v[0], 0)
|
||||||
k = np.matrix(vXStr)
|
k = np.matrix(vXStr)
|
||||||
r = I + k + k @ k * ((1 -c)/(s**2))
|
r = I + k + k @ k * ((1 - c) / (s ** 2))
|
||||||
return np.asarray(r.astype(np.float32))
|
return np.asarray(r.astype(np.float32))
|
||||||
|
|
||||||
|
|
||||||
def augment_image(img,rng,disp=None,grad=None,max_shift=64,max_blur=1.5,max_noise=10.0,max_sp_noise=0.001):
|
def augment_image(img, rng, disp=None, grad=None, max_shift=64, max_blur=1.5, max_noise=10.0, max_sp_noise=0.001):
|
||||||
|
|
||||||
# get min/max values of image
|
# get min/max values of image
|
||||||
min_val = np.min(img)
|
min_val = np.min(img)
|
||||||
max_val = np.max(img)
|
max_val = np.max(img)
|
||||||
@ -57,54 +57,56 @@ def augment_image(img,rng,disp=None,grad=None,max_shift=64,max_blur=1.5,max_nois
|
|||||||
grad_aug = grad
|
grad_aug = grad
|
||||||
|
|
||||||
# apply affine transformation
|
# apply affine transformation
|
||||||
if max_shift>1:
|
if max_shift > 1:
|
||||||
|
|
||||||
# affine parameters
|
# affine parameters
|
||||||
rows,cols = img.shape
|
rows, cols = img.shape
|
||||||
shear = 0
|
shear = 0
|
||||||
shift = 0
|
shift = 0
|
||||||
shear_correction = 0
|
shear_correction = 0
|
||||||
if rng.uniform(0,1)<0.75: shear = rng.uniform(-max_shift,max_shift) # shear with 75% probability
|
if rng.uniform(0, 1) < 0.75:
|
||||||
else: shift = rng.uniform(0,max_shift) # shift with 25% probability
|
shear = rng.uniform(-max_shift, max_shift) # shear with 75% probability
|
||||||
if shear<0: shear_correction = -shear
|
else:
|
||||||
|
shift = rng.uniform(0, max_shift) # shift with 25% probability
|
||||||
|
if shear < 0: shear_correction = -shear
|
||||||
|
|
||||||
# affine transformation
|
# affine transformation
|
||||||
a = shear/float(rows)
|
a = shear / float(rows)
|
||||||
b = shift+shear_correction
|
b = shift + shear_correction
|
||||||
|
|
||||||
# warp image
|
# warp image
|
||||||
T = np.float32([[1,a,b],[0,1,0]])
|
T = np.float32([[1, a, b], [0, 1, 0]])
|
||||||
img_aug = cv2.warpAffine(img_aug,T,(cols,rows))
|
img_aug = cv2.warpAffine(img_aug, T, (cols, rows))
|
||||||
if grad is not None:
|
if grad is not None:
|
||||||
grad_aug = cv2.warpAffine(grad,T,(cols,rows))
|
grad_aug = cv2.warpAffine(grad, T, (cols, rows))
|
||||||
|
|
||||||
# disparity correction map
|
# disparity correction map
|
||||||
col = a*np.array(range(rows))+b
|
col = a * np.array(range(rows)) + b
|
||||||
disp_delta = np.tile(col,(cols,1)).transpose()
|
disp_delta = np.tile(col, (cols, 1)).transpose()
|
||||||
if disp is not None:
|
if disp is not None:
|
||||||
disp_aug = cv2.warpAffine(disp+disp_delta,T,(cols,rows))
|
disp_aug = cv2.warpAffine(disp + disp_delta, T, (cols, rows))
|
||||||
|
|
||||||
# gaussian smoothing
|
# gaussian smoothing
|
||||||
if rng.uniform(0,1)<0.5:
|
if rng.uniform(0, 1) < 0.5:
|
||||||
img_aug = cv2.GaussianBlur(img_aug,(5,5),rng.uniform(0.2,max_blur))
|
img_aug = cv2.GaussianBlur(img_aug, (5, 5), rng.uniform(0.2, max_blur))
|
||||||
|
|
||||||
# per-pixel gaussian noise
|
# per-pixel gaussian noise
|
||||||
img_aug = img_aug + rng.randn(*img_aug.shape)*rng.uniform(0.0,max_noise)/255.0
|
img_aug = img_aug + rng.randn(*img_aug.shape) * rng.uniform(0.0, max_noise) / 255.0
|
||||||
|
|
||||||
# salt-and-pepper noise
|
# salt-and-pepper noise
|
||||||
if rng.uniform(0,1)<0.5:
|
if rng.uniform(0, 1) < 0.5:
|
||||||
ratio=rng.uniform(0.0,max_sp_noise)
|
ratio = rng.uniform(0.0, max_sp_noise)
|
||||||
img_shape = img_aug.shape
|
img_shape = img_aug.shape
|
||||||
img_aug = img_aug.flatten()
|
img_aug = img_aug.flatten()
|
||||||
coord = rng.choice(np.size(img_aug), int(np.size(img_aug)*ratio))
|
coord = rng.choice(np.size(img_aug), int(np.size(img_aug) * ratio))
|
||||||
img_aug[coord] = max_val
|
img_aug[coord] = max_val
|
||||||
coord = rng.choice(np.size(img_aug), int(np.size(img_aug)*ratio))
|
coord = rng.choice(np.size(img_aug), int(np.size(img_aug) * ratio))
|
||||||
img_aug[coord] = min_val
|
img_aug[coord] = min_val
|
||||||
img_aug = np.reshape(img_aug, img_shape)
|
img_aug = np.reshape(img_aug, img_shape)
|
||||||
|
|
||||||
# clip intensities back to [0,1]
|
# clip intensities back to [0,1]
|
||||||
img_aug = np.maximum(img_aug,0.0)
|
img_aug = np.maximum(img_aug, 0.0)
|
||||||
img_aug = np.minimum(img_aug,1.0)
|
img_aug = np.minimum(img_aug, 1.0)
|
||||||
|
|
||||||
# return image
|
# return image
|
||||||
return img_aug, disp_aug, grad_aug
|
return img_aug, disp_aug, grad_aug
|
||||||
|
@ -10,261 +10,259 @@ import cv2
|
|||||||
import os
|
import os
|
||||||
import collections
|
import collections
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.path.append('../')
|
sys.path.append('../')
|
||||||
import renderer
|
import renderer
|
||||||
import co
|
import co
|
||||||
from commons import get_patterns,get_rotation_matrix
|
from commons import get_patterns, get_rotation_matrix
|
||||||
from lcn import lcn
|
from lcn import lcn
|
||||||
|
|
||||||
|
|
||||||
def get_objs(shapenet_dir, obj_classes, num_perclass=100):
|
def get_objs(shapenet_dir, obj_classes, num_perclass=100):
|
||||||
|
shapenet = {'chair': '03001627',
|
||||||
|
'airplane': '02691156',
|
||||||
|
'car': '02958343',
|
||||||
|
'watercraft': '04530566'}
|
||||||
|
|
||||||
shapenet = {'chair': '03001627',
|
obj_paths = []
|
||||||
'airplane': '02691156',
|
for cls in obj_classes:
|
||||||
'car': '02958343',
|
if cls not in shapenet.keys():
|
||||||
'watercraft': '04530566'}
|
raise Exception('unknown class name')
|
||||||
|
ids = shapenet[cls]
|
||||||
|
obj_path = sorted(Path(f'{shapenet_dir}/{ids}').glob('**/models/*.obj'))
|
||||||
|
obj_paths += obj_path[:num_perclass]
|
||||||
|
print(f'found {len(obj_paths)} object paths')
|
||||||
|
|
||||||
obj_paths = []
|
objs = []
|
||||||
for cls in obj_classes:
|
for obj_path in obj_paths:
|
||||||
if cls not in shapenet.keys():
|
print(f'load {obj_path}')
|
||||||
raise Exception('unknown class name')
|
v, f, _, n = co.io3d.read_obj(obj_path)
|
||||||
ids = shapenet[cls]
|
diffs = v.max(axis=0) - v.min(axis=0)
|
||||||
obj_path = sorted(Path(f'{shapenet_dir}/{ids}').glob('**/models/*.obj'))
|
v /= (0.5 * diffs.max())
|
||||||
obj_paths += obj_path[:num_perclass]
|
v -= (v.min(axis=0) + 1)
|
||||||
print(f'found {len(obj_paths)} object paths')
|
f = f.astype(np.int32)
|
||||||
|
objs.append((v, f, n))
|
||||||
|
print(f'loaded {len(objs)} objects')
|
||||||
|
|
||||||
objs = []
|
return objs
|
||||||
for obj_path in obj_paths:
|
|
||||||
print(f'load {obj_path}')
|
|
||||||
v, f, _, n = co.io3d.read_obj(obj_path)
|
|
||||||
diffs = v.max(axis=0) - v.min(axis=0)
|
|
||||||
v /= (0.5 * diffs.max())
|
|
||||||
v -= (v.min(axis=0) + 1)
|
|
||||||
f = f.astype(np.int32)
|
|
||||||
objs.append((v,f,n))
|
|
||||||
print(f'loaded {len(objs)} objects')
|
|
||||||
|
|
||||||
return objs
|
|
||||||
|
|
||||||
|
|
||||||
def get_mesh(rng, min_z=0):
|
def get_mesh(rng, min_z=0):
|
||||||
# set up background board
|
# set up background board
|
||||||
verts, faces, normals, colors = [], [], [], []
|
verts, faces, normals, colors = [], [], [], []
|
||||||
v, f, n = co.geometry.xyplane(z=0, interleaved=True)
|
v, f, n = co.geometry.xyplane(z=0, interleaved=True)
|
||||||
v[:,2] += -v[:,2].min() + rng.uniform(2,7)
|
v[:, 2] += -v[:, 2].min() + rng.uniform(2, 7)
|
||||||
v[:,:2] *= 5e2
|
v[:, :2] *= 5e2
|
||||||
v[:,2] = np.mean(v[:,2]) + (v[:,2] - np.mean(v[:,2])) * 5e2
|
v[:, 2] = np.mean(v[:, 2]) + (v[:, 2] - np.mean(v[:, 2])) * 5e2
|
||||||
c = np.empty_like(v)
|
|
||||||
c[:] = rng.uniform(0,1, size=(3,)).astype(np.float32)
|
|
||||||
verts.append(v)
|
|
||||||
faces.append(f)
|
|
||||||
normals.append(n)
|
|
||||||
colors.append(c)
|
|
||||||
|
|
||||||
# randomly sample 4 foreground objects for each scene
|
|
||||||
for shape_idx in range(4):
|
|
||||||
v, f, n = objs[rng.randint(0,len(objs))]
|
|
||||||
v, f, n = v.copy(), f.copy(), n.copy()
|
|
||||||
|
|
||||||
s = rng.uniform(0.25, 1)
|
|
||||||
v *= s
|
|
||||||
R = co.geometry.rotm_from_quat(co.geometry.quat_random(rng=rng))
|
|
||||||
v = v @ R.T
|
|
||||||
n = n @ R.T
|
|
||||||
v[:,2] += -v[:,2].min() + min_z + rng.uniform(0.5, 3)
|
|
||||||
v[:,:2] += rng.uniform(-1, 1, size=(1,2))
|
|
||||||
|
|
||||||
c = np.empty_like(v)
|
c = np.empty_like(v)
|
||||||
c[:] = rng.uniform(0,1, size=(3,)).astype(np.float32)
|
c[:] = rng.uniform(0, 1, size=(3,)).astype(np.float32)
|
||||||
|
verts.append(v)
|
||||||
verts.append(v.astype(np.float32))
|
|
||||||
faces.append(f)
|
faces.append(f)
|
||||||
normals.append(n)
|
normals.append(n)
|
||||||
colors.append(c)
|
colors.append(c)
|
||||||
|
|
||||||
verts, faces = co.geometry.stack_mesh(verts, faces)
|
# randomly sample 4 foreground objects for each scene
|
||||||
normals = np.vstack(normals).astype(np.float32)
|
for shape_idx in range(4):
|
||||||
colors = np.vstack(colors).astype(np.float32)
|
v, f, n = objs[rng.randint(0, len(objs))]
|
||||||
return verts, faces, colors, normals
|
v, f, n = v.copy(), f.copy(), n.copy()
|
||||||
|
|
||||||
|
s = rng.uniform(0.25, 1)
|
||||||
|
v *= s
|
||||||
|
R = co.geometry.rotm_from_quat(co.geometry.quat_random(rng=rng))
|
||||||
|
v = v @ R.T
|
||||||
|
n = n @ R.T
|
||||||
|
v[:, 2] += -v[:, 2].min() + min_z + rng.uniform(0.5, 3)
|
||||||
|
v[:, :2] += rng.uniform(-1, 1, size=(1, 2))
|
||||||
|
|
||||||
|
c = np.empty_like(v)
|
||||||
|
c[:] = rng.uniform(0, 1, size=(3,)).astype(np.float32)
|
||||||
|
|
||||||
|
verts.append(v.astype(np.float32))
|
||||||
|
faces.append(f)
|
||||||
|
normals.append(n)
|
||||||
|
colors.append(c)
|
||||||
|
|
||||||
|
verts, faces = co.geometry.stack_mesh(verts, faces)
|
||||||
|
normals = np.vstack(normals).astype(np.float32)
|
||||||
|
colors = np.vstack(colors).astype(np.float32)
|
||||||
|
return verts, faces, colors, normals
|
||||||
|
|
||||||
|
|
||||||
def create_data(out_root, idx, n_samples, imsize, patterns, K, baseline, blend_im, noise, track_length=4):
|
def create_data(out_root, idx, n_samples, imsize, patterns, K, baseline, blend_im, noise, track_length=4):
|
||||||
|
tic = time.time()
|
||||||
|
rng = np.random.RandomState()
|
||||||
|
|
||||||
tic = time.time()
|
rng.seed(idx)
|
||||||
rng = np.random.RandomState()
|
|
||||||
|
|
||||||
rng.seed(idx)
|
verts, faces, colors, normals = get_mesh(rng)
|
||||||
|
data = renderer.PyRenderInput(verts=verts.copy(), colors=colors.copy(), normals=normals.copy(), faces=faces.copy())
|
||||||
|
print(f'loading mesh for sample {idx + 1}/{n_samples} took {time.time() - tic}[s]')
|
||||||
|
|
||||||
verts, faces, colors, normals = get_mesh(rng)
|
# let the camera point to the center
|
||||||
data = renderer.PyRenderInput(verts=verts.copy(), colors=colors.copy(), normals=normals.copy(), faces=faces.copy())
|
center = np.array([0, 0, 3], dtype=np.float32)
|
||||||
print(f'loading mesh for sample {idx+1}/{n_samples} took {time.time()-tic}[s]')
|
|
||||||
|
basevec = np.array([-baseline, 0, 0], dtype=np.float32)
|
||||||
|
unit = np.array([0, 0, 1], dtype=np.float32)
|
||||||
|
|
||||||
|
cam_x_ = rng.uniform(-0.2, 0.2)
|
||||||
|
cam_y_ = rng.uniform(-0.2, 0.2)
|
||||||
|
cam_z_ = rng.uniform(-0.2, 0.2)
|
||||||
|
|
||||||
|
ret = collections.defaultdict(list)
|
||||||
|
blend_im_rnd = np.clip(blend_im + rng.uniform(-0.1, 0.1), 0, 1)
|
||||||
|
|
||||||
|
# capture the same static scene from different view points as a track
|
||||||
|
for ind in range(track_length):
|
||||||
|
|
||||||
|
cam_x = cam_x_ + rng.uniform(-0.1, 0.1)
|
||||||
|
cam_y = cam_y_ + rng.uniform(-0.1, 0.1)
|
||||||
|
cam_z = cam_z_ + rng.uniform(-0.1, 0.1)
|
||||||
|
|
||||||
|
tcam = np.array([cam_x, cam_y, cam_z], dtype=np.float32)
|
||||||
|
|
||||||
|
if np.linalg.norm(tcam[0:2]) < 1e-9:
|
||||||
|
Rcam = np.eye(3, dtype=np.float32)
|
||||||
|
else:
|
||||||
|
Rcam = get_rotation_matrix(center, center - tcam)
|
||||||
|
|
||||||
|
tproj = tcam + basevec
|
||||||
|
Rproj = Rcam
|
||||||
|
|
||||||
|
ret['R'].append(Rcam)
|
||||||
|
ret['t'].append(tcam)
|
||||||
|
|
||||||
|
cams = []
|
||||||
|
projs = []
|
||||||
|
|
||||||
|
# render the scene at multiple scales
|
||||||
|
scales = [1, 0.5, 0.25, 0.125]
|
||||||
|
|
||||||
|
for scale in scales:
|
||||||
|
fx = K[0, 0] * scale
|
||||||
|
fy = K[1, 1] * scale
|
||||||
|
px = K[0, 2] * scale
|
||||||
|
py = K[1, 2] * scale
|
||||||
|
im_height = imsize[0] * scale
|
||||||
|
im_width = imsize[1] * scale
|
||||||
|
cams.append(renderer.PyCamera(fx, fy, px, py, Rcam, tcam, im_width, im_height))
|
||||||
|
projs.append(renderer.PyCamera(fx, fy, px, py, Rproj, tproj, im_width, im_height))
|
||||||
|
|
||||||
|
for s, cam, proj, pattern in zip(itertools.count(), cams, projs, patterns):
|
||||||
|
fl = K[0, 0] / (2 ** s)
|
||||||
|
|
||||||
|
shader = renderer.PyShader(0.5, 1.5, 0.0, 10)
|
||||||
|
pyrenderer = renderer.PyRenderer(cam, shader, engine='gpu')
|
||||||
|
pyrenderer.mesh_proj(data, proj, pattern, d_alpha=0, d_beta=0.35)
|
||||||
|
|
||||||
|
# get the reflected laser pattern $R$
|
||||||
|
im = pyrenderer.color().copy()
|
||||||
|
depth = pyrenderer.depth().copy()
|
||||||
|
disp = baseline * fl / depth
|
||||||
|
mask = depth > 0
|
||||||
|
im = np.mean(im, axis=2)
|
||||||
|
|
||||||
|
# get the ambient image $A$
|
||||||
|
ambient = pyrenderer.normal().copy()
|
||||||
|
ambient = np.mean(ambient, axis=2)
|
||||||
|
|
||||||
|
# get the noise free IR image $J$
|
||||||
|
im = blend_im_rnd * im + (1 - blend_im_rnd) * ambient
|
||||||
|
ret[f'ambient{s}'].append(ambient[None].astype(np.float32))
|
||||||
|
|
||||||
|
# get the gradient magnitude of the ambient image $|\nabla A|$
|
||||||
|
ambient = ambient.astype(np.float32)
|
||||||
|
sobelx = cv2.Sobel(ambient, cv2.CV_32F, 1, 0, ksize=5)
|
||||||
|
sobely = cv2.Sobel(ambient, cv2.CV_32F, 0, 1, ksize=5)
|
||||||
|
grad = np.sqrt(sobelx ** 2 + sobely ** 2)
|
||||||
|
grad = np.maximum(grad - 0.8, 0.0) # parameter
|
||||||
|
|
||||||
|
# get the local contract normalized grad LCN($|\nabla A|$)
|
||||||
|
grad_lcn, grad_std = lcn.normalize(grad, 5, 0.1)
|
||||||
|
grad_lcn = np.clip(grad_lcn, 0.0, 1.0) # parameter
|
||||||
|
ret[f'grad{s}'].append(grad_lcn[None].astype(np.float32))
|
||||||
|
|
||||||
|
ret[f'im{s}'].append(im[None].astype(np.float32))
|
||||||
|
ret[f'mask{s}'].append(mask[None].astype(np.float32))
|
||||||
|
ret[f'disp{s}'].append(disp[None].astype(np.float32))
|
||||||
|
|
||||||
|
for key in ret.keys():
|
||||||
|
ret[key] = np.stack(ret[key], axis=0)
|
||||||
|
|
||||||
|
# save to files
|
||||||
|
out_dir = out_root / f'{idx:08d}'
|
||||||
|
out_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
for k, val in ret.items():
|
||||||
|
for tidx in range(track_length):
|
||||||
|
v = val[tidx]
|
||||||
|
out_path = out_dir / f'{k}_{tidx}.npy'
|
||||||
|
np.save(out_path, v)
|
||||||
|
np.save(str(out_dir / 'blend_im.npy'), blend_im_rnd)
|
||||||
|
|
||||||
|
print(f'create sample {idx + 1}/{n_samples} took {time.time() - tic}[s]')
|
||||||
|
|
||||||
|
|
||||||
# let the camera point to the center
|
if __name__ == '__main__':
|
||||||
center = np.array([0,0,3], dtype=np.float32)
|
|
||||||
|
|
||||||
basevec = np.array([-baseline,0,0], dtype=np.float32)
|
np.random.seed(42)
|
||||||
unit = np.array([0,0,1],dtype=np.float32)
|
|
||||||
|
|
||||||
cam_x_ = rng.uniform(-0.2,0.2)
|
# output directory
|
||||||
cam_y_ = rng.uniform(-0.2,0.2)
|
with open('../config.json') as fp:
|
||||||
cam_z_ = rng.uniform(-0.2,0.2)
|
config = json.load(fp)
|
||||||
|
data_root = Path(config['DATA_ROOT'])
|
||||||
|
shapenet_root = config['SHAPENET_ROOT']
|
||||||
|
|
||||||
ret = collections.defaultdict(list)
|
data_type = 'syn'
|
||||||
blend_im_rnd = np.clip(blend_im + rng.uniform(-0.1,0.1), 0,1)
|
out_root = data_root / f'{data_type}'
|
||||||
|
out_root.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# capture the same static scene from different view points as a track
|
start = 0
|
||||||
for ind in range(track_length):
|
if len(sys.argv) >= 2 and isinstance(sys.argv[2], int):
|
||||||
|
start = sys.argv[2]
|
||||||
cam_x = cam_x_ + rng.uniform(-0.1,0.1)
|
|
||||||
cam_y = cam_y_ + rng.uniform(-0.1,0.1)
|
|
||||||
cam_z = cam_z_ + rng.uniform(-0.1,0.1)
|
|
||||||
|
|
||||||
tcam = np.array([cam_x, cam_y, cam_z], dtype=np.float32)
|
|
||||||
|
|
||||||
if np.linalg.norm(tcam[0:2])<1e-9:
|
|
||||||
Rcam = np.eye(3, dtype=np.float32)
|
|
||||||
else:
|
else:
|
||||||
Rcam = get_rotation_matrix(center, center-tcam)
|
if sys.argv[2] == '--resume':
|
||||||
|
try:
|
||||||
|
start = max([int(dir) for dir in os.listdir(out_root) if str.isdigit(dir)]) or 0
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
tproj = tcam + basevec
|
# load shapenet models
|
||||||
Rproj = Rcam
|
obj_classes = ['chair']
|
||||||
|
objs = get_objs(shapenet_root, obj_classes)
|
||||||
|
|
||||||
ret['R'].append(Rcam)
|
# camera parameters
|
||||||
ret['t'].append(tcam)
|
imsize = (488, 648)
|
||||||
|
imsizes = [(imsize[0] // (2 ** s), imsize[1] // (2 ** s)) for s in range(4)]
|
||||||
|
# K = np.array([[567.6, 0, 324.7], [0, 570.2, 250.1], [0 ,0, 1]], dtype=np.float32)
|
||||||
|
K = np.array([[1929.5936336276382, 0, 113.66561071478046], [0, 1911.2517985448746, 473.70108079885887], [0, 0, 1]],
|
||||||
|
dtype=np.float32)
|
||||||
|
focal_lengths = [K[0, 0] / (2 ** s) for s in range(4)]
|
||||||
|
baseline = 0.075
|
||||||
|
blend_im = 0.6
|
||||||
|
noise = 0
|
||||||
|
|
||||||
cams = []
|
# capture the same static scene from different view points as a track
|
||||||
projs = []
|
track_length = 4
|
||||||
|
|
||||||
# render the scene at multiple scales
|
# load pattern image
|
||||||
scales = [1, 0.5, 0.25, 0.125]
|
pattern_path = './kinect_pattern.png'
|
||||||
|
pattern_crop = True
|
||||||
|
patterns = get_patterns(pattern_path, imsizes, pattern_crop)
|
||||||
|
|
||||||
for scale in scales:
|
# write settings to file
|
||||||
fx = K[0,0] * scale
|
settings = {
|
||||||
fy = K[1,1] * scale
|
'imsizes': imsizes,
|
||||||
px = K[0,2] * scale
|
'patterns': patterns,
|
||||||
py = K[1,2] * scale
|
'focal_lengths': focal_lengths,
|
||||||
im_height = imsize[0] * scale
|
'baseline': baseline,
|
||||||
im_width = imsize[1] * scale
|
'K': K,
|
||||||
cams.append( renderer.PyCamera(fx,fy,px,py, Rcam, tcam, im_width, im_height) )
|
}
|
||||||
projs.append( renderer.PyCamera(fx,fy,px,py, Rproj, tproj, im_width, im_height) )
|
out_path = out_root / f'settings.pkl'
|
||||||
|
print(f'write settings to {out_path}')
|
||||||
|
with open(str(out_path), 'wb') as f:
|
||||||
|
pickle.dump(settings, f, pickle.HIGHEST_PROTOCOL)
|
||||||
|
|
||||||
|
# start the job
|
||||||
for s, cam, proj, pattern in zip(itertools.count(), cams, projs, patterns):
|
n_samples = 2 ** 10 + 2 ** 13
|
||||||
fl = K[0,0] / (2**s)
|
for idx in range(start, n_samples):
|
||||||
|
args = (out_root, idx, n_samples, imsize, patterns, K, baseline, blend_im, noise, track_length)
|
||||||
shader = renderer.PyShader(0.5,1.5,0.0,10)
|
create_data(*args)
|
||||||
pyrenderer = renderer.PyRenderer(cam, shader, engine='gpu')
|
|
||||||
pyrenderer.mesh_proj(data, proj, pattern, d_alpha=0, d_beta=0.35)
|
|
||||||
|
|
||||||
# get the reflected laser pattern $R$
|
|
||||||
im = pyrenderer.color().copy()
|
|
||||||
depth = pyrenderer.depth().copy()
|
|
||||||
disp = baseline * fl / depth
|
|
||||||
mask = depth > 0
|
|
||||||
im = np.mean(im, axis=2)
|
|
||||||
|
|
||||||
# get the ambient image $A$
|
|
||||||
ambient = pyrenderer.normal().copy()
|
|
||||||
ambient = np.mean(ambient, axis=2)
|
|
||||||
|
|
||||||
# get the noise free IR image $J$
|
|
||||||
im = blend_im_rnd * im + (1 - blend_im_rnd) * ambient
|
|
||||||
ret[f'ambient{s}'].append( ambient[None].astype(np.float32) )
|
|
||||||
|
|
||||||
# get the gradient magnitude of the ambient image $|\nabla A|$
|
|
||||||
ambient = ambient.astype(np.float32)
|
|
||||||
sobelx = cv2.Sobel(ambient,cv2.CV_32F,1,0,ksize=5)
|
|
||||||
sobely = cv2.Sobel(ambient,cv2.CV_32F,0,1,ksize=5)
|
|
||||||
grad = np.sqrt(sobelx**2 + sobely**2)
|
|
||||||
grad = np.maximum(grad-0.8,0.0) # parameter
|
|
||||||
|
|
||||||
# get the local contract normalized grad LCN($|\nabla A|$)
|
|
||||||
grad_lcn, grad_std = lcn.normalize(grad,5,0.1)
|
|
||||||
grad_lcn = np.clip(grad_lcn,0.0,1.0) # parameter
|
|
||||||
ret[f'grad{s}'].append( grad_lcn[None].astype(np.float32))
|
|
||||||
|
|
||||||
ret[f'im{s}'].append( im[None].astype(np.float32))
|
|
||||||
ret[f'mask{s}'].append(mask[None].astype(np.float32))
|
|
||||||
ret[f'disp{s}'].append(disp[None].astype(np.float32))
|
|
||||||
|
|
||||||
for key in ret.keys():
|
|
||||||
ret[key] = np.stack(ret[key], axis=0)
|
|
||||||
|
|
||||||
# save to files
|
|
||||||
out_dir = out_root / f'{idx:08d}'
|
|
||||||
out_dir.mkdir(exist_ok=True, parents=True)
|
|
||||||
for k,val in ret.items():
|
|
||||||
for tidx in range(track_length):
|
|
||||||
v = val[tidx]
|
|
||||||
out_path = out_dir / f'{k}_{tidx}.npy'
|
|
||||||
np.save(out_path, v)
|
|
||||||
np.save( str(out_dir /'blend_im.npy'), blend_im_rnd)
|
|
||||||
|
|
||||||
print(f'create sample {idx+1}/{n_samples} took {time.time()-tic}[s]')
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__=='__main__':
|
|
||||||
|
|
||||||
np.random.seed(42)
|
|
||||||
|
|
||||||
# output directory
|
|
||||||
with open('../config.json') as fp:
|
|
||||||
config = json.load(fp)
|
|
||||||
data_root = Path(config['DATA_ROOT'])
|
|
||||||
shapenet_root = config['SHAPENET_ROOT']
|
|
||||||
|
|
||||||
data_type = 'syn'
|
|
||||||
out_root = data_root / f'{data_type}'
|
|
||||||
out_root.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
start = 0
|
|
||||||
if len(sys.argv) >= 2 and isinstance(sys.argv[2], int):
|
|
||||||
start = sys.argv[2]
|
|
||||||
else:
|
|
||||||
if sys.argv[2] == '--resume':
|
|
||||||
try:
|
|
||||||
start = max([int(dir) for dir in os.listdir(out_root) if str.isdigit(dir)]) or 0
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# load shapenet models
|
|
||||||
obj_classes = ['chair']
|
|
||||||
objs = get_objs(shapenet_root, obj_classes)
|
|
||||||
|
|
||||||
# camera parameters
|
|
||||||
imsize = (488, 648)
|
|
||||||
imsizes = [(imsize[0]//(2**s), imsize[1]//(2**s)) for s in range(4)]
|
|
||||||
# K = np.array([[567.6, 0, 324.7], [0, 570.2, 250.1], [0 ,0, 1]], dtype=np.float32)
|
|
||||||
K = np.array([[1929.5936336276382, 0, 113.66561071478046], [0, 1911.2517985448746, 473.70108079885887], [0 ,0, 1]], dtype=np.float32)
|
|
||||||
focal_lengths = [K[0,0]/(2**s) for s in range(4)]
|
|
||||||
baseline=0.075
|
|
||||||
blend_im = 0.6
|
|
||||||
noise = 0
|
|
||||||
|
|
||||||
# capture the same static scene from different view points as a track
|
|
||||||
track_length = 4
|
|
||||||
|
|
||||||
# load pattern image
|
|
||||||
pattern_path = './kinect_pattern.png'
|
|
||||||
pattern_crop = True
|
|
||||||
patterns = get_patterns(pattern_path, imsizes, pattern_crop)
|
|
||||||
|
|
||||||
# write settings to file
|
|
||||||
settings = {
|
|
||||||
'imsizes': imsizes,
|
|
||||||
'patterns': patterns,
|
|
||||||
'focal_lengths': focal_lengths,
|
|
||||||
'baseline': baseline,
|
|
||||||
'K': K,
|
|
||||||
}
|
|
||||||
out_path = out_root / f'settings.pkl'
|
|
||||||
print(f'write settings to {out_path}')
|
|
||||||
with open(str(out_path), 'wb') as f:
|
|
||||||
pickle.dump(settings, f, pickle.HIGHEST_PROTOCOL)
|
|
||||||
|
|
||||||
# start the job
|
|
||||||
n_samples = 2**10 + 2**13
|
|
||||||
for idx in range(start, n_samples):
|
|
||||||
args = (out_root, idx, n_samples, imsize, patterns, K, baseline, blend_im, noise, track_length)
|
|
||||||
create_data(*args)
|
|
||||||
|
216
data/dataset.py
216
data/dataset.py
@ -21,128 +21,128 @@ from .commons import get_patterns, augment_image
|
|||||||
|
|
||||||
from mpl_toolkits.mplot3d import Axes3D
|
from mpl_toolkits.mplot3d import Axes3D
|
||||||
|
|
||||||
|
|
||||||
class TrackSynDataset(torchext.BaseDataset):
|
class TrackSynDataset(torchext.BaseDataset):
|
||||||
'''
|
'''
|
||||||
Load locally saved synthetic dataset
|
Load locally saved synthetic dataset
|
||||||
Please run ./create_syn_data.sh to generate the dataset
|
Please run ./create_syn_data.sh to generate the dataset
|
||||||
'''
|
'''
|
||||||
def __init__(self, settings_path, sample_paths, track_length=2, train=True, data_aug=False):
|
|
||||||
super().__init__(train=train)
|
|
||||||
|
|
||||||
self.settings_path = settings_path
|
def __init__(self, settings_path, sample_paths, track_length=2, train=True, data_aug=False):
|
||||||
self.sample_paths = sample_paths
|
super().__init__(train=train)
|
||||||
self.data_aug = data_aug
|
|
||||||
self.train = train
|
|
||||||
self.track_length=track_length
|
|
||||||
assert(track_length<=4)
|
|
||||||
|
|
||||||
with open(str(settings_path), 'rb') as f:
|
self.settings_path = settings_path
|
||||||
settings = pickle.load(f)
|
self.sample_paths = sample_paths
|
||||||
self.imsizes = settings['imsizes']
|
self.data_aug = data_aug
|
||||||
self.patterns = settings['patterns']
|
self.train = train
|
||||||
self.focal_lengths = settings['focal_lengths']
|
self.track_length = track_length
|
||||||
self.baseline = settings['baseline']
|
assert (track_length <= 4)
|
||||||
self.K = settings['K']
|
|
||||||
|
|
||||||
self.scale = len(self.imsizes)
|
with open(str(settings_path), 'rb') as f:
|
||||||
|
settings = pickle.load(f)
|
||||||
|
self.imsizes = settings['imsizes']
|
||||||
|
self.patterns = settings['patterns']
|
||||||
|
self.focal_lengths = settings['focal_lengths']
|
||||||
|
self.baseline = settings['baseline']
|
||||||
|
self.K = settings['K']
|
||||||
|
|
||||||
self.max_shift=0
|
self.scale = len(self.imsizes)
|
||||||
self.max_blur=0.5
|
|
||||||
self.max_noise=3.0
|
|
||||||
self.max_sp_noise=0.0005
|
|
||||||
|
|
||||||
def __len__(self):
|
self.max_shift = 0
|
||||||
return len(self.sample_paths)
|
self.max_blur = 0.5
|
||||||
|
self.max_noise = 3.0
|
||||||
|
self.max_sp_noise = 0.0005
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __len__(self):
|
||||||
if not self.train:
|
return len(self.sample_paths)
|
||||||
rng = self.get_rng(idx)
|
|
||||||
else:
|
|
||||||
rng = np.random.RandomState()
|
|
||||||
sample_path = self.sample_paths[idx]
|
|
||||||
|
|
||||||
if self.train:
|
def __getitem__(self, idx):
|
||||||
track_ind = np.random.permutation(4)[0:self.track_length]
|
if not self.train:
|
||||||
else:
|
rng = self.get_rng(idx)
|
||||||
track_ind = [0]
|
|
||||||
|
|
||||||
ret = {}
|
|
||||||
ret['id'] = idx
|
|
||||||
|
|
||||||
# load imgs, at all scales
|
|
||||||
for sidx in range(len(self.imsizes)):
|
|
||||||
imgs = []
|
|
||||||
ambs = []
|
|
||||||
grads = []
|
|
||||||
for tidx in track_ind:
|
|
||||||
imgs.append(np.load(os.path.join(sample_path,f'im{sidx}_{tidx}.npy')))
|
|
||||||
ambs.append(np.load(os.path.join(sample_path,f'ambient{sidx}_{tidx}.npy')))
|
|
||||||
grads.append(np.load(os.path.join(sample_path,f'grad{sidx}_{tidx}.npy')))
|
|
||||||
ret[f'im{sidx}'] = np.stack(imgs, axis=0)
|
|
||||||
ret[f'ambient{sidx}'] = np.stack(ambs, axis=0)
|
|
||||||
ret[f'grad{sidx}'] = np.stack(grads, axis=0)
|
|
||||||
|
|
||||||
# load disp and grad only at full resolution
|
|
||||||
disps = []
|
|
||||||
R = []
|
|
||||||
t = []
|
|
||||||
for tidx in track_ind:
|
|
||||||
disps.append(np.load(os.path.join(sample_path,f'disp0_{tidx}.npy')))
|
|
||||||
R.append(np.load(os.path.join(sample_path,f'R_{tidx}.npy')))
|
|
||||||
t.append(np.load(os.path.join(sample_path,f't_{tidx}.npy')))
|
|
||||||
ret[f'disp0'] = np.stack(disps, axis=0)
|
|
||||||
ret['R'] = np.stack(R, axis=0)
|
|
||||||
ret['t'] = np.stack(t, axis=0)
|
|
||||||
|
|
||||||
blend_im = np.load(os.path.join(sample_path,'blend_im.npy'))
|
|
||||||
ret['blend_im'] = blend_im.astype(np.float32)
|
|
||||||
|
|
||||||
#### apply data augmentation at different scales seperately, only work for max_shift=0
|
|
||||||
if self.data_aug:
|
|
||||||
for sidx in range(len(self.imsizes)):
|
|
||||||
if sidx==0:
|
|
||||||
img = ret[f'im{sidx}']
|
|
||||||
disp = ret[f'disp{sidx}']
|
|
||||||
grad = ret[f'grad{sidx}']
|
|
||||||
img_aug = np.zeros_like(img)
|
|
||||||
disp_aug = np.zeros_like(img)
|
|
||||||
grad_aug = np.zeros_like(img)
|
|
||||||
for i in range(img.shape[0]):
|
|
||||||
img_aug_, disp_aug_, grad_aug_ = augment_image(img[i,0],rng,
|
|
||||||
disp=disp[i,0],grad=grad[i,0],
|
|
||||||
max_shift=self.max_shift, max_blur=self.max_blur,
|
|
||||||
max_noise=self.max_noise, max_sp_noise=self.max_sp_noise)
|
|
||||||
img_aug[i] = img_aug_[None].astype(np.float32)
|
|
||||||
disp_aug[i] = disp_aug_[None].astype(np.float32)
|
|
||||||
grad_aug[i] = grad_aug_[None].astype(np.float32)
|
|
||||||
ret[f'im{sidx}'] = img_aug
|
|
||||||
ret[f'disp{sidx}'] = disp_aug
|
|
||||||
ret[f'grad{sidx}'] = grad_aug
|
|
||||||
else:
|
else:
|
||||||
img = ret[f'im{sidx}']
|
rng = np.random.RandomState()
|
||||||
img_aug = np.zeros_like(img)
|
sample_path = self.sample_paths[idx]
|
||||||
for i in range(img.shape[0]):
|
|
||||||
img_aug_, _, _ = augment_image(img[i,0],rng,
|
|
||||||
max_shift=self.max_shift, max_blur=self.max_blur,
|
|
||||||
max_noise=self.max_noise, max_sp_noise=self.max_sp_noise)
|
|
||||||
img_aug[i] = img_aug_[None].astype(np.float32)
|
|
||||||
ret[f'im{sidx}'] = img_aug
|
|
||||||
|
|
||||||
if len(track_ind)==1:
|
if self.train:
|
||||||
for key, val in ret.items():
|
track_ind = np.random.permutation(4)[0:self.track_length]
|
||||||
if key!='blend_im' and key!='id':
|
else:
|
||||||
ret[key] = val[0]
|
track_ind = [0]
|
||||||
|
|
||||||
|
ret = {}
|
||||||
|
ret['id'] = idx
|
||||||
|
|
||||||
return ret
|
# load imgs, at all scales
|
||||||
|
for sidx in range(len(self.imsizes)):
|
||||||
|
imgs = []
|
||||||
|
ambs = []
|
||||||
|
grads = []
|
||||||
|
for tidx in track_ind:
|
||||||
|
imgs.append(np.load(os.path.join(sample_path, f'im{sidx}_{tidx}.npy')))
|
||||||
|
ambs.append(np.load(os.path.join(sample_path, f'ambient{sidx}_{tidx}.npy')))
|
||||||
|
grads.append(np.load(os.path.join(sample_path, f'grad{sidx}_{tidx}.npy')))
|
||||||
|
ret[f'im{sidx}'] = np.stack(imgs, axis=0)
|
||||||
|
ret[f'ambient{sidx}'] = np.stack(ambs, axis=0)
|
||||||
|
ret[f'grad{sidx}'] = np.stack(grads, axis=0)
|
||||||
|
|
||||||
def getK(self, sidx=0):
|
# load disp and grad only at full resolution
|
||||||
K = self.K.copy() / (2**sidx)
|
disps = []
|
||||||
K[2,2] = 1
|
R = []
|
||||||
return K
|
t = []
|
||||||
|
for tidx in track_ind:
|
||||||
|
disps.append(np.load(os.path.join(sample_path, f'disp0_{tidx}.npy')))
|
||||||
|
R.append(np.load(os.path.join(sample_path, f'R_{tidx}.npy')))
|
||||||
|
t.append(np.load(os.path.join(sample_path, f't_{tidx}.npy')))
|
||||||
|
ret[f'disp0'] = np.stack(disps, axis=0)
|
||||||
|
ret['R'] = np.stack(R, axis=0)
|
||||||
|
ret['t'] = np.stack(t, axis=0)
|
||||||
|
|
||||||
|
blend_im = np.load(os.path.join(sample_path, 'blend_im.npy'))
|
||||||
|
ret['blend_im'] = blend_im.astype(np.float32)
|
||||||
|
|
||||||
|
#### apply data augmentation at different scales seperately, only work for max_shift=0
|
||||||
|
if self.data_aug:
|
||||||
|
for sidx in range(len(self.imsizes)):
|
||||||
|
if sidx == 0:
|
||||||
|
img = ret[f'im{sidx}']
|
||||||
|
disp = ret[f'disp{sidx}']
|
||||||
|
grad = ret[f'grad{sidx}']
|
||||||
|
img_aug = np.zeros_like(img)
|
||||||
|
disp_aug = np.zeros_like(img)
|
||||||
|
grad_aug = np.zeros_like(img)
|
||||||
|
for i in range(img.shape[0]):
|
||||||
|
img_aug_, disp_aug_, grad_aug_ = augment_image(img[i, 0], rng,
|
||||||
|
disp=disp[i, 0], grad=grad[i, 0],
|
||||||
|
max_shift=self.max_shift, max_blur=self.max_blur,
|
||||||
|
max_noise=self.max_noise,
|
||||||
|
max_sp_noise=self.max_sp_noise)
|
||||||
|
img_aug[i] = img_aug_[None].astype(np.float32)
|
||||||
|
disp_aug[i] = disp_aug_[None].astype(np.float32)
|
||||||
|
grad_aug[i] = grad_aug_[None].astype(np.float32)
|
||||||
|
ret[f'im{sidx}'] = img_aug
|
||||||
|
ret[f'disp{sidx}'] = disp_aug
|
||||||
|
ret[f'grad{sidx}'] = grad_aug
|
||||||
|
else:
|
||||||
|
img = ret[f'im{sidx}']
|
||||||
|
img_aug = np.zeros_like(img)
|
||||||
|
for i in range(img.shape[0]):
|
||||||
|
img_aug_, _, _ = augment_image(img[i, 0], rng,
|
||||||
|
max_shift=self.max_shift, max_blur=self.max_blur,
|
||||||
|
max_noise=self.max_noise, max_sp_noise=self.max_sp_noise)
|
||||||
|
img_aug[i] = img_aug_[None].astype(np.float32)
|
||||||
|
ret[f'im{sidx}'] = img_aug
|
||||||
|
|
||||||
|
if len(track_ind) == 1:
|
||||||
|
for key, val in ret.items():
|
||||||
|
if key != 'blend_im' and key != 'id':
|
||||||
|
ret[key] = val[0]
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def getK(self, sidx=0):
|
||||||
|
K = self.K.copy() / (2 ** sidx)
|
||||||
|
K[2, 2] = 1
|
||||||
|
return K
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
<!-- Generated by Cython 0.29 -->
|
<!-- Generated by Cython 0.29 -->
|
||||||
<html>
|
<html>
|
||||||
<head>
|
<head>
|
||||||
<meta http-equiv="Content-Type" content="text/html; charset=utf-8" />
|
<meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
|
||||||
<title>Cython: lcn.pyx</title>
|
<title>Cython: lcn.pyx</title>
|
||||||
<style type="text/css">
|
<style type="text/css">
|
||||||
|
|
||||||
@ -355,17 +355,23 @@ body.cython { font-family: courier; font-size: 12; }
|
|||||||
.cython .vi { color: #19177C } /* Name.Variable.Instance */
|
.cython .vi { color: #19177C } /* Name.Variable.Instance */
|
||||||
.cython .vm { color: #19177C } /* Name.Variable.Magic */
|
.cython .vm { color: #19177C } /* Name.Variable.Magic */
|
||||||
.cython .il { color: #666666 } /* Literal.Number.Integer.Long */
|
.cython .il { color: #666666 } /* Literal.Number.Integer.Long */
|
||||||
|
|
||||||
</style>
|
</style>
|
||||||
</head>
|
</head>
|
||||||
<body class="cython">
|
<body class="cython">
|
||||||
<p><span style="border-bottom: solid 1px grey;">Generated by Cython 0.29</span></p>
|
<p><span style="border-bottom: solid 1px grey;">Generated by Cython 0.29</span></p>
|
||||||
<p>
|
<p>
|
||||||
<span style="background-color: #FFFF00">Yellow lines</span> hint at Python interaction.<br />
|
<span style="background-color: #FFFF00">Yellow lines</span> hint at Python interaction.<br/>
|
||||||
Click on a line that starts with a "<code>+</code>" to see the C code that Cython generated for it.
|
Click on a line that starts with a "<code>+</code>" to see the C code that Cython generated for it.
|
||||||
</p>
|
</p>
|
||||||
<p>Raw output: <a href="lcn.c">lcn.c</a></p>
|
<p>Raw output: <a href="lcn.c">lcn.c</a></p>
|
||||||
<div class="cython"><pre class="cython line score-16" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">01</span>: <span class="k">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span></pre>
|
<div class="cython">
|
||||||
<pre class='cython code score-16 '> __pyx_t_1 = <span class='pyx_c_api'>__Pyx_Import</span>(__pyx_n_s_numpy, 0, -1);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1, __pyx_L1_error)</span>
|
<pre class="cython line score-16"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">01</span>: <span class="k">import</span> <span class="nn">numpy</span> <span
|
||||||
|
class="k">as</span> <span class="nn">np</span></pre>
|
||||||
|
<pre class='cython code score-16 '> __pyx_t_1 = <span class='pyx_c_api'>__Pyx_Import</span>(__pyx_n_s_numpy, 0, -1);<span
|
||||||
|
class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1, __pyx_L1_error)</span>
|
||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
||||||
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_np, __pyx_t_1) < 0) <span class='error_goto'>__PYX_ERR(0, 1, __pyx_L1_error)</span>
|
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_np, __pyx_t_1) < 0) <span class='error_goto'>__PYX_ERR(0, 1, __pyx_L1_error)</span>
|
||||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
|
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
|
||||||
@ -374,22 +380,39 @@ body.cython { font-family: courier; font-size: 12; }
|
|||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
||||||
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_test, __pyx_t_1) < 0) <span class='error_goto'>__PYX_ERR(0, 1, __pyx_L1_error)</span>
|
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_test, __pyx_t_1) < 0) <span class='error_goto'>__PYX_ERR(0, 1, __pyx_L1_error)</span>
|
||||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
|
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
|
||||||
</pre><pre class="cython line score-0"> <span class="">02</span>: <span class="k">cimport</span> <span class="nn">cython</span></pre>
|
</pre>
|
||||||
<pre class="cython line score-0"> <span class="">03</span>: </pre>
|
<pre class="cython line score-0"> <span class="">02</span>: <span class="k">cimport</span> <span class="nn">cython</span></pre>
|
||||||
<pre class="cython line score-0"> <span class="">04</span>: <span class="c"># use c square root function</span></pre>
|
<pre class="cython line score-0"> <span class="">03</span>: </pre>
|
||||||
<pre class="cython line score-0"> <span class="">05</span>: <span class="k">cdef</span> <span class="kr">extern</span> <span class="k">from</span> <span class="s">"math.h"</span><span class="p">:</span></pre>
|
<pre class="cython line score-0"> <span class="">04</span>: <span class="c"># use c square root function</span></pre>
|
||||||
<pre class="cython line score-0"> <span class="">06</span>: <span class="nb">float</span> <span class="n">sqrt</span><span class="p">(</span><span class="nb">float</span> <span class="n">x</span><span class="p">)</span></pre>
|
<pre class="cython line score-0"> <span class="">05</span>: <span class="k">cdef</span> <span
|
||||||
<pre class="cython line score-0"> <span class="">07</span>: </pre>
|
class="kr">extern</span> <span class="k">from</span> <span class="s">"math.h"</span><span
|
||||||
<pre class="cython line score-0"> <span class="">08</span>: <span class="nd">@cython</span><span class="o">.</span><span class="n">boundscheck</span><span class="p">(</span><span class="bp">False</span><span class="p">)</span></pre>
|
class="p">:</span></pre>
|
||||||
<pre class="cython line score-0"> <span class="">09</span>: <span class="nd">@cython</span><span class="o">.</span><span class="n">wraparound</span><span class="p">(</span><span class="bp">False</span><span class="p">)</span></pre>
|
<pre class="cython line score-0"> <span class="">06</span>: <span class="nb">float</span> <span class="n">sqrt</span><span
|
||||||
<pre class="cython line score-0"> <span class="">10</span>: <span class="nd">@cython</span><span class="o">.</span><span class="n">cdivision</span><span class="p">(</span><span class="bp">True</span><span class="p">)</span></pre>
|
class="p">(</span><span class="nb">float</span> <span class="n">x</span><span class="p">)</span></pre>
|
||||||
<pre class="cython line score-0"> <span class="">11</span>: </pre>
|
<pre class="cython line score-0"> <span class="">07</span>: </pre>
|
||||||
<pre class="cython line score-0"> <span class="">12</span>: <span class="c"># 3 parameters:</span></pre>
|
<pre class="cython line score-0"> <span class="">08</span>: <span class="nd">@cython</span><span
|
||||||
<pre class="cython line score-0"> <span class="">13</span>: <span class="c"># - float image</span></pre>
|
class="o">.</span><span class="n">boundscheck</span><span class="p">(</span><span
|
||||||
<pre class="cython line score-0"> <span class="">14</span>: <span class="c"># - kernel size (actually this is the radius, kernel is 2*k+1)</span></pre>
|
class="bp">False</span><span class="p">)</span></pre>
|
||||||
<pre class="cython line score-0"> <span class="">15</span>: <span class="c"># - small constant epsilon that is used to avoid division by zero</span></pre>
|
<pre class="cython line score-0"> <span class="">09</span>: <span class="nd">@cython</span><span
|
||||||
<pre class="cython line score-67" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">16</span>: <span class="k">def</span> <span class="nf">normalize</span><span class="p">(</span><span class="nb">float</span><span class="p">[:,</span> <span class="p">:]</span> <span class="n">img</span><span class="p">,</span> <span class="nb">int</span> <span class="n">kernel_size</span> <span class="o">=</span> <span class="mf">4</span><span class="p">,</span> <span class="nb">float</span> <span class="n">epsilon</span> <span class="o">=</span> <span class="mf">0.01</span><span class="p">):</span></pre>
|
class="o">.</span><span class="n">wraparound</span><span class="p">(</span><span
|
||||||
<pre class='cython code score-67 '>/* Python wrapper */
|
class="bp">False</span><span class="p">)</span></pre>
|
||||||
|
<pre class="cython line score-0"> <span class="">10</span>: <span class="nd">@cython</span><span
|
||||||
|
class="o">.</span><span class="n">cdivision</span><span class="p">(</span><span class="bp">True</span><span
|
||||||
|
class="p">)</span></pre>
|
||||||
|
<pre class="cython line score-0"> <span class="">11</span>: </pre>
|
||||||
|
<pre class="cython line score-0"> <span class="">12</span>: <span class="c"># 3 parameters:</span></pre>
|
||||||
|
<pre class="cython line score-0"> <span class="">13</span>: <span class="c"># - float image</span></pre>
|
||||||
|
<pre class="cython line score-0"> <span class="">14</span>: <span class="c"># - kernel size (actually this is the radius, kernel is 2*k+1)</span></pre>
|
||||||
|
<pre class="cython line score-0"> <span class="">15</span>: <span class="c"># - small constant epsilon that is used to avoid division by zero</span></pre>
|
||||||
|
<pre class="cython line score-67"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">16</span>: <span class="k">def</span> <span class="nf">normalize</span><span
|
||||||
|
class="p">(</span><span class="nb">float</span><span class="p">[:,</span> <span class="p">:]</span> <span
|
||||||
|
class="n">img</span><span class="p">,</span> <span class="nb">int</span> <span class="n">kernel_size</span> <span
|
||||||
|
class="o">=</span> <span class="mf">4</span><span class="p">,</span> <span class="nb">float</span> <span
|
||||||
|
class="n">epsilon</span> <span class="o">=</span> <span class="mf">0.01</span><span
|
||||||
|
class="p">):</span></pre>
|
||||||
|
<pre class='cython code score-67 '>/* Python wrapper */
|
||||||
static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/
|
static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/
|
||||||
static PyMethodDef __pyx_mdef_3lcn_1normalize = {"normalize", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_3lcn_1normalize, METH_VARARGS|METH_KEYWORDS, 0};
|
static PyMethodDef __pyx_mdef_3lcn_1normalize = {"normalize", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_3lcn_1normalize, METH_VARARGS|METH_KEYWORDS, 0};
|
||||||
static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) {
|
static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) {
|
||||||
@ -434,7 +457,8 @@ static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (unlikely(kw_args > 0)) {
|
if (unlikely(kw_args > 0)) {
|
||||||
if (unlikely(<span class='pyx_c_api'>__Pyx_ParseOptionalKeywords</span>(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "normalize") < 0)) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
|
if (unlikely(<span class='pyx_c_api'>__Pyx_ParseOptionalKeywords</span>(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "normalize") < 0)) <span
|
||||||
|
class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
switch (<span class='py_macro_api'>PyTuple_GET_SIZE</span>(__pyx_args)) {
|
switch (<span class='py_macro_api'>PyTuple_GET_SIZE</span>(__pyx_args)) {
|
||||||
@ -447,21 +471,27 @@ static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_
|
|||||||
default: goto __pyx_L5_argtuple_error;
|
default: goto __pyx_L5_argtuple_error;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
__pyx_v_img = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(values[0], PyBUF_WRITABLE);<span class='error_goto'> if (unlikely(!__pyx_v_img.memview)) __PYX_ERR(0, 16, __pyx_L3_error)</span>
|
__pyx_v_img = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(values[0], PyBUF_WRITABLE);<span
|
||||||
|
class='error_goto'> if (unlikely(!__pyx_v_img.memview)) __PYX_ERR(0, 16, __pyx_L3_error)</span>
|
||||||
if (values[1]) {
|
if (values[1]) {
|
||||||
__pyx_v_kernel_size = <span class='pyx_c_api'>__Pyx_PyInt_As_int</span>(values[1]); if (unlikely((__pyx_v_kernel_size == (int)-1) && <span class='py_c_api'>PyErr_Occurred</span>())) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
|
__pyx_v_kernel_size = <span class='pyx_c_api'>__Pyx_PyInt_As_int</span>(values[1]); if (unlikely((__pyx_v_kernel_size == (int)-1) && <span
|
||||||
|
class='py_c_api'>PyErr_Occurred</span>())) <span
|
||||||
|
class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
|
||||||
} else {
|
} else {
|
||||||
__pyx_v_kernel_size = ((int)4);
|
__pyx_v_kernel_size = ((int)4);
|
||||||
}
|
}
|
||||||
if (values[2]) {
|
if (values[2]) {
|
||||||
__pyx_v_epsilon = __pyx_<span class='py_c_api'>PyFloat_AsFloat</span>(values[2]); if (unlikely((__pyx_v_epsilon == (float)-1) && <span class='py_c_api'>PyErr_Occurred</span>())) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
|
__pyx_v_epsilon = __pyx_<span class='py_c_api'>PyFloat_AsFloat</span>(values[2]); if (unlikely((__pyx_v_epsilon == (float)-1) && <span
|
||||||
|
class='py_c_api'>PyErr_Occurred</span>())) <span
|
||||||
|
class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
|
||||||
} else {
|
} else {
|
||||||
__pyx_v_epsilon = ((float)0.01);
|
__pyx_v_epsilon = ((float)0.01);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
goto __pyx_L4_argument_unpacking_done;
|
goto __pyx_L4_argument_unpacking_done;
|
||||||
__pyx_L5_argtuple_error:;
|
__pyx_L5_argtuple_error:;
|
||||||
<span class='pyx_c_api'>__Pyx_RaiseArgtupleInvalid</span>("normalize", 0, 1, 3, <span class='py_macro_api'>PyTuple_GET_SIZE</span>(__pyx_args)); <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
|
<span class='pyx_c_api'>__Pyx_RaiseArgtupleInvalid</span>("normalize", 0, 1, 3, <span class='py_macro_api'>PyTuple_GET_SIZE</span>(__pyx_args)); <span
|
||||||
|
class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
|
||||||
__pyx_L3_error:;
|
__pyx_L3_error:;
|
||||||
<span class='pyx_c_api'>__Pyx_AddTraceback</span>("lcn.normalize", __pyx_clineno, __pyx_lineno, __pyx_filename);
|
<span class='pyx_c_api'>__Pyx_AddTraceback</span>("lcn.normalize", __pyx_clineno, __pyx_lineno, __pyx_filename);
|
||||||
<span class='refnanny'>__Pyx_RefNannyFinishContext</span>();
|
<span class='refnanny'>__Pyx_RefNannyFinishContext</span>();
|
||||||
@ -515,27 +545,49 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
|
|||||||
return __pyx_r;
|
return __pyx_r;
|
||||||
}
|
}
|
||||||
/* … */
|
/* … */
|
||||||
__pyx_tuple__19 = <span class='py_c_api'>PyTuple_Pack</span>(19, __pyx_n_s_img, __pyx_n_s_kernel_size, __pyx_n_s_epsilon, __pyx_n_s_M, __pyx_n_s_N, __pyx_n_s_img_lcn, __pyx_n_s_img_std, __pyx_n_s_img_lcn_view, __pyx_n_s_img_std_view, __pyx_n_s_tmp, __pyx_n_s_mean, __pyx_n_s_stddev, __pyx_n_s_m, __pyx_n_s_n, __pyx_n_s_i, __pyx_n_s_j, __pyx_n_s_ks, __pyx_n_s_eps, __pyx_n_s_num);<span class='error_goto'> if (unlikely(!__pyx_tuple__19)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
|
__pyx_tuple__19 = <span class='py_c_api'>PyTuple_Pack</span>(19, __pyx_n_s_img, __pyx_n_s_kernel_size, __pyx_n_s_epsilon, __pyx_n_s_M, __pyx_n_s_N, __pyx_n_s_img_lcn, __pyx_n_s_img_std, __pyx_n_s_img_lcn_view, __pyx_n_s_img_std_view, __pyx_n_s_tmp, __pyx_n_s_mean, __pyx_n_s_stddev, __pyx_n_s_m, __pyx_n_s_n, __pyx_n_s_i, __pyx_n_s_j, __pyx_n_s_ks, __pyx_n_s_eps, __pyx_n_s_num);<span
|
||||||
|
class='error_goto'> if (unlikely(!__pyx_tuple__19)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
|
||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_tuple__19);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_tuple__19);
|
||||||
<span class='refnanny'>__Pyx_GIVEREF</span>(__pyx_tuple__19);
|
<span class='refnanny'>__Pyx_GIVEREF</span>(__pyx_tuple__19);
|
||||||
/* … */
|
/* … */
|
||||||
__pyx_t_1 = PyCFunction_NewEx(&__pyx_mdef_3lcn_1normalize, NULL, __pyx_n_s_lcn);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
|
__pyx_t_1 = PyCFunction_NewEx(&__pyx_mdef_3lcn_1normalize, NULL, __pyx_n_s_lcn);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
|
||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
||||||
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_normalize, __pyx_t_1) < 0) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L1_error)</span>
|
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_normalize, __pyx_t_1) < 0) <span
|
||||||
|
class='error_goto'>__PYX_ERR(0, 16, __pyx_L1_error)</span>
|
||||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
|
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
|
||||||
__pyx_codeobj__20 = (PyObject*)<span class='pyx_c_api'>__Pyx_PyCode_New</span>(3, 0, 19, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__19, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_lcn_pyx, __pyx_n_s_normalize, 16, __pyx_empty_bytes);<span class='error_goto'> if (unlikely(!__pyx_codeobj__20)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
|
__pyx_codeobj__20 = (PyObject*)<span class='pyx_c_api'>__Pyx_PyCode_New</span>(3, 0, 19, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__19, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_lcn_pyx, __pyx_n_s_normalize, 16, __pyx_empty_bytes);<span
|
||||||
</pre><pre class="cython line score-0"> <span class="">17</span>: </pre>
|
class='error_goto'> if (unlikely(!__pyx_codeobj__20)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
|
||||||
<pre class="cython line score-0"> <span class="">18</span>: <span class="c"># image dimensions</span></pre>
|
</pre>
|
||||||
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">19</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">M</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mf">0</span><span class="p">]</span></pre>
|
<pre class="cython line score-0"> <span class="">17</span>: </pre>
|
||||||
<pre class='cython code score-0 '> __pyx_v_M = (__pyx_v_img.shape[0]);
|
<pre class="cython line score-0"> <span class="">18</span>: <span class="c"># image dimensions</span></pre>
|
||||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">20</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">N</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mf">1</span><span class="p">]</span></pre>
|
<pre class="cython line score-0"
|
||||||
<pre class='cython code score-0 '> __pyx_v_N = (__pyx_v_img.shape[1]);
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
</pre><pre class="cython line score-0"> <span class="">21</span>: </pre>
|
class="">19</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span
|
||||||
<pre class="cython line score-0"> <span class="">22</span>: <span class="c"># create outputs and output views</span></pre>
|
class="nf">M</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span
|
||||||
<pre class="cython line score-46" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">23</span>: <span class="n">img_lcn</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span></pre>
|
class="n">shape</span><span class="p">[</span><span class="mf">0</span><span class="p">]</span></pre>
|
||||||
<pre class='cython code score-46 '> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_1, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
<pre class='cython code score-0 '> __pyx_v_M = (__pyx_v_img.shape[0]);
|
||||||
|
</pre>
|
||||||
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">20</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span
|
||||||
|
class="nf">N</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span
|
||||||
|
class="n">shape</span><span class="p">[</span><span class="mf">1</span><span class="p">]</span></pre>
|
||||||
|
<pre class='cython code score-0 '> __pyx_v_N = (__pyx_v_img.shape[1]);
|
||||||
|
</pre>
|
||||||
|
<pre class="cython line score-0"> <span class="">21</span>: </pre>
|
||||||
|
<pre class="cython line score-0"> <span class="">22</span>: <span class="c"># create outputs and output views</span></pre>
|
||||||
|
<pre class="cython line score-46"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">23</span>: <span class="n">img_lcn</span> <span class="o">=</span> <span
|
||||||
|
class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span
|
||||||
|
class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span
|
||||||
|
class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span
|
||||||
|
class="n">float32</span><span class="p">)</span></pre>
|
||||||
|
<pre class='cython code score-46 '> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_1, __pyx_n_s_np);<span
|
||||||
|
class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
||||||
__pyx_t_2 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_1, __pyx_n_s_zeros);<span class='error_goto'> if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
__pyx_t_2 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_1, __pyx_n_s_zeros);<span
|
||||||
|
class='error_goto'> if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_2);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_2);
|
||||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
|
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
|
||||||
__pyx_t_1 = <span class='py_c_api'>PyInt_FromSsize_t</span>(__pyx_v_M);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
__pyx_t_1 = <span class='py_c_api'>PyInt_FromSsize_t</span>(__pyx_v_M);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
||||||
@ -559,22 +611,34 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
|
|||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_4);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_4);
|
||||||
<span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_1, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
<span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_1, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
||||||
__pyx_t_5 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_1, __pyx_n_s_float32);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
__pyx_t_5 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_1, __pyx_n_s_float32);<span
|
||||||
|
class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
|
||||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
|
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
|
||||||
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_t_4, __pyx_n_s_dtype, __pyx_t_5) < 0) <span class='error_goto'>__PYX_ERR(0, 23, __pyx_L1_error)</span>
|
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_t_4, __pyx_n_s_dtype, __pyx_t_5) < 0) <span
|
||||||
|
class='error_goto'>__PYX_ERR(0, 23, __pyx_L1_error)</span>
|
||||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0;
|
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0;
|
||||||
__pyx_t_5 = <span class='pyx_c_api'>__Pyx_PyObject_Call</span>(__pyx_t_2, __pyx_t_3, __pyx_t_4);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
__pyx_t_5 = <span class='pyx_c_api'>__Pyx_PyObject_Call</span>(__pyx_t_2, __pyx_t_3, __pyx_t_4);<span
|
||||||
|
class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
|
||||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_2); __pyx_t_2 = 0;
|
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_2); __pyx_t_2 = 0;
|
||||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_3); __pyx_t_3 = 0;
|
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_3); __pyx_t_3 = 0;
|
||||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_4); __pyx_t_4 = 0;
|
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_4); __pyx_t_4 = 0;
|
||||||
__pyx_v_img_lcn = __pyx_t_5;
|
__pyx_v_img_lcn = __pyx_t_5;
|
||||||
__pyx_t_5 = 0;
|
__pyx_t_5 = 0;
|
||||||
</pre><pre class="cython line score-46" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">24</span>: <span class="n">img_std</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span></pre>
|
</pre>
|
||||||
<pre class='cython code score-46 '> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_5, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
<pre class="cython line score-46"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">24</span>: <span class="n">img_std</span> <span class="o">=</span> <span
|
||||||
|
class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span
|
||||||
|
class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span
|
||||||
|
class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span
|
||||||
|
class="n">float32</span><span class="p">)</span></pre>
|
||||||
|
<pre class='cython code score-46 '> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_5, __pyx_n_s_np);<span
|
||||||
|
class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
|
||||||
__pyx_t_4 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_5, __pyx_n_s_zeros);<span class='error_goto'> if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
__pyx_t_4 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_5, __pyx_n_s_zeros);<span
|
||||||
|
class='error_goto'> if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_4);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_4);
|
||||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0;
|
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0;
|
||||||
__pyx_t_5 = <span class='py_c_api'>PyInt_FromSsize_t</span>(__pyx_v_M);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
__pyx_t_5 = <span class='py_c_api'>PyInt_FromSsize_t</span>(__pyx_v_M);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
||||||
@ -598,114 +662,236 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
|
|||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_2);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_2);
|
||||||
<span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_5, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
<span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_5, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
|
||||||
__pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_5, __pyx_n_s_float32);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
__pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_5, __pyx_n_s_float32);<span
|
||||||
|
class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
||||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0;
|
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0;
|
||||||
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_t_2, __pyx_n_s_dtype, __pyx_t_1) < 0) <span class='error_goto'>__PYX_ERR(0, 24, __pyx_L1_error)</span>
|
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_t_2, __pyx_n_s_dtype, __pyx_t_1) < 0) <span
|
||||||
|
class='error_goto'>__PYX_ERR(0, 24, __pyx_L1_error)</span>
|
||||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
|
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
|
||||||
__pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyObject_Call</span>(__pyx_t_4, __pyx_t_3, __pyx_t_2);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
__pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyObject_Call</span>(__pyx_t_4, __pyx_t_3, __pyx_t_2);<span
|
||||||
|
class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
||||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_4); __pyx_t_4 = 0;
|
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_4); __pyx_t_4 = 0;
|
||||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_3); __pyx_t_3 = 0;
|
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_3); __pyx_t_3 = 0;
|
||||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_2); __pyx_t_2 = 0;
|
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_2); __pyx_t_2 = 0;
|
||||||
__pyx_v_img_std = __pyx_t_1;
|
__pyx_v_img_std = __pyx_t_1;
|
||||||
__pyx_t_1 = 0;
|
__pyx_t_1 = 0;
|
||||||
</pre><pre class="cython line score-2" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">25</span>: <span class="k">cdef</span> <span class="kt">float</span>[<span class="p">:,</span> <span class="p">:]</span> <span class="n">img_lcn_view</span> <span class="o">=</span> <span class="n">img_lcn</span></pre>
|
</pre>
|
||||||
<pre class='cython code score-2 '> __pyx_t_6 = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(__pyx_v_img_lcn, PyBUF_WRITABLE);<span class='error_goto'> if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 25, __pyx_L1_error)</span>
|
<pre class="cython line score-2"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">25</span>: <span class="k">cdef</span> <span class="kt">float</span>[<span class="p">:,</span> <span
|
||||||
|
class="p">:]</span> <span class="n">img_lcn_view</span> <span class="o">=</span> <span
|
||||||
|
class="n">img_lcn</span></pre>
|
||||||
|
<pre class='cython code score-2 '> __pyx_t_6 = <span
|
||||||
|
class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(__pyx_v_img_lcn, PyBUF_WRITABLE);<span
|
||||||
|
class='error_goto'> if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 25, __pyx_L1_error)</span>
|
||||||
__pyx_v_img_lcn_view = __pyx_t_6;
|
__pyx_v_img_lcn_view = __pyx_t_6;
|
||||||
__pyx_t_6.memview = NULL;
|
__pyx_t_6.memview = NULL;
|
||||||
__pyx_t_6.data = NULL;
|
__pyx_t_6.data = NULL;
|
||||||
</pre><pre class="cython line score-2" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">26</span>: <span class="k">cdef</span> <span class="kt">float</span>[<span class="p">:,</span> <span class="p">:]</span> <span class="n">img_std_view</span> <span class="o">=</span> <span class="n">img_std</span></pre>
|
</pre>
|
||||||
<pre class='cython code score-2 '> __pyx_t_6 = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(__pyx_v_img_std, PyBUF_WRITABLE);<span class='error_goto'> if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 26, __pyx_L1_error)</span>
|
<pre class="cython line score-2"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">26</span>: <span class="k">cdef</span> <span class="kt">float</span>[<span class="p">:,</span> <span
|
||||||
|
class="p">:]</span> <span class="n">img_std_view</span> <span class="o">=</span> <span
|
||||||
|
class="n">img_std</span></pre>
|
||||||
|
<pre class='cython code score-2 '> __pyx_t_6 = <span
|
||||||
|
class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(__pyx_v_img_std, PyBUF_WRITABLE);<span
|
||||||
|
class='error_goto'> if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 26, __pyx_L1_error)</span>
|
||||||
__pyx_v_img_std_view = __pyx_t_6;
|
__pyx_v_img_std_view = __pyx_t_6;
|
||||||
__pyx_t_6.memview = NULL;
|
__pyx_t_6.memview = NULL;
|
||||||
__pyx_t_6.data = NULL;
|
__pyx_t_6.data = NULL;
|
||||||
</pre><pre class="cython line score-0"> <span class="">27</span>: </pre>
|
</pre>
|
||||||
<pre class="cython line score-0"> <span class="">28</span>: <span class="c"># temporary c variables</span></pre>
|
<pre class="cython line score-0"> <span class="">27</span>: </pre>
|
||||||
<pre class="cython line score-0"> <span class="">29</span>: <span class="k">cdef</span> <span class="kt">float</span> <span class="nf">tmp</span><span class="p">,</span> <span class="nf">mean</span><span class="p">,</span> <span class="nf">stddev</span></pre>
|
<pre class="cython line score-0"> <span class="">28</span>: <span class="c"># temporary c variables</span></pre>
|
||||||
<pre class="cython line score-0"> <span class="">30</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">m</span><span class="p">,</span> <span class="nf">n</span><span class="p">,</span> <span class="nf">i</span><span class="p">,</span> <span class="nf">j</span></pre>
|
<pre class="cython line score-0"> <span class="">29</span>: <span class="k">cdef</span> <span class="kt">float</span> <span
|
||||||
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">31</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">ks</span> <span class="o">=</span> <span class="n">kernel_size</span></pre>
|
class="nf">tmp</span><span class="p">,</span> <span class="nf">mean</span><span class="p">,</span> <span
|
||||||
<pre class='cython code score-0 '> __pyx_v_ks = __pyx_v_kernel_size;
|
class="nf">stddev</span></pre>
|
||||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">32</span>: <span class="k">cdef</span> <span class="kt">float</span> <span class="nf">eps</span> <span class="o">=</span> <span class="n">epsilon</span></pre>
|
<pre class="cython line score-0"> <span class="">30</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span
|
||||||
<pre class='cython code score-0 '> __pyx_v_eps = __pyx_v_epsilon;
|
class="nf">m</span><span class="p">,</span> <span class="nf">n</span><span class="p">,</span> <span
|
||||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">33</span>: <span class="k">cdef</span> <span class="kt">float</span> <span class="nf">num</span> <span class="o">=</span> <span class="p">(</span><span class="n">ks</span><span class="o">*</span><span class="mf">2</span><span class="o">+</span><span class="mf">1</span><span class="p">)</span><span class="o">**</span><span class="mf">2</span></pre>
|
class="nf">i</span><span class="p">,</span> <span class="nf">j</span></pre>
|
||||||
<pre class='cython code score-0 '> __pyx_v_num = __Pyx_pow_Py_ssize_t(((__pyx_v_ks * 2) + 1), 2);
|
<pre class="cython line score-0"
|
||||||
</pre><pre class="cython line score-0"> <span class="">34</span>: </pre>
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
<pre class="cython line score-0"> <span class="">35</span>: <span class="c"># for all pixels do</span></pre>
|
class="">31</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span
|
||||||
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">36</span>: <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ks</span><span class="p">,</span><span class="n">M</span><span class="o">-</span><span class="n">ks</span><span class="p">):</span></pre>
|
class="nf">ks</span> <span class="o">=</span> <span class="n">kernel_size</span></pre>
|
||||||
<pre class='cython code score-0 '> __pyx_t_7 = (__pyx_v_M - __pyx_v_ks);
|
<pre class='cython code score-0 '> __pyx_v_ks = __pyx_v_kernel_size;
|
||||||
|
</pre>
|
||||||
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">32</span>: <span class="k">cdef</span> <span class="kt">float</span> <span
|
||||||
|
class="nf">eps</span> <span class="o">=</span> <span class="n">epsilon</span></pre>
|
||||||
|
<pre class='cython code score-0 '> __pyx_v_eps = __pyx_v_epsilon;
|
||||||
|
</pre>
|
||||||
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">33</span>: <span class="k">cdef</span> <span class="kt">float</span> <span
|
||||||
|
class="nf">num</span> <span class="o">=</span> <span class="p">(</span><span class="n">ks</span><span
|
||||||
|
class="o">*</span><span class="mf">2</span><span class="o">+</span><span class="mf">1</span><span class="p">)</span><span
|
||||||
|
class="o">**</span><span class="mf">2</span></pre>
|
||||||
|
<pre class='cython code score-0 '> __pyx_v_num = __Pyx_pow_Py_ssize_t(((__pyx_v_ks * 2) + 1), 2);
|
||||||
|
</pre>
|
||||||
|
<pre class="cython line score-0"> <span class="">34</span>: </pre>
|
||||||
|
<pre class="cython line score-0"> <span class="">35</span>: <span
|
||||||
|
class="c"># for all pixels do</span></pre>
|
||||||
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">36</span>: <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span
|
||||||
|
class="nb">range</span><span class="p">(</span><span class="n">ks</span><span class="p">,</span><span
|
||||||
|
class="n">M</span><span class="o">-</span><span class="n">ks</span><span class="p">):</span></pre>
|
||||||
|
<pre class='cython code score-0 '> __pyx_t_7 = (__pyx_v_M - __pyx_v_ks);
|
||||||
__pyx_t_8 = __pyx_t_7;
|
__pyx_t_8 = __pyx_t_7;
|
||||||
for (__pyx_t_9 = __pyx_v_ks; __pyx_t_9 < __pyx_t_8; __pyx_t_9+=1) {
|
for (__pyx_t_9 = __pyx_v_ks; __pyx_t_9 < __pyx_t_8; __pyx_t_9+=1) {
|
||||||
__pyx_v_m = __pyx_t_9;
|
__pyx_v_m = __pyx_t_9;
|
||||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">37</span>: <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ks</span><span class="p">,</span><span class="n">N</span><span class="o">-</span><span class="n">ks</span><span class="p">):</span></pre>
|
</pre>
|
||||||
<pre class='cython code score-0 '> __pyx_t_10 = (__pyx_v_N - __pyx_v_ks);
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">37</span>: <span class="k">for</span> <span class="n">n</span> <span
|
||||||
|
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ks</span><span
|
||||||
|
class="p">,</span><span class="n">N</span><span class="o">-</span><span class="n">ks</span><span class="p">):</span></pre>
|
||||||
|
<pre class='cython code score-0 '> __pyx_t_10 = (__pyx_v_N - __pyx_v_ks);
|
||||||
__pyx_t_11 = __pyx_t_10;
|
__pyx_t_11 = __pyx_t_10;
|
||||||
for (__pyx_t_12 = __pyx_v_ks; __pyx_t_12 < __pyx_t_11; __pyx_t_12+=1) {
|
for (__pyx_t_12 = __pyx_v_ks; __pyx_t_12 < __pyx_t_11; __pyx_t_12+=1) {
|
||||||
__pyx_v_n = __pyx_t_12;
|
__pyx_v_n = __pyx_t_12;
|
||||||
</pre><pre class="cython line score-0"> <span class="">38</span>: </pre>
|
</pre>
|
||||||
<pre class="cython line score-0"> <span class="">39</span>: <span class="c"># calculate mean</span></pre>
|
<pre class="cython line score-0"> <span class="">38</span>: </pre>
|
||||||
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">40</span>: <span class="n">mean</span> <span class="o">=</span> <span class="mf">0</span><span class="p">;</span></pre>
|
<pre class="cython line score-0"> <span class="">39</span>: <span class="c"># calculate mean</span></pre>
|
||||||
<pre class='cython code score-0 '> __pyx_v_mean = 0.0;
|
<pre class="cython line score-0"
|
||||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">41</span>: <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre>
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
<pre class='cython code score-0 '> __pyx_t_13 = (__pyx_v_ks + 1);
|
class="">40</span>: <span class="n">mean</span> <span class="o">=</span> <span
|
||||||
|
class="mf">0</span><span class="p">;</span></pre>
|
||||||
|
<pre class='cython code score-0 '> __pyx_v_mean = 0.0;
|
||||||
|
</pre>
|
||||||
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">41</span>: <span class="k">for</span> <span class="n">i</span> <span
|
||||||
|
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span
|
||||||
|
class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span
|
||||||
|
class="mf">1</span><span class="p">):</span></pre>
|
||||||
|
<pre class='cython code score-0 '> __pyx_t_13 = (__pyx_v_ks + 1);
|
||||||
__pyx_t_14 = __pyx_t_13;
|
__pyx_t_14 = __pyx_t_13;
|
||||||
for (__pyx_t_15 = (-__pyx_v_ks); __pyx_t_15 < __pyx_t_14; __pyx_t_15+=1) {
|
for (__pyx_t_15 = (-__pyx_v_ks); __pyx_t_15 < __pyx_t_14; __pyx_t_15+=1) {
|
||||||
__pyx_v_i = __pyx_t_15;
|
__pyx_v_i = __pyx_t_15;
|
||||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">42</span>: <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre>
|
</pre>
|
||||||
<pre class='cython code score-0 '> __pyx_t_16 = (__pyx_v_ks + 1);
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">42</span>: <span class="k">for</span> <span class="n">j</span> <span
|
||||||
|
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span
|
||||||
|
class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span
|
||||||
|
class="mf">1</span><span class="p">):</span></pre>
|
||||||
|
<pre class='cython code score-0 '> __pyx_t_16 = (__pyx_v_ks + 1);
|
||||||
__pyx_t_17 = __pyx_t_16;
|
__pyx_t_17 = __pyx_t_16;
|
||||||
for (__pyx_t_18 = (-__pyx_v_ks); __pyx_t_18 < __pyx_t_17; __pyx_t_18+=1) {
|
for (__pyx_t_18 = (-__pyx_v_ks); __pyx_t_18 < __pyx_t_17; __pyx_t_18+=1) {
|
||||||
__pyx_v_j = __pyx_t_18;
|
__pyx_v_j = __pyx_t_18;
|
||||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">43</span>: <span class="n">mean</span> <span class="o">+=</span> <span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span class="p">]</span></pre>
|
</pre>
|
||||||
<pre class='cython code score-0 '> __pyx_t_19 = (__pyx_v_m + __pyx_v_i);
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">43</span>: <span class="n">mean</span> <span class="o">+=</span> <span
|
||||||
|
class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span
|
||||||
|
class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span
|
||||||
|
class="p">]</span></pre>
|
||||||
|
<pre class='cython code score-0 '> __pyx_t_19 = (__pyx_v_m + __pyx_v_i);
|
||||||
__pyx_t_20 = (__pyx_v_n + __pyx_v_j);
|
__pyx_t_20 = (__pyx_v_n + __pyx_v_j);
|
||||||
__pyx_v_mean = (__pyx_v_mean + (*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_19 * __pyx_v_img.strides[0]) ) + __pyx_t_20 * __pyx_v_img.strides[1]) ))));
|
__pyx_v_mean = (__pyx_v_mean + (*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_19 * __pyx_v_img.strides[0]) ) + __pyx_t_20 * __pyx_v_img.strides[1]) ))));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">44</span>: <span class="n">mean</span> <span class="o">=</span> <span class="n">mean</span><span class="o">/</span><span class="n">num</span></pre>
|
</pre>
|
||||||
<pre class='cython code score-0 '> __pyx_v_mean = (__pyx_v_mean / __pyx_v_num);
|
<pre class="cython line score-0"
|
||||||
</pre><pre class="cython line score-0"> <span class="">45</span>: </pre>
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
<pre class="cython line score-0"> <span class="">46</span>: <span class="c"># calculate std dev</span></pre>
|
class="">44</span>: <span class="n">mean</span> <span class="o">=</span> <span
|
||||||
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">47</span>: <span class="n">stddev</span> <span class="o">=</span> <span class="mf">0</span><span class="p">;</span></pre>
|
class="n">mean</span><span class="o">/</span><span class="n">num</span></pre>
|
||||||
<pre class='cython code score-0 '> __pyx_v_stddev = 0.0;
|
<pre class='cython code score-0 '> __pyx_v_mean = (__pyx_v_mean / __pyx_v_num);
|
||||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">48</span>: <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre>
|
</pre>
|
||||||
<pre class='cython code score-0 '> __pyx_t_13 = (__pyx_v_ks + 1);
|
<pre class="cython line score-0"> <span class="">45</span>: </pre>
|
||||||
|
<pre class="cython line score-0"> <span class="">46</span>: <span
|
||||||
|
class="c"># calculate std dev</span></pre>
|
||||||
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">47</span>: <span class="n">stddev</span> <span class="o">=</span> <span
|
||||||
|
class="mf">0</span><span class="p">;</span></pre>
|
||||||
|
<pre class='cython code score-0 '> __pyx_v_stddev = 0.0;
|
||||||
|
</pre>
|
||||||
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">48</span>: <span class="k">for</span> <span class="n">i</span> <span
|
||||||
|
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span
|
||||||
|
class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span
|
||||||
|
class="mf">1</span><span class="p">):</span></pre>
|
||||||
|
<pre class='cython code score-0 '> __pyx_t_13 = (__pyx_v_ks + 1);
|
||||||
__pyx_t_14 = __pyx_t_13;
|
__pyx_t_14 = __pyx_t_13;
|
||||||
for (__pyx_t_15 = (-__pyx_v_ks); __pyx_t_15 < __pyx_t_14; __pyx_t_15+=1) {
|
for (__pyx_t_15 = (-__pyx_v_ks); __pyx_t_15 < __pyx_t_14; __pyx_t_15+=1) {
|
||||||
__pyx_v_i = __pyx_t_15;
|
__pyx_v_i = __pyx_t_15;
|
||||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">49</span>: <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre>
|
</pre>
|
||||||
<pre class='cython code score-0 '> __pyx_t_16 = (__pyx_v_ks + 1);
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">49</span>: <span class="k">for</span> <span class="n">j</span> <span
|
||||||
|
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span
|
||||||
|
class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span
|
||||||
|
class="mf">1</span><span class="p">):</span></pre>
|
||||||
|
<pre class='cython code score-0 '> __pyx_t_16 = (__pyx_v_ks + 1);
|
||||||
__pyx_t_17 = __pyx_t_16;
|
__pyx_t_17 = __pyx_t_16;
|
||||||
for (__pyx_t_18 = (-__pyx_v_ks); __pyx_t_18 < __pyx_t_17; __pyx_t_18+=1) {
|
for (__pyx_t_18 = (-__pyx_v_ks); __pyx_t_18 < __pyx_t_17; __pyx_t_18+=1) {
|
||||||
__pyx_v_j = __pyx_t_18;
|
__pyx_v_j = __pyx_t_18;
|
||||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">50</span>: <span class="n">stddev</span> <span class="o">=</span> <span class="n">stddev</span> <span class="o">+</span> <span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span><span class="o">*</span><span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span></pre>
|
</pre>
|
||||||
<pre class='cython code score-0 '> __pyx_t_21 = (__pyx_v_m + __pyx_v_i);
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">50</span>: <span class="n">stddev</span> <span class="o">=</span> <span
|
||||||
|
class="n">stddev</span> <span class="o">+</span> <span class="p">(</span><span class="n">img</span><span
|
||||||
|
class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span
|
||||||
|
class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span
|
||||||
|
class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span><span
|
||||||
|
class="o">*</span><span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span
|
||||||
|
class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span
|
||||||
|
class="o">+</span><span class="n">j</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span
|
||||||
|
class="p">)</span></pre>
|
||||||
|
<pre class='cython code score-0 '> __pyx_t_21 = (__pyx_v_m + __pyx_v_i);
|
||||||
__pyx_t_22 = (__pyx_v_n + __pyx_v_j);
|
__pyx_t_22 = (__pyx_v_n + __pyx_v_j);
|
||||||
__pyx_t_23 = (__pyx_v_m + __pyx_v_i);
|
__pyx_t_23 = (__pyx_v_m + __pyx_v_i);
|
||||||
__pyx_t_24 = (__pyx_v_n + __pyx_v_j);
|
__pyx_t_24 = (__pyx_v_n + __pyx_v_j);
|
||||||
__pyx_v_stddev = (__pyx_v_stddev + (((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_21 * __pyx_v_img.strides[0]) ) + __pyx_t_22 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean) * ((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_23 * __pyx_v_img.strides[0]) ) + __pyx_t_24 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean)));
|
__pyx_v_stddev = (__pyx_v_stddev + (((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_21 * __pyx_v_img.strides[0]) ) + __pyx_t_22 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean) * ((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_23 * __pyx_v_img.strides[0]) ) + __pyx_t_24 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">51</span>: <span class="n">stddev</span> <span class="o">=</span> <span class="n">sqrt</span><span class="p">(</span><span class="n">stddev</span><span class="o">/</span><span class="n">num</span><span class="p">)</span></pre>
|
</pre>
|
||||||
<pre class='cython code score-0 '> __pyx_v_stddev = sqrt((__pyx_v_stddev / __pyx_v_num));
|
<pre class="cython line score-0"
|
||||||
</pre><pre class="cython line score-0"> <span class="">52</span>: </pre>
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
<pre class="cython line score-0"> <span class="">53</span>: <span class="c"># compute normalized image (add epsilon) and std dev image</span></pre>
|
class="">51</span>: <span class="n">stddev</span> <span class="o">=</span> <span
|
||||||
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">54</span>: <span class="n">img_lcn_view</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span><span class="o">/</span><span class="p">(</span><span class="n">stddev</span><span class="o">+</span><span class="n">eps</span><span class="p">)</span></pre>
|
class="n">sqrt</span><span class="p">(</span><span class="n">stddev</span><span class="o">/</span><span
|
||||||
<pre class='cython code score-0 '> __pyx_t_25 = __pyx_v_m;
|
class="n">num</span><span class="p">)</span></pre>
|
||||||
|
<pre class='cython code score-0 '> __pyx_v_stddev = sqrt((__pyx_v_stddev / __pyx_v_num));
|
||||||
|
</pre>
|
||||||
|
<pre class="cython line score-0"> <span class="">52</span>: </pre>
|
||||||
|
<pre class="cython line score-0"> <span class="">53</span>: <span class="c"># compute normalized image (add epsilon) and std dev image</span></pre>
|
||||||
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">54</span>: <span class="n">img_lcn_view</span><span class="p">[</span><span
|
||||||
|
class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span
|
||||||
|
class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span
|
||||||
|
class="n">n</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span
|
||||||
|
class="p">)</span><span class="o">/</span><span class="p">(</span><span class="n">stddev</span><span
|
||||||
|
class="o">+</span><span class="n">eps</span><span class="p">)</span></pre>
|
||||||
|
<pre class='cython code score-0 '> __pyx_t_25 = __pyx_v_m;
|
||||||
__pyx_t_26 = __pyx_v_n;
|
__pyx_t_26 = __pyx_v_n;
|
||||||
__pyx_t_27 = __pyx_v_m;
|
__pyx_t_27 = __pyx_v_m;
|
||||||
__pyx_t_28 = __pyx_v_n;
|
__pyx_t_28 = __pyx_v_n;
|
||||||
*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img_lcn_view.data + __pyx_t_27 * __pyx_v_img_lcn_view.strides[0]) ) + __pyx_t_28 * __pyx_v_img_lcn_view.strides[1]) )) = (((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_25 * __pyx_v_img.strides[0]) ) + __pyx_t_26 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean) / (__pyx_v_stddev + __pyx_v_eps));
|
*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img_lcn_view.data + __pyx_t_27 * __pyx_v_img_lcn_view.strides[0]) ) + __pyx_t_28 * __pyx_v_img_lcn_view.strides[1]) )) = (((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_25 * __pyx_v_img.strides[0]) ) + __pyx_t_26 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean) / (__pyx_v_stddev + __pyx_v_eps));
|
||||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">55</span>: <span class="n">img_std_view</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span class="n">stddev</span></pre>
|
</pre>
|
||||||
<pre class='cython code score-0 '> __pyx_t_29 = __pyx_v_m;
|
<pre class="cython line score-0"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">55</span>: <span class="n">img_std_view</span><span class="p">[</span><span
|
||||||
|
class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span
|
||||||
|
class="n">stddev</span></pre>
|
||||||
|
<pre class='cython code score-0 '> __pyx_t_29 = __pyx_v_m;
|
||||||
__pyx_t_30 = __pyx_v_n;
|
__pyx_t_30 = __pyx_v_n;
|
||||||
*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img_std_view.data + __pyx_t_29 * __pyx_v_img_std_view.strides[0]) ) + __pyx_t_30 * __pyx_v_img_std_view.strides[1]) )) = __pyx_v_stddev;
|
*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img_std_view.data + __pyx_t_29 * __pyx_v_img_std_view.strides[0]) ) + __pyx_t_30 * __pyx_v_img_std_view.strides[1]) )) = __pyx_v_stddev;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
</pre><pre class="cython line score-0"> <span class="">56</span>: </pre>
|
</pre>
|
||||||
<pre class="cython line score-0"> <span class="">57</span>: <span class="c"># return both</span></pre>
|
<pre class="cython line score-0"> <span class="">56</span>: </pre>
|
||||||
<pre class="cython line score-10" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">58</span>: <span class="k">return</span> <span class="n">img_lcn</span><span class="p">,</span> <span class="n">img_std</span></pre>
|
<pre class="cython line score-0"> <span class="">57</span>: <span class="c"># return both</span></pre>
|
||||||
<pre class='cython code score-10 '> <span class='pyx_macro_api'>__Pyx_XDECREF</span>(__pyx_r);
|
<pre class="cython line score-10"
|
||||||
|
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
|
||||||
|
class="">58</span>: <span class="k">return</span> <span class="n">img_lcn</span><span class="p">,</span> <span
|
||||||
|
class="n">img_std</span></pre>
|
||||||
|
<pre class='cython code score-10 '> <span class='pyx_macro_api'>__Pyx_XDECREF</span>(__pyx_r);
|
||||||
__pyx_t_1 = <span class='py_c_api'>PyTuple_New</span>(2);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 58, __pyx_L1_error)</span>
|
__pyx_t_1 = <span class='py_c_api'>PyTuple_New</span>(2);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 58, __pyx_L1_error)</span>
|
||||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
||||||
<span class='pyx_macro_api'>__Pyx_INCREF</span>(__pyx_v_img_lcn);
|
<span class='pyx_macro_api'>__Pyx_INCREF</span>(__pyx_v_img_lcn);
|
||||||
@ -717,4 +903,7 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
|
|||||||
__pyx_r = __pyx_t_1;
|
__pyx_r = __pyx_t_1;
|
||||||
__pyx_t_1 = 0;
|
__pyx_t_1 = 0;
|
||||||
goto __pyx_L0;
|
goto __pyx_L0;
|
||||||
</pre></div></body></html>
|
</pre>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
@ -2,5 +2,5 @@ from distutils.core import setup
|
|||||||
from Cython.Build import cythonize
|
from Cython.Build import cythonize
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
ext_modules = cythonize("lcn.pyx",annotate=True)
|
ext_modules=cythonize("lcn.pyx", annotate=True)
|
||||||
)
|
)
|
||||||
|
@ -5,43 +5,43 @@ from scipy import misc
|
|||||||
|
|
||||||
# load and convert to float
|
# load and convert to float
|
||||||
img = misc.imread('img.png')
|
img = misc.imread('img.png')
|
||||||
img = img.astype(np.float32)/255.0
|
img = img.astype(np.float32) / 255.0
|
||||||
|
|
||||||
# normalize
|
# normalize
|
||||||
img_lcn, img_std = lcn.normalize(img,5,0.05)
|
img_lcn, img_std = lcn.normalize(img, 5, 0.05)
|
||||||
|
|
||||||
# normalize to reasonable range between 0 and 1
|
# normalize to reasonable range between 0 and 1
|
||||||
#img_lcn = img_lcn/3.0
|
# img_lcn = img_lcn/3.0
|
||||||
#img_lcn = np.maximum(img_lcn,0.0)
|
# img_lcn = np.maximum(img_lcn,0.0)
|
||||||
#img_lcn = np.minimum(img_lcn,1.0)
|
# img_lcn = np.minimum(img_lcn,1.0)
|
||||||
|
|
||||||
# save to file
|
# save to file
|
||||||
#misc.imsave('lcn2.png',img_lcn)
|
# misc.imsave('lcn2.png',img_lcn)
|
||||||
|
|
||||||
print ("Orig Image: %d x %d (%s), Min: %f, Max: %f" % \
|
print("Orig Image: %d x %d (%s), Min: %f, Max: %f" % \
|
||||||
(img.shape[0], img.shape[1], img.dtype, img.min(), img.max()))
|
(img.shape[0], img.shape[1], img.dtype, img.min(), img.max()))
|
||||||
print ("Norm Image: %d x %d (%s), Min: %f, Max: %f" % \
|
print("Norm Image: %d x %d (%s), Min: %f, Max: %f" % \
|
||||||
(img_lcn.shape[0], img_lcn.shape[1], img_lcn.dtype, img_lcn.min(), img_lcn.max()))
|
(img_lcn.shape[0], img_lcn.shape[1], img_lcn.dtype, img_lcn.min(), img_lcn.max()))
|
||||||
|
|
||||||
# plot original image
|
# plot original image
|
||||||
plt.figure(1)
|
plt.figure(1)
|
||||||
img_plot = plt.imshow(img)
|
img_plot = plt.imshow(img)
|
||||||
img_plot.set_cmap('gray')
|
img_plot.set_cmap('gray')
|
||||||
plt.clim(0, 1) # fix range
|
plt.clim(0, 1) # fix range
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
|
|
||||||
# plot normalized image
|
# plot normalized image
|
||||||
plt.figure(2)
|
plt.figure(2)
|
||||||
img_lcn_plot = plt.imshow(img_lcn)
|
img_lcn_plot = plt.imshow(img_lcn)
|
||||||
img_lcn_plot.set_cmap('gray')
|
img_lcn_plot.set_cmap('gray')
|
||||||
#plt.clim(0, 1) # fix range
|
# plt.clim(0, 1) # fix range
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
|
|
||||||
# plot stddev image
|
# plot stddev image
|
||||||
plt.figure(3)
|
plt.figure(3)
|
||||||
img_std_plot = plt.imshow(img_std)
|
img_std_plot = plt.imshow(img_std)
|
||||||
img_std_plot.set_cmap('gray')
|
img_std_plot.set_cmap('gray')
|
||||||
#plt.clim(0, 0.1) # fix range
|
# plt.clim(0, 0.1) # fix range
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
|
|
||||||
plt.show()
|
plt.show()
|
||||||
|
@ -11,28 +11,27 @@ import dataset
|
|||||||
|
|
||||||
|
|
||||||
def get_data(n, row_from, row_to, train):
|
def get_data(n, row_from, row_to, train):
|
||||||
imsizes = [(256,384)]
|
imsizes = [(256, 384)]
|
||||||
focal_lengths = [160]
|
focal_lengths = [160]
|
||||||
dset = dataset.SynDataset(n, imsizes=imsizes, focal_lengths=focal_lengths, train=train)
|
dset = dataset.SynDataset(n, imsizes=imsizes, focal_lengths=focal_lengths, train=train)
|
||||||
ims = np.empty((n, row_to-row_from, imsizes[0][1]), dtype=np.uint8)
|
ims = np.empty((n, row_to - row_from, imsizes[0][1]), dtype=np.uint8)
|
||||||
disps = np.empty((n, row_to-row_from, imsizes[0][1]), dtype=np.float32)
|
disps = np.empty((n, row_to - row_from, imsizes[0][1]), dtype=np.float32)
|
||||||
for idx in range(n):
|
for idx in range(n):
|
||||||
print(f'load sample {idx} train={train}')
|
print(f'load sample {idx} train={train}')
|
||||||
sample = dset[idx]
|
sample = dset[idx]
|
||||||
ims[idx] = (sample['im0'][0,row_from:row_to] * 255).astype(np.uint8)
|
ims[idx] = (sample['im0'][0, row_from:row_to] * 255).astype(np.uint8)
|
||||||
disps[idx] = sample['disp0'][0,row_from:row_to]
|
disps[idx] = sample['disp0'][0, row_from:row_to]
|
||||||
return ims, disps
|
return ims, disps
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
params = hd.TrainParams(
|
params = hd.TrainParams(
|
||||||
n_trees=4,
|
n_trees=4,
|
||||||
max_tree_depth=,
|
max_tree_depth=,
|
||||||
n_test_split_functions=50,
|
n_test_split_functions=50,
|
||||||
n_test_thresholds=10,
|
n_test_thresholds=10,
|
||||||
n_test_samples=4096,
|
n_test_samples=4096,
|
||||||
min_samples_to_split=16,
|
min_samples_to_split=16,
|
||||||
min_samples_for_leaf=8)
|
min_samples_for_leaf=8)
|
||||||
|
|
||||||
n_disp_bins = 20
|
n_disp_bins = 20
|
||||||
depth_switch = 0
|
depth_switch = 0
|
||||||
@ -45,21 +44,23 @@ n_test_samples = 32
|
|||||||
train_ims, train_disps = get_data(n_train_samples, row_from, row_to, True)
|
train_ims, train_disps = get_data(n_train_samples, row_from, row_to, True)
|
||||||
test_ims, test_disps = get_data(n_test_samples, row_from, row_to, False)
|
test_ims, test_disps = get_data(n_test_samples, row_from, row_to, False)
|
||||||
|
|
||||||
for tree_depth in [8,10,12,14,16]:
|
for tree_depth in [8, 10, 12, 14, 16]:
|
||||||
depth_switch = tree_depth - 4
|
depth_switch = tree_depth - 4
|
||||||
|
|
||||||
prefix = f'td{tree_depth}_ds{depth_switch}'
|
prefix = f'td{tree_depth}_ds{depth_switch}'
|
||||||
prefix = Path(f'./forests/{prefix}/')
|
prefix = Path(f'./forests/{prefix}/')
|
||||||
prefix.mkdir(parents=True, exist_ok=True)
|
prefix.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
hd.train_forest(params, train_ims, train_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch, forest_prefix=str(prefix / 'fr'))
|
hd.train_forest(params, train_ims, train_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch,
|
||||||
|
forest_prefix=str(prefix / 'fr'))
|
||||||
|
|
||||||
es = hd.eval_forest(test_ims, test_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch, forest_prefix=str(prefix / 'fr'))
|
es = hd.eval_forest(test_ims, test_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch,
|
||||||
|
forest_prefix=str(prefix / 'fr'))
|
||||||
|
|
||||||
np.save(str(prefix / 'ta.npy'), test_disps)
|
np.save(str(prefix / 'ta.npy'), test_disps)
|
||||||
np.save(str(prefix / 'es.npy'), es)
|
np.save(str(prefix / 'es.npy'), es)
|
||||||
|
|
||||||
# plt.figure();
|
# plt.figure();
|
||||||
# plt.subplot(2,1,1); plt.imshow(test_disps[0], vmin=0, vmax=4);
|
# plt.subplot(2,1,1); plt.imshow(test_disps[0], vmin=0, vmax=4);
|
||||||
# plt.subplot(2,1,2); plt.imshow(es[0], vmin=0, vmax=4);
|
# plt.subplot(2,1,2); plt.imshow(es[0], vmin=0, vmax=4);
|
||||||
# plt.show()
|
# plt.show()
|
||||||
|
@ -8,7 +8,6 @@ import os
|
|||||||
|
|
||||||
this_dir = os.path.dirname(__file__)
|
this_dir = os.path.dirname(__file__)
|
||||||
|
|
||||||
|
|
||||||
extra_compile_args = ['-O3', '-std=c++11']
|
extra_compile_args = ['-O3', '-std=c++11']
|
||||||
extra_link_args = []
|
extra_link_args = []
|
||||||
|
|
||||||
@ -22,24 +21,20 @@ library_dirs = []
|
|||||||
libraries = ['m']
|
libraries = ['m']
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="hyperdepth",
|
name="hyperdepth",
|
||||||
cmdclass= {'build_ext': build_ext},
|
cmdclass={'build_ext': build_ext},
|
||||||
ext_modules=[
|
ext_modules=[
|
||||||
Extension('hyperdepth',
|
Extension('hyperdepth',
|
||||||
sources,
|
sources,
|
||||||
extra_objects=extra_objects,
|
extra_objects=extra_objects,
|
||||||
language='c++',
|
language='c++',
|
||||||
library_dirs=library_dirs,
|
library_dirs=library_dirs,
|
||||||
libraries=libraries,
|
libraries=libraries,
|
||||||
include_dirs=[
|
include_dirs=[
|
||||||
np.get_include(),
|
np.get_include(),
|
||||||
],
|
],
|
||||||
extra_compile_args=extra_compile_args,
|
extra_compile_args=extra_compile_args,
|
||||||
extra_link_args=extra_link_args
|
extra_link_args=extra_link_args
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,10 +6,13 @@ orig = cv2.imread('disp_orig.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
|
|||||||
ta = cv2.imread('disp_ta.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
|
ta = cv2.imread('disp_ta.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
|
||||||
es = cv2.imread('disp_es.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
|
es = cv2.imread('disp_es.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
plt.figure()
|
plt.figure()
|
||||||
plt.subplot(2,2,1); plt.imshow(orig / 16, vmin=0, vmax=4, cmap='magma')
|
plt.subplot(2, 2, 1);
|
||||||
plt.subplot(2,2,2); plt.imshow(ta / 16, vmin=0, vmax=4, cmap='magma')
|
plt.imshow(orig / 16, vmin=0, vmax=4, cmap='magma')
|
||||||
plt.subplot(2,2,3); plt.imshow(es / 16, vmin=0, vmax=4, cmap='magma')
|
plt.subplot(2, 2, 2);
|
||||||
plt.subplot(2,2,4); plt.imshow(np.abs(es - ta) / 16, vmin=0, vmax=1, cmap='magma')
|
plt.imshow(ta / 16, vmin=0, vmax=4, cmap='magma')
|
||||||
|
plt.subplot(2, 2, 3);
|
||||||
|
plt.imshow(es / 16, vmin=0, vmax=4, cmap='magma')
|
||||||
|
plt.subplot(2, 2, 4);
|
||||||
|
plt.imshow(np.abs(es - ta) / 16, vmin=0, vmax=1, cmap='magma')
|
||||||
plt.show()
|
plt.show()
|
||||||
|
@ -12,226 +12,263 @@ import torchext
|
|||||||
from model import networks
|
from model import networks
|
||||||
from data import dataset
|
from data import dataset
|
||||||
|
|
||||||
|
|
||||||
class Worker(torchext.Worker):
|
class Worker(torchext.Worker):
|
||||||
def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs):
|
def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs):
|
||||||
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, train_batch_size=train_batch_size, test_batch_size=test_batch_size, save_frequency=save_frequency, **kwargs)
|
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers,
|
||||||
|
train_batch_size=train_batch_size, test_batch_size=test_batch_size,
|
||||||
|
save_frequency=save_frequency, **kwargs)
|
||||||
|
|
||||||
self.ms = args.ms
|
self.ms = args.ms
|
||||||
self.pattern_path = args.pattern_path
|
self.pattern_path = args.pattern_path
|
||||||
self.lcn_radius = args.lcn_radius
|
self.lcn_radius = args.lcn_radius
|
||||||
self.dp_weight = args.dp_weight
|
self.dp_weight = args.dp_weight
|
||||||
self.data_type = args.data_type
|
self.data_type = args.data_type
|
||||||
|
|
||||||
self.imsizes = [(480,640)]
|
self.imsizes = [(488, 648)]
|
||||||
for iter in range(3):
|
for iter in range(3):
|
||||||
self.imsizes.append((int(self.imsizes[-1][0]/2), int(self.imsizes[-1][1]/2)))
|
self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2)))
|
||||||
|
|
||||||
with open('config.json') as fp:
|
with open('config.json') as fp:
|
||||||
config = json.load(fp)
|
config = json.load(fp)
|
||||||
data_root = Path(config['DATA_ROOT'])
|
data_root = Path(config['DATA_ROOT'])
|
||||||
self.settings_path = data_root / self.data_type / 'settings.pkl'
|
self.settings_path = data_root / self.data_type / 'settings.pkl'
|
||||||
sample_paths = sorted((data_root / self.data_type).glob('0*/'))
|
sample_paths = sorted((data_root / self.data_type).glob('0*/'))
|
||||||
|
|
||||||
self.train_paths = sample_paths[2**10:]
|
self.train_paths = sample_paths[2 ** 10:]
|
||||||
self.test_paths = sample_paths[:2**8]
|
self.test_paths = sample_paths[:2 ** 8]
|
||||||
|
|
||||||
# supervise the edge encoder with only 2**8 samples
|
# supervise the edge encoder with only 2**8 samples
|
||||||
self.train_edge = len(self.train_paths) - 2**8
|
self.train_edge = len(self.train_paths) - 2 ** 8
|
||||||
|
|
||||||
self.lcn_in = networks.LCN(self.lcn_radius, 0.05)
|
self.lcn_in = networks.LCN(self.lcn_radius, 0.05)
|
||||||
self.disparity_loss = networks.DisparityLoss()
|
self.disparity_loss = networks.DisparityLoss()
|
||||||
self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device))
|
self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device))
|
||||||
|
|
||||||
# evaluate in the region where opencv Block Matching has valid values
|
# evaluate in the region where opencv Block Matching has valid values
|
||||||
self.eval_mask = np.zeros(self.imsizes[0])
|
self.eval_mask = np.zeros(self.imsizes[0])
|
||||||
self.eval_mask[13:self.imsizes[0][0]-13, 140:self.imsizes[0][1]-13]=1
|
self.eval_mask[13:self.imsizes[0][0] - 13, 140:self.imsizes[0][1] - 13] = 1
|
||||||
self.eval_mask = self.eval_mask.astype(np.bool)
|
self.eval_mask = self.eval_mask.astype(np.bool)
|
||||||
self.eval_h = self.imsizes[0][0]-2*13
|
self.eval_h = self.imsizes[0][0] - 2 * 13
|
||||||
self.eval_w = self.imsizes[0][1]-13-140
|
self.eval_w = self.imsizes[0][1] - 13 - 140
|
||||||
|
|
||||||
def get_train_set(self):
|
def get_train_set(self):
|
||||||
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True, track_length=1)
|
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True,
|
||||||
|
track_length=1)
|
||||||
|
|
||||||
return train_set
|
return train_set
|
||||||
|
|
||||||
def get_test_sets(self):
|
def get_test_sets(self):
|
||||||
test_sets = torchext.TestSets()
|
test_sets = torchext.TestSets()
|
||||||
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True, track_length=1)
|
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True,
|
||||||
test_sets.append('simple', test_set, test_frequency=1)
|
track_length=1)
|
||||||
|
test_sets.append('simple', test_set, test_frequency=1)
|
||||||
|
|
||||||
# initialize photometric loss modules according to image sizes
|
# initialize photometric loss modules according to image sizes
|
||||||
self.losses = []
|
self.losses = []
|
||||||
for imsize, pat in zip(test_set.imsizes, test_set.patterns):
|
for imsize, pat in zip(test_set.imsizes, test_set.patterns):
|
||||||
pat = pat.mean(axis=2)
|
pat = pat.mean(axis=2)
|
||||||
pat = torch.from_numpy(pat[None][None].astype(np.float32))
|
pat = torch.from_numpy(pat[None][None].astype(np.float32))
|
||||||
pat = pat.to(self.train_device)
|
pat = pat.to(self.train_device)
|
||||||
self.lcn_in = self.lcn_in.to(self.train_device)
|
self.lcn_in = self.lcn_in.to(self.train_device)
|
||||||
pat,_ = self.lcn_in(pat)
|
pat, _ = self.lcn_in(pat)
|
||||||
pat = torch.cat([pat for idx in range(3)], dim=1)
|
pat = torch.cat([pat for idx in range(3)], dim=1)
|
||||||
self.losses.append( networks.RectifiedPatternSimilarityLoss(imsize[0],imsize[1], pattern=pat) )
|
self.losses.append(networks.RectifiedPatternSimilarityLoss(imsize[0], imsize[1], pattern=pat))
|
||||||
|
|
||||||
return test_sets
|
return test_sets
|
||||||
|
|
||||||
def copy_data(self, data, device, requires_grad, train):
|
def copy_data(self, data, device, requires_grad, train):
|
||||||
self.lcn_in = self.lcn_in.to(device)
|
self.lcn_in = self.lcn_in.to(device)
|
||||||
|
|
||||||
self.data = {}
|
self.data = {}
|
||||||
for key, val in data.items():
|
for key, val in data.items():
|
||||||
grad = 'im' in key and requires_grad
|
grad = 'im' in key and requires_grad
|
||||||
self.data[key] = val.to(device).requires_grad_(requires_grad=grad)
|
self.data[key] = val.to(device).requires_grad_(requires_grad=grad)
|
||||||
|
|
||||||
# apply lcn to IR input
|
# apply lcn to IR input
|
||||||
# concatenate the normalized IR input and the original IR image
|
# concatenate the normalized IR input and the original IR image
|
||||||
if 'im' in key and 'blend' not in key:
|
if 'im' in key and 'blend' not in key:
|
||||||
im = self.data[key]
|
im = self.data[key]
|
||||||
im_lcn,im_std = self.lcn_in(im)
|
im_lcn, im_std = self.lcn_in(im)
|
||||||
im_cat = torch.cat((im_lcn, im), dim=1)
|
im_cat = torch.cat((im_lcn, im), dim=1)
|
||||||
key_std = key.replace('im','std')
|
key_std = key.replace('im', 'std')
|
||||||
self.data[key]=im_cat
|
self.data[key] = im_cat
|
||||||
self.data[key_std] = im_std.to(device).detach()
|
self.data[key_std] = im_std.to(device).detach()
|
||||||
|
|
||||||
def net_forward(self, net, train):
|
def net_forward(self, net, train):
|
||||||
out = net(self.data['im0'])
|
out = net(self.data['im0'])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def loss_forward(self, out, train):
|
def loss_forward(self, out, train):
|
||||||
out, edge = out
|
out, edge = out
|
||||||
if not(isinstance(out, tuple) or isinstance(out, list)):
|
if not (isinstance(out, tuple) or isinstance(out, list)):
|
||||||
out = [out]
|
out = [out]
|
||||||
if not(isinstance(edge, tuple) or isinstance(edge, list)):
|
if not (isinstance(edge, tuple) or isinstance(edge, list)):
|
||||||
edge = [edge]
|
edge = [edge]
|
||||||
|
|
||||||
vals = []
|
vals = []
|
||||||
|
|
||||||
# apply photometric loss
|
# apply photometric loss
|
||||||
for s,l,o in zip(itertools.count(), self.losses, out):
|
for s, l, o in zip(itertools.count(), self.losses, out):
|
||||||
val, pattern_proj = l(o, self.data[f'im{s}'][:,0:1,...], self.data[f'std{s}'])
|
val, pattern_proj = l(o, self.data[f'im{s}'][:, 0:1, ...], self.data[f'std{s}'])
|
||||||
if s == 0:
|
if s == 0:
|
||||||
self.pattern_proj = pattern_proj.detach()
|
self.pattern_proj = pattern_proj.detach()
|
||||||
vals.append(val)
|
vals.append(val)
|
||||||
|
|
||||||
# apply disparity loss
|
# apply disparity loss
|
||||||
# 1-edge as ground truth edge if inversed
|
# 1-edge as ground truth edge if inversed
|
||||||
edge0 = 1-torch.sigmoid(edge[0])
|
edge0 = 1 - torch.sigmoid(edge[0])
|
||||||
val = self.disparity_loss(out[0], edge0)
|
val = self.disparity_loss(out[0], edge0)
|
||||||
if self.dp_weight>0:
|
if self.dp_weight > 0:
|
||||||
vals.append(val * self.dp_weight)
|
vals.append(val * self.dp_weight)
|
||||||
|
|
||||||
# apply edge loss on a subset of training samples
|
# apply edge loss on a subset of training samples
|
||||||
for s,e in zip(itertools.count(), edge):
|
for s, e in zip(itertools.count(), edge):
|
||||||
# inversed ground truth edge where 0 means edge
|
# inversed ground truth edge where 0 means edge
|
||||||
grad = self.data[f'grad{s}']<0.2
|
grad = self.data[f'grad{s}'] < 0.2
|
||||||
grad = grad.to(torch.float32)
|
grad = grad.to(torch.float32)
|
||||||
ids = self.data['id']
|
ids = self.data['id']
|
||||||
mask = ids>self.train_edge
|
mask = ids > self.train_edge
|
||||||
if mask.sum()>0:
|
if mask.sum() > 0:
|
||||||
val = self.edge_loss(e[mask], grad[mask])
|
val = self.edge_loss(e[mask], grad[mask])
|
||||||
else:
|
else:
|
||||||
val = torch.zeros_like(vals[0])
|
val = torch.zeros_like(vals[0])
|
||||||
if s == 0:
|
if s == 0:
|
||||||
self.edge = e.detach()
|
self.edge = e.detach()
|
||||||
self.edge = torch.sigmoid(self.edge)
|
self.edge = torch.sigmoid(self.edge)
|
||||||
self.edge_gt = grad.detach()
|
self.edge_gt = grad.detach()
|
||||||
vals.append(val)
|
vals.append(val)
|
||||||
|
|
||||||
return vals
|
return vals
|
||||||
|
|
||||||
def numpy_in_out(self, output):
|
def numpy_in_out(self, output):
|
||||||
output, edge = output
|
output, edge = output
|
||||||
if not(isinstance(output, tuple) or isinstance(output, list)):
|
if not (isinstance(output, tuple) or isinstance(output, list)):
|
||||||
output = [output]
|
output = [output]
|
||||||
es = output[0].detach().to('cpu').numpy()
|
es = output[0].detach().to('cpu').numpy()
|
||||||
gt = self.data['disp0'].to('cpu').numpy().astype(np.float32)
|
gt = self.data['disp0'].to('cpu').numpy().astype(np.float32)
|
||||||
im = self.data['im0'][:,0:1,...].detach().to('cpu').numpy()
|
im = self.data['im0'][:, 0:1, ...].detach().to('cpu').numpy()
|
||||||
|
|
||||||
ma = gt>0
|
ma = gt > 0
|
||||||
return es, gt, im, ma
|
return es, gt, im, ma
|
||||||
|
|
||||||
def write_img(self, out_path, es, gt, im, ma):
|
def write_img(self, out_path, es, gt, im, ma):
|
||||||
logging.info(f'write img {out_path}')
|
logging.info(f'write img {out_path}')
|
||||||
u_pos, _ = np.meshgrid(range(es.shape[1]), range(es.shape[0]))
|
u_pos, _ = np.meshgrid(range(es.shape[1]), range(es.shape[0]))
|
||||||
|
|
||||||
diff = np.abs(es - gt)
|
diff = np.abs(es - gt)
|
||||||
|
|
||||||
vmin, vmax = np.nanmin(gt), np.nanmax(gt)
|
vmin, vmax = np.nanmin(gt), np.nanmax(gt)
|
||||||
vmin = vmin - 0.2*(vmax-vmin)
|
vmin = vmin - 0.2 * (vmax - vmin)
|
||||||
vmax = vmax + 0.2*(vmax-vmin)
|
vmax = vmax + 0.2 * (vmax - vmin)
|
||||||
|
|
||||||
pattern_proj = self.pattern_proj.to('cpu').numpy()[0,0]
|
pattern_proj = self.pattern_proj.to('cpu').numpy()[0, 0]
|
||||||
im_orig = self.data['im0'].detach().to('cpu').numpy()[0,0]
|
im_orig = self.data['im0'].detach().to('cpu').numpy()[0, 0]
|
||||||
pattern_diff = np.abs(im_orig - pattern_proj)
|
pattern_diff = np.abs(im_orig - pattern_proj)
|
||||||
|
|
||||||
|
fig = plt.figure(figsize=(16, 16))
|
||||||
|
es_ = co.cmap.color_depth_map(es, scale=vmax)
|
||||||
|
gt_ = co.cmap.color_depth_map(gt, scale=vmax)
|
||||||
|
diff_ = co.cmap.color_error_image(diff, BGR=True)
|
||||||
|
|
||||||
fig = plt.figure(figsize=(16,16))
|
# plot disparities, ground truth disparity is shown only for reference
|
||||||
es_ = co.cmap.color_depth_map(es, scale=vmax)
|
ax = plt.subplot(3, 3, 1)
|
||||||
gt_ = co.cmap.color_depth_map(gt, scale=vmax)
|
plt.imshow(es_[..., [2, 1, 0]])
|
||||||
diff_ = co.cmap.color_error_image(diff, BGR=True)
|
plt.xticks([])
|
||||||
|
plt.yticks([])
|
||||||
|
ax.set_title(f'Disparity Est. {es.min():.4f}/{es.max():.4f}')
|
||||||
|
ax = plt.subplot(3, 3, 2)
|
||||||
|
plt.imshow(gt_[..., [2, 1, 0]])
|
||||||
|
plt.xticks([])
|
||||||
|
plt.yticks([])
|
||||||
|
ax.set_title(f'Disparity GT {np.nanmin(gt):.4f}/{np.nanmax(gt):.4f}')
|
||||||
|
ax = plt.subplot(3, 3, 3)
|
||||||
|
plt.imshow(diff_[..., [2, 1, 0]])
|
||||||
|
plt.xticks([])
|
||||||
|
plt.yticks([])
|
||||||
|
ax.set_title(f'Disparity Err. {diff.mean():.5f}')
|
||||||
|
|
||||||
# plot disparities, ground truth disparity is shown only for reference
|
# plot edges
|
||||||
ax = plt.subplot(3,3,1); plt.imshow(es_[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'Disparity Est. {es.min():.4f}/{es.max():.4f}')
|
edge = self.edge.to('cpu').numpy()[0, 0]
|
||||||
ax = plt.subplot(3,3,2); plt.imshow(gt_[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'Disparity GT {np.nanmin(gt):.4f}/{np.nanmax(gt):.4f}')
|
edge_gt = self.edge_gt.to('cpu').numpy()[0, 0]
|
||||||
ax = plt.subplot(3,3,3); plt.imshow(diff_[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'Disparity Err. {diff.mean():.5f}')
|
edge_err = np.abs(edge - edge_gt)
|
||||||
|
ax = plt.subplot(3, 3, 4);
|
||||||
|
plt.imshow(edge, cmap='gray');
|
||||||
|
plt.xticks([]);
|
||||||
|
plt.yticks([]);
|
||||||
|
ax.set_title(f'Edge Est. {edge.min():.5f}/{edge.max():.5f}')
|
||||||
|
ax = plt.subplot(3, 3, 5);
|
||||||
|
plt.imshow(edge_gt, cmap='gray');
|
||||||
|
plt.xticks([]);
|
||||||
|
plt.yticks([]);
|
||||||
|
ax.set_title(f'Edge GT {edge_gt.min():.5f}/{edge_gt.max():.5f}')
|
||||||
|
ax = plt.subplot(3, 3, 6);
|
||||||
|
plt.imshow(edge_err, cmap='gray');
|
||||||
|
plt.xticks([]);
|
||||||
|
plt.yticks([]);
|
||||||
|
ax.set_title(f'Edge Err. {edge_err.mean():.5f}')
|
||||||
|
|
||||||
# plot edges
|
# plot normalized IR input and warped pattern
|
||||||
edge = self.edge.to('cpu').numpy()[0,0]
|
ax = plt.subplot(3, 3, 7);
|
||||||
edge_gt = self.edge_gt.to('cpu').numpy()[0,0]
|
plt.imshow(im, vmin=im.min(), vmax=im.max(), cmap='gray');
|
||||||
edge_err = np.abs(edge - edge_gt)
|
plt.xticks([]);
|
||||||
ax = plt.subplot(3,3,4); plt.imshow(edge, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Edge Est. {edge.min():.5f}/{edge.max():.5f}')
|
plt.yticks([]);
|
||||||
ax = plt.subplot(3,3,5); plt.imshow(edge_gt, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Edge GT {edge_gt.min():.5f}/{edge_gt.max():.5f}')
|
ax.set_title(f'IR input {im.mean():.5f}/{im.std():.5f}')
|
||||||
ax = plt.subplot(3,3,6); plt.imshow(edge_err, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Edge Err. {edge_err.mean():.5f}')
|
ax = plt.subplot(3, 3, 8);
|
||||||
|
plt.imshow(pattern_proj, vmin=im.min(), vmax=im.max(), cmap='gray');
|
||||||
|
plt.xticks([]);
|
||||||
|
plt.yticks([]);
|
||||||
|
ax.set_title(f'Warped Pattern {pattern_proj.mean():.5f}/{pattern_proj.std():.5f}')
|
||||||
|
im_std = self.data['std0'].to('cpu').numpy()[0, 0]
|
||||||
|
ax = plt.subplot(3, 3, 9);
|
||||||
|
plt.imshow(im_std, cmap='gray');
|
||||||
|
plt.xticks([]);
|
||||||
|
plt.yticks([]);
|
||||||
|
ax.set_title(f'IR std {im_std.min():.5f}/{im_std.max():.5f}')
|
||||||
|
|
||||||
# plot normalized IR input and warped pattern
|
plt.tight_layout()
|
||||||
ax = plt.subplot(3,3,7); plt.imshow(im, vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'IR input {im.mean():.5f}/{im.std():.5f}')
|
plt.savefig(str(out_path))
|
||||||
ax = plt.subplot(3,3,8); plt.imshow(pattern_proj, vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Warped Pattern {pattern_proj.mean():.5f}/{pattern_proj.std():.5f}')
|
plt.close(fig)
|
||||||
im_std = self.data['std0'].to('cpu').numpy()[0,0]
|
|
||||||
ax = plt.subplot(3,3,9); plt.imshow(im_std, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'IR std {im_std.min():.5f}/{im_std.max():.5f}')
|
|
||||||
|
|
||||||
plt.tight_layout()
|
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks=[]):
|
||||||
plt.savefig(str(out_path))
|
if batch_idx % 512 == 0:
|
||||||
plt.close(fig)
|
out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png'
|
||||||
|
es, gt, im, ma = self.numpy_in_out(output)
|
||||||
|
self.write_img(out_path, es[0, 0], gt[0, 0], im[0, 0], ma[0, 0])
|
||||||
|
|
||||||
|
def callback_test_start(self, epoch, set_idx):
|
||||||
|
self.metric = co.metric.MultipleMetric(
|
||||||
|
co.metric.DistanceMetric(vec_length=1),
|
||||||
|
co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5])
|
||||||
|
)
|
||||||
|
|
||||||
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks=[]):
|
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks=[]):
|
||||||
if batch_idx % 512 == 0:
|
es, gt, im, ma = self.numpy_in_out(output)
|
||||||
out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png'
|
|
||||||
es, gt, im, ma = self.numpy_in_out(output)
|
|
||||||
self.write_img(out_path, es[0,0], gt[0,0], im[0,0], ma[0,0])
|
|
||||||
|
|
||||||
|
if batch_idx % 8 == 0:
|
||||||
|
out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png'
|
||||||
|
self.write_img(out_path, es[0, 0], gt[0, 0], im[0, 0], ma[0, 0])
|
||||||
|
|
||||||
def callback_test_start(self, epoch, set_idx):
|
es, gt, im, ma = self.crop_output(es, gt, im, ma)
|
||||||
self.metric = co.metric.MultipleMetric(
|
|
||||||
co.metric.DistanceMetric(vec_length=1),
|
|
||||||
co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5])
|
|
||||||
)
|
|
||||||
|
|
||||||
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks=[]):
|
es = es.reshape(-1, 1)
|
||||||
es, gt, im, ma = self.numpy_in_out(output)
|
gt = gt.reshape(-1, 1)
|
||||||
|
ma = ma.ravel()
|
||||||
|
self.metric.add(es, gt, ma)
|
||||||
|
|
||||||
if batch_idx % 8 == 0:
|
def callback_test_stop(self, epoch, set_idx, loss):
|
||||||
out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png'
|
logging.info(f'{self.metric}')
|
||||||
self.write_img(out_path, es[0,0], gt[0,0], im[0,0], ma[0,0])
|
for k, v in self.metric.items():
|
||||||
|
self.metric_add_test(epoch, set_idx, k, v)
|
||||||
es, gt, im, ma = self.crop_output(es, gt, im, ma)
|
|
||||||
|
|
||||||
es = es.reshape(-1,1)
|
|
||||||
gt = gt.reshape(-1,1)
|
|
||||||
ma = ma.ravel()
|
|
||||||
self.metric.add(es, gt, ma)
|
|
||||||
|
|
||||||
def callback_test_stop(self, epoch, set_idx, loss):
|
|
||||||
logging.info(f'{self.metric}')
|
|
||||||
for k, v in self.metric.items():
|
|
||||||
self.metric_add_test(epoch, set_idx, k, v)
|
|
||||||
|
|
||||||
def crop_output(self, es, gt, im, ma):
|
|
||||||
bs = es.shape[0]
|
|
||||||
es = np.reshape(es[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
|
|
||||||
gt = np.reshape(gt[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
|
|
||||||
im = np.reshape(im[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
|
|
||||||
ma = np.reshape(ma[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
|
|
||||||
return es, gt, im, ma
|
|
||||||
|
|
||||||
|
def crop_output(self, es, gt, im, ma):
|
||||||
|
bs = es.shape[0]
|
||||||
|
es = np.reshape(es[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
|
||||||
|
gt = np.reshape(gt[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
|
||||||
|
im = np.reshape(im[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
|
||||||
|
ma = np.reshape(ma[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
|
||||||
|
return es, gt, im, ma
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pass
|
pass
|
||||||
|
@ -12,287 +12,324 @@ import torchext
|
|||||||
from model import networks
|
from model import networks
|
||||||
from data import dataset
|
from data import dataset
|
||||||
|
|
||||||
|
|
||||||
class Worker(torchext.Worker):
|
class Worker(torchext.Worker):
|
||||||
def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs):
|
def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs):
|
||||||
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, train_batch_size=train_batch_size, test_batch_size=test_batch_size, save_frequency=save_frequency, **kwargs)
|
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers,
|
||||||
|
train_batch_size=train_batch_size, test_batch_size=test_batch_size,
|
||||||
|
save_frequency=save_frequency, **kwargs)
|
||||||
|
|
||||||
self.ms = args.ms
|
self.ms = args.ms
|
||||||
self.pattern_path = args.pattern_path
|
self.pattern_path = args.pattern_path
|
||||||
self.lcn_radius = args.lcn_radius
|
self.lcn_radius = args.lcn_radius
|
||||||
self.dp_weight = args.dp_weight
|
self.dp_weight = args.dp_weight
|
||||||
self.ge_weight = args.ge_weight
|
self.ge_weight = args.ge_weight
|
||||||
self.track_length = args.track_length
|
self.track_length = args.track_length
|
||||||
self.data_type = args.data_type
|
self.data_type = args.data_type
|
||||||
assert(self.track_length>1)
|
assert (self.track_length > 1)
|
||||||
|
|
||||||
self.imsizes = [(480,640)]
|
self.imsizes = [(480, 640)]
|
||||||
for iter in range(3):
|
for iter in range(3):
|
||||||
self.imsizes.append((int(self.imsizes[-1][0]/2), int(self.imsizes[-1][1]/2)))
|
self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2)))
|
||||||
|
|
||||||
with open('config.json') as fp:
|
with open('config.json') as fp:
|
||||||
config = json.load(fp)
|
config = json.load(fp)
|
||||||
data_root = Path(config['DATA_ROOT'])
|
data_root = Path(config['DATA_ROOT'])
|
||||||
self.settings_path = data_root / self.data_type / 'settings.pkl'
|
self.settings_path = data_root / self.data_type / 'settings.pkl'
|
||||||
sample_paths = sorted((data_root / self.data_type).glob('0*/'))
|
sample_paths = sorted((data_root / self.data_type).glob('0*/'))
|
||||||
|
|
||||||
self.train_paths = sample_paths[2**10:]
|
self.train_paths = sample_paths[2 ** 10:]
|
||||||
self.test_paths = sample_paths[:2**8]
|
self.test_paths = sample_paths[:2 ** 8]
|
||||||
|
|
||||||
# supervise the edge encoder with only 2**8 samples
|
# supervise the edge encoder with only 2**8 samples
|
||||||
self.train_edge = len(self.train_paths) - 2**8
|
self.train_edge = len(self.train_paths) - 2 ** 8
|
||||||
|
|
||||||
self.lcn_in = networks.LCN(self.lcn_radius, 0.05)
|
self.lcn_in = networks.LCN(self.lcn_radius, 0.05)
|
||||||
self.disparity_loss = networks.DisparityLoss()
|
self.disparity_loss = networks.DisparityLoss()
|
||||||
self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device))
|
self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device))
|
||||||
|
|
||||||
# evaluate in the region where opencv Block Matching has valid values
|
# evaluate in the region where opencv Block Matching has valid values
|
||||||
self.eval_mask = np.zeros(self.imsizes[0])
|
self.eval_mask = np.zeros(self.imsizes[0])
|
||||||
self.eval_mask[13:self.imsizes[0][0]-13, 140:self.imsizes[0][1]-13]=1
|
self.eval_mask[13:self.imsizes[0][0] - 13, 140:self.imsizes[0][1] - 13] = 1
|
||||||
self.eval_mask = self.eval_mask.astype(np.bool)
|
self.eval_mask = self.eval_mask.astype(np.bool)
|
||||||
self.eval_h = self.imsizes[0][0]-2*13
|
self.eval_h = self.imsizes[0][0] - 2 * 13
|
||||||
self.eval_w = self.imsizes[0][1]-13-140
|
self.eval_w = self.imsizes[0][1] - 13 - 140
|
||||||
|
|
||||||
|
def get_train_set(self):
|
||||||
|
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True,
|
||||||
|
track_length=self.track_length)
|
||||||
|
return train_set
|
||||||
|
|
||||||
def get_train_set(self):
|
def get_test_sets(self):
|
||||||
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True, track_length=self.track_length)
|
test_sets = torchext.TestSets()
|
||||||
return train_set
|
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True,
|
||||||
|
track_length=1)
|
||||||
|
test_sets.append('simple', test_set, test_frequency=1)
|
||||||
|
|
||||||
def get_test_sets(self):
|
self.ph_losses = []
|
||||||
test_sets = torchext.TestSets()
|
self.ge_losses = []
|
||||||
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True, track_length=1)
|
self.d2ds = []
|
||||||
test_sets.append('simple', test_set, test_frequency=1)
|
|
||||||
|
|
||||||
self.ph_losses = []
|
self.lcn_in = self.lcn_in.to('cuda')
|
||||||
self.ge_losses = []
|
for sidx in range(len(test_set.imsizes)):
|
||||||
self.d2ds = []
|
imsize = test_set.imsizes[sidx]
|
||||||
|
pat = test_set.patterns[sidx]
|
||||||
|
pat = pat.mean(axis=2)
|
||||||
|
pat = torch.from_numpy(pat[None][None].astype(np.float32)).to('cuda')
|
||||||
|
pat, _ = self.lcn_in(pat)
|
||||||
|
pat = torch.cat([pat for idx in range(3)], dim=1)
|
||||||
|
ph_loss = networks.RectifiedPatternSimilarityLoss(imsize[0], imsize[1], pattern=pat)
|
||||||
|
|
||||||
self.lcn_in = self.lcn_in.to('cuda')
|
K = test_set.getK(sidx)
|
||||||
for sidx in range(len(test_set.imsizes)):
|
Ki = np.linalg.inv(K)
|
||||||
imsize = test_set.imsizes[sidx]
|
K = torch.from_numpy(K)
|
||||||
pat = test_set.patterns[sidx]
|
Ki = torch.from_numpy(Ki)
|
||||||
pat = pat.mean(axis=2)
|
ge_loss = networks.ProjectionDepthSimilarityLoss(K, Ki, imsize[0], imsize[1], clamp=0.1)
|
||||||
pat = torch.from_numpy(pat[None][None].astype(np.float32)).to('cuda')
|
|
||||||
pat,_ = self.lcn_in(pat)
|
|
||||||
pat = torch.cat([pat for idx in range(3)], dim=1)
|
|
||||||
ph_loss = networks.RectifiedPatternSimilarityLoss(imsize[0],imsize[1], pattern=pat)
|
|
||||||
|
|
||||||
K = test_set.getK(sidx)
|
self.ph_losses.append(ph_loss)
|
||||||
Ki = np.linalg.inv(K)
|
self.ge_losses.append(ge_loss)
|
||||||
K = torch.from_numpy(K)
|
|
||||||
Ki = torch.from_numpy(Ki)
|
|
||||||
ge_loss = networks.ProjectionDepthSimilarityLoss(K, Ki, imsize[0], imsize[1], clamp=0.1)
|
|
||||||
|
|
||||||
self.ph_losses.append( ph_loss )
|
d2d = networks.DispToDepth(float(test_set.focal_lengths[sidx]), float(test_set.baseline))
|
||||||
self.ge_losses.append( ge_loss )
|
self.d2ds.append(d2d)
|
||||||
|
|
||||||
d2d = networks.DispToDepth(float(test_set.focal_lengths[sidx]), float(test_set.baseline))
|
return test_sets
|
||||||
self.d2ds.append( d2d )
|
|
||||||
|
|
||||||
return test_sets
|
def copy_data(self, data, device, requires_grad, train):
|
||||||
|
self.data = {}
|
||||||
|
|
||||||
def copy_data(self, data, device, requires_grad, train):
|
self.lcn_in = self.lcn_in.to(device)
|
||||||
self.data = {}
|
for key, val in data.items():
|
||||||
|
# from
|
||||||
|
# batch_size x track_length x ...
|
||||||
|
# to
|
||||||
|
# track_length x batch_size x ...
|
||||||
|
if len(val.shape) > 2:
|
||||||
|
if train:
|
||||||
|
val = val.transpose(0, 1)
|
||||||
|
else:
|
||||||
|
val = val.unsqueeze(0)
|
||||||
|
grad = 'im' in key and requires_grad
|
||||||
|
self.data[key] = val.to(device).requires_grad_(requires_grad=grad)
|
||||||
|
if 'im' in key and 'blend' not in key:
|
||||||
|
im = self.data[key]
|
||||||
|
tl = im.shape[0]
|
||||||
|
bs = im.shape[1]
|
||||||
|
im_lcn, im_std = self.lcn_in(im.contiguous().view(-1, *im.shape[2:]))
|
||||||
|
key_std = key.replace('im', 'std')
|
||||||
|
self.data[key_std] = im_std.view(tl, bs, *im.shape[2:]).to(device)
|
||||||
|
im_cat = torch.cat((im_lcn.view(tl, bs, *im.shape[2:]), im), dim=2)
|
||||||
|
self.data[key] = im_cat
|
||||||
|
|
||||||
self.lcn_in = self.lcn_in.to(device)
|
def net_forward(self, net, train):
|
||||||
for key, val in data.items():
|
im0 = self.data['im0']
|
||||||
# from
|
tl = im0.shape[0]
|
||||||
# batch_size x track_length x ...
|
bs = im0.shape[1]
|
||||||
# to
|
im0 = im0.view(-1, *im0.shape[2:])
|
||||||
# track_length x batch_size x ...
|
out, edge = net(im0)
|
||||||
if len(val.shape)>2:
|
if not (isinstance(out, tuple) or isinstance(out, list)):
|
||||||
if train:
|
out = out.view(tl, bs, *out.shape[1:])
|
||||||
val = val.transpose(0,1)
|
edge = edge.view(tl, bs, *out.shape[1:])
|
||||||
else:
|
else:
|
||||||
val = val.unsqueeze(0)
|
out = [o.view(tl, bs, *o.shape[1:]) for o in out]
|
||||||
grad = 'im' in key and requires_grad
|
edge = [e.view(tl, bs, *e.shape[1:]) for e in edge]
|
||||||
self.data[key] = val.to(device).requires_grad_(requires_grad=grad)
|
return out, edge
|
||||||
if 'im' in key and 'blend' not in key:
|
|
||||||
im = self.data[key]
|
|
||||||
tl = im.shape[0]
|
|
||||||
bs = im.shape[1]
|
|
||||||
im_lcn,im_std = self.lcn_in(im.contiguous().view(-1, *im.shape[2:]))
|
|
||||||
key_std = key.replace('im','std')
|
|
||||||
self.data[key_std] = im_std.view(tl, bs, *im.shape[2:]).to(device)
|
|
||||||
im_cat = torch.cat((im_lcn.view(tl, bs, *im.shape[2:]), im), dim=2)
|
|
||||||
self.data[key] = im_cat
|
|
||||||
|
|
||||||
def net_forward(self, net, train):
|
def loss_forward(self, out, train):
|
||||||
im0 = self.data['im0']
|
out, edge = out
|
||||||
tl = im0.shape[0]
|
if not (isinstance(out, tuple) or isinstance(out, list)):
|
||||||
bs = im0.shape[1]
|
out = [out]
|
||||||
im0 = im0.view(-1, *im0.shape[2:])
|
vals = []
|
||||||
out, edge = net(im0)
|
diffs = []
|
||||||
if not(isinstance(out, tuple) or isinstance(out, list)):
|
|
||||||
out = out.view(tl, bs, *out.shape[1:])
|
|
||||||
edge = edge.view(tl, bs, *out.shape[1:])
|
|
||||||
else:
|
|
||||||
out = [o.view(tl, bs, *o.shape[1:]) for o in out]
|
|
||||||
edge = [e.view(tl, bs, *e.shape[1:]) for e in edge]
|
|
||||||
return out, edge
|
|
||||||
|
|
||||||
def loss_forward(self, out, train):
|
# apply photometric loss
|
||||||
out, edge = out
|
for s, l, o in zip(itertools.count(), self.ph_losses, out):
|
||||||
if not(isinstance(out, tuple) or isinstance(out, list)):
|
im = self.data[f'im{s}']
|
||||||
out = [out]
|
im = im.view(-1, *im.shape[2:])
|
||||||
vals = []
|
o = o.view(-1, *o.shape[2:])
|
||||||
diffs = []
|
std = self.data[f'std{s}']
|
||||||
|
std = std.view(-1, *std.shape[2:])
|
||||||
|
val, pattern_proj = l(o, im[:, 0:1, ...], std)
|
||||||
|
vals.append(val)
|
||||||
|
if s == 0:
|
||||||
|
self.pattern_proj = pattern_proj.detach()
|
||||||
|
|
||||||
# apply photometric loss
|
# apply disparity loss
|
||||||
for s,l,o in zip(itertools.count(), self.ph_losses, out):
|
# 1-edge as ground truth edge if inversed
|
||||||
im = self.data[f'im{s}']
|
edge0 = 1 - torch.sigmoid(edge[0])
|
||||||
im = im.view(-1, *im.shape[2:])
|
edge0 = edge0.view(-1, *edge0.shape[2:])
|
||||||
o = o.view(-1, *o.shape[2:])
|
out0 = out[0].view(-1, *out[0].shape[2:])
|
||||||
std = self.data[f'std{s}']
|
val = self.disparity_loss(out0, edge0)
|
||||||
std = std.view(-1, *std.shape[2:])
|
if self.dp_weight > 0:
|
||||||
val, pattern_proj = l(o, im[:,0:1,...], std)
|
vals.append(val * self.dp_weight)
|
||||||
vals.append(val)
|
|
||||||
if s == 0:
|
|
||||||
self.pattern_proj = pattern_proj.detach()
|
|
||||||
|
|
||||||
# apply disparity loss
|
# apply edge loss on a subset of training samples
|
||||||
# 1-edge as ground truth edge if inversed
|
for s, e in zip(itertools.count(), edge):
|
||||||
edge0 = 1-torch.sigmoid(edge[0])
|
# inversed ground truth edge where 0 means edge
|
||||||
edge0 = edge0.view(-1, *edge0.shape[2:])
|
grad = self.data[f'grad{s}'] < 0.2
|
||||||
out0 = out[0].view(-1, *out[0].shape[2:])
|
grad = grad.to(torch.float32)
|
||||||
val = self.disparity_loss(out0, edge0)
|
ids = self.data['id']
|
||||||
if self.dp_weight>0:
|
mask = ids > self.train_edge
|
||||||
vals.append(val * self.dp_weight)
|
if mask.sum() > 0:
|
||||||
|
e = e[:, mask, :]
|
||||||
|
grad = grad[:, mask, :]
|
||||||
|
e = e.view(-1, *e.shape[2:])
|
||||||
|
grad = grad.view(-1, *grad.shape[2:])
|
||||||
|
val = self.edge_loss(e, grad)
|
||||||
|
else:
|
||||||
|
val = torch.zeros_like(vals[0])
|
||||||
|
vals.append(val)
|
||||||
|
|
||||||
# apply edge loss on a subset of training samples
|
if train is False:
|
||||||
for s,e in zip(itertools.count(), edge):
|
return vals
|
||||||
# inversed ground truth edge where 0 means edge
|
|
||||||
grad = self.data[f'grad{s}']<0.2
|
|
||||||
grad = grad.to(torch.float32)
|
|
||||||
ids = self.data['id']
|
|
||||||
mask = ids>self.train_edge
|
|
||||||
if mask.sum()>0:
|
|
||||||
e = e[:,mask,:]
|
|
||||||
grad = grad[:,mask,:]
|
|
||||||
e = e.view(-1, *e.shape[2:])
|
|
||||||
grad = grad.view(-1, *grad.shape[2:])
|
|
||||||
val = self.edge_loss(e, grad)
|
|
||||||
else:
|
|
||||||
val = torch.zeros_like(vals[0])
|
|
||||||
vals.append(val)
|
|
||||||
|
|
||||||
if train is False:
|
# apply geometric loss
|
||||||
return vals
|
R = self.data['R']
|
||||||
|
t = self.data['t']
|
||||||
|
ge_num = self.track_length * (self.track_length - 1) / 2
|
||||||
|
for sidx in range(len(out)):
|
||||||
|
d2d = self.d2ds[sidx]
|
||||||
|
depth = d2d(out[sidx])
|
||||||
|
ge_loss = self.ge_losses[sidx]
|
||||||
|
imsize = self.imsizes[sidx]
|
||||||
|
for tidx0 in range(depth.shape[0]):
|
||||||
|
for tidx1 in range(tidx0 + 1, depth.shape[0]):
|
||||||
|
depth0 = depth[tidx0]
|
||||||
|
R0 = R[tidx0]
|
||||||
|
t0 = t[tidx0]
|
||||||
|
depth1 = depth[tidx1]
|
||||||
|
R1 = R[tidx1]
|
||||||
|
t1 = t[tidx1]
|
||||||
|
|
||||||
# apply geometric loss
|
val = ge_loss(depth0, depth1, R0, t0, R1, t1)
|
||||||
R = self.data['R']
|
vals.append(val * self.ge_weight / ge_num)
|
||||||
t = self.data['t']
|
|
||||||
ge_num = self.track_length * (self.track_length-1) / 2
|
|
||||||
for sidx in range(len(out)):
|
|
||||||
d2d = self.d2ds[sidx]
|
|
||||||
depth = d2d(out[sidx])
|
|
||||||
ge_loss = self.ge_losses[sidx]
|
|
||||||
imsize = self.imsizes[sidx]
|
|
||||||
for tidx0 in range(depth.shape[0]):
|
|
||||||
for tidx1 in range(tidx0+1, depth.shape[0]):
|
|
||||||
depth0 = depth[tidx0]
|
|
||||||
R0 = R[tidx0]
|
|
||||||
t0 = t[tidx0]
|
|
||||||
depth1 = depth[tidx1]
|
|
||||||
R1 = R[tidx1]
|
|
||||||
t1 = t[tidx1]
|
|
||||||
|
|
||||||
val = ge_loss(depth0, depth1, R0, t0, R1, t1)
|
return vals
|
||||||
vals.append(val * self.ge_weight / ge_num)
|
|
||||||
|
|
||||||
return vals
|
def numpy_in_out(self, output):
|
||||||
|
output, edge = output
|
||||||
|
if not (isinstance(output, tuple) or isinstance(output, list)):
|
||||||
|
output = [output]
|
||||||
|
es = output[0].detach().to('cpu').numpy()
|
||||||
|
gt = self.data['disp0'].to('cpu').numpy().astype(np.float32)
|
||||||
|
im = self.data['im0'][:, :, 0:1, ...].detach().to('cpu').numpy()
|
||||||
|
ma = gt > 0
|
||||||
|
return es, gt, im, ma
|
||||||
|
|
||||||
def numpy_in_out(self, output):
|
def write_img(self, out_path, es, gt, im, ma):
|
||||||
output, edge = output
|
logging.info(f'write img {out_path}')
|
||||||
if not(isinstance(output, tuple) or isinstance(output, list)):
|
u_pos, _ = np.meshgrid(range(es.shape[1]), range(es.shape[0]))
|
||||||
output = [output]
|
|
||||||
es = output[0].detach().to('cpu').numpy()
|
|
||||||
gt = self.data['disp0'].to('cpu').numpy().astype(np.float32)
|
|
||||||
im = self.data['im0'][:,:,0:1,...].detach().to('cpu').numpy()
|
|
||||||
ma = gt>0
|
|
||||||
return es, gt, im, ma
|
|
||||||
|
|
||||||
def write_img(self, out_path, es, gt, im, ma):
|
diff = np.abs(es - gt)
|
||||||
logging.info(f'write img {out_path}')
|
|
||||||
u_pos, _ = np.meshgrid(range(es.shape[1]), range(es.shape[0]))
|
|
||||||
|
|
||||||
diff = np.abs(es - gt)
|
vmin, vmax = np.nanmin(gt), np.nanmax(gt)
|
||||||
|
vmin = vmin - 0.2 * (vmax - vmin)
|
||||||
|
vmax = vmax + 0.2 * (vmax - vmin)
|
||||||
|
|
||||||
vmin, vmax = np.nanmin(gt), np.nanmax(gt)
|
pattern_proj = self.pattern_proj.to('cpu').numpy()[0, 0]
|
||||||
vmin = vmin - 0.2*(vmax-vmin)
|
im_orig = self.data['im0'].detach().to('cpu').numpy()[0, 0, 0]
|
||||||
vmax = vmax + 0.2*(vmax-vmin)
|
pattern_diff = np.abs(im_orig - pattern_proj)
|
||||||
|
|
||||||
pattern_proj = self.pattern_proj.to('cpu').numpy()[0,0]
|
fig = plt.figure(figsize=(16, 16))
|
||||||
im_orig = self.data['im0'].detach().to('cpu').numpy()[0,0,0]
|
es0 = co.cmap.color_depth_map(es[0], scale=vmax)
|
||||||
pattern_diff = np.abs(im_orig - pattern_proj)
|
gt0 = co.cmap.color_depth_map(gt[0], scale=vmax)
|
||||||
|
diff0 = co.cmap.color_error_image(diff[0], BGR=True)
|
||||||
|
|
||||||
fig = plt.figure(figsize=(16,16))
|
# plot disparities, ground truth disparity is shown only for reference
|
||||||
es0 = co.cmap.color_depth_map(es[0], scale=vmax)
|
ax = plt.subplot(3, 3, 1);
|
||||||
gt0 = co.cmap.color_depth_map(gt[0], scale=vmax)
|
plt.imshow(es0[..., [2, 1, 0]]);
|
||||||
diff0 = co.cmap.color_error_image(diff[0], BGR=True)
|
plt.xticks([]);
|
||||||
|
plt.yticks([]);
|
||||||
|
ax.set_title(f'F0 Disparity Est. {es0.min():.4f}/{es0.max():.4f}')
|
||||||
|
ax = plt.subplot(3, 3, 2);
|
||||||
|
plt.imshow(gt0[..., [2, 1, 0]]);
|
||||||
|
plt.xticks([]);
|
||||||
|
plt.yticks([]);
|
||||||
|
ax.set_title(f'F0 Disparity GT {np.nanmin(gt0):.4f}/{np.nanmax(gt0):.4f}')
|
||||||
|
ax = plt.subplot(3, 3, 3);
|
||||||
|
plt.imshow(diff0[..., [2, 1, 0]]);
|
||||||
|
plt.xticks([]);
|
||||||
|
plt.yticks([]);
|
||||||
|
ax.set_title(f'F0 Disparity Err. {diff0.mean():.5f}')
|
||||||
|
|
||||||
# plot disparities, ground truth disparity is shown only for reference
|
# plot disparities of the second frame in the track if exists
|
||||||
ax = plt.subplot(3,3,1); plt.imshow(es0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Est. {es0.min():.4f}/{es0.max():.4f}')
|
if es.shape[0] >= 2:
|
||||||
ax = plt.subplot(3,3,2); plt.imshow(gt0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity GT {np.nanmin(gt0):.4f}/{np.nanmax(gt0):.4f}')
|
es1 = co.cmap.color_depth_map(es[1], scale=vmax)
|
||||||
ax = plt.subplot(3,3,3); plt.imshow(diff0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Err. {diff0.mean():.5f}')
|
gt1 = co.cmap.color_depth_map(gt[1], scale=vmax)
|
||||||
|
diff1 = co.cmap.color_error_image(diff[1], BGR=True)
|
||||||
|
ax = plt.subplot(3, 3, 4);
|
||||||
|
plt.imshow(es1[..., [2, 1, 0]]);
|
||||||
|
plt.xticks([]);
|
||||||
|
plt.yticks([]);
|
||||||
|
ax.set_title(f'F1 Disparity Est. {es1.min():.4f}/{es1.max():.4f}')
|
||||||
|
ax = plt.subplot(3, 3, 5);
|
||||||
|
plt.imshow(gt1[..., [2, 1, 0]]);
|
||||||
|
plt.xticks([]);
|
||||||
|
plt.yticks([]);
|
||||||
|
ax.set_title(f'F1 Disparity GT {np.nanmin(gt1):.4f}/{np.nanmax(gt1):.4f}')
|
||||||
|
ax = plt.subplot(3, 3, 6);
|
||||||
|
plt.imshow(diff1[..., [2, 1, 0]]);
|
||||||
|
plt.xticks([]);
|
||||||
|
plt.yticks([]);
|
||||||
|
ax.set_title(f'F1 Disparity Err. {diff1.mean():.5f}')
|
||||||
|
|
||||||
# plot disparities of the second frame in the track if exists
|
# plot normalized IR inputs
|
||||||
if es.shape[0]>=2:
|
ax = plt.subplot(3, 3, 7);
|
||||||
es1 = co.cmap.color_depth_map(es[1], scale=vmax)
|
plt.imshow(im[0], vmin=im.min(), vmax=im.max(), cmap='gray');
|
||||||
gt1 = co.cmap.color_depth_map(gt[1], scale=vmax)
|
plt.xticks([]);
|
||||||
diff1 = co.cmap.color_error_image(diff[1], BGR=True)
|
plt.yticks([]);
|
||||||
ax = plt.subplot(3,3,4); plt.imshow(es1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity Est. {es1.min():.4f}/{es1.max():.4f}')
|
ax.set_title(f'F0 IR input {im[0].mean():.5f}/{im[0].std():.5f}')
|
||||||
ax = plt.subplot(3,3,5); plt.imshow(gt1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity GT {np.nanmin(gt1):.4f}/{np.nanmax(gt1):.4f}')
|
if es.shape[0] >= 2:
|
||||||
ax = plt.subplot(3,3,6); plt.imshow(diff1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity Err. {diff1.mean():.5f}')
|
ax = plt.subplot(3, 3, 8);
|
||||||
|
plt.imshow(im[1], vmin=im.min(), vmax=im.max(), cmap='gray');
|
||||||
|
plt.xticks([]);
|
||||||
|
plt.yticks([]);
|
||||||
|
ax.set_title(f'F1 IR input {im[1].mean():.5f}/{im[1].std():.5f}')
|
||||||
|
|
||||||
# plot normalized IR inputs
|
plt.tight_layout()
|
||||||
ax = plt.subplot(3,3,7); plt.imshow(im[0], vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 IR input {im[0].mean():.5f}/{im[0].std():.5f}')
|
plt.savefig(str(out_path))
|
||||||
if es.shape[0]>=2:
|
plt.close(fig)
|
||||||
ax = plt.subplot(3,3,8); plt.imshow(im[1], vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 IR input {im[1].mean():.5f}/{im[1].std():.5f}')
|
|
||||||
|
|
||||||
plt.tight_layout()
|
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks):
|
||||||
plt.savefig(str(out_path))
|
if batch_idx % 512 == 0:
|
||||||
plt.close(fig)
|
out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png'
|
||||||
|
es, gt, im, ma = self.numpy_in_out(output)
|
||||||
|
masks = [m.detach().to('cpu').numpy() for m in masks]
|
||||||
|
self.write_img(out_path, es[:, 0, 0], gt[:, 0, 0], im[:, 0, 0], ma[:, 0, 0])
|
||||||
|
|
||||||
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks):
|
def callback_test_start(self, epoch, set_idx):
|
||||||
if batch_idx % 512 == 0:
|
self.metric = co.metric.MultipleMetric(
|
||||||
out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png'
|
co.metric.DistanceMetric(vec_length=1),
|
||||||
es, gt, im, ma = self.numpy_in_out(output)
|
co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5])
|
||||||
masks = [ m.detach().to('cpu').numpy() for m in masks ]
|
)
|
||||||
self.write_img(out_path, es[:,0,0], gt[:,0,0], im[:,0,0], ma[:,0,0])
|
|
||||||
|
|
||||||
def callback_test_start(self, epoch, set_idx):
|
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks):
|
||||||
self.metric = co.metric.MultipleMetric(
|
es, gt, im, ma = self.numpy_in_out(output)
|
||||||
co.metric.DistanceMetric(vec_length=1),
|
|
||||||
co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5])
|
|
||||||
)
|
|
||||||
|
|
||||||
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks):
|
if batch_idx % 8 == 0:
|
||||||
es, gt, im, ma = self.numpy_in_out(output)
|
out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png'
|
||||||
|
self.write_img(out_path, es[:, 0, 0], gt[:, 0, 0], im[:, 0, 0], ma[:, 0, 0])
|
||||||
|
|
||||||
if batch_idx % 8 == 0:
|
es, gt, im, ma = self.crop_output(es, gt, im, ma)
|
||||||
out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png'
|
|
||||||
self.write_img(out_path, es[:,0,0], gt[:,0,0], im[:,0,0], ma[:,0,0])
|
|
||||||
|
|
||||||
es, gt, im, ma = self.crop_output(es, gt, im, ma)
|
es = es.reshape(-1, 1)
|
||||||
|
gt = gt.reshape(-1, 1)
|
||||||
|
ma = ma.ravel()
|
||||||
|
self.metric.add(es, gt, ma)
|
||||||
|
|
||||||
es = es.reshape(-1,1)
|
def callback_test_stop(self, epoch, set_idx, loss):
|
||||||
gt = gt.reshape(-1,1)
|
logging.info(f'{self.metric}')
|
||||||
ma = ma.ravel()
|
for k, v in self.metric.items():
|
||||||
self.metric.add(es, gt, ma)
|
self.metric_add_test(epoch, set_idx, k, v)
|
||||||
|
|
||||||
def callback_test_stop(self, epoch, set_idx, loss):
|
def crop_output(self, es, gt, im, ma):
|
||||||
logging.info(f'{self.metric}')
|
tl = es.shape[0]
|
||||||
for k, v in self.metric.items():
|
bs = es.shape[1]
|
||||||
self.metric_add_test(epoch, set_idx, k, v)
|
es = np.reshape(es[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w])
|
||||||
|
gt = np.reshape(gt[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w])
|
||||||
|
im = np.reshape(im[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w])
|
||||||
|
ma = np.reshape(ma[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w])
|
||||||
|
return es, gt, im, ma
|
||||||
|
|
||||||
def crop_output(self, es, gt, im, ma):
|
|
||||||
tl = es.shape[0]
|
|
||||||
bs = es.shape[1]
|
|
||||||
es = np.reshape(es[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w])
|
|
||||||
gt = np.reshape(gt[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w])
|
|
||||||
im = np.reshape(im[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w])
|
|
||||||
ma = np.reshape(ma[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w])
|
|
||||||
return es, gt, im, ma
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pass
|
pass
|
||||||
|
@ -8,559 +8,572 @@ import co
|
|||||||
|
|
||||||
|
|
||||||
class TimedModule(torch.nn.Module):
|
class TimedModule(torch.nn.Module):
|
||||||
def __init__(self, mod_name):
|
def __init__(self, mod_name):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mod_name = mod_name
|
self.mod_name = mod_name
|
||||||
|
|
||||||
def tforward(self, *args, **kwargs):
|
def tforward(self, *args, **kwargs):
|
||||||
raise Exception('not implemented')
|
raise Exception('not implemented')
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
with co.gtimer.Ctx(self.mod_name):
|
with co.gtimer.Ctx(self.mod_name):
|
||||||
x = self.tforward(*args, **kwargs)
|
x = self.tforward(*args, **kwargs)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class PosOutput(TimedModule):
|
class PosOutput(TimedModule):
|
||||||
def __init__(self, channels_in, type, im_height, im_width, alpha=1, beta=0, gamma=1, offset=0):
|
def __init__(self, channels_in, type, im_height, im_width, alpha=1, beta=0, gamma=1, offset=0):
|
||||||
super().__init__(mod_name='PosOutput')
|
super().__init__(mod_name='PosOutput')
|
||||||
self.im_width = im_width
|
self.im_width = im_width
|
||||||
self.im_width = im_width
|
self.im_width = im_width
|
||||||
|
|
||||||
if type == 'pos':
|
if type == 'pos':
|
||||||
self.layer = torch.nn.Sequential(
|
self.layer = torch.nn.Sequential(
|
||||||
torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1),
|
torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1),
|
||||||
SigmoidAffine(alpha=alpha, beta=beta, gamma=gamma, offset=offset)
|
SigmoidAffine(alpha=alpha, beta=beta, gamma=gamma, offset=offset)
|
||||||
)
|
)
|
||||||
elif type == 'pos_row':
|
elif type == 'pos_row':
|
||||||
self.layer = torch.nn.Sequential(
|
self.layer = torch.nn.Sequential(
|
||||||
MultiLinear(im_height, channels_in, 1),
|
MultiLinear(im_height, channels_in, 1),
|
||||||
SigmoidAffine(alpha=alpha, beta=beta, gamma=gamma, offset=offset)
|
SigmoidAffine(alpha=alpha, beta=beta, gamma=gamma, offset=offset)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.u_pos = None
|
self.u_pos = None
|
||||||
|
|
||||||
def tforward(self, x):
|
def tforward(self, x):
|
||||||
if self.u_pos is None:
|
if self.u_pos is None:
|
||||||
self.u_pos = torch.arange(x.shape[3], dtype=torch.float32).view(1,1,1,-1)
|
self.u_pos = torch.arange(x.shape[3], dtype=torch.float32).view(1, 1, 1, -1)
|
||||||
self.u_pos = self.u_pos.to(x.device)
|
self.u_pos = self.u_pos.to(x.device)
|
||||||
pos = self.layer(x)
|
pos = self.layer(x)
|
||||||
disp = self.u_pos - pos
|
disp = self.u_pos - pos
|
||||||
return disp
|
return disp
|
||||||
|
|
||||||
|
|
||||||
class OutputLayerFactory(object):
|
class OutputLayerFactory(object):
|
||||||
'''
|
'''
|
||||||
Define type of output
|
Define type of output
|
||||||
type options:
|
type options:
|
||||||
linear: apply only conv channel, used for the edge decoder
|
linear: apply only conv channel, used for the edge decoder
|
||||||
disp: estimate the disparity
|
disp: estimate the disparity
|
||||||
disp_row: independently estimate the disparity per row
|
disp_row: independently estimate the disparity per row
|
||||||
pos: estimate the absolute location
|
pos: estimate the absolute location
|
||||||
pos_row: independently estimate the absolute location per row
|
pos_row: independently estimate the absolute location per row
|
||||||
'''
|
'''
|
||||||
def __init__(self, type='disp', params={}):
|
|
||||||
self.type = type
|
|
||||||
self.params = params
|
|
||||||
|
|
||||||
def __call__(self, channels_in, imsize):
|
def __init__(self, type='disp', params={}):
|
||||||
|
self.type = type
|
||||||
|
self.params = params
|
||||||
|
|
||||||
if self.type == 'linear':
|
def __call__(self, channels_in, imsize):
|
||||||
return torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1)
|
|
||||||
|
|
||||||
elif self.type == 'disp':
|
if self.type == 'linear':
|
||||||
return torch.nn.Sequential(
|
return torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1)
|
||||||
torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1),
|
|
||||||
SigmoidAffine(**self.params)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif self.type == 'disp_row':
|
elif self.type == 'disp':
|
||||||
return torch.nn.Sequential(
|
return torch.nn.Sequential(
|
||||||
MultiLinear(imsize[0], channels_in, 1),
|
torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1),
|
||||||
SigmoidAffine(**self.params)
|
SigmoidAffine(**self.params)
|
||||||
)
|
)
|
||||||
|
|
||||||
elif self.type == 'pos' or self.type == 'pos_row':
|
elif self.type == 'disp_row':
|
||||||
return PosOutput(channels_in, **self.params)
|
return torch.nn.Sequential(
|
||||||
|
MultiLinear(imsize[0], channels_in, 1),
|
||||||
|
SigmoidAffine(**self.params)
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
elif self.type == 'pos' or self.type == 'pos_row':
|
||||||
raise Exception('unknown output layer type')
|
return PosOutput(channels_in, **self.params)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise Exception('unknown output layer type')
|
||||||
|
|
||||||
|
|
||||||
class SigmoidAffine(TimedModule):
|
class SigmoidAffine(TimedModule):
|
||||||
def __init__(self, alpha=1, beta=0, gamma=1, offset=0):
|
def __init__(self, alpha=1, beta=0, gamma=1, offset=0):
|
||||||
super().__init__(mod_name='SigmoidAffine')
|
super().__init__(mod_name='SigmoidAffine')
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
self.offset = offset
|
self.offset = offset
|
||||||
|
|
||||||
def tforward(self, x):
|
def tforward(self, x):
|
||||||
return torch.sigmoid(x/self.gamma - self.offset) * self.alpha + self.beta
|
return torch.sigmoid(x / self.gamma - self.offset) * self.alpha + self.beta
|
||||||
|
|
||||||
|
|
||||||
class MultiLinear(TimedModule):
|
class MultiLinear(TimedModule):
|
||||||
def __init__(self, n, channels_in, channels_out):
|
def __init__(self, n, channels_in, channels_out):
|
||||||
super().__init__(mod_name='MultiLinear')
|
super().__init__(mod_name='MultiLinear')
|
||||||
self.channels_out = channels_out
|
self.channels_out = channels_out
|
||||||
self.mods = torch.nn.ModuleList()
|
self.mods = torch.nn.ModuleList()
|
||||||
for idx in range(n):
|
for idx in range(n):
|
||||||
self.mods.append(torch.nn.Linear(channels_in, channels_out))
|
self.mods.append(torch.nn.Linear(channels_in, channels_out))
|
||||||
|
|
||||||
def tforward(self, x):
|
|
||||||
x = x.permute(2,0,3,1) # BxCxHxW => HxBxWxC
|
|
||||||
y = x.new_empty(*x.shape[:-1], self.channels_out)
|
|
||||||
for hidx in range(x.shape[0]):
|
|
||||||
y[hidx] = self.mods[hidx](x[hidx])
|
|
||||||
y = y.permute(1,3,0,2) # HxBxWxC => BxCxHxW
|
|
||||||
return y
|
|
||||||
|
|
||||||
|
def tforward(self, x):
|
||||||
|
x = x.permute(2, 0, 3, 1) # BxCxHxW => HxBxWxC
|
||||||
|
y = x.new_empty(*x.shape[:-1], self.channels_out)
|
||||||
|
for hidx in range(x.shape[0]):
|
||||||
|
y[hidx] = self.mods[hidx](x[hidx])
|
||||||
|
y = y.permute(1, 3, 0, 2) # HxBxWxC => BxCxHxW
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
class DispNetS(TimedModule):
|
class DispNetS(TimedModule):
|
||||||
'''
|
'''
|
||||||
Disparity Decoder based on DispNetS
|
Disparity Decoder based on DispNetS
|
||||||
'''
|
'''
|
||||||
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False, channel_multiplier=1):
|
|
||||||
super(DispNetS, self).__init__(mod_name='DispNetS')
|
|
||||||
|
|
||||||
self.output_ms = output_ms
|
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False,
|
||||||
self.coordconv = coordconv
|
channel_multiplier=1):
|
||||||
|
super(DispNetS, self).__init__(mod_name='DispNetS')
|
||||||
|
|
||||||
conv_planes = channel_multiplier * np.array( [32, 64, 128, 256, 512, 512, 512] )
|
self.output_ms = output_ms
|
||||||
self.conv1 = self.downsample_conv(channels_in, conv_planes[0], kernel_size=7)
|
self.coordconv = coordconv
|
||||||
self.conv2 = self.downsample_conv(conv_planes[0], conv_planes[1], kernel_size=5)
|
|
||||||
self.conv3 = self.downsample_conv(conv_planes[1], conv_planes[2])
|
|
||||||
self.conv4 = self.downsample_conv(conv_planes[2], conv_planes[3])
|
|
||||||
self.conv5 = self.downsample_conv(conv_planes[3], conv_planes[4])
|
|
||||||
self.conv6 = self.downsample_conv(conv_planes[4], conv_planes[5])
|
|
||||||
self.conv7 = self.downsample_conv(conv_planes[5], conv_planes[6])
|
|
||||||
|
|
||||||
upconv_planes = channel_multiplier * np.array( [512, 512, 256, 128, 64, 32, 16] )
|
conv_planes = channel_multiplier * np.array([32, 64, 128, 256, 512, 512, 512])
|
||||||
self.upconv7 = self.upconv(conv_planes[6], upconv_planes[0])
|
self.conv1 = self.downsample_conv(channels_in, conv_planes[0], kernel_size=7)
|
||||||
self.upconv6 = self.upconv(upconv_planes[0], upconv_planes[1])
|
self.conv2 = self.downsample_conv(conv_planes[0], conv_planes[1], kernel_size=5)
|
||||||
self.upconv5 = self.upconv(upconv_planes[1], upconv_planes[2])
|
self.conv3 = self.downsample_conv(conv_planes[1], conv_planes[2])
|
||||||
self.upconv4 = self.upconv(upconv_planes[2], upconv_planes[3])
|
self.conv4 = self.downsample_conv(conv_planes[2], conv_planes[3])
|
||||||
self.upconv3 = self.upconv(upconv_planes[3], upconv_planes[4])
|
self.conv5 = self.downsample_conv(conv_planes[3], conv_planes[4])
|
||||||
self.upconv2 = self.upconv(upconv_planes[4], upconv_planes[5])
|
self.conv6 = self.downsample_conv(conv_planes[4], conv_planes[5])
|
||||||
self.upconv1 = self.upconv(upconv_planes[5], upconv_planes[6])
|
self.conv7 = self.downsample_conv(conv_planes[5], conv_planes[6])
|
||||||
|
|
||||||
self.iconv7 = self.conv(upconv_planes[0] + conv_planes[5], upconv_planes[0])
|
upconv_planes = channel_multiplier * np.array([512, 512, 256, 128, 64, 32, 16])
|
||||||
self.iconv6 = self.conv(upconv_planes[1] + conv_planes[4], upconv_planes[1])
|
self.upconv7 = self.upconv(conv_planes[6], upconv_planes[0])
|
||||||
self.iconv5 = self.conv(upconv_planes[2] + conv_planes[3], upconv_planes[2])
|
self.upconv6 = self.upconv(upconv_planes[0], upconv_planes[1])
|
||||||
self.iconv4 = self.conv(upconv_planes[3] + conv_planes[2], upconv_planes[3])
|
self.upconv5 = self.upconv(upconv_planes[1], upconv_planes[2])
|
||||||
self.iconv3 = self.conv(1 + upconv_planes[4] + conv_planes[1], upconv_planes[4])
|
self.upconv4 = self.upconv(upconv_planes[2], upconv_planes[3])
|
||||||
self.iconv2 = self.conv(1 + upconv_planes[5] + conv_planes[0], upconv_planes[5])
|
self.upconv3 = self.upconv(upconv_planes[3], upconv_planes[4])
|
||||||
self.iconv1 = self.conv(1 + upconv_planes[6], upconv_planes[6])
|
self.upconv2 = self.upconv(upconv_planes[4], upconv_planes[5])
|
||||||
|
self.upconv1 = self.upconv(upconv_planes[5], upconv_planes[6])
|
||||||
|
|
||||||
if isinstance(output_facs, list):
|
self.iconv7 = self.conv(upconv_planes[0] + conv_planes[5], upconv_planes[0])
|
||||||
self.predict_disp4 = output_facs[3](upconv_planes[3], imsizes[3])
|
self.iconv6 = self.conv(upconv_planes[1] + conv_planes[4], upconv_planes[1])
|
||||||
self.predict_disp3 = output_facs[2](upconv_planes[4], imsizes[2])
|
self.iconv5 = self.conv(upconv_planes[2] + conv_planes[3], upconv_planes[2])
|
||||||
self.predict_disp2 = output_facs[1](upconv_planes[5], imsizes[1])
|
self.iconv4 = self.conv(upconv_planes[3] + conv_planes[2], upconv_planes[3])
|
||||||
self.predict_disp1 = output_facs[0](upconv_planes[6], imsizes[0])
|
self.iconv3 = self.conv(1 + upconv_planes[4] + conv_planes[1], upconv_planes[4])
|
||||||
else:
|
self.iconv2 = self.conv(1 + upconv_planes[5] + conv_planes[0], upconv_planes[5])
|
||||||
self.predict_disp4 = output_facs(upconv_planes[3], imsizes[3])
|
self.iconv1 = self.conv(1 + upconv_planes[6], upconv_planes[6])
|
||||||
self.predict_disp3 = output_facs(upconv_planes[4], imsizes[2])
|
|
||||||
self.predict_disp2 = output_facs(upconv_planes[5], imsizes[1])
|
|
||||||
self.predict_disp1 = output_facs(upconv_planes[6], imsizes[0])
|
|
||||||
|
|
||||||
|
if isinstance(output_facs, list):
|
||||||
|
self.predict_disp4 = output_facs[3](upconv_planes[3], imsizes[3])
|
||||||
|
self.predict_disp3 = output_facs[2](upconv_planes[4], imsizes[2])
|
||||||
|
self.predict_disp2 = output_facs[1](upconv_planes[5], imsizes[1])
|
||||||
|
self.predict_disp1 = output_facs[0](upconv_planes[6], imsizes[0])
|
||||||
|
else:
|
||||||
|
self.predict_disp4 = output_facs(upconv_planes[3], imsizes[3])
|
||||||
|
self.predict_disp3 = output_facs(upconv_planes[4], imsizes[2])
|
||||||
|
self.predict_disp2 = output_facs(upconv_planes[5], imsizes[1])
|
||||||
|
self.predict_disp1 = output_facs(upconv_planes[6], imsizes[0])
|
||||||
|
|
||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.ConvTranspose2d):
|
if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.ConvTranspose2d):
|
||||||
torch.nn.init.xavier_uniform_(m.weight, gain=0.1)
|
torch.nn.init.xavier_uniform_(m.weight, gain=0.1)
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
torch.nn.init.zeros_(m.bias)
|
torch.nn.init.zeros_(m.bias)
|
||||||
|
|
||||||
def downsample_conv(self, in_planes, out_planes, kernel_size=3):
|
def downsample_conv(self, in_planes, out_planes, kernel_size=3):
|
||||||
if self.coordconv:
|
if self.coordconv:
|
||||||
conv = torchext.CoordConv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=(kernel_size-1)//2)
|
conv = torchext.CoordConv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2,
|
||||||
else:
|
padding=(kernel_size - 1) // 2)
|
||||||
conv = torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=(kernel_size-1)//2)
|
else:
|
||||||
return torch.nn.Sequential(
|
conv = torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2,
|
||||||
conv,
|
padding=(kernel_size - 1) // 2)
|
||||||
torch.nn.ReLU(inplace=True),
|
return torch.nn.Sequential(
|
||||||
torch.nn.Conv2d(out_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size-1)//2),
|
conv,
|
||||||
torch.nn.ReLU(inplace=True)
|
torch.nn.ReLU(inplace=True),
|
||||||
)
|
torch.nn.Conv2d(out_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size - 1) // 2),
|
||||||
|
torch.nn.ReLU(inplace=True)
|
||||||
|
)
|
||||||
|
|
||||||
def conv(self, in_planes, out_planes):
|
def conv(self, in_planes, out_planes):
|
||||||
return torch.nn.Sequential(
|
return torch.nn.Sequential(
|
||||||
torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1),
|
torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1),
|
||||||
torch.nn.ReLU(inplace=True)
|
torch.nn.ReLU(inplace=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
def upconv(self, in_planes, out_planes):
|
def upconv(self, in_planes, out_planes):
|
||||||
return torch.nn.Sequential(
|
return torch.nn.Sequential(
|
||||||
torch.nn.ConvTranspose2d(in_planes, out_planes, kernel_size=3, stride=2, padding=1, output_padding=1),
|
torch.nn.ConvTranspose2d(in_planes, out_planes, kernel_size=3, stride=2, padding=1, output_padding=1),
|
||||||
torch.nn.ReLU(inplace=True)
|
torch.nn.ReLU(inplace=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
def crop_like(self, input, ref):
|
def crop_like(self, input, ref):
|
||||||
assert(input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3))
|
assert (input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3))
|
||||||
return input[:, :, :ref.size(2), :ref.size(3)]
|
return input[:, :, :ref.size(2), :ref.size(3)]
|
||||||
|
|
||||||
def tforward(self, x):
|
def tforward(self, x):
|
||||||
out_conv1 = self.conv1(x)
|
out_conv1 = self.conv1(x)
|
||||||
out_conv2 = self.conv2(out_conv1)
|
out_conv2 = self.conv2(out_conv1)
|
||||||
out_conv3 = self.conv3(out_conv2)
|
out_conv3 = self.conv3(out_conv2)
|
||||||
out_conv4 = self.conv4(out_conv3)
|
out_conv4 = self.conv4(out_conv3)
|
||||||
out_conv5 = self.conv5(out_conv4)
|
out_conv5 = self.conv5(out_conv4)
|
||||||
out_conv6 = self.conv6(out_conv5)
|
out_conv6 = self.conv6(out_conv5)
|
||||||
out_conv7 = self.conv7(out_conv6)
|
out_conv7 = self.conv7(out_conv6)
|
||||||
|
|
||||||
out_upconv7 = self.crop_like(self.upconv7(out_conv7), out_conv6)
|
out_upconv7 = self.crop_like(self.upconv7(out_conv7), out_conv6)
|
||||||
concat7 = torch.cat((out_upconv7, out_conv6), 1)
|
concat7 = torch.cat((out_upconv7, out_conv6), 1)
|
||||||
out_iconv7 = self.iconv7(concat7)
|
out_iconv7 = self.iconv7(concat7)
|
||||||
|
|
||||||
out_upconv6 = self.crop_like(self.upconv6(out_iconv7), out_conv5)
|
out_upconv6 = self.crop_like(self.upconv6(out_iconv7), out_conv5)
|
||||||
concat6 = torch.cat((out_upconv6, out_conv5), 1)
|
concat6 = torch.cat((out_upconv6, out_conv5), 1)
|
||||||
out_iconv6 = self.iconv6(concat6)
|
out_iconv6 = self.iconv6(concat6)
|
||||||
|
|
||||||
out_upconv5 = self.crop_like(self.upconv5(out_iconv6), out_conv4)
|
out_upconv5 = self.crop_like(self.upconv5(out_iconv6), out_conv4)
|
||||||
concat5 = torch.cat((out_upconv5, out_conv4), 1)
|
concat5 = torch.cat((out_upconv5, out_conv4), 1)
|
||||||
out_iconv5 = self.iconv5(concat5)
|
out_iconv5 = self.iconv5(concat5)
|
||||||
|
|
||||||
out_upconv4 = self.crop_like(self.upconv4(out_iconv5), out_conv3)
|
out_upconv4 = self.crop_like(self.upconv4(out_iconv5), out_conv3)
|
||||||
concat4 = torch.cat((out_upconv4, out_conv3), 1)
|
concat4 = torch.cat((out_upconv4, out_conv3), 1)
|
||||||
out_iconv4 = self.iconv4(concat4)
|
out_iconv4 = self.iconv4(concat4)
|
||||||
disp4 = self.predict_disp4(out_iconv4)
|
disp4 = self.predict_disp4(out_iconv4)
|
||||||
|
|
||||||
out_upconv3 = self.crop_like(self.upconv3(out_iconv4), out_conv2)
|
out_upconv3 = self.crop_like(self.upconv3(out_iconv4), out_conv2)
|
||||||
disp4_up = self.crop_like(torch.nn.functional.interpolate(disp4, scale_factor=2, mode='bilinear', align_corners=False), out_conv2)
|
disp4_up = self.crop_like(
|
||||||
concat3 = torch.cat((out_upconv3, out_conv2, disp4_up), 1)
|
torch.nn.functional.interpolate(disp4, scale_factor=2, mode='bilinear', align_corners=False), out_conv2)
|
||||||
out_iconv3 = self.iconv3(concat3)
|
concat3 = torch.cat((out_upconv3, out_conv2, disp4_up), 1)
|
||||||
disp3 = self.predict_disp3(out_iconv3)
|
out_iconv3 = self.iconv3(concat3)
|
||||||
|
disp3 = self.predict_disp3(out_iconv3)
|
||||||
|
|
||||||
out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1)
|
out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1)
|
||||||
disp3_up = self.crop_like(torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
|
disp3_up = self.crop_like(
|
||||||
concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1)
|
torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
|
||||||
out_iconv2 = self.iconv2(concat2)
|
concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1)
|
||||||
disp2 = self.predict_disp2(out_iconv2)
|
out_iconv2 = self.iconv2(concat2)
|
||||||
|
disp2 = self.predict_disp2(out_iconv2)
|
||||||
|
|
||||||
out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x)
|
out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x)
|
||||||
disp2_up = self.crop_like(torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x)
|
disp2_up = self.crop_like(
|
||||||
concat1 = torch.cat((out_upconv1, disp2_up), 1)
|
torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x)
|
||||||
out_iconv1 = self.iconv1(concat1)
|
concat1 = torch.cat((out_upconv1, disp2_up), 1)
|
||||||
disp1 = self.predict_disp1(out_iconv1)
|
out_iconv1 = self.iconv1(concat1)
|
||||||
|
disp1 = self.predict_disp1(out_iconv1)
|
||||||
|
|
||||||
if self.output_ms:
|
if self.output_ms:
|
||||||
return disp1, disp2, disp3, disp4
|
return disp1, disp2, disp3, disp4
|
||||||
else:
|
else:
|
||||||
return disp1
|
return disp1
|
||||||
|
|
||||||
|
|
||||||
class DispNetShallow(DispNetS):
|
class DispNetShallow(DispNetS):
|
||||||
'''
|
'''
|
||||||
Edge Decoder based on DispNetS with fewer layers
|
Edge Decoder based on DispNetS with fewer layers
|
||||||
'''
|
'''
|
||||||
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False):
|
|
||||||
super(DispNetShallow, self).__init__(channels_in, imsizes, output_facs, output_ms, coordconv, weight_init)
|
|
||||||
self.mod_name = 'DispNetShallow'
|
|
||||||
conv_planes = [32, 64, 128, 256, 512, 512, 512]
|
|
||||||
upconv_planes = [512, 512, 256, 128, 64, 32, 16]
|
|
||||||
self.iconv3 = self.conv(upconv_planes[4] + conv_planes[1], upconv_planes[4])
|
|
||||||
|
|
||||||
def tforward(self, x):
|
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False):
|
||||||
out_conv1 = self.conv1(x)
|
super(DispNetShallow, self).__init__(channels_in, imsizes, output_facs, output_ms, coordconv, weight_init)
|
||||||
out_conv2 = self.conv2(out_conv1)
|
self.mod_name = 'DispNetShallow'
|
||||||
out_conv3 = self.conv3(out_conv2)
|
conv_planes = [32, 64, 128, 256, 512, 512, 512]
|
||||||
|
upconv_planes = [512, 512, 256, 128, 64, 32, 16]
|
||||||
|
self.iconv3 = self.conv(upconv_planes[4] + conv_planes[1], upconv_planes[4])
|
||||||
|
|
||||||
out_upconv3 = self.crop_like(self.upconv3(out_conv3), out_conv2)
|
def tforward(self, x):
|
||||||
concat3 = torch.cat((out_upconv3, out_conv2), 1)
|
out_conv1 = self.conv1(x)
|
||||||
out_iconv3 = self.iconv3(concat3)
|
out_conv2 = self.conv2(out_conv1)
|
||||||
disp3 = self.predict_disp3(out_iconv3)
|
out_conv3 = self.conv3(out_conv2)
|
||||||
|
|
||||||
out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1)
|
out_upconv3 = self.crop_like(self.upconv3(out_conv3), out_conv2)
|
||||||
disp3_up = self.crop_like(torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
|
concat3 = torch.cat((out_upconv3, out_conv2), 1)
|
||||||
concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1)
|
out_iconv3 = self.iconv3(concat3)
|
||||||
out_iconv2 = self.iconv2(concat2)
|
disp3 = self.predict_disp3(out_iconv3)
|
||||||
disp2 = self.predict_disp2(out_iconv2)
|
|
||||||
|
|
||||||
out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x)
|
out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1)
|
||||||
disp2_up = self.crop_like(torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x)
|
disp3_up = self.crop_like(
|
||||||
concat1 = torch.cat((out_upconv1, disp2_up), 1)
|
torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
|
||||||
out_iconv1 = self.iconv1(concat1)
|
concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1)
|
||||||
disp1 = self.predict_disp1(out_iconv1)
|
out_iconv2 = self.iconv2(concat2)
|
||||||
|
disp2 = self.predict_disp2(out_iconv2)
|
||||||
|
|
||||||
if self.output_ms:
|
out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x)
|
||||||
return disp1, disp2, disp3
|
disp2_up = self.crop_like(
|
||||||
else:
|
torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x)
|
||||||
return disp1
|
concat1 = torch.cat((out_upconv1, disp2_up), 1)
|
||||||
|
out_iconv1 = self.iconv1(concat1)
|
||||||
|
disp1 = self.predict_disp1(out_iconv1)
|
||||||
|
|
||||||
|
if self.output_ms:
|
||||||
|
return disp1, disp2, disp3
|
||||||
|
else:
|
||||||
|
return disp1
|
||||||
|
|
||||||
|
|
||||||
class DispEdgeDecoders(TimedModule):
|
class DispEdgeDecoders(TimedModule):
|
||||||
'''
|
'''
|
||||||
Disparity Decoder and Edge Decoder
|
Disparity Decoder and Edge Decoder
|
||||||
'''
|
'''
|
||||||
def __init__(self, *args, max_disp=128, **kwargs):
|
|
||||||
super(DispEdgeDecoders, self).__init__(mod_name='DispEdgeDecoders')
|
|
||||||
|
|
||||||
output_facs = [OutputLayerFactory( type='disp', params={ 'alpha': max_disp/(2**s), 'beta': 0, 'gamma': 1, 'offset': 3}) for s in range(4)]
|
def __init__(self, *args, max_disp=128, **kwargs):
|
||||||
self.disp_decoder = DispNetS(*args, output_facs=output_facs, **kwargs)
|
super(DispEdgeDecoders, self).__init__(mod_name='DispEdgeDecoders')
|
||||||
|
|
||||||
output_facs = [OutputLayerFactory( type='linear' ) for s in range(4)]
|
output_facs = [
|
||||||
self.edge_decoder = DispNetShallow(*args, output_facs=output_facs, **kwargs)
|
OutputLayerFactory(type='disp', params={'alpha': max_disp / (2 ** s), 'beta': 0, 'gamma': 1, 'offset': 3})
|
||||||
|
for s in range(4)]
|
||||||
|
self.disp_decoder = DispNetS(*args, output_facs=output_facs, **kwargs)
|
||||||
|
|
||||||
def tforward(self, x):
|
output_facs = [OutputLayerFactory(type='linear') for s in range(4)]
|
||||||
disp = self.disp_decoder(x)
|
self.edge_decoder = DispNetShallow(*args, output_facs=output_facs, **kwargs)
|
||||||
edge = self.edge_decoder(x)
|
|
||||||
return disp, edge
|
def tforward(self, x):
|
||||||
|
disp = self.disp_decoder(x)
|
||||||
|
edge = self.edge_decoder(x)
|
||||||
|
return disp, edge
|
||||||
|
|
||||||
|
|
||||||
class DispToDepth(TimedModule):
|
class DispToDepth(TimedModule):
|
||||||
def __init__(self, focal_length, baseline):
|
def __init__(self, focal_length, baseline):
|
||||||
super().__init__(mod_name='DispToDepth')
|
super().__init__(mod_name='DispToDepth')
|
||||||
self.baseline_focal_length = baseline * focal_length
|
self.baseline_focal_length = baseline * focal_length
|
||||||
|
|
||||||
def tforward(self, disp):
|
def tforward(self, disp):
|
||||||
disp = torch.nn.functional.relu(disp) + 1e-12
|
disp = torch.nn.functional.relu(disp) + 1e-12
|
||||||
depth = self.baseline_focal_length / disp
|
depth = self.baseline_focal_length / disp
|
||||||
return depth
|
return depth
|
||||||
|
|
||||||
|
|
||||||
class PosToDepth(DispToDepth):
|
class PosToDepth(DispToDepth):
|
||||||
def __init__(self, focal_length, baseline, im_height, im_width):
|
def __init__(self, focal_length, baseline, im_height, im_width):
|
||||||
super().__init__(focal_length, baseline)
|
super().__init__(focal_length, baseline)
|
||||||
self.mod_name = 'PosToDepth'
|
self.mod_name = 'PosToDepth'
|
||||||
|
|
||||||
self.im_height = im_height
|
self.im_height = im_height
|
||||||
self.im_width = im_width
|
self.im_width = im_width
|
||||||
self.u_pos = torch.arange(im_width, dtype=torch.float32).view(1,1,1,-1)
|
self.u_pos = torch.arange(im_width, dtype=torch.float32).view(1, 1, 1, -1)
|
||||||
|
|
||||||
def tforward(self, pos):
|
|
||||||
self.u_pos = self.u_pos.to(pos.device)
|
|
||||||
disp = self.u_pos - pos
|
|
||||||
return super().forward(disp)
|
|
||||||
|
|
||||||
|
def tforward(self, pos):
|
||||||
|
self.u_pos = self.u_pos.to(pos.device)
|
||||||
|
disp = self.u_pos - pos
|
||||||
|
return super().forward(disp)
|
||||||
|
|
||||||
|
|
||||||
class RectifiedPatternSimilarityLoss(TimedModule):
|
class RectifiedPatternSimilarityLoss(TimedModule):
|
||||||
'''
|
'''
|
||||||
Photometric Loss
|
Photometric Loss
|
||||||
'''
|
'''
|
||||||
def __init__(self, im_height, im_width, pattern, loss_type='census_sad', loss_eps=0.5):
|
|
||||||
super().__init__(mod_name='RectifiedPatternSimilarityLoss')
|
|
||||||
self.im_height = im_height
|
|
||||||
self.im_width = im_width
|
|
||||||
self.pattern = pattern.mean(dim=1, keepdim=True).contiguous()
|
|
||||||
|
|
||||||
u, v = np.meshgrid(range(im_width), range(im_height))
|
def __init__(self, im_height, im_width, pattern, loss_type='census_sad', loss_eps=0.5):
|
||||||
uv0 = np.stack((u,v), axis=2).reshape(-1,1)
|
super().__init__(mod_name='RectifiedPatternSimilarityLoss')
|
||||||
uv0 = uv0.astype(np.float32).reshape(1,-1,2)
|
self.im_height = im_height
|
||||||
self.uv0 = torch.from_numpy(uv0)
|
self.im_width = im_width
|
||||||
|
self.pattern = pattern.mean(dim=1, keepdim=True).contiguous()
|
||||||
|
|
||||||
self.loss_type = loss_type
|
u, v = np.meshgrid(range(im_width), range(im_height))
|
||||||
self.loss_eps = loss_eps
|
uv0 = np.stack((u, v), axis=2).reshape(-1, 1)
|
||||||
|
uv0 = uv0.astype(np.float32).reshape(1, -1, 2)
|
||||||
|
self.uv0 = torch.from_numpy(uv0)
|
||||||
|
|
||||||
def tforward(self, disp0, im, std=None):
|
self.loss_type = loss_type
|
||||||
self.pattern = self.pattern.to(disp0.device)
|
self.loss_eps = loss_eps
|
||||||
self.uv0 = self.uv0.to(disp0.device)
|
|
||||||
|
|
||||||
uv0 = self.uv0.expand(disp0.shape[0], *self.uv0.shape[1:])
|
def tforward(self, disp0, im, std=None):
|
||||||
uv1 = torch.empty_like(uv0)
|
self.pattern = self.pattern.to(disp0.device)
|
||||||
uv1[...,0] = uv0[...,0] - disp0.contiguous().view(disp0.shape[0],-1)
|
self.uv0 = self.uv0.to(disp0.device)
|
||||||
uv1[...,1] = uv0[...,1]
|
|
||||||
|
|
||||||
uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width-1) - 0.5)
|
uv0 = self.uv0.expand(disp0.shape[0], *self.uv0.shape[1:])
|
||||||
uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height-1) - 0.5)
|
uv1 = torch.empty_like(uv0)
|
||||||
uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone()
|
uv1[..., 0] = uv0[..., 0] - disp0.contiguous().view(disp0.shape[0], -1)
|
||||||
pattern = self.pattern.expand(disp0.shape[0], *self.pattern.shape[1:])
|
uv1[..., 1] = uv0[..., 1]
|
||||||
pattern_proj = torch.nn.functional.grid_sample(pattern, uv1, padding_mode='border')
|
|
||||||
mask = torch.ones_like(im)
|
uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width - 1) - 0.5)
|
||||||
if std is not None:
|
uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height - 1) - 0.5)
|
||||||
mask = mask*std
|
uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone()
|
||||||
|
pattern = self.pattern.expand(disp0.shape[0], *self.pattern.shape[1:])
|
||||||
|
pattern_proj = torch.nn.functional.grid_sample(pattern, uv1, padding_mode='border')
|
||||||
|
mask = torch.ones_like(im)
|
||||||
|
if std is not None:
|
||||||
|
mask = mask * std
|
||||||
|
|
||||||
|
diff = torchext.photometric_loss(pattern_proj.contiguous(), im.contiguous(), 9, self.loss_type, self.loss_eps)
|
||||||
|
val = (mask * diff).sum() / mask.sum()
|
||||||
|
return val, pattern_proj
|
||||||
|
|
||||||
diff = torchext.photometric_loss(pattern_proj.contiguous(), im.contiguous(), 9, self.loss_type, self.loss_eps)
|
|
||||||
val = (mask*diff).sum() / mask.sum()
|
|
||||||
return val, pattern_proj
|
|
||||||
|
|
||||||
class DisparityLoss(TimedModule):
|
class DisparityLoss(TimedModule):
|
||||||
'''
|
'''
|
||||||
Disparity Loss
|
Disparity Loss
|
||||||
'''
|
'''
|
||||||
def __init__(self):
|
|
||||||
super().__init__(mod_name='DisparityLoss')
|
|
||||||
self.sobel = SobelFilter(norm=False)
|
|
||||||
|
|
||||||
#if not edge_gt:
|
def __init__(self):
|
||||||
self.b0=0.0503428816795
|
super().__init__(mod_name='DisparityLoss')
|
||||||
self.b1=1.07274045944
|
self.sobel = SobelFilter(norm=False)
|
||||||
#else:
|
|
||||||
# self.b0=0.0587115108967
|
|
||||||
# self.b1=1.51931190491
|
|
||||||
|
|
||||||
def tforward(self, disp, edge=None):
|
# if not edge_gt:
|
||||||
self.sobel=self.sobel.to(disp.device)
|
self.b0 = 0.0503428816795
|
||||||
|
self.b1 = 1.07274045944
|
||||||
|
# else:
|
||||||
|
# self.b0=0.0587115108967
|
||||||
|
# self.b1=1.51931190491
|
||||||
|
|
||||||
if edge is not None:
|
def tforward(self, disp, edge=None):
|
||||||
grad = self.sobel(disp)
|
self.sobel = self.sobel.to(disp.device)
|
||||||
grad = torch.sqrt(grad[:,0:1,...]**2 + grad[:,1:2,...]**2 + 1e-8)
|
|
||||||
pdf = (1-edge)/self.b0 * torch.exp(-torch.abs(grad)/self.b0) + \
|
|
||||||
edge/self.b1 * torch.exp(-torch.abs(grad)/self.b1)
|
|
||||||
val = torch.mean(-torch.log(pdf.clamp(min=1e-4)))
|
|
||||||
else:
|
|
||||||
# on qifeng's data we don't have ambient info
|
|
||||||
# therefore we supress edge everywhere
|
|
||||||
grad = self.sobel(disp)
|
|
||||||
grad = torch.sqrt(grad[:,0:1,...]**2 + grad[:,1:2,...]**2 + 1e-8)
|
|
||||||
grad= torch.clamp(grad, 0, 1.0)
|
|
||||||
val = torch.mean(grad)
|
|
||||||
|
|
||||||
return val
|
if edge is not None:
|
||||||
|
grad = self.sobel(disp)
|
||||||
|
grad = torch.sqrt(grad[:, 0:1, ...] ** 2 + grad[:, 1:2, ...] ** 2 + 1e-8)
|
||||||
|
pdf = (1 - edge) / self.b0 * torch.exp(-torch.abs(grad) / self.b0) + \
|
||||||
|
edge / self.b1 * torch.exp(-torch.abs(grad) / self.b1)
|
||||||
|
val = torch.mean(-torch.log(pdf.clamp(min=1e-4)))
|
||||||
|
else:
|
||||||
|
# on qifeng's data we don't have ambient info
|
||||||
|
# therefore we supress edge everywhere
|
||||||
|
grad = self.sobel(disp)
|
||||||
|
grad = torch.sqrt(grad[:, 0:1, ...] ** 2 + grad[:, 1:2, ...] ** 2 + 1e-8)
|
||||||
|
grad = torch.clamp(grad, 0, 1.0)
|
||||||
|
val = torch.mean(grad)
|
||||||
|
|
||||||
|
return val
|
||||||
|
|
||||||
|
|
||||||
class ProjectionBaseLoss(TimedModule):
|
class ProjectionBaseLoss(TimedModule):
|
||||||
'''
|
'''
|
||||||
Base module of the Geometric Loss
|
Base module of the Geometric Loss
|
||||||
'''
|
'''
|
||||||
def __init__(self, K, Ki, im_height, im_width):
|
|
||||||
super().__init__(mod_name='ProjectionBaseLoss')
|
|
||||||
|
|
||||||
self.K = K.view(-1,3,3)
|
def __init__(self, K, Ki, im_height, im_width):
|
||||||
|
super().__init__(mod_name='ProjectionBaseLoss')
|
||||||
|
|
||||||
self.im_height = im_height
|
self.K = K.view(-1, 3, 3)
|
||||||
self.im_width = im_width
|
|
||||||
|
|
||||||
u, v = np.meshgrid(range(im_width), range(im_height))
|
self.im_height = im_height
|
||||||
uv = np.stack((u,v,np.ones_like(u)), axis=2).reshape(-1,3)
|
self.im_width = im_width
|
||||||
|
|
||||||
ray = uv @ Ki.numpy().T
|
u, v = np.meshgrid(range(im_width), range(im_height))
|
||||||
|
uv = np.stack((u, v, np.ones_like(u)), axis=2).reshape(-1, 3)
|
||||||
|
|
||||||
ray = ray.reshape(1,-1,3).astype(np.float32)
|
ray = uv @ Ki.numpy().T
|
||||||
self.ray = torch.from_numpy(ray)
|
|
||||||
|
|
||||||
def transform(self, xyz, R=None, t=None):
|
ray = ray.reshape(1, -1, 3).astype(np.float32)
|
||||||
if t is not None:
|
self.ray = torch.from_numpy(ray)
|
||||||
bs = xyz.shape[0]
|
|
||||||
xyz = xyz - t.reshape(bs,1,3)
|
|
||||||
if R is not None:
|
|
||||||
xyz = torch.bmm(xyz, R)
|
|
||||||
return xyz
|
|
||||||
|
|
||||||
def unproject(self, depth, R=None, t=None):
|
def transform(self, xyz, R=None, t=None):
|
||||||
self.ray = self.ray.to(depth.device)
|
if t is not None:
|
||||||
bs = depth.shape[0]
|
bs = xyz.shape[0]
|
||||||
|
xyz = xyz - t.reshape(bs, 1, 3)
|
||||||
|
if R is not None:
|
||||||
|
xyz = torch.bmm(xyz, R)
|
||||||
|
return xyz
|
||||||
|
|
||||||
xyz = depth.reshape(bs,-1,1) * self.ray
|
def unproject(self, depth, R=None, t=None):
|
||||||
xyz = self.transform(xyz, R, t)
|
self.ray = self.ray.to(depth.device)
|
||||||
return xyz
|
bs = depth.shape[0]
|
||||||
|
|
||||||
def project(self, xyz, R, t):
|
xyz = depth.reshape(bs, -1, 1) * self.ray
|
||||||
self.K = self.K.to(xyz.device)
|
xyz = self.transform(xyz, R, t)
|
||||||
bs = xyz.shape[0]
|
return xyz
|
||||||
|
|
||||||
xyz = torch.bmm(xyz, R.transpose(1,2))
|
def project(self, xyz, R, t):
|
||||||
xyz = xyz + t.reshape(bs,1,3)
|
self.K = self.K.to(xyz.device)
|
||||||
|
bs = xyz.shape[0]
|
||||||
|
|
||||||
Kt = self.K.transpose(1,2).expand(bs,-1,-1)
|
xyz = torch.bmm(xyz, R.transpose(1, 2))
|
||||||
uv = torch.bmm(xyz, Kt)
|
xyz = xyz + t.reshape(bs, 1, 3)
|
||||||
|
|
||||||
d = uv[:,:,2:3]
|
Kt = self.K.transpose(1, 2).expand(bs, -1, -1)
|
||||||
|
uv = torch.bmm(xyz, Kt)
|
||||||
|
|
||||||
# avoid division by zero
|
d = uv[:, :, 2:3]
|
||||||
uv = uv[:,:,:2] / (torch.nn.functional.relu(d) + 1e-12)
|
|
||||||
return uv, d
|
|
||||||
|
|
||||||
|
# avoid division by zero
|
||||||
|
uv = uv[:, :, :2] / (torch.nn.functional.relu(d) + 1e-12)
|
||||||
|
return uv, d
|
||||||
|
|
||||||
def tforward(self, depth0, R0, t0, R1, t1):
|
def tforward(self, depth0, R0, t0, R1, t1):
|
||||||
xyz = self.unproject(depth0, R0, t0)
|
xyz = self.unproject(depth0, R0, t0)
|
||||||
return self.project(xyz, R1, t1)
|
return self.project(xyz, R1, t1)
|
||||||
|
|
||||||
|
|
||||||
class ProjectionDepthSimilarityLoss(ProjectionBaseLoss):
|
class ProjectionDepthSimilarityLoss(ProjectionBaseLoss):
|
||||||
'''
|
'''
|
||||||
Geometric Loss
|
Geometric Loss
|
||||||
'''
|
'''
|
||||||
def __init__(self, *args, clamp=-1):
|
|
||||||
super().__init__(*args)
|
|
||||||
self.mod_name = 'ProjectionDepthSimilarityLoss'
|
|
||||||
self.clamp = clamp
|
|
||||||
|
|
||||||
def fwd(self, depth0, depth1, R0, t0, R1, t1):
|
def __init__(self, *args, clamp=-1):
|
||||||
uv1, d1 = super().tforward(depth0, R0, t0, R1, t1)
|
super().__init__(*args)
|
||||||
|
self.mod_name = 'ProjectionDepthSimilarityLoss'
|
||||||
|
self.clamp = clamp
|
||||||
|
|
||||||
uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width-1) - 0.5)
|
def fwd(self, depth0, depth1, R0, t0, R1, t1):
|
||||||
uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height-1) - 0.5)
|
uv1, d1 = super().tforward(depth0, R0, t0, R1, t1)
|
||||||
uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone()
|
|
||||||
|
|
||||||
depth10 = torch.nn.functional.grid_sample(depth1, uv1, padding_mode='border')
|
uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width - 1) - 0.5)
|
||||||
|
uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height - 1) - 0.5)
|
||||||
|
uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone()
|
||||||
|
|
||||||
diff = torch.abs(d1.view(-1) - depth10.view(-1))
|
depth10 = torch.nn.functional.grid_sample(depth1, uv1, padding_mode='border')
|
||||||
|
|
||||||
if self.clamp > 0:
|
diff = torch.abs(d1.view(-1) - depth10.view(-1))
|
||||||
diff = torch.clamp(diff, 0, self.clamp)
|
|
||||||
|
|
||||||
# return diff without clamping for debugging
|
if self.clamp > 0:
|
||||||
return diff.mean()
|
diff = torch.clamp(diff, 0, self.clamp)
|
||||||
|
|
||||||
def tforward(self, depth0, depth1, R0, t0, R1, t1):
|
# return diff without clamping for debugging
|
||||||
l0 = self.fwd(depth0, depth1, R0, t0, R1, t1)
|
return diff.mean()
|
||||||
l1 = self.fwd(depth1, depth0, R1, t1, R0, t0)
|
|
||||||
return l0+l1
|
|
||||||
|
|
||||||
|
def tforward(self, depth0, depth1, R0, t0, R1, t1):
|
||||||
|
l0 = self.fwd(depth0, depth1, R0, t0, R1, t1)
|
||||||
|
l1 = self.fwd(depth1, depth0, R1, t1, R0, t0)
|
||||||
|
return l0 + l1
|
||||||
|
|
||||||
|
|
||||||
class LCN(TimedModule):
|
class LCN(TimedModule):
|
||||||
'''
|
'''
|
||||||
Local Contract Normalization
|
Local Contract Normalization
|
||||||
'''
|
'''
|
||||||
def __init__(self, radius, epsilon):
|
|
||||||
super().__init__(mod_name='LCN')
|
|
||||||
self.box_conv = torch.nn.Sequential(
|
|
||||||
torch.nn.ReflectionPad2d(radius),
|
|
||||||
torch.nn.Conv2d(1, 1, kernel_size=2*radius+1, bias=False)
|
|
||||||
)
|
|
||||||
self.box_conv[1].weight.requires_grad=False
|
|
||||||
self.box_conv[1].weight.fill_(1.)
|
|
||||||
|
|
||||||
self.epsilon = epsilon
|
def __init__(self, radius, epsilon):
|
||||||
self.radius = radius
|
super().__init__(mod_name='LCN')
|
||||||
|
self.box_conv = torch.nn.Sequential(
|
||||||
|
torch.nn.ReflectionPad2d(radius),
|
||||||
|
torch.nn.Conv2d(1, 1, kernel_size=2 * radius + 1, bias=False)
|
||||||
|
)
|
||||||
|
self.box_conv[1].weight.requires_grad = False
|
||||||
|
self.box_conv[1].weight.fill_(1.)
|
||||||
|
|
||||||
def tforward(self, data):
|
self.epsilon = epsilon
|
||||||
boxs = self.box_conv(data)
|
self.radius = radius
|
||||||
|
|
||||||
avgs = boxs / (2*self.radius+1)**2
|
def tforward(self, data):
|
||||||
boxs_n2 = boxs**2
|
boxs = self.box_conv(data)
|
||||||
boxs_2n = self.box_conv(data**2)
|
|
||||||
|
|
||||||
stds = torch.sqrt(boxs_2n / (2*self.radius+1)**2 - avgs**2 + 1e-6)
|
avgs = boxs / (2 * self.radius + 1) ** 2
|
||||||
stds = stds + self.epsilon
|
boxs_n2 = boxs ** 2
|
||||||
|
boxs_2n = self.box_conv(data ** 2)
|
||||||
|
|
||||||
return (data - avgs) / stds, stds
|
stds = torch.sqrt(boxs_2n / (2 * self.radius + 1) ** 2 - avgs ** 2 + 1e-6)
|
||||||
|
stds = stds + self.epsilon
|
||||||
|
|
||||||
|
return (data - avgs) / stds, stds
|
||||||
|
|
||||||
|
|
||||||
class SobelFilter(TimedModule):
|
class SobelFilter(TimedModule):
|
||||||
'''
|
'''
|
||||||
Sobel Filter
|
Sobel Filter
|
||||||
'''
|
'''
|
||||||
def __init__(self, norm=False):
|
|
||||||
super(SobelFilter, self).__init__(mod_name='SobelFilter')
|
|
||||||
kx = np.array([[-5, -4, 0, 4, 5],
|
|
||||||
[-8, -10, 0, 10, 8],
|
|
||||||
[-10, -20, 0, 20, 10],
|
|
||||||
[-8, -10, 0, 10, 8],
|
|
||||||
[-5, -4, 0, 4, 5]])/240.0
|
|
||||||
ky = kx.copy().transpose(1,0)
|
|
||||||
|
|
||||||
self.conv_x=torch.nn.Conv2d(1, 1, kernel_size=5, stride=1, padding=0, bias=False)
|
def __init__(self, norm=False):
|
||||||
self.conv_x.weight=torch.nn.Parameter(torch.from_numpy(kx).float().unsqueeze(0).unsqueeze(0))
|
super(SobelFilter, self).__init__(mod_name='SobelFilter')
|
||||||
|
kx = np.array([[-5, -4, 0, 4, 5],
|
||||||
|
[-8, -10, 0, 10, 8],
|
||||||
|
[-10, -20, 0, 20, 10],
|
||||||
|
[-8, -10, 0, 10, 8],
|
||||||
|
[-5, -4, 0, 4, 5]]) / 240.0
|
||||||
|
ky = kx.copy().transpose(1, 0)
|
||||||
|
|
||||||
self.conv_y=torch.nn.Conv2d(1, 1, kernel_size=5, stride=1, padding=0, bias=False)
|
self.conv_x = torch.nn.Conv2d(1, 1, kernel_size=5, stride=1, padding=0, bias=False)
|
||||||
self.conv_y.weight=torch.nn.Parameter(torch.from_numpy(ky).float().unsqueeze(0).unsqueeze(0))
|
self.conv_x.weight = torch.nn.Parameter(torch.from_numpy(kx).float().unsqueeze(0).unsqueeze(0))
|
||||||
|
|
||||||
self.norm=norm
|
self.conv_y = torch.nn.Conv2d(1, 1, kernel_size=5, stride=1, padding=0, bias=False)
|
||||||
|
self.conv_y.weight = torch.nn.Parameter(torch.from_numpy(ky).float().unsqueeze(0).unsqueeze(0))
|
||||||
|
|
||||||
def tforward(self,x):
|
self.norm = norm
|
||||||
x = F.pad(x, (2,2,2,2), "replicate")
|
|
||||||
gx = self.conv_x(x)
|
|
||||||
gy = self.conv_y(x)
|
|
||||||
if self.norm:
|
|
||||||
return torch.sqrt(gx**2 + gy**2 + 1e-8)
|
|
||||||
else:
|
|
||||||
return torch.cat((gx, gy), dim=1)
|
|
||||||
|
|
||||||
|
def tforward(self, x):
|
||||||
|
x = F.pad(x, (2, 2, 2, 2), "replicate")
|
||||||
|
gx = self.conv_x(x)
|
||||||
|
gy = self.conv_y(x)
|
||||||
|
if self.norm:
|
||||||
|
return torch.sqrt(gx ** 2 + gy ** 2 + 1e-8)
|
||||||
|
else:
|
||||||
|
return torch.cat((gx, gy), dim=1)
|
||||||
|
64
readme.md
64
readme.md
@ -6,7 +6,9 @@ This repository contains the code for the paper
|
|||||||
|
|
||||||
**[Connecting the Dots: Learning Representations for Active Monocular Depth Estimation](http://www.cvlibs.net/publications/Riegler2019CVPR.pdf)**
|
**[Connecting the Dots: Learning Representations for Active Monocular Depth Estimation](http://www.cvlibs.net/publications/Riegler2019CVPR.pdf)**
|
||||||
<br>
|
<br>
|
||||||
[Gernot Riegler](https://griegler.github.io/), [Yiyi Liao](https://yiyiliao.github.io/), [Simon Donne](https://avg.is.tuebingen.mpg.de/person/sdonne), [Vladlen Koltun](http://vladlen.info/), and [Andreas Geiger](http://www.cvlibs.net/)
|
[Gernot Riegler](https://griegler.github.io/), [Yiyi Liao](https://yiyiliao.github.io/)
|
||||||
|
, [Simon Donne](https://avg.is.tuebingen.mpg.de/person/sdonne), [Vladlen Koltun](http://vladlen.info/),
|
||||||
|
and [Andreas Geiger](http://www.cvlibs.net/)
|
||||||
<br>
|
<br>
|
||||||
[CVPR 2019](http://cvpr2019.thecvf.com/)
|
[CVPR 2019](http://cvpr2019.thecvf.com/)
|
||||||
|
|
||||||
@ -24,40 +26,45 @@ If you find this code useful for your research, please cite
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## Dependencies
|
## Dependencies
|
||||||
|
|
||||||
The network training/evaluation code is based on `Pytorch`.
|
The network training/evaluation code is based on `Pytorch`.
|
||||||
|
|
||||||
```
|
```
|
||||||
PyTorch>=1.1
|
PyTorch>=1.1
|
||||||
Cuda>=10.0
|
Cuda>=10.0
|
||||||
```
|
```
|
||||||
|
|
||||||
Updated on 07.06.2021: The code is now compatible with the latest Pytorch version (1.8).
|
Updated on 07.06.2021: The code is now compatible with the latest Pytorch version (1.8).
|
||||||
|
|
||||||
The other python packages can be installed with `anaconda`:
|
The other python packages can be installed with `anaconda`:
|
||||||
|
|
||||||
```
|
```
|
||||||
conda install --file requirements.txt
|
conda install --file requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
### Structured Light Renderer
|
### Structured Light Renderer
|
||||||
To train and evaluate our method in a controlled setting, we implemented an structured light renderer.
|
|
||||||
It can be used to render a virtual scene (arbitrary triangle mesh) with the structured light pattern projected from a customizable projector location.
|
To train and evaluate our method in a controlled setting, we implemented an structured light renderer. It can be used to
|
||||||
To build it, first make sure the correct `CUDA_LIBRARY_PATH` is set in `config.json`.
|
render a virtual scene (arbitrary triangle mesh) with the structured light pattern projected from a customizable
|
||||||
Afterwards, the renderer can be build by running `make` within the `renderer` directory.
|
projector location. To build it, first make sure the correct `CUDA_LIBRARY_PATH` is set in `config.json`. Afterwards,
|
||||||
|
the renderer can be build by running `make` within the `renderer` directory.
|
||||||
|
|
||||||
### PyTorch Extensions
|
### PyTorch Extensions
|
||||||
The network training/evaluation code is based on `PyTorch`.
|
|
||||||
We implemented some custom layers that need to be built in the `torchext` directory.
|
The network training/evaluation code is based on `PyTorch`. We implemented some custom layers that need to be built in
|
||||||
Simply change into this directory and run
|
the `torchext` directory. Simply change into this directory and run
|
||||||
|
|
||||||
```
|
```
|
||||||
python setup.py build_ext --inplace
|
python setup.py build_ext --inplace
|
||||||
```
|
```
|
||||||
|
|
||||||
### Baseline HyperDepth
|
### Baseline HyperDepth
|
||||||
As baseline we partially re-implemented the random forest based method [HyperDepth](http://openaccess.thecvf.com/content_cvpr_2016/papers/Fanello_HyperDepth_Learning_Depth_CVPR_2016_paper.pdf).
|
|
||||||
The code resided in the `hyperdepth` directory and is implemented in `C++11` with a Python wrapper written in `Cython`.
|
As baseline we partially re-implemented the random forest based
|
||||||
To build it change into the directory and run
|
method [HyperDepth](http://openaccess.thecvf.com/content_cvpr_2016/papers/Fanello_HyperDepth_Learning_Depth_CVPR_2016_paper.pdf)
|
||||||
|
. The code resided in the `hyperdepth` directory and is implemented in `C++11` with a Python wrapper written in `Cython`
|
||||||
|
. To build it change into the directory and run
|
||||||
|
|
||||||
```
|
```
|
||||||
python setup.py build_ext --inplace
|
python setup.py build_ext --inplace
|
||||||
@ -65,42 +72,59 @@ python setup.py build_ext --inplace
|
|||||||
|
|
||||||
## Running
|
## Running
|
||||||
|
|
||||||
|
|
||||||
### Creating Synthetic Data
|
### Creating Synthetic Data
|
||||||
To create synthetic data and save it locally, download [ShapeNet V2](https://www.shapenet.org/) and correct `SHAPENET_ROOT` in `config.json`. Then the data can be generated and saved to `DATA_ROOT` in `config.json` by running
|
|
||||||
|
To create synthetic data and save it locally, download [ShapeNet V2](https://www.shapenet.org/) and
|
||||||
|
correct `SHAPENET_ROOT` in `config.json`. Then the data can be generated and saved to `DATA_ROOT` in `config.json` by
|
||||||
|
running
|
||||||
|
|
||||||
```
|
```
|
||||||
./create_syn_data.sh
|
./create_syn_data.sh
|
||||||
```
|
```
|
||||||
If you are only interested in evaluating our pre-trained model, [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip) is a validation set that contains a small amount of images.
|
|
||||||
|
If you are only interested in evaluating our pre-trained
|
||||||
|
model, [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip) is a
|
||||||
|
validation set that contains a small amount of images.
|
||||||
|
|
||||||
### Training Network
|
### Training Network
|
||||||
|
|
||||||
As a first stage, it is recommended to train the disparity decoder and edge decoder without the geometric loss. To train the network on synthetic data for the first stage run
|
As a first stage, it is recommended to train the disparity decoder and edge decoder without the geometric loss. To train
|
||||||
|
the network on synthetic data for the first stage run
|
||||||
|
|
||||||
```
|
```
|
||||||
python train_val.py
|
python train_val.py
|
||||||
```
|
```
|
||||||
|
|
||||||
After the model is pretrained without the geometric loss, the full model can be trained from the initialized weights by running
|
After the model is pretrained without the geometric loss, the full model can be trained from the initialized weights by
|
||||||
|
running
|
||||||
|
|
||||||
```
|
```
|
||||||
python train_val.py --loss phge
|
python train_val.py --loss phge
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
### Evaluating Network
|
### Evaluating Network
|
||||||
|
|
||||||
To evaluate a specific checkpoint, e.g. the 50th epoch, one can run
|
To evaluate a specific checkpoint, e.g. the 50th epoch, one can run
|
||||||
|
|
||||||
```
|
```
|
||||||
python train_val.py --cmd retest --epoch 50
|
python train_val.py --cmd retest --epoch 50
|
||||||
```
|
```
|
||||||
|
|
||||||
### Evaluating a Pre-trained Model
|
### Evaluating a Pre-trained Model
|
||||||
We provide a model pre-trained using the photometric loss. Once you have prepared the synthetic dataset and changed `DATA_ROOT` in `config.json`, the pre-trained model can be evaluated on the validation set by running:
|
|
||||||
|
We provide a model pre-trained using the photometric loss. Once you have prepared the synthetic dataset and
|
||||||
|
changed `DATA_ROOT` in `config.json`, the pre-trained model can be evaluated on the validation set by running:
|
||||||
|
|
||||||
```
|
```
|
||||||
mkdir -p output
|
mkdir -p output
|
||||||
mkdir -p output/exp_syn
|
mkdir -p output/exp_syn
|
||||||
wget -O output/exp_syn/net_0099.params https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/net_0099.params
|
wget -O output/exp_syn/net_0099.params https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/net_0099.params
|
||||||
python train_val.py --cmd retest --epoch 99
|
python train_val.py --cmd retest --epoch 99
|
||||||
```
|
```
|
||||||
You can also download our validation set from [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip).
|
|
||||||
|
You can also download our validation set
|
||||||
|
from [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip).
|
||||||
|
|
||||||
## Acknowledgement
|
## Acknowledgement
|
||||||
|
|
||||||
This work was supported by the Intel Network on Intelligent Systems.
|
This work was supported by the Intel Network on Intelligent Systems.
|
||||||
|
@ -10,7 +10,7 @@ import json
|
|||||||
this_dir = os.path.dirname(__file__)
|
this_dir = os.path.dirname(__file__)
|
||||||
|
|
||||||
with open('../config.json') as fp:
|
with open('../config.json') as fp:
|
||||||
config = json.load(fp)
|
config = json.load(fp)
|
||||||
|
|
||||||
extra_compile_args = ['-O3', '-std=c++11']
|
extra_compile_args = ['-O3', '-std=c++11']
|
||||||
|
|
||||||
@ -20,7 +20,7 @@ cuda_lib = 'cudart'
|
|||||||
|
|
||||||
sources = ['cyrender.pyx']
|
sources = ['cyrender.pyx']
|
||||||
extra_objects = [
|
extra_objects = [
|
||||||
os.path.join(this_dir, 'render/render_cpu.cpp.o'),
|
os.path.join(this_dir, 'render/render_cpu.cpp.o'),
|
||||||
]
|
]
|
||||||
library_dirs = []
|
library_dirs = []
|
||||||
libraries = ['m']
|
libraries = ['m']
|
||||||
@ -30,20 +30,20 @@ library_dirs.append(cuda_lib_dir)
|
|||||||
libraries.append(cuda_lib)
|
libraries.append(cuda_lib)
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="cyrender",
|
name="cyrender",
|
||||||
cmdclass= {'build_ext': build_ext},
|
cmdclass={'build_ext': build_ext},
|
||||||
ext_modules=[
|
ext_modules=[
|
||||||
Extension('cyrender',
|
Extension('cyrender',
|
||||||
sources,
|
sources,
|
||||||
extra_objects=extra_objects,
|
extra_objects=extra_objects,
|
||||||
language='c++',
|
language='c++',
|
||||||
library_dirs=library_dirs,
|
library_dirs=library_dirs,
|
||||||
libraries=libraries,
|
libraries=libraries,
|
||||||
include_dirs=[
|
include_dirs=[
|
||||||
np.get_include(),
|
np.get_include(),
|
||||||
],
|
],
|
||||||
extra_compile_args=extra_compile_args,
|
extra_compile_args=extra_compile_args,
|
||||||
# extra_link_args=extra_link_args
|
# extra_link_args=extra_link_args
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -2,65 +2,65 @@ import torch
|
|||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class TestSet(object):
|
class TestSet(object):
|
||||||
def __init__(self, name, dset, test_frequency=1):
|
def __init__(self, name, dset, test_frequency=1):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.dset = dset
|
self.dset = dset
|
||||||
self.test_frequency = test_frequency
|
self.test_frequency = test_frequency
|
||||||
|
|
||||||
|
|
||||||
class TestSets(list):
|
class TestSets(list):
|
||||||
def append(self, name, dset, test_frequency=1):
|
def append(self, name, dset, test_frequency=1):
|
||||||
super().append(TestSet(name, dset, test_frequency))
|
super().append(TestSet(name, dset, test_frequency))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MultiDataset(torch.utils.data.Dataset):
|
class MultiDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, *datasets):
|
def __init__(self, *datasets):
|
||||||
self.current_epoch = 0
|
self.current_epoch = 0
|
||||||
|
|
||||||
self.datasets = []
|
self.datasets = []
|
||||||
self.cum_n_samples = [0]
|
self.cum_n_samples = [0]
|
||||||
|
|
||||||
for dataset in datasets:
|
for dataset in datasets:
|
||||||
self.append(dataset)
|
self.append(dataset)
|
||||||
|
|
||||||
def append(self, dataset):
|
def append(self, dataset):
|
||||||
self.datasets.append(dataset)
|
self.datasets.append(dataset)
|
||||||
self.__update_cum_n_samples(dataset)
|
self.__update_cum_n_samples(dataset)
|
||||||
|
|
||||||
def __update_cum_n_samples(self, dataset):
|
def __update_cum_n_samples(self, dataset):
|
||||||
n_samples = self.cum_n_samples[-1] + len(dataset)
|
n_samples = self.cum_n_samples[-1] + len(dataset)
|
||||||
self.cum_n_samples.append(n_samples)
|
self.cum_n_samples.append(n_samples)
|
||||||
|
|
||||||
def dataset_updated(self):
|
def dataset_updated(self):
|
||||||
self.cum_n_samples = [0]
|
self.cum_n_samples = [0]
|
||||||
for dset in self.datasets:
|
for dset in self.datasets:
|
||||||
self.__update_cum_n_samples(dset)
|
self.__update_cum_n_samples(dset)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.cum_n_samples[-1]
|
return self.cum_n_samples[-1]
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
didx = np.searchsorted(self.cum_n_samples, idx, side='right') - 1
|
|
||||||
sidx = idx - self.cum_n_samples[didx]
|
|
||||||
return self.datasets[didx][sidx]
|
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
didx = np.searchsorted(self.cum_n_samples, idx, side='right') - 1
|
||||||
|
sidx = idx - self.cum_n_samples[didx]
|
||||||
|
return self.datasets[didx][sidx]
|
||||||
|
|
||||||
|
|
||||||
class BaseDataset(torch.utils.data.Dataset):
|
class BaseDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, train=True, fix_seed_per_epoch=False):
|
def __init__(self, train=True, fix_seed_per_epoch=False):
|
||||||
self.current_epoch = 0
|
self.current_epoch = 0
|
||||||
self.train = train
|
self.train = train
|
||||||
self.fix_seed_per_epoch = fix_seed_per_epoch
|
self.fix_seed_per_epoch = fix_seed_per_epoch
|
||||||
|
|
||||||
def get_rng(self, idx):
|
def get_rng(self, idx):
|
||||||
rng = np.random.RandomState()
|
rng = np.random.RandomState()
|
||||||
if self.train:
|
if self.train:
|
||||||
if self.fix_seed_per_epoch:
|
if self.fix_seed_per_epoch:
|
||||||
seed = 1 * len(self) + idx
|
seed = 1 * len(self) + idx
|
||||||
else:
|
else:
|
||||||
seed = (self.current_epoch + 1) * len(self) + idx
|
seed = (self.current_epoch + 1) * len(self) + idx
|
||||||
rng.seed(seed)
|
rng.seed(seed)
|
||||||
else:
|
else:
|
||||||
rng.seed(idx)
|
rng.seed(idx)
|
||||||
return rng
|
return rng
|
||||||
|
@ -2,146 +2,151 @@ import torch
|
|||||||
from . import ext_cpu
|
from . import ext_cpu
|
||||||
from . import ext_cuda
|
from . import ext_cuda
|
||||||
|
|
||||||
class NNFunction(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, in0, in1):
|
|
||||||
args = (in0, in1)
|
|
||||||
if in0.is_cuda:
|
|
||||||
out = ext_cuda.nn_cuda(*args)
|
|
||||||
else:
|
|
||||||
out = ext_cpu.nn_cpu(*args)
|
|
||||||
return out
|
|
||||||
|
|
||||||
@staticmethod
|
class NNFunction(torch.autograd.Function):
|
||||||
def backward(ctx, grad_out):
|
@staticmethod
|
||||||
return None, None
|
def forward(ctx, in0, in1):
|
||||||
|
args = (in0, in1)
|
||||||
|
if in0.is_cuda:
|
||||||
|
out = ext_cuda.nn_cuda(*args)
|
||||||
|
else:
|
||||||
|
out = ext_cpu.nn_cpu(*args)
|
||||||
|
return out
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_out):
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
def nn(in0, in1):
|
def nn(in0, in1):
|
||||||
return NNFunction.apply(in0, in1)
|
return NNFunction.apply(in0, in1)
|
||||||
|
|
||||||
|
|
||||||
class CrossCheckFunction(torch.autograd.Function):
|
class CrossCheckFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, in0, in1):
|
def forward(ctx, in0, in1):
|
||||||
args = (in0, in1)
|
args = (in0, in1)
|
||||||
if in0.is_cuda:
|
if in0.is_cuda:
|
||||||
out = ext_cuda.crosscheck_cuda(*args)
|
out = ext_cuda.crosscheck_cuda(*args)
|
||||||
else:
|
else:
|
||||||
out = ext_cpu.crosscheck_cpu(*args)
|
out = ext_cpu.crosscheck_cpu(*args)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_out):
|
||||||
|
return None, None
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_out):
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
def crosscheck(in0, in1):
|
def crosscheck(in0, in1):
|
||||||
return CrossCheckFunction.apply(in0, in1)
|
return CrossCheckFunction.apply(in0, in1)
|
||||||
|
|
||||||
|
|
||||||
class ProjNNFunction(torch.autograd.Function):
|
class ProjNNFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, xyz0, xyz1, K, patch_size):
|
def forward(ctx, xyz0, xyz1, K, patch_size):
|
||||||
args = (xyz0, xyz1, K, patch_size)
|
args = (xyz0, xyz1, K, patch_size)
|
||||||
if xyz0.is_cuda:
|
if xyz0.is_cuda:
|
||||||
out = ext_cuda.proj_nn_cuda(*args)
|
out = ext_cuda.proj_nn_cuda(*args)
|
||||||
else:
|
else:
|
||||||
out = ext_cpu.proj_nn_cpu(*args)
|
out = ext_cpu.proj_nn_cpu(*args)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_out):
|
||||||
|
return None, None, None, None
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_out):
|
|
||||||
return None, None, None, None
|
|
||||||
|
|
||||||
def proj_nn(xyz0, xyz1, K, patch_size):
|
def proj_nn(xyz0, xyz1, K, patch_size):
|
||||||
return ProjNNFunction.apply(xyz0, xyz1, K, patch_size)
|
return ProjNNFunction.apply(xyz0, xyz1, K, patch_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class XCorrVolFunction(torch.autograd.Function):
|
class XCorrVolFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, in0, in1, n_disps, block_size):
|
def forward(ctx, in0, in1, n_disps, block_size):
|
||||||
args = (in0, in1, n_disps, block_size)
|
args = (in0, in1, n_disps, block_size)
|
||||||
if in0.is_cuda:
|
if in0.is_cuda:
|
||||||
out = ext_cuda.xcorrvol_cuda(*args)
|
out = ext_cuda.xcorrvol_cuda(*args)
|
||||||
else:
|
else:
|
||||||
out = ext_cpu.xcorrvol_cpu(*args)
|
out = ext_cpu.xcorrvol_cpu(*args)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_out):
|
||||||
|
return None, None, None, None
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_out):
|
|
||||||
return None, None, None, None
|
|
||||||
|
|
||||||
def xcorrvol(in0, in1, n_disps, block_size):
|
def xcorrvol(in0, in1, n_disps, block_size):
|
||||||
return XCorrVolFunction.apply(in0, in1, n_disps, block_size)
|
return XCorrVolFunction.apply(in0, in1, n_disps, block_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class PhotometricLossFunction(torch.autograd.Function):
|
class PhotometricLossFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, es, ta, block_size, type, eps):
|
def forward(ctx, es, ta, block_size, type, eps):
|
||||||
args = (es, ta, block_size, type, eps)
|
args = (es, ta, block_size, type, eps)
|
||||||
ctx.save_for_backward(es, ta)
|
ctx.save_for_backward(es, ta)
|
||||||
ctx.block_size = block_size
|
ctx.block_size = block_size
|
||||||
ctx.type = type
|
ctx.type = type
|
||||||
ctx.eps = eps
|
ctx.eps = eps
|
||||||
if es.is_cuda:
|
if es.is_cuda:
|
||||||
out = ext_cuda.photometric_loss_forward(*args)
|
out = ext_cuda.photometric_loss_forward(*args)
|
||||||
else:
|
else:
|
||||||
out = ext_cpu.photometric_loss_forward(*args)
|
out = ext_cpu.photometric_loss_forward(*args)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_out):
|
||||||
|
es, ta = ctx.saved_tensors
|
||||||
|
block_size = ctx.block_size
|
||||||
|
type = ctx.type
|
||||||
|
eps = ctx.eps
|
||||||
|
args = (es, ta, grad_out.contiguous(), block_size, type, eps)
|
||||||
|
if grad_out.is_cuda:
|
||||||
|
grad_es = ext_cuda.photometric_loss_backward(*args)
|
||||||
|
else:
|
||||||
|
grad_es = ext_cpu.photometric_loss_backward(*args)
|
||||||
|
return grad_es, None, None, None, None
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_out):
|
|
||||||
es, ta = ctx.saved_tensors
|
|
||||||
block_size = ctx.block_size
|
|
||||||
type = ctx.type
|
|
||||||
eps = ctx.eps
|
|
||||||
args = (es, ta, grad_out.contiguous(), block_size, type, eps)
|
|
||||||
if grad_out.is_cuda:
|
|
||||||
grad_es = ext_cuda.photometric_loss_backward(*args)
|
|
||||||
else:
|
|
||||||
grad_es = ext_cpu.photometric_loss_backward(*args)
|
|
||||||
return grad_es, None, None, None, None
|
|
||||||
|
|
||||||
def photometric_loss(es, ta, block_size, type='mse', eps=0.1):
|
def photometric_loss(es, ta, block_size, type='mse', eps=0.1):
|
||||||
type = type.lower()
|
type = type.lower()
|
||||||
if type == 'mse':
|
if type == 'mse':
|
||||||
type = 0
|
type = 0
|
||||||
elif type == 'sad':
|
elif type == 'sad':
|
||||||
type = 1
|
type = 1
|
||||||
elif type == 'census_mse':
|
elif type == 'census_mse':
|
||||||
type = 2
|
type = 2
|
||||||
elif type == 'census_sad':
|
elif type == 'census_sad':
|
||||||
type = 3
|
type = 3
|
||||||
else:
|
else:
|
||||||
raise Exception('invalid loss type')
|
raise Exception('invalid loss type')
|
||||||
return PhotometricLossFunction.apply(es, ta, block_size, type, eps)
|
return PhotometricLossFunction.apply(es, ta, block_size, type, eps)
|
||||||
|
|
||||||
|
|
||||||
def photometric_loss_pytorch(es, ta, block_size, type='mse', eps=0.1):
|
def photometric_loss_pytorch(es, ta, block_size, type='mse', eps=0.1):
|
||||||
type = type.lower()
|
type = type.lower()
|
||||||
p = block_size // 2
|
p = block_size // 2
|
||||||
es_pad = torch.nn.functional.pad(es, (p,p,p,p), mode='replicate')
|
es_pad = torch.nn.functional.pad(es, (p, p, p, p), mode='replicate')
|
||||||
ta_pad = torch.nn.functional.pad(ta, (p,p,p,p), mode='replicate')
|
ta_pad = torch.nn.functional.pad(ta, (p, p, p, p), mode='replicate')
|
||||||
es_uf = torch.nn.functional.unfold(es_pad, kernel_size=block_size)
|
es_uf = torch.nn.functional.unfold(es_pad, kernel_size=block_size)
|
||||||
ta_uf = torch.nn.functional.unfold(ta_pad, kernel_size=block_size)
|
ta_uf = torch.nn.functional.unfold(ta_pad, kernel_size=block_size)
|
||||||
es_uf = es_uf.view(es.shape[0], es.shape[1], -1, es.shape[2], es.shape[3])
|
es_uf = es_uf.view(es.shape[0], es.shape[1], -1, es.shape[2], es.shape[3])
|
||||||
ta_uf = ta_uf.view(ta.shape[0], ta.shape[1], -1, ta.shape[2], ta.shape[3])
|
ta_uf = ta_uf.view(ta.shape[0], ta.shape[1], -1, ta.shape[2], ta.shape[3])
|
||||||
if type == 'mse':
|
if type == 'mse':
|
||||||
ref = (es_uf - ta_uf)**2
|
ref = (es_uf - ta_uf) ** 2
|
||||||
elif type == 'sad':
|
elif type == 'sad':
|
||||||
ref = torch.abs(es_uf - ta_uf)
|
ref = torch.abs(es_uf - ta_uf)
|
||||||
elif type == 'census_mse' or type == 'census_sad':
|
elif type == 'census_mse' or type == 'census_sad':
|
||||||
des = es_uf - es.unsqueeze(2)
|
des = es_uf - es.unsqueeze(2)
|
||||||
dta = ta_uf - ta.unsqueeze(2)
|
dta = ta_uf - ta.unsqueeze(2)
|
||||||
h_des = 0.5 * (1 + des / torch.sqrt(des * des + eps))
|
h_des = 0.5 * (1 + des / torch.sqrt(des * des + eps))
|
||||||
h_dta = 0.5 * (1 + dta / torch.sqrt(dta * dta + eps))
|
h_dta = 0.5 * (1 + dta / torch.sqrt(dta * dta + eps))
|
||||||
diff = h_des - h_dta
|
diff = h_des - h_dta
|
||||||
if type == 'census_mse':
|
if type == 'census_mse':
|
||||||
ref = diff * diff
|
ref = diff * diff
|
||||||
elif type == 'census_sad':
|
elif type == 'census_sad':
|
||||||
ref = torch.abs(diff)
|
ref = torch.abs(diff)
|
||||||
else:
|
else:
|
||||||
raise Exception('invalid loss type')
|
raise Exception('invalid loss type')
|
||||||
ref = ref.view(es.shape[0], -1, es.shape[2], es.shape[3])
|
ref = ref.view(es.shape[0], -1, es.shape[2], es.shape[3])
|
||||||
ref = torch.sum(ref, dim=1, keepdim=True) / block_size**2
|
ref = torch.sum(ref, dim=1, keepdim=True) / block_size ** 2
|
||||||
return ref
|
return ref
|
||||||
|
@ -4,24 +4,26 @@ import numpy as np
|
|||||||
|
|
||||||
from .functions import *
|
from .functions import *
|
||||||
|
|
||||||
|
|
||||||
class CoordConv2d(torch.nn.Module):
|
class CoordConv2d(torch.nn.Module):
|
||||||
def __init__(self, channels_in, channels_out, kernel_size, stride, padding):
|
def __init__(self, channels_in, channels_out, kernel_size, stride, padding):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.conv = torch.nn.Conv2d(channels_in+2, channels_out, kernel_size=kernel_size, padding=padding, stride=stride)
|
self.conv = torch.nn.Conv2d(channels_in + 2, channels_out, kernel_size=kernel_size, padding=padding,
|
||||||
|
stride=stride)
|
||||||
|
|
||||||
self.uv = None
|
self.uv = None
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.uv is None:
|
if self.uv is None:
|
||||||
height, width = x.shape[2], x.shape[3]
|
height, width = x.shape[2], x.shape[3]
|
||||||
u, v = np.meshgrid(range(width), range(height))
|
u, v = np.meshgrid(range(width), range(height))
|
||||||
u = 2 * u / (width - 1) - 1
|
u = 2 * u / (width - 1) - 1
|
||||||
v = 2 * v / (height - 1) - 1
|
v = 2 * v / (height - 1) - 1
|
||||||
uv = np.stack((u, v)).reshape(1, 2, height, width)
|
uv = np.stack((u, v)).reshape(1, 2, height, width)
|
||||||
self.uv = torch.from_numpy( uv.astype(np.float32) )
|
self.uv = torch.from_numpy(uv.astype(np.float32))
|
||||||
self.uv = self.uv.to(x.device)
|
self.uv = self.uv.to(x.device)
|
||||||
uv = self.uv.expand(x.shape[0], *self.uv.shape[1:])
|
uv = self.uv.expand(x.shape[0], *self.uv.shape[1:])
|
||||||
xuv = torch.cat((x, uv), dim=1)
|
xuv = torch.cat((x, uv), dim=1)
|
||||||
y = self.conv(xuv)
|
y = self.conv(xuv)
|
||||||
return y
|
return y
|
||||||
|
@ -11,11 +11,12 @@ nvcc_args = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='ext',
|
name='ext',
|
||||||
ext_modules=[
|
ext_modules=[
|
||||||
CppExtension('ext_cpu', ['ext/ext_cpu.cpp']),
|
CppExtension('ext_cpu', ['ext/ext_cpu.cpp']),
|
||||||
CUDAExtension('ext_cuda', ['ext/ext_cuda.cpp', 'ext/ext_kernel.cu'], extra_compile_args={'cxx': [], 'nvcc': nvcc_args}),
|
CUDAExtension('ext_cuda', ['ext/ext_cuda.cpp', 'ext/ext_kernel.cu'],
|
||||||
],
|
extra_compile_args={'cxx': [], 'nvcc': nvcc_args}),
|
||||||
cmdclass={'build_ext': BuildExtension},
|
],
|
||||||
include_dirs=include_dirs
|
cmdclass={'build_ext': BuildExtension},
|
||||||
|
include_dirs=include_dirs
|
||||||
)
|
)
|
||||||
|
@ -17,512 +17,516 @@ from collections import OrderedDict
|
|||||||
|
|
||||||
|
|
||||||
class StopWatch(object):
|
class StopWatch(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.timings = OrderedDict()
|
self.timings = OrderedDict()
|
||||||
self.starts = {}
|
self.starts = {}
|
||||||
|
|
||||||
def start(self, name):
|
def start(self, name):
|
||||||
self.starts[name] = time.time()
|
self.starts[name] = time.time()
|
||||||
|
|
||||||
def stop(self, name):
|
def stop(self, name):
|
||||||
if name not in self.timings:
|
if name not in self.timings:
|
||||||
self.timings[name] = []
|
self.timings[name] = []
|
||||||
self.timings[name].append(time.time() - self.starts[name])
|
self.timings[name].append(time.time() - self.starts[name])
|
||||||
|
|
||||||
def get(self, name=None, reduce=np.sum):
|
def get(self, name=None, reduce=np.sum):
|
||||||
if name is not None:
|
if name is not None:
|
||||||
return reduce(self.timings[name])
|
return reduce(self.timings[name])
|
||||||
else:
|
else:
|
||||||
ret = {}
|
ret = {}
|
||||||
for k in self.timings:
|
for k in self.timings:
|
||||||
ret[k] = reduce(self.timings[k])
|
ret[k] = reduce(self.timings[k])
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
|
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
|
||||||
def __str__(self):
|
|
||||||
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
|
def __str__(self):
|
||||||
|
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
|
||||||
|
|
||||||
|
|
||||||
class ETA(object):
|
class ETA(object):
|
||||||
def __init__(self, length):
|
def __init__(self, length):
|
||||||
self.length = length
|
self.length = length
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
self.current_idx = 0
|
self.current_idx = 0
|
||||||
self.current_time = time.time()
|
self.current_time = time.time()
|
||||||
|
|
||||||
def update(self, idx):
|
def update(self, idx):
|
||||||
self.current_idx = idx
|
self.current_idx = idx
|
||||||
self.current_time = time.time()
|
self.current_time = time.time()
|
||||||
|
|
||||||
def get_elapsed_time(self):
|
def get_elapsed_time(self):
|
||||||
return self.current_time - self.start_time
|
return self.current_time - self.start_time
|
||||||
|
|
||||||
def get_item_time(self):
|
def get_item_time(self):
|
||||||
return self.get_elapsed_time() / (self.current_idx + 1)
|
return self.get_elapsed_time() / (self.current_idx + 1)
|
||||||
|
|
||||||
def get_remaining_time(self):
|
def get_remaining_time(self):
|
||||||
return self.get_item_time() * (self.length - self.current_idx + 1)
|
return self.get_item_time() * (self.length - self.current_idx + 1)
|
||||||
|
|
||||||
def format_time(self, seconds):
|
def format_time(self, seconds):
|
||||||
minutes, seconds = divmod(seconds, 60)
|
minutes, seconds = divmod(seconds, 60)
|
||||||
hours, minutes = divmod(minutes, 60)
|
hours, minutes = divmod(minutes, 60)
|
||||||
hours = int(hours)
|
hours = int(hours)
|
||||||
minutes = int(minutes)
|
minutes = int(minutes)
|
||||||
return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}'
|
return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}'
|
||||||
|
|
||||||
def get_elapsed_time_str(self):
|
def get_elapsed_time_str(self):
|
||||||
return self.format_time(self.get_elapsed_time())
|
return self.format_time(self.get_elapsed_time())
|
||||||
|
|
||||||
|
def get_remaining_time_str(self):
|
||||||
|
return self.format_time(self.get_remaining_time())
|
||||||
|
|
||||||
def get_remaining_time_str(self):
|
|
||||||
return self.format_time(self.get_remaining_time())
|
|
||||||
|
|
||||||
class Worker(object):
|
class Worker(object):
|
||||||
def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16, num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1):
|
def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16,
|
||||||
self.out_root = Path(out_root)
|
num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1):
|
||||||
self.experiment_name = experiment_name
|
self.out_root = Path(out_root)
|
||||||
self.epochs = epochs
|
self.experiment_name = experiment_name
|
||||||
self.seed = seed
|
self.epochs = epochs
|
||||||
self.train_batch_size = train_batch_size
|
self.seed = seed
|
||||||
self.test_batch_size = test_batch_size
|
self.train_batch_size = train_batch_size
|
||||||
self.num_workers = num_workers
|
self.test_batch_size = test_batch_size
|
||||||
self.save_frequency = save_frequency
|
self.num_workers = num_workers
|
||||||
self.train_device = train_device
|
self.save_frequency = save_frequency
|
||||||
self.test_device = test_device
|
self.train_device = train_device
|
||||||
self.max_train_iter = max_train_iter
|
self.test_device = test_device
|
||||||
|
self.max_train_iter = max_train_iter
|
||||||
|
|
||||||
self.errs_list=[]
|
self.errs_list = []
|
||||||
|
|
||||||
self.setup_experiment()
|
self.setup_experiment()
|
||||||
|
|
||||||
def setup_experiment(self):
|
def setup_experiment(self):
|
||||||
self.exp_out_root = self.out_root / self.experiment_name
|
self.exp_out_root = self.out_root / self.experiment_name
|
||||||
self.exp_out_root.mkdir(parents=True, exist_ok=True)
|
self.exp_out_root.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
if logging.root: del logging.root.handlers[:]
|
if logging.root: del logging.root.handlers[:]
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
handlers=[
|
handlers=[
|
||||||
logging.FileHandler( str(self.exp_out_root / 'train.log') ),
|
logging.FileHandler(str(self.exp_out_root / 'train.log')),
|
||||||
logging.StreamHandler()
|
logging.StreamHandler()
|
||||||
],
|
],
|
||||||
format='%(relativeCreated)d:%(levelname)s:%(process)d-%(processName)s: %(message)s'
|
format='%(relativeCreated)d:%(levelname)s:%(process)d-%(processName)s: %(message)s'
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info('='*80)
|
logging.info('=' * 80)
|
||||||
logging.info(f'Start of experiment: {self.experiment_name}')
|
logging.info(f'Start of experiment: {self.experiment_name}')
|
||||||
logging.info(socket.gethostname())
|
logging.info(socket.gethostname())
|
||||||
self.log_datetime()
|
self.log_datetime()
|
||||||
logging.info('='*80)
|
logging.info('=' * 80)
|
||||||
|
|
||||||
self.metric_path = self.exp_out_root / 'metrics.json'
|
self.metric_path = self.exp_out_root / 'metrics.json'
|
||||||
if self.metric_path.exists():
|
if self.metric_path.exists():
|
||||||
with open(str(self.metric_path), 'r') as fp:
|
with open(str(self.metric_path), 'r') as fp:
|
||||||
self.metric_data = json.load(fp)
|
self.metric_data = json.load(fp)
|
||||||
else:
|
else:
|
||||||
self.metric_data = {}
|
self.metric_data = {}
|
||||||
|
|
||||||
self.init_seed()
|
self.init_seed()
|
||||||
|
|
||||||
def metric_add_train(self, epoch, key, val):
|
def metric_add_train(self, epoch, key, val):
|
||||||
epoch = str(epoch)
|
epoch = str(epoch)
|
||||||
key = str(key)
|
key = str(key)
|
||||||
if epoch not in self.metric_data:
|
if epoch not in self.metric_data:
|
||||||
self.metric_data[epoch] = {}
|
self.metric_data[epoch] = {}
|
||||||
if 'train' not in self.metric_data[epoch]:
|
if 'train' not in self.metric_data[epoch]:
|
||||||
self.metric_data[epoch]['train'] = {}
|
self.metric_data[epoch]['train'] = {}
|
||||||
self.metric_data[epoch]['train'][key] = val
|
self.metric_data[epoch]['train'][key] = val
|
||||||
|
|
||||||
def metric_add_test(self, epoch, set_idx, key, val):
|
def metric_add_test(self, epoch, set_idx, key, val):
|
||||||
epoch = str(epoch)
|
epoch = str(epoch)
|
||||||
set_idx = str(set_idx)
|
set_idx = str(set_idx)
|
||||||
key = str(key)
|
key = str(key)
|
||||||
if epoch not in self.metric_data:
|
if epoch not in self.metric_data:
|
||||||
self.metric_data[epoch] = {}
|
self.metric_data[epoch] = {}
|
||||||
if 'test' not in self.metric_data[epoch]:
|
if 'test' not in self.metric_data[epoch]:
|
||||||
self.metric_data[epoch]['test'] = {}
|
self.metric_data[epoch]['test'] = {}
|
||||||
if set_idx not in self.metric_data[epoch]['test']:
|
if set_idx not in self.metric_data[epoch]['test']:
|
||||||
self.metric_data[epoch]['test'][set_idx] = {}
|
self.metric_data[epoch]['test'][set_idx] = {}
|
||||||
self.metric_data[epoch]['test'][set_idx][key] = val
|
self.metric_data[epoch]['test'][set_idx][key] = val
|
||||||
|
|
||||||
def metric_save(self):
|
def metric_save(self):
|
||||||
with open(str(self.metric_path), 'w') as fp:
|
with open(str(self.metric_path), 'w') as fp:
|
||||||
json.dump(self.metric_data, fp, indent=2)
|
json.dump(self.metric_data, fp, indent=2)
|
||||||
|
|
||||||
def init_seed(self, seed=None):
|
def init_seed(self, seed=None):
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
logging.info(f'Set seed to {self.seed}')
|
logging.info(f'Set seed to {self.seed}')
|
||||||
np.random.seed(self.seed)
|
np.random.seed(self.seed)
|
||||||
random.seed(self.seed)
|
random.seed(self.seed)
|
||||||
torch.manual_seed(self.seed)
|
torch.manual_seed(self.seed)
|
||||||
torch.cuda.manual_seed(self.seed)
|
torch.cuda.manual_seed(self.seed)
|
||||||
|
|
||||||
def log_datetime(self):
|
def log_datetime(self):
|
||||||
logging.info(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
logging.info(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
||||||
|
|
||||||
def mem_report(self):
|
def mem_report(self):
|
||||||
for obj in gc.get_objects():
|
for obj in gc.get_objects():
|
||||||
if torch.is_tensor(obj):
|
if torch.is_tensor(obj):
|
||||||
print(type(obj), obj.shape)
|
print(type(obj), obj.shape)
|
||||||
|
|
||||||
def get_net_path(self, epoch, root=None):
|
def get_net_path(self, epoch, root=None):
|
||||||
if root is None:
|
if root is None:
|
||||||
root = self.exp_out_root
|
root = self.exp_out_root
|
||||||
return root / f'net_{epoch:04d}.params'
|
return root / f'net_{epoch:04d}.params'
|
||||||
|
|
||||||
def get_do_parser_cmds(self):
|
def get_do_parser_cmds(self):
|
||||||
return ['retrain', 'resume', 'retest', 'test_init']
|
return ['retrain', 'resume', 'retest', 'test_init']
|
||||||
|
|
||||||
def get_do_parser(self):
|
def get_do_parser(self):
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--cmd', type=str, default='resume', choices=self.get_do_parser_cmds())
|
parser.add_argument('--cmd', type=str, default='resume', choices=self.get_do_parser_cmds())
|
||||||
parser.add_argument('--epoch', type=int, default=-1)
|
parser.add_argument('--epoch', type=int, default=-1)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
def do_cmd(self, args, net, optimizer, scheduler=None):
|
def do_cmd(self, args, net, optimizer, scheduler=None):
|
||||||
if args.cmd == 'retrain':
|
if args.cmd == 'retrain':
|
||||||
self.train(net, optimizer, resume=False, scheduler=scheduler)
|
self.train(net, optimizer, resume=False, scheduler=scheduler)
|
||||||
elif args.cmd == 'resume':
|
elif args.cmd == 'resume':
|
||||||
self.train(net, optimizer, resume=True, scheduler=scheduler)
|
self.train(net, optimizer, resume=True, scheduler=scheduler)
|
||||||
elif args.cmd == 'retest':
|
elif args.cmd == 'retest':
|
||||||
self.retest(net, epoch=args.epoch)
|
self.retest(net, epoch=args.epoch)
|
||||||
elif args.cmd == 'test_init':
|
elif args.cmd == 'test_init':
|
||||||
test_sets = self.get_test_sets()
|
test_sets = self.get_test_sets()
|
||||||
self.test(-1, net, test_sets)
|
self.test(-1, net, test_sets)
|
||||||
else:
|
else:
|
||||||
raise Exception('invalid cmd')
|
raise Exception('invalid cmd')
|
||||||
|
|
||||||
def do(self, net, optimizer, load_net_optimizer=None, scheduler=None):
|
def do(self, net, optimizer, load_net_optimizer=None, scheduler=None):
|
||||||
parser = self.get_do_parser()
|
parser = self.get_do_parser()
|
||||||
args, _ = parser.parse_known_args()
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
if load_net_optimizer is not None and args.cmd not in ['schedule']:
|
if load_net_optimizer is not None and args.cmd not in ['schedule']:
|
||||||
net, optimizer = load_net_optimizer()
|
net, optimizer = load_net_optimizer()
|
||||||
|
|
||||||
self.do_cmd(args, net, optimizer, scheduler=scheduler)
|
self.do_cmd(args, net, optimizer, scheduler=scheduler)
|
||||||
|
|
||||||
def retest(self, net, epoch=-1):
|
def retest(self, net, epoch=-1):
|
||||||
if epoch < 0:
|
if epoch < 0:
|
||||||
epochs = range(self.epochs)
|
epochs = range(self.epochs)
|
||||||
else:
|
else:
|
||||||
epochs = [epoch]
|
epochs = [epoch]
|
||||||
|
|
||||||
test_sets = self.get_test_sets()
|
test_sets = self.get_test_sets()
|
||||||
|
|
||||||
for epoch in epochs:
|
for epoch in epochs:
|
||||||
net_path = self.get_net_path(epoch)
|
net_path = self.get_net_path(epoch)
|
||||||
if net_path.exists():
|
if net_path.exists():
|
||||||
state_dict = torch.load(str(net_path))
|
state_dict = torch.load(str(net_path))
|
||||||
net.load_state_dict(state_dict)
|
net.load_state_dict(state_dict)
|
||||||
self.test(epoch, net, test_sets)
|
self.test(epoch, net, test_sets)
|
||||||
|
|
||||||
def format_err_str(self, errs, div=1):
|
def format_err_str(self, errs, div=1):
|
||||||
err = sum(errs)
|
err = sum(errs)
|
||||||
if len(errs) > 1:
|
if len(errs) > 1:
|
||||||
err_str = f'{err/div:0.4f}=' + '+'.join([f'{e/div:0.4f}' for e in errs])
|
err_str = f'{err / div:0.4f}=' + '+'.join([f'{e / div:0.4f}' for e in errs])
|
||||||
else:
|
else:
|
||||||
err_str = f'{err/div:0.4f}'
|
err_str = f'{err / div:0.4f}'
|
||||||
return err_str
|
return err_str
|
||||||
|
|
||||||
def write_err_img(self):
|
def write_err_img(self):
|
||||||
err_img_path = self.exp_out_root / 'errs.png'
|
err_img_path = self.exp_out_root / 'errs.png'
|
||||||
fig = plt.figure(figsize=(16,16))
|
fig = plt.figure(figsize=(16, 16))
|
||||||
lines=[]
|
lines = []
|
||||||
for idx,errs in enumerate(self.errs_list):
|
for idx, errs in enumerate(self.errs_list):
|
||||||
line,=plt.plot(range(len(errs)), errs, label=f'error{idx}')
|
line, = plt.plot(range(len(errs)), errs, label=f'error{idx}')
|
||||||
lines.append(line)
|
lines.append(line)
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.legend(handles=lines)
|
plt.legend(handles=lines)
|
||||||
plt.savefig(str(err_img_path))
|
plt.savefig(str(err_img_path))
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
|
||||||
|
def callback_train_new_epoch(self, epoch, net, optimizer):
|
||||||
def callback_train_new_epoch(self, epoch, net, optimizer):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def train(self, net, optimizer, resume=False, scheduler=None):
|
|
||||||
logging.info('='*80)
|
|
||||||
logging.info('Start training')
|
|
||||||
self.log_datetime()
|
|
||||||
logging.info('='*80)
|
|
||||||
|
|
||||||
train_set = self.get_train_set()
|
|
||||||
test_sets = self.get_test_sets()
|
|
||||||
|
|
||||||
net = net.to(self.train_device)
|
|
||||||
|
|
||||||
epoch = 0
|
|
||||||
min_err = {ts.name: 1e9 for ts in test_sets}
|
|
||||||
|
|
||||||
state_path = self.exp_out_root / 'state.dict'
|
|
||||||
if resume and state_path.exists():
|
|
||||||
logging.info('='*80)
|
|
||||||
logging.info(f'Loading state from {state_path}')
|
|
||||||
logging.info('='*80)
|
|
||||||
state = torch.load(str(state_path))
|
|
||||||
epoch = state['epoch'] + 1
|
|
||||||
if 'min_err' in state:
|
|
||||||
min_err = state['min_err']
|
|
||||||
|
|
||||||
curr_state = net.state_dict()
|
|
||||||
curr_state.update(state['state_dict'])
|
|
||||||
net.load_state_dict(curr_state)
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
optimizer.load_state_dict(state['optimizer'])
|
|
||||||
except:
|
|
||||||
logging.info('Warning: cannot load optimizer from state_dict')
|
|
||||||
pass
|
pass
|
||||||
if 'cpu_rng_state' in state:
|
|
||||||
torch.set_rng_state(state['cpu_rng_state'])
|
|
||||||
if 'gpu_rng_state' in state:
|
|
||||||
torch.cuda.set_rng_state(state['gpu_rng_state'])
|
|
||||||
|
|
||||||
for epoch in range(epoch, self.epochs):
|
def train(self, net, optimizer, resume=False, scheduler=None):
|
||||||
self.callback_train_new_epoch(epoch, net, optimizer)
|
logging.info('=' * 80)
|
||||||
|
logging.info('Start training')
|
||||||
|
self.log_datetime()
|
||||||
|
logging.info('=' * 80)
|
||||||
|
|
||||||
# train epoch
|
train_set = self.get_train_set()
|
||||||
self.train_epoch(epoch, net, optimizer, train_set)
|
test_sets = self.get_test_sets()
|
||||||
|
|
||||||
# test epoch
|
|
||||||
errs = self.test(epoch, net, test_sets)
|
|
||||||
|
|
||||||
if (epoch + 1) % self.save_frequency == 0:
|
|
||||||
net = net.to(self.train_device)
|
net = net.to(self.train_device)
|
||||||
|
|
||||||
# store state
|
epoch = 0
|
||||||
state_dict = {
|
min_err = {ts.name: 1e9 for ts in test_sets}
|
||||||
'epoch': epoch,
|
|
||||||
'min_err': min_err,
|
|
||||||
'state_dict': net.state_dict(),
|
|
||||||
'optimizer': optimizer.state_dict(),
|
|
||||||
'cpu_rng_state': torch.get_rng_state(),
|
|
||||||
'gpu_rng_state': torch.cuda.get_rng_state(),
|
|
||||||
}
|
|
||||||
logging.info(f'save state to {state_path}')
|
|
||||||
state_path = self.exp_out_root / 'state.dict'
|
state_path = self.exp_out_root / 'state.dict'
|
||||||
torch.save(state_dict, str(state_path))
|
if resume and state_path.exists():
|
||||||
|
logging.info('=' * 80)
|
||||||
|
logging.info(f'Loading state from {state_path}')
|
||||||
|
logging.info('=' * 80)
|
||||||
|
state = torch.load(str(state_path))
|
||||||
|
epoch = state['epoch'] + 1
|
||||||
|
if 'min_err' in state:
|
||||||
|
min_err = state['min_err']
|
||||||
|
|
||||||
for test_set_name in errs:
|
curr_state = net.state_dict()
|
||||||
err = sum(errs[test_set_name])
|
curr_state.update(state['state_dict'])
|
||||||
if err < min_err[test_set_name]:
|
net.load_state_dict(curr_state)
|
||||||
min_err[test_set_name] = err
|
|
||||||
state_path = self.exp_out_root / f'state_set{test_set_name}_best.dict'
|
|
||||||
logging.info(f'save state to {state_path}')
|
|
||||||
torch.save(state_dict, str(state_path))
|
|
||||||
|
|
||||||
# store network
|
try:
|
||||||
net_path = self.get_net_path(epoch)
|
optimizer.load_state_dict(state['optimizer'])
|
||||||
logging.info(f'save network to {net_path}')
|
except:
|
||||||
torch.save(net.state_dict(), str(net_path))
|
logging.info('Warning: cannot load optimizer from state_dict')
|
||||||
|
pass
|
||||||
|
if 'cpu_rng_state' in state:
|
||||||
|
torch.set_rng_state(state['cpu_rng_state'])
|
||||||
|
if 'gpu_rng_state' in state:
|
||||||
|
torch.cuda.set_rng_state(state['gpu_rng_state'])
|
||||||
|
|
||||||
if scheduler is not None:
|
for epoch in range(epoch, self.epochs):
|
||||||
scheduler.step()
|
self.callback_train_new_epoch(epoch, net, optimizer)
|
||||||
|
|
||||||
logging.info('='*80)
|
# train epoch
|
||||||
logging.info('Finished training')
|
self.train_epoch(epoch, net, optimizer, train_set)
|
||||||
self.log_datetime()
|
|
||||||
logging.info('='*80)
|
|
||||||
|
|
||||||
def get_train_set(self):
|
# test epoch
|
||||||
# returns train_set
|
errs = self.test(epoch, net, test_sets)
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def get_test_sets(self):
|
if (epoch + 1) % self.save_frequency == 0:
|
||||||
# returns test_sets
|
net = net.to(self.train_device)
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def copy_data(self, data, device, requires_grad, train):
|
# store state
|
||||||
raise NotImplementedError()
|
state_dict = {
|
||||||
|
'epoch': epoch,
|
||||||
|
'min_err': min_err,
|
||||||
|
'state_dict': net.state_dict(),
|
||||||
|
'optimizer': optimizer.state_dict(),
|
||||||
|
'cpu_rng_state': torch.get_rng_state(),
|
||||||
|
'gpu_rng_state': torch.cuda.get_rng_state(),
|
||||||
|
}
|
||||||
|
logging.info(f'save state to {state_path}')
|
||||||
|
state_path = self.exp_out_root / 'state.dict'
|
||||||
|
torch.save(state_dict, str(state_path))
|
||||||
|
|
||||||
def net_forward(self, net, train):
|
for test_set_name in errs:
|
||||||
raise NotImplementedError()
|
err = sum(errs[test_set_name])
|
||||||
|
if err < min_err[test_set_name]:
|
||||||
|
min_err[test_set_name] = err
|
||||||
|
state_path = self.exp_out_root / f'state_set{test_set_name}_best.dict'
|
||||||
|
logging.info(f'save state to {state_path}')
|
||||||
|
torch.save(state_dict, str(state_path))
|
||||||
|
|
||||||
def loss_forward(self, output, train):
|
# store network
|
||||||
raise NotImplementedError()
|
net_path = self.get_net_path(epoch)
|
||||||
|
logging.info(f'save network to {net_path}')
|
||||||
|
torch.save(net.state_dict(), str(net_path))
|
||||||
|
|
||||||
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks):
|
if scheduler is not None:
|
||||||
# err = False
|
scheduler.step()
|
||||||
# for name, param in net.named_parameters():
|
|
||||||
# if not torch.isfinite(param.grad).all():
|
|
||||||
# print(name)
|
|
||||||
# err = True
|
|
||||||
# if err:
|
|
||||||
# import ipdb; ipdb.set_trace()
|
|
||||||
pass
|
|
||||||
|
|
||||||
def callback_train_start(self, epoch):
|
logging.info('=' * 80)
|
||||||
pass
|
logging.info('Finished training')
|
||||||
|
self.log_datetime()
|
||||||
|
logging.info('=' * 80)
|
||||||
|
|
||||||
def callback_train_stop(self, epoch, loss):
|
def get_train_set(self):
|
||||||
pass
|
# returns train_set
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def train_epoch(self, epoch, net, optimizer, dset):
|
def get_test_sets(self):
|
||||||
self.callback_train_start(epoch)
|
# returns test_sets
|
||||||
stopwatch = StopWatch()
|
raise NotImplementedError()
|
||||||
|
|
||||||
logging.info('='*80)
|
def copy_data(self, data, device, requires_grad, train):
|
||||||
logging.info('Train epoch %d' % epoch)
|
raise NotImplementedError()
|
||||||
|
|
||||||
dset.current_epoch = epoch
|
def net_forward(self, net, train):
|
||||||
train_loader = torch.utils.data.DataLoader(dset, batch_size=self.train_batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True, pin_memory=False)
|
raise NotImplementedError()
|
||||||
|
|
||||||
net = net.to(self.train_device)
|
def loss_forward(self, output, train):
|
||||||
net.train()
|
raise NotImplementedError()
|
||||||
|
|
||||||
mean_loss = None
|
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks):
|
||||||
|
# err = False
|
||||||
|
# for name, param in net.named_parameters():
|
||||||
|
# if not torch.isfinite(param.grad).all():
|
||||||
|
# print(name)
|
||||||
|
# err = True
|
||||||
|
# if err:
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
pass
|
||||||
|
|
||||||
n_batches = self.max_train_iter if self.max_train_iter > 0 else len(train_loader)
|
def callback_train_start(self, epoch):
|
||||||
bar = ETA(length=n_batches)
|
pass
|
||||||
|
|
||||||
stopwatch.start('total')
|
def callback_train_stop(self, epoch, loss):
|
||||||
stopwatch.start('data')
|
pass
|
||||||
for batch_idx, data in enumerate(train_loader):
|
|
||||||
if self.max_train_iter > 0 and batch_idx > self.max_train_iter: break
|
|
||||||
self.copy_data(data, device=self.train_device, requires_grad=True, train=True)
|
|
||||||
stopwatch.stop('data')
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
def train_epoch(self, epoch, net, optimizer, dset):
|
||||||
|
self.callback_train_start(epoch)
|
||||||
|
stopwatch = StopWatch()
|
||||||
|
|
||||||
stopwatch.start('forward')
|
logging.info('=' * 80)
|
||||||
output = self.net_forward(net, train=True)
|
logging.info('Train epoch %d' % epoch)
|
||||||
if 'cuda' in self.train_device: torch.cuda.synchronize()
|
|
||||||
stopwatch.stop('forward')
|
|
||||||
|
|
||||||
stopwatch.start('loss')
|
dset.current_epoch = epoch
|
||||||
errs = self.loss_forward(output, train=True)
|
train_loader = torch.utils.data.DataLoader(dset, batch_size=self.train_batch_size, shuffle=True,
|
||||||
if isinstance(errs, dict):
|
num_workers=self.num_workers, drop_last=True, pin_memory=False)
|
||||||
masks = errs['masks']
|
|
||||||
errs = errs['errs']
|
|
||||||
else:
|
|
||||||
masks = []
|
|
||||||
if not isinstance(errs, list) and not isinstance(errs, tuple):
|
|
||||||
errs = [errs]
|
|
||||||
err = sum(errs)
|
|
||||||
if 'cuda' in self.train_device: torch.cuda.synchronize()
|
|
||||||
stopwatch.stop('loss')
|
|
||||||
|
|
||||||
stopwatch.start('backward')
|
net = net.to(self.train_device)
|
||||||
err.backward()
|
net.train()
|
||||||
self.callback_train_post_backward(net, errs, output, epoch, batch_idx, masks)
|
|
||||||
if 'cuda' in self.train_device: torch.cuda.synchronize()
|
|
||||||
stopwatch.stop('backward')
|
|
||||||
|
|
||||||
stopwatch.start('optimizer')
|
mean_loss = None
|
||||||
optimizer.step()
|
|
||||||
if 'cuda' in self.train_device: torch.cuda.synchronize()
|
|
||||||
stopwatch.stop('optimizer')
|
|
||||||
|
|
||||||
bar.update(batch_idx)
|
n_batches = self.max_train_iter if self.max_train_iter > 0 else len(train_loader)
|
||||||
if (epoch <= 1 and batch_idx < 128) or batch_idx % 16 == 0:
|
bar = ETA(length=n_batches)
|
||||||
err_str = self.format_err_str(errs)
|
|
||||||
logging.info(f'train e{epoch}: {batch_idx+1}/{len(train_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
|
|
||||||
#self.write_err_img()
|
|
||||||
|
|
||||||
|
|
||||||
if mean_loss is None:
|
|
||||||
mean_loss = [0 for e in errs]
|
|
||||||
for erridx, err in enumerate(errs):
|
|
||||||
mean_loss[erridx] += err.item()
|
|
||||||
|
|
||||||
stopwatch.start('data')
|
|
||||||
stopwatch.stop('total')
|
|
||||||
logging.info('timings: %s' % stopwatch)
|
|
||||||
|
|
||||||
mean_loss = [l / len(train_loader) for l in mean_loss]
|
|
||||||
self.callback_train_stop(epoch, mean_loss)
|
|
||||||
self.metric_add_train(epoch, 'loss', mean_loss)
|
|
||||||
|
|
||||||
# save metrics
|
|
||||||
self.metric_save()
|
|
||||||
|
|
||||||
err_str = self.format_err_str(mean_loss)
|
|
||||||
logging.info(f'avg train_loss={err_str}')
|
|
||||||
return mean_loss
|
|
||||||
|
|
||||||
def callback_test_start(self, epoch, set_idx):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def callback_test_stop(self, epoch, set_idx, loss):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test(self, epoch, net, test_sets):
|
|
||||||
errs = {}
|
|
||||||
for test_set_idx, test_set in enumerate(test_sets):
|
|
||||||
if (epoch + 1) % test_set.test_frequency == 0:
|
|
||||||
logging.info('='*80)
|
|
||||||
logging.info(f'testing set {test_set.name}')
|
|
||||||
err = self.test_epoch(epoch, test_set_idx, net, test_set.dset)
|
|
||||||
errs[test_set.name] = err
|
|
||||||
return errs
|
|
||||||
|
|
||||||
def test_epoch(self, epoch, set_idx, net, dset):
|
|
||||||
logging.info('-'*80)
|
|
||||||
logging.info('Test epoch %d' % epoch)
|
|
||||||
dset.current_epoch = epoch
|
|
||||||
test_loader = torch.utils.data.DataLoader(dset, batch_size=self.test_batch_size, shuffle=False, num_workers=self.num_workers, drop_last=False, pin_memory=False)
|
|
||||||
|
|
||||||
net = net.to(self.test_device)
|
|
||||||
net.eval()
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
mean_loss = None
|
|
||||||
|
|
||||||
self.callback_test_start(epoch, set_idx)
|
|
||||||
|
|
||||||
bar = ETA(length=len(test_loader))
|
|
||||||
stopwatch = StopWatch()
|
|
||||||
stopwatch.start('total')
|
|
||||||
stopwatch.start('data')
|
|
||||||
for batch_idx, data in enumerate(test_loader):
|
|
||||||
# if batch_idx == 10: break
|
|
||||||
self.copy_data(data, device=self.test_device, requires_grad=False, train=False)
|
|
||||||
stopwatch.stop('data')
|
|
||||||
|
|
||||||
stopwatch.start('forward')
|
|
||||||
output = self.net_forward(net, train=False)
|
|
||||||
if 'cuda' in self.test_device: torch.cuda.synchronize()
|
|
||||||
stopwatch.stop('forward')
|
|
||||||
|
|
||||||
stopwatch.start('loss')
|
|
||||||
errs = self.loss_forward(output, train=False)
|
|
||||||
if isinstance(errs, dict):
|
|
||||||
masks = errs['masks']
|
|
||||||
errs = errs['errs']
|
|
||||||
else:
|
|
||||||
masks = []
|
|
||||||
if not isinstance(errs, list) and not isinstance(errs, tuple):
|
|
||||||
errs = [errs]
|
|
||||||
|
|
||||||
bar.update(batch_idx)
|
|
||||||
if batch_idx % 25 == 0:
|
|
||||||
err_str = self.format_err_str(errs)
|
|
||||||
logging.info(f'test e{epoch}: {batch_idx+1}/{len(test_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
|
|
||||||
|
|
||||||
if mean_loss is None:
|
|
||||||
mean_loss = [0 for e in errs]
|
|
||||||
for erridx, err in enumerate(errs):
|
|
||||||
mean_loss[erridx] += err.item()
|
|
||||||
stopwatch.stop('loss')
|
|
||||||
|
|
||||||
self.callback_test_add(epoch, set_idx, batch_idx, len(test_loader), output, masks)
|
|
||||||
|
|
||||||
|
stopwatch.start('total')
|
||||||
stopwatch.start('data')
|
stopwatch.start('data')
|
||||||
stopwatch.stop('total')
|
for batch_idx, data in enumerate(train_loader):
|
||||||
logging.info('timings: %s' % stopwatch)
|
if self.max_train_iter > 0 and batch_idx > self.max_train_iter: break
|
||||||
|
self.copy_data(data, device=self.train_device, requires_grad=True, train=True)
|
||||||
|
stopwatch.stop('data')
|
||||||
|
|
||||||
mean_loss = [l / len(test_loader) for l in mean_loss]
|
optimizer.zero_grad()
|
||||||
self.callback_test_stop(epoch, set_idx, mean_loss)
|
|
||||||
self.metric_add_test(epoch, set_idx, 'loss', mean_loss)
|
|
||||||
|
|
||||||
# save metrics
|
stopwatch.start('forward')
|
||||||
self.metric_save()
|
output = self.net_forward(net, train=True)
|
||||||
|
if 'cuda' in self.train_device: torch.cuda.synchronize()
|
||||||
|
stopwatch.stop('forward')
|
||||||
|
|
||||||
err_str = self.format_err_str(mean_loss)
|
stopwatch.start('loss')
|
||||||
logging.info(f'test epoch {epoch}: avg test_loss={err_str}')
|
errs = self.loss_forward(output, train=True)
|
||||||
return mean_loss
|
if isinstance(errs, dict):
|
||||||
|
masks = errs['masks']
|
||||||
|
errs = errs['errs']
|
||||||
|
else:
|
||||||
|
masks = []
|
||||||
|
if not isinstance(errs, list) and not isinstance(errs, tuple):
|
||||||
|
errs = [errs]
|
||||||
|
err = sum(errs)
|
||||||
|
if 'cuda' in self.train_device: torch.cuda.synchronize()
|
||||||
|
stopwatch.stop('loss')
|
||||||
|
|
||||||
|
stopwatch.start('backward')
|
||||||
|
err.backward()
|
||||||
|
self.callback_train_post_backward(net, errs, output, epoch, batch_idx, masks)
|
||||||
|
if 'cuda' in self.train_device: torch.cuda.synchronize()
|
||||||
|
stopwatch.stop('backward')
|
||||||
|
|
||||||
|
stopwatch.start('optimizer')
|
||||||
|
optimizer.step()
|
||||||
|
if 'cuda' in self.train_device: torch.cuda.synchronize()
|
||||||
|
stopwatch.stop('optimizer')
|
||||||
|
|
||||||
|
bar.update(batch_idx)
|
||||||
|
if (epoch <= 1 and batch_idx < 128) or batch_idx % 16 == 0:
|
||||||
|
err_str = self.format_err_str(errs)
|
||||||
|
logging.info(
|
||||||
|
f'train e{epoch}: {batch_idx + 1}/{len(train_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
|
||||||
|
# self.write_err_img()
|
||||||
|
|
||||||
|
if mean_loss is None:
|
||||||
|
mean_loss = [0 for e in errs]
|
||||||
|
for erridx, err in enumerate(errs):
|
||||||
|
mean_loss[erridx] += err.item()
|
||||||
|
|
||||||
|
stopwatch.start('data')
|
||||||
|
stopwatch.stop('total')
|
||||||
|
logging.info('timings: %s' % stopwatch)
|
||||||
|
|
||||||
|
mean_loss = [l / len(train_loader) for l in mean_loss]
|
||||||
|
self.callback_train_stop(epoch, mean_loss)
|
||||||
|
self.metric_add_train(epoch, 'loss', mean_loss)
|
||||||
|
|
||||||
|
# save metrics
|
||||||
|
self.metric_save()
|
||||||
|
|
||||||
|
err_str = self.format_err_str(mean_loss)
|
||||||
|
logging.info(f'avg train_loss={err_str}')
|
||||||
|
return mean_loss
|
||||||
|
|
||||||
|
def callback_test_start(self, epoch, set_idx):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def callback_test_stop(self, epoch, set_idx, loss):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test(self, epoch, net, test_sets):
|
||||||
|
errs = {}
|
||||||
|
for test_set_idx, test_set in enumerate(test_sets):
|
||||||
|
if (epoch + 1) % test_set.test_frequency == 0:
|
||||||
|
logging.info('=' * 80)
|
||||||
|
logging.info(f'testing set {test_set.name}')
|
||||||
|
err = self.test_epoch(epoch, test_set_idx, net, test_set.dset)
|
||||||
|
errs[test_set.name] = err
|
||||||
|
return errs
|
||||||
|
|
||||||
|
def test_epoch(self, epoch, set_idx, net, dset):
|
||||||
|
logging.info('-' * 80)
|
||||||
|
logging.info('Test epoch %d' % epoch)
|
||||||
|
dset.current_epoch = epoch
|
||||||
|
test_loader = torch.utils.data.DataLoader(dset, batch_size=self.test_batch_size, shuffle=False,
|
||||||
|
num_workers=self.num_workers, drop_last=False, pin_memory=False)
|
||||||
|
|
||||||
|
net = net.to(self.test_device)
|
||||||
|
net.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
mean_loss = None
|
||||||
|
|
||||||
|
self.callback_test_start(epoch, set_idx)
|
||||||
|
|
||||||
|
bar = ETA(length=len(test_loader))
|
||||||
|
stopwatch = StopWatch()
|
||||||
|
stopwatch.start('total')
|
||||||
|
stopwatch.start('data')
|
||||||
|
for batch_idx, data in enumerate(test_loader):
|
||||||
|
# if batch_idx == 10: break
|
||||||
|
self.copy_data(data, device=self.test_device, requires_grad=False, train=False)
|
||||||
|
stopwatch.stop('data')
|
||||||
|
|
||||||
|
stopwatch.start('forward')
|
||||||
|
output = self.net_forward(net, train=False)
|
||||||
|
if 'cuda' in self.test_device: torch.cuda.synchronize()
|
||||||
|
stopwatch.stop('forward')
|
||||||
|
|
||||||
|
stopwatch.start('loss')
|
||||||
|
errs = self.loss_forward(output, train=False)
|
||||||
|
if isinstance(errs, dict):
|
||||||
|
masks = errs['masks']
|
||||||
|
errs = errs['errs']
|
||||||
|
else:
|
||||||
|
masks = []
|
||||||
|
if not isinstance(errs, list) and not isinstance(errs, tuple):
|
||||||
|
errs = [errs]
|
||||||
|
|
||||||
|
bar.update(batch_idx)
|
||||||
|
if batch_idx % 25 == 0:
|
||||||
|
err_str = self.format_err_str(errs)
|
||||||
|
logging.info(
|
||||||
|
f'test e{epoch}: {batch_idx + 1}/{len(test_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
|
||||||
|
|
||||||
|
if mean_loss is None:
|
||||||
|
mean_loss = [0 for e in errs]
|
||||||
|
for erridx, err in enumerate(errs):
|
||||||
|
mean_loss[erridx] += err.item()
|
||||||
|
stopwatch.stop('loss')
|
||||||
|
|
||||||
|
self.callback_test_add(epoch, set_idx, batch_idx, len(test_loader), output, masks)
|
||||||
|
|
||||||
|
stopwatch.start('data')
|
||||||
|
stopwatch.stop('total')
|
||||||
|
logging.info('timings: %s' % stopwatch)
|
||||||
|
|
||||||
|
mean_loss = [l / len(test_loader) for l in mean_loss]
|
||||||
|
self.callback_test_stop(epoch, set_idx, mean_loss)
|
||||||
|
self.metric_add_test(epoch, set_idx, 'loss', mean_loss)
|
||||||
|
|
||||||
|
# save metrics
|
||||||
|
self.metric_save()
|
||||||
|
|
||||||
|
err_str = self.format_err_str(mean_loss)
|
||||||
|
logging.info(f'test epoch {epoch}: avg test_loss={err_str}')
|
||||||
|
return mean_loss
|
||||||
|
15
train_val.py
15
train_val.py
@ -5,25 +5,24 @@ from model import exp_synphge
|
|||||||
from model import networks
|
from model import networks
|
||||||
from co.args import parse_args
|
from co.args import parse_args
|
||||||
|
|
||||||
|
|
||||||
# parse args
|
# parse args
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# loss types
|
# loss types
|
||||||
if args.loss=='ph':
|
if args.loss == 'ph':
|
||||||
worker = exp_synph.Worker(args)
|
worker = exp_synph.Worker(args)
|
||||||
elif args.loss=='phge':
|
elif args.loss == 'phge':
|
||||||
worker = exp_synphge.Worker(args)
|
worker = exp_synphge.Worker(args)
|
||||||
|
|
||||||
# concatenation of original image and lcn image
|
# concatenation of original image and lcn image
|
||||||
channels_in=2
|
channels_in = 2
|
||||||
|
|
||||||
# set up network
|
# set up network
|
||||||
net = networks.DispEdgeDecoders(channels_in=channels_in, max_disp=args.max_disp, imsizes=worker.imsizes, output_ms=worker.ms)
|
net = networks.DispEdgeDecoders(channels_in=channels_in, max_disp=args.max_disp, imsizes=worker.imsizes,
|
||||||
|
output_ms=worker.ms)
|
||||||
|
|
||||||
# optimizer
|
# optimizer
|
||||||
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
|
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
|
||||||
|
|
||||||
# start the work
|
# start the work
|
||||||
worker.do(net, optimizer)
|
worker.do(net, optimizer)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user