Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Rnn-T LM nbest rescoring #471

Merged
merged 4 commits into from
Jul 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 183 additions & 4 deletions egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Dict, List, Optional

import k2
import sentencepiece as spm
import torch
from model import Transducer

Expand All @@ -34,6 +35,7 @@ def fast_beam_search_one_best(
beam: float,
max_states: int,
max_contexts: int,
temperature: float = 1.0,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.

Expand All @@ -56,6 +58,8 @@ def fast_beam_search_one_best(
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
temperature:
Softmax temperature.
Returns:
Return the decoded result.
"""
Expand All @@ -67,6 +71,7 @@ def fast_beam_search_one_best(
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
)

best_path = one_best_decoding(lattice)
Expand All @@ -85,6 +90,7 @@ def fast_beam_search_nbest_LG(
num_paths: int,
nbest_scale: float = 0.5,
use_double_scores: bool = True,
temperature: float = 1.0,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.

Expand Down Expand Up @@ -120,6 +126,8 @@ def fast_beam_search_nbest_LG(
use_double_scores:
True to use double precision for computation. False to use
single precision.
temperature:
Softmax temperature.
Returns:
Return the decoded result.
"""
Expand All @@ -131,6 +139,7 @@ def fast_beam_search_nbest_LG(
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
)

nbest = Nbest.from_lattice(
Expand Down Expand Up @@ -201,6 +210,7 @@ def fast_beam_search_nbest(
num_paths: int,
nbest_scale: float = 0.5,
use_double_scores: bool = True,
temperature: float = 1.0,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.

Expand Down Expand Up @@ -236,6 +246,8 @@ def fast_beam_search_nbest(
use_double_scores:
True to use double precision for computation. False to use
single precision.
temperature:
Softmax temperature.
Returns:
Return the decoded result.
"""
Expand All @@ -247,6 +259,7 @@ def fast_beam_search_nbest(
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
)

nbest = Nbest.from_lattice(
Expand Down Expand Up @@ -282,6 +295,7 @@ def fast_beam_search_nbest_oracle(
ref_texts: List[List[int]],
use_double_scores: bool = True,
nbest_scale: float = 0.5,
temperature: float = 1.0,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.

Expand Down Expand Up @@ -321,7 +335,8 @@ def fast_beam_search_nbest_oracle(
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.

temperature:
Softmax temperature.
Returns:
Return the decoded result.
"""
Expand All @@ -333,6 +348,7 @@ def fast_beam_search_nbest_oracle(
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
)

nbest = Nbest.from_lattice(
Expand Down Expand Up @@ -373,6 +389,7 @@ def fast_beam_search(
beam: float,
max_states: int,
max_contexts: int,
temperature: float = 1.0,
) -> k2.Fsa:
"""It limits the maximum number of symbols per frame to 1.

Expand All @@ -392,6 +409,8 @@ def fast_beam_search(
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
temperature:
Softmax temperature.
Returns:
Return an FsaVec with axes [utt][state][arc] containing the decoded
lattice. Note: When the input graph is a TrivialGraph, the returned
Expand Down Expand Up @@ -440,7 +459,7 @@ def fast_beam_search(
project_input=False,
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
log_probs = (logits / temperature).log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
Expand Down Expand Up @@ -783,6 +802,7 @@ def modified_beam_search(
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: int = 4,
temperature: float = 1.0,
) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.

Expand All @@ -796,6 +816,8 @@ def modified_beam_search(
encoder_out before padding.
beam:
Number of active paths during the beam search.
temperature:
Softmax temperature.
Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance.
Expand Down Expand Up @@ -879,7 +901,9 @@ def modified_beam_search(

logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)

log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs = (logits / temperature).log_softmax(
dim=-1
) # (num_hyps, vocab_size)

log_probs.add_(ys_log_probs)

Expand Down Expand Up @@ -1043,6 +1067,7 @@ def beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
temperature: float = 1.0,
) -> List[int]:
"""
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
Expand All @@ -1056,6 +1081,8 @@ def beam_search(
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
temperature:
Softmax temperature.
Returns:
Return the decoded result.
"""
Expand Down Expand Up @@ -1132,7 +1159,7 @@ def beam_search(
)

# TODO(fangjun): Scale the blank posterior
log_prob = logits.log_softmax(dim=-1)
log_prob = (logits / temperature).log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)
log_prob = log_prob.squeeze()
# Now log_prob is (vocab_size,)
Expand Down Expand Up @@ -1171,3 +1198,155 @@ def beam_search(
best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
return ys


def fast_beam_search_with_nbest_rescoring(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
ngram_lm_scale_list: List[float],
num_paths: int,
G: k2.Fsa,
sp: spm.SentencePieceProcessor,
word_table: k2.SymbolTable,
oov_word: str = "<UNK>",
use_double_scores: bool = True,
nbest_scale: float = 0.5,
temperature: float = 1.0,
) -> Dict[str, List[List[int]]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this comment is accurate? It doesn't mention anything about n-best rescoring. Please make sure all args are well documented.

the shortest path within the lattice is used as the final output.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi.
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
ngram_lm_scale_list:
A list of floats representing LM score scales.
num_paths:
Number of paths to extract from the decoded lattice.
G:
An FsaVec containing only a single FSA. It is an n-gram LM.
sp:
The BPE model.
word_table:
The word symbol table.
oov_word:
OOV words are replaced with this word.
use_double_scores:
True to use double precision for computation. False to use
single precision.
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
temperature:
Softmax temperature.
Returns:
Return the decoded result in a dict, where the key has the form
'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the
ngram LM scale value used during decoding, i.e., 0.1.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
)

nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# at this point, nbest.fsa.scores are all zeros.

nbest = nbest.intersect(lattice)
# Now nbest.fsa.scores contains acoustic scores

am_scores = nbest.tot_scores()

# Now we need to compute the LM scores of each path.
# (1) Get the token IDs of each Path. We assume the decoding_graph
# is an acceptor, i.e., lattice is also an acceptor
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc]

tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous())
tokens = tokens.remove_values_leq(0) # remove -1 and 0

token_list: List[List[int]] = tokens.tolist()
word_list: List[List[str]] = sp.decode(token_list)

assert isinstance(oov_word, str), oov_word
assert oov_word in word_table, oov_word
oov_word_id = word_table[oov_word]

word_ids_list: List[List[int]] = []

for words in word_list:
this_word_ids = []
for w in words.split():
if w in word_table:
this_word_ids.append(word_table[w])
else:
this_word_ids.append(oov_word_id)
word_ids_list.append(this_word_ids)

word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device)
word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas)

num_unique_paths = len(word_ids_list)

b_to_a_map = torch.zeros(
num_unique_paths,
dtype=torch.int32,
device=lattice.device,
)

rescored_word_fsas = k2.intersect_device(
a_fsas=G,
b_fsas=word_fsas_with_self_loops,
b_to_a_map=b_to_a_map,
sorted_match_a=True,
ret_arc_maps=False,
)

rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas)
rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas))
ngram_lm_scores = rescored_word_fsas.get_tot_scores(
use_double_scores=True,
log_semiring=False,
)

ans: Dict[str, List[List[int]]] = {}
for s in ngram_lm_scale_list:
key = f"ngram_lm_scale_{s}"
tot_scores = am_scores.values + s * ngram_lm_scores
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)

ans[key] = hyps

return ans
Loading