Skip to content

Commit

Permalink
Merge branch 'main' into add-rocbert-support-for-bettertransformer
Browse files Browse the repository at this point in the history
  • Loading branch information
shogohida authored Dec 7, 2022
2 parents cd94151 + f6eb417 commit 74af5d4
Show file tree
Hide file tree
Showing 32 changed files with 1,729 additions and 1,088 deletions.
217 changes: 115 additions & 102 deletions README.md

Large diffs are not rendered by default.

29 changes: 23 additions & 6 deletions docs/combine_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import yaml


SECTIONS_AT_THE_END = ["Utilities"]


parser = argparse.ArgumentParser(
description="Script to combine doc builds from subpackages with base doc build of Optimum. "
"Assumes all subpackage doc builds are present in the root of the `optimum` repo."
Expand Down Expand Up @@ -83,6 +86,20 @@ def rename_copy_subpackage_html_paths(subpackage: str, subpackage_path: Path, op
def main():
args = parser.parse_args()
optimum_path = Path("optimum-doc-build")
# Load optimum table of contents
base_toc_path = next(optimum_path.rglob("_toctree.yml"))
with open(base_toc_path, "r") as f:
base_toc = yaml.safe_load(f)

# Pop specific sections to add them after subpackages
sections_to_pop = {title: None for title in SECTIONS_AT_THE_END}
for i, section in enumerate(base_toc[:]):
if section["title"] in SECTIONS_AT_THE_END:
sections_to_pop[section["title"]] = base_toc.pop(i)
# Raise an error if a section was not found
for key, value in sections_to_pop.items():
if value is None:
raise ValueError(f"No section was found for title '{key}'.")

# Copy and rename all files from subpackages' docs to Optimum doc
for subpackage in args.subpackages:
Expand All @@ -96,10 +113,6 @@ def main():
args.version,
)

# Load optimum table of contents
base_toc_path = next(optimum_path.rglob("_toctree.yml"))
with open(base_toc_path, "r") as f:
base_toc = yaml.safe_load(f)
# Load subpackage table of contents
subpackage_toc_path = next(subpackage_path.rglob("_toctree.yml"))
with open(subpackage_toc_path, "r") as f:
Expand All @@ -108,8 +121,12 @@ def main():
rename_subpackage_toc(subpackage, subpackage_toc)
# Update optimum table of contents
base_toc.extend(subpackage_toc)
with open(base_toc_path, "w") as f:
yaml.safe_dump(base_toc, f, allow_unicode=True)

# Add popped sections at the end
base_toc.extend(sections_to_pop.values())
# Write final table of contents
with open(base_toc_path, "w") as f:
yaml.safe_dump(base_toc, f, allow_unicode=True)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@
- local: bettertransformer/tutorials/contribute
title: How to add support for new architectures?
title: Tutorials
title: BetterTransformer integration
title: BetterTransformer
isExpanded: false
- sections:
- local: utils/dummy_input_generators
Expand Down
1 change: 1 addition & 0 deletions docs/source/bettertransformer/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ The list of supported model below:
- [MarkupLM](https://arxiv.org/abs/2110.08518)
- [MBart](https://arxiv.org/abs/2001.08210)
- [M2M100](https://arxiv.org/abs/2010.11125)
- [RemBERT](https://arxiv.org/abs/2010.12821)
- [RoBERTa](https://arxiv.org/abs/1907.11692)
- [RoCBert](https://aclanthology.org/2022.acl-long.65.pdf)
- [Splinter](https://arxiv.org/abs/2101.00438)
Expand Down
4 changes: 4 additions & 0 deletions docs/source/onnxruntime/package_reference/modeling_ort.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ specific language governing permissions and limitations under the License.

[[autodoc]] onnxruntime.ORTModelForCausalLM

## ORTModelForCustomTasks

[[autodoc]] onnxruntime.ORTModelForCustomTasks

## ORTModelForFeatureExtraction

[[autodoc]] onnxruntime.ORTModelForFeatureExtraction
Expand Down
17 changes: 9 additions & 8 deletions optimum/bettertransformer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,18 @@

BETTER_TRANFORMER_LAYERS_MAPPING_DICT = {
# Bert Family
"TapasLayer": BertLayerBetterTransformer,
"BertGenerationLayer": BertLayerBetterTransformer,
"BertLayer": BertLayerBetterTransformer,
"ElectraLayer": BertLayerBetterTransformer,
"Data2VecTextLayer": BertLayerBetterTransformer,
"CamembertLayer": BertLayerBetterTransformer,
"Data2VecTextLayer": BertLayerBetterTransformer,
"ElectraLayer": BertLayerBetterTransformer,
"ErnieLayer": BertLayerBetterTransformer,
"LayoutLMLayer": BertLayerBetterTransformer,
"MarkupLMLayer": BertLayerBetterTransformer,
"RemBertLayer": BertLayerBetterTransformer,
"RobertaLayer": BertLayerBetterTransformer,
"SplinterLayer": BertLayerBetterTransformer,
"ErnieLayer": BertLayerBetterTransformer,
"LayoutLMLayer": BertLayerBetterTransformer,
"BertGenerationLayer": BertLayerBetterTransformer,
"TapasLayer": BertLayerBetterTransformer,
"XLMRobertaLayer": BertLayerBetterTransformer,
"RoCBertLayer": BertLayerBetterTransformer,
# Albert Family
Expand All @@ -62,13 +63,13 @@
# WhisperModel
"WhisperEncoderLayer": WhisperEncoderLayerBetterTransformer,
# Wav2vec2 family:
"Wav2Vec2EncoderLayer": Wav2Vec2EncoderLayerBetterTransformer,
"HubertEncoderLayer": Wav2Vec2EncoderLayerBetterTransformer,
"Wav2Vec2EncoderLayer": Wav2Vec2EncoderLayerBetterTransformer,
# "UniSpeechEncoderLayer": Wav2Vec2EncoderLayerBetterTransformer,
# "Data2VecAudioEncoderLayer": Wav2Vec2EncoderLayerBetterTransformer,
# ViT Family:
"ViTLayer": ViTLayerBetterTransformer,
"DeiTLayer": ViTLayerBetterTransformer,
"ViTLayer": ViTLayerBetterTransformer,
"ViTMAELayer": ViTLayerBetterTransformer,
"ViTMSNLayer": ViTLayerBetterTransformer,
"YolosLayer": ViTLayerBetterTransformer,
Expand Down
39 changes: 33 additions & 6 deletions optimum/bettertransformer/models/encoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,11 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__):
"""
super().forward_checker()

if not hasattr(hidden_states, "original_shape"):
original_shape = hidden_states.shape
else:
original_shape = hidden_states.original_shape

if hidden_states.is_nested:
attention_mask = None

Expand Down Expand Up @@ -339,8 +344,11 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__):
self.linear2_bias,
attention_mask,
)
if hidden_states.is_nested and self.is_last_layer:
hidden_states = hidden_states.to_padded_tensor(0.0)

if not self.is_last_layer:
hidden_states.original_shape = original_shape
elif hidden_states.is_nested and self.is_last_layer:
hidden_states = hidden_states.to_padded_tensor(0.0, original_shape)
return (hidden_states,)


Expand Down Expand Up @@ -412,6 +420,11 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__):
"""
super().forward_checker()

if not hasattr(hidden_states, "original_shape"):
original_shape = hidden_states.shape
else:
original_shape = hidden_states.original_shape

if hidden_states.is_nested:
attention_mask = None

Expand Down Expand Up @@ -449,8 +462,11 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__):
self.linear2_bias,
attention_mask,
)
if hidden_states.is_nested and self.is_last_layer:
hidden_states = hidden_states.to_padded_tensor(0.0)

if not self.is_last_layer:
hidden_states.original_shape = original_shape
elif hidden_states.is_nested and self.is_last_layer:
hidden_states = hidden_states.to_padded_tensor(0.0, original_shape)
return (hidden_states,)


Expand Down Expand Up @@ -1026,6 +1042,11 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__):
"""
super().forward_checker()

if not hasattr(hidden_states, "original_shape"):
original_shape = hidden_states.shape
else:
original_shape = hidden_states.original_shape

if hidden_states.is_nested:
attention_mask = None

Expand All @@ -1037,8 +1058,11 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__):
seqlen = attention_mask.shape[1]
lengths = torch.sum(~attention_mask, 1)

# FSMT swaps the first two axis before calling the encoder stack
# Reference: https://github.com/huggingface/transformers/blob/699e90437f984d69ad3c9b891dd2e9d0fc2cffe4/src/transformers/models/fsmt/modeling_fsmt.py#L508
if hidden_states.shape[0] != attention_mask.shape[0]:
hidden_states = hidden_states.transpose(1, 0)
original_shape = hidden_states.shape

if not all([l == seqlen for l in lengths]):
hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask)
Expand All @@ -1065,6 +1089,9 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__):
self.linear2_bias,
attention_mask,
)
if hidden_states.is_nested and self.is_last_layer:
hidden_states = hidden_states.to_padded_tensor(0.0)

if not self.is_last_layer:
hidden_states.original_shape = original_shape
elif hidden_states.is_nested and self.is_last_layer:
hidden_states = hidden_states.to_padded_tensor(0.0, original_shape)
return (hidden_states, attention_mask)
18 changes: 4 additions & 14 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
from argparse import ArgumentParser
from pathlib import Path

from transformers import AutoFeatureExtractor, AutoTokenizer
from transformers import AutoTokenizer

from ...utils import logging
from ...utils.save_utils import maybe_save_preprocessors
from ..tasks import TasksManager
from .base import OnnxConfigWithPast
from .convert import (
Expand All @@ -30,7 +31,7 @@
)


logger = logging.get_logger() # pylint: disable=invalid-name
logger = logging.get_logger()
logger.setLevel(logging.INFO)


Expand Down Expand Up @@ -143,18 +144,7 @@ def main():
# Saving the model config as this is needed sometimes.
model.config.save_pretrained(args.output.parent)

# Saving the tokenizer / feature extractor as well.
try:
tokenizer = AutoTokenizer.from_pretrained(args.model)
tokenizer.save_pretrained(args.output.parent)
except Exception:
pass

try:
feature_extractor = AutoFeatureExtractor.from_pretrained(args.model)
feature_extractor.save_pretrained(args.output.parent)
except Exception:
pass
maybe_save_preprocessors(args.model, args.output.parent)

if args.atol is None:
args.atol = onnx_config.ATOL_FOR_VALIDATION
Expand Down
31 changes: 25 additions & 6 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,16 +344,36 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
Inherits from [`~exporters.onnx.OnnxConfig`]. A base class to handle the ONNX configuration of decoder-only models.
"""

PAD_ATTENTION_MASK_TO_MATCH_TOTAL_SEQUENCE_LENGTH = True
PAD_ATTENTION_MASK_TO_MATCH_TOTAL_SEQUENCE_LENGTH: bool = True
USE_PAST_IN_INPUTS: Optional[bool] = None
USE_PRESENT_IN_OUTPUTS: Optional[bool] = None

def __init__(
self,
config: "PretrainedConfig",
task: str = "default",
patching_specs: List[PatchingSpec] = None,
use_past: bool = False,
use_past_in_inputs: Optional[bool] = None,
use_present_in_outputs: Optional[bool] = None,
):
self.use_past = use_past
if use_past_in_inputs is None:
use_past_in_inputs = self.USE_PAST_IN_INPUTS
if use_present_in_outputs is None:
use_present_in_outputs = self.USE_PRESENT_IN_OUTPUTS
self.use_past_in_inputs = use_past if use_past_in_inputs is None else use_past_in_inputs
self.use_present_in_outputs = use_past if use_present_in_outputs is None else use_present_in_outputs
if use_past != self.use_past_in_inputs:
logger.warning(
f"use_past = {use_past} is different than use_past_in_inputs = {use_past_in_inputs}, the value of "
"use_past_in_inputs will used for the inputs."
)
if use_past != self.use_present_in_outputs:
logger.warning(
f"use_past = {use_past} is different than use_present_in_outputs = {use_present_in_outputs}, the value "
"of use_present_in_outputs value will used for the outputs."
)
super().__init__(config, task=task, patching_specs=patching_specs)

@classmethod
Expand All @@ -375,15 +395,14 @@ def with_past(cls, config: "PretrainedConfig", task: str = "default") -> "OnnxCo
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
common_outputs = super().outputs
if self.use_past:
if self.use_present_in_outputs:
self.add_past_key_values(common_outputs, direction="outputs")

return common_outputs

@property
def values_override(self) -> Optional[Mapping[str, Any]]:
if hasattr(self._config, "use_cache"):
return {"use_cache": self.use_past}
return {"use_cache": self.use_past_in_inputs or self.use_present_in_outputs}

@add_dynamic_docstring(text=GENERATE_DUMMY_DOCSTRING, dynamic_elements=DEFAULT_DUMMY_SHAPES)
def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
Expand All @@ -407,7 +426,7 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):

if (
self.PAD_ATTENTION_MASK_TO_MATCH_TOTAL_SEQUENCE_LENGTH
and self.use_past
and self.use_past_in_inputs
and "attention_mask" in dummy_inputs
):
past_length = dummy_inputs["past_key_values"][0][0].shape[2]
Expand Down Expand Up @@ -473,7 +492,7 @@ def outputs(self) -> Mapping[str, Mapping[int, str]]:
# We reset the value as the order in common_outputs (OrderedDict) is lost otherwise
else:
axes_names[axis_idx] = name
if self.use_past:
if self.use_present_in_outputs:
self.add_past_key_values(common_outputs, direction="outputs")

return common_outputs
Expand Down
8 changes: 4 additions & 4 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class TextDecoderOnnxConfig(OnnxConfigWithPast):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = {"input_ids": {0: "batch_size", 1: "sequence_length"}}
if self.use_past:
if self.use_past_in_inputs:
self.add_past_key_values(common_inputs, direction="inputs")
common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + sequence_length"}
else:
Expand All @@ -79,15 +79,15 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]:
"input_ids": {0: "batch_size", 1: "encoder_sequence_length"},
"attention_mask": {0: "batch_size", 1: "encoder_sequence_length"},
}
if self.use_past:
if self.use_past_in_inputs:
common_inputs["attention_mask"][1] = "past_encoder_sequence_length + sequence_length"
common_inputs["decoder_input_ids"] = {0: "batch_size"}
# common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
else:
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}
# common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}

if self.use_past:
if self.use_past_in_inputs:
self.add_past_key_values(common_inputs, direction="inputs")

return common_inputs
Expand All @@ -97,7 +97,7 @@ def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGen
self.task, self._normalized_config, **kwargs
)

if self.use_past is True:
if self.use_past_in_inputs is True:
if "sequence_length" in kwargs and kwargs["sequence_length"] != 1:
logger.warning(
f"Asked a sequence length of {kwargs['sequence_length']}, but expecting a sequence length of 1 with use_past == True. Overriding the sequence length to 1."
Expand Down
Loading

0 comments on commit 74af5d4

Please sign in to comment.