Skip to content

Commit

Permalink
Fit, predict and predict_proba for individual classes
Browse files Browse the repository at this point in the history
  • Loading branch information
gAldeia committed Apr 25, 2024
1 parent c64fad7 commit ddeb6fd
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 13 deletions.
37 changes: 32 additions & 5 deletions src/bindings/bind_individuals.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,23 @@
namespace nl = nlohmann;
namespace br = Brush;

using Reg = br::Pop::Individual<br::ProgramType::Regressor>;
using Cls = br::Pop::Individual<br::ProgramType::BinaryClassifier>;
using MCls = br::Pop::Individual<br::ProgramType::MulticlassClassifier>;
using Rep = br::Pop::Individual<br::ProgramType::Representer>;

using stream_redirect = py::call_guard<py::scoped_ostream_redirect, py::scoped_estream_redirect>;

// TODO: unify PT or T
template <br::ProgramType T>
void bind_individual(py::module& m, string name)
{
using Class = br::Pop::Individual<T>;

using RetType = std::conditional_t<
std::is_same_v<Class,Reg>, ArrayXf,
std::conditional_t<std::is_same_v<Class,Cls>, ArrayXb,
std::conditional_t<std::is_same_v<Class,MCls>, ArrayXi, ArrayXXf>>>;

py::class_<Class> ind(m, name.data() );
ind.def(py::init<>())
Expand All @@ -26,7 +36,18 @@ void bind_individual(py::module& m, string name)
.def_property("objectives", &Class::get_objectives, &Class::set_objectives)
.def_property_readonly("program", &Class::get_program)
.def_property_readonly("fitness", &Class::get_fitness)
// .def_property("complexity", &Class::get_complexity, &Class::set_complexity)
.def("fit",
static_cast<Class &(Class::*)(const Dataset &d)>(&Class::fit),
"fit from Dataset object")
.def("fit",
static_cast<Class &(Class::*)(const Ref<const ArrayXXf> &X, const Ref<const ArrayXf> &y)>(&Class::fit),
"fit from X,y data")
.def("predict",
static_cast<RetType (Class::*)(const Dataset &d)>(&Class::predict),
"predict from Dataset object")
.def("predict",
static_cast<RetType (Class::*)(const Ref<const ArrayXXf> &X)>(&Class::predict),
"predict from X data")
.def(py::pickle(
[](const Class &p) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
Expand All @@ -42,9 +63,15 @@ void bind_individual(py::module& m, string name)
)
;

// if constexpr (std::is_same_v<T,Cls>)
// {

// }
if constexpr (std::is_same_v<Class,Cls>)
{
ind.def("predict_proba",
static_cast<ArrayXf (Class::*)(const Dataset &d)>(&Class::predict_proba),
"predict from Dataset object")
.def("predict_proba",
static_cast<ArrayXf (Class::*)(const Ref<const ArrayXXf> &X)>(&Class::predict_proba),
"predict from X data")
;
}

}
1 change: 0 additions & 1 deletion src/bindings/bind_programs.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ void bind_program(py::module& m, string name)
;
if constexpr (std::is_same_v<T,Cls>)
{
// TODO: have these in individual and wrapper
prog.def("predict_proba",
static_cast<ArrayXf (T::*)(const Dataset &d)>(&T::predict_proba),
"predict from Dataset object")
Expand Down
2 changes: 1 addition & 1 deletion src/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ void Engine<T>::run(Dataset &data)
bool updated_best = this->update_best(data);

// TODO: use_arch
if ( params.verbosity>1 || !logfile.empty()) {
if ( params.verbosity>1 || !params.logfile.empty()) {
calculate_stats();
}

Expand Down
37 changes: 31 additions & 6 deletions src/ind/individual.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class Individual{

// error is the aggregation of error vector, and can be user sppecified

// this flag is used to avoid re-fitting an individual. the program is_fitted_ flag is used to perform checks (like in predict with weights). They are two different things and I think I;ll keep this way (individual is just a container to keep program and fitness together)
bool is_fitted_ = false;

VectorXf error; ///< training error (used in lexicase selectors)
Expand All @@ -44,16 +45,40 @@ class Individual{
// program = SS.make_program<T>(params, params.max_depth, params.max_size);
};

// fitness, objetives, complexity, etc.
void fit(Dataset& data) {
// TODO: replace occurences of program.fit with these (also predict and predict_proba)
Individual<T> &fit(const Dataset& data) {
program.fit(data);
// this flag is used to avoid re-fitting an individual. the program is_fitted_ flag is used to perform checks (like in predict with weights). They are two different things and I think I;ll keep this way (individual is just a container to keep program and fitness together)
this->is_fitted_ = true;
return *this;
};
Individual<T> &fit(const Ref<const ArrayXXf>& X, const Ref<const ArrayXf>& y)
{
Dataset d(X,y);
return fit(d);
};

auto predict(const Dataset& data) { return program.predict(data); };
auto predict(const Ref<const ArrayXXf>& X)
{
Dataset d(X);
return predict(d);
};

template <ProgramType P = T>
requires((P == PT::BinaryClassifier) || (P == PT::MulticlassClassifier))
auto predict_proba(const Dataset &d)
{
return program.predict_proba(d);
};

template <ProgramType P = T>
requires((P == PT::BinaryClassifier) || (P == PT::MulticlassClassifier))
auto predict_proba(const Ref<const ArrayXXf>& X)
{
Dataset d(X);
return predict_proba(d);
};

auto predict(Dataset& data) { return program.predict(data); };

// TODO: predict proba and classification related methods.

// just getters
bool get_is_fitted() const { return this->is_fitted_; };
Expand Down
1 change: 1 addition & 0 deletions tests/cpp/test_individuals.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
// TODO: test predict, predict proba, fit.

0 comments on commit ddeb6fd

Please sign in to comment.