Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SliceNdLayer now uses GatherLayer to get the slices. #635

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 101 additions & 35 deletions returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,9 +842,13 @@ def get_out_data_from_opts(

class SliceNdLayer(_ConcatInputLayer):
"""
This takes out a slice-range from some axis,
This takes out a slice-range from the time axis,
e.g. ``x[start:start + size]``.
albertz marked this conversation as resolved.
Show resolved Hide resolved
This layers allows a different start slice point for each batch,
If the input is of shape (B,T,F) and start is of shape (B,),
then the output will be of shape (B,size,F).
If the input is of shape (B,T,F) and start is of shape (B,T),
then the output will be of shape (B,T,size,F).
This layer allows a different start slice point for each batch,
in contrast to :class:`SliceLayer`, and the start is variable.
See also :class:`GatherNdLayer`.
:class:`PrefixInTimeLayer` can recover the original shape (by zero-padding).
Expand All @@ -854,44 +858,93 @@ class SliceNdLayer(_ConcatInputLayer):

def __init__(self, start, size, min_size=None, **kwargs):
"""
:param LayerBase start:
:param LayerBase start: (B,...)
:param int|None size: if None, it uses the max possible size, and it becomes a dynamic axis
:param int|None min_size: if size is None, but we want to have a min-size, set this
albertz marked this conversation as resolved.
Show resolved Hide resolved
:param int|None min_size: if size is None, but we want to have a min-size
"""
super(SliceNdLayer, self).__init__(**kwargs)
from returnn.tf.util.basic import slice_nd, where_bc, expand_multiple_dims, DimensionTag
from returnn.tf.util.basic import where_bc, expand_multiple_dims
x = self.input_data.copy_as_batch_major()
assert x.time_dim_axis == 1, "currently only time-axis==1 supported"
seq_lens = x.get_sequence_lengths() if x.is_time_axis_dynamic() else None
seq_lens = x.get_sequence_lengths() if x.is_time_axis_dynamic() else None # (B,) or None
self.start = start
assert start.output.have_batch_axis() and start.output.batch_shape == (None,)
start = start.output.get_placeholder_as_batch_major()
start_data = start.output.copy_as_batch_major() # e.g. (B,) or (B,T)
start_t = start_data.placeholder
if size is None:
albertz marked this conversation as resolved.
Show resolved Hide resolved
if min_size is None:
min_size = 0
if seq_lens is None:
size = tf.maximum(tf.reduce_max(x.batch_shape[1] - start), 0)
assert isinstance(x.batch_shape[x.time_dim_axis], int)
size = tf.maximum(tf.reduce_max(x.batch_shape[x.time_dim_axis] - start_t), min_size) # scalar
else:
size = tf.maximum(tf.reduce_max(seq_lens - start), 0)
if min_size is not None:
size = tf.maximum(size, min_size)
self.size = size
start = tf.expand_dims(start, axis=1) # (B, T)
slices = slice_nd(x.placeholder, start=tf.cast(start, tf.int32), size=size) # (B,size, ...)
# make seq_lens compatible with start_t
seq_lens = expand_multiple_dims( # e.g. (B,) or (B,1)
x=seq_lens,
axes=[-1] * (len(start_t.shape) - len(seq_lens.shape)))
albertz marked this conversation as resolved.
Show resolved Hide resolved
size = tf.maximum(tf.reduce_max(seq_lens - start_t), min_size) # scalar
# for each start index in start_data, we want to gather a slice
# therefore, the output's first axes are the same as the ones from start_data
# and the next axis will therefore be the slice axis
slice_tag = self.output.dim_tags[start_data.batch_ndim]
albertz marked this conversation as resolved.
Show resolved Hide resolved
assert slice_tag.description.startswith("sliced-time:")
if not isinstance(size, int):
# in this case, size is not known before runtime and becomes dynamic and we need to set dyn_size
if seq_lens is None:
dyn_size = tf.maximum(x.batch_shape[x.time_dim_axis] - start_t, min_size) # (B,) or (B,T)
else:
dyn_size = tf.maximum(seq_lens - start_t, min_size) # (B,) or (B,T)
dyn_size_ext = Data(
name=("%s:dyn_size" % slice_tag.description),
dtype=Data.size_dtype,
placeholder=dyn_size,
dim_tags=start_data.dim_tags,
batch=slice_tag.batch,
beam=slice_tag.batch.beam if slice_tag.batch else self.output.beam,
control_flow_ctx=slice_tag.control_flow_ctx)
slice_tag.dyn_size_ext = dyn_size_ext
gather_positions_data = start_data.copy_template(name="%s_gather_positions" % self.name)
gather_positions_data = gather_positions_data.copy_add_dim_by_tag(
slice_tag,
unbroadcast=True,
axis=start_data.batch_ndim)
# [start+0, start+1, ...]
gather_positions = tf.expand_dims(start_t, -1) + tf.range(0, size) # e.g. (B, size) or (B, T, size)
if seq_lens is not None:
mask = tf.greater_equal(tf.range(size)[None, :] + start, seq_lens[:, None]) # (B,T)
mask = expand_multiple_dims(mask, list(range(2, x.batch_ndim)))
slices = where_bc(mask, tf.zeros_like(slices), slices)
size_placeholder = x.size_placeholder.copy()
if isinstance(size, tf.Tensor):
size_placeholder[0] = tf.maximum(seq_lens - tf.reshape(start, tf.shape(seq_lens)), 0)
tag = DimensionTag(
description="sliced-time:%s" % self.get_absolute_name(),
kind=DimensionTag.Types.Spatial, batch=self.output.batch)
tag.set_tag_on_size_tensor(size_placeholder[0])
# broadcast from (B,) to the shape of the indices
seq_lens = expand_multiple_dims( # e.g. (B,1) or (B,1,1)
x=seq_lens,
axes=[-1] * (len(gather_positions.shape) - len(seq_lens.shape)))
pad_mask = tf.logical_or( # shape like gather_positions
tf.greater(gather_positions, seq_lens - 1),
tf.less(gather_positions, 0))
gather_positions = tf.clip_by_value(gather_positions, 0, seq_lens - 1)
else:
assert isinstance(size, int)
size_placeholder.pop(0, None) # static time axis
self.output.size_placeholder = size_placeholder
self.output.placeholder = slices
pad_mask = tf.logical_or( # shape like gather_positions
tf.greater(gather_positions, x.batch_shape[1] - 1),
tf.less(gather_positions, 0))
gather_positions = tf.clip_by_value(gather_positions, 0, x.batch_shape[1] - 1)
gather_positions_data.placeholder = gather_positions
position = InternalLayer(
network=self.network,
name="%s_internal" % gather_positions_data.name,
output=gather_positions_data)
gather_layer = GatherLayer(
name="%s_gather" % self.name,
network=self.network,
output=self.output,
sources=self.sources,
position=position,
axis=x.get_time_dim_tag())
placeholder = gather_layer.output.placeholder
# In principle, the padded frames are being ignored
# (unless get_padding_info_dict_ref et al are used).
# However, you can still end up with gradients for them
# in unexpected ways.
# Due to our gather implementation,
# the gradient flow would go into wrong frames
# and might lead to unexpected behavior.
# So to be on the safe side, we do the masking here.
pad_mask = expand_multiple_dims(pad_mask, [-1] * (len(placeholder.shape) - len(pad_mask.shape)))
self.output.placeholder = where_bc(pad_mask, tf.zeros_like(placeholder), placeholder)
albertz marked this conversation as resolved.
Show resolved Hide resolved

def get_dep_layers(self):
"""
Expand All @@ -909,11 +962,24 @@ def get_out_data_from_opts(cls, name, sources=(), start=None, size=None, **kwarg
:rtype: Data
"""
from ..util.data import DimensionTag
input_data = get_concat_sources_data_template(sources).copy_as_batch_spatial_major()
if start:
input_data.beam = SearchBeam.get_combined_beam(input_data.beam, start.output.beam)
new_dim_tag = DimensionTag(kind=DimensionTag.Types.Spatial, description="%s:slice_nd" % name, dimension=size)
return input_data.copy_template_replace_dim_tag(axis=1, new_dim_tag=new_dim_tag, name="%s_output" % name)
start_data = start.output.copy_as_batch_major()
input_data = sources[0].output.copy_as_batch_major()
gather_positions_data = start_data.copy_template(name="%s_gather_positions" % name)
# size might be None here in which case we set the dyn_size in __init__
tag = DimensionTag(
kind=DimensionTag.Types.Spatial,
description="sliced-time:%s" % name,
dimension=size)
gather_positions_data = gather_positions_data.copy_add_dim_by_tag(tag, unbroadcast=True, axis=start_data.batch_ndim)
position = InternalLayer(
network=sources[0].network,
name="%s_internal" % gather_positions_data.name,
output=gather_positions_data)
return GatherLayer.get_out_data_from_opts(
name="%s_gather" % name,
sources=sources,
position=position,
axis=input_data.get_time_dim_tag())

@classmethod
def transform_config_dict(cls, d, network, get_layer):
Expand Down
48 changes: 48 additions & 0 deletions tests/test_TFNetworkLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2662,6 +2662,54 @@ def test_SliceNdLayer_dyn_size():
numpy.testing.assert_equal(orig_seq[t], out[b, t])


def test_SliceNdLayer_multidimensional_start():
albertz marked this conversation as resolved.
Show resolved Hide resolved
with make_scope() as session:
n_out = 5
n_batch = 3
max_seq_len = 10
config = Config({
"debug_print_layer_output_template": True,
"extern_data": {
"data": {"dim": n_out},
"classes": {"dim": n_out, "sparse": True}
}})
net = TFNetwork(config=config, train_flag=True)
net.construct_from_dict({
"output": {
"class": "rec", "from": "data:data", "unit": {
"start": {"class": "copy", "from": "prev:choice"},
"slices": {"class": "slice_nd", "from": "base:data:data", "start": "start", "size": None},
"output": {"class": "reduce", "from": "slices", "mode": "max", "axes": "dyn:-1"},
"prob": {"class": "softmax", "from": "data:source", "target": "classes", "loss": "ce"},
'choice': {
'class': 'choice', 'target': "classes", 'beam_size': 3, 'from': "prob", "input_type": "prob",
"initial_output": 0,}}}})
session.run(tf_compat.v1.global_variables_initializer())
output_layer = net.layers["output"]
starts = output_layer.cell.output_layers_net.layers["start"].output.get_placeholder_as_batch_major()
segments = output_layer.cell.output_layers_net.layers["slices"].output.get_placeholder_as_batch_major()
feed = make_feed_dict(net.extern_data.data.values(), n_batch=n_batch, n_time=max_seq_len, same_time=True)
starts = session.run(starts, feed_dict=feed)
segments = session.run(segments, feed_dict=feed)
seq_lens = feed[net.extern_data.data["data"].size_placeholder[0]]
input_data = feed[net.extern_data.data["data"].placeholder]
max_size = numpy.amax(seq_lens[:, None] - starts)
max_size = max(max_size, 0)
assert segments.shape == (n_batch, max_seq_len, max_size, n_out)
for b in range(n_batch):
for t in range(max_seq_len):
s = starts[b, t]
orig_seq = input_data[b, s:]
if len(orig_seq) < max_size:
orig_seq = numpy.pad(orig_seq, [(0, max_size - len(orig_seq)), (0, 0)], "constant")
elif len(orig_seq) > max_size:
orig_seq = orig_seq[:max_size]
assert orig_seq.shape == (max_size, n_out)
orig_seq = numpy.where((numpy.arange(s, s + max_size) >= seq_lens[b])[:, None], 0.0, orig_seq)
for t2 in range(max_size):
numpy.testing.assert_equal(orig_seq[t2], segments[b, t, t2])
albertz marked this conversation as resolved.
Show resolved Hide resolved


def test_WindowLayer_output_placeholder():
with make_scope() as session:
net = TFNetwork(extern_data=ExternData())
Expand Down