Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Porting RNN from Linen to NNX #4272

Merged
merged 5 commits into from
Oct 16, 2024
Merged

Porting RNN from Linen to NNX #4272

merged 5 commits into from
Oct 16, 2024

Conversation

zinccat
Copy link
Contributor

@zinccat zinccat commented Oct 8, 2024

Porting RNN from Linen to NNX

Fixes # (4259), #4259

Checklist

  • This PR fixes a minor issue (e.g.: typo or small bug) or improves the docs (you can dismiss the other
    checks if that's the case).
  • This change is discussed in a Github issue/
    discussion
  • The documentation and docstrings adhere to the
    documentation guidelines.
  • This change includes necessary high-coverage tests.
    (No quality testing = no merge!)

Copy link

google-cla bot commented Oct 8, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

flax/nnx/nn/recurrent.py Outdated Show resolved Hide resolved
flax/nnx/nn/recurrent.py Outdated Show resolved Hide resolved
return_carry: bool = False,
reverse: bool = False,
keep_order: bool = False,
unroll: int = 1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could accepts rngs to get a default that can be optionally override during __call__?

Suggested change
unroll: int = 1,
unroll: int = 1,
rngs: rnglib.Rngs | None = None,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed and set default to rngs(0)

Comment on lines 675 to 688
def scan_fn(carry: Carry, x: Array) -> tuple[Carry, Array]:
carry, y = self.cell(carry, x)
if slice_carry:
return carry, (carry, y)
return carry, y

scan = nnx.scan(
scan_fn,
in_axes=(Carry, time_axis),
out_axes=(Carry, (0, time_axis)) if slice_carry else (Carry, time_axis),
unroll=self.unroll,
)

scan_output = scan(carry, inputs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently self is being passed as a capture, we need to pass cell as an explicit input.

Suggested change
def scan_fn(carry: Carry, x: Array) -> tuple[Carry, Array]:
carry, y = self.cell(carry, x)
if slice_carry:
return carry, (carry, y)
return carry, y
scan = nnx.scan(
scan_fn,
in_axes=(Carry, time_axis),
out_axes=(Carry, (0, time_axis)) if slice_carry else (Carry, time_axis),
unroll=self.unroll,
)
scan_output = scan(carry, inputs)
def scan_fn(cell: RNNCellBase, carry: Carry, x: Array) -> tuple[Carry, Array]:
carry, y = cell(carry, x)
if slice_carry:
return carry, (carry, y)
return carry, y
state_axes = nnx.StateAxes({...: Carry})
scan = nnx.scan(
scan_fn,
in_axes=(state_axes, Carry, time_axis),
out_axes=(Carry, (0, time_axis)) if slice_carry else (Carry, time_axis),
unroll=self.unroll,
)
scan_output = scan(self.cell, carry, inputs)

*,
merge_fn: Callable[[Array, Array], Array] = _concatenate,
time_major: bool = False,
return_carry: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accept rngs here as well.

@cgarciae
Copy link
Collaborator

Thanks @zinccat for doing this, this is amazing!
Left a few comments.

@zinccat
Copy link
Contributor Author

zinccat commented Oct 10, 2024

Thanks for the review! Will fix it soon

@IvyZX
Copy link
Collaborator

IvyZX commented Oct 10, 2024

Thank you for making the change! You probably need to rebase to the current head to resolve the Read the Docs build error.

@zinccat
Copy link
Contributor Author

zinccat commented Oct 16, 2024

hi, any updates on this?

@IvyZX
Copy link
Collaborator

IvyZX commented Oct 16, 2024

Sorry about the delay! Merging them rn

@zinccat
Copy link
Contributor Author

zinccat commented Oct 16, 2024

thanks!

@copybara-service copybara-service bot merged commit 3bf732c into google:main Oct 16, 2024
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants