Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Solution for ambiguous dim tags #871

Merged
merged 3 commits into from
Feb 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions returnn/tf/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(self, kind=Types.Unspecified, description=None,
vocab=None,
dyn_size=None, dyn_size_ext=None,
undefined=False, generic=False, special=False,
match_priority=0,
derived_from_tag=None, derived_from_op=None,
batch=None, control_flow_ctx=None,
src_data=None, src_axis=None):
Expand All @@ -79,6 +80,10 @@ def __init__(self, kind=Types.Unspecified, description=None,
the behavior is to consider them as equal,
and assume that the chain of operations (e.g. padding + valid conv) results in the same dim.
:param Dim.Op|None derived_from_op:
:param int match_priority: when there is ambiguity between multiple dim tags, this value defines the order
in which the dimension are assigned to their matching counterparts.
A dimension tag with a higher priority value is assigned first.
E.g. for a square matrix used for a linear transformation, the reduce dim tag should have a higher priority.
:param BatchInfo|None batch: for batch-dim, or dynamic dims per batch
:param ControlFlowContext|None control_flow_ctx:
:param Data|None src_data:
Expand All @@ -95,6 +100,7 @@ def __init__(self, kind=Types.Unspecified, description=None,
self.derived_from_op = derived_from_op
if derived_from_op and not derived_from_op.output:
derived_from_op.output = self
self.match_priority = match_priority
if src_data:
assert isinstance(src_data, Data) and isinstance(src_axis, int)
if not batch and dyn_size_ext:
Expand Down Expand Up @@ -170,18 +176,20 @@ def __deepcopy__(self, memo=None):
"""
return self

def copy(self, same_as_self, description=None, kind=None):
def copy(self, same_as_self=True, description=None, kind=None, match_priority=None):
"""
:param bool same_as_self:
:param str|None description: new description
:param Entity|None kind: if set, overwrites self.kind
:param int|None match_priority:
:return: copy, maybe as new kind. setting same_as to self
:rtype: Dim
"""
if not same_as_self:
assert description is not None, "%s copy with not same_as_self should have a new description" % self
tag = Dim(
kind=kind or self.kind, description=description or self.description,
match_priority=match_priority if match_priority is not None else self.match_priority,
dimension=self.dimension, dyn_size_ext=self.dyn_size_ext,
batch=self.batch,
src_data=self.src_data, src_axis=self.src_axis)
Expand Down Expand Up @@ -4366,7 +4374,12 @@ def get_axes_from_description(self, axes, allow_int=NotSpecified):
# Once we have not guaranteed unique dim tags, multiple axes could match.
# https://github.com/rwth-i6/returnn/issues/632
dims = [i for (i, tag) in enumerate(self.dim_tags) if tag == axes]
assert len(dims) <= 1, "%s: matching dim %s must be unique" % (self, axes)
if len(dims) > 1:
max_match_priority = max(self.dim_tags[i].match_priority for i in dims)
dims = [i for i in dims if self.dim_tags[i].match_priority == max_match_priority]
assert len(dims) <= 1, (
"%s: matching dim %s must be unique,"
" use `match_priority` to resolve the matching order of ambiguous dimensions" % (self, axes))
return dims
if isinstance(axes, int):
self._verify_axis_int_from_description(allow_int=allow_int)
Expand Down
35 changes: 35 additions & 0 deletions tests/test_TFNetworkLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4293,6 +4293,41 @@ def test_DotLayer2():
assert_equal(out.shape, (S1, S2, B, V))


def test_DotLayer_linear_square_matrix():
from returnn.tf.util.data import batch_dim
time_dim = SpatialDim("time")
feat_dim = FeatureDim("feature", dimension=3)
config = Config({
"extern_data": {
"data": {"dim_tags": [batch_dim, time_dim, feat_dim]},
"matrix_ambiguous": {"dim_tags": [feat_dim, feat_dim], "available_for_inference": True},
"matrix_non_ambiguous": {
"dim_tags": [feat_dim.copy(match_priority=1), feat_dim], "available_for_inference": True},
},
})
with make_scope() as session:
net = TFNetwork(config=config)
try:
net.construct_from_dict({
"output": {
"class": "dot", "from": ["data:data", "data:matrix_ambiguous"], "reduce": feat_dim
},
})
except Exception as exc:
print("Expected exception: %r" % exc)
assert "must be unique" in str(exc)
else:
raise Exception("Expected exception but constructed layer: %s" % net.get_default_output_layer())
net.construct_from_dict({
"output": {
"class": "dot", "from": ["data:data", "data:matrix_non_ambiguous"], "reduce": feat_dim
},
})
out = net.get_default_output_layer().output
assert out.dim_tags == (batch_dim, time_dim, feat_dim)
session.run(out.placeholder, feed_dict=make_feed_dict(net.extern_data))


def test_DotLayer_mask_dyn_seq():
batch = Dim(kind=Dim.Types.Batch, description="batch")
time = SpatialDim("time")
Expand Down