From 57124ceecda502b600b163f5fdf3f3ae76d1c2c5 Mon Sep 17 00:00:00 2001 From: Andy Adinets Date: Wed, 13 Apr 2022 17:43:28 +0200 Subject: [PATCH] float64 support in treelite->FIL import and Python layer (#4690) `float64` support in treelite->FIL import and Python layer Authors: - Andy Adinets (https://github.com/canonizer) - Levs Dolgovs (https://github.com/levsnv) Approvers: - Philip Hyunsu Cho (https://github.com/hcho3) - William Hicks (https://github.com/wphicks) URL: https://github.com/rapidsai/cuml/pull/4690 --- cpp/bench/sg/fil.cu | 4 +- cpp/include/cuml/fil/fil.h | 18 ++-- cpp/src/fil/treelite_import.cu | 69 ++++++++------ cpp/test/sg/fil_test.cu | 166 ++++++++++++++++++++------------- cpp/test/sg/rf_test.cu | 15 +-- python/cuml/fil/fil.pyx | 80 ++++++++++++---- 6 files changed, 223 insertions(+), 129 deletions(-) diff --git a/cpp/bench/sg/fil.cu b/cpp/bench/sg/fil.cu index adf283fbaf..67017fd9f5 100644 --- a/cpp/bench/sg/fil.cu +++ b/cpp/bench/sg/fil.cu @@ -91,7 +91,9 @@ class FIL : public RegressionFixture { .threads_per_tree = 1, .n_items = 0, .pforest_shape_str = nullptr}; - ML::fil::from_treelite(*handle, &forest, model, &tl_params); + ML::fil::forest_variant forest_variant; + ML::fil::from_treelite(*handle, &forest_variant, model, &tl_params); + forest = std::get>(forest_variant); // only time prediction this->loopOnState(state, [this]() { diff --git a/cpp/include/cuml/fil/fil.h b/cpp/include/cuml/fil/fil.h index 581fe3eb13..2d5d786520 100644 --- a/cpp/include/cuml/fil/fil.h +++ b/cpp/include/cuml/fil/fil.h @@ -20,6 +20,8 @@ #include +#include // for std::get<>, std::variant<> + #include namespace raft { @@ -29,10 +31,8 @@ class handle_t; namespace ML { namespace fil { -/** @note FIL only supports inference with single precision. - * TODO(canonizer): parameterize the functions and structures by the data type - * and the threshold/weight type. - */ +/** @note FIL supports inference with both single and double precision. However, + the floating-point type used in the data and model must be the same. */ /** Inference algorithm to use. */ enum algo_t { @@ -76,6 +76,13 @@ struct forest; template using forest_t = forest*; +/** forest32_t and forest64_t are definitions required in Cython */ +using forest32_t = forest*; +using forest64_t = forest*; + +/** forest_variant is used to get a forest represented with either float or double. */ +using forest_variant = std::variant, forest_t>; + /** MAX_N_ITEMS determines the maximum allowed value for tl_params::n_items */ constexpr int MAX_N_ITEMS = 4; @@ -114,9 +121,8 @@ struct treelite_params_t { * @param model treelite model used to initialize the forest * @param tl_params additional parameters for the forest */ -// TODO (canonizer): use std::variant forest_t>* for pforest void from_treelite(const raft::handle_t& handle, - forest_t* pforest, + forest_variant* pforest, ModelHandle model, const treelite_params_t* tl_params); diff --git a/cpp/src/fil/treelite_import.cu b/cpp/src/fil/treelite_import.cu index 68634fe26a..2b9e320c95 100644 --- a/cpp/src/fil/treelite_import.cu +++ b/cpp/src/fil/treelite_import.cu @@ -40,6 +40,7 @@ #include // for std::size_t #include // for uint8_t #include // for ios, stringstream +#include // for std::numeric_limits #include // for std::stack #include // for std::string #include // for std::is_same @@ -223,7 +224,8 @@ cat_sets_owner allocate_cat_sets_owner(const tl::ModelImpl& model) return cat_sets; } -void adjust_threshold(float* pthreshold, bool* swap_child_nodes, tl::Operator comparison_op) +template +void adjust_threshold(real_t* pthreshold, bool* swap_child_nodes, tl::Operator comparison_op) { // in treelite (take left node if val [op] threshold), // the meaning of the condition is reversed compared to FIL; @@ -237,12 +239,12 @@ void adjust_threshold(float* pthreshold, bool* swap_child_nodes, tl::Operator co case tl::Operator::kLT: break; case tl::Operator::kLE: // x <= y is equivalent to x < y', where y' is the next representable float - *pthreshold = std::nextafterf(*pthreshold, std::numeric_limits::infinity()); + *pthreshold = std::nextafterf(*pthreshold, std::numeric_limits::infinity()); break; case tl::Operator::kGT: // x > y is equivalent to x >= y', where y' is the next representable float // left and right still need to be swapped - *pthreshold = std::nextafterf(*pthreshold, std::numeric_limits::infinity()); + *pthreshold = std::nextafterf(*pthreshold, std::numeric_limits::infinity()); case tl::Operator::kGE: // swap left and right *swap_child_nodes = !*swap_child_nodes; @@ -279,7 +281,7 @@ void tl2fil_leaf_payload(fil_node_t* fil_node, const tl::Tree& tl_tree, int tl_node_id, const forest_params_t& forest_params, - std::vector* vector_leaf, + std::vector* vector_leaf, size_t* leaf_counter) { auto vec = tl_tree.LeafVector(tl_node_id); @@ -301,7 +303,7 @@ void tl2fil_leaf_payload(fil_node_t* fil_node, } case leaf_algo_t::FLOAT_UNARY_BINARY: case leaf_algo_t::GROVE_PER_CLASS: - fil_node->val.f = static_cast(tl_tree.LeafValue(tl_node_id)); + fil_node->val.f = static_cast(tl_tree.LeafValue(tl_node_id)); ASSERT(!tl_tree.HasLeafVector(tl_node_id), "some but not all treelite leaves have leaf_vector()"); break; @@ -323,14 +325,15 @@ conversion_state tl2fil_inner_node(int fil_left_child, cat_sets_owner* cat_sets, std::size_t* bit_pool_offset) { + using real_t = typename fil_node_t::real_type; int tl_left = tree.LeftChild(tl_node_id), tl_right = tree.RightChild(tl_node_id); - val_t split = {.f = NAN}; // yes there's a default initializer already + val_t split = {.f = std::numeric_limits::quiet_NaN()}; int feature_id = tree.SplitIndex(tl_node_id); bool is_categorical = tree.SplitType(tl_node_id) == tl::SplitFeatureType::kCategorical && tree.MatchingCategories(tl_node_id).size() > 0; bool swap_child_nodes = false; if (tree.SplitType(tl_node_id) == tl::SplitFeatureType::kNumerical) { - split.f = static_cast(tree.Threshold(tl_node_id)); + split.f = static_cast(tree.Threshold(tl_node_id)); adjust_threshold(&split.f, &swap_child_nodes, tree.ComparisonOp(tl_node_id)); } else if (tree.SplitType(tl_node_id) == tl::SplitFeatureType::kCategorical) { // for FIL, the list of categories is always for the right child @@ -346,14 +349,14 @@ conversion_state tl2fil_inner_node(int fil_left_child, } } else { // always branch left in FIL. Already accounted for Treelite branching direction above. - split.f = NAN; + split.f = std::numeric_limits::quiet_NaN(); } } else { ASSERT(false, "only numerical and categorical split nodes are supported"); } bool default_left = tree.DefaultLeft(tl_node_id) ^ swap_child_nodes; fil_node_t node( - val_t{}, split, feature_id, default_left, false, is_categorical, fil_left_child); + val_t{}, split, feature_id, default_left, false, is_categorical, fil_left_child); return conversion_state{node, swap_child_nodes}; } @@ -363,7 +366,7 @@ int tree2fil(std::vector& nodes, const tl::Tree& tree, std::size_t tree_idx, const forest_params_t& forest_params, - std::vector* vector_leaf, + std::vector* vector_leaf, std::size_t* leaf_counter, cat_sets_owner* cat_sets) { @@ -443,10 +446,11 @@ std::stringstream depth_hist_and_max(const tl::ModelImpl& model) forest_shape << "Total: branches: " << total_branches << " leaves: " << total_leaves << " nodes: " << total_nodes << endl; forest_shape << "Avg nodes per tree: " << setprecision(2) - << total_nodes / (float)hist[0].n_branch_nodes << endl; + << total_nodes / static_cast(hist[0].n_branch_nodes) << endl; forest_shape.copyfmt(default_state); forest_shape << "Leaf depth: min: " << min_leaf_depth << " avg: " << setprecision(2) << fixed - << leaves_times_depth / (float)total_leaves << " max: " << hist.size() - 1 << endl; + << leaves_times_depth / static_cast(total_leaves) + << " max: " << hist.size() - 1 << endl; forest_shape.copyfmt(default_state); vector hist_bytes(hist.size() * sizeof(hist[0])); @@ -575,9 +579,10 @@ void node_traits::check(const treelite::ModelImpl& template struct tl2fil_t { + using real_t = typename fil_node_t::real_type; std::vector roots_; std::vector nodes_; - std::vector vector_leaf_; + std::vector vector_leaf_; forest_params_t params_; cat_sets_owner cat_sets_; const tl::ModelImpl& model_; @@ -631,7 +636,7 @@ struct tl2fil_t { } /// initializes FIL forest object, to be ready to infer - void init_forest(const raft::handle_t& handle, forest_t* pforest) + void init_forest(const raft::handle_t& handle, forest_t* pforest) { ML::fil::init( handle, pforest, cat_sets_.accessor(), vector_leaf_, roots_.data(), nodes_.data(), ¶ms_); @@ -646,7 +651,7 @@ struct tl2fil_t { template void convert(const raft::handle_t& handle, - forest_t* pforest, + forest_t* pforest, const tl::ModelImpl& model, const treelite_params_t& tl_params) { @@ -664,24 +669,21 @@ constexpr bool type_supported() template void from_treelite(const raft::handle_t& handle, - forest_t* pforest, + forest_variant* pforest_variant, const tl::ModelImpl& model, const treelite_params_t* tl_params) { + // floating-point type used for model representation + using real_t = decltype(threshold_t(0) + leaf_t(0)); + + // get the pointer to the right forest variant + *pforest_variant = (forest_t)nullptr; + forest_t* pforest = &std::get>(*pforest_variant); + // Invariants on threshold and leaf types static_assert(type_supported(), "Model must contain float32 or float64 thresholds for splits"); ASSERT(type_supported(), "Models with integer leaf output are not yet supported"); - // Display appropriate warnings when float64 values are being casted into - // float32, as FIL only supports inferencing with float32 for the time being - if (std::is_same::value || std::is_same::value) { - CUML_LOG_WARN( - "Casting all thresholds and leaf values to float32, as FIL currently " - "doesn't support inferencing models with float64 values. " - "This may lead to predictions with reduced accuracy."); - } - // same as std::common_type: float+double=double, float+int64_t=float - using real_t = decltype(threshold_t(0) + leaf_t(0)); storage_type_t storage_type = tl_params->storage_type; // build dense trees by default @@ -702,18 +704,25 @@ void from_treelite(const raft::handle_t& handle, switch (storage_type) { case storage_type_t::DENSE: - convert>(handle, pforest, model, *tl_params); + convert>(handle, pforest, model, *tl_params); break; case storage_type_t::SPARSE: - convert>(handle, pforest, model, *tl_params); + convert>(handle, pforest, model, *tl_params); + break; + case storage_type_t::SPARSE8: + // SPARSE8 is only supported for float32 + if constexpr (std::is_same_v) { + convert(handle, pforest, model, *tl_params); + } else { + ASSERT(false, "SPARSE8 is only supported for float32 treelite models"); + } break; - case storage_type_t::SPARSE8: convert(handle, pforest, model, *tl_params); break; default: ASSERT(false, "tl_params->sparse must be one of AUTO, DENSE or SPARSE"); } } void from_treelite(const raft::handle_t& handle, - forest_t* pforest, + forest_variant* pforest, ModelHandle model, const treelite_params_t* tl_params) { diff --git a/cpp/test/sg/fil_test.cu b/cpp/test/sg/fil_test.cu index cec210280a..cb3bcf8495 100644 --- a/cpp/test/sg/fil_test.cu +++ b/cpp/test/sg/fil_test.cu @@ -196,8 +196,9 @@ __global__ void floats_to_bit_stream_k(uint8_t* dst, real_t* src, std::size_t si dst[idx] = byte; } +template void adjust_threshold_to_treelite( - float* pthreshold, int* tl_left, int* tl_right, bool* default_left, tl::Operator comparison_op) + real_t* pthreshold, int* tl_left, int* tl_right, bool* default_left, tl::Operator comparison_op) { // in treelite (take left node if val [op] threshold), // the meaning of the condition is reversed compared to FIL; @@ -213,12 +214,12 @@ void adjust_threshold_to_treelite( case tl::Operator::kLT: break; case tl::Operator::kLE: // x <= y is equivalent to x < y', where y' is the next representable float - *pthreshold = std::nextafterf(*pthreshold, -std::numeric_limits::infinity()); + *pthreshold = std::nextafterf(*pthreshold, -std::numeric_limits::infinity()); break; case tl::Operator::kGT: // x > y is equivalent to x >= y', where y' is the next representable float // left and right still need to be swapped - *pthreshold = std::nextafterf(*pthreshold, -std::numeric_limits::infinity()); + *pthreshold = std::nextafterf(*pthreshold, -std::numeric_limits::infinity()); case tl::Operator::kGE: // swap left and right std::swap(*tl_left, *tl_right); @@ -745,7 +746,8 @@ using PredictSparse16Float32FilTest = BasePredictFilTest>; using PredictSparse8FilTest = BasePredictFilTest; -class TreeliteFilTest : public BaseFilTest { +template +class TreeliteFilTest : public BaseFilTest { protected: /** adds nodes[node] of tree starting at index root to builder at index at *pkey, increments *pkey, @@ -754,28 +756,29 @@ class TreeliteFilTest : public BaseFilTest { { int key = (*pkey)++; builder->CreateNode(key); - const fil::dense_node& dense_node = nodes[node]; + const fil::dense_node& dense_node = this->nodes[node]; std::vector left_categories; if (dense_node.is_leaf()) { - switch (ps.leaf_algo) { + switch (this->ps.leaf_algo) { case fil::leaf_algo_t::FLOAT_UNARY_BINARY: case fil::leaf_algo_t::GROVE_PER_CLASS: // default is fil::FLOAT_UNARY_BINARY - builder->SetLeafNode(key, tlf::Value::Create(dense_node.output())); + builder->SetLeafNode(key, tlf::Value::Create(dense_node.template output())); break; case fil::leaf_algo_t::CATEGORICAL_LEAF: { - std::vector vec(ps.num_classes); - for (int i = 0; i < ps.num_classes; ++i) { - vec[i] = tlf::Value::Create(i == dense_node.output() ? 1.0f : 0.0f); + std::vector vec(this->ps.num_classes); + for (int i = 0; i < this->ps.num_classes; ++i) { + vec[i] = + tlf::Value::Create(i == dense_node.template output() ? real_t(1) : real_t(0)); } builder->SetLeafVectorNode(key, vec); break; } case fil::leaf_algo_t::VECTOR_LEAF: { - std::vector vec(ps.num_classes); - for (int i = 0; i < ps.num_classes; ++i) { - auto idx = dense_node.output(); - vec[i] = tlf::Value::Create(vector_leaf[idx * ps.num_classes + i]); + std::vector vec(this->ps.num_classes); + for (int i = 0; i < this->ps.num_classes; ++i) { + auto idx = dense_node.template output(); + vec[i] = tlf::Value::Create(this->vector_leaf[idx * this->ps.num_classes + i]); } builder->SetLeafVectorNode(key, vec); break; @@ -787,14 +790,15 @@ class TreeliteFilTest : public BaseFilTest { int left = root + 2 * (node - root) + 1; int right = root + 2 * (node - root) + 2; bool default_left = dense_node.def_left(); - float threshold = dense_node.is_categorical() ? NAN : dense_node.thresh(); + real_t threshold = dense_node.is_categorical() ? std::numeric_limits::quiet_NaN() + : dense_node.thresh(); if (dense_node.is_categorical()) { uint8_t byte = 0; for (int category = 0; - category < static_cast(cat_sets_h.fid_num_cats[dense_node.fid()]); + category < static_cast(this->cat_sets_h.fid_num_cats[dense_node.fid()]); ++category) { if (category % BITS_PER_BYTE == 0) { - byte = cat_sets_h.bits[dense_node.set() + category / BITS_PER_BYTE]; + byte = this->cat_sets_h.bits[dense_node.set() + category / BITS_PER_BYTE]; } if ((byte & (1 << (category % BITS_PER_BYTE))) != 0) { left_categories.push_back(category); @@ -815,10 +819,10 @@ class TreeliteFilTest : public BaseFilTest { builder->SetCategoricalTestNode( key, dense_node.fid(), left_categories, default_left, left_key, right_key); } else { - adjust_threshold_to_treelite(&threshold, &left_key, &right_key, &default_left, ps.op); + adjust_threshold_to_treelite(&threshold, &left_key, &right_key, &default_left, this->ps.op); builder->SetNumericalTestNode(key, dense_node.fid(), - ps.op, + this->ps.op, tlf::Value::Create(threshold), default_left, left_key, @@ -828,28 +832,27 @@ class TreeliteFilTest : public BaseFilTest { return key; } - void init_forest_impl(fil::forest_t* pforest, fil::storage_type_t storage_type) + void init_forest_impl(fil::forest_t* pforest, fil::storage_type_t storage_type) { - auto stream = handle.get_stream(); - bool random_forest_flag = (ps.output & fil::output_t::AVG) != 0; + auto stream = this->handle.get_stream(); + bool random_forest_flag = (this->ps.output & fil::output_t::AVG) != 0; + tl::TypeInfo tl_type_info = + std::is_same_v ? tl::TypeInfo::kFloat32 : tl::TypeInfo::kFloat64; int treelite_num_classes = - ps.leaf_algo == fil::leaf_algo_t::FLOAT_UNARY_BINARY ? 1 : ps.num_classes; - std::unique_ptr model_builder(new tlf::ModelBuilder(ps.num_cols, - treelite_num_classes, - random_forest_flag, - tl::TypeInfo::kFloat32, - tl::TypeInfo::kFloat32)); + this->ps.leaf_algo == fil::leaf_algo_t::FLOAT_UNARY_BINARY ? 1 : this->ps.num_classes; + std::unique_ptr model_builder(new tlf::ModelBuilder( + this->ps.num_cols, treelite_num_classes, random_forest_flag, tl_type_info, tl_type_info)); // prediction transform - if ((ps.output & fil::output_t::SIGMOID) != 0) { - if (ps.num_classes > 2) + if ((this->ps.output & fil::output_t::SIGMOID) != 0) { + if (this->ps.num_classes > 2) model_builder->SetModelParam("pred_transform", "multiclass_ova"); else model_builder->SetModelParam("pred_transform", "sigmoid"); - } else if (ps.leaf_algo != fil::leaf_algo_t::FLOAT_UNARY_BINARY) { + } else if (this->ps.leaf_algo != fil::leaf_algo_t::FLOAT_UNARY_BINARY) { model_builder->SetModelParam("pred_transform", "max_index"); - ps.output = fil::output_t(ps.output | fil::output_t::CLASS); - } else if (ps.leaf_algo == GROVE_PER_CLASS) { + this->ps.output = fil::output_t(this->ps.output | fil::output_t::CLASS); + } else if (this->ps.leaf_algo == GROVE_PER_CLASS) { model_builder->SetModelParam("pred_transform", "identity_multiclass"); } else { model_builder->SetModelParam("pred_transform", "identity"); @@ -857,18 +860,17 @@ class TreeliteFilTest : public BaseFilTest { // global bias char* global_bias_str = nullptr; - ASSERT(asprintf(&global_bias_str, "%f", double(ps.global_bias)) > 0, + ASSERT(asprintf(&global_bias_str, "%f", double(this->ps.global_bias)) > 0, "cannot convert global_bias into a string"); model_builder->SetModelParam("global_bias", global_bias_str); ::free(global_bias_str); // build the trees - for (int i_tree = 0; i_tree < ps.num_trees; ++i_tree) { - tlf::TreeBuilder* tree_builder = - new tlf::TreeBuilder(tl::TypeInfo::kFloat32, tl::TypeInfo::kFloat32); - int key_counter = 0; - int root = i_tree * tree_num_nodes(); - int root_key = node_to_treelite(tree_builder, &key_counter, root, root); + for (int i_tree = 0; i_tree < this->ps.num_trees; ++i_tree) { + tlf::TreeBuilder* tree_builder = new tlf::TreeBuilder(tl_type_info, tl_type_info); + int key_counter = 0; + int root = i_tree * this->tree_num_nodes(); + int root_key = node_to_treelite(tree_builder, &key_counter, root, root); tree_builder->SetRootNode(root_key); // InsertTree() consumes tree_builder TL_CPP_CHECK(model_builder->InsertTree(tree_builder)); @@ -880,17 +882,19 @@ class TreeliteFilTest : public BaseFilTest { // init FIL forest with the model char* forest_shape_str = nullptr; fil::treelite_params_t params; - params.algo = ps.algo; - params.threshold = ps.threshold; - params.output_class = (ps.output & fil::output_t::CLASS) != 0; + params.algo = this->ps.algo; + params.threshold = this->ps.threshold; + params.output_class = (this->ps.output & fil::output_t::CLASS) != 0; params.storage_type = storage_type; - params.blocks_per_sm = ps.blocks_per_sm; - params.threads_per_tree = ps.threads_per_tree; - params.n_items = ps.n_items; - params.pforest_shape_str = ps.print_forest_shape ? &forest_shape_str : nullptr; - fil::from_treelite(handle, pforest, (ModelHandle)model.get(), ¶ms); - handle.sync_stream(stream); - if (ps.print_forest_shape) { + params.blocks_per_sm = this->ps.blocks_per_sm; + params.threads_per_tree = this->ps.threads_per_tree; + params.n_items = this->ps.n_items; + params.pforest_shape_str = this->ps.print_forest_shape ? &forest_shape_str : nullptr; + fil::forest_variant forest_variant; + fil::from_treelite(this->handle, &forest_variant, (ModelHandle)model.get(), ¶ms); + *pforest = std::get>(forest_variant); + this->handle.sync_stream(stream); + if (this->ps.print_forest_shape) { std::string str(forest_shape_str); for (const char* substr : {"model size", " MB", @@ -908,38 +912,48 @@ class TreeliteFilTest : public BaseFilTest { } }; -class TreeliteDenseFilTest : public TreeliteFilTest { +template +class TreeliteDenseFilTest : public TreeliteFilTest { protected: - void init_forest(fil::forest_t* pforest) override + void init_forest(fil::forest_t* pforest) override { - init_forest_impl(pforest, fil::storage_type_t::DENSE); + this->init_forest_impl(pforest, fil::storage_type_t::DENSE); } }; -class TreeliteSparse16FilTest : public TreeliteFilTest { +template +class TreeliteSparse16FilTest : public TreeliteFilTest { protected: - void init_forest(fil::forest_t* pforest) override + void init_forest(fil::forest_t* pforest) override { - init_forest_impl(pforest, fil::storage_type_t::SPARSE); + this->init_forest_impl(pforest, fil::storage_type_t::SPARSE); } }; -class TreeliteSparse8FilTest : public TreeliteFilTest { +class TreeliteSparse8FilTest : public TreeliteFilTest { protected: void init_forest(fil::forest_t* pforest) override { - init_forest_impl(pforest, fil::storage_type_t::SPARSE8); + this->init_forest_impl(pforest, fil::storage_type_t::SPARSE8); } }; -class TreeliteAutoFilTest : public TreeliteFilTest { +template +class TreeliteAutoFilTest : public TreeliteFilTest { protected: - void init_forest(fil::forest_t* pforest) override + void init_forest(fil::forest_t* pforest) override { - init_forest_impl(pforest, fil::storage_type_t::AUTO); + this->init_forest_impl(pforest, fil::storage_type_t::AUTO); } }; +using TreeliteDenseFloat32FilTest = TreeliteDenseFilTest; +using TreeliteDenseFloat64FilTest = TreeliteDenseFilTest; +using TreeliteSparse16Float32FilTest = TreeliteDenseFilTest; +using TreeliteSparse16Float64FilTest = TreeliteDenseFilTest; +using TreeliteAutoFloat32FilTest = TreeliteAutoFilTest; +using TreeliteAutoFloat64FilTest = TreeliteAutoFilTest; + // test for failures; currently only supported for sparse8 nodes class TreeliteThrowSparse8FilTest : public TreeliteSparse8FilTest { protected: @@ -1300,9 +1314,15 @@ std::vector import_dense_inputs = { max_magnitude_of_matching_cat = 5), }; -TEST_P(TreeliteDenseFilTest, Import) { compare(); } +TEST_P(TreeliteDenseFloat32FilTest, Import) { compare(); } +TEST_P(TreeliteDenseFloat64FilTest, Import) { compare(); } -INSTANTIATE_TEST_CASE_P(FilTests, TreeliteDenseFilTest, testing::ValuesIn(import_dense_inputs)); +INSTANTIATE_TEST_CASE_P(FilTests, + TreeliteDenseFloat32FilTest, + testing::ValuesIn(import_dense_inputs)); +INSTANTIATE_TEST_CASE_P(FilTests, + TreeliteDenseFloat64FilTest, + testing::ValuesIn(import_dense_inputs)); std::vector import_sparse_inputs = { FIL_TEST_PARAMS(), @@ -1353,9 +1373,15 @@ std::vector import_sparse_inputs = { max_magnitude_of_matching_cat = 5), }; -TEST_P(TreeliteSparse16FilTest, Import) { compare(); } +TEST_P(TreeliteSparse16Float32FilTest, Import) { compare(); } +TEST_P(TreeliteSparse16Float64FilTest, Import) { compare(); } -INSTANTIATE_TEST_CASE_P(FilTests, TreeliteSparse16FilTest, testing::ValuesIn(import_sparse_inputs)); +INSTANTIATE_TEST_CASE_P(FilTests, + TreeliteSparse16Float32FilTest, + testing::ValuesIn(import_sparse_inputs)); +INSTANTIATE_TEST_CASE_P(FilTests, + TreeliteSparse16Float64FilTest, + testing::ValuesIn(import_sparse_inputs)); TEST_P(TreeliteSparse8FilTest, Import) { compare(); } @@ -1381,9 +1407,15 @@ std::vector import_auto_inputs = { #endif }; -TEST_P(TreeliteAutoFilTest, Import) { compare(); } +TEST_P(TreeliteAutoFloat32FilTest, Import) { compare(); } +TEST_P(TreeliteAutoFloat64FilTest, Import) { compare(); } -INSTANTIATE_TEST_CASE_P(FilTests, TreeliteAutoFilTest, testing::ValuesIn(import_auto_inputs)); +INSTANTIATE_TEST_CASE_P(FilTests, + TreeliteAutoFloat32FilTest, + testing::ValuesIn(import_auto_inputs)); +INSTANTIATE_TEST_CASE_P(FilTests, + TreeliteAutoFloat64FilTest, + testing::ValuesIn(import_auto_inputs)); // adjust test parameters if the sparse8 format changes std::vector import_throw_sparse8_inputs = { diff --git a/cpp/test/sg/rf_test.cu b/cpp/test/sg/rf_test.cu index 345770efa1..18923c4baa 100644 --- a/cpp/test/sg/rf_test.cu +++ b/cpp/test/sg/rf_test.cu @@ -172,8 +172,9 @@ auto FilPredict(const raft::handle_t& handle, 1, 0, nullptr}; - fil::forest_t fil_forest; - fil::from_treelite(handle, &fil_forest, model, &tl_params); + fil::forest_variant forest_variant; + fil::from_treelite(handle, &forest_variant, model, &tl_params); + fil::forest_t fil_forest = std::get>(forest_variant); fil::predict(handle, fil_forest, pred->data().get(), X_transpose, params.n_rows, false); return pred; } @@ -191,8 +192,9 @@ auto FilPredictProba(const raft::handle_t& handle, build_treelite_forest(&model, forest, params.n_cols); fil::treelite_params_t tl_params{ fil::algo_t::ALGO_AUTO, 0, 0.0f, fil::storage_type_t::AUTO, 8, 1, 0, nullptr}; - fil::forest_t fil_forest; - fil::from_treelite(handle, &fil_forest, model, &tl_params); + fil::forest_variant forest_variant; + fil::from_treelite(handle, &forest_variant, model, &tl_params); + fil::forest_t fil_forest = std::get>(forest_variant); fil::predict(handle, fil_forest, pred->data().get(), X_transpose, params.n_rows, true); return pred; } @@ -557,8 +559,9 @@ TEST(RfTests, IntegerOverflow) 1, 0, nullptr}; - fil::forest_t fil_forest; - fil::from_treelite(handle, &fil_forest, model, &tl_params); + fil::forest_variant forest_variant; + fil::from_treelite(handle, &forest_variant, model, &tl_params); + fil::forest_t fil_forest = std::get>(forest_variant); fil::predict(handle, fil_forest, pred.data().get(), X.data().get(), m, false); } diff --git a/python/cuml/fil/fil.pyx b/python/cuml/fil/fil.pyx index a9e0b79e1a..8908b0fc4a 100644 --- a/python/cuml/fil/fil.pyx +++ b/python/cuml/fil/fil.pyx @@ -177,6 +177,14 @@ cdef class TreeliteModel(): model.set_handle(handle) return model +cdef extern from "variant" namespace "std": + cdef cppclass variant[T1, T2]: + variant() + variant(T1) + size_t index() + + cdef T& get[T, T1, T2](variant[T1, T2]& v) + cdef extern from "cuml/fil/fil.h" namespace "ML::fil": cdef enum algo_t: ALGO_AUTO, @@ -193,6 +201,10 @@ cdef extern from "cuml/fil/fil.h" namespace "ML::fil": cdef cppclass forest[real_t]: pass + ctypedef forest[float]* forest32_t + ctypedef forest[double]* forest64_t + ctypedef variant[forest32_t, forest64_t] forest_variant + # TODO(canonizer): use something like # ctypedef forest[real_t]* forest_t[real_t] # once it is supported in Cython @@ -227,23 +239,29 @@ cdef extern from "cuml/fil/fil.h" namespace "ML::fil": size_t, bool) except + - cdef forest[float]* from_treelite(handle_t& handle, - forest[float]**, - ModelHandle, - treelite_params_t*) except + + cdef void from_treelite(handle_t& handle, + forest_variant*, + ModelHandle, + treelite_params_t*) except + cdef class ForestInference_impl(): cdef object handle - cdef forest[float]* forest_data + cdef forest_variant forest_data cdef size_t num_class cdef bool output_class cdef char* shape_str + cdef forest32_t get_forest32(self): + return get[forest32_t, forest32_t, forest64_t](self.forest_data) + + cdef forest64_t get_forest64(self): + return get[forest64_t, forest32_t, forest64_t](self.forest_data) + def __cinit__(self, handle=None): self.handle = handle - self.forest_data = NULL + self.forest_data = forest_variant( NULL) self.shape_str = NULL def get_shape_str(self): @@ -251,6 +269,10 @@ cdef class ForestInference_impl(): return unicode(self.shape_str, 'utf-8') return None + def get_dtype(self): + dtype_array = [np.float32, np.float64] + return dtype_array[self.forest_data.index()] + def get_algo(self, algo_str): algo_dict={'AUTO': algo_t.ALGO_AUTO, 'auto': algo_t.ALGO_AUTO, @@ -327,12 +349,13 @@ cdef class ForestInference_impl(): " using a Classification model, please " " set `output_class=True` while creating" " the FIL model.") + fil_dtype = self.get_dtype() cdef uintptr_t X_ptr X_m, n_rows, n_cols, dtype = \ input_to_cuml_array(X, order='C', - convert_to_dtype=np.float32, + convert_to_dtype=fil_dtype, safe_dtype_conversion=safe_dtype_conversion, - check_dtype=np.float32) + check_dtype=fil_dtype) X_ptr = X_m.ptr cdef handle_t* handle_ =\ @@ -345,7 +368,7 @@ cdef class ForestInference_impl(): shape += (2,) else: shape += (self.num_class,) - preds = CumlArray.empty(shape=shape, dtype=np.float32, order='C', + preds = CumlArray.empty(shape=shape, dtype=fil_dtype, order='C', index=X_m.index) else: if not hasattr(preds, "__cuda_array_interface__"): @@ -356,12 +379,24 @@ cdef class ForestInference_impl(): cdef uintptr_t preds_ptr preds_ptr = preds.ptr - predict(handle_[0], - self.forest_data, - preds_ptr, - X_ptr, - n_rows, - predict_proba) + if fil_dtype == np.float32: + predict(handle_[0], + self.get_forest32(), + preds_ptr, + X_ptr, + n_rows, + predict_proba) + elif fil_dtype == np.float64: + predict(handle_[0], + self.get_forest64(), + preds_ptr, + X_ptr, + n_rows, + predict_proba) + else: + # should not reach here + assert False, 'invalid fil_dtype, must be np.float32 or np.float64' + self.handle.sync() # special case due to predict and predict_proba @@ -372,7 +407,7 @@ cdef class ForestInference_impl(): return preds def load_from_treelite_model_handle(self, **kwargs): - self.forest_data = NULL + self.forest_data = forest_variant( NULL) return self.load_using_treelite_handle(**kwargs) def load_from_treelite_model(self, **kwargs): @@ -413,9 +448,16 @@ cdef class ForestInference_impl(): def __dealloc__(self): cdef handle_t* handle_ = self.handle.getHandle() - - if self.forest_data !=NULL: - free(handle_[0], self.forest_data) + fil_dtype = self.get_dtype() + if fil_dtype == np.float32: + if self.get_forest32() != NULL: + free[float](handle_[0], self.get_forest32()) + elif fil_dtype == np.float64: + if self.get_forest64() != NULL: + free[double](handle_[0], self.get_forest64()) + else: + # should not reach here + assert False, 'invalid fil_dtype, must be np.float32 or np.float64' class ForestInference(Base,