Skip to content

Commit

Permalink
SliceNdLayer: small changes
Browse files Browse the repository at this point in the history
  • Loading branch information
robin-p-schmitt committed Sep 14, 2021
1 parent d38c94c commit a4f220d
Showing 1 changed file with 30 additions and 16 deletions.
46 changes: 30 additions & 16 deletions returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,9 +880,11 @@ def __init__(self, start, size, min_size=None, **kwargs):
size = tf.maximum(tf.reduce_max(seq_lens - start_t), 0) # scalar
if min_size is not None:
size = tf.maximum(size, min_size)
# build Data object for the position argument of GatherLayer
indices_data = start_data.copy_template(name="%s_gather_indices" % self.name)
# for each start index in start_data, we want to gather a slice
# therefore, the output's first axes are the same as the ones from start_data
# and the next axis will therefore be the slice axis
slice_tag = self.output.dim_tags[start_data.batch_ndim]
assert slice_tag.description.startswith("sliced-time:")
if not isinstance(size, int):
# in this case, size is not known before runtime and becomes dynamic and we need to set dyn_size
if len(seq_lens.shape) == 1:
Expand All @@ -893,21 +895,32 @@ def __init__(self, start, size, min_size=None, **kwargs):
reduce_axes = range(1, len(seq_lens.shape))
dyn_size = tf.maximum(tf.reduce_max(seq_lens - start_t, axis=reduce_axes), 0) # (B,)
slice_tag.dyn_size = dyn_size
indices_data = indices_data.copy_add_dim_by_tag(slice_tag, unbroadcast=True, axis=start_data.batch_ndim)
gather_positions_data = start_data.copy_template(name="%s_gather_positions" % self.name)
gather_positions_data = gather_positions_data.copy_add_dim_by_tag(
slice_tag,
unbroadcast=True,
axis=start_data.batch_ndim)
# [start+0, start+1, ...]
indices = tf.expand_dims(start_t, -1) + tf.range(0, size) # e.g. (B, size) or (B, T, size)
gather_positions = tf.expand_dims(start_t, -1) + tf.range(0, size) # e.g. (B, size) or (B, T, size)
if seq_lens is not None:
# broadcast from (B,) to the shape of the indices
seq_lens = expand_multiple_dims( # e.g. (B,1) or (B,1,1)
x=seq_lens,
axes=[-1] * (len(indices.shape) - len(seq_lens.shape)))
pad_mask = tf.logical_or(tf.greater(indices, seq_lens - 1), tf.less(indices, 0)) # shape like indices
indices = tf.clip_by_value(indices, 0, seq_lens - 1)
axes=[-1] * (len(gather_positions.shape) - len(seq_lens.shape)))
pad_mask = tf.logical_or( # shape like gather_positions
tf.greater(gather_positions, seq_lens - 1),
tf.less(gather_positions, 0))
gather_positions = tf.clip_by_value(gather_positions, 0, seq_lens - 1)
else:
pad_mask = tf.logical_or(tf.greater(indices, x.batch_shape[1] - 1), tf.less(indices, 0)) # shape like indices
indices = tf.clip_by_value(indices, 0, x.batch_shape[1] - 1)
indices_data.placeholder = indices
position = InternalLayer(network=self.network, name="%s_internal" % indices_data.name, output=indices_data)
pad_mask = tf.logical_or( # shape like gather_positions
tf.greater(gather_positions, x.batch_shape[1] - 1),
tf.less(gather_positions, 0))
gather_positions = tf.clip_by_value(gather_positions, 0, x.batch_shape[1] - 1)
gather_positions_data.placeholder = gather_positions
position = InternalLayer(
network=self.network,
name="%s_internal" % gather_positions_data.name,
output=gather_positions_data)
gather_layer = GatherLayer(
name="%s_gather" % self.name,
network=self.network,
Expand All @@ -916,8 +929,6 @@ def __init__(self, start, size, min_size=None, **kwargs):
position=position,
axis=x.get_time_dim_tag())
placeholder = gather_layer.output.placeholder
self.output.size_placeholder = gather_layer.output.size_placeholder
# zero padding
# In principle, the padded frames are being ignored
# (unless get_padding_info_dict_ref et al are used).
# However, you can still end up with gradients for them
Expand Down Expand Up @@ -947,14 +958,17 @@ def get_out_data_from_opts(cls, name, sources=(), start=None, size=None, **kwarg
from ..util.data import DimensionTag
start_data = start.output.copy_as_batch_major()
input_data = sources[0].output.copy_as_batch_major()
indices_data = start_data.copy_template(name="%s_gather_indices" % name)
gather_positions_data = start_data.copy_template(name="%s_gather_positions" % name)
# size might be None here in which case we set the dyn_size in __init__
tag = DimensionTag(
kind=DimensionTag.Types.Spatial,
description="sliced-time:%s" % name,
dimension=size)
indices_data = indices_data.copy_add_dim_by_tag(tag, unbroadcast=True, axis=start_data.batch_ndim)
position = InternalLayer(network=sources[0].network, name="%s_internal" % indices_data.name, output=indices_data)
gather_positions_data = gather_positions_data.copy_add_dim_by_tag(tag, unbroadcast=True, axis=start_data.batch_ndim)
position = InternalLayer(
network=sources[0].network,
name="%s_internal" % gather_positions_data.name,
output=gather_positions_data)
return GatherLayer.get_out_data_from_opts(
name="%s_gather" % name,
sources=sources,
Expand Down

0 comments on commit a4f220d

Please sign in to comment.