Skip to content

Commit

Permalink
DotLayer, single reduce argument
Browse files Browse the repository at this point in the history
Fix #636.
  • Loading branch information
albertz committed Dec 6, 2021
1 parent c76487f commit f7f6680
Showing 1 changed file with 50 additions and 17 deletions.
67 changes: 50 additions & 17 deletions returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6178,13 +6178,16 @@ class DotLayer(LayerBase):
"""
layer_class = "dot"

def __init__(self, red1=NotSpecified, red2=NotSpecified, var1=NotSpecified, var2=NotSpecified,
def __init__(self,
reduce=NotSpecified,
red1=NotSpecified, red2=NotSpecified, var1=NotSpecified, var2=NotSpecified,
add_var2_if_empty=NotSpecified, debug=False, **kwargs):
"""
:param str|Dim|tuple[str|DimensionTag]|list[str|DimensionTag] red1: reduce axes of first source
:param str|DimensionTag|tuple[str|DimensionTag]|list[str|DimensionTag] red2: reduce axes of second source
:param str|DimensionTag|tuple[str|DimensionTag]|list[str|DimensionTag]|None var1: var axes of first source
:param str|DimensionTag|tuple[str|DimensionTag]|list[str|DimensionTag]|None var2: var axes of second source
:param str|Dim|tuple[str|Dim]|list[str|Dim] reduce: reduce axes of both sources
:param str|Dim|tuple[str|Dim]|list[str|Dim] red1: reduce axes of first source
:param str|Dim|tuple[str|Dim]|list[str|Dim] red2: reduce axes of second source
:param str|Dim|tuple[str|Dim]|list[str|Dim]|None var1: var axes of first source
:param str|Dim|tuple[str|Dim]|list[str|Dim]|None var2: var axes of second source
:param bool add_var2_if_empty: if var2=None, add dim=1 at the end
:param bool debug: will print debug shapes, etc.
Expand All @@ -6196,6 +6199,9 @@ def __init__(self, red1=NotSpecified, red2=NotSpecified, var1=NotSpecified, var2
from returnn.util import BehaviorVersion
from returnn.tf.util.basic import prod, get_shape, get_padding_info_dict_ref, mask_dyn_seq_len_nd
super(DotLayer, self).__init__(**kwargs)
if reduce is not NotSpecified:
assert red1 is NotSpecified and red2 is NotSpecified
red1 = red2 = reduce
BehaviorVersion.require(
condition=all(not isinstance(a, int) for a in (red1, red2, var1, var2)),
message="DotLayer: Axes must be referenced by tag or special specified, not by int.",
Expand Down Expand Up @@ -6398,33 +6404,60 @@ def _add(dims, val, d_key):
_add(dims2, var2, "var2")

@classmethod
def get_out_data_from_opts(cls, name, sources, red1=-1, red2=-2, var1=-2, var2=-1,
def get_out_data_from_opts(cls, name, sources,
reduce=NotSpecified,
red1=NotSpecified, red2=NotSpecified, var1=NotSpecified, var2=NotSpecified,
add_var2_if_empty=NotSpecified, **kwargs):
"""
:param str name:
:param list[LayerBase] sources:
:param str|int|tuple[str|int]|list[str|int] red1: reduce axes of first source
:param str|int|tuple[str|int]|list[str|int] red2: reduce axes of second source
:param str|int|tuple[str|int]|list[str|int]|None var1: var axes of first source
:param str|int|tuple[str|int]|list[str|int]|None var2: var axes of second source
:param str|Dim|tuple[str|Dim]|list[str|Dim] reduce: reduce axes of both sources
:param str|Dim|tuple[str|Dim]|list[str|Dim] red1: reduce axes of first source
:param str|Dim|tuple[str|Dim]|list[str|Dim] red2: reduce axes of second source
:param str|Dim|tuple[str|Dim]|list[str|Dim]|None var1: var axes of first source
:param str|Dim|tuple[str|Dim]|list[str|Dim]|None var2: var axes of second source
:param bool add_var2_if_empty:
:rtype: Data
"""
from returnn.util import BehaviorVersion
from ..util.data import BatchInfo
assert len(sources) == 2, "dot-layer %r: needs exactly two sources" % (name,)
# See __init__.
# As usual, do as minimal error checking as possible here.
if add_var2_if_empty is NotSpecified:
add_var2_if_empty = True if BehaviorVersion.get() < 3 else False
if reduce is not NotSpecified:
assert red1 is NotSpecified and red2 is NotSpecified
red1 = red2 = reduce
BehaviorVersion.require(
condition=all(not isinstance(a, int) for a in (red1, red2, var1, var2)),
message="DotLayer: Axes must be referenced by tag or special specified, not by int.",
version=3)
BehaviorVersion.require(
condition=all(a is not NotSpecified for a in (red1, red2, var1, var2)),
message="DotLayer: Axes must be specified explicitly. There is no default.",
version=3)
BehaviorVersion.require(
condition=add_var2_if_empty is NotSpecified or not add_var2_if_empty,
message="DotLayer: add_var2_if_empty not allowed",
version=3)
if BehaviorVersion.get() < 3:
# Earlier defaults: red1=-1, red2=-2, var1=-2, var2=-1, add_var2_if_empty=True.
red1 = -1 if red1 is NotSpecified else red1
red2 = -2 if red2 is NotSpecified else red2
var1 = -2 if var1 is NotSpecified else var1
var2 = -1 if var2 is NotSpecified else var2
add_var2_if_empty = True if add_var2_if_empty is NotSpecified else add_var2_if_empty
axis_desc_allow_int = True
else:
# add_var2_if_empty not supported anymore.
add_var2_if_empty = False
axis_desc_allow_int = False
a_out = sources[0].output.copy()
a_reduce_axes = a_out.get_axes_from_description(red1)
a_reduce_axes = a_out.get_axes_from_description(red1, allow_int=axis_desc_allow_int)
b_out = sources[1].output.copy()
assert not a_out.beam or not b_out.beam or a_out.beam == b_out.beam
b_reduce_axes = b_out.get_axes_from_description(red2)
b_reduce_axes = b_out.get_axes_from_description(red2, allow_int=axis_desc_allow_int)
assert a_reduce_axes and b_reduce_axes, "%s: sources %r, red1 %r, red2 %r" % (name, sources, red1, red2)
a_var_axes = a_out.get_axes_from_description(var1)
b_var_axes = b_out.get_axes_from_description(var2)
a_var_axes = a_out.get_axes_from_description(var1, allow_int=axis_desc_allow_int)
b_var_axes = b_out.get_axes_from_description(var2, allow_int=axis_desc_allow_int)
assert not set(a_reduce_axes).intersection(a_var_axes)
assert not set(b_reduce_axes).intersection(b_var_axes)
a_rem_axes = [i for i in range(a_out.batch_ndim) if i not in a_var_axes + a_reduce_axes]
Expand Down

0 comments on commit f7f6680

Please sign in to comment.