Skip to content

Commit

Permalink
MaskedComputationLayer, better dim logic, unmask if rec optimized
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Mar 8, 2022
1 parent 3f938d5 commit 97c612b
Showing 1 changed file with 104 additions and 57 deletions.
161 changes: 104 additions & 57 deletions returnn/tf/layers/rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -7578,19 +7578,20 @@ class MaskedComputationLayer(LayerBase):

def __init__(self, mask, unit, masked_from,
_layer_class, _layer_desc,
in_spatial_dim=None,
out_spatial_dim=None,
_parent_layer_cache=None,
**kwargs):
"""
:param LayerBase|None mask:
:param dict[str] unit:
:param LayerBase|None masked_from:
:param Dim|None out_spatial_dim:
:param Dim|None in_spatial_dim:
:param Dim|None out_spatial_dim: the masked dim
:param type[LayerBase] _layer_class:
:param dict[str] _layer_desc:
:param dict[str,LayerBase]|None _parent_layer_cache:
"""
out_spatial_dim # noqa # handled in transform_config_dict
from returnn.tf.network import get_layer_class
from .base import WrappedInternalLayer
from returnn.tf.util.basic import where_bc, get_shape, nd_indices
Expand All @@ -7611,16 +7612,15 @@ def __init__(self, mask, unit, masked_from,
else:
assert mask.output.have_time_axis() and mask.output.shape == (None,) and mask.output.dtype == "bool", (
"%s: invalid mask %s (outside rec loop)" % (self, mask))
assert in_spatial_dim and out_spatial_dim
mask_data = mask.output.copy_as_time_major()
mask_t = where_bc(mask_data.placeholder, mask_data.get_sequence_mask(), tf.convert_to_tensor(False))
idxs = tf.cumsum(tf.cast(mask_t, tf.int32), axis=0) # [T,B] -> idx in T' + 1
if masked_from:
new_size = masked_from.output.get_sequence_lengths()
new_dim_tag = masked_from.output.get_time_dim_tag()
else:
new_size = idxs[-1] # [B]
new_dim_tag = self.output.get_time_dim_tag()
new_dim_tag.dyn_size = new_size
out_spatial_dim.dyn_size = new_size
new_time = tf.reduce_max(new_size) # T'
idxs = where_bc(mask_t, idxs - 1, new_time)

Expand All @@ -7636,12 +7636,17 @@ def get_masked_layer(source):
# We can just leave it as-is. The state will handled below.
return source
else:
source_data = source.output.copy_as_time_major()
assert in_spatial_dim
if in_spatial_dim not in source.output.dim_tags:
return source
axis = source.output.get_axis_from_description(in_spatial_dim)
source_data = source.output.copy_move_axis(old_axis=axis, new_axis=0) # time-major
source_data.time_dim_axis = 0
assert source_data.is_same_time_dim(mask_data), "%s mask and source time dim do not match" % self
tmp_shape = get_shape(source_data.placeholder)
tmp_shape[0] = new_time + 1 # one more for the padded data
res = tf.scatter_nd(nd_indices(idxs, batch_axis=1), source_data.placeholder, shape=tmp_shape)
res_data = source_data.copy_template().copy_template_replace_dim_tag(axis=0, new_dim_tag=new_dim_tag)
res_data = source_data.copy_template().copy_template_replace_dim_tag(axis=0, new_dim_tag=out_spatial_dim)
res_data.placeholder = res[:new_time]
res_data.beam = SearchBeam.get_combined_beam(res_data.beam, mask.output.beam)
layer_desc = dict(base_layer=source, network=source.network, name=source.name, output=res_data)
Expand All @@ -7661,7 +7666,7 @@ def get_masked_layer(source):
source.sources.append(masked_from) # add dep
sub_layers["data"] = source

else:
elif len(self.sources) >= 1:
assert len(self.sources) == 1
sub_layers["data"] = get_masked_layer(self.sources[0])

Expand Down Expand Up @@ -7698,7 +7703,19 @@ def sub_get_layer(sub_layer_name):
assert isinstance(self.sub_layer, LayerBase)
self.sub_layer.post_init(layer_desc)
self.sub_layer.output.sanity_check()
self.output = self.sub_layer.output.copy(name="%s_output" % self.name)
inside_rec_time_dim = self.network.get_inside_rec_time_dim(inside_loop=True)
over_rec_time_dim = self.network.get_inside_rec_time_dim(inside_loop=False)
self._unmask_layer = None
# In case we are in rec layer and optimize out of loop and mask over this axis, and don't have masked_from:
if not masked_from and over_rec_time_dim == in_spatial_dim and over_rec_time_dim and not inside_rec_time_dim:
# We will automatically unmask again (https://github.com/rwth-i6/returnn/pull/976).
assert self.mask
self._unmask_layer = UnmaskLayer(
name="%s(internal-unmask)" % self.name, mask=self.mask, sources=[self.sub_layer], network=self.network,
output=self.output.copy_template())
self.output = self._unmask_layer.output.copy(name="%s_output" % self.name)
else:
self.output = self.sub_layer.output.copy(name="%s_output" % self.name)
self.rec_vars_outputs = self.sub_layer.rec_vars_outputs.copy()
self.params = self.sub_layer.params

Expand Down Expand Up @@ -7773,10 +7790,11 @@ def transform_config_dict(cls, d, network, get_layer):
super(MaskedComputationLayer, cls).transform_config_dict(d, network=network, get_layer=get_layer)
# Just call it for dep resolution.
parent_layer_cache = d.setdefault("_parent_layer_cache", {})
d["_layer_class"], d["_layer_desc"] = cls._create_template(
d["_layer_class"], d["_layer_desc"], d["in_spatial_dim"], d["out_spatial_dim"] = cls._create_template(
name=d["_name"], network=network, sources=d["sources"],
masked_from=masked_from,
unit=d["unit"],
in_spatial_dim=d.get("in_spatial_dim", None),
out_spatial_dim=d.get("out_spatial_dim", None),
get_layer=get_layer, _parent_layer_cache=parent_layer_cache)
if masked_from and not parent_layer_cache:
Expand All @@ -7788,82 +7806,99 @@ def transform_config_dict(cls, d, network, get_layer):
# noinspection PyUnusedLocal
@classmethod
def _create_template(cls, name, network, sources, masked_from, unit,
out_spatial_dim=None,
in_spatial_dim=None, out_spatial_dim=None,
get_layer=None, _parent_layer_cache=None, **kwargs):
"""
:param str name:
:param returnn.tf.network.TFNetwork network:
:param list[LayerBase] sources:
:param LayerBase masked_from:
:param dict[str] unit:
:param Dim|None out_spatial_dim:
:param Dim|None in_spatial_dim:
:param Dim|None out_spatial_dim: the masked dim
:param (str)->LayerBase get_layer:
:param dict[str,LayerBase]|None parent_layer_cache:
:return: layer_class, layer_desc
:return: layer_class, layer_desc, in_spatial_dim, out_spatial_dim
"""
from returnn.tf.network import get_layer_class
from .base import WrappedInternalLayer
if not get_layer:
get_layer = network.get_layer
# We don't care about the right masked input here, but just about deriving the right output shape.
if masked_from:
if out_spatial_dim:
masked_from.output.get_time_dim_tag().declare_same_as(out_spatial_dim)
if network.is_inside_rec_layer(inside_loop=True):
source_data = (
masked_from.output
.copy_template_excluding_time_dim(
name="%s_%s_masked_input_frame" % (masked_from.output.name, name))
.copy_template_set_ctx(network.get_control_flow_ctx()))
else:
source_data = masked_from.output.copy_template(
name="%s_%s_masked_input" % (masked_from.output.name, name))
source_data.available_for_inference = True # we would make sure that this works at inference
source = WrappedInternalLayer(
base_layer=masked_from, network=masked_from.network, name=masked_from.name, output=source_data)
else:
assert len(sources) == 1
source, = sources
assert isinstance(source, LayerBase) or not source
if not network.is_inside_rec_layer() and source:
source_data = source.output.copy_template().copy_as_time_major()
# Create own time dim tag, to make sure we have some own custom.
if not out_spatial_dim:
out_spatial_dim = Dim(
kind=Dim.Types.Spatial, description="%s:masked:time" % name,
derived_from_tag=source_data.get_time_dim_tag(), auto_generated=True)
source_data = source_data.copy_template_replace_dim_tag(
axis=0,
new_dim_tag=out_spatial_dim)
source = WrappedInternalLayer(
base_layer=source, network=source.network, name=source.name,
output=source_data)

over_rec_time_dim = network.get_inside_rec_time_dim(inside_loop=False)
if over_rec_time_dim and not in_spatial_dim:
in_spatial_dim = over_rec_time_dim

class _Locals: # Python 2, and modify this maybe via sub_get_layer
in_spatial_dim_ = in_spatial_dim
out_spatial_dim_ = out_spatial_dim

_layer_cache = {}

def sub_get_layer(sub_layer_name):
"""
:param str sub_layer_name:
:rtype: LayerBase
"""
if sub_layer_name == "data":
if sub_layer_name in _layer_cache:
return _layer_cache[sub_layer_name]

if sub_layer_name == "data" and masked_from:
if _Locals.out_spatial_dim_:
if _Locals.out_spatial_dim_ not in masked_from.output.dim_tags:
masked_from.output.get_time_dim_tag().declare_same_as(_Locals.out_spatial_dim_)
else:
_Locals.out_spatial_dim_ = masked_from.output.get_time_dim_tag()
if network.is_inside_rec_layer():
source_data = (
masked_from.output
.copy_template_excluding_time_dim(
name="%s_%s_masked_input_frame" % (masked_from.output.name, name))
.copy_template_set_ctx(network.get_control_flow_ctx()))
else:
source_data = masked_from.output.copy_template(
name="%s_%s_masked_input" % (masked_from.output.name, name))
source_data.available_for_inference = True # we would make sure that this works at inference
source = WrappedInternalLayer(
base_layer=masked_from, network=masked_from.network, name=masked_from.name, output=source_data)
_layer_cache[sub_layer_name] = source
return source
if _parent_layer_cache and sub_layer_name in _parent_layer_cache:

elif sub_layer_name == "data" and len(sources) >= 1:
assert len(sources) == 1
layer = sources[0]
elif _parent_layer_cache and sub_layer_name in _parent_layer_cache:
layer = _parent_layer_cache[sub_layer_name]
else:
assert not sub_layer_name.startswith(extra_net.extra_name_prefix + ":")
layer = get_layer(sub_layer_name)
if not layer:
return layer
return None
if _parent_layer_cache is not None:
_parent_layer_cache[sub_layer_name] = layer
if not network.is_inside_rec_layer():
# noinspection PyShadowingNames
source_data_ = layer.output.copy_template().copy_as_time_major()
source_data_ = source_data_.copy_template_replace_dim_tag(axis=0, new_dim_tag=source.output.get_time_dim_tag())
source_data = layer.output.copy_template()
if not _Locals.in_spatial_dim_:
_Locals.in_spatial_dim_ = source_data.get_time_dim_tag()

if _Locals.in_spatial_dim_ in source_data.dim_tags:
axis = source_data.get_axis_from_description(_Locals.in_spatial_dim_)
source_data = source_data.copy_move_axis(old_axis=axis, new_axis=0) # time-major
# Create own time dim tag, to make sure we have some own custom.
if not _Locals.out_spatial_dim_:
_Locals.out_spatial_dim_ = Dim(
kind=Dim.Types.Spatial, description="%s:masked:time" % name,
derived_from_tag=_Locals.in_spatial_dim_, auto_generated=True)
source_data = source_data.copy_template_replace_dim_tag(axis=0, new_dim_tag=_Locals.out_spatial_dim_)
layer = WrappedInternalLayer(
base_layer=layer, network=layer.network, name=layer.name,
output=source_data_)
output=source_data)
_layer_cache[sub_layer_name] = layer
return layer

if sources or masked_from:
# Make sure to resolve this first, to maybe declare in_spatial_dim/out_spatial_dim.
sub_get_layer("data")

extra_net = network.make_extra_net(
prefix_name="extra._internal_template.masked(%s)" % name,
# There should be no need by the base get_layer to get back to us.
Expand All @@ -7879,22 +7914,34 @@ def sub_get_layer(sub_layer_name):
layer_class.transform_config_dict(layer_desc, network=extra_net, get_layer=sub_get_layer)
# noinspection PyProtectedMember
layer_desc = extra_net._create_layer_layer_desc(name=name, layer_desc=layer_desc)
return layer_class, layer_desc
return layer_class, layer_desc, _Locals.in_spatial_dim_, _Locals.out_spatial_dim_

@classmethod
def get_out_data_from_opts(cls, network, out_spatial_dim=None, **kwargs):
def get_out_data_from_opts(cls, network, masked_from=None, in_spatial_dim=None, out_spatial_dim=None, **kwargs):
"""
:param returnn.tf.network.TFNetwork network:
:param LayerBase|Noen masked_from:
:param Dim|None in_spatial_dim:
:param Dim|None out_spatial_dim:
:rtype: Data
"""
out_spatial_dim # noqa # handled in transform_config_dict
layer_class, layer_desc = kwargs["_layer_class"], kwargs["_layer_desc"]
assert issubclass(layer_class, LayerBase)
output = layer_class.get_out_data_from_opts(**layer_desc)
assert isinstance(output, Data)
output.sanity_check()
if out_spatial_dim and out_spatial_dim in output.dim_tags:
axis = output.get_axis_from_description(out_spatial_dim)
output = output.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_spatial_dim)
inside_rec_time_dim = network.get_inside_rec_time_dim(inside_loop=True)
over_rec_time_dim = network.get_inside_rec_time_dim(inside_loop=False)
if network.is_inside_rec_layer():
output = output.copy_as_batch_major()
elif not masked_from and in_spatial_dim == over_rec_time_dim and over_rec_time_dim and not inside_rec_time_dim:
# We will automatically unmask again (https://github.com/rwth-i6/returnn/pull/976).
assert out_spatial_dim and out_spatial_dim in output.dim_tags
axis = output.get_axis_from_description(out_spatial_dim)
output = output.copy_template_replace_dim_tag(axis=axis, new_dim_tag=in_spatial_dim)
return output

def get_constraints_value(self):
Expand Down

0 comments on commit 97c612b

Please sign in to comment.