Skip to content

Commit

Permalink
Add ci jobs to test BLIS
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremiedbb committed Jul 3, 2019
1 parent 8cc1bad commit 2236f63
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 22 deletions.
13 changes: 13 additions & 0 deletions .azure_pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,19 @@ jobs:
VERSION_PYTHON: '*'
CC_OUTER_LOOP: 'gcc'
CC_INNER_LOOP: 'clang-8'
# Linux environment with numpy linked to BLIS
pylatest_blis_gcc_clang:
PACKAGER: 'conda'
VERSION_PYTHON: '*'
INSTALL_BLIS: 'true'
CC_OUTER_LOOP: 'gcc'
BLIS_CC: 'clang-8'
pylatest_blis_clang_gcc:
PACKAGER: 'conda'
VERSION_PYTHON: '*'
INSTALL_BLIS: 'true'
CC_OUTER_LOOP: 'clang-8'
BLIS_CC: 'gcc'

- template: continuous_integration/posix.yml
parameters:
Expand Down
53 changes: 53 additions & 0 deletions continuous_integration/install_with_blis.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#!/bin/bash

set -e

pushd ..
ABS_PATH=$(pwd)
popd

# Assume Ubuntu: install a recent version of clang and libomp
echo "deb http://apt.llvm.org/xenial/ llvm-toolchain-xenial-8 main" | sudo tee -a /etc/apt/sources.list.d/llvm.list
echo "deb-src http://apt.llvm.org/xenial/ llvm-toolchain-xenial-8 main" | sudo tee -a /etc/apt/sources.list.d/llvm.list
wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add -
sudo apt update
sudo apt install clang-8 libomp-8-dev

# create conda env
conda create -n $VIRTUALENV -q --yes python=$VERSION_PYTHON pip cython
source activate $VIRTUALENV

pushd ..

# build & install blis
mkdir BLIS_install
git clone https://github.com/flame/blis.git
pushd blis
./configure --prefix=$ABS_PATH/BLIS_install --enable-cblas --enable-threading=openmp CC=$BLIS_CC auto
make -j4
make install
popd

# build & install numpy
git clone https://github.com/numpy/numpy.git
pushd numpy
echo "[blis]
libraries = blis
library_dirs = $ABS_PATH/BLIS_install/lib
include_dirs = $ABS_PATH/BLIS_install/include/blis
runtime_library_dirs = $ABS_PATH/BLIS_install/lib" > site.cfg
python setup.py build_ext -i
pip install -e .
popd

popd

python -m pip install -q -r dev-requirements.txt
CFLAGS=-I$ABS_PATH/BLIS_install/include/blis LDFLAGS=-L$ABS_PATH/BLIS_install/lib \
bash ./continuous_integration/build_test_ext.sh

python --version
python -c "import numpy; print('numpy %s' % numpy.__version__)" || echo "no numpy"
python -c "import scipy; print('scipy %s' % scipy.__version__)" || echo "no scipy"

python -m flit install --symlink
5 changes: 5 additions & 0 deletions continuous_integration/posix.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ jobs:
- script: |
continuous_integration/install.sh
displayName: 'Install'
condition: ne(variables['INSTALL_BLIS'], 'true')
- script: |
continuous_integration/install_with_blis.sh
displayName: 'Install'
condition: eq(variables['INSTALL_BLIS'], 'true')
- script: |
continuous_integration/test_script.sh
displayName: 'Test Library'
Expand Down
5 changes: 5 additions & 0 deletions continuous_integration/test_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ elif [[ "$PACKAGER" == "ubuntu" ]]; then
source $VIRTUALENV/bin/activate
fi

# by default BLIS is single-threaded. Enable multi-threading to run the tests
if [[ "$INSTALL_BLIS" == "true" ]]; then
export BLIS_NUM_THREADS=4
fi

set -x
PYTHONPATH="." python continuous_integration/display_versions.py

Expand Down
43 changes: 32 additions & 11 deletions tests/_openmp_test_helper/nested_prange_blas.pyx
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
cimport openmp
from cython.parallel import prange
from cython.parallel import parallel, prange

import numpy as np
from scipy.linalg.cython_blas cimport dgemm

IF USE_BLIS:
cdef extern from 'cblas.h' nogil:
ctypedef enum CBLAS_ORDER:
CblasRowMajor=101
CblasColMajor=102
ctypedef enum CBLAS_TRANSPOSE:
CblasNoTrans=111
CblasTrans=112
CblasConjTrans=113
void dgemm 'cblas_dgemm' (
CBLAS_ORDER Order, CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB, int M, int N,
int K, double alpha, double *A, int lda,
double *B, int ldb, double beta, double *C, int ldc)
ELSE:
from scipy.linalg.cython_blas cimport dgemm

