Skip to content

Commit

Permalink
Fixed sample weights not working in lexicase
Browse files Browse the repository at this point in the history
  • Loading branch information
gAldeia committed Nov 8, 2024
1 parent 2451619 commit 691aa04
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 37 deletions.
6 changes: 3 additions & 3 deletions pybrush/EstimatorInterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ def _wrap_parameters(self, y, **extra_kwargs):
params = Parameters()

# Setting up the classification or regression problem
params.classification = self.mode == "classification"
if params.classification:
if self.mode == "classification":
params.classification = True
params.set_n_classes(y)
params.set_class_weights(y)
params.set_sample_weights(y)
Expand Down Expand Up @@ -256,7 +256,7 @@ def _wrap_parameters(self, y, **extra_kwargs):
params.max_time = self.max_time

# Sampling probabilities
params.weights_init=self.weights_init
params.weights_init = self.weights_init
params.bandit = self.bandit
params.mutation_probs = self.mutation_probs

Expand Down
13 changes: 8 additions & 5 deletions src/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,14 +335,17 @@ void Engine<T>::run(Dataset &data)
this->init();

if (params.load_population != "") {
// std::cout << "Loading population from: " << params.load_population << std::endl;
this->pop.load(params.load_population);

// invalidating all individuals
// for (auto& individual : this->pop.individuals) {
// if (individual != nullptr) {
// individual->set_is_fitted(false);
// }
// }
for (auto& individual : this->pop.individuals) {
if (individual != nullptr) {
individual->set_is_fitted(false);
// std::cout << "Invalidated individual with ID: " << individual->id << std::endl;
}
}
// std::cout << "Population loaded and individuals invalidated." << std::endl;
}
else
this->pop.init(this->ss, this->params);
Expand Down
4 changes: 2 additions & 2 deletions src/eval/evaluation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ void Evaluation<T>::assign_fit(Individual<T>& ind, const Dataset& data,
VectorXf val_errors;
f_v = S.score(ind, validation, val_errors, params);

if (val)
ind.error = val_errors;
// if (val) // never use validation data here. This is used in lexicase selection
// ind.error = val_errors;
}

