Skip to content

Commit

Permalink
Add instructions for paraphrasing model (#1968)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #1968

Reviewed By: ngoyal2707

Differential Revision: D20860682

Pulled By: myleott

fbshipit-source-id: b7dced493410a4b9e217e4735eb9cdd0370ad47e
  • Loading branch information
myleott authored and facebook-github-bot committed Apr 7, 2020
1 parent 5feb564 commit 630701e
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 1 deletion.
46 changes: 46 additions & 0 deletions examples/paraphraser/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Paraphrasing with round-trip translation and mixture of experts

Machine translation models can be used to paraphrase text by translating it to
an intermediate language and back (round-trip translation).

This example shows how to paraphrase text by first passing it to an
English-French translation model, followed by a French-English [mixture of
experts translation model](/examples/translation_moe).

##### 0. Setup

Clone fairseq from source and install necessary dependencies:
```bash
git clone https://github.com/pytorch/fairseq.git
cd fairseq
pip install --editable .
pip install sacremoses sentencepiece
```

##### 1. Download models
```bash
wget https://dl.fbaipublicfiles.com/fairseq/models/paraphraser.en-fr.tar.gz
wget https://dl.fbaipublicfiles.com/fairseq/models/paraphraser.fr-en.hMoEup.tar.gz
tar -xzvf paraphraser.en-fr.tar.gz
tar -xzvf paraphraser.fr-en.hMoEup.tar.gz
```

##### 2. Paraphrase
```bash
python examples/paraphraser/paraphrase.py \
--en2fr paraphraser.en-fr \
--fr2en paraphraser.fr-en.hMoEup
# Example input:
# The new date for the Games, postponed for a year in response to the coronavirus pandemic, gives athletes time to recalibrate their training schedules.
# Example outputs:
# Delayed one year in response to the coronavirus pandemic, the new date of the Games gives athletes time to rebalance their training schedule.
# The new date of the Games, which was rescheduled one year in response to the coronavirus (CV) pandemic, gives athletes time to rebalance their training schedule.
# The new date of the Games, postponed one year in response to the coronavirus pandemic, provides athletes with time to rebalance their training schedule.
# The Games' new date, postponed one year in response to the coronavirus pandemic, gives athletes time to rebalance their training schedule.
# The new Games date, postponed one year in response to the coronavirus pandemic, gives the athletes time to rebalance their training schedule.
# The new date of the Games, which was postponed one year in response to the coronavirus pandemic, gives the athletes time to rebalance their training schedule.
# The new date of the Games, postponed one year in response to the coronavirus pandemic, gives athletes time to rebalance their training schedule.
# The new date of the Games, postponed one year in response to the coronavirus pandemic, gives athletes time to re-balance their training schedule.
# The new date of the Games, postponed one year in response to the coronavirus pandemic, gives the athletes time to rebalance their schedule of training.
# The new date of the Games, postponed one year in response to the pandemic of coronavirus, gives the athletes time to rebalance their training schedule.
```
76 changes: 76 additions & 0 deletions examples/paraphraser/paraphrase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#!/usr/bin/env python3 -u

import argparse
import fileinput
import logging
import os
import sys

from fairseq.models.transformer import TransformerModel


logging.getLogger().setLevel(logging.INFO)


def main():
parser = argparse.ArgumentParser(description='')
parser.add_argument('--en2fr', required=True,
help='path to en2fr model')
parser.add_argument('--fr2en', required=True,
help='path to fr2en mixture of experts model')
parser.add_argument('--user-dir',
help='path to fairseq examples/translation_moe/src directory')
parser.add_argument('--num-experts', type=int, default=10,
help='(keep at 10 unless using a different model)')
parser.add_argument('files', nargs='*', default=['-'],
help='input files to paraphrase; "-" for stdin')
args = parser.parse_args()

if args.user_dir is None:
args.user_dir = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # examples/
'translation_moe',
'src',
)
if os.path.exists(args.user_dir):
logging.info('found user_dir:' + args.user_dir)
else:
raise RuntimeError(
'cannot find fairseq examples/translation_moe/src '
'(tried looking here: {})'.format(args.user_dir)
)

logging.info('loading en2fr model from:' + args.en2fr)
en2fr = TransformerModel.from_pretrained(
model_name_or_path=args.en2fr,
tokenizer='moses',
bpe='sentencepiece',
).eval()

logging.info('loading fr2en model from:' + args.fr2en)
fr2en = TransformerModel.from_pretrained(
model_name_or_path=args.fr2en,
tokenizer='moses',
bpe='sentencepiece',
user_dir=args.user_dir,
task='translation_moe',
).eval()

def gen_paraphrases(en):
fr = en2fr.translate(en)
return [
fr2en.translate(fr, inference_step_args={'expert': i})
for i in range(args.num_experts)
]

logging.info('Type the input sentence and press return:')
for line in fileinput.input(args.files):
line = line.strip()
if len(line) == 0:
continue
for paraphrase in gen_paraphrases(line):
print(paraphrase)


if __name__ == '__main__':
main()
6 changes: 5 additions & 1 deletion fairseq/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def generate(
beam: int = 5,
verbose: bool = False,
skip_invalid_size_inputs=False,
inference_step_args=None,
**kwargs
) -> List[List[Dict[str, torch.Tensor]]]:
if torch.is_tensor(tokenized_sentences) and tokenized_sentences.dim() == 1:
Expand All @@ -159,10 +160,13 @@ def generate(
setattr(gen_args, k, v)
generator = self.task.build_generator(self.models, gen_args)

inference_step_args = inference_step_args or {}
results = []
for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs):
batch = utils.apply_to_sample(lambda t: t.to(self.device), batch)
translations = self.task.inference_step(generator, self.models, batch)
translations = self.task.inference_step(
generator, self.models, batch, **inference_step_args
)
for id, hypos in zip(batch["id"].tolist(), translations):
results.append((id, hypos))

Expand Down

0 comments on commit 630701e

Please sign in to comment.