[nnx] use jax-style transforms API in nnx_basics #13891
Workflow file for this run
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions | |
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions | |
name: Build | |
on: | |
push: | |
branches: | |
- main | |
- 'test_*' | |
pull_request: | |
branches: | |
- main | |
jobs: | |
cancel-previous: | |
name: Cancel Previous Runs | |
runs-on: ubuntu-latest | |
steps: | |
- name: Cancel previous | |
uses: styfle/[email protected] | |
if: ${{github.ref != 'refs/head/main'}} | |
with: | |
access_token: ${{ github.token }} | |
pre-commit: | |
name: Test pre-commit hooks | |
runs-on: ubuntu-latest | |
steps: | |
- uses: actions/checkout@v3 | |
- name: Set up Python | |
uses: actions/setup-python@v4 | |
with: | |
python-version: '3.10' | |
- uses: pre-commit/[email protected] | |
commit-count: | |
name: Check commit count | |
runs-on: ubuntu-latest | |
steps: | |
- uses: actions/checkout@v3 | |
# We allow at most 5 commits in a branch to ensure our CI doesn't break. | |
- name: Check commit count in PR | |
if: always() | |
shell: bash | |
run: | | |
set -x | |
# $GITHUB_REF is in format `refs/heads/<branch_name>`. We fetch it under | |
# the name `commit-count` so we can refer to it below. | |
# Do an unshallow fetch so we retrieve all commits (this is necessary | |
# because ations/checkout@v2 fetches a shallow copy). | |
git fetch origin --unshallow $GITHUB_REF:commit-count | |
git fetch origin main | |
diff=$(git rev-list --count origin/main...commit-count) | |
# $GITHUB_REF adds an additional commit to the commit tree, so $diff is | |
# one too high when executing this as a Github Action. | |
if (( $diff > 6)); then | |
echo "ERROR! More than 5 commits in PR -- please squash your commits." | |
url=https://flax.readthedocs.io/en/latest/contributing.html#too-many-commits-in-a-pull-request | |
echo "See $url for help on how to resolve this." | |
exit 1 | |
fi | |
test-import: | |
name: Test import standalone | |
runs-on: ubuntu-latest | |
strategy: | |
matrix: | |
python-version: ['3.10', '3.11'] | |
steps: | |
- uses: actions/checkout@v3 | |
- name: Set up Python ${{ matrix.python-version }} | |
uses: actions/setup-python@v4 | |
with: | |
python-version: ${{ matrix.python-version }} | |
- uses: yezz123/setup-uv@v4 | |
with: | |
uv-version: "0.3.0" | |
- name: Install standalone dependencies only | |
run: | | |
uv sync --locked --extra all | |
- name: Test importing Flax | |
run: | | |
uv run python -c "import flax" | |
tests: | |
name: Run Tests | |
needs: [cancel-previous, pre-commit, commit-count, test-import] | |
runs-on: ubuntu-20.04-16core | |
strategy: | |
matrix: | |
python-version: ['3.10', '3.11'] | |
test-type: [doctest, pytest, pytype, mypy] | |
jax-version: [newest] | |
exclude: | |
- test-type: pytype | |
python-version: '3.10' | |
- test-type: mypy | |
python-version: '3.11' | |
include: | |
- python-version: '3.10' | |
test-type: pytest | |
jax-version: '0.4.27' # keep in sync with jax pin in pyproject.toml | |
steps: | |
- uses: actions/checkout@v3 | |
- name: Set up Python ${{ matrix.python-version }} | |
id: setup_python | |
uses: actions/setup-python@v4 | |
with: | |
python-version: ${{ matrix.python-version }} | |
- uses: yezz123/setup-uv@v4 | |
with: | |
uv-version: "0.3.0" | |
- name: Cached virtual environment | |
id: venv_cache | |
uses: actions/cache@v3 | |
with: | |
path: .venv | |
key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('uv.lock') }} | |
- name: Install Dependencies for cache | |
if: steps.venv_cache.outputs.cache-hit != 'true' | |
run: | | |
if [ -d ".venv" ]; then rm -rf .venv; fi | |
uv sync --locked --all-extras | |
- name: Check lockfile | |
run: | | |
uv sync --locked --all-extras | |
- name: Install JAX | |
run: | | |
if [[ "${{ matrix.jax-version }}" == "newest" ]]; then | |
uv pip install -U jax jaxlib | |
else | |
uv pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" | |
fi | |
- name: Test with ${{ matrix.test-type }} | |
run: | | |
if [[ "${{ matrix.test-type }}" == "doctest" ]]; then | |
uv run tests/run_all_tests.sh --only-doctest | |
elif [[ "${{ matrix.test-type }}" == "pytest" ]]; then | |
uv run tests/run_all_tests.sh --only-pytest | |
elif [[ "${{ matrix.test-type }}" == "pytype" ]]; then | |
uv run tests/run_all_tests.sh --only-pytype | |
elif [[ "${{ matrix.test-type }}" == "mypy" ]]; then | |
uv run tests/run_all_tests.sh --only-mypy | |
else | |
echo "Unknown test type: ${{ matrix.test-type }}" | |
exit 1 | |
fi | |
- name: Upload coverage to Codecov | |
if: matrix.test-type == 'pytest' | |
uses: codecov/codecov-action@v4 | |
env: | |
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} | |
with: | |
file: ./coverage.xml | |
# The below step just reports the success or failure of tests as a "commit status". | |
# This is needed for copybara integration. | |
- name: Report success or failure as github status | |
if: always() | |
shell: bash | |
run: | | |
status="${{ job.status }}" | |
lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') | |
curl -sS --request POST \ | |
--url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ | |
--header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ | |
--header 'content-type: application/json' \ | |
--data '{ | |
"state": "'$lowercase_status'", | |
"target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", | |
"description": "'$status'", | |
"context": "github-actions/Build" | |
}' |