// This is what is going to determine the weights for the individual's fitness
Expand Down
2 changes: 1 addition & 1 deletion src/ind/individual.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Individual{

// storing what changed in relation to parent inside variation
string variation = "born"; // spontanegous generation (born), crossover, or which type of mutation
vector<Node> sampled_nodes; // nodes that were sampled in mutation
vector<Node> sampled_nodes = {}; // nodes that were sampled in mutation

VectorXf error; ///< training error (used in lexicase selectors)

Expand Down
44 changes: 26 additions & 18 deletions src/params.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ struct Parameters

string scorer="mse"; ///< actual loss function used, determined by error

vector<int> classes; ///< class labels
vector<int> classes = vector<int>(); ///< class labels
vector<float> class_weights = vector<float>(); ///< weights for each class
vector<float> sample_weights = vector<float>(); ///< weights for each sample

// for creating dataset from X and y in Engine<T>::fit. Ignored if
// the uses uses an dataset
bool classification;
unsigned int n_classes;
bool classification = false;
unsigned int n_classes = 0;

// validation partition
bool shuffle_split = false;
Expand Down Expand Up @@ -188,31 +188,36 @@ struct Parameters
bool get_weights_init(){ return weights_init; };

void set_n_classes(const ArrayXf& y){
vector<int> uc = unique( ArrayXi(y.cast<int>()) );

if (int(uc.at(0)) != 0)
HANDLE_ERROR_THROW("Class labels must start at 0");

vector<int> cont_classes(uc.size());
iota(cont_classes.begin(), cont_classes.end(), 0);
for (int i = 0; i < cont_classes.size(); ++i)
if (classification)
{
if ( int(uc.at(i)) != cont_classes.at(i))
HANDLE_ERROR_THROW("Class labels must be contiguous");
vector<int> uc = unique( ArrayXi(y.cast<int>()) );

if (int(uc.at(0)) != 0)
HANDLE_ERROR_THROW("Class labels must start at 0");

vector<int> cont_classes(uc.size());
iota(cont_classes.begin(), cont_classes.end(), 0);
for (int i = 0; i < cont_classes.size(); ++i)
{
if ( int(uc.at(i)) != cont_classes.at(i))
HANDLE_ERROR_THROW("Class labels must be contiguous");
}
n_classes = uc.size();
// classes = uc;
}
n_classes = uc.size();
};
void set_class_weights(const ArrayXf& y){
class_weights.resize(n_classes); // set_n_classes must be called first
for (unsigned i = 0; i < n_classes; ++i){
class_weights.at(i) = float((y.cast<int>().array() == i).count())/y.size();
class_weights.at(i) = (1 - class_weights.at(i))*float(n_classes);
class_weights.at(i) = (1.0 - class_weights.at(i))*float(n_classes);
}
};
void set_sample_weights(const ArrayXf& y){
sample_weights.clear(); // set_class_weights must be called first
for (unsigned i = 0; i < y.size(); ++i)
sample_weights.push_back(class_weights.at(int(y(i))));
sample_weights.resize(0); // set_class_weights must be called first
if (!class_weights.empty())
for (unsigned i = 0; i < y.size(); ++i)
sample_weights.push_back(class_weights.at(int(y(i))));
};

unsigned int get_n_classes(){ return n_classes; };
Expand Down Expand Up @@ -261,6 +266,9 @@ NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(Parameters,
mig_prob,
classification,
n_classes,
classes, // TODO: get rid of this parameter? for some reason, when i remove it (or set it to any value) the load population starts to fail with regression
class_weights,
sample_weights,
validation_size,
feature_names,
batch_size,
Expand Down
5 changes: 5 additions & 0 deletions src/pop/population.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,18 @@ void Population<T>::load(string filename)
std::string line;
indata >> line;

// std::cout << "Debug: Read line from file " << std::endl;

json j = json::parse(line);
from_json(j, *this);

// std::cout << "Debug: Parsed JSON successfully." << std::endl;

logger.log("Loaded population from " + filename + " of size = "
+ to_string(this->size()),1);

indata.close();
// std::cout << "Debug: Closed input file." << std::endl;
}

/// update individual vector size and island indexes
Expand Down
16 changes: 8 additions & 8 deletions src/pop/population.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,14 @@ class Population{
};
};

NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(
Population<PT::Regressor>, individuals, island_indexes, pop_size, num_islands, mig_prob, linear_complexity);
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(
Population<PT::BinaryClassifier>, individuals, island_indexes, pop_size, num_islands, mig_prob, linear_complexity);
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(
Population<PT::MulticlassClassifier>, individuals, island_indexes, pop_size, num_islands, mig_prob, linear_complexity);
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(
Population<PT::Representer>, individuals, island_indexes, pop_size, num_islands, mig_prob, linear_complexity);
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(Population<PT::Regressor>,
individuals, island_indexes, pop_size, num_islands, mig_prob, linear_complexity);
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(Population<PT::BinaryClassifier>,
individuals, island_indexes, pop_size, num_islands, mig_prob, linear_complexity);
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(Population<PT::MulticlassClassifier>,
individuals, island_indexes, pop_size, num_islands, mig_prob, linear_complexity);
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(Population<PT::Representer>,
individuals, island_indexes, pop_size, num_islands, mig_prob, linear_complexity);

}// Pop
}// Brush
Expand Down
4 changes: 4 additions & 0 deletions src/selection/lexicase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ vector<size_t> Lexicase<T>::select(Population<T>& pop, int island,
vector<size_t> cases; // cases (samples)
if (params.classification && !params.class_weights.empty())
{
// NOTE: when calling lexicase, make sure `errors` is from training
// data, and not from validation data. This is because the sample
// weights indexes are based on train partition

// for classification problems, weight case selection
// by class weights
cases.resize(0);
Expand Down

0 comments on commit 691aa04

Please sign in to comment.