Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix potential CUDA context poison when negative (invalid) categories provided to FIL model [21.10] #4315

Merged
24 changes: 15 additions & 9 deletions cpp/src/fil/fil.cu
Original file line number Diff line number Diff line change
Expand Up @@ -802,8 +802,9 @@ conversion_state<fil_node_t> tl2fil_inner_node(int fil_left_child,
int tl_left = tree.LeftChild(tl_node_id), tl_right = tree.RightChild(tl_node_id);
val_t split = {.f = NAN}; // yes there's a default initializer already
int feature_id = tree.SplitIndex(tl_node_id);
bool is_categorical = tree.SplitType(tl_node_id) == tl::SplitFeatureType::kCategorical;
bool default_left = tree.DefaultLeft(tl_node_id);
bool is_categorical = tree.SplitType(tl_node_id) == tl::SplitFeatureType::kCategorical &&
tree.MatchingCategories(tl_node_id).size() > 0;
bool default_left = tree.DefaultLeft(tl_node_id);
if (tree.SplitType(tl_node_id) == tl::SplitFeatureType::kNumerical) {
split.f = static_cast<float>(tree.Threshold(tl_node_id));
adjust_threshold(&split.f, &tl_left, &tl_right, &default_left, tree.ComparisonOp(tl_node_id));
Expand All @@ -813,13 +814,18 @@ conversion_state<fil_node_t> tl2fil_inner_node(int fil_left_child,
std::swap(tl_left, tl_right);
default_left = !default_left;
}
int sizeof_mask = cat_sets->accessor().sizeof_mask(feature_id);
split.idx = *bit_pool_offset;
*bit_pool_offset += sizeof_mask;
// cat_sets->bits have been zero-initialized
uint8_t* bits = &cat_sets->bits[split.idx];
for (std::uint32_t category : tree.MatchingCategories(tl_node_id)) {
bits[category / BITS_PER_BYTE] |= 1 << (category % BITS_PER_BYTE);
if (tree.MatchingCategories(tl_node_id).size() > 0) {
int sizeof_mask = cat_sets->accessor().sizeof_mask(feature_id);
split.idx = *bit_pool_offset;
*bit_pool_offset += sizeof_mask;
// cat_sets->bits have been zero-initialized
uint8_t* bits = &cat_sets->bits[split.idx];
for (std::uint32_t category : tree.MatchingCategories(tl_node_id)) {
bits[category / BITS_PER_BYTE] |= 1 << (category % BITS_PER_BYTE);
}
} else {
// always branch left in FIL. Already accounted for Treelite branching direction above.
split.f = NAN;
}
} else {
ASSERT(false, "only numerical and categorical split nodes are supported");
Expand Down
18 changes: 16 additions & 2 deletions cpp/src/fil/internal.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ const int FIL_TPB = 256;

constexpr std::int32_t MAX_PRECISE_INT_FLOAT = 1 << 24; // 16'777'216

__host__ __device__ __forceinline__ int fetch_bit(const uint8_t* array, int bit)
__host__ __device__ __forceinline__ int fetch_bit(const uint8_t* array, uint32_t bit)
{
return (array[bit / BITS_PER_BYTE] >> (bit % BITS_PER_BYTE)) & 1;
}
Expand Down Expand Up @@ -345,7 +345,13 @@ struct categorical_sets {
// features with similar categorical feature count, we may consider
// storing node ID within nodes with same feature ID and look up
// {.max_matching, .first_node_offset} = ...[feature_id]
return category <= max_matching[node.fid()] && fetch_bit(bits + node.set(), category);

/* category < 0 is equivalent to out-of-dictionary category (not matching, branch left)
importing ensured no categorical nodes (hence, features) with max_matching == -1 (or,
consequently, max_matching < 0)
*/
return static_cast<uint32_t>(category) <= static_cast<uint32_t>(max_matching[node.fid()]) &&
levsnv marked this conversation as resolved.
Show resolved Hide resolved
fetch_bit(bits + node.set(), category);
wphicks marked this conversation as resolved.
Show resolved Hide resolved
}
static int sizeof_mask_from_max_matching(int max_matching)
{
Expand All @@ -372,6 +378,14 @@ struct tree_base {
if (isnan(val)) {
cond = !node.def_left();
} else if (CATS_SUPPORTED && node.is_categorical()) {
/* cannot cast float directly to uint32_t since C++ standard doesn't mandate two's complement
in that case:
http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2014/n4296.pdf
levsnv marked this conversation as resolved.
Show resolved Hide resolved
4.9 Floating-integral conversions [conv.fpint]
1 A prvalue of a floating point type can be converted to a prvalue of an integer type. The
conversion truncates; that is, the fractional part is discarded. The behavior is undefined
if the truncated value cannot be represented in the destination type.
*/
cond = cat_sets.category_matches(node, static_cast<int>(val));
levsnv marked this conversation as resolved.
Show resolved Hide resolved
levsnv marked this conversation as resolved.
Show resolved Hide resolved
} else {
cond = val >= node.thresh();
Expand Down
15 changes: 11 additions & 4 deletions cpp/test/sg/fil_child_index_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,18 @@ std::vector<ChildIndexTestParams> params = {
CHILD_INDEX_TEST_PARAMS(parent_node_idx = 4, input = NAN, correct = 10), // !def_left
CHILD_INDEX_TEST_PARAMS(
node = NODE(def_left = true), input = NAN, parent_node_idx = 4, correct = 9), // !def_left
// cannot match ( < 0 and realistic max_matching)
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true),
cso.bits = {},
cso.max_matching = {10},
input = -5,
correct = 1),
// Skipping category < 0 and dummy categorical node: max_matching == -1. Prevented by FIL import.
// cannot match ( > max_matching)
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true),
cso.bits = {},
cso.max_matching = {-1},
input = 0,
cso.max_matching = {1},
input = 2,
correct = 1),
// does not match (bits[category] == 0, category == 0)
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true),
Expand All @@ -228,7 +235,7 @@ std::vector<ChildIndexTestParams> params = {
// matches
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true),
cso.bits = {0b0000'0101},
cso.max_matching = {2, -1},
cso.max_matching = {2, 0},
input = 2,
correct = 2),
// does not match (bits[category] == 0, category > 0)
Expand All @@ -241,7 +248,7 @@ std::vector<ChildIndexTestParams> params = {
CHILD_INDEX_TEST_PARAMS(node = NODE(is_categorical = true),
levsnv marked this conversation as resolved.
Show resolved Hide resolved
node.fid = 1,
cso.bits = {0b0000'0101},
cso.max_matching = {2, -1},
cso.max_matching = {2, 0},
input = 2,
correct = 1),
};
Expand Down
7 changes: 4 additions & 3 deletions cpp/test/sg/fil_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ struct replace_some_floating_with_categorical {
{
int max_matching_cat = max_matching_cat_d[data_idx % num_cols];
if (max_matching_cat == -1) return data;
return roundf((data * 0.5f + 0.5f) * max_matching_cat);
// also test invalid (negative) categories
return roundf((data * 0.5f + 0.5f) * max_matching_cat - 1.0);
levsnv marked this conversation as resolved.
Show resolved Hide resolved
}
};

Expand Down Expand Up @@ -305,8 +306,8 @@ class BaseFilTest : public testing::TestWithParam<FilTestParams> {
for (int fid = 0; fid < ps.num_cols; ++fid) {
feature_categorical[fid] = fc(gen);
if (feature_categorical[fid]) {
// even for some categorical features, we will have no matching categories
float mm = pow(10, mmc(gen)) - 1.0f;
// categorical features will never have max_matching == -1
float mm = pow(10, mmc(gen));
ASSERT(mm < INT_MAX,
"internal error: max_magnitude_of_matching_cat %f is too large",
ps.max_magnitude_of_matching_cat);
Expand Down