Reformat $EVERYTHING

This commit is contained in:
2021-11-15 16:53:30 +01:00
parent 56f2aa7d5d
commit 43df77fb9b
32 changed files with 4171 additions and 3749 deletions
+33 -32
View File
@@ -11,28 +11,27 @@ import dataset
def get_data(n, row_from, row_to, train):
imsizes = [(256,384)]
focal_lengths = [160]
dset = dataset.SynDataset(n, imsizes=imsizes, focal_lengths=focal_lengths, train=train)
ims = np.empty((n, row_to-row_from, imsizes[0][1]), dtype=np.uint8)
disps = np.empty((n, row_to-row_from, imsizes[0][1]), dtype=np.float32)
for idx in range(n):
print(f'load sample {idx} train={train}')
sample = dset[idx]
ims[idx] = (sample['im0'][0,row_from:row_to] * 255).astype(np.uint8)
disps[idx] = sample['disp0'][0,row_from:row_to]
return ims, disps
imsizes = [(256, 384)]
focal_lengths = [160]
dset = dataset.SynDataset(n, imsizes=imsizes, focal_lengths=focal_lengths, train=train)
ims = np.empty((n, row_to - row_from, imsizes[0][1]), dtype=np.uint8)
disps = np.empty((n, row_to - row_from, imsizes[0][1]), dtype=np.float32)
for idx in range(n):
print(f'load sample {idx} train={train}')
sample = dset[idx]
ims[idx] = (sample['im0'][0, row_from:row_to] * 255).astype(np.uint8)
disps[idx] = sample['disp0'][0, row_from:row_to]
return ims, disps
params = hd.TrainParams(
n_trees=4,
max_tree_depth=,
n_test_split_functions=50,
n_test_thresholds=10,
n_test_samples=4096,
min_samples_to_split=16,
min_samples_for_leaf=8)
n_trees=4,
max_tree_depth=,
n_test_split_functions=50,
n_test_thresholds=10,
n_test_samples=4096,
min_samples_to_split=16,
min_samples_for_leaf=8)
n_disp_bins = 20
depth_switch = 0
@@ -45,21 +44,23 @@ n_test_samples = 32
train_ims, train_disps = get_data(n_train_samples, row_from, row_to, True)
test_ims, test_disps = get_data(n_test_samples, row_from, row_to, False)
for tree_depth in [8,10,12,14,16]:
depth_switch = tree_depth - 4
for tree_depth in [8, 10, 12, 14, 16]:
depth_switch = tree_depth - 4
prefix = f'td{tree_depth}_ds{depth_switch}'
prefix = Path(f'./forests/{prefix}/')
prefix.mkdir(parents=True, exist_ok=True)
prefix = f'td{tree_depth}_ds{depth_switch}'
prefix = Path(f'./forests/{prefix}/')
prefix.mkdir(parents=True, exist_ok=True)
hd.train_forest(params, train_ims, train_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch, forest_prefix=str(prefix / 'fr'))
hd.train_forest(params, train_ims, train_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch,
forest_prefix=str(prefix / 'fr'))
es = hd.eval_forest(test_ims, test_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch, forest_prefix=str(prefix / 'fr'))
es = hd.eval_forest(test_ims, test_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch,
forest_prefix=str(prefix / 'fr'))
np.save(str(prefix / 'ta.npy'), test_disps)
np.save(str(prefix / 'es.npy'), es)
np.save(str(prefix / 'ta.npy'), test_disps)
np.save(str(prefix / 'es.npy'), es)
# plt.figure();
# plt.subplot(2,1,1); plt.imshow(test_disps[0], vmin=0, vmax=4);
# plt.subplot(2,1,2); plt.imshow(es[0], vmin=0, vmax=4);
# plt.show()
# plt.figure();
# plt.subplot(2,1,1); plt.imshow(test_disps[0], vmin=0, vmax=4);
# plt.subplot(2,1,2); plt.imshow(es[0], vmin=0, vmax=4);
# plt.show()
+16 -21
View File
@@ -8,7 +8,6 @@ import os
this_dir = os.path.dirname(__file__)
extra_compile_args = ['-O3', '-std=c++11']
extra_link_args = []
@@ -22,24 +21,20 @@ library_dirs = []
libraries = ['m']
setup(
name="hyperdepth",
cmdclass= {'build_ext': build_ext},
ext_modules=[
Extension('hyperdepth',
sources,
extra_objects=extra_objects,
language='c++',
library_dirs=library_dirs,
libraries=libraries,
include_dirs=[
np.get_include(),
],
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args
)
]
name="hyperdepth",
cmdclass={'build_ext': build_ext},
ext_modules=[
Extension('hyperdepth',
sources,
extra_objects=extra_objects,
language='c++',
library_dirs=library_dirs,
libraries=libraries,
include_dirs=[
np.get_include(),
],
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args
)
]
)
+8 -5
View File
@@ -6,10 +6,13 @@ orig = cv2.imread('disp_orig.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
ta = cv2.imread('disp_ta.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
es = cv2.imread('disp_es.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
plt.figure()
plt.subplot(2,2,1); plt.imshow(orig / 16, vmin=0, vmax=4, cmap='magma')
plt.subplot(2,2,2); plt.imshow(ta / 16, vmin=0, vmax=4, cmap='magma')
plt.subplot(2,2,3); plt.imshow(es / 16, vmin=0, vmax=4, cmap='magma')
plt.subplot(2,2,4); plt.imshow(np.abs(es - ta) / 16, vmin=0, vmax=1, cmap='magma')
plt.subplot(2, 2, 1);
plt.imshow(orig / 16, vmin=0, vmax=4, cmap='magma')
plt.subplot(2, 2, 2);
plt.imshow(ta / 16, vmin=0, vmax=4, cmap='magma')
plt.subplot(2, 2, 3);
plt.imshow(es / 16, vmin=0, vmax=4, cmap='magma')
plt.subplot(2, 2, 4);
plt.imshow(np.abs(es - ta) / 16, vmin=0, vmax=1, cmap='magma')
plt.show()