Skip to content

Commit

Permalink
Style
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger committed Sep 7, 2021
1 parent 0bd8b2f commit de17667
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 21 deletions.
18 changes: 1 addition & 17 deletions src/transformers/models/auto/auto_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
import importlib
from collections import OrderedDict

from transformers.models.auto.dynamic import get_class_from_dynamic_module

from ...configuration_utils import PretrainedConfig
from ...file_utils import copy_func
from ...utils import logging
Expand Down Expand Up @@ -380,21 +378,7 @@ def __init__(self, *args, **kwargs):

@classmethod
def from_config(cls, config, **kwargs):
allow_custom_model = kwargs.pop("allow_custom_model", False)
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
if not allow_custom_model:
raise ValueError(
f"Loading {pretrained_model_name_or_path} requires you to execute the modeling file in that repo "
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
"the option `allow_custom_model=True` to remove this error."
)
class_ref = config.auto_map[cls.__name__]
module_file, class_name = class_ref.split(".")
model_class = get_class_from_dynamic_module(
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
)
return model_class._from_config(config, **kwargs)
elif type(config) in cls._model_mapping.keys():
if type(config) in cls._model_mapping.keys():
model_class = _get_model_class(config, cls._model_mapping)
return model_class._from_config(config, **kwargs)

Expand Down
14 changes: 10 additions & 4 deletions src/transformers/models/auto/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
from pathlib import Path
from typing import Dict, Optional, Union

from ...file_utils import HF_MODULES_CACHE, TRANSFORMERS_DYNAMIC_MODULE_NAME, cached_path, hf_bucket_url, is_offline_mode
from ...file_utils import (
HF_MODULES_CACHE,
TRANSFORMERS_DYNAMIC_MODULE_NAME,
cached_path,
hf_bucket_url,
is_offline_mode,
)
from ...utils import logging


Expand Down Expand Up @@ -65,7 +71,7 @@ def check_imports(filename):
"""
with open(filename, "r", encoding="utf-8") as f:
content = f.read()

# Imports of the form `import xxx`
imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
# Imports of the form `from xxx import yyy`
Expand All @@ -81,7 +87,7 @@ def check_imports(filename):
importlib.import_module(imp)
except ImportError:
missing_packages.append(imp)

if len(missing_packages) > 0:
raise ImportError(
"This modeling file requires the following packages that were not found in your environment: "
Expand Down Expand Up @@ -193,7 +199,7 @@ def get_class_from_dynamic_module(
except EnvironmentError:
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
raise

# Check we have all the requirements in our environment
check_imports(resolved_module_file)

Expand Down

0 comments on commit de17667

Please sign in to comment.