diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index c38a62d0e..40f48de97 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -2571,6 +2571,12 @@ def get_axes_from_description(self, axes, allow_int=True): return self.get_axes_from_description(list(axes)) elif axes.startswith("stag:"): # spatial tag axes = self.get_axis_by_tag_name(axes[len("stag:"):], spatial_only=True) + elif axes.startswith("stag-single:"): # spatial tag which possibly matches multiple spatial axes + # in this case, a name of form "stag-single:: is expected. + # idx is relative to the matching stags, i.e., it is the index among the list of spatial dims matching the name + _, idx_s, name = axes.split(":", 2) # stag-single:: + idx = int(idx_s) + axes = self.get_axes_by_tag_name(name, spatial_only=True)[idx] else: raise Exception("invalid axis mode %r" % axes) if isinstance(axes, int): @@ -2601,18 +2607,28 @@ def get_axis_from_description(self, axis, allow_int=True): assert len(axes) == 1, "%r: %r is not a unique axis but %r" % (self, axis, axes) return axes[0] - def get_axis_by_tag_name(self, name, spatial_only=False): + def get_axes_by_tag_name(self, name, spatial_only=False): """ :param str name: the tag name, or part of it (must be unique, and must exist) :param bool spatial_only: - :rtype: int + :rtype: list[int] """ dim_tags = self.get_batch_shape_dim_tags() matching_dim_tags = [(axis, tag) for axis, tag in enumerate(dim_tags) if name.lower() in tag.description.lower()] if spatial_only: matching_dim_tags = [(axis, tag) for axis, tag in matching_dim_tags if tag.kind == DimensionTag.Types.Spatial] - assert len(matching_dim_tags) == 1, "%r: tag name %r is not unique in dim tags %r" % (self, name, dim_tags) - return matching_dim_tags[0][0] + return [ax for ax, _ in matching_dim_tags] + + def get_axis_by_tag_name(self, name, spatial_only=False): + """ + :param str name: the tag name, or part of it (must be unique, and must exist) + :param bool spatial_only: + :rtype: int + """ + matching_dim_tags = self.get_axes_by_tag_name(name, spatial_only) + assert len(matching_dim_tags) == 1, "%r: tag name %r is not unique in dim tags %r" % ( + self, name, self.get_batch_shape_dim_tags()) + return matching_dim_tags[0] def get_batch_axis_excluding_batch(self, axis): """