Skip to content

Commit

Permalink
Refactor hub interface for batched inference (#1539) (#1539)
Browse files Browse the repository at this point in the history
Summary:
# Before submitting

- [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [x] Did you make sure to update the docs?
- [] Did you write any new necessary tests?

## What does this PR do?
Fixes facebookresearch/fairseq#1508.
Pull Request resolved: facebookresearch/fairseq#1539

Pulled By: myleott

Differential Revision: D19216104

fbshipit-source-id: 14917c1459b8794eeb74c09a16b9899c366242d2
  • Loading branch information
sai-prasanna authored and yzpang committed Feb 19, 2021
1 parent 950dfc3 commit 7b002ea
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 0 deletions.
4 changes: 4 additions & 0 deletions examples/language_model/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ torch.hub.list('pytorch/fairseq') # [..., 'transformer_lm.wmt19.en', ...]

# Load an English LM trained on WMT'19 News Crawl data
en_lm = torch.hub.load('pytorch/fairseq', 'transformer_lm.wmt19.en', tokenizer='moses', bpe='fastbpe')
en_lm.eval() # disable dropout

# Move model to GPU
en_lm.cuda()

# Sample from the language model
en_lm.sample('Barack Obama', beam=1, sampling=True, sampling_topk=10, temperature=0.8)
Expand Down
8 changes: 8 additions & 0 deletions examples/translation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,21 @@ torch.hub.list('pytorch/fairseq') # [..., 'transformer.wmt16.en-de', ... ]

# Load a transformer trained on WMT'16 En-De
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt16.en-de', tokenizer='moses', bpe='subword_nmt')
en2de.eval() # disable dropout

# The underlying model is available under the *models* attribute
assert isinstance(en2de.models[0], fairseq.models.transformer.TransformerModel)

# Move model to GPU for faster translation
en2de.cuda()

# Translate a sentence
en2de.translate('Hello world!')
# 'Hallo Welt!'

# Batched translation
en2de.translate(['Hello world!', 'The cat sat on the mat.'])
# ['Hallo Welt!', 'Die Katze saß auf der Matte.']
```

Loading custom models:
Expand Down
1 change: 1 addition & 0 deletions fairseq/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import List, Dict, Iterator, Tuple, Any

import os
from typing import List, Dict, Iterator, Tuple, Any

import torch
from torch import nn
Expand Down

0 comments on commit 7b002ea

Please sign in to comment.