From a6d62aaba01ce4ff1b2ee8705bf113904672c345 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 5 Aug 2021 10:12:13 +0200 Subject: [PATCH] GPT-Neo ONNX export (#12911) GPT-Neo ONNX export and task / feature refactoring Authored-by: Michael Benayoun --- src/transformers/models/gpt_neo/__init__.py | 4 +- .../models/gpt_neo/configuration_gpt_neo.py | 142 ++++++++++++++++++ .../models/gpt_neo/modeling_gpt_neo.py | 2 +- src/transformers/onnx/__init__.py | 2 +- src/transformers/onnx/__main__.py | 91 +---------- src/transformers/onnx/config.py | 84 +++++++++-- src/transformers/onnx/convert.py | 14 +- src/transformers/onnx/features.py | 135 +++++++++++++++++ tests/test_onnx_v2.py | 9 +- 9 files changed, 380 insertions(+), 103 deletions(-) create mode 100644 src/transformers/onnx/features.py diff --git a/src/transformers/models/gpt_neo/__init__.py b/src/transformers/models/gpt_neo/__init__.py index 898b9a0df78..3112d9a36e3 100644 --- a/src/transformers/models/gpt_neo/__init__.py +++ b/src/transformers/models/gpt_neo/__init__.py @@ -21,7 +21,7 @@ _import_structure = { - "configuration_gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig"], + "configuration_gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig", "GPTNeoOnnxConfig"], } if is_torch_available(): @@ -43,7 +43,7 @@ if TYPE_CHECKING: - from .configuration_gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig + from .configuration_gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig, GPTNeoOnnxConfig if is_torch_available(): from .modeling_gpt_neo import ( diff --git a/src/transformers/models/gpt_neo/configuration_gpt_neo.py b/src/transformers/models/gpt_neo/configuration_gpt_neo.py index 4ad22eaa1c5..1e9995dfc7b 100644 --- a/src/transformers/models/gpt_neo/configuration_gpt_neo.py +++ b/src/transformers/models/gpt_neo/configuration_gpt_neo.py @@ -14,7 +14,12 @@ # limitations under the License. """ GPT Neo model configuration """ +from collections import OrderedDict +from typing import Any, Mapping, Optional + +from ... import PreTrainedTokenizer, TensorType, is_torch_available from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfigWithPast, PatchingSpec from ...utils import logging @@ -173,3 +178,140 @@ def num_attention_heads(self): @property def num_hidden_layers(self): return self.num_layers + + +def custom_unfold(input, dimension, size, step): + """Custom torch.Tensor.unfold implementation to enable the export to ONNX.""" + import torch + + shape = input.size() + rank = len(shape) + sizedim = shape[dimension] + + low_indices = torch.arange(0, sizedim, step) + min_length = torch.div(sizedim - size, step, rounding_mode="floor") + 1 + indices = torch.arange(size) + low_indices[:min_length][:, None] + + s = [slice(None)] * rank + s[dimension] = indices + sliced = input[s] + + perm = list(range(0, rank + 1)) + perm.append(perm.pop(dimension + 1)) + + return sliced.permute(perm) + + +def custom_get_block_length_and_num_blocks(seq_length, window_size): + """ + Custom implementation for GPTNeoAttentionMixin._get_block_length_and_num_blocks to enable the export to ONNX as + original implmentation uses Python variables and control flow. + """ + import torch + + candidates = torch.arange(1, window_size) + remainders = torch.remainder(seq_length, candidates) + divisor_indices = remainders == 0 + divisors = candidates[divisor_indices] + largest_divisor = torch.max(divisors) + return largest_divisor, torch.div(seq_length, largest_divisor, rounding_mode="floor") + + +class GPTNeoOnnxConfig(OnnxConfigWithPast): + def __init__(self, config: PretrainedConfig, task: str = "default", use_past: bool = False): + if is_torch_available(): + import torch + + from .modeling_gpt_neo import GPTNeoAttentionMixin + + patching_specs = [ + PatchingSpec(torch.Tensor, name="unfold", custom_op=custom_unfold), + PatchingSpec( + GPTNeoAttentionMixin, + name="_get_block_length_and_num_blocks", + custom_op=custom_get_block_length_and_num_blocks, + op_wrapper=staticmethod, + ), + ] + + super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past) + + self._num_local_attention = len([type_ for type_ in self._config.attention_layers if type_ == "local"]) + self._key_values_dynamic_axis = [] + for i in range(self._config.num_layers): + if self._config.attention_layers[i] == "local": + self._key_values_dynamic_axis.append({0: "batch", 1: "sequence"}) + else: + self._key_values_dynamic_axis.append({0: "batch", 2: "sequence"}) + self._key_values_dynamic_axis.append({0: "batch", 2: "sequence"}) + + @property + def _number_key_values(self): + return (self._config.num_layers * 2) - self._num_local_attention + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) + if self.use_past: + for i in range(self._number_key_values): + common_inputs[f"past_key_values.{i}"] = self._key_values_dynamic_axis[i] + + common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} + + return common_inputs + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + common_outputs = super().outputs + if self.use_past: + for i in range(self._number_key_values): + common_outputs[f"present.{i}"] = self._key_values_dynamic_axis[i] + + return common_outputs + + def generate_dummy_inputs( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework) + + # We need to order the input in the way they appears in the forward() + ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]}) + + batch = common_inputs["input_ids"].shape[0] + past_shapes = { + "global": (batch, self._config.num_heads, 1, self._config.hidden_size // self._config.num_attention_heads), + "local": (batch, 1, self._config.hidden_size), + } + + # Need to add the past_keys + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + + ordered_inputs["past_key_values"] = [] + for i in range(self._config.num_layers): + attention_type = self._config.attention_layers[i] + if attention_type == "global": + ordered_inputs["past_key_values"].append( + ( + torch.zeros(past_shapes[attention_type]), + torch.zeros(past_shapes[attention_type]), + ) + ) + else: + ordered_inputs["past_key_values"].append((torch.zeros(past_shapes[attention_type]),)) + + ordered_inputs["attention_mask"] = common_inputs["attention_mask"] + if self.use_past: + ordered_inputs["attention_mask"] = torch.cat( + [ordered_inputs["attention_mask"], torch.zeros(batch, 1)], dim=1 + ) + + return ordered_inputs diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 44c6dd6583a..05e5b1ce281 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -1121,7 +1121,7 @@ def forward( f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" ) - pooled_logits = logits[range(batch_size), sequence_lengths] + pooled_logits = logits[torch.arange(batch_size), sequence_lengths] loss = None if labels is not None: diff --git a/src/transformers/onnx/__init__.py b/src/transformers/onnx/__init__.py index a61e475b828..a80567e202b 100644 --- a/src/transformers/onnx/__init__.py +++ b/src/transformers/onnx/__init__.py @@ -13,6 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .config import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, OnnxConfigWithPast +from .config import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, OnnxConfigWithPast, PatchingSpec from .convert import export, validate_model_outputs from .utils import ParameterFormat, compute_serialized_parameters_size diff --git a/src/transformers/onnx/__main__.py b/src/transformers/onnx/__main__.py index 2c7b2a69528..be724423316 100644 --- a/src/transformers/onnx/__main__.py +++ b/src/transformers/onnx/__main__.py @@ -14,101 +14,22 @@ from argparse import ArgumentParser from pathlib import Path -from typing import Callable, Tuple -from transformers.models.albert import AlbertOnnxConfig from transformers.models.auto import AutoTokenizer -from transformers.models.bart import BartOnnxConfig -from transformers.models.bert import BertOnnxConfig -from transformers.models.distilbert import DistilBertOnnxConfig -from transformers.models.gpt2 import GPT2OnnxConfig -from transformers.models.longformer import LongformerOnnxConfig -from transformers.models.roberta import RobertaOnnxConfig -from transformers.models.t5 import T5OnnxConfig -from transformers.models.xlm_roberta import XLMRobertaOnnxConfig -from .. import is_torch_available from ..utils import logging from .convert import export, validate_model_outputs - - -if is_torch_available(): - from transformers import AutoModel, PreTrainedModel - - FEATURES_TO_AUTOMODELS = { - "default": AutoModel, - } - - -# Set of model topologies we support associated to the features supported by each topology and the factory -SUPPORTED_MODEL_KIND = { - "albert": {"default": AlbertOnnxConfig.default}, - "bart": {"default": BartOnnxConfig.default}, - "bert": {"default": BertOnnxConfig.default}, - "distilbert": {"default": DistilBertOnnxConfig.default}, - "gpt2": {"default": GPT2OnnxConfig.default}, - "longformer": {"default": LongformerOnnxConfig.default}, - "roberta": {"default": RobertaOnnxConfig}, - "t5": {"default": T5OnnxConfig.default}, - "xlm-roberta": {"default": XLMRobertaOnnxConfig.default}, -} - - -def get_model_from_features(features: str, model: str): - """ - Attempt to retrieve a model from a model's name and the features to be enabled. - - Args: - features: The features required - model: The name of the model to export - - Returns: - - """ - if features not in FEATURES_TO_AUTOMODELS: - raise KeyError(f"Unknown feature: {features}." f"Possible values are {list(FEATURES_TO_AUTOMODELS.values())}") - - return FEATURES_TO_AUTOMODELS[features].from_pretrained(model) - - -def check_supported_model_or_raise(model: PreTrainedModel, features: str = "default") -> Tuple[str, Callable]: - """ - Check whether or not the model has the requested features - - Args: - model: The model to export - features: The name of the features to check if they are avaiable - - Returns: - (str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties - - """ - if model.config.model_type not in SUPPORTED_MODEL_KIND: - raise KeyError( - f"{model.config.model_type} ({model.name}) is not supported yet. " - f"Only {SUPPORTED_MODEL_KIND} are supported. " - f"If you want to support ({model.config.model_type}) please propose a PR or open up an issue." - ) - - # Look for the features - model_features = SUPPORTED_MODEL_KIND[model.config.model_type] - if features not in model_features: - raise ValueError( - f"{model.config.model_type} doesn't support features {features}. " - f"Supported values are: {list(model_features.keys())}" - ) - - return model.config.model_type, SUPPORTED_MODEL_KIND[model.config.model_type][features] +from .features import FeaturesManager def main(): parser = ArgumentParser("Hugging Face ONNX Exporter tool") parser.add_argument("-m", "--model", type=str, required=True, help="Model's name of path on disk to load.") parser.add_argument( - "--features", - choices=["default"], + "--feature", + choices=list(FeaturesManager.AVAILABLE_FEATURES), default="default", - help="Export the model with some additional features.", + help="Export the model with some additional feature.", ) parser.add_argument( "--opset", type=int, default=12, help="ONNX opset version to export the model with (default 12)." @@ -127,8 +48,8 @@ def main(): # Allocate the model tokenizer = AutoTokenizer.from_pretrained(args.model) - model = get_model_from_features(args.features, args.model) - model_kind, model_onnx_config = check_supported_model_or_raise(model, features=args.features) + model = FeaturesManager.get_model_from_feature(args.feature, args.model) + model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=args.feature) onnx_config = model_onnx_config(model.config) # Ensure the requested opset is sufficient diff --git a/src/transformers/onnx/config.py b/src/transformers/onnx/config.py index 2cf11368a0c..56512d3652d 100644 --- a/src/transformers/onnx/config.py +++ b/src/transformers/onnx/config.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses from abc import ABC, abstractmethod from collections import OrderedDict -from typing import Any, Mapping, Optional +from typing import Any, Callable, List, Mapping, Optional from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType @@ -26,6 +27,27 @@ EXTERNAL_DATA_FORMAT_SIZE_LIMIT = 2 * 1024 * 1024 * 1024 +@dataclasses.dataclass +class PatchingSpec: + """ + Data class that holds patching specifications. + + Args: + o: Module / object where the op to patch is located + name: Name of the op to monkey patch + custom_op: Custom op that patches the original op + orig_op: Original op that is being patched + op_wrapper: Wrapper (optional) that wraps both the original and custom ops. + It is useful for ops that are class or static methods for instance. + """ + + o: Any + name: str + custom_op: Callable + orig_op: Optional[Callable] = None + op_wrapper: Optional[Callable] = None + + class OnnxConfig(ABC): """ Base class for ONNX exportable model describing metadata on how to export the model through the ONNX format. @@ -34,11 +56,38 @@ class OnnxConfig(ABC): DEFAULT_FIXED_BATCH = 2 DEFAULT_FIXED_SEQUENCE = 8 - def __init__(self, config: PretrainedConfig): + _TASKS_TO_COMMON_OUTPUTS = { + "default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}), + "causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), + "sequence-classification": OrderedDict({"logits": {0: "batch"}}), + "token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), + "multiple-choice": OrderedDict({"logits": {0: "batch"}}), + "question-answering": OrderedDict( + { + "start_logits": {0: "batch", 1: "sequence"}, + "end_logits": {0: "batch", 1: "sequence"}, + } + ), + } + + def __init__(self, config: PretrainedConfig, task: str = "default", patching_specs: List[PatchingSpec] = None): self._config = config + if task not in self._TASKS_TO_COMMON_OUTPUTS: + raise ValueError( + f"{task} is not a supported task, supported tasks: {self._TASKS_TO_COMMON_OUTPUTS.keys()}" + ) + self.task = task + + self._patching_specs = [] + for spec in patching_specs if patching_specs is not None else []: + final_spec = spec + if spec.orig_op is None: + final_spec = dataclasses.replace(spec, orig_op=getattr(spec.o, spec.name)) + self._patching_specs.append(final_spec) + @classmethod - def default(cls, config: PretrainedConfig) -> "OnnxConfig": + def from_model_config(cls, config: PretrainedConfig, task: str = "default") -> "OnnxConfig": """ Instantiate a OnnxConfig for a specific model @@ -48,7 +97,7 @@ def default(cls, config: PretrainedConfig) -> "OnnxConfig": Returns: OnnxConfig for this model """ - return cls(config) + return cls(config, task=task) @property @abstractmethod @@ -62,7 +111,6 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]: raise NotImplementedError() @property - @abstractmethod def outputs(self) -> Mapping[str, Mapping[int, str]]: """ Mapping containing the axis definition of the output tensors to provide to the model @@ -70,7 +118,7 @@ def outputs(self) -> Mapping[str, Mapping[int, str]]: Returns: For each output: its name associated to the axes symbolic name and the axis position within the tensor """ - raise NotImplementedError() + return self._TASKS_TO_COMMON_OUTPUTS[self.task] @property def values_override(self) -> Optional[Mapping[str, Any]]: @@ -170,14 +218,30 @@ def generate_dummy_inputs( dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size return dict(tokenizer(dummy_input, return_tensors=framework)) + def patch_ops(self): + for spec in self._patching_specs: + custom_op = spec.custom_op if spec.op_wrapper is None else spec.op_wrapper(spec.custom_op) + setattr(spec.o, spec.name, custom_op) + + def restore_ops(self): + for spec in self._patching_specs: + orig_op = spec.orig_op if spec.op_wrapper is None else spec.op_wrapper(spec.orig_op) + setattr(spec.o, spec.name, orig_op) + class OnnxConfigWithPast(OnnxConfig, ABC): - def __init__(self, config: PretrainedConfig, use_past: bool = False): - super().__init__(config) + def __init__( + self, + config: PretrainedConfig, + task: str = "default", + patching_specs: List[PatchingSpec] = None, + use_past: bool = False, + ): + super().__init__(config, task=task, patching_specs=patching_specs) self.use_past = use_past @classmethod - def with_past(cls, config: PretrainedConfig) -> "OnnxConfigWithPast": + def with_past(cls, config: PretrainedConfig, task: str = "default") -> "OnnxConfigWithPast": """ Instantiate a OnnxConfig with `use_past` attribute set to True @@ -187,7 +251,7 @@ def with_past(cls, config: PretrainedConfig) -> "OnnxConfigWithPast": Returns: OnnxConfig with `.use_past = True` """ - return cls(config, use_past=True) + return cls(config, task=task, use_past=True) @property def values_override(self) -> Optional[Mapping[str, Any]]: diff --git a/src/transformers/onnx/convert.py b/src/transformers/onnx/convert.py index cfb600dfa4e..86491a7f5b8 100644 --- a/src/transformers/onnx/convert.py +++ b/src/transformers/onnx/convert.py @@ -111,6 +111,8 @@ def export( if not inputs_match: raise ValueError("Model and config inputs doesn't match") + config.patch_ops() + # export can works with named args but the dict containing named args as to be last element of the args tuple export( model, @@ -125,6 +127,8 @@ def export( opset_version=opset, ) + config.restore_ops() + return matched_inputs, onnx_outputs @@ -140,6 +144,8 @@ def validate_model_outputs( logger.info("Validating ONNX model...") + # TODO: generate inputs with a different batch_size and seq_len that was used for conversion to properly test + # dynamic input shapes. reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH) # Create ONNX Runtime session @@ -152,6 +158,10 @@ def validate_model_outputs( # We flatten potential collection of outputs (i.e. past_keys) to a flat structure for name, value in ref_outputs.items(): + # Overwriting the output name as "present" since it is the name used for the ONNX ouputs + # ("past_key_values" being taken for the ONNX inputs) + if name == "past_key_values": + name = "present" if isinstance(value, (list, tuple)): value = flatten_output_collection_property(name, value) ref_outputs_dict.update(value) @@ -186,7 +196,7 @@ def validate_model_outputs( # Check the shape and values match for name, ort_value in zip(onnx_named_outputs, onnx_outputs): - ref_value = ref_outputs_dict[name].numpy() + ref_value = ref_outputs_dict[name].detach().numpy() logger.info(f'\t- Validating ONNX Model output "{name}":') # Shape @@ -197,7 +207,7 @@ def validate_model_outputs( f"Got {ref_value.shape} (reference) and {ort_value.shape} (ONNX)" ) else: - logger.info(f"\t\t-[✓] {ort_value.shape} matchs {ref_value.shape}") + logger.info(f"\t\t-[✓] {ort_value.shape} matches {ref_value.shape}") # Values if not np.allclose(ref_value, ort_value, atol=atol): diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py new file mode 100644 index 00000000000..8d06176cbbd --- /dev/null +++ b/src/transformers/onnx/features.py @@ -0,0 +1,135 @@ +from functools import partial, reduce +from typing import Callable, Tuple + +from .. import is_torch_available +from ..models.albert import AlbertOnnxConfig +from ..models.bart import BartOnnxConfig +from ..models.bert import BertOnnxConfig +from ..models.distilbert import DistilBertOnnxConfig +from ..models.gpt2 import GPT2OnnxConfig +from ..models.gpt_neo import GPTNeoOnnxConfig +from ..models.longformer import LongformerOnnxConfig +from ..models.roberta import RobertaOnnxConfig +from ..models.t5 import T5OnnxConfig +from ..models.xlm_roberta import XLMRobertaOnnxConfig + + +if is_torch_available(): + from transformers import PreTrainedModel + from transformers.models.auto import ( + AutoModel, + AutoModelForCausalLM, + AutoModelForMultipleChoice, + AutoModelForQuestionAnswering, + AutoModelForSequenceClassification, + AutoModelForTokenClassification, + ) + + +def supported_features_mapping(*supported_features, onnx_config_cls=None): + """Generates the mapping between supported features and their corresponding OnnxConfig.""" + if onnx_config_cls is None: + raise ValueError("A OnnxConfig class must be provided") + + mapping = {} + for feature in supported_features: + if "-with-past" in feature: + task = feature.replace("-with-past", "") + mapping[feature] = partial(onnx_config_cls.with_past, task=task) + else: + mapping[feature] = partial(onnx_config_cls.from_model_config, task=feature) + + return mapping + + +class FeaturesManager: + _TASKS_TO_AUTOMODELS = { + "default": AutoModel, + "causal-lm": AutoModelForCausalLM, + "sequence-classification": AutoModelForSequenceClassification, + "token-classification": AutoModelForTokenClassification, + "multiple-choice": AutoModelForMultipleChoice, + "question-answering": AutoModelForQuestionAnswering, + } + + # Set of model topologies we support associated to the features supported by each topology and the factory + _SUPPORTED_MODEL_KIND = { + "albert": supported_features_mapping("default", onnx_config_cls=AlbertOnnxConfig), + "bart": supported_features_mapping("default", onnx_config_cls=BartOnnxConfig), + "bert": supported_features_mapping("default", onnx_config_cls=BertOnnxConfig), + "distilbert": supported_features_mapping("default", onnx_config_cls=DistilBertOnnxConfig), + "gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig), + "longformer": supported_features_mapping("default", onnx_config_cls=LongformerOnnxConfig), + "roberta": supported_features_mapping("default", onnx_config_cls=RobertaOnnxConfig), + "t5": supported_features_mapping("default", onnx_config_cls=T5OnnxConfig), + "xlm-roberta": supported_features_mapping("default", onnx_config_cls=XLMRobertaOnnxConfig), + "gpt-neo": supported_features_mapping( + "default", + "causal-lm", + "sequence-classification", + "default-with-past", + "causal-lm-with-past", + "sequence-classification-with-past", + onnx_config_cls=GPTNeoOnnxConfig, + ), + } + + AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_KIND.values()))) + + @staticmethod + def feature_to_task(feature: str) -> str: + return feature.replace("-with-past", "") + + @staticmethod + def get_model_from_feature(feature: str, model: str): + """ + Attempt to retrieve a model from a model's name and the feature to be enabled. + + Args: + feature: The feature required + model: The name of the model to export + + Returns: + + """ + task = FeaturesManager.feature_to_task(feature) + if task not in FeaturesManager._TASKS_TO_AUTOMODELS: + raise KeyError( + f"Unknown task: {feature}." + f"Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}" + ) + + return FeaturesManager._TASKS_TO_AUTOMODELS[task].from_pretrained(model) + + @staticmethod + def check_supported_model_or_raise(model: PreTrainedModel, feature: str = "default") -> Tuple[str, Callable]: + """ + Check whether or not the model has the requested features + + Args: + model: The model to export + feature: The name of the feature to check if it is avaiable + + Returns: + (str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties + + """ + model_type = model.config.model_type.replace("_", "-") + model_name = getattr(model, "name", "") + model_name = f"({model_name})" if model_name else "" + if model_type not in FeaturesManager._SUPPORTED_MODEL_KIND: + raise KeyError( + f"{model.config.model_type} ({model_name}) is not supported yet. " + f"Only {FeaturesManager._SUPPORTED_MODEL_KIND} are supported. " + f"If you want to support ({model.config.model_type}) please propose a PR or open up an issue." + ) + + # Look for the features + model_features = FeaturesManager._SUPPORTED_MODEL_KIND[model_type] + if feature not in model_features: + raise ValueError( + f"{model.config.model_type} doesn't support feature {feature}. " + f"Supported values are: {list(model_features.keys())}" + ) + + return model.config.model_type, FeaturesManager._SUPPORTED_MODEL_KIND[model_type][feature] diff --git a/tests/test_onnx_v2.py b/tests/test_onnx_v2.py index 9493e8e066b..6a2ce960d78 100644 --- a/tests/test_onnx_v2.py +++ b/tests/test_onnx_v2.py @@ -9,6 +9,7 @@ BartConfig, DistilBertConfig, GPT2Config, + GPTNeoConfig, RobertaConfig, XLMRobertaConfig, is_torch_available, @@ -20,6 +21,7 @@ # from transformers.models.longformer import LongformerOnnxConfig from transformers.models.gpt2 import GPT2OnnxConfig +from transformers.models.gpt_neo import GPTNeoOnnxConfig from transformers.models.roberta import RobertaOnnxConfig # from transformers.models.t5 import T5OnnxConfig @@ -151,7 +153,8 @@ def test_use_past(self): for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS: with self.subTest(name): self.assertFalse( - OnnxConfigWithPast.default(config()).use_past, "OnnxConfigWithPast.default() should not use_past" + OnnxConfigWithPast.from_model_config(config()).use_past, + "OnnxConfigWithPast.from_model_config() should not use_past", ) self.assertTrue( @@ -167,7 +170,7 @@ def test_values_override(self): with self.subTest(name): # without past - onnx_config_default = OnnxConfigWithPast.default(config()) + onnx_config_default = OnnxConfigWithPast.from_model_config(config()) self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None") self.assertIn("use_cache", onnx_config_default.values_override, "use_cache should be present") self.assertFalse( @@ -190,6 +193,7 @@ def test_values_override(self): BertModel, DistilBertModel, GPT2Model, + GPTNeoModel, RobertaModel, XLMRobertaModel, ) @@ -200,6 +204,7 @@ def test_values_override(self): ("BERT", "bert-base-cased", BertModel, BertConfig, BertOnnxConfig), ("DistilBERT", "distilbert-base-cased", DistilBertModel, DistilBertConfig, DistilBertOnnxConfig), ("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig), + ("GPT-Neo", "EleutherAI/gpt-neo-125M", GPTNeoModel, GPTNeoConfig, GPTNeoOnnxConfig), # ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig), ("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig), ("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),