From 6afb7e79913e56d85a0b55f0f16bd80e1d039656 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 12 Jan 2022 17:21:09 +0800 Subject: [PATCH 1/8] Extend interface of SubsampleRagged. --- k2/csrc/ragged_ops.h | 70 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 56 insertions(+), 14 deletions(-) diff --git a/k2/csrc/ragged_ops.h b/k2/csrc/ragged_ops.h index 441434819..2076d8e4a 100644 --- a/k2/csrc/ragged_ops.h +++ b/k2/csrc/ragged_ops.h @@ -699,14 +699,32 @@ RaggedShape RandomRaggedShape(bool set_row_ids = false, int32_t max_num_elements = 2000); /* - Return ragged shape with only a subset of the bottom-level elements kept. - Require renumbering.NumOldElems() == src.NumElements(). Note: all - dimensions and tot-sizes preceding the final axis will remain the same, which - might give rise to empty lists. + Return ragged shape with only a subset of the elements or sub-lists + on the specified axis kept. (This is not regular sampling, it is + irregular subsampling with specified elements kept). + + @param [in] src The ragged shape that we are subsampling + @param [in] renumbering The renumbering object that dictates + which elements of `src` we keep; we require + renumbering.NumOldElems() == src.TotSize(axis2) + where axis2 = (axis < 0 ? src.NumAxes() - axis : axis). + @param [in] axis The axis to subsample; if negative, will be + interpreted as an offset from src.NumAxes(). + @param [out] elems_renumbering If supplied, this function will + output to this location a renumbering object that + dictates how the elements of a ragged tensor + with shape `src` would be renumbered. + @return Returns the subsampled shape. All dimensions and tot-sizes + preceding the final axis will remain the same, which might give + rise to empty lists on those axes; these can be removed if + necessary with RemoveEmptyLists(). Notice the other version of this function below. */ -RaggedShape SubsampleRaggedShape(RaggedShape &src, Renumbering &renumbering); +RaggedShape SubsampleRaggedShape(RaggedShape &src, + Renumbering &renumbering, + int32_t axis = -1, + Renumbering *elems_renumbering = nullptr); /* @@ -804,16 +822,41 @@ RaggedShape RemoveEmptyListsAxis0(RaggedShape &src_shape, RaggedShape RenumberAxis0Simple(RaggedShape &src_shape, Renumbering &renumbering); + + /* - Return ragged array with only a subset of the bottom-level elements kept. - Require renumbering.NumOldElems() == src.NumElements(). Note: all - dimensions and tot-sizes preceding the final axis will remain the same, which - might give rise to empty lists. + Return ragged array with only a subset of the elements or sub-lists + on the specified axis kept. (This is not regular sampling, it is + irregular subsampling with specified elements kept). + + @param [in] src The ragged shape that we are subsampling + @param [in] renumbering The renumbering object that dictates + which elements of `src` we keep; we require + renumbering.NumOldElems() == src.TotSize(axis2) + where axis2 = (axis < 0 ? src.NumAxes() - axis : axis). + @param [in] axis The axis to subsample; if negative, will be + interpreted as an offset from src.NumAxes(). + @param [out] elems_renumbering If supplied, this function will + output to this location a renumbering object that + dictates how the elements of a ragged tensor + with shape `src` would be renumbered. + @return Returns the subsampled shape. All dimensions and tot-sizes + preceding the final axis will remain the same, which might give + rise to empty lists on those axes; these can be removed if + necessary with RemoveEmptyLists(). + + Notice the other version of this function below. */ template -Ragged SubsampleRagged(Ragged &src, Renumbering &renumbering) { - return Ragged(SubsampleRaggedShape(src.shape, renumbering), - src.values[renumbering.New2Old()]); +Ragged SubsampleRagged(Ragged &src, Renumbering &renumbering, + int32_t axis = -1, + Renumbering *elems_renumbering = nullptr) { + Renumbering tmp; + if (elems_renumbering == nullptr) + elems_renumbering = &tmp; + RaggedShape shape = SubsampleRaggedShape(src.shape, renumbering, + axis, elems_renumbering); + return Ragged(src.values[*elems_renumbering.New2Old()]); } /* @@ -855,8 +898,7 @@ Ragged Stack(int32_t axis, int32_t num_srcs, Ragged *src, /* Concatenate a list of Ragged to form a single Ragged. - @param [in] axis Axis to append them on. Currently - we only support axis == 0 or axis == 1. + @param [in] axis Axis to append them on. Previous axes must have the same shape, i.e. if axis == 1 then `src[i]->Dim0()` must all have the From 7341450bc021bda479d2ebbf4b8e5255e158c7fe Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 12 Jan 2022 17:54:41 +0800 Subject: [PATCH 2/8] Add interface for pruning ragged tensor. --- k2/csrc/ragged_ops.h | 59 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/k2/csrc/ragged_ops.h b/k2/csrc/ragged_ops.h index 2076d8e4a..1c9905304 100644 --- a/k2/csrc/ragged_ops.h +++ b/k2/csrc/ragged_ops.h @@ -859,6 +859,65 @@ Ragged SubsampleRagged(Ragged &src, Renumbering &renumbering, return Ragged(src.values[*elems_renumbering.New2Old()]); } + +/* + This function creates a Renumbering object that can be used to obtain subsets + of ragged arrays via SubsampleRaggedShape(). It implements beam pruning as + used in pruned Viterbi search and similar algorithms, where there is both a + beam and a max-active (`max_elems`) constraint. T will probably be float or + double, interpreted as a "positive-is-better" sense, i.e. as scores. + + @param [in] src The ragged object to be subsampled. + @param [in] axis The axis to be subsampled, must satisfy + 0 <= axis < src.NumAxes(). The axis before `axis`, if axis < 0, + will be interpreted as a "batch" axis. + @param [in] beam The main pruning beam. The sub-lists of elements on axis + `axis` will be removed if their maximum element (or the element + itself, if axis + 1 == src.NumAxes()) is less than + this_best_elem - beam, where this_best_elem + is the maximum element taken over axis `axis-1` (or over the + entire array, if axis == 0). Think of axis `axis-1`, if + present, as the "batch" axis, and axis `axis` as the axis that we + actually remove elements or sub-lists on. Empty sub-lists on axis + `axis` will always be pruned, as their score would be treated + as -infinity. + @param [in] max_elems If max_elems > 0, it is the maximum number of sub-lists + or elements that are allowed within any sub-list on axis `axis-1` + (or the maximum number of top-level sub-lists after subsampling, + if axis == 0). We keep the best ones, but behavior in case of ties is + undefined (TODO: check whether SortSublists() is a stable sort, and + change this doc if it is). If max_elems <= 0, there is no such + constraint. + @return Returns the renumbering object to be used to actually + prune/subsample the specified axis. + + Example: + PruneRagged([ [0 -1 -2 -3], [ -10, -20 ], [ ] ], 1, 5.0, 3) + would create a Renumbering object that would prune the + ragged tensor to [ [0 -1 -2], [ -10 ], [ ] ] + + PruneRagged([ [0 -1 -2 -3], [ -10, -20 ], [ ] ], 0, 5.0, 0) + would create a Renumbering object that would prune the + ragged tensor to [ [0 -1 -2 -3] ] + + + TODO: don't forget to change subsample->subset when we rename + SubsampleRaggedShape(). + IMPLEMENTATION NOTES (please delete later): + - We might want/need to treat certain cases specially, e.g. + the case when axis == src.NumAxes() - 1, and/or when + axis == 0. + - If `max_elems` is <= 0, we might want to choose a different + implementation, e.g. using max on the sub-lists rather + than sorting. + */ +template +Renumbering PruneRagged(const Ragged &src, + int32_t axis, + T beam, + int32_t max_elems); + + /* Stack a list of Ragged arrays to create a Ragged array with one more axis. Similar to TF/PyTorch's Stack. The result will have Dim0 == num_srcs. All From 280e2c259c45a4bce628ba25f49489f9898e05d2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 22 Jan 2022 22:25:14 +0800 Subject: [PATCH 3/8] Draft of new RNN-T decoding method --- k2/csrc/algorithms.h | 27 +++++- k2/csrc/array_of_ragged.h | 177 ++++++++++++++++++++++++++++++++++++++ k2/csrc/ragged.h | 26 ++++++ k2/csrc/ragged_ops.h | 20 ++--- 4 files changed, 238 insertions(+), 12 deletions(-) create mode 100644 k2/csrc/array_of_ragged.h diff --git a/k2/csrc/algorithms.h b/k2/csrc/algorithms.h index 439ac5ba3..8b1ba5cdb 100644 --- a/k2/csrc/algorithms.h +++ b/k2/csrc/algorithms.h @@ -121,17 +121,40 @@ class Renumbering { return new2old_; } + /* Return a mapping from new index to old index, with one extra element + containing the total number of kept elements if extra_element == true. + If Keep() can be interpreted as a tails vector, i.e. with 1 at the end + of sub-lists of elements, then New2Old(true) would corresponds to a + row-splits array and Old2New(false) would correspond to a row-ids + array. + */ + Array1 New2Old(bool extra_element) { + Array1 &new2old_part = New2Old(); + if (!extra_element) { + return new2old_part; + } else { + // This is a little perverse, using low-level interfaces to increase the + // dimension of the array; but we know it does have one more element. + // Because we normally use New2Old() with no arg (equivalent to false), + // the overloaded version of this function returns a reference for + // efficiency. + return Array1(new2old_part.Dim() + 1, + new2old_part.GetRegion(), 0); + } + } + /* Return a mapping from old index to new index. This is created on demand (must only be called after the Keep() array has been populated). @param [in] extra_element If true, will return the array of size NumOldElems() + 1, which includes one more element; otherwise it will return an array of size NumOldElems(). + + + @return Returns an array mapping the old indexes to the new indexes. This array is just the exclusive sum of Keep(). It gives the mapping for indexes that are kept; element i is kept if `Old2New()[i+1] > Old2New()[i]`. - - @return Returns an array mapping the old indexes to the new indexes. */ Array1 Old2New(bool extra_element = false) { NVTX_RANGE(K2_FUNC); diff --git a/k2/csrc/array_of_ragged.h b/k2/csrc/array_of_ragged.h new file mode 100644 index 000000000..8f02aefcc --- /dev/null +++ b/k2/csrc/array_of_ragged.h @@ -0,0 +1,177 @@ +/** + * Copyright 2022 Xiaomi Corporation (authors: Daniel Povey) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef K2_CSRC_ARRAY_OF_RAGGED_H_ +#define K2_CSRC_ARRAY_OF_RAGGED_H_ + +#include +#include +#include + +#include "k2/csrc/array.h" +#include "k2/csrc/context.h" +#include "k2/csrc/log.h" + +namespace k2 { + +/* + ArrayOfRagged is a 1-dimensional array of Ragged. + It is intended for situations where you want to do some operations on + arrays of ragged arrays, without explicitly concatenating them (e.g. to + save time). This is a fairly low-level interface, intended to + be used mostly by CUDA/C++ implementation code. It is a convenience + wrapper that saves you the trouble of creating arrays of pointers. + */ + + +/* + ArrayOfRaggedShape is a convenience function that gives you easy access + to pointers-of-pointers for an array of ragged shapes. + */ +class Array1OfRaggedShape { + public: + + /* + Constructor. + Args: + srcs: pointers to the source shapes, a CPU pointer + num_srcs: the number of source shapes. All shapes must have the + same NumAxes() and must be on the same device. + + TODO: we'll likely, later, add optional args which dictate which of + the MetaRowSplits() and MetaRowIds() are to be pre-populated; this should + enable us to save kernels by combining certain operations across the + axes. + + */ + ArrayOfRaggedShape(RaggedShape *src, + int32_t num_srcs); + + + int32_t NumSrcs() const; + int32_t NumAxes() const; + + // Returns device-accessible array of row-splits for the individual shapes, + // indexed [axis-1][src], with 0 <= src < num_srcs. The shape of this + // Array2 is [NumAxes() - 1][NumSrcs()]. + Array2 *RowSplits(); + + // Returns device-accessible vector of row-splits for a particular + // axis, indexed by 0 <= src < num_srcs. + int32_t **RowSplits(int32_t axis) { + } + + // Returns device-accessible array of row-ids for the individual shapes + // indexed [axis-1][src], with 0 <= src < num_srcs. The shape of this + // Array2 is [NumAxes() - 1][NumSrcs()]. + Array2 *RowIds(); + + + // Returns device-accessible vector of row-splits for a particular + // axis, indexed by 0 <= src < num_srcs. + int32_t **RowIds(int32_t axis) { + } + + + /* Return the total size on this axis, which is the sum of the TotSize() of + the individual shapes. Requires 0 <= axis < NumAxes() and + for axis=0 the returned value is the same as Dim0(). + */ + int32_t TotSize(int32_t axis) const; + + // equivalent to TotSize(0). + int32_t Dim0() const { return TotSize(0); } + + + /* Return the device-accessible meta-row-splits, which is the cumulative sum, + along the src axis, of the tot-sizes of the individual arrays. + This Array2 is of shape [NumAxes()][NumSrcs() + 1], indexed [axis][src]; + caution, the indexing is different from RowSplits(), there is no offset. + Also, the meta_row_splits0 is a thing, unlike with regular row-splits + which start from 1. + + Caution: the lengths of the arrays pointed to by the elements of this + Array2 (which contains pointers!) are of course all different, and + these lengths are currently only available + + Implementation note: we can probably just populate this on CPU and transfer + to GPU, this will be faster than invoking an extra kernel in normal cases + when the NumSrcs() is small. [Also: see GetRowInfoMulti()]. + */ + Array2 MetaRowSplits(); + + // could POSSIBLY add this so this code could be used in functions like Stack(). + // would be like MetaRowSplits but with an extra 1st row containing 0,1,2,... + // We could perhaps create it with 1 extra initial row so this is always + // convenient to output. + Array2 Offsets(); + + /* + Returns the meta-row-splits for a particular axis, with 0 <= axis < NumAxes(); + this is the cumulative sum of the TotSize(axis) for all of the sources, + with MetaRowSplits(axis).Dim() == NumSrcs() + 1. + + Note: in ragged_opts.cu we refer to this as composed_row_splits + */ + Array1 MetaRowSplits(int32_t axis); + + /* Return the device-accessible meta-row-ids, which are the row-ids corresponding + to MetaRowSplits(); this tells us, for indexes into the appended/concatenated + array, which source array they belong to, i.e. elements are in [0,NumSrcs()-1]. + + This cannot be an Array2 because unlike the MetaRowSplits(), all the row-ids + arrays are of different lengths. + + Note: in ragged_ops.cu we refer to this as composed_row_ids. + */ + Array1 MetaRowIds(); + + /* + Returns the meta-row-ids for a particular axis, with 0 <= axis < NumAxes(); + this is the row-ids corresponding to MetaRowSplits(axis), and its elements + gives, for indexes into the concatentated shape (concatenated on axis 0),m + which source they come from. E.g. element 100 of MetaRowIds(2) + would tell us which source an idx012 with value 100 into axis 2 of + concatenated array would come from. + */ + Array1 MetaRowIds(int32_t axis); +}; + + + +template +struct Array1OfRagged { + + Array1OfRaggedShape shape; + + // Array of the individual values pointers of the source arrays, indexed by + // shape + Array1 values; + + int32_t NumSrcs() { return values.Dim(); } + + ArrayOfRagged(Ragged *srcs, + int32_t num_srcs); + +} + + + +} // namespace k2 + +#endif // K2_CSRC_ARRAY_OF_RAGGED_H_ diff --git a/k2/csrc/ragged.h b/k2/csrc/ragged.h index 4e2db522d..45fec7601 100644 --- a/k2/csrc/ragged.h +++ b/k2/csrc/ragged.h @@ -12,6 +12,32 @@ * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, + // `states` contains int64_t which represents the decoder state; this is: + // context_state * num_graph_states + graph_state. + // the num_graph_states is specific to the decoding stream; + // it is held at an outer level, in the RnntDecodingStream + // or RnntDecodingStreams object. + // + // For a single RNN-T stream, `states` would be indexed + // [context_state][state], i.e. the states are grouped first + // by context_state (they are sorted, to make this possible). + // If this object represents multiple RNN-T streams, + // the `states` object is indexed [stream][context_state][state]. + Ragged states; + + // `scores` contains the forward scores of the states in `states`; + // it has the same shape as `states`. + Ragged scores; + + // forward_scores contains the forward scores of the states in `states.values` + // (best score from start state to here); the shape is the same as + // `states` + Ragged forward_scores; + + // frames contains the arc information, for previously decoded + // frames, that we can later use to create a lattice. + // It contains Ragged with 2 axes (state, arc). + std::vector > prev_frames; * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. diff --git a/k2/csrc/ragged_ops.h b/k2/csrc/ragged_ops.h index 1c9905304..a63aff203 100644 --- a/k2/csrc/ragged_ops.h +++ b/k2/csrc/ragged_ops.h @@ -710,8 +710,8 @@ RaggedShape RandomRaggedShape(bool set_row_ids = false, where axis2 = (axis < 0 ? src.NumAxes() - axis : axis). @param [in] axis The axis to subsample; if negative, will be interpreted as an offset from src.NumAxes(). - @param [out] elems_renumbering If supplied, this function will - output to this location a renumbering object that + @param [out] elems_new2old If supplied, this function will + output to this location a new2old vector that dictates how the elements of a ragged tensor with shape `src` would be renumbered. @return Returns the subsampled shape. All dimensions and tot-sizes @@ -724,7 +724,7 @@ RaggedShape RandomRaggedShape(bool set_row_ids = false, RaggedShape SubsampleRaggedShape(RaggedShape &src, Renumbering &renumbering, int32_t axis = -1, - Renumbering *elems_renumbering = nullptr); + Array1 *elems_new2old = nullptr); /* @@ -836,8 +836,8 @@ RaggedShape RenumberAxis0Simple(RaggedShape &src_shape, where axis2 = (axis < 0 ? src.NumAxes() - axis : axis). @param [in] axis The axis to subsample; if negative, will be interpreted as an offset from src.NumAxes(). - @param [out] elems_renumbering If supplied, this function will - output to this location a renumbering object that + @param [out] elems_new2old If supplied, this function will + output to this location a new2old array that dictates how the elements of a ragged tensor with shape `src` would be renumbered. @return Returns the subsampled shape. All dimensions and tot-sizes @@ -850,13 +850,13 @@ RaggedShape RenumberAxis0Simple(RaggedShape &src_shape, template Ragged SubsampleRagged(Ragged &src, Renumbering &renumbering, int32_t axis = -1, - Renumbering *elems_renumbering = nullptr) { + Array1 *elems_new2old = nullptr); Renumbering tmp; - if (elems_renumbering == nullptr) - elems_renumbering = &tmp; + if (elems_new2old == nullptr) + elems_new2old = &tmp; RaggedShape shape = SubsampleRaggedShape(src.shape, renumbering, - axis, elems_renumbering); - return Ragged(src.values[*elems_renumbering.New2Old()]); + axis, elems_new2old); + return Ragged(src.values[*elems_new2old]); } From 294976f9a8cf6b37db11efd4df10bf9d17edf0e8 Mon Sep 17 00:00:00 2001 From: pkufool Date: Fri, 4 Feb 2022 10:24:35 +0800 Subject: [PATCH 4/8] Implements SubsampleRaggedShape --- k2/csrc/algorithms.h | 2 +- k2/csrc/algorithms_test.cu | 9 ++++++ k2/csrc/ragged.h | 26 ---------------- k2/csrc/ragged_ops.cu | 15 ++++------ k2/csrc/ragged_ops.h | 9 ++---- k2/csrc/ragged_shape_test.cu | 58 ++++++++++++++++++++++++++++++++---- 6 files changed, 71 insertions(+), 48 deletions(-) diff --git a/k2/csrc/algorithms.h b/k2/csrc/algorithms.h index 8b1ba5cdb..c058c533a 100644 --- a/k2/csrc/algorithms.h +++ b/k2/csrc/algorithms.h @@ -129,7 +129,7 @@ class Renumbering { array. */ Array1 New2Old(bool extra_element) { - Array1 &new2old_part = New2Old(); + Array1 &new2old_part = New2Old(); if (!extra_element) { return new2old_part; } else { diff --git a/k2/csrc/algorithms_test.cu b/k2/csrc/algorithms_test.cu index 5edc01cac..bbf310d12 100644 --- a/k2/csrc/algorithms_test.cu +++ b/k2/csrc/algorithms_test.cu @@ -45,6 +45,9 @@ TEST(AlgorithmsTest, TestRenumbering) { Array1 new2old = numbering.New2Old(); EXPECT_EQ(new2old.Dim(), 0); EXPECT_EQ(numbering.NumNewElems(), 0); + new2old = numbering.New2Old(true); + EXPECT_EQ(new2old.Dim(), 1); + EXPECT_EQ(new2old.Back(), 0); } { @@ -67,6 +70,9 @@ TEST(AlgorithmsTest, TestRenumbering) { Array1 new2old = numbering.New2Old(); EXPECT_EQ(new2old.Dim(), 0); EXPECT_EQ(numbering.NumNewElems(), 0); + new2old = numbering.New2Old(true); + EXPECT_EQ(new2old.Dim(), 1); + EXPECT_EQ(new2old.Back(), 5); } { @@ -93,6 +99,9 @@ TEST(AlgorithmsTest, TestRenumbering) { std::vector cpu_new2old(new2old.Data(), new2old.Data() + new2old.Dim()); EXPECT_THAT(cpu_new2old, ::testing::ElementsAre(0, 2, 3, 6)); + new2old = numbering.New2Old(true); + EXPECT_EQ(new2old.Dim(), 5); + EXPECT_EQ(new2old.Back(), 7); } } } diff --git a/k2/csrc/ragged.h b/k2/csrc/ragged.h index 45fec7601..4e2db522d 100644 --- a/k2/csrc/ragged.h +++ b/k2/csrc/ragged.h @@ -12,32 +12,6 @@ * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, - // `states` contains int64_t which represents the decoder state; this is: - // context_state * num_graph_states + graph_state. - // the num_graph_states is specific to the decoding stream; - // it is held at an outer level, in the RnntDecodingStream - // or RnntDecodingStreams object. - // - // For a single RNN-T stream, `states` would be indexed - // [context_state][state], i.e. the states are grouped first - // by context_state (they are sorted, to make this possible). - // If this object represents multiple RNN-T streams, - // the `states` object is indexed [stream][context_state][state]. - Ragged states; - - // `scores` contains the forward scores of the states in `states`; - // it has the same shape as `states`. - Ragged scores; - - // forward_scores contains the forward scores of the states in `states.values` - // (best score from start state to here); the shape is the same as - // `states` - Ragged forward_scores; - - // frames contains the arc information, for previously decoded - // frames, that we can later use to create a lattice. - // It contains Ragged with 2 axes (state, arc). - std::vector > prev_frames; * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. diff --git a/k2/csrc/ragged_ops.cu b/k2/csrc/ragged_ops.cu index efd546a59..80eee3f84 100644 --- a/k2/csrc/ragged_ops.cu +++ b/k2/csrc/ragged_ops.cu @@ -1632,17 +1632,12 @@ Ragged AddPrefixToRagged(Ragged &src, return Ragged(dst_shape, dst_values); } -RaggedShape SubsampleRaggedShape(RaggedShape &src, Renumbering &renumbering) { +RaggedShape SubsampleRaggedShape(RaggedShape &src, Renumbering &renumbering, + int32_t axis, Array1 *elems_new2old) { NVTX_RANGE(K2_FUNC); - K2_CHECK_EQ(renumbering.NumOldElems(), src.NumElements()); - - // Make sure final row-ids are populated. - src.RowIds(src.NumAxes() - 1); - std::vector axes = src.Layers(); - axes.back().row_ids = axes.back().row_ids[renumbering.New2Old()]; - axes.back().row_splits = renumbering.Old2New()[axes.back().row_splits]; - axes.back().cached_tot_size = axes.back().row_ids.Dim(); - return RaggedShape(axes); + axis = axis < 0 ? src.NumAxes() + axis : axis; + K2_CHECK_EQ(renumbering.NumOldElems(), src.TotSize(axis)); + return Index(src, axis, renumbering.New2Old(), elems_new2old); } RaggedShape SubsampleRaggedShape(RaggedShape &src, Renumbering &r_before_last, diff --git a/k2/csrc/ragged_ops.h b/k2/csrc/ragged_ops.h index a63aff203..74e13460b 100644 --- a/k2/csrc/ragged_ops.h +++ b/k2/csrc/ragged_ops.h @@ -844,22 +844,19 @@ RaggedShape RenumberAxis0Simple(RaggedShape &src_shape, preceding the final axis will remain the same, which might give rise to empty lists on those axes; these can be removed if necessary with RemoveEmptyLists(). - - Notice the other version of this function below. */ template Ragged SubsampleRagged(Ragged &src, Renumbering &renumbering, int32_t axis = -1, - Array1 *elems_new2old = nullptr); - Renumbering tmp; + Array1 *elems_new2old = nullptr) { + Array1 tmp; if (elems_new2old == nullptr) elems_new2old = &tmp; RaggedShape shape = SubsampleRaggedShape(src.shape, renumbering, axis, elems_new2old); - return Ragged(src.values[*elems_new2old]); + return Ragged(shape, src.values[*elems_new2old]); } - /* This function creates a Renumbering object that can be used to obtain subsets of ragged arrays via SubsampleRaggedShape(). It implements beam pruning as diff --git a/k2/csrc/ragged_shape_test.cu b/k2/csrc/ragged_shape_test.cu index f09d7edcd..7476c9c9b 100644 --- a/k2/csrc/ragged_shape_test.cu +++ b/k2/csrc/ragged_shape_test.cu @@ -360,11 +360,6 @@ TEST(RaggedShapeTest, RemoveEmptyLists) { } } - - - - - TEST(RaggedShapeTest, RaggedShapeIterator) { // note RaggedShapeIndexIterator works only for CPU ContextPtr context = GetCpuContext(); @@ -425,4 +420,57 @@ TEST(RaggedShapeTest, RandomRaggedShape) { } } +TEST(RaggedShapeTest, SubsampleRaggedShape) { + for (auto &c : {GetCpuContext(), GetCudaContext()}) { + RaggedShape src(c, "[ [ [ x x ] [ x ] ] [ [ x x x ] [ ] ] [ [ x ] [ x x ] ] ]"); + // axis = 2 + Array1 ref_keep(c, std::vector({1, 0, 1, 0, 0, 1, 0, 1, 0})); + Renumbering renumbering(c, src.NumElements()); + auto &keep = renumbering.Keep(); + keep.CopyFrom(ref_keep); + Array1 new2old; + auto dest = SubsampleRaggedShape(src, renumbering, 2, &new2old); + Array1 ref_new2old(c, std::vector({0, 2, 5, 7})); + RaggedShape ref_dest(c, "[ [ [ x ] [ x ] ] [ [ x ] [ ] ] [ [ ] [ x ] ] ]"); + EXPECT_TRUE(Equal(dest, ref_dest)); + EXPECT_TRUE(Equal(new2old, ref_new2old)); + // test axis = -1 + dest = SubsampleRaggedShape(src, renumbering, -1, &new2old); + EXPECT_TRUE(Equal(dest, ref_dest)); + EXPECT_TRUE(Equal(new2old, ref_new2old)); + + + // axis = 1 + ref_keep = Array1(c, std::vector({1, 0, 0, 1, 1, 1})); + renumbering = Renumbering(c, src.TotSize(1)); + keep = renumbering.Keep(); + keep.CopyFrom(ref_keep); + dest = SubsampleRaggedShape(src, renumbering, 1, &new2old); + ref_new2old = Array1(c, std::vector({0, 1, 6, 7, 8})); + ref_dest = RaggedShape(c, "[ [ [ x x ] ] [ [ ] ] [ [ x ] [ x x ] ] ]"); + EXPECT_TRUE(Equal(dest, ref_dest)); + EXPECT_TRUE(Equal(new2old, ref_new2old)); + // test axis = -2 + dest = SubsampleRaggedShape(src, renumbering, -2, &new2old); + EXPECT_TRUE(Equal(dest, ref_dest)); + EXPECT_TRUE(Equal(new2old, ref_new2old)); + + + // axis = 0 + ref_keep = Array1(c, std::vector({1, 0, 1})); + renumbering = Renumbering(c, src.TotSize(0)); + keep = renumbering.Keep(); + keep.CopyFrom(ref_keep); + dest = SubsampleRaggedShape(src, renumbering, 0, &new2old); + ref_new2old = Array1(c, std::vector({0, 1, 2, 6, 7, 8})); + ref_dest = RaggedShape(c, "[ [ [ x x ] [ x ] ] [ [ x ] [ x x ] ] ]"); + EXPECT_TRUE(Equal(dest, ref_dest)); + EXPECT_TRUE(Equal(new2old, ref_new2old)); + // test axis = -3 + dest = SubsampleRaggedShape(src, renumbering, -3, &new2old); + EXPECT_TRUE(Equal(dest, ref_dest)); + EXPECT_TRUE(Equal(new2old, ref_new2old)); + } +} + } // namespace k2 From 2ba62e1dc160a0d3e57f9dad65bb8b03dce21c39 Mon Sep 17 00:00:00 2001 From: pkufool Date: Sun, 6 Feb 2022 13:08:36 +0800 Subject: [PATCH 5/8] Implements PruneRagged --- k2/csrc/ragged_ops.h | 2 +- k2/csrc/ragged_ops_inl.h | 123 +++++++++++++++++++++++++++++++++++ k2/csrc/ragged_shape_test.cu | 9 ++- k2/csrc/ragged_test.cu | 118 +++++++++++++++++++++++++++++++++ 4 files changed, 248 insertions(+), 4 deletions(-) diff --git a/k2/csrc/ragged_ops.h b/k2/csrc/ragged_ops.h index 74e13460b..9d2e1d91b 100644 --- a/k2/csrc/ragged_ops.h +++ b/k2/csrc/ragged_ops.h @@ -909,7 +909,7 @@ Ragged SubsampleRagged(Ragged &src, Renumbering &renumbering, than sorting. */ template -Renumbering PruneRagged(const Ragged &src, +Renumbering PruneRagged(Ragged &src, int32_t axis, T beam, int32_t max_elems); diff --git a/k2/csrc/ragged_ops_inl.h b/k2/csrc/ragged_ops_inl.h index 4d44761cd..f1cff92f7 100644 --- a/k2/csrc/ragged_ops_inl.h +++ b/k2/csrc/ragged_ops_inl.h @@ -759,6 +759,129 @@ Array2 PadRagged(Ragged &src, const std::string &mode, T padding_value) { return res; } +// Prune a two axes ragged tensor on axis0 +template +Renumbering PruneRaggedAxis0(Ragged &src, T beam, int32_t max_elems) { + K2_CHECK_EQ(src.NumAxes(), 2); + const ContextPtr &c = src.Context(); + int32_t total_elements = src.TotSize(0); + Renumbering renumbering(c, total_elements); + + T negative_infinity = -std::numeric_limits::infinity(); + Array1 sub_max(c, total_elements); + MaxPerSublist(src, negative_infinity, &sub_max); + + T max_value = MaxValue(src.values); + + bool prune_with_max_elems = + max_elems > 0 && max_elems < total_elements; + + Array1 order_map; + const int32_t *order_map_data; + if (prune_with_max_elems) { + order_map = Array1(c, total_elements); + Sort>(&sub_max, &order_map); + order_map_data = order_map.Data(); + } + + char *keep_data = renumbering.Keep().Data(); + const T *sub_max_data = sub_max.Data(); + + // prune_with_max_elems means we have sorted the source ragged tensor + if (prune_with_max_elems) { + K2_EVAL(c, total_elements, lambda_set_keep_sorted, (int32_t i) { + bool pruned_by_beam = sub_max_data[i] < max_value - beam; + bool pruned_by_max_elems = i >= max_elems; + keep_data[order_map_data[i]] = + !(pruned_by_max_elems || pruned_by_beam); + }); + } else { + K2_EVAL(c, total_elements, lambda_set_keep, (int32_t i) { + keep_data[i] = sub_max_data[i] >= max_value - beam; + }); + } + return renumbering; +} + +// Prune a two axes ragged tensor on axis1 +template +Renumbering PruneRaggedLastAxis(Ragged &src, T beam, + int32_t max_elems) { + K2_CHECK_EQ(src.NumAxes(), 2); + const ContextPtr &c = src.Context(); + int32_t total_elements = src.TotSize(1); + Renumbering renumbering(c, total_elements); + + T negative_infinity = -std::numeric_limits::infinity(); + Array1 sub_max(c, src.TotSize(0)); + MaxPerSublist(src, negative_infinity, &sub_max); + + bool prune_with_max_elems = + max_elems > 0 && max_elems < total_elements; + + Array1 order_map; + const int32_t *order_map_data; + if (prune_with_max_elems) { + Ragged sorted_src = src.Clone(); + order_map = Array1(c, total_elements); + SortSublists>(&sorted_src, &order_map); + order_map_data = order_map.Data(); + } + + char *keep_data = renumbering.Keep().Data(); + const T *sub_max_data = sub_max.Data(), + *src_data = src.values.Data(); + const int32_t *row_ids1_data = src.RowIds(1).Data(), + *row_splits1_data = src.RowSplits(1).Data(); + // prune_with_max_elems means we have sorted the source ragged tensor + if (prune_with_max_elems) { + K2_EVAL(c, total_elements, lambda_set_keep_sorted, (int32_t idx01) { + // idx01 is the index after sorting + int32_t original_idx01 = order_map_data[idx01], + // SortSublists wouldn't chaneg idx0 & idx0x + idx0 = row_ids1_data[original_idx01], + idx0x = row_splits1_data[idx0], + // idx1 is the index after sorting + idx1 = idx01 - idx0x; + bool pruned_by_max_elems = idx1 >= max_elems, + pruned_by_beam = + src_data[original_idx01] < sub_max_data[idx0] - beam; + keep_data[original_idx01] = + !(pruned_by_max_elems || pruned_by_beam); + }); + } else { + K2_EVAL(c, total_elements, lambda_set_keep, (int32_t idx01) { + int32_t idx0 = row_ids1_data[idx01]; + keep_data[idx01] = src_data[idx01] >= sub_max_data[idx0] - beam; + }); + } + return renumbering; +} + +template +Renumbering PruneRagged(Ragged &src, int32_t axis, T beam, + int32_t max_elems) { + NVTX_RANGE(K2_FUNC); + if (axis == 0) { + auto reduced_src = src; + while (reduced_src.NumAxes() > 2) { + reduced_src = RemoveAxis(reduced_src, reduced_src.NumAxes() - 2); + } + return PruneRaggedAxis0(reduced_src, beam, max_elems); + } else if (axis == src.NumAxes() - 1) { + auto reduced_src = src; + while (reduced_src.NumAxes() > 2) { + reduced_src = RemoveAxis(reduced_src, 0); + } + return PruneRaggedLastAxis(reduced_src, beam, max_elems); + } else { + RaggedShape top, bottom; + DecomposeRaggedShape(src.shape, axis, &top, &bottom); + Ragged bottom_ragged(bottom, src.values); + return PruneRagged(bottom_ragged, 0, beam, max_elems); + } +} + } // namespace k2 #endif // K2_CSRC_RAGGED_OPS_INL_H_ diff --git a/k2/csrc/ragged_shape_test.cu b/k2/csrc/ragged_shape_test.cu index 7476c9c9b..e7fe67bf9 100644 --- a/k2/csrc/ragged_shape_test.cu +++ b/k2/csrc/ragged_shape_test.cu @@ -422,7 +422,8 @@ TEST(RaggedShapeTest, RandomRaggedShape) { TEST(RaggedShapeTest, SubsampleRaggedShape) { for (auto &c : {GetCpuContext(), GetCudaContext()}) { - RaggedShape src(c, "[ [ [ x x ] [ x ] ] [ [ x x x ] [ ] ] [ [ x ] [ x x ] ] ]"); + RaggedShape src(c, + "[ [ [ x x ] [ x ] ] [ [ x x x ] [ ] ] [ [ x ] [ x x ] ] ]"); // axis = 2 Array1 ref_keep(c, std::vector({1, 0, 1, 0, 0, 1, 0, 1, 0})); Renumbering renumbering(c, src.NumElements()); @@ -431,7 +432,8 @@ TEST(RaggedShapeTest, SubsampleRaggedShape) { Array1 new2old; auto dest = SubsampleRaggedShape(src, renumbering, 2, &new2old); Array1 ref_new2old(c, std::vector({0, 2, 5, 7})); - RaggedShape ref_dest(c, "[ [ [ x ] [ x ] ] [ [ x ] [ ] ] [ [ ] [ x ] ] ]"); + RaggedShape ref_dest(c, + "[ [ [ x ] [ x ] ] [ [ x ] [ ] ] [ [ ] [ x ] ] ]"); EXPECT_TRUE(Equal(dest, ref_dest)); EXPECT_TRUE(Equal(new2old, ref_new2old)); // test axis = -1 @@ -462,7 +464,8 @@ TEST(RaggedShapeTest, SubsampleRaggedShape) { keep = renumbering.Keep(); keep.CopyFrom(ref_keep); dest = SubsampleRaggedShape(src, renumbering, 0, &new2old); - ref_new2old = Array1(c, std::vector({0, 1, 2, 6, 7, 8})); + ref_new2old = Array1(c, + std::vector({0, 1, 2, 6, 7, 8})); ref_dest = RaggedShape(c, "[ [ [ x x ] [ x ] ] [ [ x ] [ x x ] ] ]"); EXPECT_TRUE(Equal(dest, ref_dest)); EXPECT_TRUE(Equal(new2old, ref_new2old)); diff --git a/k2/csrc/ragged_test.cu b/k2/csrc/ragged_test.cu index a8fdd1286..14a488a0f 100644 --- a/k2/csrc/ragged_test.cu +++ b/k2/csrc/ragged_test.cu @@ -2796,4 +2796,122 @@ TEST(RaggedTest, TestPadRagged) { TestPadRagged(); } + +template +static void TestPruneRagged() { + for (auto &c : {GetCpuContext(), GetCudaContext()}) { + Ragged src(c, "[ [ [ 1.1 2.1 5.2 ] [ 1.0 5.1 ] [ 6.1 ] ] " + " [ [ 1.2 ] [ 2.2 6.3 ] [ ] ] " + " [ [ 1.3 4.4 ] [ 2.3 5.0 ] ] ]"); + + T beam = 2.0; + auto renumbering = PruneRagged(src, 0, beam, 2); + // best_score=6.3, best scores for sublists are [6.1, 6.3, 5.0] + // no sublist is pruned by beam, 5.0 is pruned by max-elems + // keep : [ [ [ 1.1 2.1 5.2 ] [ 1.0 5.1 ] [ 6.1 ] ] + // [ [ 1.2 ] [ 2.2 6.3 ] [ ] ] ] + Array1 keep_ref(c, std::vector{1, 1, 0}); + K2_CHECK(Equal(renumbering.Keep(), keep_ref)); + + beam = 0.1; + renumbering = PruneRagged(src, 0, beam, 3); + // best_score=6.3, best scores for sublists are [6.1, 6.3, 5.0] + // 6.1 & 5.0 are pruned by beam + // keep : [ [ [ 1.2 ] [ 2.2 6.3 ] [ ] ] ] + keep_ref = Array1(c, std::vector{0, 1, 0}); + K2_CHECK(Equal(renumbering.Keep(), keep_ref)); + + beam = 2.0; + renumbering = PruneRagged(src, 1, beam, 5); + // best_score=6.3, best scores for sublists are + // [5.2, 5.1, 6.1, 1.2, 6.3, -inf, 4.4, 5.0] + // 1.2 & -inf are pruned by beam, 4.4 is pruned by max-elems. + // keep : [ [ [ 1.1 2.1 5.2 ] [ 1.0 5.1 ] [ 6.1 ] ] [ [ 2.2 6.3 ] ] + // [ [ 2.3 5.0 ] ] ] + keep_ref = Array1(c, std::vector{1, 1, 1, 0, 1, 0, 0, 1}); + K2_CHECK(Equal(renumbering.Keep(), keep_ref)); + + beam = 1.0; + renumbering = PruneRagged(src, 1, beam, 5); + // best_score=6.3, best scores for sublists are + // [5.2, 5.1, 6.1, 1.2, 6.3, -inf, 4.4, 5.0] + // all sublists are pruned by beam, except 6.1 & 6.3 + // keep : [ [ [ 6.1 ] ] [ [ 2.2 6.3 ] ] ] + keep_ref = Array1(c, std::vector{0, 0, 1, 0, 1, 0, 0, 0}); + K2_CHECK(Equal(renumbering.Keep(), keep_ref)); + + beam = 4.0; + renumbering = PruneRagged(src, 2, beam, 3); + // best scores for sublists are + // [5.2, 5.1, 6.1, 1.2, 6.3, -inf, 4.4, 5.0] + // 1.1, 1.0, 2.2 are pruned by beam. + // keep : [ [ [ 2.1 5.2 ] [ 5.1 ] [ 6.1 ] ] [ [ 1.2 ] [ 6.3 ] ] + // [ [ 1.3 4.4 ] [ 2.3 5.0 ] ] ] + keep_ref = Array1(c, + std::vector{0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1}); + K2_CHECK(Equal(renumbering.Keep(), keep_ref)); + + beam = 5.0; + renumbering = PruneRagged(src, 2, beam, 2); + // best scores for sublists are + // [5.2, 5.1, 6.1, 1.2, 6.3, -inf, 4.4, 5.0] + // 1.1 is pruned by max-elems. + // keep : [ [ [ 2.1 5.2 ] [ 1.0 5.1 ] [ 6.1 ] ] [ [ 1.2 ] [ 2.2 6.3 ] ] + // [ [ 1.3 4.4 ] [ 2.3 5.0 ] ] ] + keep_ref = Array1(c, + std::vector{0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + K2_CHECK(Equal(renumbering.Keep(), keep_ref)); + } +} + +TEST(RaggedTest, TestPruneRagged) { + TestPruneRagged(); + TestPruneRagged(); +} + +template +static void TestPruneRaggedAndSubsampleRagged() { + for (auto &c : {GetCpuContext(), GetCudaContext()}) { + Ragged src(c, "[ [ [ 1.1 4.2 2.1 1.8 ] [ 5.0 3.1 ] ] " + " [ [ 1.2 ] [ 2.2 6.3 ] [ 2.4 6.1 ] [ 5.1 ] ] " + " [ [ 1.3 4.4 ] [ 1.4 0.8 2.3 5.2 3.6 ] ] ]"); + T beam = 1.0; + auto renumbering = PruneRagged(src, 0, beam, 3); + Array1 new2old; + auto dest = SubsampleRagged(src, renumbering, 0, &new2old); + Ragged dest_ref(c, + "[ [ [ 1.2 ] [ 2.2 6.3 ] [ 2.4 6.1 ] [ 5.1 ] ] ]"); + Array1 new2old_ref(c, std::vector{6, 7, 8, 9, 10, 11}); + K2_CHECK(Equal(dest, dest_ref)); + K2_CHECK(Equal(new2old, new2old_ref)); + + beam = 2.0; + renumbering = PruneRagged(src, 1, beam, 5); + dest = SubsampleRagged(src, renumbering, 1, &new2old); + dest_ref = Ragged(c, + "[ [ [ 5.0 3.1 ] ] [ [ 2.2 6.3 ] [ 2.4 6.1 ] [ 5.1 ] ] " + " [ [ 1.4 0.8 2.3 5.2 3.6 ] ] ]"); + new2old_ref = Array1(c, + std::vector{4, 5, 7, 8, 9, 10, 11, 14, 15, 16, 17, 18}); + K2_CHECK(Equal(dest, dest_ref)); + K2_CHECK(Equal(new2old, new2old_ref)); + + beam = 3.0; + renumbering = PruneRagged(src, 2, beam, 3); + dest = SubsampleRagged(src, renumbering, 2, &new2old); + dest_ref = Ragged(c, + "[ [ [ 4.2 2.1 1.8 ] [ 5.0 3.1 ] ] [ [ 1.2 ] [ 6.3 ] [ 6.1 ] [ 5.1 ] ]" + " [ [ 4.4 ] [ 2.3 5.2 3.6 ] ] ]"); + new2old_ref = Array1(c, + std::vector{1, 2, 3, 4, 5, 6, 8, 10, 11, 13, 16, 17, 18}); + K2_CHECK(Equal(dest, dest_ref)); + K2_CHECK(Equal(new2old, new2old_ref)); + } +} + +TEST(RaggedTest, TestPruneRaggedAndSubsampleRagged) { + TestPruneRaggedAndSubsampleRagged(); + TestPruneRaggedAndSubsampleRagged(); +} + } // namespace k2 From efb786d68ac4c840b592431958579e83cc94d202 Mon Sep 17 00:00:00 2001 From: pkufool Date: Mon, 14 Feb 2022 11:37:01 +0800 Subject: [PATCH 6/8] Rename subsample-> subset --- k2/csrc/array_of_ragged.h | 177 ----------------------------------- k2/csrc/ragged_ops.cu | 10 +- k2/csrc/ragged_ops.h | 25 ++--- k2/csrc/ragged_ops_inl.h | 4 +- k2/csrc/ragged_shape_test.cu | 14 +-- k2/csrc/ragged_test.cu | 14 +-- k2/csrc/rm_epsilon.cu | 4 +- 7 files changed, 30 insertions(+), 218 deletions(-) delete mode 100644 k2/csrc/array_of_ragged.h diff --git a/k2/csrc/array_of_ragged.h b/k2/csrc/array_of_ragged.h deleted file mode 100644 index 8f02aefcc..000000000 --- a/k2/csrc/array_of_ragged.h +++ /dev/null @@ -1,177 +0,0 @@ -/** - * Copyright 2022 Xiaomi Corporation (authors: Daniel Povey) - * - * See LICENSE for clarification regarding multiple authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef K2_CSRC_ARRAY_OF_RAGGED_H_ -#define K2_CSRC_ARRAY_OF_RAGGED_H_ - -#include -#include -#include - -#include "k2/csrc/array.h" -#include "k2/csrc/context.h" -#include "k2/csrc/log.h" - -namespace k2 { - -/* - ArrayOfRagged is a 1-dimensional array of Ragged. - It is intended for situations where you want to do some operations on - arrays of ragged arrays, without explicitly concatenating them (e.g. to - save time). This is a fairly low-level interface, intended to - be used mostly by CUDA/C++ implementation code. It is a convenience - wrapper that saves you the trouble of creating arrays of pointers. - */ - - -/* - ArrayOfRaggedShape is a convenience function that gives you easy access - to pointers-of-pointers for an array of ragged shapes. - */ -class Array1OfRaggedShape { - public: - - /* - Constructor. - Args: - srcs: pointers to the source shapes, a CPU pointer - num_srcs: the number of source shapes. All shapes must have the - same NumAxes() and must be on the same device. - - TODO: we'll likely, later, add optional args which dictate which of - the MetaRowSplits() and MetaRowIds() are to be pre-populated; this should - enable us to save kernels by combining certain operations across the - axes. - - */ - ArrayOfRaggedShape(RaggedShape *src, - int32_t num_srcs); - - - int32_t NumSrcs() const; - int32_t NumAxes() const; - - // Returns device-accessible array of row-splits for the individual shapes, - // indexed [axis-1][src], with 0 <= src < num_srcs. The shape of this - // Array2 is [NumAxes() - 1][NumSrcs()]. - Array2 *RowSplits(); - - // Returns device-accessible vector of row-splits for a particular - // axis, indexed by 0 <= src < num_srcs. - int32_t **RowSplits(int32_t axis) { - } - - // Returns device-accessible array of row-ids for the individual shapes - // indexed [axis-1][src], with 0 <= src < num_srcs. The shape of this - // Array2 is [NumAxes() - 1][NumSrcs()]. - Array2 *RowIds(); - - - // Returns device-accessible vector of row-splits for a particular - // axis, indexed by 0 <= src < num_srcs. - int32_t **RowIds(int32_t axis) { - } - - - /* Return the total size on this axis, which is the sum of the TotSize() of - the individual shapes. Requires 0 <= axis < NumAxes() and - for axis=0 the returned value is the same as Dim0(). - */ - int32_t TotSize(int32_t axis) const; - - // equivalent to TotSize(0). - int32_t Dim0() const { return TotSize(0); } - - - /* Return the device-accessible meta-row-splits, which is the cumulative sum, - along the src axis, of the tot-sizes of the individual arrays. - This Array2 is of shape [NumAxes()][NumSrcs() + 1], indexed [axis][src]; - caution, the indexing is different from RowSplits(), there is no offset. - Also, the meta_row_splits0 is a thing, unlike with regular row-splits - which start from 1. - - Caution: the lengths of the arrays pointed to by the elements of this - Array2 (which contains pointers!) are of course all different, and - these lengths are currently only available - - Implementation note: we can probably just populate this on CPU and transfer - to GPU, this will be faster than invoking an extra kernel in normal cases - when the NumSrcs() is small. [Also: see GetRowInfoMulti()]. - */ - Array2 MetaRowSplits(); - - // could POSSIBLY add this so this code could be used in functions like Stack(). - // would be like MetaRowSplits but with an extra 1st row containing 0,1,2,... - // We could perhaps create it with 1 extra initial row so this is always - // convenient to output. - Array2 Offsets(); - - /* - Returns the meta-row-splits for a particular axis, with 0 <= axis < NumAxes(); - this is the cumulative sum of the TotSize(axis) for all of the sources, - with MetaRowSplits(axis).Dim() == NumSrcs() + 1. - - Note: in ragged_opts.cu we refer to this as composed_row_splits - */ - Array1 MetaRowSplits(int32_t axis); - - /* Return the device-accessible meta-row-ids, which are the row-ids corresponding - to MetaRowSplits(); this tells us, for indexes into the appended/concatenated - array, which source array they belong to, i.e. elements are in [0,NumSrcs()-1]. - - This cannot be an Array2 because unlike the MetaRowSplits(), all the row-ids - arrays are of different lengths. - - Note: in ragged_ops.cu we refer to this as composed_row_ids. - */ - Array1 MetaRowIds(); - - /* - Returns the meta-row-ids for a particular axis, with 0 <= axis < NumAxes(); - this is the row-ids corresponding to MetaRowSplits(axis), and its elements - gives, for indexes into the concatentated shape (concatenated on axis 0),m - which source they come from. E.g. element 100 of MetaRowIds(2) - would tell us which source an idx012 with value 100 into axis 2 of - concatenated array would come from. - */ - Array1 MetaRowIds(int32_t axis); -}; - - - -template -struct Array1OfRagged { - - Array1OfRaggedShape shape; - - // Array of the individual values pointers of the source arrays, indexed by - // shape - Array1 values; - - int32_t NumSrcs() { return values.Dim(); } - - ArrayOfRagged(Ragged *srcs, - int32_t num_srcs); - -} - - - -} // namespace k2 - -#endif // K2_CSRC_ARRAY_OF_RAGGED_H_ diff --git a/k2/csrc/ragged_ops.cu b/k2/csrc/ragged_ops.cu index 80eee3f84..81b637092 100644 --- a/k2/csrc/ragged_ops.cu +++ b/k2/csrc/ragged_ops.cu @@ -564,7 +564,7 @@ RaggedShape Index(RaggedShape &src, int32_t axis, if (axis == 0) { return IndexAxis0(src, indexes, elem_indexes); } else if (axis == src.NumAxes() - 1) { - // This code is related to SubsampleRaggedShape(). `indexes` corresponds + // This code is related to SubsetRaggedShape(). `indexes` corresponds // to `new2old`. Array1 last_row_ids = src.RowIds(num_axes - 1)[indexes]; #ifndef NDEBUG @@ -1632,7 +1632,7 @@ Ragged AddPrefixToRagged(Ragged &src, return Ragged(dst_shape, dst_values); } -RaggedShape SubsampleRaggedShape(RaggedShape &src, Renumbering &renumbering, +RaggedShape SubsetRaggedShape(RaggedShape &src, Renumbering &renumbering, int32_t axis, Array1 *elems_new2old) { NVTX_RANGE(K2_FUNC); axis = axis < 0 ? src.NumAxes() + axis : axis; @@ -1640,7 +1640,7 @@ RaggedShape SubsampleRaggedShape(RaggedShape &src, Renumbering &renumbering, return Index(src, axis, renumbering.New2Old(), elems_new2old); } -RaggedShape SubsampleRaggedShape(RaggedShape &src, Renumbering &r_before_last, +RaggedShape SubsetRaggedShape(RaggedShape &src, Renumbering &r_before_last, Renumbering &r_last) { NVTX_RANGE(K2_FUNC); K2_CHECK_EQ(r_before_last.NumOldElems(), src.TotSize(src.NumAxes() - 2)); @@ -1786,7 +1786,7 @@ RaggedShape RemoveEmptyLists(RaggedShape &src_shape, int32_t axis, Renumbering r_temp; if (!renumbering_out) renumbering_out = &r_temp; bottom_shape = RemoveEmptyListsAxis0(bottom_shape, renumbering_out); - top_shape = SubsampleRaggedShape(top_shape, *renumbering_out); + top_shape = SubsetRaggedShape(top_shape, *renumbering_out); return ComposeRaggedShapes(top_shape, bottom_shape); } @@ -1800,7 +1800,7 @@ RaggedShape RemoveSomeEmptyLists(RaggedShape &src_shape, int32_t axis, DecomposeRaggedShape(src_shape, axis, &top_shape, &bottom_shape); bottom_shape = RenumberAxis0Simple(bottom_shape, renumbering); - top_shape = SubsampleRaggedShape(top_shape, renumbering); + top_shape = SubsetRaggedShape(top_shape, renumbering); return ComposeRaggedShapes(top_shape, bottom_shape); } diff --git a/k2/csrc/ragged_ops.h b/k2/csrc/ragged_ops.h index 9d2e1d91b..cd2d4ebae 100644 --- a/k2/csrc/ragged_ops.h +++ b/k2/csrc/ragged_ops.h @@ -721,7 +721,7 @@ RaggedShape RandomRaggedShape(bool set_row_ids = false, Notice the other version of this function below. */ -RaggedShape SubsampleRaggedShape(RaggedShape &src, +RaggedShape SubsetRaggedShape(RaggedShape &src, Renumbering &renumbering, int32_t axis = -1, Array1 *elems_new2old = nullptr); @@ -736,7 +736,7 @@ RaggedShape SubsampleRaggedShape(RaggedShape &src, Note: all dimensions and tot-sizes preceding the last two axes will remain the same, which might give rise to empty lists. */ -RaggedShape SubsampleRaggedShape(RaggedShape &src, +RaggedShape SubsetRaggedShape(RaggedShape &src, Renumbering &renumbering_before_last, Renumbering &renumbering_last); @@ -846,20 +846,20 @@ RaggedShape RenumberAxis0Simple(RaggedShape &src_shape, necessary with RemoveEmptyLists(). */ template -Ragged SubsampleRagged(Ragged &src, Renumbering &renumbering, +Ragged SubsetRagged(Ragged &src, Renumbering &renumbering, int32_t axis = -1, Array1 *elems_new2old = nullptr) { Array1 tmp; if (elems_new2old == nullptr) elems_new2old = &tmp; - RaggedShape shape = SubsampleRaggedShape(src.shape, renumbering, + RaggedShape shape = SubsetRaggedShape(src.shape, renumbering, axis, elems_new2old); return Ragged(shape, src.values[*elems_new2old]); } /* This function creates a Renumbering object that can be used to obtain subsets - of ragged arrays via SubsampleRaggedShape(). It implements beam pruning as + of ragged arrays via SubsetRaggedShape(). It implements beam pruning as used in pruned Viterbi search and similar algorithms, where there is both a beam and a max-active (`max_elems`) constraint. T will probably be float or double, interpreted as a "positive-is-better" sense, i.e. as scores. @@ -896,17 +896,6 @@ Ragged SubsampleRagged(Ragged &src, Renumbering &renumbering, PruneRagged([ [0 -1 -2 -3], [ -10, -20 ], [ ] ], 0, 5.0, 0) would create a Renumbering object that would prune the ragged tensor to [ [0 -1 -2 -3] ] - - - TODO: don't forget to change subsample->subset when we rename - SubsampleRaggedShape(). - IMPLEMENTATION NOTES (please delete later): - - We might want/need to treat certain cases specially, e.g. - the case when axis == src.NumAxes() - 1, and/or when - axis == 0. - - If `max_elems` is <= 0, we might want to choose a different - implementation, e.g. using max on the sub-lists rather - than sorting. */ template Renumbering PruneRagged(Ragged &src, @@ -1347,7 +1336,7 @@ Ragged Merge(int32_t num_srcs, Ragged **src, /* Returns a ragged tensor after removing all 'values' that were <= a provided cutoff. Leaves all layers of the shape except for the last one unaffected. - Equivalent to SubsampleRaggedShape with a numbering given by (src.values[i] <= + Equivalent to SubsetRaggedShape with a numbering given by (src.values[i] <= cutoff). */ template @@ -1356,7 +1345,7 @@ Ragged RemoveValuesLeq(Ragged &src, T cutoff); /* Returns a ragged tensor after removing all 'values' that equal a provided target. Leaves all layers of the shape except for the last one unaffected. - Equivalent to SubsampleRaggedShape with a numbering given by (src.values[i] == + Equivalent to SubsetRaggedShape with a numbering given by (src.values[i] == target). */ template diff --git a/k2/csrc/ragged_ops_inl.h b/k2/csrc/ragged_ops_inl.h index f1cff92f7..84bf5983d 100644 --- a/k2/csrc/ragged_ops_inl.h +++ b/k2/csrc/ragged_ops_inl.h @@ -232,7 +232,7 @@ Ragged RemoveValuesLeq(Ragged &src, T cutoff) { K2_EVAL( c, src.NumElements(), lambda_set_keep, (int32_t i)->void { keep[i] = (char)(values_data[i] > cutoff); }); - return SubsampleRagged(src, r); + return SubsetRagged(src, r); } template @@ -245,7 +245,7 @@ Ragged RemoveValuesEq(Ragged &src, T target) { K2_EVAL( c, src.NumElements(), lambda_set_keep, (int32_t i)->void { keep[i] = (char)(values_data[i] != target); }); - return SubsampleRagged(src, r); + return SubsetRagged(src, r); } // Recursive function that prints (part of) a ragged shape. diff --git a/k2/csrc/ragged_shape_test.cu b/k2/csrc/ragged_shape_test.cu index e7fe67bf9..d365b62aa 100644 --- a/k2/csrc/ragged_shape_test.cu +++ b/k2/csrc/ragged_shape_test.cu @@ -420,7 +420,7 @@ TEST(RaggedShapeTest, RandomRaggedShape) { } } -TEST(RaggedShapeTest, SubsampleRaggedShape) { +TEST(RaggedShapeTest, SubsetRaggedShape) { for (auto &c : {GetCpuContext(), GetCudaContext()}) { RaggedShape src(c, "[ [ [ x x ] [ x ] ] [ [ x x x ] [ ] ] [ [ x ] [ x x ] ] ]"); @@ -430,14 +430,14 @@ TEST(RaggedShapeTest, SubsampleRaggedShape) { auto &keep = renumbering.Keep(); keep.CopyFrom(ref_keep); Array1 new2old; - auto dest = SubsampleRaggedShape(src, renumbering, 2, &new2old); + auto dest = SubsetRaggedShape(src, renumbering, 2, &new2old); Array1 ref_new2old(c, std::vector({0, 2, 5, 7})); RaggedShape ref_dest(c, "[ [ [ x ] [ x ] ] [ [ x ] [ ] ] [ [ ] [ x ] ] ]"); EXPECT_TRUE(Equal(dest, ref_dest)); EXPECT_TRUE(Equal(new2old, ref_new2old)); // test axis = -1 - dest = SubsampleRaggedShape(src, renumbering, -1, &new2old); + dest = SubsetRaggedShape(src, renumbering, -1, &new2old); EXPECT_TRUE(Equal(dest, ref_dest)); EXPECT_TRUE(Equal(new2old, ref_new2old)); @@ -447,13 +447,13 @@ TEST(RaggedShapeTest, SubsampleRaggedShape) { renumbering = Renumbering(c, src.TotSize(1)); keep = renumbering.Keep(); keep.CopyFrom(ref_keep); - dest = SubsampleRaggedShape(src, renumbering, 1, &new2old); + dest = SubsetRaggedShape(src, renumbering, 1, &new2old); ref_new2old = Array1(c, std::vector({0, 1, 6, 7, 8})); ref_dest = RaggedShape(c, "[ [ [ x x ] ] [ [ ] ] [ [ x ] [ x x ] ] ]"); EXPECT_TRUE(Equal(dest, ref_dest)); EXPECT_TRUE(Equal(new2old, ref_new2old)); // test axis = -2 - dest = SubsampleRaggedShape(src, renumbering, -2, &new2old); + dest = SubsetRaggedShape(src, renumbering, -2, &new2old); EXPECT_TRUE(Equal(dest, ref_dest)); EXPECT_TRUE(Equal(new2old, ref_new2old)); @@ -463,14 +463,14 @@ TEST(RaggedShapeTest, SubsampleRaggedShape) { renumbering = Renumbering(c, src.TotSize(0)); keep = renumbering.Keep(); keep.CopyFrom(ref_keep); - dest = SubsampleRaggedShape(src, renumbering, 0, &new2old); + dest = SubsetRaggedShape(src, renumbering, 0, &new2old); ref_new2old = Array1(c, std::vector({0, 1, 2, 6, 7, 8})); ref_dest = RaggedShape(c, "[ [ [ x x ] [ x ] ] [ [ x ] [ x x ] ] ]"); EXPECT_TRUE(Equal(dest, ref_dest)); EXPECT_TRUE(Equal(new2old, ref_new2old)); // test axis = -3 - dest = SubsampleRaggedShape(src, renumbering, -3, &new2old); + dest = SubsetRaggedShape(src, renumbering, -3, &new2old); EXPECT_TRUE(Equal(dest, ref_dest)); EXPECT_TRUE(Equal(new2old, ref_new2old)); } diff --git a/k2/csrc/ragged_test.cu b/k2/csrc/ragged_test.cu index 14a488a0f..011a66178 100644 --- a/k2/csrc/ragged_test.cu +++ b/k2/csrc/ragged_test.cu @@ -2870,7 +2870,7 @@ TEST(RaggedTest, TestPruneRagged) { } template -static void TestPruneRaggedAndSubsampleRagged() { +static void TestPruneRaggedAndSubsetRagged() { for (auto &c : {GetCpuContext(), GetCudaContext()}) { Ragged src(c, "[ [ [ 1.1 4.2 2.1 1.8 ] [ 5.0 3.1 ] ] " " [ [ 1.2 ] [ 2.2 6.3 ] [ 2.4 6.1 ] [ 5.1 ] ] " @@ -2878,7 +2878,7 @@ static void TestPruneRaggedAndSubsampleRagged() { T beam = 1.0; auto renumbering = PruneRagged(src, 0, beam, 3); Array1 new2old; - auto dest = SubsampleRagged(src, renumbering, 0, &new2old); + auto dest = SubsetRagged(src, renumbering, 0, &new2old); Ragged dest_ref(c, "[ [ [ 1.2 ] [ 2.2 6.3 ] [ 2.4 6.1 ] [ 5.1 ] ] ]"); Array1 new2old_ref(c, std::vector{6, 7, 8, 9, 10, 11}); @@ -2887,7 +2887,7 @@ static void TestPruneRaggedAndSubsampleRagged() { beam = 2.0; renumbering = PruneRagged(src, 1, beam, 5); - dest = SubsampleRagged(src, renumbering, 1, &new2old); + dest = SubsetRagged(src, renumbering, 1, &new2old); dest_ref = Ragged(c, "[ [ [ 5.0 3.1 ] ] [ [ 2.2 6.3 ] [ 2.4 6.1 ] [ 5.1 ] ] " " [ [ 1.4 0.8 2.3 5.2 3.6 ] ] ]"); @@ -2898,7 +2898,7 @@ static void TestPruneRaggedAndSubsampleRagged() { beam = 3.0; renumbering = PruneRagged(src, 2, beam, 3); - dest = SubsampleRagged(src, renumbering, 2, &new2old); + dest = SubsetRagged(src, renumbering, 2, &new2old); dest_ref = Ragged(c, "[ [ [ 4.2 2.1 1.8 ] [ 5.0 3.1 ] ] [ [ 1.2 ] [ 6.3 ] [ 6.1 ] [ 5.1 ] ]" " [ [ 4.4 ] [ 2.3 5.2 3.6 ] ] ]"); @@ -2909,9 +2909,9 @@ static void TestPruneRaggedAndSubsampleRagged() { } } -TEST(RaggedTest, TestPruneRaggedAndSubsampleRagged) { - TestPruneRaggedAndSubsampleRagged(); - TestPruneRaggedAndSubsampleRagged(); +TEST(RaggedTest, TestPruneRaggedAndSubsetRagged) { + TestPruneRaggedAndSubsetRagged(); + TestPruneRaggedAndSubsetRagged(); } } // namespace k2 diff --git a/k2/csrc/rm_epsilon.cu b/k2/csrc/rm_epsilon.cu index b8774bf53..c20775c1a 100644 --- a/k2/csrc/rm_epsilon.cu +++ b/k2/csrc/rm_epsilon.cu @@ -795,7 +795,7 @@ void ComputeEpsilonClosure(FsaVec &epsilon_fsa, FsaVec *closure_fsa, const Arc &cur_arc = arcs_data[arc_idx012]; arc_keep_data[arc_idx012] = (cur_arc.src_state != cur_arc.dest_state); }); - *closure_fsa = SubsampleRagged(*closure_fsa, arc_renumbering); + *closure_fsa = SubsetRagged(*closure_fsa, arc_renumbering); *arc_map = Index(*arc_map, 0, arc_renumbering.New2Old()); } @@ -1081,7 +1081,7 @@ void RemoveEpsilonDevice(FsaOrVec &src_fsa, FsaOrVec *dest_fsa, non_epsilon_arc_map, foll_shape, &combined_foll, &combined_foll_arc_map); FsaVec epsilon_closure_prec = - SubsampleRagged(epsilon_closure_mapped, epsilon_prec_renumbering); + SubsetRagged(epsilon_closure_mapped, epsilon_prec_renumbering); Ragged epsilon_closure_prec_arc_map = Index( epsilon_closure_mapped_arc_map, 0, epsilon_prec_renumbering.New2Old()); // `combined_prec` will be set to an FSA, with the same state numbering as From f9a07de77df0fb91d3073550738df8eb42f10404 Mon Sep 17 00:00:00 2001 From: pkufool Date: Mon, 14 Feb 2022 11:47:36 +0800 Subject: [PATCH 7/8] Minor fixes --- k2/csrc/ragged_ops.cu | 4 ++-- k2/csrc/ragged_ops.h | 14 +++++++------- k2/csrc/ragged_ops_inl.h | 6 +++--- k2/csrc/ragged_shape_test.cu | 2 -- 4 files changed, 12 insertions(+), 14 deletions(-) diff --git a/k2/csrc/ragged_ops.cu b/k2/csrc/ragged_ops.cu index 81b637092..4db7ac121 100644 --- a/k2/csrc/ragged_ops.cu +++ b/k2/csrc/ragged_ops.cu @@ -1633,7 +1633,7 @@ Ragged AddPrefixToRagged(Ragged &src, } RaggedShape SubsetRaggedShape(RaggedShape &src, Renumbering &renumbering, - int32_t axis, Array1 *elems_new2old) { + int32_t axis, Array1 *elems_new2old) { NVTX_RANGE(K2_FUNC); axis = axis < 0 ? src.NumAxes() + axis : axis; K2_CHECK_EQ(renumbering.NumOldElems(), src.TotSize(axis)); @@ -1641,7 +1641,7 @@ RaggedShape SubsetRaggedShape(RaggedShape &src, Renumbering &renumbering, } RaggedShape SubsetRaggedShape(RaggedShape &src, Renumbering &r_before_last, - Renumbering &r_last) { + Renumbering &r_last) { NVTX_RANGE(K2_FUNC); K2_CHECK_EQ(r_before_last.NumOldElems(), src.TotSize(src.NumAxes() - 2)); K2_CHECK_EQ(r_last.NumOldElems(), src.NumElements()); diff --git a/k2/csrc/ragged_ops.h b/k2/csrc/ragged_ops.h index cd2d4ebae..96a05c1e7 100644 --- a/k2/csrc/ragged_ops.h +++ b/k2/csrc/ragged_ops.h @@ -722,9 +722,9 @@ RaggedShape RandomRaggedShape(bool set_row_ids = false, Notice the other version of this function below. */ RaggedShape SubsetRaggedShape(RaggedShape &src, - Renumbering &renumbering, - int32_t axis = -1, - Array1 *elems_new2old = nullptr); + Renumbering &renumbering, + int32_t axis = -1, + Array1 *elems_new2old = nullptr); /* @@ -737,8 +737,8 @@ RaggedShape SubsetRaggedShape(RaggedShape &src, same, which might give rise to empty lists. */ RaggedShape SubsetRaggedShape(RaggedShape &src, - Renumbering &renumbering_before_last, - Renumbering &renumbering_last); + Renumbering &renumbering_before_last, + Renumbering &renumbering_last); /* Removes empty lists on a particular axis (not last axis) of a RaggedShape, @@ -847,8 +847,8 @@ RaggedShape RenumberAxis0Simple(RaggedShape &src_shape, */ template Ragged SubsetRagged(Ragged &src, Renumbering &renumbering, - int32_t axis = -1, - Array1 *elems_new2old = nullptr) { + int32_t axis = -1, + Array1 *elems_new2old = nullptr) { Array1 tmp; if (elems_new2old == nullptr) elems_new2old = &tmp; diff --git a/k2/csrc/ragged_ops_inl.h b/k2/csrc/ragged_ops_inl.h index 84bf5983d..ce8c2dda7 100644 --- a/k2/csrc/ragged_ops_inl.h +++ b/k2/csrc/ragged_ops_inl.h @@ -805,8 +805,8 @@ Renumbering PruneRaggedAxis0(Ragged &src, T beam, int32_t max_elems) { // Prune a two axes ragged tensor on axis1 template -Renumbering PruneRaggedLastAxis(Ragged &src, T beam, - int32_t max_elems) { +Renumbering PruneRaggedAxis1(Ragged &src, T beam, + int32_t max_elems) { K2_CHECK_EQ(src.NumAxes(), 2); const ContextPtr &c = src.Context(); int32_t total_elements = src.TotSize(1); @@ -873,7 +873,7 @@ Renumbering PruneRagged(Ragged &src, int32_t axis, T beam, while (reduced_src.NumAxes() > 2) { reduced_src = RemoveAxis(reduced_src, 0); } - return PruneRaggedLastAxis(reduced_src, beam, max_elems); + return PruneRaggedAxis1(reduced_src, beam, max_elems); } else { RaggedShape top, bottom; DecomposeRaggedShape(src.shape, axis, &top, &bottom); diff --git a/k2/csrc/ragged_shape_test.cu b/k2/csrc/ragged_shape_test.cu index d365b62aa..451a2cf60 100644 --- a/k2/csrc/ragged_shape_test.cu +++ b/k2/csrc/ragged_shape_test.cu @@ -441,7 +441,6 @@ TEST(RaggedShapeTest, SubsetRaggedShape) { EXPECT_TRUE(Equal(dest, ref_dest)); EXPECT_TRUE(Equal(new2old, ref_new2old)); - // axis = 1 ref_keep = Array1(c, std::vector({1, 0, 0, 1, 1, 1})); renumbering = Renumbering(c, src.TotSize(1)); @@ -457,7 +456,6 @@ TEST(RaggedShapeTest, SubsetRaggedShape) { EXPECT_TRUE(Equal(dest, ref_dest)); EXPECT_TRUE(Equal(new2old, ref_new2old)); - // axis = 0 ref_keep = Array1(c, std::vector({1, 0, 1})); renumbering = Renumbering(c, src.TotSize(0)); From ccb29e01cc433dd7eb97689cccb1a6dbbf75708b Mon Sep 17 00:00:00 2001 From: pkufool Date: Wed, 16 Feb 2022 19:23:45 +0800 Subject: [PATCH 8/8] Fix comments --- k2/csrc/algorithms.h | 4 +--- k2/csrc/ragged_ops.h | 23 ++++++++++------------- k2/csrc/ragged_ops_inl.h | 16 ++++++++++++---- 3 files changed, 23 insertions(+), 20 deletions(-) diff --git a/k2/csrc/algorithms.h b/k2/csrc/algorithms.h index c058c533a..6e11a31cb 100644 --- a/k2/csrc/algorithms.h +++ b/k2/csrc/algorithms.h @@ -111,9 +111,7 @@ class Renumbering { (pre-renumbering) indexes. Its dimension is the number of new indexes (i.e. the number of 1 in keep_), but internally it has one extra element which contains the number of old - elements, so it's OK to read one past the end. (We may - later make it possible to access the array with the one-larger - dimension). + elements, so it's OK to read one past the end. */ Array1 &New2Old() { NVTX_RANGE(K2_FUNC); diff --git a/k2/csrc/ragged_ops.h b/k2/csrc/ragged_ops.h index 96a05c1e7..ffe7a7413 100644 --- a/k2/csrc/ragged_ops.h +++ b/k2/csrc/ragged_ops.h @@ -707,7 +707,7 @@ RaggedShape RandomRaggedShape(bool set_row_ids = false, @param [in] renumbering The renumbering object that dictates which elements of `src` we keep; we require renumbering.NumOldElems() == src.TotSize(axis2) - where axis2 = (axis < 0 ? src.NumAxes() - axis : axis). + where axis2 = (axis < 0 ? src.NumAxes() + axis : axis). @param [in] axis The axis to subsample; if negative, will be interpreted as an offset from src.NumAxes(). @param [out] elems_new2old If supplied, this function will @@ -715,7 +715,7 @@ RaggedShape RandomRaggedShape(bool set_row_ids = false, dictates how the elements of a ragged tensor with shape `src` would be renumbered. @return Returns the subsampled shape. All dimensions and tot-sizes - preceding the final axis will remain the same, which might give + preceding the axis `axis` will remain the same, which might give rise to empty lists on those axes; these can be removed if necessary with RemoveEmptyLists(). @@ -841,7 +841,7 @@ RaggedShape RenumberAxis0Simple(RaggedShape &src_shape, dictates how the elements of a ragged tensor with shape `src` would be renumbered. @return Returns the subsampled shape. All dimensions and tot-sizes - preceding the final axis will remain the same, which might give + preceding the axis `axis` will remain the same, which might give rise to empty lists on those axes; these can be removed if necessary with RemoveEmptyLists(). */ @@ -866,7 +866,7 @@ Ragged SubsetRagged(Ragged &src, Renumbering &renumbering, @param [in] src The ragged object to be subsampled. @param [in] axis The axis to be subsampled, must satisfy - 0 <= axis < src.NumAxes(). The axis before `axis`, if axis < 0, + 0 <= axis < src.NumAxes(). The axis before `axis`, if axis > 0, will be interpreted as a "batch" axis. @param [in] beam The main pruning beam. The sub-lists of elements on axis `axis` will be removed if their maximum element (or the element @@ -874,17 +874,15 @@ Ragged SubsetRagged(Ragged &src, Renumbering &renumbering, this_best_elem - beam, where this_best_elem is the maximum element taken over axis `axis-1` (or over the entire array, if axis == 0). Think of axis `axis-1`, if - present, as the "batch" axis, and axis `axis` as the axis that we + axis > 0, as the "batch" axis, and axis `axis` as the axis that we actually remove elements or sub-lists on. Empty sub-lists on axis `axis` will always be pruned, as their score would be treated as -infinity. - @param [in] max_elems If max_elems > 0, it is the maximum number of sub-lists - or elements that are allowed within any sub-list on axis `axis-1` - (or the maximum number of top-level sub-lists after subsampling, - if axis == 0). We keep the best ones, but behavior in case of ties is - undefined (TODO: check whether SortSublists() is a stable sort, and - change this doc if it is). If max_elems <= 0, there is no such - constraint. + @param [in] max_elems If max_elems > 0, it is the maximum number of + sub-lists or elements that are allowed within any sub-list + on axis `axis-1` (or the maximum number of top-level sub-lists + after subsampling, if axis == 0). We keep the best ones. + If max_elems <= 0, there is no such constraint. @return Returns the renumbering object to be used to actually prune/subsample the specified axis. @@ -903,7 +901,6 @@ Renumbering PruneRagged(Ragged &src, T beam, int32_t max_elems); - /* Stack a list of Ragged arrays to create a Ragged array with one more axis. Similar to TF/PyTorch's Stack. The result will have Dim0 == num_srcs. All diff --git a/k2/csrc/ragged_ops_inl.h b/k2/csrc/ragged_ops_inl.h index ce8c2dda7..e274e96bb 100644 --- a/k2/csrc/ragged_ops_inl.h +++ b/k2/csrc/ragged_ops_inl.h @@ -317,9 +317,9 @@ static void SortSublistsCpu(Ragged *src, Array1 *order) { int32_t cur = row_splits[i]; int32_t next = row_splits[i + 1]; if (order != nullptr) - std::sort(order->Data() + cur, order->Data() + next, lambda_comp); + std::stable_sort(order->Data() + cur, order->Data() + next, lambda_comp); - std::sort(p + cur, p + next, comp); + std::stable_sort(p + cur, p + next, comp); } } @@ -759,7 +759,11 @@ Array2 PadRagged(Ragged &src, const std::string &mode, T padding_value) { return res; } -// Prune a two axes ragged tensor on axis0 +/* Prune a two axes ragged tensor on axis0. + * This is a special case of PruneRagged with axis == 0 and src.NumAxes() == 2, + * To get more details, please refer to the docs for PruneRagged in + * ragged_ops.h. + */ template Renumbering PruneRaggedAxis0(Ragged &src, T beam, int32_t max_elems) { K2_CHECK_EQ(src.NumAxes(), 2); @@ -803,7 +807,11 @@ Renumbering PruneRaggedAxis0(Ragged &src, T beam, int32_t max_elems) { return renumbering; } -// Prune a two axes ragged tensor on axis1 +/* Prune a two axes ragged tensor on axis1 + * This is a special case of PruneRagged with axis == 1 and src.NumAxes() == 2, + * To get more details, please refer to the docs for PruneRagged in + * ragged_ops.h. + */ template Renumbering PruneRaggedAxis1(Ragged &src, T beam, int32_t max_elems) {