Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor hub interface for batched inference #1539

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ We also provide [pre-trained models for translation and language modeling](#pre-
with a convenient `torch.hub` interface:
```python
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model')
en2de.translate('Hello world', beam=5)
en2de.translate(['Hello world'], beam=5)
# 'Hallo Welt'
```
See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/)
Expand Down
2 changes: 1 addition & 1 deletion examples/backtranslation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ len(en2de_ensemble.models)
# 5

# Translate
en2de_ensemble.translate('Hello world!')
en2de_ensemble.translate(['Hello world!'])
# 'Hallo Welt!'
```

Expand Down
10 changes: 5 additions & 5 deletions examples/language_model/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,18 @@ torch.hub.list('pytorch/fairseq') # [..., 'transformer_lm.wmt19.en', ...]
en_lm = torch.hub.load('pytorch/fairseq', 'transformer_lm.wmt19.en', tokenizer='moses', bpe='fastbpe')

# Sample from the language model
en_lm.sample('Barack Obama', beam=1, sampling=True, sampling_topk=10, temperature=0.8)
# "Barack Obama is coming to Sydney and New Zealand (...)"
en_lm.sample(['Barack Obama'], beam=1, sampling=True, sampling_topk=10, temperature=0.8)
# ["Barack Obama is coming to Sydney and New Zealand (...)"]

# Compute perplexity for a sequence
en_lm.score('Barack Obama is coming to Sydney and New Zealand')['positional_scores'].mean().neg().exp()
en_lm.score(['Barack Obama is coming to Sydney and New Zealand'])[0]['positional_scores'].mean().neg().exp()
# tensor(15.1474)

# The same interface can be used with custom models as well
from fairseq.models.transformer_lm import TransformerLanguageModel
custom_lm = TransformerLanguageModel.from_pretrained('/path/to/model/dir', 'checkpoint100.pt', tokenizer='moses', bpe='fastbpe')
custom_lm.sample('Barack Obama', beam=5)
# "Barack Obama (...)"
custom_lm.sample(['Barack Obama'], beam=5)
# ["Barack Obama (...)"]
```

