From 8dca9a582daa400f951fd13111b62bf78c017ae5 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Sat, 10 Apr 2021 01:18:05 +0000 Subject: [PATCH 01/19] [WIP] Faster model import for sklearn tree models --- include/treelite/c_api.h | 7 +++ include/treelite/frontend.h | 6 +++ python/treelite/sklearn/__init__.py | 57 +++++++++++++++++++++- src/CMakeLists.txt | 1 + src/c_api/c_api.cc | 13 +++++ src/frontend/sklearn.cc | 73 +++++++++++++++++++++++++++++ 6 files changed, 156 insertions(+), 1 deletion(-) create mode 100644 src/frontend/sklearn.cc diff --git a/include/treelite/c_api.h b/include/treelite/c_api.h index 23fc5364..98f926eb 100644 --- a/include/treelite/c_api.h +++ b/include/treelite/c_api.h @@ -171,6 +171,13 @@ TREELITE_DLL int TreeliteLoadXGBoostJSONString(const char* json_str, TREELITE_DLL int TreeliteLoadXGBoostModelFromMemoryBuffer(const void* buf, size_t len, ModelHandle* out); + +TREELITE_DLL int TreeliteLoadSKLearnRandomForestRegressor( + int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left, + const int64_t** children_right, const int64_t** feature, const double** threshold, + const double** value, const int64_t** n_node_samples, const double** impurity, + ModelHandle* out); + /*! * \brief Query the number of trees in the model * \param handle model to query diff --git a/include/treelite/frontend.h b/include/treelite/frontend.h index 8f92c91f..c124338e 100644 --- a/include/treelite/frontend.h +++ b/include/treelite/frontend.h @@ -7,6 +7,7 @@ #ifndef TREELITE_FRONTEND_H_ #define TREELITE_FRONTEND_H_ +#include #include #include #include @@ -58,6 +59,11 @@ std::unique_ptr LoadXGBoostJSONModel(const char* filename); */ std::unique_ptr LoadXGBoostJSONModelString(const char* json_str, size_t length); +std::unique_ptr LoadSKLearnRandomForestRegressor( + int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left, + const int64_t** children_right, const int64_t** feature, const double** threshold, + const double** value, const int64_t** n_node_samples, const double** impurity); + //-------------------------------------------------------------------------- // model builder interface: build trees incrementally //-------------------------------------------------------------------------- diff --git a/python/treelite/sklearn/__init__.py b/python/treelite/sklearn/__init__.py index b26e3b2f..86ac9e7a 100644 --- a/python/treelite/sklearn/__init__.py +++ b/python/treelite/sklearn/__init__.py @@ -1,7 +1,11 @@ # coding: utf-8 """Converter to ingest scikit-learn models into Treelite""" +from __future__ import absolute_import as _abs + from ..util import TreeliteError +from ..core import _LIB, c_array, _check_call +from ..frontend import Model from .common import SKLConverterBase from .gbm_regressor import SKLGBMRegressorMixin from .gbm_classifier import SKLGBMClassifierMixin @@ -9,6 +13,8 @@ from .rf_regressor import SKLRFRegressorMixin from .rf_classifier import SKLRFClassifierMixin from .rf_multi_classifier import SKLRFMultiClassifierMixin +import ctypes +import numpy as np def import_model(sklearn_model): @@ -97,4 +103,53 @@ class SKLRFMultiClassifierConverter(SKLRFMultiClassifierMixin, SKLConverterBase) pass -__all__ = ['import_model'] +def import_model_v2(clf): + int64_ptr_type = ctypes.POINTER(ctypes.c_int64) + float64_ptr_type = ctypes.POINTER(ctypes.c_double) + + node_count = [] + children_left = [] + children_right = [] + feature = [] + threshold = [] + value = [] + n_node_samples = [] + impurity = [] + for i, estimator in enumerate(clf.estimators_): + tree = estimator.tree_ + node_count_v = tree.node_count + node_count.append(node_count_v) + assert tree.children_left.shape == (node_count_v,) + children_left_v = np.array(tree.children_left, copy=False, dtype=np.int64, order='C') + children_left.append(children_left_v.ctypes.data_as(int64_ptr_type)) + assert tree.children_right.shape == (node_count_v,) + children_right_v = np.array(tree.children_right, copy=False, dtype=np.int64, order='C') + children_right.append(children_right_v.ctypes.data_as(int64_ptr_type)) + assert tree.feature.shape == (node_count_v,) + feature_v = np.array(tree.feature, copy=False, dtype=np.int64, order='C') + feature.append(feature_v.ctypes.data_as(int64_ptr_type)) + assert tree.threshold.shape == (node_count_v,) + threshold_v = np.array(tree.threshold, copy=False, dtype=np.float64, order='C') + threshold.append(threshold_v.ctypes.data_as(float64_ptr_type)) + assert tree.value.shape == (node_count_v, 1, 1) + value_v = np.array(tree.value, copy=False, dtype=np.float64, order='C') + value.append(value_v.ctypes.data_as(float64_ptr_type)) + assert tree.n_node_samples.shape == (node_count_v,) + n_node_samples_v = np.array(tree.n_node_samples, copy=False, dtype=np.int64, order='C') + n_node_samples.append(n_node_samples_v.ctypes.data_as(int64_ptr_type)) + assert tree.impurity.shape == (node_count_v,) + impurity_v = np.array(tree.impurity, copy=False, dtype=np.float64, order='C') + impurity.append(impurity_v.ctypes.data_as(float64_ptr_type)) + + handle = ctypes.c_void_p() + _check_call(_LIB.TreeliteLoadSKLearnRandomForestRegressor( + ctypes.c_int(clf.n_estimators), ctypes.c_int(clf.n_features_), + c_array(ctypes.c_int64, node_count), c_array(int64_ptr_type, children_left), + c_array(int64_ptr_type, children_right), c_array(int64_ptr_type, feature), + c_array(float64_ptr_type, threshold), c_array(float64_ptr_type, value), + c_array(int64_ptr_type, n_node_samples), c_array(float64_ptr_type, impurity), + ctypes.byref(handle))) + return Model(handle) + + +__all__ = ['import_model', 'import_model_v2'] diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8fd83910..38be2a13 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -90,6 +90,7 @@ target_sources(objtreelite compiler/pred_transform.h frontend/builder.cc frontend/lightgbm.cc + frontend/sklearn.cc frontend/xgboost.cc frontend/xgboost_json.cc frontend/xgboost_util.cc diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index fbf97775..8d183434 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -176,6 +176,19 @@ int TreeliteLoadXGBoostModelFromMemoryBuffer(const void* buf, size_t len, ModelH API_END(); } +int TreeliteLoadSKLearnRandomForestRegressor( + int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left, + const int64_t** children_right, const int64_t** feature, const double** threshold, + const double** value, const int64_t** n_node_samples, const double** impurity, + ModelHandle* out) { + API_BEGIN(); + std::unique_ptr model = frontend::LoadSKLearnRandomForestRegressor( + n_estimators, n_features, node_count, children_left, children_right, feature, threshold, + value, n_node_samples, impurity); + *out = static_cast(model.release()); + API_END(); +} + int TreeliteFreeModel(ModelHandle handle) { API_BEGIN(); delete static_cast(handle); diff --git a/src/frontend/sklearn.cc b/src/frontend/sklearn.cc new file mode 100644 index 00000000..6c426eec --- /dev/null +++ b/src/frontend/sklearn.cc @@ -0,0 +1,73 @@ +#include +#include +#include +#include +#include + +namespace treelite { +namespace frontend { + +std::unique_ptr LoadSKLearnRandomForestRegressor( + int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left, + const int64_t** children_right, const int64_t** feature, const double** threshold, + const double** value, const int64_t** n_node_samples, const double** impurity) { + CHECK_GT(n_estimators, 0); + CHECK_GT(n_features, 0); + + std::unique_ptr model_ptr = treelite::Model::Create(); + auto* model = dynamic_cast*>(model_ptr.get()); + model->num_feature = n_features; + model->average_tree_output = true; + model->task_type = treelite::TaskType::kBinaryClfRegr; + model->task_param.grove_per_class = false; + model->task_param.output_type = treelite::TaskParameter::OutputType::kFloat; + model->task_param.num_class = 1; + model->task_param.leaf_vector_size = 1; + std::strncpy(model->param.pred_transform, "identity", sizeof(model->param.pred_transform)); + model->param.global_bias = 0.0f; + + for (int tree_id = 0; tree_id < n_estimators; ++tree_id) { + model->trees.emplace_back(); + treelite::Tree& tree = model->trees.back(); + tree.Init(); + + // assign node ID's so that a breadth-wise traversal would yield + // the monotonic sequence 0, 1, 2, ... + std::queue> Q; // (old ID, new ID) pair + Q.push({0, 0}); + const int64_t total_sample_cnt = n_node_samples[tree_id][0]; + while (!Q.empty()) { + int64_t node_id; + int new_node_id; + std::tie(node_id, new_node_id) = Q.front(); Q.pop(); + const int64_t left_child_id = children_left[tree_id][node_id]; + const int64_t right_child_id = children_right[tree_id][node_id]; + const int64_t sample_cnt = n_node_samples[tree_id][node_id]; + if (left_child_id == -1) { // leaf node + const double leaf_value = value[tree_id][node_id]; + tree.SetLeaf(new_node_id, leaf_value); + } else { + const int64_t split_index = feature[tree_id][node_id]; + const double split_cond = threshold[tree_id][node_id]; + const int64_t left_child_sample_cnt = n_node_samples[tree_id][left_child_id]; + const int64_t right_child_sample_cnt = n_node_samples[tree_id][right_child_id]; + const double gain = sample_cnt * ( + impurity[tree_id][node_id] + - left_child_sample_cnt * impurity[tree_id][left_child_id] / sample_cnt + - right_child_sample_cnt * impurity[tree_id][right_child_id] / sample_cnt + ) / total_sample_cnt; + + tree.AddChilds(new_node_id); + tree.SetNumericalSplit(new_node_id, split_index, split_cond, true, treelite::Operator::kLE); + tree.SetGain(new_node_id, gain); + Q.push({left_child_id, tree.LeftChild(new_node_id)}); + Q.push({right_child_id, tree.RightChild(new_node_id)}); + } + tree.SetDataCount(new_node_id, sample_cnt); + } + } + return model_ptr; +} + +} // namespace frontend +} // namespace treelite From 0ad56f6d4caf6c2a803c37d961d0e6906e522f80 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Sat, 10 Apr 2021 01:24:27 +0000 Subject: [PATCH 02/19] Comply with formatting guideline --- python/treelite/sklearn/__init__.py | 34 ++++++++++++++++++++++++----- src/frontend/sklearn.cc | 12 +++++++--- 2 files changed, 37 insertions(+), 9 deletions(-) diff --git a/python/treelite/sklearn/__init__.py b/python/treelite/sklearn/__init__.py index 86ac9e7a..56d55bd0 100644 --- a/python/treelite/sklearn/__init__.py +++ b/python/treelite/sklearn/__init__.py @@ -3,6 +3,9 @@ from __future__ import absolute_import as _abs +import ctypes +import numpy as np + from ..util import TreeliteError from ..core import _LIB, c_array, _check_call from ..frontend import Model @@ -13,8 +16,6 @@ from .rf_regressor import SKLRFRegressorMixin from .rf_classifier import SKLRFClassifierMixin from .rf_multi_classifier import SKLRFMultiClassifierMixin -import ctypes -import numpy as np def import_model(sklearn_model): @@ -103,10 +104,31 @@ class SKLRFMultiClassifierConverter(SKLRFMultiClassifierMixin, SKLConverterBase) pass -def import_model_v2(clf): +def import_model_v2(sklearn_model): + # pylint: disable=R0914 + """ + Load a tree ensemble model from a scikit-learn model object + + Parameters + ---------- + sklearn_model : object of type \ + :py:class:`~sklearn.ensemble.RandomForestRegressor` + Python handle to scikit-learn model + + Returns + ------- + model : :py:class:`~treelite.Model` object + loaded model + """ + class_name = sklearn_model.__class__.__name__ + module_name = sklearn_model.__module__.split('.')[0] + if module_name != 'sklearn': + raise Exception('Not a scikit-learn model') + if class_name != 'RandomForestRegressor': + raise Exception('Only RandomForestRegressor supported for now') + int64_ptr_type = ctypes.POINTER(ctypes.c_int64) float64_ptr_type = ctypes.POINTER(ctypes.c_double) - node_count = [] children_left = [] children_right = [] @@ -115,7 +137,7 @@ def import_model_v2(clf): value = [] n_node_samples = [] impurity = [] - for i, estimator in enumerate(clf.estimators_): + for estimator in sklearn_model.estimators_: tree = estimator.tree_ node_count_v = tree.node_count node_count.append(node_count_v) @@ -143,7 +165,7 @@ def import_model_v2(clf): handle = ctypes.c_void_p() _check_call(_LIB.TreeliteLoadSKLearnRandomForestRegressor( - ctypes.c_int(clf.n_estimators), ctypes.c_int(clf.n_features_), + ctypes.c_int(sklearn_model.n_estimators), ctypes.c_int(sklearn_model.n_features_), c_array(ctypes.c_int64, node_count), c_array(int64_ptr_type, children_left), c_array(int64_ptr_type, children_right), c_array(int64_ptr_type, feature), c_array(float64_ptr_type, threshold), c_array(float64_ptr_type, value), diff --git a/src/frontend/sklearn.cc b/src/frontend/sklearn.cc index 6c426eec..8e2bafef 100644 --- a/src/frontend/sklearn.cc +++ b/src/frontend/sklearn.cc @@ -1,3 +1,9 @@ +/*! + * Copyright (c) 2021 by Contributors + * \file sklearn.cc + * \brief Frontend for scikit-learn models + * \author Hyunsu Cho + */ #include #include #include @@ -13,7 +19,7 @@ std::unique_ptr LoadSKLearnRandomForestRegressor( const double** value, const int64_t** n_node_samples, const double** impurity) { CHECK_GT(n_estimators, 0); CHECK_GT(n_features, 0); - + std::unique_ptr model_ptr = treelite::Model::Create(); auto* model = dynamic_cast*>(model_ptr.get()); model->num_feature = n_features; @@ -54,8 +60,8 @@ std::unique_ptr LoadSKLearnRandomForestRegressor( const double gain = sample_cnt * ( impurity[tree_id][node_id] - left_child_sample_cnt * impurity[tree_id][left_child_id] / sample_cnt - - right_child_sample_cnt * impurity[tree_id][right_child_id] / sample_cnt - ) / total_sample_cnt; + - right_child_sample_cnt * impurity[tree_id][right_child_id] / sample_cnt) + / total_sample_cnt; tree.AddChilds(new_node_id); tree.SetNumericalSplit(new_node_id, split_index, split_cond, true, treelite::Operator::kLE); From 5860bcde66f71edccc30f81608869a082313f315 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Sat, 10 Apr 2021 03:05:19 +0000 Subject: [PATCH 03/19] Add test for RandomForestRegressor --- tests/python/test_skl_importer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/python/test_skl_importer.py b/tests/python/test_skl_importer.py index 9a322a4d..e5b7bee8 100644 --- a/tests/python/test_skl_importer.py +++ b/tests/python/test_skl_importer.py @@ -128,7 +128,10 @@ def test_skl_converter_regressor(tmpdir, clazz, toolchain): # pylint: disable=t clf.fit(X, y) expected_pred = clf.predict(X) - model = treelite.sklearn.import_model(clf) + if clazz == RandomForestRegressor: + model = treelite.sklearn.import_model_v2(clf) + else: + model = treelite.sklearn.import_model(clf) assert model.num_feature == clf.n_features_ assert model.num_class == 1 assert model.num_tree == clf.n_estimators From 12b680ab5166ca5d2cbeb275669c248be54e6fc7 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Wed, 21 Apr 2021 13:58:11 -0700 Subject: [PATCH 04/19] Clean up boilerplate --- python/treelite/sklearn/__init__.py | 77 +++++++++++++++-------------- 1 file changed, 41 insertions(+), 36 deletions(-) diff --git a/python/treelite/sklearn/__init__.py b/python/treelite/sklearn/__init__.py index 56d55bd0..aa3d7050 100644 --- a/python/treelite/sklearn/__init__.py +++ b/python/treelite/sklearn/__init__.py @@ -104,6 +104,29 @@ class SKLRFMultiClassifierConverter(SKLRFMultiClassifierMixin, SKLConverterBase) pass +class ArrayOfArrays: + def __init__(self, *, dtype): + int64_ptr_type = ctypes.POINTER(ctypes.c_int64) + float64_ptr_type = ctypes.POINTER(ctypes.c_double) + if dtype == np.int64: + self.ptr_type = int64_ptr_type + elif dtype == np.float64: + self.ptr_type = float64_ptr_type + else: + raise ValueError(f'dtype {dtype} is not supported') + self.dtype = dtype + self.collection = [] + + def add(self, array, *, expected_shape=None): + if expected_shape: + assert array.shape == expected_shape + v = np.array(array, copy=False, dtype=self.dtype, order='C') + self.collection.append(v.ctypes.data_as(self.ptr_type)) + + def as_c_array(self): + return c_array(self.ptr_type, self.collection) + + def import_model_v2(sklearn_model): # pylint: disable=R0914 """ @@ -127,49 +150,31 @@ def import_model_v2(sklearn_model): if class_name != 'RandomForestRegressor': raise Exception('Only RandomForestRegressor supported for now') - int64_ptr_type = ctypes.POINTER(ctypes.c_int64) - float64_ptr_type = ctypes.POINTER(ctypes.c_double) node_count = [] - children_left = [] - children_right = [] - feature = [] - threshold = [] - value = [] - n_node_samples = [] - impurity = [] + children_left = ArrayOfArrays(dtype=np.int64) + children_right = ArrayOfArrays(dtype=np.int64) + feature = ArrayOfArrays(dtype=np.int64) + threshold = ArrayOfArrays(dtype=np.float64) + value = ArrayOfArrays(dtype=np.float64) + n_node_samples = ArrayOfArrays(dtype=np.int64) + impurity = ArrayOfArrays(dtype=np.float64) for estimator in sklearn_model.estimators_: tree = estimator.tree_ - node_count_v = tree.node_count - node_count.append(node_count_v) - assert tree.children_left.shape == (node_count_v,) - children_left_v = np.array(tree.children_left, copy=False, dtype=np.int64, order='C') - children_left.append(children_left_v.ctypes.data_as(int64_ptr_type)) - assert tree.children_right.shape == (node_count_v,) - children_right_v = np.array(tree.children_right, copy=False, dtype=np.int64, order='C') - children_right.append(children_right_v.ctypes.data_as(int64_ptr_type)) - assert tree.feature.shape == (node_count_v,) - feature_v = np.array(tree.feature, copy=False, dtype=np.int64, order='C') - feature.append(feature_v.ctypes.data_as(int64_ptr_type)) - assert tree.threshold.shape == (node_count_v,) - threshold_v = np.array(tree.threshold, copy=False, dtype=np.float64, order='C') - threshold.append(threshold_v.ctypes.data_as(float64_ptr_type)) - assert tree.value.shape == (node_count_v, 1, 1) - value_v = np.array(tree.value, copy=False, dtype=np.float64, order='C') - value.append(value_v.ctypes.data_as(float64_ptr_type)) - assert tree.n_node_samples.shape == (node_count_v,) - n_node_samples_v = np.array(tree.n_node_samples, copy=False, dtype=np.int64, order='C') - n_node_samples.append(n_node_samples_v.ctypes.data_as(int64_ptr_type)) - assert tree.impurity.shape == (node_count_v,) - impurity_v = np.array(tree.impurity, copy=False, dtype=np.float64, order='C') - impurity.append(impurity_v.ctypes.data_as(float64_ptr_type)) + node_count.append(tree.node_count) + children_left.add(tree.children_left, expected_shape=(tree.node_count,)) + children_right.add(tree.children_right, expected_shape=(tree.node_count,)) + feature.add(tree.feature, expected_shape=(tree.node_count,)) + threshold.add(tree.threshold, expected_shape=(tree.node_count,)) + value.add(tree.value, expected_shape=(tree.node_count, 1, 1)) + n_node_samples.add(tree.n_node_samples, expected_shape=(tree.node_count,)) + impurity.add(tree.impurity, expected_shape=(tree.node_count,)) handle = ctypes.c_void_p() _check_call(_LIB.TreeliteLoadSKLearnRandomForestRegressor( ctypes.c_int(sklearn_model.n_estimators), ctypes.c_int(sklearn_model.n_features_), - c_array(ctypes.c_int64, node_count), c_array(int64_ptr_type, children_left), - c_array(int64_ptr_type, children_right), c_array(int64_ptr_type, feature), - c_array(float64_ptr_type, threshold), c_array(float64_ptr_type, value), - c_array(int64_ptr_type, n_node_samples), c_array(float64_ptr_type, impurity), + c_array(ctypes.c_int64, node_count), children_left.as_c_array(), + children_right.as_c_array(), feature.as_c_array(), threshold.as_c_array(), + value.as_c_array(), n_node_samples.as_c_array(), impurity.as_c_array(), ctypes.byref(handle))) return Model(handle) From d540cb76ff42e6bbaf20b5d8f7fa4f870638d3e8 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Wed, 21 Apr 2021 14:42:45 -0700 Subject: [PATCH 05/19] Add support for RandomForestClassifier with n_classes_ = 2 --- include/treelite/c_api.h | 5 ++ include/treelite/frontend.h | 6 ++ python/treelite/sklearn/__init__.py | 39 +++++++++---- src/c_api/c_api.cc | 13 +++++ src/frontend/sklearn.cc | 89 +++++++++++++++++++++++++++++ tests/python/test_skl_importer.py | 5 +- 6 files changed, 145 insertions(+), 12 deletions(-) diff --git a/include/treelite/c_api.h b/include/treelite/c_api.h index 98f926eb..c90846d0 100644 --- a/include/treelite/c_api.h +++ b/include/treelite/c_api.h @@ -177,6 +177,11 @@ TREELITE_DLL int TreeliteLoadSKLearnRandomForestRegressor( const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, const double** impurity, ModelHandle* out); +TREELITE_DLL int TreeliteLoadSKLearnRandomForestClassifier( + int n_estimators, int n_features, int n_classes, const int64_t* node_count, + const int64_t** children_left, const int64_t** children_right, const int64_t** feature, + const double** threshold, const double** value, const int64_t** n_node_samples, + const double** impurity, ModelHandle* out); /*! * \brief Query the number of trees in the model diff --git a/include/treelite/frontend.h b/include/treelite/frontend.h index c124338e..b66df0db 100644 --- a/include/treelite/frontend.h +++ b/include/treelite/frontend.h @@ -64,6 +64,12 @@ std::unique_ptr LoadSKLearnRandomForestRegressor( const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, const double** impurity); +std::unique_ptr LoadSKLearnRandomForestClassifier( + int n_estimators, int n_features, int n_classes, const int64_t* node_count, + const int64_t** children_left, const int64_t** children_right, const int64_t** feature, + const double** threshold, const double** value, const int64_t** n_node_samples, + const double** impurity); + //-------------------------------------------------------------------------- // model builder interface: build trees incrementally //-------------------------------------------------------------------------- diff --git a/python/treelite/sklearn/__init__.py b/python/treelite/sklearn/__init__.py index aa3d7050..56638e37 100644 --- a/python/treelite/sklearn/__init__.py +++ b/python/treelite/sklearn/__init__.py @@ -118,8 +118,10 @@ def __init__(self, *, dtype): self.collection = [] def add(self, array, *, expected_shape=None): + assert array.dtype == self.dtype if expected_shape: - assert array.shape == expected_shape + assert array.shape == expected_shape, \ + f'Expected shape: {expected_shape}, Got shape {array.shape}' v = np.array(array, copy=False, dtype=self.dtype, order='C') self.collection.append(v.ctypes.data_as(self.ptr_type)) @@ -146,9 +148,14 @@ def import_model_v2(sklearn_model): class_name = sklearn_model.__class__.__name__ module_name = sklearn_model.__module__.split('.')[0] if module_name != 'sklearn': - raise Exception('Not a scikit-learn model') - if class_name != 'RandomForestRegressor': - raise Exception('Only RandomForestRegressor supported for now') + raise TreeliteError('Not a scikit-learn model') + + if class_name == 'RandomForestRegressor': + leaf_value_expected_shape = lambda node_count: (node_count, 1, 1) + elif class_name == 'RandomForestClassifier': + leaf_value_expected_shape = lambda node_count: (node_count, 1, sklearn_model.n_classes_) + else: + raise TreeliteError(f'Not supported: {class_name}') node_count = [] children_left = ArrayOfArrays(dtype=np.int64) @@ -165,17 +172,27 @@ def import_model_v2(sklearn_model): children_right.add(tree.children_right, expected_shape=(tree.node_count,)) feature.add(tree.feature, expected_shape=(tree.node_count,)) threshold.add(tree.threshold, expected_shape=(tree.node_count,)) - value.add(tree.value, expected_shape=(tree.node_count, 1, 1)) + value.add(tree.value, expected_shape=leaf_value_expected_shape(tree.node_count)) n_node_samples.add(tree.n_node_samples, expected_shape=(tree.node_count,)) impurity.add(tree.impurity, expected_shape=(tree.node_count,)) handle = ctypes.c_void_p() - _check_call(_LIB.TreeliteLoadSKLearnRandomForestRegressor( - ctypes.c_int(sklearn_model.n_estimators), ctypes.c_int(sklearn_model.n_features_), - c_array(ctypes.c_int64, node_count), children_left.as_c_array(), - children_right.as_c_array(), feature.as_c_array(), threshold.as_c_array(), - value.as_c_array(), n_node_samples.as_c_array(), impurity.as_c_array(), - ctypes.byref(handle))) + if class_name == 'RandomForestRegressor': + _check_call(_LIB.TreeliteLoadSKLearnRandomForestRegressor( + ctypes.c_int(sklearn_model.n_estimators), ctypes.c_int(sklearn_model.n_features_), + c_array(ctypes.c_int64, node_count), children_left.as_c_array(), + children_right.as_c_array(), feature.as_c_array(), threshold.as_c_array(), + value.as_c_array(), n_node_samples.as_c_array(), impurity.as_c_array(), + ctypes.byref(handle))) + elif class_name == 'RandomForestClassifier': + _check_call(_LIB.TreeliteLoadSKLearnRandomForestClassifier( + ctypes.c_int(sklearn_model.n_estimators), ctypes.c_int(sklearn_model.n_features_), + ctypes.c_int(sklearn_model.n_classes_), c_array(ctypes.c_int64, node_count), + children_left.as_c_array(), children_right.as_c_array(), feature.as_c_array(), + threshold.as_c_array(), value.as_c_array(), n_node_samples.as_c_array(), + impurity.as_c_array(), ctypes.byref(handle))) + else: + raise TreeliteError(f'Not supported: {class_name}') return Model(handle) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 8d183434..12d0c277 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -189,6 +189,19 @@ int TreeliteLoadSKLearnRandomForestRegressor( API_END(); } +int TreeliteLoadSKLearnRandomForestClassifier( + int n_estimators, int n_features, int n_classes, const int64_t* node_count, + const int64_t** children_left, const int64_t** children_right, const int64_t** feature, + const double** threshold, const double** value, const int64_t** n_node_samples, + const double** impurity, ModelHandle* out) { + API_BEGIN(); + std::unique_ptr model = frontend::LoadSKLearnRandomForestClassifier( + n_estimators, n_features, n_classes, node_count, children_left, children_right, feature, + threshold, value, n_node_samples, impurity); + *out = static_cast(model.release()); + API_END(); +} + int TreeliteFreeModel(ModelHandle handle) { API_BEGIN(); delete static_cast(handle); diff --git a/src/frontend/sklearn.cc b/src/frontend/sklearn.cc index 8e2bafef..8d97a4d4 100644 --- a/src/frontend/sklearn.cc +++ b/src/frontend/sklearn.cc @@ -75,5 +75,94 @@ std::unique_ptr LoadSKLearnRandomForestRegressor( return model_ptr; } +std::unique_ptr LoadSKLearnRandomForestClassifierBinary( + int n_estimators, int n_features, int n_classes, const int64_t* node_count, + const int64_t** children_left, const int64_t** children_right, const int64_t** feature, + const double** threshold, const double** value, const int64_t** n_node_samples, + const double** impurity) { + CHECK_GT(n_estimators, 0); + CHECK_GT(n_features, 0); + + std::unique_ptr model_ptr = treelite::Model::Create(); + auto* model = dynamic_cast*>(model_ptr.get()); + model->num_feature = n_features; + model->average_tree_output = true; + model->task_type = treelite::TaskType::kBinaryClfRegr; + model->task_param.grove_per_class = false; + model->task_param.output_type = treelite::TaskParameter::OutputType::kFloat; + model->task_param.num_class = 1; + model->task_param.leaf_vector_size = 1; + std::strncpy(model->param.pred_transform, "identity", sizeof(model->param.pred_transform)); + model->param.global_bias = 0.0f; + + for (int tree_id = 0; tree_id < n_estimators; ++tree_id) { + model->trees.emplace_back(); + treelite::Tree& tree = model->trees.back(); + tree.Init(); + + // assign node ID's so that a breadth-wise traversal would yield + // the monotonic sequence 0, 1, 2, ... + std::queue> Q; // (old ID, new ID) pair + Q.push({0, 0}); + const int64_t total_sample_cnt = n_node_samples[tree_id][0]; + while (!Q.empty()) { + int64_t node_id; + int new_node_id; + std::tie(node_id, new_node_id) = Q.front(); Q.pop(); + const int64_t left_child_id = children_left[tree_id][node_id]; + const int64_t right_child_id = children_right[tree_id][node_id]; + const int64_t sample_cnt = n_node_samples[tree_id][node_id]; + if (left_child_id == -1) { // leaf node + // # Get counts for each label (+/-) at this leaf node + const double* leaf_count = &value[tree_id][node_id * 2]; + // Compute the fraction of positive data points at this leaf node + const double fraction_positive = leaf_count[1] / (leaf_count[0] + leaf_count[1]); + tree.SetLeaf(new_node_id, fraction_positive); + } else { + const int64_t split_index = feature[tree_id][node_id]; + const double split_cond = threshold[tree_id][node_id]; + const int64_t left_child_sample_cnt = n_node_samples[tree_id][left_child_id]; + const int64_t right_child_sample_cnt = n_node_samples[tree_id][right_child_id]; + const double gain = sample_cnt * ( + impurity[tree_id][node_id] + - left_child_sample_cnt * impurity[tree_id][left_child_id] / sample_cnt + - right_child_sample_cnt * impurity[tree_id][right_child_id] / sample_cnt) + / total_sample_cnt; + + tree.AddChilds(new_node_id); + tree.SetNumericalSplit(new_node_id, split_index, split_cond, true, treelite::Operator::kLE); + tree.SetGain(new_node_id, gain); + Q.push({left_child_id, tree.LeftChild(new_node_id)}); + Q.push({right_child_id, tree.RightChild(new_node_id)}); + } + tree.SetDataCount(new_node_id, sample_cnt); + } + } + return model_ptr; +} + +std::unique_ptr LoadSKLearnRandomForestClassifierMulticlass( + int n_estimators, int n_features, int n_classes, const int64_t* node_count, + const int64_t** children_left, const int64_t** children_right, const int64_t** feature, + const double** threshold, const double** value, const int64_t** n_node_samples, + const double** impurity) { +} + +std::unique_ptr LoadSKLearnRandomForestClassifier( + int n_estimators, int n_features, int n_classes, const int64_t* node_count, + const int64_t** children_left, const int64_t** children_right, const int64_t** feature, + const double** threshold, const double** value, const int64_t** n_node_samples, + const double** impurity) { + CHECK_GE(n_classes, 2); + if (n_classes == 2) { + return LoadSKLearnRandomForestClassifierBinary(n_estimators, n_features, n_classes, node_count, + children_left, children_right, feature, threshold, value, n_node_samples, impurity); + } else { + return LoadSKLearnRandomForestClassifierMulticlass(n_estimators, n_features, n_classes, + node_count, children_left, children_right, feature, threshold, value, n_node_samples, + impurity); + } +} + } // namespace frontend } // namespace treelite diff --git a/tests/python/test_skl_importer.py b/tests/python/test_skl_importer.py index e5b7bee8..f0ebd383 100644 --- a/tests/python/test_skl_importer.py +++ b/tests/python/test_skl_importer.py @@ -89,7 +89,10 @@ def test_skl_converter_binary_classifier(tmpdir, clazz, toolchain): clf.fit(X, y) expected_prob = clf.predict_proba(X)[:, 1] - model = treelite.sklearn.import_model(clf) + if clazz == RandomForestClassifier: + model = treelite.sklearn.import_model_v2(clf) + else: + model = treelite.sklearn.import_model(clf) assert model.num_feature == clf.n_features_ assert model.num_class == 1 assert model.num_tree == clf.n_estimators From ea2e885123bb65d2bdcc279be6361159d210081e Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Wed, 21 Apr 2021 14:57:01 -0700 Subject: [PATCH 06/19] Add support for RandomForestClassifier with n_classes_ > 2 --- src/frontend/sklearn.cc | 73 +++++++++++++++++++++++++++++-- tests/python/test_skl_importer.py | 5 ++- 2 files changed, 74 insertions(+), 4 deletions(-) diff --git a/src/frontend/sklearn.cc b/src/frontend/sklearn.cc index 8d97a4d4..2355cc1c 100644 --- a/src/frontend/sklearn.cc +++ b/src/frontend/sklearn.cc @@ -8,6 +8,8 @@ #include #include #include +#include +#include #include namespace treelite { @@ -37,7 +39,7 @@ std::unique_ptr LoadSKLearnRandomForestRegressor( treelite::Tree& tree = model->trees.back(); tree.Init(); - // assign node ID's so that a breadth-wise traversal would yield + // Assign node ID's so that a breadth-wise traversal would yield // the monotonic sequence 0, 1, 2, ... std::queue> Q; // (old ID, new ID) pair Q.push({0, 0}); @@ -100,7 +102,7 @@ std::unique_ptr LoadSKLearnRandomForestClassifierBinary( treelite::Tree& tree = model->trees.back(); tree.Init(); - // assign node ID's so that a breadth-wise traversal would yield + // Assign node ID's so that a breadth-wise traversal would yield // the monotonic sequence 0, 1, 2, ... std::queue> Q; // (old ID, new ID) pair Q.push({0, 0}); @@ -113,7 +115,7 @@ std::unique_ptr LoadSKLearnRandomForestClassifierBinary( const int64_t right_child_id = children_right[tree_id][node_id]; const int64_t sample_cnt = n_node_samples[tree_id][node_id]; if (left_child_id == -1) { // leaf node - // # Get counts for each label (+/-) at this leaf node + // Get counts for each label (+/-) at this leaf node const double* leaf_count = &value[tree_id][node_id * 2]; // Compute the fraction of positive data points at this leaf node const double fraction_positive = leaf_count[1] / (leaf_count[0] + leaf_count[1]); @@ -146,6 +148,71 @@ std::unique_ptr LoadSKLearnRandomForestClassifierMulticlass( const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, const double** impurity) { + CHECK_GT(n_estimators, 0); + CHECK_GT(n_features, 0); + + std::unique_ptr model_ptr = treelite::Model::Create(); + auto* model = dynamic_cast*>(model_ptr.get()); + model->num_feature = n_features; + model->average_tree_output = true; + model->task_type = treelite::TaskType::kMultiClfProbDistLeaf; + model->task_param.grove_per_class = false; + model->task_param.output_type = treelite::TaskParameter::OutputType::kFloat; + model->task_param.num_class = n_classes; + model->task_param.leaf_vector_size = n_classes; + std::strncpy(model->param.pred_transform, "identity_multiclass", + sizeof(model->param.pred_transform)); + model->param.global_bias = 0.0f; + + for (int tree_id = 0; tree_id < n_estimators; ++tree_id) { + model->trees.emplace_back(); + treelite::Tree& tree = model->trees.back(); + tree.Init(); + + // Assign node ID's so that a breadth-wise traversal would yield + // the monotonic sequence 0, 1, 2, ... + std::queue> Q; // (old ID, new ID) pair + Q.push({0, 0}); + const int64_t total_sample_cnt = n_node_samples[tree_id][0]; + while (!Q.empty()) { + int64_t node_id; + int new_node_id; + std::tie(node_id, new_node_id) = Q.front(); Q.pop(); + const int64_t left_child_id = children_left[tree_id][node_id]; + const int64_t right_child_id = children_right[tree_id][node_id]; + const int64_t sample_cnt = n_node_samples[tree_id][node_id]; + if (left_child_id == -1) { // leaf node + // Get counts for each label class at this leaf node + std::vector prob_distribution(&value[tree_id][node_id * n_classes], + &value[tree_id][(node_id + 1) * n_classes]); + // Compute the probability distribution over label classes + const double norm_factor = + std::accumulate(prob_distribution.begin(), prob_distribution.end(), 0.0); + std::for_each(prob_distribution.begin(), prob_distribution.end(), [norm_factor](double& e) { + e /= norm_factor; + }); + tree.SetLeafVector(new_node_id, prob_distribution); + } else { + const int64_t split_index = feature[tree_id][node_id]; + const double split_cond = threshold[tree_id][node_id]; + const int64_t left_child_sample_cnt = n_node_samples[tree_id][left_child_id]; + const int64_t right_child_sample_cnt = n_node_samples[tree_id][right_child_id]; + const double gain = sample_cnt * ( + impurity[tree_id][node_id] + - left_child_sample_cnt * impurity[tree_id][left_child_id] / sample_cnt + - right_child_sample_cnt * impurity[tree_id][right_child_id] / sample_cnt) + / total_sample_cnt; + + tree.AddChilds(new_node_id); + tree.SetNumericalSplit(new_node_id, split_index, split_cond, true, treelite::Operator::kLE); + tree.SetGain(new_node_id, gain); + Q.push({left_child_id, tree.LeftChild(new_node_id)}); + Q.push({right_child_id, tree.RightChild(new_node_id)}); + } + tree.SetDataCount(new_node_id, sample_cnt); + } + } + return model_ptr; } std::unique_ptr LoadSKLearnRandomForestClassifier( diff --git a/tests/python/test_skl_importer.py b/tests/python/test_skl_importer.py index f0ebd383..10e3a332 100644 --- a/tests/python/test_skl_importer.py +++ b/tests/python/test_skl_importer.py @@ -48,7 +48,10 @@ def test_skl_converter_multiclass_classifier(tmpdir, clazz, toolchain): clf.fit(X, y) expected_prob = clf.predict_proba(X) - model = treelite.sklearn.import_model(clf) + if clazz == RandomForestClassifier: + model = treelite.sklearn.import_model_v2(clf) + else: + model = treelite.sklearn.import_model(clf) assert model.num_feature == clf.n_features_ assert model.num_class == clf.n_classes_ assert (model.num_tree == From d9792c097eb0e06f8fffa1cc423e1c2e5b7be61a Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Wed, 21 Apr 2021 15:02:18 -0700 Subject: [PATCH 07/19] Add support for ExtraTreesRegressor and ExtraTreesClassifier --- python/treelite/sklearn/__init__.py | 15 ++++++++++----- tests/python/test_skl_importer.py | 6 +++--- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/python/treelite/sklearn/__init__.py b/python/treelite/sklearn/__init__.py index 56638e37..5a58040c 100644 --- a/python/treelite/sklearn/__init__.py +++ b/python/treelite/sklearn/__init__.py @@ -27,6 +27,8 @@ def import_model(sklearn_model): sklearn_model : object of type \ :py:class:`~sklearn.ensemble.RandomForestRegressor` / \ :py:class:`~sklearn.ensemble.RandomForestClassifier` / \ + :py:class:`~sklearn.ensemble.ExtraTreesRegressor` / \ + :py:class:`~sklearn.ensemble.ExtraTreesClassifier` / \ :py:class:`~sklearn.ensemble.GradientBoostingRegressor` / \ :py:class:`~sklearn.ensemble.GradientBoostingClassifier` Python handle to scikit-learn model @@ -137,7 +139,10 @@ def import_model_v2(sklearn_model): Parameters ---------- sklearn_model : object of type \ - :py:class:`~sklearn.ensemble.RandomForestRegressor` + :py:class:`~sklearn.ensemble.RandomForestRegressor` / \ + :py:class:`~sklearn.ensemble.RandomForestClassifier` / \ + :py:class:`~sklearn.ensemble.ExtraTreesRegressor` / \ + :py:class:`~sklearn.ensemble.ExtraTreesClassifier` Python handle to scikit-learn model Returns @@ -150,9 +155,9 @@ def import_model_v2(sklearn_model): if module_name != 'sklearn': raise TreeliteError('Not a scikit-learn model') - if class_name == 'RandomForestRegressor': + if class_name in ['RandomForestRegressor', 'ExtraTreesRegressor']: leaf_value_expected_shape = lambda node_count: (node_count, 1, 1) - elif class_name == 'RandomForestClassifier': + elif class_name in ['RandomForestClassifier', 'ExtraTreesClassifier']: leaf_value_expected_shape = lambda node_count: (node_count, 1, sklearn_model.n_classes_) else: raise TreeliteError(f'Not supported: {class_name}') @@ -177,14 +182,14 @@ def import_model_v2(sklearn_model): impurity.add(tree.impurity, expected_shape=(tree.node_count,)) handle = ctypes.c_void_p() - if class_name == 'RandomForestRegressor': + if class_name in ['RandomForestRegressor', 'ExtraTreesRegressor']: _check_call(_LIB.TreeliteLoadSKLearnRandomForestRegressor( ctypes.c_int(sklearn_model.n_estimators), ctypes.c_int(sklearn_model.n_features_), c_array(ctypes.c_int64, node_count), children_left.as_c_array(), children_right.as_c_array(), feature.as_c_array(), threshold.as_c_array(), value.as_c_array(), n_node_samples.as_c_array(), impurity.as_c_array(), ctypes.byref(handle))) - elif class_name == 'RandomForestClassifier': + elif class_name in ['RandomForestClassifier', 'ExtraTreesClassifier']: _check_call(_LIB.TreeliteLoadSKLearnRandomForestClassifier( ctypes.c_int(sklearn_model.n_estimators), ctypes.c_int(sklearn_model.n_features_), ctypes.c_int(sklearn_model.n_classes_), c_array(ctypes.c_int64, node_count), diff --git a/tests/python/test_skl_importer.py b/tests/python/test_skl_importer.py index 10e3a332..4fe05cf2 100644 --- a/tests/python/test_skl_importer.py +++ b/tests/python/test_skl_importer.py @@ -48,7 +48,7 @@ def test_skl_converter_multiclass_classifier(tmpdir, clazz, toolchain): clf.fit(X, y) expected_prob = clf.predict_proba(X) - if clazz == RandomForestClassifier: + if clazz in [RandomForestClassifier, ExtraTreesClassifier]: model = treelite.sklearn.import_model_v2(clf) else: model = treelite.sklearn.import_model(clf) @@ -92,7 +92,7 @@ def test_skl_converter_binary_classifier(tmpdir, clazz, toolchain): clf.fit(X, y) expected_prob = clf.predict_proba(X)[:, 1] - if clazz == RandomForestClassifier: + if clazz in [RandomForestClassifier, ExtraTreesClassifier]: model = treelite.sklearn.import_model_v2(clf) else: model = treelite.sklearn.import_model(clf) @@ -134,7 +134,7 @@ def test_skl_converter_regressor(tmpdir, clazz, toolchain): # pylint: disable=t clf.fit(X, y) expected_pred = clf.predict(X) - if clazz == RandomForestRegressor: + if clazz == [RandomForestRegressor, ExtraTreesRegressor]: model = treelite.sklearn.import_model_v2(clf) else: model = treelite.sklearn.import_model(clf) From 7b91a43bca37fbb8f77caace632fb62e94d501fd Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Wed, 21 Apr 2021 15:28:04 -0700 Subject: [PATCH 08/19] Add support for GradientBoostingRegressor --- include/treelite/c_api.h | 10 ++++ include/treelite/frontend.h | 10 +++- python/treelite/sklearn/__init__.py | 22 +++++-- src/c_api/c_api.cc | 26 ++++++++ src/frontend/sklearn.cc | 93 +++++++++++++++++++++++++++++ tests/python/test_skl_importer.py | 5 +- 6 files changed, 157 insertions(+), 9 deletions(-) diff --git a/include/treelite/c_api.h b/include/treelite/c_api.h index c90846d0..3a3ddef1 100644 --- a/include/treelite/c_api.h +++ b/include/treelite/c_api.h @@ -182,6 +182,16 @@ TREELITE_DLL int TreeliteLoadSKLearnRandomForestClassifier( const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, const double** impurity, ModelHandle* out); +TREELITE_DLL int TreeliteLoadSKLearnGradientBoostingRegressor( + int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left, + const int64_t** children_right, const int64_t** feature, const double** threshold, + const double** value, const int64_t** n_node_samples, const double** impurity, + ModelHandle* out); +TREELITE_DLL int TreeliteLoadSKLearnGradientBoostingClassifier( + int n_estimators, int n_features, int n_classes, const int64_t* node_count, + const int64_t** children_left, const int64_t** children_right, const int64_t** feature, + const double** threshold, const double** value, const int64_t** n_node_samples, + const double** impurity, ModelHandle* out); /*! * \brief Query the number of trees in the model diff --git a/include/treelite/frontend.h b/include/treelite/frontend.h index b66df0db..feb63854 100644 --- a/include/treelite/frontend.h +++ b/include/treelite/frontend.h @@ -63,12 +63,20 @@ std::unique_ptr LoadSKLearnRandomForestRegressor( int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, const double** impurity); - std::unique_ptr LoadSKLearnRandomForestClassifier( int n_estimators, int n_features, int n_classes, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, const double** impurity); +std::unique_ptr LoadSKLearnGradientBoostingRegressor( + int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left, + const int64_t** children_right, const int64_t** feature, const double** threshold, + const double** value, const int64_t** n_node_samples, const double** impurity); +std::unique_ptr LoadSKLearnGradientBoostingClassifier( + int n_estimators, int n_features, int n_classes, const int64_t* node_count, + const int64_t** children_left, const int64_t** children_right, const int64_t** feature, + const double** threshold, const double** value, const int64_t** n_node_samples, + const double** impurity); //-------------------------------------------------------------------------- // model builder interface: build trees incrementally diff --git a/python/treelite/sklearn/__init__.py b/python/treelite/sklearn/__init__.py index 5a58040c..3d452701 100644 --- a/python/treelite/sklearn/__init__.py +++ b/python/treelite/sklearn/__init__.py @@ -142,7 +142,8 @@ def import_model_v2(sklearn_model): :py:class:`~sklearn.ensemble.RandomForestRegressor` / \ :py:class:`~sklearn.ensemble.RandomForestClassifier` / \ :py:class:`~sklearn.ensemble.ExtraTreesRegressor` / \ - :py:class:`~sklearn.ensemble.ExtraTreesClassifier` + :py:class:`~sklearn.ensemble.ExtraTreesClassifier` / \ + :py:class:`~sklearn.ensemble.GradientBoostingRegressor` Python handle to scikit-learn model Returns @@ -155,7 +156,7 @@ def import_model_v2(sklearn_model): if module_name != 'sklearn': raise TreeliteError('Not a scikit-learn model') - if class_name in ['RandomForestRegressor', 'ExtraTreesRegressor']: + if class_name in ['RandomForestRegressor', 'ExtraTreesRegressor', 'GradientBoostingRegressor']: leaf_value_expected_shape = lambda node_count: (node_count, 1, 1) elif class_name in ['RandomForestClassifier', 'ExtraTreesClassifier']: leaf_value_expected_shape = lambda node_count: (node_count, 1, sklearn_model.n_classes_) @@ -171,13 +172,19 @@ def import_model_v2(sklearn_model): n_node_samples = ArrayOfArrays(dtype=np.int64) impurity = ArrayOfArrays(dtype=np.float64) for estimator in sklearn_model.estimators_: - tree = estimator.tree_ + if class_name.startswith('GradientBoosting'): + tree = estimator[0].tree_ + learning_rate = sklearn_model.learning_rate + else: + tree = estimator.tree_ + learning_rate = 1.0 node_count.append(tree.node_count) children_left.add(tree.children_left, expected_shape=(tree.node_count,)) children_right.add(tree.children_right, expected_shape=(tree.node_count,)) feature.add(tree.feature, expected_shape=(tree.node_count,)) threshold.add(tree.threshold, expected_shape=(tree.node_count,)) - value.add(tree.value, expected_shape=leaf_value_expected_shape(tree.node_count)) + value.add(tree.value * learning_rate, + expected_shape=leaf_value_expected_shape(tree.node_count)) n_node_samples.add(tree.n_node_samples, expected_shape=(tree.node_count,)) impurity.add(tree.impurity, expected_shape=(tree.node_count,)) @@ -196,6 +203,13 @@ def import_model_v2(sklearn_model): children_left.as_c_array(), children_right.as_c_array(), feature.as_c_array(), threshold.as_c_array(), value.as_c_array(), n_node_samples.as_c_array(), impurity.as_c_array(), ctypes.byref(handle))) + elif class_name == 'GradientBoostingRegressor': + _check_call(_LIB.TreeliteLoadSKLearnGradientBoostingRegressor( + ctypes.c_int(sklearn_model.n_estimators), ctypes.c_int(sklearn_model.n_features_), + c_array(ctypes.c_int64, node_count), children_left.as_c_array(), + children_right.as_c_array(), feature.as_c_array(), threshold.as_c_array(), + value.as_c_array(), n_node_samples.as_c_array(), impurity.as_c_array(), + ctypes.byref(handle))) else: raise TreeliteError(f'Not supported: {class_name}') return Model(handle) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 12d0c277..b6c18a00 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -202,6 +202,32 @@ int TreeliteLoadSKLearnRandomForestClassifier( API_END(); } +int TreeliteLoadSKLearnGradientBoostingRegressor( + int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left, + const int64_t** children_right, const int64_t** feature, const double** threshold, + const double** value, const int64_t** n_node_samples, const double** impurity, + ModelHandle* out) { + API_BEGIN(); + std::unique_ptr model = frontend::LoadSKLearnGradientBoostingRegressor( + n_estimators, n_features, node_count, children_left, children_right, feature, threshold, + value, n_node_samples, impurity); + *out = static_cast(model.release()); + API_END(); +} + +int TreeliteLoadSKLearnGradientBoostingClassifier( + int n_estimators, int n_features, int n_classes, const int64_t* node_count, + const int64_t** children_left, const int64_t** children_right, const int64_t** feature, + const double** threshold, const double** value, const int64_t** n_node_samples, + const double** impurity, ModelHandle* out) { + API_BEGIN(); + std::unique_ptr model = frontend::LoadSKLearnGradientBoostingClassifier( + n_estimators, n_features, n_classes, node_count, children_left, children_right, feature, + threshold, value, n_node_samples, impurity); + *out = static_cast(model.release()); + API_END(); +} + int TreeliteFreeModel(ModelHandle handle) { API_BEGIN(); delete static_cast(handle); diff --git a/src/frontend/sklearn.cc b/src/frontend/sklearn.cc index 2355cc1c..27ea6e24 100644 --- a/src/frontend/sklearn.cc +++ b/src/frontend/sklearn.cc @@ -231,5 +231,98 @@ std::unique_ptr LoadSKLearnRandomForestClassifier( } } +std::unique_ptr LoadSKLearnGradientBoostingRegressor( + int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left, + const int64_t** children_right, const int64_t** feature, const double** threshold, + const double** value, const int64_t** n_node_samples, const double** impurity) { + CHECK_GT(n_estimators, 0); + CHECK_GT(n_features, 0); + + std::unique_ptr model_ptr = treelite::Model::Create(); + auto* model = dynamic_cast*>(model_ptr.get()); + model->num_feature = n_features; + model->average_tree_output = false; + model->task_type = treelite::TaskType::kBinaryClfRegr; + model->task_param.grove_per_class = false; + model->task_param.output_type = treelite::TaskParameter::OutputType::kFloat; + model->task_param.num_class = 1; + model->task_param.leaf_vector_size = 1; + std::strncpy(model->param.pred_transform, "identity", sizeof(model->param.pred_transform)); + model->param.global_bias = 0.0f; + + for (int tree_id = 0; tree_id < n_estimators; ++tree_id) { + model->trees.emplace_back(); + treelite::Tree& tree = model->trees.back(); + tree.Init(); + + // Assign node ID's so that a breadth-wise traversal would yield + // the monotonic sequence 0, 1, 2, ... + std::queue> Q; // (old ID, new ID) pair + Q.push({0, 0}); + const int64_t total_sample_cnt = n_node_samples[tree_id][0]; + while (!Q.empty()) { + int64_t node_id; + int new_node_id; + std::tie(node_id, new_node_id) = Q.front(); Q.pop(); + const int64_t left_child_id = children_left[tree_id][node_id]; + const int64_t right_child_id = children_right[tree_id][node_id]; + const int64_t sample_cnt = n_node_samples[tree_id][node_id]; + if (left_child_id == -1) { // leaf node + const double leaf_value = value[tree_id][node_id]; + tree.SetLeaf(new_node_id, leaf_value); + } else { + const int64_t split_index = feature[tree_id][node_id]; + const double split_cond = threshold[tree_id][node_id]; + const int64_t left_child_sample_cnt = n_node_samples[tree_id][left_child_id]; + const int64_t right_child_sample_cnt = n_node_samples[tree_id][right_child_id]; + const double gain = sample_cnt * ( + impurity[tree_id][node_id] + - left_child_sample_cnt * impurity[tree_id][left_child_id] / sample_cnt + - right_child_sample_cnt * impurity[tree_id][right_child_id] / sample_cnt) + / total_sample_cnt; + + tree.AddChilds(new_node_id); + tree.SetNumericalSplit(new_node_id, split_index, split_cond, true, treelite::Operator::kLE); + tree.SetGain(new_node_id, gain); + Q.push({left_child_id, tree.LeftChild(new_node_id)}); + Q.push({right_child_id, tree.RightChild(new_node_id)}); + } + tree.SetDataCount(new_node_id, sample_cnt); + } + } + return model_ptr; +} + +std::unique_ptr LoadSKLearnGradientBoostingClassifierBinary( + int n_estimators, int n_features, int n_classes, const int64_t* node_count, + const int64_t** children_left, const int64_t** children_right, const int64_t** feature, + const double** threshold, const double** value, const int64_t** n_node_samples, + const double** impurity) { +} + +std::unique_ptr LoadSKLearnGradientBoostingClassifierMulticlass( + int n_estimators, int n_features, int n_classes, const int64_t* node_count, + const int64_t** children_left, const int64_t** children_right, const int64_t** feature, + const double** threshold, const double** value, const int64_t** n_node_samples, + const double** impurity) { +} + +std::unique_ptr LoadSKLearnGradientBoostingClassifier( + int n_estimators, int n_features, int n_classes, const int64_t* node_count, + const int64_t** children_left, const int64_t** children_right, const int64_t** feature, + const double** threshold, const double** value, const int64_t** n_node_samples, + const double** impurity) { + CHECK_GE(n_classes, 2); + if (n_classes == 2) { + return LoadSKLearnGradientBoostingClassifierBinary(n_estimators, n_features, n_classes, + node_count, children_left, children_right, feature, threshold, value, n_node_samples, + impurity); + } else { + return LoadSKLearnGradientBoostingClassifierMulticlass(n_estimators, n_features, n_classes, + node_count, children_left, children_right, feature, threshold, value, n_node_samples, + impurity); + } +} + } // namespace frontend } // namespace treelite diff --git a/tests/python/test_skl_importer.py b/tests/python/test_skl_importer.py index 4fe05cf2..df642925 100644 --- a/tests/python/test_skl_importer.py +++ b/tests/python/test_skl_importer.py @@ -134,10 +134,7 @@ def test_skl_converter_regressor(tmpdir, clazz, toolchain): # pylint: disable=t clf.fit(X, y) expected_pred = clf.predict(X) - if clazz == [RandomForestRegressor, ExtraTreesRegressor]: - model = treelite.sklearn.import_model_v2(clf) - else: - model = treelite.sklearn.import_model(clf) + model = treelite.sklearn.import_model_v2(clf) assert model.num_feature == clf.n_features_ assert model.num_class == 1 assert model.num_tree == clf.n_estimators From c3a726b34b2295dab54865a96e5669eea2be0c2e Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Wed, 21 Apr 2021 15:56:25 -0700 Subject: [PATCH 09/19] Add support for GradientBoostingClassifier --- python/treelite/sklearn/__init__.py | 42 +++++++---- src/frontend/sklearn.cc | 112 ++++++++++++++++++++++++++++ tests/python/test_skl_importer.py | 10 +-- 3 files changed, 143 insertions(+), 21 deletions(-) diff --git a/python/treelite/sklearn/__init__.py b/python/treelite/sklearn/__init__.py index 3d452701..ac6beab6 100644 --- a/python/treelite/sklearn/__init__.py +++ b/python/treelite/sklearn/__init__.py @@ -143,7 +143,8 @@ def import_model_v2(sklearn_model): :py:class:`~sklearn.ensemble.RandomForestClassifier` / \ :py:class:`~sklearn.ensemble.ExtraTreesRegressor` / \ :py:class:`~sklearn.ensemble.ExtraTreesClassifier` / \ - :py:class:`~sklearn.ensemble.GradientBoostingRegressor` + :py:class:`~sklearn.ensemble.GradientBoostingRegressor` / \ + :py:class:`~sklearn.ensemble.GradientBoostingClassifier` Python handle to scikit-learn model Returns @@ -156,13 +157,18 @@ def import_model_v2(sklearn_model): if module_name != 'sklearn': raise TreeliteError('Not a scikit-learn model') - if class_name in ['RandomForestRegressor', 'ExtraTreesRegressor', 'GradientBoostingRegressor']: + if class_name in ['RandomForestRegressor', 'ExtraTreesRegressor', 'GradientBoostingRegressor', + 'GradientBoostingClassifier']: leaf_value_expected_shape = lambda node_count: (node_count, 1, 1) elif class_name in ['RandomForestClassifier', 'ExtraTreesClassifier']: leaf_value_expected_shape = lambda node_count: (node_count, 1, sklearn_model.n_classes_) else: raise TreeliteError(f'Not supported: {class_name}') + if class_name.startswith('GradientBoosting') and sklearn_model.init != 'zero': + raise treelite.TreeliteError("Gradient boosted trees must be trained with " + "the option init='zero'") + node_count = [] children_left = ArrayOfArrays(dtype=np.int64) children_right = ArrayOfArrays(dtype=np.int64) @@ -173,20 +179,23 @@ def import_model_v2(sklearn_model): impurity = ArrayOfArrays(dtype=np.float64) for estimator in sklearn_model.estimators_: if class_name.startswith('GradientBoosting'): - tree = estimator[0].tree_ + estimator_range = estimator learning_rate = sklearn_model.learning_rate else: - tree = estimator.tree_ + estimator_range = [estimator] learning_rate = 1.0 - node_count.append(tree.node_count) - children_left.add(tree.children_left, expected_shape=(tree.node_count,)) - children_right.add(tree.children_right, expected_shape=(tree.node_count,)) - feature.add(tree.feature, expected_shape=(tree.node_count,)) - threshold.add(tree.threshold, expected_shape=(tree.node_count,)) - value.add(tree.value * learning_rate, - expected_shape=leaf_value_expected_shape(tree.node_count)) - n_node_samples.add(tree.n_node_samples, expected_shape=(tree.node_count,)) - impurity.add(tree.impurity, expected_shape=(tree.node_count,)) + for sub_estimator in estimator_range: + tree = sub_estimator.tree_ + node_count.append(tree.node_count) + children_left.add(tree.children_left, expected_shape=(tree.node_count,)) + children_right.add(tree.children_right, expected_shape=(tree.node_count,)) + feature.add(tree.feature, expected_shape=(tree.node_count,)) + threshold.add(tree.threshold, expected_shape=(tree.node_count,)) + # Note: for gradient boosted trees, we shrink each leaf output by the learning rate + value.add(tree.value * learning_rate, + expected_shape=leaf_value_expected_shape(tree.node_count)) + n_node_samples.add(tree.n_node_samples, expected_shape=(tree.node_count,)) + impurity.add(tree.impurity, expected_shape=(tree.node_count,)) handle = ctypes.c_void_p() if class_name in ['RandomForestRegressor', 'ExtraTreesRegressor']: @@ -210,6 +219,13 @@ def import_model_v2(sklearn_model): children_right.as_c_array(), feature.as_c_array(), threshold.as_c_array(), value.as_c_array(), n_node_samples.as_c_array(), impurity.as_c_array(), ctypes.byref(handle))) + elif class_name == 'GradientBoostingClassifier': + _check_call(_LIB.TreeliteLoadSKLearnGradientBoostingClassifier( + ctypes.c_int(sklearn_model.n_estimators), ctypes.c_int(sklearn_model.n_features_), + ctypes.c_int(sklearn_model.n_classes_), c_array(ctypes.c_int64, node_count), + children_left.as_c_array(), children_right.as_c_array(), feature.as_c_array(), + threshold.as_c_array(), value.as_c_array(), n_node_samples.as_c_array(), + impurity.as_c_array(), ctypes.byref(handle))) else: raise TreeliteError(f'Not supported: {class_name}') return Model(handle) diff --git a/src/frontend/sklearn.cc b/src/frontend/sklearn.cc index 27ea6e24..096def7e 100644 --- a/src/frontend/sklearn.cc +++ b/src/frontend/sklearn.cc @@ -298,6 +298,62 @@ std::unique_ptr LoadSKLearnGradientBoostingClassifierBinary( const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, const double** impurity) { + CHECK_GT(n_estimators, 0); + CHECK_GT(n_features, 0); + + std::unique_ptr model_ptr = treelite::Model::Create(); + auto* model = dynamic_cast*>(model_ptr.get()); + model->num_feature = n_features; + model->average_tree_output = false; + model->task_type = treelite::TaskType::kBinaryClfRegr; + model->task_param.grove_per_class = false; + model->task_param.output_type = treelite::TaskParameter::OutputType::kFloat; + model->task_param.num_class = 1; + model->task_param.leaf_vector_size = 1; + std::strncpy(model->param.pred_transform, "sigmoid", sizeof(model->param.pred_transform)); + model->param.global_bias = 0.0f; + + for (int tree_id = 0; tree_id < n_estimators; ++tree_id) { + model->trees.emplace_back(); + treelite::Tree& tree = model->trees.back(); + tree.Init(); + + // Assign node ID's so that a breadth-wise traversal would yield + // the monotonic sequence 0, 1, 2, ... + std::queue> Q; // (old ID, new ID) pair + Q.push({0, 0}); + const int64_t total_sample_cnt = n_node_samples[tree_id][0]; + while (!Q.empty()) { + int64_t node_id; + int new_node_id; + std::tie(node_id, new_node_id) = Q.front(); Q.pop(); + const int64_t left_child_id = children_left[tree_id][node_id]; + const int64_t right_child_id = children_right[tree_id][node_id]; + const int64_t sample_cnt = n_node_samples[tree_id][node_id]; + if (left_child_id == -1) { // leaf node + const double leaf_value = value[tree_id][node_id]; + tree.SetLeaf(new_node_id, leaf_value); + } else { + const int64_t split_index = feature[tree_id][node_id]; + const double split_cond = threshold[tree_id][node_id]; + const int64_t left_child_sample_cnt = n_node_samples[tree_id][left_child_id]; + const int64_t right_child_sample_cnt = n_node_samples[tree_id][right_child_id]; + const double gain = sample_cnt * ( + impurity[tree_id][node_id] + - left_child_sample_cnt * impurity[tree_id][left_child_id] / sample_cnt + - right_child_sample_cnt * impurity[tree_id][right_child_id] / sample_cnt) + / total_sample_cnt; + + tree.AddChilds(new_node_id); + tree.SetNumericalSplit(new_node_id, split_index, split_cond, true, treelite::Operator::kLE); + tree.SetGain(new_node_id, gain); + Q.push({left_child_id, tree.LeftChild(new_node_id)}); + Q.push({right_child_id, tree.RightChild(new_node_id)}); + } + tree.SetDataCount(new_node_id, sample_cnt); + } + } + return model_ptr; } std::unique_ptr LoadSKLearnGradientBoostingClassifierMulticlass( @@ -305,6 +361,62 @@ std::unique_ptr LoadSKLearnGradientBoostingClassifierMulticlass const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, const double** impurity) { + CHECK_GT(n_estimators, 0); + CHECK_GT(n_features, 0); + + std::unique_ptr model_ptr = treelite::Model::Create(); + auto* model = dynamic_cast*>(model_ptr.get()); + model->num_feature = n_features; + model->average_tree_output = false; + model->task_type = treelite::TaskType::kMultiClfGrovePerClass; + model->task_param.grove_per_class = true; + model->task_param.output_type = treelite::TaskParameter::OutputType::kFloat; + model->task_param.num_class = n_classes; + model->task_param.leaf_vector_size = 1; + std::strncpy(model->param.pred_transform, "softmax", sizeof(model->param.pred_transform)); + model->param.global_bias = 0.0f; + + for (int tree_id = 0; tree_id < n_estimators * n_classes; ++tree_id) { + model->trees.emplace_back(); + treelite::Tree& tree = model->trees.back(); + tree.Init(); + + // Assign node ID's so that a breadth-wise traversal would yield + // the monotonic sequence 0, 1, 2, ... + std::queue> Q; // (old ID, new ID) pair + Q.push({0, 0}); + const int64_t total_sample_cnt = n_node_samples[tree_id][0]; + while (!Q.empty()) { + int64_t node_id; + int new_node_id; + std::tie(node_id, new_node_id) = Q.front(); Q.pop(); + const int64_t left_child_id = children_left[tree_id][node_id]; + const int64_t right_child_id = children_right[tree_id][node_id]; + const int64_t sample_cnt = n_node_samples[tree_id][node_id]; + if (left_child_id == -1) { // leaf node + const double leaf_value = value[tree_id][node_id]; + tree.SetLeaf(new_node_id, leaf_value); + } else { + const int64_t split_index = feature[tree_id][node_id]; + const double split_cond = threshold[tree_id][node_id]; + const int64_t left_child_sample_cnt = n_node_samples[tree_id][left_child_id]; + const int64_t right_child_sample_cnt = n_node_samples[tree_id][right_child_id]; + const double gain = sample_cnt * ( + impurity[tree_id][node_id] + - left_child_sample_cnt * impurity[tree_id][left_child_id] / sample_cnt + - right_child_sample_cnt * impurity[tree_id][right_child_id] / sample_cnt) + / total_sample_cnt; + + tree.AddChilds(new_node_id); + tree.SetNumericalSplit(new_node_id, split_index, split_cond, true, treelite::Operator::kLE); + tree.SetGain(new_node_id, gain); + Q.push({left_child_id, tree.LeftChild(new_node_id)}); + Q.push({right_child_id, tree.RightChild(new_node_id)}); + } + tree.SetDataCount(new_node_id, sample_cnt); + } + } + return model_ptr; } std::unique_ptr LoadSKLearnGradientBoostingClassifier( diff --git a/tests/python/test_skl_importer.py b/tests/python/test_skl_importer.py index df642925..248ef447 100644 --- a/tests/python/test_skl_importer.py +++ b/tests/python/test_skl_importer.py @@ -48,10 +48,7 @@ def test_skl_converter_multiclass_classifier(tmpdir, clazz, toolchain): clf.fit(X, y) expected_prob = clf.predict_proba(X) - if clazz in [RandomForestClassifier, ExtraTreesClassifier]: - model = treelite.sklearn.import_model_v2(clf) - else: - model = treelite.sklearn.import_model(clf) + model = treelite.sklearn.import_model_v2(clf) assert model.num_feature == clf.n_features_ assert model.num_class == clf.n_classes_ assert (model.num_tree == @@ -92,10 +89,7 @@ def test_skl_converter_binary_classifier(tmpdir, clazz, toolchain): clf.fit(X, y) expected_prob = clf.predict_proba(X)[:, 1] - if clazz in [RandomForestClassifier, ExtraTreesClassifier]: - model = treelite.sklearn.import_model_v2(clf) - else: - model = treelite.sklearn.import_model(clf) + model = treelite.sklearn.import_model_v2(clf) assert model.num_feature == clf.n_features_ assert model.num_class == 1 assert model.num_tree == clf.n_estimators From f1270486258e88543eb1730599f9ac9b01f7f91d Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Wed, 21 Apr 2021 16:19:04 -0700 Subject: [PATCH 10/19] Fix style --- python/treelite/sklearn/__init__.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/python/treelite/sklearn/__init__.py b/python/treelite/sklearn/__init__.py index ac6beab6..6a7a30ab 100644 --- a/python/treelite/sklearn/__init__.py +++ b/python/treelite/sklearn/__init__.py @@ -107,6 +107,9 @@ class SKLRFMultiClassifierConverter(SKLRFMultiClassifierMixin, SKLConverterBase) class ArrayOfArrays: + """ + Utility class to marshall a collection of arrays in order to pass to a C function + """ def __init__(self, *, dtype): int64_ptr_type = ctypes.POINTER(ctypes.c_int64) float64_ptr_type = ctypes.POINTER(ctypes.c_double) @@ -120,6 +123,7 @@ def __init__(self, *, dtype): self.collection = [] def add(self, array, *, expected_shape=None): + """Add an array to the collection""" assert array.dtype == self.dtype if expected_shape: assert array.shape == expected_shape, \ @@ -128,11 +132,12 @@ def add(self, array, *, expected_shape=None): self.collection.append(v.ctypes.data_as(self.ptr_type)) def as_c_array(self): + """Prepare the collection to pass as an argument of a C function""" return c_array(self.ptr_type, self.collection) def import_model_v2(sklearn_model): - # pylint: disable=R0914 + # pylint: disable=R0914,R0912 """ Load a tree ensemble model from a scikit-learn model object @@ -166,8 +171,7 @@ def import_model_v2(sklearn_model): raise TreeliteError(f'Not supported: {class_name}') if class_name.startswith('GradientBoosting') and sklearn_model.init != 'zero': - raise treelite.TreeliteError("Gradient boosted trees must be trained with " - "the option init='zero'") + raise TreeliteError("Gradient boosted trees must be trained with the option init='zero'") node_count = [] children_left = ArrayOfArrays(dtype=np.int64) @@ -226,8 +230,6 @@ def import_model_v2(sklearn_model): children_left.as_c_array(), children_right.as_c_array(), feature.as_c_array(), threshold.as_c_array(), value.as_c_array(), n_node_samples.as_c_array(), impurity.as_c_array(), ctypes.byref(handle))) - else: - raise TreeliteError(f'Not supported: {class_name}') return Model(handle) From 625484bc1797dc73ac7da0efdfc42499ffa07b49 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Wed, 21 Apr 2021 16:51:36 -0700 Subject: [PATCH 11/19] Depulicate code and refactor --- src/frontend/sklearn.cc | 450 +++++++++++++--------------------------- 1 file changed, 141 insertions(+), 309 deletions(-) diff --git a/src/frontend/sklearn.cc b/src/frontend/sklearn.cc index 096def7e..c63bcb14 100644 --- a/src/frontend/sklearn.cc +++ b/src/frontend/sklearn.cc @@ -15,26 +15,20 @@ namespace treelite { namespace frontend { -std::unique_ptr LoadSKLearnRandomForestRegressor( - int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left, - const int64_t** children_right, const int64_t** feature, const double** threshold, - const double** value, const int64_t** n_node_samples, const double** impurity) { - CHECK_GT(n_estimators, 0); +template +std::unique_ptr LoadSKLearnModel( + int n_trees, int n_features, int n_classes, const int64_t* node_count, + const int64_t** children_left, const int64_t** children_right, const int64_t** feature, + const double** threshold, const double** value, const int64_t** n_node_samples, + const double** impurity, MetaHandlerFunc meta_handler, LeafHandlerFunc leaf_handler) { + CHECK_GT(n_trees, 0); CHECK_GT(n_features, 0); std::unique_ptr model_ptr = treelite::Model::Create(); + meta_handler(model_ptr.get(), n_features, n_classes); auto* model = dynamic_cast*>(model_ptr.get()); - model->num_feature = n_features; - model->average_tree_output = true; - model->task_type = treelite::TaskType::kBinaryClfRegr; - model->task_param.grove_per_class = false; - model->task_param.output_type = treelite::TaskParameter::OutputType::kFloat; - model->task_param.num_class = 1; - model->task_param.leaf_vector_size = 1; - std::strncpy(model->param.pred_transform, "identity", sizeof(model->param.pred_transform)); - model->param.global_bias = 0.0f; - for (int tree_id = 0; tree_id < n_estimators; ++tree_id) { + for (int tree_id = 0; tree_id < n_trees; ++tree_id) { model->trees.emplace_back(); treelite::Tree& tree = model->trees.back(); tree.Init(); @@ -52,8 +46,7 @@ std::unique_ptr LoadSKLearnRandomForestRegressor( const int64_t right_child_id = children_right[tree_id][node_id]; const int64_t sample_cnt = n_node_samples[tree_id][node_id]; if (left_child_id == -1) { // leaf node - const double leaf_value = value[tree_id][node_id]; - tree.SetLeaf(new_node_id, leaf_value); + leaf_handler(tree_id, node_id, new_node_id, value, n_classes, tree); } else { const int64_t split_index = feature[tree_id][node_id]; const double split_cond = threshold[tree_id][node_id]; @@ -77,70 +70,57 @@ std::unique_ptr LoadSKLearnRandomForestRegressor( return model_ptr; } +std::unique_ptr LoadSKLearnRandomForestRegressor( + int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left, + const int64_t** children_right, const int64_t** feature, const double** threshold, + const double** value, const int64_t** n_node_samples, const double** impurity) { + auto meta_handler = [](treelite::Model* model, int n_features, int n_classes) { + model->num_feature = n_features; + model->average_tree_output = true; + model->task_type = treelite::TaskType::kBinaryClfRegr; + model->task_param.grove_per_class = false; + model->task_param.output_type = treelite::TaskParameter::OutputType::kFloat; + model->task_param.num_class = 1; + model->task_param.leaf_vector_size = 1; + std::strncpy(model->param.pred_transform, "identity", sizeof(model->param.pred_transform)); + model->param.global_bias = 0.0f; + }; + auto leaf_handler = [](int tree_id, int64_t node_id, int new_node_id, const double** value, + int n_classes, treelite::Tree& dest_tree) { + const double leaf_value = value[tree_id][node_id]; + dest_tree.SetLeaf(new_node_id, leaf_value); + }; + return LoadSKLearnModel(n_estimators, n_features, 1, node_count, children_left, children_right, + feature, threshold, value, n_node_samples, impurity, meta_handler, leaf_handler); +} + std::unique_ptr LoadSKLearnRandomForestClassifierBinary( int n_estimators, int n_features, int n_classes, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, const double** impurity) { - CHECK_GT(n_estimators, 0); - CHECK_GT(n_features, 0); - - std::unique_ptr model_ptr = treelite::Model::Create(); - auto* model = dynamic_cast*>(model_ptr.get()); - model->num_feature = n_features; - model->average_tree_output = true; - model->task_type = treelite::TaskType::kBinaryClfRegr; - model->task_param.grove_per_class = false; - model->task_param.output_type = treelite::TaskParameter::OutputType::kFloat; - model->task_param.num_class = 1; - model->task_param.leaf_vector_size = 1; - std::strncpy(model->param.pred_transform, "identity", sizeof(model->param.pred_transform)); - model->param.global_bias = 0.0f; - - for (int tree_id = 0; tree_id < n_estimators; ++tree_id) { - model->trees.emplace_back(); - treelite::Tree& tree = model->trees.back(); - tree.Init(); - - // Assign node ID's so that a breadth-wise traversal would yield - // the monotonic sequence 0, 1, 2, ... - std::queue> Q; // (old ID, new ID) pair - Q.push({0, 0}); - const int64_t total_sample_cnt = n_node_samples[tree_id][0]; - while (!Q.empty()) { - int64_t node_id; - int new_node_id; - std::tie(node_id, new_node_id) = Q.front(); Q.pop(); - const int64_t left_child_id = children_left[tree_id][node_id]; - const int64_t right_child_id = children_right[tree_id][node_id]; - const int64_t sample_cnt = n_node_samples[tree_id][node_id]; - if (left_child_id == -1) { // leaf node - // Get counts for each label (+/-) at this leaf node - const double* leaf_count = &value[tree_id][node_id * 2]; - // Compute the fraction of positive data points at this leaf node - const double fraction_positive = leaf_count[1] / (leaf_count[0] + leaf_count[1]); - tree.SetLeaf(new_node_id, fraction_positive); - } else { - const int64_t split_index = feature[tree_id][node_id]; - const double split_cond = threshold[tree_id][node_id]; - const int64_t left_child_sample_cnt = n_node_samples[tree_id][left_child_id]; - const int64_t right_child_sample_cnt = n_node_samples[tree_id][right_child_id]; - const double gain = sample_cnt * ( - impurity[tree_id][node_id] - - left_child_sample_cnt * impurity[tree_id][left_child_id] / sample_cnt - - right_child_sample_cnt * impurity[tree_id][right_child_id] / sample_cnt) - / total_sample_cnt; - - tree.AddChilds(new_node_id); - tree.SetNumericalSplit(new_node_id, split_index, split_cond, true, treelite::Operator::kLE); - tree.SetGain(new_node_id, gain); - Q.push({left_child_id, tree.LeftChild(new_node_id)}); - Q.push({right_child_id, tree.RightChild(new_node_id)}); - } - tree.SetDataCount(new_node_id, sample_cnt); - } - } - return model_ptr; + auto meta_handler = [](treelite::Model* model, int n_features, int n_classes) { + model->num_feature = n_features; + model->average_tree_output = true; + model->task_type = treelite::TaskType::kBinaryClfRegr; + model->task_param.grove_per_class = false; + model->task_param.output_type = treelite::TaskParameter::OutputType::kFloat; + model->task_param.num_class = 1; + model->task_param.leaf_vector_size = 1; + std::strncpy(model->param.pred_transform, "identity", sizeof(model->param.pred_transform)); + model->param.global_bias = 0.0f; + }; + auto leaf_handler = [](int tree_id, int64_t node_id, int new_node_id, const double** value, + int n_classes, treelite::Tree& dest_tree) { + // Get counts for each label (+/-) at this leaf node + const double* leaf_count = &value[tree_id][node_id * 2]; + // Compute the fraction of positive data points at this leaf node + const double fraction_positive = leaf_count[1] / (leaf_count[0] + leaf_count[1]); + dest_tree.SetLeaf(new_node_id, fraction_positive); + }; + return LoadSKLearnModel(n_estimators, n_features, n_classes, node_count, children_left, + children_right, feature, threshold, value, n_node_samples, impurity, meta_handler, + leaf_handler); } std::unique_ptr LoadSKLearnRandomForestClassifierMulticlass( @@ -148,71 +128,34 @@ std::unique_ptr LoadSKLearnRandomForestClassifierMulticlass( const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, const double** impurity) { - CHECK_GT(n_estimators, 0); - CHECK_GT(n_features, 0); - - std::unique_ptr model_ptr = treelite::Model::Create(); - auto* model = dynamic_cast*>(model_ptr.get()); - model->num_feature = n_features; - model->average_tree_output = true; - model->task_type = treelite::TaskType::kMultiClfProbDistLeaf; - model->task_param.grove_per_class = false; - model->task_param.output_type = treelite::TaskParameter::OutputType::kFloat; - model->task_param.num_class = n_classes; - model->task_param.leaf_vector_size = n_classes; - std::strncpy(model->param.pred_transform, "identity_multiclass", - sizeof(model->param.pred_transform)); - model->param.global_bias = 0.0f; - - for (int tree_id = 0; tree_id < n_estimators; ++tree_id) { - model->trees.emplace_back(); - treelite::Tree& tree = model->trees.back(); - tree.Init(); - - // Assign node ID's so that a breadth-wise traversal would yield - // the monotonic sequence 0, 1, 2, ... - std::queue> Q; // (old ID, new ID) pair - Q.push({0, 0}); - const int64_t total_sample_cnt = n_node_samples[tree_id][0]; - while (!Q.empty()) { - int64_t node_id; - int new_node_id; - std::tie(node_id, new_node_id) = Q.front(); Q.pop(); - const int64_t left_child_id = children_left[tree_id][node_id]; - const int64_t right_child_id = children_right[tree_id][node_id]; - const int64_t sample_cnt = n_node_samples[tree_id][node_id]; - if (left_child_id == -1) { // leaf node - // Get counts for each label class at this leaf node - std::vector prob_distribution(&value[tree_id][node_id * n_classes], - &value[tree_id][(node_id + 1) * n_classes]); - // Compute the probability distribution over label classes - const double norm_factor = - std::accumulate(prob_distribution.begin(), prob_distribution.end(), 0.0); - std::for_each(prob_distribution.begin(), prob_distribution.end(), [norm_factor](double& e) { - e /= norm_factor; - }); - tree.SetLeafVector(new_node_id, prob_distribution); - } else { - const int64_t split_index = feature[tree_id][node_id]; - const double split_cond = threshold[tree_id][node_id]; - const int64_t left_child_sample_cnt = n_node_samples[tree_id][left_child_id]; - const int64_t right_child_sample_cnt = n_node_samples[tree_id][right_child_id]; - const double gain = sample_cnt * ( - impurity[tree_id][node_id] - - left_child_sample_cnt * impurity[tree_id][left_child_id] / sample_cnt - - right_child_sample_cnt * impurity[tree_id][right_child_id] / sample_cnt) - / total_sample_cnt; - - tree.AddChilds(new_node_id); - tree.SetNumericalSplit(new_node_id, split_index, split_cond, true, treelite::Operator::kLE); - tree.SetGain(new_node_id, gain); - Q.push({left_child_id, tree.LeftChild(new_node_id)}); - Q.push({right_child_id, tree.RightChild(new_node_id)}); - } - tree.SetDataCount(new_node_id, sample_cnt); - } - } - return model_ptr; + auto meta_handler = [](treelite::Model* model, int n_features, int n_classes) { + model->num_feature = n_features; + model->average_tree_output = true; + model->task_type = treelite::TaskType::kMultiClfProbDistLeaf; + model->task_param.grove_per_class = false; + model->task_param.output_type = treelite::TaskParameter::OutputType::kFloat; + model->task_param.num_class = n_classes; + model->task_param.leaf_vector_size = n_classes; + std::strncpy(model->param.pred_transform, "identity_multiclass", + sizeof(model->param.pred_transform)); + model->param.global_bias = 0.0f; + }; + auto leaf_handler = [](int tree_id, int64_t node_id, int new_node_id, const double** value, + int n_classes, treelite::Tree& dest_tree) { + // Get counts for each label class at this leaf node + std::vector prob_distribution(&value[tree_id][node_id * n_classes], + &value[tree_id][(node_id + 1) * n_classes]); + // Compute the probability distribution over label classes + const double norm_factor = + std::accumulate(prob_distribution.begin(), prob_distribution.end(), 0.0); + std::for_each(prob_distribution.begin(), prob_distribution.end(), [norm_factor](double& e) { + e /= norm_factor; + }); + dest_tree.SetLeafVector(new_node_id, prob_distribution); + }; + return LoadSKLearnModel(n_estimators, n_features, n_classes, node_count, children_left, + children_right, feature, threshold, value, n_node_samples, impurity, meta_handler, + leaf_handler); } std::unique_ptr LoadSKLearnRandomForestClassifier( @@ -235,62 +178,25 @@ std::unique_ptr LoadSKLearnGradientBoostingRegressor( int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, const double** impurity) { - CHECK_GT(n_estimators, 0); - CHECK_GT(n_features, 0); - - std::unique_ptr model_ptr = treelite::Model::Create(); - auto* model = dynamic_cast*>(model_ptr.get()); - model->num_feature = n_features; - model->average_tree_output = false; - model->task_type = treelite::TaskType::kBinaryClfRegr; - model->task_param.grove_per_class = false; - model->task_param.output_type = treelite::TaskParameter::OutputType::kFloat; - model->task_param.num_class = 1; - model->task_param.leaf_vector_size = 1; - std::strncpy(model->param.pred_transform, "identity", sizeof(model->param.pred_transform)); - model->param.global_bias = 0.0f; - - for (int tree_id = 0; tree_id < n_estimators; ++tree_id) { - model->trees.emplace_back(); - treelite::Tree& tree = model->trees.back(); - tree.Init(); - - // Assign node ID's so that a breadth-wise traversal would yield - // the monotonic sequence 0, 1, 2, ... - std::queue> Q; // (old ID, new ID) pair - Q.push({0, 0}); - const int64_t total_sample_cnt = n_node_samples[tree_id][0]; - while (!Q.empty()) { - int64_t node_id; - int new_node_id; - std::tie(node_id, new_node_id) = Q.front(); Q.pop(); - const int64_t left_child_id = children_left[tree_id][node_id]; - const int64_t right_child_id = children_right[tree_id][node_id]; - const int64_t sample_cnt = n_node_samples[tree_id][node_id]; - if (left_child_id == -1) { // leaf node - const double leaf_value = value[tree_id][node_id]; - tree.SetLeaf(new_node_id, leaf_value); - } else { - const int64_t split_index = feature[tree_id][node_id]; - const double split_cond = threshold[tree_id][node_id]; - const int64_t left_child_sample_cnt = n_node_samples[tree_id][left_child_id]; - const int64_t right_child_sample_cnt = n_node_samples[tree_id][right_child_id]; - const double gain = sample_cnt * ( - impurity[tree_id][node_id] - - left_child_sample_cnt * impurity[tree_id][left_child_id] / sample_cnt - - right_child_sample_cnt * impurity[tree_id][right_child_id] / sample_cnt) - / total_sample_cnt; - - tree.AddChilds(new_node_id); - tree.SetNumericalSplit(new_node_id, split_index, split_cond, true, treelite::Operator::kLE); - tree.SetGain(new_node_id, gain); - Q.push({left_child_id, tree.LeftChild(new_node_id)}); - Q.push({right_child_id, tree.RightChild(new_node_id)}); - } - tree.SetDataCount(new_node_id, sample_cnt); - } - } - return model_ptr; + auto meta_handler = [](treelite::Model* model, int n_features, int n_classes) { + model->num_feature = n_features; + model->average_tree_output = false; + model->task_type = treelite::TaskType::kBinaryClfRegr; + model->task_param.grove_per_class = false; + model->task_param.output_type = treelite::TaskParameter::OutputType::kFloat; + model->task_param.num_class = 1; + model->task_param.leaf_vector_size = 1; + std::strncpy(model->param.pred_transform, "identity", sizeof(model->param.pred_transform)); + model->param.global_bias = 0.0f; + }; + auto leaf_handler = [](int tree_id, int64_t node_id, int new_node_id, const double** value, + int n_classes, treelite::Tree& dest_tree) { + const double leaf_value = value[tree_id][node_id]; + dest_tree.SetLeaf(new_node_id, leaf_value); + }; + return LoadSKLearnModel(n_estimators, n_features, 1, node_count, children_left, + children_right, feature, threshold, value, n_node_samples, impurity, meta_handler, + leaf_handler); } std::unique_ptr LoadSKLearnGradientBoostingClassifierBinary( @@ -298,62 +204,25 @@ std::unique_ptr LoadSKLearnGradientBoostingClassifierBinary( const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, const double** impurity) { - CHECK_GT(n_estimators, 0); - CHECK_GT(n_features, 0); - - std::unique_ptr model_ptr = treelite::Model::Create(); - auto* model = dynamic_cast*>(model_ptr.get()); - model->num_feature = n_features; - model->average_tree_output = false; - model->task_type = treelite::TaskType::kBinaryClfRegr; - model->task_param.grove_per_class = false; - model->task_param.output_type = treelite::TaskParameter::OutputType::kFloat; - model->task_param.num_class = 1; - model->task_param.leaf_vector_size = 1; - std::strncpy(model->param.pred_transform, "sigmoid", sizeof(model->param.pred_transform)); - model->param.global_bias = 0.0f; - - for (int tree_id = 0; tree_id < n_estimators; ++tree_id) { - model->trees.emplace_back(); - treelite::Tree& tree = model->trees.back(); - tree.Init(); - - // Assign node ID's so that a breadth-wise traversal would yield - // the monotonic sequence 0, 1, 2, ... - std::queue> Q; // (old ID, new ID) pair - Q.push({0, 0}); - const int64_t total_sample_cnt = n_node_samples[tree_id][0]; - while (!Q.empty()) { - int64_t node_id; - int new_node_id; - std::tie(node_id, new_node_id) = Q.front(); Q.pop(); - const int64_t left_child_id = children_left[tree_id][node_id]; - const int64_t right_child_id = children_right[tree_id][node_id]; - const int64_t sample_cnt = n_node_samples[tree_id][node_id]; - if (left_child_id == -1) { // leaf node - const double leaf_value = value[tree_id][node_id]; - tree.SetLeaf(new_node_id, leaf_value); - } else { - const int64_t split_index = feature[tree_id][node_id]; - const double split_cond = threshold[tree_id][node_id]; - const int64_t left_child_sample_cnt = n_node_samples[tree_id][left_child_id]; - const int64_t right_child_sample_cnt = n_node_samples[tree_id][right_child_id]; - const double gain = sample_cnt * ( - impurity[tree_id][node_id] - - left_child_sample_cnt * impurity[tree_id][left_child_id] / sample_cnt - - right_child_sample_cnt * impurity[tree_id][right_child_id] / sample_cnt) - / total_sample_cnt; - - tree.AddChilds(new_node_id); - tree.SetNumericalSplit(new_node_id, split_index, split_cond, true, treelite::Operator::kLE); - tree.SetGain(new_node_id, gain); - Q.push({left_child_id, tree.LeftChild(new_node_id)}); - Q.push({right_child_id, tree.RightChild(new_node_id)}); - } - tree.SetDataCount(new_node_id, sample_cnt); - } - } - return model_ptr; + auto meta_handler = [](treelite::Model* model, int n_features, int n_classes) { + model->num_feature = n_features; + model->average_tree_output = false; + model->task_type = treelite::TaskType::kBinaryClfRegr; + model->task_param.grove_per_class = false; + model->task_param.output_type = treelite::TaskParameter::OutputType::kFloat; + model->task_param.num_class = 1; + model->task_param.leaf_vector_size = 1; + std::strncpy(model->param.pred_transform, "sigmoid", sizeof(model->param.pred_transform)); + model->param.global_bias = 0.0f; + }; + auto leaf_handler = [](int tree_id, int64_t node_id, int new_node_id, const double** value, + int n_classes, treelite::Tree& dest_tree) { + const double leaf_value = value[tree_id][node_id]; + dest_tree.SetLeaf(new_node_id, leaf_value); + }; + return LoadSKLearnModel(n_estimators, n_features, n_classes, node_count, children_left, + children_right, feature, threshold, value, n_node_samples, impurity, meta_handler, + leaf_handler); } std::unique_ptr LoadSKLearnGradientBoostingClassifierMulticlass( @@ -361,62 +230,25 @@ std::unique_ptr LoadSKLearnGradientBoostingClassifierMulticlass const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, const double** impurity) { - CHECK_GT(n_estimators, 0); - CHECK_GT(n_features, 0); - - std::unique_ptr model_ptr = treelite::Model::Create(); - auto* model = dynamic_cast*>(model_ptr.get()); - model->num_feature = n_features; - model->average_tree_output = false; - model->task_type = treelite::TaskType::kMultiClfGrovePerClass; - model->task_param.grove_per_class = true; - model->task_param.output_type = treelite::TaskParameter::OutputType::kFloat; - model->task_param.num_class = n_classes; - model->task_param.leaf_vector_size = 1; - std::strncpy(model->param.pred_transform, "softmax", sizeof(model->param.pred_transform)); - model->param.global_bias = 0.0f; - - for (int tree_id = 0; tree_id < n_estimators * n_classes; ++tree_id) { - model->trees.emplace_back(); - treelite::Tree& tree = model->trees.back(); - tree.Init(); - - // Assign node ID's so that a breadth-wise traversal would yield - // the monotonic sequence 0, 1, 2, ... - std::queue> Q; // (old ID, new ID) pair - Q.push({0, 0}); - const int64_t total_sample_cnt = n_node_samples[tree_id][0]; - while (!Q.empty()) { - int64_t node_id; - int new_node_id; - std::tie(node_id, new_node_id) = Q.front(); Q.pop(); - const int64_t left_child_id = children_left[tree_id][node_id]; - const int64_t right_child_id = children_right[tree_id][node_id]; - const int64_t sample_cnt = n_node_samples[tree_id][node_id]; - if (left_child_id == -1) { // leaf node - const double leaf_value = value[tree_id][node_id]; - tree.SetLeaf(new_node_id, leaf_value); - } else { - const int64_t split_index = feature[tree_id][node_id]; - const double split_cond = threshold[tree_id][node_id]; - const int64_t left_child_sample_cnt = n_node_samples[tree_id][left_child_id]; - const int64_t right_child_sample_cnt = n_node_samples[tree_id][right_child_id]; - const double gain = sample_cnt * ( - impurity[tree_id][node_id] - - left_child_sample_cnt * impurity[tree_id][left_child_id] / sample_cnt - - right_child_sample_cnt * impurity[tree_id][right_child_id] / sample_cnt) - / total_sample_cnt; - - tree.AddChilds(new_node_id); - tree.SetNumericalSplit(new_node_id, split_index, split_cond, true, treelite::Operator::kLE); - tree.SetGain(new_node_id, gain); - Q.push({left_child_id, tree.LeftChild(new_node_id)}); - Q.push({right_child_id, tree.RightChild(new_node_id)}); - } - tree.SetDataCount(new_node_id, sample_cnt); - } - } - return model_ptr; + auto meta_handler = [](treelite::Model* model, int n_features, int n_classes) { + model->num_feature = n_features; + model->average_tree_output = false; + model->task_type = treelite::TaskType::kMultiClfGrovePerClass; + model->task_param.grove_per_class = true; + model->task_param.output_type = treelite::TaskParameter::OutputType::kFloat; + model->task_param.num_class = n_classes; + model->task_param.leaf_vector_size = 1; + std::strncpy(model->param.pred_transform, "softmax", sizeof(model->param.pred_transform)); + model->param.global_bias = 0.0f; + }; + auto leaf_handler = [](int tree_id, int64_t node_id, int new_node_id, const double** value, + int n_classes, treelite::Tree& dest_tree) { + const double leaf_value = value[tree_id][node_id]; + dest_tree.SetLeaf(new_node_id, leaf_value); + }; + return LoadSKLearnModel(n_estimators * n_classes, n_features, n_classes, node_count, + children_left, children_right, feature, threshold, value, n_node_samples, impurity, + meta_handler, leaf_handler); } std::unique_ptr LoadSKLearnGradientBoostingClassifier( From d0d547d3b3bce7e454df9609ffd0ffab18246e2b Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Tue, 27 Apr 2021 10:31:10 -0700 Subject: [PATCH 12/19] Rename import_model_v2 -> import_model --- python/treelite/sklearn/__init__.py | 8 ++++---- tests/python/test_skl_importer.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/treelite/sklearn/__init__.py b/python/treelite/sklearn/__init__.py index 6a7a30ab..4f64655d 100644 --- a/python/treelite/sklearn/__init__.py +++ b/python/treelite/sklearn/__init__.py @@ -18,7 +18,7 @@ from .rf_multi_classifier import SKLRFMultiClassifierMixin -def import_model(sklearn_model): +def import_model_with_model_builder(sklearn_model): """ Load a tree ensemble model from a scikit-learn model object @@ -51,7 +51,7 @@ def import_model(sklearn_model): clf.fit(X, y) import treelite.sklearn - model = treelite.sklearn.import_model(clf) + model = treelite.sklearn.import_model_with_model_builder(clf) """ class_name = sklearn_model.__class__.__name__ module_name = sklearn_model.__module__.split('.')[0] @@ -136,7 +136,7 @@ def as_c_array(self): return c_array(self.ptr_type, self.collection) -def import_model_v2(sklearn_model): +def import_model(sklearn_model): # pylint: disable=R0914,R0912 """ Load a tree ensemble model from a scikit-learn model object @@ -233,4 +233,4 @@ def import_model_v2(sklearn_model): return Model(handle) -__all__ = ['import_model', 'import_model_v2'] +__all__ = ['import_model_with_model_builder', 'import_model'] diff --git a/tests/python/test_skl_importer.py b/tests/python/test_skl_importer.py index 248ef447..9a322a4d 100644 --- a/tests/python/test_skl_importer.py +++ b/tests/python/test_skl_importer.py @@ -48,7 +48,7 @@ def test_skl_converter_multiclass_classifier(tmpdir, clazz, toolchain): clf.fit(X, y) expected_prob = clf.predict_proba(X) - model = treelite.sklearn.import_model_v2(clf) + model = treelite.sklearn.import_model(clf) assert model.num_feature == clf.n_features_ assert model.num_class == clf.n_classes_ assert (model.num_tree == @@ -89,7 +89,7 @@ def test_skl_converter_binary_classifier(tmpdir, clazz, toolchain): clf.fit(X, y) expected_prob = clf.predict_proba(X)[:, 1] - model = treelite.sklearn.import_model_v2(clf) + model = treelite.sklearn.import_model(clf) assert model.num_feature == clf.n_features_ assert model.num_class == 1 assert model.num_tree == clf.n_estimators @@ -128,7 +128,7 @@ def test_skl_converter_regressor(tmpdir, clazz, toolchain): # pylint: disable=t clf.fit(X, y) expected_pred = clf.predict(X) - model = treelite.sklearn.import_model_v2(clf) + model = treelite.sklearn.import_model(clf) assert model.num_feature == clf.n_features_ assert model.num_class == 1 assert model.num_tree == clf.n_estimators From 7aa4b13d7b56ce4a13499b2bc243f7bfcf909407 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Tue, 27 Apr 2021 11:17:36 -0700 Subject: [PATCH 13/19] Direct users to prefer new import_model --- python/treelite/sklearn/__init__.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/python/treelite/sklearn/__init__.py b/python/treelite/sklearn/__init__.py index 4f64655d..2d9355ca 100644 --- a/python/treelite/sklearn/__init__.py +++ b/python/treelite/sklearn/__init__.py @@ -20,7 +20,13 @@ def import_model_with_model_builder(sklearn_model): """ - Load a tree ensemble model from a scikit-learn model object + Load a tree ensemble model from a scikit-learn model object using the model builder API. + + .. note:: Use ``import_model`` for production use + + This function exists to demonstrate the use of the model builder API and is slow with + large models. For production, please use :py:func:`~treelite.sklearn.import_model` + which is significantly faster. Parameters ---------- @@ -156,6 +162,21 @@ def import_model(sklearn_model): ------- model : :py:class:`~treelite.Model` object loaded model + + Example + ------- + + .. code-block:: python + :emphasize-lines: 8 + + import sklearn.datasets + import sklearn.ensemble + X, y = sklearn.datasets.load_boston(return_X_y=True) + clf = sklearn.ensemble.RandomForestRegressor(n_estimators=10) + clf.fit(X, y) + + import treelite.sklearn + model = treelite.sklearn.import_model(clf) """ class_name = sklearn_model.__class__.__name__ module_name = sklearn_model.__module__.split('.')[0] From 3568ecd1370c5c3305998a7477db8526b2dfc147 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Tue, 27 Apr 2021 11:19:19 -0700 Subject: [PATCH 14/19] Remove __future__, as we support Python 3 only --- python/treelite/sklearn/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/treelite/sklearn/__init__.py b/python/treelite/sklearn/__init__.py index 2d9355ca..ca31d3fa 100644 --- a/python/treelite/sklearn/__init__.py +++ b/python/treelite/sklearn/__init__.py @@ -1,8 +1,6 @@ # coding: utf-8 """Converter to ingest scikit-learn models into Treelite""" -from __future__ import absolute_import as _abs - import ctypes import numpy as np From 5e806f1fdd48a68f4be1f9040222e4dadacba114 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Tue, 27 Apr 2021 12:11:08 -0700 Subject: [PATCH 15/19] Pytest should skip at the lack of sklearn --- python/treelite/sklearn/__init__.py | 68 ++++++++++++++++++----------- tests/python/test_skl_importer.py | 8 ++++ 2 files changed, 50 insertions(+), 26 deletions(-) diff --git a/python/treelite/sklearn/__init__.py b/python/treelite/sklearn/__init__.py index ca31d3fa..981a3302 100644 --- a/python/treelite/sklearn/__init__.py +++ b/python/treelite/sklearn/__init__.py @@ -57,29 +57,34 @@ def import_model_with_model_builder(sklearn_model): import treelite.sklearn model = treelite.sklearn.import_model_with_model_builder(clf) """ - class_name = sklearn_model.__class__.__name__ - module_name = sklearn_model.__module__.split('.')[0] - - if module_name != 'sklearn': - raise Exception('Not a scikit-learn model') - - if class_name in ['RandomForestRegressor', 'ExtraTreesRegressor']: + try: + import sklearn.ensemble + from sklearn.ensemble import RandomForestRegressor as RandomForestR + from sklearn.ensemble import RandomForestClassifier as RandomForestC + from sklearn.ensemble import ExtraTreesRegressor as ExtraTreesR + from sklearn.ensemble import ExtraTreesClassifier as ExtraTreesC + from sklearn.ensemble import GradientBoostingRegressor as GradientBoostingR + from sklearn.ensemble import GradientBoostingClassifier as GradientBoostingC + except ImportError as e: + raise TreeliteError('This function requires scikit-learn package') from e + + if isinstance(sklearn_model, (RandomForestR, ExtraTreesR)): return SKLRFRegressorConverter.process_model(sklearn_model) - if class_name in ['RandomForestClassifier', 'ExtraTreesClassifier']: + if isinstance(sklearn_model, (RandomForestC, ExtraTreesC)): if sklearn_model.n_classes_ == 2: return SKLRFClassifierConverter.process_model(sklearn_model) if sklearn_model.n_classes_ > 2: return SKLRFMultiClassifierConverter.process_model(sklearn_model) raise TreeliteError('n_classes_ must be at least 2') - if class_name == 'GradientBoostingRegressor': + if isinstance(sklearn_model, GradientBoostingR): return SKLGBMRegressorConverter.process_model(sklearn_model) - if class_name == 'GradientBoostingClassifier': + if isinstance(sklearn_model, GradientBoostingC): if sklearn_model.n_classes_ == 2: return SKLGBMClassifierConverter.process_model(sklearn_model) if sklearn_model.n_classes_ > 2: return SKLGBMMultiClassifierConverter.process_model(sklearn_model) raise TreeliteError('n_classes_ must be at least 2') - raise TreeliteError('Unsupported model type: currently ' + + raise TreeliteError(f'Unsupported model type {sklearn_model.__class__.__name__}: currently ' + 'random forests, extremely randomized trees, and gradient boosted trees ' + 'are supported') @@ -176,20 +181,27 @@ def import_model(sklearn_model): import treelite.sklearn model = treelite.sklearn.import_model(clf) """ - class_name = sklearn_model.__class__.__name__ - module_name = sklearn_model.__module__.split('.')[0] - if module_name != 'sklearn': - raise TreeliteError('Not a scikit-learn model') - - if class_name in ['RandomForestRegressor', 'ExtraTreesRegressor', 'GradientBoostingRegressor', - 'GradientBoostingClassifier']: + try: + import sklearn.ensemble + from sklearn.ensemble import RandomForestRegressor as RandomForestR + from sklearn.ensemble import RandomForestClassifier as RandomForestC + from sklearn.ensemble import ExtraTreesRegressor as ExtraTreesR + from sklearn.ensemble import ExtraTreesClassifier as ExtraTreesC + from sklearn.ensemble import GradientBoostingRegressor as GradientBoostingR + from sklearn.ensemble import GradientBoostingClassifier as GradientBoostingC + except ImportError as e: + raise TreeliteError('This function requires scikit-learn package') from e + + if isinstance(sklearn_model, + (RandomForestR, ExtraTreesR, GradientBoostingR, GradientBoostingC)): leaf_value_expected_shape = lambda node_count: (node_count, 1, 1) - elif class_name in ['RandomForestClassifier', 'ExtraTreesClassifier']: + elif isinstance(sklearn_model, (RandomForestC, ExtraTreesC)): leaf_value_expected_shape = lambda node_count: (node_count, 1, sklearn_model.n_classes_) else: - raise TreeliteError(f'Not supported: {class_name}') + raise TreeliteError(f'Not supported model type: {sklearn_model.__class__.__name__}') - if class_name.startswith('GradientBoosting') and sklearn_model.init != 'zero': + if isinstance(sklearn_model, + (GradientBoostingR, GradientBoostingC)) and sklearn_model.init != 'zero': raise TreeliteError("Gradient boosted trees must be trained with the option init='zero'") node_count = [] @@ -201,7 +213,7 @@ def import_model(sklearn_model): n_node_samples = ArrayOfArrays(dtype=np.int64) impurity = ArrayOfArrays(dtype=np.float64) for estimator in sklearn_model.estimators_: - if class_name.startswith('GradientBoosting'): + if isinstance(sklearn_model, (GradientBoostingR, GradientBoostingC)): estimator_range = estimator learning_rate = sklearn_model.learning_rate else: @@ -221,34 +233,38 @@ def import_model(sklearn_model): impurity.add(tree.impurity, expected_shape=(tree.node_count,)) handle = ctypes.c_void_p() - if class_name in ['RandomForestRegressor', 'ExtraTreesRegressor']: + if isinstance(sklearn_model, (RandomForestR, ExtraTreesR)): _check_call(_LIB.TreeliteLoadSKLearnRandomForestRegressor( ctypes.c_int(sklearn_model.n_estimators), ctypes.c_int(sklearn_model.n_features_), c_array(ctypes.c_int64, node_count), children_left.as_c_array(), children_right.as_c_array(), feature.as_c_array(), threshold.as_c_array(), value.as_c_array(), n_node_samples.as_c_array(), impurity.as_c_array(), ctypes.byref(handle))) - elif class_name in ['RandomForestClassifier', 'ExtraTreesClassifier']: + elif isinstance(sklearn_model, (RandomForestC, ExtraTreesC)): _check_call(_LIB.TreeliteLoadSKLearnRandomForestClassifier( ctypes.c_int(sklearn_model.n_estimators), ctypes.c_int(sklearn_model.n_features_), ctypes.c_int(sklearn_model.n_classes_), c_array(ctypes.c_int64, node_count), children_left.as_c_array(), children_right.as_c_array(), feature.as_c_array(), threshold.as_c_array(), value.as_c_array(), n_node_samples.as_c_array(), impurity.as_c_array(), ctypes.byref(handle))) - elif class_name == 'GradientBoostingRegressor': + elif isinstance(sklearn_model, GradientBoostingR): _check_call(_LIB.TreeliteLoadSKLearnGradientBoostingRegressor( ctypes.c_int(sklearn_model.n_estimators), ctypes.c_int(sklearn_model.n_features_), c_array(ctypes.c_int64, node_count), children_left.as_c_array(), children_right.as_c_array(), feature.as_c_array(), threshold.as_c_array(), value.as_c_array(), n_node_samples.as_c_array(), impurity.as_c_array(), ctypes.byref(handle))) - elif class_name == 'GradientBoostingClassifier': + elif isinstance(sklearn_model, GradientBoostingC): _check_call(_LIB.TreeliteLoadSKLearnGradientBoostingClassifier( ctypes.c_int(sklearn_model.n_estimators), ctypes.c_int(sklearn_model.n_features_), ctypes.c_int(sklearn_model.n_classes_), c_array(ctypes.c_int64, node_count), children_left.as_c_array(), children_right.as_c_array(), feature.as_c_array(), threshold.as_c_array(), value.as_c_array(), n_node_samples.as_c_array(), impurity.as_c_array(), ctypes.byref(handle))) + else: + raise TreeliteError(f'Unsupported model type {sklearn_model.__class__.__name__}: ' + + 'currently random forests, extremely randomized trees, and gradient ' + + 'boosted trees are supported') return Model(handle) diff --git a/tests/python/test_skl_importer.py b/tests/python/test_skl_importer.py index 9a322a4d..58b3d5be 100644 --- a/tests/python/test_skl_importer.py +++ b/tests/python/test_skl_importer.py @@ -21,10 +21,18 @@ class RandomForestClassifier: # pylint: disable=missing-class-docstring, R0903 pass + class RandomForestRegressor: # pylint: disable=missing-class-docstring, R0903 + pass + + class GradientBoostingClassifier: # pylint: disable=missing-class-docstring, R0903 pass + class GradientBoostingRegressor: # pylint: disable=missing-class-docstring, R0903 + pass + + class ExtraTreesClassifier: # pylint: disable=missing-class-docstring, R0903 pass From 2e6487064f0fd7b32a5140c02245092d549fa645 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Tue, 27 Apr 2021 12:15:22 -0700 Subject: [PATCH 16/19] Test both old and new import methods in pytest --- tests/python/test_skl_importer.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/tests/python/test_skl_importer.py b/tests/python/test_skl_importer.py index 58b3d5be..e0b53a51 100644 --- a/tests/python/test_skl_importer.py +++ b/tests/python/test_skl_importer.py @@ -45,7 +45,8 @@ class ExtraTreesRegressor: # pylint: disable=missing-class-docstring, R0903 @pytest.mark.parametrize('toolchain', os_compatible_toolchains()) @pytest.mark.parametrize('clazz', [RandomForestClassifier, ExtraTreesClassifier, GradientBoostingClassifier]) -def test_skl_converter_multiclass_classifier(tmpdir, clazz, toolchain): +@pytest.mark.parametrize('import_method', ['import_old', 'import_new']) +def test_skl_converter_multiclass_classifier(tmpdir, import_method, clazz, toolchain): # pylint: disable=too-many-locals """Convert scikit-learn multi-class classifier""" X, y = load_iris(return_X_y=True) @@ -56,7 +57,10 @@ def test_skl_converter_multiclass_classifier(tmpdir, clazz, toolchain): clf.fit(X, y) expected_prob = clf.predict_proba(X) - model = treelite.sklearn.import_model(clf) + if import_method == 'import_new': + model = treelite.sklearn.import_model(clf) + else: + model = treelite.sklearn.import_model_with_model_builder(clf) assert model.num_feature == clf.n_features_ assert model.num_class == clf.n_classes_ assert (model.num_tree == @@ -86,7 +90,8 @@ def test_skl_converter_multiclass_classifier(tmpdir, clazz, toolchain): @pytest.mark.parametrize('toolchain', os_compatible_toolchains()) @pytest.mark.parametrize('clazz', [RandomForestClassifier, ExtraTreesClassifier, GradientBoostingClassifier]) -def test_skl_converter_binary_classifier(tmpdir, clazz, toolchain): +@pytest.mark.parametrize('import_method', ['import_old', 'import_new']) +def test_skl_converter_binary_classifier(tmpdir, import_method, clazz, toolchain): # pylint: disable=too-many-locals """Convert scikit-learn binary classifier""" X, y = load_breast_cancer(return_X_y=True) @@ -97,7 +102,10 @@ def test_skl_converter_binary_classifier(tmpdir, clazz, toolchain): clf.fit(X, y) expected_prob = clf.predict_proba(X)[:, 1] - model = treelite.sklearn.import_model(clf) + if import_method == 'import_new': + model = treelite.sklearn.import_model(clf) + else: + model = treelite.sklearn.import_model_with_model_builder(clf) assert model.num_feature == clf.n_features_ assert model.num_class == 1 assert model.num_tree == clf.n_estimators @@ -126,7 +134,9 @@ def test_skl_converter_binary_classifier(tmpdir, clazz, toolchain): @pytest.mark.parametrize('toolchain', os_compatible_toolchains()) @pytest.mark.parametrize('clazz', [RandomForestRegressor, ExtraTreesRegressor, GradientBoostingRegressor]) -def test_skl_converter_regressor(tmpdir, clazz, toolchain): # pylint: disable=too-many-locals +@pytest.mark.parametrize('import_method', ['import_old', 'import_new']) +def test_skl_converter_regressor(tmpdir, import_method, clazz, toolchain): + # pylint: disable=too-many-locals """Convert scikit-learn regressor""" X, y = load_boston(return_X_y=True) kwargs = {} @@ -136,7 +146,10 @@ def test_skl_converter_regressor(tmpdir, clazz, toolchain): # pylint: disable=t clf.fit(X, y) expected_pred = clf.predict(X) - model = treelite.sklearn.import_model(clf) + if import_method == 'import_new': + model = treelite.sklearn.import_model(clf) + else: + model = treelite.sklearn.import_model_with_model_builder(clf) assert model.num_feature == clf.n_features_ assert model.num_class == 1 assert model.num_tree == clf.n_estimators From e3da3abde1de0cabb8ca238fad00e1088f5a270c Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Tue, 27 Apr 2021 12:20:47 -0700 Subject: [PATCH 17/19] Put the new import_model() in its own Python file --- python/treelite/sklearn/__init__.py | 156 +-------------------------- python/treelite/sklearn/importer.py | 162 ++++++++++++++++++++++++++++ 2 files changed, 163 insertions(+), 155 deletions(-) create mode 100644 python/treelite/sklearn/importer.py diff --git a/python/treelite/sklearn/__init__.py b/python/treelite/sklearn/__init__.py index 981a3302..051aa5d5 100644 --- a/python/treelite/sklearn/__init__.py +++ b/python/treelite/sklearn/__init__.py @@ -2,11 +2,10 @@ """Converter to ingest scikit-learn models into Treelite""" import ctypes -import numpy as np from ..util import TreeliteError -from ..core import _LIB, c_array, _check_call from ..frontend import Model +from .importer import import_model from .common import SKLConverterBase from .gbm_regressor import SKLGBMRegressorMixin from .gbm_classifier import SKLGBMClassifierMixin @@ -115,157 +114,4 @@ class SKLRFMultiClassifierConverter(SKLRFMultiClassifierMixin, SKLConverterBase) pass -class ArrayOfArrays: - """ - Utility class to marshall a collection of arrays in order to pass to a C function - """ - def __init__(self, *, dtype): - int64_ptr_type = ctypes.POINTER(ctypes.c_int64) - float64_ptr_type = ctypes.POINTER(ctypes.c_double) - if dtype == np.int64: - self.ptr_type = int64_ptr_type - elif dtype == np.float64: - self.ptr_type = float64_ptr_type - else: - raise ValueError(f'dtype {dtype} is not supported') - self.dtype = dtype - self.collection = [] - - def add(self, array, *, expected_shape=None): - """Add an array to the collection""" - assert array.dtype == self.dtype - if expected_shape: - assert array.shape == expected_shape, \ - f'Expected shape: {expected_shape}, Got shape {array.shape}' - v = np.array(array, copy=False, dtype=self.dtype, order='C') - self.collection.append(v.ctypes.data_as(self.ptr_type)) - - def as_c_array(self): - """Prepare the collection to pass as an argument of a C function""" - return c_array(self.ptr_type, self.collection) - - -def import_model(sklearn_model): - # pylint: disable=R0914,R0912 - """ - Load a tree ensemble model from a scikit-learn model object - - Parameters - ---------- - sklearn_model : object of type \ - :py:class:`~sklearn.ensemble.RandomForestRegressor` / \ - :py:class:`~sklearn.ensemble.RandomForestClassifier` / \ - :py:class:`~sklearn.ensemble.ExtraTreesRegressor` / \ - :py:class:`~sklearn.ensemble.ExtraTreesClassifier` / \ - :py:class:`~sklearn.ensemble.GradientBoostingRegressor` / \ - :py:class:`~sklearn.ensemble.GradientBoostingClassifier` - Python handle to scikit-learn model - - Returns - ------- - model : :py:class:`~treelite.Model` object - loaded model - - Example - ------- - - .. code-block:: python - :emphasize-lines: 8 - - import sklearn.datasets - import sklearn.ensemble - X, y = sklearn.datasets.load_boston(return_X_y=True) - clf = sklearn.ensemble.RandomForestRegressor(n_estimators=10) - clf.fit(X, y) - - import treelite.sklearn - model = treelite.sklearn.import_model(clf) - """ - try: - import sklearn.ensemble - from sklearn.ensemble import RandomForestRegressor as RandomForestR - from sklearn.ensemble import RandomForestClassifier as RandomForestC - from sklearn.ensemble import ExtraTreesRegressor as ExtraTreesR - from sklearn.ensemble import ExtraTreesClassifier as ExtraTreesC - from sklearn.ensemble import GradientBoostingRegressor as GradientBoostingR - from sklearn.ensemble import GradientBoostingClassifier as GradientBoostingC - except ImportError as e: - raise TreeliteError('This function requires scikit-learn package') from e - - if isinstance(sklearn_model, - (RandomForestR, ExtraTreesR, GradientBoostingR, GradientBoostingC)): - leaf_value_expected_shape = lambda node_count: (node_count, 1, 1) - elif isinstance(sklearn_model, (RandomForestC, ExtraTreesC)): - leaf_value_expected_shape = lambda node_count: (node_count, 1, sklearn_model.n_classes_) - else: - raise TreeliteError(f'Not supported model type: {sklearn_model.__class__.__name__}') - - if isinstance(sklearn_model, - (GradientBoostingR, GradientBoostingC)) and sklearn_model.init != 'zero': - raise TreeliteError("Gradient boosted trees must be trained with the option init='zero'") - - node_count = [] - children_left = ArrayOfArrays(dtype=np.int64) - children_right = ArrayOfArrays(dtype=np.int64) - feature = ArrayOfArrays(dtype=np.int64) - threshold = ArrayOfArrays(dtype=np.float64) - value = ArrayOfArrays(dtype=np.float64) - n_node_samples = ArrayOfArrays(dtype=np.int64) - impurity = ArrayOfArrays(dtype=np.float64) - for estimator in sklearn_model.estimators_: - if isinstance(sklearn_model, (GradientBoostingR, GradientBoostingC)): - estimator_range = estimator - learning_rate = sklearn_model.learning_rate - else: - estimator_range = [estimator] - learning_rate = 1.0 - for sub_estimator in estimator_range: - tree = sub_estimator.tree_ - node_count.append(tree.node_count) - children_left.add(tree.children_left, expected_shape=(tree.node_count,)) - children_right.add(tree.children_right, expected_shape=(tree.node_count,)) - feature.add(tree.feature, expected_shape=(tree.node_count,)) - threshold.add(tree.threshold, expected_shape=(tree.node_count,)) - # Note: for gradient boosted trees, we shrink each leaf output by the learning rate - value.add(tree.value * learning_rate, - expected_shape=leaf_value_expected_shape(tree.node_count)) - n_node_samples.add(tree.n_node_samples, expected_shape=(tree.node_count,)) - impurity.add(tree.impurity, expected_shape=(tree.node_count,)) - - handle = ctypes.c_void_p() - if isinstance(sklearn_model, (RandomForestR, ExtraTreesR)): - _check_call(_LIB.TreeliteLoadSKLearnRandomForestRegressor( - ctypes.c_int(sklearn_model.n_estimators), ctypes.c_int(sklearn_model.n_features_), - c_array(ctypes.c_int64, node_count), children_left.as_c_array(), - children_right.as_c_array(), feature.as_c_array(), threshold.as_c_array(), - value.as_c_array(), n_node_samples.as_c_array(), impurity.as_c_array(), - ctypes.byref(handle))) - elif isinstance(sklearn_model, (RandomForestC, ExtraTreesC)): - _check_call(_LIB.TreeliteLoadSKLearnRandomForestClassifier( - ctypes.c_int(sklearn_model.n_estimators), ctypes.c_int(sklearn_model.n_features_), - ctypes.c_int(sklearn_model.n_classes_), c_array(ctypes.c_int64, node_count), - children_left.as_c_array(), children_right.as_c_array(), feature.as_c_array(), - threshold.as_c_array(), value.as_c_array(), n_node_samples.as_c_array(), - impurity.as_c_array(), ctypes.byref(handle))) - elif isinstance(sklearn_model, GradientBoostingR): - _check_call(_LIB.TreeliteLoadSKLearnGradientBoostingRegressor( - ctypes.c_int(sklearn_model.n_estimators), ctypes.c_int(sklearn_model.n_features_), - c_array(ctypes.c_int64, node_count), children_left.as_c_array(), - children_right.as_c_array(), feature.as_c_array(), threshold.as_c_array(), - value.as_c_array(), n_node_samples.as_c_array(), impurity.as_c_array(), - ctypes.byref(handle))) - elif isinstance(sklearn_model, GradientBoostingC): - _check_call(_LIB.TreeliteLoadSKLearnGradientBoostingClassifier( - ctypes.c_int(sklearn_model.n_estimators), ctypes.c_int(sklearn_model.n_features_), - ctypes.c_int(sklearn_model.n_classes_), c_array(ctypes.c_int64, node_count), - children_left.as_c_array(), children_right.as_c_array(), feature.as_c_array(), - threshold.as_c_array(), value.as_c_array(), n_node_samples.as_c_array(), - impurity.as_c_array(), ctypes.byref(handle))) - else: - raise TreeliteError(f'Unsupported model type {sklearn_model.__class__.__name__}: ' + - 'currently random forests, extremely randomized trees, and gradient ' + - 'boosted trees are supported') - return Model(handle) - - __all__ = ['import_model_with_model_builder', 'import_model'] diff --git a/python/treelite/sklearn/importer.py b/python/treelite/sklearn/importer.py new file mode 100644 index 00000000..43fae91e --- /dev/null +++ b/python/treelite/sklearn/importer.py @@ -0,0 +1,162 @@ +# coding: utf-8 +"""Converter to ingest scikit-learn models into Treelite""" + +import ctypes +import numpy as np + +from ..util import TreeliteError +from ..core import _LIB, c_array, _check_call +from ..frontend import Model + + +class ArrayOfArrays: + """ + Utility class to marshall a collection of arrays in order to pass to a C function + """ + def __init__(self, *, dtype): + int64_ptr_type = ctypes.POINTER(ctypes.c_int64) + float64_ptr_type = ctypes.POINTER(ctypes.c_double) + if dtype == np.int64: + self.ptr_type = int64_ptr_type + elif dtype == np.float64: + self.ptr_type = float64_ptr_type + else: + raise ValueError(f'dtype {dtype} is not supported') + self.dtype = dtype + self.collection = [] + + def add(self, array, *, expected_shape=None): + """Add an array to the collection""" + assert array.dtype == self.dtype + if expected_shape: + assert array.shape == expected_shape, \ + f'Expected shape: {expected_shape}, Got shape {array.shape}' + v = np.array(array, copy=False, dtype=self.dtype, order='C') + self.collection.append(v.ctypes.data_as(self.ptr_type)) + + def as_c_array(self): + """Prepare the collection to pass as an argument of a C function""" + return c_array(self.ptr_type, self.collection) + + +def import_model(sklearn_model): + # pylint: disable=R0914,R0912 + """ + Load a tree ensemble model from a scikit-learn model object + + Parameters + ---------- + sklearn_model : object of type \ + :py:class:`~sklearn.ensemble.RandomForestRegressor` / \ + :py:class:`~sklearn.ensemble.RandomForestClassifier` / \ + :py:class:`~sklearn.ensemble.ExtraTreesRegressor` / \ + :py:class:`~sklearn.ensemble.ExtraTreesClassifier` / \ + :py:class:`~sklearn.ensemble.GradientBoostingRegressor` / \ + :py:class:`~sklearn.ensemble.GradientBoostingClassifier` + Python handle to scikit-learn model + + Returns + ------- + model : :py:class:`~treelite.Model` object + loaded model + + Example + ------- + + .. code-block:: python + :emphasize-lines: 8 + + import sklearn.datasets + import sklearn.ensemble + X, y = sklearn.datasets.load_boston(return_X_y=True) + clf = sklearn.ensemble.RandomForestRegressor(n_estimators=10) + clf.fit(X, y) + + import treelite.sklearn + model = treelite.sklearn.import_model(clf) + """ + try: + import sklearn.ensemble + from sklearn.ensemble import RandomForestRegressor as RandomForestR + from sklearn.ensemble import RandomForestClassifier as RandomForestC + from sklearn.ensemble import ExtraTreesRegressor as ExtraTreesR + from sklearn.ensemble import ExtraTreesClassifier as ExtraTreesC + from sklearn.ensemble import GradientBoostingRegressor as GradientBoostingR + from sklearn.ensemble import GradientBoostingClassifier as GradientBoostingC + except ImportError as e: + raise TreeliteError('This function requires scikit-learn package') from e + + if isinstance(sklearn_model, + (RandomForestR, ExtraTreesR, GradientBoostingR, GradientBoostingC)): + leaf_value_expected_shape = lambda node_count: (node_count, 1, 1) + elif isinstance(sklearn_model, (RandomForestC, ExtraTreesC)): + leaf_value_expected_shape = lambda node_count: (node_count, 1, sklearn_model.n_classes_) + else: + raise TreeliteError(f'Not supported model type: {sklearn_model.__class__.__name__}') + + if isinstance(sklearn_model, + (GradientBoostingR, GradientBoostingC)) and sklearn_model.init != 'zero': + raise TreeliteError("Gradient boosted trees must be trained with the option init='zero'") + + node_count = [] + children_left = ArrayOfArrays(dtype=np.int64) + children_right = ArrayOfArrays(dtype=np.int64) + feature = ArrayOfArrays(dtype=np.int64) + threshold = ArrayOfArrays(dtype=np.float64) + value = ArrayOfArrays(dtype=np.float64) + n_node_samples = ArrayOfArrays(dtype=np.int64) + impurity = ArrayOfArrays(dtype=np.float64) + for estimator in sklearn_model.estimators_: + if isinstance(sklearn_model, (GradientBoostingR, GradientBoostingC)): + estimator_range = estimator + learning_rate = sklearn_model.learning_rate + else: + estimator_range = [estimator] + learning_rate = 1.0 + for sub_estimator in estimator_range: + tree = sub_estimator.tree_ + node_count.append(tree.node_count) + children_left.add(tree.children_left, expected_shape=(tree.node_count,)) + children_right.add(tree.children_right, expected_shape=(tree.node_count,)) + feature.add(tree.feature, expected_shape=(tree.node_count,)) + threshold.add(tree.threshold, expected_shape=(tree.node_count,)) + # Note: for gradient boosted trees, we shrink each leaf output by the learning rate + value.add(tree.value * learning_rate, + expected_shape=leaf_value_expected_shape(tree.node_count)) + n_node_samples.add(tree.n_node_samples, expected_shape=(tree.node_count,)) + impurity.add(tree.impurity, expected_shape=(tree.node_count,)) + + handle = ctypes.c_void_p() + if isinstance(sklearn_model, (RandomForestR, ExtraTreesR)): + _check_call(_LIB.TreeliteLoadSKLearnRandomForestRegressor( + ctypes.c_int(sklearn_model.n_estimators), ctypes.c_int(sklearn_model.n_features_), + c_array(ctypes.c_int64, node_count), children_left.as_c_array(), + children_right.as_c_array(), feature.as_c_array(), threshold.as_c_array(), + value.as_c_array(), n_node_samples.as_c_array(), impurity.as_c_array(), + ctypes.byref(handle))) + elif isinstance(sklearn_model, (RandomForestC, ExtraTreesC)): + _check_call(_LIB.TreeliteLoadSKLearnRandomForestClassifier( + ctypes.c_int(sklearn_model.n_estimators), ctypes.c_int(sklearn_model.n_features_), + ctypes.c_int(sklearn_model.n_classes_), c_array(ctypes.c_int64, node_count), + children_left.as_c_array(), children_right.as_c_array(), feature.as_c_array(), + threshold.as_c_array(), value.as_c_array(), n_node_samples.as_c_array(), + impurity.as_c_array(), ctypes.byref(handle))) + elif isinstance(sklearn_model, GradientBoostingR): + _check_call(_LIB.TreeliteLoadSKLearnGradientBoostingRegressor( + ctypes.c_int(sklearn_model.n_estimators), ctypes.c_int(sklearn_model.n_features_), + c_array(ctypes.c_int64, node_count), children_left.as_c_array(), + children_right.as_c_array(), feature.as_c_array(), threshold.as_c_array(), + value.as_c_array(), n_node_samples.as_c_array(), impurity.as_c_array(), + ctypes.byref(handle))) + elif isinstance(sklearn_model, GradientBoostingC): + _check_call(_LIB.TreeliteLoadSKLearnGradientBoostingClassifier( + ctypes.c_int(sklearn_model.n_estimators), ctypes.c_int(sklearn_model.n_features_), + ctypes.c_int(sklearn_model.n_classes_), c_array(ctypes.c_int64, node_count), + children_left.as_c_array(), children_right.as_c_array(), feature.as_c_array(), + threshold.as_c_array(), value.as_c_array(), n_node_samples.as_c_array(), + impurity.as_c_array(), ctypes.byref(handle))) + else: + raise TreeliteError(f'Unsupported model type {sklearn_model.__class__.__name__}: ' + + 'currently random forests, extremely randomized trees, and gradient ' + + 'boosted trees are supported') + return Model(handle) From 85bc853a63f6becf2959359f0aa61a660aea493b Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Tue, 27 Apr 2021 14:24:02 -0700 Subject: [PATCH 18/19] Add docstring to functions --- include/treelite/c_api.h | 103 ++++++++++++++++++++++++++++++++++++ include/treelite/frontend.h | 101 +++++++++++++++++++++++++++++++++-- 2 files changed, 201 insertions(+), 3 deletions(-) diff --git a/include/treelite/c_api.h b/include/treelite/c_api.h index 3a3ddef1..48a5f271 100644 --- a/include/treelite/c_api.h +++ b/include/treelite/c_api.h @@ -172,21 +172,124 @@ TREELITE_DLL int TreeliteLoadXGBoostModelFromMemoryBuffer(const void* buf, size_t len, ModelHandle* out); +/*! + * \brief Load a scikit-learn random forest regressor model from a collection of arrays. Refer to + * https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html to + * learn the mearning of the arrays in detail. Note that this function can also be used to + * load an ensemble of extremely randomized trees (sklearn.ensemble.ExtraTreesRegressor). + * \param n_estimators number of trees in the random forest + * \param n_features number of features in the training data + * \param node_count node_count[i] stores the number of nodes in the i-th tree + * \param children_left children_left[i][k] stores the ID of the left child node of node k of the + * i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param children_right children_right[i][k] stores the ID of the right child node of node k of the + * i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param feature feature[i][k] stores the ID of the feature used in the binary tree split at node k + * of the i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param threshold threshold[i][k] stores the threshold used in the binary tree split at node k of + * the i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param value value[i][k] stores the leaf output of node k of the i-th tree. This is only defined + * if node k is a leaf node. + * \param n_node_samples n_node_samples[i][k] stores the number of data samples associated with + * node k of the i-th tree. + * \param impurity impurity[i][k] stores the impurity measure (gini, entropy etc) associated with + * node k of the i-th tree. + * \param out pointer to store the loaded model + * \return 0 for success, -1 for failure + */ TREELITE_DLL int TreeliteLoadSKLearnRandomForestRegressor( int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, const double** impurity, ModelHandle* out); + +/*! + * \brief Load a scikit-learn random forest classifier model from a collection of arrays. Refer to + * https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html to + * learn the mearning of the arrays in detail. Note that this function can also be used to + * load an ensemble of extremely randomized trees (sklearn.ensemble.ExtraTreesClassifier). + * \param n_estimators number of trees in the random forest + * \param n_features number of features in the training data + * \param n_classes number of classes in the target variable + * \param node_count node_count[i] stores the number of nodes in the i-th tree + * \param children_left children_left[i][k] stores the ID of the left child node of node k of the + * i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param children_right children_right[i][k] stores the ID of the right child node of node k of the + * i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param feature feature[i][k] stores the ID of the feature used in the binary tree split at node k + * of the i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param threshold threshold[i][k] stores the threshold used in the binary tree split at node k of + * the i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param value value[i][k] stores the leaf output of node k of the i-th tree. This is only defined + * if node k is a leaf node. + * \param n_node_samples n_node_samples[i][k] stores the number of data samples associated with + * node k of the i-th tree. + * \param impurity impurity[i][k] stores the impurity measure (gini, entropy etc) associated with + * node k of the i-th tree. + * \param out pointer to store the loaded model + * \return 0 for success, -1 for failure + */ TREELITE_DLL int TreeliteLoadSKLearnRandomForestClassifier( int n_estimators, int n_features, int n_classes, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, const double** impurity, ModelHandle* out); + +/*! + * \brief Load a scikit-learn gradient boosting regressor model from a collection of arrays. Refer + * to https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html to + * learn the mearning of the arrays in detail. + * \param n_estimators number of trees in the random forest + * \param n_features number of features in the training data + * \param node_count node_count[i] stores the number of nodes in the i-th tree + * \param children_left children_left[i][k] stores the ID of the left child node of node k of the + * i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param children_right children_right[i][k] stores the ID of the right child node of node k of the + * i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param feature feature[i][k] stores the ID of the feature used in the binary tree split at node k + * of the i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param threshold threshold[i][k] stores the threshold used in the binary tree split at node k of + * the i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param value value[i][k] stores the leaf output of node k of the i-th tree. This is only defined + * if node k is a leaf node. + * \param n_node_samples n_node_samples[i][k] stores the number of data samples associated with + * node k of the i-th tree. + * \param impurity impurity[i][k] stores the impurity measure (gini, entropy etc) associated with + * node k of the i-th tree. + * \param out pointer to store the loaded model + * \return 0 for success, -1 for failure + */ TREELITE_DLL int TreeliteLoadSKLearnGradientBoostingRegressor( int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, const double** impurity, ModelHandle* out); + +/*! + * \brief Load a scikit-learn gradient boosting classifier model from a collection of arrays. Refer + * to https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html to + * learn the mearning of the arrays in detail. + * \param n_estimators number of trees in the random forest + * \param n_features number of features in the training data + * \param n_classes number of classes in the target variable + * \param node_count node_count[i] stores the number of nodes in the i-th tree + * \param children_left children_left[i][k] stores the ID of the left child node of node k of the + * i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param children_right children_right[i][k] stores the ID of the right child node of node k of the + * i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param feature feature[i][k] stores the ID of the feature used in the binary tree split at node k + * of the i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param threshold threshold[i][k] stores the threshold used in the binary tree split at node k of + * the i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param value value[i][k] stores the leaf output of node k of the i-th tree. This is only defined + * if node k is a leaf node. + * \param n_node_samples n_node_samples[i][k] stores the number of data samples associated with + * node k of the i-th tree. + * \param impurity impurity[i][k] stores the impurity measure (gini, entropy etc) associated with + * node k of the i-th tree. + * \param out pointer to store the loaded model + * \return 0 for success, -1 for failure + */ TREELITE_DLL int TreeliteLoadSKLearnGradientBoostingClassifier( int n_estimators, int n_features, int n_classes, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, diff --git a/include/treelite/frontend.h b/include/treelite/frontend.h index feb63854..a8f8a7b9 100644 --- a/include/treelite/frontend.h +++ b/include/treelite/frontend.h @@ -48,30 +48,125 @@ std::unique_ptr LoadXGBoostModel(const void* buf, size_t len); * \brief load a model file generated by XGBoost (dmlc/xgboost). The model file * must contain a decision tree ensemble in the JSON format. * \param filename name of model file - * \param out reference to loaded model + * \return loaded model */ std::unique_ptr LoadXGBoostJSONModel(const char* filename); /*! * \brief load an XGBoost model from a JSON string * \param json_str JSON char array * \param length length of JSON char array - * \param out reference to loaded model + * \return loaded model */ std::unique_ptr LoadXGBoostJSONModelString(const char* json_str, size_t length); - +/*! + * \brief Load a scikit-learn random forest regressor model from a collection of arrays. Refer to + * https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html to + * learn the mearning of the arrays in detail. Note that this function can also be used to + * load an ensemble of extremely randomized trees (sklearn.ensemble.ExtraTreesRegressor). + * \param n_estimators number of trees in the random forest + * \param n_features number of features in the training data + * \param node_count node_count[i] stores the number of nodes in the i-th tree + * \param children_left children_left[i][k] stores the ID of the left child node of node k of the + * i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param children_right children_right[i][k] stores the ID of the right child node of node k of the + * i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param feature feature[i][k] stores the ID of the feature used in the binary tree split at node k + * of the i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param threshold threshold[i][k] stores the threshold used in the binary tree split at node k of + * the i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param value value[i][k] stores the leaf output of node k of the i-th tree. This is only defined + * if node k is a leaf node. + * \param n_node_samples n_node_samples[i][k] stores the number of data samples associated with + * node k of the i-th tree. + * \param impurity impurity[i][k] stores the impurity measure (gini, entropy etc) associated with + * node k of the i-th tree. + * \return loaded model + */ std::unique_ptr LoadSKLearnRandomForestRegressor( int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, const double** impurity); +/*! + * \brief Load a scikit-learn random forest classifier model from a collection of arrays. Refer to + * https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html to + * learn the mearning of the arrays in detail. Note that this function can also be used to + * load an ensemble of extremely randomized trees (sklearn.ensemble.ExtraTreesClassifier). + * \param n_estimators number of trees in the random forest + * \param n_features number of features in the training data + * \param n_classes number of classes in the target variable + * \param node_count node_count[i] stores the number of nodes in the i-th tree + * \param children_left children_left[i][k] stores the ID of the left child node of node k of the + * i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param children_right children_right[i][k] stores the ID of the right child node of node k of the + * i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param feature feature[i][k] stores the ID of the feature used in the binary tree split at node k + * of the i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param threshold threshold[i][k] stores the threshold used in the binary tree split at node k of + * the i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param value value[i][k] stores the leaf output of node k of the i-th tree. This is only defined + * if node k is a leaf node. + * \param n_node_samples n_node_samples[i][k] stores the number of data samples associated with + * node k of the i-th tree. + * \param impurity impurity[i][k] stores the impurity measure (gini, entropy etc) associated with + * node k of the i-th tree. + * \return loaded model + */ std::unique_ptr LoadSKLearnRandomForestClassifier( int n_estimators, int n_features, int n_classes, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, const double** impurity); +/*! + * \brief Load a scikit-learn gradient boosting regressor model from a collection of arrays. Refer + * to https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html to + * learn the mearning of the arrays in detail. + * \param n_estimators number of trees in the random forest + * \param n_features number of features in the training data + * \param node_count node_count[i] stores the number of nodes in the i-th tree + * \param children_left children_left[i][k] stores the ID of the left child node of node k of the + * i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param children_right children_right[i][k] stores the ID of the right child node of node k of the + * i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param feature feature[i][k] stores the ID of the feature used in the binary tree split at node k + * of the i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param threshold threshold[i][k] stores the threshold used in the binary tree split at node k of + * the i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param value value[i][k] stores the leaf output of node k of the i-th tree. This is only defined + * if node k is a leaf node. + * \param n_node_samples n_node_samples[i][k] stores the number of data samples associated with + * node k of the i-th tree. + * \param impurity impurity[i][k] stores the impurity measure (gini, entropy etc) associated with + * node k of the i-th tree. + * \return loaded model + */ std::unique_ptr LoadSKLearnGradientBoostingRegressor( int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, const double** threshold, const double** value, const int64_t** n_node_samples, const double** impurity); +/*! + * \brief Load a scikit-learn gradient boosting classifier model from a collection of arrays. Refer + * to https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html to + * learn the mearning of the arrays in detail. + * \param n_estimators number of trees in the random forest + * \param n_features number of features in the training data + * \param n_classes number of classes in the target variable + * \param node_count node_count[i] stores the number of nodes in the i-th tree + * \param children_left children_left[i][k] stores the ID of the left child node of node k of the + * i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param children_right children_right[i][k] stores the ID of the right child node of node k of the + * i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param feature feature[i][k] stores the ID of the feature used in the binary tree split at node k + * of the i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param threshold threshold[i][k] stores the threshold used in the binary tree split at node k of + * the i-th tree. This is only defined if node k is an internal (non-leaf) node. + * \param value value[i][k] stores the leaf output of node k of the i-th tree. This is only defined + * if node k is a leaf node. + * \param n_node_samples n_node_samples[i][k] stores the number of data samples associated with + * node k of the i-th tree. + * \param impurity impurity[i][k] stores the impurity measure (gini, entropy etc) associated with + * node k of the i-th tree. + * \return loaded model + */ std::unique_ptr LoadSKLearnGradientBoostingClassifier( int n_estimators, int n_features, int n_classes, const int64_t* node_count, const int64_t** children_left, const int64_t** children_right, const int64_t** feature, From a5b249dbd3609c99a9d33c453483292f731e49aa Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Tue, 27 Apr 2021 15:20:46 -0700 Subject: [PATCH 19/19] Comply with pylint check --- python/treelite/sklearn/__init__.py | 1 - python/treelite/sklearn/importer.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/python/treelite/sklearn/__init__.py b/python/treelite/sklearn/__init__.py index 051aa5d5..7d3bf474 100644 --- a/python/treelite/sklearn/__init__.py +++ b/python/treelite/sklearn/__init__.py @@ -57,7 +57,6 @@ def import_model_with_model_builder(sklearn_model): model = treelite.sklearn.import_model_with_model_builder(clf) """ try: - import sklearn.ensemble from sklearn.ensemble import RandomForestRegressor as RandomForestR from sklearn.ensemble import RandomForestClassifier as RandomForestC from sklearn.ensemble import ExtraTreesRegressor as ExtraTreesR diff --git a/python/treelite/sklearn/importer.py b/python/treelite/sklearn/importer.py index 43fae91e..a484aac1 100644 --- a/python/treelite/sklearn/importer.py +++ b/python/treelite/sklearn/importer.py @@ -40,7 +40,7 @@ def as_c_array(self): def import_model(sklearn_model): - # pylint: disable=R0914,R0912 + # pylint: disable=R0914,R0912,R0915 """ Load a tree ensemble model from a scikit-learn model object @@ -76,7 +76,6 @@ def import_model(sklearn_model): model = treelite.sklearn.import_model(clf) """ try: - import sklearn.ensemble from sklearn.ensemble import RandomForestRegressor as RandomForestR from sklearn.ensemble import RandomForestClassifier as RandomForestC from sklearn.ensemble import ExtraTreesRegressor as ExtraTreesR