From 4b9bfd2beaa949df21c4dfaa0cbf7a2f3bde3243 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 16 Sep 2021 15:51:43 +0200 Subject: [PATCH 01/18] wip --- .../nested_prange_blas.pyx | 8 +- tests/test_threadpoolctl.py | 200 ++++----- threadpoolctl.py | 395 +++++++++--------- 3 files changed, 299 insertions(+), 304 deletions(-) diff --git a/tests/_openmp_test_helper/nested_prange_blas.pyx b/tests/_openmp_test_helper/nested_prange_blas.pyx index e327eee0..3746d2db 100644 --- a/tests/_openmp_test_helper/nested_prange_blas.pyx +++ b/tests/_openmp_test_helper/nested_prange_blas.pyx @@ -20,7 +20,7 @@ IF USE_BLIS: ELSE: from scipy.linalg.cython_blas cimport dgemm -from threadpoolctl import _ThreadpoolInfo, _ALL_USER_APIS +from threadpoolctl import ThreadpoolController def check_nested_prange_blas(double[:, ::1] A, double[:, ::1] B, int nthreads): @@ -43,12 +43,12 @@ def check_nested_prange_blas(double[:, ::1] A, double[:, ::1] B, int nthreads): int prange_num_threads int *prange_num_threads_ptr = &prange_num_threads - inner_info = [None] + inner_ctl = [None] with nogil, parallel(num_threads=nthreads): if openmp.omp_get_thread_num() == 0: with gil: - inner_info[0] = _ThreadpoolInfo(user_api=_ALL_USER_APIS) + inner_ctl[0] = ThreadpoolController() prange_num_threads_ptr[0] = openmp.omp_get_num_threads() @@ -62,4 +62,4 @@ def check_nested_prange_blas(double[:, ::1] A, double[:, ::1] B, int nthreads): &alpha, &B[0, 0], &k, &A[i * chunk_size, 0], &k, &beta, &C[i * chunk_size, 0], &n) - return np.asarray(C), prange_num_threads, inner_info[0] + return np.asarray(C), prange_num_threads, inner_ctl[0] diff --git a/tests/test_threadpoolctl.py b/tests/test_threadpoolctl.py index 544d86e3..a17677e2 100644 --- a/tests/test_threadpoolctl.py +++ b/tests/test_threadpoolctl.py @@ -5,7 +5,8 @@ import subprocess import sys -from threadpoolctl import threadpool_limits, threadpool_info, _ThreadpoolInfo +from threadpoolctl import threadpool_limits, threadpool_info +from threadpoolctl import ThreadpoolController from threadpoolctl import _ALL_PREFIXES, _ALL_USER_APIS from .utils import cython_extensions_compiled @@ -14,10 +15,10 @@ from .utils import threadpool_info_from_subprocess -def is_old_openblas(module): +def is_old_openblas(libctl): # Possible bug in getting maximum number of threads with OpenBLAS < 0.2.16 # and OpenBLAS does not expose its version before 0.3.4. - return module.internal_api == "openblas" and module.version is None + return libctl.internal_api == "openblas" and libctl.version is None def effective_num_threads(nthreads, max_threads): @@ -26,130 +27,128 @@ def effective_num_threads(nthreads, max_threads): return nthreads -def _threadpool_info(): - # Like threadpool_info but return the object instead of the list of dicts - return _ThreadpoolInfo(user_api=_ALL_USER_APIS) +# def _threadpool_info(): +# # Like threadpool_info but return the object instead of the list of dicts +# return ThreadpoolController() -def test_threadpool_limits_public_api(): - # Check consistency between threadpool_info and _ThreadpoolInfo - public_info = threadpool_info() - private_info = _threadpool_info() +def test_threadpool_info(): + # Check consistency between threadpool_info and ThreadpoolController + function_info = threadpool_info() + object_info = ThreadpoolController().lib_controllers - for module1, module2 in zip(public_info, private_info): - assert module1 == module2.todict() + for libctl1, libctl2 in zip(function_info, object_info): + assert libctl1 == libctl2.todict() -def test_ThreadpoolInfo_todicts(): - # Check all keys expected for the public api are in the dicts returned by - # the .todict(s) methods - info = _threadpool_info() +def test_ThreadpoolController_todicts(): + # Check that all keys expected for the private api are in the dicts + # returned by the todict(s) methods + controller = ThreadpoolController() - assert threadpool_info() == [module.todict() for module in info.modules] - assert info.todicts() == [module.todict() for module in info] - assert info.todicts() == [module.todict() for module in info.modules] + assert threadpool_info() == [libctl.todict() for libctl in controller] + assert controller.todicts() == [libctl.todict() for libctl in controller] - for module in info: - module_dict = module.todict() - assert "user_api" in module_dict - assert "internal_api" in module_dict - assert "prefix" in module_dict - assert "filepath" in module_dict - assert "version" in module_dict - assert "num_threads" in module_dict + for libctl_dict in controller.todicts(): + assert "user_api" in libctl_dict + assert "internal_api" in libctl_dict + assert "prefix" in libctl_dict + assert "filepath" in libctl_dict + assert "version" in libctl_dict + assert "num_threads" in libctl_dict - if module.internal_api in ("mkl", "blis", "openblas"): - assert "threading_layer" in module_dict + if libctl_dict["internal_api"] in ("mkl", "blis", "openblas"): + assert "threading_layer" in libctl_dict @pytest.mark.parametrize("prefix", _ALL_PREFIXES) @pytest.mark.parametrize("limit", [1, 3]) def test_threadpool_limits_by_prefix(prefix, limit): # Check that the maximum number of threads can be set by prefix - original_info = _threadpool_info() + original_ctl = ThreadpoolController() - modules_matching_prefix = original_info.get_modules("prefix", prefix) - if not modules_matching_prefix: + libctl_matching_prefix = original_ctl.select(prefix=prefix) + if not libctl_matching_prefix: pytest.skip("Requires {} runtime".format(prefix)) with threadpool_limits(limits={prefix: limit}): - for module in modules_matching_prefix: - if is_old_openblas(module): + for libctl in libctl_matching_prefix: + if is_old_openblas(libctl): continue # threadpool_limits only sets an upper bound on the number of # threads. - assert 0 < module.get_num_threads() <= limit - assert _threadpool_info() == original_info + assert 0 < libctl._get_num_threads() <= limit + assert ThreadpoolController() == original_ctl @pytest.mark.parametrize("user_api", (None, "blas", "openmp")) @pytest.mark.parametrize("limit", [1, 3]) def test_set_threadpool_limits_by_api(user_api, limit): # Check that the maximum number of threads can be set by user_api - original_info = _threadpool_info() + original_ctl = ThreadpoolController() - modules_matching_api = original_info.get_modules("user_api", user_api) - if not modules_matching_api: + libctl_matching_api = original_ctl.select(user_api=user_api) + if not libctl_matching_api: user_apis = _ALL_USER_APIS if user_api is None else [user_api] pytest.skip("Requires a library which api is in {}".format(user_apis)) with threadpool_limits(limits=limit, user_api=user_api): - for module in modules_matching_api: - if is_old_openblas(module): + for libctl in libctl_matching_api: + if is_old_openblas(libctl): continue # threadpool_limits only sets an upper bound on the number of # threads. - assert 0 < module.get_num_threads() <= limit + assert 0 < libctl._get_num_threads() <= limit - assert _threadpool_info() == original_info + assert ThreadpoolController() == original_ctl def test_threadpool_limits_function_with_side_effect(): # Check that threadpool_limits can be used as a function with # side effects instead of a context manager. - original_info = _threadpool_info() + original_ctl = ThreadpoolController() threadpool_limits(limits=1) try: - for module in _threadpool_info(): - if is_old_openblas(module): + for libctl in ThreadpoolController(): + if is_old_openblas(libctl): continue - assert module.num_threads == 1 + assert libctl.num_threads == 1 finally: # Restore the original limits so that this test does not have any # side-effect. - threadpool_limits(limits=original_info) + threadpool_limits(limits=original_ctl) - assert _threadpool_info() == original_info + assert ThreadpoolController() == original_ctl def test_set_threadpool_limits_no_limit(): # Check that limits=None does nothing. - original_info = _threadpool_info() + original_ctl = ThreadpoolController() with threadpool_limits(limits=None): - assert _threadpool_info() == original_info + assert ThreadpoolController() == original_ctl - assert _threadpool_info() == original_info + assert ThreadpoolController() == original_ctl def test_threadpool_limits_manual_unregister(): # Check that threadpool_limits can be used as an object which holds the # original state of the threadpools and that can be restored thanks to the # dedicated unregister method - original_info = _threadpool_info() + original_ctl = ThreadpoolController() limits = threadpool_limits(limits=1) try: - for module in _threadpool_info(): - if is_old_openblas(module): + for libctl in ThreadpoolController(): + if is_old_openblas(libctl): continue - assert module.num_threads == 1 + assert libctl.num_threads == 1 finally: # Restore the original limits so that this test does not have any # side-effect. limits.unregister() - assert _threadpool_info() == original_info + assert ThreadpoolController() == original_ctl def test_threadpool_limits_bad_input(): @@ -205,11 +204,11 @@ def test_openmp_nesting(nthreads_outer): # There are 2 openmp, the one from inner and the one from outer. assert len(outer_info) == 2 # We already know the one from inner. It has to be the other one. - prefixes = {module["prefix"] for module in outer_info} + prefixes = {lib_info["prefix"] for lib_info in outer_info} outer_omp = prefixes - {inner_omp} outer_num_threads, inner_num_threads = check_nested_openmp_loops(10) - original_info = _threadpool_info() + original_ctl = ThreadpoolController() if inner_omp == outer_omp: # The OpenMP runtime should be shared by default, meaning that the @@ -227,7 +226,7 @@ def test_openmp_nesting(nthreads_outer): # The state of the original state of all threadpools should have been # restored. - assert _threadpool_info() == original_info + assert ThreadpoolController() == original_ctl # The number of threads available in the outer loop should not have been # decreased: @@ -247,15 +246,15 @@ def test_openmp_nesting(nthreads_outer): def test_shipped_openblas(): # checks that OpenBLAS effectively uses the number of threads requested by # the context manager - original_info = _threadpool_info() + original_ctl = ThreadpoolController() - openblas_modules = original_info.get_modules("internal_api", "openblas") + openblas_controllers = original_ctl.select(internal_api="openblas") with threadpool_limits(1): - for module in openblas_modules: - assert module.get_num_threads() == 1 + for libctl in openblas_controllers: + assert libctl._get_num_threads() == 1 - assert original_info == _threadpool_info() + assert original_ctl == ThreadpoolController() @pytest.mark.skipif(len(libopenblas_paths) < 2, @@ -279,17 +278,18 @@ def test_nested_prange_blas(nthreads_outer): import tests._openmp_test_helper.nested_prange_blas as prange_blas check_nested_prange_blas = prange_blas.check_nested_prange_blas - original_info = _threadpool_info() + original_ctl = ThreadpoolController() - blas_info = original_info.get_modules("user_api", "blas") - blis_info = original_info.get_modules("internal_api", "blis") + blas_controllers = original_ctl.select(user_api="blas") + blis_controllers = original_ctl.select(internal_api="blis") # skip if the BLAS used by numpy is an old openblas. OpenBLAS 0.3.3 and # older are known to cause an unrecoverable deadlock at process shutdown # time (after pytest has exited). # numpy can be linked to BLIS for CBLAS and OpenBLAS for LAPACK. In that # case this test will run BLIS gemm so no need to skip. - if not blis_info and any(is_old_openblas(module) for module in blas_info): + if (not blis_controllers and + any(is_old_openblas(libctl) for libctl in blas_controllers)): pytest.skip("Old OpenBLAS: skipping test to avoid deadlock") A = np.ones((1000, 10)) @@ -300,17 +300,17 @@ def test_nested_prange_blas(nthreads_outer): nthreads = effective_num_threads(nthreads_outer, max_threads) result = check_nested_prange_blas(A, B, nthreads) - C, prange_num_threads, inner_info = result + C, prange_num_threads, inner_ctl = result assert np.allclose(C, np.dot(A, B.T)) assert prange_num_threads == nthreads - nested_blas_info = inner_info.get_modules("user_api", "blas") - assert len(nested_blas_info) == len(blas_info) - for module in nested_blas_info: - assert module.num_threads == 1 + nested_blas_controllers = inner_ctl.select(user_api="blas") + assert len(nested_blas_controllers) == len(blas_controllers) + for libctl in nested_blas_controllers: + assert libctl.num_threads == 1 - assert original_info == _threadpool_info() + assert original_ctl == ThreadpoolController() # the method `get_original_num_threads` raises a UserWarning due to different @@ -323,18 +323,18 @@ def test_get_original_num_threads(limit): # Tests the method get_original_num_threads of the context manager with threadpool_limits(limits=2, user_api="blas") as ctl: # set different blas num threads to start with (when multiple openblas) - if ctl._original_info: - ctl._original_info.modules[0].set_num_threads(1) + if ctl._controller: + ctl._controller.lib_controllers[0]._set_num_threads(1) - original_info = _threadpool_info() + original_ctl = ThreadpoolController() with threadpool_limits(limits=limit, user_api="blas") as threadpoolctx: original_num_threads = threadpoolctx.get_original_num_threads() assert "openmp" not in original_num_threads - blas_info = original_info.get_modules("user_api", "blas") - if blas_info: - expected = min(module.num_threads for module in blas_info) + blas_ctl = original_ctl.select(user_api="blas") + if blas_ctl: + expected = min(libctl.num_threads for libctl in blas_ctl) assert original_num_threads["blas"] == expected else: assert original_num_threads["blas"] is None @@ -347,30 +347,30 @@ def test_get_original_num_threads(limit): def test_mkl_threading_layer(): # Check that threadpool_info correctly recovers the threading layer used # by mkl - mkl_info = _threadpool_info().get_modules("internal_api", "mkl") + mkl_ctl = ThreadpoolController().select(internal_api="mkl") expected_layer = os.getenv("MKL_THREADING_LAYER") - if not (mkl_info and expected_layer): + if not (mkl_ctl and expected_layer): pytest.skip("requires MKL and the environment variable " "MKL_THREADING_LAYER set") - actual_layer = mkl_info.modules[0].threading_layer + actual_layer = mkl_ctl.lib_controllers[0].threading_layer assert actual_layer == expected_layer.lower() def test_blis_threading_layer(): # Check that threadpool_info correctly recovers the threading layer used # by blis - blis_info = _threadpool_info().get_modules("internal_api", "blis") + blis_ctl = ThreadpoolController().select(internal_api="blis") expected_layer = os.getenv("BLIS_ENABLE_THREADING") if expected_layer == "no": expected_layer = "disabled" - if not (blis_info and expected_layer): + if not (blis_ctl and expected_layer): pytest.skip("requires BLIS and the environment variable " "BLIS_ENABLE_THREADING set") - actual_layer = blis_info.modules[0].threading_layer + actual_layer = blis_ctl.lib_controllers[0].threading_layer assert actual_layer == expected_layer @@ -385,8 +385,8 @@ def test_libomp_libiomp_warning(recwarn): # Check that a warning is raised when both libomp and libiomp are loaded # It should happen in one CI job (pylatest_conda_mkl_clang_gcc). - info = _threadpool_info() - prefixes = [module.prefix for module in info] + ctl = ThreadpoolController() + prefixes = [libctl.prefix for libctl in ctl] if not ("libomp" in prefixes and "libiomp" in prefixes and sys.platform == "linux"): @@ -413,8 +413,8 @@ def test_command_line_command_flag(): cli_info = json.loads(output.decode("utf-8")) this_process_info = threadpool_info() - for module in cli_info: - assert module in this_process_info + for lib_info in cli_info: + assert lib_info in this_process_info @pytest.mark.skipif(sys.version_info < (3, 7), @@ -430,8 +430,8 @@ def test_command_line_import_flag(): cli_info = json.loads(result.stdout) this_process_info = threadpool_info() - for module in cli_info: - assert module in this_process_info + for lib_info in cli_info: + assert lib_info in this_process_info warnings = [w.strip() for w in result.stderr.splitlines()] assert "WARNING: could not import invalid_package" in warnings @@ -455,11 +455,11 @@ def test_architecture(): "skx", "haswell", ) - for module in threadpool_info(): - if module["internal_api"] == "openblas": - assert module["architecture"] in expected_openblas_architectures - elif module["internal_api"] == "blis": - assert module["architecture"] in expected_blis_architectures + for lib_info in threadpool_info(): + if lib_info["internal_api"] == "openblas": + assert lib_info["architecture"] in expected_openblas_architectures + elif lib_info["internal_api"] == "blis": + assert lib_info["architecture"] in expected_blis_architectures else: - # Not supported for other modules - assert "architecture" not in module + # Not supported for other libraries + assert "architecture" not in lib_info diff --git a/threadpoolctl.py b/threadpoolctl.py index 9aefe515..4df5bb18 100644 --- a/threadpoolctl.py +++ b/threadpoolctl.py @@ -59,26 +59,26 @@ class _dl_phdr_info(ctypes.Structure): # List of the supported libraries. The items are indexed by the name of the -# class to instanciate to create the module objects. The items hold the -# possible prefixes of loaded shared objects, the name of the internal_api to -# call and the name of the user_api. -_SUPPORTED_MODULES = { - "_OpenMPModule": { +# class to instanciate to create the library controller objects. The items hold +# the possible prefixes of loaded shared objects, the name of the internal_api +# to call and the name of the user_api. +_SUPPORTED_LIBRARIES = { + "OpenMPController": { "user_api": "openmp", "internal_api": "openmp", "filename_prefixes": ("libiomp", "libgomp", "libomp", "vcomp") }, - "_OpenBLASModule": { + "OpenBLASController": { "user_api": "blas", "internal_api": "openblas", "filename_prefixes": ("libopenblas",) }, - "_MKLModule": { + "MKLController": { "user_api": "blas", "internal_api": "mkl", "filename_prefixes": ("libmkl_rt", "mkl_rt") }, - "_BLISModule": { + "BLISController": { "user_api": "blas", "internal_api": "blis", "filename_prefixes": ("libblis",) @@ -86,14 +86,17 @@ class _dl_phdr_info(ctypes.Structure): } # Helpers for the doc and test names -_ALL_USER_APIS = list(set(m["user_api"] for m in _SUPPORTED_MODULES.values())) -_ALL_INTERNAL_APIS = [m["internal_api"] for m in _SUPPORTED_MODULES.values()] -_ALL_PREFIXES = [prefix for m in _SUPPORTED_MODULES.values() - for prefix in m["filename_prefixes"]] -_ALL_BLAS_LIBRARIES = [m["internal_api"] for m in _SUPPORTED_MODULES.values() - if m["user_api"] == "blas"] +_ALL_USER_APIS = list( + set(lib["user_api"] for lib in _SUPPORTED_LIBRARIES.values())) +_ALL_INTERNAL_APIS = [lib["internal_api"] for lib + in _SUPPORTED_LIBRARIES.values()] +_ALL_PREFIXES = [prefix for lib in _SUPPORTED_LIBRARIES.values() + for prefix in lib["filename_prefixes"]] +_ALL_BLAS_LIBRARIES = [lib["internal_api"] for lib + in _SUPPORTED_LIBRARIES.values() + if lib["user_api"] == "blas"] _ALL_OPENMP_LIBRARIES = list( - _SUPPORTED_MODULES["_OpenMPModule"]["filename_prefixes"]) + _SUPPORTED_LIBRARIES["OpenMPController"]["filename_prefixes"]) def _format_docstring(*args, **kwargs): @@ -110,19 +113,19 @@ def decorator(o): def threadpool_info(): """Return the maximal number of threads for each detected library. - Return a list with all the supported modules that have been found. Each - module is represented by a dict with the following information: + Return a list with all the supported libraries that have been found. Each + library is represented by a dict with the following information: - "user_api" : user API. Possible values are {USER_APIS}. - "internal_api": internal API. Possible values are {INTERNAL_APIS}. - "prefix" : filename prefix of the specific implementation. - - "filepath": path to the loaded module. + - "filepath": path to the loaded library. - "version": version of the library (if available). - "num_threads": the current thread limit. - In addition, each module may contain internal_api specific entries. + In addition, each library may contain internal_api specific entries. """ - return _ThreadpoolInfo(user_api=_ALL_USER_APIS).todicts() + return ThreadpoolController().todicts() @_format_docstring( @@ -164,12 +167,20 @@ class threadpool_limits: by the BLAS libraries if they rely on OpenMP. - If None, this function will apply to all supported libraries. + + controller : instance of ``ThreadpoolController`` or None (default=None) + The threadpool controller to use. If None, a new controller is created. """ - def __init__(self, limits=None, user_api=None): + def __init__(self, limits=None, user_api=None, controller=None): self._limits, self._user_api, self._prefixes = \ self._check_params(limits, user_api) - self._original_info = self._set_threadpool_limits() + if controller is not None: + self._controller = controller + else: + self._controller = ThreadpoolController() + + self._set_threadpool_limits() def __enter__(self): return self @@ -178,26 +189,22 @@ def __exit__(self, type, value, traceback): self.unregister() def unregister(self): - if self._original_info is not None: - for module in self._original_info: - module.set_num_threads(module.num_threads) + for libctl in self._controller.lib_controllers: + # Since we never call get_num_threads after instanciation of + # ThreadpoolController, num_threads holds the original value. + libctl._set_num_threads(libctl.num_threads) def get_original_num_threads(self): """Original num_threads from before calling threadpool_limits Return a dict `{user_api: num_threads}`. """ - if self._original_info is not None: - original_limits = self._original_info - else: - original_limits = _ThreadpoolInfo(user_api=self._user_api) - num_threads = {} warning_apis = [] for user_api in self._user_api: - limits = [module.num_threads for module in - original_limits.get_modules("user_api", user_api)] + limits = [libctl.num_threads for libctl in + self._controller.select(user_api=user_api)] limits = set(limits) n_limits = len(limits) @@ -236,22 +243,20 @@ def _check_params(self, limits, user_api): prefixes = [] else: if isinstance(limits, list): - # This should be a list of dicts of modules, for compatibility - # with the result from threadpool_info. - limits = {module["prefix"]: module["num_threads"] - for module in limits} - elif isinstance(limits, _ThreadpoolInfo): - # To set the limits from the modules of a _ThreadpoolInfo - # object. - limits = {module.prefix: module.num_threads - for module in limits} + # This should be a list of dicts of library controllers, for + # compatibility with the result from threadpool_info. + limits = {ctl["prefix"]: ctl["num_threads"] for ctl in limits} + elif isinstance(limits, ThreadpoolController): + # To set the limits from the library controllers of a + # ThreadpoolController object. + limits = {ctl.prefix: ctl.num_threads for ctl in limits} if not isinstance(limits, dict): raise TypeError("limits must either be an int, a list or a " "dict. Got {} instead".format(type(limits))) - # With a dictionary, can set both specific limit for given modules - # and global limit for user_api. Fetch each separately. + # With a dictionary, can set both specific limit for given + # libraries and global limit for user_api. Fetch each separately. prefixes = [prefix for prefix in limits if prefix in _ALL_PREFIXES] user_api = [api for api in limits if api in _ALL_USER_APIS] @@ -260,65 +265,47 @@ def _check_params(self, limits, user_api): def _set_threadpool_limits(self): """Change the maximal number of threads in selected thread pools. - Return a list with all the supported modules that have been found + Return a list with all the supported libraries that have been found matching `self._prefixes` and `self._user_api`. """ if self._limits is None: return None - modules = _ThreadpoolInfo(prefixes=self._prefixes, - user_api=self._user_api) - for module in modules: + for libctl in self._controller.lib_controllers: # self._limits is a dict {key: num_threads} where key is either - # a prefix or a user_api. If a module matches both, the limit - # corresponding to the prefix is chosed. - if module.prefix in self._limits: - num_threads = self._limits[module.prefix] + # a prefix or a user_api. If a library matches both, the limit + # corresponding to the prefix is chosen. + if libctl.prefix in self._limits: + num_threads = self._limits[libctl.prefix] + elif libctl.user_api in self._limits: + num_threads = self._limits[libctl.user_api] else: - num_threads = self._limits[module.user_api] + continue if num_threads is not None: - module.set_num_threads(num_threads) - return modules + libctl._set_num_threads(num_threads) -# The object oriented API of _ThreadpoolInfo and its modules is private. -# The public API (i.e. the "threadpool_info" function) only exposes the -# "list of dicts" representation returned by the .todicts method. @_format_docstring( PREFIXES=", ".join('"{}"'.format(prefix) for prefix in _ALL_PREFIXES), USER_APIS=", ".join('"{}"'.format(api) for api in _ALL_USER_APIS), BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES), OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES)) -class _ThreadpoolInfo(): - """Collection of all supported modules that have been found +class ThreadpoolController(): + """Collection of LibController objects for all loaded supported libraries Parameters ---------- - user_api : list of user APIs or None (default=None) - Select libraries matching the requested API. Ignored if `modules` is - not None. Supported user APIs are {USER_APIS}. - - - "blas" selects all BLAS supported libraries ({BLAS_LIBS}) - - "openmp" selects all OpenMP supported libraries ({OPENMP_LIBS}) - - If None, libraries are not selected by their `user_api`. - - prefixes : list of prefixes or None (default=None) - Select libraries matching the requested prefixes. Supported prefixes - are {PREFIXES}. - If None, libraries are not selected by their prefix. Ignored if - `modules` is not None. - - modules : list of _Module objects or None (default=None) - Wraps a list of _Module objects into a _ThreapoolInfo object. Does not - load or reload any shared library. If it is not None, `prefixes` and - `user_api` are ignored. - - Note - ---- - Is is possible to select libraries both by prefixes and by user_api. All - libraries matching one or the other will be selected. + controllers : list of ``LibController`` objects or None (default=None) + Wraps a list of library controllers into a ``ThreadpoolController`` + object. Does not load or reload any shared library. + + Attributes + ---------- + controllers : list of ``LibController`` objects + The list of library controllers of all loaded supported libraries that + match the selection from ``user_api`` and ``prefixes`` or provided by + ``controllers``. """ # Cache for libc under POSIX and a few system libraries under Windows. # We use a class level cache instead of an instance level cache because @@ -332,50 +319,59 @@ class _ThreadpoolInfo(): # never change during the lifetime of a program. _realpaths = dict() - def __init__(self, user_api=None, prefixes=None, modules=None): - if modules is None: - self.prefixes = [] if prefixes is None else prefixes - self.user_api = [] if user_api is None else user_api + def __init__(self, lib_controllers=None): + self.lib_controllers = lib_controllers - self.modules = [] - self._load_modules() + if self.lib_controllers is None: + self.lib_controllers = [] + self._load_libraries() self._warn_if_incompatible_openmp() - else: - self.modules = modules - - def get_modules(self, key, values): - """Return all modules such that values contains module[key]""" - if key == "user_api" and values is None: - values = list(_ALL_USER_APIS) - if not isinstance(values, list): - values = [values] - modules = [module for module in self.modules - if getattr(module, key) in values] - return _ThreadpoolInfo(modules=modules) + + @classmethod + def _from_controllers(cls, lib_controllers): + new_controller = cls.__new__(cls) + new_controller.lib_controllers = lib_controllers + return new_controller def todicts(self): """Return info as a list of dicts""" - return [module.todict() for module in self.modules] + return [libctl.todict() for libctl in self.lib_controllers] + + def select(self, **kwargs): + """Return a ThreadpoolController containing a subset of its libraries + + kwargs can be any number of pair (key, value) where key is a entry + """ + if not kwargs: + kwargs = {"user_api": _ALL_USER_APIS} + for key, vals in kwargs.items(): + kwargs[key] = [vals] if not isinstance(vals, list) else vals + + lib_controllers = [libctl for libctl in self.lib_controllers + if any(getattr(libctl, key, None) in vals + for key, vals in kwargs.items())] + + return ThreadpoolController._from_controllers(lib_controllers) def __len__(self): - return len(self.modules) + return len(self.lib_controllers) def __iter__(self): - yield from self.modules + yield from self.lib_controllers def __eq__(self, other): - return self.modules == other.modules + return self.lib_controllers == other.lib_controllers - def _load_modules(self): - """Loop through loaded libraries and store supported ones""" + def _load_libraries(self): + """Loop through loaded shared libraries and store the supported ones""" if sys.platform == "darwin": - self._find_modules_with_dyld() + self._find_libraries_with_dyld() elif sys.platform == "win32": - self._find_modules_with_enum_process_module_ex() + self._find_libraries_with_enum_process_module_ex() else: - self._find_modules_with_dl_iterate_phdr() + self._find_libraries_with_dl_iterate_phdr() - def _find_modules_with_dl_iterate_phdr(self): + def _find_libraries_with_dl_iterate_phdr(self): """Loop through loaded libraries and return binders on supported ones This function is expected to work on POSIX system only. @@ -390,26 +386,26 @@ def _find_modules_with_dl_iterate_phdr(self): return [] # Callback function for `dl_iterate_phdr` which is called for every - # module loaded in the current process until it returns 1. - def match_module_callback(info, size, data): - # Get the path of the current module + # library loaded in the current process until it returns 1. + def match_library_callback(info, size, data): + # Get the path of the current library filepath = info.contents.dlpi_name if filepath: filepath = filepath.decode("utf-8") - # Store the module if it is supported and selected - self._make_module_from_path(filepath) + # Store the library controller if it is supported and selected + self._make_controller_from_path(filepath) return 0 c_func_signature = ctypes.CFUNCTYPE( ctypes.c_int, # Return type ctypes.POINTER(_dl_phdr_info), ctypes.c_size_t, ctypes.c_char_p) - c_match_module_callback = c_func_signature(match_module_callback) + c_match_library_callback = c_func_signature(match_library_callback) data = ctypes.c_char_p(b"") - libc.dl_iterate_phdr(c_match_module_callback, data) + libc.dl_iterate_phdr(c_match_library_callback, data) - def _find_modules_with_dyld(self): + def _find_libraries_with_dyld(self): """Loop through loaded libraries and return binders on supported ones This function is expected to work on OSX system only @@ -425,10 +421,10 @@ def _find_modules_with_dyld(self): filepath = ctypes.string_at(libc._dyld_get_image_name(i)) filepath = filepath.decode("utf-8") - # Store the module if it is supported and selected - self._make_module_from_path(filepath) + # Store the library controller if it is supported and selected + self._make_controller_from_path(filepath) - def _find_modules_with_enum_process_module_ex(self): + def _find_libraries_with_enum_process_module_ex(self): """Loop through loaded libraries and return binders on supported ones This function is expected to work on windows system only. @@ -440,7 +436,7 @@ def _find_modules_with_enum_process_module_ex(self): PROCESS_QUERY_INFORMATION = 0x0400 PROCESS_VM_READ = 0x0010 - LIST_MODULES_ALL = 0x03 + LIST_LIBRARIES_ALL = 0x03 ps_api = self._get_windll("Psapi") kernel_32 = self._get_windll("kernel32") @@ -461,7 +457,7 @@ def _find_modules_with_enum_process_module_ex(self): buf_size = ctypes.sizeof(buf) if not ps_api.EnumProcessModulesEx( h_process, ctypes.byref(buf), buf_size, - ctypes.byref(needed), LIST_MODULES_ALL): + ctypes.byref(needed), LIST_LIBRARIES_ALL): raise OSError("EnumProcessModulesEx failed") if buf_size >= needed.value: break @@ -470,7 +466,7 @@ def _find_modules_with_enum_process_module_ex(self): count = needed.value // (buf_size // buf_count) h_modules = map(HMODULE, buf[:count]) - # Loop through all the module headers and get the module path + # Loop through all the module headers and get the library path buf = ctypes.create_unicode_buffer(MAX_PATH) n_size = DWORD() for h_module in h_modules: @@ -482,39 +478,41 @@ def _find_modules_with_enum_process_module_ex(self): raise OSError("GetModuleFileNameEx failed") filepath = buf.value - # Store the module if it is supported and selected - self._make_module_from_path(filepath) + # Store the library controller if it is supported and selected + self._make_controller_from_path(filepath) finally: kernel_32.CloseHandle(h_process) - def _make_module_from_path(self, filepath): - """Store a module if it is supported and selected""" + def _make_controller_from_path(self, filepath): + """Store a library controller if it is supported and selected""" # Required to resolve symlinks filepath = self._realpath(filepath) # `lower` required to take account of OpenMP dll case on Windows # (vcomp, VCOMP, Vcomp, ...) filename = os.path.basename(filepath).lower() - # Loop through supported modules to find if this filename corresponds - # to a supported module. - for module_class, candidate_module in _SUPPORTED_MODULES.items(): + # Loop through supported libraries to find if this filename corresponds + # to a supported one. + for controller_class, candidate_lib in _SUPPORTED_LIBRARIES.items(): # check if filename matches a supported prefix prefix = self._check_prefix(filename, - candidate_module["filename_prefixes"]) + candidate_lib["filename_prefixes"]) # filename does not match any of the prefixes of the candidate - # module. move to next module. + # library. move to next library. if prefix is None: continue - # filename matches a prefix. Check if it matches the request. If - # so, create and store the module. - user_api = candidate_module["user_api"] - internal_api = candidate_module["internal_api"] - if prefix in self.prefixes or user_api in self.user_api: - module_class = globals()[module_class] - module = module_class(filepath, prefix, user_api, internal_api) - self.modules.append(module) + # filename matches a prefix. Create and store the library + # controller. + user_api = candidate_lib["user_api"] + internal_api = candidate_lib["internal_api"] + + libctl_class = globals()[controller_class] + libctl = libctl_class( + filepath=filepath, prefix=prefix, user_api=user_api, + internal_api=internal_api) + self.lib_controllers.append(libctl) def _check_prefix(self, library_basename, filename_prefixes): """Return the prefix library_basename starts with @@ -532,7 +530,7 @@ def _warn_if_incompatible_openmp(self): # Only raise the warning on linux return - prefixes = [module.prefix for module in self.modules] + prefixes = [libctl.prefix for libctl in self.lib_controllers] msg = textwrap.dedent( """ Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at @@ -584,29 +582,29 @@ def _realpath(cls, filepath, cache_limit=10000): @_format_docstring( USER_APIS=", ".join('"{}"'.format(api) for api in _ALL_USER_APIS), INTERNAL_APIS=", ".join('"{}"'.format(api) for api in _ALL_INTERNAL_APIS)) -class _Module(ABC): - """Abstract base class for the modules +class LibController(ABC): + """Abstract base class for the individual library controllers - A module is represented by the following information: + A library controller is represented by the following information: - "user_api" : user API. Possible values are {USER_APIS}. - "internal_api" : internal API. Possible values are {INTERNAL_APIS}. - "prefix" : prefix of the shared library's name. - - "filepath" : path to the loaded module. + - "filepath" : path to the loaded library. - "version" : version of the library (if available). - "num_threads" : the current thread limit. - In addition, each module may contain internal_api specific entries. + In addition, each library controller may contain internal_api specific + entries. """ - def __init__(self, filepath=None, prefix=None, user_api=None, + def __init__(self, *, filepath=None, prefix=None, user_api=None, internal_api=None): - self.filepath = filepath - self.prefix = prefix self.user_api = user_api self.internal_api = internal_api + self.prefix = prefix + self.filepath = filepath self._dynlib = ctypes.CDLL(filepath, mode=_RTLD_NOLOAD) - self.version = self.get_version() - self.num_threads = self.get_num_threads() - self._get_extra_info() + self.version = self._get_version() + self.num_threads = self._get_num_threads() def __eq__(self, other): return self.todict() == other.todict() @@ -616,29 +614,29 @@ def todict(self): return {k: v for k, v in vars(self).items() if not k.startswith("_")} @abstractmethod - def get_version(self): + def _get_version(self): """Return the version of the shared library""" pass # pragma: no cover @abstractmethod - def get_num_threads(self): + def _get_num_threads(self): """Return the maximum number of threads available to use""" pass # pragma: no cover @abstractmethod - def set_num_threads(self, num_threads): + def _set_num_threads(self, num_threads): """Set the maximum number of threads to use""" pass # pragma: no cover - @abstractmethod - def _get_extra_info(self): - """Add additional module specific information""" - pass # pragma: no cover +class OpenBLASController(LibController): + """Controller class for OpenBLAS""" + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.threading_layer = self._get_threading_layer() + self.architecture = self._get_architecture() -class _OpenBLASModule(_Module): - """Module class for OpenBLAS""" - def get_version(self): + def _get_version(self): # None means OpenBLAS is not loaded or version < 0.3.4, since OpenBLAS # did not expose its version before that. get_config = getattr(self._dynlib, "openblas_get_config", None) @@ -651,21 +649,17 @@ def get_version(self): return config[1].decode("utf-8") return None - def get_num_threads(self): + def _get_num_threads(self): get_func = getattr(self._dynlib, "openblas_get_num_threads", lambda: None) return get_func() - def set_num_threads(self, num_threads): + def _set_num_threads(self, num_threads): set_func = getattr(self._dynlib, "openblas_set_num_threads", lambda num_threads: None) return set_func(num_threads) - def _get_extra_info(self): - self.threading_layer = self.get_threading_layer() - self.architecture = self.get_architecture() - - def get_threading_layer(self): + def _get_threading_layer(self): """Return the threading layer of OpenBLAS""" openblas_get_parallel = getattr(self._dynlib, "openblas_get_parallel", None) @@ -678,7 +672,8 @@ def get_threading_layer(self): return "pthreads" return "disabled" - def get_architecture(self): + def _get_architecture(self): + """Return the architecture detected by OpenBLAS""" get_corename = getattr(self._dynlib, "openblas_get_corename", None) if get_corename is None: return None @@ -687,9 +682,14 @@ def get_architecture(self): return get_corename().decode("utf-8") -class _BLISModule(_Module): - """Module class for BLIS""" - def get_version(self): +class BLISController(LibController): + """Controller class for BLIS""" + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.threading_layer = self._get_threading_layer() + self.architecture = self._get_architecture() + + def _get_version(self): get_version_ = getattr(self._dynlib, "bli_info_get_version_str", None) if get_version_ is None: return None @@ -697,7 +697,7 @@ def get_version(self): get_version_.restype = ctypes.c_char_p return get_version_().decode("utf-8") - def get_num_threads(self): + def _get_num_threads(self): get_func = getattr(self._dynlib, "bli_thread_get_num_threads", lambda: None) num_threads = get_func() @@ -705,16 +705,12 @@ def get_num_threads(self): # returns -1. We map it to 1 for consistency with other libraries. return 1 if num_threads == -1 else num_threads - def set_num_threads(self, num_threads): + def _set_num_threads(self, num_threads): set_func = getattr(self._dynlib, "bli_thread_set_num_threads", lambda num_threads: None) return set_func(num_threads) - def _get_extra_info(self): - self.threading_layer = self.get_threading_layer() - self.architecture = self.get_architecture() - - def get_threading_layer(self): + def _get_threading_layer(self): """Return the threading layer of BLIS""" if self._dynlib.bli_info_get_enable_openmp(): return "openmp" @@ -722,7 +718,8 @@ def get_threading_layer(self): return "pthreads" return "disabled" - def get_architecture(self): + def _get_architecture(self): + """Return the architecture detected by BLIS""" bli_arch_query_id = getattr(self._dynlib, "bli_arch_query_id", None) bli_arch_string = getattr(self._dynlib, "bli_arch_string", None) if bli_arch_query_id is None or bli_arch_string is None: @@ -735,9 +732,13 @@ def get_architecture(self): return bli_arch_string(bli_arch_query_id()).decode("utf-8") -class _MKLModule(_Module): - """Module class for MKL""" - def get_version(self): +class MKLController(LibController): + """Controller class for MKL""" + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.threading_layer = self._get_threading_layer() + + def _get_version(self): if not hasattr(self._dynlib, "MKL_Get_Version_String"): return None @@ -750,19 +751,16 @@ def get_version(self): version = group.groups()[0] return version.strip() - def get_num_threads(self): + def _get_num_threads(self): get_func = getattr(self._dynlib, "MKL_Get_Max_Threads", lambda: None) return get_func() - def set_num_threads(self, num_threads): + def _set_num_threads(self, num_threads): set_func = getattr(self._dynlib, "MKL_Set_Num_Threads", lambda num_threads: None) return set_func(num_threads) - def _get_extra_info(self): - self.threading_layer = self.get_threading_layer() - - def get_threading_layer(self): + def _get_threading_layer(self): """Return the threading layer of MKL""" # The function mkl_set_threading_layer returns the current threading # layer. Calling it with an invalid threading layer allows us to safely @@ -774,24 +772,21 @@ def get_threading_layer(self): return layer_map[set_threading_layer(-1)] -class _OpenMPModule(_Module): - """Module class for OpenMP""" - def get_version(self): +class OpenMPController(LibController): + """Controller class for OpenMP""" + def _get_version(self): # There is no way to get the version number programmatically in OpenMP. return None - def get_num_threads(self): + def _get_num_threads(self): get_func = getattr(self._dynlib, "omp_get_max_threads", lambda: None) return get_func() - def set_num_threads(self, num_threads): + def _set_num_threads(self, num_threads): set_func = getattr(self._dynlib, "omp_set_num_threads", lambda num_threads: None) return set_func(num_threads) - def _get_extra_info(self): - pass - def _main(): """Commandline interface to display thread-pool information and exit.""" From 7c213f17c5ce3251b0fa9ac04d3e54a0983636e6 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 16 Sep 2021 16:01:37 +0200 Subject: [PATCH 02/18] cln --- threadpoolctl.py | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/threadpoolctl.py b/threadpoolctl.py index 4df5bb18..622e992d 100644 --- a/threadpoolctl.py +++ b/threadpoolctl.py @@ -294,18 +294,10 @@ def _set_threadpool_limits(self): class ThreadpoolController(): """Collection of LibController objects for all loaded supported libraries - Parameters - ---------- - controllers : list of ``LibController`` objects or None (default=None) - Wraps a list of library controllers into a ``ThreadpoolController`` - object. Does not load or reload any shared library. - Attributes ---------- - controllers : list of ``LibController`` objects - The list of library controllers of all loaded supported libraries that - match the selection from ``user_api`` and ``prefixes`` or provided by - ``controllers``. + lib_controllers : list of ``LibController`` objects + The list of library controllers of all loaded supported libraries. """ # Cache for libc under POSIX and a few system libraries under Windows. # We use a class level cache instead of an instance level cache because @@ -319,13 +311,10 @@ class ThreadpoolController(): # never change during the lifetime of a program. _realpaths = dict() - def __init__(self, lib_controllers=None): - self.lib_controllers = lib_controllers - - if self.lib_controllers is None: - self.lib_controllers = [] - self._load_libraries() - self._warn_if_incompatible_openmp() + def __init__(self): + self.lib_controllers = [] + self._load_libraries() + self._warn_if_incompatible_openmp() @classmethod def _from_controllers(cls, lib_controllers): @@ -338,9 +327,11 @@ def todicts(self): return [libctl.todict() for libctl in self.lib_controllers] def select(self, **kwargs): - """Return a ThreadpoolController containing a subset of its libraries + """Return a ThreadpoolController containing a subset of its current + library controllers kwargs can be any number of pair (key, value) where key is a entry + TODO """ if not kwargs: kwargs = {"user_api": _ALL_USER_APIS} From 4ba0ccda6e8b72a9fc4d4d05fbb4e452d7dabc0f Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 16 Sep 2021 16:58:25 +0200 Subject: [PATCH 03/18] fix select --- threadpoolctl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/threadpoolctl.py b/threadpoolctl.py index 622e992d..e577e0dc 100644 --- a/threadpoolctl.py +++ b/threadpoolctl.py @@ -333,7 +333,7 @@ def select(self, **kwargs): kwargs can be any number of pair (key, value) where key is a entry TODO """ - if not kwargs: + if not kwargs or all(val is None for val in kwargs.values()): kwargs = {"user_api": _ALL_USER_APIS} for key, vals in kwargs.items(): kwargs[key] = [vals] if not isinstance(vals, list) else vals From 7a7a9fd7c57586a7e8f7f949522f372130f67eb6 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 16 Sep 2021 17:48:55 +0200 Subject: [PATCH 04/18] test for select --- tests/test_threadpoolctl.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/tests/test_threadpoolctl.py b/tests/test_threadpoolctl.py index a17677e2..0fbb8993 100644 --- a/tests/test_threadpoolctl.py +++ b/tests/test_threadpoolctl.py @@ -27,11 +27,6 @@ def effective_num_threads(nthreads, max_threads): return nthreads -# def _threadpool_info(): -# # Like threadpool_info but return the object instead of the list of dicts -# return ThreadpoolController() - - def test_threadpool_info(): # Check consistency between threadpool_info and ThreadpoolController function_info = threadpool_info() @@ -41,7 +36,7 @@ def test_threadpool_info(): assert libctl1 == libctl2.todict() -def test_ThreadpoolController_todicts(): +def test_threadpool_controller_todicts(): # Check that all keys expected for the private api are in the dicts # returned by the todict(s) methods controller = ThreadpoolController() @@ -61,6 +56,25 @@ def test_ThreadpoolController_todicts(): assert "threading_layer" in libctl_dict +@pytest.mark.parametrize("kwargs", [ + {"user_api": "blas"}, + {"prefix": "libgomp"}, + {"internal_api": "openblas", "prefix": "libomp"}, + {"prefix": ["libgomp", "libomp", "libiomp"]}] +) +def test_threadpool_controller_select(kwargs): + # Check the behior of the select method of ThreadpoolController + controller = ThreadpoolController().select(**kwargs) + if not controller: + pytest.skip(f"Requires at least one of {list(kwargs.values())}.") + + for libctl in controller.lib_controllers: + assert any( + getattr(libctl, key) in (val if isinstance(val, list) else [val]) + for key, val in kwargs.items() + ) + + @pytest.mark.parametrize("prefix", _ALL_PREFIXES) @pytest.mark.parametrize("limit", [1, 3]) def test_threadpool_limits_by_prefix(prefix, limit): From 1f30714821ec1385d16e0affd108cccbaa834992 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 16 Sep 2021 18:03:26 +0200 Subject: [PATCH 05/18] add threadpool_limits + controller test --- tests/test_threadpoolctl.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/test_threadpoolctl.py b/tests/test_threadpoolctl.py index 0fbb8993..81428082 100644 --- a/tests/test_threadpoolctl.py +++ b/tests/test_threadpoolctl.py @@ -165,6 +165,22 @@ def test_threadpool_limits_manual_unregister(): assert ThreadpoolController() == original_ctl +def test_threadpool_limits_with_controller(): + # Check that threadpool_limits will only act on the libraries contained in + # the controller when provided + original_blas_ctl = ThreadpoolController().select(user_api="blas") + original_openmp_ctl = ThreadpoolController().select(user_api="openmp") + + with threadpool_limits(1, controller=original_blas_ctl): + blas_ctl = ThreadpoolController().select(user_api="blas") + openmp_ctl = ThreadpoolController().select(user_api="openmp") + + assert all(libctl.num_threads == 1 for libctl in blas_ctl) + # the provided controller contains only blas libraries so no opemp + # library should be impacted. + assert openmp_ctl == original_openmp_ctl + + def test_threadpool_limits_bad_input(): # Check that appropriate errors are raised for invalid arguments match = re.escape("user_api must be either in {} or None." From 83bbb1819eeb0781a12075154966fefdfffd2be1 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 17 Sep 2021 10:53:22 +0200 Subject: [PATCH 06/18] wip --- benchmarks/bench_context_manager_overhead.py | 18 +- tests/_openmp_test_helper/build_utils.py | 5 +- tests/_openmp_test_helper/setup_inner.py | 17 +- .../setup_nested_prange_blas.py | 21 +- tests/_openmp_test_helper/setup_outer.py | 15 +- tests/test_threadpoolctl.py | 119 ++++++---- tests/utils.py | 16 +- threadpoolctl.py | 211 +++++++++++------- 8 files changed, 251 insertions(+), 171 deletions(-) diff --git a/benchmarks/bench_context_manager_overhead.py b/benchmarks/bench_context_manager_overhead.py index 34c1d92e..d3b69c15 100644 --- a/benchmarks/bench_context_manager_overhead.py +++ b/benchmarks/bench_context_manager_overhead.py @@ -4,12 +4,15 @@ from statistics import mean, stdev from threadpoolctl import threadpool_info, threadpool_limits -parser = ArgumentParser(description='Measure threadpool_limits call overhead.') -parser.add_argument('--import', dest="packages", default=[], nargs='+', - help='Python packages to import to load threadpool enabled' - ' libraries.') -parser.add_argument("--n-calls", type=int, default=100, - help="Number of iterations") +parser = ArgumentParser(description="Measure threadpool_limits call overhead.") +parser.add_argument( + "--import", + dest="packages", + default=[], + nargs="+", + help="Python packages to import to load threadpool enabled libraries.", +) +parser.add_argument("--n-calls", type=int, default=100, help="Number of iterations") args = parser.parse_args() for package_name in args.packages: @@ -24,5 +27,4 @@ pass timings.append(time.time() - t) -print("Overhead per call: {:.3f} +/-{:.3f} ms" - .format(mean(timings) * 1e3, stdev(timings) * 1e3)) +print(f"Overhead per call: {mean(timings) * 1e3:.3f} +/-{stdev(timings) * 1e3:.3f} ms") diff --git a/tests/_openmp_test_helper/build_utils.py b/tests/_openmp_test_helper/build_utils.py index 06cc52bb..90356b33 100644 --- a/tests/_openmp_test_helper/build_utils.py +++ b/tests/_openmp_test_helper/build_utils.py @@ -7,8 +7,7 @@ def set_cc_variables(var_name="CC"): if cc_var is not None: os.environ["CC"] = cc_var if sys.platform == "darwin": - os.environ["LDSHARED"] = ( - cc_var + " -bundle -undefined dynamic_lookup") + os.environ["LDSHARED"] = cc_var + " -bundle -undefined dynamic_lookup" else: os.environ["LDSHARED"] = cc_var + " -shared" @@ -18,6 +17,6 @@ def set_cc_variables(var_name="CC"): def get_openmp_flag(): if sys.platform == "win32": return ["/openmp"] - elif sys.platform == "darwin" and 'openmp' in os.getenv('CPPFLAGS', ''): + elif sys.platform == "darwin" and "openmp" in os.getenv("CPPFLAGS", ""): return [] return ["-fopenmp"] diff --git a/tests/_openmp_test_helper/setup_inner.py b/tests/_openmp_test_helper/setup_inner.py index af3e40df..e6103141 100644 --- a/tests/_openmp_test_helper/setup_inner.py +++ b/tests/_openmp_test_helper/setup_inner.py @@ -19,18 +19,21 @@ "openmp_helpers_inner", ["openmp_helpers_inner.pyx"], extra_compile_args=openmp_flag, - extra_link_args=openmp_flag - ) + extra_link_args=openmp_flag, + ) ] setup( - name='_openmp_test_helper_inner', + name="_openmp_test_helper_inner", ext_modules=cythonize( ext_modules, - compiler_directives={'language_level': 3, - 'boundscheck': False, - 'wraparound': False}, - compile_time_env={"CC_INNER_LOOP": inner_loop_cc_var or "unknown"}) + compiler_directives={ + "language_level": 3, + "boundscheck": False, + "wraparound": False, + }, + compile_time_env={"CC_INNER_LOOP": inner_loop_cc_var or "unknown"}, + ), ) finally: diff --git a/tests/_openmp_test_helper/setup_nested_prange_blas.py b/tests/_openmp_test_helper/setup_nested_prange_blas.py index 54da058f..275a92ce 100644 --- a/tests/_openmp_test_helper/setup_nested_prange_blas.py +++ b/tests/_openmp_test_helper/setup_nested_prange_blas.py @@ -12,8 +12,8 @@ set_cc_variables("CC_OUTER_LOOP") openmp_flag = get_openmp_flag() - use_blis = os.getenv('INSTALL_BLIS', False) - libraries = ['blis'] if use_blis else [] + use_blis = os.getenv("INSTALL_BLIS", False) + libraries = ["blis"] if use_blis else [] ext_modules = [ Extension( @@ -21,18 +21,21 @@ ["nested_prange_blas.pyx"], extra_compile_args=openmp_flag, extra_link_args=openmp_flag, - libraries=libraries - ) + libraries=libraries, + ) ] setup( - name='_openmp_test_helper_nested_prange_blas', + name="_openmp_test_helper_nested_prange_blas", ext_modules=cythonize( ext_modules, - compile_time_env={'USE_BLIS': use_blis}, - compiler_directives={'language_level': 3, - 'boundscheck': False, - 'wraparound': False}) + compile_time_env={"USE_BLIS": use_blis}, + compiler_directives={ + "language_level": 3, + "boundscheck": False, + "wraparound": False, + }, + ), ) finally: diff --git a/tests/_openmp_test_helper/setup_outer.py b/tests/_openmp_test_helper/setup_outer.py index 56c9bd2a..8f875bdf 100644 --- a/tests/_openmp_test_helper/setup_outer.py +++ b/tests/_openmp_test_helper/setup_outer.py @@ -19,18 +19,21 @@ "openmp_helpers_outer", ["openmp_helpers_outer.pyx"], extra_compile_args=openmp_flag, - extra_link_args=openmp_flag - ) + extra_link_args=openmp_flag, + ) ] setup( name="_openmp_test_helper_outer", ext_modules=cythonize( ext_modules, - compiler_directives={'language_level': 3, - 'boundscheck': False, - 'wraparound': False}, - compile_time_env={"CC_OUTER_LOOP": outer_loop_cc_var or "unknown"}) + compiler_directives={ + "language_level": 3, + "boundscheck": False, + "wraparound": False, + }, + compile_time_env={"CC_OUTER_LOOP": outer_loop_cc_var or "unknown"}, + ), ) finally: diff --git a/tests/test_threadpoolctl.py b/tests/test_threadpoolctl.py index 81428082..68a523d8 100644 --- a/tests/test_threadpoolctl.py +++ b/tests/test_threadpoolctl.py @@ -56,11 +56,14 @@ def test_threadpool_controller_todicts(): assert "threading_layer" in libctl_dict -@pytest.mark.parametrize("kwargs", [ - {"user_api": "blas"}, - {"prefix": "libgomp"}, - {"internal_api": "openblas", "prefix": "libomp"}, - {"prefix": ["libgomp", "libomp", "libiomp"]}] +@pytest.mark.parametrize( + "kwargs", + [ + {"user_api": "blas"}, + {"prefix": "libgomp"}, + {"internal_api": "openblas", "prefix": "libomp"}, + {"prefix": ["libgomp", "libomp", "libiomp"]}, + ], ) def test_threadpool_controller_select(kwargs): # Check the behior of the select method of ThreadpoolController @@ -83,7 +86,7 @@ def test_threadpool_limits_by_prefix(prefix, limit): libctl_matching_prefix = original_ctl.select(prefix=prefix) if not libctl_matching_prefix: - pytest.skip("Requires {} runtime".format(prefix)) + pytest.skip(f"Requires {prefix} runtime") with threadpool_limits(limits={prefix: limit}): for libctl in libctl_matching_prefix: @@ -104,7 +107,7 @@ def test_set_threadpool_limits_by_api(user_api, limit): libctl_matching_api = original_ctl.select(user_api=user_api) if not libctl_matching_api: user_apis = _ALL_USER_APIS if user_api is None else [user_api] - pytest.skip("Requires a library which api is in {}".format(user_apis)) + pytest.skip(f"Requires a library which api is in {user_apis}") with threadpool_limits(limits=limit, user_api=user_api): for libctl in libctl_matching_api: @@ -183,23 +186,24 @@ def test_threadpool_limits_with_controller(): def test_threadpool_limits_bad_input(): # Check that appropriate errors are raised for invalid arguments - match = re.escape("user_api must be either in {} or None." - .format(_ALL_USER_APIS)) + match = re.escape(f"user_api must be either in {_ALL_USER_APIS} or None.") with pytest.raises(ValueError, match=match): threadpool_limits(limits=1, user_api="wrong") - with pytest.raises(TypeError, - match="limits must either be an int, a list or a dict"): + match = "limits must either be an int, a list or a dict" + with pytest.raises(TypeError, match=match): threadpool_limits(limits=(1, 2, 3)) -@pytest.mark.skipif(not cython_extensions_compiled, - reason='Requires cython extensions to be compiled') -@pytest.mark.parametrize('num_threads', [1, 2, 4]) +@pytest.mark.skipif( + not cython_extensions_compiled, reason="Requires cython extensions to be compiled" +) +@pytest.mark.parametrize("num_threads", [1, 2, 4]) def test_openmp_limit_num_threads(num_threads): # checks that OpenMP effectively uses the number of threads requested by # the context manager import tests._openmp_test_helper.openmp_helpers_inner as omp_inner + check_openmp_num_threads = omp_inner.check_openmp_num_threads old_num_threads = check_openmp_num_threads(100) @@ -209,24 +213,28 @@ def test_openmp_limit_num_threads(num_threads): assert check_openmp_num_threads(100) == old_num_threads -@pytest.mark.skipif(not cython_extensions_compiled, - reason="Requires cython extensions to be compiled") +@pytest.mark.skipif( + not cython_extensions_compiled, reason="Requires cython extensions to be compiled" +) @pytest.mark.parametrize("nthreads_outer", [None, 1, 2, 4]) def test_openmp_nesting(nthreads_outer): # checks that OpenMP effectively uses the number of threads requested by # the context manager when nested in an outer OpenMP loop. import tests._openmp_test_helper.openmp_helpers_outer as omp_outer + check_nested_openmp_loops = omp_outer.check_nested_openmp_loops # Find which OpenMP lib is used at runtime for inner loop inner_info = threadpool_info_from_subprocess( - "tests._openmp_test_helper.openmp_helpers_inner") + "tests._openmp_test_helper.openmp_helpers_inner" + ) assert len(inner_info) == 1 inner_omp = inner_info[0]["prefix"] # Find which OpenMP lib is used at runtime for outer loop outer_info = threadpool_info_from_subprocess( - "tests._openmp_test_helper.openmp_helpers_outer") + "tests._openmp_test_helper.openmp_helpers_outer" + ) if len(outer_info) == 1: # Only 1 openmp loaded. It has to be this one. outer_omp = outer_info[0]["prefix"] @@ -251,8 +259,7 @@ def test_openmp_nesting(nthreads_outer): # Ask outer loop to run on nthreads threads and inner loop run on 1 # thread - outer_num_threads, inner_num_threads = \ - check_nested_openmp_loops(10, nthreads) + outer_num_threads, inner_num_threads = check_nested_openmp_loops(10, nthreads) # The state of the original state of all threadpools should have been # restored. @@ -268,8 +275,9 @@ def test_openmp_nesting(nthreads_outer): if inner_num_threads != 1: # XXX: this does not always work when nesting independent openmp # implementations. See: https://github.com/jeremiedbb/Nested_OpenMP - pytest.xfail("Inner OpenMP num threads was %d instead of 1" - % inner_num_threads) + pytest.xfail( + f"Inner OpenMP num threads was {inner_num_threads} instead of 1" + ) assert inner_num_threads == 1 @@ -287,8 +295,9 @@ def test_shipped_openblas(): assert original_ctl == ThreadpoolController() -@pytest.mark.skipif(len(libopenblas_paths) < 2, - reason="need at least 2 shipped openblas library") +@pytest.mark.skipif( + len(libopenblas_paths) < 2, reason="need at least 2 shipped openblas library" +) def test_multiple_shipped_openblas(): # This redundant test is meant to make it easier to see if the system # has 2 or more active openblas runtimes available just be reading the @@ -297,8 +306,9 @@ def test_multiple_shipped_openblas(): @pytest.mark.skipif(scipy is None, reason="requires scipy") -@pytest.mark.skipif(not cython_extensions_compiled, - reason='Requires cython extensions to be compiled') +@pytest.mark.skipif( + not cython_extensions_compiled, reason="Requires cython extensions to be compiled" +) @pytest.mark.parametrize("nthreads_outer", [None, 1, 2, 4]) def test_nested_prange_blas(nthreads_outer): # Check that the BLAS linked to scipy effectively uses the number of @@ -306,6 +316,7 @@ def test_nested_prange_blas(nthreads_outer): # loop. import numpy as np import tests._openmp_test_helper.nested_prange_blas as prange_blas + check_nested_prange_blas = prange_blas.check_nested_prange_blas original_ctl = ThreadpoolController() @@ -318,8 +329,9 @@ def test_nested_prange_blas(nthreads_outer): # time (after pytest has exited). # numpy can be linked to BLIS for CBLAS and OpenBLAS for LAPACK. In that # case this test will run BLIS gemm so no need to skip. - if (not blis_controllers and - any(is_old_openblas(libctl) for libctl in blas_controllers)): + if not blis_controllers and any( + is_old_openblas(libctl) for libctl in blas_controllers + ): pytest.skip("Old OpenBLAS: skipping test to avoid deadlock") A = np.ones((1000, 10)) @@ -381,8 +393,9 @@ def test_mkl_threading_layer(): expected_layer = os.getenv("MKL_THREADING_LAYER") if not (mkl_ctl and expected_layer): - pytest.skip("requires MKL and the environment variable " - "MKL_THREADING_LAYER set") + pytest.skip( + "requires MKL and the environment variable MKL_THREADING_LAYER set" + ) actual_layer = mkl_ctl.lib_controllers[0].threading_layer assert actual_layer == expected_layer.lower() @@ -397,15 +410,17 @@ def test_blis_threading_layer(): expected_layer = "disabled" if not (blis_ctl and expected_layer): - pytest.skip("requires BLIS and the environment variable " - "BLIS_ENABLE_THREADING set") + pytest.skip( + "requires BLIS and the environment variable BLIS_ENABLE_THREADING set" + ) actual_layer = blis_ctl.lib_controllers[0].threading_layer assert actual_layer == expected_layer -@pytest.mark.skipif(not cython_extensions_compiled, - reason='Requires cython extensions to be compiled') +@pytest.mark.skipif( + not cython_extensions_compiled, reason="Requires cython extensions to be compiled" +) def test_libomp_libiomp_warning(recwarn): # Trigger the import of a potentially clang-compiled extension: import tests._openmp_test_helper.openmp_helpers_outer # noqa @@ -418,8 +433,7 @@ def test_libomp_libiomp_warning(recwarn): ctl = ThreadpoolController() prefixes = [libctl.prefix for libctl in ctl] - if not ("libomp" in prefixes and "libiomp" in prefixes and - sys.platform == "linux"): + if not ("libomp" in prefixes and "libiomp" in prefixes and sys.platform == "linux"): pytest.skip("Requires both libomp and libiomp loaded, on Linux") assert len(recwarn) == 1 @@ -431,15 +445,15 @@ def test_libomp_libiomp_warning(recwarn): def test_command_line_empty(): - output = subprocess.check_output( - (sys.executable + " -m threadpoolctl").split()) + output = subprocess.check_output((sys.executable + " -m threadpoolctl").split()) assert json.loads(output.decode("utf-8")) == [] def test_command_line_command_flag(): pytest.importorskip("numpy") output = subprocess.check_output( - [sys.executable, "-m", "threadpoolctl", "-c", "import numpy"]) + [sys.executable, "-m", "threadpoolctl", "-c", "import numpy"] + ) cli_info = json.loads(output.decode("utf-8")) this_process_info = threadpool_info() @@ -447,16 +461,25 @@ def test_command_line_command_flag(): assert lib_info in this_process_info -@pytest.mark.skipif(sys.version_info < (3, 7), - reason="need recent subprocess.run options") +@pytest.mark.skipif( + sys.version_info < (3, 7), reason="need recent subprocess.run options" +) def test_command_line_import_flag(): - result = subprocess.run([ - sys.executable, "-m", "threadpoolctl", "-i", - "numpy", - "scipy.linalg", - "invalid_package", - "numpy.invalid_sumodule", - ], capture_output=True, check=True, encoding="utf-8") + result = subprocess.run( + [ + sys.executable, + "-m", + "threadpoolctl", + "-i", + "numpy", + "scipy.linalg", + "invalid_package", + "numpy.invalid_sumodule", + ], + capture_output=True, + check=True, + encoding="utf-8", + ) cli_info = json.loads(result.stdout) this_process_info = threadpool_info() diff --git a/tests/utils.py b/tests/utils.py index a7efeb02..a0043da3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -14,10 +14,10 @@ try: # make sure the mkl/blas are loaded for test_threadpool_limits import numpy as np + np.dot(np.ones(1000), np.ones(1000)) - libopenblas_patterns.append(os.path.join(np.__path__[0], ".libs", - "libopenblas*")) + libopenblas_patterns.append(os.path.join(np.__path__[0], ".libs", "libopenblas*")) except ImportError: pass @@ -25,19 +25,23 @@ try: import scipy import scipy.linalg # noqa: F401 + scipy.linalg.svd([[1, 2], [3, 4]]) - libopenblas_patterns.append(os.path.join(scipy.__path__[0], ".libs", - "libopenblas*")) + libopenblas_patterns.append( + os.path.join(scipy.__path__[0], ".libs", "libopenblas*") + ) except ImportError: scipy = None -libopenblas_paths = set(path for pattern in libopenblas_patterns - for path in glob(pattern)) +libopenblas_paths = set( + path for pattern in libopenblas_patterns for path in glob(pattern) +) try: import tests._openmp_test_helper.openmp_helpers_inner # noqa: F401 + cython_extensions_compiled = True except ImportError: cython_extensions_compiled = False diff --git a/threadpoolctl.py b/threadpoolctl.py index e577e0dc..a6636c1b 100644 --- a/threadpoolctl.py +++ b/threadpoolctl.py @@ -20,7 +20,7 @@ from abc import ABC, abstractmethod __version__ = "2.3.0.dev0" -__all__ = ["threadpool_limits", "threadpool_info"] +__all__ = ["threadpool_limits", "threadpool_info", "ThreadpoolController"] # One can get runtime errors or even segfaults due to multiple OpenMP libraries @@ -38,16 +38,16 @@ # Structure to cast the info on dynamically loaded library. See # https://linux.die.net/man/3/dl_iterate_phdr for more details. -_SYSTEM_UINT = ctypes.c_uint64 if sys.maxsize > 2**32 else ctypes.c_uint32 -_SYSTEM_UINT_HALF = ctypes.c_uint32 if sys.maxsize > 2**32 else ctypes.c_uint16 +_SYSTEM_UINT = ctypes.c_uint64 if sys.maxsize > 2 ** 32 else ctypes.c_uint32 +_SYSTEM_UINT_HALF = ctypes.c_uint32 if sys.maxsize > 2 ** 32 else ctypes.c_uint16 class _dl_phdr_info(ctypes.Structure): _fields_ = [ - ("dlpi_addr", _SYSTEM_UINT), # Base address of object - ("dlpi_name", ctypes.c_char_p), # path to the library - ("dlpi_phdr", ctypes.c_void_p), # pointer on dlpi_headers - ("dlpi_phnum", _SYSTEM_UINT_HALF) # number of elements in dlpi_phdr + ("dlpi_addr", _SYSTEM_UINT), # Base address of object + ("dlpi_name", ctypes.c_char_p), # path to the library + ("dlpi_phdr", ctypes.c_void_p), # pointer on dlpi_headers + ("dlpi_phnum", _SYSTEM_UINT_HALF), # number of elements in dlpi_phdr ] @@ -66,37 +66,41 @@ class _dl_phdr_info(ctypes.Structure): "OpenMPController": { "user_api": "openmp", "internal_api": "openmp", - "filename_prefixes": ("libiomp", "libgomp", "libomp", "vcomp") + "filename_prefixes": ("libiomp", "libgomp", "libomp", "vcomp"), }, "OpenBLASController": { "user_api": "blas", "internal_api": "openblas", - "filename_prefixes": ("libopenblas",) + "filename_prefixes": ("libopenblas",), }, "MKLController": { "user_api": "blas", "internal_api": "mkl", - "filename_prefixes": ("libmkl_rt", "mkl_rt") + "filename_prefixes": ("libmkl_rt", "mkl_rt"), }, "BLISController": { "user_api": "blas", "internal_api": "blis", - "filename_prefixes": ("libblis",) - } + "filename_prefixes": ("libblis",), + }, } # Helpers for the doc and test names -_ALL_USER_APIS = list( - set(lib["user_api"] for lib in _SUPPORTED_LIBRARIES.values())) -_ALL_INTERNAL_APIS = [lib["internal_api"] for lib - in _SUPPORTED_LIBRARIES.values()] -_ALL_PREFIXES = [prefix for lib in _SUPPORTED_LIBRARIES.values() - for prefix in lib["filename_prefixes"]] -_ALL_BLAS_LIBRARIES = [lib["internal_api"] for lib - in _SUPPORTED_LIBRARIES.values() - if lib["user_api"] == "blas"] +_ALL_USER_APIS = list(set(lib["user_api"] for lib in _SUPPORTED_LIBRARIES.values())) +_ALL_INTERNAL_APIS = [lib["internal_api"] for lib in _SUPPORTED_LIBRARIES.values()] +_ALL_PREFIXES = [ + prefix + for lib in _SUPPORTED_LIBRARIES.values() + for prefix in lib["filename_prefixes"] +] +_ALL_BLAS_LIBRARIES = [ + lib["internal_api"] + for lib in _SUPPORTED_LIBRARIES.values() + if lib["user_api"] == "blas" +] _ALL_OPENMP_LIBRARIES = list( - _SUPPORTED_LIBRARIES["OpenMPController"]["filename_prefixes"]) + _SUPPORTED_LIBRARIES["OpenMPController"]["filename_prefixes"] +) def _format_docstring(*args, **kwargs): @@ -108,8 +112,7 @@ def decorator(o): return decorator -@_format_docstring(USER_APIS=list(_ALL_USER_APIS), - INTERNAL_APIS=_ALL_INTERNAL_APIS) +@_format_docstring(USER_APIS=list(_ALL_USER_APIS), INTERNAL_APIS=_ALL_INTERNAL_APIS) def threadpool_info(): """Return the maximal number of threads for each detected library. @@ -131,7 +134,8 @@ def threadpool_info(): @_format_docstring( USER_APIS=", ".join('"{}"'.format(api) for api in _ALL_USER_APIS), BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES), - OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES)) + OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES), +) class threadpool_limits: """Change the maximal number of threads that can be used in thread pools. @@ -171,9 +175,11 @@ class threadpool_limits: controller : instance of ``ThreadpoolController`` or None (default=None) The threadpool controller to use. If None, a new controller is created. """ + def __init__(self, limits=None, user_api=None, controller=None): - self._limits, self._user_api, self._prefixes = \ - self._check_params(limits, user_api) + self._limits, self._user_api, self._prefixes = self._check_params( + limits, user_api + ) if controller is not None: self._controller = controller @@ -203,8 +209,10 @@ def get_original_num_threads(self): warning_apis = [] for user_api in self._user_api: - limits = [libctl.num_threads for libctl in - self._controller.select(user_api=user_api)] + limits = [ + libctl.num_threads + for libctl in self._controller.select(user_api=user_api) + ] limits = set(limits) n_limits = len(limits) @@ -220,14 +228,15 @@ def get_original_num_threads(self): if warning_apis: warnings.warn( - "Multiple value possible for following user apis: " + - ", ".join(warning_apis) + ". Returning the minimum.") + "Multiple value possible for following user apis: " + + ", ".join(warning_apis) + + ". Returning the minimum." + ) return num_threads def _check_params(self, limits, user_api): - """Suitable values for the _limits, _user_api and _prefixes attributes - """ + """Suitable values for the _limits, _user_api and _prefixes attributes""" if limits is None or isinstance(limits, int): if user_api is None: user_api = _ALL_USER_APIS @@ -235,8 +244,9 @@ def _check_params(self, limits, user_api): user_api = [user_api] else: raise ValueError( - "user_api must be either in {} or None. Got " - "{} instead.".format(_ALL_USER_APIS, user_api)) + f"user_api must be either in {_ALL_USER_APIS} or None. Got " + f"{user_api} instead." + ) if limits is not None: limits = {api: limits for api in user_api} @@ -252,8 +262,10 @@ def _check_params(self, limits, user_api): limits = {ctl.prefix: ctl.num_threads for ctl in limits} if not isinstance(limits, dict): - raise TypeError("limits must either be an int, a list or a " - "dict. Got {} instead".format(type(limits))) + raise TypeError( + "limits must either be an int, a list or a " + f"dict. Got {type(limits)} instead" + ) # With a dictionary, can set both specific limit for given # libraries and global limit for user_api. Fetch each separately. @@ -290,8 +302,9 @@ def _set_threadpool_limits(self): PREFIXES=", ".join('"{}"'.format(prefix) for prefix in _ALL_PREFIXES), USER_APIS=", ".join('"{}"'.format(api) for api in _ALL_USER_APIS), BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES), - OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES)) -class ThreadpoolController(): + OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES), +) +class ThreadpoolController: """Collection of LibController objects for all loaded supported libraries Attributes @@ -299,6 +312,7 @@ class ThreadpoolController(): lib_controllers : list of ``LibController`` objects The list of library controllers of all loaded supported libraries. """ + # Cache for libc under POSIX and a few system libraries under Windows. # We use a class level cache instead of an instance level cache because # it's very unlikely that a shared library will be unloaded and reloaded @@ -338,9 +352,11 @@ def select(self, **kwargs): for key, vals in kwargs.items(): kwargs[key] = [vals] if not isinstance(vals, list) else vals - lib_controllers = [libctl for libctl in self.lib_controllers - if any(getattr(libctl, key, None) in vals - for key, vals in kwargs.items())] + lib_controllers = [ + libctl + for libctl in self.lib_controllers + if any(getattr(libctl, key, None) in vals for key, vals in kwargs.items()) + ] return ThreadpoolController._from_controllers(lib_controllers) @@ -390,7 +406,10 @@ def match_library_callback(info, size, data): c_func_signature = ctypes.CFUNCTYPE( ctypes.c_int, # Return type - ctypes.POINTER(_dl_phdr_info), ctypes.c_size_t, ctypes.c_char_p) + ctypes.POINTER(_dl_phdr_info), + ctypes.c_size_t, + ctypes.c_char_p, + ) c_match_library_callback = c_func_signature(match_library_callback) data = ctypes.c_char_p(b"") @@ -433,8 +452,8 @@ def _find_libraries_with_enum_process_module_ex(self): kernel_32 = self._get_windll("kernel32") h_process = kernel_32.OpenProcess( - PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, - False, os.getpid()) + PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, False, os.getpid() + ) if not h_process: # pragma: no cover raise OSError("Could not open PID %s" % os.getpid()) @@ -447,8 +466,12 @@ def _find_libraries_with_enum_process_module_ex(self): buf = (HMODULE * buf_count)() buf_size = ctypes.sizeof(buf) if not ps_api.EnumProcessModulesEx( - h_process, ctypes.byref(buf), buf_size, - ctypes.byref(needed), LIST_LIBRARIES_ALL): + h_process, + ctypes.byref(buf), + buf_size, + ctypes.byref(needed), + LIST_LIBRARIES_ALL, + ): raise OSError("EnumProcessModulesEx failed") if buf_size >= needed.value: break @@ -464,8 +487,8 @@ def _find_libraries_with_enum_process_module_ex(self): # Get the path of the current module if not ps_api.GetModuleFileNameExW( - h_process, h_module, ctypes.byref(buf), - ctypes.byref(n_size)): + h_process, h_module, ctypes.byref(buf), ctypes.byref(n_size) + ): raise OSError("GetModuleFileNameEx failed") filepath = buf.value @@ -486,8 +509,7 @@ def _make_controller_from_path(self, filepath): # to a supported one. for controller_class, candidate_lib in _SUPPORTED_LIBRARIES.items(): # check if filename matches a supported prefix - prefix = self._check_prefix(filename, - candidate_lib["filename_prefixes"]) + prefix = self._check_prefix(filename, candidate_lib["filename_prefixes"]) # filename does not match any of the prefixes of the candidate # library. move to next library. @@ -501,8 +523,11 @@ def _make_controller_from_path(self, filepath): libctl_class = globals()[controller_class] libctl = libctl_class( - filepath=filepath, prefix=prefix, user_api=user_api, - internal_api=internal_api) + filepath=filepath, + prefix=prefix, + user_api=user_api, + internal_api=internal_api, + ) self.lib_controllers.append(libctl) def _check_prefix(self, library_basename, filename_prefixes): @@ -517,7 +542,7 @@ def _check_prefix(self, library_basename, filename_prefixes): def _warn_if_incompatible_openmp(self): """Raise a warning if llvm-OpenMP and intel-OpenMP are both loaded""" - if sys.platform != 'linux': + if sys.platform != "linux": # Only raise the warning on linux return @@ -531,8 +556,9 @@ def _warn_if_incompatible_openmp(self): Using threadpoolctl may cause crashes or deadlocks. For more information and possible workarounds, please see https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md - """) - if 'libomp' in prefixes and 'libiomp' in prefixes: + """ + ) + if "libomp" in prefixes and "libiomp" in prefixes: warnings.warn(msg, RuntimeWarning) @classmethod @@ -558,8 +584,7 @@ def _get_windll(cls, dll_name): @classmethod def _realpath(cls, filepath, cache_limit=10000): - """Small caching wrapper around os.path.realpath to limit system calls - """ + """Small caching wrapper around os.path.realpath to limit system calls""" rpath = cls._realpaths.get(filepath) if rpath is None: rpath = os.path.realpath(filepath) @@ -572,7 +597,8 @@ def _realpath(cls, filepath, cache_limit=10000): @_format_docstring( USER_APIS=", ".join('"{}"'.format(api) for api in _ALL_USER_APIS), - INTERNAL_APIS=", ".join('"{}"'.format(api) for api in _ALL_INTERNAL_APIS)) + INTERNAL_APIS=", ".join('"{}"'.format(api) for api in _ALL_INTERNAL_APIS), +) class LibController(ABC): """Abstract base class for the individual library controllers @@ -587,8 +613,8 @@ class LibController(ABC): In addition, each library controller may contain internal_api specific entries. """ - def __init__(self, *, filepath=None, prefix=None, user_api=None, - internal_api=None): + + def __init__(self, *, filepath=None, prefix=None, user_api=None, internal_api=None): self.user_api = user_api self.internal_api = internal_api self.prefix = prefix @@ -622,6 +648,7 @@ def _set_num_threads(self, num_threads): class OpenBLASController(LibController): """Controller class for OpenBLAS""" + def __init__(self, **kwargs): super().__init__(**kwargs) self.threading_layer = self._get_threading_layer() @@ -641,19 +668,18 @@ def _get_version(self): return None def _get_num_threads(self): - get_func = getattr(self._dynlib, "openblas_get_num_threads", - lambda: None) + get_func = getattr(self._dynlib, "openblas_get_num_threads", lambda: None) return get_func() def _set_num_threads(self, num_threads): - set_func = getattr(self._dynlib, "openblas_set_num_threads", - lambda num_threads: None) + set_func = getattr( + self._dynlib, "openblas_set_num_threads", lambda num_threads: None + ) return set_func(num_threads) def _get_threading_layer(self): """Return the threading layer of OpenBLAS""" - openblas_get_parallel = getattr(self._dynlib, "openblas_get_parallel", - None) + openblas_get_parallel = getattr(self._dynlib, "openblas_get_parallel", None) if openblas_get_parallel is None: return "unknown" threading_layer = openblas_get_parallel() @@ -675,6 +701,7 @@ def _get_architecture(self): class BLISController(LibController): """Controller class for BLIS""" + def __init__(self, **kwargs): super().__init__(**kwargs) self.threading_layer = self._get_threading_layer() @@ -689,16 +716,16 @@ def _get_version(self): return get_version_().decode("utf-8") def _get_num_threads(self): - get_func = getattr(self._dynlib, "bli_thread_get_num_threads", - lambda: None) + get_func = getattr(self._dynlib, "bli_thread_get_num_threads", lambda: None) num_threads = get_func() # by default BLIS is single-threaded and get_num_threads # returns -1. We map it to 1 for consistency with other libraries. return 1 if num_threads == -1 else num_threads def _set_num_threads(self, num_threads): - set_func = getattr(self._dynlib, "bli_thread_set_num_threads", - lambda num_threads: None) + set_func = getattr( + self._dynlib, "bli_thread_set_num_threads", lambda num_threads: None + ) return set_func(num_threads) def _get_threading_layer(self): @@ -725,6 +752,7 @@ def _get_architecture(self): class MKLController(LibController): """Controller class for MKL""" + def __init__(self, **kwargs): super().__init__(**kwargs) self.threading_layer = self._get_threading_layer() @@ -747,8 +775,9 @@ def _get_num_threads(self): return get_func() def _set_num_threads(self, num_threads): - set_func = getattr(self._dynlib, "MKL_Set_Num_Threads", - lambda num_threads: None) + set_func = getattr( + self._dynlib, "MKL_Set_Num_Threads", lambda num_threads: None + ) return set_func(num_threads) def _get_threading_layer(self): @@ -756,15 +785,23 @@ def _get_threading_layer(self): # The function mkl_set_threading_layer returns the current threading # layer. Calling it with an invalid threading layer allows us to safely # get the threading layer - set_threading_layer = getattr(self._dynlib, "MKL_Set_Threading_Layer", - lambda layer: -1) - layer_map = {0: "intel", 1: "sequential", 2: "pgi", - 3: "gnu", 4: "tbb", -1: "not specified"} + set_threading_layer = getattr( + self._dynlib, "MKL_Set_Threading_Layer", lambda layer: -1 + ) + layer_map = { + 0: "intel", + 1: "sequential", + 2: "pgi", + 3: "gnu", + 4: "tbb", + -1: "not specified", + } return layer_map[set_threading_layer(-1)] class OpenMPController(LibController): """Controller class for OpenMP""" + def _get_version(self): # There is no way to get the version number programmatically in OpenMP. return None @@ -774,8 +811,9 @@ def _get_num_threads(self): return get_func() def _set_num_threads(self, num_threads): - set_func = getattr(self._dynlib, "omp_set_num_threads", - lambda num_threads: None) + set_func = getattr( + self._dynlib, "omp_set_num_threads", lambda num_threads: None + ) return set_func(num_threads) @@ -791,13 +829,18 @@ def _main(): description="Display thread-pool information and exit.", ) parser.add_argument( - "-i", "--import", dest="modules", nargs="*", default=(), - help="Python modules to import before introspecting thread-pools." + "-i", + "--import", + dest="modules", + nargs="*", + default=(), + help="Python modules to import before introspecting thread-pools.", ) parser.add_argument( - "-c", "--command", - help="a Python statement to execute before introspecting" - " thread-pools.") + "-c", + "--command", + help="a Python statement to execute before introspecting" " thread-pools.", + ) options = parser.parse_args(sys.argv[1:]) for module in options.modules: From 20c4f65deab20e0358dbfcb68cabbf105f6dca47 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 17 Sep 2021 16:39:32 +0200 Subject: [PATCH 07/18] address some comments + doc select --- .../nested_prange_blas.pyx | 6 +- tests/test_threadpoolctl.py | 173 +++++++++--------- threadpoolctl.py | 54 +++--- 3 files changed, 119 insertions(+), 114 deletions(-) diff --git a/tests/_openmp_test_helper/nested_prange_blas.pyx b/tests/_openmp_test_helper/nested_prange_blas.pyx index 3746d2db..aec7f815 100644 --- a/tests/_openmp_test_helper/nested_prange_blas.pyx +++ b/tests/_openmp_test_helper/nested_prange_blas.pyx @@ -43,12 +43,12 @@ def check_nested_prange_blas(double[:, ::1] A, double[:, ::1] B, int nthreads): int prange_num_threads int *prange_num_threads_ptr = &prange_num_threads - inner_ctl = [None] + inner_controller = [None] with nogil, parallel(num_threads=nthreads): if openmp.omp_get_thread_num() == 0: with gil: - inner_ctl[0] = ThreadpoolController() + inner_controller[0] = ThreadpoolController() prange_num_threads_ptr[0] = openmp.omp_get_num_threads() @@ -62,4 +62,4 @@ def check_nested_prange_blas(double[:, ::1] A, double[:, ::1] B, int nthreads): &alpha, &B[0, 0], &k, &A[i * chunk_size, 0], &k, &beta, &C[i * chunk_size, 0], &n) - return np.asarray(C), prange_num_threads, inner_ctl[0] + return np.asarray(C), prange_num_threads, inner_controller[0] diff --git a/tests/test_threadpoolctl.py b/tests/test_threadpoolctl.py index 68a523d8..474de75f 100644 --- a/tests/test_threadpoolctl.py +++ b/tests/test_threadpoolctl.py @@ -15,10 +15,10 @@ from .utils import threadpool_info_from_subprocess -def is_old_openblas(libctl): +def is_old_openblas(lib_controller): # Possible bug in getting maximum number of threads with OpenBLAS < 0.2.16 # and OpenBLAS does not expose its version before 0.3.4. - return libctl.internal_api == "openblas" and libctl.version is None + return lib_controller.internal_api == "openblas" and lib_controller.version is None def effective_num_threads(nthreads, max_threads): @@ -32,8 +32,8 @@ def test_threadpool_info(): function_info = threadpool_info() object_info = ThreadpoolController().lib_controllers - for libctl1, libctl2 in zip(function_info, object_info): - assert libctl1 == libctl2.todict() + for lib_info, lib_controller in zip(function_info, object_info): + assert lib_info == lib_controller.todict() def test_threadpool_controller_todicts(): @@ -41,19 +41,19 @@ def test_threadpool_controller_todicts(): # returned by the todict(s) methods controller = ThreadpoolController() - assert threadpool_info() == [libctl.todict() for libctl in controller] - assert controller.todicts() == [libctl.todict() for libctl in controller] + assert threadpool_info() == [lib_controller.todict() for lib_controller in controller] + assert controller.todicts() == [lib_controller.todict() for lib_controller in controller] - for libctl_dict in controller.todicts(): - assert "user_api" in libctl_dict - assert "internal_api" in libctl_dict - assert "prefix" in libctl_dict - assert "filepath" in libctl_dict - assert "version" in libctl_dict - assert "num_threads" in libctl_dict + for lib_controller_dict in controller.todicts(): + assert "user_api" in lib_controller_dict + assert "internal_api" in lib_controller_dict + assert "prefix" in lib_controller_dict + assert "filepath" in lib_controller_dict + assert "version" in lib_controller_dict + assert "num_threads" in lib_contoller_dict - if libctl_dict["internal_api"] in ("mkl", "blis", "openblas"): - assert "threading_layer" in libctl_dict + if lib_controller_dict["internal_api"] in ("mkl", "blis", "openblas"): + assert "threading_layer" in lib_controller_dict @pytest.mark.parametrize( @@ -71,9 +71,9 @@ def test_threadpool_controller_select(kwargs): if not controller: pytest.skip(f"Requires at least one of {list(kwargs.values())}.") - for libctl in controller.lib_controllers: + for lib_controller in controller.lib_controllers: assert any( - getattr(libctl, key) in (val if isinstance(val, list) else [val]) + getattr(lib_controller, key) in (val if isinstance(val, list) else [val]) for key, val in kwargs.items() ) @@ -82,106 +82,109 @@ def test_threadpool_controller_select(kwargs): @pytest.mark.parametrize("limit", [1, 3]) def test_threadpool_limits_by_prefix(prefix, limit): # Check that the maximum number of threads can be set by prefix - original_ctl = ThreadpoolController() + original_controller = ThreadpoolController() - libctl_matching_prefix = original_ctl.select(prefix=prefix) - if not libctl_matching_prefix: + lib_controller_matching_prefix = original_controller.select(prefix=prefix) + if not lib_controller_matching_prefix: pytest.skip(f"Requires {prefix} runtime") with threadpool_limits(limits={prefix: limit}): - for libctl in libctl_matching_prefix: - if is_old_openblas(libctl): + for lib_controller in lib_controller_matching_prefix: + if is_old_openblas(lib_controller): continue # threadpool_limits only sets an upper bound on the number of # threads. - assert 0 < libctl._get_num_threads() <= limit - assert ThreadpoolController() == original_ctl + assert 0 < lib_controller._get_num_threads() <= limit + assert ThreadpoolController() == original_controller @pytest.mark.parametrize("user_api", (None, "blas", "openmp")) @pytest.mark.parametrize("limit", [1, 3]) def test_set_threadpool_limits_by_api(user_api, limit): # Check that the maximum number of threads can be set by user_api - original_ctl = ThreadpoolController() + original_controller = ThreadpoolController() - libctl_matching_api = original_ctl.select(user_api=user_api) - if not libctl_matching_api: + if user_api is None: + lib_controller_matching_api = original_controller + else: + lib_controller_matching_api = original_controller.select(user_api=user_api) + if not lib_controller_matching_api: user_apis = _ALL_USER_APIS if user_api is None else [user_api] pytest.skip(f"Requires a library which api is in {user_apis}") with threadpool_limits(limits=limit, user_api=user_api): - for libctl in libctl_matching_api: - if is_old_openblas(libctl): + for lib_controller in lib_controller_matching_api: + if is_old_openblas(lib_controller): continue # threadpool_limits only sets an upper bound on the number of # threads. - assert 0 < libctl._get_num_threads() <= limit + assert 0 < lib_controller._get_num_threads() <= limit - assert ThreadpoolController() == original_ctl + assert ThreadpoolController() == original_controller def test_threadpool_limits_function_with_side_effect(): # Check that threadpool_limits can be used as a function with # side effects instead of a context manager. - original_ctl = ThreadpoolController() + original_controller = ThreadpoolController() threadpool_limits(limits=1) try: - for libctl in ThreadpoolController(): - if is_old_openblas(libctl): + for lib_controller in ThreadpoolController(): + if is_old_openblas(lib_controller): continue - assert libctl.num_threads == 1 + assert lib_controller.num_threads == 1 finally: # Restore the original limits so that this test does not have any # side-effect. - threadpool_limits(limits=original_ctl) + threadpool_limits(limits=original_controller) - assert ThreadpoolController() == original_ctl + assert ThreadpoolController() == original_controller def test_set_threadpool_limits_no_limit(): # Check that limits=None does nothing. - original_ctl = ThreadpoolController() + original_controller = ThreadpoolController() with threadpool_limits(limits=None): - assert ThreadpoolController() == original_ctl + assert ThreadpoolController() == original_controller - assert ThreadpoolController() == original_ctl + assert ThreadpoolController() == original_controller def test_threadpool_limits_manual_unregister(): # Check that threadpool_limits can be used as an object which holds the # original state of the threadpools and that can be restored thanks to the # dedicated unregister method - original_ctl = ThreadpoolController() + original_controller = ThreadpoolController() limits = threadpool_limits(limits=1) try: - for libctl in ThreadpoolController(): - if is_old_openblas(libctl): + for lib_controller in ThreadpoolController(): + if is_old_openblas(lib_controller): continue - assert libctl.num_threads == 1 + assert lib_controller.num_threads == 1 finally: # Restore the original limits so that this test does not have any # side-effect. limits.unregister() - assert ThreadpoolController() == original_ctl + assert ThreadpoolController() == original_controller def test_threadpool_limits_with_controller(): # Check that threadpool_limits will only act on the libraries contained in # the controller when provided - original_blas_ctl = ThreadpoolController().select(user_api="blas") - original_openmp_ctl = ThreadpoolController().select(user_api="openmp") + original_blas_controller = ThreadpoolController().select(user_api="blas") + original_openmp_controller = ThreadpoolController().select(user_api="openmp") - with threadpool_limits(1, controller=original_blas_ctl): - blas_ctl = ThreadpoolController().select(user_api="blas") - openmp_ctl = ThreadpoolController().select(user_api="openmp") + with threadpool_limits(1, controller=original_blas_controller): + blas_controller = ThreadpoolController().select(user_api="blas") + openmp_controller = ThreadpoolController().select(user_api="openmp") - assert all(libctl.num_threads == 1 for libctl in blas_ctl) + assert all(lib_controller.num_threads == 1 for lib_controller in blas_controller) # the provided controller contains only blas libraries so no opemp # library should be impacted. - assert openmp_ctl == original_openmp_ctl + assert openmp_controller == original_openmp_controller def test_threadpool_limits_bad_input(): @@ -246,7 +249,7 @@ def test_openmp_nesting(nthreads_outer): outer_omp = prefixes - {inner_omp} outer_num_threads, inner_num_threads = check_nested_openmp_loops(10) - original_ctl = ThreadpoolController() + original_controller = ThreadpoolController() if inner_omp == outer_omp: # The OpenMP runtime should be shared by default, meaning that the @@ -263,7 +266,7 @@ def test_openmp_nesting(nthreads_outer): # The state of the original state of all threadpools should have been # restored. - assert ThreadpoolController() == original_ctl + assert ThreadpoolController() == original_controller # The number of threads available in the outer loop should not have been # decreased: @@ -284,15 +287,15 @@ def test_openmp_nesting(nthreads_outer): def test_shipped_openblas(): # checks that OpenBLAS effectively uses the number of threads requested by # the context manager - original_ctl = ThreadpoolController() + original_controller = ThreadpoolController() - openblas_controllers = original_ctl.select(internal_api="openblas") + openblas_controllers = original_controller.select(internal_api="openblas") with threadpool_limits(1): - for libctl in openblas_controllers: - assert libctl._get_num_threads() == 1 + for lib_controller in openblas_controllers: + assert lib_controller._get_num_threads() == 1 - assert original_ctl == ThreadpoolController() + assert original_controller == ThreadpoolController() @pytest.mark.skipif( @@ -319,10 +322,10 @@ def test_nested_prange_blas(nthreads_outer): check_nested_prange_blas = prange_blas.check_nested_prange_blas - original_ctl = ThreadpoolController() + original_controller = ThreadpoolController() - blas_controllers = original_ctl.select(user_api="blas") - blis_controllers = original_ctl.select(internal_api="blis") + blas_controllers = original_controller.select(user_api="blas") + blis_controllers = original_controller.select(internal_api="blis") # skip if the BLAS used by numpy is an old openblas. OpenBLAS 0.3.3 and # older are known to cause an unrecoverable deadlock at process shutdown @@ -330,7 +333,7 @@ def test_nested_prange_blas(nthreads_outer): # numpy can be linked to BLIS for CBLAS and OpenBLAS for LAPACK. In that # case this test will run BLIS gemm so no need to skip. if not blis_controllers and any( - is_old_openblas(libctl) for libctl in blas_controllers + is_old_openblas(lib_controller) for lib_controller in blas_controllers ): pytest.skip("Old OpenBLAS: skipping test to avoid deadlock") @@ -342,17 +345,17 @@ def test_nested_prange_blas(nthreads_outer): nthreads = effective_num_threads(nthreads_outer, max_threads) result = check_nested_prange_blas(A, B, nthreads) - C, prange_num_threads, inner_ctl = result + C, prange_num_threads, inner_controller = result assert np.allclose(C, np.dot(A, B.T)) assert prange_num_threads == nthreads - nested_blas_controllers = inner_ctl.select(user_api="blas") + nested_blas_controllers = inner_controller.select(user_api="blas") assert len(nested_blas_controllers) == len(blas_controllers) - for libctl in nested_blas_controllers: - assert libctl.num_threads == 1 + for lib_controller in nested_blas_controllers: + assert lib_controller.num_threads == 1 - assert original_ctl == ThreadpoolController() + assert original_controller == ThreadpoolController() # the method `get_original_num_threads` raises a UserWarning due to different @@ -363,20 +366,20 @@ def test_nested_prange_blas(nthreads_outer): @pytest.mark.parametrize("limit", [1, None]) def test_get_original_num_threads(limit): # Tests the method get_original_num_threads of the context manager - with threadpool_limits(limits=2, user_api="blas") as ctl: + with threadpool_limits(limits=2, user_api="blas") as ctx: # set different blas num threads to start with (when multiple openblas) - if ctl._controller: - ctl._controller.lib_controllers[0]._set_num_threads(1) + if ctx._controller: + ctx._controller.lib_controllers[0]._set_num_threads(1) - original_ctl = ThreadpoolController() + original_controller = ThreadpoolController() with threadpool_limits(limits=limit, user_api="blas") as threadpoolctx: original_num_threads = threadpoolctx.get_original_num_threads() assert "openmp" not in original_num_threads - blas_ctl = original_ctl.select(user_api="blas") - if blas_ctl: - expected = min(libctl.num_threads for libctl in blas_ctl) + blas_controller = original_controller.select(user_api="blas") + if blas_controller: + expected = min(lib_controller.num_threads for lib_controller in blas_controller) assert original_num_threads["blas"] == expected else: assert original_num_threads["blas"] is None @@ -389,32 +392,30 @@ def test_get_original_num_threads(limit): def test_mkl_threading_layer(): # Check that threadpool_info correctly recovers the threading layer used # by mkl - mkl_ctl = ThreadpoolController().select(internal_api="mkl") + mkl_controller = ThreadpoolController().select(internal_api="mkl") expected_layer = os.getenv("MKL_THREADING_LAYER") - if not (mkl_ctl and expected_layer): - pytest.skip( - "requires MKL and the environment variable MKL_THREADING_LAYER set" - ) + if not (mkl_controller and expected_layer): + pytest.skip("requires MKL and the environment variable MKL_THREADING_LAYER set") - actual_layer = mkl_ctl.lib_controllers[0].threading_layer + actual_layer = mkl_controller.lib_controllers[0].threading_layer assert actual_layer == expected_layer.lower() def test_blis_threading_layer(): # Check that threadpool_info correctly recovers the threading layer used # by blis - blis_ctl = ThreadpoolController().select(internal_api="blis") + blis_controller = ThreadpoolController().select(internal_api="blis") expected_layer = os.getenv("BLIS_ENABLE_THREADING") if expected_layer == "no": expected_layer = "disabled" - if not (blis_ctl and expected_layer): + if not (blis_controller and expected_layer): pytest.skip( "requires BLIS and the environment variable BLIS_ENABLE_THREADING set" ) - actual_layer = blis_ctl.lib_controllers[0].threading_layer + actual_layer = blis_controller.lib_controllers[0].threading_layer assert actual_layer == expected_layer @@ -430,8 +431,8 @@ def test_libomp_libiomp_warning(recwarn): # Check that a warning is raised when both libomp and libiomp are loaded # It should happen in one CI job (pylatest_conda_mkl_clang_gcc). - ctl = ThreadpoolController() - prefixes = [libctl.prefix for libctl in ctl] + controller = ThreadpoolController() + prefixes = [lib_controller.prefix for lib_controller in controller] if not ("libomp" in prefixes and "libiomp" in prefixes and sys.platform == "linux"): pytest.skip("Requires both libomp and libiomp loaded, on Linux") diff --git a/threadpoolctl.py b/threadpoolctl.py index a6636c1b..af4a7ac7 100644 --- a/threadpoolctl.py +++ b/threadpoolctl.py @@ -195,10 +195,10 @@ def __exit__(self, type, value, traceback): self.unregister() def unregister(self): - for libctl in self._controller.lib_controllers: + for lib_controller in self._controller.lib_controllers: # Since we never call get_num_threads after instanciation of # ThreadpoolController, num_threads holds the original value. - libctl._set_num_threads(libctl.num_threads) + lib_controller._set_num_threads(lib_controller.num_threads) def get_original_num_threads(self): """Original num_threads from before calling threadpool_limits @@ -210,8 +210,8 @@ def get_original_num_threads(self): for user_api in self._user_api: limits = [ - libctl.num_threads - for libctl in self._controller.select(user_api=user_api) + lib_controller.num_threads + for lib_controller in self._controller.select(user_api=user_api) ] limits = set(limits) n_limits = len(limits) @@ -253,13 +253,13 @@ def _check_params(self, limits, user_api): prefixes = [] else: if isinstance(limits, list): - # This should be a list of dicts of library controllers, for + # This should be a list of dicts of library info, for # compatibility with the result from threadpool_info. - limits = {ctl["prefix"]: ctl["num_threads"] for ctl in limits} + limits = {lib_info["prefix"]: lib_info["num_threads"] for lib_info in limits} elif isinstance(limits, ThreadpoolController): # To set the limits from the library controllers of a # ThreadpoolController object. - limits = {ctl.prefix: ctl.num_threads for ctl in limits} + limits = {lib_controller.prefix: lib_controller.num_threads for lib_controller in limits} if not isinstance(limits, dict): raise TypeError( @@ -283,19 +283,19 @@ def _set_threadpool_limits(self): if self._limits is None: return None - for libctl in self._controller.lib_controllers: + for lib_controller in self._controller.lib_controllers: # self._limits is a dict {key: num_threads} where key is either # a prefix or a user_api. If a library matches both, the limit # corresponding to the prefix is chosen. - if libctl.prefix in self._limits: - num_threads = self._limits[libctl.prefix] - elif libctl.user_api in self._limits: - num_threads = self._limits[libctl.user_api] + if lib_controller.prefix in self._limits: + num_threads = self._limits[lib_controller.prefix] + elif lib_controller.user_api in self._limits: + num_threads = self._limits[lib_controller.user_api] else: continue if num_threads is not None: - libctl._set_num_threads(num_threads) + lib_controller._set_num_threads(num_threads) @_format_docstring( @@ -338,24 +338,28 @@ def _from_controllers(cls, lib_controllers): def todicts(self): """Return info as a list of dicts""" - return [libctl.todict() for libctl in self.lib_controllers] + return [lib_controller.todict() for lib_controller in self.lib_controllers] def select(self, **kwargs): """Return a ThreadpoolController containing a subset of its current library controllers - kwargs can be any number of pair (key, value) where key is a entry - TODO + It will select all libraries matching at least one pair (key, value) from kwargs + where key is an entry of the library info dict (like "user_api", "internal_api", + "prefix", ...) and value is the value or a list of acceptable values for that + entry. + + For instance, `ThreadpoolController().select(internal_api=["blis", "openblas"])` + will select all library controllers whose internal_api is either "blis" or + "openblas". """ - if not kwargs or all(val is None for val in kwargs.values()): - kwargs = {"user_api": _ALL_USER_APIS} for key, vals in kwargs.items(): kwargs[key] = [vals] if not isinstance(vals, list) else vals lib_controllers = [ - libctl - for libctl in self.lib_controllers - if any(getattr(libctl, key, None) in vals for key, vals in kwargs.items()) + lib_controller + for lib_controller in self.lib_controllers + if any(getattr(lib_controller, key, None) in vals for key, vals in kwargs.items()) ] return ThreadpoolController._from_controllers(lib_controllers) @@ -521,14 +525,14 @@ def _make_controller_from_path(self, filepath): user_api = candidate_lib["user_api"] internal_api = candidate_lib["internal_api"] - libctl_class = globals()[controller_class] - libctl = libctl_class( + lib_controller_class = globals()[controller_class] + lib_controller = lib_controller_class( filepath=filepath, prefix=prefix, user_api=user_api, internal_api=internal_api, ) - self.lib_controllers.append(libctl) + self.lib_controllers.append(lib_controller) def _check_prefix(self, library_basename, filename_prefixes): """Return the prefix library_basename starts with @@ -546,7 +550,7 @@ def _warn_if_incompatible_openmp(self): # Only raise the warning on linux return - prefixes = [libctl.prefix for libctl in self.lib_controllers] + prefixes = [lib_controller.prefix for lib_controller in self.lib_controllers] msg = textwrap.dedent( """ Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at From e45c5bb93876faec5da8cb8cbedb4226905b639d Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 17 Sep 2021 16:48:00 +0200 Subject: [PATCH 08/18] keep get set public --- tests/test_threadpoolctl.py | 10 ++-- threadpoolctl.py | 91 +++++++++++++++++++------------------ 2 files changed, 51 insertions(+), 50 deletions(-) diff --git a/tests/test_threadpoolctl.py b/tests/test_threadpoolctl.py index 474de75f..e7c5ed1e 100644 --- a/tests/test_threadpoolctl.py +++ b/tests/test_threadpoolctl.py @@ -50,7 +50,7 @@ def test_threadpool_controller_todicts(): assert "prefix" in lib_controller_dict assert "filepath" in lib_controller_dict assert "version" in lib_controller_dict - assert "num_threads" in lib_contoller_dict + assert "num_threads" in lib_controller_dict if lib_controller_dict["internal_api"] in ("mkl", "blis", "openblas"): assert "threading_layer" in lib_controller_dict @@ -94,7 +94,7 @@ def test_threadpool_limits_by_prefix(prefix, limit): continue # threadpool_limits only sets an upper bound on the number of # threads. - assert 0 < lib_controller._get_num_threads() <= limit + assert 0 < lib_controller.get_num_threads() <= limit assert ThreadpoolController() == original_controller @@ -118,7 +118,7 @@ def test_set_threadpool_limits_by_api(user_api, limit): continue # threadpool_limits only sets an upper bound on the number of # threads. - assert 0 < lib_controller._get_num_threads() <= limit + assert 0 < lib_controller.get_num_threads() <= limit assert ThreadpoolController() == original_controller @@ -293,7 +293,7 @@ def test_shipped_openblas(): with threadpool_limits(1): for lib_controller in openblas_controllers: - assert lib_controller._get_num_threads() == 1 + assert lib_controller.get_num_threads() == 1 assert original_controller == ThreadpoolController() @@ -369,7 +369,7 @@ def test_get_original_num_threads(limit): with threadpool_limits(limits=2, user_api="blas") as ctx: # set different blas num threads to start with (when multiple openblas) if ctx._controller: - ctx._controller.lib_controllers[0]._set_num_threads(1) + ctx._controller.lib_controllers[0].set_num_threads(1) original_controller = ThreadpoolController() with threadpool_limits(limits=limit, user_api="blas") as threadpoolctx: diff --git a/threadpoolctl.py b/threadpoolctl.py index af4a7ac7..215b1305 100644 --- a/threadpoolctl.py +++ b/threadpoolctl.py @@ -198,7 +198,7 @@ def unregister(self): for lib_controller in self._controller.lib_controllers: # Since we never call get_num_threads after instanciation of # ThreadpoolController, num_threads holds the original value. - lib_controller._set_num_threads(lib_controller.num_threads) + lib_controller.set_num_threads(lib_controller.num_threads) def get_original_num_threads(self): """Original num_threads from before calling threadpool_limits @@ -295,7 +295,7 @@ def _set_threadpool_limits(self): continue if num_threads is not None: - lib_controller._set_num_threads(num_threads) + lib_controller.set_num_threads(num_threads) @_format_docstring( @@ -625,7 +625,7 @@ def __init__(self, *, filepath=None, prefix=None, user_api=None, internal_api=No self.filepath = filepath self._dynlib = ctypes.CDLL(filepath, mode=_RTLD_NOLOAD) self.version = self._get_version() - self.num_threads = self._get_num_threads() + self.num_threads = self.get_num_threads() def __eq__(self, other): return self.todict() == other.todict() @@ -635,18 +635,18 @@ def todict(self): return {k: v for k, v in vars(self).items() if not k.startswith("_")} @abstractmethod - def _get_version(self): - """Return the version of the shared library""" + def get_num_threads(self): + """Return the maximum number of threads available to use""" pass # pragma: no cover @abstractmethod - def _get_num_threads(self): - """Return the maximum number of threads available to use""" + def set_num_threads(self, num_threads): + """Set the maximum number of threads to use""" pass # pragma: no cover @abstractmethod - def _set_num_threads(self, num_threads): - """Set the maximum number of threads to use""" + def _get_version(self): + """Return the version of the shared library""" pass # pragma: no cover @@ -658,6 +658,16 @@ def __init__(self, **kwargs): self.threading_layer = self._get_threading_layer() self.architecture = self._get_architecture() + def get_num_threads(self): + get_func = getattr(self._dynlib, "openblas_get_num_threads", lambda: None) + return get_func() + + def set_num_threads(self, num_threads): + set_func = getattr( + self._dynlib, "openblas_set_num_threads", lambda num_threads: None + ) + return set_func(num_threads) + def _get_version(self): # None means OpenBLAS is not loaded or version < 0.3.4, since OpenBLAS # did not expose its version before that. @@ -671,16 +681,6 @@ def _get_version(self): return config[1].decode("utf-8") return None - def _get_num_threads(self): - get_func = getattr(self._dynlib, "openblas_get_num_threads", lambda: None) - return get_func() - - def _set_num_threads(self, num_threads): - set_func = getattr( - self._dynlib, "openblas_set_num_threads", lambda num_threads: None - ) - return set_func(num_threads) - def _get_threading_layer(self): """Return the threading layer of OpenBLAS""" openblas_get_parallel = getattr(self._dynlib, "openblas_get_parallel", None) @@ -711,27 +711,27 @@ def __init__(self, **kwargs): self.threading_layer = self._get_threading_layer() self.architecture = self._get_architecture() - def _get_version(self): - get_version_ = getattr(self._dynlib, "bli_info_get_version_str", None) - if get_version_ is None: - return None - - get_version_.restype = ctypes.c_char_p - return get_version_().decode("utf-8") - - def _get_num_threads(self): + def get_num_threads(self): get_func = getattr(self._dynlib, "bli_thread_get_num_threads", lambda: None) num_threads = get_func() # by default BLIS is single-threaded and get_num_threads # returns -1. We map it to 1 for consistency with other libraries. return 1 if num_threads == -1 else num_threads - def _set_num_threads(self, num_threads): + def set_num_threads(self, num_threads): set_func = getattr( self._dynlib, "bli_thread_set_num_threads", lambda num_threads: None ) return set_func(num_threads) + def _get_version(self): + get_version_ = getattr(self._dynlib, "bli_info_get_version_str", None) + if get_version_ is None: + return None + + get_version_.restype = ctypes.c_char_p + return get_version_().decode("utf-8") + def _get_threading_layer(self): """Return the threading layer of BLIS""" if self._dynlib.bli_info_get_enable_openmp(): @@ -761,6 +761,16 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self.threading_layer = self._get_threading_layer() + def get_num_threads(self): + get_func = getattr(self._dynlib, "MKL_Get_Max_Threads", lambda: None) + return get_func() + + def set_num_threads(self, num_threads): + set_func = getattr( + self._dynlib, "MKL_Set_Num_Threads", lambda num_threads: None + ) + return set_func(num_threads) + def _get_version(self): if not hasattr(self._dynlib, "MKL_Get_Version_String"): return None @@ -774,16 +784,6 @@ def _get_version(self): version = group.groups()[0] return version.strip() - def _get_num_threads(self): - get_func = getattr(self._dynlib, "MKL_Get_Max_Threads", lambda: None) - return get_func() - - def _set_num_threads(self, num_threads): - set_func = getattr( - self._dynlib, "MKL_Set_Num_Threads", lambda num_threads: None - ) - return set_func(num_threads) - def _get_threading_layer(self): """Return the threading layer of MKL""" # The function mkl_set_threading_layer returns the current threading @@ -806,20 +806,21 @@ def _get_threading_layer(self): class OpenMPController(LibController): """Controller class for OpenMP""" - def _get_version(self): - # There is no way to get the version number programmatically in OpenMP. - return None - - def _get_num_threads(self): + def get_num_threads(self): get_func = getattr(self._dynlib, "omp_get_max_threads", lambda: None) return get_func() - def _set_num_threads(self, num_threads): + def set_num_threads(self, num_threads): set_func = getattr( self._dynlib, "omp_set_num_threads", lambda num_threads: None ) return set_func(num_threads) + def _get_version(self): + # There is no way to get the version number programmatically in OpenMP. + return None + + def _main(): """Commandline interface to display thread-pool information and exit.""" From 8428a806c27bd8fabcb6df0ff08b4d5886a1ac9d Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 17 Sep 2021 16:48:39 +0200 Subject: [PATCH 09/18] black --- tests/test_threadpoolctl.py | 16 ++++++++++++---- threadpoolctl.py | 15 +++++++++++---- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/tests/test_threadpoolctl.py b/tests/test_threadpoolctl.py index e7c5ed1e..80faf0f3 100644 --- a/tests/test_threadpoolctl.py +++ b/tests/test_threadpoolctl.py @@ -41,8 +41,12 @@ def test_threadpool_controller_todicts(): # returned by the todict(s) methods controller = ThreadpoolController() - assert threadpool_info() == [lib_controller.todict() for lib_controller in controller] - assert controller.todicts() == [lib_controller.todict() for lib_controller in controller] + assert threadpool_info() == [ + lib_controller.todict() for lib_controller in controller + ] + assert controller.todicts() == [ + lib_controller.todict() for lib_controller in controller + ] for lib_controller_dict in controller.todicts(): assert "user_api" in lib_controller_dict @@ -181,7 +185,9 @@ def test_threadpool_limits_with_controller(): blas_controller = ThreadpoolController().select(user_api="blas") openmp_controller = ThreadpoolController().select(user_api="openmp") - assert all(lib_controller.num_threads == 1 for lib_controller in blas_controller) + assert all( + lib_controller.num_threads == 1 for lib_controller in blas_controller + ) # the provided controller contains only blas libraries so no opemp # library should be impacted. assert openmp_controller == original_openmp_controller @@ -379,7 +385,9 @@ def test_get_original_num_threads(limit): blas_controller = original_controller.select(user_api="blas") if blas_controller: - expected = min(lib_controller.num_threads for lib_controller in blas_controller) + expected = min( + lib_controller.num_threads for lib_controller in blas_controller + ) assert original_num_threads["blas"] == expected else: assert original_num_threads["blas"] is None diff --git a/threadpoolctl.py b/threadpoolctl.py index 215b1305..c7a630e7 100644 --- a/threadpoolctl.py +++ b/threadpoolctl.py @@ -255,11 +255,16 @@ def _check_params(self, limits, user_api): if isinstance(limits, list): # This should be a list of dicts of library info, for # compatibility with the result from threadpool_info. - limits = {lib_info["prefix"]: lib_info["num_threads"] for lib_info in limits} + limits = { + lib_info["prefix"]: lib_info["num_threads"] for lib_info in limits + } elif isinstance(limits, ThreadpoolController): # To set the limits from the library controllers of a # ThreadpoolController object. - limits = {lib_controller.prefix: lib_controller.num_threads for lib_controller in limits} + limits = { + lib_controller.prefix: lib_controller.num_threads + for lib_controller in limits + } if not isinstance(limits, dict): raise TypeError( @@ -359,7 +364,10 @@ def select(self, **kwargs): lib_controllers = [ lib_controller for lib_controller in self.lib_controllers - if any(getattr(lib_controller, key, None) in vals for key, vals in kwargs.items()) + if any( + getattr(lib_controller, key, None) in vals + for key, vals in kwargs.items() + ) ] return ThreadpoolController._from_controllers(lib_controllers) @@ -821,7 +829,6 @@ def _get_version(self): return None - def _main(): """Commandline interface to display thread-pool information and exit.""" import argparse From c0c86e6afe891c80081f7c0c7a9f13076991f40c Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 17 Sep 2021 17:51:12 +0200 Subject: [PATCH 10/18] add a method to return the context manager --- tests/test_threadpoolctl.py | 31 +++++++++++--- threadpoolctl.py | 83 ++++++++++++++++++++++++++++++------- 2 files changed, 92 insertions(+), 22 deletions(-) diff --git a/tests/test_threadpoolctl.py b/tests/test_threadpoolctl.py index 80faf0f3..f345f606 100644 --- a/tests/test_threadpoolctl.py +++ b/tests/test_threadpoolctl.py @@ -175,24 +175,43 @@ def test_threadpool_limits_manual_unregister(): assert ThreadpoolController() == original_controller -def test_threadpool_limits_with_controller(): - # Check that threadpool_limits will only act on the libraries contained in - # the controller when provided +def test_threadpool_controller_limit(): + # Check that using the limit method of ThreadpoolController only impact its + # library controllers. original_blas_controller = ThreadpoolController().select(user_api="blas") original_openmp_controller = ThreadpoolController().select(user_api="openmp") - with threadpool_limits(1, controller=original_blas_controller): + with original_blas_controller.limit(limits=1): blas_controller = ThreadpoolController().select(user_api="blas") openmp_controller = ThreadpoolController().select(user_api="openmp") assert all( lib_controller.num_threads == 1 for lib_controller in blas_controller ) - # the provided controller contains only blas libraries so no opemp - # library should be impacted. + # original_blas_controller contains only blas libraries so no opemp library + # should be impacted. assert openmp_controller == original_openmp_controller +def test_threadpool_controller_restore(): + # Check that the restore_limits method of ThreadpoolController is able to set the + # limits back to their original values. Similar to + # test_threadpool_limits_function_with_side_effect but with the object api + controller = ThreadpoolController() + + controller.limit(limits=1) + try: + for lib_controller in ThreadpoolController(): + if is_old_openblas(lib_controller): + continue + assert lib_controller.num_threads == 1 + finally: + # Restore the original limits so that this test does not have any + # side-effect. + controller.restore_limits() + + assert ThreadpoolController() == controller + def test_threadpool_limits_bad_input(): # Check that appropriate errors are raised for invalid arguments match = re.escape(f"user_api must be either in {_ALL_USER_APIS} or None.") diff --git a/threadpoolctl.py b/threadpoolctl.py index c7a630e7..07996ea2 100644 --- a/threadpoolctl.py +++ b/threadpoolctl.py @@ -136,12 +136,12 @@ def threadpool_info(): BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES), OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES), ) -class threadpool_limits: +def threadpool_limits(limits=None, user_api=None): """Change the maximal number of threads that can be used in thread pools. - This class can be used either as a function (the construction of this - object limits the number of threads) or as a context manager, in a `with` - block. + This function returns a class that can be used either as a function (the + construction of this object limits the number of threads) or as a context manager, + in a `with` block. Set the maximal number of threads that can be used in thread pools used in the supported libraries to `limit`. This function works for libraries that @@ -171,21 +171,22 @@ class threadpool_limits: by the BLAS libraries if they rely on OpenMP. - If None, this function will apply to all supported libraries. - - controller : instance of ``ThreadpoolController`` or None (default=None) - The threadpool controller to use. If None, a new controller is created. """ + return ThreadpoolController().limit(limits=limits, user_api=user_api) + + +class _threadpool_limits: + """The guts of ThreadpoolController.limit + + Refer to the docstring of ThreadpoolController.limit for more details. - def __init__(self, limits=None, user_api=None, controller=None): + It will only act on the library controllers held by the provided `controller`. + """ + def __init__(self, controller, *, limits=None, user_api=None): self._limits, self._user_api, self._prefixes = self._check_params( limits, user_api ) - - if controller is not None: - self._controller = controller - else: - self._controller = ThreadpoolController() - + self._controller = controller self._set_threadpool_limits() def __enter__(self): @@ -314,10 +315,9 @@ class ThreadpoolController: Attributes ---------- - lib_controllers : list of ``LibController`` objects + lib_controllers : list of `LibController` objects The list of library controllers of all loaded supported libraries. """ - # Cache for libc under POSIX and a few system libraries under Windows. # We use a class level cache instead of an instance level cache because # it's very unlikely that a shared library will be unloaded and reloaded @@ -372,6 +372,57 @@ def select(self, **kwargs): return ThreadpoolController._from_controllers(lib_controllers) + @_format_docstring( + USER_APIS=", ".join('"{}"'.format(api) for api in _ALL_USER_APIS), + BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES), + OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES), + ) + def limit(self, *, limits=None, user_api=None): + """Change the maximal number of threads that can be used in thread pools. + + This function returns a class that can be used either as a function (the + construction of this object limits the number of threads) or as a context + manager, in a `with` block. + + Set the maximal number of threads that can be used in thread pools used in + the supported libraries to `limits`. This function works for libraries that + are already loaded in the interpreter and can be changed dynamically. + + Parameters + ---------- + limits : int, dict or None (default=None) + The maximal number of threads that can be used in thread pools + + - If int, sets the maximum number of threads to `limits` for each + library selected by `user_api`. + + - If it is a dictionary `{{key: max_threads}}`, this function sets a + custom maximum number of threads for each `key` which can be either a + `user_api` or a `prefix` for a specific library. + + - If None, this function does not do anything. + + user_api : {USER_APIS} or None (default=None) + APIs of libraries to limit. Used only if `limits` is an int. + + - If "blas", it will only limit BLAS supported libraries ({BLAS_LIBS}). + + - If "openmp", it will only limit OpenMP supported libraries + ({OPENMP_LIBS}). Note that it can affect the number of threads used + by the BLAS libraries if they rely on OpenMP. + + - If None, this function will apply to all supported libraries. + """ + return _threadpool_limits(self, limits=limits, user_api=user_api) + + def restore_limits(self): + """Set the limits back to their original values + + Since get_num_threads is only called once at initialization, the instance keeps + the original num_threads during its whole lifetime. + """ + self.limit(limits=self) + def __len__(self): return len(self.lib_controllers) From a221220d067d47c889cfc7393cdf2fb4ac532d6c Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 17 Sep 2021 17:57:53 +0200 Subject: [PATCH 11/18] cln --- threadpoolctl.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/threadpoolctl.py b/threadpoolctl.py index 07996ea2..ee4d519d 100644 --- a/threadpoolctl.py +++ b/threadpoolctl.py @@ -394,11 +394,11 @@ def limit(self, *, limits=None, user_api=None): The maximal number of threads that can be used in thread pools - If int, sets the maximum number of threads to `limits` for each - library selected by `user_api`. + library selected by `user_api`. - If it is a dictionary `{{key: max_threads}}`, this function sets a - custom maximum number of threads for each `key` which can be either a - `user_api` or a `prefix` for a specific library. + custom maximum number of threads for each `key` which can be either a + `user_api` or a `prefix` for a specific library. - If None, this function does not do anything. @@ -408,8 +408,8 @@ def limit(self, *, limits=None, user_api=None): - If "blas", it will only limit BLAS supported libraries ({BLAS_LIBS}). - If "openmp", it will only limit OpenMP supported libraries - ({OPENMP_LIBS}). Note that it can affect the number of threads used - by the BLAS libraries if they rely on OpenMP. + ({OPENMP_LIBS}). Note that it can affect the number of threads used + by the BLAS libraries if they rely on OpenMP. - If None, this function will apply to all supported libraries. """ From 826eda8ecedbdc3e313e3bf831125729101d005e Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 23 Sep 2021 17:47:01 +0200 Subject: [PATCH 12/18] remove list-like api --- tests/test_threadpoolctl.py | 71 +++++++++++++++++++------------------ threadpoolctl.py | 35 +++++++----------- 2 files changed, 49 insertions(+), 57 deletions(-) diff --git a/tests/test_threadpoolctl.py b/tests/test_threadpoolctl.py index f345f606..c52f472f 100644 --- a/tests/test_threadpoolctl.py +++ b/tests/test_threadpoolctl.py @@ -33,22 +33,22 @@ def test_threadpool_info(): object_info = ThreadpoolController().lib_controllers for lib_info, lib_controller in zip(function_info, object_info): - assert lib_info == lib_controller.todict() + assert lib_info == lib_controller.info() -def test_threadpool_controller_todicts(): +def test_threadpool_controller_info(): # Check that all keys expected for the private api are in the dicts - # returned by the todict(s) methods + # returned by the `info` methods controller = ThreadpoolController() assert threadpool_info() == [ - lib_controller.todict() for lib_controller in controller + lib_controller.info() for lib_controller in controller.lib_controllers ] - assert controller.todicts() == [ - lib_controller.todict() for lib_controller in controller + assert controller.info() == [ + lib_controller.info() for lib_controller in controller.lib_controllers ] - for lib_controller_dict in controller.todicts(): + for lib_controller_dict in controller.info(): assert "user_api" in lib_controller_dict assert "internal_api" in lib_controller_dict assert "prefix" in lib_controller_dict @@ -88,18 +88,18 @@ def test_threadpool_limits_by_prefix(prefix, limit): # Check that the maximum number of threads can be set by prefix original_controller = ThreadpoolController() - lib_controller_matching_prefix = original_controller.select(prefix=prefix) - if not lib_controller_matching_prefix: + controller_matching_prefix = original_controller.select(prefix=prefix) + if not controller_matching_prefix: pytest.skip(f"Requires {prefix} runtime") with threadpool_limits(limits={prefix: limit}): - for lib_controller in lib_controller_matching_prefix: + for lib_controller in controller_matching_prefix.lib_controllers: if is_old_openblas(lib_controller): continue # threadpool_limits only sets an upper bound on the number of # threads. assert 0 < lib_controller.get_num_threads() <= limit - assert ThreadpoolController() == original_controller + assert ThreadpoolController().info() == original_controller.info() @pytest.mark.parametrize("user_api", (None, "blas", "openmp")) @@ -109,22 +109,22 @@ def test_set_threadpool_limits_by_api(user_api, limit): original_controller = ThreadpoolController() if user_api is None: - lib_controller_matching_api = original_controller + controller_matching_api = original_controller else: - lib_controller_matching_api = original_controller.select(user_api=user_api) - if not lib_controller_matching_api: + controller_matching_api = original_controller.select(user_api=user_api) + if not controller_matching_api: user_apis = _ALL_USER_APIS if user_api is None else [user_api] pytest.skip(f"Requires a library which api is in {user_apis}") with threadpool_limits(limits=limit, user_api=user_api): - for lib_controller in lib_controller_matching_api: + for lib_controller in controller_matching_api.lib_controllers: if is_old_openblas(lib_controller): continue # threadpool_limits only sets an upper bound on the number of # threads. assert 0 < lib_controller.get_num_threads() <= limit - assert ThreadpoolController() == original_controller + assert ThreadpoolController().info() == original_controller.info() def test_threadpool_limits_function_with_side_effect(): @@ -134,7 +134,7 @@ def test_threadpool_limits_function_with_side_effect(): threadpool_limits(limits=1) try: - for lib_controller in ThreadpoolController(): + for lib_controller in ThreadpoolController().lib_controllers: if is_old_openblas(lib_controller): continue assert lib_controller.num_threads == 1 @@ -143,16 +143,16 @@ def test_threadpool_limits_function_with_side_effect(): # side-effect. threadpool_limits(limits=original_controller) - assert ThreadpoolController() == original_controller + assert ThreadpoolController().info() == original_controller.info() def test_set_threadpool_limits_no_limit(): # Check that limits=None does nothing. original_controller = ThreadpoolController() with threadpool_limits(limits=None): - assert ThreadpoolController() == original_controller + assert ThreadpoolController().info() == original_controller.info() - assert ThreadpoolController() == original_controller + assert ThreadpoolController().info() == original_controller.info() def test_threadpool_limits_manual_unregister(): @@ -163,7 +163,7 @@ def test_threadpool_limits_manual_unregister(): limits = threadpool_limits(limits=1) try: - for lib_controller in ThreadpoolController(): + for lib_controller in ThreadpoolController().lib_controllers: if is_old_openblas(lib_controller): continue assert lib_controller.num_threads == 1 @@ -172,7 +172,7 @@ def test_threadpool_limits_manual_unregister(): # side-effect. limits.unregister() - assert ThreadpoolController() == original_controller + assert ThreadpoolController().info() == original_controller.info() def test_threadpool_controller_limit(): @@ -186,11 +186,11 @@ def test_threadpool_controller_limit(): openmp_controller = ThreadpoolController().select(user_api="openmp") assert all( - lib_controller.num_threads == 1 for lib_controller in blas_controller + lib_controller.num_threads == 1 for lib_controller in blas_controller.lib_controllers ) # original_blas_controller contains only blas libraries so no opemp library # should be impacted. - assert openmp_controller == original_openmp_controller + assert openmp_controller.info() == original_openmp_controller.info() def test_threadpool_controller_restore(): @@ -201,7 +201,7 @@ def test_threadpool_controller_restore(): controller.limit(limits=1) try: - for lib_controller in ThreadpoolController(): + for lib_controller in ThreadpoolController().lib_controllers: if is_old_openblas(lib_controller): continue assert lib_controller.num_threads == 1 @@ -210,7 +210,8 @@ def test_threadpool_controller_restore(): # side-effect. controller.restore_limits() - assert ThreadpoolController() == controller + assert ThreadpoolController().info() == controller.info() + def test_threadpool_limits_bad_input(): # Check that appropriate errors are raised for invalid arguments @@ -291,7 +292,7 @@ def test_openmp_nesting(nthreads_outer): # The state of the original state of all threadpools should have been # restored. - assert ThreadpoolController() == original_controller + assert ThreadpoolController().info() == original_controller.info() # The number of threads available in the outer loop should not have been # decreased: @@ -317,10 +318,10 @@ def test_shipped_openblas(): openblas_controllers = original_controller.select(internal_api="openblas") with threadpool_limits(1): - for lib_controller in openblas_controllers: + for lib_controller in openblas_controllers.lib_controllers: assert lib_controller.get_num_threads() == 1 - assert original_controller == ThreadpoolController() + assert original_controller.info() == ThreadpoolController().info() @pytest.mark.skipif( @@ -358,7 +359,7 @@ def test_nested_prange_blas(nthreads_outer): # numpy can be linked to BLIS for CBLAS and OpenBLAS for LAPACK. In that # case this test will run BLIS gemm so no need to skip. if not blis_controllers and any( - is_old_openblas(lib_controller) for lib_controller in blas_controllers + is_old_openblas(lib_controller) for lib_controller in blas_controllers.lib_controllers ): pytest.skip("Old OpenBLAS: skipping test to avoid deadlock") @@ -376,11 +377,11 @@ def test_nested_prange_blas(nthreads_outer): assert prange_num_threads == nthreads nested_blas_controllers = inner_controller.select(user_api="blas") - assert len(nested_blas_controllers) == len(blas_controllers) - for lib_controller in nested_blas_controllers: + assert len(nested_blas_controllers.lib_controllers) == len(blas_controllers.lib_controllers) + for lib_controller in nested_blas_controllers.lib_controllers: assert lib_controller.num_threads == 1 - assert original_controller == ThreadpoolController() + assert original_controller.info() == ThreadpoolController().info() # the method `get_original_num_threads` raises a UserWarning due to different @@ -405,7 +406,7 @@ def test_get_original_num_threads(limit): blas_controller = original_controller.select(user_api="blas") if blas_controller: expected = min( - lib_controller.num_threads for lib_controller in blas_controller + lib_controller.num_threads for lib_controller in blas_controller.lib_controllers ) assert original_num_threads["blas"] == expected else: @@ -459,7 +460,7 @@ def test_libomp_libiomp_warning(recwarn): # Check that a warning is raised when both libomp and libiomp are loaded # It should happen in one CI job (pylatest_conda_mkl_clang_gcc). controller = ThreadpoolController() - prefixes = [lib_controller.prefix for lib_controller in controller] + prefixes = [lib_controller.prefix for lib_controller in controller.lib_controllers] if not ("libomp" in prefixes and "libiomp" in prefixes and sys.platform == "linux"): pytest.skip("Requires both libomp and libiomp loaded, on Linux") diff --git a/threadpoolctl.py b/threadpoolctl.py index 4fe3b9ca..74950aef 100644 --- a/threadpoolctl.py +++ b/threadpoolctl.py @@ -136,7 +136,7 @@ def threadpool_info(): In addition, each library may contain internal_api specific entries. """ - return ThreadpoolController().todicts() + return ThreadpoolController().info() @_format_docstring( @@ -220,7 +220,7 @@ def get_original_num_threads(self): for user_api in self._user_api: limits = [ lib_controller.num_threads - for lib_controller in self._controller.select(user_api=user_api) + for lib_controller in self._controller.select(user_api=user_api).lib_controllers ] limits = set(limits) n_limits = len(limits) @@ -272,7 +272,7 @@ def _check_params(self, limits, user_api): # ThreadpoolController object. limits = { lib_controller.prefix: lib_controller.num_threads - for lib_controller in limits + for lib_controller in limits.lib_controllers } if not isinstance(limits, dict): @@ -343,9 +343,9 @@ def _from_controllers(cls, lib_controllers): new_controller.lib_controllers = lib_controllers return new_controller - def todicts(self): - """Return info as a list of dicts""" - return [lib_controller.todict() for lib_controller in self.lib_controllers] + def info(self): + """Return lib_controllers info as a list of dicts""" + return [lib_controller.info() for lib_controller in self.lib_controllers] def select(self, **kwargs): """Return a ThreadpoolController containing a subset of its current @@ -428,12 +428,6 @@ def restore_limits(self): def __len__(self): return len(self.lib_controllers) - def __iter__(self): - yield from self.lib_controllers - - def __eq__(self, other): - return self.lib_controllers == other.lib_controllers - def _load_libraries(self): """Loop through loaded shared libraries and store the supported ones""" if sys.platform == "darwin": @@ -673,13 +667,10 @@ def __init__(self, *, filepath=None, prefix=None, user_api=None, internal_api=No self.prefix = prefix self.filepath = filepath self._dynlib = ctypes.CDLL(filepath, mode=_RTLD_NOLOAD) - self.version = self._get_version() + self.version = self.get_version() self.num_threads = self.get_num_threads() - def __eq__(self, other): - return self.todict() == other.todict() - - def todict(self): + def info(self): """Return relevant info wrapped in a dict""" return {k: v for k, v in vars(self).items() if not k.startswith("_")} @@ -694,7 +685,7 @@ def set_num_threads(self, num_threads): pass # pragma: no cover @abstractmethod - def _get_version(self): + def get_version(self): """Return the version of the shared library""" pass # pragma: no cover @@ -717,7 +708,7 @@ def set_num_threads(self, num_threads): ) return set_func(num_threads) - def _get_version(self): + def get_version(self): # None means OpenBLAS is not loaded or version < 0.3.4, since OpenBLAS # did not expose its version before that. get_config = getattr(self._dynlib, "openblas_get_config", None) @@ -773,7 +764,7 @@ def set_num_threads(self, num_threads): ) return set_func(num_threads) - def _get_version(self): + def get_version(self): get_version_ = getattr(self._dynlib, "bli_info_get_version_str", None) if get_version_ is None: return None @@ -820,7 +811,7 @@ def set_num_threads(self, num_threads): ) return set_func(num_threads) - def _get_version(self): + def get_version(self): if not hasattr(self._dynlib, "MKL_Get_Version_String"): return None @@ -865,7 +856,7 @@ def set_num_threads(self, num_threads): ) return set_func(num_threads) - def _get_version(self): + def get_version(self): # There is no way to get the version number programmatically in OpenMP. return None From 6621a8805afd46a7d2a1dd64e3e1fb5300d094c8 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 23 Sep 2021 18:27:24 +0200 Subject: [PATCH 13/18] readme --- README.md | 56 +++++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 44 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 5a40d799..56cf3eff 100644 --- a/README.md +++ b/README.md @@ -113,10 +113,28 @@ that are loaded when importing Python packages: 'version': None}] ``` -In the above example, `numpy` was installed from the default anaconda channel and -comes with the MKL and its Intel OpenMP (`libiomp5`) implementation while -`xgboost` was installed from pypi.org and links against GNU OpenMP (`libgomp`) -so both OpenMP runtimes are loaded in the same Python program. +In the above example, `numpy` was installed from the default anaconda channel and comes +with MKL and its Intel OpenMP (`libiomp5`) implementation while `xgboost` was installed +from pypi.org and links against GNU OpenMP (`libgomp`) so both OpenMP runtimes are +loaded in the same Python program. + +The state of these libraries is also accessible through the object oriented API: + +```python +>>> from threadpoolctl import threadpool_info +>>> from pprint import pprint +>>> import numpy +>>> controller = ThreadpoolController() +>>> pprint(controller.info()) +[{'architecture': 'Haswell', + 'filepath': '/home/jeremie/miniconda/envs/dev/lib/libopenblasp-r0.3.17.so', + 'internal_api': 'openblas', + 'num_threads': 4, + 'prefix': 'libopenblas', + 'threading_layer': 'pthreads', + 'user_api': 'blas', + 'version': '0.3.17'}] +``` ### Setting the Maximum Size of Thread-Pools @@ -124,16 +142,30 @@ Control the number of threads used by the underlying runtime libraries in specific sections of your Python program: ```python -from threadpoolctl import threadpool_limits -import numpy as np +>>> from threadpoolctl import threadpool_limits +>>> import numpy as np + +>>> with threadpool_limits(limits=1, user_api='blas'): +... # In this block, calls to blas implementation (like openblas or MKL) +... # will be limited to use only one thread. They can thus be used jointly +... # with thread-parallelism. +... a = np.random.randn(1000, 1000) +... a_squared = a @ a +``` + +The threadpools can also be controlled via the object oriented API, which is especially +useful to avoid searching through all the loaded shared libraries each time. It will +however not act on libraries loaded after the instanciation of the +``ThreadpoolController``: +```python +>>> from threadpoolctl import ThreadpoolController +>>> import numpy as np +>>> controller = ThreadpoolController() -with threadpool_limits(limits=1, user_api='blas'): - # In this block, calls to blas implementation (like openblas or MKL) - # will be limited to use only one thread. They can thus be used jointly - # with thread-parallelism. - a = np.random.randn(1000, 1000) - a_squared = a @ a +>>> with controller.limit(limits=1, user_api='blas'): +... a = np.random.randn(1000, 1000) +... a_squared = a @ a ``` ### Known Limitations From f78dcaaaaab786b94a3d41f204387068d1569728 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 23 Sep 2021 19:31:31 +0200 Subject: [PATCH 14/18] change log --- CHANGES.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index 0d34112b..8219ddd5 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,6 +1,13 @@ -2.3.0 (in development) +3.0.0 (in development) ====================== +- New object `threadpooctl.ThreadpoolController` which holds controllers for all the + supported native libraries. The states of these libraries is accessible through the + `info` method (equivalent to `threadpoolctl.threadpool_info()`) and their number of + threads can be limited with the `limit` method which can be used as a context + manager (equivalent to `threadpoolctl.threadpool_limits()`). This is especially useful + to avoid searching through all loaded shared libraries each time. + - Fixed an attribute error when using old versions of OpenBLAS or BLIS that are missing version query functions. https://github.com/joblib/threadpoolctl/pull/88 From e6ad48dde2da36793652239aa3a79f548ad9bb47 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 23 Sep 2021 19:34:46 +0200 Subject: [PATCH 15/18] address comment --- threadpoolctl.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/threadpoolctl.py b/threadpoolctl.py index 74950aef..356479fd 100644 --- a/threadpoolctl.py +++ b/threadpoolctl.py @@ -147,9 +147,10 @@ def threadpool_info(): def threadpool_limits(limits=None, user_api=None): """Change the maximal number of threads that can be used in thread pools. - This function returns a class that can be used either as a function (the + This function returns an object that can be used either as a callable (the construction of this object limits the number of threads) or as a context manager, - in a `with` block. + in a `with` block to automatically restore the original state of the controlled + libraries when exiting the block. Set the maximal number of threads that can be used in thread pools used in the supported libraries to `limit`. This function works for libraries that @@ -382,9 +383,10 @@ def select(self, **kwargs): def limit(self, *, limits=None, user_api=None): """Change the maximal number of threads that can be used in thread pools. - This function returns a class that can be used either as a function (the + This function returns an object that can be used either as a callable (the construction of this object limits the number of threads) or as a context - manager, in a `with` block. + manager, in a `with` block to automatically restore the original state of the + controlled libraries when exiting the block. Set the maximal number of threads that can be used in thread pools used in the supported libraries to `limits`. This function works for libraries that From 3745899d1431b6136d0c474a5667216d054a45d7 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 24 Sep 2021 10:58:09 +0200 Subject: [PATCH 16/18] black --- tests/test_threadpoolctl.py | 13 +++++++++---- threadpoolctl.py | 13 ++++++++----- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/tests/test_threadpoolctl.py b/tests/test_threadpoolctl.py index c52f472f..0da45111 100644 --- a/tests/test_threadpoolctl.py +++ b/tests/test_threadpoolctl.py @@ -186,7 +186,8 @@ def test_threadpool_controller_limit(): openmp_controller = ThreadpoolController().select(user_api="openmp") assert all( - lib_controller.num_threads == 1 for lib_controller in blas_controller.lib_controllers + lib_controller.num_threads == 1 + for lib_controller in blas_controller.lib_controllers ) # original_blas_controller contains only blas libraries so no opemp library # should be impacted. @@ -359,7 +360,8 @@ def test_nested_prange_blas(nthreads_outer): # numpy can be linked to BLIS for CBLAS and OpenBLAS for LAPACK. In that # case this test will run BLIS gemm so no need to skip. if not blis_controllers and any( - is_old_openblas(lib_controller) for lib_controller in blas_controllers.lib_controllers + is_old_openblas(lib_controller) + for lib_controller in blas_controllers.lib_controllers ): pytest.skip("Old OpenBLAS: skipping test to avoid deadlock") @@ -377,7 +379,9 @@ def test_nested_prange_blas(nthreads_outer): assert prange_num_threads == nthreads nested_blas_controllers = inner_controller.select(user_api="blas") - assert len(nested_blas_controllers.lib_controllers) == len(blas_controllers.lib_controllers) + assert len(nested_blas_controllers.lib_controllers) == len( + blas_controllers.lib_controllers + ) for lib_controller in nested_blas_controllers.lib_controllers: assert lib_controller.num_threads == 1 @@ -406,7 +410,8 @@ def test_get_original_num_threads(limit): blas_controller = original_controller.select(user_api="blas") if blas_controller: expected = min( - lib_controller.num_threads for lib_controller in blas_controller.lib_controllers + lib_controller.num_threads + for lib_controller in blas_controller.lib_controllers ) assert original_num_threads["blas"] == expected else: diff --git a/threadpoolctl.py b/threadpoolctl.py index 356479fd..2e7bd749 100644 --- a/threadpoolctl.py +++ b/threadpoolctl.py @@ -119,8 +119,7 @@ def _realpath(filepath): return os.path.realpath(filepath) -@_format_docstring(USER_APIS=list(_ALL_USER_APIS), - INTERNAL_APIS=_ALL_INTERNAL_APIS) +@_format_docstring(USER_APIS=list(_ALL_USER_APIS), INTERNAL_APIS=_ALL_INTERNAL_APIS) def threadpool_info(): """Return the maximal number of threads for each detected library. @@ -186,11 +185,12 @@ def threadpool_limits(limits=None, user_api=None): class _threadpool_limits: """The guts of ThreadpoolController.limit - + Refer to the docstring of ThreadpoolController.limit for more details. It will only act on the library controllers held by the provided `controller`. """ + def __init__(self, controller, *, limits=None, user_api=None): self._limits, self._user_api, self._prefixes = self._check_params( limits, user_api @@ -221,7 +221,9 @@ def get_original_num_threads(self): for user_api in self._user_api: limits = [ lib_controller.num_threads - for lib_controller in self._controller.select(user_api=user_api).lib_controllers + for lib_controller in self._controller.select( + user_api=user_api + ).lib_controllers ] limits = set(limits) n_limits = len(limits) @@ -327,6 +329,7 @@ class ThreadpoolController: lib_controllers : list of `LibController` objects The list of library controllers of all loaded supported libraries. """ + # Cache for libc under POSIX and a few system libraries under Windows. # We use a class level cache instead of an instance level cache because # it's very unlikely that a shared library will be unloaded and reloaded @@ -421,7 +424,7 @@ def limit(self, *, limits=None, user_api=None): def restore_limits(self): """Set the limits back to their original values - + Since get_num_threads is only called once at initialization, the instance keeps the original num_threads during its whole lifetime. """ From 9002c7265c0a6d3c28d809c9d8f5098d7a0bd297 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 24 Sep 2021 11:07:17 +0200 Subject: [PATCH 17/18] cln --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 56cf3eff..9370b711 100644 --- a/README.md +++ b/README.md @@ -121,7 +121,7 @@ loaded in the same Python program. The state of these libraries is also accessible through the object oriented API: ```python ->>> from threadpoolctl import threadpool_info +>>> from threadpoolctl import ThreadpoolController >>> from pprint import pprint >>> import numpy >>> controller = ThreadpoolController() From de1078dc5fde83c0c03658cc2fb30b6b713d5cf6 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 24 Sep 2021 11:11:10 +0200 Subject: [PATCH 18/18] show equivalence in snippet --- README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 9370b711..624c4922 100644 --- a/README.md +++ b/README.md @@ -121,7 +121,7 @@ loaded in the same Python program. The state of these libraries is also accessible through the object oriented API: ```python ->>> from threadpoolctl import ThreadpoolController +>>> from threadpoolctl import ThreadpoolController, threadpool_info >>> from pprint import pprint >>> import numpy >>> controller = ThreadpoolController() @@ -134,6 +134,9 @@ The state of these libraries is also accessible through the object oriented API: 'threading_layer': 'pthreads', 'user_api': 'blas', 'version': '0.3.17'}] + +>>> controller.info() == threadpool_info() +True ``` ### Setting the Maximum Size of Thread-Pools