From e3c02633c0d7f36390f36943c74c202531ec597d Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Thu, 15 Feb 2024 22:42:49 -0800 Subject: [PATCH] Fix handling of normalization at leaf node --- src/model_loader/sklearn.cc | 22 ++++++++++++++++++---- tests/python/test_sklearn_integration.py | 2 +- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/src/model_loader/sklearn.cc b/src/model_loader/sklearn.cc index 5f926cad..677e3efe 100644 --- a/src/model_loader/sklearn.cc +++ b/src/model_loader/sklearn.cc @@ -14,6 +14,7 @@ #include #include +#include #include #include #include @@ -23,6 +24,11 @@ namespace treelite::model_loader::sklearn { namespace detail { +namespace stdex = std::experimental; +// Multidimensional array views. Use row-major (C) layout +template +using Array2DView = stdex::mdspan, stdex::layout_right>; + class RandomForestRegressorMixIn { public: void HandleMetadata(model_builder::ModelBuilder& builder, int n_trees, int n_features, @@ -80,7 +86,7 @@ class RandomForestClassifierMixIn { } void HandleLeafNode(model_builder::ModelBuilder& builder, int tree_id, int node_id, - double const** value, [[maybe_unused]] int const* n_classes) const { + double const** value, std::int32_t const* n_classes) const { TREELITE_CHECK_GT(n_targets_, 0) << "n_targets not yet initialized. Was HandleMetadata() called?"; TREELITE_CHECK_GT(max_num_class_, 0) @@ -88,8 +94,16 @@ class RandomForestClassifierMixIn { std::vector leafvec(&value[tree_id][node_id * n_targets_ * max_num_class_], &value[tree_id][(node_id + 1) * n_targets_ * max_num_class_]); // Compute the probability distribution over label classes - double const norm_factor = std::accumulate(leafvec.begin(), leafvec.end(), 0.0); - std::for_each(leafvec.begin(), leafvec.end(), [norm_factor](double& e) { e /= norm_factor; }); + auto leaf_view = Array2DView(leafvec.data(), n_targets_, max_num_class_); + for (int target_id = 0; target_id < n_targets_; ++target_id) { + double norm_factor = 0.0; + for (std::int32_t class_id = 0; class_id < n_classes[target_id]; ++class_id) { + norm_factor += leaf_view(target_id, class_id); + } + for (std::int32_t class_id = 0; class_id < n_classes[target_id]; ++class_id) { + leaf_view(target_id, class_id) /= norm_factor; + } + } builder.LeafVector(leafvec); } @@ -116,7 +130,7 @@ class IsolationForestMixIn { } void HandleLeafNode(model_builder::ModelBuilder& builder, int tree_id, int node_id, - double const** value, [[maybe_unused]] int const* n_classes) const { + double const** value, [[maybe_unused]] std::int32_t const* n_classes) const { builder.LeafScalar(value[tree_id][node_id]); } diff --git a/tests/python/test_sklearn_integration.py b/tests/python/test_sklearn_integration.py index 5ecd2aca..6bc5236c 100644 --- a/tests/python/test_sklearn_integration.py +++ b/tests/python/test_sklearn_integration.py @@ -163,7 +163,7 @@ def test_skl_multitarget_multiclass_rf(n_classes, n_estimators): tl_model = treelite.sklearn.import_model(clf) out_prob = treelite.gtil.predict(tl_model, X) - expected_prob = clf.predict_proba(X) + expected_prob = np.transpose(clf.predict_proba(X), axes=(1, 0, 2)) np.testing.assert_almost_equal(out_prob, expected_prob, decimal=5)