Skip to content

Commit

Permalink
MaskedComputationLayer, out spatial dim with automatic unmasking info
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Mar 8, 2022
1 parent 2d3e641 commit 11aa978
Showing 1 changed file with 31 additions and 9 deletions.
40 changes: 31 additions & 9 deletions returnn/tf/layers/rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -7773,7 +7773,7 @@ def transform_config_dict(cls, d, network, get_layer):
super(MaskedComputationLayer, cls).transform_config_dict(d, network=network, get_layer=get_layer)
# Just call it for dep resolution.
parent_layer_cache = d.setdefault("_parent_layer_cache", {})
d["_layer_class"], d["_layer_desc"] = cls._create_template(
d["_layer_class"], d["_layer_desc"], out_spatial_dim = cls._create_template(
name=d["_name"], network=network, sources=d["sources"],
masked_from=masked_from,
unit=d["unit"],
Expand All @@ -7783,7 +7783,16 @@ def transform_config_dict(cls, d, network, get_layer):
# We explicitly do not want to have these as deps.
d["mask"] = None
else:
d["mask"] = get_layer(d["mask"])
mask_layer = get_layer(d["mask"])
d["mask"] = mask_layer
if out_spatial_dim:
out_spatial_dim = out_spatial_dim.get_same_base()
assert out_spatial_dim.derived_from_op and out_spatial_dim.derived_from_op.kind == "mask"
assert out_spatial_dim.derived_from_op.attribs["unmask_type"] == "left"
out_dim_mask = out_spatial_dim.derived_from_op.attribs.get("mask")
assert not out_dim_mask or isinstance(out_dim_mask, Data)
if not out_dim_mask or out_dim_mask.placeholder is None:
out_spatial_dim.derived_from_op.attribs["mask"] = mask_layer.output

# noinspection PyUnusedLocal
@classmethod
Expand All @@ -7799,7 +7808,7 @@ def _create_template(cls, name, network, sources, masked_from, unit,
:param Dim|None out_spatial_dim:
:param (str)->LayerBase get_layer:
:param dict[str,LayerBase]|None parent_layer_cache:
:return: layer_class, layer_desc
:return: layer_class, layer_desc, out_spatial_dim
"""
from returnn.tf.network import get_layer_class
from .base import WrappedInternalLayer
Expand All @@ -7809,6 +7818,8 @@ def _create_template(cls, name, network, sources, masked_from, unit,
if masked_from:
if out_spatial_dim:
masked_from.output.get_time_dim_tag().declare_same_as(out_spatial_dim)
else:
out_spatial_dim = masked_from.output.get_time_dim_tag()
if network.is_inside_rec_layer(inside_loop=True):
source_data = (
masked_from.output
Expand All @@ -7828,13 +7839,24 @@ def _create_template(cls, name, network, sources, masked_from, unit,
if not network.is_inside_rec_layer() and source:
source_data = source.output.copy_template().copy_as_time_major()
# Create own time dim tag, to make sure we have some own custom.
if not out_spatial_dim:
out_spatial_dim = Dim(
kind=Dim.Types.Spatial, description="%s:masked:time" % name,
derived_from_tag=source_data.get_time_dim_tag(), auto_generated=True)
source_time_dim_tag = source_data.get_time_dim_tag()
out_spatial_dim_ = Dim(
kind=Dim.Types.Spatial, description="%s:masked:time" % name,
derived_from_tag=source_time_dim_tag,
derived_from_op=Dim.Op(
kind="mask",
inputs=[source_time_dim_tag],
attribs={
"mask": None, # will be set later
"unmask_type": "left"}),
auto_generated=True)
if out_spatial_dim:
out_spatial_dim_.declare_same_as(out_spatial_dim)
else:
out_spatial_dim = out_spatial_dim_
source_data = source_data.copy_template_replace_dim_tag(
axis=0,
new_dim_tag=out_spatial_dim)
new_dim_tag=out_spatial_dim_)
source = WrappedInternalLayer(
base_layer=source, network=source.network, name=source.name,
output=source_data)
Expand Down Expand Up @@ -7879,7 +7901,7 @@ def sub_get_layer(sub_layer_name):
layer_class.transform_config_dict(layer_desc, network=extra_net, get_layer=sub_get_layer)
# noinspection PyProtectedMember
layer_desc = extra_net._create_layer_layer_desc(name=name, layer_desc=layer_desc)
return layer_class, layer_desc
return layer_class, layer_desc, out_spatial_dim

@classmethod
def get_out_data_from_opts(cls, network, out_spatial_dim=None, **kwargs):
Expand Down

0 comments on commit 11aa978

Please sign in to comment.