Skip to content

Commit

Permalink
MeanLabel uses the most frequent y value in classification tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
gAldeia committed Jul 10, 2024
1 parent 9368414 commit a2c9f4f
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion src/program/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,24 @@ struct Operator<NodeType::MeanLabel, S, Fit>
using W = typename S::WeightType;

RetType fit(const Dataset& d, TreeNode& tn) const {
tn.data.W = d.y.mean();
// we take the mode of the labels if it is a classification problem
if (d.classification)
{
std::unordered_map<float, int> counters;
for (float val : d.y) {
++counters[val];
}

auto mode = std::max_element(
counters.begin(), counters.end(),
[](const auto& a, const auto& b) { return a.second < b.second; }
);

tn.data.W = mode->first;
}
else
tn.data.W = d.y.mean();

return predict(d, tn);
};

Expand Down

0 comments on commit a2c9f4f

Please sign in to comment.