Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RF: Accumulate Dim in rf.scan for beam search #1327

Closed
albertz opened this issue May 16, 2023 · 2 comments
Closed

RF: Accumulate Dim in rf.scan for beam search #1327

albertz opened this issue May 16, 2023 · 2 comments

Comments

@albertz
Copy link
Member

albertz commented May 16, 2023

In case of an eager-based backend (e.g. PyTorch), it's all quite straight-forward:

  • You have an initial state [Batch, InitialBeam(1), HiddenDim]
  • get logits [Batch, InitialBeam, Vocab]
  • take top-k, get [Batch, Beam(k)] labels
  • new state [Batch, Beam(k), HiddenDim]

And so on. I.e. there is a loop over the labels. Let's assume the generic case, that each loop iteration, we can have a different beam dimension.

How to handle this? If the label would be an output of rf.scan, would this be possible? Currently not, because you must specify the template output per iteration, so sth like [Batch, Beam], but this would be wrong, as you would have a different beam dim in each iteration.

The current logic of RETURNN dim tags would solve this via the get_for_batch_ctx. However, we actually wanted to avoid this for the RETURNN frontend because of its complexity (#975, see discussion in #1120). With get_for_batch_ctx, there would be a single base Dim object which can have different values depending on the context (loop iteration, or outside the loop). Then a template like [Batch, Beam] would be correct. But I think it is clear that we do not want that complexity. We want that each beam dim is really different.

Actually, we don't really need to get that stacked together, because we anyway want to resolve (backtrack) the beams afterwards to only have the final beam. So we would use rf.scan with return_tensor_arrays=True and then first resolve the beams. So then stacked together we get [NumLabels, Batch, FinalBeam].

Currently the TensorArray would in any case need some relaxation to not always enforce a match to the template, or at least allow different batch dims.

So, then all it now clear for eager-based execution.

However, how does this fit together with graph-based execution?

We could actually have an OuterBeam dim with dyn_size_ext [NumLabels] -> beam size. In rf.scan, we could allow dim tags in ys, and then the beam dim per iteration would be an output, resulting in this OuterBeam dim. This would unambiguously define such OuterBeam and allows for the accumulation of the labels resulting in [NumLabels, Batch, OuterBeam] (although as said before, you never really want to stack and pad this together -- you want to resolve it directly on the tensor array to the final beam). So, while this could be a solution, it's actually not really what we want or need in the end.

Another question for graph-based backends: When body is executed (exactly once only at compilation time), how do we represent the prev beam dim? We have the initial beam dim as input for the initial state. But in the formulation before, this would be static of size 1. When we pass this to body, then it looks like the prev state has a beam dim with static size 1. This is clearly not always the case for the prev beam dim, so this would be wrong. We need to transform somehow from the initial beam dim to some generic prev beam dim placeholder. But how could we do this? Can we even do this automatically? We would need to specify the initial beam dim not as static dim but actually as dyn_size_ext to a scalar. Then it is clear that it is dynamic, and can change. We still must somehow require that the user specifies (e.g. via the state) that the prev beam dim gets transformed to the new beam dim. This solves it partly, but not fully. I think we still have the problem that we have the initial beam dim being the same as the prev beam dim, which is not really correct. Or maybe not because the beam dim itself is a state var, and then it can know that it needs to transform the initial value to some generic prev state var placeholder value. So it could work.

Can we make sure that the user does it correctly in a way that it works correct with both eager-based and graph-based backends? Actually this is simple. We just need to check that the dims of state vars do not change across iterations. We currently don't do this in the rf.while_loop eager implementation (but we should). But we do check it in rf.scan via the TensorArray via the template for ys.

Is this a good solution? Ok, let's assume we have the beam dim also as a state var. Now getting back to the specific example with beam search, and accumulating the labels in rf.scan. We still would use return_tensor_arrays=True. We can automatically update the tensor array template by those per-iteration dim tags. It looks to me like this could work. But maybe I'm overlooking sth?

To summarize, for the user, it just means: any dims which would get updated in each iteration must be part of the state vars. Nothing else. If the user does it wrong, it is easy to catch. This is still reasonable straight-forward then. So all good?

@albertz
Copy link
Member Author

albertz commented May 17, 2023

Some follow-up: While implementing an initial version of beam search, I realize there is still some ambiguity for the back refs. They are usually of shape [Batch, CurBeam] -> PrevBeam, and we would store them for all loop iterations, such that we can later go over them to resolve the final beam. The problem is that we have both CurBeam and PrevBeam here.

@albertz
Copy link
Member Author

albertz commented May 17, 2023

One solution to this: Just be more relaxed for the sparse dim, allow either the current or the prev.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant