Skip to content

Commit

Permalink
Fix typo in pmap docstring
Browse files Browse the repository at this point in the history
Docstring states:
>  If the pmapped function is called with fewer positional arguments than indicated by **`static_argnums`** then an error is raised.

However `static_argnums` is not an argument that exists - I believe this should be corrected to `static_broadcasted_argnums`.

PiperOrigin-RevId: 595731210
  • Loading branch information
tomcobley authored and jax authors committed Jan 4, 2024
1 parent 326d1d2 commit ebc7af9
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1415,10 +1415,10 @@ def pmap(
Operations that only depend on static arguments will be constant-folded.
Calling the pmapped function with different values for these constants
will trigger recompilation. If the pmapped function is called with fewer
positional arguments than indicated by ``static_argnums`` then an error is
raised. Each of the static arguments will be broadcasted to all devices.
Arguments that are not arrays or containers thereof must be marked as
static. Defaults to ().
positional arguments than indicated by ``static_broadcasted_argnums`` then
an error is raised. Each of the static arguments will be broadcasted to
all devices. Arguments that are not arrays or containers thereof must be
marked as static. Defaults to ().
Static arguments must be hashable, meaning both ``__hash__`` and
``__eq__`` are implemented, and should be immutable.
Expand Down

0 comments on commit ebc7af9

Please sign in to comment.