## Training a transformer language model with the CLI tools
Expand Down
4 changes: 2 additions & 2 deletions examples/pay_less_attention_paper/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ zh2en = torch.hub.load('pytorch/fairseq', 'lightconv.glu.wmt17.zh-en', tokenizer
assert isinstance(zh2en.models[0], fairseq.models.lightconv.LightConvModel)

# Translate a sentence
zh2en.translate('你好 世界')
zh2en.translate(['你好 世界'])
# 'Hello World'
```

Expand All @@ -84,7 +84,7 @@ en2fr = LightConvModel.from_pretrained(
bpe='subword_nmt',
bpe_codes='data-bin/wmt14_en_fr/en.code'
)
en2fr.translate('Hello world!')
en2fr.translate(['Hello world!'])
# 'Bonjour le monde'
```

Expand Down
2 changes: 1 addition & 1 deletion examples/translation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt16.en-de', tokenizer='
assert isinstance(en2de.models[0], fairseq.models.transformer.TransformerModel)

# Translate a sentence
en2de.translate('Hello world!')
en2de.translate(['Hello world!'])
# 'Hallo Welt!'
```

Expand Down
14 changes: 7 additions & 7 deletions examples/wmt19/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,38 +31,38 @@ import torch
# English to German translation
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de', checkpoint_file='model1.pt:model2.pt:model3.pt:model4.pt',
tokenizer='moses', bpe='fastbpe')
en2de.translate("Machine learning is great!") # 'Maschinelles Lernen ist großartig!'
en2de.translate(["Machine learning is great!"]) # ['Maschinelles Lernen ist großartig!']

# German to English translation
de2en = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.de-en', checkpoint_file='model1.pt:model2.pt:model3.pt:model4.pt',
tokenizer='moses', bpe='fastbpe')
de2en.translate("Maschinelles Lernen ist großartig!") # 'Machine learning is great!'
de2en.translate(["Maschinelles Lernen ist großartig!"]) # ['Machine learning is great!']

# English to Russian translation
en2ru = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-ru', checkpoint_file='model1.pt:model2.pt:model3.pt:model4.pt',
tokenizer='moses', bpe='fastbpe')
en2ru.translate("Machine learning is great!") # 'Машинное обучение - это здорово!'
en2ru.translate(["Machine learning is great!"]) # ['Машинное обучение - это здорово!']

# Russian to English translation
ru2en = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.ru-en', checkpoint_file='model1.pt:model2.pt:model3.pt:model4.pt',
tokenizer='moses', bpe='fastbpe')
ru2en.translate("Машинное обучение - это здорово!") # 'Machine learning is great!'
ru2en.translate(["Машинное обучение - это здорово!"]) # ['Machine learning is great!']
```

#### Language Modeling

```python
# Sample from the English LM
en_lm = torch.hub.load('pytorch/fairseq', 'transformer_lm.wmt19.en', tokenizer='moses', bpe='fastbpe')
en_lm.sample("Machine learning is") # 'Machine learning is the future of computing, says Microsoft boss Satya Nadella ...'
en_lm.sample(["Machine learning is"]) # ['Machine learning is the future of computing, says Microsoft boss Satya Nadella ...']

# Sample from the German LM
de_lm = torch.hub.load('pytorch/fairseq', 'transformer_lm.wmt19.de', tokenizer='moses', bpe='fastbpe')
de_lm.sample("Maschinelles lernen ist") # 'Maschinelles lernen ist das A und O (neues-deutschland.de) Die Arbeitsbedingungen für Lehrerinnen und Lehrer sind seit Jahren verbesserungswürdig ...'
de_lm.sample(["Maschinelles lernen ist"]) # ['Maschinelles lernen ist das A und O (neues-deutschland.de) Die Arbeitsbedingungen für Lehrerinnen und Lehrer sind seit Jahren verbesserungswürdig ...']

# Sample from the Russian LM
ru_lm = torch.hub.load('pytorch/fairseq', 'transformer_lm.wmt19.ru', tokenizer='moses', bpe='fastbpe')
ru_lm.sample("машинное обучение это") # 'машинное обучение это то, что мы называем "искусственным интеллектом".'
ru_lm.sample(["машинное обучение это"]) # ['машинное обучение это то, что мы называем "искусственным интеллектом".']
```

## Citation
Expand Down
95 changes: 52 additions & 43 deletions fairseq/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import argparse
import copy
import os
from typing import List, Dict, Iterator, Tuple, Any

import torch
from torch import nn
Expand Down Expand Up @@ -106,60 +107,68 @@ def __init__(self, args, task, models):
self.tokenizer = encoders.build_tokenizer(args)
self.bpe = encoders.build_bpe(args)

self.max_positions = utils.resolve_max_positions(
self.task.max_positions(), *[model.max_positions() for model in models]
)

# this is useful for determining the device
self.register_buffer('_float_tensor', torch.tensor([0], dtype=torch.float))

@property
def device(self):
return self._float_tensor.device

def translate(self, sentence: str, beam: int = 5, verbose: bool = False, **kwargs) -> str:
return self.sample(sentence, beam, verbose, **kwargs)
def translate(self, sentences: List[str], beam: int = 5, verbose: bool = False, **kwargs) -> List[str]:
return self.sample(sentences, beam, verbose, **kwargs)

def sample(self, sentence: str, beam: int = 1, verbose: bool = False, **kwargs) -> str:
input = self.encode(sentence)
hypo = self.generate(input, beam, verbose, **kwargs)[0]['tokens']
return self.decode(hypo)
def sample(self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs) -> List[str]:
tokenized_sentences = [self.encode(sentence) for sentence in sentences]
batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs)
return [self.decode(hypos[0]['tokens']) for hypos in batched_hypos]

