Skip to content

Commit

Permalink
Normalizing probabilities. avoiding unecessary context elements
Browse files Browse the repository at this point in the history
  • Loading branch information
gAldeia committed Oct 6, 2024
1 parent 6483a00 commit 7283739
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 33 deletions.
34 changes: 25 additions & 9 deletions src/bandit/bandit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,29 @@ VectorXf Bandit<T>::get_context(const tree<Node>& tree, Iter spot, const SearchS

// std::cout << "Spot name: " << (*spot).name << std::endl;

size_t tot_operators = NodeTypes::Count;
size_t tot_operators = ss.op_names.size(); //NodeTypes::Count;
size_t tot_features = 0;

for (const auto& pair : ss.terminal_map)
tot_features += pair.second.size();

size_t tot_symbols = tot_operators + tot_features;

// Print the header with the operator names and terminal names
// std::cout << "Operators: ";
// for (const auto& op_name : ss.op_names) {
// std::cout << op_name << " ";
// }
// std::cout << std::endl;

// std::cout << "Terminals: ";
// for (const auto& pair : ss.terminal_map) {
// for (const auto& terminal : pair.second) {
// std::cout << terminal.name << " ";
// }
// }
// std::cout << std::endl;

// Assert that tot_symbols is the same as context_size
assert(tot_symbols == context_size);

Expand All @@ -163,7 +178,7 @@ VectorXf Bandit<T>::get_context(const tree<Node>& tree, Iter spot, const SearchS
pos_shift = 2;

// std::cout << "Position shift: " << pos_shift << std::endl;
if (Is<NodeType::Terminal>((*it).node_type)){
if (Is<NodeType::Terminal, NodeType::Constant, NodeType::MeanLabel>((*it).node_type)){
size_t feature_index = 0;

// iterating using terminal_types since it is ordered
Expand All @@ -176,14 +191,15 @@ VectorXf Bandit<T>::get_context(const tree<Node>& tree, Iter spot, const SearchS
++feature_index;
}
} else {
size_t op_index=0; // doesnt care the arg type, as long as the operator is correct
for (; op_index< NodeTypes::Count; op_index++){
if (static_cast<NodeType>(1UL << op_index) == (*it).node_type)
break;
auto it_op = std::find(ss.op_names.begin(), ss.op_names.end(), (*it).name);
if (it_op != ss.op_names.end()) {
size_t op_index = std::distance(ss.op_names.begin(), it_op);
context(pos_shift * tot_symbols + op_index) += 1.0;
// std::cout << "Below spot, operator: " << (*it).name << " of index " << pos_shift*tot_symbols + op_index << std::endl;
}
else {
HANDLE_ERROR_THROW("Undefined operator " + (*it).name + "\n");
}

context( pos_shift*tot_symbols + op_index ) += 1.0;
// std::cout << "Below spot, operator: " << (*it).name << " of index " << pos_shift*tot_symbols + op_index << std::endl;
}
}
}
Expand Down
25 changes: 8 additions & 17 deletions src/bandit/linear_thompson.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,26 +88,17 @@ std::map<T, float> LinearThompsonSamplingBandit<T>::sample_probs(bool update) {
// u(i) = dot_product;
// }

float total_prob = 0.0f;
for (int i = 0; i < n_arms; ++i) {
this->probabilities[arm_index_to_key[i]] = std::exp(u(i));
float prob = std::exp(u(i));
this->probabilities[arm_index_to_key[i]] = prob;
total_prob += prob;
}

// // Calculate probabilities
// std::map<T, float> probs;
// float total_prob = 0.0f;

// for (int i = 0; i < n_arms; ++i) {
// float prob = exp(u(i)) / exp(u.maxCoeff());
// probs[arm_index_to_key[i]] = prob;
// total_prob += prob;
// }

// // Normalize probabilities to ensure they sum to 1
// for (auto& pair : probs) {
// pair.second /= total_prob;
// }

// this->probabilities = probs;
// Normalize probabilities to ensure they sum to 1
for (auto& [k, v] : this->probabilities) {
this->probabilities[k] = std::min(this->probabilities[k] / total_prob, 1.0f);
}
}

return this->probabilities;
Expand Down
15 changes: 14 additions & 1 deletion src/vary/search_space.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,19 @@ void SearchSpace::init(const Dataset& d, const unordered_map<string,float>& user
this->terminal_map.clear();
this->terminal_types.clear();
this->terminal_weights.clear();
this->op_names.clear();

bool use_all = user_ops.size() == 0;
vector<string> op_names;
if (use_all)
{
for (size_t op_index=0; op_index< NodeTypes::Count; op_index++){
op_names.push_back(
NodeTypeName.at(static_cast<NodeType>(1UL << op_index))
);
}
}


for (const auto& [op, weight] : user_ops)
op_names.push_back(op);

Expand Down Expand Up @@ -211,12 +221,15 @@ void SearchSpace::init(const Dataset& d, const unordered_map<string,float>& user
// We need some ops in the search space so we can have the logit and offset
if (user_ops.find("OffsetSum") == user_ops.end())
extended_user_ops.insert({"OffsetSum", 0.0f});
op_names.push_back("OffsetSum");

if (unique_classes.size()==2 && (user_ops.find("Logistic") == user_ops.end())) {
extended_user_ops.insert({"Logistic", 0.0f});
op_names.push_back("Logistic");
}
else if (user_ops.find("Softmax") == user_ops.end()) {
extended_user_ops.insert({"Softmax", 0.0f});
op_names.push_back("Softmax");
}

if (extended_user_ops.size() > 0)
Expand Down
6 changes: 5 additions & 1 deletion src/vary/search_space.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ struct SearchSpace
/// @brief A vector storing the available return types of terminals.
vector<DataType> terminal_types;

/// @brief A vector storing the available operator names (used by bandits).
vector<string> op_names;

// serialization
#ifndef DOXYGEN_SKIP

Expand All @@ -128,7 +131,8 @@ struct SearchSpace
node_map_weights,
terminal_map,
terminal_weights,
terminal_types
terminal_types,
op_names
)

#endif
Expand Down
17 changes: 12 additions & 5 deletions src/vary/variation.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class Variation {
if (parameters.cx_prob > 0.0)
variation_probs["cx"] = parameters.cx_prob;

size_t tot_operators = NodeTypes::Count;
size_t tot_operators = search_space.op_names.size(); //NodeTypes::Count;
size_t tot_features = 0;
for (const auto& pair : search_space.terminal_map)
tot_features += pair.second.size();
Expand Down Expand Up @@ -327,6 +327,13 @@ class Variation {
if (allPositive)
r = 1.0;

// linear bandit can handle non-bernoulli-like rewards
if (parameters.bandit.compare("linear_thompson") == 0)
{
r = std::count_if(deltas.begin(), deltas.end(),
[](float delta) { return delta > 0; });
}

// std::cout << "Updating variation bandit with reward: " << r << std::endl;

if (ind.get_variation().compare("born") != 0)
Expand All @@ -340,9 +347,9 @@ class Variation {
this->variation_bandit.update(choice, 0.0, root_context);
}

if (!ind.get_variation().compare("born") && !ind.get_variation().compare("cx")
&& !ind.get_variation().compare("subtree"))
{
// if (!ind.get_variation().compare("born") && !ind.get_variation().compare("cx")
// && !ind.get_variation().compare("subtree"))
// {
if (ind.get_sampled_nodes().size() > 0) {
// std::cout << "Updating terminal and operator bandits for sampled nodes" << std::endl;
const auto& changed_nodes = ind.get_sampled_nodes();
Expand All @@ -361,7 +368,7 @@ class Variation {
}
}
}
}
// }

pop.individuals.at(indices.at(i)) = std::make_shared<Individual<T>>(ind);
// std::cout << "Individual at index " << indices.at(i) << " updated successfully" << std::endl;
Expand Down

0 comments on commit 7283739

Please sign in to comment.