Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Accelerate framework blas__ldflags tests #1056

Merged
merged 2 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 32 additions & 8 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,17 @@ jobs:
python-version: ${{ matrix.python-version }}
- uses: pre-commit/[email protected]

test_ubuntu:
name: "Test py${{ matrix.python-version }} : fast-compile ${{ matrix.fast-compile }} : float32 ${{ matrix.float32 }} : ${{ matrix.part }}"
test:
name: "${{ matrix.os }} test py${{ matrix.python-version }} : fast-compile ${{ matrix.fast-compile }} : float32 ${{ matrix.float32 }} : ${{ matrix.part }}"
needs:
- changes
- style
runs-on: ubuntu-latest
runs-on: ${{ matrix.os }}
if: ${{ needs.changes.outputs.changes == 'true' && needs.style.result == 'success' }}
strategy:
fail-fast: false
matrix:
os: ["ubuntu-latest"]
python-version: ["3.10", "3.12"]
fast-compile: [0, 1]
float32: [0, 1]
Expand Down Expand Up @@ -103,30 +104,44 @@ jobs:
fast-compile: 1
include:
- install-numba: 1
os: "ubuntu-latest"
python-version: "3.10"
fast-compile: 0
float32: 0
part: "tests/link/numba"
- install-numba: 1
os: "ubuntu-latest"
python-version: "3.12"
fast-compile: 0
float32: 0
part: "tests/link/numba"
- install-jax: 1
os: "ubuntu-latest"
python-version: "3.10"
fast-compile: 0
float32: 0
part: "tests/link/jax"
- install-jax: 1
os: "ubuntu-latest"
python-version: "3.12"
fast-compile: 0
float32: 0
part: "tests/link/jax"
- install-torch: 1
os: "ubuntu-latest"
python-version: "3.10"
fast-compile: 0
float32: 0
part: "tests/link/pytorch"
- os: macos-latest
python-version: "3.12"
fast-compile: 0
float32: 0
install-numba: 0
install-jax: 0
install-torch: 0
part: "tests/tensor/test_blas.py tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py"

steps:
- uses: actions/checkout@v4
with:
Expand All @@ -146,15 +161,19 @@ jobs:
MATRIX_CONTEXT: ${{ toJson(matrix) }}
run: |
echo $MATRIX_CONTEXT
export MATRIX_ID=`echo $MATRIX_CONTEXT | md5sum | cut -c 1-32`
export MATRIX_ID=`echo $MATRIX_CONTEXT | sha256sum | cut -c 1-32`
echo $MATRIX_ID
echo "id=$MATRIX_ID" >> $GITHUB_OUTPUT

- name: Install dependencies
shell: micromamba-shell {0}
run: |

micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock
if [[ $OS == "macos-latest" ]]; then
micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" numpy scipy pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock libblas=*=*accelerate;
else
micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock;
fi
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
Expand All @@ -163,12 +182,17 @@ jobs:
pip install -e ./
micromamba list && pip freeze
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
python -c 'import pytensor; assert pytensor.config.blas__ldflags != "", "Blas flags are empty"'
if [[ $OS == "macos-latest" ]]; then
python -c 'import pytensor; assert pytensor.config.blas__ldflags.startswith("-framework Accelerate"), "Blas flags are not set to MacOS Accelerate"';
else
python -c 'import pytensor; assert pytensor.config.blas__ldflags != "", "Blas flags are empty"';
fi
env:
PYTHON_VERSION: ${{ matrix.python-version }}
INSTALL_NUMBA: ${{ matrix.install-numba }}
INSTALL_JAX: ${{ matrix.install-jax }}
INSTALL_TORCH: ${{ matrix.install-torch}}
OS: ${{ matrix.os}}

- name: Run tests
shell: micromamba-shell {0}
Expand Down Expand Up @@ -249,10 +273,10 @@ jobs:
if: ${{ always() }}
runs-on: ubuntu-latest
name: "All tests"
needs: [changes, style, test_ubuntu]
needs: [changes, style, test]
steps:
- name: Check build matrix status
if: ${{ needs.changes.outputs.changes == 'true' && (needs.style.result != 'success' || needs.test_ubuntu.result != 'success') }}
if: ${{ needs.changes.outputs.changes == 'true' && (needs.style.result != 'success' || needs.test.result != 'success') }}
run: exit 1

