Skip to content

Commit

Permalink
use uv package manager
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Aug 22, 2024
1 parent a3ddb0f commit 33b85b9
Show file tree
Hide file tree
Showing 8 changed files with 3,829 additions and 91 deletions.
52 changes: 24 additions & 28 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,15 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- uses: yezz123/setup-uv@v4
with:
uv-version: "0.3.0"
- name: Install standalone dependencies only
run: |
pip install -e .[all]
uv sync --locked --extra all
- name: Test importing Flax
run: |
python -c "import flax"
uv run python -c "import flax"
tests:
name: Run Tests
needs: [cancel-previous, pre-commit, commit-count, test-import]
Expand All @@ -101,51 +104,44 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- uses: yezz123/setup-uv@v4
with:
uv-version: "0.3.0"
- name: Get week and year
id: date_key
run: echo "DATE=$(date +%j)" >> $GITHUB_OUTPUT
# TODO(cgarciae): caching is breaking the install step, disabling for now.
# - name: Cached virtual environment
# id: venv_cache
# uses: actions/cache@v3
# with:
# path: venv
# key: pip-${{ steps.setup_python.outputs.python-version }}-${{ steps.date_key.outputs.DATE }}-${{ hashFiles('**/requirements.txt', 'pyproject.toml') }}
- name: Cached virtual environment
id: venv_cache
uses: actions/cache@v3
with:
path: venv
key: pip-${{ steps.setup_python.outputs.python-version }}-${{ steps.date_key.outputs.DATE }}-${{ hashFiles('uv.lock') }}
- name: Install Dependencies for cache
if: steps.venv_cache.outputs.cache-hit != 'true'
run: |
if [ -d "venv" ]; then rm -rf venv; fi
python3 -m venv venv
venv/bin/python3 -m pip install .[all,testing]
venv/bin/python3 -m pip install tensorflow_datasets[dev]
venv/bin/python3 -m pip install -r docs/requirements.txt
- name: Install Flax
if [ -d ".venv" ]; then rm -rf .venv; fi
uv sync --locked --all-extras
- name: Check lockfile
run: |
venv/bin/python3 -m pip install -e .[all,testing]
uv sync --locked --all-extras
- name: Install JAX
run: |
if [[ "${{ matrix.jax-version }}" == "newest" ]]; then
venv/bin/python3 -m pip install -U jax jaxlib
uv pip install -U jax jaxlib
else
venv/bin/python3 -m pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
uv pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
fi
- name: Cached mypy cache
id: mypy_cache
uses: actions/cache@v3
if: matrix.test-type == 'mypy'
with:
path: .mypy_cache
key: mypy-${{ steps.setup_python.outputs.python-version }}-${{ steps.date_key.outputs.DATE }}
- name: Test with ${{ matrix.test-type }}
run: |
if [[ "${{ matrix.test-type }}" == "doctest" ]]; then
tests/run_all_tests.sh --no-pytest --no-pytype --no-mypy --use-venv
uv run tests/run_all_tests.sh --only-doctest
elif [[ "${{ matrix.test-type }}" == "pytest" ]]; then
tests/run_all_tests.sh --no-doctest --no-pytype --no-mypy --with-cov --use-venv
uv run tests/run_all_tests.sh --only-pytest
elif [[ "${{ matrix.test-type }}" == "pytype" ]]; then
tests/run_all_tests.sh --no-doctest --no-pytest --no-mypy --use-venv
uv run tests/run_all_tests.sh --only-pytype
elif [[ "${{ matrix.test-type }}" == "mypy" ]]; then
tests/run_all_tests.sh --no-doctest --no-pytest --no-pytype --use-venv
uv run tests/run_all_tests.sh --only-mypy
else
echo "Unknown test type: ${{ matrix.test-type }}"
exit 1
Expand Down
7 changes: 6 additions & 1 deletion .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,9 @@ formats:
# Optionally set the version of Python and requirements required to build your docs
python:
install:
- requirements: docs/requirements.txt
- method: pip
path: .
extra_requirements:
- all
- testing
- docs
25 changes: 22 additions & 3 deletions docs/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,14 @@ To contribute code to Flax on GitHub, follow these steps:
```bash
git clone https://github.com/YOUR_USERNAME/flax
cd flax
pip install -e ".[all]"
pip install -e ".[testing]"
pip install -r docs/requirements.txt
pip install -e ".[all,testing,docs]"
```

