diff --git a/CHANGES.md b/CHANGES.md index 0d34112b..8219ddd5 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,6 +1,13 @@ -2.3.0 (in development) +3.0.0 (in development) ====================== +- New object `threadpooctl.ThreadpoolController` which holds controllers for all the + supported native libraries. The states of these libraries is accessible through the + `info` method (equivalent to `threadpoolctl.threadpool_info()`) and their number of + threads can be limited with the `limit` method which can be used as a context + manager (equivalent to `threadpoolctl.threadpool_limits()`). This is especially useful + to avoid searching through all loaded shared libraries each time. + - Fixed an attribute error when using old versions of OpenBLAS or BLIS that are missing version query functions. https://github.com/joblib/threadpoolctl/pull/88 diff --git a/README.md b/README.md index 5a40d799..624c4922 100644 --- a/README.md +++ b/README.md @@ -113,10 +113,31 @@ that are loaded when importing Python packages: 'version': None}] ``` -In the above example, `numpy` was installed from the default anaconda channel and -comes with the MKL and its Intel OpenMP (`libiomp5`) implementation while -`xgboost` was installed from pypi.org and links against GNU OpenMP (`libgomp`) -so both OpenMP runtimes are loaded in the same Python program. +In the above example, `numpy` was installed from the default anaconda channel and comes +with MKL and its Intel OpenMP (`libiomp5`) implementation while `xgboost` was installed +from pypi.org and links against GNU OpenMP (`libgomp`) so both OpenMP runtimes are +loaded in the same Python program. + +The state of these libraries is also accessible through the object oriented API: + +```python +>>> from threadpoolctl import ThreadpoolController, threadpool_info +>>> from pprint import pprint +>>> import numpy +>>> controller = ThreadpoolController() +>>> pprint(controller.info()) +[{'architecture': 'Haswell', + 'filepath': '/home/jeremie/miniconda/envs/dev/lib/libopenblasp-r0.3.17.so', + 'internal_api': 'openblas', + 'num_threads': 4, + 'prefix': 'libopenblas', + 'threading_layer': 'pthreads', + 'user_api': 'blas', + 'version': '0.3.17'}] + +>>> controller.info() == threadpool_info() +True +``` ### Setting the Maximum Size of Thread-Pools @@ -124,16 +145,30 @@ Control the number of threads used by the underlying runtime libraries in specific sections of your Python program: ```python -from threadpoolctl import threadpool_limits -import numpy as np +>>> from threadpoolctl import threadpool_limits +>>> import numpy as np + +>>> with threadpool_limits(limits=1, user_api='blas'): +... # In this block, calls to blas implementation (like openblas or MKL) +... # will be limited to use only one thread. They can thus be used jointly +... # with thread-parallelism. +... a = np.random.randn(1000, 1000) +... a_squared = a @ a +``` + +The threadpools can also be controlled via the object oriented API, which is especially +useful to avoid searching through all the loaded shared libraries each time. It will +however not act on libraries loaded after the instanciation of the +``ThreadpoolController``: +```python +>>> from threadpoolctl import ThreadpoolController +>>> import numpy as np +>>> controller = ThreadpoolController() -with threadpool_limits(limits=1, user_api='blas'): - # In this block, calls to blas implementation (like openblas or MKL) - # will be limited to use only one thread. They can thus be used jointly - # with thread-parallelism. - a = np.random.randn(1000, 1000) - a_squared = a @ a +>>> with controller.limit(limits=1, user_api='blas'): +... a = np.random.randn(1000, 1000) +... a_squared = a @ a ``` ### Known Limitations diff --git a/continuous_integration/install.sh b/continuous_integration/install.sh index fb643256..8b20fef0 100755 --- a/continuous_integration/install.sh +++ b/continuous_integration/install.sh @@ -76,7 +76,7 @@ python -m pip install -q -r dev-requirements.txt bash ./continuous_integration/build_test_ext.sh python --version -python -c "import numpy; print('numpy %s' % numpy.__version__)" || echo "no numpy" -python -c "import scipy; print('scipy %s' % scipy.__version__)" || echo "no scipy" +python -c "import numpy; print(f'numpy {numpy.__version__}')" || echo "no numpy" +python -c "import scipy; print(f'scipy {scipy.__version__}')" || echo "no scipy" python -m flit install --symlink diff --git a/continuous_integration/install_with_blis.sh b/continuous_integration/install_with_blis.sh index 4a10d95c..a2d6c962 100755 --- a/continuous_integration/install_with_blis.sh +++ b/continuous_integration/install_with_blis.sh @@ -51,7 +51,7 @@ CFLAGS=-I$ABS_PATH/BLIS_install/include/blis LDFLAGS=-L$ABS_PATH/BLIS_install/li bash ./continuous_integration/build_test_ext.sh python --version -python -c "import numpy; print('numpy %s' % numpy.__version__)" || echo "no numpy" -python -c "import scipy; print('scipy %s' % scipy.__version__)" || echo "no scipy" +python -c "import numpy; print(f'numpy {numpy.__version__}')" || echo "no numpy" +python -c "import scipy; print(f'scipy {scipy.__version__}')" || echo "no scipy" python -m flit install --symlink diff --git a/tests/_openmp_test_helper/nested_prange_blas.pyx b/tests/_openmp_test_helper/nested_prange_blas.pyx index e327eee0..aec7f815 100644 --- a/tests/_openmp_test_helper/nested_prange_blas.pyx +++ b/tests/_openmp_test_helper/nested_prange_blas.pyx @@ -20,7 +20,7 @@ IF USE_BLIS: ELSE: from scipy.linalg.cython_blas cimport dgemm -from threadpoolctl import _ThreadpoolInfo, _ALL_USER_APIS +from threadpoolctl import ThreadpoolController def check_nested_prange_blas(double[:, ::1] A, double[:, ::1] B, int nthreads): @@ -43,12 +43,12 @@ def check_nested_prange_blas(double[:, ::1] A, double[:, ::1] B, int nthreads): int prange_num_threads int *prange_num_threads_ptr = &prange_num_threads - inner_info = [None] + inner_controller = [None] with nogil, parallel(num_threads=nthreads): if openmp.omp_get_thread_num() == 0: with gil: - inner_info[0] = _ThreadpoolInfo(user_api=_ALL_USER_APIS) + inner_controller[0] = ThreadpoolController() prange_num_threads_ptr[0] = openmp.omp_get_num_threads() @@ -62,4 +62,4 @@ def check_nested_prange_blas(double[:, ::1] A, double[:, ::1] B, int nthreads): &alpha, &B[0, 0], &k, &A[i * chunk_size, 0], &k, &beta, &C[i * chunk_size, 0], &n) - return np.asarray(C), prange_num_threads, inner_info[0] + return np.asarray(C), prange_num_threads, inner_controller[0] diff --git a/tests/test_threadpoolctl.py b/tests/test_threadpoolctl.py index c4f5e18c..cc7fe06a 100644 --- a/tests/test_threadpoolctl.py +++ b/tests/test_threadpoolctl.py @@ -5,7 +5,8 @@ import subprocess import sys -from threadpoolctl import threadpool_limits, threadpool_info, _ThreadpoolInfo +from threadpoolctl import threadpool_limits, threadpool_info +from threadpoolctl import ThreadpoolController from threadpoolctl import _ALL_PREFIXES, _ALL_USER_APIS from .utils import cython_extensions_compiled @@ -14,10 +15,10 @@ from .utils import threadpool_info_from_subprocess -def is_old_openblas(module): +def is_old_openblas(lib_controller): # Possible bug in getting maximum number of threads with OpenBLAS < 0.2.16 # and OpenBLAS does not expose its version before 0.3.4. - return module.internal_api == "openblas" and module.version is None + return lib_controller.internal_api == "openblas" and lib_controller.version is None def effective_num_threads(nthreads, max_threads): @@ -26,130 +27,191 @@ def effective_num_threads(nthreads, max_threads): return nthreads -def _threadpool_info(): - # Like threadpool_info but return the object instead of the list of dicts - return _ThreadpoolInfo(user_api=_ALL_USER_APIS) +def test_threadpool_info(): + # Check consistency between threadpool_info and ThreadpoolController + function_info = threadpool_info() + object_info = ThreadpoolController().lib_controllers + for lib_info, lib_controller in zip(function_info, object_info): + assert lib_info == lib_controller.info() -def test_threadpool_limits_public_api(): - # Check consistency between threadpool_info and _ThreadpoolInfo - public_info = threadpool_info() - private_info = _threadpool_info() - for module1, module2 in zip(public_info, private_info): - assert module1 == module2.todict() +def test_threadpool_controller_info(): + # Check that all keys expected for the private api are in the dicts + # returned by the `info` methods + controller = ThreadpoolController() + assert threadpool_info() == [ + lib_controller.info() for lib_controller in controller.lib_controllers + ] + assert controller.info() == [ + lib_controller.info() for lib_controller in controller.lib_controllers + ] -def test_ThreadpoolInfo_todicts(): - # Check all keys expected for the public api are in the dicts returned by - # the .todict(s) methods - info = _threadpool_info() + for lib_controller_dict in controller.info(): + assert "user_api" in lib_controller_dict + assert "internal_api" in lib_controller_dict + assert "prefix" in lib_controller_dict + assert "filepath" in lib_controller_dict + assert "version" in lib_controller_dict + assert "num_threads" in lib_controller_dict - assert threadpool_info() == [module.todict() for module in info.modules] - assert info.todicts() == [module.todict() for module in info] - assert info.todicts() == [module.todict() for module in info.modules] + if lib_controller_dict["internal_api"] in ("mkl", "blis", "openblas"): + assert "threading_layer" in lib_controller_dict - for module in info: - module_dict = module.todict() - assert "user_api" in module_dict - assert "internal_api" in module_dict - assert "prefix" in module_dict - assert "filepath" in module_dict - assert "version" in module_dict - assert "num_threads" in module_dict - if module.internal_api in ("mkl", "blis", "openblas"): - assert "threading_layer" in module_dict +@pytest.mark.parametrize( + "kwargs", + [ + {"user_api": "blas"}, + {"prefix": "libgomp"}, + {"internal_api": "openblas", "prefix": "libomp"}, + {"prefix": ["libgomp", "libomp", "libiomp"]}, + ], +) +def test_threadpool_controller_select(kwargs): + # Check the behior of the select method of ThreadpoolController + controller = ThreadpoolController().select(**kwargs) + if not controller: + pytest.skip(f"Requires at least one of {list(kwargs.values())}.") + + for lib_controller in controller.lib_controllers: + assert any( + getattr(lib_controller, key) in (val if isinstance(val, list) else [val]) + for key, val in kwargs.items() + ) @pytest.mark.parametrize("prefix", _ALL_PREFIXES) @pytest.mark.parametrize("limit", [1, 3]) def test_threadpool_limits_by_prefix(prefix, limit): # Check that the maximum number of threads can be set by prefix - original_info = _threadpool_info() + original_controller = ThreadpoolController() - modules_matching_prefix = original_info.get_modules("prefix", prefix) - if not modules_matching_prefix: + controller_matching_prefix = original_controller.select(prefix=prefix) + if not controller_matching_prefix: pytest.skip(f"Requires {prefix} runtime") with threadpool_limits(limits={prefix: limit}): - for module in modules_matching_prefix: - if is_old_openblas(module): + for lib_controller in controller_matching_prefix.lib_controllers: + if is_old_openblas(lib_controller): continue # threadpool_limits only sets an upper bound on the number of # threads. - assert 0 < module.get_num_threads() <= limit - assert _threadpool_info() == original_info + assert 0 < lib_controller.get_num_threads() <= limit + assert ThreadpoolController().info() == original_controller.info() @pytest.mark.parametrize("user_api", (None, "blas", "openmp")) @pytest.mark.parametrize("limit", [1, 3]) def test_set_threadpool_limits_by_api(user_api, limit): # Check that the maximum number of threads can be set by user_api - original_info = _threadpool_info() + original_controller = ThreadpoolController() - modules_matching_api = original_info.get_modules("user_api", user_api) - if not modules_matching_api: + if user_api is None: + controller_matching_api = original_controller + else: + controller_matching_api = original_controller.select(user_api=user_api) + if not controller_matching_api: user_apis = _ALL_USER_APIS if user_api is None else [user_api] pytest.skip(f"Requires a library which api is in {user_apis}") with threadpool_limits(limits=limit, user_api=user_api): - for module in modules_matching_api: - if is_old_openblas(module): + for lib_controller in controller_matching_api.lib_controllers: + if is_old_openblas(lib_controller): continue # threadpool_limits only sets an upper bound on the number of # threads. - assert 0 < module.get_num_threads() <= limit + assert 0 < lib_controller.get_num_threads() <= limit - assert _threadpool_info() == original_info + assert ThreadpoolController().info() == original_controller.info() def test_threadpool_limits_function_with_side_effect(): # Check that threadpool_limits can be used as a function with # side effects instead of a context manager. - original_info = _threadpool_info() + original_controller = ThreadpoolController() threadpool_limits(limits=1) try: - for module in _threadpool_info(): - if is_old_openblas(module): + for lib_controller in ThreadpoolController().lib_controllers: + if is_old_openblas(lib_controller): continue - assert module.num_threads == 1 + assert lib_controller.num_threads == 1 finally: # Restore the original limits so that this test does not have any # side-effect. - threadpool_limits(limits=original_info) + threadpool_limits(limits=original_controller) - assert _threadpool_info() == original_info + assert ThreadpoolController().info() == original_controller.info() def test_set_threadpool_limits_no_limit(): # Check that limits=None does nothing. - original_info = _threadpool_info() + original_controller = ThreadpoolController() with threadpool_limits(limits=None): - assert _threadpool_info() == original_info + assert ThreadpoolController().info() == original_controller.info() - assert _threadpool_info() == original_info + assert ThreadpoolController().info() == original_controller.info() def test_threadpool_limits_manual_unregister(): # Check that threadpool_limits can be used as an object which holds the # original state of the threadpools and that can be restored thanks to the # dedicated unregister method - original_info = _threadpool_info() + original_controller = ThreadpoolController() limits = threadpool_limits(limits=1) try: - for module in _threadpool_info(): - if is_old_openblas(module): + for lib_controller in ThreadpoolController().lib_controllers: + if is_old_openblas(lib_controller): continue - assert module.num_threads == 1 + assert lib_controller.num_threads == 1 finally: # Restore the original limits so that this test does not have any # side-effect. limits.unregister() - assert _threadpool_info() == original_info + assert ThreadpoolController().info() == original_controller.info() + + +def test_threadpool_controller_limit(): + # Check that using the limit method of ThreadpoolController only impact its + # library controllers. + original_blas_controller = ThreadpoolController().select(user_api="blas") + original_openmp_controller = ThreadpoolController().select(user_api="openmp") + + with original_blas_controller.limit(limits=1): + blas_controller = ThreadpoolController().select(user_api="blas") + openmp_controller = ThreadpoolController().select(user_api="openmp") + + assert all( + lib_controller.num_threads == 1 + for lib_controller in blas_controller.lib_controllers + ) + # original_blas_controller contains only blas libraries so no opemp library + # should be impacted. + assert openmp_controller.info() == original_openmp_controller.info() + + +def test_threadpool_controller_restore(): + # Check that the restore_limits method of ThreadpoolController is able to set the + # limits back to their original values. Similar to + # test_threadpool_limits_function_with_side_effect but with the object api + controller = ThreadpoolController() + + controller.limit(limits=1) + try: + for lib_controller in ThreadpoolController().lib_controllers: + if is_old_openblas(lib_controller): + continue + assert lib_controller.num_threads == 1 + finally: + # Restore the original limits so that this test does not have any + # side-effect. + controller.restore_limits() + + assert ThreadpoolController().info() == controller.info() def test_threadpool_limits_bad_input(): @@ -211,11 +273,11 @@ def test_openmp_nesting(nthreads_outer): # There are 2 openmp, the one from inner and the one from outer. assert len(outer_info) == 2 # We already know the one from inner. It has to be the other one. - prefixes = {module["prefix"] for module in outer_info} + prefixes = {lib_info["prefix"] for lib_info in outer_info} outer_omp = prefixes - {inner_omp} outer_num_threads, inner_num_threads = check_nested_openmp_loops(10) - original_info = _threadpool_info() + original_controller = ThreadpoolController() if inner_omp == outer_omp: # The OpenMP runtime should be shared by default, meaning that the @@ -232,7 +294,7 @@ def test_openmp_nesting(nthreads_outer): # The state of the original state of all threadpools should have been # restored. - assert _threadpool_info() == original_info + assert ThreadpoolController().info() == original_controller.info() # The number of threads available in the outer loop should not have been # decreased: @@ -245,7 +307,7 @@ def test_openmp_nesting(nthreads_outer): # XXX: this does not always work when nesting independent openmp # implementations. See: https://github.com/jeremiedbb/Nested_OpenMP pytest.xfail( - "Inner OpenMP num threads was %d instead of 1" % inner_num_threads + f"Inner OpenMP num threads was {inner_num_threads} instead of 1" ) assert inner_num_threads == 1 @@ -253,15 +315,15 @@ def test_openmp_nesting(nthreads_outer): def test_shipped_openblas(): # checks that OpenBLAS effectively uses the number of threads requested by # the context manager - original_info = _threadpool_info() + original_controller = ThreadpoolController() - openblas_modules = original_info.get_modules("internal_api", "openblas") + openblas_controllers = original_controller.select(internal_api="openblas") with threadpool_limits(1): - for module in openblas_modules: - assert module.get_num_threads() == 1 + for lib_controller in openblas_controllers.lib_controllers: + assert lib_controller.get_num_threads() == 1 - assert original_info == _threadpool_info() + assert original_controller.info() == ThreadpoolController().info() @pytest.mark.skipif( @@ -288,17 +350,20 @@ def test_nested_prange_blas(nthreads_outer): check_nested_prange_blas = prange_blas.check_nested_prange_blas - original_info = _threadpool_info() + original_controller = ThreadpoolController() - blas_info = original_info.get_modules("user_api", "blas") - blis_info = original_info.get_modules("internal_api", "blis") + blas_controllers = original_controller.select(user_api="blas") + blis_controllers = original_controller.select(internal_api="blis") # skip if the BLAS used by numpy is an old openblas. OpenBLAS 0.3.3 and # older are known to cause an unrecoverable deadlock at process shutdown # time (after pytest has exited). # numpy can be linked to BLIS for CBLAS and OpenBLAS for LAPACK. In that # case this test will run BLIS gemm so no need to skip. - if not blis_info and any(is_old_openblas(module) for module in blas_info): + if not blis_controllers and any( + is_old_openblas(lib_controller) + for lib_controller in blas_controllers.lib_controllers + ): pytest.skip("Old OpenBLAS: skipping test to avoid deadlock") A = np.ones((1000, 10)) @@ -309,17 +374,19 @@ def test_nested_prange_blas(nthreads_outer): nthreads = effective_num_threads(nthreads_outer, max_threads) result = check_nested_prange_blas(A, B, nthreads) - C, prange_num_threads, inner_info = result + C, prange_num_threads, inner_controller = result assert np.allclose(C, np.dot(A, B.T)) assert prange_num_threads == nthreads - nested_blas_info = inner_info.get_modules("user_api", "blas") - assert len(nested_blas_info) == len(blas_info) - for module in nested_blas_info: - assert module.num_threads == 1 + nested_blas_controllers = inner_controller.select(user_api="blas") + assert len(nested_blas_controllers.lib_controllers) == len( + blas_controllers.lib_controllers + ) + for lib_controller in nested_blas_controllers.lib_controllers: + assert lib_controller.num_threads == 1 - assert original_info == _threadpool_info() + assert original_controller.info() == ThreadpoolController().info() # the method `get_original_num_threads` raises a UserWarning due to different @@ -330,20 +397,23 @@ def test_nested_prange_blas(nthreads_outer): @pytest.mark.parametrize("limit", [1, None]) def test_get_original_num_threads(limit): # Tests the method get_original_num_threads of the context manager - with threadpool_limits(limits=2, user_api="blas") as ctl: + with threadpool_limits(limits=2, user_api="blas") as ctx: # set different blas num threads to start with (when multiple openblas) - if ctl._original_info: - ctl._original_info.modules[0].set_num_threads(1) + if ctx._controller: + ctx._controller.lib_controllers[0].set_num_threads(1) - original_info = _threadpool_info() + original_controller = ThreadpoolController() with threadpool_limits(limits=limit, user_api="blas") as threadpoolctx: original_num_threads = threadpoolctx.get_original_num_threads() assert "openmp" not in original_num_threads - blas_info = original_info.get_modules("user_api", "blas") - if blas_info: - expected = min(module.num_threads for module in blas_info) + blas_controller = original_controller.select(user_api="blas") + if blas_controller: + expected = min( + lib_controller.num_threads + for lib_controller in blas_controller.lib_controllers + ) assert original_num_threads["blas"] == expected else: assert original_num_threads["blas"] is None @@ -356,30 +426,30 @@ def test_get_original_num_threads(limit): def test_mkl_threading_layer(): # Check that threadpool_info correctly recovers the threading layer used # by mkl - mkl_info = _threadpool_info().get_modules("internal_api", "mkl") + mkl_controller = ThreadpoolController().select(internal_api="mkl") expected_layer = os.getenv("MKL_THREADING_LAYER") - if not (mkl_info and expected_layer): + if not (mkl_controller and expected_layer): pytest.skip("requires MKL and the environment variable MKL_THREADING_LAYER set") - actual_layer = mkl_info.modules[0].threading_layer + actual_layer = mkl_controller.lib_controllers[0].threading_layer assert actual_layer == expected_layer.lower() def test_blis_threading_layer(): # Check that threadpool_info correctly recovers the threading layer used # by blis - blis_info = _threadpool_info().get_modules("internal_api", "blis") + blis_controller = ThreadpoolController().select(internal_api="blis") expected_layer = os.getenv("BLIS_ENABLE_THREADING") if expected_layer == "no": expected_layer = "disabled" - if not (blis_info and expected_layer): + if not (blis_controller and expected_layer): pytest.skip( "requires BLIS and the environment variable BLIS_ENABLE_THREADING set" ) - actual_layer = blis_info.modules[0].threading_layer + actual_layer = blis_controller.lib_controllers[0].threading_layer assert actual_layer == expected_layer @@ -395,8 +465,8 @@ def test_libomp_libiomp_warning(recwarn): # Check that a warning is raised when both libomp and libiomp are loaded # It should happen in one CI job (pylatest_conda_mkl_clang_gcc). - info = _threadpool_info() - prefixes = [module.prefix for module in info] + controller = ThreadpoolController() + prefixes = [lib_controller.prefix for lib_controller in controller.lib_controllers] if not ("libomp" in prefixes and "libiomp" in prefixes and sys.platform == "linux"): pytest.skip("Requires both libomp and libiomp loaded, on Linux") @@ -422,8 +492,8 @@ def test_command_line_command_flag(): cli_info = json.loads(output.decode("utf-8")) this_process_info = threadpool_info() - for module in cli_info: - assert module in this_process_info + for lib_info in cli_info: + assert lib_info in this_process_info @pytest.mark.skipif( @@ -448,8 +518,8 @@ def test_command_line_import_flag(): cli_info = json.loads(result.stdout) this_process_info = threadpool_info() - for module in cli_info: - assert module in this_process_info + for lib_info in cli_info: + assert lib_info in this_process_info warnings = [w.strip() for w in result.stderr.splitlines()] assert "WARNING: could not import invalid_package" in warnings @@ -473,11 +543,11 @@ def test_architecture(): "skx", "haswell", ) - for module in threadpool_info(): - if module["internal_api"] == "openblas": - assert module["architecture"] in expected_openblas_architectures - elif module["internal_api"] == "blis": - assert module["architecture"] in expected_blis_architectures + for lib_info in threadpool_info(): + if lib_info["internal_api"] == "openblas": + assert lib_info["architecture"] in expected_openblas_architectures + elif lib_info["internal_api"] == "blis": + assert lib_info["architecture"] in expected_blis_architectures else: - # Not supported for other modules - assert "architecture" not in module + # Not supported for other libraries + assert "architecture" not in lib_info diff --git a/threadpoolctl.py b/threadpoolctl.py index 205db9e8..a7e368c3 100644 --- a/threadpoolctl.py +++ b/threadpoolctl.py @@ -21,7 +21,7 @@ from functools import lru_cache __version__ = "2.3.0.dev0" -__all__ = ["threadpool_limits", "threadpool_info"] +__all__ = ["threadpool_limits", "threadpool_info", "ThreadpoolController"] # One can get runtime errors or even segfaults due to multiple OpenMP libraries @@ -60,26 +60,26 @@ class _dl_phdr_info(ctypes.Structure): # List of the supported libraries. The items are indexed by the name of the -# class to instanciate to create the module objects. The items hold the -# possible prefixes of loaded shared objects, the name of the internal_api to -# call and the name of the user_api. -_SUPPORTED_MODULES = { - "_OpenMPModule": { +# class to instanciate to create the library controller objects. The items hold +# the possible prefixes of loaded shared objects, the name of the internal_api +# to call and the name of the user_api. +_SUPPORTED_LIBRARIES = { + "OpenMPController": { "user_api": "openmp", "internal_api": "openmp", "filename_prefixes": ("libiomp", "libgomp", "libomp", "vcomp"), }, - "_OpenBLASModule": { + "OpenBLASController": { "user_api": "blas", "internal_api": "openblas", "filename_prefixes": ("libopenblas",), }, - "_MKLModule": { + "MKLController": { "user_api": "blas", "internal_api": "mkl", "filename_prefixes": ("libmkl_rt", "mkl_rt"), }, - "_BLISModule": { + "BLISController": { "user_api": "blas", "internal_api": "blis", "filename_prefixes": ("libblis",), @@ -87,15 +87,21 @@ class _dl_phdr_info(ctypes.Structure): } # Helpers for the doc and test names -_ALL_USER_APIS = list(set(m["user_api"] for m in _SUPPORTED_MODULES.values())) -_ALL_INTERNAL_APIS = [m["internal_api"] for m in _SUPPORTED_MODULES.values()] +_ALL_USER_APIS = list(set(lib["user_api"] for lib in _SUPPORTED_LIBRARIES.values())) +_ALL_INTERNAL_APIS = [lib["internal_api"] for lib in _SUPPORTED_LIBRARIES.values()] _ALL_PREFIXES = [ - prefix for m in _SUPPORTED_MODULES.values() for prefix in m["filename_prefixes"] + prefix + for lib in _SUPPORTED_LIBRARIES.values() + for prefix in lib["filename_prefixes"] ] _ALL_BLAS_LIBRARIES = [ - m["internal_api"] for m in _SUPPORTED_MODULES.values() if m["user_api"] == "blas" + lib["internal_api"] + for lib in _SUPPORTED_LIBRARIES.values() + if lib["user_api"] == "blas" ] -_ALL_OPENMP_LIBRARIES = list(_SUPPORTED_MODULES["_OpenMPModule"]["filename_prefixes"]) +_ALL_OPENMP_LIBRARIES = list( + _SUPPORTED_LIBRARIES["OpenMPController"]["filename_prefixes"] +) def _format_docstring(*args, **kwargs): @@ -117,19 +123,19 @@ def _realpath(filepath): def threadpool_info(): """Return the maximal number of threads for each detected library. - Return a list with all the supported modules that have been found. Each - module is represented by a dict with the following information: + Return a list with all the supported libraries that have been found. Each + library is represented by a dict with the following information: - "user_api" : user API. Possible values are {USER_APIS}. - "internal_api": internal API. Possible values are {INTERNAL_APIS}. - "prefix" : filename prefix of the specific implementation. - - "filepath": path to the loaded module. + - "filepath": path to the loaded library. - "version": version of the library (if available). - "num_threads": the current thread limit. - In addition, each module may contain internal_api specific entries. + In addition, each library may contain internal_api specific entries. """ - return _ThreadpoolInfo(user_api=_ALL_USER_APIS).todicts() + return ThreadpoolController().info() @_format_docstring( @@ -137,12 +143,13 @@ def threadpool_info(): BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES), OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES), ) -class threadpool_limits: +def threadpool_limits(limits=None, user_api=None): """Change the maximal number of threads that can be used in thread pools. - This class can be used either as a function (the construction of this - object limits the number of threads) or as a context manager, in a `with` - block. + This function returns an object that can be used either as a callable (the + construction of this object limits the number of threads) or as a context manager, + in a `with` block to automatically restore the original state of the controlled + libraries when exiting the block. Set the maximal number of threads that can be used in thread pools used in the supported libraries to `limit`. This function works for libraries that @@ -173,13 +180,23 @@ class threadpool_limits: - If None, this function will apply to all supported libraries. """ + return ThreadpoolController().limit(limits=limits, user_api=user_api) + + +class _threadpool_limits: + """The guts of ThreadpoolController.limit + + Refer to the docstring of ThreadpoolController.limit for more details. - def __init__(self, limits=None, user_api=None): + It will only act on the library controllers held by the provided `controller`. + """ + + def __init__(self, controller, *, limits=None, user_api=None): self._limits, self._user_api, self._prefixes = self._check_params( limits, user_api ) - - self._original_info = self._set_threadpool_limits() + self._controller = controller + self._set_threadpool_limits() def __enter__(self): return self @@ -188,27 +205,25 @@ def __exit__(self, type, value, traceback): self.unregister() def unregister(self): - if self._original_info is not None: - for module in self._original_info: - module.set_num_threads(module.num_threads) + for lib_controller in self._controller.lib_controllers: + # Since we never call get_num_threads after instanciation of + # ThreadpoolController, num_threads holds the original value. + lib_controller.set_num_threads(lib_controller.num_threads) def get_original_num_threads(self): """Original num_threads from before calling threadpool_limits Return a dict `{user_api: num_threads}`. """ - if self._original_info is not None: - original_limits = self._original_info - else: - original_limits = _ThreadpoolInfo(user_api=self._user_api) - num_threads = {} warning_apis = [] for user_api in self._user_api: limits = [ - module.num_threads - for module in original_limits.get_modules("user_api", user_api) + lib_controller.num_threads + for lib_controller in self._controller.select( + user_api=user_api + ).lib_controllers ] limits = set(limits) n_limits = len(limits) @@ -250,22 +265,27 @@ def _check_params(self, limits, user_api): prefixes = [] else: if isinstance(limits, list): - # This should be a list of dicts of modules, for compatibility - # with the result from threadpool_info. - limits = {module["prefix"]: module["num_threads"] for module in limits} - elif isinstance(limits, _ThreadpoolInfo): - # To set the limits from the modules of a _ThreadpoolInfo - # object. - limits = {module.prefix: module.num_threads for module in limits} + # This should be a list of dicts of library info, for + # compatibility with the result from threadpool_info. + limits = { + lib_info["prefix"]: lib_info["num_threads"] for lib_info in limits + } + elif isinstance(limits, ThreadpoolController): + # To set the limits from the library controllers of a + # ThreadpoolController object. + limits = { + lib_controller.prefix: lib_controller.num_threads + for lib_controller in limits.lib_controllers + } if not isinstance(limits, dict): raise TypeError( "limits must either be an int, a list or a " - f"dict. Got {type(limits)} instead." + f"dict. Got {type(limits)} instead" ) - # With a dictionary, can set both specific limit for given modules - # and global limit for user_api. Fetch each separately. + # With a dictionary, can set both specific limit for given + # libraries and global limit for user_api. Fetch each separately. prefixes = [prefix for prefix in limits if prefix in _ALL_PREFIXES] user_api = [api for api in limits if api in _ALL_USER_APIS] @@ -274,65 +294,40 @@ def _check_params(self, limits, user_api): def _set_threadpool_limits(self): """Change the maximal number of threads in selected thread pools. - Return a list with all the supported modules that have been found + Return a list with all the supported libraries that have been found matching `self._prefixes` and `self._user_api`. """ if self._limits is None: return None - modules = _ThreadpoolInfo(prefixes=self._prefixes, user_api=self._user_api) - for module in modules: + for lib_controller in self._controller.lib_controllers: # self._limits is a dict {key: num_threads} where key is either - # a prefix or a user_api. If a module matches both, the limit - # corresponding to the prefix is chosed. - if module.prefix in self._limits: - num_threads = self._limits[module.prefix] + # a prefix or a user_api. If a library matches both, the limit + # corresponding to the prefix is chosen. + if lib_controller.prefix in self._limits: + num_threads = self._limits[lib_controller.prefix] + elif lib_controller.user_api in self._limits: + num_threads = self._limits[lib_controller.user_api] else: - num_threads = self._limits[module.user_api] + continue if num_threads is not None: - module.set_num_threads(num_threads) - return modules + lib_controller.set_num_threads(num_threads) -# The object oriented API of _ThreadpoolInfo and its modules is private. -# The public API (i.e. the "threadpool_info" function) only exposes the -# "list of dicts" representation returned by the .todicts method. @_format_docstring( PREFIXES=", ".join(f'"{prefix}"' for prefix in _ALL_PREFIXES), USER_APIS=", ".join(f'"{api}"' for api in _ALL_USER_APIS), BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES), OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES), ) -class _ThreadpoolInfo: - """Collection of all supported modules that have been found +class ThreadpoolController: + """Collection of LibController objects for all loaded supported libraries - Parameters + Attributes ---------- - user_api : list of user APIs or None (default=None) - Select libraries matching the requested API. Ignored if `modules` is - not None. Supported user APIs are {USER_APIS}. - - - "blas" selects all BLAS supported libraries ({BLAS_LIBS}) - - "openmp" selects all OpenMP supported libraries ({OPENMP_LIBS}) - - If None, libraries are not selected by their `user_api`. - - prefixes : list of prefixes or None (default=None) - Select libraries matching the requested prefixes. Supported prefixes - are {PREFIXES}. - If None, libraries are not selected by their prefix. Ignored if - `modules` is not None. - - modules : list of _Module objects or None (default=None) - Wraps a list of _Module objects into a _ThreapoolInfo object. Does not - load or reload any shared library. If it is not None, `prefixes` and - `user_api` are ignored. - - Note - ---- - Is is possible to select libraries both by prefixes and by user_api. All - libraries matching one or the other will be selected. + lib_controllers : list of `LibController` objects + The list of library controllers of all loaded supported libraries. """ # Cache for libc under POSIX and a few system libraries under Windows. @@ -341,49 +336,113 @@ class _ThreadpoolInfo: # during the lifetime of a program. _system_libraries = dict() - def __init__(self, user_api=None, prefixes=None, modules=None): - if modules is None: - self.prefixes = [] if prefixes is None else prefixes - self.user_api = [] if user_api is None else user_api + def __init__(self): + self.lib_controllers = [] + self._load_libraries() + self._warn_if_incompatible_openmp() - self.modules = [] - self._load_modules() - self._warn_if_incompatible_openmp() - else: - self.modules = modules + @classmethod + def _from_controllers(cls, lib_controllers): + new_controller = cls.__new__(cls) + new_controller.lib_controllers = lib_controllers + return new_controller + + def info(self): + """Return lib_controllers info as a list of dicts""" + return [lib_controller.info() for lib_controller in self.lib_controllers] + + def select(self, **kwargs): + """Return a ThreadpoolController containing a subset of its current + library controllers + + It will select all libraries matching at least one pair (key, value) from kwargs + where key is an entry of the library info dict (like "user_api", "internal_api", + "prefix", ...) and value is the value or a list of acceptable values for that + entry. + + For instance, `ThreadpoolController().select(internal_api=["blis", "openblas"])` + will select all library controllers whose internal_api is either "blis" or + "openblas". + """ + for key, vals in kwargs.items(): + kwargs[key] = [vals] if not isinstance(vals, list) else vals + + lib_controllers = [ + lib_controller + for lib_controller in self.lib_controllers + if any( + getattr(lib_controller, key, None) in vals + for key, vals in kwargs.items() + ) + ] + + return ThreadpoolController._from_controllers(lib_controllers) - def get_modules(self, key, values): - """Return all modules such that values contains module[key]""" - if key == "user_api" and values is None: - values = list(_ALL_USER_APIS) - if not isinstance(values, list): - values = [values] - modules = [module for module in self.modules if getattr(module, key) in values] - return _ThreadpoolInfo(modules=modules) + @_format_docstring( + USER_APIS=", ".join('"{}"'.format(api) for api in _ALL_USER_APIS), + BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES), + OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES), + ) + def limit(self, *, limits=None, user_api=None): + """Change the maximal number of threads that can be used in thread pools. - def todicts(self): - """Return info as a list of dicts""" - return [module.todict() for module in self.modules] + This function returns an object that can be used either as a callable (the + construction of this object limits the number of threads) or as a context + manager, in a `with` block to automatically restore the original state of the + controlled libraries when exiting the block. - def __len__(self): - return len(self.modules) + Set the maximal number of threads that can be used in thread pools used in + the supported libraries to `limits`. This function works for libraries that + are already loaded in the interpreter and can be changed dynamically. + + Parameters + ---------- + limits : int, dict or None (default=None) + The maximal number of threads that can be used in thread pools + + - If int, sets the maximum number of threads to `limits` for each + library selected by `user_api`. + + - If it is a dictionary `{{key: max_threads}}`, this function sets a + custom maximum number of threads for each `key` which can be either a + `user_api` or a `prefix` for a specific library. + + - If None, this function does not do anything. + + user_api : {USER_APIS} or None (default=None) + APIs of libraries to limit. Used only if `limits` is an int. + + - If "blas", it will only limit BLAS supported libraries ({BLAS_LIBS}). + + - If "openmp", it will only limit OpenMP supported libraries + ({OPENMP_LIBS}). Note that it can affect the number of threads used + by the BLAS libraries if they rely on OpenMP. + + - If None, this function will apply to all supported libraries. + """ + return _threadpool_limits(self, limits=limits, user_api=user_api) - def __iter__(self): - yield from self.modules + def restore_limits(self): + """Set the limits back to their original values - def __eq__(self, other): - return self.modules == other.modules + Since get_num_threads is only called once at initialization, the instance keeps + the original num_threads during its whole lifetime. + """ + self.limit(limits=self) + + def __len__(self): + return len(self.lib_controllers) - def _load_modules(self): - """Loop through loaded libraries and store supported ones""" + def _load_libraries(self): + """Loop through loaded shared libraries and store the supported ones""" if sys.platform == "darwin": - self._find_modules_with_dyld() + self._find_libraries_with_dyld() elif sys.platform == "win32": - self._find_modules_with_enum_process_module_ex() + self._find_libraries_with_enum_process_module_ex() else: - self._find_modules_with_dl_iterate_phdr() + self._find_libraries_with_dl_iterate_phdr() - def _find_modules_with_dl_iterate_phdr(self): + def _find_libraries_with_dl_iterate_phdr(self): """Loop through loaded libraries and return binders on supported ones This function is expected to work on POSIX system only. @@ -398,15 +457,15 @@ def _find_modules_with_dl_iterate_phdr(self): return [] # Callback function for `dl_iterate_phdr` which is called for every - # module loaded in the current process until it returns 1. - def match_module_callback(info, size, data): - # Get the path of the current module + # library loaded in the current process until it returns 1. + def match_library_callback(info, size, data): + # Get the path of the current library filepath = info.contents.dlpi_name if filepath: filepath = filepath.decode("utf-8") - # Store the module if it is supported and selected - self._make_module_from_path(filepath) + # Store the library controller if it is supported and selected + self._make_controller_from_path(filepath) return 0 c_func_signature = ctypes.CFUNCTYPE( @@ -415,12 +474,12 @@ def match_module_callback(info, size, data): ctypes.c_size_t, ctypes.c_char_p, ) - c_match_module_callback = c_func_signature(match_module_callback) + c_match_library_callback = c_func_signature(match_library_callback) data = ctypes.c_char_p(b"") - libc.dl_iterate_phdr(c_match_module_callback, data) + libc.dl_iterate_phdr(c_match_library_callback, data) - def _find_modules_with_dyld(self): + def _find_libraries_with_dyld(self): """Loop through loaded libraries and return binders on supported ones This function is expected to work on OSX system only @@ -436,10 +495,10 @@ def _find_modules_with_dyld(self): filepath = ctypes.string_at(libc._dyld_get_image_name(i)) filepath = filepath.decode("utf-8") - # Store the module if it is supported and selected - self._make_module_from_path(filepath) + # Store the library controller if it is supported and selected + self._make_controller_from_path(filepath) - def _find_modules_with_enum_process_module_ex(self): + def _find_libraries_with_enum_process_module_ex(self): """Loop through loaded libraries and return binders on supported ones This function is expected to work on windows system only. @@ -451,7 +510,7 @@ def _find_modules_with_enum_process_module_ex(self): PROCESS_QUERY_INFORMATION = 0x0400 PROCESS_VM_READ = 0x0010 - LIST_MODULES_ALL = 0x03 + LIST_LIBRARIES_ALL = 0x03 ps_api = self._get_windll("Psapi") kernel_32 = self._get_windll("kernel32") @@ -460,7 +519,7 @@ def _find_modules_with_enum_process_module_ex(self): PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, False, os.getpid() ) if not h_process: # pragma: no cover - raise OSError("Could not open PID %s" % os.getpid()) + raise OSError(f"Could not open PID {os.getpid()}") try: buf_count = 256 @@ -475,7 +534,7 @@ def _find_modules_with_enum_process_module_ex(self): ctypes.byref(buf), buf_size, ctypes.byref(needed), - LIST_MODULES_ALL, + LIST_LIBRARIES_ALL, ): raise OSError("EnumProcessModulesEx failed") if buf_size >= needed.value: @@ -485,7 +544,7 @@ def _find_modules_with_enum_process_module_ex(self): count = needed.value // (buf_size // buf_count) h_modules = map(HMODULE, buf[:count]) - # Loop through all the module headers and get the module path + # Loop through all the module headers and get the library path buf = ctypes.create_unicode_buffer(MAX_PATH) n_size = DWORD() for h_module in h_modules: @@ -497,38 +556,43 @@ def _find_modules_with_enum_process_module_ex(self): raise OSError("GetModuleFileNameEx failed") filepath = buf.value - # Store the module if it is supported and selected - self._make_module_from_path(filepath) + # Store the library controller if it is supported and selected + self._make_controller_from_path(filepath) finally: kernel_32.CloseHandle(h_process) - def _make_module_from_path(self, filepath): - """Store a module if it is supported and selected""" + def _make_controller_from_path(self, filepath): + """Store a library controller if it is supported and selected""" # Required to resolve symlinks filepath = _realpath(filepath) # `lower` required to take account of OpenMP dll case on Windows # (vcomp, VCOMP, Vcomp, ...) filename = os.path.basename(filepath).lower() - # Loop through supported modules to find if this filename corresponds - # to a supported module. - for module_class, candidate_module in _SUPPORTED_MODULES.items(): + # Loop through supported libraries to find if this filename corresponds + # to a supported one. + for controller_class, candidate_lib in _SUPPORTED_LIBRARIES.items(): # check if filename matches a supported prefix - prefix = self._check_prefix(filename, candidate_module["filename_prefixes"]) + prefix = self._check_prefix(filename, candidate_lib["filename_prefixes"]) # filename does not match any of the prefixes of the candidate - # module. move to next module. + # library. move to next library. if prefix is None: continue - # filename matches a prefix. Check if it matches the request. If - # so, create and store the module. - user_api = candidate_module["user_api"] - internal_api = candidate_module["internal_api"] - if prefix in self.prefixes or user_api in self.user_api: - module_class = globals()[module_class] - module = module_class(filepath, prefix, user_api, internal_api) - self.modules.append(module) + # filename matches a prefix. Create and store the library + # controller. + user_api = candidate_lib["user_api"] + internal_api = candidate_lib["internal_api"] + + lib_controller_class = globals()[controller_class] + lib_controller = lib_controller_class( + filepath=filepath, + prefix=prefix, + user_api=user_api, + internal_api=internal_api, + ) + self.lib_controllers.append(lib_controller) def _check_prefix(self, library_basename, filename_prefixes): """Return the prefix library_basename starts with @@ -546,7 +610,7 @@ def _warn_if_incompatible_openmp(self): # Only raise the warning on linux return - prefixes = [module.prefix for module in self.modules] + prefixes = [lib_controller.prefix for lib_controller in self.lib_controllers] msg = textwrap.dedent( """ Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at @@ -587,42 +651,34 @@ def _get_windll(cls, dll_name): USER_APIS=", ".join('"{}"'.format(api) for api in _ALL_USER_APIS), INTERNAL_APIS=", ".join('"{}"'.format(api) for api in _ALL_INTERNAL_APIS), ) -class _Module(ABC): - """Abstract base class for the modules +class LibController(ABC): + """Abstract base class for the individual library controllers - A module is represented by the following information: + A library controller is represented by the following information: - "user_api" : user API. Possible values are {USER_APIS}. - "internal_api" : internal API. Possible values are {INTERNAL_APIS}. - "prefix" : prefix of the shared library's name. - - "filepath" : path to the loaded module. + - "filepath" : path to the loaded library. - "version" : version of the library (if available). - "num_threads" : the current thread limit. - In addition, each module may contain internal_api specific entries. + In addition, each library controller may contain internal_api specific + entries. """ - def __init__(self, filepath=None, prefix=None, user_api=None, internal_api=None): - self.filepath = filepath - self.prefix = prefix + def __init__(self, *, filepath=None, prefix=None, user_api=None, internal_api=None): self.user_api = user_api self.internal_api = internal_api + self.prefix = prefix + self.filepath = filepath self._dynlib = ctypes.CDLL(filepath, mode=_RTLD_NOLOAD) self.version = self.get_version() self.num_threads = self.get_num_threads() - self._get_extra_info() - - def __eq__(self, other): - return self.todict() == other.todict() - def todict(self): + def info(self): """Return relevant info wrapped in a dict""" return {k: v for k, v in vars(self).items() if not k.startswith("_")} - @abstractmethod - def get_version(self): - """Return the version of the shared library""" - pass # pragma: no cover - @abstractmethod def get_num_threads(self): """Return the maximum number of threads available to use""" @@ -634,13 +690,28 @@ def set_num_threads(self, num_threads): pass # pragma: no cover @abstractmethod - def _get_extra_info(self): - """Add additional module specific information""" + def get_version(self): + """Return the version of the shared library""" pass # pragma: no cover -class _OpenBLASModule(_Module): - """Module class for OpenBLAS""" +class OpenBLASController(LibController): + """Controller class for OpenBLAS""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.threading_layer = self._get_threading_layer() + self.architecture = self._get_architecture() + + def get_num_threads(self): + get_func = getattr(self._dynlib, "openblas_get_num_threads", lambda: None) + return get_func() + + def set_num_threads(self, num_threads): + set_func = getattr( + self._dynlib, "openblas_set_num_threads", lambda num_threads: None + ) + return set_func(num_threads) def get_version(self): # None means OpenBLAS is not loaded or version < 0.3.4, since OpenBLAS @@ -655,21 +726,7 @@ def get_version(self): return config[1].decode("utf-8") return None - def get_num_threads(self): - get_func = getattr(self._dynlib, "openblas_get_num_threads", lambda: None) - return get_func() - - def set_num_threads(self, num_threads): - set_func = getattr( - self._dynlib, "openblas_set_num_threads", lambda num_threads: None - ) - return set_func(num_threads) - - def _get_extra_info(self): - self.threading_layer = self.get_threading_layer() - self.architecture = self.get_architecture() - - def get_threading_layer(self): + def _get_threading_layer(self): """Return the threading layer of OpenBLAS""" openblas_get_parallel = getattr(self._dynlib, "openblas_get_parallel", None) if openblas_get_parallel is None: @@ -681,7 +738,8 @@ def get_threading_layer(self): return "pthreads" return "disabled" - def get_architecture(self): + def _get_architecture(self): + """Return the architecture detected by OpenBLAS""" get_corename = getattr(self._dynlib, "openblas_get_corename", None) if get_corename is None: return None @@ -690,16 +748,13 @@ def get_architecture(self): return get_corename().decode("utf-8") -class _BLISModule(_Module): - """Module class for BLIS""" +class BLISController(LibController): + """Controller class for BLIS""" - def get_version(self): - get_version_ = getattr(self._dynlib, "bli_info_get_version_str", None) - if get_version_ is None: - return None - - get_version_.restype = ctypes.c_char_p - return get_version_().decode("utf-8") + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.threading_layer = self._get_threading_layer() + self.architecture = self._get_architecture() def get_num_threads(self): get_func = getattr(self._dynlib, "bli_thread_get_num_threads", lambda: None) @@ -714,11 +769,15 @@ def set_num_threads(self, num_threads): ) return set_func(num_threads) - def _get_extra_info(self): - self.threading_layer = self.get_threading_layer() - self.architecture = self.get_architecture() + def get_version(self): + get_version_ = getattr(self._dynlib, "bli_info_get_version_str", None) + if get_version_ is None: + return None + + get_version_.restype = ctypes.c_char_p + return get_version_().decode("utf-8") - def get_threading_layer(self): + def _get_threading_layer(self): """Return the threading layer of BLIS""" if self._dynlib.bli_info_get_enable_openmp(): return "openmp" @@ -726,7 +785,8 @@ def get_threading_layer(self): return "pthreads" return "disabled" - def get_architecture(self): + def _get_architecture(self): + """Return the architecture detected by BLIS""" bli_arch_query_id = getattr(self._dynlib, "bli_arch_query_id", None) bli_arch_string = getattr(self._dynlib, "bli_arch_string", None) if bli_arch_query_id is None or bli_arch_string is None: @@ -739,8 +799,22 @@ def get_architecture(self): return bli_arch_string(bli_arch_query_id()).decode("utf-8") -class _MKLModule(_Module): - """Module class for MKL""" +class MKLController(LibController): + """Controller class for MKL""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.threading_layer = self._get_threading_layer() + + def get_num_threads(self): + get_func = getattr(self._dynlib, "MKL_Get_Max_Threads", lambda: None) + return get_func() + + def set_num_threads(self, num_threads): + set_func = getattr( + self._dynlib, "MKL_Set_Num_Threads", lambda num_threads: None + ) + return set_func(num_threads) def get_version(self): if not hasattr(self._dynlib, "MKL_Get_Version_String"): @@ -755,20 +829,7 @@ def get_version(self): version = group.groups()[0] return version.strip() - def get_num_threads(self): - get_func = getattr(self._dynlib, "MKL_Get_Max_Threads", lambda: None) - return get_func() - - def set_num_threads(self, num_threads): - set_func = getattr( - self._dynlib, "MKL_Set_Num_Threads", lambda num_threads: None - ) - return set_func(num_threads) - - def _get_extra_info(self): - self.threading_layer = self.get_threading_layer() - - def get_threading_layer(self): + def _get_threading_layer(self): """Return the threading layer of MKL""" # The function mkl_set_threading_layer returns the current threading # layer. Calling it with an invalid threading layer allows us to safely @@ -787,12 +848,8 @@ def get_threading_layer(self): return layer_map[set_threading_layer(-1)] -class _OpenMPModule(_Module): - """Module class for OpenMP""" - - def get_version(self): - # There is no way to get the version number programmatically in OpenMP. - return None +class OpenMPController(LibController): + """Controller class for OpenMP""" def get_num_threads(self): get_func = getattr(self._dynlib, "omp_get_max_threads", lambda: None) @@ -804,8 +861,9 @@ def set_num_threads(self, num_threads): ) return set_func(num_threads) - def _get_extra_info(self): - pass + def get_version(self): + # There is no way to get the version number programmatically in OpenMP. + return None def _main():