Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Make it possible to set limits from already loaded modules #38

Closed
wants to merge 12 commits into from
2 changes: 1 addition & 1 deletion continuous_integration/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ make_conda() {
if [[ "$UNAMESTR" == "Darwin" ]]; then
if [[ "$INSTALL_LIBOMP" == "conda-forge" ]]; then
# Install an OpenMP-enabled clang/llvm from conda-forge
TO_INSTALL="$TO_INSTALL conda-forge::compilers"
TO_INSTALL="$TO_INSTALL conda-forge::compilers conda-forge::llvm-openmp"
export CFLAGS="$CFLAGS -I$CONDA/envs/$VIRTUALENV/include"
export LDFLAGS="$LDFLAGS -Wl,-rpath,$CONDA/envs/$VIRTUALENV/lib -L$CONDA/envs/$VIRTUALENV/lib"

Expand Down
2 changes: 1 addition & 1 deletion continuous_integration/test_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ fi
set -x
PYTHONPATH="." python continuous_integration/display_versions.py

pytest -vlrxXs --junitxml=$JUNITXML --cov=threadpoolctl
pytest -vlrxXs -vv --junitxml=$JUNITXML --cov=threadpoolctl
92 changes: 74 additions & 18 deletions tests/test_threadpoolctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,24 @@ def effective_num_threads(nthreads, max_threads):
return nthreads


def removed_api(infos):
return [{key: val for (key, val) in module.items()
if key not in ("set_num_threads", "get_num_threads")}
for module in infos]

def get_module_from_path(modules, path):
for module in modules:
if module['filepath'] == path:
return module


@pytest.mark.parametrize("use_infos", [True, False])
@pytest.mark.parametrize("prefix", _ALL_PREFIXES)
def test_threadpool_limits_by_prefix(openblas_present, mkl_present, prefix):
original_infos = threadpool_info()
def test_threadpool_limits_by_prefix(openblas_present, mkl_present,
use_infos, prefix):
original_infos = threadpool_info(return_api=use_infos)
original_infos_or_none = original_infos if use_infos else None

mkl_found = any([True for info in original_infos
if info["prefix"] in ('mkl_rt', 'libmkl_rt')])
prefix_found = len([info["prefix"] for info in original_infos
Expand All @@ -38,57 +53,63 @@ def test_threadpool_limits_by_prefix(openblas_present, mkl_present, prefix):
else:
pytest.skip("{} runtime missing".format(prefix))

with threadpool_limits(limits={prefix: 1}):
with threadpool_limits(limits={prefix: 1}, infos=original_infos_or_none):
for module in threadpool_info():
if is_old_openblas(module):
continue
if module["prefix"] == prefix:
assert module["num_threads"] == 1

with threadpool_limits(limits={prefix: 3}):
with threadpool_limits(limits={prefix: 3}, infos=original_infos_or_none):
for module in threadpool_info():
if is_old_openblas(module):
continue
if module["prefix"] == prefix:
assert module["num_threads"] <= 3

assert threadpool_info() == original_infos
assert threadpool_info() == removed_api(original_infos)


@pytest.mark.parametrize("use_infos", [True, False])
@pytest.mark.parametrize("user_api", (None, "blas", "openmp"))
def test_set_threadpool_limits_by_api(user_api):
def test_set_threadpool_limits_by_api(use_infos, user_api):
# Check that the number of threads used by the multithreaded libraries can
# be modified dynamically.
if user_api is None:
user_apis = ("blas", "openmp")
else:
user_apis = (user_api,)

original_infos = threadpool_info()
original_infos = threadpool_info(return_api=use_infos)
original_infos_or_none = original_infos if use_infos else None

with threadpool_limits(limits=1, user_api=user_api):
with threadpool_limits(limits=1, user_api=user_api,
infos=original_infos_or_none):
for module in threadpool_info():
if is_old_openblas(module):
continue
if module["user_api"] in user_apis:
assert module["num_threads"] == 1

with threadpool_limits(limits=3, user_api=user_api):
with threadpool_limits(limits=3, user_api=user_api,
infos=original_infos_or_none):
for module in threadpool_info():
if is_old_openblas(module):
continue
if module["user_api"] in user_apis:
assert module["num_threads"] <= 3

assert threadpool_info() == original_infos
assert threadpool_info() == removed_api(original_infos)


def test_threadpool_limits_function_with_side_effect():
@pytest.mark.parametrize("use_infos", [True, False])
def test_threadpool_limits_function_with_side_effect(use_infos):
# Check that threadpool_limits can be used as a function with
# side effects instead of a context manager.
original_infos = threadpool_info()
original_infos = threadpool_info(return_api=use_infos)
original_infos_or_none = original_infos if use_infos else None

threadpool_limits(limits=1)
threadpool_limits(limits=1, infos=original_infos_or_none)
try:
for module in threadpool_info():
if is_old_openblas(module):
Expand All @@ -99,7 +120,7 @@ def test_threadpool_limits_function_with_side_effect():
# side-effect.
threadpool_limits(limits=original_infos)

assert threadpool_info() == original_infos
assert threadpool_info() == removed_api(original_infos)


def test_set_threadpool_limits_no_limit():
Expand All @@ -111,13 +132,15 @@ def test_set_threadpool_limits_no_limit():
assert threadpool_info() == original_infos


def test_threadpool_limits_manual_unregister():
@pytest.mark.parametrize("use_infos", [True, False])
def test_threadpool_limits_manual_unregister(use_infos):
# Check that threadpool_limits can be used as an object with that hold
# the original state of the threadpools that can be restored thanks to the
# dedicated unregister method
original_infos = threadpool_info()
original_infos = threadpool_info(return_api=use_infos)
original_infos_or_none = original_infos if use_infos else None

limits = threadpool_limits(limits=1)
limits = threadpool_limits(limits=1, infos=original_infos_or_none)
try:
for module in threadpool_info():
if is_old_openblas(module):
Expand All @@ -128,7 +151,7 @@ def test_threadpool_limits_manual_unregister():
# side-effect.
limits.unregister()

assert threadpool_info() == original_infos
assert threadpool_info() == removed_api(original_infos)


def test_threadpool_limits_bad_input():
Expand Down Expand Up @@ -301,3 +324,36 @@ def test_get_original_num_threads(limit):
with pytest.warns(None, match='Multiple value possible'):
expected = min([module['num_threads'] for module in original_infos])
assert original_num_threads['blas'] == expected


@pytest.mark.parametrize("user_api", [None, *_ALL_USER_APIS])
def test_threadpool_limits_in_python_threads(user_api):
from multiprocessing.pool import ThreadPool

if user_api is None:
user_apis = ("blas", "openmp")
else:
user_apis = (user_api,)

original_infos = threadpool_info(return_api=True)

def func(i):
with threadpool_limits(limits=1, user_api=user_api,
infos=original_infos):
new_infos = threadpool_info()

for module in new_infos:
if module['user_api'] in user_apis:
assert module['num_threads'] == 1
else:
corresponding_module = get_module_from_path(
original_infos, module['filepath'])
expected_num_threads = corresponding_module['num_threads']
assert module['num_threads'] == expected_num_threads

pool = ThreadPool(2)
pool.map(func, range(2))
pool.close()
pool.join()

assert threadpool_info() == removed_api(original_infos)
38 changes: 28 additions & 10 deletions threadpoolctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _get_limit(prefix, user_api, limits):

@_format_docstring(ALL_PREFIXES=_ALL_PREFIXES,
INTERNAL_APIS=_ALL_INTERNAL_APIS)
def _set_threadpool_limits(limits, user_api=None):
def _set_threadpool_limits(limits, user_api=None, infos=None):
"""Limit the maximal number of threads for threadpools in supported libs

Set the maximal number of threads that can be used in thread pools used in
Expand Down Expand Up @@ -201,13 +201,20 @@ def _set_threadpool_limits(limits, user_api=None):
prefixes = [module for module in limits if module in _ALL_PREFIXES]
user_api = [module for module in limits if module in _ALL_USER_APIS]

modules = _load_modules(prefixes=prefixes, user_api=user_api)
if infos is not None:
modules = [
module for module in infos
if _match_module(module, module['prefix'], prefixes, user_api)]
else:
modules = _load_modules(prefixes=prefixes, user_api=user_api)

for module in modules:
# Workaround clang bug (TODO: report it)
module['get_num_threads']()

for module in modules:
module['num_threads'] = module['get_num_threads']()
module['num_threads'] = _formatted_num_threads(module)
num_threads = _get_limit(module['prefix'], module['user_api'], limits)
if num_threads is not None:
set_func = module['set_num_threads']
Expand All @@ -217,7 +224,7 @@ def _set_threadpool_limits(limits, user_api=None):


@_format_docstring(INTERNAL_APIS=_ALL_INTERNAL_APIS)
def threadpool_info():
def threadpool_info(return_api=False):
"""Return the maximal number of threads for each detected library.

Return a list with all the supported modules that have been found. Each
Expand All @@ -227,23 +234,34 @@ def threadpool_info():
- 'internal_api': internal API. Possible values are {INTERNAL_APIS}.
- 'version': version of the library implemented (if available).
- 'num_threads': the current thread limit.

If ``return_api``, the dict also contains pointers to the internal API
functions:
- 'set_num_threads'
- 'get_num_threads'
"""
infos = []
modules = _load_modules(user_api=_ALL_USER_APIS)
for module in modules:
module['num_threads'] = module['get_num_threads']()
# by default BLIS is single-threaded and get_num_threads returns -1.
# we map it to 1 for consistency with other libraries.
if module['num_threads'] == -1 and module['internal_api'] == 'blis':
module['num_threads'] = 1
module['num_threads'] = _formatted_num_threads(module)
# Remove the wrapper for the module and its function
del module['set_num_threads'], module['get_num_threads']
del module['dynlib']
del module['filename_prefixes']
if not return_api:
del module['set_num_threads'], module['get_num_threads']
infos.append(module)
return infos


def _formatted_num_threads(module):
# by default BLIS is single-threaded and get_num_threads returns -1.
# we map it to 1 for consistency with other libraries.
if module['num_threads'] == -1 and module['internal_api'] == 'blis':
return 1
return module['num_threads']


def _get_version(dynlib, internal_api):
if internal_api == "mkl":
return _get_mkl_version(dynlib)
Expand Down Expand Up @@ -533,12 +551,12 @@ class threadpool_limits:
limited. Note that the latter can affect the number of threads used by the
BLAS libraries if they rely on OpenMP.
"""
def __init__(self, limits=None, user_api=None):
def __init__(self, limits=None, user_api=None, infos=None):
self._user_api = _ALL_USER_APIS if user_api is None else [user_api]

if limits is not None:
self._original_limits = _set_threadpool_limits(
limits=limits, user_api=user_api)
limits=limits, user_api=user_api, infos=infos)
else:
self._original_limits = None

Expand Down