diff --git a/conda/environments/cuml_dev_cuda10.1.yml b/conda/environments/cuml_dev_cuda10.1.yml index 98a80e0b05..00e4953fe8 100644 --- a/conda/environments/cuml_dev_cuda10.1.yml +++ b/conda/environments/cuml_dev_cuda10.1.yml @@ -21,7 +21,7 @@ dependencies: - faiss-proc=*=cuda - umap-learn - scikit-learn=0.23.1 -- treelite=0.93 +- treelite=1.0.0rc1 - pip - pip: - sphinx_markdown_tables diff --git a/conda/environments/cuml_dev_cuda10.2.yml b/conda/environments/cuml_dev_cuda10.2.yml index 891680c02e..852f22b06a 100644 --- a/conda/environments/cuml_dev_cuda10.2.yml +++ b/conda/environments/cuml_dev_cuda10.2.yml @@ -21,7 +21,7 @@ dependencies: - faiss-proc=*=cuda - umap-learn - scikit-learn=0.23.1 -- treelite=0.93 +- treelite=1.0.0rc1 - pip - pip: - sphinx_markdown_tables diff --git a/conda/environments/cuml_dev_cuda11.0.yml b/conda/environments/cuml_dev_cuda11.0.yml index dcc869650e..d8bbf2828b 100644 --- a/conda/environments/cuml_dev_cuda11.0.yml +++ b/conda/environments/cuml_dev_cuda11.0.yml @@ -21,7 +21,7 @@ dependencies: - faiss-proc=*=cuda - umap-learn - scikit-learn=0.23.1 -- treelite=0.93 +- treelite=1.0.0rc1 - pip - pip: - sphinx_markdown_tables diff --git a/conda/recipes/cuml/meta.yaml b/conda/recipes/cuml/meta.yaml index 740d238d1d..c849f87dd2 100644 --- a/conda/recipes/cuml/meta.yaml +++ b/conda/recipes/cuml/meta.yaml @@ -28,7 +28,7 @@ requirements: - setuptools - cython>=0.29,<0.30 - cmake>=3.14 - - treelite=0.93 + - treelite=1.0.0rc1 - cudf {{ minor_version }} - libcuml={{ version }} - libcumlprims {{ minor_version }} @@ -40,7 +40,7 @@ requirements: - dask-cudf {{ minor_version }} - libcuml={{ version }} - libcumlprims {{ minor_version }} - - treelite=0.93 + - treelite=1.0.0rc1 - cupy>7.1.0,<9.0.0a0 - nccl>=2.5 - ucx-py {{ minor_version }} diff --git a/conda/recipes/libcuml/meta.yaml b/conda/recipes/libcuml/meta.yaml index 012f911253..2a12a3a990 100644 --- a/conda/recipes/libcuml/meta.yaml +++ b/conda/recipes/libcuml/meta.yaml @@ -37,7 +37,7 @@ requirements: - ucx-py {{ minor_version }} - libcumlprims {{ minor_version }} - lapack - - treelite=0.93 + - treelite=1.0.0rc1 - faiss-proc=*=cuda - gtest=1.10.0 - libfaiss=1.6.3 @@ -47,7 +47,7 @@ requirements: - nccl>=2.5 - ucx-py {{ minor_version }} - {{ pin_compatible('cudatoolkit', max_pin='x.x') }} - - treelite=0.93 + - treelite=1.0.0rc1 - faiss-proc=*=cuda - libfaiss=1.6.3 diff --git a/cpp/cmake/Dependencies.cmake b/cpp/cmake/Dependencies.cmake index 860fbb859a..1c7ff00b8d 100644 --- a/cpp/cmake/Dependencies.cmake +++ b/cpp/cmake/Dependencies.cmake @@ -185,7 +185,7 @@ endif(BUILD_STATIC_FAISS) ############################################################################## # - treelite build ----------------------------------------------------------- -find_package(Treelite 0.93 REQUIRED) +find_package(Treelite 1.0.0 REQUIRED) ############################################################################## # - googletest build ----------------------------------------------------------- diff --git a/cpp/src/decisiontree/decisiontree_impl.cuh b/cpp/src/decisiontree/decisiontree_impl.cuh index 62dab85e39..d4f87e23c3 100644 --- a/cpp/src/decisiontree/decisiontree_impl.cuh +++ b/cpp/src/decisiontree/decisiontree_impl.cuh @@ -30,6 +30,7 @@ #include "levelalgo/metric.cuh" #include "memory.cuh" #include "quantile/quantile.cuh" +#include "treelite_util.h" namespace ML { @@ -140,7 +141,7 @@ struct Node_ID_info { template void build_treelite_tree(TreeBuilderHandle tree_builder, DecisionTree::TreeMetaDataNode *tree_ptr, - int num_output_group) { + int num_class) { int node_id = 0; TREELITE_CHECK(TreeliteTreeBuilderCreateNode(tree_builder, node_id)); @@ -174,28 +175,54 @@ void build_treelite_tree(TreeBuilderHandle tree_builder, TreeliteTreeBuilderCreateNode(tree_builder, node_id + 2)); // Set node from current level as numerical node. Children IDs known. + ValueHandle threshold; + TREELITE_CHECK(TreeliteTreeBuilderCreateValue( + &q_node.node.quesval, TreeliteType::value, &threshold)); TREELITE_CHECK(TreeliteTreeBuilderSetNumericalTestNode( tree_builder, q_node.unique_node_id, q_node.node.colid, - "<=", q_node.node.quesval, 1, node_id + 1, node_id + 2)); + "<=", threshold, 1, node_id + 1, node_id + 2)); + TREELITE_CHECK(TreeliteTreeBuilderDeleteValue(threshold)); node_id += 2; } else { - if (num_output_group == 1) { + if (num_class == 1) { + ValueHandle leaf_value; + if (std::is_same::value) { + // Integer output is not yet supported in Treelite codegen + float prediction = static_cast(q_node.node.prediction); + TREELITE_CHECK(TreeliteTreeBuilderCreateValue( + &prediction, TreeliteType::value, &leaf_value)); + } else { + TREELITE_CHECK(TreeliteTreeBuilderCreateValue( + &q_node.node.prediction, TreeliteType::value, &leaf_value)); + } TREELITE_CHECK(TreeliteTreeBuilderSetLeafNode( - tree_builder, q_node.unique_node_id, q_node.node.prediction)); + tree_builder, q_node.unique_node_id, leaf_value)); + TREELITE_CHECK(TreeliteTreeBuilderDeleteValue(leaf_value)); } else { - std::vector leaf_vector(num_output_group); - for (int j = 0; j < num_output_group; j++) { + std::vector leaf_vector(num_class); + std::vector leaf_vector_handle(num_class, nullptr); + for (int j = 0; j < num_class; j++) { if (q_node.node.prediction == j) { leaf_vector[j] = 1; } else { leaf_vector[j] = 0; } } + for (int j = 0; j < num_class; j++) { + TREELITE_CHECK(TreeliteTreeBuilderCreateValue( + &leaf_vector[j], TreeliteType::value, + &leaf_vector_handle[j])); + } TREELITE_CHECK(TreeliteTreeBuilderSetLeafVectorNode( - tree_builder, q_node.unique_node_id, leaf_vector.data(), - num_output_group)); + tree_builder, q_node.unique_node_id, leaf_vector_handle.data(), + num_class)); + for (int j = 0; j < num_class; j++) { + TREELITE_CHECK( + TreeliteTreeBuilderDeleteValue(leaf_vector_handle[j])); + } leaf_vector.clear(); + leaf_vector_handle.clear(); } } } @@ -521,17 +548,16 @@ template class DecisionTreeRegressor; template void build_treelite_tree( TreeBuilderHandle tree_builder, - DecisionTree::TreeMetaDataNode *tree_ptr, int num_output_group); + DecisionTree::TreeMetaDataNode *tree_ptr, int num_class); template void build_treelite_tree( TreeBuilderHandle tree_builder, - DecisionTree::TreeMetaDataNode *tree_ptr, int num_output_group); + DecisionTree::TreeMetaDataNode *tree_ptr, int num_class); template void build_treelite_tree( TreeBuilderHandle tree_builder, - DecisionTree::TreeMetaDataNode *tree_ptr, int num_output_group); + DecisionTree::TreeMetaDataNode *tree_ptr, int num_class); template void build_treelite_tree( TreeBuilderHandle tree_builder, - DecisionTree::TreeMetaDataNode *tree_ptr, - int num_output_group); + DecisionTree::TreeMetaDataNode *tree_ptr, int num_class); } //End namespace DecisionTree } //End namespace ML diff --git a/cpp/src/decisiontree/decisiontree_impl.h b/cpp/src/decisiontree/decisiontree_impl.h index 4ba3266130..55d927ef7f 100644 --- a/cpp/src/decisiontree/decisiontree_impl.h +++ b/cpp/src/decisiontree/decisiontree_impl.h @@ -57,7 +57,7 @@ std::string dump_node_as_json( template void build_treelite_tree(TreeBuilderHandle tree_builder, DecisionTree::TreeMetaDataNode *tree_ptr, - int num_output_group); + int num_class); struct DataInfo { unsigned int NLocalrows; diff --git a/cpp/src/decisiontree/treelite_util.h b/cpp/src/decisiontree/treelite_util.h new file mode 100644 index 0000000000..d47aa28acd --- /dev/null +++ b/cpp/src/decisiontree/treelite_util.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace ML { +namespace DecisionTree { + +template +class TreeliteType; + +template <> +class TreeliteType { + public: + static constexpr const char* value = "float32"; +}; + +template <> +class TreeliteType { + public: + static constexpr const char* value = "float64"; +}; + +template <> +class TreeliteType { + public: + static constexpr const char* value = "uint32"; +}; + +template <> +class TreeliteType { + public: + static_assert(sizeof(int) == sizeof(uint32_t), "int must be 32-bit"); + static constexpr const char* value = "uint32"; +}; + +} //End namespace DecisionTree + +} //End namespace ML diff --git a/cpp/src/fil/fil.cu b/cpp/src/fil/fil.cu index 58bce9e885..8fa3c6d77e 100644 --- a/cpp/src/fil/fil.cu +++ b/cpp/src/fil/fil.cu @@ -397,20 +397,13 @@ void check_params(const forest_params_t* params, bool dense) { ASSERT(params->blocks_per_sm >= 0, "blocks_per_sm must be nonnegative"); } -int tree_root(const tl::Tree& tree) { +template +int tree_root(const tl::Tree& tree) { return 0; // Treelite format assumes that the root is 0 } -int max_depth_helper(const tl::Tree& tree, int node_id, int limit) { - if (tree.IsLeaf(node_id)) return 0; - ASSERT(limit > 0, - "recursion depth limit reached, might be a cycle in the tree"); - return 1 + - std::max(max_depth_helper(tree, tree.LeftChild(node_id), limit - 1), - max_depth_helper(tree, tree.RightChild(node_id), limit - 1)); -} - -inline int max_depth(const tl::Tree& tree) { +template +inline int max_depth(const tl::Tree& tree) { // trees of this depth aren't used, so it most likely means bad input data, // e.g. cycles in the forest const int DEPTH_LIMIT = 500; @@ -437,9 +430,12 @@ inline int max_depth(const tl::Tree& tree) { return max_depth; } -int max_depth(const tl::Model& model) { +template +int max_depth(const tl::ModelImpl& model) { int depth = 0; - for (const auto& tree : model.trees) depth = std::max(depth, max_depth(tree)); + for (const auto& tree : model.trees) { + depth = std::max(depth, max_depth(tree)); + } return depth; } @@ -475,16 +471,17 @@ inline void adjust_threshold(float* pthreshold, int* tl_left, int* tl_right, /** if the vector consists of zeros and a single one, return the position for the one (assumed class label). Else, asserts false. If the vector contains a NAN, asserts false */ -int find_class_label_from_one_hot(tl::tl_float* vector, int len) { +template +int find_class_label_from_one_hot(L* vector, int len) { bool found_label = false; int out; for (int i = 0; i < len; ++i) { - if (vector[i] == 1.0f) { + if (vector[i] == static_cast(1.0)) { ASSERT(!found_label, "label vector contains multiple 1.0f"); out = i; found_label = true; } else { - ASSERT(vector[i] == 0.0f, + ASSERT(vector[i] == static_cast(0.0), "label vector contains values other than 0.0 and 1.0"); } } @@ -492,8 +489,8 @@ int find_class_label_from_one_hot(tl::tl_float* vector, int len) { return out; } -template -void tl2fil_leaf_payload(fil_node_t* fil_node, const tl::Tree& tl_tree, +template +void tl2fil_leaf_payload(fil_node_t* fil_node, const tl::Tree& tl_tree, int tl_node_id, const forest_params_t& forest_params) { auto vec = tl_tree.LeafVector(tl_node_id); switch (forest_params.leaf_algo) { @@ -504,7 +501,7 @@ void tl2fil_leaf_payload(fil_node_t* fil_node, const tl::Tree& tl_tree, break; case leaf_algo_t::FLOAT_UNARY_BINARY: case leaf_algo_t::GROVE_PER_CLASS: - fil_node->val.f = tl_tree.LeafValue(tl_node_id); + fil_node->val.f = static_cast(tl_tree.LeafValue(tl_node_id)); ASSERT(!tl_tree.HasLeafVector(tl_node_id), "some but not all treelite leaves have leaf_vector()"); break; @@ -513,8 +510,9 @@ void tl2fil_leaf_payload(fil_node_t* fil_node, const tl::Tree& tl_tree, }; } +template void node2fil_dense(std::vector* pnodes, int root, int cur, - const tl::Tree& tree, int node_id, + const tl::Tree& tree, int node_id, const forest_params_t& forest_params) { if (tree.IsLeaf(node_id)) { (*pnodes)[root + cur] = dense_node(val_t{.f = NAN}, NAN, 0, false, true); @@ -527,7 +525,7 @@ void node2fil_dense(std::vector* pnodes, int root, int cur, "only numerical split nodes are supported"); int tl_left = tree.LeftChild(node_id), tl_right = tree.RightChild(node_id); bool default_left = tree.DefaultLeft(node_id); - float threshold = tree.Threshold(node_id); + float threshold = static_cast(tree.Threshold(node_id)); adjust_threshold(&threshold, &tl_left, &tl_right, &default_left, tree.ComparisonOp(node_id)); (*pnodes)[root + cur] = dense_node( @@ -537,14 +535,15 @@ void node2fil_dense(std::vector* pnodes, int root, int cur, node2fil_dense(pnodes, root, left + 1, tree, tl_right, forest_params); } +template void tree2fil_dense(std::vector* pnodes, int root, - const tl::Tree& tree, + const tl::Tree& tree, const forest_params_t& forest_params) { node2fil_dense(pnodes, root, 0, tree, tree_root(tree), forest_params); } -template -int tree2fil_sparse(std::vector* pnodes, const tl::Tree& tree, +template +int tree2fil_sparse(std::vector* pnodes, const tl::Tree& tree, const forest_params_t& forest_params) { typedef std::pair pair_t; std::stack stack; @@ -566,7 +565,7 @@ int tree2fil_sparse(std::vector* pnodes, const tl::Tree& tree, int tl_left = tree.LeftChild(node_id), tl_right = tree.RightChild(node_id); bool default_left = tree.DefaultLeft(node_id); - float threshold = tree.Threshold(node_id); + float threshold = static_cast(tree.Threshold(node_id)); adjust_threshold(&threshold, &tl_left, &tl_right, &default_left, tree.ComparisonOp(node_id)); @@ -595,8 +594,9 @@ int tree2fil_sparse(std::vector* pnodes, const tl::Tree& tree, return root; } -size_t tl_leaf_vector_size(const tl::Model& model) { - const tl::Tree& tree = model.trees[0]; +template +size_t tl_leaf_vector_size(const tl::ModelImpl& model) { + const tl::Tree& tree = model.trees[0]; int node_key; for (node_key = tree_root(tree); !tree.IsLeaf(node_key); node_key = tree.RightChild(node_key)) @@ -607,7 +607,8 @@ size_t tl_leaf_vector_size(const tl::Model& model) { // tl2fil_common is the part of conversion from a treelite model // common for dense and sparse forests -void tl2fil_common(forest_params_t* params, const tl::Model& model, +template +void tl2fil_common(forest_params_t* params, const tl::ModelImpl& model, const treelite_params_t* tl_params) { // fill in forest-indendent params params->algo = tl_params->algo; @@ -622,7 +623,7 @@ void tl2fil_common(forest_params_t* params, const tl::Model& model, size_t leaf_vec_size = tl_leaf_vector_size(model); std::string pred_transform(param.pred_transform); if (leaf_vec_size > 0) { - ASSERT(leaf_vec_size == model.num_output_group, + ASSERT(leaf_vec_size == model.task_param.num_class, "treelite model inconsistent"); params->num_classes = leaf_vec_size; params->leaf_algo = leaf_algo_t::CATEGORICAL_LEAF; @@ -636,8 +637,8 @@ void tl2fil_common(forest_params_t* params, const tl::Model& model, "are supported for multi-class models"); } else { - if (model.num_output_group > 1) { - params->num_classes = model.num_output_group; + if (model.task_param.num_class > 1) { + params->num_classes = static_cast(model.task_param.num_class); ASSERT(tl_params->output_class, "output_class==true is required for multi-class models"); ASSERT(pred_transform == "sigmoid" || pred_transform == "identity" || @@ -672,7 +673,7 @@ void tl2fil_common(forest_params_t* params, const tl::Model& model, params->output = output_t(params->output | output_t::CLASS); } // "random forest" in treelite means tree output averaging - if (model.random_forest_flag) { + if (model.average_tree_output) { params->output = output_t(params->output | output_t::AVG); } if (std::string(param.pred_transform) == "sigmoid") { @@ -684,8 +685,10 @@ void tl2fil_common(forest_params_t* params, const tl::Model& model, // uses treelite model with additional tl_params to initialize FIL params // and dense nodes (stored in *pnodes) +template void tl2fil_dense(std::vector* pnodes, forest_params_t* params, - const tl::Model& model, const treelite_params_t* tl_params) { + const tl::ModelImpl& model, + const treelite_params_t* tl_params) { tl2fil_common(params, model, tl_params); // convert the nodes @@ -699,24 +702,27 @@ void tl2fil_dense(std::vector* pnodes, forest_params_t* params, template struct tl2fil_sparse_check_t { - static void check(const tl::Model& model) { + template + static void check(const tl::ModelImpl& model) { ASSERT(false, "internal error: " - "only a specialization of this tempalte should be used"); + "only a specialization of this template should be used"); } }; template <> struct tl2fil_sparse_check_t { // no extra check for 16-byte sparse nodes - static void check(const tl::Model& model) {} + template + static void check(const tl::ModelImpl& model) {} }; template <> struct tl2fil_sparse_check_t { static const int MAX_FEATURES = 1 << sparse_node8::FID_NUM_BITS; static const int MAX_TREE_NODES = (1 << sparse_node8::LEFT_NUM_BITS) - 1; - static void check(const tl::Model& model) { + template + static void check(const tl::ModelImpl& model) { // check the number of features int num_features = model.num_feature; ASSERT(num_features <= MAX_FEATURES, @@ -725,7 +731,7 @@ struct tl2fil_sparse_check_t { num_features, MAX_FEATURES); // check the number of tree nodes - const std::vector& trees = model.trees; + const std::vector>& trees = model.trees; for (int i = 0; i < trees.size(); ++i) { int num_nodes = trees[i].num_nodes; ASSERT(num_nodes <= MAX_TREE_NODES, @@ -738,9 +744,9 @@ struct tl2fil_sparse_check_t { // uses treelite model with additional tl_params to initialize FIL params, // trees (stored in *ptrees) and sparse nodes (stored in *pnodes) -template +template void tl2fil_sparse(std::vector* ptrees, std::vector* pnodes, - forest_params_t* params, const tl::Model& model, + forest_params_t* params, const tl::ModelImpl& model, const treelite_params_t* tl_params) { tl2fil_common(params, model, tl_params); tl2fil_sparse_check_t::check(model); @@ -781,19 +787,34 @@ template void init_sparse(const raft::handle_t& h, forest_t* pf, const sparse_node8* nodes, const forest_params_t* params); +template void from_treelite(const raft::handle_t& handle, forest_t* pforest, - ModelHandle model, const treelite_params_t* tl_params) { + const tl::ModelImpl& model, + const treelite_params_t* tl_params) { + // Invariants on threshold and leaf types + static_assert(std::is_same::value || std::is_same::value, + "Model must contain float32 or float64 thresholds for splits"); + ASSERT((std::is_same::value || std::is_same::value), + "Models with integer leaf output are not yet supported"); + // Display appropriate warnings when float64 values are being casted into + // float32, as FIL only supports inferencing with float32 for the time being + if (std::is_same::value || std::is_same::value) { + CUML_LOG_WARN( + "Casting all thresholds and leaf values to float32, as FIL currently " + "doesn't support inferencing models with float64 values. " + "This may lead to predictions with reduced accuracy."); + } + storage_type_t storage_type = tl_params->storage_type; // build dense trees by default - const tl::Model& model_ref = *(tl::Model*)model; if (storage_type == storage_type_t::AUTO) { if (tl_params->algo == algo_t::ALGO_AUTO || tl_params->algo == algo_t::NAIVE) { - int depth = max_depth(model_ref); + int depth = max_depth(model); // max 2**25 dense nodes, 256 MiB dense model size const int LOG2_MAX_DENSE_NODES = 25; int log2_num_dense_nodes = - depth + 1 + int(ceil(std::log2(model_ref.trees.size()))); + depth + 1 + int(ceil(std::log2(model.trees.size()))); storage_type = log2_num_dense_nodes > LOG2_MAX_DENSE_NODES ? storage_type_t::SPARSE : storage_type_t::DENSE; @@ -807,7 +828,7 @@ void from_treelite(const raft::handle_t& handle, forest_t* pforest, switch (storage_type) { case storage_type_t::DENSE: { std::vector nodes; - tl2fil_dense(&nodes, ¶ms, model_ref, tl_params); + tl2fil_dense(&nodes, ¶ms, model, tl_params); init_dense(handle, pforest, nodes.data(), ¶ms); // sync is necessary as nodes is used in init_dense(), // but destructed at the end of this function @@ -817,7 +838,7 @@ void from_treelite(const raft::handle_t& handle, forest_t* pforest, case storage_type_t::SPARSE: { std::vector trees; std::vector nodes; - tl2fil_sparse(&trees, &nodes, ¶ms, model_ref, tl_params); + tl2fil_sparse(&trees, &nodes, ¶ms, model, tl_params); init_sparse(handle, pforest, trees.data(), nodes.data(), ¶ms); CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); break; @@ -825,7 +846,7 @@ void from_treelite(const raft::handle_t& handle, forest_t* pforest, case storage_type_t::SPARSE8: { std::vector trees; std::vector nodes; - tl2fil_sparse(&trees, &nodes, ¶ms, model_ref, tl_params); + tl2fil_sparse(&trees, &nodes, ¶ms, model, tl_params); init_sparse(handle, pforest, trees.data(), nodes.data(), ¶ms); CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); break; @@ -835,6 +856,15 @@ void from_treelite(const raft::handle_t& handle, forest_t* pforest, } } +void from_treelite(const raft::handle_t& handle, forest_t* pforest, + ModelHandle model, const treelite_params_t* tl_params) { + const tl::Model& model_ref = *(tl::Model*)model; + model_ref.Dispatch([&handle, pforest, tl_params](const auto& model_inner) { + // model_inner is of the concrete type tl::ModelImpl + from_treelite(handle, pforest, model_inner, tl_params); + }); +} + void free(const raft::handle_t& h, forest_t f) { f->free(h); delete f; diff --git a/cpp/src/randomforest/randomforest.cu b/cpp/src/randomforest/randomforest.cu index 0cac558bd4..1a2b7a08c8 100644 --- a/cpp/src/randomforest/randomforest.cu +++ b/cpp/src/randomforest/randomforest.cu @@ -311,11 +311,19 @@ void build_treelite_forest(ModelHandle* model, // The value should be set to 0 if the model is gradient boosted trees. int random_forest_flag = 1; ModelBuilderHandle model_builder; - // num_output_group is 1 for binary classification and regression - // num_output_group is #class for multiclass classification which is the same as task_category - int num_output_group = task_category > 2 ? task_category : 1; + // num_class is 1 for binary classification and regression + // num_class is #class for multiclass classification which is the same as task_category + int num_class = task_category > 2 ? task_category : 1; + + const char* leaf_type = DecisionTree::TreeliteType::value; + if (std::is_same::value) { + // Treelite codegen doesn't yet support integer leaf output + leaf_type = DecisionTree::TreeliteType::value; + } + TREELITE_CHECK(TreeliteCreateModelBuilder( - num_features, num_output_group, random_forest_flag, &model_builder)); + num_features, num_class, random_forest_flag, + DecisionTree::TreeliteType::value, leaf_type, &model_builder)); if (task_category > 2) { // Multi-class classification @@ -327,10 +335,11 @@ void build_treelite_forest(ModelHandle* model, DecisionTree::TreeMetaDataNode* tree_ptr = &forest->trees[i]; TreeBuilderHandle tree_builder; - TREELITE_CHECK(TreeliteCreateTreeBuilder(&tree_builder)); + TREELITE_CHECK(TreeliteCreateTreeBuilder( + DecisionTree::TreeliteType::value, leaf_type, &tree_builder)); if (tree_ptr->sparsetree.size() != 0) { DecisionTree::build_treelite_tree(tree_builder, tree_ptr, - num_output_group); + num_class); // The third argument -1 means append to the end of the tree list. TREELITE_CHECK( @@ -349,8 +358,9 @@ void build_treelite_forest(ModelHandle* model, * @param[in] tree_from_concatenated_forest: Tree info from the concatenated forest. * @param[in] tree_from_individual_forest: Tree info from the forest present in each worker. */ -void compare_trees(tl::Tree& tree_from_concatenated_forest, - tl::Tree& tree_from_individual_forest) { +template +void compare_trees(tl::Tree& tree_from_concatenated_forest, + tl::Tree& tree_from_individual_forest) { ASSERT(tree_from_concatenated_forest.num_nodes == tree_from_individual_forest.num_nodes, "Error! Mismatch the number of nodes present in a tree in the " @@ -416,24 +426,37 @@ void compare_concat_forest_to_subforests( for (int forest_idx = 0; forest_idx < treelite_handles.size(); forest_idx++) { tl::Model& model = *(tl::Model*)(treelite_handles[forest_idx]); + ASSERT( + concat_model.GetThresholdType() == model.GetThresholdType(), + "Error! Concatenated forest does not have the same threshold type as " + "the individual forests"); + ASSERT( + concat_model.GetLeafOutputType() == model.GetLeafOutputType(), + "Error! Concatenated forest does not have the same leaf output type as " + "the individual forests"); ASSERT( concat_model.num_feature == model.num_feature, "Error! number of features mismatch between concatenated forest and the" - " individual forests "); - ASSERT(concat_model.num_output_group == model.num_output_group, - "Error! number of output group mismatch between concatenated forest " - "and the" - " individual forests "); - ASSERT(concat_model.random_forest_flag == model.random_forest_flag, - "Error! random forest flag value mismatch between concatenated " - "forest and the" - " individual forests "); - - for (int indiv_trees = 0; indiv_trees < model.trees.size(); indiv_trees++) { - compare_trees(concat_model.trees[concat_mod_tree_num + indiv_trees], - model.trees[indiv_trees]); - } - concat_mod_tree_num = concat_mod_tree_num + model.trees.size(); + " individual forests"); + ASSERT(concat_model.task_param.num_class == model.task_param.num_class, + "Error! number of classes mismatch between concatenated forest " + "and the individual forests "); + ASSERT(concat_model.average_tree_output == model.average_tree_output, + "Error! average_tree_output flag value mismatch between " + "concatenated forest and the individual forests"); + + model.Dispatch([&concat_mod_tree_num, &concat_model](auto& model_inner) { + // model_inner is of the concrete type tl::ModelImpl + using model_type = std::remove_reference_t; + auto& concat_model_inner = dynamic_cast(concat_model); + for (int indiv_trees = 0; indiv_trees < model_inner.trees.size(); + indiv_trees++) { + compare_trees( + concat_model_inner.trees[concat_mod_tree_num + indiv_trees], + model_inner.trees[indiv_trees]); + } + concat_mod_tree_num = concat_mod_tree_num + model_inner.trees.size(); + }); } } @@ -447,17 +470,28 @@ void compare_concat_forest_to_subforests( */ ModelHandle concatenate_trees(std::vector treelite_handles) { tl::Model& first_model = *(tl::Model*)treelite_handles[0]; - tl::Model* concat_model = new tl::Model; - for (int forest_idx = 0; forest_idx < treelite_handles.size(); forest_idx++) { - tl::Model& model = *(tl::Model*)treelite_handles[forest_idx]; - for (const tl::Tree& tree : model.trees) { - concat_model->trees.push_back(tree.Clone()); - } - } - concat_model->num_feature = first_model.num_feature; - concat_model->num_output_group = first_model.num_output_group; - concat_model->random_forest_flag = first_model.random_forest_flag; - concat_model->param = first_model.param; + tl::Model* concat_model = + first_model.Dispatch([&treelite_handles](auto& first_model_inner) { + // first_model_inner is of the concrete type tl::ModelImpl + using model_type = std::remove_reference_t; + auto* concat_model = dynamic_cast( + tl::Model::Create(first_model_inner.GetThresholdType(), + first_model_inner.GetLeafOutputType()) + .release()); + for (int forest_idx = 0; forest_idx < treelite_handles.size(); + forest_idx++) { + tl::Model& model = *(tl::Model*)treelite_handles[forest_idx]; + auto& model_inner = dynamic_cast(model); + for (const auto& tree : model_inner.trees) { + concat_model->trees.push_back(tree.Clone()); + } + } + concat_model->num_feature = first_model_inner.num_feature; + concat_model->task_param = first_model_inner.task_param; + concat_model->average_tree_output = first_model_inner.average_tree_output; + concat_model->param = first_model_inner.param; + return static_cast(concat_model); + }); return concat_model; } diff --git a/cpp/src/randomforest/randomforest_impl.cuh b/cpp/src/randomforest/randomforest_impl.cuh index c9e9f8a25d..e86719935b 100644 --- a/cpp/src/randomforest/randomforest_impl.cuh +++ b/cpp/src/randomforest/randomforest_impl.cuh @@ -19,6 +19,7 @@ #endif #include #include +#include #include #include #include diff --git a/cpp/test/sg/fil_test.cu b/cpp/test/sg/fil_test.cu index 59ea3eab39..a43dc8e6e2 100644 --- a/cpp/test/sg/fil_test.cu +++ b/cpp/test/sg/fil_test.cu @@ -500,13 +500,14 @@ class TreeliteFilTest : public BaseFilTest { case fil::leaf_algo_t::FLOAT_UNARY_BINARY: case fil::leaf_algo_t::GROVE_PER_CLASS: // default is fil::FLOAT_UNARY_BINARY - builder->SetLeafNode(key, dense_node.base_node::output().f); + builder->SetLeafNode( + key, tlf::Value::Create(dense_node.base_node::output().f)); break; case fil::leaf_algo_t::CATEGORICAL_LEAF: - std::vector vec(ps.num_classes); + std::vector vec(ps.num_classes); for (int i = 0; i < ps.num_classes; ++i) { - vec[i] = - i == dense_node.base_node::output().idx ? 1.0f : 0.0f; + vec[i] = tlf::Value::Create( + i == dense_node.base_node::output().idx ? 1.0f : 0.0f); } builder->SetLeafVectorNode(key, vec); } @@ -537,8 +538,9 @@ class TreeliteFilTest : public BaseFilTest { } int left_key = node_to_treelite(builder, pkey, root, left); int right_key = node_to_treelite(builder, pkey, root, right); - builder->SetNumericalTestNode(key, dense_node.fid(), ps.op, threshold, - default_left, left_key, right_key); + builder->SetNumericalTestNode(key, dense_node.fid(), ps.op, + tlf::Value::Create(threshold), default_left, + left_key, right_key); } return key; } @@ -549,7 +551,8 @@ class TreeliteFilTest : public BaseFilTest { int treelite_num_classes = ps.leaf_algo == fil::leaf_algo_t::FLOAT_UNARY_BINARY ? 1 : ps.num_classes; std::unique_ptr model_builder(new tlf::ModelBuilder( - ps.num_cols, treelite_num_classes, random_forest_flag)); + ps.num_cols, treelite_num_classes, random_forest_flag, + tl::TypeInfo::kFloat32, tl::TypeInfo::kFloat32)); // prediction transform if ((ps.output & fil::output_t::SIGMOID) != 0) { @@ -570,7 +573,8 @@ class TreeliteFilTest : public BaseFilTest { // build the trees for (int i_tree = 0; i_tree < ps.num_trees; ++i_tree) { - tlf::TreeBuilder* tree_builder = new tlf::TreeBuilder(); + tlf::TreeBuilder* tree_builder = + new tlf::TreeBuilder(tl::TypeInfo::kFloat32, tl::TypeInfo::kFloat32); int key_counter = 0; int root = i_tree * tree_num_nodes(); int root_key = node_to_treelite(tree_builder, &key_counter, root, root); @@ -580,8 +584,7 @@ class TreeliteFilTest : public BaseFilTest { } // commit the model - std::unique_ptr model(new tl::Model); - model_builder->CommitModel(model.get()); + std::unique_ptr model = model_builder->CommitModel(); // init FIL forest with the model fil::treelite_params_t params; diff --git a/cpp/test/sg/rf_treelite_test.cu b/cpp/test/sg/rf_treelite_test.cu index 5d3b771502..b0dded3e88 100644 --- a/cpp/test/sg/rf_treelite_test.cu +++ b/cpp/test/sg/rf_treelite_test.cu @@ -15,6 +15,7 @@ */ #include +#include #include #include #include @@ -122,26 +123,25 @@ class RfTreeliteTestCommon : public ::testing::TestWithParam> { TREELITE_CHECK( TreelitePredictorLoad(lib_path.c_str(), worker_thread, &predictor)); - DenseBatchHandle dense_batch; - // Current RF dosen't seem to support missing value, put NaN to be safe. - float missing_value = std::numeric_limits::quiet_NaN(); - TREELITE_CHECK(TreeliteAssembleDenseBatch( - inference_data_h.data(), missing_value, params.n_inference_rows, - params.n_cols, &dense_batch)); + DMatrixHandle dmat; + // Current RF doesn't seem to support missing value, put NaN to be safe. + T missing_value = std::numeric_limits::quiet_NaN(); + TREELITE_CHECK(TreeliteDMatrixCreateFromMat( + inference_data_h.data(), ML::DecisionTree::TreeliteType::value, + params.n_inference_rows, params.n_cols, &missing_value, &dmat)); // Use dense batch so batch_sparse is 0. // pred_margin = true means to produce raw margins rather than transformed probability. - int batch_sparse = 0; bool pred_margin = false; // Allocate larger array for treelite predicted label with using multi-class classification to avoid seg faults. // Altough later we only use first params.n_inference_rows elements. size_t treelite_predicted_labels_size; TREELITE_CHECK(TreelitePredictorPredictBatch( - predictor, dense_batch, batch_sparse, verbose, pred_margin, - treelite_predicted_labels.data(), &treelite_predicted_labels_size)); + predictor, dmat, verbose, pred_margin, treelite_predicted_labels.data(), + &treelite_predicted_labels_size)); - TREELITE_CHECK(TreeliteDeleteDenseBatch(dense_batch)); + TREELITE_CHECK(TreeliteDMatrixFree(dmat)); TREELITE_CHECK(TreelitePredictorFree(predictor)); TREELITE_CHECK(TreeliteFreeModel(concatenated_forest_handle)); TREELITE_CHECK(TreeliteFreeModel(treelite_indiv_handles[0])); diff --git a/python/cuml/benchmark/algorithms.py b/python/cuml/benchmark/algorithms.py index e562ee1642..b321141c2b 100644 --- a/python/cuml/benchmark/algorithms.py +++ b/python/cuml/benchmark/algorithms.py @@ -193,7 +193,7 @@ def _labels_to_int_hook(data): def _treelite_format_hook(data): """Helper function converting data into treelite format""" - return treelite_runtime.Batch.from_npy2d(data[0]), data[1] + return treelite_runtime.DMatrix(data[0]), data[1] def all_algorithms(): diff --git a/python/cuml/ensemble/randomforest_shared.pyx b/python/cuml/ensemble/randomforest_shared.pyx index b58b2a442f..166289d0af 100644 --- a/python/cuml/ensemble/randomforest_shared.pyx +++ b/python/cuml/ensemble/randomforest_shared.pyx @@ -20,6 +20,7 @@ from libcpp.vector cimport vector from cython.operator cimport dereference as deref, preincrement as inc from cpython.object cimport PyObject from libc.stdint cimport uintptr_t +from libcpp.memory cimport unique_ptr from typing import Tuple, Dict, List, Union import numpy as np @@ -31,7 +32,8 @@ cdef extern from "treelite/tree.h" namespace "treelite": size_t nitem cdef cppclass Model: vector[PyBufferFrame] GetPyBuffer() except + - void InitFromPyBuffer(vector[PyBufferFrame] frames) except + + @staticmethod + unique_ptr[Model] CreateFromPyBuffer(vector[PyBufferFrame]) except + cdef extern from "Python.h": Py_buffer* PyMemoryView_GET_BUFFER(PyObject* mview) @@ -78,9 +80,7 @@ cdef list _get_frames(ModelHandle model): for v in (model).GetPyBuffer()] cdef ModelHandle _init_from_frames(vector[PyBufferFrame] frames) except *: - cdef Model* model_obj = new Model() - model_obj.InitFromPyBuffer(frames) - return model_obj + return Model.CreateFromPyBuffer(frames).release() def get_frames(model: uintptr_t) -> List[memoryview]: diff --git a/python/cuml/fil/fil.pyx b/python/cuml/fil/fil.pyx index c763c91baa..b85b741921 100644 --- a/python/cuml/fil/fil.pyx +++ b/python/cuml/fil/fil.pyx @@ -45,19 +45,14 @@ cdef extern from "treelite/c_api.h": ctypedef void* ModelHandle cdef int TreeliteLoadXGBoostModel(const char* filename, ModelHandle* out) except + - cdef int TreeliteLoadXGBoostModelFromMemoryBuffer(const void* buf, - size_t len, - ModelHandle* out) \ - except + + cdef int TreeliteLoadXGBoostJSON(const char* filename, + ModelHandle* out) except + cdef int TreeliteFreeModel(ModelHandle handle) except + cdef int TreeliteQueryNumTree(ModelHandle handle, size_t* out) except + cdef int TreeliteQueryNumFeature(ModelHandle handle, size_t* out) except + - cdef int TreeliteQueryNumOutputGroups(ModelHandle handle, - size_t* out) except + + cdef int TreeliteQueryNumClass(ModelHandle handle, size_t* out) except + cdef int TreeliteLoadLightGBMModel(const char* filename, ModelHandle* out) except + - cdef int TreeliteLoadProtobufModel(const char* filename, - ModelHandle* out) except + cdef const char* TreeliteGetLastError() @@ -133,6 +128,11 @@ cdef class TreeliteModel(): if res < 0: err = TreeliteGetLastError() raise RuntimeError("Failed to load %s (%s)" % (filename, err)) + elif model_type == "xgboost_json": + res = TreeliteLoadXGBoostJSON(filename_bytes, &handle) + if res < 0: + err = TreeliteGetLastError() + raise RuntimeError("Failed to load %s (%s)" % (filename, err)) elif model_type == "lightgbm": logger.warn("Treelite currently does not support float64 model" " parameters. Accuracy may degrade slightly relative" @@ -195,13 +195,13 @@ cdef extern from "cuml/fil/fil.h" namespace "ML::fil": cdef forest_t from_treelite(handle_t& handle, forest_t*, ModelHandle, - treelite_params_t*) + treelite_params_t*) except + cdef class ForestInference_impl(): cdef object handle cdef forest_t forest_data - cdef size_t num_output_groups + cdef size_t num_class cdef bool output_class def __cinit__(self, @@ -284,10 +284,10 @@ cdef class ForestInference_impl(): if preds is None: shape = (n_rows, ) if predict_proba: - if self.num_output_groups <= 2: + if self.num_class <= 2: shape += (2,) else: - shape += (self.num_output_groups,) + shape += (self.num_class,) preds = CumlArray.empty(shape=shape, dtype=np.float32, order='C') elif (not isinstance(preds, cudf.Series) and not rmm.is_cuda_array(preds)): @@ -337,8 +337,8 @@ cdef class ForestInference_impl(): &self.forest_data, model_ptr, &treelite_params) - TreeliteQueryNumOutputGroups( model_ptr, - & self.num_output_groups) + TreeliteQueryNumClass( model_ptr, + & self.num_class) return self def load_from_treelite_model(self, @@ -348,8 +348,8 @@ cdef class ForestInference_impl(): float threshold, str storage_type, int blocks_per_sm): - TreeliteQueryNumOutputGroups( model.handle, - & self.num_output_groups) + TreeliteQueryNumClass( model.handle, + & self.num_class) return self.load_from_treelite_model_handle(model.handle, output_class, algo, threshold, storage_type, @@ -380,8 +380,8 @@ cdef class ForestInference_impl(): &self.forest_data, model_ptr, &treelite_params) - TreeliteQueryNumOutputGroups( model_ptr, - &self.num_output_groups) + TreeliteQueryNumClass( model_ptr, + &self.num_class) return self def __dealloc__(self):