Skip to content

Commit

Permalink
Fixed bad logic creating context with different data types
Browse files Browse the repository at this point in the history
  • Loading branch information
gAldeia committed Oct 1, 2024
1 parent 83835f2 commit 169304a
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 60 deletions.
88 changes: 31 additions & 57 deletions src/bandit/bandit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ VectorXf Bandit<T>::get_context(const tree<Node>& tree, Iter spot, const SearchS

size_t tot_symbols = tot_operators + tot_features;

// Assert that tot_symbols is the same as context_size
assert(tot_symbols == context_size);

VectorXf context( 3 * tot_symbols );
context.setZero();

Expand All @@ -149,67 +152,38 @@ VectorXf Bandit<T>::get_context(const tree<Node>& tree, Iter spot, const SearchS
// std::cout << "Check succeeded for node: " << (*it).name << std::endl;
// std::cout << "Depth of spot: " << tree.depth(spot) << std::endl;
// std::cout << "Depth of it: " << tree.depth(it) << std::endl;
if (it == spot) {
// std::cout << "It is the spot, searching for it " << std::endl;

if (Is<NodeType::Terminal>((*it).node_type)){
size_t feature_index = 0;

// iterating using terminal_types since it is ordered
for (const auto& dtype : ss.terminal_types) {
for (const auto& terminal : ss.terminal_map.at(dtype)) {
if (terminal.name == (*it).name) {
context((tot_operators + feature_index) + tot_symbols) += 1.0;
// std::cout << "Spot terminal: " << terminal.name << " at feature index " << feature_index << std::endl;
break;
}
++feature_index;
}
}
} else {
// finding the index of the operator
size_t op_index=0;
for (; op_index< NodeTypes::Count; op_index++){
if (static_cast<NodeType>(1UL << op_index) == (*it).node_type)
break;
// std::cout << "It is the spot, searching for it " << std::endl;

// deciding if it is above or below the spot
size_t pos_shift = 0; // above
if (it == spot) { // spot
pos_shift = 1;
}
else if (tree.is_in_subtree(it, spot)) // below
pos_shift = 2;

// std::cout << "Position shift: " << pos_shift << std::endl;
if (Is<NodeType::Terminal>((*it).node_type)){
size_t feature_index = 0;

// iterating using terminal_types since it is ordered
for (const auto& terminal : ss.terminal_map.at((*it).ret_type)) {
if (terminal.name == (*it).name) {
context((tot_operators + feature_index) + pos_shift*tot_symbols) += 1.0;
// std::cout << "Below spot, terminal: " << terminal.name << " at feature index " << feature_index << std::endl;
break;
}

context(tot_symbols + op_index) += 1.0;
// std::cout << "Spot operator: " << (*it).name << " of index " << tot_symbols + op_index << std::endl;
++feature_index;
}
} else {
// std::cout << "Below spot " << std::endl;

// deciding if it is above or below the spot
size_t pos_shift = 0;
if (tree.is_in_subtree(it, spot))
pos_shift = 2;

// std::cout << "Position shift: " << pos_shift << std::endl;
if (Is<NodeType::Terminal>((*it).node_type)){
size_t feature_index = 0;

// iterating using terminal_types since it is ordered
for (const auto& dtype : ss.terminal_types) {
for (const auto& terminal : ss.terminal_map.at(dtype)) {
if (terminal.name == (*it).name) {
context((tot_operators + feature_index) + pos_shift*tot_symbols) += 1.0;
// std::cout << "Below spot, terminal: " << terminal.name << " at feature index " << feature_index << std::endl;
break;
}
++feature_index;
}
}
} else {
size_t op_index=0;
for (; op_index< NodeTypes::Count; op_index++){
if (static_cast<NodeType>(1UL << op_index) == (*it).node_type)
break;
}

context( pos_shift*tot_symbols + op_index ) += 1.0;
// std::cout << "Below spot, operator: " << (*it).name << " of index " << pos_shift*tot_symbols + op_index << std::endl;
size_t op_index=0; // doesnt care the arg type, as long as the operator is correct
for (; op_index< NodeTypes::Count; op_index++){
if (static_cast<NodeType>(1UL << op_index) == (*it).node_type)
break;
}

context( pos_shift*tot_symbols + op_index ) += 1.0;
// std::cout << "Below spot, operator: " << (*it).name << " of index " << pos_shift*tot_symbols + op_index << std::endl;
}
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,11 +297,13 @@ bool Engine<T>::update_best()
for (size_t j = 0; j < this->best_ind.fitness.get_wvalues().size(); ++j) {
if (ind.fitness.get_wvalues()[j] > this->best_ind.fitness.get_wvalues()[j]) {
passed = true;
break;
}
else if (ind.fitness.get_wvalues()[j] < this->best_ind.fitness.get_wvalues()[j]) {
if (ind.fitness.get_wvalues()[j] < this->best_ind.fitness.get_wvalues()[j]) {
// it is not better, and it is also not equal. So, it is worse. Stop here.
break;
}
// if no break, then its equal, so we keep going
}

if (passed)
Expand Down
2 changes: 0 additions & 2 deletions src/pop/population.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,6 @@ vector<size_t> Population<T>::hall_of_fame(unsigned rank)
hof.clear();

for (int i = 0; i < merged_islands.size(); ++i) {

std::vector<unsigned int> dom;
int dcount = 0;

auto p = individuals.at(merged_islands[i]);
Expand Down

0 comments on commit 169304a

Please sign in to comment.