diff --git a/src/program/node.cpp b/src/program/node.cpp index ff57ff69..4c55dc81 100644 --- a/src/program/node.cpp +++ b/src/program/node.cpp @@ -241,6 +241,10 @@ void from_json(const json &j, Node& p) // j.at("feature").get_to(p.feature); p.set_feature(j.at("feature")); } + if (j.contains("feature_type")) + { + p.set_feature_type(j.at("feature_type")); + } // if node has a ret_type and arg_types, get them. if not we need to make // a signature diff --git a/src/program/node.h b/src/program/node.h index 8fcad729..5ab747d4 100644 --- a/src/program/node.h +++ b/src/program/node.h @@ -248,6 +248,9 @@ struct Node { inline void set_feature(string f){ feature = f; }; inline string get_feature() const { return feature; }; + + inline void set_feature_type(DataType ft){ feature_type = ft; }; + inline DataType get_feature_type() const { return feature_type; }; inline bool get_is_weighted() const {return this->is_weighted;}; inline void set_is_weighted(bool is_weighted){ @@ -257,9 +260,12 @@ struct Node { }; private: - /// @brief feature name for terminals or splitting nodes string feature; + + /// @brief feature type for terminals or splitting nodes + DataType feature_type; + }; template diff --git a/src/program/split.h b/src/program/split.h index 9b937ea8..9249f02f 100644 --- a/src/program/split.h +++ b/src/program/split.h @@ -238,6 +238,16 @@ struct Operator(d[feature])) + tn.data.set_feature_type(DataType::ArrayB); + else if (std::holds_alternative(d[feature])) + tn.data.set_feature_type(DataType::ArrayI); + else if (std::holds_alternative(d[feature])) + tn.data.set_feature_type(DataType::ArrayF); + else + HANDLE_ERROR_THROW("Unknown feature type in data\n"); } return predict(d, tn); diff --git a/src/program/tree_node.cpp b/src/program/tree_node.cpp index 1d2f89a5..7ab72569 100644 --- a/src/program/tree_node.cpp +++ b/src/program/tree_node.cpp @@ -39,7 +39,7 @@ string TreeNode::get_tree_model(bool pretty, string offset) const } if (Is(data.node_type)){ - if (data.arg_types.at(0) == DataType::ArrayB) + if (data.get_feature_type() == DataType::ArrayB) return fmt::format("If({})", data.get_feature()) + child_outputs; return fmt::format("If({}>{:.2f})", data.get_feature(), data.W) +