From fdddfc4a51458138309d4ddb5111a952bb95d8f7 Mon Sep 17 00:00:00 2001 From: pkufool Date: Fri, 5 Nov 2021 14:44:06 +0800 Subject: [PATCH 1/2] Using FsaClass for ctc decoding & HLG decoding --- k2/torch/bin/decode.cu | 47 +++++++-------- k2/torch/csrc/CMakeLists.txt | 1 + k2/torch/csrc/decode.cu | 99 +++++++------------------------- k2/torch/csrc/decode.h | 35 ++--------- k2/torch/csrc/deserialization.cu | 17 +++--- k2/torch/csrc/deserialization.h | 4 +- k2/torch/csrc/fsa_algo.cu | 71 +++++++++++++++++++++++ k2/torch/csrc/fsa_algo.h | 38 ++++++++++++ k2/torch/csrc/fsa_class.cu | 59 +++++++++++++++---- k2/torch/csrc/fsa_class.h | 48 ++++++++++------ 10 files changed, 247 insertions(+), 172 deletions(-) create mode 100644 k2/torch/csrc/fsa_algo.cu create mode 100644 k2/torch/csrc/fsa_algo.h diff --git a/k2/torch/bin/decode.cu b/k2/torch/bin/decode.cu index 62c9dcfad..d4d641bd5 100644 --- a/k2/torch/bin/decode.cu +++ b/k2/torch/bin/decode.cu @@ -21,11 +21,11 @@ #include #include -#include "k2/csrc/fsa_algo.h" #include "k2/torch/csrc/decode.h" #include "k2/torch/csrc/dense_fsa_vec.h" #include "k2/torch/csrc/deserialization.h" #include "k2/torch/csrc/features.h" +#include "k2/torch/csrc/fsa_algo.h" #include "k2/torch/csrc/symbol_table.h" #include "k2/torch/csrc/utils.h" #include "k2/torch/csrc/wave_reader.h" @@ -36,7 +36,7 @@ #include "torch/utils.h" C10_DEFINE_bool(use_gpu, false, "True to use GPU. False to use CPU"); -C10_DEFINE_string(jit_pt, "", "Path to exported jit filename."); +C10_DEFINE_string(jit_pt, "", "Path to exported jit file."); C10_DEFINE_string( bpe_model, "", "Path to a pretrained BPE model. Needed if --use_ctc_decoding is true"); @@ -45,7 +45,7 @@ C10_DEFINE_string(hlg, "", "Path to HLG.pt. Needed if --use_ctc_decoding is false"); C10_DEFINE_string(word_table, "", "Path to words.txt. Needed if --use_ctc_decoding is false"); -// +// Fsa decoding related C10_DEFINE_double(search_beam, 20, "search_beam in IntersectDensePruned"); C10_DEFINE_double(output_beam, 8, "output_beam in IntersectDensePruned"); C10_DEFINE_int(min_activate_states, 30, @@ -120,7 +120,7 @@ int main(int argc, char *argv[]) { --use_ctc_decoding false \ --jit_pt \ --hlg \ - --word-table \ + --word_table \ /path/to/foo.wav \ /path/to/bar.wav \ @@ -141,7 +141,7 @@ int main(int argc, char *argv[]) { K2_LOG(INFO) << "Device: " << device; int32_t num_waves = argc - 1; - K2_CHECK_GE(num_waves, 1) << "You have to provided at least one wave file"; + K2_CHECK_GE(num_waves, 1) << "You have to provide at least one wave file"; std::vector wave_filenames(num_waves); for (int32_t i = 0; i != num_waves; ++i) { wave_filenames[i] = argv[i + 1]; @@ -205,50 +205,43 @@ int main(int argc, char *argv[]) { torch::Tensor supervision_segments = k2::GetSupervisionSegments(supervisions, subsampling_factor); - k2::ContextPtr ctx = k2::ContextFromTensor(nnet_output); - - k2::Fsa decoding_graph; - - k2::Array1 aux_labels; // only one of the two aux_labels is used - k2::Ragged ragged_aux_labels; + k2::FsaClass decoding_graph; if (FLAGS_use_ctc_decoding) { K2_LOG(INFO) << "Build CTC topo"; - k2::Fsa ctc_topo = k2::CtcTopo(ctx, nnet_output.size(2) - 1, - /*modified*/ false, &aux_labels); - decoding_graph = k2::FsaToFsaVec(ctc_topo); + decoding_graph = + k2::CtcTopo(nnet_output.size(2) - 1, /*modified*/ false, device); } else { K2_LOG(INFO) << "Load HLG.pt"; - // TODO(fangjun): We will eventually use an FSA wrapper to - // associate attributes with an FSA. - decoding_graph = k2::LoadFsa(FLAGS_hlg, &ragged_aux_labels); - decoding_graph = decoding_graph.To(ctx); - ragged_aux_labels = ragged_aux_labels.To(ctx); + decoding_graph = k2::LoadFsa(FLAGS_hlg); + decoding_graph = decoding_graph.To(device); } K2_LOG(INFO) << "Decoding"; - k2::FsaVec lattice = k2::GetLattice( + k2::FsaClass lattice = k2::GetLattice( nnet_output, decoding_graph, supervision_segments, FLAGS_search_beam, FLAGS_output_beam, FLAGS_min_activate_states, FLAGS_max_activate_states, - subsampling_factor, aux_labels, ragged_aux_labels, &aux_labels, - &ragged_aux_labels); + subsampling_factor); - lattice = k2::OneBestDecoding(lattice, aux_labels, ragged_aux_labels, - &aux_labels, &ragged_aux_labels); + lattice = k2::ShortestPath(lattice); + + auto ragged_aux_labels = k2::GetTexts(lattice); - ragged_aux_labels = k2::GetTexts(lattice, aux_labels, ragged_aux_labels); auto aux_labels_vec = ragged_aux_labels.ToVecVec(); std::vector texts; if (FLAGS_use_ctc_decoding) { sentencepiece::SentencePieceProcessor processor; - const auto status = processor.Load(FLAGS_bpe_model); + auto status = processor.Load(FLAGS_bpe_model); if (!status.ok()) { K2_LOG(FATAL) << status.ToString(); } for (const auto &ids : aux_labels_vec) { std::string text; - processor.Decode(ids, &text); + status = processor.Decode(ids, &text); + if (!status.ok()) { + K2_LOG(FATAL) << status.ToString(); + } texts.emplace_back(std::move(text)); } } else { diff --git a/k2/torch/csrc/CMakeLists.txt b/k2/torch/csrc/CMakeLists.txt index 7988a03e8..9ac896365 100644 --- a/k2/torch/csrc/CMakeLists.txt +++ b/k2/torch/csrc/CMakeLists.txt @@ -7,6 +7,7 @@ set(k2_torch_srcs dense_fsa_vec.cu deserialization.cu features.cu + fsa_algo.cu fsa_class.cu symbol_table.cu utils.cu diff --git a/k2/torch/csrc/decode.cu b/k2/torch/csrc/decode.cu index 7738b3282..c8b37b357 100644 --- a/k2/torch/csrc/decode.cu +++ b/k2/torch/csrc/decode.cu @@ -20,95 +20,36 @@ #include "k2/csrc/ragged_ops.h" #include "k2/torch/csrc/decode.h" #include "k2/torch/csrc/dense_fsa_vec.h" +#include "k2/torch/csrc/fsa_algo.h" +#include "k2/torch/csrc/utils.h" namespace k2 { -FsaVec GetLattice(torch::Tensor nnet_output, FsaVec decoding_graph, - torch::Tensor supervision_segments, float search_beam, - float output_beam, int32_t min_activate_states, - int32_t max_activate_states, int32_t subsampling_factor, - Array1 &in_aux_labels, - Ragged &in_ragged_aux_labels, - Array1 *out_aux_labels, - Ragged *out_ragged_aux_labels) { - if (in_aux_labels.Dim() != 0) { - K2_CHECK_EQ(in_ragged_aux_labels.values.Dim(), 0); - } else { - K2_CHECK_NE(in_ragged_aux_labels.values.Dim(), 0); - } - +FsaClass GetLattice(torch::Tensor nnet_output, FsaClass &decoding_graph, + torch::Tensor supervision_segments, float search_beam, + float output_beam, int32_t min_activate_states, + int32_t max_activate_states, int32_t subsampling_factor) { DenseFsaVec dense_fsa_vec = CreateDenseFsaVec( nnet_output, supervision_segments, subsampling_factor - 1); - - FsaVec lattice; - Array1 arc_map_a; - Array1 arc_map_b; - IntersectDensePruned(decoding_graph, dense_fsa_vec, search_beam, output_beam, - min_activate_states, max_activate_states, &lattice, - &arc_map_a, &arc_map_b); - if (in_aux_labels.Dim() > 0) { - // see Index() in array_ops.h - *out_aux_labels = Index(in_aux_labels, arc_map_a, /*allow_minus_one*/ false, - /*default_value*/ 0); - } else { - // See Index() in ragged_ops.h - *out_ragged_aux_labels = Index(in_ragged_aux_labels, /*axis*/ 0, arc_map_a); - } - return lattice; -} - -FsaVec OneBestDecoding(FsaVec &lattice, Array1 &in_aux_labels, - Ragged &in_ragged_aux_labels, - Array1 *out_aux_labels, - Ragged *out_ragged_aux_labels) { - if (in_aux_labels.Dim() != 0) { - K2_CHECK_EQ(in_ragged_aux_labels.values.Dim(), 0); - } else { - K2_CHECK_NE(in_ragged_aux_labels.values.Dim(), 0); - } - - Ragged state_batches = GetStateBatches(lattice, true); - Array1 dest_states = GetDestStates(lattice, true); - Ragged incoming_arcs = GetIncomingArcs(lattice, dest_states); - Ragged entering_arc_batches = - GetEnteringArcIndexBatches(lattice, incoming_arcs, state_batches); - - bool log_semiring = false; - Array1 entering_arcs; - GetForwardScores(lattice, state_batches, entering_arc_batches, - log_semiring, &entering_arcs); - - Ragged best_path_arc_indexes = ShortestPath(lattice, entering_arcs); - - if (in_aux_labels.Dim() > 0) { - *out_aux_labels = Index(in_aux_labels, best_path_arc_indexes.values, - /*allow_minus_one*/ false, - /*default_value*/ 0); - } else { - *out_ragged_aux_labels = - Index(in_ragged_aux_labels, /*axis*/ 0, best_path_arc_indexes.values); - } - - FsaVec out = FsaVecFromArcIndexes(lattice, best_path_arc_indexes); - return out; + return IntersectDensePruned(decoding_graph, dense_fsa_vec, search_beam, + output_beam, min_activate_states, + max_activate_states); } -Ragged GetTexts(FsaVec &lattice, Array1 &in_aux_labels, - Ragged &in_ragged_aux_labels) { - if (in_aux_labels.Dim() != 0) { - K2_CHECK_EQ(in_ragged_aux_labels.values.Dim(), 0); - } else { - K2_CHECK_NE(in_ragged_aux_labels.values.Dim(), 0); - } - +Ragged GetTexts(FsaClass &lattice) { + K2_CHECK(lattice.HasAttr("aux_labels")); Ragged ragged_aux_labels; - if (in_aux_labels.Dim() != 0) { - // [utt][state][arc] -> [utt][arc] - RaggedShape aux_labels_shape = RemoveAxis(lattice.shape, 1); - ragged_aux_labels = Ragged(aux_labels_shape, in_aux_labels); + torch::IValue aux_labels = lattice.GetAttr("aux_labels"); + if (aux_labels.isTensor()) { + Array1 aux_labels_array = + Array1FromTorch(aux_labels.toTensor()); + RaggedShape aux_labels_shape = RemoveAxis(lattice.fsa.shape, 1); + ragged_aux_labels = Ragged(aux_labels_shape, aux_labels_array); } else { + K2_CHECK(IsRaggedInt(aux_labels)); + Ragged in_ragged_aux_labels = ToRaggedInt(aux_labels); RaggedShape aux_labels_shape = - ComposeRaggedShapes(lattice.shape, in_ragged_aux_labels.shape); + ComposeRaggedShapes(lattice.fsa.shape, in_ragged_aux_labels.shape); aux_labels_shape = RemoveAxis(aux_labels_shape, 1); aux_labels_shape = RemoveAxis(aux_labels_shape, 1); ragged_aux_labels = diff --git a/k2/torch/csrc/decode.h b/k2/torch/csrc/decode.h index 35ebaf8de..5bef2d4d8 100644 --- a/k2/torch/csrc/decode.h +++ b/k2/torch/csrc/decode.h @@ -22,6 +22,7 @@ #include "k2/csrc/array.h" #include "k2/csrc/fsa.h" #include "k2/csrc/ragged.h" +#include "k2/torch/csrc/fsa_class.h" #include "torch/script.h" namespace k2 { @@ -68,33 +69,10 @@ attributes. @return Return an FsaVec, which is the intersection of decoding graph and the FSA constructed from `nnet_output`. */ -FsaVec GetLattice(torch::Tensor nnet_output, FsaVec decoding_graph, - torch::Tensor supervision_segments, float search_beam, - float output_beam, int32_t min_activate_states, - int32_t max_activate_states, int32_t subsampling_factor, - Array1 &in_aux_labels, - Ragged &in_ragged_aux_labels, - Array1 *out_aux_labels, - Ragged *out_ragged_aux_labels); - -/** Extract the best path from a lattice. - - @param lattice It can be the return value of `GetLattice()`. - @param in_aux_labels If not empty, it associates an extra label with each - arc in the input lattice. - @param in_ragged_aux_labels If not empty, it associates an extra label with - each arc in the input lattice. - @param out_aux_labels If in_aux_labels is not empty, it contains the - aux_labels for the returned FSA. - @param out_ragged_aux_labels If in_aux_labels is not empty, it contains - the aux_labels for the returned FSA. - - @return Return a FsaVec containing linear FSAs. - */ -FsaVec OneBestDecoding(FsaVec &lattice, Array1 &in_aux_labels, - Ragged &in_ragged_aux_labels, - Array1 *out_aux_labels, - Ragged *out_ragged_aux_labels); +FsaClass GetLattice(torch::Tensor nnet_output, FsaClass &decoding_graph, + torch::Tensor supervision_segments, float search_beam, + float output_beam, int32_t min_activate_states, + int32_t max_activate_states, int32_t subsampling_factor); /** Get aux labels of each FSA contained in the lattice. @@ -109,8 +87,7 @@ FsaVec OneBestDecoding(FsaVec &lattice, Array1 &in_aux_labels, with each arc in the `lattice. @return Return a ragged array with two axes [utt][aux_label]. */ -Ragged GetTexts(FsaVec &lattice, Array1 &in_aux_labels, - Ragged &in_ragged_aux_labels); +Ragged GetTexts(FsaClass &lattice); } // namespace k2 diff --git a/k2/torch/csrc/deserialization.cu b/k2/torch/csrc/deserialization.cu index 8bac51f4a..203a29464 100644 --- a/k2/torch/csrc/deserialization.cu +++ b/k2/torch/csrc/deserialization.cu @@ -296,8 +296,7 @@ static void RegisterRaggedInt() { // This function is modified from torch::jit::load() // See torch/csrc/jit/serialization/import.cpp // -k2::FsaOrVec LoadFsa(const std::string &filename, - Ragged *ragged_aux_labels /*=nullptr*/) { +k2::FsaClass LoadFsa(const std::string &filename) { auto rai = std::make_unique(filename); // Verify that we're loading a zip archive and not a torch.save pickle archive @@ -391,19 +390,19 @@ k2::FsaOrVec LoadFsa(const std::string &filename, // // We are using this function to load HLG.pt, whose aux_labels are ragged // tensors. - if (ragged_aux_labels != nullptr && dict.contains("aux_labels") && + torch::IValue ragged_aux_labels; + if (dict.contains("aux_labels") && dict.at("aux_labels").type() == c10::getCustomClassType>()) { - *ragged_aux_labels = - *dict.at("aux_labels").toCustomClass(); + ragged_aux_labels = dict.at("aux_labels"); } - // todo: attach aux_labels to the returned FSA - // K2_LOG(INFO) << "aux_labels:" << aux_labels; bool error = false; Fsa fsa = FsaFromArray1(arcs, &error); K2_CHECK_EQ(error, false); - - return fsa; + FsaClass dest(fsa); + if (!ragged_aux_labels.isNone()) + dest.SetAttr("aux_labels", ragged_aux_labels); + return dest; } } // namespace k2 diff --git a/k2/torch/csrc/deserialization.h b/k2/torch/csrc/deserialization.h index 46c007af2..ee2f5fa24 100644 --- a/k2/torch/csrc/deserialization.h +++ b/k2/torch/csrc/deserialization.h @@ -22,6 +22,7 @@ #include #include "k2/csrc/fsa.h" +#include "k2/torch/csrc/fsa_class.h" #include "torch/script.h" namespace k2 { @@ -47,8 +48,7 @@ struct RaggedIntHelper : public Ragged, ragged tensors, then return it via this parameter. @return Return the FSA contained in the filename. */ -k2::FsaOrVec LoadFsa(const std::string &filename, - Ragged *ragged_aux_labels = nullptr); +k2::FsaClass LoadFsa(const std::string &filename); } // namespace k2 diff --git a/k2/torch/csrc/fsa_algo.cu b/k2/torch/csrc/fsa_algo.cu new file mode 100644 index 000000000..ed6549163 --- /dev/null +++ b/k2/torch/csrc/fsa_algo.cu @@ -0,0 +1,71 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Wei Kang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "k2/csrc/fsa_algo.h" +#include "k2/csrc/fsa_utils.h" +#include "k2/torch/csrc/fsa_algo.h" +#include "k2/torch/csrc/utils.h" + +namespace k2 { + +FsaClass CtcTopo(int32_t max_token, bool modified /*= false*/, + torch::Device device /*=torch::kCPU*/) { + Array1 aux_labels; + auto ctx = ContextFromDevice(device); + Fsa fsa = CtcTopo(ctx, max_token, modified, &aux_labels); + FsaClass dest(fsa); + dest.SetAttr("aux_labels", torch::IValue(Array1ToTorch(aux_labels))); + return dest; +} + +FsaClass IntersectDensePruned(FsaClass &graph, DenseFsaVec &dense, + float search_beam, float output_beam, + int32_t min_activate_states, + int32_t max_activate_states) { + Array1 graph_arc_map; + Array1 dense_arc_map; + FsaVec fsa; + IntersectDensePruned(graph.fsa, dense, search_beam, output_beam, + min_activate_states, max_activate_states, &fsa, + &graph_arc_map, &dense_arc_map); + FsaClass dest(fsa); + dest.CopyAttrs(graph, Array1ToTorch(graph_arc_map)); + return dest; +} + +FsaClass ShortestPath(FsaClass &lattice) { + Ragged state_batches = GetStateBatches(lattice.fsa, true); + Array1 dest_states = GetDestStates(lattice.fsa, true); + Ragged incoming_arcs = GetIncomingArcs(lattice.fsa, dest_states); + Ragged entering_arc_batches = + GetEnteringArcIndexBatches(lattice.fsa, incoming_arcs, state_batches); + + bool log_semiring = false; + Array1 entering_arcs; + GetForwardScores(lattice.fsa, state_batches, entering_arc_batches, + log_semiring, &entering_arcs); + + Ragged best_path_arc_indexes = + ShortestPath(lattice.fsa, entering_arcs); + + FsaVec out = FsaVecFromArcIndexes(lattice.fsa, best_path_arc_indexes); + torch::Tensor arc_map = Array1ToTorch(best_path_arc_indexes.values); + return FsaClass::FromUnaryFunctionTensor(lattice, out, arc_map); +} + +} // namespace k2 diff --git a/k2/torch/csrc/fsa_algo.h b/k2/torch/csrc/fsa_algo.h new file mode 100644 index 000000000..221212c8a --- /dev/null +++ b/k2/torch/csrc/fsa_algo.h @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Wei Kang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef K2_TORCH_CSRC_FSA_ALGO_H_ +#define K2_TORCH_CSRC_FSA_ALGO_H_ + +#include "k2/csrc/fsa.h" +#include "k2/torch/csrc/fsa_class.h" + +namespace k2 { + +FsaClass CtcTopo(int32_t max_token, bool modified = false, + torch::Device device = torch::kCPU); + +FsaClass IntersectDensePruned(FsaClass &graph, DenseFsaVec &dense, + float search_beam, float output_beam, + int32_t min_activate_states, + int32_t max_activate_states); + +FsaClass ShortestPath(FsaClass &lattice); +} // namespace k2 + +#endif // K2_TORCH_CSRC_FSA_ALGO_H_ diff --git a/k2/torch/csrc/fsa_class.cu b/k2/torch/csrc/fsa_class.cu index cf19ebba7..ec4972ae2 100644 --- a/k2/torch/csrc/fsa_class.cu +++ b/k2/torch/csrc/fsa_class.cu @@ -45,6 +45,11 @@ FsaClass FsaClass::FromUnaryFunctionTensor(FsaClass &src, const FsaOrVec &arcs, return dest; } +void FsaClass::CopyAttrs(FsaClass &src, torch::Tensor arc_map) { + CopyTensorAttrs(src, arc_map); + CopyRaggedTensorAttrs(src, arc_map); +} + void FsaClass::CopyTensorAttrs(FsaClass &src, torch::Tensor arc_map) { for (const auto &iter : src.tensor_attrs) { if (!HasAttr(iter.first)) { @@ -68,16 +73,6 @@ void FsaClass::CopyRaggedTensorAttrs(FsaClass &src, torch::Tensor arc_map) { } } -void FsaClass::CopyRaggedTensorAttrs(FsaClass &src, Ragged &arc_map) { - for (auto &iter : src.ragged_tensor_attrs) { - if (!HasAttr(iter.first)) { - Ragged new_value = - Index(iter.second, arc_map, /*remove_axis*/ true); - SetRaggedTensorAttr(iter.first, new_value); - } - } -} - void FsaClass::SetScores(torch::Tensor scores) { K2_CHECK_EQ(scores.numel(), fsa.NumElements()); K2_CHECK_EQ(scores.scalar_type(), torch::kFloat32); @@ -243,4 +238,48 @@ bool FsaClass::HasAttr(const std::string &name) const { return all_attr_names.count(name) > 0; } +FsaClass FsaClass::ToOtherContext(const ContextPtr &context) const { + K2_CHECK(!context->IsCompatible(*fsa.Context())); + FsaClass dest(fsa.To(context)); + auto device = DeviceFromContext(context); + for (const auto &iter : tensor_attrs) { + dest.SetTensorAttr(iter.first, (iter.second).to(device)); + } + for (const auto &iter : ragged_tensor_attrs) { + dest.SetRaggedTensorAttr(iter.first, (iter.second).To(context)); + } + return dest; +} + +FsaClass FsaClass::To(torch::Device device) const { + ContextPtr context = fsa.Context(); + if (device.is_cpu()) { + // CPU -> CPU + if (context->GetDeviceType() == kCpu) return *this; + + // CUDA -> CPU + DeviceGuard guard(context); + return this->ToOtherContext(GetCpuContext()); + } + + K2_CHECK(device.is_cuda()) << device.str(); + + int32_t device_index = device.index(); + + if (context->GetDeviceType() == kCuda && + context->GetDeviceId() == device_index) + // CUDA to CUDA, and it's the same device + return *this; + + // CPU to CUDA + // or from one GPU to another GPU + DeviceGuard guard(device_index); + return this->ToOtherContext(GetCudaContext(device_index)); +} + +FsaClass FsaClass::To(const std::string &device) const { + torch::Device d(device); + return this->To(d); +} + } // namespace k2 diff --git a/k2/torch/csrc/fsa_class.h b/k2/torch/csrc/fsa_class.h index e5606f7f1..9e581573e 100644 --- a/k2/torch/csrc/fsa_class.h +++ b/k2/torch/csrc/fsa_class.h @@ -52,7 +52,11 @@ struct FsaClass { // The default constructor initializes an invalid FSA. FsaClass() = default; - explicit FsaClass(const FsaOrVec &fsa) : fsa(fsa) {} + explicit FsaClass(const FsaOrVec &fsa) : fsa(fsa) { + // Check the validation of the fsa, will trigger a fatal error if the fsa + // is not valid. + Properties(); + } FsaClass(const FsaClass &other) = default; @@ -133,6 +137,21 @@ struct FsaClass { */ bool HasAttr(const std::string &name) const; + /** Propagate attributes from source FsaClass via tensor arc_map. + + Caution: If there are attributes in source FsaClass with the name + conflicting with current FsaClass, we will skip the attributes in source + FsaClass and keep the current one. + + @param src The source FsaClass. + @param arc_map The arc_map (as idx012) to select items in attributes. + */ + void CopyAttrs(FsaClass &src, torch::Tensor arc_map); + + // Transfer current fsa to another device. + FsaClass To(torch::Device device) const; + FsaClass To(const std::string &device) const; + private: /** Associate an tensor attribute with a value directly. @@ -171,13 +190,13 @@ struct FsaClass { /** Propagate tensor attributes from source FsaClass via tensor arc_map. - Caution: If there are attributes in source FsaClass with the name - conflicting with current FsaClass, we will skip the attributes in source - FsaClass and keep the current one. + Caution: If there are attributes in source FsaClass with the name + conflicting with current FsaClass, we will skip the attributes in source + FsaClass and keep the current one. - @param src The source FsaClass. - @param arc_map The arc_map (as idx012) to select items in attributes. - */ + @param src The source FsaClass. + @param arc_map The arc_map (as idx012) to select items in attributes. + */ void CopyTensorAttrs(FsaClass &src, torch::Tensor arc_map); /** Propagate ragged tensor attributes from source FsaClass via tensor @@ -192,18 +211,15 @@ struct FsaClass { */ void CopyRaggedTensorAttrs(FsaClass &src, torch::Tensor arc_map); - /** Propagate ragged tensor attributes from source FsaClass via - Ragged arc_map. + /** Transfer current FsaClass to another devices. - Caution: If there are attributes in source FsaClass - with the name conflicting with current FsaClass, we will skip the attributes - in source FsaClass and keep the current one. + Note: This function assumes that the target context is different from + current context. It crashes if you call this function with the context the + same as current one. - @param src The source FsaClass. - @param arc_map The arc_map (arc_map.values as idx012) to select items in - attributes. + @param context The target context. */ - void CopyRaggedTensorAttrs(FsaClass &src, Ragged &arc_map); + FsaClass ToOtherContext(const ContextPtr &context) const; }; } // namespace k2 From 9711beb1f2c4797bea2f69721919250e188d468b Mon Sep 17 00:00:00 2001 From: pkufool Date: Fri, 5 Nov 2021 16:03:16 +0800 Subject: [PATCH 2/2] Update docs --- k2/torch/csrc/decode.h | 36 ++++------------------- k2/torch/csrc/fsa_algo.h | 58 +++++++++++++++++++++++++++++++++++++- k2/torch/csrc/fsa_class.cu | 4 +-- 3 files changed, 65 insertions(+), 33 deletions(-) diff --git a/k2/torch/csrc/decode.h b/k2/torch/csrc/decode.h index 5bef2d4d8..4076e77ff 100644 --- a/k2/torch/csrc/decode.h +++ b/k2/torch/csrc/decode.h @@ -27,21 +27,14 @@ namespace k2 { -/* -Note: Several functions in this file takes as inputs two kinds of aux_labels: -a linear array and a ragged array. Only one of them is used. We will refactor -the code once `FsaClass` is implemented, which wraps an FsaVec and its -attributes. -*/ - /** Get decoding lattice from a neural network output and a decoding graph. @param nnet_output A 3-D tensor with dtype torch.float32. It is usally the last layer of the neural network model, e.g., the output of `log-softmax` layer. It has shape `(N, T, C)`. - @param decoding_graph It is an FsaVec. It usually contains only only - on graph. For instance, when using CTC decoding, + @param decoding_graph It is an FsaClass. It usually contains only one + graph. For instance, when using CTC decoding, it contains a single CTC topo graph; when using HLG decoding, it contains a single HLG graph. @@ -53,21 +46,10 @@ attributes. @param min_activate_states See `k2::IntersectDensePruned()` for its meaning. @param max_activate_states See `k2::IntersectDensePruned()` for its meaning. @param subsampling_factor The subsampling factor of the model. - @param in_aux_labels If not empty, it associates an extra label with each - arc in decoding graph. - in_aux_labels.Dim() == decoding_graph.NumElements() - if in_aux_labels is not empty. - @param in_ragged_aux_labels If not empty, it must have 2 axes and it - associates an extra label with each arc in decoding graph. - in_ragged_aux_labels.tot_size(0) == decoding_graph.NumElements() - if in_ragged_aux_labels is not empty. - @param out_aux_labels If in_aux_labels is not empty, it associates an extra - label for each arc in the returned FSA - @param out_ragged_aux_labels If in_aux_labels is not empty, it associates an - extra label for each arc in the returned FSA - @return Return an FsaVec, which is the intersection of decoding graph and - the FSA constructed from `nnet_output`. + @return Return an FsaClass, which contains the intersection of decoding graph + and the FSA constructed from `nnet_output`. All the attributes of the + decoding_graph are propagated the returned FsaClass as well. */ FsaClass GetLattice(torch::Tensor nnet_output, FsaClass &decoding_graph, torch::Tensor supervision_segments, float search_beam, @@ -76,15 +58,9 @@ FsaClass GetLattice(torch::Tensor nnet_output, FsaClass &decoding_graph, /** Get aux labels of each FSA contained in the lattice. - Note: The input aux labels are for each arc in the lattice, while - the output aux_labels are for each FSA in the lattice. - @param lattice An FsaVec containing linear FSAs. It can be the return value of `OneBestDecoding()`. - @param in_aux_labels If not empty, it associates an extra label with each - arc in the `lattice. - @param in_ragged_aux_labels If not empty, it associates an extra label - with each arc in the `lattice. + @return Return a ragged array with two axes [utt][aux_label]. */ Ragged GetTexts(FsaClass &lattice); diff --git a/k2/torch/csrc/fsa_algo.h b/k2/torch/csrc/fsa_algo.h index 221212c8a..dc36ab53f 100644 --- a/k2/torch/csrc/fsa_algo.h +++ b/k2/torch/csrc/fsa_algo.h @@ -24,14 +24,70 @@ namespace k2 { +/* Create a CTC topology. + + Note: + A standard CTC topology is the conventional one, where there + is a mandatory blank between two repeated neighboring symbols. + A non-standard, i.e., modified CTC topology, imposes no such constraint. + + @param max_token The maximum token ID (inclusive). We assume that token IDs + are contiguous (from 1 to `max_token`). 0 represents blank. + @param modified If False, create a standard CTC topology. Otherwise, create + a modified CTC topology. + @param device A torch.device indicating what device the returned Fsa will + be. Default torch::CPU. + @return Return either a standard or a modified CTC topology as an FSA + depending on whether `modified` is false or true. + */ FsaClass CtcTopo(int32_t max_token, bool modified = false, torch::Device device = torch::kCPU); -FsaClass IntersectDensePruned(FsaClass &graph, DenseFsaVec &dense, +/* Intersect a DenseFsaVec constructed from nnet_output with an FsaClass, i.e., + decoding graphs. + + @param graphs Input FsaClass containing decoding graphs and the associated + attributes. The decoding graph might just be a linear + sequence of phones, or might be something more complicated. + Must have either `graph.fsa.shape[0] == dense.dim0()`, or + `graphs.fsa.shape[0] == 1` in which case the graph is shared. + @param dense Input FSAs that correspond to neural network output. + @param search_beam Decoding beam, e.g. 20. Smaller is faster, larger is + more exact (less pruning). This is the default value; it + may be modified by `min_active_states` and + `max_active_states`. + @param output_beam Pruning beam for the output of intersection (vs. best + path); equivalent to kaldi's lattice-beam. E.g. 8. + @param max_active_states Maximum number of FSA states that are allowed to + be active on any given frame for any given + intersection/composition task. This is advisory, + in that it will try not to exceed that but may not + always succeed. You can use a very large number if + no constraint is needed. + @param min_active_states Minimum number of FSA states that are allowed to + be active on any given frame for any given + intersection/composition task. This is advisory, + in that it will try not to have fewer than this + number active. Set it to zero if there is no + constraint. + @return Returns an FsaClass containing the intersection of DenseFsaVec and + decoding graphs with the attributes propagated. + */ +FsaClass IntersectDensePruned(FsaClass &graphs, DenseFsaVec &dense, float search_beam, float output_beam, int32_t min_activate_states, int32_t max_activate_states); +/* Return the shortest paths as linear FSAs from the start state + to the final state in the tropical semiring. + + Note: + It uses the opposite sign. That is, It uses `max` instead of `min`. + + @param lattice The input FsaClass. + @return An FsaClass containing the best paths as linear FSAs with the + attributes propagated. + */ FsaClass ShortestPath(FsaClass &lattice); } // namespace k2 diff --git a/k2/torch/csrc/fsa_class.cu b/k2/torch/csrc/fsa_class.cu index ec4972ae2..0735ef23e 100644 --- a/k2/torch/csrc/fsa_class.cu +++ b/k2/torch/csrc/fsa_class.cu @@ -243,10 +243,10 @@ FsaClass FsaClass::ToOtherContext(const ContextPtr &context) const { FsaClass dest(fsa.To(context)); auto device = DeviceFromContext(context); for (const auto &iter : tensor_attrs) { - dest.SetTensorAttr(iter.first, (iter.second).to(device)); + dest.SetTensorAttr(iter.first, iter.second.to(device)); } for (const auto &iter : ragged_tensor_attrs) { - dest.SetRaggedTensorAttr(iter.first, (iter.second).To(context)); + dest.SetRaggedTensorAttr(iter.first, iter.second.To(context)); } return dest; }