Skip to content

Commit

Permalink
Data: add dim tag format stag-single:<idx>:<name> (#490)
Browse files Browse the repository at this point in the history
Co-authored-by: Albert Zeyer <[email protected]>
  • Loading branch information
vieting and albertz authored Apr 27, 2021
1 parent 07ed4a7 commit df3e1d0
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions returnn/tf/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2571,6 +2571,12 @@ def get_axes_from_description(self, axes, allow_int=True):
return self.get_axes_from_description(list(axes))
elif axes.startswith("stag:"): # spatial tag
axes = self.get_axis_by_tag_name(axes[len("stag:"):], spatial_only=True)
elif axes.startswith("stag-single:"): # spatial tag which possibly matches multiple spatial axes
# in this case, a name of form "stag-single:<idx>:<name> is expected.
# idx is relative to the matching stags, i.e., it is the index among the list of spatial dims matching the name
_, idx_s, name = axes.split(":", 2) # stag-single:<idx>:<name>
idx = int(idx_s)
axes = self.get_axes_by_tag_name(name, spatial_only=True)[idx]
else:
raise Exception("invalid axis mode %r" % axes)
if isinstance(axes, int):
Expand Down Expand Up @@ -2601,18 +2607,28 @@ def get_axis_from_description(self, axis, allow_int=True):
assert len(axes) == 1, "%r: %r is not a unique axis but %r" % (self, axis, axes)
return axes[0]

def get_axis_by_tag_name(self, name, spatial_only=False):
def get_axes_by_tag_name(self, name, spatial_only=False):
"""
:param str name: the tag name, or part of it (must be unique, and must exist)
:param bool spatial_only:
:rtype: int
:rtype: list[int]
"""
dim_tags = self.get_batch_shape_dim_tags()
matching_dim_tags = [(axis, tag) for axis, tag in enumerate(dim_tags) if name.lower() in tag.description.lower()]
if spatial_only:
matching_dim_tags = [(axis, tag) for axis, tag in matching_dim_tags if tag.kind == DimensionTag.Types.Spatial]
assert len(matching_dim_tags) == 1, "%r: tag name %r is not unique in dim tags %r" % (self, name, dim_tags)
return matching_dim_tags[0][0]
return [ax for ax, _ in matching_dim_tags]

def get_axis_by_tag_name(self, name, spatial_only=False):
"""
:param str name: the tag name, or part of it (must be unique, and must exist)
:param bool spatial_only:
:rtype: int
"""
matching_dim_tags = self.get_axes_by_tag_name(name, spatial_only)
assert len(matching_dim_tags) == 1, "%r: tag name %r is not unique in dim tags %r" % (
self, name, self.get_batch_shape_dim_tags())
return matching_dim_tags[0]

def get_batch_axis_excluding_batch(self, axis):
"""
Expand Down

0 comments on commit df3e1d0

Please sign in to comment.