From bc63a6adf45f42362d6ad1f5e3b864b143c8a1c9 Mon Sep 17 00:00:00 2001 From: Levs Dolgovs <36520083+levsnv@users.noreply.github.com> Date: Wed, 10 Nov 2021 08:14:19 -0800 Subject: [PATCH] Fix potential CUDA context poison when negative (invalid) categories provided to FIL model (#4315) Fix potential CUDA context poison due to invalid global read when negative categories provided at inference: now equivalent to non-matching. Same for `+-Inf` categories. NAN categories are still `!def_left`, and fractional categories are truncated (via a typecast), as this is what Treelite does FIL now converts dummy nodes to numerical on import and never generates max_matching == -1 categorical features in test. FIL will still generate empty categorical nodes in test (a non-empty bits vector which contains only zeros), export them as dummy numerical nodes and import again as dummy numerical nodes. If a feature only contains dummy numerical nodes, it will be deemed a numerical feature (same as for non-dummy numerical nodes or a mix thereof). Therefore, categorical feature max_matching == -1 is still prevented. --- cpp/src/fil/fil.cu | 27 ++++++++----- cpp/src/fil/internal.cuh | 20 ++++++--- cpp/test/sg/fil_child_index_test.cu | 63 +++++++++++++++++++++++++---- cpp/test/sg/fil_test.cu | 7 ++-- python/cuml/fil/fil.pyx | 12 +++++- python/cuml/test/test_fil.py | 38 ++++++++++++----- 6 files changed, 130 insertions(+), 37 deletions(-) diff --git a/cpp/src/fil/fil.cu b/cpp/src/fil/fil.cu index fafb66eec9..ad1a4741b6 100644 --- a/cpp/src/fil/fil.cu +++ b/cpp/src/fil/fil.cu @@ -660,7 +660,8 @@ inline std::size_t bit_pool_size(const tl::Tree& tree, const categorical_s int node_id = stack.top(); stack.pop(); while (!tree.IsLeaf(node_id)) { - if (tree.SplitType(node_id) == tl::SplitFeatureType::kCategorical) { + if (tree.SplitType(node_id) == tl::SplitFeatureType::kCategorical && + tree.MatchingCategories(node_id).size() > 0) { int fid = tree.SplitIndex(node_id); size += cat_sets.sizeof_mask(fid); } @@ -802,8 +803,9 @@ conversion_state tl2fil_inner_node(int fil_left_child, int tl_left = tree.LeftChild(tl_node_id), tl_right = tree.RightChild(tl_node_id); val_t split = {.f = NAN}; // yes there's a default initializer already int feature_id = tree.SplitIndex(tl_node_id); - bool is_categorical = tree.SplitType(tl_node_id) == tl::SplitFeatureType::kCategorical; - bool default_left = tree.DefaultLeft(tl_node_id); + bool is_categorical = tree.SplitType(tl_node_id) == tl::SplitFeatureType::kCategorical && + tree.MatchingCategories(tl_node_id).size() > 0; + bool default_left = tree.DefaultLeft(tl_node_id); if (tree.SplitType(tl_node_id) == tl::SplitFeatureType::kNumerical) { split.f = static_cast(tree.Threshold(tl_node_id)); adjust_threshold(&split.f, &tl_left, &tl_right, &default_left, tree.ComparisonOp(tl_node_id)); @@ -813,13 +815,18 @@ conversion_state tl2fil_inner_node(int fil_left_child, std::swap(tl_left, tl_right); default_left = !default_left; } - int sizeof_mask = cat_sets->accessor().sizeof_mask(feature_id); - split.idx = *bit_pool_offset; - *bit_pool_offset += sizeof_mask; - // cat_sets->bits have been zero-initialized - uint8_t* bits = &cat_sets->bits[split.idx]; - for (std::uint32_t category : tree.MatchingCategories(tl_node_id)) { - bits[category / BITS_PER_BYTE] |= 1 << (category % BITS_PER_BYTE); + if (tree.MatchingCategories(tl_node_id).size() > 0) { + int sizeof_mask = cat_sets->accessor().sizeof_mask(feature_id); + split.idx = *bit_pool_offset; + *bit_pool_offset += sizeof_mask; + // cat_sets->bits have been zero-initialized + uint8_t* bits = &cat_sets->bits[split.idx]; + for (std::uint32_t category : tree.MatchingCategories(tl_node_id)) { + bits[category / BITS_PER_BYTE] |= 1 << (category % BITS_PER_BYTE); + } + } else { + // always branch left in FIL. Already accounted for Treelite branching direction above. + split.f = NAN; } } else { ASSERT(false, "only numerical and categorical split nodes are supported"); diff --git a/cpp/src/fil/internal.cuh b/cpp/src/fil/internal.cuh index 1cce74bdb1..7aa6363f58 100644 --- a/cpp/src/fil/internal.cuh +++ b/cpp/src/fil/internal.cuh @@ -307,9 +307,10 @@ struct forest_params_t { /// FIL_TPB is the number of threads per block to use with FIL kernels const int FIL_TPB = 256; -constexpr std::int32_t MAX_PRECISE_INT_FLOAT = 1 << 24; // 16'777'216 +// as far as FIL is concerned, 16'777'214 is the most we can do. +constexpr std::int32_t MAX_PRECISE_INT_FLOAT = (1 << 24) - 2; -__host__ __device__ __forceinline__ int fetch_bit(const uint8_t* array, int bit) +__host__ __device__ __forceinline__ int fetch_bit(const uint8_t* array, uint32_t bit) { return (array[bit / BITS_PER_BYTE] >> (bit % BITS_PER_BYTE)) & 1; } @@ -337,7 +338,7 @@ struct categorical_sets { // set count is due to tree_idx + node_within_tree_idx are both ints, hence uint32_t result template - __host__ __device__ __forceinline__ int category_matches(node_t node, int category) const + __host__ __device__ __forceinline__ int category_matches(node_t node, float category) const { // standard boolean packing. This layout has better ILP // node.set() is global across feature IDs and is an offset (as opposed @@ -345,7 +346,16 @@ struct categorical_sets { // features with similar categorical feature count, we may consider // storing node ID within nodes with same feature ID and look up // {.max_matching, .first_node_offset} = ...[feature_id] - return category <= max_matching[node.fid()] && fetch_bit(bits + node.set(), category); + + /* category < 0.0f or category > INT_MAX is equivalent to out-of-dictionary category + (not matching, branch left). -0.0f represents category 0. + If (float)(int)category != category, we will discard the fractional part. + E.g. 3.8f represents category 3 regardless of max_matching value. + FIL will reject a model where an integer within [0, max_matching + 1] cannot be represented + precisely as a 32-bit float. + */ + return category < static_cast(max_matching[node.fid()] + 1) && category >= 0.0f && + fetch_bit(bits + node.set(), static_cast(category)); } static int sizeof_mask_from_max_matching(int max_matching) { @@ -372,7 +382,7 @@ struct tree_base { if (isnan(val)) { cond = !node.def_left(); } else if (CATS_SUPPORTED && node.is_categorical()) { - cond = cat_sets.category_matches(node, static_cast(val)); + cond = cat_sets.category_matches(node, val); } else { cond = val >= node.thresh(); } diff --git a/cpp/test/sg/fil_child_index_test.cu b/cpp/test/sg/fil_child_index_test.cu index fa963eff1f..00807a2766 100644 --- a/cpp/test/sg/fil_child_index_test.cu +++ b/cpp/test/sg/fil_child_index_test.cu @@ -207,11 +207,42 @@ std::vector params = { CHILD_INDEX_TEST_PARAMS(parent_node_idx = 4, input = NAN, correct = 10), // !def_left CHILD_INDEX_TEST_PARAMS( node = NODE(def_left = true), input = NAN, parent_node_idx = 4, correct = 9), // !def_left - // cannot match ( > max_matching) + // cannot match ( < 0 and realistic max_matching) CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true), cso.bits = {}, - cso.max_matching = {-1}, - input = 0, + cso.max_matching = {10}, + input = -5, + correct = 1), + // Skipping category < 0 and dummy categorical node: max_matching == -1. Prevented by FIL import. + // cannot match ( > INT_MAX) + CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true), + cso.bits = {0b1111'1111}, + cso.max_matching = {7}, + input = (float)(1ll << 33ll), + correct = 1), + // cannot match ( > max_matching and integer) + CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true), + cso.bits = {0b1111'1111}, + cso.max_matching = {1}, + input = 2, + correct = 1), + // matches ( > max_matching only due to fractional part) + CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true), + cso.bits = {0b1111'1111}, + cso.max_matching = {1}, + input = 1.8f, + correct = 2), + // cannot match ( > max_matching not only due to fractional part) + CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true), + cso.bits = {0b1111'1111}, + cso.max_matching = {1}, + input = 2.1f, + correct = 1), + // cannot match ( > max_matching not only due to fractional part) + CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true), + cso.bits = {0b1111'1111}, + cso.max_matching = {1}, + input = 2.8f, correct = 1), // does not match (bits[category] == 0, category == 0) CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true), @@ -219,7 +250,13 @@ std::vector params = { cso.max_matching = {0}, input = 0, correct = 1), - // matches + // matches (negative zero) + CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true), + cso.bits = {0b0000'0001}, + cso.max_matching = {0}, + input = -0.0f, + correct = 2), + // matches (positive zero) CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true), cso.bits = {0b0000'0001}, cso.max_matching = {0}, @@ -228,7 +265,7 @@ std::vector params = { // matches CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true), cso.bits = {0b0000'0101}, - cso.max_matching = {2, -1}, + cso.max_matching = {2, 0}, input = 2, correct = 2), // does not match (bits[category] == 0, category > 0) @@ -237,13 +274,25 @@ std::vector params = { cso.max_matching = {2}, input = 1, correct = 1), - // cannot match (max_matching[fid=1] == -1) + // cannot match (max_matching[fid=1] < input) CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true), node.fid = 1, cso.bits = {0b0000'0101}, - cso.max_matching = {2, -1}, + cso.max_matching = {2, 0}, input = 2, correct = 1), + // default left + CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true, def_left = true), + cso.bits = {0b0000'0101}, + cso.max_matching = {2}, + input = NAN, + correct = 1), + // default right + CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true, def_left = false), + cso.bits = {0b0000'0101}, + cso.max_matching = {2}, + input = NAN, + correct = 2), }; TEST_P(ChildIndexTestDense, Predict) { check(); } diff --git a/cpp/test/sg/fil_test.cu b/cpp/test/sg/fil_test.cu index 293222667e..958624596a 100644 --- a/cpp/test/sg/fil_test.cu +++ b/cpp/test/sg/fil_test.cu @@ -161,7 +161,8 @@ struct replace_some_floating_with_categorical { { int max_matching_cat = max_matching_cat_d[data_idx % num_cols]; if (max_matching_cat == -1) return data; - return roundf((data * 0.5f + 0.5f) * max_matching_cat); + // also test invalid (negative) categories + return roundf((data * 0.5f + 0.5f) * max_matching_cat - 1.0); } }; @@ -305,8 +306,8 @@ class BaseFilTest : public testing::TestWithParam { for (int fid = 0; fid < ps.num_cols; ++fid) { feature_categorical[fid] = fc(gen); if (feature_categorical[fid]) { - // even for some categorical features, we will have no matching categories - float mm = pow(10, mmc(gen)) - 1.0f; + // categorical features will never have max_matching == -1 + float mm = pow(10, mmc(gen)); ASSERT(mm < INT_MAX, "internal error: max_magnitude_of_matching_cat %f is too large", ps.max_magnitude_of_matching_cat); diff --git a/python/cuml/fil/fil.pyx b/python/cuml/fil/fil.pyx index 7e977d195f..b18b8dfa34 100644 --- a/python/cuml/fil/fil.pyx +++ b/python/cuml/fil/fil.pyx @@ -289,7 +289,17 @@ cdef class ForestInference_impl(): Parameters ---------- X : float32 array-like (device or host) shape = (n_samples, n_features) - For optimal performance, pass a device array with C-style layout + For optimal performance, pass a device array with C-style layout. + For categorical features: category < 0.0 or category > 16'777'214 + is equivalent to out-of-dictionary category (not matching). + -0.0 represents category 0. + If float(int(category)) != category, we will discard the + fractional part. E.g. 3.8 represents category 3 regardless of + max_matching value. FIL will reject a model where an integer + within [0, max_matching + 1] cannot be represented precisely + as a float32. + NANs work the same between numerical and categorical inputs: + they are missing values and follow Treelite's DefaultLeft. preds : float32 device array, shape = n_samples predict_proba : bool, whether to output class probabilities(vs classes) Supported only for binary classification. output format diff --git a/python/cuml/test/test_fil.py b/python/cuml/test/test_fil.py index be460e7b68..1737a6c8e8 100644 --- a/python/cuml/test/test_fil.py +++ b/python/cuml/test/test_fil.py @@ -17,7 +17,7 @@ import pytest import os import pandas as pd -from random import sample, seed +from math import ceil from cuml import ForestInference from cuml.test.utils import array_equal, unit_param, \ @@ -491,17 +491,33 @@ def test_output_args(small_classifier_and_preds): assert array_equal(fil_preds, xgb_preds, 1e-3) -def to_categorical(features, n_categorical): +def to_categorical(features, n_categorical, invalid_pct, rng): + """ returns data in two formats: pandas (for LightGBM) and numpy (for FIL) + """ # the main bottleneck (>80%) of to_categorical() is the pandas operations n_features = features.shape[1] df_cols = {} # all categorical columns cat_cols = features[:, :n_categorical] - cat_cols = cat_cols - cat_cols.min(axis=1, keepdims=True) # range [0, ?] - cat_cols /= cat_cols.max(axis=1, keepdims=True) # range [0, 1] + cat_cols = cat_cols - cat_cols.min(axis=0, keepdims=True) # range [0, ?] + cat_cols /= cat_cols.max(axis=0, keepdims=True) # range [0, 1] rough_n_categories = 100 # round into rough_n_categories bins cat_cols = (cat_cols * rough_n_categories).astype(int) + # randomly inject invalid categories + invalid_idx = rng.choice( + a=cat_cols.size, + size=ceil(cat_cols.size * invalid_pct / 100), + replace=False, + shuffle=False) + cat_cols.flat[invalid_idx] += rough_n_categories + + new_features = features.copy() + new_features[:, :n_categorical] = cat_cols + + # shuffle the columns around + new_idx = rng.choice(n_features, n_features, replace=False, shuffle=True) + new_matrix = new_features[:, new_idx] for icol in range(n_categorical): col = cat_cols[:, icol] df_cols[icol] = pd.Series(pd.Categorical(col, @@ -510,11 +526,9 @@ def to_categorical(features, n_categorical): for icol in range(n_categorical, n_features): df_cols[icol] = pd.Series(features[:, icol]) # shuffle the columns around - seed(42) - new_idx = sample(range(n_features), k=n_features) df_cols = {i: df_cols[new_idx[i]] for i in range(n_features)} - return pd.DataFrame(df_cols) + return pd.DataFrame(df_cols), new_matrix @pytest.mark.parametrize('num_classes', [2, 5]) @@ -532,14 +546,16 @@ def test_lightgbm(tmp_path, num_classes, n_categorical): n_rows = 500 n_informative = 'auto' + state = np.random.RandomState(43210) X, y = simulate_data(n_rows, n_features, num_classes, n_informative=n_informative, - random_state=43210, + random_state=state, classification=True) + rng = np.random.default_rng(hash(state)) if n_categorical > 0: - X_fit = to_categorical(X, n_categorical) + X_fit, X = to_categorical(X, n_categorical, 10, rng) else: X_fit = X @@ -560,7 +576,7 @@ def test_lightgbm(tmp_path, num_classes, n_categorical): # binary classification gbm_proba = bst.predict(X) fil_proba = fm.predict_proba(X)[:, 1] - gbm_preds = (gbm_proba > 0.5) + gbm_preds = (gbm_proba > 0.5).astype(int) fil_preds = fm.predict(X) assert array_equal(gbm_preds, fil_preds) np.testing.assert_allclose(gbm_proba, fil_proba, @@ -572,11 +588,11 @@ def test_lightgbm(tmp_path, num_classes, n_categorical): n_estimators=num_round) lgm.fit(X_fit, y) lgm.booster_.save_model(model_path) + lgm_preds = lgm.predict(X).astype(int) fm = ForestInference.load(model_path, algo='TREE_REORG', output_class=True, model_type="lightgbm") - lgm_preds = lgm.predict(X) assert array_equal(lgm.booster_.predict(X).argmax(axis=1), lgm_preds) assert array_equal(lgm_preds, fm.predict(X)) # lightgbm uses float64 thresholds, while FIL uses float32