You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
OS Platform and Distribution (e.g., Linux Ubuntu 16.04): N/A
Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: flax==0.6.11, jax==0.4.9, jaxlib==0.4.9
Python version: 3.8
GPU/TPU model and memory: N/A
CUDA version (if applicable): N/A
Problem you have encountered:
When using flax.linen.scan with a negative output_axes, there is an unexpected AssertionError. If I have understood the source code correctly, it is due to a typo here (namely, a minus sign instead of a plus sign).
What you expected to happen:
Apply scan as usual, stacking the outputs along the specified axis.
Logs, error messages, etc:
(projectabcde) lucaslingle@Lucass-MacBook-Pro projectabcde % python3 scripts/scan_issue.py
Traceback (most recent call last):
File "scripts/scan_issue.py", line 39, in <module>
main()
File "scripts/scan_issue.py", line 32, in main
params = cls().init(
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/linen/module.py", line 1689, in init
_, v_out = self.init_with_output(
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/linen/module.py", line 1594, in init_with_output
return init_with_output(
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/scope.py", line 968, in wrapper
return apply(fn, mutable=mutable, flags=init_flags)({}, *args, rngs=rngs,
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/scope.py", line 936, in wrapper
y = fn(root, *args, **kwargs)
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/linen/module.py", line 2170, in scope_fn
return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/linen/module.py", line 432, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/linen/module.py", line 868, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "scripts/scan_issue.py", line 18, in __call__
_, outputs = nn.scan(
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/linen/transforms.py", line 323, in wrapped_fn
ret = trafo_fn(module_scopes, *args, **kwargs)
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/lift.py", line 219, in wrapper
y, out_variable_groups_xs_t = fn(
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/lift.py", line 806, in inner
broadcast_vars, (carry_vars, c), (ys, scan_vars) = scanned(
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/axes_scan.py", line 151, in scan_fn
ys = jax.tree_util.tree_map(transpose_from_front, out_axes, ys)
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/jax/_src/tree_util.py", line 210, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/jax/_src/tree_util.py", line 210, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/axes_scan.py", line 106, in transpose_from_front
return jax.tree_util.tree_map(trans, xs)
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/jax/_src/tree_util.py", line 210, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/jax/_src/tree_util.py", line 210, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/axes_scan.py", line 103, in trans
assert pax < x.ndim
jax._src.traceback_util.UnfilteredStackTrace: AssertionError
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "scripts/scan_issue.py", line 39, in <module>
main()
File "scripts/scan_issue.py", line 32, in main
params = cls().init(
File "scripts/scan_issue.py", line 18, in __call__
_, outputs = nn.scan(
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/axes_scan.py", line 151, in scan_fn
ys = jax.tree_util.tree_map(transpose_from_front, out_axes, ys)
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/axes_scan.py", line 106, in transpose_from_front
return jax.tree_util.tree_map(trans, xs)
File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/axes_scan.py", line 103, in trans
assert pax < x.ndim
AssertionError
Steps to reproduce:
# issue appears to be at https://github.com/google/flax/blob/main/flax/core/axes_scan.py#L101
import flax.linen as nn
import jax.random
class Foo(nn.Module):
unused_config: int
@nn.compact
def __call__(self, state, input_dict):
return state, nn.Dense(100)(input_dict["x"])
class Bar(nn.Module):
@nn.compact
def __call__(self, x):
_, outputs = nn.scan(
Foo,
variable_broadcast="params",
split_rngs=dict(
params=False,
),
in_axes=0,
out_axes=-1,
)(unused_config=123)(dict(unused_state_item=None), dict(x=x))
return outputs
def main():
cls = Bar
params = cls().init(
{"params": jax.random.PRNGKey(0)},
jax.random.normal(jax.random.PRNGKey(1), shape=[8, 128, 16])
)["params"]
if __name__ == "__main__":
main()
Thank you for your attention to this matter!
The text was updated successfully, but these errors were encountered:
lucaslingle
changed the title
Error when using nn.scan with negative output_axis
Error when using nn.scan with negative output_axes
Nov 4, 2023
I traced the assertion error pax < x.ndim in the debugger and it seems like pax has a value of 4, whereas x.ndim has a value of 3. x.ndim is 3 because your input data is 3-dimensional (with shape [8, 128, 16]). pax is derived here, where ax is the out_axes input, which is -1 in this case. So pax = x.ndim - ax = 3 - (-1) = 4. Looking at this line of code, it seems like the assertion error will be raised if your out_axes argument is 0 or less. If you use out_axes=2 instead, it should work. However I wonder if this is a bug in the code since if the ax < 0 condition is fulfilled, it will be guaranteed that the assertion error will be raised.
System information
pip show flax jax jaxlib
:flax==0.6.11, jax==0.4.9, jaxlib==0.4.9
3.8
Problem you have encountered:
When using
flax.linen.scan
with a negativeoutput_axes
, there is an unexpectedAssertionError
. If I have understood the source code correctly, it is due to a typo here (namely, a minus sign instead of a plus sign).What you expected to happen:
Apply scan as usual, stacking the outputs along the specified axis.
Logs, error messages, etc:
Steps to reproduce:
Thank you for your attention to this matter!
The text was updated successfully, but these errors were encountered: