Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using FsaClass for ctc decoding & HLG decoding #862

Merged
merged 2 commits into from
Nov 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 20 additions & 27 deletions k2/torch/bin/decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
#include <string>
#include <vector>

#include "k2/csrc/fsa_algo.h"
#include "k2/torch/csrc/decode.h"
#include "k2/torch/csrc/dense_fsa_vec.h"
#include "k2/torch/csrc/deserialization.h"
#include "k2/torch/csrc/features.h"
#include "k2/torch/csrc/fsa_algo.h"
#include "k2/torch/csrc/symbol_table.h"
#include "k2/torch/csrc/utils.h"
#include "k2/torch/csrc/wave_reader.h"
Expand All @@ -36,7 +36,7 @@
#include "torch/utils.h"

C10_DEFINE_bool(use_gpu, false, "True to use GPU. False to use CPU");
C10_DEFINE_string(jit_pt, "", "Path to exported jit filename.");
C10_DEFINE_string(jit_pt, "", "Path to exported jit file.");
C10_DEFINE_string(
bpe_model, "",
"Path to a pretrained BPE model. Needed if --use_ctc_decoding is true");
Expand All @@ -45,7 +45,7 @@ C10_DEFINE_string(hlg, "",
"Path to HLG.pt. Needed if --use_ctc_decoding is false");
C10_DEFINE_string(word_table, "",
"Path to words.txt. Needed if --use_ctc_decoding is false");
//
// Fsa decoding related
C10_DEFINE_double(search_beam, 20, "search_beam in IntersectDensePruned");
C10_DEFINE_double(output_beam, 8, "output_beam in IntersectDensePruned");
C10_DEFINE_int(min_activate_states, 30,
Expand Down Expand Up @@ -120,7 +120,7 @@ int main(int argc, char *argv[]) {
--use_ctc_decoding false \
--jit_pt <path to exported torch script pt file> \
--hlg <path to HLG.pt> \
--word-table <path to words.txt> \
--word_table <path to words.txt> \
/path/to/foo.wav \
/path/to/bar.wav \
<more wave files if any>
Expand All @@ -141,7 +141,7 @@ int main(int argc, char *argv[]) {
K2_LOG(INFO) << "Device: " << device;

int32_t num_waves = argc - 1;
K2_CHECK_GE(num_waves, 1) << "You have to provided at least one wave file";
K2_CHECK_GE(num_waves, 1) << "You have to provide at least one wave file";
std::vector<std::string> wave_filenames(num_waves);
for (int32_t i = 0; i != num_waves; ++i) {
wave_filenames[i] = argv[i + 1];
Expand Down Expand Up @@ -205,50 +205,43 @@ int main(int argc, char *argv[]) {
torch::Tensor supervision_segments =
k2::GetSupervisionSegments(supervisions, subsampling_factor);

k2::ContextPtr ctx = k2::ContextFromTensor(nnet_output);

k2::Fsa decoding_graph;

k2::Array1<int32_t> aux_labels; // only one of the two aux_labels is used
k2::Ragged<int32_t> ragged_aux_labels;
k2::FsaClass decoding_graph;

if (FLAGS_use_ctc_decoding) {
K2_LOG(INFO) << "Build CTC topo";
k2::Fsa ctc_topo = k2::CtcTopo(ctx, nnet_output.size(2) - 1,
/*modified*/ false, &aux_labels);
decoding_graph = k2::FsaToFsaVec(ctc_topo);
decoding_graph =
k2::CtcTopo(nnet_output.size(2) - 1, /*modified*/ false, device);
} else {
K2_LOG(INFO) << "Load HLG.pt";
// TODO(fangjun): We will eventually use an FSA wrapper to
// associate attributes with an FSA.
decoding_graph = k2::LoadFsa(FLAGS_hlg, &ragged_aux_labels);
decoding_graph = decoding_graph.To(ctx);
ragged_aux_labels = ragged_aux_labels.To(ctx);
decoding_graph = k2::LoadFsa(FLAGS_hlg);
decoding_graph = decoding_graph.To(device);
}

K2_LOG(INFO) << "Decoding";
k2::FsaVec lattice = k2::GetLattice(
k2::FsaClass lattice = k2::GetLattice(
nnet_output, decoding_graph, supervision_segments, FLAGS_search_beam,
FLAGS_output_beam, FLAGS_min_activate_states, FLAGS_max_activate_states,
subsampling_factor, aux_labels, ragged_aux_labels, &aux_labels,
&ragged_aux_labels);
subsampling_factor);

lattice = k2::OneBestDecoding(lattice, aux_labels, ragged_aux_labels,
&aux_labels, &ragged_aux_labels);
lattice = k2::ShortestPath(lattice);

auto ragged_aux_labels = k2::GetTexts(lattice);

ragged_aux_labels = k2::GetTexts(lattice, aux_labels, ragged_aux_labels);
auto aux_labels_vec = ragged_aux_labels.ToVecVec();

std::vector<std::string> texts;
if (FLAGS_use_ctc_decoding) {
sentencepiece::SentencePieceProcessor processor;
const auto status = processor.Load(FLAGS_bpe_model);
auto status = processor.Load(FLAGS_bpe_model);
if (!status.ok()) {
K2_LOG(FATAL) << status.ToString();
}
for (const auto &ids : aux_labels_vec) {
std::string text;
processor.Decode(ids, &text);
status = processor.Decode(ids, &text);
if (!status.ok()) {
K2_LOG(FATAL) << status.ToString();
}
texts.emplace_back(std::move(text));
}
} else {
Expand Down
1 change: 1 addition & 0 deletions k2/torch/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ set(k2_torch_srcs
dense_fsa_vec.cu
deserialization.cu
features.cu
fsa_algo.cu
fsa_class.cu
symbol_table.cu
utils.cu
Expand Down
99 changes: 20 additions & 79 deletions k2/torch/csrc/decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,95 +20,36 @@
#include "k2/csrc/ragged_ops.h"
#include "k2/torch/csrc/decode.h"
#include "k2/torch/csrc/dense_fsa_vec.h"
#include "k2/torch/csrc/fsa_algo.h"
#include "k2/torch/csrc/utils.h"

namespace k2 {

FsaVec GetLattice(torch::Tensor nnet_output, FsaVec decoding_graph,
torch::Tensor supervision_segments, float search_beam,
float output_beam, int32_t min_activate_states,
int32_t max_activate_states, int32_t subsampling_factor,
Array1<int32_t> &in_aux_labels,
Ragged<int32_t> &in_ragged_aux_labels,
Array1<int32_t> *out_aux_labels,
Ragged<int32_t> *out_ragged_aux_labels) {
if (in_aux_labels.Dim() != 0) {
K2_CHECK_EQ(in_ragged_aux_labels.values.Dim(), 0);
} else {
K2_CHECK_NE(in_ragged_aux_labels.values.Dim(), 0);
}

FsaClass GetLattice(torch::Tensor nnet_output, FsaClass &decoding_graph,
torch::Tensor supervision_segments, float search_beam,
float output_beam, int32_t min_activate_states,
int32_t max_activate_states, int32_t subsampling_factor) {
DenseFsaVec dense_fsa_vec = CreateDenseFsaVec(
nnet_output, supervision_segments, subsampling_factor - 1);

FsaVec lattice;
Array1<int32_t> arc_map_a;
Array1<int32_t> arc_map_b;
IntersectDensePruned(decoding_graph, dense_fsa_vec, search_beam, output_beam,
min_activate_states, max_activate_states, &lattice,
&arc_map_a, &arc_map_b);
if (in_aux_labels.Dim() > 0) {
// see Index() in array_ops.h
*out_aux_labels = Index(in_aux_labels, arc_map_a, /*allow_minus_one*/ false,
/*default_value*/ 0);
} else {
// See Index() in ragged_ops.h
*out_ragged_aux_labels = Index(in_ragged_aux_labels, /*axis*/ 0, arc_map_a);
}
return lattice;
}

FsaVec OneBestDecoding(FsaVec &lattice, Array1<int32_t> &in_aux_labels,
Ragged<int32_t> &in_ragged_aux_labels,
Array1<int32_t> *out_aux_labels,
Ragged<int32_t> *out_ragged_aux_labels) {
if (in_aux_labels.Dim() != 0) {
K2_CHECK_EQ(in_ragged_aux_labels.values.Dim(), 0);
} else {
K2_CHECK_NE(in_ragged_aux_labels.values.Dim(), 0);
}

Ragged<int32_t> state_batches = GetStateBatches(lattice, true);
Array1<int32_t> dest_states = GetDestStates(lattice, true);
Ragged<int32_t> incoming_arcs = GetIncomingArcs(lattice, dest_states);
Ragged<int32_t> entering_arc_batches =
GetEnteringArcIndexBatches(lattice, incoming_arcs, state_batches);

bool log_semiring = false;
Array1<int32_t> entering_arcs;
GetForwardScores<float>(lattice, state_batches, entering_arc_batches,
log_semiring, &entering_arcs);

Ragged<int32_t> best_path_arc_indexes = ShortestPath(lattice, entering_arcs);

if (in_aux_labels.Dim() > 0) {
*out_aux_labels = Index(in_aux_labels, best_path_arc_indexes.values,
/*allow_minus_one*/ false,
/*default_value*/ 0);
} else {
*out_ragged_aux_labels =
Index(in_ragged_aux_labels, /*axis*/ 0, best_path_arc_indexes.values);
}

FsaVec out = FsaVecFromArcIndexes(lattice, best_path_arc_indexes);
return out;
return IntersectDensePruned(decoding_graph, dense_fsa_vec, search_beam,
output_beam, min_activate_states,
max_activate_states);
}

Ragged<int32_t> GetTexts(FsaVec &lattice, Array1<int32_t> &in_aux_labels,
Ragged<int32_t> &in_ragged_aux_labels) {
if (in_aux_labels.Dim() != 0) {
K2_CHECK_EQ(in_ragged_aux_labels.values.Dim(), 0);
} else {
K2_CHECK_NE(in_ragged_aux_labels.values.Dim(), 0);
}

Ragged<int32_t> GetTexts(FsaClass &lattice) {
K2_CHECK(lattice.HasAttr("aux_labels"));
Ragged<int32_t> ragged_aux_labels;
if (in_aux_labels.Dim() != 0) {
// [utt][state][arc] -> [utt][arc]
RaggedShape aux_labels_shape = RemoveAxis(lattice.shape, 1);
ragged_aux_labels = Ragged<int32_t>(aux_labels_shape, in_aux_labels);
torch::IValue aux_labels = lattice.GetAttr("aux_labels");
if (aux_labels.isTensor()) {
Array1<int32_t> aux_labels_array =
Array1FromTorch<int32_t>(aux_labels.toTensor());
RaggedShape aux_labels_shape = RemoveAxis(lattice.fsa.shape, 1);
ragged_aux_labels = Ragged<int32_t>(aux_labels_shape, aux_labels_array);
} else {
K2_CHECK(IsRaggedInt(aux_labels));
Ragged<int32_t> in_ragged_aux_labels = ToRaggedInt(aux_labels);
RaggedShape aux_labels_shape =
ComposeRaggedShapes(lattice.shape, in_ragged_aux_labels.shape);
ComposeRaggedShapes(lattice.fsa.shape, in_ragged_aux_labels.shape);
aux_labels_shape = RemoveAxis(aux_labels_shape, 1);
aux_labels_shape = RemoveAxis(aux_labels_shape, 1);
ragged_aux_labels =
Expand Down
71 changes: 12 additions & 59 deletions k2/torch/csrc/decode.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,19 @@
#include "k2/csrc/array.h"
#include "k2/csrc/fsa.h"
#include "k2/csrc/ragged.h"
#include "k2/torch/csrc/fsa_class.h"
#include "torch/script.h"

namespace k2 {

/*
Note: Several functions in this file takes as inputs two kinds of aux_labels:
a linear array and a ragged array. Only one of them is used. We will refactor
the code once `FsaClass` is implemented, which wraps an FsaVec and its
attributes.
*/

/** Get decoding lattice from a neural network output and a decoding graph.

@param nnet_output A 3-D tensor with dtype torch.float32. It is usally
the last layer of the neural network model, e.g.,
the output of `log-softmax` layer. It has shape
`(N, T, C)`.
@param decoding_graph It is an FsaVec. It usually contains only only
on graph. For instance, when using CTC decoding,
@param decoding_graph It is an FsaClass. It usually contains only one
graph. For instance, when using CTC decoding,
it contains a single CTC topo graph; when using
HLG decoding, it contains a single HLG graph.

Expand All @@ -52,65 +46,24 @@ attributes.
@param min_activate_states See `k2::IntersectDensePruned()` for its meaning.
@param max_activate_states See `k2::IntersectDensePruned()` for its meaning.
@param subsampling_factor The subsampling factor of the model.
@param in_aux_labels If not empty, it associates an extra label with each
arc in decoding graph.
in_aux_labels.Dim() == decoding_graph.NumElements()
if in_aux_labels is not empty.
@param in_ragged_aux_labels If not empty, it must have 2 axes and it
associates an extra label with each arc in decoding graph.
in_ragged_aux_labels.tot_size(0) == decoding_graph.NumElements()
if in_ragged_aux_labels is not empty.
@param out_aux_labels If in_aux_labels is not empty, it associates an extra
label for each arc in the returned FSA
@param out_ragged_aux_labels If in_aux_labels is not empty, it associates an
extra label for each arc in the returned FSA

@return Return an FsaVec, which is the intersection of decoding graph and
the FSA constructed from `nnet_output`.
@return Return an FsaClass, which contains the intersection of decoding graph
and the FSA constructed from `nnet_output`. All the attributes of the
decoding_graph are propagated the returned FsaClass as well.
*/
FsaVec GetLattice(torch::Tensor nnet_output, FsaVec decoding_graph,
torch::Tensor supervision_segments, float search_beam,
float output_beam, int32_t min_activate_states,
int32_t max_activate_states, int32_t subsampling_factor,
Array1<int32_t> &in_aux_labels,
Ragged<int32_t> &in_ragged_aux_labels,
Array1<int32_t> *out_aux_labels,
Ragged<int32_t> *out_ragged_aux_labels);

/** Extract the best path from a lattice.

@param lattice It can be the return value of `GetLattice()`.
@param in_aux_labels If not empty, it associates an extra label with each
arc in the input lattice.
@param in_ragged_aux_labels If not empty, it associates an extra label with
each arc in the input lattice.
@param out_aux_labels If in_aux_labels is not empty, it contains the
aux_labels for the returned FSA.
@param out_ragged_aux_labels If in_aux_labels is not empty, it contains
the aux_labels for the returned FSA.

@return Return a FsaVec containing linear FSAs.
*/
FsaVec OneBestDecoding(FsaVec &lattice, Array1<int32_t> &in_aux_labels,
Ragged<int32_t> &in_ragged_aux_labels,
Array1<int32_t> *out_aux_labels,
Ragged<int32_t> *out_ragged_aux_labels);
FsaClass GetLattice(torch::Tensor nnet_output, FsaClass &decoding_graph,
torch::Tensor supervision_segments, float search_beam,
float output_beam, int32_t min_activate_states,
int32_t max_activate_states, int32_t subsampling_factor);

/** Get aux labels of each FSA contained in the lattice.

Note: The input aux labels are for each arc in the lattice, while
the output aux_labels are for each FSA in the lattice.

@param lattice An FsaVec containing linear FSAs. It can be the return
value of `OneBestDecoding()`.
@param in_aux_labels If not empty, it associates an extra label with each
arc in the `lattice.
@param in_ragged_aux_labels If not empty, it associates an extra label
with each arc in the `lattice.

@return Return a ragged array with two axes [utt][aux_label].
*/
Ragged<int32_t> GetTexts(FsaVec &lattice, Array1<int32_t> &in_aux_labels,
Ragged<int32_t> &in_ragged_aux_labels);
Ragged<int32_t> GetTexts(FsaClass &lattice);

} // namespace k2

Expand Down
17 changes: 8 additions & 9 deletions k2/torch/csrc/deserialization.cu
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,7 @@ static void RegisterRaggedInt() {
// This function is modified from torch::jit::load()
// See torch/csrc/jit/serialization/import.cpp
//
k2::FsaOrVec LoadFsa(const std::string &filename,
Ragged<int32_t> *ragged_aux_labels /*=nullptr*/) {
k2::FsaClass LoadFsa(const std::string &filename) {
auto rai = std::make_unique<caffe2::serialize::FileAdapter>(filename);

// Verify that we're loading a zip archive and not a torch.save pickle archive
Expand Down Expand Up @@ -391,19 +390,19 @@ k2::FsaOrVec LoadFsa(const std::string &filename,
//
// We are using this function to load HLG.pt, whose aux_labels are ragged
// tensors.
if (ragged_aux_labels != nullptr && dict.contains("aux_labels") &&
torch::IValue ragged_aux_labels;
if (dict.contains("aux_labels") &&
dict.at("aux_labels").type() ==
c10::getCustomClassType<c10::intrusive_ptr<RaggedIntHelper>>()) {
*ragged_aux_labels =
*dict.at("aux_labels").toCustomClass<RaggedIntHelper>();
ragged_aux_labels = dict.at("aux_labels");
}
// todo: attach aux_labels to the returned FSA
// K2_LOG(INFO) << "aux_labels:" << aux_labels;
bool error = false;
Fsa fsa = FsaFromArray1(arcs, &error);
K2_CHECK_EQ(error, false);

return fsa;
FsaClass dest(fsa);
if (!ragged_aux_labels.isNone())
dest.SetAttr("aux_labels", ragged_aux_labels);
return dest;
}

} // namespace k2
4 changes: 2 additions & 2 deletions k2/torch/csrc/deserialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <string>

#include "k2/csrc/fsa.h"
#include "k2/torch/csrc/fsa_class.h"
#include "torch/script.h"

namespace k2 {
Expand All @@ -47,8 +48,7 @@ struct RaggedIntHelper : public Ragged<int32_t>,
ragged tensors, then return it via this parameter.
@return Return the FSA contained in the filename.
*/
k2::FsaOrVec LoadFsa(const std::string &filename,
Ragged<int32_t> *ragged_aux_labels = nullptr);
k2::FsaClass LoadFsa(const std::string &filename);

} // namespace k2

Expand Down
Loading