From c37250ab1c845919af721cd3f5c4cec2993aefe1 Mon Sep 17 00:00:00 2001 From: Suvrat Bhooshan Date: Mon, 10 Dec 2018 23:49:09 -0800 Subject: [PATCH] Loading PreTrained Models (#406) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/406 Static helper function in TranslationTask to load pretrained models Reviewed By: myleott Differential Revision: D13345276 fbshipit-source-id: 3a675ee1a144ceb8b010f30e1a6163ef670b53f3 --- fairseq/tasks/translation.py | 20 +++++++++++++++++++- fairseq/utils.py | 15 ++++++++++----- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index 924d6fb7e6..7d4e62f063 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -9,7 +9,7 @@ import numpy as np import os -from fairseq import options +from fairseq import options, utils from fairseq.data import ( data_utils, Dictionary, LanguagePairDataset, ConcatDataset, IndexedRawTextDataset, IndexedCachedDataset, IndexedDataset @@ -63,6 +63,24 @@ def add_args(parser): help='amount to upsample primary dataset') # fmt: on + @staticmethod + def load_pretrained_model(path, src_dict_path, tgt_dict_path, arg_overrides=None): + model = utils.load_checkpoint_to_cpu(path) + args = model['args'] + state_dict = model['model'] + args = utils.override_model_args(args, arg_overrides) + src_dict = Dictionary.load(src_dict_path) + tgt_dict = Dictionary.load(tgt_dict_path) + assert src_dict.pad() == tgt_dict.pad() + assert src_dict.eos() == tgt_dict.eos() + assert src_dict.unk() == tgt_dict.unk() + + task = TranslationTask(args, src_dict, tgt_dict) + model = task.build_model(args) + model.upgrade_state_dict(state_dict) + model.load_state_dict(state_dict, strict=True) + return model + def __init__(self, args, src_dict, tgt_dict): super().__init__(args) self.src_dict = src_dict diff --git a/fairseq/utils.py b/fairseq/utils.py index 350958fd40..6a40e0e26b 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -131,6 +131,12 @@ def _upgrade_state_dict(state): return state +def load_checkpoint_to_cpu(path): + state = torch.load(path, map_location=lambda s, l: default_restore_location(s, 'cpu')) + state = _upgrade_state_dict(state) + return state + + def load_ensemble_for_inference(filenames, task, model_arg_overrides=None): """Load an ensemble of models for inference. @@ -143,8 +149,7 @@ def load_ensemble_for_inference(filenames, task, model_arg_overrides=None): for filename in filenames: if not os.path.exists(filename): raise IOError('Model file not found: {}'.format(filename)) - state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu')) - state = _upgrade_state_dict(state) + state = load_checkpoint_to_cpu(filename) states.append(state) ensemble = [] @@ -152,7 +157,7 @@ def load_ensemble_for_inference(filenames, task, model_arg_overrides=None): args = state['args'] if model_arg_overrides is not None: - args = _override_model_args(args, model_arg_overrides) + args = override_model_args(args, model_arg_overrides) # build model for ensemble model = task.build_model(args) @@ -162,12 +167,12 @@ def load_ensemble_for_inference(filenames, task, model_arg_overrides=None): # some args (e.g., tokens_per_sample) might have been updated while building the model if model_arg_overrides is not None: - args = _override_model_args(args, model_arg_overrides) + args = override_model_args(args, model_arg_overrides) return ensemble, args -def _override_model_args(args, model_arg_overrides): +def override_model_args(args, model_arg_overrides): # Uses model_arg_overrides {'arg_name': arg} to override model args for arg_name, arg_val in model_arg_overrides.items(): setattr(args, arg_name, arg_val)