Skip to content

Commit

Permalink
Refactor hub interface for batched inference (facebookresearch#1539) (f…
Browse files Browse the repository at this point in the history
…acebookresearch#1539)

Summary:
# Before submitting

- [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [x] Did you make sure to update the docs?
- [] Did you write any new necessary tests?

## What does this PR do?
Fixes facebookresearch#1508.
Pull Request resolved: facebookresearch#1539

Pulled By: myleott

Differential Revision: D19216104

fbshipit-source-id: 14917c1459b8794eeb74c09a16b9899c366242d2
  • Loading branch information
sai-prasanna authored and facebook-github-bot committed Dec 26, 2019
1 parent e82ffe4 commit a10cee8
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 41 deletions.
4 changes: 4 additions & 0 deletions examples/language_model/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ torch.hub.list('pytorch/fairseq') # [..., 'transformer_lm.wmt19.en', ...]

# Load an English LM trained on WMT'19 News Crawl data
en_lm = torch.hub.load('pytorch/fairseq', 'transformer_lm.wmt19.en', tokenizer='moses', bpe='fastbpe')
en_lm.eval() # disable dropout

# Move model to GPU
en_lm.cuda()

# Sample from the language model
en_lm.sample('Barack Obama', beam=1, sampling=True, sampling_topk=10, temperature=0.8)
Expand Down
8 changes: 8 additions & 0 deletions examples/translation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,21 @@ torch.hub.list('pytorch/fairseq') # [..., 'transformer.wmt16.en-de', ... ]

# Load a transformer trained on WMT'16 En-De
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt16.en-de', tokenizer='moses', bpe='subword_nmt')
en2de.eval() # disable dropout

# The underlying model is available under the *models* attribute
assert isinstance(en2de.models[0], fairseq.models.transformer.TransformerModel)

# Move model to GPU for faster translation
en2de.cuda()

# Translate a sentence
en2de.translate('Hello world!')
# 'Hallo Welt!'

# Batched translation
en2de.translate(['Hello world!', 'The cat sat on the mat.'])
# ['Hallo Welt!', 'Die Katze saß auf der Matte.']
```

Loading custom models:
Expand Down
109 changes: 68 additions & 41 deletions fairseq/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import argparse
import copy
import os
from typing import List, Dict, Iterator, Tuple, Any

import torch
from torch import nn
Expand Down Expand Up @@ -106,28 +107,46 @@ def __init__(self, args, task, models):
self.tokenizer = encoders.build_tokenizer(args)
self.bpe = encoders.build_bpe(args)

self.max_positions = utils.resolve_max_positions(
self.task.max_positions(), *[model.max_positions() for model in models]
)

# this is useful for determining the device
self.register_buffer('_float_tensor', torch.tensor([0], dtype=torch.float))

@property
def device(self):
return self._float_tensor.device

def translate(self, sentence: str, beam: int = 5, verbose: bool = False, **kwargs) -> str:
return self.sample(sentence, beam, verbose, **kwargs)
def translate(self, sentences: List[str], beam: int = 5, verbose: bool = False, **kwargs) -> List[str]:
return self.sample(sentences, beam, verbose, **kwargs)

def sample(self, sentence: str, beam: int = 1, verbose: bool = False, **kwargs) -> str:
input = self.encode(sentence)
hypo = self.generate(input, beam, verbose, **kwargs)[0]['tokens']
return self.decode(hypo)
def sample(self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs) -> List[str]:
if isinstance(sentences, str):
return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0]
tokenized_sentences = [self.encode(sentence) for sentence in sentences]
batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs)
return [self.decode(hypos[0]['tokens']) for hypos in batched_hypos]

def score(self, sentence: str, **kwargs):
def score(self, sentences: List[str], **kwargs):
if isinstance(sentences, str):
return self.score([sentences], **kwargs)[0]
# NOTE: this doesn't support translation tasks currently
input = self.encode(sentence)
return self.generate(input, score_reference=True, **kwargs)[0]

def generate(self, tokens: torch.LongTensor, beam: int = 5, verbose: bool = False, **kwargs) -> torch.LongTensor:
sample = self._build_sample(tokens)
tokenized_sentences = [self.encode(sentence) for sentence in sentences]
return [hypos[0] for hypos in self.generate(tokenized_sentences, score_reference=True, **kwargs)]

def generate(
self,
tokenized_sentences: List[torch.LongTensor],
beam: int = 5,
verbose: bool = False,
skip_invalid_size_inputs=False,
**kwargs
) -> List[List[Dict[str, torch.Tensor]]]:
if torch.is_tensor(tokenized_sentences) and tokenized_sentences.dim() == 1:
return self.generate(
tokenized_sentences.unsqueeze(0), beam=beam, verbose=verbose, **kwargs
)[0]

# build generator using current args as well as any kwargs
gen_args = copy.copy(self.args)
Expand All @@ -136,30 +155,35 @@ def generate(self, tokens: torch.LongTensor, beam: int = 5, verbose: bool = Fals
setattr(gen_args, k, v)
generator = self.task.build_generator(gen_args)

translations = self.task.inference_step(generator, self.models, sample)

if verbose:
src_str_with_unk = self.string(tokens)
print('S\t{}'.format(src_str_with_unk))
results = []
for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs):
batch = utils.apply_to_sample(lambda t: t.to(self.device), batch)
translations = self.task.inference_step(generator, self.models, batch)
for id, hypos in zip(batch["id"].tolist(), translations):
results.append((id, hypos))

def getarg(name, default):
return getattr(gen_args, name, getattr(self.args, name, default))
# sort output to match input order
outputs = [hypos for _, hypos in sorted(results, key=lambda x: x[0])]

# Process top predictions
hypos = translations[0]
if verbose:
for hypo in hypos:
hypo_str = self.decode(hypo['tokens'])
print('H\t{}\t{}'.format(hypo['score'], hypo_str))
print('P\t{}'.format(
' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist()))
))
if hypo['alignment'] is not None and getarg('print_alignment', False):
print('A\t{}'.format(
' '.join(map(lambda x: str(utils.item(x)), hypo['alignment'].int().cpu()))
))

return hypos
def getarg(name, default):
return getattr(gen_args, name, getattr(self.args, name, default))

for source_tokens, target_hypotheses in zip(tokenized_sentences, outputs):
src_str_with_unk = self.string(source_tokens)
print('S\t{}'.format(src_str_with_unk))
for hypo in target_hypotheses:
hypo_str = self.decode(hypo['tokens'])
print('H\t{}\t{}'.format(hypo['score'], hypo_str))
print('P\t{}'.format(
' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist()))
))
if hypo['alignment'] is not None and getarg('print_alignment', False):
print('A\t{}'.format(
' '.join(map(lambda x: str(utils.item(x)), hypo['alignment'].int().cpu()))
))
return outputs

def encode(self, sentence: str) -> torch.LongTensor:
sentence = self.tokenize(sentence)
Expand Down Expand Up @@ -197,15 +221,18 @@ def binarize(self, sentence: str) -> torch.LongTensor:
def string(self, tokens: torch.LongTensor) -> str:
return self.tgt_dict.string(tokens)

def _build_sample(self, src_tokens: torch.LongTensor):
assert torch.is_tensor(src_tokens)
dataset = self.task.build_dataset_for_inference([src_tokens], [src_tokens.numel()])
sample = dataset.collater([dataset[0]])
sample = utils.apply_to_sample(
lambda tensor: tensor.to(self.device),
sample
)
return sample
def _build_batches(
self, tokens: List[List[int]], skip_invalid_size_inputs: bool
) -> Iterator[Dict[str, Any]]:
lengths = torch.LongTensor([t.numel() for t in tokens])
batch_iterator = self.task.get_batch_iterator(
dataset=self.task.build_dataset_for_inference(tokens, lengths),
max_tokens=self.args.max_tokens,
max_sentences=self.args.max_sentences,
max_positions=self.max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs,
).next_epoch_itr(shuffle=False)
return batch_iterator


class BPEHubInterface(object):
Expand Down

0 comments on commit a10cee8

Please sign in to comment.