Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RecLayer multiple inputs via explicit unstacking #552

Merged
merged 5 commits into from
Oct 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5036,6 +5036,9 @@ def get_out_data_from_opts(cls, axis, enforce_batch_dim_axis=None, allow_no_op=F
class StackLayer(LayerBase):
"""
Stacks multiple inputs together using :func:`tf.stack`.
This creates a new dimension for the stack.

For concatenation (in feature dimension), see :class:`CopyLayer`.
"""
layer_class = "stack"

Expand Down
56 changes: 56 additions & 0 deletions returnn/tf/layers/rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4525,6 +4525,62 @@ def get_out_data_from_opts(cls, name, sources, sub_layer, **kwargs):
return subnet.layers[sub_layer].output


class RecUnstackLayer(LayerBase):
"""
This is supposed to be used inside a :class:`RecLayer`.
The input is supposed to be outside the rec layer (i.e. via ``base:``).
Uses tf.TensorArray and then unstack on the inputs to make it available per-frame.
This is an alternative to making some input to the rec layer,
such that the rec layer can have multiple inputs (as long as they have the same time dim).

Note that due to automatic optimization, this layer will be optimized out of the rec loop anyway,
and then the tf.TensorArray logic happens internally in RecLayer,
thus we do not need to care about this here.
(See get_input_moved_out for some internal handling.)

Effectively, this layer is very similar to :class:`CopyLayer`,
with the only special behavior that it assigns the loop dimension of RecLayer.
"""
layer_class = "rec_unstack"

def __init__(self, axis, **kwargs):
"""
:param str|DimensionTag axis:
"""
axis # noqa # unused here, used in get_out_data_from_opts
super(RecUnstackLayer, self).__init__(**kwargs)
assert len(self.sources) == 1
src = self.sources[0].output
rec_time_dim = self.network.get_inside_rec_time_dim()
if rec_time_dim:
raise NotImplementedError("%s: We expect that this layer is always optimized out." % self)
self.output.placeholder = src.placeholder

@classmethod
def get_out_data_from_opts(cls, name, axis, sources, network, **kwargs):
"""
:param str name:
:param str|DimensionTag axis:
:param list[LayerBase] sources:
:param returnn.tf.network.TFNetwork network:
:rtype: Data
"""
assert sources
out = sources[0].output.copy_template(name="%s_output" % name)
out_dim = out.get_dim_tag_from_description(axis)
rec_time_dim = network.get_inside_rec_time_dim()
if rec_time_dim:
if rec_time_dim.is_dim_known(): # defined
assert out_dim == rec_time_dim
else:
rec_time_dim.declare_same_as(out_dim)
out.mark_same_time(out_dim, must_match=True)
return out.copy_template_excluding_time_dim()
else:
out.mark_same_time(out_dim, must_match=True)
return out


class BaseChoiceLayer(LayerBase):
"""
This is a base-class for any layer which defines a new search choice,
Expand Down
37 changes: 35 additions & 2 deletions returnn/tf/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,24 @@ def is_spatial_dim(self):
"""
return self.kind == DimensionTag.Types.Spatial

def is_dim_known(self):
"""
:return: whether we know the dimension; basically whether this is defined
(although `not self.undefined` is defined slightly differently)
:rtype: bool
"""
if self.is_batch_dim():
return True
if self.dimension is not None:
return True
if self.dyn_size_ext:
return True
base = self.get_same_base()
for _, other in base._same_for_batch_ctx.items():
if other.dyn_size_ext:
return True
return False

def is_same_size_tensor(self, x):
"""
:param tf.Tensor x:
Expand Down Expand Up @@ -3174,6 +3192,15 @@ def get_axes_from_description(self, axes, allow_int=True):
res.append(i)
return res

def get_dim_tag_from_description(self, axis):
"""
:param str|DimensionTag axis:
:return: our matching dim tag. this assumes it exists.
:rtype: DimensionTag
"""
axis_int = self.get_axis_from_description(axis, allow_int=False)
return self.dim_tags[axis_int]

def get_axis_from_description(self, axis, allow_int=True):
"""
:param int|str|DimensionTag axis:
Expand Down Expand Up @@ -3372,18 +3399,24 @@ def get_static_axes(self):
return [axis for axis, dim in enumerate(self.batch_shape)
if axis != self.batch_dim_axis and dim is not None]

def mark_same_time(self, tags):
def mark_same_time(self, tags, must_match=False):
"""
If the given dimension tag matches any of our axes, we set our time axis to the selected one.

:param set[DimensionTag] tags:
:param set[DimensionTag]|DimensionTag tags:
:param bool must_match: if True, throw an exception if not found
:return: whether we have found the same
:rtype: bool
"""
if isinstance(tags, DimensionTag):
tags = {tags}
assert all(isinstance(tag, DimensionTag) for tag in tags)
for axis, dim_tag in enumerate(self.dim_tags):
if dim_tag in tags:
self.time_dim_axis = axis
return True
if must_match:
raise Exception("%s mark_same_time: %s not found" % (self, tags))
return False

def is_same_time_dim(self, other):
Expand Down