Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when using nn.scan with negative output_axes #3460

Closed
lucaslingle opened this issue Nov 4, 2023 · 1 comment · Fixed by #3540
Closed

Error when using nn.scan with negative output_axes #3460

lucaslingle opened this issue Nov 4, 2023 · 1 comment · Fixed by #3540

Comments

@lucaslingle
Copy link

lucaslingle commented Nov 4, 2023

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): N/A
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: flax==0.6.11, jax==0.4.9, jaxlib==0.4.9
  • Python version: 3.8
  • GPU/TPU model and memory: N/A
  • CUDA version (if applicable): N/A

Problem you have encountered:

When using flax.linen.scan with a negative output_axes, there is an unexpected AssertionError. If I have understood the source code correctly, it is due to a typo here (namely, a minus sign instead of a plus sign).

What you expected to happen:

Apply scan as usual, stacking the outputs along the specified axis.

Logs, error messages, etc:

(projectabcde) lucaslingle@Lucass-MacBook-Pro projectabcde % python3 scripts/scan_issue.py
Traceback (most recent call last):
  File "scripts/scan_issue.py", line 39, in <module>
    main()
  File "scripts/scan_issue.py", line 32, in main
    params = cls().init(
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/linen/module.py", line 1689, in init
    _, v_out = self.init_with_output(
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/linen/module.py", line 1594, in init_with_output
    return init_with_output(
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/scope.py", line 968, in wrapper
    return apply(fn, mutable=mutable, flags=init_flags)({}, *args, rngs=rngs,
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/scope.py", line 936, in wrapper
    y = fn(root, *args, **kwargs)
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/linen/module.py", line 2170, in scope_fn
    return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/linen/module.py", line 432, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/linen/module.py", line 868, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "scripts/scan_issue.py", line 18, in __call__
    _, outputs = nn.scan(
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/linen/transforms.py", line 323, in wrapped_fn
    ret = trafo_fn(module_scopes, *args, **kwargs)
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/lift.py", line 219, in wrapper
    y, out_variable_groups_xs_t = fn(
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/lift.py", line 806, in inner
    broadcast_vars, (carry_vars, c), (ys, scan_vars) = scanned(
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/axes_scan.py", line 151, in scan_fn
    ys = jax.tree_util.tree_map(transpose_from_front, out_axes, ys)
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/jax/_src/tree_util.py", line 210, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/jax/_src/tree_util.py", line 210, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/axes_scan.py", line 106, in transpose_from_front
    return jax.tree_util.tree_map(trans, xs)
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/jax/_src/tree_util.py", line 210, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/jax/_src/tree_util.py", line 210, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/axes_scan.py", line 103, in trans
    assert pax < x.ndim
jax._src.traceback_util.UnfilteredStackTrace: AssertionError

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "scripts/scan_issue.py", line 39, in <module>
    main()
  File "scripts/scan_issue.py", line 32, in main
    params = cls().init(
  File "scripts/scan_issue.py", line 18, in __call__
    _, outputs = nn.scan(
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/axes_scan.py", line 151, in scan_fn
    ys = jax.tree_util.tree_map(transpose_from_front, out_axes, ys)
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/axes_scan.py", line 106, in transpose_from_front
    return jax.tree_util.tree_map(trans, xs)
  File "/Users/lucaslingle/opt/miniconda3/envs/projectabcde/lib/python3.8/site-packages/flax/core/axes_scan.py", line 103, in trans
    assert pax < x.ndim
AssertionError

Steps to reproduce:

# issue appears to be at https://github.com/google/flax/blob/main/flax/core/axes_scan.py#L101

import flax.linen as nn
import jax.random


class Foo(nn.Module):
    unused_config: int

    @nn.compact
    def __call__(self, state, input_dict):
        return state, nn.Dense(100)(input_dict["x"])


class Bar(nn.Module):
    @nn.compact
    def __call__(self, x):
        _, outputs = nn.scan(
            Foo,
            variable_broadcast="params",
            split_rngs=dict(
                params=False,
            ),
            in_axes=0,
            out_axes=-1,
        )(unused_config=123)(dict(unused_state_item=None), dict(x=x))
        return outputs


def main():
    cls = Bar
    params = cls().init(
        {"params": jax.random.PRNGKey(0)},
        jax.random.normal(jax.random.PRNGKey(1), shape=[8, 128, 16])
    )["params"]


if __name__ == "__main__":
    main()

Thank you for your attention to this matter!

@lucaslingle lucaslingle changed the title Error when using nn.scan with negative output_axis Error when using nn.scan with negative output_axes Nov 4, 2023
@chiamp
Copy link
Collaborator

chiamp commented Dec 4, 2023

I traced the assertion error pax < x.ndim in the debugger and it seems like pax has a value of 4, whereas x.ndim has a value of 3. x.ndim is 3 because your input data is 3-dimensional (with shape [8, 128, 16]). pax is derived here, where ax is the out_axes input, which is -1 in this case. So pax = x.ndim - ax = 3 - (-1) = 4. Looking at this line of code, it seems like the assertion error will be raised if your out_axes argument is 0 or less. If you use out_axes=2 instead, it should work. However I wonder if this is a bug in the code since if the ax < 0 condition is fulfilled, it will be guaranteed that the assertion error will be raised.

cc: @marcvanzee @cgarciae

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants