This commit is contained in:
Yiyi Liao
2019-06-13 16:25:11 +02:00
parent 26157cbb80
commit f5e5c4bd3f
84 changed files with 31343 additions and 2 deletions
+46
View File
@@ -0,0 +1,46 @@
#include "hyperdepth.h"
int main() {
cv::Mat_<uint8_t> im = read_im(0);
cv::Mat_<uint16_t> disp = read_disp(0);
int im_rows = im.rows;
int im_cols = im.cols;
std::cout << im.rows << "/" << im.cols << std::endl;
std::cout << disp.rows << "/" << disp.cols << std::endl;
cv::Mat_<uint16_t> ta_disp(im_rows, im_cols);
cv::Mat_<uint16_t> es_disp(im_rows, im_cols);
int n_disp_bins = 16;
for(int row = 0; row < im_rows; ++row) {
std::vector<TrainDatum> data;
extract_row_samples(im, disp, row, data, false, n_disp_bins);
std::ostringstream forest_path;
forest_path << "cforest_" << row << ".bin";
BinarySerializationIn fin(forest_path.str());
HDForest forest;
forest.Load(fin);
auto res = forest.inferencemt(data, 18);
for(int col = 0; col < im_cols; ++col) {
auto fcn = res[col];
auto target = std::static_pointer_cast<ClassificationTarget>(data[col].target);
float ta = col - float(target->cl()) / n_disp_bins;
float es = col - float(fcn->argmax()) / n_disp_bins;
es = std::max(0.f, es);
ta_disp(row, col) = int(ta * 16);
es_disp(row, col) = int(es * 16);
}
}
cv::imwrite("disp_orig.png", disp);
cv::imwrite("disp_ta.png", ta_disp);
cv::imwrite("disp_es.png", es_disp);
}
+287
View File
@@ -0,0 +1,287 @@
#include <sstream>
#include <iomanip>
#include "rf/forest.h"
#include "rf/spliteval.h"
class HyperdepthSplitEvaluator : public SplitEvaluator {
public:
HyperdepthSplitEvaluator(bool normalize, int n_classes, int n_disp_bins, int depth_switch)
: SplitEvaluator(normalize), n_classes_(n_classes), n_disp_bins_(n_disp_bins), depth_switch_(depth_switch) {}
virtual ~HyperdepthSplitEvaluator() {}
protected:
virtual float Purity(const std::vector<TrainDatum>& targets, int depth) const {
if(targets.size() == 0) return 0;
int n_classes = n_classes_;
if(depth >= depth_switch_) {
n_classes *= n_disp_bins_;
}
std::vector<int> ps;
ps.resize(n_classes, 0);
for(auto target : targets) {
auto ctarget = std::static_pointer_cast<ClassificationTarget>(target.optimize_target);
int cl = ctarget->cl();
if(depth < depth_switch_) {
cl /= n_disp_bins_;
}
ps[cl] += 1;
}
float h = 0;
for(int cl = 0; cl < n_classes; ++cl) {
float fi = float(ps[cl]) / float(targets.size());
if(fi > 0) {
h = h - fi * std::log(fi);
}
}
return h;
}
private:
int n_classes_;
int n_disp_bins_;
int depth_switch_;
};
class HyperdepthLeafFunction {
public:
HyperdepthLeafFunction() : n_classes_(-1) {}
HyperdepthLeafFunction(int n_classes) : n_classes_(n_classes) {}
virtual ~HyperdepthLeafFunction() {}
virtual std::shared_ptr<HyperdepthLeafFunction> Copy() const {
auto fcn = std::make_shared<HyperdepthLeafFunction>();
fcn->n_classes_ = n_classes_;
fcn->counts_.resize(counts_.size());
for(size_t idx = 0; idx < counts_.size(); ++idx) {
fcn->counts_[idx] = counts_[idx];
}
fcn->sum_counts_ = sum_counts_;
return fcn;
}
virtual std::shared_ptr<HyperdepthLeafFunction> Create(const std::vector<TrainDatum>& samples) {
auto stat = std::make_shared<HyperdepthLeafFunction>();
stat->counts_.resize(n_classes_, 0);
for(auto sample : samples) {
auto ctarget = std::static_pointer_cast<ClassificationTarget>(sample.target);
stat->counts_[ctarget->cl()] += 1;
}
stat->sum_counts_ = samples.size();
return stat;
}
virtual std::shared_ptr<HyperdepthLeafFunction> Reduce(const std::vector<std::shared_ptr<HyperdepthLeafFunction>>& fcns) const {
auto stat = std::make_shared<HyperdepthLeafFunction>();
auto cfcn0 = std::static_pointer_cast<HyperdepthLeafFunction>(fcns[0]);
stat->counts_.resize(cfcn0->counts_.size(), 0);
stat->sum_counts_ = 0;
for(auto fcn : fcns) {
auto cfcn = std::static_pointer_cast<HyperdepthLeafFunction>(fcn);
for(size_t cl = 0; cl < stat->counts_.size(); ++cl) {
stat->counts_[cl] += cfcn->counts_[cl];
}
stat->sum_counts_ += cfcn->sum_counts_;
}
return stat;
}
virtual std::tuple<int,int> argmax() const {
int max_idx = 0;
int max_count = counts_[0];
int max2_idx = -1;
int max2_count = -1;
for(size_t idx = 1; idx < counts_.size(); ++idx) {
if(counts_[idx] > max_count) {
max2_count = max_count;
max2_idx = max_idx;
max_count = counts_[idx];
max_idx = idx;
}
else if(counts_[idx] > max2_count) {
max2_count = counts_[idx];
max2_idx = idx;
}
}
return std::make_tuple(max_idx, max2_idx);
}
virtual std::vector<float> prob_vec() const {
std::vector<float> probs(counts_.size(), 0.f);
int sum = 0;
for(int cnt : counts_) {
sum += cnt;
}
for(size_t idx = 0; idx < counts_.size(); ++idx) {
probs[idx] = float(counts_[idx]) / sum;
}
return probs;
}
virtual void Save(SerializationOut& ar) const {
ar << n_classes_;
int n_counts = counts_.size();
ar << n_counts;
for(int idx = 0; idx < n_counts; ++idx) {
ar << counts_[idx];
}
ar << sum_counts_;
}
virtual void Load(SerializationIn& ar) {
ar >> n_classes_;
int n_counts;
ar >> n_counts;
counts_.resize(n_counts);
for(int idx = 0; idx < n_counts; ++idx) {
ar >> counts_[idx];
}
ar >> sum_counts_;
}
public:
int n_classes_;
std::vector<int> counts_;
int sum_counts_;
DISABLE_COPY_AND_ASSIGN(HyperdepthLeafFunction);
};
typedef SplitFunctionPixelDifference HDSplitFunctionT;
typedef HyperdepthLeafFunction HDLeafFunctionT;
typedef HyperdepthSplitEvaluator HDSplitEvaluatorT;
typedef Forest<HDSplitFunctionT, HDLeafFunctionT> HDForest;
template <typename T>
class Raw {
public:
const T* raw;
const int nsamples;
const int rows;
const int cols;
Raw(const T* raw, int nsamples, int rows, int cols)
: raw(raw), nsamples(nsamples), rows(rows), cols(cols) {}
T operator()(int n, int r, int c) const {
return raw[(n * rows + r) * cols + c];
}
};
class RawSample : public Sample {
public:
RawSample(const Raw<uint8_t>& raw, int n, int rc, int cc, int patch_height, int patch_width)
: Sample(1, patch_height, patch_width), raw(raw), n(n), rc(rc), cc(cc) {}
virtual float at(int ch, int r, int c) const {
r += rc - height_ / 2;
c += cc - width_ / 2;
r = std::max(0, std::min(raw.rows-1, r));
c = std::max(0, std::min(raw.cols-1, c));
return raw(n, r, c);
}
protected:
const Raw<uint8_t>& raw;
int n;
int rc;
int cc;
};
void extract_row_samples(const Raw<uint8_t>& im, const Raw<float>& disp, int row, int n_disp_bins, bool only_valid, std::vector<TrainDatum>& data) {
for(int n = 0; n < im.nsamples; ++n) {
for(int col = 0; col < im.cols; ++col) {
float d = disp(n, row, col);
float pos = col - d;
int cl = pos * n_disp_bins;
if((d < 0 || cl < 0) && only_valid) continue;
auto sample = std::make_shared<RawSample>(im, n, row, col, 32, 32);
auto target = std::make_shared<ClassificationTarget>(cl);
auto datum = TrainDatum(sample, target);
data.push_back(datum);
}
}
std::cout << "extracted " << data.size() << " train samples" << std::endl;
std::cout << "n_classes (" << im.cols << ") * n_disp_bins (" << n_disp_bins << ") = " << (im.cols * n_disp_bins) << std::endl;
}
void train(int row_from, int row_to, TrainParameters params, const uint8_t* ims, const float* disps, int n, int h, int w, int n_disp_bins, int depth_switch, int n_threads, std::string forest_prefix) {
Raw<uint8_t> raw_ims(ims, n, h, w);
Raw<float> raw_disps(disps, n, h, w);
int n_classes = w;
auto gen_split_fcn = std::make_shared<HDSplitFunctionT>();
auto gen_leaf_fcn = std::make_shared<HDLeafFunctionT>(n_classes * n_disp_bins);
auto split_eval = std::make_shared<HDSplitEvaluatorT>(true, n_classes, n_disp_bins, depth_switch);
for(int row = row_from; row < row_to; ++row) {
std::cout << "train row " << row << std::endl;
std::vector<TrainDatum> data;
extract_row_samples(raw_ims, raw_disps, row, n_disp_bins, true, data);
TrainForestQueued<HDSplitFunctionT, HDLeafFunctionT, HDSplitEvaluatorT> train(params, gen_split_fcn, gen_leaf_fcn, split_eval, n_threads, true);
auto forest = train.Train(data, TrainType::TRAIN, nullptr);
std::ostringstream forest_path;
forest_path << forest_prefix << row << ".bin";
std::cout << "save forest of row " << row << " to " << forest_path.str() << std::endl;
BinarySerializationOut fout(forest_path.str());
forest->Save(fout);
}
}
void eval(int row_from, int row_to, const uint8_t* ims, const float* disps, int n, int h, int w, int n_disp_bins, int depth_switch, int n_threads, std::string forest_prefix, float* out) {
Raw<uint8_t> raw_ims(ims, n, h, w);
Raw<float> raw_disps(disps, n, h, w);
for(int row = row_from; row < row_to; ++row) {
std::vector<TrainDatum> data;
extract_row_samples(raw_ims, raw_disps, row, n_disp_bins, false, data);
std::ostringstream forest_path;
forest_path << forest_prefix << row << ".bin";
std::cout << "eval row " << row << " - " << forest_path.str() << std::endl;
BinarySerializationIn fin(forest_path.str());
HDForest forest;
forest.Load(fin);
auto res = forest.inferencemt(data, n_threads);
for(int nidx = 0; nidx < n; ++nidx) {
for(int col = 0; col < w; ++col) {
auto fcn = res[nidx * w + col];
int pos, pos2;
std::tie(pos, pos2) = fcn->argmax();
float disp = col - float(pos) / n_disp_bins;
float disp2 = col - float(pos2) / n_disp_bins;
float prob = fcn->prob_vec()[pos];
out[((nidx * h + row) * w + col) * 3 + 0] = disp;
out[((nidx * h + row) * w + col) * 3 + 1] = prob;
out[((nidx * h + row) * w + col) * 3 + 2] = std::abs(disp - disp2);
}
}
}
}
+86
View File
@@ -0,0 +1,86 @@
cimport cython
import numpy as np
cimport numpy as np
from libc.stdlib cimport free, malloc
from libcpp cimport bool
from libcpp.string cimport string
from cpython cimport PyObject, Py_INCREF
CREATE_INIT = True # workaround, so cython builds a init function
np.import_array()
ctypedef unsigned char uint8_t
cdef extern from "rf/train.h":
cdef cppclass TrainParameters:
int n_trees;
int max_tree_depth;
int n_test_split_functions;
int n_test_thresholds;
int n_test_samples;
int min_samples_to_split;
int min_samples_for_leaf;
int print_node_info;
TrainParameters();
cdef extern from "hyperdepth.h":
void train(int row_from, int row_to, TrainParameters params, const uint8_t* ims, const float* disps, int n, int h, int w, int n_disp_bins, int depth_switch, int n_threads, string forest_prefix);
void eval(int row_from, int row_to, const uint8_t* ims, const float* disps, int n, int h, int w, int n_disp_bins, int depth_switch, int n_threads, string forest_prefix, float* out);
cdef class TrainParams:
cdef TrainParameters params;
def __cinit__(self, int n_trees=6, int max_tree_depth=8, int n_test_split_functions=50, int n_test_thresholds=10, int n_test_samples=4096, int min_samples_to_split=16, int min_samples_for_leaf=8, int print_node_info=100):
self.params.n_trees = n_trees
self.params.max_tree_depth = max_tree_depth
self.params.n_test_split_functions = n_test_split_functions
self.params.n_test_thresholds = n_test_thresholds
self.params.n_test_samples = n_test_samples
self.params.min_samples_to_split = min_samples_to_split
self.params.min_samples_for_leaf = min_samples_for_leaf
self.params.print_node_info = print_node_info
def __str__(self):
return f'n_trees={self.params.n_trees}, max_tree_depth={self.params.max_tree_depth}, n_test_split_functions={self.params.n_test_split_functions}, n_test_thresholds={self.params.n_test_thresholds}, n_test_samples={self.params.n_test_samples}, min_samples_to_split={self.params.min_samples_to_split}, min_samples_for_leaf={self.params.min_samples_for_leaf}'
def train_forest(TrainParams params, uint8_t[:,:,::1] ims, float[:,:,::1] disps, int n_disp_bins=10, int depth_switch=0, int n_threads=18, str forest_prefix='forest', int row_from=-1, int row_to=-1):
cdef int n = ims.shape[0]
cdef int h = ims.shape[1]
cdef int w = ims.shape[2]
if row_from < 0:
row_from = 0
if row_to > h or row_to < 0:
row_to = h
if n != disps.shape[0] or h != disps.shape[1] or w != disps.shape[2]:
raise Exception('ims.shape != disps.shape')
train(row_from, row_to, params.params, &ims[0,0,0], &disps[0,0,0], n, h, w, n_disp_bins, depth_switch, n_threads, forest_prefix.encode())
def eval_forest(uint8_t[:,:,::1] ims, float[:,:,::1] disps, int n_disp_bins=10, int depth_switch=0, int n_threads=18, str forest_prefix='forest', int row_from=-1, int row_to=-1):
cdef int n = ims.shape[0]
cdef int h = ims.shape[1]
cdef int w = ims.shape[2]
if n != disps.shape[0] or h != disps.shape[1] or w != disps.shape[2]:
raise Exception('ims.shape != disps.shape')
if row_from < 0:
row_from = 0
if row_to > h or row_to < 0:
row_to = h
out = np.empty((n, h, w, 3), dtype=np.float32)
cdef float[:,:,:,::1] out_view = out
eval(row_from, row_to, &ims[0,0,0], &disps[0,0,0], n, h, w, n_disp_bins, depth_switch, n_threads, forest_prefix.encode(), &out_view[0,0,0,0])
return out
+65
View File
@@ -0,0 +1,65 @@
import numpy as np
import matplotlib.pyplot as plt
import cv2
from pathlib import Path
import sys
import hyperdepth as hd
sys.path.append('../')
import dataset
def get_data(n, row_from, row_to, train):
imsizes = [(256,384)]
focal_lengths = [160]
dset = dataset.SynDataset(n, imsizes=imsizes, focal_lengths=focal_lengths, train=train)
ims = np.empty((n, row_to-row_from, imsizes[0][1]), dtype=np.uint8)
disps = np.empty((n, row_to-row_from, imsizes[0][1]), dtype=np.float32)
for idx in range(n):
print(f'load sample {idx} train={train}')
sample = dset[idx]
ims[idx] = (sample['im0'][0,row_from:row_to] * 255).astype(np.uint8)
disps[idx] = sample['disp0'][0,row_from:row_to]
return ims, disps
params = hd.TrainParams(
n_trees=4,
max_tree_depth=,
n_test_split_functions=50,
n_test_thresholds=10,
n_test_samples=4096,
min_samples_to_split=16,
min_samples_for_leaf=8)
n_disp_bins = 20
depth_switch = 0
row_from = 100
row_to = 108
n_train_samples = 1024
n_test_samples = 32
train_ims, train_disps = get_data(n_train_samples, row_from, row_to, True)
test_ims, test_disps = get_data(n_test_samples, row_from, row_to, False)
for tree_depth in [8,10,12,14,16]:
depth_switch = tree_depth - 4
prefix = f'td{tree_depth}_ds{depth_switch}'
prefix = Path(f'./forests/{prefix}/')
prefix.mkdir(parents=True, exist_ok=True)
hd.train_forest(params, train_ims, train_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch, forest_prefix=str(prefix / 'fr'))
es = hd.eval_forest(test_ims, test_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch, forest_prefix=str(prefix / 'fr'))
np.save(str(prefix / 'ta.npy'), test_disps)
np.save(str(prefix / 'es.npy'), es)
# plt.figure();
# plt.subplot(2,1,1); plt.imshow(test_disps[0], vmin=0, vmax=4);
# plt.subplot(2,1,2); plt.imshow(es[0], vmin=0, vmax=4);
# plt.show()
+18
View File
@@ -0,0 +1,18 @@
#ifndef COMMON_H
#define COMMON_H
#include <cmath>
#include <algorithm>
#if defined(_OPENMP)
#include <omp.h>
#endif
#define DISABLE_COPY_AND_ASSIGN(classname) \
private:\
classname(const classname&) = delete;\
classname& operator=(const classname&) = delete;
#endif
+72
View File
@@ -0,0 +1,72 @@
#pragma once
#include <vector>
class Sample {
public:
Sample(int channels, int height, int width)
: channels_(channels), height_(height), width_(width) {}
virtual ~Sample() {}
virtual float at(int c, int h, int w) const = 0;
virtual float operator()(int c, int h, int w) const {
return at(c,h,w);
}
virtual int channels() const { return channels_; }
virtual int height() const { return height_; }
virtual int width() const { return width_; }
protected:
int channels_;
int height_;
int width_;
};
typedef std::shared_ptr<Sample> SamplePtr;
class Target {
public:
Target() {}
virtual ~Target() {}
};
typedef std::shared_ptr<Target> TargetPtr;
typedef std::vector<TargetPtr> VecTargetPtr;
typedef std::shared_ptr<VecTargetPtr> VecPtrTargetPtr;
class ClassificationTarget : public Target {
public:
ClassificationTarget(int cl) : cl_(cl) {}
virtual ~ClassificationTarget() {}
int cl() const { return cl_; }
private:
int cl_;
};
typedef std::shared_ptr<ClassificationTarget> ClassificationTargetPtr;
struct TrainDatum {
SamplePtr sample;
TargetPtr target;
TargetPtr optimize_target;
TrainDatum() : sample(nullptr), target(nullptr), optimize_target(nullptr) {}
TrainDatum(SamplePtr sample, TargetPtr target)
: sample(sample), target(target), optimize_target(target) {}
TrainDatum(SamplePtr sample, TargetPtr target, TargetPtr optimize_target)
: sample(sample), target(target), optimize_target(optimize_target) {}
};
+92
View File
@@ -0,0 +1,92 @@
#pragma once
#include "tree.h"
template <typename SplitFunctionT, typename LeafFunctionT>
class Forest {
public:
Forest() {}
virtual ~Forest() {}
std::shared_ptr<LeafFunctionT> inferencest(const SamplePtr& sample) const {
int n_trees = trees_.size();
std::vector<std::shared_ptr<LeafFunctionT>> fcns;
//inference of individual trees
for(int tree_idx = 0; tree_idx < n_trees; ++tree_idx) {
std::shared_ptr<LeafFunctionT> tree_fcn = trees_[tree_idx]->inference(sample);
fcns.push_back(tree_fcn);
}
//combine tree fcns/results and collect all results
return fcns[0]->Reduce(fcns);
}
std::vector<std::shared_ptr<LeafFunctionT>> inferencemt(const std::vector<SamplePtr>& samples, int n_threads) const {
std::vector<std::shared_ptr<LeafFunctionT>> targets(samples.size());
omp_set_num_threads(n_threads);
#pragma omp parallel for
for(size_t sample_idx = 0; sample_idx < samples.size(); ++sample_idx) {
targets[sample_idx] = inferencest(samples[sample_idx]);
}
return targets;
}
std::vector<std::shared_ptr<LeafFunctionT>> inferencemt(const std::vector<TrainDatum>& samples, int n_threads) const {
std::vector<std::shared_ptr<LeafFunctionT>> targets(samples.size());
omp_set_num_threads(n_threads);
#pragma omp parallel for
for(size_t sample_idx = 0; sample_idx < samples.size(); ++sample_idx) {
targets[sample_idx] = inferencest(samples[sample_idx].sample);
}
return targets;
}
void AddTree(std::shared_ptr<Tree<SplitFunctionT, LeafFunctionT>> tree) {
trees_.push_back(tree);
}
size_t trees_size() const { return trees_.size(); }
// TreePtr trees(int idx) const { return trees_[idx]; }
virtual void Save(SerializationOut& ar) const {
size_t n_trees = trees_.size();
std::cout << "[DEBUG] write " << n_trees << " trees" << std::endl;
ar << n_trees;
if(true) std::cout << "[Forest][write] write number of trees " << n_trees << std::endl;
for(size_t tree_idx = 0; tree_idx < trees_.size(); ++tree_idx) {
if(true) std::cout << "[Forest][write] write tree nb. " << tree_idx << std::endl;
trees_[tree_idx]->Save(ar);
}
}
virtual void Load(SerializationIn& ar) {
size_t n_trees;
ar >> n_trees;
if(true) std::cout << "[Forest][read] nTrees: " << n_trees << std::endl;
trees_.clear();
for(size_t i = 0; i < n_trees; ++i) {
if(true) std::cout << "[Forest][read] read tree " << (i+1) << " of " << n_trees << " - " << std::endl;
auto tree = std::make_shared<Tree<SplitFunctionT, LeafFunctionT>>();
tree->Load(ar);
trees_.push_back(tree);
if(true) std::cout << "[Forest][read] finished read tree " << (i+1) << " of " << n_trees << std::endl;
}
}
private:
std::vector<std::shared_ptr<Tree<SplitFunctionT, LeafFunctionT>>> trees_;
};
+99
View File
@@ -0,0 +1,99 @@
#pragma once
#include <iostream>
#include "common.h"
#include "data.h"
class ClassProbabilitiesLeafFunction {
public:
ClassProbabilitiesLeafFunction() : n_classes_(-1) {}
ClassProbabilitiesLeafFunction(int n_classes) : n_classes_(n_classes) {}
virtual ~ClassProbabilitiesLeafFunction() {}
virtual std::shared_ptr<ClassProbabilitiesLeafFunction> Copy() const {
auto fcn = std::make_shared<ClassProbabilitiesLeafFunction>();
fcn->n_classes_ = n_classes_;
fcn->counts_.resize(counts_.size());
for(size_t idx = 0; idx < counts_.size(); ++idx) {
fcn->counts_[idx] = counts_[idx];
}
fcn->sum_counts_ = sum_counts_;
return fcn;
}
virtual std::shared_ptr<ClassProbabilitiesLeafFunction> Create(const std::vector<TrainDatum>& samples) {
auto stat = std::make_shared<ClassProbabilitiesLeafFunction>();
stat->counts_.resize(n_classes_, 0);
for(auto sample : samples) {
auto ctarget = std::static_pointer_cast<ClassificationTarget>(sample.target);
stat->counts_[ctarget->cl()] += 1;
}
stat->sum_counts_ = samples.size();
return stat;
}
virtual std::shared_ptr<ClassProbabilitiesLeafFunction> Reduce(const std::vector<std::shared_ptr<ClassProbabilitiesLeafFunction>>& fcns) const {
auto stat = std::make_shared<ClassProbabilitiesLeafFunction>();
auto cfcn0 = std::static_pointer_cast<ClassProbabilitiesLeafFunction>(fcns[0]);
stat->counts_.resize(cfcn0->counts_.size(), 0);
stat->sum_counts_ = 0;
for(auto fcn : fcns) {
auto cfcn = std::static_pointer_cast<ClassProbabilitiesLeafFunction>(fcn);
for(size_t cl = 0; cl < stat->counts_.size(); ++cl) {
stat->counts_[cl] += cfcn->counts_[cl];
}
stat->sum_counts_ += cfcn->sum_counts_;
}
return stat;
}
virtual int argmax() const {
int max_idx = 0;
int max_count = counts_[0];
for(size_t idx = 1; idx < counts_.size(); ++idx) {
if(counts_[idx] > max_count) {
max_count = counts_[idx];
max_idx = idx;
}
}
return max_idx;
}
virtual void Save(SerializationOut& ar) const {
ar << n_classes_;
int n_counts = counts_.size();
ar << n_counts;
for(int idx = 0; idx < n_counts; ++idx) {
ar << counts_[idx];
}
ar << sum_counts_;
}
virtual void Load(SerializationIn& ar) {
ar >> n_classes_;
int n_counts;
ar >> n_counts;
counts_.resize(n_counts);
for(int idx = 0; idx < n_counts; ++idx) {
ar >> counts_[idx];
}
ar >> sum_counts_;
}
public:
int n_classes_;
std::vector<int> counts_;
int sum_counts_;
DISABLE_COPY_AND_ASSIGN(ClassProbabilitiesLeafFunction);
};
+158
View File
@@ -0,0 +1,158 @@
#pragma once
#include <memory>
#include "serialization.h"
#include "leaffcn.h"
#include "splitfcn.h"
class Node {
public:
Node() {}
virtual ~Node() {}
virtual std::shared_ptr<Node> Copy() const = 0;
virtual int type() const = 0;
virtual void Save(SerializationOut& ar) const = 0;
virtual void Load(SerializationIn& ar) = 0;
};
typedef std::shared_ptr<Node> NodePtr;
template <typename LeafFunctionT>
class LeafNode : public Node {
public:
static const int TYPE = 0;
LeafNode() {}
LeafNode(std::shared_ptr<LeafFunctionT> leaf_node_fcn) : leaf_node_fcn_(leaf_node_fcn) {}
virtual ~LeafNode() {}
virtual NodePtr Copy() const {
auto node = std::make_shared<LeafNode>();
node->leaf_node_fcn_ = leaf_node_fcn_->Copy();
return node;
}
virtual void Save(SerializationOut& ar) const {
leaf_node_fcn_->Save(ar);
}
virtual void Load(SerializationIn& ar) {
leaf_node_fcn_ = std::make_shared<LeafFunctionT>();
leaf_node_fcn_->Load(ar);
}
virtual int type() const { return TYPE; };
std::shared_ptr<LeafFunctionT> leaf_node_fcn() const { return leaf_node_fcn_; }
private:
std::shared_ptr<LeafFunctionT> leaf_node_fcn_;
DISABLE_COPY_AND_ASSIGN(LeafNode);
};
template <typename SplitFunctionT, typename LeafFunctionT>
class SplitNode : public Node {
public:
static const int TYPE = 1;
SplitNode() {}
SplitNode(NodePtr left, NodePtr right, std::shared_ptr<SplitFunctionT> split_fcn) :
left_(left), right_(right), split_fcn_(split_fcn)
{}
virtual ~SplitNode() {}
virtual std::shared_ptr<Node> Copy() const {
std::shared_ptr<SplitNode> node = std::make_shared<SplitNode>();
node->left_ = left_->Copy();
node->right_ = right_->Copy();
node->split_fcn_ = split_fcn_->Copy();
return node;
}
bool Split(SamplePtr sample) {
return split_fcn_->Split(sample);
}
virtual void Save(SerializationOut& ar) const {
split_fcn_->Save(ar);
//left
int type = left_->type();
ar << type;
left_->Save(ar);
//right
type = right_->type();
ar << type;
right_->Save(ar);
}
virtual void Load(SerializationIn& ar);
virtual int type() const { return TYPE; }
NodePtr left() const { return left_; }
NodePtr right() const { return right_; }
std::shared_ptr<SplitFunctionT> split_fcn() const { return split_fcn_; }
void set_left(NodePtr left) { left_ = left; }
void set_right(NodePtr right) { right_ = right; }
void set_split_fcn(std::shared_ptr<SplitFunctionT> split_fcn) { split_fcn_ = split_fcn; }
public:
NodePtr left_;
NodePtr right_;
std::shared_ptr<SplitFunctionT> split_fcn_;
DISABLE_COPY_AND_ASSIGN(SplitNode);
};
template <typename SplitFunctionT, typename LeafFunctionT>
NodePtr MakeNode(int type) {
NodePtr node;
if(type == LeafNode<LeafFunctionT>::TYPE) {
node = std::make_shared<LeafNode<LeafFunctionT>>();
}
else if(type == SplitNode<SplitFunctionT, LeafFunctionT>::TYPE) {
node = std::make_shared<SplitNode<SplitFunctionT, LeafFunctionT>>();
}
else {
std::cout << "[ERROR] unknown node type" << std::endl;
exit(-1);
}
return node;
}
template <typename SplitFunctionT, typename LeafFunctionT>
void SplitNode<SplitFunctionT, LeafFunctionT>::Load(SerializationIn& ar) {
split_fcn_ = std::make_shared<SplitFunctionT>();
split_fcn_->Load(ar);
//left
int left_type;
ar >> left_type;
left_ = MakeNode<SplitFunctionT, LeafFunctionT>(left_type);
left_->Load(ar);
//right
int right_type;
ar >> right_type;
right_ = MakeNode<SplitFunctionT, LeafFunctionT>(right_type);
right_->Load(ar);
}
+256
View File
@@ -0,0 +1,256 @@
#pragma once
#include <fstream>
class SerializationOut {
public:
SerializationOut(const std::string& path) : path_(path) {}
virtual ~SerializationOut() {}
virtual SerializationOut& operator<<(const bool& v) = 0;
virtual SerializationOut& operator<<(const char& v) = 0;
virtual SerializationOut& operator<<(const int& v) = 0;
virtual SerializationOut& operator<<(const unsigned int& v) = 0;
virtual SerializationOut& operator<<(const long int& v) = 0;
virtual SerializationOut& operator<<(const unsigned long int& v) = 0;
virtual SerializationOut& operator<<(const long long int& v) = 0;
virtual SerializationOut& operator<<(const unsigned long long int& v) = 0;
virtual SerializationOut& operator<<(const float& v) = 0;
virtual SerializationOut& operator<<(const double& v) = 0;
protected:
const std::string& path_;
};
class SerializationIn {
public:
SerializationIn(const std::string& path) : path_(path) {}
virtual ~SerializationIn() {}
virtual SerializationIn& operator>>(bool& v) = 0;
virtual SerializationIn& operator>>(char& v) = 0;
virtual SerializationIn& operator>>(int& v) = 0;
virtual SerializationIn& operator>>(unsigned int& v) = 0;
virtual SerializationIn& operator>>(long int& v) = 0;
virtual SerializationIn& operator>>(unsigned long int& v) = 0;
virtual SerializationIn& operator>>(long long int& v) = 0;
virtual SerializationIn& operator>>(unsigned long long int& v) = 0;
virtual SerializationIn& operator>>(float& v) = 0;
virtual SerializationIn& operator>>(double& v) = 0;
protected:
const std::string& path_;
};
class TextSerializationOut : public SerializationOut {
public:
TextSerializationOut(const std::string& path) : SerializationOut(path),
f_(path.c_str()) {}
virtual ~TextSerializationOut() {
f_.close();
}
virtual SerializationOut& operator<<(const bool& v) {
f_ << v << std::endl;
return (*this);
}
virtual SerializationOut& operator<<(const char& v) {
f_ << v << std::endl;
return (*this);
}
virtual SerializationOut& operator<<(const int& v) {
f_ << v << std::endl;
return (*this);
}
virtual SerializationOut& operator<<(const unsigned int& v) {
f_ << v << std::endl;
return (*this);
}
virtual SerializationOut& operator<<(const long int& v) {
f_ << v << std::endl;
return (*this);
}
virtual SerializationOut& operator<<(const unsigned long int& v) {
f_ << v << std::endl;
return (*this);
}
virtual SerializationOut& operator<<(const long long int& v) {
f_ << v << std::endl;
return (*this);
}
virtual SerializationOut& operator<<(const unsigned long long int& v) {
f_ << v << std::endl;
return (*this);
}
virtual SerializationOut& operator<<(const float& v) {
f_ << v << std::endl;
return (*this);
}
virtual SerializationOut& operator<<(const double& v) {
f_ << v << std::endl;
return (*this);
}
protected:
std::ofstream f_;
};
class TextSerializationIn : public SerializationIn {
public:
TextSerializationIn(const std::string& path) : SerializationIn(path),
f_(path.c_str()) {}
virtual ~TextSerializationIn() {
f_.close();
}
virtual SerializationIn& operator>>(bool& v) {
f_ >> v;
return (*this);
}
virtual SerializationIn& operator>>(char& v) {
f_ >> v;
return (*this);
}
virtual SerializationIn& operator>>(int& v) {
f_ >> v;
return (*this);
}
virtual SerializationIn& operator>>(unsigned int& v) {
f_ >> v;
return (*this);
}
virtual SerializationIn& operator>>(long int& v) {
f_ >> v;
return (*this);
}
virtual SerializationIn& operator>>(unsigned long int& v) {
f_ >> v;
return (*this);
}
virtual SerializationIn& operator>>(long long int& v) {
f_ >> v;
return (*this);
}
virtual SerializationIn& operator>>(unsigned long long int& v) {
f_ >> v;
return (*this);
}
virtual SerializationIn& operator>>(float& v) {
f_ >> v;
return (*this);
}
virtual SerializationIn& operator>>(double& v) {
f_ >> v;
return (*this);
}
protected:
std::ifstream f_;
};
class BinarySerializationOut : public SerializationOut {
public:
BinarySerializationOut(const std::string& path) : SerializationOut(path),
f_(path.c_str(), std::ios::binary) {}
virtual ~BinarySerializationOut() {
f_.close();
}
virtual SerializationOut& operator<<(const bool& v) {
f_.write((char*)&v, sizeof(bool));
return (*this);
}
virtual SerializationOut& operator<<(const char& v) {
f_.write((char*)&v, sizeof(char));
return (*this);
}
virtual SerializationOut& operator<<(const int& v) {
f_.write((char*)&v, sizeof(int));
return (*this);
}
virtual SerializationOut& operator<<(const unsigned int& v) {
f_.write((char*)&v, sizeof(unsigned int));
return (*this);
}
virtual SerializationOut& operator<<(const long int& v) {
f_.write((char*)&v, sizeof(long int));
return (*this);
}
virtual SerializationOut& operator<<(const unsigned long int& v) {
f_.write((char*)&v, sizeof(unsigned long int));
return (*this);
}
virtual SerializationOut& operator<<(const long long int& v) {
f_.write((char*)&v, sizeof(long long int));
return (*this);
}
virtual SerializationOut& operator<<(const unsigned long long int& v) {
f_.write((char*)&v, sizeof(unsigned long long int));
return (*this);
}
virtual SerializationOut& operator<<(const float& v) {
f_.write((char*)&v, sizeof(float));
return (*this);
}
virtual SerializationOut& operator<<(const double& v) {
f_.write((char*)&v, sizeof(double));
return (*this);
}
protected:
std::ofstream f_;
};
class BinarySerializationIn : public SerializationIn {
public:
BinarySerializationIn(const std::string& path) : SerializationIn(path),
f_(path.c_str(), std::ios::binary) {}
virtual ~BinarySerializationIn() {
f_.close();
}
virtual SerializationIn& operator>>(bool& v) {
f_.read((char*)&v, sizeof(bool));
return (*this);
}
virtual SerializationIn& operator>>(char& v) {
f_.read((char*)&v, sizeof(char));
return (*this);
}
virtual SerializationIn& operator>>(int& v) {
f_.read((char*)&v, sizeof(int));
return (*this);
}
virtual SerializationIn& operator>>(unsigned int& v) {
f_.read((char*)&v, sizeof(unsigned int));
return (*this);
}
virtual SerializationIn& operator>>(long int& v) {
f_.read((char*)&v, sizeof(long int));
return (*this);
}
virtual SerializationIn& operator>>(unsigned long int& v) {
f_.read((char*)&v, sizeof(unsigned long int));
return (*this);
}
virtual SerializationIn& operator>>(long long int& v) {
f_.read((char*)&v, sizeof(long long int));
return (*this);
}
virtual SerializationIn& operator>>(unsigned long long int& v) {
f_.read((char*)&v, sizeof(unsigned long long int));
return (*this);
}
virtual SerializationIn& operator>>(float& v) {
f_.read((char*)&v, sizeof(float));
return (*this);
}
virtual SerializationIn& operator>>(double& v) {
f_.read((char*)&v, sizeof(double));
return (*this);
}
protected:
std::ifstream f_;
};
+71
View File
@@ -0,0 +1,71 @@
#pragma once
class SplitEvaluator {
public:
SplitEvaluator(bool normalize)
: normalize_(normalize) {}
virtual ~SplitEvaluator() {}
virtual float Eval(const std::vector<TrainDatum>& lefttargets, const std::vector<TrainDatum>& righttargets, int depth) const {
float purity_left = Purity(lefttargets, depth);
float purity_right = Purity(righttargets, depth);
float normalize_left = 1.0;
float normalize_right = 1.0;
if(normalize_) {
unsigned int n_left = lefttargets.size();
unsigned int n_right = righttargets.size();
unsigned int n_total = n_left + n_right;
normalize_left = float(n_left) / float(n_total);
normalize_right = float(n_right) / float(n_total);
}
float purity = purity_left * normalize_left + purity_right * normalize_right;
return purity;
}
protected:
virtual float Purity(const std::vector<TrainDatum>& targets, int depth) const = 0;
protected:
bool normalize_;
};
class ClassificationIGSplitEvaluator : public SplitEvaluator {
public:
ClassificationIGSplitEvaluator(bool normalize, int n_classes)
: SplitEvaluator(normalize), n_classes_(n_classes) {}
virtual ~ClassificationIGSplitEvaluator() {}
protected:
virtual float Purity(const std::vector<TrainDatum>& targets, int depth) const {
if(targets.size() == 0) return 0;
std::vector<int> ps;
ps.resize(n_classes_, 0);
for(auto target : targets) {
auto ctarget = std::static_pointer_cast<ClassificationTarget>(target.optimize_target);
ps[ctarget->cl()] += 1;
}
float h = 0;
for(int cl = 0; cl < n_classes_; ++cl) {
float fi = float(ps[cl]) / float(targets.size());
if(fi > 0) {
h = h - fi * std::log(fi);
}
}
return h;
}
private:
int n_classes_;
};
+106
View File
@@ -0,0 +1,106 @@
#pragma once
#include <random>
class SplitFunction {
public:
SplitFunction() {}
virtual ~SplitFunction() {}
virtual float Compute(SamplePtr sample) const = 0;
virtual bool Split(SamplePtr sample) const {
return Compute(sample) < threshold_;
}
virtual void Save(SerializationOut& ar) const {
ar << threshold_;
}
virtual void Load(SerializationIn& ar) {
ar >> threshold_;
}
virtual float threshold() const { return threshold_; }
virtual void set_threshold(float threshold) { threshold_ = threshold; }
protected:
float threshold_;
};
class SplitFunctionPixelDifference : public SplitFunction {
public:
SplitFunctionPixelDifference() {}
virtual ~SplitFunctionPixelDifference() {}
virtual std::shared_ptr<SplitFunctionPixelDifference> Copy() const {
std::shared_ptr<SplitFunctionPixelDifference> split_fcn = std::make_shared<SplitFunctionPixelDifference>();
split_fcn->threshold_ = threshold_;
split_fcn->c0_ = c0_;
split_fcn->c1_ = c1_;
split_fcn->h0_ = h0_;
split_fcn->h1_ = h1_;
split_fcn->w0_ = w0_;
split_fcn->w1_ = w1_;
return split_fcn;
}
virtual std::shared_ptr<SplitFunctionPixelDifference> Generate(std::mt19937& rng, const SamplePtr sample) const {
std::shared_ptr<SplitFunctionPixelDifference> split_fcn = std::make_shared<SplitFunctionPixelDifference>();
std::uniform_int_distribution<int> cdist(0, sample->channels()-1);
split_fcn->c0_ = cdist(rng);
split_fcn->c1_ = cdist(rng);
std::uniform_int_distribution<int> hdist(0, sample->height()-1);
split_fcn->h0_ = hdist(rng);
split_fcn->h1_ = hdist(rng);
std::uniform_int_distribution<int> wdist(0, sample->width()-1);
split_fcn->w0_ = wdist(rng);
split_fcn->w1_ = wdist(rng);
return split_fcn;
}
virtual float Compute(SamplePtr sample) const {
return (*sample)(c0_, h0_, w0_) - (*sample)(c1_, h1_, w1_);
}
virtual void Save(SerializationOut& ar) const {
SplitFunction::Save(ar);
ar << c0_;
ar << c1_;
ar << h0_;
ar << h1_;
ar << w0_;
ar << w1_;
}
virtual void Load(SerializationIn& ar) {
SplitFunction::Load(ar);
ar >> c0_;
ar >> c1_;
ar >> h0_;
ar >> h1_;
ar >> w0_;
ar >> w1_;
}
private:
int c0_;
int c1_;
int h0_;
int h1_;
int w0_;
int w1_;
DISABLE_COPY_AND_ASSIGN(SplitFunctionPixelDifference);
};
+112
View File
@@ -0,0 +1,112 @@
#pragma once
#include <vector>
#include <queue>
#include <memory>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <future>
#include <functional>
#include <stdexcept>
class ThreadPool {
public:
ThreadPool(size_t);
template<class F, class... Args>
auto enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>;
~ThreadPool();
bool has_running_tasks() {
std::unique_lock<std::mutex> lock(running_tasks_mutex);
return n_running_tasks > 0;
}
private:
// need to keep track of threads so we can join them
std::vector< std::thread > workers;
// the task queue
std::queue< std::function<void()> > tasks;
int n_running_tasks;
// synchronization
std::mutex queue_mutex;
std::mutex running_tasks_mutex;
std::condition_variable condition;
bool stop;
};
// the constructor just launches some amount of workers
inline ThreadPool::ThreadPool(size_t threads)
: n_running_tasks(0), stop(false)
{
for(size_t i = 0;i<threads;++i)
workers.emplace_back(
[this]
{
for(;;)
{
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(lock,
[this]{ return this->stop || !this->tasks.empty(); });
if(this->stop && this->tasks.empty())
return;
task = std::move(this->tasks.front());
this->tasks.pop();
}
{
std::unique_lock<std::mutex> lock(this->running_tasks_mutex);
n_running_tasks++;
}
task();
{
std::unique_lock<std::mutex> lock(this->running_tasks_mutex);
n_running_tasks--;
}
}
}
);
}
// add new work item to the pool
template<class F, class... Args>
auto ThreadPool::enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>
{
using return_type = typename std::result_of<F(Args...)>::type;
auto task = std::make_shared< std::packaged_task<return_type()> >(
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
);
std::future<return_type> res = task->get_future();
{
std::unique_lock<std::mutex> lock(queue_mutex);
// don't allow enqueueing after stopping the pool
if(stop)
throw std::runtime_error("enqueue on stopped ThreadPool");
tasks.emplace([task](){ (*task)(); });
}
condition.notify_one();
return res;
}
// the destructor joins all threads
inline ThreadPool::~ThreadPool()
{
{
std::unique_lock<std::mutex> lock(queue_mutex);
stop = true;
}
condition.notify_all();
for(std::thread &worker: workers)
worker.join();
}
+423
View File
@@ -0,0 +1,423 @@
#pragma once
#include <chrono>
#include <set>
#include <queue>
#include "threadpool.h"
#include "forest.h"
#include "spliteval.h"
enum class TrainType : int {
TRAIN = 0,
RETRAIN = 1,
RETRAIN_WITH_REPLACEMENT = 2
};
struct TrainParameters {
TrainType train_type;
int n_trees;
int max_tree_depth;
int n_test_split_functions;
int n_test_thresholds;
int n_test_samples;
int min_samples_to_split;
int min_samples_for_leaf;
int print_node_info;
TrainParameters() :
train_type(TrainType::TRAIN),
n_trees(5),
max_tree_depth(7),
n_test_split_functions(50),
n_test_thresholds(10),
n_test_samples(100),
min_samples_to_split(14),
min_samples_for_leaf(7),
print_node_info(100)
{}
};
template <typename SplitFunctionT, typename LeafFunctionT, typename SplitEvaluatorT>
class TrainForest {
public:
TrainForest(const TrainParameters& params, const std::shared_ptr<SplitFunctionT> gen_split_fcn, const std::shared_ptr<LeafFunctionT> gen_leaf_fcn, const std::shared_ptr<SplitEvaluatorT> split_eval, int n_threads, bool verbose)
: params_(params), gen_split_fcn_(gen_split_fcn), gen_leaf_fcn_(gen_leaf_fcn), split_eval_(split_eval), n_threads(n_threads), verbose_(verbose) {
n_created_nodes_ = 0;
n_max_nodes_ = 1;
unsigned long n_nodes_d = 1;
for(int depth = 0; depth < params.max_tree_depth; ++depth) {
n_nodes_d *= 2;
n_max_nodes_ += n_nodes_d;
}
n_max_nodes_ *= params.n_trees;
}
virtual ~TrainForest() {}
virtual std::shared_ptr<Forest<SplitFunctionT, LeafFunctionT>> Train(const std::vector<TrainDatum>& samples, TrainType train_type, const std::shared_ptr<Forest<SplitFunctionT, LeafFunctionT>>& old_forest) = 0;
protected:
virtual void PrintParams() {
if(verbose_){
#pragma omp critical (TrainForest_train)
{
std::cout << "[TRAIN] training forest " << std::endl;
std::cout << "[TRAIN] n_trees : " << params_.n_trees << std::endl;
std::cout << "[TRAIN] max_tree_depth : " << params_.max_tree_depth << std::endl;
std::cout << "[TRAIN] n_test_split_functions: " << params_.n_test_split_functions << std::endl;
std::cout << "[TRAIN] n_test_thresholds : " << params_.n_test_thresholds << std::endl;
std::cout << "[TRAIN] n_test_samples : " << params_.n_test_samples << std::endl;
std::cout << "[TRAIN] min_samples_to_split : " << params_.min_samples_to_split << std::endl;
}
}
}
virtual void UpdateNodeInfo(unsigned int depth, bool leaf) {
if(verbose_) {
n_created_nodes_ += 1;
if(leaf) {
unsigned long n_nodes_d = 1;
unsigned int n_remove_max_nodes = 0;
for(int d = depth; d < params_.max_tree_depth; ++d) {
n_nodes_d *= 2;
n_remove_max_nodes += n_nodes_d;
}
n_max_nodes_ -= n_remove_max_nodes;
}
if(n_created_nodes_ % params_.print_node_info == 0 || n_created_nodes_ == n_max_nodes_) {
std::cout << "[Forest]"
<< " created node number " << n_created_nodes_
<< " @ depth " << depth
<< ", max. " << n_max_nodes_ << " left"
<< " => " << (double(n_created_nodes_) / double(n_max_nodes_))
<< " done" << std::endl;
}
}
}
virtual void SampleData(const std::vector<TrainDatum>& all, std::vector<TrainDatum>& sampled, std::mt19937& rng) {
unsigned int n = all.size();
unsigned int k = params_.n_test_samples;
k = n < k ? n : k;
std::set<int> indices;
std::uniform_int_distribution<int> udist(0, all.size()-1);
while(indices.size() < k) {
int idx = udist(rng);
indices.insert(idx);
}
sampled.resize(k);
int sidx = 0;
for(int idx : indices) {
sampled[sidx] = all[idx];
sidx += 1;
}
}
virtual void Split(const std::shared_ptr<SplitFunctionT>& split_function, const std::vector<TrainDatum>& samples, std::vector<TrainDatum>& left, std::vector<TrainDatum>& right) {
for(auto sample : samples) {
if(split_function->Split(sample.sample)) {
left.push_back(sample);
}
else {
right.push_back(sample);
}
}
}
virtual std::shared_ptr<SplitFunctionT> OptimizeSplitFunction(const std::vector<TrainDatum>& samples, int depth, std::mt19937& rng) {
std::vector<TrainDatum> split_samples;
SampleData(samples, split_samples, rng);
unsigned int min_samples_for_leaf = params_.min_samples_for_leaf;
float min_cost = std::numeric_limits<float>::max();
std::shared_ptr<SplitFunctionT> best_split_fcn;
float best_threshold = 0;
for(int split_fcn_idx = 0; split_fcn_idx < params_.n_test_split_functions; ++split_fcn_idx) {
auto split_fcn = gen_split_fcn_->Generate(rng, samples[0].sample);
for(int threshold_idx = 0; threshold_idx < params_.n_test_thresholds; ++threshold_idx) {
std::uniform_int_distribution<int> udist(0, split_samples.size()-1);
int rand_split_sample_idx = udist(rng);
float threshold = split_fcn->Compute(split_samples[rand_split_sample_idx].sample);
split_fcn->set_threshold(threshold);
std::vector<TrainDatum> left;
std::vector<TrainDatum> right;
Split(split_fcn, split_samples, left, right);
if(left.size() < min_samples_for_leaf || right.size() < min_samples_for_leaf) {
continue;
}
// std::cout << "split done " << left.size() << "," << right.size() << std::endl;
float split_cost = split_eval_->Eval(left, right, depth);
// std::cout << ", " << split_cost << ", " << threshold << "; " << std::endl;
if(split_cost < min_cost) {
min_cost = split_cost;
best_split_fcn = split_fcn;
best_threshold = threshold; //need theshold extra because of pointer
}
}
}
if(best_split_fcn != nullptr) {
best_split_fcn->set_threshold(best_threshold);
}
return best_split_fcn;
}
virtual NodePtr CreateLeafNode(const std::vector<TrainDatum>& samples, unsigned int depth) {
auto leaf_fct = gen_leaf_fcn_->Create(samples);
auto node = std::make_shared<LeafNode<LeafFunctionT>>(leaf_fct);
UpdateNodeInfo(depth, true);
return node;
}
protected:
const TrainParameters& params_;
const std::shared_ptr<SplitFunctionT> gen_split_fcn_;
const std::shared_ptr<LeafFunctionT> gen_leaf_fcn_;
const std::shared_ptr<SplitEvaluatorT> split_eval_;
int n_threads;
bool verbose_;
unsigned long n_created_nodes_;
unsigned long n_max_nodes_;
};
template <typename SplitFunctionT, typename LeafFunctionT, typename SplitEvaluatorT>
class TrainForestRecursive : public TrainForest<SplitFunctionT, LeafFunctionT, SplitEvaluatorT> {
public:
TrainForestRecursive(const TrainParameters& params, const std::shared_ptr<SplitFunctionT> gen_split_fcn, const std::shared_ptr<LeafFunctionT> gen_leaf_fcn, const std::shared_ptr<SplitEvaluatorT> split_eval, int n_threads, bool verbose)
: TrainForest<SplitFunctionT, LeafFunctionT, SplitEvaluatorT>(params, gen_split_fcn, gen_leaf_fcn, split_eval, n_threads, verbose) {}
virtual ~TrainForestRecursive() {}
virtual std::shared_ptr<Forest<SplitFunctionT, LeafFunctionT>> Train(const std::vector<TrainDatum>& samples, TrainType train_type, const std::shared_ptr<Forest<SplitFunctionT, LeafFunctionT>>& old_forest) {
this->PrintParams();
auto tim = std::chrono::system_clock::now();
auto forest = std::make_shared<Forest<SplitFunctionT, LeafFunctionT>>();
omp_set_num_threads(this->n_threads);
#pragma omp parallel for ordered
for(size_t treeIdx = 0; treeIdx < this->params_.n_trees; ++treeIdx) {
auto treetim = std::chrono::system_clock::now();
#pragma omp critical (TrainForest_train)
{
if(this->verbose_){
std::cout << "[TRAIN][START] training tree " << treeIdx << " of " << this->params_.n_trees << std::endl;
}
}
std::shared_ptr<Tree<SplitFunctionT, LeafFunctionT>> old_tree;
if(old_forest != 0 && treeIdx < old_forest->trees_size()) {
old_tree = old_forest->trees(treeIdx);
}
std::random_device rd;
std::mt19937 rng(rd());
auto tree = Train(samples, train_type, old_tree,rng);
#pragma omp critical (TrainForest_train)
{
forest->AddTree(tree);
if(this->verbose_){
auto now = std::chrono::system_clock::now();
auto ms = std::chrono::duration_cast<std::chrono::milliseconds>(now - treetim);
std::cout << "[TRAIN][FINISHED] training tree " << treeIdx << " of " << this->params_.n_trees << " - took " << (ms.count() * 1e-3) << "[s]" << std::endl;
std::cout << "[TRAIN][FINISHED] " << (this->params_.n_trees - forest->trees_size()) << " left for training" << std::endl;
}
}
}
if(this->verbose_){
auto now = std::chrono::system_clock::now();
auto ms = std::chrono::duration_cast<std::chrono::milliseconds>(now - tim);
std::cout << "[TRAIN][FINISHED] training forest - took " << (ms.count() * 1e-3) << "[s]" << std::endl;
}
return forest;
}
private:
virtual std::shared_ptr<Tree<SplitFunctionT, LeafFunctionT>> Train(const std::vector<TrainDatum>& samples, TrainType train_type, const std::shared_ptr<Tree<SplitFunctionT, LeafFunctionT>>& old_tree, std::mt19937& rng) {
NodePtr old_root;
if(old_tree != nullptr) {
old_root = old_tree->root();
}
NodePtr root = Train(samples, train_type, old_root, 0, rng);
return std::make_shared<Tree<SplitFunctionT, LeafFunctionT>>(root);
}
virtual NodePtr Train(const std::vector<TrainDatum>& samples, TrainType train_type, const NodePtr& old_node, unsigned int depth, std::mt19937& rng) {
if(depth < this->params_.max_tree_depth && samples.size() > this->params_.min_samples_to_split) {
std::shared_ptr<SplitFunctionT> best_split_fcn;
bool was_split_node = false;
if(old_node == nullptr || old_node->type() == LeafNode<LeafFunctionT>::TYPE) {
best_split_fcn = this->OptimizeSplitFunction(samples, depth, rng);
was_split_node = false;
}
else if(old_node->type() == SplitNode<SplitFunctionT, LeafFunctionT>::TYPE) {
auto split_node = std::static_pointer_cast<SplitNode<SplitFunctionT, LeafFunctionT>>(old_node);
best_split_fcn = split_node->split_fcn()->Copy();
was_split_node = true;
}
if(best_split_fcn == nullptr) {
if(old_node == nullptr || train_type == TrainType::TRAIN || train_type == TrainType::RETRAIN_WITH_REPLACEMENT) {
return this->CreateLeafNode(samples, depth);
}
else if(train_type == TrainType::RETRAIN) {
return old_node->Copy();
}
else {
std::cout << "[ERROR] unknown train type" << std::endl;
exit(-1);
}
}
// (1) split samples
std::vector<TrainDatum> leftsamples, rightsamples;
this->Split(best_split_fcn, samples, leftsamples, rightsamples);
//output node information
this->UpdateNodeInfo(depth, false);
//create split node - recursively train the siblings
if(was_split_node) {
auto split_node = std::static_pointer_cast<SplitNode<SplitFunctionT, LeafFunctionT>>(old_node);
NodePtr left = this->Train(leftsamples, train_type, split_node->left(), depth + 1, rng);
NodePtr right = this->Train(rightsamples, train_type, split_node->right(), depth + 1, rng);
auto new_node = std::make_shared<SplitNode<SplitFunctionT, LeafFunctionT>>(left, right, best_split_fcn);
return new_node;
}
else {
NodePtr left = this->Train(leftsamples, train_type, nullptr, depth + 1, rng);
NodePtr right = this->Train(rightsamples, train_type, nullptr, depth + 1, rng);
auto new_node = std::make_shared<SplitNode<SplitFunctionT, LeafFunctionT>>(left, right, best_split_fcn);
return new_node;
}
} // if samples < min_samples || depth >= max_depth then make leaf node
else {
if(old_node == 0 || train_type == TrainType::TRAIN || train_type == TrainType::RETRAIN_WITH_REPLACEMENT) {
return this->CreateLeafNode(samples, depth);
}
else if(train_type == TrainType::RETRAIN) {
return old_node->Copy();
}
else {
std::cout << "[ERROR] unknown train type" << std::endl;
exit(-1);
}
}
}
};
struct QueueTuple {
int depth;
std::vector<TrainDatum> train_data;
NodePtr* parent;
QueueTuple() : depth(-1), train_data(), parent(nullptr) {}
QueueTuple(int depth, std::vector<TrainDatum> train_data, NodePtr* parent) :
depth(depth), train_data(train_data), parent(parent) {}
};
template <typename SplitFunctionT, typename LeafFunctionT, typename SplitEvaluatorT>
class TrainForestQueued : public TrainForest<SplitFunctionT, LeafFunctionT, SplitEvaluatorT> {
public:
TrainForestQueued(const TrainParameters& params, const std::shared_ptr<SplitFunctionT> gen_split_fcn, const std::shared_ptr<LeafFunctionT> gen_leaf_fcn, const std::shared_ptr<SplitEvaluatorT> split_eval, int n_threads, bool verbose)
: TrainForest<SplitFunctionT, LeafFunctionT, SplitEvaluatorT>(params, gen_split_fcn, gen_leaf_fcn, split_eval, n_threads, verbose) {}
virtual ~TrainForestQueued() {}
virtual std::shared_ptr<Forest<SplitFunctionT, LeafFunctionT>> Train(const std::vector<TrainDatum>& samples, TrainType train_type, const std::shared_ptr<Forest<SplitFunctionT, LeafFunctionT>>& old_forest) {
this->PrintParams();
auto tim = std::chrono::system_clock::now();
auto forest = std::make_shared<Forest<SplitFunctionT, LeafFunctionT>>();
std::cout << "[TRAIN] create pool with " << this->n_threads << " threads" << std::endl;
auto pool = std::make_shared<ThreadPool>(this->n_threads);
for(int treeidx = 0; treeidx < this->params_.n_trees; ++treeidx) {
auto tree = std::make_shared<Tree<SplitFunctionT, LeafFunctionT>>();
forest->AddTree(tree);
AddJob(pool, QueueTuple(0, samples, &(tree->root_)));
}
while(pool->has_running_tasks()) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
if(this->verbose_){
auto now = std::chrono::system_clock::now();
auto ms = std::chrono::duration_cast<std::chrono::milliseconds>(now - tim);
std::cout << "[TRAIN][FINISHED] training forest - took " << (ms.count() * 1e-3) << "[s]" << std::endl;
}
return forest;
}
private:
virtual void AddJob(std::shared_ptr<ThreadPool> pool, QueueTuple data) {
pool->enqueue([this](std::shared_ptr<ThreadPool> pool, QueueTuple data) {
std::random_device rd;
std::mt19937 rng(rd());
std::shared_ptr<SplitFunctionT> best_split_fcn = nullptr;
if(data.depth < this->params_.max_tree_depth && int(data.train_data.size()) > this->params_.min_samples_to_split) {
best_split_fcn = this->OptimizeSplitFunction(data.train_data, data.depth, rng);
}
if(best_split_fcn == nullptr) {
auto node = this->CreateLeafNode(data.train_data, data.depth);
*(data.parent) = node;
}
else {
this->UpdateNodeInfo(data.depth, false);
auto node = std::make_shared<SplitNode<SplitFunctionT, LeafFunctionT>>();
node->split_fcn_ = best_split_fcn;
*(data.parent) = node;
QueueTuple left;
QueueTuple right;
this->Split(best_split_fcn, data.train_data, left.train_data, right.train_data);
left.depth = data.depth + 1;
right.depth = data.depth + 1;
left.parent = &(node->left_);
right.parent = &(node->right_);
this->AddJob(pool, left);
this->AddJob(pool, right);
}
}, pool, data);
}
};
+55
View File
@@ -0,0 +1,55 @@
#pragma once
#include "node.h"
template <typename SplitFunctionT, typename LeafFunctionT>
class Tree {
public:
Tree() : root_(nullptr) {}
Tree(NodePtr root) : root_(root) {}
virtual ~Tree() {}
std::shared_ptr<LeafFunctionT> inference(const SamplePtr sample) const {
if(root_ == nullptr) {
std::cout << "[ERROR] tree inference root node is NULL";
exit(-1);
}
NodePtr node = root_;
while(node->type() == SplitNode<SplitFunctionT, LeafFunctionT>::TYPE) {
auto splitNode = std::static_pointer_cast<SplitNode<SplitFunctionT, LeafFunctionT>>(node);
bool left = splitNode->Split(sample);
if(left) {
node = splitNode->left();
}
else {
node = splitNode->right();
}
}
auto leaf_node = std::static_pointer_cast<LeafNode<LeafFunctionT>>(node);
return leaf_node->leaf_node_fcn();
}
NodePtr root() const { return root_; }
void set_root(NodePtr root) { root_ = root; }
virtual void Save(SerializationOut& ar) const {
int type = root_->type();
ar << type;
root_->Save(ar);
}
virtual void Load(SerializationIn& ar) {
int type;
ar >> type;
root_ = MakeNode<SplitFunctionT, LeafFunctionT>(type);
root_->Load(ar);
}
public:
NodePtr root_;
};
+45
View File
@@ -0,0 +1,45 @@
from distutils.core import setup
from Cython.Build import cythonize
from distutils.extension import Extension
from Cython.Distutils import build_ext
import numpy as np
import platform
import os
this_dir = os.path.dirname(__file__)
extra_compile_args = ['-O3', '-std=c++11']
extra_link_args = []
print('using openmp')
extra_compile_args.append('-fopenmp')
extra_link_args.append('-fopenmp')
sources = ['hyperdepth.pyx']
extra_objects = []
library_dirs = []
libraries = ['m']
setup(
name="hyperdepth",
cmdclass= {'build_ext': build_ext},
ext_modules=[
Extension('hyperdepth',
sources,
extra_objects=extra_objects,
language='c++',
library_dirs=library_dirs,
libraries=libraries,
include_dirs=[
np.get_include(),
],
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args
)
]
)
+52
View File
@@ -0,0 +1,52 @@
#include "hyperdepth.h"
#include "rf/train.h"
int main() {
cv::Mat_<uint8_t> im = read_im(0);
cv::Mat_<uint16_t> disp = read_disp(0);
int im_rows = im.rows;
int im_cols = im.cols;
std::cout << im.rows << "/" << im.cols << std::endl;
std::cout << disp.rows << "/" << disp.cols << std::endl;
TrainParameters params;
params.n_trees = 6;
params.n_test_samples = 2048;
params.min_samples_to_split = 16;
params.min_samples_for_leaf = 8;
params.n_test_split_functions = 50;
params.n_test_thresholds = 10;
params.max_tree_depth = 8;
int n_classes = im_cols;
int n_disp_bins = 16;
int depth_switch = 4;
auto gen_split_fcn = std::make_shared<HDSplitFunctionT>();
auto gen_leaf_fcn = std::make_shared<HDLeafFunctionT>(n_classes * n_disp_bins);
auto split_eval = std::make_shared<HDSplitEvaluatorT>(true, n_classes, n_disp_bins, depth_switch);
for(int row = 0; row < im_rows; ++row) {
std::vector<TrainDatum> train_data;
for(int idx = 0; idx < 12; ++idx) {
std::cout << "read sample " << idx << std::endl;
im = read_im(idx);
disp = read_disp(idx);
extract_row_samples(im, disp, row, train_data, true, n_disp_bins);
}
std::cout << "extracted " << train_data.size() << " train samples" << std::endl;
std::cout << "n_classes (" << n_classes << ") * n_disp_bins (" << n_disp_bins << ") = " << (n_classes * n_disp_bins) << std::endl;
TrainForestQueued<HDSplitFunctionT, HDLeafFunctionT, HDSplitEvaluatorT> train(params, gen_split_fcn, gen_leaf_fcn, split_eval, true);
auto forest = train.Train(train_data, TrainType::TRAIN, nullptr);
std::cout << "training done" << std::endl;
std::ostringstream forest_path;
forest_path << "cforest_" << row << ".bin";
BinarySerializationOut fout(forest_path.str());
forest->Save(fout);
}
}
+15
View File
@@ -0,0 +1,15 @@
import numpy as np
import matplotlib.pyplot as plt
import cv2
orig = cv2.imread('disp_orig.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
ta = cv2.imread('disp_ta.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
es = cv2.imread('disp_es.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
plt.figure()
plt.subplot(2,2,1); plt.imshow(orig / 16, vmin=0, vmax=4, cmap='magma')
plt.subplot(2,2,2); plt.imshow(ta / 16, vmin=0, vmax=4, cmap='magma')
plt.subplot(2,2,3); plt.imshow(es / 16, vmin=0, vmax=4, cmap='magma')
plt.subplot(2,2,4); plt.imshow(np.abs(es - ta) / 16, vmin=0, vmax=1, cmap='magma')
plt.show()