You can also use [uv](https://docs.astral.sh/uv/) to setup
the development environment:

```bash
uv sync --all-extras
```

5. Set up pre-commit hooks, this will run some automated checks during each `git` commit and
Expand Down Expand Up @@ -130,6 +135,20 @@ To contribute code to Flax on GitHub, follow these steps:
You can learn more in GitHub's [Creating a pull request from a fork
](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork). documentation.
### Adding or updating dependencies
To add or update dependencies, you must use `uv` after
updating the `pyproject.toml` file to ensure that the `uv.lock` file is up-to-date.
```bash
uv sync --all-extras
```
Alternatively use can use `uv add` to add or update the dependencies automatically, for example:
```bash
uv add 'some-package>=1.2.3'
```
### Updating Jupyter Notebooks
We use [jupytext](https://jupytext.readthedocs.io/) to maintain two synced copies of docs
Expand Down
36 changes: 0 additions & 36 deletions docs/requirements.txt

This file was deleted.

6 changes: 4 additions & 2 deletions flax/nnx/examples/lm1b/input_pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
import pathlib
import tempfile

import input_pipeline
import tensorflow_datasets as tfds
from absl.testing import absltest
import tensorflow_datasets as tfds

from configs import default
import input_pipeline

# We just use different values here to verify that the input pipeline uses the
# the correct value for the 3 different datasets.
Expand All @@ -29,6 +30,7 @@


class InputPipelineTest(absltest.TestCase):

def setUp(self):
super().setUp()
self.train_ds, self.eval_ds, self.predict_ds = self._get_datasets()
Expand Down
47 changes: 39 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ authors = [
{name = "Flax team", email = "[email protected]"},
]
dependencies = [
"numpy>=1.22",
"numpy>=1.23.2; python_version>='3.11'",
"numpy>=1.26.0; python_version>='3.12'",
"jax>=0.4.27", # keep in sync with jax-version in .github/workflows/build.yml
# keep in sync with jax-version in .github/workflows/build.yml
"jax>=0.4.27",
"msgpack",
"optax",
"orbax-checkpoint",
Expand All @@ -40,7 +40,7 @@ all = [
"matplotlib", # only needed for tensorboard export
]
testing = [
"clu", # All examples.
"clu",
"clu<=0.0.9; python_version<'3.10'",
"einops",
"gymnasium[atari, accept-rom-license]",
Expand All @@ -55,16 +55,42 @@ testing = [
"pytest-custom_exit_code",
"pytest-xdist",
"pytype",
"sentencepiece", # WMT/LM1B examples
"tensorflow_text>=2.11.0", # WMT/LM1B examples
# WMT/LM1B examples
"sentencepiece",
"tensorflow_text>=2.11.0; platform_system!='Darwin'",
"tensorflow_datasets",
"tensorflow>=2.12.0", # to fix Numpy np.bool8 deprecation error
"torch",
"nbstripout",
"black[jupyter]==23.7.0",
# "pyink==23.5.0", # disabling pyink fow now
"treescope>=0.1.1; python_version>='3.10'",
]
docs = [
"sphinx>=3.3.1",
"sphinx-book-theme",
"Pygments>=2.6.1",
"ipykernel",
"myst_nb",
"nbstripout",
"recommonmark",
"ipython_genutils",
"sphinx-design",
"jupytext==1.13.8",
"dm-haiku",

# Need to pin docutils to 0.16 to make bulleted lists appear correctly on
# ReadTheDocs: https://stackoverflow.com/a/68008428
"docutils==0.16",

# The next packages are for notebooks.
"matplotlib",
"scikit-learn",
# The next packages are used in testcode blocks.
"ml_collections",
# notebooks
"einops",
]
dev = [
"pre-commit>=3.8.0",
]

[project.urls]
homepage = "https://github.com/google/flax"
Expand Down Expand Up @@ -151,6 +177,10 @@ filterwarnings = [
"ignore:.*jax.xla_computation is deprecated.*:DeprecationWarning",
# Orbax warnings inside deprecated `flax.training` package.
"ignore:.*Couldn't find sharding info under RestoreArgs.*:UserWarning",
# numpy: RuntimeWarning: invalid value encountered in cast
"ignore:.*invalid value encountered in cast.*:RuntimeWarning",
# numpy: RuntimeWarning: divide by zero encountered in {not_}equal
"ignore:.*divide by zero encountered in.*:RuntimeWarning",
]

[tool.coverage.report]
Expand Down Expand Up @@ -190,3 +220,4 @@ unfixable = []
[tool.ruff.format]
indent-style = "space"
quote-style = "single"

35 changes: 22 additions & 13 deletions tests/run_all_tests.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#!/bin/bash

# export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
PYTEST_OPTS=
RUN_DOCTEST=true
RUN_MYPY=true
RUN_PYTEST=true
RUN_PYTYPE=true
RUN_DOCTEST=false
RUN_MYPY=false
RUN_PYTEST=false
RUN_PYTYPE=false
GH_VENV=false

for flag in "$@"; do
Expand All @@ -17,17 +18,17 @@ case $flag in
echo " --with-cov: Also generate pytest coverage."
exit
;;
--no-doctest)
RUN_DOCTEST=false
--only-doctest)
RUN_DOCTEST=true
;;
--no-pytest)
RUN_PYTEST=false
--only-pytest)
RUN_PYTEST=true
;;
--no-pytype)
RUN_PYTYPE=false
--only-pytype)
RUN_PYTYPE=true
;;
--no-mypy)
RUN_MYPY=false
--only-mypy)
RUN_MYPY=true
;;
--use-venv)
GH_VENV=true
Expand All @@ -39,9 +40,17 @@ case $flag in
esac
done

# if neither --only-doctest, --only-pytest, --only-pytype, --only-mypy is set, run all tests
if ! $RUN_DOCTEST && ! $RUN_PYTEST && ! $RUN_PYTYPE && ! $RUN_MYPY; then
RUN_DOCTEST=true
RUN_PYTEST=true
RUN_PYTYPE=true
RUN_MYPY=true
fi

# Activate cached virtual env for github CI
if $GH_VENV; then
source $(dirname "$0")/../venv/bin/activate
source $(dirname "$0")/../.venv/bin/activate
fi

echo "====== test config ======="
Expand Down
Loading

0 comments on commit 33b85b9

Please sign in to comment.