Skip to content

Commit

Permalink
dynamic bandit implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
gAldeia committed Jul 29, 2024
1 parent bbf5f06 commit eaba775
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 13 deletions.
4 changes: 3 additions & 1 deletion src/bandit/bandit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace MAB {

template <typename T>
Bandit<T>::Bandit() {
set_type("dummy");
set_type("dynamic_thompson");
set_arms({});
set_probs({});
set_bandit();
Expand Down Expand Up @@ -42,6 +42,8 @@ void Bandit<T>::set_bandit() {
// other methods to raise an error if bandit was not set
if (type == "thompson") {
pbandit = make_unique<ThompsonSamplingBandit<T>>(probabilities);
} else if (type == "dynamic_thompson") {
pbandit = make_unique<ThompsonSamplingBandit<T>>(probabilities, true);
} else if (type == "dummy") {
pbandit = make_unique<DummyBandit<T>>(probabilities);
} else {
Expand Down
28 changes: 23 additions & 5 deletions src/bandit/thompson.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ namespace Brush {
namespace MAB {

template <typename T>
ThompsonSamplingBandit<T>::ThompsonSamplingBandit(vector<T> arms)
ThompsonSamplingBandit<T>::ThompsonSamplingBandit(vector<T> arms, bool dynamic)
: BanditOperator<T>(arms)
, dynamic_update(dynamic)
{
for (const auto& arm : arms) {
alphas[arm] = 2;
Expand All @@ -14,8 +15,9 @@ ThompsonSamplingBandit<T>::ThompsonSamplingBandit(vector<T> arms)
}

template <typename T>
ThompsonSamplingBandit<T>::ThompsonSamplingBandit(map<T, float> arms_probs)
ThompsonSamplingBandit<T>::ThompsonSamplingBandit(map<T, float> arms_probs, bool dynamic)
: BanditOperator<T>(arms_probs)
, dynamic_update(dynamic)
{
for (const auto& pair : arms_probs) {
alphas[pair.first] = 2;
Expand All @@ -26,7 +28,8 @@ ThompsonSamplingBandit<T>::ThompsonSamplingBandit(map<T, float> arms_probs)

template <typename T>
std::map<T, float> ThompsonSamplingBandit<T>::sample_probs(bool update) {

// gets sampling probabilities using the bandit

// from https://stackoverflow.com/questions/4181403/generate-random-number-based-on-beta-distribution-using-boost
// You'll first want to draw a random number uniformly from the
// range (0,1). Given any distribution, you can then plug that number
Expand Down Expand Up @@ -59,8 +62,16 @@ std::map<T, float> ThompsonSamplingBandit<T>::sample_probs(bool update) {

prob = X/(X+Y);

this->probabilities[arm] = prob;
// avoiding deadlocks when sampling from search space
this->probabilities[arm] = std::max(prob, 0.01f);
}

// assert that the sum is not zero
float totalProb = 0.0f;
for (const auto& pair : this->probabilities) {
totalProb += pair.second;
}
assert(totalProb != 0.0f && "Sum of probabilities is zero!");
}

return this->probabilities;
Expand All @@ -69,8 +80,15 @@ std::map<T, float> ThompsonSamplingBandit<T>::sample_probs(bool update) {
template <typename T>
void ThompsonSamplingBandit<T>::update(T arm, float reward) {
// reward must be either 0 or 1

alphas[arm] += reward;
betas[arm] += 1 - reward;
betas[arm] += 1-reward;

if (dynamic_update && alphas[arm] + betas[arm] >= C)
{
alphas[arm] *= C/(C+1) ;
betas[arm] *= C/(C+1) ;
}
}

} // MAB
Expand Down
8 changes: 4 additions & 4 deletions src/bandit/thompson.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ template <typename T>
class ThompsonSamplingBandit : public BanditOperator<T>
{
public:
ThompsonSamplingBandit(vector<T> arms);
ThompsonSamplingBandit(map<T, float> arms_probs);
ThompsonSamplingBandit(vector<T> arms, bool dynamic=false);
ThompsonSamplingBandit(map<T, float> arms_probs, bool dynamic=false);
~ThompsonSamplingBandit(){};

std::map<T, float> sample_probs(bool update);
void update(T arm, float reward);

private:
// additional stuff should come here
bool dynamic_update;
float C = 1000;

std::map<T, int> alphas;
std::map<T, int> betas;
Expand Down
2 changes: 1 addition & 1 deletion src/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ void Engine<T>::run(Dataset &data)

[&]() {
// getting the updated versions
if (params.bandit != "duummy")
if (params.bandit != "dummy")
{
// TODO: make the probabilities add up to 1 (this doesnt matter for the cpp side, but it is a good practice and helps comparing different probabilities)
this->ss = variator.search_space;
Expand Down
2 changes: 1 addition & 1 deletion src/params.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ struct Parameters
unsigned int max_size = 50;

vector<string> objectives{"error","complexity"}; // error should be generic and deducted based on mode
string bandit = "dummy"; // should I rename that?
string bandit = "dynamic_thompson"; // TODO: should I rename dummy?
string sel = "lexicase"; //selection method
string surv = "nsga2"; //survival method
std::unordered_map<string, float> functions;
Expand Down
7 changes: 6 additions & 1 deletion src/vary/variation.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ class Variation {

this->variation_bandit = Bandit<string>(parameters.bandit, variation_probs);

// TODO: should I set C parameter based on pop size or leave it fixed?
// TODO: update string comparisons to use .compare method
// if (parameters.bandit.compare("dynamic_thompson")==0)
// this->variation_bandit.pbandit.set_C(parameters.pop_size);

// initializing one bandit for each terminal type
for (const auto& entry : this->search_space.terminal_weights) {
// entry is a tuple <dataType, vector<float>> where the vector is the weights
Expand All @@ -119,7 +124,7 @@ class Variation {
}
}


// TODO: op bandit?
// this->op_bandit = Bandit<DataType>(this->parameters.bandit,
// this->search_space.node_map_weights.size() );

Expand Down

0 comments on commit eaba775

Please sign in to comment.