Skip to content

jax.lax: deprecate inadvertent exports & internal utilities #40468

jax.lax: deprecate inadvertent exports & internal utilities

jax.lax: deprecate inadvertent exports & internal utilities #40468

Workflow file for this run

name: CI
# We test all supported Python versions as follows:
# - 3.9 : Documentation build
# - 3.9 : Part of Matrix with NumPy dispatch
# - 3.10 : Part of Matrix
# - 3.11 : Part of Matrix
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
pull_request:
branches:
- main
permissions:
contents: read # to fetch code
actions: write # to cancel previous workflows
jobs:
lint_and_typecheck:
runs-on: ubuntu-latest
timeout-minutes: 5
steps:
- name: Cancel previous
uses: styfle/[email protected]
with:
access_token: ${{ github.token }}
if: ${{github.ref != 'refs/heads/main'}}
- uses: actions/checkout@8ade135a41bc03ea155e62e844d188df1ea18608 # ratchet:actions/checkout@v4
- name: Set up Python 3.11
uses: actions/setup-python@v4
with:
python-version: 3.11
- uses: pre-commit/[email protected]
build:
name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ${{ matrix.os }}, x64=${{ matrix.enable-x64}})"
runs-on: ${{ matrix.os }}
timeout-minutes: 60
strategy:
matrix:
include:
- name-prefix: "with numpy-dispatch"
python-version: "3.9"
os: ubuntu-20.04-16core
enable-x64: 1
prng-upgrade: 1
# Test experimental NumPy dispatch
package-overrides: "git+https://github.com/seberg/numpy-dispatch.git"
num_generated_cases: 1
use-latest-jaxlib: false
- name-prefix: "with 3.10"
python-version: "3.10"
os: ubuntu-20.04-16core
enable-x64: 0
prng-upgrade: 0
package-overrides: "none"
num_generated_cases: 1
use-latest-jaxlib: false
- name-prefix: "with 3.11"
python-version: "3.11"
os: ubuntu-20.04-16core
enable-x64: 0
prng-upgrade: 0
package-overrides: "none"
num_generated_cases: 1
use-latest-jaxlib: false
steps:
- name: Cancel previous
uses: styfle/[email protected]
with:
access_token: ${{ github.token }}
if: ${{github.ref != 'refs/heads/main'}}
- uses: actions/checkout@8ade135a41bc03ea155e62e844d188df1ea18608 # ratchet:actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Get pip cache dir
id: pip-cache
run: |
python -m pip install --upgrade pip wheel
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- name: pip cache
uses: actions/cache@v3
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
- name: Install dependencies
run: |
pip install -r build/test-requirements.txt
if [ "${{ matrix.package-overrides }}" != "none" ]; then
pip install ${{ matrix.package-overrides }}
fi
if [ "${{ matrix.use-latest-jaxlib }}" == "true" ]; then
pip install .[cpu]
else
pip install .[minimum-jaxlib]
fi
- name: Run tests
env:
JAX_NUM_GENERATED_CASES: ${{ matrix.num_generated_cases }}
JAX_ENABLE_X64: ${{ matrix.enable-x64 }}
JAX_ENABLE_CUSTOM_PRNG: ${{ matrix.prng-upgrade }}
JAX_THREEFRY_PARTITIONABLE: ${{ matrix.prng-upgrade }}
JAX_ENABLE_CHECKS: true
JAX_SKIP_SLOW_TESTS: true
PY_COLORS: 1
run: |
pip install -e .
echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES"
echo "JAX_ENABLE_X64=$JAX_ENABLE_X64"
echo "JAX_ENABLE_CUSTOM_PRNG=$JAX_ENABLE_CUSTOM_PRNG"
echo "JAX_THREEFRY_PARTITIONABLE=$JAX_THREEFRY_PARTITIONABLE"
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
pytest -n auto --tb=short --maxfail=20 tests examples
documentation:
name: Documentation - test code snippets
runs-on: ubuntu-latest
timeout-minutes: 10
strategy:
matrix:
python-version: [3.9]
steps:
- name: Cancel previous
uses: styfle/[email protected]
with:
access_token: ${{ github.token }}
if: ${{github.ref != 'refs/heads/main'}}
- uses: actions/checkout@8ade135a41bc03ea155e62e844d188df1ea18608 # ratchet:actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Get pip cache dir
id: pip-cache
run: |
python -m pip install --upgrade pip wheel
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- name: pip cache
uses: actions/cache@v3
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
- name: Install dependencies
run: |
pip install -r docs/requirements.txt
- name: Test documentation
env:
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
JAX_ARRAY: 1
PY_COLORS: 1
run: |
pytest -n auto --tb=short docs
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas
documentation_render:
name: Documentation - render documentation
runs-on: ubuntu-latest
timeout-minutes: 10
strategy:
matrix:
python-version: [3.9]
steps:
- name: Cancel previous
uses: styfle/[email protected]
with:
access_token: ${{ github.token }}
if: ${{github.ref != 'refs/heads/main'}}
- uses: actions/checkout@8ade135a41bc03ea155e62e844d188df1ea18608 # ratchet:actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Get pip cache dir
id: pip-cache
run: |
python -m pip install --upgrade pip wheel
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- name: pip cache
uses: actions/cache@v3
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
- name: Install dependencies
run: |
pip install -r docs/requirements.txt
- name: Render documentation
run: |
sphinx-build --color -W --keep-going -b html -D nb_execution_mode=off docs docs/build/html -j auto