Skip to content

Commit

Permalink
Enforce axis description order independent (#795)
Browse files Browse the repository at this point in the history
This is required to ensure that layers can reorder axes in whatever way to allow for potential optimizations. Nothing should depend on the order of axes. See the [RETURNN principles](https://github.com/rwth-i6/returnn/wiki/RETURNN-principles).

This is also for #792 to allow for an easier transition.

This introduces a new behavior version (#508).

While it probably requires changes for many configs, the changes should still be quite simple.
  • Loading branch information
albertz authored Nov 29, 2021
1 parent 7e6d9d3 commit 6aad62b
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 85 deletions.
11 changes: 11 additions & 0 deletions docs/configuration_reference/behavior_version.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@ and not listing legacy/deprecated parameters.
Version History
---------------

Behavior version 7 (2021-11-29)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Do not allow to specify ``axes`` or ``axis`` arguments in a way that depends on the order of the axes.
E.g. things like ``axis="spatial:1"`` would not be allowed.

To fix this, use dimension tags, i.e. :class:`DimensionTag` instances.
To fix older configs without too much effort,
you might also want to use ``"stag:<name>"`` or ``"stag-single:<idx>:<name>"``
or ``"dim:<static-dim>"``.

Behavior version 6 (2021-11-27)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
39 changes: 29 additions & 10 deletions returnn/tf/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3437,6 +3437,19 @@ def _verify_axis_int_from_description(cls, allow_int=NotSpecified):
return
raise Exception(msg)

@classmethod
def _verify_axis_order_dependent(cls):
"""
Call this when you have the case that ``axis`` or ``axes``
in :func:`get_axes_from_description` or :func:`get_axis_from_description`
depends on the order of the axes.
"""
from returnn.util import BehaviorVersion
BehaviorVersion.require(
condition=False,
message="Do not specify axis or axes in a way that depends on the order of the axes.",
version=7)

def _make_valid_int_axis(self, axis):
"""
:param int axis: counted with batch. anything in [-ndim,ndim-1]
Expand Down Expand Up @@ -3480,6 +3493,7 @@ def get_axes_from_description(self, axes, allow_int=NotSpecified):
elif axes == "spatial":
return self.get_spatial_batch_axes()
elif re.match("(s|spatial):-?\\d+$", axes):
self._verify_axis_order_dependent()
s = int(axes.split(":")[1])
spatial_axes = self.get_spatial_batch_axes()
if s < 0:
Expand All @@ -3489,6 +3503,7 @@ def get_axes_from_description(self, axes, allow_int=NotSpecified):
elif axes in ["dyn", "dynamic"]:
return self.get_dynamic_axes()
elif re.match("(d|dyn|dynamic):-?\\d+$", axes):
self._verify_axis_order_dependent()
s = int(axes.split(":")[1])
dyn_axes = self.get_dynamic_axes()
if s < 0:
Expand Down Expand Up @@ -3516,6 +3531,7 @@ def get_axes_from_description(self, axes, allow_int=NotSpecified):
axes.remove(self.batch_dim_axis)
return axes
elif re.match("(except_batch):-?\\d+$", axes):
self._verify_axis_order_dependent()
s = int(axes.split(":")[1])
non_batch_axes = list(range(self.batch_ndim))
if self.batch_dim_axis is not None:
Expand All @@ -3529,6 +3545,7 @@ def get_axes_from_description(self, axes, allow_int=NotSpecified):
elif axes == "static":
return self.get_static_axes()
elif re.match("(static):-?\\d+$", axes):
self._verify_axis_order_dependent()
s = int(axes.split(":")[1])
static_axes = self.get_static_axes()
if s < 0:
Expand All @@ -3549,6 +3566,9 @@ def get_axes_from_description(self, axes, allow_int=NotSpecified):
elif axes.startswith("stag-single:"): # spatial tag which possibly matches multiple spatial axes
# in this case, a name of form "stag-single:<idx>:<name> is expected.
# idx is relative to the matching stags, i.e., it is the index among the list of spatial dims matching the name
# Note: no _verify_axis_order_dependent here because as long as we do not enforce unique dim tags
# (https://github.com/rwth-i6/returnn/issues/632), we can have multiple axes with the same tag,
# and then we need to be able to differentiate between them by order.
_, idx_s, name = axes.split(":", 2) # stag-single:<idx>:<name>
idx = int(idx_s)
return [self.get_axes_by_tag_name(name, spatial_only=True)[idx]]
Expand Down Expand Up @@ -3612,16 +3632,15 @@ def get_description_from_axis(self, axis):
if len(matching_tags) == 1:
# Fallback with dim tag
return dim_tag
# Fallback without dim tag
if dim_tag.dimension is not None: # static
kind = "static"
axes = self.get_static_axes()
else: # dynamic
kind = "dynamic"
axes = self.get_dynamic_axes()
assert axis in axes, "%s: %s axes %s do not contain axis %i" % (self, kind, axes, axis)
i = axes.index(axis)
return "%s:%i" % (kind, i - len(axes)) # negative because this is likely more robust
# Do not use indexed static or dynamic because we want to avoid relying on the axis order as much as possible.
# However, as we do not have unique dim tags in this case, we have to rely at least on the order of this dim tag.
# Use stag-single.
name = dim_tag.description
matching_axes = self.get_axes_by_tag_name(name, spatial_only=True)
assert axis in matching_axes
return (
"stag-single:%i:%s" % (
matching_axes.index(axis) - len(matching_axes), name)) # negative because this is likely more robust

def has_axis(self, axis):
"""
Expand Down
2 changes: 1 addition & 1 deletion returnn/util/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ class BehaviorVersion:
The version will be set after the config is defined at __main__.init_config() or Engine.__init__()
"""

_latest_behavior_version = 6
_latest_behavior_version = 7
_behavior_version = None # type: typing.Optional[int]

@classmethod
Expand Down
13 changes: 6 additions & 7 deletions tests/test_TFEngine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,21 +1320,20 @@ def test_attention_convolutional_feedback_variant1():
# basic attention
"s_transformed": {'class': 'linear', 'from': ['s'], 'n_out': 6, 'activation': None},
"att_energy_tanh": {'class': 'activation', 'from': ['att_energy_in'], 'activation': 'tanh'},
"att_energy": {'class': 'linear', 'from': ['att_energy_tanh'], 'n_out': 1, 'activation': None,
},
"att_energy": {'class': 'linear', 'from': ['att_energy_tanh'], 'n_out': 1, 'activation': None},
"att_weights": {'class': 'softmax_over_spatial', 'from': ['att_energy']},
'accum_att_weights': {'class': 'eval',
'eval': 'source(0) + source(1)',
'from': ['prev:accum_att_weights', 'att_weights'],
'is_output_layer': True,
'out_type': {'dim': 1, 'shape': (None, 1)}},
'feedback_pad_left': {'axes': 's:0',
'feedback_pad_left': {'axes': 'stag:extern_data:data',
'class': 'pad',
'from': ['prev:accum_att_weights'],
'mode': 'constant',
'padding': ((2, 0),),
'value': 1},
'feedback_pad_right': {'axes': 's:0',
'feedback_pad_right': {'axes': 'stag:pad',
'class': 'pad',
'from': ['feedback_pad_left'],
'mode': 'constant',
Expand Down Expand Up @@ -1395,13 +1394,13 @@ def test_attention_convolutional_feedback_variant3():
"accum_att_weights": {'class': 'combine', 'kind': 'add',
'from': ['att_weights', 'prev:accum_att_weights'],
},
'feedback_pad_left': {'axes': 's:0',
'feedback_pad_left': {'axes': 'stag:extern_data:data',
'class': 'pad',
'from': ['prev:accum_att_weights'],
'mode': 'constant',
'padding': ((2, 0),),
'value': 1},
'feedback_pad_right': {'axes': 's:0',
'feedback_pad_right': {'axes': 'stag:pad',
'class': 'pad',
'from': ['feedback_pad_left'],
'mode': 'constant',
Expand Down Expand Up @@ -2328,7 +2327,7 @@ def test_rec_subnet_eval_init_out_apply0():
"from": ["base:enc_ctx", "att_query"], "debug": True}, # (B, enc-T, H, 1)

"att_weights": {"class": "softmax_over_spatial", "from": ["energy"], "energy_factor": EncKeyPerHeadDim ** -0.5},
"att_weights_avg": {"class": "reduce", "axes": "static:0", "mode": "avg", "from": ["att_weights"]}, # (B, enc-T, 1)
"att_weights_avg": {"class": "reduce", "axes": "dim:%i" % AttNumHeads, "mode": "avg", "from": ["att_weights"]}, # (B, enc-T, 1)
"accum_att_weights": {"class": "eval",
"from": ["prev:accum_att_weights", "att_weights_avg", "base:inv_fertility"],
"eval": "source(0) + source(1) * source(2) * 0.5",
Expand Down
65 changes: 25 additions & 40 deletions tests/test_TFNetworkLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,8 @@ def test_cnn_building_block():
"num_outputs": filters,
"network": {
"split": {"class": "split_dims", "axis": "f", "dims": (channel_num, feature_dim), "from": ["data"]},
"swap_axes": {"class": "swap_axes", "axis1": "s:1", "axis2": "f", "from": ["split"]},
"swap_axes": {
"class": "swap_axes", "axis1": "dim:%i" % channel_num, "axis2": "dim:%i" % feature_dim, "from": "split"},
"c1": {"class": "conv", "n_out": filters, "filter_size": filter_size, "auto_use_channel_first": False,
"strides": (1, 1), "dilation_rate": (1, 1), "padding": "SAME", "activation": None, "with_bias": False,
"from": "swap_axes"},
Expand All @@ -746,7 +747,8 @@ def test_cnn_building_block():
"bn2": {"class": "batch_norm", "from": "p"},
"y2": {"class": "activation", "activation": "relu", "batch_norm": False, "from": "bn2"},

"out_pool": {"class": "reduce", "mode": "avg", "axes": "s:1", "keep_dims": False, "from": "y2"},
"out_pool": {
"class": "reduce", "mode": "avg", "axes": "dim:%i" % feature_dim, "keep_dims": False, "from": "y2"},
"output": {"class": "copy", "from": ["out_pool"], "is_output_layer": True}
}})
network = TFNetwork(config=config, train_flag=True)
Expand Down Expand Up @@ -1248,8 +1250,8 @@ def test_dot_layer_shuffled_remaining_dims_static():
with make_scope() as session:
import numpy as np
net_dict = {
"a": {"class": "split_dims", "axis": "static:0", "dims": (2, 3, 5), "from": "data:data"},
"b": {"class": "transpose", "from": ["a"], "perm": {"static:0": "static:1", "static:1": "static:0"}},
"a": {"class": "split_dims", "axis": "F", "dims": (2, 3, 5), "from": "data:data"},
"b": {"class": "transpose", "from": ["a"], "perm": {"dim:2": "dim:3", "dim:3": "dim:2"}},
"dot": {
"class": "dot", "from": ["a", "b"],
"red1": "dim:5", "red2": "dim:5", "var1": None, "var2": None,
Expand Down Expand Up @@ -1504,7 +1506,7 @@ def test_ReduceLayer_reduce4d():
name="src", network=network, output=Data(name="src", shape=(None, 4, 512), auto_create_placeholders=True))
print("src:", src_layer)
opts = {
'axes': "s:1",
'axes': "dim:4",
'keep_dims': True,
'mode': 'mean',
'name': 'c_out_reduce',
Expand Down Expand Up @@ -1966,7 +1968,7 @@ def test_MergeDimsLayer_dim_tags():
src_data.get_axis_by_tag_name('key-chunk') == 1 and src_data.get_axis_by_tag_name('key-window') == 2 and
src_data.get_axis_by_tag_name('att-heads') == 3)

merge_axes = ['stag:key-window', 'spatial:-1']
merge_axes = ['stag:key-window', 'dim:1']
print('merge axes:', merge_axes)

src = InternalLayer(name="src", network=net, output=src_data)
Expand Down Expand Up @@ -2099,7 +2101,7 @@ def test_MergeDimsLayer_2d_dynamic_merge_axis():
"start1": {"class": "range_in_axis", "from": "data", "axis": "t", "out_shape": {time_dim, ImplicitDynSizeDim(BatchDim)}},
"start": {"class": "combine", "from": ["start0", "start1"], "kind": "add", "out_shape": {BatchDim, time_dim}},
"slices": {"class": "slice_nd", "from": "data", "start": "start", "size": None}, # [B,T[B],slice[B,T],D]
"output": {"class": "merge_dims", "from": "slices", "axes": ["f", "dyn:-1"]} # [B,T[B],merge[B,T]]
"output": {"class": "merge_dims", "from": "slices", "axes": ["f", "stag:slice"]} # [B,T[B],merge[B,T]]
})
slices_layer = net.get_layer("slices")
assert isinstance(slices_layer, SliceNdLayer)
Expand Down Expand Up @@ -2671,7 +2673,7 @@ def test_ScatterNdLayer_pos_batch_last_dim():
auto_create_placeholders=True))
scatter_opts = dict(
name="scatter", network=network,
sources=[val], position=pos, position_axis="except_batch:-1",
sources=[val], position=pos, position_axis="dim:6",
output_dim_via_time_from=data, filter_invalid_indices=True)
scatter_out_template = ScatterNdLayer.get_out_data_from_opts(**scatter_opts)
print("scatter out:", scatter_out_template)
Expand Down Expand Up @@ -3116,9 +3118,9 @@ def test_GatherLayer():
# should become [B, T, F2, F1]
layer = GatherLayer(
name="gather", network=net,
sources=[values], position=position, axis="static:0",
sources=[values], position=position, axis="dim:%i" % gather_dim,
output=GatherLayer.get_out_data_from_opts(
name="gather", sources=[values], position=position, axis="static:0"))
name="gather", sources=[values], position=position, axis="dim:%i" % gather_dim))
layer.output.sanity_check()
out_seqs, size = session.run([layer.output.placeholder, layer.output.size_placeholder.as_dict()])
assert isinstance(out_seqs, numpy.ndarray)
Expand Down Expand Up @@ -3165,9 +3167,9 @@ def test_GatherLayer_constant_position():
# should become [B, F1, F2]
layer = GatherLayer(
name="gather", network=net,
sources=[values], position=position, axis="static:-2",
sources=[values], position=position, axis="dim:%i" % gather_dim,
output=GatherLayer.get_out_data_from_opts(
name="gather", sources=[values], position=position, axis="static:-2"))
name="gather", sources=[values], position=position, axis="dim:%i" % gather_dim))
layer.output.sanity_check()
out_seqs = session.run(layer.output.placeholder)
assert isinstance(out_seqs, numpy.ndarray)
Expand Down Expand Up @@ -3317,7 +3319,7 @@ def test_SliceNdLayer_multidimensional_start():
"class": "rec", "from": "data:data", "unit": {
"start": {"class": "copy", "from": "prev:choice"},
"slices": {"class": "slice_nd", "from": "base:data:data", "start": "start", "size": None},
"output": {"class": "reduce", "from": "slices", "mode": "max", "axes": "dyn:-1"},
"output": {"class": "reduce", "from": "slices", "mode": "max", "axes": "stag:slice"},
"prob": {"class": "softmax", "from": "data:source", "target": "classes", "loss": "ce"},
'choice': {
'class': 'choice', 'target': "classes", 'beam_size': 3, 'from': "prob", "input_type": "prob",
Expand Down Expand Up @@ -3367,7 +3369,7 @@ def test_SliceNdLayer_multidimensional_size():
"start": {"class": "reinterpret_data", "from": "prev:choice", "set_sparse": False},
"size": {"class": "combine", "from": ["const1", "start"], "kind": "add"},
"slices": {"class": "slice_nd", "from": "base:data:data", "start": "start", "size": "size"},
"output": {"class": "reduce", "from": "slices", "mode": "max", "axes": "dyn:-1"},
"output": {"class": "reduce", "from": "slices", "mode": "max", "axes": "stag:slice"},
"prob": {"class": "softmax", "from": "data:source", "target": "classes", "loss": "ce"},
'choice': {
'class': 'choice', 'target': "classes", 'beam_size': 3, 'from': "prob", "input_type": "prob",
Expand Down Expand Up @@ -3707,52 +3709,35 @@ def test_pool_layer_NCHW():
pool_size=pool_size, padding=padding, strides=strides,
use_channel_first=True,
network=net, sources=[src_nchw]))
with tf_compat.v1.variable_scope("pool_nhwc_from_nchw"):
pool_nhwc_from_nchw = PoolLayer(
name="pool_nhwc_from_nchw", network=net, mode="max", pool_size=pool_size,
padding=padding, strides=strides, use_channel_first=False, sources=[src_nchw],
output=PoolLayer.get_out_data_from_opts(name="pool_nhwc_from_nchw",
pool_size=pool_size, padding=padding, strides=strides,
use_channel_first=False,
network=net, sources=[src_nchw]))
tf_compat.v1.global_variables_initializer().run()
out, seq_lens = session.run([pool_nhwc_from_nhwc.output.placeholder,
pool_nhwc_from_nhwc.output.size_placeholder[0]],
pool_nhwc_from_nhwc.output.get_sequence_lengths()],
feed_dict={src_nhwc.output.placeholder: np.random.rand(10, 11, 16, 16),
src_nhwc.output.size_placeholder[0]: np.full(shape=(10,), fill_value=11)}
src_nhwc.output.get_sequence_lengths(): np.full(shape=(10,), fill_value=11)}
)
print(out.shape)
assert_equal(out.shape, (10, 7, 6, 16))
print(seq_lens)
time_dim_axis = 1 if tf_util.is_gpu_available() else 0
out, seq_lens = session.run([pool_nchw_from_nhwc.output.placeholder,
pool_nchw_from_nhwc.output.size_placeholder[time_dim_axis]],
pool_nchw_from_nhwc.output.get_sequence_lengths()],
feed_dict={src_nhwc.output.placeholder: np.random.rand(10, 11, 16, 16),
src_nhwc.output.size_placeholder[0]: np.full(shape=(10,), fill_value=11)
src_nhwc.output.get_sequence_lengths(): np.full(shape=(10,), fill_value=11)
})
print(out.shape)
if time_dim_axis == 1:
print(pool_nchw_from_nhwc.output, out.shape)
if pool_nchw_from_nhwc.output.feature_dim_axis == 1:
assert_equal(out.shape, (10, 16, 7, 6))
else:
assert_equal(out.shape, (10, 7, 6, 16))
print(seq_lens)
if tf_util.is_gpu_available():
out, seq_lens = session.run([pool_nchw_from_nchw.output.placeholder,
pool_nchw_from_nchw.output.size_placeholder[1]],
pool_nchw_from_nchw.output.get_sequence_lengths()],
feed_dict={src_nchw.output.placeholder: np.random.rand(10, 16, 11, 16),
src_nchw.output.size_placeholder[1]: np.full(shape=(10,), fill_value=11)
src_nchw.output.get_sequence_lengths(): np.full(shape=(10,), fill_value=11)
})
print(out.shape)
assert_equal(out.shape, (10, 16, 7, 6))
print(seq_lens)
out, seq_lens = session.run([pool_nhwc_from_nchw.output.placeholder,
pool_nhwc_from_nchw.output.size_placeholder[0]],
feed_dict={src_nchw.output.placeholder: np.random.rand(10, 16, 11, 16),
src_nchw.output.size_placeholder[1]: np.full(shape=(10,), fill_value=11)}
)
print(out.shape)
assert_equal(out.shape, (10, 7, 6, 16))
print(seq_lens)


def test_TransposedConvLayer_2d_simple():
Expand Down Expand Up @@ -4050,7 +4035,7 @@ def test_DotLayer2():

kwargs = dict(
name="dot", network=net, sources=[a, b], debug=True,
red1='F', red2='spatial:-1', var1='B', var2='F')
red1='F', red2='dim:%i' % R, var1='B', var2='F')
layer = DotLayer(output=DotLayer.get_out_data_from_opts(**kwargs), **kwargs)
print(layer, layer.output)
assert layer.output.batch_dim_axis == 2
Expand Down
Loading

0 comments on commit 6aad62b

Please sign in to comment.