def score(self, sentence: str, **kwargs):
def score(self, sentences: List[str], **kwargs):
# NOTE: this doesn't support translation tasks currently
input = self.encode(sentence)
return self.generate(input, score_reference=True, **kwargs)[0]

def generate(self, tokens: torch.LongTensor, beam: int = 5, verbose: bool = False, **kwargs) -> torch.LongTensor:
sample = self._build_sample(tokens)
tokenized_sentences = [self.encode(sentence) for sentence in sentences]
return [hypos[0] for hypos in self.generate([input], score_reference=True, **kwargs)]

def generate(self, tokenized_sentences: List[torch.LongTensor], beam: int = 5, verbose: bool = False, **kwargs) -> List[List[Dict[str, torch.Tensor]]]:
# build generator using current args as well as any kwargs
gen_args = copy.copy(self.args)
gen_args.beam = beam
for k, v in kwargs.items():
setattr(gen_args, k, v)
generator = self.task.build_generator(gen_args)

translations = self.task.inference_step(generator, self.models, sample)

if verbose:
src_str_with_unk = self.string(tokens)
print('S\t{}'.format(src_str_with_unk))

def getarg(name, default):
return getattr(gen_args, name, getattr(self.args, name, default))
results = []
for batch in self._build_batches(tokenized_sentences):
for k, input_tensor in batch["net_input"].items():
batch["net_input"][k] = input_tensor.to(self.device)
translations = self.task.inference_step(
generator, self.models, batch
)
for (iden, hypos) in zip(batch["id"].tolist(), translations):
results.append((iden, hypos))

# sort output to match input order
outputs = [hypos for (_, hypos) in sorted(results, key=lambda x: x[0])]

# Process top predictions
hypos = translations[0]
if verbose:
for hypo in hypos:
hypo_str = self.decode(hypo['tokens'])
print('H\t{}\t{}'.format(hypo['score'], hypo_str))
print('P\t{}'.format(
' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist()))
))
if hypo['alignment'] is not None and getarg('print_alignment', False):
print('A\t{}'.format(
' '.join(map(lambda x: str(utils.item(x)), hypo['alignment'].int().cpu()))
def getarg(name, default):
return getattr(gen_args, name, getattr(self.args, name, default))
for source_tokens, target_hypotheses in zip(tokenized_sentences, outputs):
src_str_with_unk = self.string(source_tokens)
print('S\t{}'.format(src_str_with_unk))
for hypo in target_hypotheses:
hypo_str = self.decode(hypo['tokens'])
print('H\t{}\t{}'.format(hypo['score'], hypo_str))
print('P\t{}'.format(
' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist()))
))

return hypos
if hypo['alignment'] is not None and getarg('print_alignment', False):
print('A\t{}'.format(
' '.join(map(lambda x: str(utils.item(x)), hypo['alignment'].int().cpu()))
))
return outputs

def encode(self, sentence: str) -> torch.LongTensor:
sentence = self.tokenize(sentence)
Expand Down Expand Up @@ -196,16 +205,16 @@ def binarize(self, sentence: str) -> torch.LongTensor:

def string(self, tokens: torch.LongTensor) -> str:
return self.tgt_dict.string(tokens)

def _build_sample(self, src_tokens: torch.LongTensor):
assert torch.is_tensor(src_tokens)
dataset = self.task.build_dataset_for_inference([src_tokens], [src_tokens.numel()])
sample = dataset.collater([dataset[0]])
sample = utils.apply_to_sample(
lambda tensor: tensor.to(self.device),
sample
)
return sample
def _build_batches(self, tokens: List[List[int]]) -> Iterator[Dict[str, Any]]:
lengths = torch.LongTensor([t.numel() for t in tokens])
batch_iterator = self.task.get_batch_iterator(
dataset=self.task.build_dataset_for_inference(tokens, lengths),
max_tokens=self.args.max_tokens,
max_sentences=self.args.max_sentences,
max_positions=self.max_positions,
).next_epoch_itr(shuffle=False)
return batch_iterator


class BPEHubInterface(object):
Expand Down