Skip to content

Commit

Permalink
Correctly handle missing categorical data in experimental FIL (#6132)
Browse files Browse the repository at this point in the history
Correctly handle missing categorical data in FIL when 0 is one of the included categories.

Resolve #5578.

Authors:
  - William Hicks (https://github.com/wphicks)

Approvers:
  - Philip Hyunsu Cho (https://github.com/hcho3)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #6132
  • Loading branch information
wphicks authored Nov 8, 2024
1 parent 9150279 commit 22c3ee8
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 2 deletions.
2 changes: 1 addition & 1 deletion cpp/include/cuml/experimental/fil/detail/evaluate_tree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ HOST DEVICE auto evaluate_tree_impl(node_t const* __restrict__ node,
if (cur_node.is_categorical()) {
auto valid_categories = categorical_set_type{
&cur_node.index(), uint32_t(sizeof(typename node_t::index_type) * 8)};
condition = valid_categories.test(input_val);
condition = valid_categories.test(input_val) && !isnan(input_val);
} else {
condition = (input_val < cur_node.threshold());
}
Expand Down
50 changes: 49 additions & 1 deletion python/cuml/cuml/tests/experimental/test_filex.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -870,3 +870,51 @@ def test_apply(train_device, infer_device, n_classes, tmp_path):
expected_shape = (n_rows, num_boost_round * n_classes)
assert pred_leaf.shape == expected_shape
np.testing.assert_equal(pred_leaf, expected_pred_leaf)


def test_missing_categorical():
builder = treelite.model_builder.ModelBuilder(
threshold_type="float32",
leaf_output_type="float32",
metadata=treelite.model_builder.Metadata(
num_feature=1,
task_type="kBinaryClf",
average_tree_output=False,
num_target=1,
num_class=[1],
leaf_vector_shape=(1, 1),
),
tree_annotation=treelite.model_builder.TreeAnnotation(
num_tree=1, target_id=[0], class_id=[0]
),
postprocessor=treelite.model_builder.PostProcessorFunc(
name="identity"
),
base_scores=[0.0],
)
builder.start_tree()
builder.start_node(0)
builder.categorical_test(
feature_id=0,
category_list=[0, 2],
default_left=False,
category_list_right_child=False,
left_child_key=1,
right_child_key=2,
)
builder.end_node()
builder.start_node(1)
builder.leaf(1.0)
builder.end_node()
builder.start_node(2)
builder.leaf(2.0)
builder.end_node()
builder.end_tree()

model = builder.commit()

input = np.array([[np.nan]])
gtil_preds = treelite.gtil.predict(model, input)
fm = ForestInference.load_from_treelite_model(model)
fil_preds = np.asarray(fm.predict(input))
np.testing.assert_equal(fil_preds.flatten(), gtil_preds.flatten())

0 comments on commit 22c3ee8

Please sign in to comment.