diff --git a/src/program/operator.h b/src/program/operator.h index 2abceeba..9d2746ac 100644 --- a/src/program/operator.h +++ b/src/program/operator.h @@ -331,7 +331,24 @@ struct Operator using W = typename S::WeightType; RetType fit(const Dataset& d, TreeNode& tn) const { - tn.data.W = d.y.mean(); + // we take the mode of the labels if it is a classification problem + if (d.classification) + { + std::unordered_map counters; + for (float val : d.y) { + ++counters[val]; + } + + auto mode = std::max_element( + counters.begin(), counters.end(), + [](const auto& a, const auto& b) { return a.second < b.second; } + ); + + tn.data.W = mode->first; + } + else + tn.data.W = d.y.mean(); + return predict(d, tn); };