diff --git a/pyproject.toml b/pyproject.toml index 4a99cfaf8b..26f8e5ae6e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -138,6 +138,8 @@ cu12 = [ jax = [ 'jax>=0.4.33;python_version>="3.10"', 'flax>=0.8.0;python_version>="3.10"', + 'orbax-checkpoint;python_version>="3.10"', + 'jax-ai-stack;python_version>="3.10"', ] [tool.deepmd_build_backend.scripts]