-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Masked computation wrapper #23
Comments
Note that you can very easily and straightforwardly write this using It is not actually so clear whether we really need to wrap |
Related is also: rwth-i6/returnn#769 ( |
Also see rwth-i6/returnn#976. Edit Merged now. Note that the |
I'm not exactly sure whether |
The design is actually not really so clear here. The code without such explicit loop = nn.Loop(...)
loop.state.y = ... # some initial output
loop.state.h = ... # some initial state
with loop:
mask = ... # dtype bool, shape [batch] or whatever, for current frame
y_, h_ = slow_rnn(x, loop.state.h)
loop.state.y = nest.map(lambda a, b: nn.where(cond=mask, x=a, y=b), y_, loop.state.y)
loop.state.h = nest.map(lambda a, b: nn.where(cond=mask, x=a, y=b), h_, loop.state.h)
y = loop.state.y But how exactly would we handle state and initial state, initial output with It should be straight-forward and non-ambiguous. Maybe: loop = nn.Loop(...)
loop.state.y = ... # some initial output
loop.state.h = ... # some initial state
with loop:
mask = ... # dtype bool, shape [batch] or whatever, for current (fast) frame
with nn.MaskedComputation(mask=mask):
loop.state.y, loop.state.h = slow_rnn(x, loop.state.h)
y = loop.state.y # access from outside |
The test runs through now. While some aspects might need some better handling (e.g. |
Similar to the rec loop design (#16), for the masked computation, we could have some API like
with MaskedComputation(mask=...)
. This would wrapMaskedComputationLayer
, and also automatically applyUnmaskLayer
.Note that while it is quite trivial to implement such masking logic by hand (using
nn.where
given the mask, to update the output and state or take the previous output and state), such explicitnn.MaskedComputation
allows for efficiency optimization on RETURNN side. Specifically, when it can optimize this part out of the loop, it can calculate it much more efficiently by only going over the relevant frames.Example for transducer using SlowRNN and FastRNN:
The text was updated successfully, but these errors were encountered: