Skip to content

Commit

Permalink
[pallas] More simplification of grid mapping and calling convention
Browse files Browse the repository at this point in the history
In previous PR #22552 I have expanded `GridMapping` to encode more
parts of the calling convention. Here we use that new functionality
and clean up some code.

I have removed the internal methods from `BlockSpec` and `GridSpec` because
these classes are part of the API.

I added entries to pallas/CHANGELOG.
  • Loading branch information
gnecula authored and Rifur13 committed Jul 29, 2024
1 parent bdc93d9 commit 95bd812
Show file tree
Hide file tree
Showing 13 changed files with 307 additions and 302 deletions.
3 changes: 2 additions & 1 deletion docs/jax.experimental.pallas.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Classes
:toctree: _autosummary

BlockSpec
GridSpec
Slice

Functions
Expand All @@ -34,4 +35,4 @@ Functions
atomic_or
atomic_xchg

debug_print
debug_print
7 changes: 7 additions & 0 deletions docs/pallas/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ Remember to align the itemized text with the first line of an item within a list
* {class}`jax.experimental.pallas.BlockSpec` now expects `block_shape` to
be passed *before* `index_map`. The old argument order is deprecated and
will be removed in a future release.
* {class}`jax.experimental.pallas.GridSpec` does not have anymore the `in_specs_tree`,
and the `out_specs_tree` fields, and the `in_specs` and `out_specs` tree now
store the values as pytrees of BlockSpec. Previously, `in_specs` and
`out_specs` were flattened ({jax-issue}`#22552`).
* The method `compute_index` of {class}`jax.experimental.pallas.GridSpec` has
been removed because it is private. Similarly, the `get_grid_mapping` and
`unzip_dynamic_bounds` have been removed from `BlockSpec` ({jax-issue}`#22593`).
* Fixed the interpreter mode to work with BlockSpec that involve padding
({jax-issue}`#22275`).
Padding in interpreter mode will be with NaN, to help debug out-of-bounds
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1705,7 +1705,7 @@ def canonicalize_dim(d: DimSize, context: str="") -> DimSize:
"""Canonicalizes and checks for errors in a user-provided shape dimension value.
Args:
f: a Python value that represents a dimension.
d: a Python value that represents a dimension.
Returns:
A canonical dimension value.
Expand Down
Loading

0 comments on commit 95bd812

Please sign in to comment.