Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Masked computation wrapper #23

Closed
albertz opened this issue Aug 9, 2021 · 6 comments
Closed

Masked computation wrapper #23

albertz opened this issue Aug 9, 2021 · 6 comments

Comments

@albertz
Copy link
Member

albertz commented Aug 9, 2021

Similar to the rec loop design (#16), for the masked computation, we could have some API like with MaskedComputation(mask=...). This would wrap MaskedComputationLayer, and also automatically apply UnmaskLayer.

Note that while it is quite trivial to implement such masking logic by hand (using nn.where given the mask, to update the output and state or take the previous output and state), such explicit nn.MaskedComputation allows for efficiency optimization on RETURNN side. Specifically, when it can optimize this part out of the loop, it can calculate it much more efficiently by only going over the relevant frames.


Example for transducer using SlowRNN and FastRNN:

x  # shape {batch,enc_time,dim}
slow_rnn = SlowRNN()
fast_rnn = FastRNN()
blank_pred = nn.Linear(1)
non_blank_pred = nn.Linear(...)
loop = nn.Loop()  # over alignment labels
loop.state.t = nn.zeros([nn.batch_dim], dtype=int32)
loop.state.align_label = nn.zeros([nn.batch_dim], dtype=int32)
with loop:
  x_t = x[loop.state.t]  # shape {batch,dim}
  with nn.MaskedComputation(mask=(loop.state.align_label != BLANK)):
    slow = slow_rnn(loop.state.align_label, x_t)
  fast = fast_rnn(loop.state.align_label, x_t, slow)
  blank_pred_energy = blank_pred(fast)
  log_prob_blank = nn.log_sigmoid(blank_pred_energy)
  log_prob_not_blank = nn.log_sigmoid(-blank_pred_energy)
  log_prob_non_blank_labels = nn.log_softmax(non_blank_pred(fast))
  log_prob_combined = nn.concat(log_prob_non_blank_labels + log_prob_not_blank, log_prob_blank)
  loop.state.align_label = nn.choice(log_prob_combined, input_type="log_prob")
  loop.state.t = nn.where(loop.state.align_label == BLANK, 0, 1)
  loop.end(loop.state.t >= x.seq_len)
@albertz
Copy link
Member Author

albertz commented Oct 20, 2021

Note that you can very easily and straightforwardly write this using where. However, the point of MaskedComputationLayer on RETURNN-side is that it can be more efficient.

It is not actually so clear whether we really need to wrap MaskedComputationLayer here, or whether this can be implemented more directly. In any case, we want to keep the efficient cases as efficient as before.

@albertz
Copy link
Member Author

albertz commented Nov 26, 2021

Related is also: rwth-i6/returnn#769 (MaskedComputationLayer is violating the principle that the user should not need to think about rec automatic optimization)
Edit This is fixed now by rwth-i6/returnn#976.

@albertz
Copy link
Member Author

albertz commented Mar 8, 2022

Also see rwth-i6/returnn#976. Edit Merged now.

Note that the masked_from option is not really needed here. In case this is available, it is probably simpler and cleaner to just have it explicitly separate.

@albertz
Copy link
Member Author

albertz commented Mar 8, 2022

I'm not exactly sure whether MaskedComputationLayer supports accessing arbitrary sub layers. I think it does not. This needs the corresponding fix on RETURNN side. But I think this should be simple. Edit rwth-i6/returnn#984 Edit Merged.

@albertz
Copy link
Member Author

albertz commented Mar 10, 2022

The design is actually not really so clear here.

The code without such explicit nn.MaskedComputation is simple and straight-forward:

loop = nn.Loop(...)
loop.state.y = ...  # some initial output
loop.state.h = ...  # some initial state
with loop:

  mask = ...  # dtype bool, shape [batch] or whatever, for current frame
  y_, h_ = slow_rnn(x, loop.state.h)
  loop.state.y = nest.map(lambda a, b: nn.where(cond=mask, x=a, y=b), y_, loop.state.y)
  loop.state.h = nest.map(lambda a, b: nn.where(cond=mask, x=a, y=b), h_, loop.state.h)
  y = loop.state.y

But how exactly would we handle state and initial state, initial output with nn.MaskedComputation?

It should be straight-forward and non-ambiguous.

Maybe:

loop = nn.Loop(...)
loop.state.y = ...  # some initial output
loop.state.h = ...  # some initial state
with loop:

  mask = ...  # dtype bool, shape [batch] or whatever, for current (fast) frame
  with nn.MaskedComputation(mask=mask):
    loop.state.y, loop.state.h = slow_rnn(x, loop.state.h)
  y = loop.state.y  # access from outside

@albertz
Copy link
Member Author

albertz commented Mar 13, 2022

The test runs through now. While some aspects might need some better handling (e.g. initial_state vs initial_output of RecLayer, needs latest RETURNNs and only works with 0 state), I think this is good enough for now.

@albertz albertz closed this as completed Mar 13, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant