From 761c06f85c6a16c51742e42df9c096ea0a9869e1 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Tue, 16 Apr 2024 13:47:15 -0700 Subject: [PATCH 1/7] Prevent undefined behavior when passing handle from Treelite to cuML FIL --- python/cuml/fil/fil.pyx | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/python/cuml/fil/fil.pyx b/python/cuml/fil/fil.pyx index 16413b34ac..e7a1820a02 100644 --- a/python/cuml/fil/fil.pyx +++ b/python/cuml/fil/fil.pyx @@ -65,6 +65,8 @@ cdef extern from "treelite/c_api.h": TreeliteModelHandle* out) except + cdef int TreeliteSerializeModelToFile(TreeliteModelHandle handle, const char* filename) except + + cdef int TreeliteDeserializeModelFromBytes(const char* bytes_seq, size_t len, + TreeliteModelHandle* out) except + cdef int TreeliteGetHeaderField( TreeliteModelHandle model, const char * name, TreelitePyBufferFrame* out_frame) except + cdef const char* TreeliteGetLastError() @@ -164,6 +166,26 @@ cdef class TreeliteModel(): cdef uintptr_t model_ptr = model_handle TreeliteFreeModel( model_ptr) + @classmethod + def from_treelite_bytes(cls, bytes_seq): + """ + Returns a TreeliteModel object loaded from bytes representing a + serialized Treelite model object. + + Parameters + ---------- + bytes_seq: bytes + bytes representing a serialized Treelite model + """ + cdef TreeliteModelHandle handle + res = TreeliteDeserializeModelFromBytes(bytes_seq, len(bytes_seq), &handle) + if res < 0: + err = TreeliteGetLastError() + raise RuntimeError("Failed to load Treelite model from bytes (%s)" % (err)) + model = TreeliteModel() + model.set_handle(handle) + return model + @classmethod def from_filename(cls, filename, model_type="xgboost"): """ @@ -882,8 +904,13 @@ class ForestInference(Base, " parameters. Accuracy may degrade slightly relative to" " native sklearn invocation.") tl_model = tl_skl.import_model(skl_model) + # Serialize Treelite model object and de-serialize again, + # to get around C++ ABI incompatibilities (due to different compilers + # being used to build cuML pip wheel vs. Treelite pip wheel) + bytes_seq = tl_model.serialize_bytes() + tl_model2 = TreeliteModel.from_treelite_bytes(bytes_seq) cuml_fm.load_from_treelite_model( - model=tl_model, + model=tl_model2, output_class=output_class, threshold=threshold, algo=algo, From a314c1af159d2d1f5c4607005b1eb167665a6ccf Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Tue, 16 Apr 2024 13:54:30 -0700 Subject: [PATCH 2/7] Re-enable FIL tests --- ci/test_wheel.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/test_wheel.sh b/ci/test_wheel.sh index a0adddc6b7..86eef035cd 100755 --- a/ci/test_wheel.sh +++ b/ci/test_wheel.sh @@ -27,7 +27,7 @@ rapids-logger "pytest cuml single GPU" ./ci/run_cuml_singlegpu_pytests.sh \ --numprocesses=8 \ --dist=worksteal \ - -k 'not test_sparse_pca_inputs and not test_fil_skl_classification' \ + -k 'not test_sparse_pca_inputs' \ --junitxml="${RAPIDS_TESTS_DIR}/junit-cuml.xml" # Run test_sparse_pca_inputs separately From 86079d811000a6ef790e89e3209094f188c36fe1 Mon Sep 17 00:00:00 2001 From: Philip Hyunsu Cho Date: Tue, 16 Apr 2024 14:54:46 -0700 Subject: [PATCH 3/7] Apply suggestions from code review Co-authored-by: jakirkham --- python/cuml/fil/fil.pyx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/cuml/fil/fil.pyx b/python/cuml/fil/fil.pyx index e7a1820a02..d028d65795 100644 --- a/python/cuml/fil/fil.pyx +++ b/python/cuml/fil/fil.pyx @@ -167,7 +167,7 @@ cdef class TreeliteModel(): TreeliteFreeModel( model_ptr) @classmethod - def from_treelite_bytes(cls, bytes_seq): + def from_treelite_bytes(cls, bytes bytes_seq): """ Returns a TreeliteModel object loaded from bytes representing a serialized Treelite model object. @@ -907,7 +907,7 @@ class ForestInference(Base, # Serialize Treelite model object and de-serialize again, # to get around C++ ABI incompatibilities (due to different compilers # being used to build cuML pip wheel vs. Treelite pip wheel) - bytes_seq = tl_model.serialize_bytes() + cdef bytes bytes_seq = tl_model.serialize_bytes() tl_model2 = TreeliteModel.from_treelite_bytes(bytes_seq) cuml_fm.load_from_treelite_model( model=tl_model2, From 9d8f158222884174d1c62b9cfc2b740237a23c23 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Tue, 16 Apr 2024 16:35:04 -0700 Subject: [PATCH 4/7] Add more typing Co-authored-by: jakirkham --- python/cuml/fil/fil.pyx | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/python/cuml/fil/fil.pyx b/python/cuml/fil/fil.pyx index d028d65795..63ddb99930 100644 --- a/python/cuml/fil/fil.pyx +++ b/python/cuml/fil/fil.pyx @@ -178,11 +178,12 @@ cdef class TreeliteModel(): bytes representing a serialized Treelite model """ cdef TreeliteModelHandle handle - res = TreeliteDeserializeModelFromBytes(bytes_seq, len(bytes_seq), &handle) + cdef int res = TreeliteDeserializeModelFromBytes(bytes_seq, len(bytes_seq), &handle) + cdef const char* err if res < 0: err = TreeliteGetLastError() raise RuntimeError("Failed to load Treelite model from bytes (%s)" % (err)) - model = TreeliteModel() + cdef TreeliteModel model = TreeliteModel() model.set_handle(handle) return model @@ -202,6 +203,8 @@ cdef class TreeliteModel(): filename_bytes = filename.encode("UTF-8") config_bytes = "{}".encode("UTF-8") cdef TreeliteModelHandle handle + cdef int res + cdef const char* err if model_type == "xgboost": res = TreeliteLoadXGBoostModelLegacyBinary(filename_bytes, config_bytes, &handle) if res < 0: @@ -222,7 +225,7 @@ cdef class TreeliteModel(): raise RuntimeError("Failed to load %s (%s)" % (filename, err)) else: raise ValueError("Unknown model type %s" % model_type) - model = TreeliteModel() + cdef TreeliteModel model = TreeliteModel() model.set_handle(handle) return model @@ -237,7 +240,11 @@ cdef class TreeliteModel(): """ assert self.handle != NULL filename_bytes = filename.encode("UTF-8") - TreeliteSerializeModelToFile(self.handle, filename_bytes) + cdef int res = TreeliteSerializeModelToFile(self.handle, filename_bytes) + cdef const char* err + if res < 0: + err = TreeliteGetLastError() + raise RuntimeError("Failed to serialize Treelite model (%s)" % (err)) @classmethod def from_treelite_model_handle(cls, @@ -908,7 +915,7 @@ class ForestInference(Base, # to get around C++ ABI incompatibilities (due to different compilers # being used to build cuML pip wheel vs. Treelite pip wheel) cdef bytes bytes_seq = tl_model.serialize_bytes() - tl_model2 = TreeliteModel.from_treelite_bytes(bytes_seq) + cdef TreeliteModel tl_model2 = TreeliteModel.from_treelite_bytes(bytes_seq) cuml_fm.load_from_treelite_model( model=tl_model2, output_class=output_class, From 0e15e167dd285b3fb5ab4ce78f2291dd0b19b805 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Tue, 16 Apr 2024 20:18:18 -0700 Subject: [PATCH 5/7] Improve formatting for errors Co-authored-by: jakirkham --- python/cuml/fil/fil.pyx | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/python/cuml/fil/fil.pyx b/python/cuml/fil/fil.pyx index 63ddb99930..11331d1b9d 100644 --- a/python/cuml/fil/fil.pyx +++ b/python/cuml/fil/fil.pyx @@ -179,10 +179,10 @@ cdef class TreeliteModel(): """ cdef TreeliteModelHandle handle cdef int res = TreeliteDeserializeModelFromBytes(bytes_seq, len(bytes_seq), &handle) - cdef const char* err + cdef str err_msg if res < 0: - err = TreeliteGetLastError() - raise RuntimeError("Failed to load Treelite model from bytes (%s)" % (err)) + err_msg = TreeliteGetLastError().decode("UTF-8") + raise RuntimeError(f"Failed to load Treelite model from bytes ({err_msg})") cdef TreeliteModel model = TreeliteModel() model.set_handle(handle) return model @@ -200,31 +200,31 @@ cdef class TreeliteModel(): model_type : string Type of model: 'xgboost', 'xgboost_json', or 'lightgbm' """ - filename_bytes = filename.encode("UTF-8") - config_bytes = "{}".encode("UTF-8") + cdef bytes filename_bytes = filename.encode("UTF-8") + cdef bytes config_bytes = b"{}" cdef TreeliteModelHandle handle cdef int res - cdef const char* err + cdef str err_msg if model_type == "xgboost": res = TreeliteLoadXGBoostModelLegacyBinary(filename_bytes, config_bytes, &handle) if res < 0: - err = TreeliteGetLastError() - raise RuntimeError("Failed to load %s (%s)" % (filename, err)) + err_msg = TreeliteGetLastError().decode("UTF-8") + raise RuntimeError(f"Failed to load {filename} ({err_msg})") elif model_type == "xgboost_json": res = TreeliteLoadXGBoostModel(filename_bytes, config_bytes, &handle) if res < 0: - err = TreeliteGetLastError() - raise RuntimeError("Failed to load %s (%s)" % (filename, err)) + err_msg = TreeliteGetLastError().decode("UTF-8") + raise RuntimeError(f"Failed to load {filename} ({err_msg})") elif model_type == "lightgbm": logger.warn("Treelite currently does not support float64 model" " parameters. Accuracy may degrade slightly relative" " to native LightGBM invocation.") res = TreeliteLoadLightGBMModel(filename_bytes, config_bytes, &handle) if res < 0: - err = TreeliteGetLastError() - raise RuntimeError("Failed to load %s (%s)" % (filename, err)) + err_msg = TreeliteGetLastError().decode("UTF-8") + raise RuntimeError(f"Failed to load {filename} ({err_msg})") else: - raise ValueError("Unknown model type %s" % model_type) + raise ValueError(f"Unknown model type {model_type}") cdef TreeliteModel model = TreeliteModel() model.set_handle(handle) return model @@ -241,10 +241,10 @@ cdef class TreeliteModel(): assert self.handle != NULL filename_bytes = filename.encode("UTF-8") cdef int res = TreeliteSerializeModelToFile(self.handle, filename_bytes) - cdef const char* err + cdef str err_msg if res < 0: - err = TreeliteGetLastError() - raise RuntimeError("Failed to serialize Treelite model (%s)" % (err)) + err_msg = TreeliteGetLastError().decode("UTF-8") + raise RuntimeError("Failed to serialize Treelite model (%s)" % (err_msg)) @classmethod def from_treelite_model_handle(cls, @@ -543,10 +543,11 @@ cdef class ForestInference_impl(): &treelite_params) # Get num_class cdef TreelitePyBufferFrame frame - res = TreeliteGetHeaderField( model_ptr, "num_class", &frame) + cdef int res = TreeliteGetHeaderField( model_ptr, "num_class", &frame) + cdef str err_msg if res < 0: - err = TreeliteGetLastError() - raise RuntimeError(f"Failed to fetch num_class: {err}") + err_msg = TreeliteGetLastError().decode("UTF-8") + raise RuntimeError(f"Failed to fetch num_class: {err_msg}") view = memoryview(MakePyBufferFrameWrapper(frame)) self.num_class = np.asarray(view).copy() if len(self.num_class) > 1: From fb90c973c5074571e945b4e79e5e49a2a2a20582 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Tue, 16 Apr 2024 21:39:00 -0700 Subject: [PATCH 6/7] Apply the same fix to experimental FIL --- python/cuml/experimental/fil/fil.pyx | 26 ++++++++++++++++---------- python/cuml/fil/fil.pyx | 2 +- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/python/cuml/experimental/fil/fil.pyx b/python/cuml/experimental/fil/fil.pyx index 5057b22529..a256114e5e 100644 --- a/python/cuml/experimental/fil/fil.pyx +++ b/python/cuml/experimental/fil/fil.pyx @@ -55,6 +55,9 @@ nvtx_annotate = gpu_only_import_from("nvtx", "annotate", alt=null_decorator) cdef extern from "treelite/c_api.h": ctypedef void* TreeliteModelHandle + cdef int TreeliteDeserializeModelFromBytes(const char* bytes_seq, size_t len, + TreeliteModelHandle* out) except + + cdef const char* TreeliteGetLastError() cdef raft_proto_device_t get_device_type(arr): @@ -137,16 +140,19 @@ cdef class ForestInference_impl(): use_double_precision_bool = use_double_precision use_double_precision_c = use_double_precision_bool - try: - model_handle = tl_model.handle.value - except AttributeError: - try: - model_handle = tl_model.handle - except AttributeError: - try: - model_handle = tl_model.value - except AttributeError: - model_handle = tl_model + if not isinstance(tl_model, treelite.Model): + raise ValueError("tl_model must be a treelite.Model object") + # Serialize Treelite model object and de-serialize again, + # to get around C++ ABI incompatibilities (due to different compilers + # being used to build cuML pip wheel vs. Treelite pip wheel) + bytes_seq = tl_model.serialize_bytes() + cdef TreeliteModelHandle model_handle = NULL + cdef int res = TreeliteDeserializeModelFromBytes(bytes_seq, len(bytes_seq), + &model_handle) + cdef str err_msg + if res < 0: + err_msg = TreeliteGetLastError().decode("UTF-8") + raise RuntimeError(f"Failed to load Treelite model from bytes ({err_msg})") cdef raft_proto_device_t dev_type if mem_type.is_device_accessible: diff --git a/python/cuml/fil/fil.pyx b/python/cuml/fil/fil.pyx index 11331d1b9d..170f18992e 100644 --- a/python/cuml/fil/fil.pyx +++ b/python/cuml/fil/fil.pyx @@ -244,7 +244,7 @@ cdef class TreeliteModel(): cdef str err_msg if res < 0: err_msg = TreeliteGetLastError().decode("UTF-8") - raise RuntimeError("Failed to serialize Treelite model (%s)" % (err_msg)) + raise RuntimeError(f"Failed to serialize Treelite model ({err_msg})") @classmethod def from_treelite_model_handle(cls, From 6297d8f8b3fad866c57796d60e7abb848635ae0f Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Wed, 17 Apr 2024 14:36:36 -0700 Subject: [PATCH 7/7] Prevent memory leak --- python/cuml/experimental/fil/fil.pyx | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/cuml/experimental/fil/fil.pyx b/python/cuml/experimental/fil/fil.pyx index a256114e5e..7fe59e43a1 100644 --- a/python/cuml/experimental/fil/fil.pyx +++ b/python/cuml/experimental/fil/fil.pyx @@ -57,6 +57,7 @@ cdef extern from "treelite/c_api.h": ctypedef void* TreeliteModelHandle cdef int TreeliteDeserializeModelFromBytes(const char* bytes_seq, size_t len, TreeliteModelHandle* out) except + + cdef int TreeliteFreeModel(TreeliteModelHandle handle) except + cdef const char* TreeliteGetLastError() @@ -175,6 +176,8 @@ cdef class ForestInference_impl(): self.raft_proto_handle.get_next_usable_stream() ) + TreeliteFreeModel(model_handle) + def get_dtype(self): return [np.float32, np.float64][self.model.is_double_precision()]