upload-coverage:
Expand Down
38 changes: 34 additions & 4 deletions pytensor/link/c/cmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2458,14 +2458,32 @@ def patch_ldflags(flag_list: list[str]) -> list[str]:
@staticmethod
def linking_patch(lib_dirs: list[str], libs: list[str]) -> list[str]:
if sys.platform != "win32":
return [f"-l{l}" for l in libs]
patched_libs = []
framework = False
for lib in libs:
# The clang framework flag is handled differently.
# The flag will have the format -framework framework_name
# If we find a lib that is called -framework, we keep it and the following
# entry in the lib list unchanged. Anything else, we add the standard
# -l library prefix.
if lib == "-framework":
framework = True
patched_libs.append(lib)
elif framework:
framework = False
patched_libs.append(lib)
else:
patched_libs.append(f"-l{lib}")
return patched_libs
else:
# In explicit else because of https://github.com/python/mypy/issues/10773
def sort_key(lib):
name, *numbers, extension = lib.split(".")
return (extension == "dll", tuple(map(int, numbers)))

patched_lib_ldflags = []
# Should we also add a framework possibility on windows? I didn't do so because
# clang is not intended to be used there at the moment.
for lib in libs:
ldflag = f"-l{lib}"
for lib_dir in lib_dirs:
Expand Down Expand Up @@ -2873,9 +2891,21 @@ def check_libs(
)
except Exception as e:
_logger.debug(e)
try:
# 3. Mac Accelerate framework
_logger.debug("Checking Accelerate framework")
flags = ["-framework", "Accelerate"]
if rpath:
flags = [*flags, f"-Wl,-rpath,{rpath}"]
validated_flags = try_blas_flag(flags)
if validated_flags == "":
raise Exception("Accelerate framework flag failed ")
return validated_flags
except Exception as e:
_logger.debug(e)
try:
_logger.debug("Checking Lapack + blas")
# 3. Try to use LAPACK + BLAS
# 4. Try to use LAPACK + BLAS
return check_libs(
all_libs,
required_libs=["lapack", "blas", "cblas", "m"],
Expand All @@ -2885,7 +2915,7 @@ def check_libs(
except Exception as e:
_logger.debug(e)
try:
# 4. Try to use BLAS alone
# 5. Try to use BLAS alone
_logger.debug("Checking blas alone")
return check_libs(
all_libs,
Expand All @@ -2896,7 +2926,7 @@ def check_libs(
except Exception as e:
_logger.debug(e)
try:
# 5. Try to use openblas
# 6. Try to use openblas
_logger.debug("Checking openblas")
return check_libs(
all_libs,
Expand Down
34 changes: 30 additions & 4 deletions pytensor/tensor/blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@
import functools
import logging
import os
import shlex
import time
from pathlib import Path

import numpy as np

Expand Down Expand Up @@ -396,7 +398,7 @@ def _ldflags(
rval = []
if libs_dir:
found_dyn = False
dirs = [x[2:] for x in ldflags_str.split() if x.startswith("-L")]
dirs = [x[2:] for x in shlex.split(ldflags_str) if x.startswith("-L")]
l = _ldflags(
ldflags_str=ldflags_str,
libs=True,
Expand All @@ -409,14 +411,22 @@ def _ldflags(
if f.endswith(".so") or f.endswith(".dylib") or f.endswith(".dll"):
if any(f.find(ll) >= 0 for ll in l):
found_dyn = True
# Special treatment of clang framework. Specifically for MacOS Accelerate
if "-framework" in l and "Accelerate" in l:
found_dyn = True
if not found_dyn and dirs:
_logger.warning(
"We did not find a dynamic library in the "
"library_dir of the library we use for blas. If you use "
"ATLAS, make sure to compile it with dynamics library."
)

for t in ldflags_str.split():
split_flags = shlex.split(ldflags_str)
skip = False
for pos, t in enumerate(split_flags):
if skip:
skip = False
continue
# Remove extra quote.
if (t.startswith("'") and t.endswith("'")) or (
t.startswith('"') and t.endswith('"')
Expand All @@ -425,10 +435,26 @@ def _ldflags(

try:
t0, t1 = t[0], t[1]
assert t0 == "-"
assert t0 == "-" or Path(t).exists()
except Exception:
raise ValueError(f'invalid token "{t}" in ldflags_str: "{ldflags_str}"')
if libs_dir and t1 == "L":
if t == "-framework":
skip = True
# Special treatment of clang framework. Specifically for MacOS Accelerate
# The clang framework implicitly adds: header dirs, libraries, and library dirs.
# If we choose to always return these flags, we run into a huge deal amount of
# incompatibilities. For this reason, we only return the framework if libs are
# requested.
if (
libs
and len(split_flags) >= pos
and split_flags[pos + 1] == "Accelerate"
):
# We only add the Accelerate framework, but in the future we could extend it to
# other frameworks
rval.append(t)
rval.append(split_flags[pos + 1])
elif libs_dir and t1 == "L":
rval.append(t[2:])
elif include_dir and t1 == "I":
raise ValueError(
Expand Down
Loading
Loading