Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend interface of SubsampleRagged. #900

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 115 additions & 14 deletions k2/csrc/ragged_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -699,14 +699,32 @@ RaggedShape RandomRaggedShape(bool set_row_ids = false,
int32_t max_num_elements = 2000);

/*
Return ragged shape with only a subset of the bottom-level elements kept.
Require renumbering.NumOldElems() == src.NumElements(). Note: all
dimensions and tot-sizes preceding the final axis will remain the same, which
might give rise to empty lists.
Return ragged shape with only a subset of the elements or sub-lists
on the specified axis kept. (This is not regular sampling, it is
irregular subsampling with specified elements kept).

@param [in] src The ragged shape that we are subsampling
@param [in] renumbering The renumbering object that dictates
which elements of `src` we keep; we require
renumbering.NumOldElems() == src.TotSize(axis2)
where axis2 = (axis < 0 ? src.NumAxes() - axis : axis).
@param [in] axis The axis to subsample; if negative, will be
interpreted as an offset from src.NumAxes().
@param [out] elems_renumbering If supplied, this function will
output to this location a renumbering object that
dictates how the elements of a ragged tensor
with shape `src` would be renumbered.
@return Returns the subsampled shape. All dimensions and tot-sizes
preceding the final axis will remain the same, which might give
rise to empty lists on those axes; these can be removed if
necessary with RemoveEmptyLists().

Notice the other version of this function below.
*/
RaggedShape SubsampleRaggedShape(RaggedShape &src, Renumbering &renumbering);
RaggedShape SubsampleRaggedShape(RaggedShape &src,
Renumbering &renumbering,
int32_t axis = -1,
Renumbering *elems_renumbering = nullptr);


/*
Expand Down Expand Up @@ -804,18 +822,102 @@ RaggedShape RemoveEmptyListsAxis0(RaggedShape &src_shape,
RaggedShape RenumberAxis0Simple(RaggedShape &src_shape,
Renumbering &renumbering);



/*
Return ragged array with only a subset of the bottom-level elements kept.
Require renumbering.NumOldElems() == src.NumElements(). Note: all
dimensions and tot-sizes preceding the final axis will remain the same, which
might give rise to empty lists.
Return ragged array with only a subset of the elements or sub-lists
on the specified axis kept. (This is not regular sampling, it is
irregular subsampling with specified elements kept).

@param [in] src The ragged shape that we are subsampling
@param [in] renumbering The renumbering object that dictates
which elements of `src` we keep; we require
renumbering.NumOldElems() == src.TotSize(axis2)
where axis2 = (axis < 0 ? src.NumAxes() - axis : axis).
@param [in] axis The axis to subsample; if negative, will be
interpreted as an offset from src.NumAxes().
@param [out] elems_renumbering If supplied, this function will
output to this location a renumbering object that
dictates how the elements of a ragged tensor
with shape `src` would be renumbered.
@return Returns the subsampled shape. All dimensions and tot-sizes
preceding the final axis will remain the same, which might give
rise to empty lists on those axes; these can be removed if
necessary with RemoveEmptyLists().

Notice the other version of this function below.
*/
template <typename T>
Ragged<T> SubsampleRagged(Ragged<T> &src, Renumbering &renumbering) {
return Ragged<T>(SubsampleRaggedShape(src.shape, renumbering),
src.values[renumbering.New2Old()]);
Ragged<T> SubsampleRagged(Ragged<T> &src, Renumbering &renumbering,
int32_t axis = -1,
Renumbering *elems_renumbering = nullptr) {
Renumbering tmp;
if (elems_renumbering == nullptr)
elems_renumbering = &tmp;
RaggedShape shape = SubsampleRaggedShape(src.shape, renumbering,
axis, elems_renumbering);
return Ragged<T>(src.values[*elems_renumbering.New2Old()]);
}


/*
This function creates a Renumbering object that can be used to obtain subsets
of ragged arrays via SubsampleRaggedShape(). It implements beam pruning as
used in pruned Viterbi search and similar algorithms, where there is both a
beam and a max-active (`max_elems`) constraint. T will probably be float or
double, interpreted as a "positive-is-better" sense, i.e. as scores.

@param [in] src The ragged object to be subsampled.
@param [in] axis The axis to be subsampled, must satisfy
0 <= axis < src.NumAxes(). The axis before `axis`, if axis < 0,
will be interpreted as a "batch" axis.
@param [in] beam The main pruning beam. The sub-lists of elements on axis
`axis` will be removed if their maximum element (or the element
itself, if axis + 1 == src.NumAxes()) is less than
this_best_elem - beam, where this_best_elem
is the maximum element taken over axis `axis-1` (or over the
entire array, if axis == 0). Think of axis `axis-1`, if
present, as the "batch" axis, and axis `axis` as the axis that we
actually remove elements or sub-lists on. Empty sub-lists on axis
`axis` will always be pruned, as their score would be treated
as -infinity.
@param [in] max_elems If max_elems > 0, it is the maximum number of sub-lists
or elements that are allowed within any sub-list on axis `axis-1`
(or the maximum number of top-level sub-lists after subsampling,
if axis == 0). We keep the best ones, but behavior in case of ties is
undefined (TODO: check whether SortSublists() is a stable sort, and
change this doc if it is). If max_elems <= 0, there is no such
constraint.
@return Returns the renumbering object to be used to actually
prune/subsample the specified axis.

Example:
PruneRagged([ [0 -1 -2 -3], [ -10, -20 ], [ ] ], 1, 5.0, 3)
would create a Renumbering object that would prune the
ragged tensor to [ [0 -1 -2], [ -10 ], [ ] ]

PruneRagged([ [0 -1 -2 -3], [ -10, -20 ], [ ] ], 0, 5.0, 0)
would create a Renumbering object that would prune the
ragged tensor to [ [0 -1 -2 -3] ]


TODO: don't forget to change subsample->subset when we rename
SubsampleRaggedShape().
IMPLEMENTATION NOTES (please delete later):
- We might want/need to treat certain cases specially, e.g.
the case when axis == src.NumAxes() - 1, and/or when
axis == 0.
- If `max_elems` is <= 0, we might want to choose a different
implementation, e.g. using max on the sub-lists rather
than sorting.
*/
template <typename T>
Renumbering PruneRagged(const Ragged<T> &src,
int32_t axis,
T beam,
int32_t max_elems);


/*
Stack a list of Ragged arrays to create a Ragged array with one more axis.
Similar to TF/PyTorch's Stack. The result will have Dim0 == num_srcs. All
Expand Down Expand Up @@ -855,8 +957,7 @@ Ragged<T> Stack(int32_t axis, int32_t num_srcs, Ragged<T> *src,
/*
Concatenate a list of Ragged<T> to form a single Ragged<T>.

@param [in] axis Axis to append them on. Currently
we only support axis == 0 or axis == 1.
@param [in] axis Axis to append them on.
Previous axes must
have the same shape, i.e. if axis == 1
then `src[i]->Dim0()` must all have the
Expand Down