From adf6fcad61fb843fba3b1836424e59a319e9bed9 Mon Sep 17 00:00:00 2001 From: gongenlei Date: Mon, 26 Sep 2022 20:20:15 +0800 Subject: [PATCH] Fix erniegen no model_config_file (#3321) * fix * rm save_pretrained --- paddlenlp/transformers/ernie_gen/modeling.py | 36 +++----------------- 1 file changed, 4 insertions(+), 32 deletions(-) diff --git a/paddlenlp/transformers/ernie_gen/modeling.py b/paddlenlp/transformers/ernie_gen/modeling.py index 6aa6ceff3a2e..c758a77d7ec4 100644 --- a/paddlenlp/transformers/ernie_gen/modeling.py +++ b/paddlenlp/transformers/ernie_gen/modeling.py @@ -25,6 +25,7 @@ from paddle.utils.download import get_path_from_url from paddlenlp.utils.log import logger from paddlenlp.transformers import BertPretrainedModel, ElectraPretrainedModel, RobertaPretrainedModel, ErniePretrainedModel +from .. import PretrainedModel, register_base_model from ..utils import InitTrackerMeta, fn_args_to_dict @@ -216,7 +217,7 @@ def forward(self, inputs, attn_bias=None, past_cache=None): @six.add_metaclass(InitTrackerMeta) -class ErnieGenPretrainedModel(object): +class ErnieGenPretrainedModel(PretrainedModel): r""" An abstract class for pretrained ErnieGen models. It provides ErnieGen related `model_config_file`, `pretrained_init_configuration`, `resource_files_names`, @@ -389,36 +390,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): weight_path) return model - def save_pretrained(self, save_directory): - """ - Save model configuration and related resources (model state) to files - under `save_directory`. - Args: - save_directory (str): Directory to save files into. - """ - assert os.path.isdir( - save_directory - ), "Saving directory ({}) should be a directory".format(save_directory) - # save model config - model_config_file = os.path.join(save_directory, self.model_config_file) - model_config = self.init_config - # If init_config contains a Layer, use the layer's init_config to save - for key, value in model_config.items(): - if key == "init_args": - args = [] - for arg in value: - args.append(arg.init_config if isinstance( - arg, ErnieGenPretrainedModel) else arg) - model_config[key] = tuple(args) - elif isinstance(value, ErnieGenPretrainedModel): - model_config[key] = value.init_config - with io.open(model_config_file, "w", encoding="utf-8") as f: - f.write(json.dumps(model_config, ensure_ascii=False)) - # save model - file_name = os.path.join(save_directory, - list(self.resource_files_names.values())[0]) - paddle.save(self.state_dict(), file_name) - def _post_init(self, original_init, *args, **kwargs): """ It would be hooked after `__init__` to add a dict including arguments of @@ -428,7 +399,8 @@ def _post_init(self, original_init, *args, **kwargs): self.config = init_dict -class ErnieModel(nn.Layer, ErnieGenPretrainedModel): +@register_base_model +class ErnieModel(ErnieGenPretrainedModel): def __init__(self, cfg, name=None): """