Skip to content

Commit

Permalink
Option to print the search space or get a string with the output
Browse files Browse the repository at this point in the history
  • Loading branch information
gAldeia committed Jul 8, 2024
1 parent c102bb6 commit 03476d8
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 22 deletions.
1 change: 1 addition & 0 deletions src/bindings/bind_search_space.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ void bind_search_space(py::module &m)
py::arg("data"),
py::arg("user_ops"),
py::arg("weights_init") = true )
.def("__repr__", &br::SearchSpace::repr, "Representation for debugging the SearchSpace object")
.def("make_regressor", &br::SearchSpace::make_regressor,
py::arg("max_d") = 0,
py::arg("max_size") = 0,
Expand Down
50 changes: 28 additions & 22 deletions src/vary/search_space.h
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,28 @@ struct SearchSpace
/// @brief prints the search space map.
void print() const;

/// @brief returns a string with a json representation of the search space map
std::string repr() const {
string output = "=== Search space ===\n";
output += fmt::format("terminal_map: {}\n", this->terminal_map);
output += fmt::format("terminal_weights: {}\n", this->terminal_weights);

for (const auto& [ret_type, v] : this->node_map) {
for (const auto& [args_type, v2] : v) {
for (const auto& [node_type, node] : v2) {
output += fmt::format("node_map[{}][{}][{}] = {}, weight = {}\n",
ret_type,
ArgsName[args_type],
node_type,
node,
this->node_map_weights.at(ret_type).at(args_type).at(node_type)
);
}
}
}
return output;
};

private:
tree<Node>& PTC2(tree<Node>& Tree, tree<Node>::iterator root, int max_d, int max_size) const;

Expand Down Expand Up @@ -757,28 +779,12 @@ extern SearchSpace SS;
} // Brush

// format overload
template <> struct fmt::formatter<Brush::SearchSpace>: formatter<string_view> {
template <typename FormatContext>
auto format(const Brush::SearchSpace& SS, FormatContext& ctx) const {
string output = "Search Space\n===\n";
output += fmt::format("terminal_map: {}\n", SS.terminal_map);
output += fmt::format("terminal_weights: {}\n", SS.terminal_weights);

for (const auto& [ret_type, v] : SS.node_map) {
for (const auto& [args_type, v2] : v) {
for (const auto& [node_type, node] : v2) {
output += fmt::format("node_map[{}][{}][{}] = {}, weight = {}\n",
ret_type,
ArgsName[args_type],
node_type,
node,
SS.node_map_weights.at(ret_type).at(args_type).at(node_type)
);
}
}
template <>
struct fmt::formatter<Brush::SearchSpace>: formatter<string_view> {
template <typename FormatContext>
auto format(const Brush::SearchSpace& SS, FormatContext& ctx) const {
string output = SS.repr();
return formatter<string_view>::format(output, ctx);
}
output += "===";
return formatter<string_view>::format(output, ctx);
}
};
#endif

0 comments on commit 03476d8

Please sign in to comment.