diff --git a/CHANGES.md b/CHANGES.md index 579cbdff..3dbb7e03 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,11 @@ +3.2.1 (TBD) +=========== + +- Fixed a bug where an unsupported library would be detected because it shares a common + prefix with one of the supported libraries. Now the symbols are also checked to + identify the supported libraries. + https://github.com/joblib/threadpoolctl/pull/151 + 3.2.0 (2023-07-13) ================== diff --git a/continuous_integration/install_with_blis.sh b/continuous_integration/install_with_blis.sh index 6463ac86..89984f1e 100755 --- a/continuous_integration/install_with_blis.sh +++ b/continuous_integration/install_with_blis.sh @@ -35,6 +35,7 @@ popd # build & install numpy git clone https://github.com/numpy/numpy.git pushd numpy +git checkout v1.26.0 # pin numpy < 2 for now git submodule update --init echo "[blis] libraries = blis diff --git a/tests/_pyMylib/__init__.py b/tests/_pyMylib/__init__.py index d9b60c3e..af2867e5 100644 --- a/tests/_pyMylib/__init__.py +++ b/tests/_pyMylib/__init__.py @@ -19,6 +19,14 @@ class MyThreadedLibController(LibController): # instance. filename_prefixes = ("my_threaded_lib",) + # (Optional) Symbols that the linked library is expected to expose. It is used along + # with the `filename_prefixes` to make sure that the correct library is identified. + check_symbols = ( + "mylib_get_num_threads", + "mylib_set_num_threads", + "mylib_get_version", + ) + def get_num_threads(self): # This function should return the current maximum number of threads, # which is reported as "num_threads" by `ThreadpoolController.info`. diff --git a/threadpoolctl.py b/threadpoolctl.py index 55b4decc..2e231101 100644 --- a/threadpoolctl.py +++ b/threadpoolctl.py @@ -157,7 +157,18 @@ class OpenBLASController(LibController): user_api = "blas" internal_api = "openblas" filename_prefixes = ("libopenblas", "libblas") - check_symbols = ("openblas_get_num_threads", "openblas_get_num_threads64_") + check_symbols = ( + "openblas_get_num_threads", + "openblas_get_num_threads64_", + "openblas_set_num_threads", + "openblas_set_num_threads64_", + "openblas_get_config", + "openblas_get_config64_", + "openblas_get_parallel", + "openblas_get_parallel64_", + "openblas_get_corename", + "openblas_get_corename64_", + ) def set_additional_attributes(self): self.threading_layer = self._get_threading_layer() @@ -237,7 +248,15 @@ class BLISController(LibController): user_api = "blas" internal_api = "blis" filename_prefixes = ("libblis", "libblas") - check_symbols = ("bli_thread_get_num_threads",) + check_symbols = ( + "bli_thread_get_num_threads", + "bli_thread_set_num_threads", + "bli_info_get_version_str", + "bli_info_get_enable_openmp", + "bli_info_get_enable_pthreads", + "bli_arch_query_id", + "bli_arch_string", + ) def set_additional_attributes(self): self.threading_layer = self._get_threading_layer() @@ -266,9 +285,9 @@ def get_version(self): def _get_threading_layer(self): """Return the threading layer of BLIS""" - if self.dynlib.bli_info_get_enable_openmp(): + if getattr(self.dynlib, "bli_info_get_enable_openmp", lambda: False)(): return "openmp" - elif self.dynlib.bli_info_get_enable_pthreads(): + elif getattr(self.dynlib, "bli_info_get_enable_pthreads", lambda: False)(): return "pthreads" return "disabled" @@ -292,7 +311,12 @@ class MKLController(LibController): user_api = "blas" internal_api = "mkl" filename_prefixes = ("libmkl_rt", "mkl_rt", "libblas") - check_symbols = ("MKL_Get_Max_Threads",) + check_symbols = ( + "MKL_Get_Max_Threads", + "MKL_Set_Num_Threads", + "MKL_Get_Version_String", + "MKL_Set_Threading_Layer", + ) def set_additional_attributes(self): self.threading_layer = self._get_threading_layer() @@ -343,6 +367,10 @@ class OpenMPController(LibController): user_api = "openmp" internal_api = "openmp" filename_prefixes = ("libiomp", "libgomp", "libomp", "vcomp") + check_symbols = ( + "omp_get_max_threads", + "omp_get_num_threads", + ) def get_num_threads(self): get_func = getattr(self.dynlib, "omp_get_max_threads", lambda: None) @@ -978,11 +1006,17 @@ def _make_controller_from_path(self, filepath): # duplicate entry in threadpool_info. continue - # filename matches a prefix. Create and store the library + # filename matches a prefix. Now we check if the library has the symbols we + # are looking for. If none of the symbols exists, it's very likely not the + # expected library (e.g. a library having a common prefix with one of the + # our supported libraries). Otherwise, create and store the library # controller. - lib_controller = controller_class(filepath=filepath, prefix=prefix) - self.lib_controllers.append(lib_controller) + if not hasattr(controller_class, "check_symbols") or any( + hasattr(lib_controller.dynlib, func) + for func in controller_class.check_symbols + ): + self.lib_controllers.append(lib_controller) def _check_prefix(self, library_basename, filename_prefixes): """Return the prefix library_basename starts with