Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix potential CUDA context poison when negative (invalid) categories provided to FIL model [21.10] #4315

Merged
24 changes: 15 additions & 9 deletions cpp/src/fil/fil.cu
Original file line number Diff line number Diff line change
Expand Up @@ -802,8 +802,9 @@ conversion_state<fil_node_t> tl2fil_inner_node(int fil_left_child,
int tl_left = tree.LeftChild(tl_node_id), tl_right = tree.RightChild(tl_node_id);
val_t split = {.f = NAN}; // yes there's a default initializer already
int feature_id = tree.SplitIndex(tl_node_id);
bool is_categorical = tree.SplitType(tl_node_id) == tl::SplitFeatureType::kCategorical;
bool default_left = tree.DefaultLeft(tl_node_id);
bool is_categorical = tree.SplitType(tl_node_id) == tl::SplitFeatureType::kCategorical &&
tree.MatchingCategories(tl_node_id).size() > 0;
bool default_left = tree.DefaultLeft(tl_node_id);
if (tree.SplitType(tl_node_id) == tl::SplitFeatureType::kNumerical) {
split.f = static_cast<float>(tree.Threshold(tl_node_id));
adjust_threshold(&split.f, &tl_left, &tl_right, &default_left, tree.ComparisonOp(tl_node_id));
Expand All @@ -813,13 +814,18 @@ conversion_state<fil_node_t> tl2fil_inner_node(int fil_left_child,
std::swap(tl_left, tl_right);
default_left = !default_left;
}
int sizeof_mask = cat_sets->accessor().sizeof_mask(feature_id);
split.idx = *bit_pool_offset;
*bit_pool_offset += sizeof_mask;
// cat_sets->bits have been zero-initialized
uint8_t* bits = &cat_sets->bits[split.idx];
for (std::uint32_t category : tree.MatchingCategories(tl_node_id)) {
bits[category / BITS_PER_BYTE] |= 1 << (category % BITS_PER_BYTE);
if (tree.MatchingCategories(tl_node_id).size() > 0) {
int sizeof_mask = cat_sets->accessor().sizeof_mask(feature_id);
split.idx = *bit_pool_offset;
*bit_pool_offset += sizeof_mask;
// cat_sets->bits have been zero-initialized
uint8_t* bits = &cat_sets->bits[split.idx];
for (std::uint32_t category : tree.MatchingCategories(tl_node_id)) {
bits[category / BITS_PER_BYTE] |= 1 << (category % BITS_PER_BYTE);
}
} else {
// always branch left in FIL. Already accounted for Treelite branching direction above.
split.f = NAN;
}
} else {
ASSERT(false, "only numerical and categorical split nodes are supported");
Expand Down
20 changes: 15 additions & 5 deletions cpp/src/fil/internal.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,10 @@ struct forest_params_t {
/// FIL_TPB is the number of threads per block to use with FIL kernels
const int FIL_TPB = 256;

constexpr std::int32_t MAX_PRECISE_INT_FLOAT = 1 << 24; // 16'777'216
// as far as FIL is concerned, 16'777'214 is the most we can do.
constexpr std::int32_t MAX_PRECISE_INT_FLOAT = 1 << 24 - 2;
levsnv marked this conversation as resolved.
Show resolved Hide resolved

__host__ __device__ __forceinline__ int fetch_bit(const uint8_t* array, int bit)
__host__ __device__ __forceinline__ int fetch_bit(const uint8_t* array, uint32_t bit)
{
return (array[bit / BITS_PER_BYTE] >> (bit % BITS_PER_BYTE)) & 1;
}
Expand Down Expand Up @@ -337,15 +338,24 @@ struct categorical_sets {

// set count is due to tree_idx + node_within_tree_idx are both ints, hence uint32_t result
template <typename node_t>
__host__ __device__ __forceinline__ int category_matches(node_t node, int category) const
__host__ __device__ __forceinline__ int category_matches(node_t node, float category) const
{
// standard boolean packing. This layout has better ILP
// node.set() is global across feature IDs and is an offset (as opposed
// to set number). If we run out of uint32_t and we have hundreds of
// features with similar categorical feature count, we may consider
// storing node ID within nodes with same feature ID and look up
// {.max_matching, .first_node_offset} = ...[feature_id]
return category <= max_matching[node.fid()] && fetch_bit(bits + node.set(), category);

/* category < 0.0f or category > INT_MAX is equivalent to out-of-dictionary category
(not matching, branch left). -0.0f represents category 0.
If (float)(int)category != category, we will discard the fractional part.
E.g. 3.8f represents category 3 regardless of max_matching value.
FIL will reject a model where an integer within [0, max_matching] cannot be represented
precisely as a 32-bit float.
*/
return category < static_cast<float>(max_matching[node.fid()] + 1) && category >= 0.0f &&
fetch_bit(bits + node.set(), static_cast<int>(category));
}
static int sizeof_mask_from_max_matching(int max_matching)
{
Expand All @@ -372,7 +382,7 @@ struct tree_base {
if (isnan(val)) {
cond = !node.def_left();
} else if (CATS_SUPPORTED && node.is_categorical()) {
cond = cat_sets.category_matches(node, static_cast<int>(val));
cond = cat_sets.category_matches(node, val);
} else {
cond = val >= node.thresh();
}
Expand Down
55 changes: 49 additions & 6 deletions cpp/test/sg/fil_child_index_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -207,19 +207,50 @@ std::vector<ChildIndexTestParams> params = {
CHILD_INDEX_TEST_PARAMS(parent_node_idx = 4, input = NAN, correct = 10), // !def_left
CHILD_INDEX_TEST_PARAMS(
node = NODE(def_left = true), input = NAN, parent_node_idx = 4, correct = 9), // !def_left
// cannot match ( > max_matching)
// cannot match ( < 0 and realistic max_matching)
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true),
cso.bits = {},
cso.max_matching = {-1},
input = 0,
cso.max_matching = {10},
input = -5,
correct = 1),
// Skipping category < 0 and dummy categorical node: max_matching == -1. Prevented by FIL import.
// cannot match ( > INT_MAX)
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true),
cso.bits = {0b1111'1111},
cso.max_matching = {7},
input = (float)(1ll << 33ll),
correct = 1),
// cannot match ( > max_matching and integer)
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true),
cso.bits = {0b1111'1111},
cso.max_matching = {1},
input = 2,
correct = 1),
// matches ( > max_matching only due to fractional part)
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true),
cso.bits = {0b1111'1111},
cso.max_matching = {1},
input = 1.1f,
correct = 2),
levsnv marked this conversation as resolved.
Show resolved Hide resolved
// cannot match ( > max_matching not only due to fractional part)
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true),
cso.bits = {0b1111'1111},
cso.max_matching = {1},
input = 2.1f,
correct = 1),
// does not match (bits[category] == 0, category == 0)
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true),
cso.bits = {0b0000'0000},
cso.max_matching = {0},
input = 0,
correct = 1),
// matches
// matches (negative zero)
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true),
cso.bits = {0b0000'0001},
cso.max_matching = {0},
input = -0.0f,
correct = 2),
// matches (positive zero)
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true),
cso.bits = {0b0000'0001},
cso.max_matching = {0},
Expand All @@ -228,7 +259,7 @@ std::vector<ChildIndexTestParams> params = {
// matches
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true),
cso.bits = {0b0000'0101},
cso.max_matching = {2, -1},
cso.max_matching = {2, 0},
input = 2,
correct = 2),
// does not match (bits[category] == 0, category > 0)
Expand All @@ -241,9 +272,21 @@ std::vector<ChildIndexTestParams> params = {
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true),
levsnv marked this conversation as resolved.
Show resolved Hide resolved
node.fid = 1,
cso.bits = {0b0000'0101},
cso.max_matching = {2, -1},
cso.max_matching = {2, 0},
input = 2,
correct = 1),
// default left
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true, def_left = true),
cso.bits = {0b0000'0101},
cso.max_matching = {2},
input = NAN,
correct = 1),
// default right
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true, def_left = false),
cso.bits = {0b0000'0101},
cso.max_matching = {2},
input = NAN,
correct = 2),
};

