Skip to content

Commit

Permalink
Comments
Browse files Browse the repository at this point in the history
  • Loading branch information
gAldeia committed Jul 30, 2024
1 parent 0bfc7e4 commit 1a7ae95
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 5 deletions.
6 changes: 3 additions & 3 deletions src/bandit/thompson.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ void ThompsonSamplingBandit<T>::update(T arm, float reward) {
// reward must be either 0 or 1

alphas[arm] += reward;
betas[arm] += 1-reward;
betas[arm] += 1.0f-reward;

if (dynamic_update && alphas[arm] + betas[arm] >= C)
{
alphas[arm] *= C/(C+1) ;
betas[arm] *= C/(C+1) ;
alphas[arm] *= C/(C+1.0f) ;
betas[arm] *= C/(C+1.0f) ;
}
}

Expand Down
14 changes: 14 additions & 0 deletions src/data/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class Dataset

/// @brief dataset features, as key value pairs
std::map<string, State> features;

// TODO: this should probably be a more complex type to include feature type
// and potentially other info, like arbitrary relations between features

Expand All @@ -92,6 +93,7 @@ class Dataset
bool use_batch;

Dataset operator()(const vector<size_t>& idx) const;

/// call init at the end of constructors
/// to define metafeatures of the data.
void init();
Expand Down Expand Up @@ -242,6 +244,18 @@ class Dataset
/* template<> ArrayXb get<ArrayXb>(std::string name) */
}; // class data

// TODO: serialization of features in order to nlohmann to work
// NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(Dataset,
// features,
// y,
// classification,
// validation_size,
// use_validation,
// batch_size,
// use_batch,
// shuffle_split
// );

// // read csv
// Dataset read_csv(const std::string & path, MatrixXf& X, VectorXf& y,
// vector<string>& names, vector<char> &dtypes, bool& binary_endpoint, char sep) ;
Expand Down
2 changes: 1 addition & 1 deletion src/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ class Engine{
};

// TODO: should I serialize data and search space as well?
// Only stuff to make new predictions or call fit again
// Only stuff to make new predictions should be serialized
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(Engine<PT::Regressor>, params, best_ind, archive);
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(Engine<PT::BinaryClassifier>, params, best_ind, archive);
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(Engine<PT::MulticlassClassifier>, params, best_ind, archive);
Expand Down
5 changes: 4 additions & 1 deletion src/vary/variation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,8 @@ std::optional<Individual<T>> Variation<T>::cross(
Individual<T> ind(child);
ind.set_variation("cx"); // TODO: use enum here to make it faster

// std::cout << "returning after crossover" << std::endl;

return ind;
}
}
Expand Down Expand Up @@ -562,7 +564,8 @@ std::optional<Individual<T>> Variation<T>::mutate(const Individual<T>& parent)
else // it must be"toggle_weight_off"
success = ToggleWeightOffMutation::mutate(child.Tree, spot, search_space, parameters);

// std::cout << "returning" << std::endl;
// std::cout << "returning after mutation " << choice << std::endl;

if (success
&& ( (child.size() <= parameters.max_size)
&& (child.depth() <= parameters.max_depth) )){
Expand Down

0 comments on commit 1a7ae95

Please sign in to comment.