Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent undefined behavior when passing handle from Treelite to cuML FIL #5849

Merged
merged 9 commits into from
Apr 20, 2024
2 changes: 1 addition & 1 deletion ci/test_wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ rapids-logger "pytest cuml single GPU"
./ci/run_cuml_singlegpu_pytests.sh \
--numprocesses=8 \
--dist=worksteal \
-k 'not test_sparse_pca_inputs and not test_fil_skl_classification' \
-k 'not test_sparse_pca_inputs' \
--junitxml="${RAPIDS_TESTS_DIR}/junit-cuml.xml"

# Run test_sparse_pca_inputs separately
Expand Down
29 changes: 19 additions & 10 deletions python/cuml/experimental/fil/fil.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ nvtx_annotate = gpu_only_import_from("nvtx", "annotate", alt=null_decorator)

cdef extern from "treelite/c_api.h":
ctypedef void* TreeliteModelHandle
cdef int TreeliteDeserializeModelFromBytes(const char* bytes_seq, size_t len,
TreeliteModelHandle* out) except +
cdef int TreeliteFreeModel(TreeliteModelHandle handle) except +
cdef const char* TreeliteGetLastError()


cdef raft_proto_device_t get_device_type(arr):
Expand Down Expand Up @@ -137,16 +141,19 @@ cdef class ForestInference_impl():
use_double_precision_bool = use_double_precision
use_double_precision_c = use_double_precision_bool

try:
model_handle = tl_model.handle.value
except AttributeError:
try:
model_handle = tl_model.handle
except AttributeError:
try:
model_handle = tl_model.value
except AttributeError:
model_handle = tl_model
if not isinstance(tl_model, treelite.Model):
raise ValueError("tl_model must be a treelite.Model object")
# Serialize Treelite model object and de-serialize again,
# to get around C++ ABI incompatibilities (due to different compilers
# being used to build cuML pip wheel vs. Treelite pip wheel)
bytes_seq = tl_model.serialize_bytes()
cdef TreeliteModelHandle model_handle = NULL
cdef int res = TreeliteDeserializeModelFromBytes(bytes_seq, len(bytes_seq),
&model_handle)
cdef str err_msg
if res < 0:
err_msg = TreeliteGetLastError().decode("UTF-8")
raise RuntimeError(f"Failed to load Treelite model from bytes ({err_msg})")

cdef raft_proto_device_t dev_type
if mem_type.is_device_accessible:
Expand All @@ -169,6 +176,8 @@ cdef class ForestInference_impl():
self.raft_proto_handle.get_next_usable_stream()
)

TreeliteFreeModel(model_handle)

def get_dtype(self):
return [np.float32, np.float64][self.model.is_double_precision()]

Expand Down
65 changes: 50 additions & 15 deletions python/cuml/fil/fil.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ cdef extern from "treelite/c_api.h":
TreeliteModelHandle* out) except +
cdef int TreeliteSerializeModelToFile(TreeliteModelHandle handle,
const char* filename) except +
cdef int TreeliteDeserializeModelFromBytes(const char* bytes_seq, size_t len,
TreeliteModelHandle* out) except +
cdef int TreeliteGetHeaderField(
TreeliteModelHandle model, const char * name, TreelitePyBufferFrame* out_frame) except +
cdef const char* TreeliteGetLastError()
Expand Down Expand Up @@ -164,6 +166,27 @@ cdef class TreeliteModel():
cdef uintptr_t model_ptr = <uintptr_t>model_handle
TreeliteFreeModel(<TreeliteModelHandle> model_ptr)

@classmethod
def from_treelite_bytes(cls, bytes bytes_seq):
"""
Returns a TreeliteModel object loaded from bytes representing a
serialized Treelite model object.

Parameters
----------
bytes_seq: bytes
bytes representing a serialized Treelite model
"""
cdef TreeliteModelHandle handle
cdef int res = TreeliteDeserializeModelFromBytes(bytes_seq, len(bytes_seq), &handle)
cdef str err_msg
if res < 0:
err_msg = TreeliteGetLastError().decode("UTF-8")
raise RuntimeError(f"Failed to load Treelite model from bytes ({err_msg})")
cdef TreeliteModel model = TreeliteModel()
model.set_handle(handle)
return model

