Skip to content

Commit

Permalink
Loading PreTrained Models (facebookresearch#406)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#406

Static helper function in TranslationTask to load pretrained models

Reviewed By: myleott

Differential Revision: D13345276

fbshipit-source-id: 3a675ee1a144ceb8b010f30e1a6163ef670b53f3
  • Loading branch information
Suvrat Bhooshan authored and facebook-github-bot committed Dec 11, 2018
1 parent 00e47d7 commit c37250a
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 6 deletions.
20 changes: 19 additions & 1 deletion fairseq/tasks/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
import os

from fairseq import options
from fairseq import options, utils
from fairseq.data import (
data_utils, Dictionary, LanguagePairDataset, ConcatDataset,
IndexedRawTextDataset, IndexedCachedDataset, IndexedDataset
Expand Down Expand Up @@ -63,6 +63,24 @@ def add_args(parser):
help='amount to upsample primary dataset')
# fmt: on

@staticmethod
def load_pretrained_model(path, src_dict_path, tgt_dict_path, arg_overrides=None):
model = utils.load_checkpoint_to_cpu(path)
args = model['args']
state_dict = model['model']
args = utils.override_model_args(args, arg_overrides)
src_dict = Dictionary.load(src_dict_path)
tgt_dict = Dictionary.load(tgt_dict_path)
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk()

task = TranslationTask(args, src_dict, tgt_dict)
model = task.build_model(args)
model.upgrade_state_dict(state_dict)
model.load_state_dict(state_dict, strict=True)
return model

def __init__(self, args, src_dict, tgt_dict):
super().__init__(args)
self.src_dict = src_dict
Expand Down
15 changes: 10 additions & 5 deletions fairseq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@ def _upgrade_state_dict(state):
return state


def load_checkpoint_to_cpu(path):
state = torch.load(path, map_location=lambda s, l: default_restore_location(s, 'cpu'))
state = _upgrade_state_dict(state)
return state


def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
"""Load an ensemble of models for inference.
Expand All @@ -143,16 +149,15 @@ def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
for filename in filenames:
if not os.path.exists(filename):
raise IOError('Model file not found: {}'.format(filename))
state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
state = _upgrade_state_dict(state)
state = load_checkpoint_to_cpu(filename)
states.append(state)

ensemble = []
for state in states:
args = state['args']

if model_arg_overrides is not None:
args = _override_model_args(args, model_arg_overrides)
args = override_model_args(args, model_arg_overrides)

# build model for ensemble
model = task.build_model(args)
Expand All @@ -162,12 +167,12 @@ def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):

# some args (e.g., tokens_per_sample) might have been updated while building the model
if model_arg_overrides is not None:
args = _override_model_args(args, model_arg_overrides)
args = override_model_args(args, model_arg_overrides)

return ensemble, args


def _override_model_args(args, model_arg_overrides):
def override_model_args(args, model_arg_overrides):
# Uses model_arg_overrides {'arg_name': arg} to override model args
for arg_name, arg_val in model_arg_overrides.items():
setattr(args, arg_name, arg_val)
Expand Down

0 comments on commit c37250a

Please sign in to comment.