From ebc7af95dfaada7daa9a57ef7a8a76fcda7cdcf2 Mon Sep 17 00:00:00 2001 From: Tom Cobley Date: Thu, 4 Jan 2024 09:49:14 -0800 Subject: [PATCH] Fix typo in pmap docstring Docstring states: > If the pmapped function is called with fewer positional arguments than indicated by **`static_argnums`** then an error is raised. However `static_argnums` is not an argument that exists - I believe this should be corrected to `static_broadcasted_argnums`. PiperOrigin-RevId: 595731210 --- jax/_src/api.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 7367a655c544..ea5bfd8ea86c 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1415,10 +1415,10 @@ def pmap( Operations that only depend on static arguments will be constant-folded. Calling the pmapped function with different values for these constants will trigger recompilation. If the pmapped function is called with fewer - positional arguments than indicated by ``static_argnums`` then an error is - raised. Each of the static arguments will be broadcasted to all devices. - Arguments that are not arrays or containers thereof must be marked as - static. Defaults to (). + positional arguments than indicated by ``static_broadcasted_argnums`` then + an error is raised. Each of the static arguments will be broadcasted to + all devices. Arguments that are not arrays or containers thereof must be + marked as static. Defaults to (). Static arguments must be hashable, meaning both ``__hash__`` and ``__eq__`` are implemented, and should be immutable.