Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] ENH Add BLIS support #23

Merged
merged 13 commits into from
Sep 9, 2019
25 changes: 25 additions & 0 deletions .azure_pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,31 @@ jobs:
VERSION_PYTHON: '*'
CC_OUTER_LOOP: 'gcc'
CC_INNER_LOOP: 'clang-8'
# Linux environment with numpy linked to BLIS
pylatest_blis_gcc_clang:
PACKAGER: 'conda'
VERSION_PYTHON: '*'
INSTALL_BLIS: 'true'
BLIS_NUM_THREADS: '4'
CC_OUTER_LOOP: 'gcc'
CC_INNER_LOOP: 'gcc'
BLIS_CC: 'clang-8'
pylatest_blis_clang_gcc:
PACKAGER: 'conda'
VERSION_PYTHON: '*'
INSTALL_BLIS: 'true'
BLIS_NUM_THREADS: '4'
CC_OUTER_LOOP: 'clang-8'
CC_INNER_LOOP: 'clang-8'
BLIS_CC: 'gcc'
pylatest_blis_sinlge_threaded:
PACKAGER: 'conda'
VERSION_PYTHON: '*'
INSTALL_BLIS: 'true'
BLIS_NUM_THREADS: '1'
CC_OUTER_LOOP: 'gcc'
CC_INNER_LOOP: 'gcc'
BLIS_CC: 'gcc'

- template: continuous_integration/posix.yml
parameters:
Expand Down
53 changes: 53 additions & 0 deletions continuous_integration/install_with_blis.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#!/bin/bash

set -e

pushd ..
ABS_PATH=$(pwd)
popd

# Assume Ubuntu: install a recent version of clang and libomp
echo "deb http://apt.llvm.org/xenial/ llvm-toolchain-xenial-8 main" | sudo tee -a /etc/apt/sources.list.d/llvm.list
echo "deb-src http://apt.llvm.org/xenial/ llvm-toolchain-xenial-8 main" | sudo tee -a /etc/apt/sources.list.d/llvm.list
wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add -
sudo apt update
sudo apt install clang-8 libomp-8-dev

# create conda env
conda create -n $VIRTUALENV -q --yes python=$VERSION_PYTHON pip cython
source activate $VIRTUALENV

pushd ..

# build & install blis
mkdir BLIS_install
git clone https://github.com/flame/blis.git
pushd blis
./configure --prefix=$ABS_PATH/BLIS_install --enable-cblas --enable-threading=openmp CC=$BLIS_CC auto
make -j4
make install
popd

# build & install numpy
git clone https://github.com/numpy/numpy.git
pushd numpy
echo "[blis]
libraries = blis
library_dirs = $ABS_PATH/BLIS_install/lib
include_dirs = $ABS_PATH/BLIS_install/include/blis
runtime_library_dirs = $ABS_PATH/BLIS_install/lib" > site.cfg
python setup.py build_ext -i
pip install -e .
popd

popd

python -m pip install -q -r dev-requirements.txt
CFLAGS=-I$ABS_PATH/BLIS_install/include/blis LDFLAGS=-L$ABS_PATH/BLIS_install/lib \
bash ./continuous_integration/build_test_ext.sh

python --version
python -c "import numpy; print('numpy %s' % numpy.__version__)" || echo "no numpy"
python -c "import scipy; print('scipy %s' % scipy.__version__)" || echo "no scipy"

python -m flit install --symlink
7 changes: 6 additions & 1 deletion continuous_integration/posix.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@ jobs:
condition: eq('${{ parameters.name }}', 'macOS')
- script: |
continuous_integration/install.sh
displayName: 'Install'
displayName: 'Install without BLIS'
condition: ne(variables['INSTALL_BLIS'], 'true')
- script: |
continuous_integration/install_with_blis.sh
displayName: 'Install with BLIS'
condition: eq(variables['INSTALL_BLIS'], 'true')
- script: |
continuous_integration/test_script.sh
displayName: 'Test Library'
Expand Down
47 changes: 35 additions & 12 deletions tests/_openmp_test_helper/nested_prange_blas.pyx
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
cimport openmp
from cython.parallel import prange
from cython.parallel import parallel, prange

import numpy as np
from scipy.linalg.cython_blas cimport dgemm

IF USE_BLIS:
cdef extern from 'cblas.h' nogil:
ctypedef enum CBLAS_ORDER:
CblasRowMajor=101
CblasColMajor=102
ctypedef enum CBLAS_TRANSPOSE:
CblasNoTrans=111
CblasTrans=112
CblasConjTrans=113
void dgemm 'cblas_dgemm' (
CBLAS_ORDER Order, CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB, int M, int N,
int K, double alpha, double *A, int lda,
double *B, int ldb, double beta, double *C, int ldc)
ELSE:
from scipy.linalg.cython_blas cimport dgemm

from threadpoolctl import threadpool_info

Expand All @@ -25,18 +41,25 @@ def check_nested_prange_blas(double[:, ::1] A, double[:, ::1] B, int nthreads):

int i
int prange_num_threads
int *prange_num_threads_ptr = &prange_num_threads

threadpool_infos = None
threadpool_infos = [None]

for i in prange(n_chunks, num_threads=nthreads, nogil=True):
dgemm(trans, no_trans, &n, &chunk_size, &k,
&alpha, &B[0, 0], &k, &A[i * chunk_size, 0], &k,
&beta, &C[i * chunk_size, 0], &n)
with nogil, parallel(num_threads=nthreads):
if openmp.omp_get_thread_num() == 0:
with gil:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I figured out what was the cause of the hanging. Note exactly the cause but at least which line of the code was faulty. It happens if we try to call threadpool_info in another thread than the main thread (I was doing that in a previous commit).

