-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Porting RNN from Linen to NNX #4272
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
return_carry: bool = False, | ||
reverse: bool = False, | ||
keep_order: bool = False, | ||
unroll: int = 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we could accepts rngs
to get a default that can be optionally override during __call__
?
unroll: int = 1, | |
unroll: int = 1, | |
rngs: rnglib.Rngs | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed and set default to rngs(0)
flax/nnx/nn/recurrent.py
Outdated
def scan_fn(carry: Carry, x: Array) -> tuple[Carry, Array]: | ||
carry, y = self.cell(carry, x) | ||
if slice_carry: | ||
return carry, (carry, y) | ||
return carry, y | ||
|
||
scan = nnx.scan( | ||
scan_fn, | ||
in_axes=(Carry, time_axis), | ||
out_axes=(Carry, (0, time_axis)) if slice_carry else (Carry, time_axis), | ||
unroll=self.unroll, | ||
) | ||
|
||
scan_output = scan(carry, inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently self
is being passed as a capture, we need to pass cell
as an explicit input.
def scan_fn(carry: Carry, x: Array) -> tuple[Carry, Array]: | |
carry, y = self.cell(carry, x) | |
if slice_carry: | |
return carry, (carry, y) | |
return carry, y | |
scan = nnx.scan( | |
scan_fn, | |
in_axes=(Carry, time_axis), | |
out_axes=(Carry, (0, time_axis)) if slice_carry else (Carry, time_axis), | |
unroll=self.unroll, | |
) | |
scan_output = scan(carry, inputs) | |
def scan_fn(cell: RNNCellBase, carry: Carry, x: Array) -> tuple[Carry, Array]: | |
carry, y = cell(carry, x) | |
if slice_carry: | |
return carry, (carry, y) | |
return carry, y | |
state_axes = nnx.StateAxes({...: Carry}) | |
scan = nnx.scan( | |
scan_fn, | |
in_axes=(state_axes, Carry, time_axis), | |
out_axes=(Carry, (0, time_axis)) if slice_carry else (Carry, time_axis), | |
unroll=self.unroll, | |
) | |
scan_output = scan(self.cell, carry, inputs) |
*, | ||
merge_fn: Callable[[Array, Array], Array] = _concatenate, | ||
time_major: bool = False, | ||
return_carry: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accept rngs
here as well.
Thanks @zinccat for doing this, this is amazing! |
Thanks for the review! Will fix it soon |
Thank you for making the change! You probably need to rebase to the current head to resolve the Read the Docs build error. |
hi, any updates on this? |
Sorry about the delay! Merging them rn |
thanks! |
Porting RNN from Linen to NNX
Fixes # (4259), #4259
Checklist
checks if that's the case).
discussion
documentation guidelines.
(No quality testing = no merge!)