Skip to content

Commit

Permalink
Working on operator set bandit. new tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gAldeia committed Aug 19, 2024
1 parent 35dfb78 commit 5a37674
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 9 deletions.
5 changes: 1 addition & 4 deletions src/vary/variation.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ class Variation {
// learn only that
for (auto& [ret_type, arg_w_map]: search_space.node_map)
{
std::cout << "creating bandit..." << std::endl;

// TODO: this could be made much easier using user_ops
map<string, float> node_probs;
Expand All @@ -146,11 +145,9 @@ class Variation {
if (!inserted) {
// it->second += weight;
}

std::cout << node.name << ", " << it->second << std::endl;
}
}
op_bandits[ret_type] = Bandit<string>(parameters.bandit, node_probs );
op_bandits[ret_type] = Bandit<string>(parameters.bandit, node_probs);
}
};

Expand Down
43 changes: 39 additions & 4 deletions tests/cpp/test_brush.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ TEST(Engine, SavingLoadingFixedNodes)
Parameters params;
params.set_verbosity(2);
params.set_scorer("log");
params.set_cx_prob(1.0); // TODO: debug. why if I set this to 0.0 it does not work?
params.set_cx_prob(0.0);
params.set_save_population("./tests/cpp/__pop_clf.json");

Brush::ClassifierEngine est(params, ss);
Expand All @@ -205,9 +205,14 @@ TEST(Engine, SavingLoadingFixedNodes)
ASSERT_TRUE(cx_child_root.fixed==true);
}

params.set_load_population("./tests/cpp/__pop_clf.json");
// TODO: why if I set cx_prob to 0.0 it does not work? (maybe because Im using the same params object for the two engines? do i need to remove save_pop file first?)

Brush::ClassifierEngine est2(params, ss);
Parameters params2;
params2.set_verbosity(2);
params2.set_scorer("log");
params2.set_load_population("./tests/cpp/__pop_clf.json");

Brush::ClassifierEngine est2(params2, ss);
est2.run(data);

cout << "Checking if all individuals in the population have the logistic node as its root after loading a previously saved pop to resume execution" << endl;
Expand All @@ -225,4 +230,34 @@ TEST(Engine, SavingLoadingFixedNodes)
ASSERT_TRUE(cx_child_root.get_prob_change()==0.0);
ASSERT_TRUE(cx_child_root.fixed==true);
}
}
}


TEST(Engine, MaxStall)
{
MatrixXf X(10,2);
ArrayXf y(10);
X << 0.85595296, 0.55417453, 0.8641915 , 0.99481109, 0.99123376,
0.9742618 , 0.70894019, 0.94940306, 0.99748867, 0.54205151,

0.5170537 , 0.8324005 , 0.50316305, 0.10173936, 0.13211973,
0.2254195 , 0.70526861, 0.31406024, 0.07082619, 0.84034526;

y << 3.55634251, 3.13854087, 3.55887523, 3.29462895, 3.33443517,
3.4378868 , 3.41092345, 3.5087468 , 3.25110243, 3.11382179;

Dataset data(X,y);
SearchSpace ss(data);

Parameters params;
params.set_pop_size(100);
params.set_max_gens(1000000);
params.set_mig_prob(0.0);
params.set_max_stall(10);
params.set_verbosity(2);

Brush::RegressorEngine est(params, ss);
est.run(data);


}
56 changes: 55 additions & 1 deletion tests/cpp/test_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,58 @@ TEST(Data, MixedVariableTypes)
// operations, and are used to resolve the evaluation of an expression.
// dtable_fit.print();
// dtable_predict.print();
}
}
TEST(Data, ShuffleTrueFalse)
{
MatrixXf X(20,3);
X << 0 , 1, 0 ,
0.0, 1.0, 1.0,
2 , 1.0, -3.0,
2 , 1 , 3 ,
2.1, 3.7, -5.2,
0 , 1, 0 ,
0.0, 1.0, 1.0,
2 , 1.0, -3.0,
2 , 1 , 3 ,
2.1, 3.7, -5.2,
0 , 1, 0 ,
0.0, 1.0, 1.0,
2 , 1.0, -3.0,
2 , 1 , 3 ,
2.1, 3.7, -5.2,
0 , 1, 0 ,
0.0, 1.0, 1.0,
2 , 1.0, -3.0,
2 , 1 , 3 ,
2.1, 3.7, -5.2;

X.transposeInPlace();

ArrayXf y(20);

y << 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0;

// vector<string> vn = {};
// map<string, State> Z = {};
// vector<string> ft = {};

// Dataset(const ArrayXXf& X,
// const Ref<const ArrayXf>& y_ = ArrayXf(),
// const vector<string>& vn = {},
// const map<string, State>& Z = {},
// const vector<string>& ft = {},
// bool c = false,
// float validation_size = 0.0,
// float batch_size = 1.0,
// bool shuffle_split = false

Dataset dt1(X, y, {}, {}, {}, true, 0.3, 1.0, true);
Dataset dt2(X, y, {}, {}, {}, true, 0.3, 1.0, false);
Dataset dt3(X, y, {}, {}, {}, true, 0.0, 1.0, true);
Dataset dt4(X, y, {}, {}, {}, true, 0.0, 1.0, false);
Dataset dt5(X, y, {}, {}, {}, true, 1.0, 1.0, true);
Dataset dt6(X, y, {}, {}, {}, true, 1.0, 1.0, false);

// TODO: write some assertions here
}

0 comments on commit 5a37674

Please sign in to comment.