I don't know why but it's probably related to the bad state of OpenMP threadpool in non main thread when there are multiple openmp loaded.

Anyway, it's unrelated to this PR and should be investigated separately.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I see that you opened #27 to track this issue.

threadpool_infos[0] = threadpool_info()

prange_num_threads = openmp.omp_get_num_threads()
prange_num_threads_ptr[0] = openmp.omp_get_num_threads()

if i == 0:
with gil:
threadpool_infos = threadpool_info()
for i in prange(n_chunks):
IF USE_BLIS:
dgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
chunk_size, n, k, alpha, &A[i * chunk_size, 0], k,
&B[0, 0], k, beta, &C[i * chunk_size, 0], n)
ELSE:
dgemm(trans, no_trans, &n, &chunk_size, &k,
&alpha, &B[0, 0], &k, &A[i * chunk_size, 0], &k,
&beta, &C[i * chunk_size, 0], &n)

return np.asarray(C), prange_num_threads, threadpool_infos
return np.asarray(C), prange_num_threads, threadpool_infos[0]
7 changes: 6 additions & 1 deletion tests/_openmp_test_helper/setup_nested_prange_blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,24 @@
set_cc_variables("CC_OUTER_LOOP")
openmp_flag = get_openmp_flag()

use_blis = os.getenv('INSTALL_BLIS', False)
libraries = ['blis'] if use_blis else []

ext_modules = [
Extension(
"nested_prange_blas",
["nested_prange_blas.pyx"],
extra_compile_args=openmp_flag,
extra_link_args=openmp_flag
extra_link_args=openmp_flag,
libraries=libraries
)
]

setup(
name='_openmp_test_helper_nested_prange_blas',
ext_modules=cythonize(
ext_modules,
compile_time_env={'USE_BLIS': use_blis},
compiler_directives={'language_level': 3,
'boundscheck': False,
'wraparound': False})
Expand Down
16 changes: 11 additions & 5 deletions tests/test_threadpoolctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,17 @@ def test_nested_prange_blas(nthreads_outer):

blas_info = [module for module in threadpool_info()
if module["user_api"] == "blas"]
for module in threadpool_info():
if is_old_openblas(module):
# OpenBLAS 0.3.3 and older are known to cause an unrecoverable
# deadlock at process shutdown time (after pytest has exited).
pytest.skip("Old OpenBLAS: skipping test to avoid deadlock")

blis_linked = any([module['internal_api'] == 'blis'
for module in threadpool_info()])
if not blis_linked:
# numpy can be linked to BLIS for CBLAS and OpenBLAS for LAPACK. In that
# case this test will run BLIS gemm so no need to skip.
for module in threadpool_info():
if is_old_openblas(module):
# OpenBLAS 0.3.3 and older are known to cause an unrecoverable
# deadlock at process shutdown time (after pytest has exited).
pytest.skip("Old OpenBLAS: skipping test to avoid deadlock")

from ._openmp_test_helper import check_nested_prange_blas
A = np.ones((1000, 10))
Expand Down
24 changes: 22 additions & 2 deletions threadpoolctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ class _dl_phdr_info(ctypes.Structure):
"internal_api": "mkl",
"filename_prefixes": ("libmkl_rt", "mkl_rt",),
},
{
"user_api": "blas",
"internal_api": "blis",
"filename_prefixes": ("libblis",),
},
]

# map a internal_api (openmp, openblas, mkl) to set and get functions
Expand All @@ -88,6 +93,9 @@ class _dl_phdr_info(ctypes.Structure):
"mkl": {
"set_num_threads": "MKL_Set_Num_Threads",
"get_num_threads": "MKL_Get_Max_Threads"},
"blis": {
"set_num_threads": "bli_thread_set_num_threads",
"get_num_threads": "bli_thread_get_num_threads"}
}

# Helpers for the doc and test names
Expand All @@ -110,9 +118,8 @@ def decorator(o):
def _get_limit(prefix, user_api, limits):
if prefix in limits:
return limits[prefix]
if user_api in limits:
else:
return limits[user_api]
return None


@_format_docstring(ALL_PREFIXES=_ALL_PREFIXES,
Expand Down Expand Up @@ -210,6 +217,10 @@ def threadpool_info():
modules = _load_modules(user_api=_ALL_USER_APIS)
for module in modules:
module['num_threads'] = module['get_num_threads']()
# by default BLIS is single-threaded and get_num_threads returns -1.
# we map it to 1 for consistency with other libraries.
if module['num_threads'] == -1 and module['internal_api'] == 'blis':
module['num_threads'] = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is allegedly not covered by any test according to codecov. Do you think it would be possible to make sure it's covered? I believe the only way is to introduce a new build configuration without export BLIS_NUM_THREADS=4.

# Remove the wrapper for the module and its function
del module['set_num_threads'], module['get_num_threads']
del module['dynlib']
Expand All @@ -227,6 +238,8 @@ def _get_version(dynlib, internal_api):
return None
elif internal_api == "openblas":
return _get_openblas_version(dynlib)
elif internal_api == "blis":
return _get_blis_version(dynlib)
else:
raise NotImplementedError("Unsupported API {}".format(internal_api))

Expand Down Expand Up @@ -257,6 +270,13 @@ def _get_openblas_version(openblas_dynlib):
return None


def _get_blis_version(blis_dynlib):
"""Return the BLIS version"""
get_version = getattr(blis_dynlib, "bli_info_get_version_str")
get_version.restype = ctypes.c_char_p
return get_version().decode('utf-8')


# Loading utilities for dynamically linked shared objects

def _load_modules(prefixes=None, user_api=None):
Expand Down