from threadpoolctl import threadpool_info

Expand All @@ -26,17 +42,22 @@ def check_nested_prange_blas(double[:, ::1] A, double[:, ::1] B, int nthreads):
int i
int prange_num_threads

threadpool_infos = None
threadpool_infos = []

for i in prange(n_chunks, num_threads=nthreads, nogil=True):
dgemm(trans, no_trans, &n, &chunk_size, &k,
&alpha, &B[0, 0], &k, &A[i * chunk_size, 0], &k,
&beta, &C[i * chunk_size, 0], &n)
with nogil, parallel(num_threads=nthreads):
with gil:
threadpool_infos.append(threadpool_info())

prange_num_threads = openmp.omp_get_num_threads()
for i in prange(n_chunks):
IF USE_BLIS:
dgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
chunk_size, n, k, alpha, &A[i * chunk_size, 0], k,
&B[0, 0], k, beta, &C[i * chunk_size, 0], n)
ELSE:
dgemm(trans, no_trans, &n, &chunk_size, &k,
&alpha, &B[0, 0], &k, &A[i * chunk_size, 0], &k,
&beta, &C[i * chunk_size, 0], &n)

if i == 0:
with gil:
threadpool_infos = threadpool_info()
prange_num_threads = openmp.omp_get_num_threads()

return np.asarray(C), prange_num_threads, threadpool_infos
7 changes: 6 additions & 1 deletion tests/_openmp_test_helper/setup_nested_prange_blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,24 @@
set_cc_variables("CC_OUTER_LOOP")
openmp_flag = get_openmp_flag()

use_blis = os.getenv('INSTALL_BLIS', False)
libraries = ['blis'] if use_blis else []

ext_modules = [
Extension(
"nested_prange_blas",
["nested_prange_blas.pyx"],
extra_compile_args=openmp_flag,
extra_link_args=openmp_flag
extra_link_args=openmp_flag,
libraries=libraries
)
]

setup(
name='_openmp_test_helper_nested_prange_blas',
ext_modules=cythonize(
ext_modules,
compile_time_env={'USE_BLIS': use_blis},
compiler_directives={'language_level': 3,
'boundscheck': False,
'wraparound': False})
Expand Down
27 changes: 17 additions & 10 deletions tests/test_threadpoolctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,17 @@ def test_nested_prange_blas(nthreads_outer):

blas_info = [module for module in threadpool_info()
if module["user_api"] == "blas"]
for module in threadpool_info():
if is_old_openblas(module):
# OpenBLAS 0.3.3 and older are known to cause an unrecoverable
# deadlock at process shutdown time (after pytest has exited).
pytest.skip("Old OpenBLAS: skipping test to avoid deadlock")

blis_linked = any([module['internal_api'] == 'blis'
for module in threadpool_info()])
if not blis_linked:
# numpy can be linked to BLIS for CBLAS and OpenBLAS for LAPACK. In that
# case this test will run BLIS gemm so no need to skip.
for module in threadpool_info():
if is_old_openblas(module):
# OpenBLAS 0.3.3 and older are known to cause an unrecoverable
# deadlock at process shutdown time (after pytest has exited).
pytest.skip("Old OpenBLAS: skipping test to avoid deadlock")

from ._openmp_test_helper import check_nested_prange_blas
A = np.ones((1000, 10))
Expand All @@ -265,9 +271,10 @@ def test_nested_prange_blas(nthreads_outer):
assert np.allclose(C, np.dot(A, B.T))
assert prange_num_threads == nthreads

nested_blas_info = [module for module in threadpool_infos
if module["user_api"] == "blas"]
for thread_infos in threadpool_infos:
nested_blas_info = [module for module in thread_infos
if module["user_api"] == "blas"]

assert len(nested_blas_info) == len(blas_info)
for module in nested_blas_info:
assert module['num_threads'] == 1
assert len(nested_blas_info) == len(blas_info)
for module in nested_blas_info:
assert module['num_threads'] == 1
4 changes: 4 additions & 0 deletions threadpoolctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ def threadpool_info():
modules = _load_modules(user_api=_ALL_USER_APIS)
for module in modules:
module['num_threads'] = module['get_num_threads']()
# by default BLIS is single-threaded and get_num_threads returns -1.
# we map it to 1 for consistency with other libraries.
if module['num_threads'] == -1 and module['internal_api'] == 'blis':
module['num_threads'] = 1
# Remove the wrapper for the module and its function
del module['set_num_threads'], module['get_num_threads']
del module['dynlib']
Expand Down

0 comments on commit 2236f63

Please sign in to comment.