Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Shallow fusion in modified_beam_search #630

Merged
merged 5 commits into from
Oct 21, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions egs/librispeech/ASR/generate-lm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/usr/bin/env bash

lang_dir=data/lang_bpe_500
if [ ! -f $lang_dir/2gram.arpa ]; then
ezerhouni marked this conversation as resolved.
Show resolved Hide resolved
./shared/make_kn_lm.py \
-ngram-order 2 \
-text $lang_dir/transcript_tokens.txt \
-lm $lang_dir/2gram.arpa
fi

if [ ! -f $lang_dir/2gram.fst.txt ]; then
python3 -m kaldilm \
--read-symbol-table="$lang_dir/tokens.txt" \
--disambig-symbol='#0' \
--max-order=2 \
$lang_dir/2gram.arpa > $lang_dir/2gram.fst.txt
fi

if [ ! -f $lang_dir/3gram.arpa ]; then
./shared/make_kn_lm.py \
-ngram-order 3 \
-text $lang_dir/transcript_tokens.txt \
-lm $lang_dir/3gram.arpa
fi

if [ ! -f $lang_dir/3gram.fst.txt ]; then
python3 -m kaldilm \
--read-symbol-table="$lang_dir/tokens.txt" \
--disambig-symbol='#0' \
--max-order=3 \
$lang_dir/3gram.arpa > $lang_dir/3gram.fst.txt
fi


if [ ! -f $lang_dir/5gram.arpa ]; then
./shared/make_kn_lm.py \
-ngram-order 5 \
-text $lang_dir/transcript_tokens.txt \
-lm $lang_dir/5gram.arpa
fi

if [ ! -f $lang_dir/5gram.fst.txt ]; then
python3 -m kaldilm \
--read-symbol-table="$lang_dir/tokens.txt" \
--disambig-symbol='#0' \
--max-order=5 \
$lang_dir/5gram.arpa > $lang_dir/5gram.fst.txt
fi
33 changes: 33 additions & 0 deletions egs/librispeech/ASR/lstm_transducer_stateless2/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,12 @@
greedy_search,
greedy_search_batch,
modified_beam_search,
modified_beam_search_ngram_rescoring,
)
from librispeech import LibriSpeech
from train import add_model_arguments, get_params, get_transducer_model

from icefall import NgramLm
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
Expand Down Expand Up @@ -214,6 +216,7 @@ def get_parser():
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
- modified_beam_search_ngram_rescoring
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
Expand Down Expand Up @@ -315,6 +318,8 @@ def decode_one_batch(
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
Expand Down Expand Up @@ -448,6 +453,17 @@ def decode_one_batch(
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_ngram_rescoring":
hyp_tokens = modified_beam_search_ngram_rescoring(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)

Expand Down Expand Up @@ -497,6 +513,8 @@ def decode_dataset(
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.

Expand Down Expand Up @@ -546,6 +564,8 @@ def decode_dataset(
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
)

for name, hyps in hyps_dict.items():
Expand Down Expand Up @@ -631,6 +651,7 @@ def main():
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
"modified_beam_search_ngram_rescoring",
)
params.res_dir = params.exp_dir / params.decoding_method

Expand All @@ -655,6 +676,7 @@ def main():
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"

if params.use_averaged_model:
params.suffix += "-use-averaged-model"
Expand Down Expand Up @@ -768,6 +790,15 @@ def main():
model.to(device)
model.eval()

lm_filename = "5ram.fst.txt"
ezerhouni marked this conversation as resolved.
Show resolved Hide resolved
logging.info(f"lm filename: {lm_filename}")
ngram_lm = NgramLm(
str(params.lang_dir / lm_filename),
backoff_id=500,
ezerhouni marked this conversation as resolved.
Show resolved Hide resolved
is_binary=False,
)
logging.info(f"num states: {ngram_lm.lm.num_states}")

if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
Expand Down Expand Up @@ -812,6 +843,8 @@ def main():
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
ngram_lm=ngram_lm,
ngram_lm_scale=params.ngram_lm_scale,
)

save_results(
Expand Down
173 changes: 173 additions & 0 deletions egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch
from model import Transducer

from icefall import NgramLm, NgramLmStateCost
from icefall.decode import Nbest, one_best_decoding
from icefall.utils import add_eos, add_sos, get_texts

Expand Down Expand Up @@ -656,6 +657,8 @@ class Hypothesis:
# It contains only one entry.
log_prob: torch.Tensor

state_cost: Optional[NgramLmStateCost] = None

@property
def key(self) -> str:
"""Return a string representation of self.ys"""
Expand Down Expand Up @@ -1539,3 +1542,173 @@ def fast_beam_search_with_nbest_rnn_rescoring(
ans[key] = hyps

return ans


def modified_beam_search_ngram_rescoring(
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ngram_lm: NgramLm,
ngram_lm_scale: float,
beam: int = 4,
temperature: float = 1.0,
) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.

Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C).
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
beam:
Number of active paths during the beam search.
temperature:
Softmax temperature.
Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)

packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)

blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = next(model.parameters()).device
lm_scale = ngram_lm_scale

batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)

B = [HypothesisList() for _ in range(N)]
for i in range(N):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
state_cost=NgramLmStateCost(ngram_lm),
)
)

encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)

offset = 0
finalized_B = []
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end

finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size]

hyps_shape = get_hyps_shape(B).to(device)

A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]

ys_log_probs = torch.cat(
[
hyp.log_prob.reshape(1, 1) + hyp.state_cost.lm_score * lm_scale
for hyps in A
for hyp in hyps
]
) # (num_hyps, 1)

decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)

decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)

# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, 1, 1, encoder_out_dim)

logits = model.joiner(
current_encoder_out,
decoder_out,
project_input=False,
) # (num_hyps, 1, 1, vocab_size)

logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)

log_probs = (logits / temperature).log_softmax(
dim=-1
) # (num_hyps, vocab_size)

log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)

row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)

for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)

with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()

for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]

new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token not in (blank_id, unk_id):
new_ys.append(new_token)
state_cost = hyp.state_cost.forward_one_step(new_token)
else:
state_cost = hyp.state_cost

# We only keep AM scores in new_hyp.log_prob
new_log_prob = (
topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale
)

new_hyp = Hypothesis(
ys=new_ys, log_prob=new_log_prob, state_cost=state_cost
)
B[i].add(new_hyp)

B = B + finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B]

sorted_ans = [h.ys[context_size:] for h in best_hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])

return ans
2 changes: 2 additions & 0 deletions icefall/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,5 @@
subsequent_chunk_mask,
write_error_stats,
)

from .ngram_lm import NgramLm, NgramLmStateCost
Loading