Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RF while_loop API extensions #1324

Closed
albertz opened this issue May 8, 2023 · 1 comment
Closed

RF while_loop API extensions #1324

albertz opened this issue May 8, 2023 · 1 comment

Comments

@albertz
Copy link
Member

albertz commented May 8, 2023

The current RETURNN-frontend (RF, #1120) while_loop has a very bare API, mapping directly to the underlying control flow logic, e.g. a Python while loop with a conditional ending.

The RecLayer on the other side is a bit higher level, e.g. it can automatically iterate over some given input, automatically collect the results from each frame for the output, and then also handles dimension tags better.

We probably also want to have such higher-level functions in the RF.

How should it look like exactly?

Other frameworks like Theano, TF or JAX also have scan, which does the automatic iteration and automatic stacking.

TF provides the TensorArray to make automatic iteration and stacking efficient. The naive variant with gathering and concatenating would be inefficient with backprop. I was also checking JAX while_loop and scan and it seems it does not have such mechanism but instead uses dynamic_index_in_dim/dynamic_slice and dynamic_update_index_in_dim/dynamic_update_slice. I wonder how it can make that efficient with backprop. (StackOverflow question, related: jax-ml/jax#3106).

JAX scan also supports to have pytrees for xs and ys. xs can even be None, when you specify length. So it is somewhat similar to our RecLayer, but still less generic. In our RecLayer, we also can have a dynamic ending condition.

But we don't necessarily need to replicate RecLayer exactly for the RF. What cases do we actually need, and what would be the most convenient API?

Some wanted features:

  • Iterate over existing dim. spatial_dim option.
  • Scan over some input (or inputs) over some dimension. xs and spatial_dim options.
  • Collect and stack outputs. In that case, the outputs ys and the spatial_dim would be returned.
  • Potential dynamic continue condition (dynamic end). Not just like cond in while_loop (returning scalar) but might also return more complex tensor, e.g. shape [B]. Needed if given spatial_dim is None or undefined.

This extended implementation should be an implementation around the lower-level while_loop.

The extended implementation maybe should have a separate API, to keep while_loop simple. The while_loop docstring can refer to it.

Note: xs and ys can be any nested structure. Thus we again need #1314 here.


So, potential new API:

S = TypeVar("S")
X = TypeVar("X")  # any nested structure, can be None
Y = TypeVar("Y")  # any nested structure, can be None

def scan(
    *,
    spatial_dim: Optional[Dim] = None,
    initial: S = None,
    xs: X = None,
    cond: Optional[Callable[S, X], Tensor] = None,
    body: Callable[[S, X], Tuple[S, Y]],
) -> Tuple[S, Y, Dim]:
    ...

It's a bit bad that xs and ys are a bit disconnected from the body function. But maybe it's not so much a problem.

Note that end in RecLayer is different than cond in while_loop in multiple ways:

  • It signals whether the sequence ended, not whether to continue the loop.
  • It is not just a scalar but normally e.g. of shape [B].
  • It is evaluated at the middle or end of the iteration, not at the beginning. (include_eos)

The cond here is also different:

  • It is not just a scalar but normally e.g. of shape [B].
  • It is like a mask. Like while_loop cond, it signals whether to continue the sequence.
  • It is evaluated at the beginning, just like while_loop cond.

I think we don't need such include_eos option. We just always include all frames until cond is False.

Or alternatively, we can put most of the things into a LoopCtx helper object, which simplifies the API:

def scan(
    *,
    spatial_dim: Optional[Dim] = None,
    initial: S = None,
    body: Callable[[LoopCtx], None],
) -> Tuple[LoopCtx, Dim]:
    ...

class LoopCtx(Generic[S]):
    def unstack(self, source: Tensor, *, key: Any) -> Tensor:
        ...

    def stack(self, source: Tensor, *, key: Any) -> None:
        ...

    @property
    def state(self) -> S:
        ...

    def cond(self, source: Tensor) -> None:
        ...

    @property
    def stacked(self) -> Dict[Any, Tensor]:
        ...

The user can call unstack to get values from outside, like xs above. The key is needed to be able to identity the same tensor in different iterations. For an eager-based backend, for the first call it would unstack the input, and then in each iteration it would return the corresponding value.
The user can call stack to accumulate frames for the outputs, like ys. The key is for the dict key in the final stacked.
The user can access state to get the current state, and also write to that to update it.
The user can call cond to signal the loop condition.
The final returned loop ctx can be used to get the final state, and the stacked values in stacked.

The underlying backend is also relevant. E.g. when using JAX, we probably need to map to scan for efficiency (jax-ml/jax#15906), and for that, we need to know the xs in advance. I'm not sure if the second proposed API here is really possible. The first API is in any case more straight-forward, even if maybe less nice.


Here some use cases:

Training of attention-based encoder-decoder model. The loop over the targets for the decoder (not needed for Transformer decoder, but once there is some state, it is needed). In one iteration, it gets the prev target as input, or the embedding of that, and then calculates the logits, and in the end we want to get the CE to the current target.

In this example, we know the spatial_dim in advance. xs would be the prev target embeddings. ys would be the readout input.

Example code:

    def _body(state: rf.State):
        new_state = rf.State()
        logits, new_state.decoder = model.decode(
            **enc_args,
            enc_spatial_dim=enc_spatial_dim,
            prev_nb_target=state.label,
            prev_nb_target_spatial_dim=single_step_dim,
            state=state.decoder,
        )

        new_state.label = loop.unstack(targets)
        new_state.logits_stack = state.logits_stack.stack(logits)  # TODO?
        return new_state

    final_state = rf.while_loop(
        cond=lambda i, _: i < targets_spatial_dim.get_dim_value_tensor(),  # TODO?
        body=_body,
        loop_vars=rf.State(
            decoder=model.decoder_default_initial_state(batch_dims=batch_dims, enc_spatial_dim=enc_spatial_dim),
            label=rf.constant(model.bos_idx, shape=batch_dims, sparse_dim=model.nb_target_dim),
            logits_stack=rf.TensorArray(),  # TODO?
        ),
    )

In the case of recognition with beam search for AED, the end condition would be when the last label is EOS.
Whether you want to include the EOS label in the output or not (include_eos) is somewhat up to the user. If the user cares about the logits or log-probs of each label, then it is probably relevant to also have it for that label.
This end would be evaluated at the end of an iteration on the predicted label. Or it could also check at the beginning of an iteration (like while_loop cond), it would check the last prev label. So both variants are possible. When checking prev label, we need to distinguish BOS from EOS.

Another example is transducer. In search, the ending condition at the end of an iteration is when t == T - 1 and y = BLANK, or the ending condition at the beginning of an iteration is t >= T.

Another example is CTC or RNA. We have xs for the encoder output. So no ending condition needed, as the spatial dim is given.

@albertz
Copy link
Member Author

albertz commented May 10, 2023

I now implemented the first proposed API. I also introduced TensorArray. And I fixed #1314 along the way.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant