Skip to content

Commit

Permalink
Fix handling of normalization at leaf node
Browse files Browse the repository at this point in the history
  • Loading branch information
hcho3 committed Feb 16, 2024
1 parent 4f46aba commit e3c0263
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
22 changes: 18 additions & 4 deletions src/model_loader/sklearn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <algorithm>
#include <cstdint>
#include <experimental/mdspan>
#include <limits>
#include <memory>
#include <numeric>
Expand All @@ -23,6 +24,11 @@ namespace treelite::model_loader::sklearn {

namespace detail {

namespace stdex = std::experimental;
// Multidimensional array views. Use row-major (C) layout
template <typename ElemT>
using Array2DView = stdex::mdspan<ElemT, stdex::dextents<std::uint64_t, 2>, stdex::layout_right>;

class RandomForestRegressorMixIn {
public:
void HandleMetadata(model_builder::ModelBuilder& builder, int n_trees, int n_features,
Expand Down Expand Up @@ -80,16 +86,24 @@ class RandomForestClassifierMixIn {
}

void HandleLeafNode(model_builder::ModelBuilder& builder, int tree_id, int node_id,
double const** value, [[maybe_unused]] int const* n_classes) const {
double const** value, std::int32_t const* n_classes) const {
TREELITE_CHECK_GT(n_targets_, 0)
<< "n_targets not yet initialized. Was HandleMetadata() called?";
TREELITE_CHECK_GT(max_num_class_, 0)
<< "max_num_class not yet initialized. Was HandleMetadata() called?";
std::vector<double> leafvec(&value[tree_id][node_id * n_targets_ * max_num_class_],
&value[tree_id][(node_id + 1) * n_targets_ * max_num_class_]);
// Compute the probability distribution over label classes
double const norm_factor = std::accumulate(leafvec.begin(), leafvec.end(), 0.0);
std::for_each(leafvec.begin(), leafvec.end(), [norm_factor](double& e) { e /= norm_factor; });
auto leaf_view = Array2DView<double>(leafvec.data(), n_targets_, max_num_class_);
for (int target_id = 0; target_id < n_targets_; ++target_id) {
double norm_factor = 0.0;
for (std::int32_t class_id = 0; class_id < n_classes[target_id]; ++class_id) {
norm_factor += leaf_view(target_id, class_id);
}
for (std::int32_t class_id = 0; class_id < n_classes[target_id]; ++class_id) {
leaf_view(target_id, class_id) /= norm_factor;
}
}
builder.LeafVector(leafvec);
}

Expand All @@ -116,7 +130,7 @@ class IsolationForestMixIn {
}

void HandleLeafNode(model_builder::ModelBuilder& builder, int tree_id, int node_id,
double const** value, [[maybe_unused]] int const* n_classes) const {
double const** value, [[maybe_unused]] std::int32_t const* n_classes) const {
builder.LeafScalar(value[tree_id][node_id]);
}

Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_sklearn_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def test_skl_multitarget_multiclass_rf(n_classes, n_estimators):

tl_model = treelite.sklearn.import_model(clf)
out_prob = treelite.gtil.predict(tl_model, X)
expected_prob = clf.predict_proba(X)
expected_prob = np.transpose(clf.predict_proba(X), axes=(1, 0, 2))
np.testing.assert_almost_equal(out_prob, expected_prob, decimal=5)


Expand Down

0 comments on commit e3c0263

Please sign in to comment.