TEST_P(ChildIndexTestDense, Predict) { check(); }
Expand Down
7 changes: 4 additions & 3 deletions cpp/test/sg/fil_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ struct replace_some_floating_with_categorical {
{
int max_matching_cat = max_matching_cat_d[data_idx % num_cols];
if (max_matching_cat == -1) return data;
return roundf((data * 0.5f + 0.5f) * max_matching_cat);
// also test invalid (negative) categories
return roundf((data * 0.5f + 0.5f) * max_matching_cat - 1.0);
levsnv marked this conversation as resolved.
Show resolved Hide resolved
}
};

Expand Down Expand Up @@ -305,8 +306,8 @@ class BaseFilTest : public testing::TestWithParam<FilTestParams> {
for (int fid = 0; fid < ps.num_cols; ++fid) {
feature_categorical[fid] = fc(gen);
if (feature_categorical[fid]) {
// even for some categorical features, we will have no matching categories
float mm = pow(10, mmc(gen)) - 1.0f;
// categorical features will never have max_matching == -1
float mm = pow(10, mmc(gen));
ASSERT(mm < INT_MAX,
"internal error: max_magnitude_of_matching_cat %f is too large",
ps.max_magnitude_of_matching_cat);
Expand Down
12 changes: 11 additions & 1 deletion python/cuml/fil/fil.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,17 @@ cdef class ForestInference_impl():
Parameters
----------
X : float32 array-like (device or host) shape = (n_samples, n_features)
For optimal performance, pass a device array with C-style layout
For optimal performance, pass a device array with C-style layout.
For categorical features: category < 0.0 or category > 2.0**31-1 is
levsnv marked this conversation as resolved.
Show resolved Hide resolved
equivalent to out-of-dictionary category (not matching).
-0.0 represents category 0.
If float(int(category)) != category, we will discard the
fractional part. E.g. 3.8 represents category 3 regardless of
max_matching value. FIL will reject a model where an integer
within [0, max_matching] cannot be represented precisely
levsnv marked this conversation as resolved.
Show resolved Hide resolved
as a float32.
NANs work the same between numerical and categorical inputs:
they are missing values and follow Treelite's DefaultLeft.
preds : float32 device array, shape = n_samples
predict_proba : bool, whether to output class probabilities(vs classes)
Supported only for binary classification. output format
Expand Down