Skip to content

Commit

Permalink
lm rescoring attempt (#1242)
Browse files Browse the repository at this point in the history
Summary:
CUDA_VISIBLE_DEVICES=1 PYTHONPATH=/private/home/abaevski/fairseq-py-master python fairseq_cli/generate.py /checkpoint/henryzhou7/dataset/libri/960h/raw3/decoder --task audio_pretraining --seed 1 --nbest 1 --gen-subset dev_other --max-tokens 600000 --path ~/models/wav2vec2/vox_960h_seq2seq_10kwp.pt --labels 10k --remove-bpe 'wordpiece' --quiet --beam 50 --temperature 1 --scoring wer --lm-path /checkpoint/henryzhou7/wp_lm/transformer_raw3_adam_cosine2node/lr_1e-4_updatefreq_8/checkpoint_best.pt --lm-weight 1

results:

no lm: 4.30577896347444
lm (1.5): 24.691650853889943
lm (1): 10.884539582804846
lm (0.5): 4.894205665744457
lm (0.25): 4.012853671917862
lm (0.1): 4.087637055489084
lm (0.05): 4.194788887144875

Pull Request resolved: fairinternal/fairseq-py#1242

Reviewed By: kahne

Differential Revision: D23277386

Pulled By: alexeib

fbshipit-source-id: 062f483bd45ddd2dd5ff24a8a35cc1c4f34ce6ab
  • Loading branch information
alexeib authored and facebook-github-bot committed Oct 7, 2020
1 parent b880744 commit 5379461
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 8 deletions.
2 changes: 1 addition & 1 deletion examples/speech_recognition/tasks/speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def load_dataset(self, split, combine=False, **kwargs):
data_json_path = os.path.join(self.args.data, "{}.json".format(split))
self.datasets[split] = get_asr_dataset_from_json(data_json_path, self.tgt_dict)

def build_generator(self, models, args):
def build_generator(self, models, args, **unused):
w2l_decoder = getattr(args, "w2l_decoder", None)
if w2l_decoder == "viterbi":
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
Expand Down
5 changes: 5 additions & 0 deletions fairseq/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,11 @@ def add_generation_args(parser):
help='if set, uses attention feedback to compute and print alignment to source tokens')
group.add_argument('--print-step', action='store_true')

group.add_argument('--lm-path', default=None, type=str, metavar='PATH',
help='path to lm checkpoint for lm fusion')
group.add_argument('--lm-weight', default=0.0, type=float, metavar='N',
help='weight for lm probs for lm fusion')

# arguments for iterative refinement generator
group.add_argument('--iter-decode-eos-penalty', default=0.0, type=float, metavar='N',
help='if > 0.0, it penalized early-stopping in decoding.')
Expand Down
18 changes: 18 additions & 0 deletions fairseq/sequence_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def __init__(
search_strategy=None,
eos=None,
symbols_to_strip_from_output=None,
lm_model=None,
lm_weight=1.0
):
"""Generates translations of a given source sentence.
Expand Down Expand Up @@ -94,6 +96,11 @@ def __init__(

self.model.eval()

self.lm_model = lm_model
self.lm_weight = lm_weight
if self.lm_model is not None:
self.lm_model.eval()

def cuda(self):
self.model.cuda()
return self
Expand Down Expand Up @@ -292,6 +299,15 @@ def _generate(
incremental_states,
self.temperature,
)

if self.lm_model is not None:
lm_out = self.lm_model(tokens[:, : step + 1])
probs = self.lm_model.get_normalized_probs(
lm_out, log_probs=True, sample=None
)
probs = probs[:, -1, :] * self.lm_weight
lprobs += probs

lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)

lprobs[:, self.pad] = -math.inf # never select pad
Expand Down Expand Up @@ -820,9 +836,11 @@ def forward_decoder(
avg_attn = attn
else:
avg_attn.add_(attn)

avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log(
self.models_size
)

if avg_attn is not None:
avg_attn.div_(self.models_size)
return avg_probs, avg_attn
Expand Down
2 changes: 1 addition & 1 deletion fairseq/tasks/translation_from_pretrained_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
append_source_id=True
)

def build_generator(self, models, args):
def build_generator(self, models, args, **unused):
if getattr(args, 'score_reference', False):
from fairseq.sequence_scorer import SequenceScorer
return SequenceScorer(
Expand Down
2 changes: 1 addition & 1 deletion fairseq/tasks/translation_lev.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def _full_mask(target_tokens):
else:
raise NotImplementedError

def build_generator(self, models, args):
def build_generator(self, models, args, **unused):
# add models input to match the API for SequenceGenerator
from fairseq.iterative_refinement_generator import IterativeRefinementGenerator
return IterativeRefinementGenerator(
Expand Down
41 changes: 36 additions & 5 deletions fairseq_cli/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
Translate pre-processed data with a trained model.
"""

import ast
from itertools import chain
import logging
import math
import os
Expand Down Expand Up @@ -78,17 +80,39 @@ def _main(args, output_file):
src_dict = None
tgt_dict = task.target_dictionary

overrides = ast.literal_eval(args.model_overrides)

# Load ensemble
logger.info('loading model(s) from {}'.format(args.path))
models, _model_args = checkpoint_utils.load_model_ensemble(
utils.split_paths(args.path),
arg_overrides=eval(args.model_overrides),
arg_overrides=overrides,
task=task,
suffix=getattr(args, "checkpoint_suffix", ""),
)

if args.lm_path is not None:
overrides['data'] = args.data

try:
lms, _ = checkpoint_utils.load_model_ensemble(
[args.lm_path],
arg_overrides=overrides,
task=None,
)
except:
logger.warning(f"Failed to load language model! Please make sure that the language model dict is the same "
f"as target dict and is located in the data dir ({args.data})")
raise

assert len(lms) == 1
else:
lms = [None]

# Optimize ensemble for generation
for model in models:
for model in chain(models, lms):
if model is None:
continue
model.prepare_for_inference_(args)
if args.fp16:
model.half()
Expand Down Expand Up @@ -124,7 +148,12 @@ def _main(args, output_file):

# Initialize generator
gen_timer = StopwatchMeter()
generator = task.build_generator(models, args)

extra_gen_cls_kwargs = {
'lm_model': lms[0],
'lm_weight': args.lm_weight
}
generator = task.build_generator(models, args, extra_gen_cls_kwargs=extra_gen_cls_kwargs)

# Handle tokenization and BPE
tokenizer = encoders.build_tokenizer(args)
Expand Down Expand Up @@ -269,9 +298,11 @@ def decode_fn(x):
if has_target:
if args.bpe and not args.sacrebleu:
if args.remove_bpe:
logger.warning("BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization")
logger.warning(
"BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization")
else:
logger.warning("If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization")
logger.warning(
"If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization")
# use print to be consistent with other main outputs: S-, H-, T-, D- and so on
print(
'Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()),
Expand Down

0 comments on commit 5379461

Please sign in to comment.