diff --git a/src/vary/search_space.h b/src/vary/search_space.h index 80a89586..9a2e9687 100644 --- a/src/vary/search_space.h +++ b/src/vary/search_space.h @@ -85,10 +85,10 @@ struct SearchSpace using ArgsHash = std::size_t; template - using Map = unordered_map>>; // the data! + using Map = unordered_map>>; // the data! /** * @brief Maps return types to argument types to node types. diff --git a/src/vary/variation.h b/src/vary/variation.h index 3241ebab..489a280c 100644 --- a/src/vary/variation.h +++ b/src/vary/variation.h @@ -124,10 +124,34 @@ class Variation { } } - // TODO: op bandit? - // this->op_bandit = Bandit(this->parameters.bandit, - // this->search_space.node_map_weights.size() ); - + // one bandit for each return type. If we look at implementation of + // sample op, the thing that matters is the most nested probabilities, so we will + // learn only that + for (auto& [ret_type, arg_w_map]: search_space.node_map) + { + std::cout << "creating bandit..." << std::endl; + + // TODO: this could be made much easier using user_ops + map node_probs; + for (const auto& [args_type, node_map] : arg_w_map) + { + for (const auto& [node_type, node]: node_map) + { + auto weight = search_space.node_map_weights.at(ret_type).at(args_type).at(node_type); + + // Attempt to emplace; if the key exists, do nothing + auto [it, inserted] = node_probs.try_emplace(node.name, weight); + + // If the key already existed, update its value + if (!inserted) { + // it->second += weight; + } + + std::cout << node.name << ", " << it->second << std::endl; + } + } + op_bandits[ret_type] = Bandit(parameters.bandit, node_probs ); + } }; /** @@ -185,9 +209,7 @@ class Variation { map> terminal_bandits; - // TODO: implement bandit for operators - // Bandit op_bandit; - + map> op_bandits; }; // // Explicitly instantiate the template for brush program types