diff --git a/k2/csrc/algorithms.h b/k2/csrc/algorithms.h index 439ac5ba3..6e11a31cb 100644 --- a/k2/csrc/algorithms.h +++ b/k2/csrc/algorithms.h @@ -111,9 +111,7 @@ class Renumbering { (pre-renumbering) indexes. Its dimension is the number of new indexes (i.e. the number of 1 in keep_), but internally it has one extra element which contains the number of old - elements, so it's OK to read one past the end. (We may - later make it possible to access the array with the one-larger - dimension). + elements, so it's OK to read one past the end. */ Array1 &New2Old() { NVTX_RANGE(K2_FUNC); @@ -121,17 +119,40 @@ class Renumbering { return new2old_; } + /* Return a mapping from new index to old index, with one extra element + containing the total number of kept elements if extra_element == true. + If Keep() can be interpreted as a tails vector, i.e. with 1 at the end + of sub-lists of elements, then New2Old(true) would corresponds to a + row-splits array and Old2New(false) would correspond to a row-ids + array. + */ + Array1 New2Old(bool extra_element) { + Array1 &new2old_part = New2Old(); + if (!extra_element) { + return new2old_part; + } else { + // This is a little perverse, using low-level interfaces to increase the + // dimension of the array; but we know it does have one more element. + // Because we normally use New2Old() with no arg (equivalent to false), + // the overloaded version of this function returns a reference for + // efficiency. + return Array1(new2old_part.Dim() + 1, + new2old_part.GetRegion(), 0); + } + } + /* Return a mapping from old index to new index. This is created on demand (must only be called after the Keep() array has been populated). @param [in] extra_element If true, will return the array of size NumOldElems() + 1, which includes one more element; otherwise it will return an array of size NumOldElems(). + + + @return Returns an array mapping the old indexes to the new indexes. This array is just the exclusive sum of Keep(). It gives the mapping for indexes that are kept; element i is kept if `Old2New()[i+1] > Old2New()[i]`. - - @return Returns an array mapping the old indexes to the new indexes. */ Array1 Old2New(bool extra_element = false) { NVTX_RANGE(K2_FUNC); diff --git a/k2/csrc/algorithms_test.cu b/k2/csrc/algorithms_test.cu index 5edc01cac..bbf310d12 100644 --- a/k2/csrc/algorithms_test.cu +++ b/k2/csrc/algorithms_test.cu @@ -45,6 +45,9 @@ TEST(AlgorithmsTest, TestRenumbering) { Array1 new2old = numbering.New2Old(); EXPECT_EQ(new2old.Dim(), 0); EXPECT_EQ(numbering.NumNewElems(), 0); + new2old = numbering.New2Old(true); + EXPECT_EQ(new2old.Dim(), 1); + EXPECT_EQ(new2old.Back(), 0); } { @@ -67,6 +70,9 @@ TEST(AlgorithmsTest, TestRenumbering) { Array1 new2old = numbering.New2Old(); EXPECT_EQ(new2old.Dim(), 0); EXPECT_EQ(numbering.NumNewElems(), 0); + new2old = numbering.New2Old(true); + EXPECT_EQ(new2old.Dim(), 1); + EXPECT_EQ(new2old.Back(), 5); } { @@ -93,6 +99,9 @@ TEST(AlgorithmsTest, TestRenumbering) { std::vector cpu_new2old(new2old.Data(), new2old.Data() + new2old.Dim()); EXPECT_THAT(cpu_new2old, ::testing::ElementsAre(0, 2, 3, 6)); + new2old = numbering.New2Old(true); + EXPECT_EQ(new2old.Dim(), 5); + EXPECT_EQ(new2old.Back(), 7); } } } diff --git a/k2/csrc/ragged_ops.cu b/k2/csrc/ragged_ops.cu index 6120428cc..808ad41e2 100644 --- a/k2/csrc/ragged_ops.cu +++ b/k2/csrc/ragged_ops.cu @@ -569,7 +569,7 @@ RaggedShape Index(RaggedShape &src, int32_t axis, if (axis == 0) { return IndexAxis0(src, indexes, elem_indexes); } else if (axis == src.NumAxes() - 1) { - // This code is related to SubsampleRaggedShape(). `indexes` corresponds + // This code is related to SubsetRaggedShape(). `indexes` corresponds // to `new2old`. Array1 last_row_ids = src.RowIds(num_axes - 1)[indexes]; #ifndef NDEBUG @@ -1944,21 +1944,16 @@ Ragged AddPrefixToRagged(Ragged &src, return Ragged(dst_shape, dst_values); } -RaggedShape SubsampleRaggedShape(RaggedShape &src, Renumbering &renumbering) { +RaggedShape SubsetRaggedShape(RaggedShape &src, Renumbering &renumbering, + int32_t axis, Array1 *elems_new2old) { NVTX_RANGE(K2_FUNC); - K2_CHECK_EQ(renumbering.NumOldElems(), src.NumElements()); - - // Make sure final row-ids are populated. - src.RowIds(src.NumAxes() - 1); - std::vector axes = src.Layers(); - axes.back().row_ids = axes.back().row_ids[renumbering.New2Old()]; - axes.back().row_splits = renumbering.Old2New()[axes.back().row_splits]; - axes.back().cached_tot_size = axes.back().row_ids.Dim(); - return RaggedShape(axes); + axis = axis < 0 ? src.NumAxes() + axis : axis; + K2_CHECK_EQ(renumbering.NumOldElems(), src.TotSize(axis)); + return Index(src, axis, renumbering.New2Old(), elems_new2old); } -RaggedShape SubsampleRaggedShape(RaggedShape &src, Renumbering &r_before_last, - Renumbering &r_last) { +RaggedShape SubsetRaggedShape(RaggedShape &src, Renumbering &r_before_last, + Renumbering &r_last) { NVTX_RANGE(K2_FUNC); K2_CHECK_EQ(r_before_last.NumOldElems(), src.TotSize(src.NumAxes() - 2)); K2_CHECK_EQ(r_last.NumOldElems(), src.NumElements()); @@ -2103,7 +2098,7 @@ RaggedShape RemoveEmptyLists(RaggedShape &src_shape, int32_t axis, Renumbering r_temp; if (!renumbering_out) renumbering_out = &r_temp; bottom_shape = RemoveEmptyListsAxis0(bottom_shape, renumbering_out); - top_shape = SubsampleRaggedShape(top_shape, *renumbering_out); + top_shape = SubsetRaggedShape(top_shape, *renumbering_out); return ComposeRaggedShapes(top_shape, bottom_shape); } @@ -2117,7 +2112,7 @@ RaggedShape RemoveSomeEmptyLists(RaggedShape &src_shape, int32_t axis, DecomposeRaggedShape(src_shape, axis, &top_shape, &bottom_shape); bottom_shape = RenumberAxis0Simple(bottom_shape, renumbering); - top_shape = SubsampleRaggedShape(top_shape, renumbering); + top_shape = SubsetRaggedShape(top_shape, renumbering); return ComposeRaggedShapes(top_shape, bottom_shape); } diff --git a/k2/csrc/ragged_ops.h b/k2/csrc/ragged_ops.h index 62e8b640c..83c2ca238 100644 --- a/k2/csrc/ragged_ops.h +++ b/k2/csrc/ragged_ops.h @@ -759,14 +759,32 @@ RaggedShape RandomRaggedShape(bool set_row_ids = false, int32_t max_num_elements = 2000); /* - Return ragged shape with only a subset of the bottom-level elements kept. - Require renumbering.NumOldElems() == src.NumElements(). Note: all - dimensions and tot-sizes preceding the final axis will remain the same, which - might give rise to empty lists. + Return ragged shape with only a subset of the elements or sub-lists + on the specified axis kept. (This is not regular sampling, it is + irregular subsampling with specified elements kept). + + @param [in] src The ragged shape that we are subsampling + @param [in] renumbering The renumbering object that dictates + which elements of `src` we keep; we require + renumbering.NumOldElems() == src.TotSize(axis2) + where axis2 = (axis < 0 ? src.NumAxes() + axis : axis). + @param [in] axis The axis to subsample; if negative, will be + interpreted as an offset from src.NumAxes(). + @param [out] elems_new2old If supplied, this function will + output to this location a new2old vector that + dictates how the elements of a ragged tensor + with shape `src` would be renumbered. + @return Returns the subsampled shape. All dimensions and tot-sizes + preceding the axis `axis` will remain the same, which might give + rise to empty lists on those axes; these can be removed if + necessary with RemoveEmptyLists(). Notice the other version of this function below. */ -RaggedShape SubsampleRaggedShape(RaggedShape &src, Renumbering &renumbering); +RaggedShape SubsetRaggedShape(RaggedShape &src, + Renumbering &renumbering, + int32_t axis = -1, + Array1 *elems_new2old = nullptr); /* Return ragged shape with only a subset of the elements on the last @@ -777,9 +795,9 @@ RaggedShape SubsampleRaggedShape(RaggedShape &src, Renumbering &renumbering); Note: all dimensions and tot-sizes preceding the last two axes will remain the same, which might give rise to empty lists. */ -RaggedShape SubsampleRaggedShape(RaggedShape &src, - Renumbering &renumbering_before_last, - Renumbering &renumbering_last); +RaggedShape SubsetRaggedShape(RaggedShape &src, + Renumbering &renumbering_before_last, + Renumbering &renumbering_last); /* Removes empty lists on a particular axis (not last axis) of a RaggedShape, @@ -866,17 +884,82 @@ RaggedShape RenumberAxis0Simple(RaggedShape &src_shape, /* - Return ragged array with only a subset of the bottom-level elements kept. - Require renumbering.NumOldElems() == src.NumElements(). Note: all - dimensions and tot-sizes preceding the final axis will remain the same, which - might give rise to empty lists. + Return ragged array with only a subset of the elements or sub-lists + on the specified axis kept. (This is not regular sampling, it is + irregular subsampling with specified elements kept). + + @param [in] src The ragged shape that we are subsampling + @param [in] renumbering The renumbering object that dictates + which elements of `src` we keep; we require + renumbering.NumOldElems() == src.TotSize(axis2) + where axis2 = (axis < 0 ? src.NumAxes() - axis : axis). + @param [in] axis The axis to subsample; if negative, will be + interpreted as an offset from src.NumAxes(). + @param [out] elems_new2old If supplied, this function will + output to this location a new2old array that + dictates how the elements of a ragged tensor + with shape `src` would be renumbered. + @return Returns the subsampled shape. All dimensions and tot-sizes + preceding the axis `axis` will remain the same, which might give + rise to empty lists on those axes; these can be removed if + necessary with RemoveEmptyLists(). */ template -Ragged SubsampleRagged(Ragged &src, Renumbering &renumbering) { - return Ragged(SubsampleRaggedShape(src.shape, renumbering), - src.values[renumbering.New2Old()]); +Ragged SubsetRagged(Ragged &src, Renumbering &renumbering, + int32_t axis = -1, + Array1 *elems_new2old = nullptr) { + Array1 tmp; + if (elems_new2old == nullptr) + elems_new2old = &tmp; + RaggedShape shape = SubsetRaggedShape(src.shape, renumbering, + axis, elems_new2old); + return Ragged(shape, src.values[*elems_new2old]); } +/* + This function creates a Renumbering object that can be used to obtain subsets + of ragged arrays via SubsetRaggedShape(). It implements beam pruning as + used in pruned Viterbi search and similar algorithms, where there is both a + beam and a max-active (`max_elems`) constraint. T will probably be float or + double, interpreted as a "positive-is-better" sense, i.e. as scores. + + @param [in] src The ragged object to be subsampled. + @param [in] axis The axis to be subsampled, must satisfy + 0 <= axis < src.NumAxes(). The axis before `axis`, if axis > 0, + will be interpreted as a "batch" axis. + @param [in] beam The main pruning beam. The sub-lists of elements on axis + `axis` will be removed if their maximum element (or the element + itself, if axis + 1 == src.NumAxes()) is less than + this_best_elem - beam, where this_best_elem + is the maximum element taken over axis `axis-1` (or over the + entire array, if axis == 0). Think of axis `axis-1`, if + axis > 0, as the "batch" axis, and axis `axis` as the axis that we + actually remove elements or sub-lists on. Empty sub-lists on axis + `axis` will always be pruned, as their score would be treated + as -infinity. + @param [in] max_elems If max_elems > 0, it is the maximum number of + sub-lists or elements that are allowed within any sub-list + on axis `axis-1` (or the maximum number of top-level sub-lists + after subsampling, if axis == 0). We keep the best ones. + If max_elems <= 0, there is no such constraint. + @return Returns the renumbering object to be used to actually + prune/subsample the specified axis. + + Example: + PruneRagged([ [0 -1 -2 -3], [ -10, -20 ], [ ] ], 1, 5.0, 3) + would create a Renumbering object that would prune the + ragged tensor to [ [0 -1 -2], [ -10 ], [ ] ] + + PruneRagged([ [0 -1 -2 -3], [ -10, -20 ], [ ] ], 0, 5.0, 0) + would create a Renumbering object that would prune the + ragged tensor to [ [0 -1 -2 -3] ] + */ +template +Renumbering PruneRagged(Ragged &src, + int32_t axis, + T beam, + int32_t max_elems); + /* Stack a list of Ragged arrays to create a Ragged array with one more axis. Similar to TF/PyTorch's Stack. The result will have Dim0 == num_srcs. All @@ -974,8 +1057,7 @@ void Unstack(Ragged src, int32_t axis, std::vector> *out, /* Concatenate a list of Ragged to form a single Ragged. - @param [in] axis Axis to append them on. Currently - we only support axis == 0 or axis == 1. + @param [in] axis Axis to append them on. Previous axes must have the same shape, i.e. if axis == 1 then `src[i]->Dim0()` must all have the @@ -1368,7 +1450,7 @@ Ragged Merge(int32_t num_srcs, Ragged **src, /* Returns a ragged tensor after removing all 'values' that were <= a provided cutoff. Leaves all layers of the shape except for the last one unaffected. - Equivalent to SubsampleRaggedShape with a numbering given by (src.values[i] <= + Equivalent to SubsetRaggedShape with a numbering given by (src.values[i] <= cutoff). */ template @@ -1377,7 +1459,7 @@ Ragged RemoveValuesLeq(Ragged &src, T cutoff); /* Returns a ragged tensor after removing all 'values' that equal a provided target. Leaves all layers of the shape except for the last one unaffected. - Equivalent to SubsampleRaggedShape with a numbering given by (src.values[i] == + Equivalent to SubsetRaggedShape with a numbering given by (src.values[i] == target). */ template diff --git a/k2/csrc/ragged_ops_inl.h b/k2/csrc/ragged_ops_inl.h index bee7efb21..47297fab4 100644 --- a/k2/csrc/ragged_ops_inl.h +++ b/k2/csrc/ragged_ops_inl.h @@ -288,7 +288,7 @@ Ragged RemoveValuesLeq(Ragged &src, T cutoff) { K2_EVAL( c, src.NumElements(), lambda_set_keep, (int32_t i)->void { keep[i] = (char)(values_data[i] > cutoff); }); - return SubsampleRagged(src, r); + return SubsetRagged(src, r); } template @@ -301,7 +301,7 @@ Ragged RemoveValuesEq(Ragged &src, T target) { K2_EVAL( c, src.NumElements(), lambda_set_keep, (int32_t i)->void { keep[i] = (char)(values_data[i] != target); }); - return SubsampleRagged(src, r); + return SubsetRagged(src, r); } // Recursive function that prints (part of) a ragged shape. @@ -373,9 +373,9 @@ static void SortSublistsCpu(Ragged *src, Array1 *order) { int32_t cur = row_splits[i]; int32_t next = row_splits[i + 1]; if (order != nullptr) - std::sort(order->Data() + cur, order->Data() + next, lambda_comp); + std::stable_sort(order->Data() + cur, order->Data() + next, lambda_comp); - std::sort(p + cur, p + next, comp); + std::stable_sort(p + cur, p + next, comp); } } @@ -815,6 +815,137 @@ Array2 PadRagged(Ragged &src, const std::string &mode, T padding_value) { return res; } +/* Prune a two axes ragged tensor on axis0. + * This is a special case of PruneRagged with axis == 0 and src.NumAxes() == 2, + * To get more details, please refer to the docs for PruneRagged in + * ragged_ops.h. + */ +template +Renumbering PruneRaggedAxis0(Ragged &src, T beam, int32_t max_elems) { + K2_CHECK_EQ(src.NumAxes(), 2); + const ContextPtr &c = src.Context(); + int32_t total_elements = src.TotSize(0); + Renumbering renumbering(c, total_elements); + + T negative_infinity = -std::numeric_limits::infinity(); + Array1 sub_max(c, total_elements); + MaxPerSublist(src, negative_infinity, &sub_max); + + T max_value = MaxValue(src.values); + + bool prune_with_max_elems = + max_elems > 0 && max_elems < total_elements; + + Array1 order_map; + const int32_t *order_map_data; + if (prune_with_max_elems) { + order_map = Array1(c, total_elements); + Sort>(&sub_max, &order_map); + order_map_data = order_map.Data(); + } + + char *keep_data = renumbering.Keep().Data(); + const T *sub_max_data = sub_max.Data(); + + // prune_with_max_elems means we have sorted the source ragged tensor + if (prune_with_max_elems) { + K2_EVAL(c, total_elements, lambda_set_keep_sorted, (int32_t i) { + bool pruned_by_beam = sub_max_data[i] < max_value - beam; + bool pruned_by_max_elems = i >= max_elems; + keep_data[order_map_data[i]] = + !(pruned_by_max_elems || pruned_by_beam); + }); + } else { + K2_EVAL(c, total_elements, lambda_set_keep, (int32_t i) { + keep_data[i] = sub_max_data[i] >= max_value - beam; + }); + } + return renumbering; +} + +/* Prune a two axes ragged tensor on axis1 + * This is a special case of PruneRagged with axis == 1 and src.NumAxes() == 2, + * To get more details, please refer to the docs for PruneRagged in + * ragged_ops.h. + */ +template +Renumbering PruneRaggedAxis1(Ragged &src, T beam, + int32_t max_elems) { + K2_CHECK_EQ(src.NumAxes(), 2); + const ContextPtr &c = src.Context(); + int32_t total_elements = src.TotSize(1); + Renumbering renumbering(c, total_elements); + + T negative_infinity = -std::numeric_limits::infinity(); + Array1 sub_max(c, src.TotSize(0)); + MaxPerSublist(src, negative_infinity, &sub_max); + + bool prune_with_max_elems = + max_elems > 0 && max_elems < total_elements; + + Array1 order_map; + const int32_t *order_map_data; + if (prune_with_max_elems) { + Ragged sorted_src = src.Clone(); + order_map = Array1(c, total_elements); + SortSublists>(&sorted_src, &order_map); + order_map_data = order_map.Data(); + } + + char *keep_data = renumbering.Keep().Data(); + const T *sub_max_data = sub_max.Data(), + *src_data = src.values.Data(); + const int32_t *row_ids1_data = src.RowIds(1).Data(), + *row_splits1_data = src.RowSplits(1).Data(); + // prune_with_max_elems means we have sorted the source ragged tensor + if (prune_with_max_elems) { + K2_EVAL(c, total_elements, lambda_set_keep_sorted, (int32_t idx01) { + // idx01 is the index after sorting + int32_t original_idx01 = order_map_data[idx01], + // SortSublists wouldn't chaneg idx0 & idx0x + idx0 = row_ids1_data[original_idx01], + idx0x = row_splits1_data[idx0], + // idx1 is the index after sorting + idx1 = idx01 - idx0x; + bool pruned_by_max_elems = idx1 >= max_elems, + pruned_by_beam = + src_data[original_idx01] < sub_max_data[idx0] - beam; + keep_data[original_idx01] = + !(pruned_by_max_elems || pruned_by_beam); + }); + } else { + K2_EVAL(c, total_elements, lambda_set_keep, (int32_t idx01) { + int32_t idx0 = row_ids1_data[idx01]; + keep_data[idx01] = src_data[idx01] >= sub_max_data[idx0] - beam; + }); + } + return renumbering; +} + +template +Renumbering PruneRagged(Ragged &src, int32_t axis, T beam, + int32_t max_elems) { + NVTX_RANGE(K2_FUNC); + if (axis == 0) { + auto reduced_src = src; + while (reduced_src.NumAxes() > 2) { + reduced_src = RemoveAxis(reduced_src, reduced_src.NumAxes() - 2); + } + return PruneRaggedAxis0(reduced_src, beam, max_elems); + } else if (axis == src.NumAxes() - 1) { + auto reduced_src = src; + while (reduced_src.NumAxes() > 2) { + reduced_src = RemoveAxis(reduced_src, 0); + } + return PruneRaggedAxis1(reduced_src, beam, max_elems); + } else { + RaggedShape top, bottom; + DecomposeRaggedShape(src.shape, axis, &top, &bottom); + Ragged bottom_ragged(bottom, src.values); + return PruneRagged(bottom_ragged, 0, beam, max_elems); + } +} + } // namespace k2 #endif // K2_CSRC_RAGGED_OPS_INL_H_ diff --git a/k2/csrc/ragged_shape_test.cu b/k2/csrc/ragged_shape_test.cu index f09d7edcd..451a2cf60 100644 --- a/k2/csrc/ragged_shape_test.cu +++ b/k2/csrc/ragged_shape_test.cu @@ -360,11 +360,6 @@ TEST(RaggedShapeTest, RemoveEmptyLists) { } } - - - - - TEST(RaggedShapeTest, RaggedShapeIterator) { // note RaggedShapeIndexIterator works only for CPU ContextPtr context = GetCpuContext(); @@ -425,4 +420,58 @@ TEST(RaggedShapeTest, RandomRaggedShape) { } } +TEST(RaggedShapeTest, SubsetRaggedShape) { + for (auto &c : {GetCpuContext(), GetCudaContext()}) { + RaggedShape src(c, + "[ [ [ x x ] [ x ] ] [ [ x x x ] [ ] ] [ [ x ] [ x x ] ] ]"); + // axis = 2 + Array1 ref_keep(c, std::vector({1, 0, 1, 0, 0, 1, 0, 1, 0})); + Renumbering renumbering(c, src.NumElements()); + auto &keep = renumbering.Keep(); + keep.CopyFrom(ref_keep); + Array1 new2old; + auto dest = SubsetRaggedShape(src, renumbering, 2, &new2old); + Array1 ref_new2old(c, std::vector({0, 2, 5, 7})); + RaggedShape ref_dest(c, + "[ [ [ x ] [ x ] ] [ [ x ] [ ] ] [ [ ] [ x ] ] ]"); + EXPECT_TRUE(Equal(dest, ref_dest)); + EXPECT_TRUE(Equal(new2old, ref_new2old)); + // test axis = -1 + dest = SubsetRaggedShape(src, renumbering, -1, &new2old); + EXPECT_TRUE(Equal(dest, ref_dest)); + EXPECT_TRUE(Equal(new2old, ref_new2old)); + + // axis = 1 + ref_keep = Array1(c, std::vector({1, 0, 0, 1, 1, 1})); + renumbering = Renumbering(c, src.TotSize(1)); + keep = renumbering.Keep(); + keep.CopyFrom(ref_keep); + dest = SubsetRaggedShape(src, renumbering, 1, &new2old); + ref_new2old = Array1(c, std::vector({0, 1, 6, 7, 8})); + ref_dest = RaggedShape(c, "[ [ [ x x ] ] [ [ ] ] [ [ x ] [ x x ] ] ]"); + EXPECT_TRUE(Equal(dest, ref_dest)); + EXPECT_TRUE(Equal(new2old, ref_new2old)); + // test axis = -2 + dest = SubsetRaggedShape(src, renumbering, -2, &new2old); + EXPECT_TRUE(Equal(dest, ref_dest)); + EXPECT_TRUE(Equal(new2old, ref_new2old)); + + // axis = 0 + ref_keep = Array1(c, std::vector({1, 0, 1})); + renumbering = Renumbering(c, src.TotSize(0)); + keep = renumbering.Keep(); + keep.CopyFrom(ref_keep); + dest = SubsetRaggedShape(src, renumbering, 0, &new2old); + ref_new2old = Array1(c, + std::vector({0, 1, 2, 6, 7, 8})); + ref_dest = RaggedShape(c, "[ [ [ x x ] [ x ] ] [ [ x ] [ x x ] ] ]"); + EXPECT_TRUE(Equal(dest, ref_dest)); + EXPECT_TRUE(Equal(new2old, ref_new2old)); + // test axis = -3 + dest = SubsetRaggedShape(src, renumbering, -3, &new2old); + EXPECT_TRUE(Equal(dest, ref_dest)); + EXPECT_TRUE(Equal(new2old, ref_new2old)); + } +} + } // namespace k2 diff --git a/k2/csrc/ragged_test.cu b/k2/csrc/ragged_test.cu index 7fe9c5fae..801e51f5c 100644 --- a/k2/csrc/ragged_test.cu +++ b/k2/csrc/ragged_test.cu @@ -41,22 +41,25 @@ namespace k2 { - TEST(RaggedShapeOpsTest, CatMoreAxes) { for (auto &c : {GetCpuContext(), GetCudaContext()}) { - RaggedShape shape1 = RaggedShape(c, "[ [ [ [ x x ] ] [ [ x ] ] ]" - " [ [ [ x ] ] ] ]"), - shape2 = RaggedShape(c, "[ [ [ [ x ] ] [ [ x ] ] ]" - " [ [ [ x x ] ] ] ]"), - shape3 = RaggedShape(c, "[ [ [ [ ] ] [ [ x ] ] ]" - " [ [ [ ] ] ] ]"); + RaggedShape shape1 = RaggedShape(c, + "[ [ [ [ x x ] ] [ [ x ] ] ]" + " [ [ [ x ] ] ] ]"), + shape2 = RaggedShape(c, + "[ [ [ [ x ] ] [ [ x ] ] ]" + " [ [ [ x x ] ] ] ]"), + shape3 = RaggedShape(c, + "[ [ [ [ ] ] [ [ x ] ] ]" + " [ [ [ ] ] ] ]"); RaggedShape cat_axis2_ref = - RaggedShape(c, "[ [ [ [ x x ] [ x ] [ ] ] [ [ x ] [ x ] [ x ] ] ]" - " [ [ [ x ] [ x x ] [ ] ] ] ]"); - RaggedShape cat_axis3_ref = - RaggedShape(c, "[ [ [ [ x x x ] ] [ [ x x x ] ] ]" - " [ [ [ x x x ] ] ] ]"); + RaggedShape(c, + "[ [ [ [ x x ] [ x ] [ ] ] [ [ x ] [ x ] [ x ] ] ]" + " [ [ [ x ] [ x x ] [ ] ] ] ]"); + RaggedShape cat_axis3_ref = RaggedShape(c, + "[ [ [ [ x x x ] ] [ [ x x x ] ] ]" + " [ [ [ x x x ] ] ] ]"); RaggedShape *srcs[] = {&shape1, &shape2, &shape3}; Array1 merge_map2; Array1 merge_map3; @@ -74,21 +77,25 @@ TEST(RaggedShapeOpsTest, CatMoreAxes) { TEST(RaggedShapeOpsTest, StackMoreAxes) { for (auto &c : {GetCpuContext(), GetCudaContext()}) { - RaggedShape shape1 = RaggedShape(c, "[ [ [ [ x x ] ] [ [ x ] ] ]" - " [ [ [ x ] ] ] ]"), - shape2 = RaggedShape(c, "[ [ [ [ x ] ] [ [ x ] ] ]" - " [ [ [ x x ] ] ] ]"), - shape3 = RaggedShape(c, "[ [ [ [ ] ] [ [ x ] ] ]" - " [ [ [ ] ] ] ]"); + RaggedShape shape1 = RaggedShape(c, + "[ [ [ [ x x ] ] [ [ x ] ] ]" + " [ [ [ x ] ] ] ]"), + shape2 = RaggedShape(c, + "[ [ [ [ x ] ] [ [ x ] ] ]" + " [ [ [ x x ] ] ] ]"), + shape3 = RaggedShape(c, + "[ [ [ [ ] ] [ [ x ] ] ]" + " [ [ [ ] ] ] ]"); RaggedShape stacked2_ref = - RaggedShape(c, "[ [ [ [ [ x x ] ] [ [ x ] ] [ [ ] ] ]" - " [ [ [ x ] ] [ [ x ] ] [ [ x ] ] ] ]" - " [ [ [ [ x ] ] [ [ x x ] ] [ [ ] ] ] ] ]"); - RaggedShape stacked3_ref = - RaggedShape(c, "[ [ [ [ [ x x ] [ x ] [ ] ] ]" - " [ [ [ x ] [ x ] [ x ] ] ] ]" - " [ [ [ [ x ] [ x x ] [ ] ] ] ] ]"); + RaggedShape(c, + "[ [ [ [ [ x x ] ] [ [ x ] ] [ [ ] ] ]" + " [ [ [ x ] ] [ [ x ] ] [ [ x ] ] ] ]" + " [ [ [ [ x ] ] [ [ x x ] ] [ [ ] ] ] ] ]"); + RaggedShape stacked3_ref = RaggedShape(c, + "[ [ [ [ [ x x ] [ x ] [ ] ] ]" + " [ [ [ x ] [ x ] [ x ] ] ] ]" + " [ [ [ [ x ] [ x x ] [ ] ] ] ] ]"); RaggedShape *srcs[] = {&shape1, &shape2, &shape3}; Array1 merge_map2; Array1 merge_map3; @@ -113,18 +120,13 @@ TEST(RaggedShapeOpsTest, Unstack2Axes) { // axis = 0 Unstack(shape, 0, &out, &out_map); - K2_CHECK(Equal(out[0], - RaggedShape(c, "[ [ x x ] ]"))); - K2_CHECK(Equal(out_map[0], - Array1(c, std::vector{0, 1}))); - K2_CHECK(Equal(out[1], - RaggedShape(c, "[ [ x x x ] ]"))); - K2_CHECK(Equal(out_map[1], - Array1(c, std::vector{2, 3, 4}))); - K2_CHECK(Equal(out[2], - RaggedShape(c, "[ [ x ] ]"))); - K2_CHECK(Equal(out_map[2], - Array1(c, std::vector{5}))); + K2_CHECK(Equal(out[0], RaggedShape(c, "[ [ x x ] ]"))); + K2_CHECK(Equal(out_map[0], Array1(c, std::vector{0, 1}))); + K2_CHECK(Equal(out[1], RaggedShape(c, "[ [ x x x ] ]"))); + K2_CHECK( + Equal(out_map[1], Array1(c, std::vector{2, 3, 4}))); + K2_CHECK(Equal(out[2], RaggedShape(c, "[ [ x ] ]"))); + K2_CHECK(Equal(out_map[2], Array1(c, std::vector{5}))); std::vector out_ptr; out_ptr.clear(); @@ -135,18 +137,13 @@ TEST(RaggedShapeOpsTest, Unstack2Axes) { // axis = 1 Unstack(shape, 1, &out, &out_map); - K2_CHECK(Equal(out[0], - RaggedShape(c, "[ [ x x x ] ]"))); - K2_CHECK(Equal(out_map[0], - Array1(c, std::vector{0, 2, 5}))); - K2_CHECK(Equal(out[1], - RaggedShape(c, "[ [ x x ] ]"))); - K2_CHECK(Equal(out_map[1], - Array1(c, std::vector{1, 3}))); - K2_CHECK(Equal(out[2], - RaggedShape(c, "[ [ x ] ]"))); - K2_CHECK(Equal(out_map[2], - Array1(c, std::vector{4}))); + K2_CHECK(Equal(out[0], RaggedShape(c, "[ [ x x x ] ]"))); + K2_CHECK( + Equal(out_map[0], Array1(c, std::vector{0, 2, 5}))); + K2_CHECK(Equal(out[1], RaggedShape(c, "[ [ x x ] ]"))); + K2_CHECK(Equal(out_map[1], Array1(c, std::vector{1, 3}))); + K2_CHECK(Equal(out[2], RaggedShape(c, "[ [ x ] ]"))); + K2_CHECK(Equal(out_map[2], Array1(c, std::vector{4}))); // can not test Stack here, because the element numbers of axis 1 is not // the same } @@ -154,21 +151,21 @@ TEST(RaggedShapeOpsTest, Unstack2Axes) { TEST(RaggedShapeOpsTest, Unstack) { for (auto &c : {GetCpuContext(), GetCudaContext()}) { - RaggedShape shape(c, "[ [ [ [ x x ] [ x ] ] [ [ x ] [ x x ] ] ]" - " [ [ [ x x x ] ] ] ]"); + RaggedShape shape(c, + "[ [ [ [ x x ] [ x ] ] [ [ x ] [ x x ] ] ]" + " [ [ [ x x x ] ] ] ]"); std::vector out; std::vector> out_map; Unstack(shape, 0, &out, &out_map); // axis = 0 K2_CHECK(Equal(out[0], - RaggedShape(c, "[ [ [ x x ] [ x ] ] [ [ x ] [ x x ] ] ]"))); + RaggedShape(c, "[ [ [ x x ] [ x ] ] [ [ x ] [ x x ] ] ]"))); K2_CHECK(Equal(out_map[0], - Array1(c, std::vector{0, 1, 2, 3, 4, 5}))); - K2_CHECK(Equal(out[1], - RaggedShape(c, "[ [ [ x x x ] ] ]"))); - K2_CHECK(Equal(out_map[1], - Array1(c, std::vector{6, 7, 8}))); + Array1(c, std::vector{0, 1, 2, 3, 4, 5}))); + K2_CHECK(Equal(out[1], RaggedShape(c, "[ [ [ x x x ] ] ]"))); + K2_CHECK( + Equal(out_map[1], Array1(c, std::vector{6, 7, 8}))); std::vector out_ptr; for (size_t i = 0; i < out.size(); ++i) out_ptr.emplace_back(&(out[i])); @@ -177,14 +174,13 @@ TEST(RaggedShapeOpsTest, Unstack) { // axis = 1 Unstack(shape, 1, &out, &out_map); - K2_CHECK(Equal(out[0], - RaggedShape(c, "[ [ [ x x ] [ x ] ] [ [ x x x ] ] ]"))); + K2_CHECK( + Equal(out[0], RaggedShape(c, "[ [ [ x x ] [ x ] ] [ [ x x x ] ] ]"))); K2_CHECK(Equal(out_map[0], - Array1(c, std::vector{0, 1, 2, 6, 7, 8}))); - K2_CHECK(Equal(out[1], - RaggedShape(c, "[ [ [ x ] [ x x ] ] [ ] ]"))); - K2_CHECK(Equal(out_map[1], - Array1(c, std::vector{3, 4, 5}))); + Array1(c, std::vector{0, 1, 2, 6, 7, 8}))); + K2_CHECK(Equal(out[1], RaggedShape(c, "[ [ [ x ] [ x x ] ] [ ] ]"))); + K2_CHECK( + Equal(out_map[1], Array1(c, std::vector{3, 4, 5}))); out_ptr.clear(); for (size_t i = 0; i < out.size(); ++i) out_ptr.emplace_back(&(out[i])); @@ -194,14 +190,13 @@ TEST(RaggedShapeOpsTest, Unstack) { // axis = 2 Unstack(shape, 2, &out, &out_map); - K2_CHECK(Equal(out[0], - RaggedShape(c, "[ [ [ x x ] [ x ] ] [ [ x x x ] ] ]"))); + K2_CHECK( + Equal(out[0], RaggedShape(c, "[ [ [ x x ] [ x ] ] [ [ x x x ] ] ]"))); K2_CHECK(Equal(out_map[0], - Array1(c, std::vector{0, 1, 3, 6, 7, 8}))); - K2_CHECK(Equal(out[1], - RaggedShape(c, "[ [ [ x ] [ x x ] ] [ [ ] ] ]"))); - K2_CHECK(Equal(out_map[1], - Array1(c, std::vector{2, 4, 5}))); + Array1(c, std::vector{0, 1, 3, 6, 7, 8}))); + K2_CHECK(Equal(out[1], RaggedShape(c, "[ [ [ x ] [ x x ] ] [ [ ] ] ]"))); + K2_CHECK( + Equal(out_map[1], Array1(c, std::vector{2, 4, 5}))); out_ptr.clear(); for (size_t i = 0; i < out.size(); ++i) out_ptr.emplace_back(&(out[i])); @@ -211,18 +206,15 @@ TEST(RaggedShapeOpsTest, Unstack) { // axis = 3 Unstack(shape, 3, &out, &out_map); - K2_CHECK(Equal(out[0], - RaggedShape(c, "[ [ [ x x ] [ x x ] ] [ [ x ] ] ]"))); + K2_CHECK( + Equal(out[0], RaggedShape(c, "[ [ [ x x ] [ x x ] ] [ [ x ] ] ]"))); K2_CHECK(Equal(out_map[0], - Array1(c, std::vector{0, 2, 3, 4, 6}))); - K2_CHECK(Equal(out[1], - RaggedShape(c, "[ [ [ x ] [ x ] ] [ [ x ] ] ]"))); - K2_CHECK(Equal(out_map[1], - Array1(c, std::vector{1, 5, 7}))); - K2_CHECK(Equal(out[2], - RaggedShape(c, "[ [ [ ] [ ] ] [ [ x ] ] ]"))); - K2_CHECK(Equal(out_map[2], - Array1(c, std::vector{8}))); + Array1(c, std::vector{0, 2, 3, 4, 6}))); + K2_CHECK(Equal(out[1], RaggedShape(c, "[ [ [ x ] [ x ] ] [ [ x ] ] ]"))); + K2_CHECK( + Equal(out_map[1], Array1(c, std::vector{1, 5, 7}))); + K2_CHECK(Equal(out[2], RaggedShape(c, "[ [ [ ] [ ] ] [ [ x ] ] ]"))); + K2_CHECK(Equal(out_map[2], Array1(c, std::vector{8}))); // can not test Stack here, because the element numbers of axis 3 is not // the same } @@ -230,12 +222,13 @@ TEST(RaggedShapeOpsTest, Unstack) { TEST(RaggedShapeOpsTest, UnstackMoreAxes) { for (auto &c : {GetCpuContext(), GetCudaContext()}) { - RaggedShape shape(c, "[ [ [ [ [ x ] [ ] ] [ [ x x x ] ] ] ]" - " [ [ [ [ x x x ] ] [ [ x x ] ] [ [ x ] ] ]" - " [ [ [ x x ] [ x ] [ ] [ x ] ] ]" - " [ [ [ x ] ] [ [ x ] [ x x x x ] ] ] ]" - " [ [ [ [ x ] ] [ ] ]" - " [ [ [ x x ] ] ] ] ]"); + RaggedShape shape(c, + "[ [ [ [ [ x ] [ ] ] [ [ x x x ] ] ] ]" + " [ [ [ [ x x x ] ] [ [ x x ] ] [ [ x ] ] ]" + " [ [ [ x x ] [ x ] [ ] [ x ] ] ]" + " [ [ [ x ] ] [ [ x ] [ x x x x ] ] ] ]" + " [ [ [ [ x ] ] [ ] ]" + " [ [ [ x x ] ] ] ] ]"); std::vector out; std::vector> out_map; @@ -254,11 +247,11 @@ TEST(RaggedShapeOpsTest, UnstackMoreAxes) { } TEST(RaggedShapeOpsTest, UnstackRandom) { - RaggedShape random_shape_ = RandomRaggedShape(true, // set_row_ids - 5, // min_num_axes - 5, // max_num_axes - 1, // min_num_elements - 100); // max_num_elements + RaggedShape random_shape_ = RandomRaggedShape(true, // set_row_ids + 5, // min_num_axes + 5, // max_num_axes + 1, // min_num_elements + 100); // max_num_elements for (auto &c : {GetCpuContext(), GetCudaContext()}) { auto random_shape0 = random_shape_.To(c); std::vector out; @@ -1733,9 +1726,13 @@ TEST(RaggedShapeOpsTest, TestIndex) { TEST(RaggedShapeOpsTest, TestIndexAxis1) { for (auto &context : {GetCpuContext(), GetCudaContext()}) { { - Ragged input = Ragged(" [ [ 1 2 ] [ 3 4 5 ] [ 6 7 ] [ ] ]").To(context); // NOLINT + Ragged input = + Ragged(" [ [ 1 2 ] [ 3 4 5 ] [ 6 7 ] [ ] ]") + .To(context); // NOLINT Array1 indexes = Array1(" [ 1 0 4 2 6 5 ]").To(context); - Ragged output = Ragged(" [ [ 2 1 ] [ 5 3 ] [ 7 6 ] [ ] ]").To(context); // NOLINT + Ragged output = + Ragged(" [ [ 2 1 ] [ 5 3 ] [ 7 6 ] [ ] ]") + .To(context); // NOLINT Ragged indexed = Index(input, 1, indexes); EXPECT_EQ(Equal(output, indexed), true); @@ -1743,8 +1740,6 @@ TEST(RaggedShapeOpsTest, TestIndexAxis1) { } } - - TEST(GetTransposeReordering, NoDuplicates) { // col0 col1 col2 col3 col4 col5 // row0 a0 b1 @@ -2241,41 +2236,39 @@ void TestUnstackRagged() { K2_CHECK(Equal(out[2], Ragged(c, "[ [ 50 ] ]"))); // more axes - ragged = Ragged(c, "[ [ [ [ 1 11 21 ] [ 21 22 ] [ 31 ] ]" - " [ [ 41 ] [ 51 ] ] ]" - " [ [ [ 61 62 63 ] ] ] ]"); + ragged = Ragged(c, + "[ [ [ [ 1 11 21 ] [ 21 22 ] [ 31 ] ]" + " [ [ 41 ] [ 51 ] ] ]" + " [ [ [ 61 62 63 ] ] ] ]"); // axis = 0 Unstack(ragged, 0, &out); - K2_CHECK(Equal(out[0], Ragged(c, - "[ [ [ 1 11 21 ] [ 21 22 ] [ 31 ] ] [ [ 41 ] [ 51 ] ] ]"))); - K2_CHECK(Equal(out[1], - Ragged(c, "[ [ [ 61 62 63 ] ] ]"))); + K2_CHECK(Equal( + out[0], + Ragged(c, + "[ [ [ 1 11 21 ] [ 21 22 ] [ 31 ] ] [ [ 41 ] [ 51 ] ] ]"))); + K2_CHECK(Equal(out[1], Ragged(c, "[ [ [ 61 62 63 ] ] ]"))); // axis = 1 Unstack(ragged, 1, &out); - K2_CHECK(Equal(out[0], Ragged(c, - "[ [ [ 1 11 21 ] [ 21 22 ] [ 31 ] ] [ [ 61 62 63 ] ] ]"))); - K2_CHECK(Equal(out[1], - Ragged(c, "[ [ [ 41 ] [ 51 ] ] [ ] ]"))); + K2_CHECK(Equal( + out[0], + Ragged(c, "[ [ [ 1 11 21 ] [ 21 22 ] [ 31 ] ] [ [ 61 62 63 ] ] ]"))); + K2_CHECK(Equal(out[1], Ragged(c, "[ [ [ 41 ] [ 51 ] ] [ ] ]"))); // axis = 2 Unstack(ragged, 2, &out); - K2_CHECK(Equal(out[0], - Ragged(c, "[ [ [ 1 11 21 ] [ 41 ] ] [ [ 61 62 63 ] ] ]"))); - K2_CHECK(Equal(out[1], - Ragged(c, "[ [ [ 21 22 ] [ 51 ] ] [ [ ] ] ]"))); - K2_CHECK(Equal(out[2], - Ragged(c, "[ [ [ 31 ] [ ] ] [ [ ] ] ]"))); + K2_CHECK(Equal( + out[0], Ragged(c, "[ [ [ 1 11 21 ] [ 41 ] ] [ [ 61 62 63 ] ] ]"))); + K2_CHECK(Equal(out[1], Ragged(c, "[ [ [ 21 22 ] [ 51 ] ] [ [ ] ] ]"))); + K2_CHECK(Equal(out[2], Ragged(c, "[ [ [ 31 ] [ ] ] [ [ ] ] ]"))); // axis = 3 Unstack(ragged, 3, &out); K2_CHECK(Equal(out[0], - Ragged(c, "[ [ [ 1 21 31 ] [ 41 51 ] ] [ [ 61 ] ] ]"))); - K2_CHECK(Equal(out[1], - Ragged(c, "[ [ [ 11 22 ] [ ] ] [ [ 62 ] ] ]"))); - K2_CHECK(Equal(out[2], - Ragged(c, "[ [ [ 21 ] [ ] ] [ [ 63 ] ] ]"))); + Ragged(c, "[ [ [ 1 21 31 ] [ 41 51 ] ] [ [ 61 ] ] ]"))); + K2_CHECK(Equal(out[1], Ragged(c, "[ [ [ 11 22 ] [ ] ] [ [ 62 ] ] ]"))); + K2_CHECK(Equal(out[2], Ragged(c, "[ [ [ 21 ] [ ] ] [ [ 63 ] ] ]"))); } } @@ -2916,8 +2909,6 @@ TEST(RaggedOpsTest, TestComputeHash) { } } - - TEST(RaggedOpsTest, TestUniqueSequences) { for (int32_t i = 0; i < 20; i++) { for (auto &c : {GetCpuContext(), GetCudaContext()}) { @@ -2932,7 +2923,7 @@ TEST(RaggedOpsTest, TestUniqueSequences) { ContextPtr cpu = GetCpuContext(); Array1 hash_src = ComputeHash(src).To(cpu), - hash_unique = ComputeHash(unique).To(cpu); + hash_unique = ComputeHash(unique).To(cpu); RaggedShape src_hash_shape = RemoveAxis(src.shape, src.NumAxes() - 1).To(cpu); @@ -2946,9 +2937,10 @@ TEST(RaggedOpsTest, TestUniqueSequences) { K2_CHECK_EQ(src_hash_shape.Dim0(), unique_hash_shape.Dim0()); const int32_t *src_hash_row_splits = src_hash_shape.RowSplits(1).Data(), - *unique_hash_row_splits = unique_hash_shape.RowSplits(1).Data(); + *unique_hash_row_splits = + unique_hash_shape.RowSplits(1).Data(); const int32_t *src_hash_data = hash_src.Data(), - *unique_hash_data = hash_unique.Data(); + *unique_hash_data = hash_unique.Data(); for (int32_t r = 0; r < src_hash_shape.Dim0(); r++) { int32_t src_begin = src_hash_row_splits[r], @@ -2979,7 +2971,6 @@ TEST(RaggedIntTest, TestCreateRagged2Int) { K2_CHECK(Equal(r, r2)); } - TEST(RaggedFloatTest, TestCreateRagged2Float) { std::vector> vecs{{1.2, 2.3}, {}, {3.4, 5.6}}; std::vector expected_values{1.2, 2.3, 3.4, 5.6}; @@ -3004,10 +2995,8 @@ static void TestPadRagged() { T padding_value = 0; Array2 res = PadRagged(src, "constant", padding_value); Array1 dst = res.Flatten(); - std::vector expected = {1, 2, 0, 0, - 3, 4, 3, 0, - 0, 0, 0, 0, - 5, 6, 7, 8}; + std::vector expected = {1, 2, 0, 0, 3, 4, 3, 0, + 0, 0, 0, 0, 5, 6, 7, 8}; CheckArrayData(dst, expected); } { @@ -3015,10 +3004,8 @@ static void TestPadRagged() { T padding_value = -1; Array2 res = PadRagged(src, "constant", padding_value); Array1 dst = res.Flatten(); - std::vector expected = {1, 2, -1, -1, - 3, 4, 3, -1, - -1, -1, -1, -1, - 5, 6, 7, 8}; + std::vector expected = {1, 2, -1, -1, 3, 4, 3, -1, + -1, -1, -1, -1, 5, 6, 7, 8}; CheckArrayData(dst, expected); } { @@ -3026,10 +3013,8 @@ static void TestPadRagged() { T padding_value = 100; Array2 res = PadRagged(src, "replicate", padding_value); Array1 dst = res.Flatten(); - std::vector expected = {1, 2, 2, 2, - 3, 4, 3, 3, - 100, 100, 100, 100, - 5, 6, 7, 8}; + std::vector expected = {1, 2, 2, 2, 3, 4, 3, 3, + 100, 100, 100, 100, 5, 6, 7, 8}; CheckArrayData(dst, expected); } } @@ -3039,4 +3024,125 @@ TEST(RaggedTest, TestPadRagged) { TestPadRagged(); TestPadRagged(); } + +template +static void TestPruneRagged() { + for (auto &c : {GetCpuContext(), GetCudaContext()}) { + Ragged src(c, + "[ [ [ 1.1 2.1 5.2 ] [ 1.0 5.1 ] [ 6.1 ] ] " + " [ [ 1.2 ] [ 2.2 6.3 ] [ ] ] " + " [ [ 1.3 4.4 ] [ 2.3 5.0 ] ] ]"); + + T beam = 2.0; + auto renumbering = PruneRagged(src, 0, beam, 2); + // best_score=6.3, best scores for sublists are [6.1, 6.3, 5.0] + // no sublist is pruned by beam, 5.0 is pruned by max-elems + // keep : [ [ [ 1.1 2.1 5.2 ] [ 1.0 5.1 ] [ 6.1 ] ] + // [ [ 1.2 ] [ 2.2 6.3 ] [ ] ] ] + Array1 keep_ref(c, std::vector{1, 1, 0}); + K2_CHECK(Equal(renumbering.Keep(), keep_ref)); + + beam = 0.1; + renumbering = PruneRagged(src, 0, beam, 3); + // best_score=6.3, best scores for sublists are [6.1, 6.3, 5.0] + // 6.1 & 5.0 are pruned by beam + // keep : [ [ [ 1.2 ] [ 2.2 6.3 ] [ ] ] ] + keep_ref = Array1(c, std::vector{0, 1, 0}); + K2_CHECK(Equal(renumbering.Keep(), keep_ref)); + + beam = 2.0; + renumbering = PruneRagged(src, 1, beam, 5); + // best_score=6.3, best scores for sublists are + // [5.2, 5.1, 6.1, 1.2, 6.3, -inf, 4.4, 5.0] + // 1.2 & -inf are pruned by beam, 4.4 is pruned by max-elems. + // keep : [ [ [ 1.1 2.1 5.2 ] [ 1.0 5.1 ] [ 6.1 ] ] [ [ 2.2 6.3 ] ] + // [ [ 2.3 5.0 ] ] ] + keep_ref = Array1(c, std::vector{1, 1, 1, 0, 1, 0, 0, 1}); + K2_CHECK(Equal(renumbering.Keep(), keep_ref)); + + beam = 1.0; + renumbering = PruneRagged(src, 1, beam, 5); + // best_score=6.3, best scores for sublists are + // [5.2, 5.1, 6.1, 1.2, 6.3, -inf, 4.4, 5.0] + // all sublists are pruned by beam, except 6.1 & 6.3 + // keep : [ [ [ 6.1 ] ] [ [ 2.2 6.3 ] ] ] + keep_ref = Array1(c, std::vector{0, 0, 1, 0, 1, 0, 0, 0}); + K2_CHECK(Equal(renumbering.Keep(), keep_ref)); + + beam = 4.0; + renumbering = PruneRagged(src, 2, beam, 3); + // best scores for sublists are + // [5.2, 5.1, 6.1, 1.2, 6.3, -inf, 4.4, 5.0] + // 1.1, 1.0, 2.2 are pruned by beam. + // keep : [ [ [ 2.1 5.2 ] [ 5.1 ] [ 6.1 ] ] [ [ 1.2 ] [ 6.3 ] ] + // [ [ 1.3 4.4 ] [ 2.3 5.0 ] ] ] + keep_ref = Array1( + c, std::vector{0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1}); + K2_CHECK(Equal(renumbering.Keep(), keep_ref)); + + beam = 5.0; + renumbering = PruneRagged(src, 2, beam, 2); + // best scores for sublists are + // [5.2, 5.1, 6.1, 1.2, 6.3, -inf, 4.4, 5.0] + // 1.1 is pruned by max-elems. + // keep : [ [ [ 2.1 5.2 ] [ 1.0 5.1 ] [ 6.1 ] ] [ [ 1.2 ] [ 2.2 6.3 ] ] + // [ [ 1.3 4.4 ] [ 2.3 5.0 ] ] ] + keep_ref = Array1( + c, std::vector{0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + K2_CHECK(Equal(renumbering.Keep(), keep_ref)); + } +} + +TEST(RaggedTest, TestPruneRagged) { + TestPruneRagged(); + TestPruneRagged(); +} + +template +static void TestPruneRaggedAndSubsetRagged() { + for (auto &c : {GetCpuContext(), GetCudaContext()}) { + Ragged src(c, + "[ [ [ 1.1 4.2 2.1 1.8 ] [ 5.0 3.1 ] ] " + " [ [ 1.2 ] [ 2.2 6.3 ] [ 2.4 6.1 ] [ 5.1 ] ] " + " [ [ 1.3 4.4 ] [ 1.4 0.8 2.3 5.2 3.6 ] ] ]"); + T beam = 1.0; + auto renumbering = PruneRagged(src, 0, beam, 3); + Array1 new2old; + auto dest = SubsetRagged(src, renumbering, 0, &new2old); + Ragged dest_ref(c, "[ [ [ 1.2 ] [ 2.2 6.3 ] [ 2.4 6.1 ] [ 5.1 ] ] ]"); + Array1 new2old_ref(c, std::vector{6, 7, 8, 9, 10, 11}); + K2_CHECK(Equal(dest, dest_ref)); + K2_CHECK(Equal(new2old, new2old_ref)); + + beam = 2.0; + renumbering = PruneRagged(src, 1, beam, 5); + dest = SubsetRagged(src, renumbering, 1, &new2old); + dest_ref = + Ragged(c, + "[ [ [ 5.0 3.1 ] ] [ [ 2.2 6.3 ] [ 2.4 6.1 ] [ 5.1 ] ] " + " [ [ 1.4 0.8 2.3 5.2 3.6 ] ] ]"); + new2old_ref = Array1( + c, std::vector{4, 5, 7, 8, 9, 10, 11, 14, 15, 16, 17, 18}); + K2_CHECK(Equal(dest, dest_ref)); + K2_CHECK(Equal(new2old, new2old_ref)); + + beam = 3.0; + renumbering = PruneRagged(src, 2, beam, 3); + dest = SubsetRagged(src, renumbering, 2, &new2old); + dest_ref = Ragged( + c, + "[ [ [ 4.2 2.1 1.8 ] [ 5.0 3.1 ] ] [ [ 1.2 ] [ 6.3 ] [ 6.1 ] [ 5.1 ] ]" + " [ [ 4.4 ] [ 2.3 5.2 3.6 ] ] ]"); + new2old_ref = Array1( + c, std::vector{1, 2, 3, 4, 5, 6, 8, 10, 11, 13, 16, 17, 18}); + K2_CHECK(Equal(dest, dest_ref)); + K2_CHECK(Equal(new2old, new2old_ref)); + } +} + +TEST(RaggedTest, TestPruneRaggedAndSubsetRagged) { + TestPruneRaggedAndSubsetRagged(); + TestPruneRaggedAndSubsetRagged(); +} + } // namespace k2 diff --git a/k2/csrc/rm_epsilon.cu b/k2/csrc/rm_epsilon.cu index b8774bf53..c20775c1a 100644 --- a/k2/csrc/rm_epsilon.cu +++ b/k2/csrc/rm_epsilon.cu @@ -795,7 +795,7 @@ void ComputeEpsilonClosure(FsaVec &epsilon_fsa, FsaVec *closure_fsa, const Arc &cur_arc = arcs_data[arc_idx012]; arc_keep_data[arc_idx012] = (cur_arc.src_state != cur_arc.dest_state); }); - *closure_fsa = SubsampleRagged(*closure_fsa, arc_renumbering); + *closure_fsa = SubsetRagged(*closure_fsa, arc_renumbering); *arc_map = Index(*arc_map, 0, arc_renumbering.New2Old()); } @@ -1081,7 +1081,7 @@ void RemoveEpsilonDevice(FsaOrVec &src_fsa, FsaOrVec *dest_fsa, non_epsilon_arc_map, foll_shape, &combined_foll, &combined_foll_arc_map); FsaVec epsilon_closure_prec = - SubsampleRagged(epsilon_closure_mapped, epsilon_prec_renumbering); + SubsetRagged(epsilon_closure_mapped, epsilon_prec_renumbering); Ragged epsilon_closure_prec_arc_map = Index( epsilon_closure_mapped_arc_map, 0, epsilon_prec_renumbering.New2Old()); // `combined_prec` will be set to an FSA, with the same state numbering as