From 04d135a0b66c448830232d33965cc237719c52c0 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 6 Oct 2021 10:49:10 +0200 Subject: [PATCH 1/5] DimensionTag.is_dim_known --- returnn/tf/util/data.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index f4c76e974..3aaf9c4ca 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -361,6 +361,24 @@ def is_spatial_dim(self): """ return self.kind == DimensionTag.Types.Spatial + def is_dim_known(self): + """ + :return: whether we know the dimension; basically whether this is defined + (although `not self.undefined` is defined slightly differently) + :rtype: bool + """ + if self.is_batch_dim(): + return True + if self.dimension is not None: + return True + if self.dyn_size_ext: + return True + base = self.get_same_base() + for _, other in base._same_for_batch_ctx.items(): + if other.dyn_size_ext: + return True + return False + def is_same_size_tensor(self, x): """ :param tf.Tensor x: From 090f78ae128c07d62b2176037afb1ff064885fa5 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 6 Oct 2021 10:50:38 +0200 Subject: [PATCH 2/5] Data.get_dim_tag_from_description --- returnn/tf/util/data.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 3aaf9c4ca..809baa0ed 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -3192,6 +3192,15 @@ def get_axes_from_description(self, axes, allow_int=True): res.append(i) return res + def get_dim_tag_from_description(self, axis): + """ + :param str|DimensionTag axis: + :return: our matching dim tag. this assumes it exists. + :rtype: DimensionTag + """ + axis_int = self.get_axis_from_description(axis, allow_int=False) + return self.dim_tags[axis_int] + def get_axis_from_description(self, axis, allow_int=True): """ :param int|str|DimensionTag axis: From 03298f22c54b953615d8bbaaf3f30a61c6fba885 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 6 Oct 2021 10:50:55 +0200 Subject: [PATCH 3/5] Data.mark_same_time extended --- returnn/tf/util/data.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 809baa0ed..c266b55fa 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -3399,18 +3399,24 @@ def get_static_axes(self): return [axis for axis, dim in enumerate(self.batch_shape) if axis != self.batch_dim_axis and dim is not None] - def mark_same_time(self, tags): + def mark_same_time(self, tags, must_match=False): """ If the given dimension tag matches any of our axes, we set our time axis to the selected one. - :param set[DimensionTag] tags: + :param set[DimensionTag]|DimensionTag tags: + :param bool must_match: if True, throw an exception if not found :return: whether we have found the same :rtype: bool """ + if isinstance(tags, DimensionTag): + tags = {tags} + assert all(isinstance(tag, DimensionTag) for tag in tags) for axis, dim_tag in enumerate(self.dim_tags): if dim_tag in tags: self.time_dim_axis = axis return True + if must_match: + raise Exception("%s mark_same_time: %s not found" % (self, tags)) return False def is_same_time_dim(self, other): From d44464a7e1540eb50f2117dba2bff77f9646c239 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Mon, 28 Jun 2021 01:06:43 +0200 Subject: [PATCH 4/5] RecUnstackLayer --- returnn/tf/layers/rec.py | 56 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/returnn/tf/layers/rec.py b/returnn/tf/layers/rec.py index d71cfb87c..a9612615b 100644 --- a/returnn/tf/layers/rec.py +++ b/returnn/tf/layers/rec.py @@ -4525,6 +4525,62 @@ def get_out_data_from_opts(cls, name, sources, sub_layer, **kwargs): return subnet.layers[sub_layer].output +class RecUnstackLayer(LayerBase): + """ + This is supposed to be used inside a :class:`RecLayer`. + The input is supposed to be outside the rec layer (i.e. via ``base:``). + Uses tf.TensorArray and then unstack on the inputs to make it available per-frame. + This is an alternative to making some input to the rec layer, + such that the rec layer can have multiple inputs (as long as they have the same time dim). + + Note that due to automatic optimization, this layer will be optimized out of the rec loop anyway, + and then the tf.TensorArray logic happens internally in RecLayer, + thus we do not need to care about this here. + (See get_input_moved_out for some internal handling.) + + Effectively, this layer is very similar to :class:`CopyLayer`, + with the only special behavior that it assigns the loop dimension of RecLayer. + """ + layer_class = "rec_unstack" + + def __init__(self, axis, **kwargs): + """ + :param str|DimensionTag axis: + """ + axis # noqa # unused here, used in get_out_data_from_opts + super(RecUnstackLayer, self).__init__(**kwargs) + assert len(self.sources) == 1 + src = self.sources[0].output + rec_time_dim = self.network.get_inside_rec_time_dim() + if rec_time_dim: + raise NotImplementedError("%s: We expect that this layer is always optimized out." % self) + self.output.placeholder = src.placeholder + + @classmethod + def get_out_data_from_opts(cls, name, axis, sources, network, **kwargs): + """ + :param str name: + :param str|DimensionTag axis: + :param list[LayerBase] sources: + :param returnn.tf.network.TFNetwork network: + :rtype: Data + """ + assert sources + out = sources[0].output.copy_template(name="%s_output" % name) + out_dim = out.get_dim_tag_from_description(axis) + rec_time_dim = network.get_inside_rec_time_dim() + if rec_time_dim: + if rec_time_dim.is_dim_known(): # defined + assert out_dim == rec_time_dim + else: + rec_time_dim.declare_same_as(out_dim) + out.mark_same_time(out_dim, must_match=True) + return out.copy_template_excluding_time_dim() + else: + out.mark_same_time(out_dim, must_match=True) + return out + + class BaseChoiceLayer(LayerBase): """ This is a base-class for any layer which defines a new search choice, From 3f179ebb6a08f2377154ee8a944f52ddd3b21868 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 6 Oct 2021 11:09:50 +0200 Subject: [PATCH 5/5] StackLayer, extend doc --- returnn/tf/layers/basic.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index d47dada17..80535df8f 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -5036,6 +5036,9 @@ def get_out_data_from_opts(cls, axis, enforce_batch_dim_axis=None, allow_no_op=F class StackLayer(LayerBase): """ Stacks multiple inputs together using :func:`tf.stack`. + This creates a new dimension for the stack. + + For concatenation (in feature dimension), see :class:`CopyLayer`. """ layer_class = "stack"