Skip to content

Commit

Permalink
Fix erniegen no model_config_file (#3321)
Browse files Browse the repository at this point in the history
* fix

* rm save_pretrained
  • Loading branch information
gongenlei authored Sep 26, 2022
1 parent 0e159ad commit adf6fca
Showing 1 changed file with 4 additions and 32 deletions.
36 changes: 4 additions & 32 deletions paddlenlp/transformers/ernie_gen/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from paddle.utils.download import get_path_from_url
from paddlenlp.utils.log import logger
from paddlenlp.transformers import BertPretrainedModel, ElectraPretrainedModel, RobertaPretrainedModel, ErniePretrainedModel
from .. import PretrainedModel, register_base_model

from ..utils import InitTrackerMeta, fn_args_to_dict

Expand Down Expand Up @@ -216,7 +217,7 @@ def forward(self, inputs, attn_bias=None, past_cache=None):


@six.add_metaclass(InitTrackerMeta)
class ErnieGenPretrainedModel(object):
class ErnieGenPretrainedModel(PretrainedModel):
r"""
An abstract class for pretrained ErnieGen models. It provides ErnieGen related
`model_config_file`, `pretrained_init_configuration`, `resource_files_names`,
Expand Down Expand Up @@ -389,36 +390,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
weight_path)
return model

def save_pretrained(self, save_directory):
"""
Save model configuration and related resources (model state) to files
under `save_directory`.
Args:
save_directory (str): Directory to save files into.
"""
assert os.path.isdir(
save_directory
), "Saving directory ({}) should be a directory".format(save_directory)
# save model config
model_config_file = os.path.join(save_directory, self.model_config_file)
model_config = self.init_config
# If init_config contains a Layer, use the layer's init_config to save
for key, value in model_config.items():
if key == "init_args":
args = []
for arg in value:
args.append(arg.init_config if isinstance(
arg, ErnieGenPretrainedModel) else arg)
model_config[key] = tuple(args)
elif isinstance(value, ErnieGenPretrainedModel):
model_config[key] = value.init_config
with io.open(model_config_file, "w", encoding="utf-8") as f:
f.write(json.dumps(model_config, ensure_ascii=False))
# save model
file_name = os.path.join(save_directory,
list(self.resource_files_names.values())[0])
paddle.save(self.state_dict(), file_name)

def _post_init(self, original_init, *args, **kwargs):
"""
It would be hooked after `__init__` to add a dict including arguments of
Expand All @@ -428,7 +399,8 @@ def _post_init(self, original_init, *args, **kwargs):
self.config = init_dict


class ErnieModel(nn.Layer, ErnieGenPretrainedModel):
@register_base_model
class ErnieModel(ErnieGenPretrainedModel):

def __init__(self, cfg, name=None):
"""
Expand Down

0 comments on commit adf6fca

Please sign in to comment.