@classmethod
def from_filename(cls, filename, model_type="xgboost"):
"""
Expand All @@ -177,30 +200,32 @@ cdef class TreeliteModel():
model_type : string
Type of model: 'xgboost', 'xgboost_json', or 'lightgbm'
"""
filename_bytes = filename.encode("UTF-8")
config_bytes = "{}".encode("UTF-8")
cdef bytes filename_bytes = filename.encode("UTF-8")
cdef bytes config_bytes = b"{}"
cdef TreeliteModelHandle handle
cdef int res
cdef str err_msg
if model_type == "xgboost":
res = TreeliteLoadXGBoostModelLegacyBinary(filename_bytes, config_bytes, &handle)
if res < 0:
err = TreeliteGetLastError()
raise RuntimeError("Failed to load %s (%s)" % (filename, err))
err_msg = TreeliteGetLastError().decode("UTF-8")
raise RuntimeError(f"Failed to load {filename} ({err_msg})")
elif model_type == "xgboost_json":
res = TreeliteLoadXGBoostModel(filename_bytes, config_bytes, &handle)
if res < 0:
err = TreeliteGetLastError()
raise RuntimeError("Failed to load %s (%s)" % (filename, err))
err_msg = TreeliteGetLastError().decode("UTF-8")
raise RuntimeError(f"Failed to load {filename} ({err_msg})")
elif model_type == "lightgbm":
logger.warn("Treelite currently does not support float64 model"
" parameters. Accuracy may degrade slightly relative"
" to native LightGBM invocation.")
res = TreeliteLoadLightGBMModel(filename_bytes, config_bytes, &handle)
if res < 0:
err = TreeliteGetLastError()
raise RuntimeError("Failed to load %s (%s)" % (filename, err))
err_msg = TreeliteGetLastError().decode("UTF-8")
raise RuntimeError(f"Failed to load {filename} ({err_msg})")
else:
raise ValueError("Unknown model type %s" % model_type)
model = TreeliteModel()
raise ValueError(f"Unknown model type {model_type}")
cdef TreeliteModel model = TreeliteModel()
model.set_handle(handle)
return model

Expand All @@ -215,7 +240,11 @@ cdef class TreeliteModel():
"""
assert self.handle != NULL
filename_bytes = filename.encode("UTF-8")
TreeliteSerializeModelToFile(self.handle, filename_bytes)
cdef int res = TreeliteSerializeModelToFile(self.handle, filename_bytes)
cdef str err_msg
if res < 0:
err_msg = TreeliteGetLastError().decode("UTF-8")
raise RuntimeError(f"Failed to serialize Treelite model ({err_msg})")

@classmethod
def from_treelite_model_handle(cls,
Expand Down Expand Up @@ -514,10 +543,11 @@ cdef class ForestInference_impl():
&treelite_params)
# Get num_class
cdef TreelitePyBufferFrame frame
res = TreeliteGetHeaderField(<TreeliteModelHandle> model_ptr, "num_class", &frame)
cdef int res = TreeliteGetHeaderField(<TreeliteModelHandle> model_ptr, "num_class", &frame)
cdef str err_msg
if res < 0:
err = TreeliteGetLastError()
raise RuntimeError(f"Failed to fetch num_class: {err}")
err_msg = TreeliteGetLastError().decode("UTF-8")
raise RuntimeError(f"Failed to fetch num_class: {err_msg}")
view = memoryview(MakePyBufferFrameWrapper(frame))
self.num_class = np.asarray(view).copy()
if len(self.num_class) > 1:
Expand Down Expand Up @@ -882,8 +912,13 @@ class ForestInference(Base,
" parameters. Accuracy may degrade slightly relative to"
" native sklearn invocation.")
tl_model = tl_skl.import_model(skl_model)
# Serialize Treelite model object and de-serialize again,
# to get around C++ ABI incompatibilities (due to different compilers
# being used to build cuML pip wheel vs. Treelite pip wheel)
cdef bytes bytes_seq = tl_model.serialize_bytes()
cdef TreeliteModel tl_model2 = TreeliteModel.from_treelite_bytes(bytes_seq)
cuml_fm.load_from_treelite_model(
model=tl_model,
model=tl_model2,
output_class=output_class,
threshold=threshold,
algo=algo,
Expand Down
Loading