Skip to content

Commit

Permalink
Use defensive flag
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger committed Sep 7, 2021
1 parent 54e320d commit 0bd8b2f
Showing 1 changed file with 35 additions and 5 deletions.
40 changes: 35 additions & 5 deletions src/transformers/models/auto/auto_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
allow_custom_model (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
should only be set to :obj:`True` for repository you trust and in which you have read the code, as it
will execute code present on the Hub in your local machine.
kwargs (additional keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
Expand Down Expand Up @@ -214,6 +218,10 @@
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
allow_custom_model (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
should only be set to :obj:`True` for repository you trust and in which you have read the code, as it
will execute code present on the Hub in your local machine.
kwargs (additional keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
Expand Down Expand Up @@ -303,6 +311,10 @@
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
allow_custom_model (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
should only be set to :obj:`True` for repository you trust and in which you have read the code, as it
will execute code present on the Hub in your local machine.
kwargs (additional keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
Expand Down Expand Up @@ -368,7 +380,21 @@ def __init__(self, *args, **kwargs):

@classmethod
def from_config(cls, config, **kwargs):
if type(config) in cls._model_mapping.keys():
allow_custom_model = kwargs.pop("allow_custom_model", False)
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
if not allow_custom_model:
raise ValueError(
f"Loading {pretrained_model_name_or_path} requires you to execute the modeling file in that repo "
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
"the option `allow_custom_model=True` to remove this error."
)
class_ref = config.auto_map[cls.__name__]
module_file, class_name = class_ref.split(".")
model_class = get_class_from_dynamic_module(
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
)
return model_class._from_config(config, **kwargs)
elif type(config) in cls._model_mapping.keys():
model_class = _get_model_class(config, cls._model_mapping)
return model_class._from_config(config, **kwargs)

Expand All @@ -380,19 +406,23 @@ def from_config(cls, config, **kwargs):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
config = kwargs.pop("config", None)
allow_custom_model = kwargs.pop("allow_custom_model", False)
kwargs["_from_auto"] = True
if not isinstance(config, PretrainedConfig):
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
if not allow_custom_model:
raise ValueError(
f"Loading {pretrained_model_name_or_path} requires you to execute the modeling file in that repo "
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
"the option `allow_custom_model=True` to remove this error."
)
class_ref = config.auto_map[cls.__name__]
module_file, class_name = class_ref.split(".")
model_class = get_class_from_dynamic_module(
pretrained_model_name_or_path,
module_file + ".py",
class_name,
**kwargs
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
)
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif type(config) in cls._model_mapping.keys():
Expand Down

0 comments on commit 0bd8b2f

Please sign in to comment.