Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exposing the experimental _split_transpose JAX scan parameter in Flax. #3795

Merged
merged 1 commit into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions flax/core/axes_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def scan(
length: Optional[int] = None,
reverse: bool = False,
unroll: int = 1,
_split_transpose: bool = False
):
"""A wrapper around `jax.lax.scan` with in_axes/out_axes api.

Expand Down Expand Up @@ -74,6 +75,8 @@ def body_fn(b, c, x):
reverse: scan in reverse order from end to start.
unroll: how many scan iterations to unroll within a single
iteration of a loop (default: 1).
_split_transpose: An experimental feature to split the transpose of scan
into a scan and a map, backed by an experimental Jax lax.scan() feature.
Returns:
the function that performs the scan of the form:
(broadcast_in, carry_in, *args) -> (broadcast_out, carry_out, scan_out).
Expand Down Expand Up @@ -158,9 +161,15 @@ def body_fn(c, xs, init_mode=False):
out_tree(), out_flat
)

c, ys = lax.scan(
body_fn, init, xs, length=length, reverse=reverse, unroll=unroll
)
if jax.version.__version_info__ > (0, 4, 25):
c, ys = lax.scan(
body_fn, init, xs, length=length, reverse=reverse, unroll=unroll,
_split_transpose=_split_transpose
)
else:
c, ys = lax.scan(
body_fn, init, xs, length=length, reverse=reverse, unroll=unroll
)
ys = jax.tree_util.tree_map(transpose_from_front, out_axes, ys)
ys = jax.tree_util.tree_map(
lambda ax, const, y: (const if ax is broadcast else y),
Expand Down
4 changes: 4 additions & 0 deletions flax/core/lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,7 @@ def scan(
length: Optional[int] = None,
reverse: bool = False,
unroll: int = 1,
_split_transpose: bool = False,
data_transform: Optional[Callable[..., Any]] = None,
metadata_params: Dict[Any, Any] = {},
) -> Callable[..., Any]:
Expand Down Expand Up @@ -935,6 +936,8 @@ def body_fn(scope, c, x):
reverse: If true, scan from end to start in reverse order.
unroll: how many scan iterations to unroll within a single
iteration of a loop (default: 1).
_split_transpose: An experimental feature to split the transpose of a scan
into a scan and a map, backed by an experimental Jax lax.scan() feature.
data_transform: optional function to transform raw variable and rng groups,
intended for inline SPMD annotations.
metadata_params: arguments dict passed to AxisMetadata instances in the
Expand Down Expand Up @@ -993,6 +996,7 @@ def find_length(axis, x):
length=length,
reverse=reverse,
unroll=unroll,
_split_transpose=_split_transpose
)
def scanned(broadcast_vars, carry, scan_variable_groups, rng_groups, args):
carry_vars, c = carry
Expand Down
4 changes: 4 additions & 0 deletions flax/linen/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,7 @@ def scan(
data_transform: Optional[Callable[..., Any]] = None,
metadata_params: Mapping[Any, Any] = {},
methods=None,
_split_transpose: bool = False,
) -> Target:
"""A lifted version of ``jax.lax.scan``.

Expand Down Expand Up @@ -1280,6 +1281,8 @@ def scan(
metadata_params: arguments dict passed to AxisMetadata instances in the
variable tree.
methods: If ``target`` is a ``Module``, the methods of ``Module`` to scan over.
_split_transpose: An experimental feature to split the transpose of a scan
into a scan and a map, backed by an experimental Jax lax.scan() feature.

Returns:
The scan function with the signature ``(module, carry, *xs) -> (carry,
Expand All @@ -1298,6 +1301,7 @@ def scan(
length=length,
reverse=reverse,
unroll=unroll,
_split_transpose=_split_transpose,
data_transform=data_transform,
metadata_params=metadata_params,
methods=methods,
Expand Down
Loading