Skip to content

Commit

Permalink
Better normalization of bandit probabilities
Browse files Browse the repository at this point in the history
  • Loading branch information
gAldeia committed Oct 7, 2024
1 parent 7283739 commit 4acf9e6
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 45 deletions.
30 changes: 15 additions & 15 deletions src/bandit/bandit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,19 +142,19 @@ VectorXf Bandit<T>::get_context(const tree<Node>& tree, Iter spot, const SearchS
size_t tot_symbols = tot_operators + tot_features;

// Print the header with the operator names and terminal names
// std::cout << "Operators: ";
// for (const auto& op_name : ss.op_names) {
// std::cout << op_name << " ";
// }
// std::cout << std::endl;
std::cout << "Operators: ";
for (const auto& op_name : ss.op_names) {
std::cout << op_name << " ";
}
std::cout << std::endl;

// std::cout << "Terminals: ";
// for (const auto& pair : ss.terminal_map) {
// for (const auto& terminal : pair.second) {
// std::cout << terminal.name << " ";
// }
// }
// std::cout << std::endl;
std::cout << "Terminals: ";
for (const auto& pair : ss.terminal_map) {
for (const auto& terminal : pair.second) {
std::cout << terminal.name << " ";
}
}
std::cout << std::endl;

// Assert that tot_symbols is the same as context_size
assert(tot_symbols == context_size);
Expand Down Expand Up @@ -204,9 +204,9 @@ VectorXf Bandit<T>::get_context(const tree<Node>& tree, Iter spot, const SearchS
}
}

// std::cout << "Context part 1: " << context.head(tot_symbols).transpose() << std::endl;
// std::cout << "Context part 2: " << context.segment(tot_symbols, tot_symbols).transpose() << std::endl;
// std::cout << "Context part 3: " << context.tail(tot_symbols).transpose() << std::endl;
std::cout << "Context part 1: " << context.head(tot_symbols).transpose() << std::endl;
std::cout << "Context part 2: " << context.segment(tot_symbols, tot_symbols).transpose() << std::endl;
std::cout << "Context part 3: " << context.tail(tot_symbols).transpose() << std::endl;

return context;
}
Expand Down
36 changes: 6 additions & 30 deletions src/bandit/linear_thompson.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,24 +73,14 @@ std::map<T, float> LinearThompsonSamplingBandit<T>::sample_probs(bool update) {
w = mean + w;

VectorXf u(n_arms);

last_context.setOnes();

u = w * last_context; // mat mul

// for (int i = 0; i < n_arms; ++i) {
// // cout << "Dot product for row " << i;
// float dot_product = w.row(i).dot(last_context);
// if (std::isnan(dot_product))
// {
// dot_product = 0.0f;
// // cout << "(nan)";
// }
// // cout << "Dot product for row " << i << ": " << dot_product << endl;

// u(i) = dot_product;
// }

float total_prob = 0.0f;
for (int i = 0; i < n_arms; ++i) {
float prob = std::exp(u(i));
float prob = std::exp(u(i)) / std::exp(u.maxCoeff());
this->probabilities[arm_index_to_key[i]] = prob;
total_prob += prob;
}
Expand Down Expand Up @@ -127,20 +117,6 @@ T LinearThompsonSamplingBandit<T>::choose(const VectorXf& context) {
u = w * context; // mat mul
// cout << "u: " << u << endl;

// for (int i = 0; i < n_arms; ++i) {
// // cout << "Dot product for row " << i;
// float dot_product = w.row(i).dot(context);
// if (std::isnan(dot_product))
// {
// dot_product = 0.0f;
// // cout << "(nan)";
// }

// // cout << "Dot product for row " << i << ": " << dot_product << endl;

// u(i) = dot_product;
// }

Eigen::Index max_index;
float max_value = u.maxCoeff(&max_index);
// cout << "max_index: " << max_index << ", max_value: " << max_value << endl;
Expand Down Expand Up @@ -173,7 +149,7 @@ void LinearThompsonSamplingBandit<T>::update(T arm, float reward, VectorXf& cont
B[arm_index] += context * context.transpose();
// cout << "B[arm_index] after update: " << B[arm_index] << endl;

m2_r.row(arm_index) += context * reward;
m2_r.row(arm_index) += (context * reward);
// cout << "m2_r.row(arm_index) after update: " << m2_r.row(arm_index) << endl;

B_inv[arm_index] = B[arm_index].inverse();
Expand All @@ -182,7 +158,7 @@ void LinearThompsonSamplingBandit<T>::update(T arm, float reward, VectorXf& cont
B_inv_sqrt[arm_index] = B_inv[arm_index].ldlt().matrixL();
// cout << "B_inv_sqrt[arm_index]: " << B_inv_sqrt[arm_index] << endl;

mean.row(arm_index) = B_inv[arm_index] * m2_r.row(arm_index); // mat mul
mean.row(arm_index) = B_inv[arm_index] * m2_r.row(arm_index).transpose(); // mat mul
// cout << "mean.row(arm_index): " << mean.row(arm_index) << endl;

// cout << "update finished" << endl;
Expand Down

0 comments on commit 4acf9e6

Please sign in to comment.