#pragma once #include #include "serialization.h" #include "leaffcn.h" #include "splitfcn.h" class Node { public: Node() {} virtual ~Node() {} virtual std::shared_ptr Copy() const = 0; virtual int type() const = 0; virtual void Save(SerializationOut& ar) const = 0; virtual void Load(SerializationIn& ar) = 0; }; typedef std::shared_ptr NodePtr; template class LeafNode : public Node { public: static const int TYPE = 0; LeafNode() {} LeafNode(std::shared_ptr leaf_node_fcn) : leaf_node_fcn_(leaf_node_fcn) {} virtual ~LeafNode() {} virtual NodePtr Copy() const { auto node = std::make_shared(); node->leaf_node_fcn_ = leaf_node_fcn_->Copy(); return node; } virtual void Save(SerializationOut& ar) const { leaf_node_fcn_->Save(ar); } virtual void Load(SerializationIn& ar) { leaf_node_fcn_ = std::make_shared(); leaf_node_fcn_->Load(ar); } virtual int type() const { return TYPE; }; std::shared_ptr leaf_node_fcn() const { return leaf_node_fcn_; } private: std::shared_ptr leaf_node_fcn_; DISABLE_COPY_AND_ASSIGN(LeafNode); }; template class SplitNode : public Node { public: static const int TYPE = 1; SplitNode() {} SplitNode(NodePtr left, NodePtr right, std::shared_ptr split_fcn) : left_(left), right_(right), split_fcn_(split_fcn) {} virtual ~SplitNode() {} virtual std::shared_ptr Copy() const { std::shared_ptr node = std::make_shared(); node->left_ = left_->Copy(); node->right_ = right_->Copy(); node->split_fcn_ = split_fcn_->Copy(); return node; } bool Split(SamplePtr sample) { return split_fcn_->Split(sample); } virtual void Save(SerializationOut& ar) const { split_fcn_->Save(ar); //left int type = left_->type(); ar << type; left_->Save(ar); //right type = right_->type(); ar << type; right_->Save(ar); } virtual void Load(SerializationIn& ar); virtual int type() const { return TYPE; } NodePtr left() const { return left_; } NodePtr right() const { return right_; } std::shared_ptr split_fcn() const { return split_fcn_; } void set_left(NodePtr left) { left_ = left; } void set_right(NodePtr right) { right_ = right; } void set_split_fcn(std::shared_ptr split_fcn) { split_fcn_ = split_fcn; } public: NodePtr left_; NodePtr right_; std::shared_ptr split_fcn_; DISABLE_COPY_AND_ASSIGN(SplitNode); }; template NodePtr MakeNode(int type) { NodePtr node; if(type == LeafNode::TYPE) { node = std::make_shared>(); } else if(type == SplitNode::TYPE) { node = std::make_shared>(); } else { std::cout << "[ERROR] unknown node type" << std::endl; exit(-1); } return node; } template void SplitNode::Load(SerializationIn& ar) { split_fcn_ = std::make_shared(); split_fcn_->Load(ar); //left int left_type; ar >> left_type; left_ = MakeNode(left_type); left_->Load(ar); //right int right_type; ar >> right_type; right_ = MakeNode(right_type); right_->Load(ar); }