Skip to content

Commit

Permalink
word level lm rescore
Browse files Browse the repository at this point in the history
  • Loading branch information
glynpu committed May 18, 2022
1 parent ac9655c commit 3a9ff31
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 6 deletions.
35 changes: 35 additions & 0 deletions egs/librispeech/ASR/lm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
stage=1

text=data/local/lm/librispeech-lm-norm.txt.gz
text_dir=data/lm/text
all_train_text=$text_dir/librispeech.txt
# there are 40,398,052 pieces in all_train_text, which will take 50 MINUTES to be tokenized, with a single process.
# use $train_pieces data to validate pipeline
# train_pieces=300000 # 15 times of dev.txt
# uncomment follwoing line to use all_train_text
train_pieces=
dev_text=$text_dir/dev.txt
if [ $stage -le 0 ]; then
# reference:
# https://github.com/kaldi-asr/kaldi/blob/pybind11/egs/librispeech/s5/local/rnnlm/tuning/run_tdnn_lstm_1a.sh#L75
# use the same data seperation method to kaldi whose result can be used as a baseline
if [ ! -f $text ]; then
wget http://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz -P data/local/lm
fi
echo -n >$text_dir/dev.txt
# hold out one in every 2000 lines as dev data.
gunzip -c $text | cut -d ' ' -f2- | awk -v text_dir=$text_dir '{if(NR%2000 == 0) { print >text_dir"/dev.txt"; } else {print;}}' >$all_train_text
fi

if [ $stage -eq 1 ]; then
# for text_file in dev.txt librispeech.txt; do
# python ./vq_pruned_transducer_stateless2/tokenize_text.py \
# --tokenizer-path ./data/lang_bpe_500/bpe.model \
# --text-file ./data/lm/text/$text_file
# done
lmplz -o 4 --text data/lm/text/librispeech.txt --arpa train.arpa -S 10%
# lmplz -o 4 --text data/lm/text/librispeech.txt --arpa discount_train.arpa -S 10% \
# --discount_fallback
# lmplz -o 4 --text data/lm/text/librispeech.txt.tokens --arpa token_train.arpa -S 10% \
# --discount_fallback 0.5
fi
16 changes: 13 additions & 3 deletions egs/librispeech/ASR/vq_pruned_transducer_stateless2/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Dict, List, Optional

import k2
import kenlm
import torch
from model import Transducer

Expand Down Expand Up @@ -267,6 +268,9 @@ class Hypothesis:
# It contains only one entry.
log_prob: torch.Tensor

last_start_idx: int
state: None # lm state

