Skip to content

Commit

Permalink
Fixed scorer_ not visible to python
Browse files Browse the repository at this point in the history
  • Loading branch information
gAldeia committed Apr 16, 2024
1 parent 7976fe2 commit 658b68d
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 7 deletions.
1 change: 1 addition & 0 deletions src/bindings/bind_params.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ void bind_params(py::module& m)
.def_property("max_stall", &Brush::Parameters::get_max_stall, &Brush::Parameters::set_max_stall)
.def_property("max_time", &Brush::Parameters::get_max_time, &Brush::Parameters::set_max_time)
.def_property("current_gen", &Brush::Parameters::get_current_gen, &Brush::Parameters::set_current_gen)
.def_property("scorer_", &Brush::Parameters::get_scorer_, &Brush::Parameters::set_scorer_)
.def_property("load_population", &Brush::Parameters::get_load_population, &Brush::Parameters::set_load_population)
.def_property("save_population", &Brush::Parameters::get_save_population, &Brush::Parameters::set_save_population)
.def_property("num_islands", &Brush::Parameters::get_num_islands, &Brush::Parameters::set_num_islands)
Expand Down
3 changes: 3 additions & 0 deletions src/params.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ struct Parameters
void set_max_time(int new_max_time){ max_time = new_max_time; };
int get_max_time(){ return max_time; };

void set_scorer_(string new_scorer_){ scorer_ = new_scorer_; };
string get_scorer_(){ return scorer_; };

void set_load_population(string new_load_population){ load_population = new_load_population; };
string get_load_population(){ return load_population; };

Expand Down
10 changes: 7 additions & 3 deletions src/program/node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,16 @@ auto Node::get_name(bool include_weight) const noexcept -> std::string
else if (Is<NodeType::MeanLabel>(node_type))
{
if (include_weight)
return fmt::format("{:.2f}", W); // Handle as if it was a constant
//explicitly print as a MeanLabel and include weight on label
return fmt::format("MeanLabel({:.2f})", W);
return fmt::format("{:.2f}*{}", W, feature);

return feature;
}
else if (Is<NodeType::OffsetSum>(node_type)){
return fmt::format("{}+Sum", W);
}
else if (is_weighted && include_weight)
return fmt::format("{:.2f}*{}",W,name);

return name;
}

Expand Down
9 changes: 5 additions & 4 deletions src/program/tree_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ string TreeNode::get_tree_model(bool pretty, string offset) const
if (sib != nullptr)
child_outputs += "\n";
}
/* if (pretty) */
/* return op_name + child_outputs; */
/* else */

return data.get_name() + child_outputs;
};
////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -163,7 +161,10 @@ int TreeNode::get_complexity() const

// include the `w` and `*` if the node is weighted (and it is not a constant or mean label)
if (data.get_is_weighted()
&& (Is<NodeType::Constant>(data.node_type) || Is<NodeType::MeanLabel>(data.node_type)) )
&& (Is<NodeType::Constant>(data.node_type)
|| Is<NodeType::MeanLabel>(data.node_type)
|| Is<NodeType::OffsetSum>(data.node_type)) )

return operator_complexities.at(NodeType::Mul)*(
operator_complexities.at(NodeType::Constant) +
node_complexity*(children_complexity_sum)
Expand Down

0 comments on commit 658b68d

Please sign in to comment.