Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster model import for sklearn tree models #264

Merged
merged 20 commits into from
Apr 28, 2021
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions include/treelite/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,28 @@ TREELITE_DLL int TreeliteLoadXGBoostJSONString(const char* json_str,
TREELITE_DLL int TreeliteLoadXGBoostModelFromMemoryBuffer(const void* buf,
size_t len,
ModelHandle* out);

TREELITE_DLL int TreeliteLoadSKLearnRandomForestRegressor(
hcho3 marked this conversation as resolved.
Show resolved Hide resolved
int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left,
const int64_t** children_right, const int64_t** feature, const double** threshold,
const double** value, const int64_t** n_node_samples, const double** impurity,
ModelHandle* out);
TREELITE_DLL int TreeliteLoadSKLearnRandomForestClassifier(
int n_estimators, int n_features, int n_classes, const int64_t* node_count,
const int64_t** children_left, const int64_t** children_right, const int64_t** feature,
const double** threshold, const double** value, const int64_t** n_node_samples,
const double** impurity, ModelHandle* out);
TREELITE_DLL int TreeliteLoadSKLearnGradientBoostingRegressor(
int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left,
const int64_t** children_right, const int64_t** feature, const double** threshold,
const double** value, const int64_t** n_node_samples, const double** impurity,
ModelHandle* out);
TREELITE_DLL int TreeliteLoadSKLearnGradientBoostingClassifier(
int n_estimators, int n_features, int n_classes, const int64_t* node_count,
const int64_t** children_left, const int64_t** children_right, const int64_t** feature,
const double** threshold, const double** value, const int64_t** n_node_samples,
const double** impurity, ModelHandle* out);

/*!
* \brief Query the number of trees in the model
* \param handle model to query
Expand Down
20 changes: 20 additions & 0 deletions include/treelite/frontend.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#ifndef TREELITE_FRONTEND_H_
#define TREELITE_FRONTEND_H_

#include <dmlc/logging.h>
#include <treelite/base.h>
#include <string>
#include <memory>
Expand Down Expand Up @@ -58,6 +59,25 @@ std::unique_ptr<treelite::Model> LoadXGBoostJSONModel(const char* filename);
*/
std::unique_ptr<treelite::Model> LoadXGBoostJSONModelString(const char* json_str, size_t length);

std::unique_ptr<treelite::Model> LoadSKLearnRandomForestRegressor(
int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left,
const int64_t** children_right, const int64_t** feature, const double** threshold,
const double** value, const int64_t** n_node_samples, const double** impurity);
std::unique_ptr<treelite::Model> LoadSKLearnRandomForestClassifier(
int n_estimators, int n_features, int n_classes, const int64_t* node_count,
const int64_t** children_left, const int64_t** children_right, const int64_t** feature,
const double** threshold, const double** value, const int64_t** n_node_samples,
const double** impurity);
std::unique_ptr<treelite::Model> LoadSKLearnGradientBoostingRegressor(
int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left,
const int64_t** children_right, const int64_t** feature, const double** threshold,
const double** value, const int64_t** n_node_samples, const double** impurity);
std::unique_ptr<treelite::Model> LoadSKLearnGradientBoostingClassifier(
int n_estimators, int n_features, int n_classes, const int64_t* node_count,
const int64_t** children_left, const int64_t** children_right, const int64_t** feature,
const double** threshold, const double** value, const int64_t** n_node_samples,
const double** impurity);

//--------------------------------------------------------------------------
// model builder interface: build trees incrementally
//--------------------------------------------------------------------------
Expand Down
138 changes: 137 additions & 1 deletion python/treelite/sklearn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
# coding: utf-8
"""Converter to ingest scikit-learn models into Treelite"""

from __future__ import absolute_import as _abs
hcho3 marked this conversation as resolved.
Show resolved Hide resolved

import ctypes
import numpy as np

from ..util import TreeliteError
from ..core import _LIB, c_array, _check_call
from ..frontend import Model
from .common import SKLConverterBase
from .gbm_regressor import SKLGBMRegressorMixin
from .gbm_classifier import SKLGBMClassifierMixin
Expand All @@ -20,6 +27,8 @@ def import_model(sklearn_model):
sklearn_model : object of type \
:py:class:`~sklearn.ensemble.RandomForestRegressor` / \
:py:class:`~sklearn.ensemble.RandomForestClassifier` / \
:py:class:`~sklearn.ensemble.ExtraTreesRegressor` / \
:py:class:`~sklearn.ensemble.ExtraTreesClassifier` / \
:py:class:`~sklearn.ensemble.GradientBoostingRegressor` / \
:py:class:`~sklearn.ensemble.GradientBoostingClassifier`
Python handle to scikit-learn model
Expand Down Expand Up @@ -97,4 +106,131 @@ class SKLRFMultiClassifierConverter(SKLRFMultiClassifierMixin, SKLConverterBase)
pass


__all__ = ['import_model']
class ArrayOfArrays:
"""
Utility class to marshall a collection of arrays in order to pass to a C function
"""
def __init__(self, *, dtype):
int64_ptr_type = ctypes.POINTER(ctypes.c_int64)
float64_ptr_type = ctypes.POINTER(ctypes.c_double)
if dtype == np.int64:
self.ptr_type = int64_ptr_type
elif dtype == np.float64:
self.ptr_type = float64_ptr_type
else:
raise ValueError(f'dtype {dtype} is not supported')
self.dtype = dtype
self.collection = []

def add(self, array, *, expected_shape=None):
"""Add an array to the collection"""
assert array.dtype == self.dtype
if expected_shape:
assert array.shape == expected_shape, \
f'Expected shape: {expected_shape}, Got shape {array.shape}'
v = np.array(array, copy=False, dtype=self.dtype, order='C')
self.collection.append(v.ctypes.data_as(self.ptr_type))

def as_c_array(self):
"""Prepare the collection to pass as an argument of a C function"""
return c_array(self.ptr_type, self.collection)


def import_model_v2(sklearn_model):
hcho3 marked this conversation as resolved.
Show resolved Hide resolved
# pylint: disable=R0914,R0912
"""
Load a tree ensemble model from a scikit-learn model object

Parameters
----------
sklearn_model : object of type \
:py:class:`~sklearn.ensemble.RandomForestRegressor` / \
:py:class:`~sklearn.ensemble.RandomForestClassifier` / \
:py:class:`~sklearn.ensemble.ExtraTreesRegressor` / \
:py:class:`~sklearn.ensemble.ExtraTreesClassifier` / \
:py:class:`~sklearn.ensemble.GradientBoostingRegressor` / \
:py:class:`~sklearn.ensemble.GradientBoostingClassifier`
Python handle to scikit-learn model

Returns
-------
model : :py:class:`~treelite.Model` object
loaded model
"""
class_name = sklearn_model.__class__.__name__
module_name = sklearn_model.__module__.split('.')[0]
if module_name != 'sklearn':
hcho3 marked this conversation as resolved.
Show resolved Hide resolved
raise TreeliteError('Not a scikit-learn model')

if class_name in ['RandomForestRegressor', 'ExtraTreesRegressor', 'GradientBoostingRegressor',
hcho3 marked this conversation as resolved.
Show resolved Hide resolved
'GradientBoostingClassifier']:
leaf_value_expected_shape = lambda node_count: (node_count, 1, 1)
elif class_name in ['RandomForestClassifier', 'ExtraTreesClassifier']:
leaf_value_expected_shape = lambda node_count: (node_count, 1, sklearn_model.n_classes_)
else:
raise TreeliteError(f'Not supported: {class_name}')

if class_name.startswith('GradientBoosting') and sklearn_model.init != 'zero':
raise TreeliteError("Gradient boosted trees must be trained with the option init='zero'")

node_count = []
children_left = ArrayOfArrays(dtype=np.int64)
children_right = ArrayOfArrays(dtype=np.int64)
feature = ArrayOfArrays(dtype=np.int64)
threshold = ArrayOfArrays(dtype=np.float64)
value = ArrayOfArrays(dtype=np.float64)
n_node_samples = ArrayOfArrays(dtype=np.int64)
impurity = ArrayOfArrays(dtype=np.float64)
for estimator in sklearn_model.estimators_:
if class_name.startswith('GradientBoosting'):
estimator_range = estimator
learning_rate = sklearn_model.learning_rate
else:
estimator_range = [estimator]
learning_rate = 1.0
for sub_estimator in estimator_range:
tree = sub_estimator.tree_
node_count.append(tree.node_count)
children_left.add(tree.children_left, expected_shape=(tree.node_count,))
children_right.add(tree.children_right, expected_shape=(tree.node_count,))
feature.add(tree.feature, expected_shape=(tree.node_count,))
threshold.add(tree.threshold, expected_shape=(tree.node_count,))
# Note: for gradient boosted trees, we shrink each leaf output by the learning rate
value.add(tree.value * learning_rate,
expected_shape=leaf_value_expected_shape(tree.node_count))
n_node_samples.add(tree.n_node_samples, expected_shape=(tree.node_count,))
impurity.add(tree.impurity, expected_shape=(tree.node_count,))

handle = ctypes.c_void_p()
if class_name in ['RandomForestRegressor', 'ExtraTreesRegressor']:
_check_call(_LIB.TreeliteLoadSKLearnRandomForestRegressor(
ctypes.c_int(sklearn_model.n_estimators), ctypes.c_int(sklearn_model.n_features_),
c_array(ctypes.c_int64, node_count), children_left.as_c_array(),
children_right.as_c_array(), feature.as_c_array(), threshold.as_c_array(),
value.as_c_array(), n_node_samples.as_c_array(), impurity.as_c_array(),
ctypes.byref(handle)))
elif class_name in ['RandomForestClassifier', 'ExtraTreesClassifier']:
_check_call(_LIB.TreeliteLoadSKLearnRandomForestClassifier(
ctypes.c_int(sklearn_model.n_estimators), ctypes.c_int(sklearn_model.n_features_),
ctypes.c_int(sklearn_model.n_classes_), c_array(ctypes.c_int64, node_count),
children_left.as_c_array(), children_right.as_c_array(), feature.as_c_array(),
threshold.as_c_array(), value.as_c_array(), n_node_samples.as_c_array(),
impurity.as_c_array(), ctypes.byref(handle)))
elif class_name == 'GradientBoostingRegressor':
_check_call(_LIB.TreeliteLoadSKLearnGradientBoostingRegressor(
ctypes.c_int(sklearn_model.n_estimators), ctypes.c_int(sklearn_model.n_features_),
c_array(ctypes.c_int64, node_count), children_left.as_c_array(),
children_right.as_c_array(), feature.as_c_array(), threshold.as_c_array(),
value.as_c_array(), n_node_samples.as_c_array(), impurity.as_c_array(),
ctypes.byref(handle)))
elif class_name == 'GradientBoostingClassifier':
_check_call(_LIB.TreeliteLoadSKLearnGradientBoostingClassifier(
ctypes.c_int(sklearn_model.n_estimators), ctypes.c_int(sklearn_model.n_features_),
ctypes.c_int(sklearn_model.n_classes_), c_array(ctypes.c_int64, node_count),
children_left.as_c_array(), children_right.as_c_array(), feature.as_c_array(),
threshold.as_c_array(), value.as_c_array(), n_node_samples.as_c_array(),
impurity.as_c_array(), ctypes.byref(handle)))
return Model(handle)


__all__ = ['import_model', 'import_model_v2']
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ target_sources(objtreelite
compiler/pred_transform.h
frontend/builder.cc
frontend/lightgbm.cc
frontend/sklearn.cc
frontend/xgboost.cc
frontend/xgboost_json.cc
frontend/xgboost_util.cc
Expand Down
52 changes: 52 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,58 @@ int TreeliteLoadXGBoostModelFromMemoryBuffer(const void* buf, size_t len, ModelH
API_END();
}

int TreeliteLoadSKLearnRandomForestRegressor(
int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left,
const int64_t** children_right, const int64_t** feature, const double** threshold,
const double** value, const int64_t** n_node_samples, const double** impurity,
ModelHandle* out) {
API_BEGIN();
std::unique_ptr<Model> model = frontend::LoadSKLearnRandomForestRegressor(
n_estimators, n_features, node_count, children_left, children_right, feature, threshold,
value, n_node_samples, impurity);
*out = static_cast<ModelHandle>(model.release());
API_END();
}

int TreeliteLoadSKLearnRandomForestClassifier(
int n_estimators, int n_features, int n_classes, const int64_t* node_count,
const int64_t** children_left, const int64_t** children_right, const int64_t** feature,
const double** threshold, const double** value, const int64_t** n_node_samples,
const double** impurity, ModelHandle* out) {
API_BEGIN();
std::unique_ptr<Model> model = frontend::LoadSKLearnRandomForestClassifier(
n_estimators, n_features, n_classes, node_count, children_left, children_right, feature,
threshold, value, n_node_samples, impurity);
*out = static_cast<ModelHandle>(model.release());
API_END();
}

int TreeliteLoadSKLearnGradientBoostingRegressor(
int n_estimators, int n_features, const int64_t* node_count, const int64_t** children_left,
const int64_t** children_right, const int64_t** feature, const double** threshold,
const double** value, const int64_t** n_node_samples, const double** impurity,
ModelHandle* out) {
API_BEGIN();
std::unique_ptr<Model> model = frontend::LoadSKLearnGradientBoostingRegressor(
n_estimators, n_features, node_count, children_left, children_right, feature, threshold,
value, n_node_samples, impurity);
*out = static_cast<ModelHandle>(model.release());
API_END();
}

int TreeliteLoadSKLearnGradientBoostingClassifier(
int n_estimators, int n_features, int n_classes, const int64_t* node_count,
const int64_t** children_left, const int64_t** children_right, const int64_t** feature,
const double** threshold, const double** value, const int64_t** n_node_samples,
const double** impurity, ModelHandle* out) {
API_BEGIN();
std::unique_ptr<Model> model = frontend::LoadSKLearnGradientBoostingClassifier(
n_estimators, n_features, n_classes, node_count, children_left, children_right, feature,
threshold, value, n_node_samples, impurity);
*out = static_cast<ModelHandle>(model.release());
API_END();
}

int TreeliteFreeModel(ModelHandle handle) {
API_BEGIN();
delete static_cast<Model*>(handle);
Expand Down
Loading