Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Workaround for openblas + openmp threading layer #114

Merged
merged 16 commits into from
Jan 28, 2022
Merged
3 changes: 2 additions & 1 deletion .azure_pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,12 @@ stages:
VERSION_PYTHON: '3.8'
CC_OUTER_LOOP: 'gcc'
CC_INNER_LOOP: 'clang-10'
# Linux environment with numpy from conda-forge channel
# Linux environment with numpy from conda-forge channel and openblas-openmp
ogrisel marked this conversation as resolved.
Show resolved Hide resolved
pylatest_conda_forge:
PACKAGER: 'conda-forge'
VERSION_PYTHON: '*'
BLAS: 'openblas'
OPENBLAS_THREADING_LAYER: 'openmp'
CC_OUTER_LOOP: 'gcc'
CC_INNER_LOOP: 'gcc'
LINT: 'true'
Expand Down
7 changes: 7 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@
- Fixed a detection issue of the BLAS libraires packaged by conda-forge on Windows.
https://github.com/joblib/threadpoolctl/pull/112

- `threadpool_info` and `ThreadpoolController.limit` accept a new value for the `limits`
parameter: the string "sequential_blas_under_openmp". It should only be used for the
ogrisel marked this conversation as resolved.
Show resolved Hide resolved
specific case when one wants to have sequential BLAS calls within an OpenMP parallel
region. It takes into account the unexpected behavior of OpenBLAS with the OpenMP
threading layer.
https://github.com/joblib/threadpoolctl/pull/114

3.0.0 (2021-10-01)
==================

Expand Down
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,13 @@ decorators are accessible through their `wrap` method:
...
```

### Sequential BLAS within OpenMP parallel region

When one wants to have sequential BLAS calls within an OpenMP parallel region, it's
safer to set `limits="sequential_blas_under_openmp"` since setting `limits=1` and `user_api="blas"` might not lead to the expected behavior in some configurations
(e.g. OpenBLAS with the OpenMP threading layer
https://github.com/xianyi/OpenBLAS/issues/2985).

### Known Limitations

- `threadpool_limits` can fail to limit the number of inner threads when nesting
Expand Down
3 changes: 3 additions & 0 deletions continuous_integration/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ elif [[ "$PACKAGER" == "conda-forge" ]]; then
conda config --prepend channels conda-forge
conda config --set channel_priority strict
TO_INSTALL="python=$VERSION_PYTHON numpy scipy blas[build=$BLAS]"
if [[ "$BLAS" == "openblas" && "$OPENBLAS_THREADING_LAYER" == "openmp" ]]; then
TO_INSTALL="$TO_INSTALL libopenblas=*=*openmp*"
fi
ogrisel marked this conversation as resolved.
Show resolved Hide resolved
make_conda $TO_INSTALL

elif [[ "$PACKAGER" == "pip" ]]; then
Expand Down
40 changes: 37 additions & 3 deletions tests/test_threadpoolctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,42 @@ def test_threadpool_controller_limit():
for lib_controller in blas_controller.lib_controllers
)
# original_blas_controller contains only blas libraries so no opemp library
# should be impacted.
assert openmp_info == original_openmp_info
# should be impacted. This is not True for OpenBLAS with the OpenMP threading
# layer.
if not any(
lib_controller.internal_api == "openblas"
and lib_controller.threading_layer == "openmp"
for lib_controller in blas_controller.lib_controllers
):
assert openmp_info == original_openmp_info


def test_get_params_for_sequential_blas_under_openmp():
# Test for the behavior of get_params_for_sequential_blas_under_openmp.
controller = ThreadpoolController()
original_info = controller.info()

params = controller._get_params_for_sequential_blas_under_openmp()

if controller.select(
internal_api="openblas", threading_layer="openmp"
).lib_controllers:
assert params["limits"] is None
assert params["user_api"] is None

with controller.limit(limits="sequential_blas_under_openmp"):
assert controller.info() == original_info

else:
assert params["limits"] == 1
assert params["user_api"] == "blas"

with controller.limit(limits="sequential_blas_under_openmp"):
assert all(
lib_info["num_threads"] == 1
for lib_info in controller.info()
if lib_info["user_api"] == "blas"
)


def test_nested_limits():
Expand Down Expand Up @@ -245,7 +279,7 @@ def test_threadpool_limits_bad_input():
threadpool_limits(limits=1, user_api="wrong")

with pytest.raises(
TypeError, match="limits must either be an int, a list or a dict"
TypeError, match="limits must either be an int, a list, a dict, or"
):
threadpool_limits(limits=(1, 2, 3))

Expand Down
39 changes: 34 additions & 5 deletions threadpoolctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,10 @@ class _ThreadpoolLimiter:
"""

def __init__(self, controller, *, limits=None, user_api=None):
self._controller = controller
self._limits, self._user_api, self._prefixes = self._check_params(
limits, user_api
)
self._controller = controller
self._original_info = self._controller.info()
self._set_threadpool_limits()

Expand Down Expand Up @@ -226,6 +226,13 @@ def get_original_num_threads(self):

def _check_params(self, limits, user_api):
"""Suitable values for the _limits, _user_api and _prefixes attributes"""

if isinstance(limits, str) and limits == "sequential_blas_under_openmp":
(
limits,
user_api,
) = self._controller._get_params_for_sequential_blas_under_openmp().values()

if limits is None or isinstance(limits, int):
if user_api is None:
user_api = _ALL_USER_APIS
Expand Down Expand Up @@ -257,8 +264,8 @@ def _check_params(self, limits, user_api):

if not isinstance(limits, dict):
raise TypeError(
"limits must either be an int, a list or a "
f"dict. Got {type(limits)} instead"
"limits must either be an int, a list, a dict, or "
f"'sequential_blas_under_openmp'. Got {type(limits)} instead"
)

# With a dictionary, can set both specific limit for given
Expand Down Expand Up @@ -333,7 +340,7 @@ class threadpool_limits(_ThreadpoolLimiter):

Parameters
----------
limits : int, dict or None (default=None)
limits : int, dict, 'sequential_blas_under_openmp' or None (default=None)
The maximal number of threads that can be used in thread pools

- If int, sets the maximum number of threads to `limits` for each
Expand All @@ -343,6 +350,11 @@ class threadpool_limits(_ThreadpoolLimiter):
custom maximum number of threads for each `key` which can be either a
`user_api` or a `prefix` for a specific library.

- If 'sequential_blas_under_openmp', it will chose the appropriate `limits`
and `user_api` parameters for the specific use case of sequential BLAS
calls within an OpenMP parallel region. The `user_api` parameter is
ignored.

- If None, this function does not do anything.

user_api : {USER_APIS} or None (default=None)
Expand Down Expand Up @@ -428,6 +440,18 @@ def select(self, **kwargs):

return ThreadpoolController._from_controllers(lib_controllers)

def _get_params_for_sequential_blas_under_openmp(self):
"""Return appropriate params to use for a sequential BLAS call in an OpenMP loop

This function takes into account the unexpected behavior of OpenBLAS with the
OpenMP threading layer.
"""
if self.select(
internal_api="openblas", threading_layer="openmp"
).lib_controllers:
return {"limits": None, "user_api": None}
return {"limits": 1, "user_api": "blas"}

@_format_docstring(
USER_APIS=", ".join('"{}"'.format(api) for api in _ALL_USER_APIS),
BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES),
Expand All @@ -451,7 +475,7 @@ def limit(self, *, limits=None, user_api=None):

Parameters
----------
limits : int, dict or None (default=None)
limits : int, dict, 'sequential_blas_under_openmp' or None (default=None)
The maximal number of threads that can be used in thread pools

- If int, sets the maximum number of threads to `limits` for each
Expand All @@ -461,6 +485,11 @@ def limit(self, *, limits=None, user_api=None):
custom maximum number of threads for each `key` which can be either a
`user_api` or a `prefix` for a specific library.

- If 'sequential_blas_under_openmp', it will chose the appropriate `limits`
and `user_api` parameters for the specific use case of sequential BLAS
calls within an OpenMP parallel region. The `user_api` parameter is
ignored.

- If None, this function does not do anything.

user_api : {USER_APIS} or None (default=None)
Expand Down