Skip to content

Commit

Permalink
Initializing operator bandit
Browse files Browse the repository at this point in the history
  • Loading branch information
not-gAldeia committed Aug 14, 2024
1 parent dd10584 commit d4b87ef
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 11 deletions.
8 changes: 4 additions & 4 deletions src/vary/search_space.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ struct SearchSpace
using ArgsHash = std::size_t;

template<typename T>
using Map = unordered_map<DataType, // return type
unordered_map<ArgsHash, // hash of arg types
unordered_map<NodeType, // node type
T>>>; // the data!
using Map = unordered_map<DataType, // return type
unordered_map<ArgsHash, // hash of arg types
unordered_map<NodeType, // node type
T>>>; // the data!

/**
* @brief Maps return types to argument types to node types.
Expand Down
36 changes: 29 additions & 7 deletions src/vary/variation.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,34 @@ class Variation {
}
}

// TODO: op bandit?
// this->op_bandit = Bandit<DataType>(this->parameters.bandit,
// this->search_space.node_map_weights.size() );

// one bandit for each return type. If we look at implementation of
// sample op, the thing that matters is the most nested probabilities, so we will
// learn only that
for (auto& [ret_type, arg_w_map]: search_space.node_map)
{
std::cout << "creating bandit..." << std::endl;

// TODO: this could be made much easier using user_ops
map<string, float> node_probs;
for (const auto& [args_type, node_map] : arg_w_map)
{
for (const auto& [node_type, node]: node_map)
{
auto weight = search_space.node_map_weights.at(ret_type).at(args_type).at(node_type);

// Attempt to emplace; if the key exists, do nothing
auto [it, inserted] = node_probs.try_emplace(node.name, weight);

// If the key already existed, update its value
if (!inserted) {
// it->second += weight;
}

std::cout << node.name << ", " << it->second << std::endl;
}
}
op_bandits[ret_type] = Bandit<string>(parameters.bandit, node_probs );
}
};

/**
Expand Down Expand Up @@ -185,9 +209,7 @@ class Variation {

map<DataType, Bandit<string>> terminal_bandits;

// TODO: implement bandit for operators
// Bandit<DataType> op_bandit;

map<DataType, Bandit<string>> op_bandits;
};

// // Explicitly instantiate the template for brush program types
Expand Down

0 comments on commit d4b87ef

Please sign in to comment.