Skip to content

Commit

Permalink
Improving display of split best when splitting on boolean features
Browse files Browse the repository at this point in the history
  • Loading branch information
gAldeia committed Nov 7, 2024
1 parent 10af3c1 commit 2fd7c39
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 2 deletions.
4 changes: 4 additions & 0 deletions src/program/node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,10 @@ void from_json(const json &j, Node& p)
// j.at("feature").get_to(p.feature);
p.set_feature(j.at("feature"));
}
if (j.contains("feature_type"))
{
p.set_feature_type(j.at("feature_type"));
}

// if node has a ret_type and arg_types, get them. if not we need to make
// a signature
Expand Down
8 changes: 7 additions & 1 deletion src/program/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,9 @@ struct Node {

inline void set_feature(string f){ feature = f; };
inline string get_feature() const { return feature; };

inline void set_feature_type(DataType ft){ feature_type = ft; };
inline DataType get_feature_type() const { return feature_type; };

inline bool get_is_weighted() const {return this->is_weighted;};
inline void set_is_weighted(bool is_weighted){
Expand All @@ -257,9 +260,12 @@ struct Node {
};

private:

/// @brief feature name for terminals or splitting nodes
string feature;

/// @brief feature type for terminals or splitting nodes
DataType feature_type;

};

template <NodeType... T>
Expand Down
10 changes: 10 additions & 0 deletions src/program/split.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,16 @@ struct Operator<NT, S, Fit, enable_if_t<is_in_v<NT, NodeType::SplitOn, NodeType:
string feature = "";
tie(feature, threshold) = Split::get_best_variable_and_threshold(d, tn);
tn.data.set_feature(feature);

// TODO: improve this. Im interested only in ArrayB, maybe I could simplify
if (std::holds_alternative<ArrayXf>(d[feature]))
tn.data.set_feature_type(DataType::ArrayB);
else if (std::holds_alternative<ArrayXi>(d[feature]))
tn.data.set_feature_type(DataType::ArrayI);
else if (std::holds_alternative<ArrayXf>(d[feature]))
tn.data.set_feature_type(DataType::ArrayF);
else
HANDLE_ERROR_THROW("Unknown feature type in data\n");
}

return predict(d, tn);
Expand Down
2 changes: 1 addition & 1 deletion src/program/tree_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ string TreeNode::get_tree_model(bool pretty, string offset) const
}

if (Is<NodeType::SplitBest>(data.node_type)){
if (data.arg_types.at(0) == DataType::ArrayB)
if (data.get_feature_type() == DataType::ArrayB)
return fmt::format("If({})", data.get_feature()) + child_outputs;

return fmt::format("If({}>{:.2f})", data.get_feature(), data.W) +
Expand Down

0 comments on commit 2fd7c39

Please sign in to comment.