From 4acf9e6eba5ea1c4c26c139bbecb2133e5db38aa Mon Sep 17 00:00:00 2001 From: gAldeia Date: Mon, 7 Oct 2024 16:21:32 -0300 Subject: [PATCH] Better normalization of bandit probabilities --- src/bandit/bandit.cpp | 30 ++++++++++++++-------------- src/bandit/linear_thompson.cpp | 36 ++++++---------------------------- 2 files changed, 21 insertions(+), 45 deletions(-) diff --git a/src/bandit/bandit.cpp b/src/bandit/bandit.cpp index 0904a4d1..15006d78 100644 --- a/src/bandit/bandit.cpp +++ b/src/bandit/bandit.cpp @@ -142,19 +142,19 @@ VectorXf Bandit::get_context(const tree& tree, Iter spot, const SearchS size_t tot_symbols = tot_operators + tot_features; // Print the header with the operator names and terminal names - // std::cout << "Operators: "; - // for (const auto& op_name : ss.op_names) { - // std::cout << op_name << " "; - // } - // std::cout << std::endl; + std::cout << "Operators: "; + for (const auto& op_name : ss.op_names) { + std::cout << op_name << " "; + } + std::cout << std::endl; - // std::cout << "Terminals: "; - // for (const auto& pair : ss.terminal_map) { - // for (const auto& terminal : pair.second) { - // std::cout << terminal.name << " "; - // } - // } - // std::cout << std::endl; + std::cout << "Terminals: "; + for (const auto& pair : ss.terminal_map) { + for (const auto& terminal : pair.second) { + std::cout << terminal.name << " "; + } + } + std::cout << std::endl; // Assert that tot_symbols is the same as context_size assert(tot_symbols == context_size); @@ -204,9 +204,9 @@ VectorXf Bandit::get_context(const tree& tree, Iter spot, const SearchS } } - // std::cout << "Context part 1: " << context.head(tot_symbols).transpose() << std::endl; - // std::cout << "Context part 2: " << context.segment(tot_symbols, tot_symbols).transpose() << std::endl; - // std::cout << "Context part 3: " << context.tail(tot_symbols).transpose() << std::endl; + std::cout << "Context part 1: " << context.head(tot_symbols).transpose() << std::endl; + std::cout << "Context part 2: " << context.segment(tot_symbols, tot_symbols).transpose() << std::endl; + std::cout << "Context part 3: " << context.tail(tot_symbols).transpose() << std::endl; return context; } diff --git a/src/bandit/linear_thompson.cpp b/src/bandit/linear_thompson.cpp index 1c43f2b8..68a0e66a 100644 --- a/src/bandit/linear_thompson.cpp +++ b/src/bandit/linear_thompson.cpp @@ -73,24 +73,14 @@ std::map LinearThompsonSamplingBandit::sample_probs(bool update) { w = mean + w; VectorXf u(n_arms); + + last_context.setOnes(); + u = w * last_context; // mat mul - // for (int i = 0; i < n_arms; ++i) { - // // cout << "Dot product for row " << i; - // float dot_product = w.row(i).dot(last_context); - // if (std::isnan(dot_product)) - // { - // dot_product = 0.0f; - // // cout << "(nan)"; - // } - // // cout << "Dot product for row " << i << ": " << dot_product << endl; - - // u(i) = dot_product; - // } - float total_prob = 0.0f; for (int i = 0; i < n_arms; ++i) { - float prob = std::exp(u(i)); + float prob = std::exp(u(i)) / std::exp(u.maxCoeff()); this->probabilities[arm_index_to_key[i]] = prob; total_prob += prob; } @@ -127,20 +117,6 @@ T LinearThompsonSamplingBandit::choose(const VectorXf& context) { u = w * context; // mat mul // cout << "u: " << u << endl; - // for (int i = 0; i < n_arms; ++i) { - // // cout << "Dot product for row " << i; - // float dot_product = w.row(i).dot(context); - // if (std::isnan(dot_product)) - // { - // dot_product = 0.0f; - // // cout << "(nan)"; - // } - - // // cout << "Dot product for row " << i << ": " << dot_product << endl; - - // u(i) = dot_product; - // } - Eigen::Index max_index; float max_value = u.maxCoeff(&max_index); // cout << "max_index: " << max_index << ", max_value: " << max_value << endl; @@ -173,7 +149,7 @@ void LinearThompsonSamplingBandit::update(T arm, float reward, VectorXf& cont B[arm_index] += context * context.transpose(); // cout << "B[arm_index] after update: " << B[arm_index] << endl; - m2_r.row(arm_index) += context * reward; + m2_r.row(arm_index) += (context * reward); // cout << "m2_r.row(arm_index) after update: " << m2_r.row(arm_index) << endl; B_inv[arm_index] = B[arm_index].inverse(); @@ -182,7 +158,7 @@ void LinearThompsonSamplingBandit::update(T arm, float reward, VectorXf& cont B_inv_sqrt[arm_index] = B_inv[arm_index].ldlt().matrixL(); // cout << "B_inv_sqrt[arm_index]: " << B_inv_sqrt[arm_index] << endl; - mean.row(arm_index) = B_inv[arm_index] * m2_r.row(arm_index); // mat mul + mean.row(arm_index) = B_inv[arm_index] * m2_r.row(arm_index).transpose(); // mat mul // cout << "mean.row(arm_index): " << mean.row(arm_index) << endl; // cout << "update finished" << endl;