From 658b68d70a64ed8058b3584355be44b868adc19d Mon Sep 17 00:00:00 2001 From: gAldeia Date: Tue, 16 Apr 2024 14:58:42 -0300 Subject: [PATCH] Fixed scorer_ not visible to python --- src/bindings/bind_params.cpp | 1 + src/params.h | 3 +++ src/program/node.cpp | 10 +++++++--- src/program/tree_node.cpp | 9 +++++---- 4 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/bindings/bind_params.cpp b/src/bindings/bind_params.cpp index 77e67a30..8c0c080e 100644 --- a/src/bindings/bind_params.cpp +++ b/src/bindings/bind_params.cpp @@ -18,6 +18,7 @@ void bind_params(py::module& m) .def_property("max_stall", &Brush::Parameters::get_max_stall, &Brush::Parameters::set_max_stall) .def_property("max_time", &Brush::Parameters::get_max_time, &Brush::Parameters::set_max_time) .def_property("current_gen", &Brush::Parameters::get_current_gen, &Brush::Parameters::set_current_gen) + .def_property("scorer_", &Brush::Parameters::get_scorer_, &Brush::Parameters::set_scorer_) .def_property("load_population", &Brush::Parameters::get_load_population, &Brush::Parameters::set_load_population) .def_property("save_population", &Brush::Parameters::get_save_population, &Brush::Parameters::set_save_population) .def_property("num_islands", &Brush::Parameters::get_num_islands, &Brush::Parameters::set_num_islands) diff --git a/src/params.h b/src/params.h index f46c73a8..749df87b 100644 --- a/src/params.h +++ b/src/params.h @@ -96,6 +96,9 @@ struct Parameters void set_max_time(int new_max_time){ max_time = new_max_time; }; int get_max_time(){ return max_time; }; + void set_scorer_(string new_scorer_){ scorer_ = new_scorer_; }; + string get_scorer_(){ return scorer_; }; + void set_load_population(string new_load_population){ load_population = new_load_population; }; string get_load_population(){ return load_population; }; diff --git a/src/program/node.cpp b/src/program/node.cpp index bb9dc351..23a7be82 100644 --- a/src/program/node.cpp +++ b/src/program/node.cpp @@ -34,12 +34,16 @@ auto Node::get_name(bool include_weight) const noexcept -> std::string else if (Is(node_type)) { if (include_weight) - return fmt::format("{:.2f}", W); // Handle as if it was a constant - //explicitly print as a MeanLabel and include weight on label - return fmt::format("MeanLabel({:.2f})", W); + return fmt::format("{:.2f}*{}", W, feature); + + return feature; + } + else if (Is(node_type)){ + return fmt::format("{}+Sum", W); } else if (is_weighted && include_weight) return fmt::format("{:.2f}*{}",W,name); + return name; } diff --git a/src/program/tree_node.cpp b/src/program/tree_node.cpp index cd04fbee..c1695122 100644 --- a/src/program/tree_node.cpp +++ b/src/program/tree_node.cpp @@ -37,9 +37,7 @@ string TreeNode::get_tree_model(bool pretty, string offset) const if (sib != nullptr) child_outputs += "\n"; } - /* if (pretty) */ - /* return op_name + child_outputs; */ - /* else */ + return data.get_name() + child_outputs; }; //////////////////////////////////////////////////////////////////////////////// @@ -163,7 +161,10 @@ int TreeNode::get_complexity() const // include the `w` and `*` if the node is weighted (and it is not a constant or mean label) if (data.get_is_weighted() - && (Is(data.node_type) || Is(data.node_type)) ) + && (Is(data.node_type) + || Is(data.node_type) + || Is(data.node_type)) ) + return operator_complexities.at(NodeType::Mul)*( operator_complexities.at(NodeType::Constant) + node_complexity*(children_complexity_sum)