init
This commit is contained in:
@@ -0,0 +1,18 @@
|
||||
#ifndef COMMON_H
|
||||
#define COMMON_H
|
||||
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
|
||||
#define DISABLE_COPY_AND_ASSIGN(classname) \
|
||||
private:\
|
||||
classname(const classname&) = delete;\
|
||||
classname& operator=(const classname&) = delete;
|
||||
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,72 @@
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
|
||||
class Sample {
|
||||
public:
|
||||
Sample(int channels, int height, int width)
|
||||
: channels_(channels), height_(height), width_(width) {}
|
||||
|
||||
virtual ~Sample() {}
|
||||
|
||||
virtual float at(int c, int h, int w) const = 0;
|
||||
|
||||
virtual float operator()(int c, int h, int w) const {
|
||||
return at(c,h,w);
|
||||
}
|
||||
|
||||
virtual int channels() const { return channels_; }
|
||||
virtual int height() const { return height_; }
|
||||
virtual int width() const { return width_; }
|
||||
|
||||
protected:
|
||||
int channels_;
|
||||
int height_;
|
||||
int width_;
|
||||
};
|
||||
|
||||
typedef std::shared_ptr<Sample> SamplePtr;
|
||||
|
||||
|
||||
|
||||
|
||||
class Target {
|
||||
public:
|
||||
Target() {}
|
||||
virtual ~Target() {}
|
||||
};
|
||||
|
||||
typedef std::shared_ptr<Target> TargetPtr;
|
||||
typedef std::vector<TargetPtr> VecTargetPtr;
|
||||
typedef std::shared_ptr<VecTargetPtr> VecPtrTargetPtr;
|
||||
|
||||
|
||||
class ClassificationTarget : public Target {
|
||||
public:
|
||||
ClassificationTarget(int cl) : cl_(cl) {}
|
||||
virtual ~ClassificationTarget() {}
|
||||
int cl() const { return cl_; }
|
||||
|
||||
private:
|
||||
int cl_;
|
||||
};
|
||||
|
||||
typedef std::shared_ptr<ClassificationTarget> ClassificationTargetPtr;
|
||||
|
||||
|
||||
|
||||
|
||||
struct TrainDatum {
|
||||
SamplePtr sample;
|
||||
TargetPtr target;
|
||||
TargetPtr optimize_target;
|
||||
|
||||
TrainDatum() : sample(nullptr), target(nullptr), optimize_target(nullptr) {}
|
||||
|
||||
TrainDatum(SamplePtr sample, TargetPtr target)
|
||||
: sample(sample), target(target), optimize_target(target) {}
|
||||
|
||||
TrainDatum(SamplePtr sample, TargetPtr target, TargetPtr optimize_target)
|
||||
: sample(sample), target(target), optimize_target(optimize_target) {}
|
||||
};
|
||||
@@ -0,0 +1,92 @@
|
||||
#pragma once
|
||||
|
||||
#include "tree.h"
|
||||
|
||||
template <typename SplitFunctionT, typename LeafFunctionT>
|
||||
class Forest {
|
||||
public:
|
||||
Forest() {}
|
||||
virtual ~Forest() {}
|
||||
|
||||
std::shared_ptr<LeafFunctionT> inferencest(const SamplePtr& sample) const {
|
||||
int n_trees = trees_.size();
|
||||
|
||||
std::vector<std::shared_ptr<LeafFunctionT>> fcns;
|
||||
|
||||
//inference of individual trees
|
||||
for(int tree_idx = 0; tree_idx < n_trees; ++tree_idx) {
|
||||
std::shared_ptr<LeafFunctionT> tree_fcn = trees_[tree_idx]->inference(sample);
|
||||
fcns.push_back(tree_fcn);
|
||||
}
|
||||
|
||||
//combine tree fcns/results and collect all results
|
||||
return fcns[0]->Reduce(fcns);
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<LeafFunctionT>> inferencemt(const std::vector<SamplePtr>& samples, int n_threads) const {
|
||||
std::vector<std::shared_ptr<LeafFunctionT>> targets(samples.size());
|
||||
|
||||
omp_set_num_threads(n_threads);
|
||||
#pragma omp parallel for
|
||||
for(size_t sample_idx = 0; sample_idx < samples.size(); ++sample_idx) {
|
||||
targets[sample_idx] = inferencest(samples[sample_idx]);
|
||||
}
|
||||
|
||||
return targets;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<LeafFunctionT>> inferencemt(const std::vector<TrainDatum>& samples, int n_threads) const {
|
||||
std::vector<std::shared_ptr<LeafFunctionT>> targets(samples.size());
|
||||
|
||||
omp_set_num_threads(n_threads);
|
||||
#pragma omp parallel for
|
||||
for(size_t sample_idx = 0; sample_idx < samples.size(); ++sample_idx) {
|
||||
targets[sample_idx] = inferencest(samples[sample_idx].sample);
|
||||
}
|
||||
|
||||
return targets;
|
||||
}
|
||||
|
||||
void AddTree(std::shared_ptr<Tree<SplitFunctionT, LeafFunctionT>> tree) {
|
||||
trees_.push_back(tree);
|
||||
}
|
||||
|
||||
size_t trees_size() const { return trees_.size(); }
|
||||
// TreePtr trees(int idx) const { return trees_[idx]; }
|
||||
|
||||
virtual void Save(SerializationOut& ar) const {
|
||||
size_t n_trees = trees_.size();
|
||||
std::cout << "[DEBUG] write " << n_trees << " trees" << std::endl;
|
||||
ar << n_trees;
|
||||
|
||||
if(true) std::cout << "[Forest][write] write number of trees " << n_trees << std::endl;
|
||||
|
||||
for(size_t tree_idx = 0; tree_idx < trees_.size(); ++tree_idx) {
|
||||
if(true) std::cout << "[Forest][write] write tree nb. " << tree_idx << std::endl;
|
||||
trees_[tree_idx]->Save(ar);
|
||||
}
|
||||
}
|
||||
|
||||
virtual void Load(SerializationIn& ar) {
|
||||
size_t n_trees;
|
||||
ar >> n_trees;
|
||||
|
||||
if(true) std::cout << "[Forest][read] nTrees: " << n_trees << std::endl;
|
||||
|
||||
trees_.clear();
|
||||
for(size_t i = 0; i < n_trees; ++i) {
|
||||
if(true) std::cout << "[Forest][read] read tree " << (i+1) << " of " << n_trees << " - " << std::endl;
|
||||
|
||||
auto tree = std::make_shared<Tree<SplitFunctionT, LeafFunctionT>>();
|
||||
tree->Load(ar);
|
||||
trees_.push_back(tree);
|
||||
|
||||
if(true) std::cout << "[Forest][read] finished read tree " << (i+1) << " of " << n_trees << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<Tree<SplitFunctionT, LeafFunctionT>>> trees_;
|
||||
};
|
||||
|
||||
@@ -0,0 +1,99 @@
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "common.h"
|
||||
#include "data.h"
|
||||
|
||||
|
||||
class ClassProbabilitiesLeafFunction {
|
||||
public:
|
||||
ClassProbabilitiesLeafFunction() : n_classes_(-1) {}
|
||||
ClassProbabilitiesLeafFunction(int n_classes) : n_classes_(n_classes) {}
|
||||
virtual ~ClassProbabilitiesLeafFunction() {}
|
||||
|
||||
virtual std::shared_ptr<ClassProbabilitiesLeafFunction> Copy() const {
|
||||
auto fcn = std::make_shared<ClassProbabilitiesLeafFunction>();
|
||||
fcn->n_classes_ = n_classes_;
|
||||
fcn->counts_.resize(counts_.size());
|
||||
for(size_t idx = 0; idx < counts_.size(); ++idx) {
|
||||
fcn->counts_[idx] = counts_[idx];
|
||||
}
|
||||
fcn->sum_counts_ = sum_counts_;
|
||||
|
||||
return fcn;
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<ClassProbabilitiesLeafFunction> Create(const std::vector<TrainDatum>& samples) {
|
||||
auto stat = std::make_shared<ClassProbabilitiesLeafFunction>();
|
||||
|
||||
stat->counts_.resize(n_classes_, 0);
|
||||
for(auto sample : samples) {
|
||||
auto ctarget = std::static_pointer_cast<ClassificationTarget>(sample.target);
|
||||
stat->counts_[ctarget->cl()] += 1;
|
||||
}
|
||||
stat->sum_counts_ = samples.size();
|
||||
|
||||
return stat;
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<ClassProbabilitiesLeafFunction> Reduce(const std::vector<std::shared_ptr<ClassProbabilitiesLeafFunction>>& fcns) const {
|
||||
auto stat = std::make_shared<ClassProbabilitiesLeafFunction>();
|
||||
auto cfcn0 = std::static_pointer_cast<ClassProbabilitiesLeafFunction>(fcns[0]);
|
||||
stat->counts_.resize(cfcn0->counts_.size(), 0);
|
||||
stat->sum_counts_ = 0;
|
||||
|
||||
for(auto fcn : fcns) {
|
||||
auto cfcn = std::static_pointer_cast<ClassProbabilitiesLeafFunction>(fcn);
|
||||
for(size_t cl = 0; cl < stat->counts_.size(); ++cl) {
|
||||
stat->counts_[cl] += cfcn->counts_[cl];
|
||||
}
|
||||
stat->sum_counts_ += cfcn->sum_counts_;
|
||||
}
|
||||
|
||||
return stat;
|
||||
}
|
||||
|
||||
virtual int argmax() const {
|
||||
int max_idx = 0;
|
||||
int max_count = counts_[0];
|
||||
for(size_t idx = 1; idx < counts_.size(); ++idx) {
|
||||
if(counts_[idx] > max_count) {
|
||||
max_count = counts_[idx];
|
||||
max_idx = idx;
|
||||
}
|
||||
}
|
||||
return max_idx;
|
||||
}
|
||||
|
||||
virtual void Save(SerializationOut& ar) const {
|
||||
ar << n_classes_;
|
||||
int n_counts = counts_.size();
|
||||
ar << n_counts;
|
||||
for(int idx = 0; idx < n_counts; ++idx) {
|
||||
ar << counts_[idx];
|
||||
}
|
||||
ar << sum_counts_;
|
||||
}
|
||||
|
||||
virtual void Load(SerializationIn& ar) {
|
||||
ar >> n_classes_;
|
||||
int n_counts;
|
||||
ar >> n_counts;
|
||||
counts_.resize(n_counts);
|
||||
for(int idx = 0; idx < n_counts; ++idx) {
|
||||
ar >> counts_[idx];
|
||||
}
|
||||
ar >> sum_counts_;
|
||||
}
|
||||
|
||||
public:
|
||||
int n_classes_;
|
||||
|
||||
std::vector<int> counts_;
|
||||
int sum_counts_;
|
||||
|
||||
DISABLE_COPY_AND_ASSIGN(ClassProbabilitiesLeafFunction);
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,158 @@
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "serialization.h"
|
||||
#include "leaffcn.h"
|
||||
#include "splitfcn.h"
|
||||
|
||||
class Node {
|
||||
public:
|
||||
Node() {}
|
||||
virtual ~Node() {}
|
||||
|
||||
virtual std::shared_ptr<Node> Copy() const = 0;
|
||||
|
||||
virtual int type() const = 0;
|
||||
|
||||
virtual void Save(SerializationOut& ar) const = 0;
|
||||
virtual void Load(SerializationIn& ar) = 0;
|
||||
|
||||
};
|
||||
|
||||
typedef std::shared_ptr<Node> NodePtr;
|
||||
|
||||
|
||||
template <typename LeafFunctionT>
|
||||
class LeafNode : public Node {
|
||||
public:
|
||||
static const int TYPE = 0;
|
||||
|
||||
LeafNode() {}
|
||||
LeafNode(std::shared_ptr<LeafFunctionT> leaf_node_fcn) : leaf_node_fcn_(leaf_node_fcn) {}
|
||||
|
||||
virtual ~LeafNode() {}
|
||||
|
||||
virtual NodePtr Copy() const {
|
||||
auto node = std::make_shared<LeafNode>();
|
||||
node->leaf_node_fcn_ = leaf_node_fcn_->Copy();
|
||||
return node;
|
||||
}
|
||||
|
||||
virtual void Save(SerializationOut& ar) const {
|
||||
leaf_node_fcn_->Save(ar);
|
||||
}
|
||||
|
||||
virtual void Load(SerializationIn& ar) {
|
||||
leaf_node_fcn_ = std::make_shared<LeafFunctionT>();
|
||||
leaf_node_fcn_->Load(ar);
|
||||
}
|
||||
|
||||
virtual int type() const { return TYPE; };
|
||||
std::shared_ptr<LeafFunctionT> leaf_node_fcn() const { return leaf_node_fcn_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<LeafFunctionT> leaf_node_fcn_;
|
||||
|
||||
DISABLE_COPY_AND_ASSIGN(LeafNode);
|
||||
};
|
||||
|
||||
|
||||
template <typename SplitFunctionT, typename LeafFunctionT>
|
||||
class SplitNode : public Node {
|
||||
public:
|
||||
static const int TYPE = 1;
|
||||
|
||||
SplitNode() {}
|
||||
|
||||
SplitNode(NodePtr left, NodePtr right, std::shared_ptr<SplitFunctionT> split_fcn) :
|
||||
left_(left), right_(right), split_fcn_(split_fcn)
|
||||
{}
|
||||
|
||||
virtual ~SplitNode() {}
|
||||
|
||||
virtual std::shared_ptr<Node> Copy() const {
|
||||
std::shared_ptr<SplitNode> node = std::make_shared<SplitNode>();
|
||||
node->left_ = left_->Copy();
|
||||
node->right_ = right_->Copy();
|
||||
node->split_fcn_ = split_fcn_->Copy();
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
bool Split(SamplePtr sample) {
|
||||
return split_fcn_->Split(sample);
|
||||
}
|
||||
|
||||
virtual void Save(SerializationOut& ar) const {
|
||||
split_fcn_->Save(ar);
|
||||
|
||||
//left
|
||||
int type = left_->type();
|
||||
ar << type;
|
||||
left_->Save(ar);
|
||||
|
||||
//right
|
||||
type = right_->type();
|
||||
ar << type;
|
||||
right_->Save(ar);
|
||||
}
|
||||
|
||||
virtual void Load(SerializationIn& ar);
|
||||
|
||||
|
||||
virtual int type() const { return TYPE; }
|
||||
|
||||
NodePtr left() const { return left_; }
|
||||
NodePtr right() const { return right_; }
|
||||
std::shared_ptr<SplitFunctionT> split_fcn() const { return split_fcn_; }
|
||||
|
||||
void set_left(NodePtr left) { left_ = left; }
|
||||
void set_right(NodePtr right) { right_ = right; }
|
||||
void set_split_fcn(std::shared_ptr<SplitFunctionT> split_fcn) { split_fcn_ = split_fcn; }
|
||||
|
||||
public:
|
||||
NodePtr left_;
|
||||
NodePtr right_;
|
||||
std::shared_ptr<SplitFunctionT> split_fcn_;
|
||||
|
||||
DISABLE_COPY_AND_ASSIGN(SplitNode);
|
||||
};
|
||||
|
||||
|
||||
template <typename SplitFunctionT, typename LeafFunctionT>
|
||||
NodePtr MakeNode(int type) {
|
||||
NodePtr node;
|
||||
if(type == LeafNode<LeafFunctionT>::TYPE) {
|
||||
node = std::make_shared<LeafNode<LeafFunctionT>>();
|
||||
}
|
||||
else if(type == SplitNode<SplitFunctionT, LeafFunctionT>::TYPE) {
|
||||
node = std::make_shared<SplitNode<SplitFunctionT, LeafFunctionT>>();
|
||||
}
|
||||
else {
|
||||
std::cout << "[ERROR] unknown node type" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
|
||||
template <typename SplitFunctionT, typename LeafFunctionT>
|
||||
void SplitNode<SplitFunctionT, LeafFunctionT>::Load(SerializationIn& ar) {
|
||||
|
||||
split_fcn_ = std::make_shared<SplitFunctionT>();
|
||||
split_fcn_->Load(ar);
|
||||
|
||||
//left
|
||||
int left_type;
|
||||
ar >> left_type;
|
||||
left_ = MakeNode<SplitFunctionT, LeafFunctionT>(left_type);
|
||||
left_->Load(ar);
|
||||
|
||||
//right
|
||||
int right_type;
|
||||
ar >> right_type;
|
||||
right_ = MakeNode<SplitFunctionT, LeafFunctionT>(right_type);
|
||||
right_->Load(ar);
|
||||
}
|
||||
@@ -0,0 +1,256 @@
|
||||
#pragma once
|
||||
|
||||
#include <fstream>
|
||||
|
||||
class SerializationOut {
|
||||
public:
|
||||
SerializationOut(const std::string& path) : path_(path) {}
|
||||
virtual ~SerializationOut() {}
|
||||
|
||||
virtual SerializationOut& operator<<(const bool& v) = 0;
|
||||
virtual SerializationOut& operator<<(const char& v) = 0;
|
||||
virtual SerializationOut& operator<<(const int& v) = 0;
|
||||
virtual SerializationOut& operator<<(const unsigned int& v) = 0;
|
||||
virtual SerializationOut& operator<<(const long int& v) = 0;
|
||||
virtual SerializationOut& operator<<(const unsigned long int& v) = 0;
|
||||
virtual SerializationOut& operator<<(const long long int& v) = 0;
|
||||
virtual SerializationOut& operator<<(const unsigned long long int& v) = 0;
|
||||
virtual SerializationOut& operator<<(const float& v) = 0;
|
||||
virtual SerializationOut& operator<<(const double& v) = 0;
|
||||
|
||||
protected:
|
||||
const std::string& path_;
|
||||
};
|
||||
|
||||
class SerializationIn {
|
||||
public:
|
||||
SerializationIn(const std::string& path) : path_(path) {}
|
||||
virtual ~SerializationIn() {}
|
||||
|
||||
virtual SerializationIn& operator>>(bool& v) = 0;
|
||||
virtual SerializationIn& operator>>(char& v) = 0;
|
||||
virtual SerializationIn& operator>>(int& v) = 0;
|
||||
virtual SerializationIn& operator>>(unsigned int& v) = 0;
|
||||
virtual SerializationIn& operator>>(long int& v) = 0;
|
||||
virtual SerializationIn& operator>>(unsigned long int& v) = 0;
|
||||
virtual SerializationIn& operator>>(long long int& v) = 0;
|
||||
virtual SerializationIn& operator>>(unsigned long long int& v) = 0;
|
||||
virtual SerializationIn& operator>>(float& v) = 0;
|
||||
virtual SerializationIn& operator>>(double& v) = 0;
|
||||
|
||||
protected:
|
||||
const std::string& path_;
|
||||
};
|
||||
|
||||
class TextSerializationOut : public SerializationOut {
|
||||
public:
|
||||
TextSerializationOut(const std::string& path) : SerializationOut(path),
|
||||
f_(path.c_str()) {}
|
||||
virtual ~TextSerializationOut() {
|
||||
f_.close();
|
||||
}
|
||||
|
||||
virtual SerializationOut& operator<<(const bool& v) {
|
||||
f_ << v << std::endl;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const char& v) {
|
||||
f_ << v << std::endl;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const int& v) {
|
||||
f_ << v << std::endl;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const unsigned int& v) {
|
||||
f_ << v << std::endl;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const long int& v) {
|
||||
f_ << v << std::endl;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const unsigned long int& v) {
|
||||
f_ << v << std::endl;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const long long int& v) {
|
||||
f_ << v << std::endl;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const unsigned long long int& v) {
|
||||
f_ << v << std::endl;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const float& v) {
|
||||
f_ << v << std::endl;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const double& v) {
|
||||
f_ << v << std::endl;
|
||||
return (*this);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::ofstream f_;
|
||||
};
|
||||
|
||||
class TextSerializationIn : public SerializationIn {
|
||||
public:
|
||||
TextSerializationIn(const std::string& path) : SerializationIn(path),
|
||||
f_(path.c_str()) {}
|
||||
virtual ~TextSerializationIn() {
|
||||
f_.close();
|
||||
}
|
||||
|
||||
virtual SerializationIn& operator>>(bool& v) {
|
||||
f_ >> v;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(char& v) {
|
||||
f_ >> v;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(int& v) {
|
||||
f_ >> v;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(unsigned int& v) {
|
||||
f_ >> v;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(long int& v) {
|
||||
f_ >> v;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(unsigned long int& v) {
|
||||
f_ >> v;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(long long int& v) {
|
||||
f_ >> v;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(unsigned long long int& v) {
|
||||
f_ >> v;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(float& v) {
|
||||
f_ >> v;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(double& v) {
|
||||
f_ >> v;
|
||||
return (*this);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::ifstream f_;
|
||||
};
|
||||
|
||||
class BinarySerializationOut : public SerializationOut {
|
||||
public:
|
||||
BinarySerializationOut(const std::string& path) : SerializationOut(path),
|
||||
f_(path.c_str(), std::ios::binary) {}
|
||||
virtual ~BinarySerializationOut() {
|
||||
f_.close();
|
||||
}
|
||||
|
||||
virtual SerializationOut& operator<<(const bool& v) {
|
||||
f_.write((char*)&v, sizeof(bool));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const char& v) {
|
||||
f_.write((char*)&v, sizeof(char));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const int& v) {
|
||||
f_.write((char*)&v, sizeof(int));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const unsigned int& v) {
|
||||
f_.write((char*)&v, sizeof(unsigned int));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const long int& v) {
|
||||
f_.write((char*)&v, sizeof(long int));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const unsigned long int& v) {
|
||||
f_.write((char*)&v, sizeof(unsigned long int));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const long long int& v) {
|
||||
f_.write((char*)&v, sizeof(long long int));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const unsigned long long int& v) {
|
||||
f_.write((char*)&v, sizeof(unsigned long long int));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const float& v) {
|
||||
f_.write((char*)&v, sizeof(float));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const double& v) {
|
||||
f_.write((char*)&v, sizeof(double));
|
||||
return (*this);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::ofstream f_;
|
||||
};
|
||||
|
||||
class BinarySerializationIn : public SerializationIn {
|
||||
public:
|
||||
BinarySerializationIn(const std::string& path) : SerializationIn(path),
|
||||
f_(path.c_str(), std::ios::binary) {}
|
||||
virtual ~BinarySerializationIn() {
|
||||
f_.close();
|
||||
}
|
||||
|
||||
virtual SerializationIn& operator>>(bool& v) {
|
||||
f_.read((char*)&v, sizeof(bool));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(char& v) {
|
||||
f_.read((char*)&v, sizeof(char));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(int& v) {
|
||||
f_.read((char*)&v, sizeof(int));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(unsigned int& v) {
|
||||
f_.read((char*)&v, sizeof(unsigned int));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(long int& v) {
|
||||
f_.read((char*)&v, sizeof(long int));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(unsigned long int& v) {
|
||||
f_.read((char*)&v, sizeof(unsigned long int));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(long long int& v) {
|
||||
f_.read((char*)&v, sizeof(long long int));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(unsigned long long int& v) {
|
||||
f_.read((char*)&v, sizeof(unsigned long long int));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(float& v) {
|
||||
f_.read((char*)&v, sizeof(float));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(double& v) {
|
||||
f_.read((char*)&v, sizeof(double));
|
||||
return (*this);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::ifstream f_;
|
||||
};
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
#pragma once
|
||||
|
||||
class SplitEvaluator {
|
||||
public:
|
||||
SplitEvaluator(bool normalize)
|
||||
: normalize_(normalize) {}
|
||||
|
||||
virtual ~SplitEvaluator() {}
|
||||
|
||||
virtual float Eval(const std::vector<TrainDatum>& lefttargets, const std::vector<TrainDatum>& righttargets, int depth) const {
|
||||
float purity_left = Purity(lefttargets, depth);
|
||||
float purity_right = Purity(righttargets, depth);
|
||||
|
||||
float normalize_left = 1.0;
|
||||
float normalize_right = 1.0;
|
||||
|
||||
if(normalize_) {
|
||||
unsigned int n_left = lefttargets.size();
|
||||
unsigned int n_right = righttargets.size();
|
||||
unsigned int n_total = n_left + n_right;
|
||||
|
||||
normalize_left = float(n_left) / float(n_total);
|
||||
normalize_right = float(n_right) / float(n_total);
|
||||
}
|
||||
|
||||
float purity = purity_left * normalize_left + purity_right * normalize_right;
|
||||
|
||||
return purity;
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual float Purity(const std::vector<TrainDatum>& targets, int depth) const = 0;
|
||||
|
||||
protected:
|
||||
bool normalize_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class ClassificationIGSplitEvaluator : public SplitEvaluator {
|
||||
public:
|
||||
ClassificationIGSplitEvaluator(bool normalize, int n_classes)
|
||||
: SplitEvaluator(normalize), n_classes_(n_classes) {}
|
||||
virtual ~ClassificationIGSplitEvaluator() {}
|
||||
|
||||
protected:
|
||||
virtual float Purity(const std::vector<TrainDatum>& targets, int depth) const {
|
||||
if(targets.size() == 0) return 0;
|
||||
|
||||
std::vector<int> ps;
|
||||
ps.resize(n_classes_, 0);
|
||||
for(auto target : targets) {
|
||||
auto ctarget = std::static_pointer_cast<ClassificationTarget>(target.optimize_target);
|
||||
ps[ctarget->cl()] += 1;
|
||||
}
|
||||
|
||||
float h = 0;
|
||||
for(int cl = 0; cl < n_classes_; ++cl) {
|
||||
float fi = float(ps[cl]) / float(targets.size());
|
||||
if(fi > 0) {
|
||||
h = h - fi * std::log(fi);
|
||||
}
|
||||
}
|
||||
|
||||
return h;
|
||||
}
|
||||
|
||||
private:
|
||||
int n_classes_;
|
||||
};
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
#pragma once
|
||||
|
||||
#include <random>
|
||||
|
||||
class SplitFunction {
|
||||
public:
|
||||
SplitFunction() {}
|
||||
virtual ~SplitFunction() {}
|
||||
|
||||
virtual float Compute(SamplePtr sample) const = 0;
|
||||
|
||||
virtual bool Split(SamplePtr sample) const {
|
||||
return Compute(sample) < threshold_;
|
||||
}
|
||||
|
||||
virtual void Save(SerializationOut& ar) const {
|
||||
ar << threshold_;
|
||||
}
|
||||
|
||||
virtual void Load(SerializationIn& ar) {
|
||||
ar >> threshold_;
|
||||
}
|
||||
|
||||
virtual float threshold() const { return threshold_; }
|
||||
virtual void set_threshold(float threshold) { threshold_ = threshold; }
|
||||
|
||||
protected:
|
||||
float threshold_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class SplitFunctionPixelDifference : public SplitFunction {
|
||||
public:
|
||||
|
||||
SplitFunctionPixelDifference() {}
|
||||
virtual ~SplitFunctionPixelDifference() {}
|
||||
|
||||
virtual std::shared_ptr<SplitFunctionPixelDifference> Copy() const {
|
||||
std::shared_ptr<SplitFunctionPixelDifference> split_fcn = std::make_shared<SplitFunctionPixelDifference>();
|
||||
split_fcn->threshold_ = threshold_;
|
||||
split_fcn->c0_ = c0_;
|
||||
split_fcn->c1_ = c1_;
|
||||
split_fcn->h0_ = h0_;
|
||||
split_fcn->h1_ = h1_;
|
||||
split_fcn->w0_ = w0_;
|
||||
split_fcn->w1_ = w1_;
|
||||
|
||||
return split_fcn;
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SplitFunctionPixelDifference> Generate(std::mt19937& rng, const SamplePtr sample) const {
|
||||
std::shared_ptr<SplitFunctionPixelDifference> split_fcn = std::make_shared<SplitFunctionPixelDifference>();
|
||||
|
||||
std::uniform_int_distribution<int> cdist(0, sample->channels()-1);
|
||||
split_fcn->c0_ = cdist(rng);
|
||||
split_fcn->c1_ = cdist(rng);
|
||||
|
||||
std::uniform_int_distribution<int> hdist(0, sample->height()-1);
|
||||
split_fcn->h0_ = hdist(rng);
|
||||
split_fcn->h1_ = hdist(rng);
|
||||
|
||||
std::uniform_int_distribution<int> wdist(0, sample->width()-1);
|
||||
split_fcn->w0_ = wdist(rng);
|
||||
split_fcn->w1_ = wdist(rng);
|
||||
|
||||
return split_fcn;
|
||||
}
|
||||
|
||||
virtual float Compute(SamplePtr sample) const {
|
||||
return (*sample)(c0_, h0_, w0_) - (*sample)(c1_, h1_, w1_);
|
||||
}
|
||||
|
||||
virtual void Save(SerializationOut& ar) const {
|
||||
SplitFunction::Save(ar);
|
||||
ar << c0_;
|
||||
ar << c1_;
|
||||
ar << h0_;
|
||||
ar << h1_;
|
||||
ar << w0_;
|
||||
ar << w1_;
|
||||
}
|
||||
|
||||
virtual void Load(SerializationIn& ar) {
|
||||
SplitFunction::Load(ar);
|
||||
|
||||
ar >> c0_;
|
||||
ar >> c1_;
|
||||
ar >> h0_;
|
||||
ar >> h1_;
|
||||
ar >> w0_;
|
||||
ar >> w1_;
|
||||
}
|
||||
|
||||
private:
|
||||
int c0_;
|
||||
int c1_;
|
||||
int h0_;
|
||||
int h1_;
|
||||
int w0_;
|
||||
int w1_;
|
||||
|
||||
DISABLE_COPY_AND_ASSIGN(SplitFunctionPixelDifference);
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <queue>
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
#include <future>
|
||||
#include <functional>
|
||||
#include <stdexcept>
|
||||
|
||||
class ThreadPool {
|
||||
public:
|
||||
ThreadPool(size_t);
|
||||
template<class F, class... Args>
|
||||
auto enqueue(F&& f, Args&&... args)
|
||||
-> std::future<typename std::result_of<F(Args...)>::type>;
|
||||
~ThreadPool();
|
||||
|
||||
bool has_running_tasks() {
|
||||
std::unique_lock<std::mutex> lock(running_tasks_mutex);
|
||||
return n_running_tasks > 0;
|
||||
}
|
||||
|
||||
private:
|
||||
// need to keep track of threads so we can join them
|
||||
std::vector< std::thread > workers;
|
||||
// the task queue
|
||||
std::queue< std::function<void()> > tasks;
|
||||
|
||||
int n_running_tasks;
|
||||
|
||||
// synchronization
|
||||
std::mutex queue_mutex;
|
||||
std::mutex running_tasks_mutex;
|
||||
std::condition_variable condition;
|
||||
bool stop;
|
||||
};
|
||||
|
||||
// the constructor just launches some amount of workers
|
||||
inline ThreadPool::ThreadPool(size_t threads)
|
||||
: n_running_tasks(0), stop(false)
|
||||
{
|
||||
for(size_t i = 0;i<threads;++i)
|
||||
workers.emplace_back(
|
||||
[this]
|
||||
{
|
||||
for(;;)
|
||||
{
|
||||
std::function<void()> task;
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(this->queue_mutex);
|
||||
this->condition.wait(lock,
|
||||
[this]{ return this->stop || !this->tasks.empty(); });
|
||||
if(this->stop && this->tasks.empty())
|
||||
return;
|
||||
task = std::move(this->tasks.front());
|
||||
this->tasks.pop();
|
||||
}
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(this->running_tasks_mutex);
|
||||
n_running_tasks++;
|
||||
}
|
||||
task();
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(this->running_tasks_mutex);
|
||||
n_running_tasks--;
|
||||
}
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
// add new work item to the pool
|
||||
template<class F, class... Args>
|
||||
auto ThreadPool::enqueue(F&& f, Args&&... args)
|
||||
-> std::future<typename std::result_of<F(Args...)>::type>
|
||||
{
|
||||
using return_type = typename std::result_of<F(Args...)>::type;
|
||||
|
||||
auto task = std::make_shared< std::packaged_task<return_type()> >(
|
||||
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
|
||||
);
|
||||
|
||||
std::future<return_type> res = task->get_future();
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
|
||||
// don't allow enqueueing after stopping the pool
|
||||
if(stop)
|
||||
throw std::runtime_error("enqueue on stopped ThreadPool");
|
||||
|
||||
tasks.emplace([task](){ (*task)(); });
|
||||
}
|
||||
condition.notify_one();
|
||||
return res;
|
||||
}
|
||||
|
||||
// the destructor joins all threads
|
||||
inline ThreadPool::~ThreadPool()
|
||||
{
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
stop = true;
|
||||
}
|
||||
condition.notify_all();
|
||||
for(std::thread &worker: workers)
|
||||
worker.join();
|
||||
}
|
||||
@@ -0,0 +1,423 @@
|
||||
#pragma once
|
||||
|
||||
#include <chrono>
|
||||
#include <set>
|
||||
#include <queue>
|
||||
|
||||
#include "threadpool.h"
|
||||
|
||||
#include "forest.h"
|
||||
#include "spliteval.h"
|
||||
|
||||
|
||||
enum class TrainType : int {
|
||||
TRAIN = 0,
|
||||
RETRAIN = 1,
|
||||
RETRAIN_WITH_REPLACEMENT = 2
|
||||
};
|
||||
|
||||
struct TrainParameters {
|
||||
TrainType train_type;
|
||||
int n_trees;
|
||||
int max_tree_depth;
|
||||
int n_test_split_functions;
|
||||
int n_test_thresholds;
|
||||
int n_test_samples;
|
||||
int min_samples_to_split;
|
||||
int min_samples_for_leaf;
|
||||
int print_node_info;
|
||||
|
||||
TrainParameters() :
|
||||
train_type(TrainType::TRAIN),
|
||||
n_trees(5),
|
||||
max_tree_depth(7),
|
||||
n_test_split_functions(50),
|
||||
n_test_thresholds(10),
|
||||
n_test_samples(100),
|
||||
min_samples_to_split(14),
|
||||
min_samples_for_leaf(7),
|
||||
print_node_info(100)
|
||||
{}
|
||||
};
|
||||
|
||||
|
||||
template <typename SplitFunctionT, typename LeafFunctionT, typename SplitEvaluatorT>
|
||||
class TrainForest {
|
||||
public:
|
||||
TrainForest(const TrainParameters& params, const std::shared_ptr<SplitFunctionT> gen_split_fcn, const std::shared_ptr<LeafFunctionT> gen_leaf_fcn, const std::shared_ptr<SplitEvaluatorT> split_eval, int n_threads, bool verbose)
|
||||
: params_(params), gen_split_fcn_(gen_split_fcn), gen_leaf_fcn_(gen_leaf_fcn), split_eval_(split_eval), n_threads(n_threads), verbose_(verbose) {
|
||||
|
||||
n_created_nodes_ = 0;
|
||||
n_max_nodes_ = 1;
|
||||
unsigned long n_nodes_d = 1;
|
||||
for(int depth = 0; depth < params.max_tree_depth; ++depth) {
|
||||
n_nodes_d *= 2;
|
||||
n_max_nodes_ += n_nodes_d;
|
||||
}
|
||||
n_max_nodes_ *= params.n_trees;
|
||||
}
|
||||
|
||||
virtual ~TrainForest() {}
|
||||
|
||||
virtual std::shared_ptr<Forest<SplitFunctionT, LeafFunctionT>> Train(const std::vector<TrainDatum>& samples, TrainType train_type, const std::shared_ptr<Forest<SplitFunctionT, LeafFunctionT>>& old_forest) = 0;
|
||||
|
||||
protected:
|
||||
virtual void PrintParams() {
|
||||
if(verbose_){
|
||||
#pragma omp critical (TrainForest_train)
|
||||
{
|
||||
std::cout << "[TRAIN] training forest " << std::endl;
|
||||
std::cout << "[TRAIN] n_trees : " << params_.n_trees << std::endl;
|
||||
std::cout << "[TRAIN] max_tree_depth : " << params_.max_tree_depth << std::endl;
|
||||
std::cout << "[TRAIN] n_test_split_functions: " << params_.n_test_split_functions << std::endl;
|
||||
std::cout << "[TRAIN] n_test_thresholds : " << params_.n_test_thresholds << std::endl;
|
||||
std::cout << "[TRAIN] n_test_samples : " << params_.n_test_samples << std::endl;
|
||||
std::cout << "[TRAIN] min_samples_to_split : " << params_.min_samples_to_split << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
virtual void UpdateNodeInfo(unsigned int depth, bool leaf) {
|
||||
if(verbose_) {
|
||||
n_created_nodes_ += 1;
|
||||
|
||||
if(leaf) {
|
||||
unsigned long n_nodes_d = 1;
|
||||
unsigned int n_remove_max_nodes = 0;
|
||||
for(int d = depth; d < params_.max_tree_depth; ++d) {
|
||||
n_nodes_d *= 2;
|
||||
n_remove_max_nodes += n_nodes_d;
|
||||
}
|
||||
n_max_nodes_ -= n_remove_max_nodes;
|
||||
}
|
||||
|
||||
if(n_created_nodes_ % params_.print_node_info == 0 || n_created_nodes_ == n_max_nodes_) {
|
||||
std::cout << "[Forest]"
|
||||
<< " created node number " << n_created_nodes_
|
||||
<< " @ depth " << depth
|
||||
<< ", max. " << n_max_nodes_ << " left"
|
||||
<< " => " << (double(n_created_nodes_) / double(n_max_nodes_))
|
||||
<< " done" << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
virtual void SampleData(const std::vector<TrainDatum>& all, std::vector<TrainDatum>& sampled, std::mt19937& rng) {
|
||||
unsigned int n = all.size();
|
||||
unsigned int k = params_.n_test_samples;
|
||||
k = n < k ? n : k;
|
||||
|
||||
std::set<int> indices;
|
||||
std::uniform_int_distribution<int> udist(0, all.size()-1);
|
||||
while(indices.size() < k) {
|
||||
int idx = udist(rng);
|
||||
indices.insert(idx);
|
||||
}
|
||||
|
||||
sampled.resize(k);
|
||||
int sidx = 0;
|
||||
for(int idx : indices) {
|
||||
sampled[sidx] = all[idx];
|
||||
sidx += 1;
|
||||
}
|
||||
}
|
||||
|
||||
virtual void Split(const std::shared_ptr<SplitFunctionT>& split_function, const std::vector<TrainDatum>& samples, std::vector<TrainDatum>& left, std::vector<TrainDatum>& right) {
|
||||
for(auto sample : samples) {
|
||||
if(split_function->Split(sample.sample)) {
|
||||
left.push_back(sample);
|
||||
}
|
||||
else {
|
||||
right.push_back(sample);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
virtual std::shared_ptr<SplitFunctionT> OptimizeSplitFunction(const std::vector<TrainDatum>& samples, int depth, std::mt19937& rng) {
|
||||
std::vector<TrainDatum> split_samples;
|
||||
SampleData(samples, split_samples, rng);
|
||||
|
||||
unsigned int min_samples_for_leaf = params_.min_samples_for_leaf;
|
||||
|
||||
float min_cost = std::numeric_limits<float>::max();
|
||||
std::shared_ptr<SplitFunctionT> best_split_fcn;
|
||||
float best_threshold = 0;
|
||||
|
||||
for(int split_fcn_idx = 0; split_fcn_idx < params_.n_test_split_functions; ++split_fcn_idx) {
|
||||
auto split_fcn = gen_split_fcn_->Generate(rng, samples[0].sample);
|
||||
|
||||
for(int threshold_idx = 0; threshold_idx < params_.n_test_thresholds; ++threshold_idx) {
|
||||
std::uniform_int_distribution<int> udist(0, split_samples.size()-1);
|
||||
int rand_split_sample_idx = udist(rng);
|
||||
float threshold = split_fcn->Compute(split_samples[rand_split_sample_idx].sample);
|
||||
split_fcn->set_threshold(threshold);
|
||||
|
||||
std::vector<TrainDatum> left;
|
||||
std::vector<TrainDatum> right;
|
||||
Split(split_fcn, split_samples, left, right);
|
||||
if(left.size() < min_samples_for_leaf || right.size() < min_samples_for_leaf) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// std::cout << "split done " << left.size() << "," << right.size() << std::endl;
|
||||
float split_cost = split_eval_->Eval(left, right, depth);
|
||||
// std::cout << ", " << split_cost << ", " << threshold << "; " << std::endl;
|
||||
|
||||
if(split_cost < min_cost) {
|
||||
min_cost = split_cost;
|
||||
best_split_fcn = split_fcn;
|
||||
best_threshold = threshold; //need theshold extra because of pointer
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(best_split_fcn != nullptr) {
|
||||
best_split_fcn->set_threshold(best_threshold);
|
||||
}
|
||||
|
||||
return best_split_fcn;
|
||||
}
|
||||
|
||||
|
||||
virtual NodePtr CreateLeafNode(const std::vector<TrainDatum>& samples, unsigned int depth) {
|
||||
auto leaf_fct = gen_leaf_fcn_->Create(samples);
|
||||
auto node = std::make_shared<LeafNode<LeafFunctionT>>(leaf_fct);
|
||||
|
||||
UpdateNodeInfo(depth, true);
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
protected:
|
||||
const TrainParameters& params_;
|
||||
const std::shared_ptr<SplitFunctionT> gen_split_fcn_;
|
||||
const std::shared_ptr<LeafFunctionT> gen_leaf_fcn_;
|
||||
const std::shared_ptr<SplitEvaluatorT> split_eval_;
|
||||
int n_threads;
|
||||
bool verbose_;
|
||||
|
||||
unsigned long n_created_nodes_;
|
||||
unsigned long n_max_nodes_;
|
||||
};
|
||||
|
||||
|
||||
template <typename SplitFunctionT, typename LeafFunctionT, typename SplitEvaluatorT>
|
||||
class TrainForestRecursive : public TrainForest<SplitFunctionT, LeafFunctionT, SplitEvaluatorT> {
|
||||
public:
|
||||
TrainForestRecursive(const TrainParameters& params, const std::shared_ptr<SplitFunctionT> gen_split_fcn, const std::shared_ptr<LeafFunctionT> gen_leaf_fcn, const std::shared_ptr<SplitEvaluatorT> split_eval, int n_threads, bool verbose)
|
||||
: TrainForest<SplitFunctionT, LeafFunctionT, SplitEvaluatorT>(params, gen_split_fcn, gen_leaf_fcn, split_eval, n_threads, verbose) {}
|
||||
|
||||
virtual ~TrainForestRecursive() {}
|
||||
|
||||
virtual std::shared_ptr<Forest<SplitFunctionT, LeafFunctionT>> Train(const std::vector<TrainDatum>& samples, TrainType train_type, const std::shared_ptr<Forest<SplitFunctionT, LeafFunctionT>>& old_forest) {
|
||||
|
||||
this->PrintParams();
|
||||
|
||||
auto tim = std::chrono::system_clock::now();
|
||||
auto forest = std::make_shared<Forest<SplitFunctionT, LeafFunctionT>>();
|
||||
|
||||
omp_set_num_threads(this->n_threads);
|
||||
#pragma omp parallel for ordered
|
||||
for(size_t treeIdx = 0; treeIdx < this->params_.n_trees; ++treeIdx) {
|
||||
auto treetim = std::chrono::system_clock::now();
|
||||
|
||||
#pragma omp critical (TrainForest_train)
|
||||
{
|
||||
if(this->verbose_){
|
||||
std::cout << "[TRAIN][START] training tree " << treeIdx << " of " << this->params_.n_trees << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<Tree<SplitFunctionT, LeafFunctionT>> old_tree;
|
||||
if(old_forest != 0 && treeIdx < old_forest->trees_size()) {
|
||||
old_tree = old_forest->trees(treeIdx);
|
||||
}
|
||||
|
||||
std::random_device rd;
|
||||
std::mt19937 rng(rd());
|
||||
|
||||
auto tree = Train(samples, train_type, old_tree,rng);
|
||||
|
||||
#pragma omp critical (TrainForest_train)
|
||||
{
|
||||
forest->AddTree(tree);
|
||||
if(this->verbose_){
|
||||
auto now = std::chrono::system_clock::now();
|
||||
auto ms = std::chrono::duration_cast<std::chrono::milliseconds>(now - treetim);
|
||||
std::cout << "[TRAIN][FINISHED] training tree " << treeIdx << " of " << this->params_.n_trees << " - took " << (ms.count() * 1e-3) << "[s]" << std::endl;
|
||||
std::cout << "[TRAIN][FINISHED] " << (this->params_.n_trees - forest->trees_size()) << " left for training" << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(this->verbose_){
|
||||
auto now = std::chrono::system_clock::now();
|
||||
auto ms = std::chrono::duration_cast<std::chrono::milliseconds>(now - tim);
|
||||
std::cout << "[TRAIN][FINISHED] training forest - took " << (ms.count() * 1e-3) << "[s]" << std::endl;
|
||||
}
|
||||
|
||||
return forest;
|
||||
}
|
||||
|
||||
private:
|
||||
virtual std::shared_ptr<Tree<SplitFunctionT, LeafFunctionT>> Train(const std::vector<TrainDatum>& samples, TrainType train_type, const std::shared_ptr<Tree<SplitFunctionT, LeafFunctionT>>& old_tree, std::mt19937& rng) {
|
||||
NodePtr old_root;
|
||||
if(old_tree != nullptr) {
|
||||
old_root = old_tree->root();
|
||||
}
|
||||
|
||||
NodePtr root = Train(samples, train_type, old_root, 0, rng);
|
||||
return std::make_shared<Tree<SplitFunctionT, LeafFunctionT>>(root);
|
||||
}
|
||||
|
||||
virtual NodePtr Train(const std::vector<TrainDatum>& samples, TrainType train_type, const NodePtr& old_node, unsigned int depth, std::mt19937& rng) {
|
||||
|
||||
if(depth < this->params_.max_tree_depth && samples.size() > this->params_.min_samples_to_split) {
|
||||
std::shared_ptr<SplitFunctionT> best_split_fcn;
|
||||
bool was_split_node = false;
|
||||
if(old_node == nullptr || old_node->type() == LeafNode<LeafFunctionT>::TYPE) {
|
||||
best_split_fcn = this->OptimizeSplitFunction(samples, depth, rng);
|
||||
was_split_node = false;
|
||||
}
|
||||
else if(old_node->type() == SplitNode<SplitFunctionT, LeafFunctionT>::TYPE) {
|
||||
auto split_node = std::static_pointer_cast<SplitNode<SplitFunctionT, LeafFunctionT>>(old_node);
|
||||
best_split_fcn = split_node->split_fcn()->Copy();
|
||||
was_split_node = true;
|
||||
}
|
||||
|
||||
if(best_split_fcn == nullptr) {
|
||||
if(old_node == nullptr || train_type == TrainType::TRAIN || train_type == TrainType::RETRAIN_WITH_REPLACEMENT) {
|
||||
return this->CreateLeafNode(samples, depth);
|
||||
}
|
||||
else if(train_type == TrainType::RETRAIN) {
|
||||
return old_node->Copy();
|
||||
}
|
||||
else {
|
||||
std::cout << "[ERROR] unknown train type" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
// (1) split samples
|
||||
std::vector<TrainDatum> leftsamples, rightsamples;
|
||||
this->Split(best_split_fcn, samples, leftsamples, rightsamples);
|
||||
|
||||
//output node information
|
||||
this->UpdateNodeInfo(depth, false);
|
||||
|
||||
//create split node - recursively train the siblings
|
||||
if(was_split_node) {
|
||||
auto split_node = std::static_pointer_cast<SplitNode<SplitFunctionT, LeafFunctionT>>(old_node);
|
||||
NodePtr left = this->Train(leftsamples, train_type, split_node->left(), depth + 1, rng);
|
||||
NodePtr right = this->Train(rightsamples, train_type, split_node->right(), depth + 1, rng);
|
||||
auto new_node = std::make_shared<SplitNode<SplitFunctionT, LeafFunctionT>>(left, right, best_split_fcn);
|
||||
return new_node;
|
||||
}
|
||||
else {
|
||||
NodePtr left = this->Train(leftsamples, train_type, nullptr, depth + 1, rng);
|
||||
NodePtr right = this->Train(rightsamples, train_type, nullptr, depth + 1, rng);
|
||||
auto new_node = std::make_shared<SplitNode<SplitFunctionT, LeafFunctionT>>(left, right, best_split_fcn);
|
||||
return new_node;
|
||||
}
|
||||
} // if samples < min_samples || depth >= max_depth then make leaf node
|
||||
else {
|
||||
if(old_node == 0 || train_type == TrainType::TRAIN || train_type == TrainType::RETRAIN_WITH_REPLACEMENT) {
|
||||
return this->CreateLeafNode(samples, depth);
|
||||
}
|
||||
else if(train_type == TrainType::RETRAIN) {
|
||||
return old_node->Copy();
|
||||
}
|
||||
else {
|
||||
std::cout << "[ERROR] unknown train type" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
struct QueueTuple {
|
||||
int depth;
|
||||
std::vector<TrainDatum> train_data;
|
||||
NodePtr* parent;
|
||||
|
||||
QueueTuple() : depth(-1), train_data(), parent(nullptr) {}
|
||||
QueueTuple(int depth, std::vector<TrainDatum> train_data, NodePtr* parent) :
|
||||
depth(depth), train_data(train_data), parent(parent) {}
|
||||
};
|
||||
|
||||
template <typename SplitFunctionT, typename LeafFunctionT, typename SplitEvaluatorT>
|
||||
class TrainForestQueued : public TrainForest<SplitFunctionT, LeafFunctionT, SplitEvaluatorT> {
|
||||
public:
|
||||
TrainForestQueued(const TrainParameters& params, const std::shared_ptr<SplitFunctionT> gen_split_fcn, const std::shared_ptr<LeafFunctionT> gen_leaf_fcn, const std::shared_ptr<SplitEvaluatorT> split_eval, int n_threads, bool verbose)
|
||||
: TrainForest<SplitFunctionT, LeafFunctionT, SplitEvaluatorT>(params, gen_split_fcn, gen_leaf_fcn, split_eval, n_threads, verbose) {}
|
||||
|
||||
virtual ~TrainForestQueued() {}
|
||||
|
||||
virtual std::shared_ptr<Forest<SplitFunctionT, LeafFunctionT>> Train(const std::vector<TrainDatum>& samples, TrainType train_type, const std::shared_ptr<Forest<SplitFunctionT, LeafFunctionT>>& old_forest) {
|
||||
this->PrintParams();
|
||||
|
||||
auto tim = std::chrono::system_clock::now();
|
||||
auto forest = std::make_shared<Forest<SplitFunctionT, LeafFunctionT>>();
|
||||
|
||||
std::cout << "[TRAIN] create pool with " << this->n_threads << " threads" << std::endl;
|
||||
auto pool = std::make_shared<ThreadPool>(this->n_threads);
|
||||
for(int treeidx = 0; treeidx < this->params_.n_trees; ++treeidx) {
|
||||
auto tree = std::make_shared<Tree<SplitFunctionT, LeafFunctionT>>();
|
||||
forest->AddTree(tree);
|
||||
AddJob(pool, QueueTuple(0, samples, &(tree->root_)));
|
||||
}
|
||||
|
||||
while(pool->has_running_tasks()) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
}
|
||||
|
||||
if(this->verbose_){
|
||||
auto now = std::chrono::system_clock::now();
|
||||
auto ms = std::chrono::duration_cast<std::chrono::milliseconds>(now - tim);
|
||||
std::cout << "[TRAIN][FINISHED] training forest - took " << (ms.count() * 1e-3) << "[s]" << std::endl;
|
||||
}
|
||||
|
||||
return forest;
|
||||
}
|
||||
|
||||
private:
|
||||
virtual void AddJob(std::shared_ptr<ThreadPool> pool, QueueTuple data) {
|
||||
pool->enqueue([this](std::shared_ptr<ThreadPool> pool, QueueTuple data) {
|
||||
std::random_device rd;
|
||||
std::mt19937 rng(rd());
|
||||
|
||||
std::shared_ptr<SplitFunctionT> best_split_fcn = nullptr;
|
||||
|
||||
if(data.depth < this->params_.max_tree_depth && int(data.train_data.size()) > this->params_.min_samples_to_split) {
|
||||
best_split_fcn = this->OptimizeSplitFunction(data.train_data, data.depth, rng);
|
||||
}
|
||||
|
||||
if(best_split_fcn == nullptr) {
|
||||
auto node = this->CreateLeafNode(data.train_data, data.depth);
|
||||
*(data.parent) = node;
|
||||
}
|
||||
else {
|
||||
this->UpdateNodeInfo(data.depth, false);
|
||||
auto node = std::make_shared<SplitNode<SplitFunctionT, LeafFunctionT>>();
|
||||
node->split_fcn_ = best_split_fcn;
|
||||
*(data.parent) = node;
|
||||
|
||||
QueueTuple left;
|
||||
QueueTuple right;
|
||||
this->Split(best_split_fcn, data.train_data, left.train_data, right.train_data);
|
||||
|
||||
left.depth = data.depth + 1;
|
||||
right.depth = data.depth + 1;
|
||||
|
||||
left.parent = &(node->left_);
|
||||
right.parent = &(node->right_);
|
||||
|
||||
this->AddJob(pool, left);
|
||||
this->AddJob(pool, right);
|
||||
}
|
||||
}, pool, data);
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,55 @@
|
||||
#pragma once
|
||||
|
||||
#include "node.h"
|
||||
|
||||
template <typename SplitFunctionT, typename LeafFunctionT>
|
||||
class Tree {
|
||||
public:
|
||||
Tree() : root_(nullptr) {}
|
||||
Tree(NodePtr root) : root_(root) {}
|
||||
|
||||
virtual ~Tree() {}
|
||||
|
||||
std::shared_ptr<LeafFunctionT> inference(const SamplePtr sample) const {
|
||||
if(root_ == nullptr) {
|
||||
std::cout << "[ERROR] tree inference root node is NULL";
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
NodePtr node = root_;
|
||||
while(node->type() == SplitNode<SplitFunctionT, LeafFunctionT>::TYPE) {
|
||||
auto splitNode = std::static_pointer_cast<SplitNode<SplitFunctionT, LeafFunctionT>>(node);
|
||||
bool left = splitNode->Split(sample);
|
||||
if(left) {
|
||||
node = splitNode->left();
|
||||
}
|
||||
else {
|
||||
node = splitNode->right();
|
||||
}
|
||||
}
|
||||
|
||||
auto leaf_node = std::static_pointer_cast<LeafNode<LeafFunctionT>>(node);
|
||||
return leaf_node->leaf_node_fcn();
|
||||
}
|
||||
|
||||
NodePtr root() const { return root_; }
|
||||
void set_root(NodePtr root) { root_ = root; }
|
||||
|
||||
virtual void Save(SerializationOut& ar) const {
|
||||
int type = root_->type();
|
||||
ar << type;
|
||||
root_->Save(ar);
|
||||
}
|
||||
|
||||
virtual void Load(SerializationIn& ar) {
|
||||
int type;
|
||||
ar >> type;
|
||||
root_ = MakeNode<SplitFunctionT, LeafFunctionT>(type);
|
||||
root_->Load(ar);
|
||||
}
|
||||
|
||||
|
||||
public:
|
||||
NodePtr root_;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user