@property
def key(self) -> str:
"""Return a string representation of self.ys"""
Expand Down Expand Up @@ -637,6 +641,7 @@ def beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
lmr = None, # lm rescorer
) -> List[int]:
"""
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
Expand Down Expand Up @@ -677,7 +682,10 @@ def beam_search(
t = 0

B = HypothesisList()
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0))

start_state= kenlm.State()
lmr.lm.BeginSentenceWrite(start_state)
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, last_start_idx=0, state=start_state))

max_sym_per_utt = 20000

Expand Down Expand Up @@ -738,7 +746,7 @@ def beam_search(
new_y_star_log_prob = y_star.log_prob + skip_log_prob

# ys[:] returns a copy of ys
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob, last_start_idx=y_star.last_start_idx, state=y_star.state))

# Second, process other non-blank labels
values, indices = log_prob.topk(beam + 1)
Expand All @@ -747,7 +755,9 @@ def beam_search(
continue
new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
tmp = Hypothesis(ys=new_ys, log_prob=new_log_prob, last_start_idx=y_star.last_start_idx, state=y_star.state)
lmr.rescore(tmp)
A.add(tmp)

# Check whether B contains more than "beam" elements more probable
# than the most probable in A
Expand Down
25 changes: 22 additions & 3 deletions egs/librispeech/ASR/vq_pruned_transducer_stateless2/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from typing import Dict, List, Optional, Tuple

import k2
import kenlm
import sentencepiece as spm
import torch
import torch.nn as nn
Expand All @@ -88,12 +89,19 @@
write_error_stats,
)

from lm import LMRescorer

def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)

parser.add_argument(
"--lm-weight",
type=float,
default=0.0,
)

parser.add_argument(
"--epoch",
type=int,
Expand Down Expand Up @@ -206,6 +214,7 @@ def decode_one_batch(
sp: spm.SentencePieceProcessor,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
lmr=None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
Expand Down Expand Up @@ -298,6 +307,7 @@ def decode_one_batch(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
lmr=lmr,
)
else:
raise ValueError(
Expand Down Expand Up @@ -325,6 +335,7 @@ def decode_dataset(
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
lmr=None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Expand Down Expand Up @@ -369,6 +380,7 @@ def decode_dataset(
sp=sp,
decoding_graph=decoding_graph,
batch=batch,
lmr=lmr,
)

for name, hyps in hyps_dict.items():
Expand Down Expand Up @@ -399,19 +411,19 @@ def save_results(
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir / f"lm_weight-{params.lm_weight}-recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")

# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir / f"lm_weight-{params.lm_weight}-errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
f, f"lm_weight-{params.lm_weight}-{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer

Expand Down Expand Up @@ -479,6 +491,8 @@ def main():
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
LM = "/ceph-data2/ly/kenlm/train_lm/train.bin"
params.lm_path = f'{LM}'

logging.info(params)

Expand Down Expand Up @@ -506,6 +520,10 @@ def main():
model.eval()
model.device = device

lm_model = kenlm.LanguageModel(LM)

lmr=LMRescorer(Path("./data/lang_bpe_500/"), blank_id = model.decoder.blank_id, lm=lm_model, weight=params.lm_weight)

if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
Expand All @@ -532,6 +550,7 @@ def main():
model=model,
sp=sp,
decoding_graph=decoding_graph,
lmr=lmr,
)

save_results(
Expand Down
67 changes: 67 additions & 0 deletions egs/librispeech/ASR/vq_pruned_transducer_stateless2/lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from pathlib import Path
from icefall.lexicon import read_lexicon
import sentencepiece as spm
import kenlm

def extract_start_tokens(lang_dir: Path = Path("./data/lang_bpe_500/"):
tokens = read_lexicon(lang_dir / "/tokens.txt")

# Get the leading underscore of '▁THE 4'.
# Actually its not a underscore, its just looks similar to it.
word_start_char = tokens[4][0][0]

word_start_token = []
non_start_token = []

aux=['<sos/eos>', '<unk>']
for t in tokens:
leading_char = t[0][0]
if leading_char == word_start_char or t[0] in aux:
word_start_token.append(t)
else:
non_start_token.append(t)

write_lexicon(lang_dir / "word_start_tokens.txt", word_start_token)
write_lexicon(lang_dir / "non_start_tokens.txt", non_start_token)

def lexicon_to_dict(lexicon):
token2idx = {}
idx2token = {}
for token, idx in lexicon:
assert len(idx) == 1
idx = idx[0]
token2idx[token] = int(idx)
idx2token[int(idx)] = token
return token2idx, idx2token


class LMRescorer:
def __init__(self, lang_dir, blank_id, lm, weight):
self.lm=lm
self.start_token2idx, self.start_idx2token = lexicon_to_dict(read_lexicon(lang_dir/"word_start_tokens.txt"))
self.nonstart_token2idx, self.nonstart_idx2token = lexicon_to_dict(read_lexicon(lang_dir/"non_start_tokens.txt"))
self.token2idx, self.idx2token = lexicon_to_dict(read_lexicon(lang_dir/"tokens.txt"))
self.sp = spm.SentencePieceProcessor()
self.sp.load(str(lang_dir/"bpe.model"))
self.blank_id = blank_id
self.weight = weight

def rescore(self, hyp):
if self.weight > 0 and hyp.ys[-1] in self.start_idx2token:
word = self.previous_word(hyp)
output_state= kenlm.State()
lm_score = self.lm.BaseScore(hyp.state, word, output_state)
hyp.state = output_state
hyp.log_prob += self.weight * lm_score
return hyp

def previous_word(self, hyp):
last_start_idx = hyp.last_start_idx
tokens_seq = hyp.ys[last_start_idx: -1]
tokens_seq = [t for t in tokens_seq if t!=self.blank_id]
word = self.sp.decode(tokens_seq)
hyp.last_start_idx = len(hyp.ys) - 1
return word



0 comments on commit 3a9ff31

Please sign in to comment.