Skip to content

Commit

Permalink
Using bandits to sample from the search space directly
Browse files Browse the repository at this point in the history
  • Loading branch information
gAldeia committed Sep 25, 2024
1 parent f6f4480 commit 96cf141
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 41 deletions.
4 changes: 1 addition & 3 deletions src/vary/search_space.h
Original file line number Diff line number Diff line change
Expand Up @@ -466,10 +466,8 @@ struct SearchSpace

vector<Node> matches;
vector<float> weights;
for (const auto& kv: ret_match)
for (const auto& [arg_hash, node_type_map]: ret_match)
{
auto arg_hash = kv.first;
auto node_type_map = kv.second;
if (node_type_map.find(type) != node_type_map.end())
{
matches.push_back(node_type_map.at(type));
Expand Down
49 changes: 20 additions & 29 deletions src/vary/variation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ class PointMutation : public MutationBase
{
// get_node_like will sample a similar node based on node_map_weights or
// terminal_weights, and maybe will return a Node.
optional<Node> newNode = variator.search_space.get_node_like(spot.node->data);
auto context = variator.get_context(Tree, spot);
optional<Node> newNode = variator.bandit_get_node_like(spot.node->data, context);

if (!newNode) // overload to check if newNode == nullopt
return false;
Expand Down Expand Up @@ -89,8 +90,9 @@ class InsertMutation : public MutationBase
// size restriction, which will be relaxed here (just as it is in the PTC2
// algorithm). This mutation can create a new expression that exceeds the
// maximum size by the highest arity among the operators.
std::optional<Node> n = variator.search_space.sample_op_with_arg(
spot_type, spot_type, true, params.max_size-Tree.size()-1);
auto context = variator.get_context(Tree, spot);
std::optional<Node> n = variator.bandit_sample_op_with_arg(
spot_type, spot_type, context, params.max_size-Tree.size()-1);

if (!n) // there is no operator with compatible arguments
return false;
Expand Down Expand Up @@ -692,12 +694,12 @@ void Variation<T>::vary(Population<T>& pop, int island,
}
else
{
std::cout << "Performing mutation " << std::endl;
// std::cout << "Performing mutation " << std::endl;
auto variation_result = mutate(mom);
cout << "finished mutation" << endl;
// cout << "finished mutation" << endl;
ind_parents = {mom};
tie(opt, context) = variation_result;
cout << "unpacked" << endl;
// cout << "unpacked" << endl;
}

// this assumes that islands do not share indexes before doing variation
Expand Down Expand Up @@ -861,29 +863,18 @@ void Variation<T>::update_ss(Population<T>& pop)
}

// operators: getting new probabilities for op nodes
for (auto& bandit : op_bandits) {
auto ret_type = bandit.first;

auto op_probs = bandit.second.sample_probs(true);
for (auto& op : op_probs) {

auto op_name = op.first;
auto op_prob = op.second;

// Search for the index that matches the op name (for all different arguments)
for (const auto& [args_type, node_map] : search_space.node_map.at(ret_type))
{
auto it = std::find_if(
node_map.begin(),
node_map.end(),
[&](auto& entry) { // entry is a pair of index and node
return entry.second.name == op_name; });

if (it != node_map.end()) {
auto index = it->first;

// Update the node_map_weights with the second value
search_space.node_map_weights.at(ret_type).at(args_type).at(index) = op_prob;
for (auto& [ret_type, bandit_map] : op_bandits) {
for (auto& [args_type, bandit] : bandit_map) {
auto op_probs = bandit.sample_probs(true);

for (auto& [op_name, op_prob] : op_probs) {

for (const auto& [node_type, node_value]: search_space.node_map.at(ret_type).at(args_type))
{
// std::cout << " - Node name: " << node_value.name << std::endl;
if (node_value.name == op_name) {
search_space.node_map_weights.at(ret_type).at(args_type).at(node_type) = op_prob;
}
}
}
}
Expand Down
147 changes: 138 additions & 9 deletions src/vary/variation.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,16 @@ class Variation {
// learn only that
for (auto& [ret_type, arg_w_map]: search_space.node_map)
{
// if (op_bandits.find(ret_type) == op_bandits.end())
// op_bandits.at(ret_type) = map<size_t, Bandit<string>>();

// TODO: this could be made much easier using user_ops
map<string, float> node_probs;
for (const auto& [args_type, node_map] : arg_w_map)
{
// if (op_bandits.at(ret_type).find(args_type) != op_bandits.at(ret_type).end())
// continue

for (const auto& [node_type, node]: node_map)
{
auto weight = search_space.node_map_weights.at(ret_type).at(args_type).at(node_type);
Expand All @@ -127,8 +132,8 @@ class Variation {
// it->second += weight;
}
}
op_bandits[ret_type][args_type] = Bandit<string>(parameters.bandit, node_probs);
}
op_bandits[ret_type] = Bandit<string>(parameters.bandit, node_probs);
}
};

Expand Down Expand Up @@ -319,21 +324,23 @@ class Variation {
this->variation_bandit.update(choice, 0.0, context);
}

if (ind.get_variation() != "born" && ind.get_variation() != "cx")
if (ind.get_variation() != "born" && ind.get_variation() != "cx"
&& ind.get_variation() != "subtree")
{
if (ind.get_sampled_nodes().size() > 0) {
const auto& changed_nodes = ind.get_sampled_nodes();
for (const auto& node : changed_nodes) {
for (auto& node : changed_nodes) {
if (node.get_arg_count() == 0) {
auto datatype = node.get_ret_type();
// std::cout << "Updating terminal bandit for node: " << node.get_feature() << std::endl;
// std::cout << "Updating terminal bandit for node: " << node.name << std::endl;
this->terminal_bandits[datatype].update(node.get_feature(), r, context);
}
else {
auto ret_type = node.get_ret_type();
auto args_type = node.args_type();
auto name = node.name;
// std::cout << "Updating operator bandit for node: " << name << std::endl;
this->op_bandits[ret_type].update(name, r, context);
this->op_bandits[ret_type][args_type].update(name, r, context);
}
}
}
Expand All @@ -348,29 +355,151 @@ class Variation {
// bandit_sample_terminal
std::optional<Node> bandit_sample_terminal(DataType R, VectorXf& context)
{
if (terminal_bandits.find(R) == terminal_bandits.end())
// std::cout << "bandit_sample_terminal called with DataType: " << std::endl;

if (terminal_bandits.find(R) == terminal_bandits.end()) {
// std::cout << "No bandit found for DataType: " << std::endl;
return std::nullopt;
}

auto& bandit = terminal_bandits.at(R);
string terminal_name = bandit.choose(context);
// std::cout << "Bandit chose terminal name: " << terminal_name << std::endl;

auto it = std::find_if(
search_space.terminal_map.at(R).begin(),
search_space.terminal_map.at(R).end(),
[&](auto& node) { return node.get_feature() == terminal_name; });

if (it != search_space.terminal_map.at(R).end()) {
auto index = std::distance(search_space.terminal_map.at(R).begin(), it);
// std::cout << "Terminal found at index: " << index << std::endl;
return search_space.terminal_map.at(R).at(index);
}

// std::cout << "Terminal not found for name: " << terminal_name << std::endl;
return std::nullopt;
};

// bandit_get_node_like
std::optional<Node> bandit_get_node_like(Node node, VectorXf& context)
{
// std::cout << "bandit_get_node_like called with node: " << node.name << std::endl;

// TODO: use search_space.terminal_types here (and in search_space get_node_like as well)
if (Is<NodeType::Terminal, NodeType::Constant, NodeType::MeanLabel>(node.node_type)){
// std::cout << "Node is of type Terminal, Constant, or MeanLabel" << std::endl;
return bandit_sample_terminal(node.ret_type, context);
}

if (op_bandits.find(node.ret_type) == op_bandits.end()) {
// std::cout << "No bandit found for return type: " << std::endl;
return std::nullopt;
}
if (op_bandits.at(node.ret_type).find(node.args_type()) == op_bandits.at(node.ret_type).end()) {
// std::cout << "No bandit found for arg type: " << std::endl;
return std::nullopt;
}

auto& bandit = op_bandits[node.ret_type][node.args_type()];
string node_name = bandit.choose(context);
// std::cout << "Bandit chose node name: " << node_name << std::endl;

auto entries = search_space.node_map[node.ret_type][node.args_type()];
// std::cout << "Ret match size: " << entries.size() << std::endl;

for (const auto& [node_type, node_value]: entries)
{
// std::cout << " - Node name: " << node_value.name << std::endl;
if (node_value.name == node_name) {
// std::cout << "Node name match: " << node_value.name << std::endl;
return node_value;
}
}

return std::nullopt;
};

// bandit_sample_op_with_arg
std::optional<Node> bandit_sample_op_with_arg(DataType ret, DataType arg,
VectorXf& context, int max_args=0)
{
auto args_map = search_space.node_map.at(ret);
vector<size_t> matches;
matches.resize(0);

for (const auto& [args_type, name_map]: args_map) {
for (const auto& [name, node]: name_map) {
auto node_arg_types = node.get_arg_types();

auto within_size_limit = !(max_args) || (node.get_arg_count() <= max_args);

if (in(node_arg_types, arg) && within_size_limit) {
bool compatible = true;
for (const auto& arg_type: node_arg_types) {
if (arg_type != arg) {
if ( ! in(search_space.terminal_types, arg_type) ) {
compatible = false;
break;
}
}
}
if (! compatible)
continue;

matches.push_back(node.args_type());
}
}
}

if (matches.size()==0)
return std::nullopt;

// any bandit to do the job
auto args_type = *r.select_randomly(matches.begin(),
matches.end() );
auto& bandit = op_bandits[ret][args_type];
string node_name = bandit.choose(context);

auto entries = search_space.node_map[ret][args_type];
for (const auto& [node_type, node_value]: entries)
{
if (node_value.name == node_name) {
return node_value;
}
}

return std::nullopt;
};

// bandit_sample_op
// bandit_sample_subtree
std::optional<Node> bandit_sample_op(DataType ret, VectorXf& context)
{
if (search_space.node_map.find(ret) == search_space.node_map.end())
return std::nullopt;

// any bandit to do the job
auto& [args_type, bandit] = *r.select_randomly(op_bandits[ret].begin(),
op_bandits[ret].end() );

string node_name = bandit.choose(context);

auto entries = search_space.node_map[ret][args_type];
for (const auto& [node_type, node_value]: entries)
{
if (node_value.name == node_name) {
return node_value;
}
}

return std::nullopt;
};

// bandit_sample_subtree // TODO: should I implement this? (its going to be hard).
// without this one being performed directly by the bandits, we then rely on
// the sampled probabilities we update after every generation. Since there are lots
// of samplings, I think it is ok to not update them and just use the distribution they learned.


VectorXf get_context(const tree<Node>& tree, Iter spot) {
return variation_bandit.get_context(tree, spot); }

Expand All @@ -383,7 +512,7 @@ class Variation {
// and also propagate what they learn back to the search space at the end of the execution.
Bandit<string> variation_bandit;
map<DataType, Bandit<string>> terminal_bandits;
map<DataType, Bandit<string>> op_bandits;
map<DataType, map<size_t, Bandit<string>>> op_bandits;
};

// // Explicitly instantiate the template for brush program types
Expand Down

0 comments on commit 96cf141

Please sign in to comment.