You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The current RETURNN-frontend (RF, #1120) while_loop has a very bare API, mapping directly to the underlying control flow logic, e.g. a Python while loop with a conditional ending.
The RecLayer on the other side is a bit higher level, e.g. it can automatically iterate over some given input, automatically collect the results from each frame for the output, and then also handles dimension tags better.
We probably also want to have such higher-level functions in the RF.
How should it look like exactly?
Other frameworks like Theano, TF or JAX also have scan, which does the automatic iteration and automatic stacking.
TF provides the TensorArray to make automatic iteration and stacking efficient. The naive variant with gathering and concatenating would be inefficient with backprop. I was also checking JAX while_loop and scan and it seems it does not have such mechanism but instead uses dynamic_index_in_dim/dynamic_slice and dynamic_update_index_in_dim/dynamic_update_slice. I wonder how it can make that efficient with backprop. (StackOverflow question, related: jax-ml/jax#3106).
JAX scan also supports to have pytrees for xs and ys. xs can even be None, when you specify length. So it is somewhat similar to our RecLayer, but still less generic. In our RecLayer, we also can have a dynamic ending condition.
But we don't necessarily need to replicate RecLayer exactly for the RF. What cases do we actually need, and what would be the most convenient API?
Some wanted features:
Iterate over existing dim. spatial_dim option.
Scan over some input (or inputs) over some dimension. xs and spatial_dim options.
Collect and stack outputs. In that case, the outputs ys and the spatial_dim would be returned.
Potential dynamic continue condition (dynamic end). Not just like cond in while_loop (returning scalar) but might also return more complex tensor, e.g. shape [B]. Needed if given spatial_dim is None or undefined.
This extended implementation should be an implementation around the lower-level while_loop.
The extended implementation maybe should have a separate API, to keep while_loop simple. The while_loop docstring can refer to it.
Note: xs and ys can be any nested structure. Thus we again need #1314 here.
So, potential new API:
S=TypeVar("S")
X=TypeVar("X") # any nested structure, can be NoneY=TypeVar("Y") # any nested structure, can be Nonedefscan(
*,
spatial_dim: Optional[Dim] =None,
initial: S=None,
xs: X=None,
cond: Optional[Callable[S, X], Tensor] =None,
body: Callable[[S, X], Tuple[S, Y]],
) ->Tuple[S, Y, Dim]:
...
It's a bit bad that xs and ys are a bit disconnected from the body function. But maybe it's not so much a problem.
Note that end in RecLayer is different than cond in while_loop in multiple ways:
It signals whether the sequence ended, not whether to continue the loop.
It is not just a scalar but normally e.g. of shape [B].
It is evaluated at the middle or end of the iteration, not at the beginning. (include_eos)
The cond here is also different:
It is not just a scalar but normally e.g. of shape [B].
It is like a mask. Like while_loopcond, it signals whether to continue the sequence.
It is evaluated at the beginning, just like while_loopcond.
I think we don't need such include_eos option. We just always include all frames until cond is False.
Or alternatively, we can put most of the things into a LoopCtx helper object, which simplifies the API:
The user can call unstack to get values from outside, like xs above. The key is needed to be able to identity the same tensor in different iterations. For an eager-based backend, for the first call it would unstack the input, and then in each iteration it would return the corresponding value.
The user can call stack to accumulate frames for the outputs, like ys. The key is for the dict key in the final stacked.
The user can access state to get the current state, and also write to that to update it.
The user can call cond to signal the loop condition.
The final returned loop ctx can be used to get the final state, and the stacked values in stacked.
The underlying backend is also relevant. E.g. when using JAX, we probably need to map to scan for efficiency (jax-ml/jax#15906), and for that, we need to know the xs in advance. I'm not sure if the second proposed API here is really possible. The first API is in any case more straight-forward, even if maybe less nice.
Here some use cases:
Training of attention-based encoder-decoder model. The loop over the targets for the decoder (not needed for Transformer decoder, but once there is some state, it is needed). In one iteration, it gets the prev target as input, or the embedding of that, and then calculates the logits, and in the end we want to get the CE to the current target.
In this example, we know the spatial_dim in advance. xs would be the prev target embeddings. ys would be the readout input.
In the case of recognition with beam search for AED, the end condition would be when the last label is EOS.
Whether you want to include the EOS label in the output or not (include_eos) is somewhat up to the user. If the user cares about the logits or log-probs of each label, then it is probably relevant to also have it for that label.
This end would be evaluated at the end of an iteration on the predicted label. Or it could also check at the beginning of an iteration (like while_loopcond), it would check the last prev label. So both variants are possible. When checking prev label, we need to distinguish BOS from EOS.
Another example is transducer. In search, the ending condition at the end of an iteration is when t == T - 1 and y = BLANK, or the ending condition at the beginning of an iteration is t >= T.
Another example is CTC or RNA. We have xs for the encoder output. So no ending condition needed, as the spatial dim is given.
The text was updated successfully, but these errors were encountered:
The current RETURNN-frontend (RF, #1120)
while_loop
has a very bare API, mapping directly to the underlying control flow logic, e.g. a Python while loop with a conditional ending.The
RecLayer
on the other side is a bit higher level, e.g. it can automatically iterate over some given input, automatically collect the results from each frame for the output, and then also handles dimension tags better.We probably also want to have such higher-level functions in the RF.
How should it look like exactly?
Other frameworks like Theano, TF or JAX also have
scan
, which does the automatic iteration and automatic stacking.TF provides the
TensorArray
to make automatic iteration and stacking efficient. The naive variant with gathering and concatenating would be inefficient with backprop. I was also checking JAX while_loop and scan and it seems it does not have such mechanism but instead usesdynamic_index_in_dim
/dynamic_slice
anddynamic_update_index_in_dim
/dynamic_update_slice
. I wonder how it can make that efficient with backprop. (StackOverflow question, related: jax-ml/jax#3106).JAX
scan
also supports to have pytrees forxs
andys
.xs
can even be None, when you specifylength
. So it is somewhat similar to ourRecLayer
, but still less generic. In ourRecLayer
, we also can have a dynamic ending condition.But we don't necessarily need to replicate
RecLayer
exactly for the RF. What cases do we actually need, and what would be the most convenient API?Some wanted features:
spatial_dim
option.xs
andspatial_dim
options.ys
and thespatial_dim
would be returned.cond
inwhile_loop
(returning scalar) but might also return more complex tensor, e.g. shape [B]. Needed if givenspatial_dim
is None or undefined.This extended implementation should be an implementation around the lower-level
while_loop
.The extended implementation maybe should have a separate API, to keep
while_loop
simple. Thewhile_loop
docstring can refer to it.Note:
xs
andys
can be any nested structure. Thus we again need #1314 here.So, potential new API:
It's a bit bad that
xs
andys
are a bit disconnected from thebody
function. But maybe it's not so much a problem.Note that
end
inRecLayer
is different thancond
inwhile_loop
in multiple ways:include_eos
)The
cond
here is also different:while_loop
cond
, it signals whether to continue the sequence.while_loop
cond
.I think we don't need such
include_eos
option. We just always include all frames untilcond
is False.Or alternatively, we can put most of the things into a
LoopCtx
helper object, which simplifies the API:The user can call
unstack
to get values from outside, likexs
above. Thekey
is needed to be able to identity the same tensor in different iterations. For an eager-based backend, for the first call it would unstack the input, and then in each iteration it would return the corresponding value.The user can call
stack
to accumulate frames for the outputs, likeys
. Thekey
is for the dict key in the finalstacked
.The user can access
state
to get the current state, and also write to that to update it.The user can call
cond
to signal the loop condition.The final returned loop ctx can be used to get the final state, and the stacked values in
stacked
.The underlying backend is also relevant. E.g. when using JAX, we probably need to map to
scan
for efficiency (jax-ml/jax#15906), and for that, we need to know thexs
in advance. I'm not sure if the second proposed API here is really possible. The first API is in any case more straight-forward, even if maybe less nice.Here some use cases:
Training of attention-based encoder-decoder model. The loop over the targets for the decoder (not needed for Transformer decoder, but once there is some state, it is needed). In one iteration, it gets the prev target as input, or the embedding of that, and then calculates the logits, and in the end we want to get the CE to the current target.
In this example, we know the
spatial_dim
in advance.xs
would be the prev target embeddings.ys
would be the readout input.Example code:
In the case of recognition with beam search for AED, the
end
condition would be when the last label is EOS.Whether you want to include the EOS label in the output or not (
include_eos
) is somewhat up to the user. If the user cares about the logits or log-probs of each label, then it is probably relevant to also have it for that label.This
end
would be evaluated at the end of an iteration on the predicted label. Or it could also check at the beginning of an iteration (likewhile_loop
cond
), it would check the last prev label. So both variants are possible. When checking prev label, we need to distinguish BOS from EOS.Another example is transducer. In search, the ending condition at the end of an iteration is when
t == T - 1
andy = BLANK
, or the ending condition at the beginning of an iteration ist >= T
.Another example is CTC or RNA. We have
xs
for the encoder output. So no ending condition needed, as the spatial dim is given.The text was updated successfully, but these errors were encountered: