Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Workaround for openblas + openmp threading layer #114

Merged
merged 16 commits into from
Jan 28, 2022
Merged
2 changes: 1 addition & 1 deletion .azure_pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ stages:
VERSION_PYTHON: '3.8'
CC_OUTER_LOOP: 'gcc'
CC_INNER_LOOP: 'clang-10'
# Linux environment with numpy from conda-forge channel
# Linux environment with numpy from conda-forge channel and openblas-openmp
ogrisel marked this conversation as resolved.
Show resolved Hide resolved
pylatest_conda_forge:
PACKAGER: 'conda-forge'
VERSION_PYTHON: '*'
Expand Down
8 changes: 8 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
- Fixed a detection issue of the BLAS libraires packaged by conda-forge on Windows.
https://github.com/joblib/threadpoolctl/pull/112

- New helper function `threadpoolctl.get_params_for_sequential_blas_under_openmp` and
new method `ThreadpoolController.get_params_for_sequential_blas_under_openmp` that
returns the appropriate params to pass to `threadpool_info` or
`ThreadpoolController.limit` for the specific case when one wants to have sequential
BLAS calls within an OpenMP parallel region. This helper function takes into account
the unexpected behavior of OpenBLAS with the OpenMP threading layer.
https://github.com/joblib/threadpoolctl/pull/114

3.0.0 (2021-10-01)
==================

Expand Down
3 changes: 3 additions & 0 deletions continuous_integration/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ elif [[ "$PACKAGER" == "conda-forge" ]]; then
conda config --prepend channels conda-forge
conda config --set channel_priority strict
TO_INSTALL="python=$VERSION_PYTHON numpy scipy blas[build=$BLAS]"
if [[ "$BLAS" == "openblas" ]]; then
TO_INSTALL="$TO_INSTALL libopenblas=*=*openmp*"
fi
ogrisel marked this conversation as resolved.
Show resolved Hide resolved
make_conda $TO_INSTALL

elif [[ "$PACKAGER" == "pip" ]]; then
Expand Down
26 changes: 24 additions & 2 deletions tests/test_threadpoolctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,30 @@ def test_threadpool_controller_limit():
for lib_controller in blas_controller.lib_controllers
)
# original_blas_controller contains only blas libraries so no opemp library
# should be impacted.
assert openmp_info == original_openmp_info
# should be impacted. This is not True for OpenBLAS with the OpenMP threading
# layer.
if not any(
lib_controller.internal_api == "openblas"
and lib_controller.threading_layer == "openmp"
for lib_controller in blas_controller.lib_controllers
):
assert openmp_info == original_openmp_info


def test_get_params_for_sequential_blas_under_openmp():
controller = ThreadpoolController()
original_info = controller.info()

params = controller.get_params_for_sequential_blas_under_openmp()

if controller.select(
internal_api="openblas", threading_layer="openmp"
).lib_controllers:
assert params["limits"] is None
assert "user_api" not in params
else:
assert params["limits"] == 1
assert params["user_api"] == "blas"


def test_nested_limits():
Expand Down
21 changes: 21 additions & 0 deletions threadpoolctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,15 @@ def __enter__(self):
return self


def get_params_for_sequential_blas_under_openmp():
"""Return appropriate params to use for a sequential BLAS call in an OpenMP loop

This function takes into account the unexpected behavior of OpenBLAS with the
OpenMP threading layer.
"""
return ThreadpoolController().get_params_for_sequential_blas_under_openmp()


@_format_docstring(
USER_APIS=", ".join(f'"{api}"' for api in _ALL_USER_APIS),
BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES),
Expand Down Expand Up @@ -428,6 +437,18 @@ def select(self, **kwargs):

return ThreadpoolController._from_controllers(lib_controllers)

def get_params_for_sequential_blas_under_openmp(self):
"""Return appropriate params to use for a sequential BLAS call in an OpenMP loop

This function takes into account the unexpected behavior of OpenBLAS with the
OpenMP threading layer.
"""
if self.select(
internal_api="openblas", threading_layer="openmp"
).lib_controllers:
return {"limits": None}
return {"limits": 1, "user_api": "blas"}

@_format_docstring(
USER_APIS=", ".join('"{}"'.format(api) for api in _ALL_USER_APIS),
BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES),
Expand Down