From 7283739378162bb4f7364f7198c935d1d6bbba10 Mon Sep 17 00:00:00 2001 From: gAldeia Date: Sun, 6 Oct 2024 19:22:42 -0300 Subject: [PATCH] Normalizing probabilities. avoiding unecessary context elements --- src/bandit/bandit.cpp | 34 +++++++++++++++++++++++++--------- src/bandit/linear_thompson.cpp | 25 ++++++++----------------- src/vary/search_space.cpp | 15 ++++++++++++++- src/vary/search_space.h | 6 +++++- src/vary/variation.h | 17 ++++++++++++----- 5 files changed, 64 insertions(+), 33 deletions(-) diff --git a/src/bandit/bandit.cpp b/src/bandit/bandit.cpp index 176e3b44..0904a4d1 100644 --- a/src/bandit/bandit.cpp +++ b/src/bandit/bandit.cpp @@ -133,7 +133,7 @@ VectorXf Bandit::get_context(const tree& tree, Iter spot, const SearchS // std::cout << "Spot name: " << (*spot).name << std::endl; - size_t tot_operators = NodeTypes::Count; + size_t tot_operators = ss.op_names.size(); //NodeTypes::Count; size_t tot_features = 0; for (const auto& pair : ss.terminal_map) @@ -141,6 +141,21 @@ VectorXf Bandit::get_context(const tree& tree, Iter spot, const SearchS size_t tot_symbols = tot_operators + tot_features; + // Print the header with the operator names and terminal names + // std::cout << "Operators: "; + // for (const auto& op_name : ss.op_names) { + // std::cout << op_name << " "; + // } + // std::cout << std::endl; + + // std::cout << "Terminals: "; + // for (const auto& pair : ss.terminal_map) { + // for (const auto& terminal : pair.second) { + // std::cout << terminal.name << " "; + // } + // } + // std::cout << std::endl; + // Assert that tot_symbols is the same as context_size assert(tot_symbols == context_size); @@ -163,7 +178,7 @@ VectorXf Bandit::get_context(const tree& tree, Iter spot, const SearchS pos_shift = 2; // std::cout << "Position shift: " << pos_shift << std::endl; - if (Is((*it).node_type)){ + if (Is((*it).node_type)){ size_t feature_index = 0; // iterating using terminal_types since it is ordered @@ -176,14 +191,15 @@ VectorXf Bandit::get_context(const tree& tree, Iter spot, const SearchS ++feature_index; } } else { - size_t op_index=0; // doesnt care the arg type, as long as the operator is correct - for (; op_index< NodeTypes::Count; op_index++){ - if (static_cast(1UL << op_index) == (*it).node_type) - break; + auto it_op = std::find(ss.op_names.begin(), ss.op_names.end(), (*it).name); + if (it_op != ss.op_names.end()) { + size_t op_index = std::distance(ss.op_names.begin(), it_op); + context(pos_shift * tot_symbols + op_index) += 1.0; + // std::cout << "Below spot, operator: " << (*it).name << " of index " << pos_shift*tot_symbols + op_index << std::endl; + } + else { + HANDLE_ERROR_THROW("Undefined operator " + (*it).name + "\n"); } - - context( pos_shift*tot_symbols + op_index ) += 1.0; - // std::cout << "Below spot, operator: " << (*it).name << " of index " << pos_shift*tot_symbols + op_index << std::endl; } } } diff --git a/src/bandit/linear_thompson.cpp b/src/bandit/linear_thompson.cpp index 5d9e39c0..1c43f2b8 100644 --- a/src/bandit/linear_thompson.cpp +++ b/src/bandit/linear_thompson.cpp @@ -88,26 +88,17 @@ std::map LinearThompsonSamplingBandit::sample_probs(bool update) { // u(i) = dot_product; // } + float total_prob = 0.0f; for (int i = 0; i < n_arms; ++i) { - this->probabilities[arm_index_to_key[i]] = std::exp(u(i)); + float prob = std::exp(u(i)); + this->probabilities[arm_index_to_key[i]] = prob; + total_prob += prob; } - // // Calculate probabilities - // std::map probs; - // float total_prob = 0.0f; - - // for (int i = 0; i < n_arms; ++i) { - // float prob = exp(u(i)) / exp(u.maxCoeff()); - // probs[arm_index_to_key[i]] = prob; - // total_prob += prob; - // } - - // // Normalize probabilities to ensure they sum to 1 - // for (auto& pair : probs) { - // pair.second /= total_prob; - // } - - // this->probabilities = probs; + // Normalize probabilities to ensure they sum to 1 + for (auto& [k, v] : this->probabilities) { + this->probabilities[k] = std::min(this->probabilities[k] / total_prob, 1.0f); + } } return this->probabilities; diff --git a/src/vary/search_space.cpp b/src/vary/search_space.cpp index f7008ed0..c2a35d2e 100644 --- a/src/vary/search_space.cpp +++ b/src/vary/search_space.cpp @@ -175,9 +175,19 @@ void SearchSpace::init(const Dataset& d, const unordered_map& user this->terminal_map.clear(); this->terminal_types.clear(); this->terminal_weights.clear(); + this->op_names.clear(); bool use_all = user_ops.size() == 0; - vector op_names; + if (use_all) + { + for (size_t op_index=0; op_index< NodeTypes::Count; op_index++){ + op_names.push_back( + NodeTypeName.at(static_cast(1UL << op_index)) + ); + } + } + + for (const auto& [op, weight] : user_ops) op_names.push_back(op); @@ -211,12 +221,15 @@ void SearchSpace::init(const Dataset& d, const unordered_map& user // We need some ops in the search space so we can have the logit and offset if (user_ops.find("OffsetSum") == user_ops.end()) extended_user_ops.insert({"OffsetSum", 0.0f}); + op_names.push_back("OffsetSum"); if (unique_classes.size()==2 && (user_ops.find("Logistic") == user_ops.end())) { extended_user_ops.insert({"Logistic", 0.0f}); + op_names.push_back("Logistic"); } else if (user_ops.find("Softmax") == user_ops.end()) { extended_user_ops.insert({"Softmax", 0.0f}); + op_names.push_back("Softmax"); } if (extended_user_ops.size() > 0) diff --git a/src/vary/search_space.h b/src/vary/search_space.h index 86d3db31..248b48f7 100644 --- a/src/vary/search_space.h +++ b/src/vary/search_space.h @@ -120,6 +120,9 @@ struct SearchSpace /// @brief A vector storing the available return types of terminals. vector terminal_types; + /// @brief A vector storing the available operator names (used by bandits). + vector op_names; + // serialization #ifndef DOXYGEN_SKIP @@ -128,7 +131,8 @@ struct SearchSpace node_map_weights, terminal_map, terminal_weights, - terminal_types + terminal_types, + op_names ) #endif diff --git a/src/vary/variation.h b/src/vary/variation.h index ff26e7f2..8441a8ff 100644 --- a/src/vary/variation.h +++ b/src/vary/variation.h @@ -79,7 +79,7 @@ class Variation { if (parameters.cx_prob > 0.0) variation_probs["cx"] = parameters.cx_prob; - size_t tot_operators = NodeTypes::Count; + size_t tot_operators = search_space.op_names.size(); //NodeTypes::Count; size_t tot_features = 0; for (const auto& pair : search_space.terminal_map) tot_features += pair.second.size(); @@ -327,6 +327,13 @@ class Variation { if (allPositive) r = 1.0; + // linear bandit can handle non-bernoulli-like rewards + if (parameters.bandit.compare("linear_thompson") == 0) + { + r = std::count_if(deltas.begin(), deltas.end(), + [](float delta) { return delta > 0; }); + } + // std::cout << "Updating variation bandit with reward: " << r << std::endl; if (ind.get_variation().compare("born") != 0) @@ -340,9 +347,9 @@ class Variation { this->variation_bandit.update(choice, 0.0, root_context); } - if (!ind.get_variation().compare("born") && !ind.get_variation().compare("cx") - && !ind.get_variation().compare("subtree")) - { + // if (!ind.get_variation().compare("born") && !ind.get_variation().compare("cx") + // && !ind.get_variation().compare("subtree")) + // { if (ind.get_sampled_nodes().size() > 0) { // std::cout << "Updating terminal and operator bandits for sampled nodes" << std::endl; const auto& changed_nodes = ind.get_sampled_nodes(); @@ -361,7 +368,7 @@ class Variation { } } } - } + // } pop.individuals.at(indices.at(i)) = std::make_shared>(ind); // std::cout << "Individual at index " << indices.at(i) << " updated successfully" << std::endl;