diff --git a/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt b/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt
index e201580dd53..2cf45a7627f 100644
--- a/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt
+++ b/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt
@@ -495,6 +495,7 @@ dnf
dnn
dnnl
DNNL
+DnnlExecutionProvider
Dockerfile
doclist
docstrings
@@ -563,6 +564,7 @@ enum
env
environ
ep
+eps
eq
erf
Erf
diff --git a/docs/source/mixed_precision.md b/docs/source/mixed_precision.md
index 560bcc41e33..70ea2d98a89 100644
--- a/docs/source/mixed_precision.md
+++ b/docs/source/mixed_precision.md
@@ -17,6 +17,7 @@ The recently launched 3rd Gen Intel® Xeon® Scalable processor (codenamed Coope
## Mixed Precision Support Matrix
+
@@ -48,7 +49,7 @@ The recently launched 3rd Gen Intel® Xeon® Scalable processor (codenamed Coope
:x: |
- ONNX Runtime |
+ ONNX Runtime |
CPUExecutionProvider |
MLAS |
"default" |
@@ -72,6 +73,14 @@ The recently launched 3rd Gen Intel® Xeon® Scalable processor (codenamed Coope
✔ |
✔ |
+
+ DnnlExecutionProvider |
+ OneDNN |
+ "onnxrt_dnnl_ep" |
+ cpu |
+ ✔ |
+ :x: |
+
Tensorflow |
Tensorflow |
@@ -162,4 +171,5 @@ converted_model.save('./path/to/save/')
- Quick started with [helloworld example](/examples/helloworld/tf_example3)
- PyTorch [ResNet18](/examples/pytorch/image_recognition/torchvision_models/mixed_precision/resnet18)
- IPEX [DistilBERT base](/examples/pytorch/nlp/huggingface_models/question-answering/mixed_precision/ipex)
-- Tensorflow [ResNet50](/examples/tensorflow/image_recognition/tensorflow_models/resnet50_v1/mixed_precision)
\ No newline at end of file
+- Tensorflow [ResNet50](/examples/tensorflow/image_recognition/tensorflow_models/resnet50_v1/mixed_precision)
+- ONNX Runtime [Bert base](/examples/onnxrt/nlp/huggingface_model/text_classification/mix_precision)
diff --git a/docs/source/quantization.md b/docs/source/quantization.md
index e52e5ee702f..a9c770e729b 100644
--- a/docs/source/quantization.md
+++ b/docs/source/quantization.md
@@ -452,7 +452,7 @@ Intel(R) Neural Compressor support multi-framework: PyTorch, Tensorflow, ONNX Ru
cpu |
- ONNX Runtime |
+ ONNX Runtime |
CPUExecutionProvider |
MLAS |
"default" |
@@ -470,6 +470,12 @@ Intel(R) Neural Compressor support multi-framework: PyTorch, Tensorflow, ONNX Ru
"onnxrt_cuda_ep" |
gpu |
+
+ DnnlExecutionProvider |
+ OneDNN |
+ "onnxrt_dnnl_ep" |
+ cpu |
+
Tensorflow |
Tensorflow |
diff --git a/examples/onnxrt/nlp/huggingface_model/text_classification/mix_precision/README.md b/examples/onnxrt/nlp/huggingface_model/text_classification/mix_precision/README.md
new file mode 100644
index 00000000000..25857c31adc
--- /dev/null
+++ b/examples/onnxrt/nlp/huggingface_model/text_classification/mix_precision/README.md
@@ -0,0 +1,77 @@
+Step-by-Step
+============
+
+This example load a language translation model and confirm its accuracy and speed based on [GLUE data](https://gluebenchmark.com/).
+
+# Prerequisite
+
+## 1. Environment
+```shell
+git clone -b dnnl_ep --depth 1 https://github.com/intel/neural-compressor.git
+cd neural-compressor
+pip install -e ./
+
+cd examples/onnxrt/nlp/huggingface_model/text_classification/mix_precision/
+pip install -r requirements.txt
+```
+> Note: Validated ONNX Runtime [Version](/docs/source/installation_guide.md#validated-software-environment).
+
+## 2. Prepare Model
+
+Supported model identifier from [huggingface.co](https://huggingface.co/):
+
+| Model Identifier |
+|:-----------------------------------------------:|
+| Intel/bert-base-uncased-mrpc |
+| Intel/roberta-base-mrpc |
+| Intel/xlm-roberta-base-mrpc |
+| Intel/camembert-base-mrpc |
+| distilbert-base-uncased-finetuned-sst-2-english |
+| Alireza1044/albert-base-v2-sst2 |
+| Intel/MiniLM-L12-H384-uncased-mrpc |
+| philschmid/MiniLM-L6-H384-uncased-sst2 |
+| bert-base-cased-finetuned-mrpc |
+| Intel/electra-small-discriminator-mrpc |
+| M-FAC/bert-mini-finetuned-mrpc |
+| Intel/xlnet-base-cased-mrpc |
+| Intel/bart-large-mrpc |
+
+```bash
+optimum-cli export onnx --model Intel/bert-base-uncased-mrpc --task text-classification
+```
+
+## 3. Prepare Dataset
+Download the GLUE data with `prepare_data.sh` script.
+
+```shell
+export GLUE_DIR=/path/to/glue_data
+export TASK_NAME=MRPC # or SST
+
+bash prepare_data.sh --data_dir=$GLUE_DIR --task_name=$TASK_NAME
+```
+
+# Run
+
+If the hardware doesn't support bf16 instruction, please set flag as below to force bf16 conversion (this way will be deprecated):
+
+```shell
+export FORCE_BF16=1
+```
+
+## 1. Only mixed precision conversion
+
+```bash
+bash run.sh --input_model=path/to/model \ # model path as *.onnx
+ --output_model=path/to/model_tune \ # model path as *.onnx
+```
+
+## 2. Mixed precision conversion + accuracy evaluation
+
+Please make sure DnnlExecutionProvider is in available providers list to execute evaluation.
+
+```bash
+bash eval.sh --input_model=path/to/model \ # model path as *.onnx
+ --output_model=path/to/model_tune \ # model path as *.onnx
+ --dataset_location=path/to/glue/data \
+ --batch_size=batch_size \ # optional
+```
diff --git a/examples/onnxrt/nlp/huggingface_model/text_classification/mix_precision/eval.sh b/examples/onnxrt/nlp/huggingface_model/text_classification/mix_precision/eval.sh
new file mode 100644
index 00000000000..9cc04b05e1e
--- /dev/null
+++ b/examples/onnxrt/nlp/huggingface_model/text_classification/mix_precision/eval.sh
@@ -0,0 +1,128 @@
+#!/bin/bash
+set -x
+
+function main {
+ init_params "$@"
+ run_tuning
+}
+
+# init params
+function init_params {
+ for var in "$@"
+ do
+ case $var in
+ --input_model=*)
+ input_model=$(echo $var |cut -f2 -d=)
+ ;;
+ --output_model=*)
+ output_model=$(echo $var |cut -f2 -d=)
+ ;;
+ --dataset_location=*)
+ dataset_location=$(echo $var |cut -f2 -d=)
+ ;;
+ --batch_size=*)
+ batch_size=$(echo $var |cut -f2 -d=)
+ ;;
+ esac
+ done
+
+}
+
+# run_tuning
+function run_tuning {
+
+ if [[ "${input_model}" =~ "bert-base-uncased" ]]; then
+ model_name_or_path="Intel/bert-base-uncased-mrpc"
+ TASK_NAME='mrpc'
+ num_heads=12
+ hidden_size=768
+ fi
+ if [[ "${input_model}" =~ "roberta-base" ]]; then
+ model_name_or_path="Intel/roberta-base-mrpc"
+ TASK_NAME='mrpc'
+ num_heads=12
+ hidden_size=768
+ fi
+ if [[ "${input_model}" =~ "xlm-roberta-base" ]]; then
+ model_name_or_path="Intel/xlm-roberta-base-mrpc"
+ TASK_NAME='mrpc'
+ num_heads=12
+ hidden_size=768
+ fi
+ if [[ "${input_model}" =~ "camembert-base" ]]; then
+ model_name_or_path="Intel/camembert-base-mrpc"
+ TASK_NAME='mrpc'
+ num_heads=12
+ hidden_size=768
+ fi
+ if [[ "${input_model}" =~ "distilbert-base" ]]; then
+ model_name_or_path="distilbert-base-uncased-finetuned-sst-2-english"
+ TASK_NAME='sst-2'
+ num_heads=12
+ hidden_size=768
+ fi
+ if [[ "${input_model}" =~ "albert-base" ]]; then
+ model_name_or_path="Alireza1044/albert-base-v2-sst2"
+ TASK_NAME='sst-2'
+ num_heads=12
+ hidden_size=768
+ fi
+ if [[ "${input_model}" =~ "MiniLM-L6" ]]; then
+ model_name_or_path="philschmid/MiniLM-L6-H384-uncased-sst2"
+ TASK_NAME='sst-2'
+ num_heads=12
+ hidden_size=384
+ fi
+ if [[ "${input_model}" =~ "MiniLM-L12" ]]; then
+ model_name_or_path="Intel/MiniLM-L12-H384-uncased-mrpc"
+ TASK_NAME='mrpc'
+ num_heads=12
+ hidden_size=384
+ fi
+ if [[ "${input_model}" =~ "bert-base-cased" ]]; then
+ model_name_or_path="bert-base-cased-finetuned-mrpc"
+ TASK_NAME='mrpc'
+ num_heads=12
+ hidden_size=384
+ fi
+ if [[ "${input_model}" =~ "xlnet-base-cased" ]]; then
+ model_name_or_path="Intel/xlnet-base-cased-mrpc"
+ TASK_NAME='mrpc'
+ num_heads=12
+ hidden_size=768
+ fi
+ if [[ "${input_model}" =~ "bert-mini" ]]; then
+ model_name_or_path="M-FAC/bert-mini-finetuned-mrpc"
+ TASK_NAME='mrpc'
+ num_heads=4
+ hidden_size=256
+ fi
+ if [[ "${input_model}" =~ "electra-small-discriminator" ]]; then
+ model_name_or_path="Intel/electra-small-discriminator-mrpc"
+ TASK_NAME='mrpc'
+ num_heads=4
+ hidden_size=256
+ fi
+ if [[ "${input_model}" =~ "bart" ]]; then
+ model_name_or_path="Intel/bart-large-mrpc"
+ TASK_NAME='mrpc'
+ num_heads=16
+ hidden_size=4096
+ fi
+
+ python main.py \
+ --model_name_or_path ${model_name_or_path} \
+ --model_path ${input_model} \
+ --output_model ${output_model} \
+ --data_path ${dataset_location} \
+ --batch_size ${batch_size-1} \
+ --task ${TASK_NAME} \
+ --num_heads ${num_heads} \
+ --hidden_size ${hidden_size} \
+ --do_eval
+}
+
+main "$@"
+
+
+
diff --git a/examples/onnxrt/nlp/huggingface_model/text_classification/mix_precision/main.py b/examples/onnxrt/nlp/huggingface_model/text_classification/mix_precision/main.py
new file mode 100644
index 00000000000..fa5bd52f578
--- /dev/null
+++ b/examples/onnxrt/nlp/huggingface_model/text_classification/mix_precision/main.py
@@ -0,0 +1,405 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint:disable=redefined-outer-name,logging-format-interpolation
+
+import logging
+import argparse
+import onnx
+import onnxruntime as ort
+import transformers
+import os
+import torch
+import numpy as np
+from dataclasses import dataclass
+from typing import List, Optional, Union
+from neural_compressor.data.dataloaders.onnxrt_dataloader import DefaultDataLoader
+from neural_compressor.data.datasets.dummy_dataset import DummyDataset
+
+
+class ONNXRTBertDataset:
+ """Dataset used for model Bert.
+ Args: data_dir (str): The input data dir.
+ model_name_or_path (str): Path to pre-trained student model or shortcut name,
+ selected in the list:
+ max_seq_length (int, default=128): The maximum length after tokenization.
+ Sequences longer than this will be truncated,
+ sequences shorter will be padded.
+ do_lower_case (bool, default=True): Whether to lowercase the input when tokenizing.
+ task (str, default=mrpc): The name of the task to fine-tune.
+ Choices include mrpc, qqp, qnli, rte,
+ sts-b, cola, mnli, wnli.
+ model_type (str, default='bert'): model type, support 'distilbert', 'bert',
+ 'mobilebert', 'roberta'.
+ dynamic_length (bool, default=False): Whether to use fixed sequence length.
+ evaluate (bool, default=True): Whether do evaluation or training.
+ transform (transform object, default=None): transform to process input data.
+ filter (Filter objects, default=None): filter out examples according
+ to specific conditions.
+ """
+ def __init__(self, model, data_dir, model_name_or_path, max_seq_length=128,\
+ do_lower_case=True, task='mrpc', model_type='bert', dynamic_length=False,\
+ evaluate=True, transform=None, filter=None):
+ self.inputs = [inp.name for inp in onnx.load(model).graph.input]
+ task = task.lower()
+ model_type = model_type.lower()
+ assert task in ['mrpc', 'qqp', 'qnli', 'rte', 'sts-b', 'cola', \
+ 'mnli', 'wnli', 'sst-2'], 'Unsupported task type'
+ assert model_type in ['distilbert', 'bert', 'mobilebert', 'roberta'], 'Unsupported \
+ model type'
+ self.dynamic_length = dynamic_length
+ self.model_type = model_type
+ self.max_seq_length = max_seq_length
+ tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path,
+ do_lower_case=do_lower_case)
+ self.dataset = load_and_cache_examples(data_dir, model_name_or_path, \
+ max_seq_length, task, model_type, tokenizer, evaluate)
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def __getitem__(self, index):
+ batch = tuple(t.detach().cpu().numpy() if not isinstance(t, np.ndarray) else t for t in self.dataset[index])
+ return batch[:len(self.inputs)], batch[-1]
+
+def load_and_cache_examples(data_dir, model_name_or_path, max_seq_length, task, \
+ model_type, tokenizer, evaluate):
+ from torch.utils.data import TensorDataset
+
+ processor = transformers.glue_processors[task]()
+ output_mode = transformers.glue_output_modes[task]
+ # Load data features from cache or dataset file
+ if not os.path.exists("./dataset_cached"):
+ os.makedirs("./dataset_cached")
+ cached_features_file = os.path.join("./dataset_cached", 'cached_{}_{}_{}_{}'.format(
+ 'dev' if evaluate else 'train',
+ list(filter(None, model_name_or_path.split('/'))).pop(),
+ str(max_seq_length),
+ str(task)))
+ if os.path.exists(cached_features_file):
+ logger.info("Load features from cached file {}.".format(cached_features_file))
+ features = torch.load(cached_features_file)
+ else:
+ logger.info("Create features from dataset file at {}.".format(data_dir))
+ label_list = processor.get_labels()
+ examples = processor.get_dev_examples(data_dir) if evaluate else \
+ processor.get_train_examples(data_dir)
+ features = convert_examples_to_features(examples,
+ tokenizer,
+ task=task,
+ label_list=label_list,
+ max_length=max_seq_length,
+ output_mode=output_mode,
+ )
+ logger.info("Save features into cached file {}.".format(cached_features_file))
+ torch.save(features, cached_features_file)
+ # Convert to Tensors and build dataset
+ all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
+ all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
+ all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
+ all_seq_lengths = torch.tensor([f.seq_length for f in features], dtype=torch.long)
+ if output_mode == "classification":
+ all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
+ elif output_mode == "regression":
+ all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
+ dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, \
+ all_seq_lengths, all_labels)
+ return dataset
+
+def convert_examples_to_features(
+ examples,
+ tokenizer,
+ max_length=128,
+ task=None,
+ label_list=None,
+ output_mode="classification",
+ pad_token=0,
+ pad_token_segment_id=0,
+ mask_padding_with_zero=True,
+):
+ processor = transformers.glue_processors[task]()
+ if label_list is None:
+ label_list = processor.get_labels()
+ logger.info("Use label list {} for task {}.".format(label_list, task))
+ label_map = {label: i for i, label in enumerate(label_list)}
+ features = []
+ for (ex_index, example) in enumerate(examples):
+ inputs = tokenizer.encode_plus(
+ example.text_a,
+ example.text_b,
+ add_special_tokens=True,
+ max_length=max_length,
+ return_token_type_ids=True,
+ truncation=True,
+ )
+ input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
+ # tokens are attended to.
+ attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
+
+ # Zero-pad up to the sequence length.
+ seq_length = len(input_ids)
+ padding_length = max_length - len(input_ids)
+
+ input_ids = input_ids + ([pad_token] * padding_length)
+ attention_mask = attention_mask + \
+ ([0 if mask_padding_with_zero else 1] * padding_length)
+ token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
+
+ assert len(input_ids) == max_length, \
+ "Error with input_ids length {} vs {}".format(
+ len(input_ids), max_length)
+ assert len(attention_mask) == max_length, \
+ "Error with attention_mask length {} vs {}".format(
+ len(attention_mask), max_length
+ )
+ assert len(token_type_ids) == max_length, \
+ "Error with token_type_ids length {} vs {}".format(
+ len(token_type_ids), max_length
+ )
+ if output_mode == "classification":
+ label = label_map[example.label]
+ elif output_mode == "regression":
+ label = float(example.label)
+ else:
+ raise KeyError(output_mode)
+
+ feats = InputFeatures(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ label=label,
+ seq_length=seq_length,
+ )
+ features.append(feats)
+ return features
+
+@dataclass(frozen=True)
+class InputFeatures:
+ """
+ A single set of features of data.
+ Property names are the same names as the corresponding inputs to a model.
+ Args:
+ input_ids: Indices of input sequence tokens in the vocabulary.
+ attention_mask: Mask to avoid performing attention on padding token indices.
+ Mask values selected in ``[0, 1]``: Usually ``1`` for tokens that are NOT MASKED,
+ ``0`` for MASKED (padded) tokens.
+ token_type_ids: (Optional) Segment token indices to indicate first and second
+ portions of the inputs. Only some models use them.
+ label: (Optional) Label corresponding to the input. Int for classification problems,
+ float for regression problems.
+ seq_length: (Optional) The length of input sequence before padding.
+ """
+
+ input_ids: List[int]
+ attention_mask: Optional[List[int]] = None
+ token_type_ids: Optional[List[int]] = None
+ label: Optional[Union[int, float]] = None
+ seq_length: Optional[List[int]] = None
+
+class ONNXRTGLUE:
+ """Computes GLUE score.
+
+ Args:
+ task (str, default=mrpc): The name of the task.
+ Choices include mrpc, qqp, qnli, rte,
+ sts-b, cola, mnli, wnli.
+
+ """
+ def __init__(self, task='mrpc'):
+ assert task in ['mrpc', 'qqp', 'qnli', 'rte', 'sts-b', 'cola', \
+ 'mnli', 'wnli', 'sst-2'], 'Unsupported task type'
+ self.pred_list = None
+ self.label_list = None
+ self.task = task
+ self.return_key = {
+ "cola": "mcc",
+ "mrpc": "f1",
+ "sts-b": "corr",
+ "qqp": "acc",
+ "mnli": "mnli/acc",
+ "qnli": "acc",
+ "rte": "acc",
+ "wnli": "acc",
+ "sst-2": "acc"
+ }
+
+ def update(self, preds, labels):
+ """add preds and labels to storage"""
+ if isinstance(preds, list) and len(preds) == 1:
+ preds = preds[0]
+ if isinstance(labels, list) and len(labels) == 1:
+ labels = labels[0]
+ if self.pred_list is None:
+ self.pred_list = preds
+ self.label_list = labels
+ else:
+ self.pred_list = np.append(self.pred_list, preds, axis=0)
+ self.label_list = np.append(self.label_list, labels, axis=0)
+
+ def reset(self):
+ """clear preds and labels storage"""
+ self.pred_list = None
+ self.label_list = None
+
+ def result(self):
+ """calculate metric"""
+ output_mode = transformers.glue_output_modes[self.task]
+
+ if output_mode == "classification":
+ processed_preds = np.argmax(self.pred_list, axis=1)
+ elif output_mode == "regression":
+ processed_preds = np.squeeze(self.pred_list)
+ result = transformers.glue_compute_metrics(\
+ self.task, processed_preds, self.label_list)
+ return result[self.return_key[self.task]]
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
+ datefmt = '%m/%d/%Y %H:%M:%S',
+ level = logging.WARN)
+
+if __name__ == "__main__":
+ logger.info('Evaluating ONNXRuntime full precision accuracy and performance:')
+ parser = argparse.ArgumentParser(
+ description='BERT fine-tune examples for classification/regression tasks.',
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument(
+ '--model_path',
+ type=str,
+ help="Pre-trained resnet50 model on onnx file"
+ )
+ parser.add_argument(
+ '--do_eval',
+ action='store_true', \
+ default=False
+ )
+ parser.add_argument(
+ '--output_model',
+ type=str,
+ default=None,
+ help="output model path"
+ )
+ parser.add_argument(
+ '--data_path',
+ type=str,
+ default=None,
+ help="input data path"
+ )
+ parser.add_argument(
+ '--batch_size',
+ default=8,
+ type=int,
+ )
+ parser.add_argument(
+ '--model_name_or_path',
+ type=str,
+ choices=['Intel/bert-base-uncased-mrpc',
+ 'Intel/roberta-base-mrpc',
+ 'Intel/xlm-roberta-base-mrpc',
+ 'Intel/camembert-base-mrpc',
+ 'distilbert-base-uncased-finetuned-sst-2-english',
+ 'Alireza1044/albert-base-v2-sst2',
+ 'philschmid/MiniLM-L6-H384-uncased-sst2',
+ 'Intel/MiniLM-L12-H384-uncased-mrpc',
+ 'bert-base-cased-finetuned-mrpc',
+ 'Intel/electra-small-discriminator-mrpc',
+ 'M-FAC/bert-mini-finetuned-mrpc',
+ 'Intel/xlnet-base-cased-mrpc',
+ 'Intel/bart-large-mrpc'],
+ help="pretrained model name or path"
+ )
+ parser.add_argument(
+ '--task',
+ type=str,
+ choices=['mrpc', 'qqp', 'qnli', 'rte', 'sts-b', 'cola', \
+ 'mnli', 'wnli', 'sst-2'],
+ help="GLUE task name"
+ )
+ parser.add_argument(
+ '--num_heads',
+ default=12,
+ type=int,
+ )
+ parser.add_argument(
+ '--hidden_size',
+ default=768,
+ type=int,
+ )
+
+ args = parser.parse_args()
+
+ if ort.__version__ <= '1.13.1':
+ from onnxruntime.transformers import optimizer
+ from onnxruntime.transformers.fusion_options import FusionOptions
+ model_type = 'bart' if args.model_name_or_path == 'Intel/bart-large-mrpc' else 'bert'
+ opt_options = FusionOptions(model_type)
+ opt_options.enable_embed_layer_norm = False
+
+ model_optimizer = optimizer.optimize_model(
+ args.model_path,
+ model_type,
+ num_heads=args.num_heads,
+ hidden_size=args.hidden_size,
+ optimization_options=opt_options)
+ model = model_optimizer.model
+ else:
+ model = onnx.load(args.model_path)
+
+ from neural_compressor import MixedPrecisionConfig
+ from neural_compressor.mix_precision import fit
+ config = MixedPrecisionConfig(backend='onnxrt_dnnl_ep', precision='bf16')
+ converted_model = fit(model, config)
+ if any([i.domain in ['', 'ai.onnx'] and i.version < 15 for i in converted_model.model.opset_import]):
+ from onnx import version_converter
+ try:
+ new = version_converter.convert_version(converted_model.model, 15)
+ onnx.save(new, args.output_model)
+ except:
+ logging.warning("Fail to upgrade opset_import to > 15, "
+ "please upgrate it manually to run with bf16 data type")
+ else:
+ converted_model.save(args.output_model)
+
+ if args.do_eval:
+ dataset = ONNXRTBertDataset(args.model_path,
+ data_dir=args.data_path,
+ model_name_or_path=args.model_name_or_path,
+ task=args.task)
+ dataloader = DefaultDataLoader(dataset, args.batch_size)
+ metric = ONNXRTGLUE(args.task)
+
+ def eval_func(model, *args):
+ metric.reset()
+ session = ort.InferenceSession(model.SerializeToString(),
+ providers=ort.get_available_providers())
+ ort_inputs = {}
+ len_inputs = len(session.get_inputs())
+ inputs_names = [session.get_inputs()[i].name for i in range(len_inputs)]
+ for idx, (inputs, labels) in enumerate(dataloader):
+ if not isinstance(labels, list):
+ labels = [labels]
+ inputs = inputs[:len_inputs]
+ for i in range(len_inputs):
+ ort_inputs.update({inputs_names[i]: inputs[i]})
+ predictions = session.run(None, ort_inputs)
+ metric.update(predictions[0], labels)
+ return metric.result()
+
+ model = onnx.load(args.output_model)
+ acc_result = eval_func(model)
+ print("Batch size = %d" % args.batch_size)
+ print("Accuracy: %.5f" % acc_result)
diff --git a/examples/onnxrt/nlp/huggingface_model/text_classification/mix_precision/prepare_data.sh b/examples/onnxrt/nlp/huggingface_model/text_classification/mix_precision/prepare_data.sh
new file mode 100644
index 00000000000..8e434a5c521
--- /dev/null
+++ b/examples/onnxrt/nlp/huggingface_model/text_classification/mix_precision/prepare_data.sh
@@ -0,0 +1,34 @@
+#!/bin/bash
+set -x
+
+function main {
+ init_params "$@"
+ download_data
+
+}
+
+# init params
+function init_params {
+
+ for var in "$@"
+ do
+ case $var in
+ --data_dir=*)
+ data_dir=$(echo $var |cut -f2 -d=)
+ ;;
+ --task_name=*)
+ task_name=$(echo $var |cut -f2 -d=)
+ ;;
+ esac
+ done
+
+}
+
+# run_tuning
+function download_data {
+ wget https://raw.githubusercontent.com/huggingface/transformers/f98ef14d161d7bcdc9808b5ec399981481411cc1/utils/download_glue_data.py
+ python download_glue_data.py --data_dir=${data_dir} --tasks=${task_name}
+}
+
+main "$@"
+
diff --git a/examples/onnxrt/nlp/huggingface_model/text_classification/mix_precision/requirements.txt b/examples/onnxrt/nlp/huggingface_model/text_classification/mix_precision/requirements.txt
new file mode 100644
index 00000000000..89803cfb0d2
--- /dev/null
+++ b/examples/onnxrt/nlp/huggingface_model/text_classification/mix_precision/requirements.txt
@@ -0,0 +1,9 @@
+torch
+transformers
+onnx
+onnxruntime >= 1.14.0
+coloredlogs
+sympy
+onnxruntime-extensions; python_version < '3.10'
+numpy
+optimum[exporters]
diff --git a/examples/onnxrt/nlp/huggingface_model/text_classification/mix_precision/run.sh b/examples/onnxrt/nlp/huggingface_model/text_classification/mix_precision/run.sh
new file mode 100644
index 00000000000..033907a1e3c
--- /dev/null
+++ b/examples/onnxrt/nlp/huggingface_model/text_classification/mix_precision/run.sh
@@ -0,0 +1,118 @@
+#!/bin/bash
+set -x
+
+function main {
+ init_params "$@"
+ run_tuning
+}
+
+# init params
+function init_params {
+ for var in "$@"
+ do
+ case $var in
+ --input_model=*)
+ input_model=$(echo $var |cut -f2 -d=)
+ ;;
+ --output_model=*)
+ output_model=$(echo $var |cut -f2 -d=)
+ ;;
+ esac
+ done
+
+}
+
+# run_tuning
+function run_tuning {
+
+ if [[ "${input_model}" =~ "bert-base-uncased" ]]; then
+ model_name_or_path="Intel/bert-base-uncased-mrpc"
+ TASK_NAME='mrpc'
+ num_heads=12
+ hidden_size=768
+ fi
+ if [[ "${input_model}" =~ "roberta-base" ]]; then
+ model_name_or_path="Intel/roberta-base-mrpc"
+ TASK_NAME='mrpc'
+ num_heads=12
+ hidden_size=768
+ fi
+ if [[ "${input_model}" =~ "xlm-roberta-base" ]]; then
+ model_name_or_path="Intel/xlm-roberta-base-mrpc"
+ TASK_NAME='mrpc'
+ num_heads=12
+ hidden_size=768
+ fi
+ if [[ "${input_model}" =~ "camembert-base" ]]; then
+ model_name_or_path="Intel/camembert-base-mrpc"
+ TASK_NAME='mrpc'
+ num_heads=12
+ hidden_size=768
+ fi
+ if [[ "${input_model}" =~ "distilbert-base" ]]; then
+ model_name_or_path="distilbert-base-uncased-finetuned-sst-2-english"
+ TASK_NAME='sst-2'
+ num_heads=12
+ hidden_size=768
+ fi
+ if [[ "${input_model}" =~ "albert-base" ]]; then
+ model_name_or_path="Alireza1044/albert-base-v2-sst2"
+ TASK_NAME='sst-2'
+ num_heads=12
+ hidden_size=768
+ fi
+ if [[ "${input_model}" =~ "MiniLM-L6" ]]; then
+ model_name_or_path="philschmid/MiniLM-L6-H384-uncased-sst2"
+ TASK_NAME='sst-2'
+ num_heads=12
+ hidden_size=384
+ fi
+ if [[ "${input_model}" =~ "MiniLM-L12" ]]; then
+ model_name_or_path="Intel/MiniLM-L12-H384-uncased-mrpc"
+ TASK_NAME='mrpc'
+ num_heads=12
+ hidden_size=384
+ fi
+ if [[ "${input_model}" =~ "bert-base-cased" ]]; then
+ model_name_or_path="bert-base-cased-finetuned-mrpc"
+ TASK_NAME='mrpc'
+ num_heads=12
+ hidden_size=384
+ fi
+ if [[ "${input_model}" =~ "xlnet-base-cased" ]]; then
+ model_name_or_path="Intel/xlnet-base-cased-mrpc"
+ TASK_NAME='mrpc'
+ num_heads=12
+ hidden_size=768
+ fi
+ if [[ "${input_model}" =~ "bert-mini" ]]; then
+ model_name_or_path="M-FAC/bert-mini-finetuned-mrpc"
+ TASK_NAME='mrpc'
+ num_heads=4
+ hidden_size=256
+ fi
+ if [[ "${input_model}" =~ "electra-small-discriminator" ]]; then
+ model_name_or_path="Intel/electra-small-discriminator-mrpc"
+ TASK_NAME='mrpc'
+ num_heads=4
+ hidden_size=256
+ fi
+ if [[ "${input_model}" =~ "bart" ]]; then
+ model_name_or_path="Intel/bart-large-mrpc"
+ TASK_NAME='mrpc'
+ num_heads=16
+ hidden_size=4096
+ fi
+
+ python main.py \
+ --model_name_or_path ${model_name_or_path} \
+ --model_path ${input_model} \
+ --output_model ${output_model} \
+ --num_heads ${num_heads} \
+ --hidden_size ${hidden_size}
+}
+
+main "$@"
+
+
+
diff --git a/neural_compressor/adaptor/onnxrt.py b/neural_compressor/adaptor/onnxrt.py
index 0f62b88a5b4..ae0173979d4 100644
--- a/neural_compressor/adaptor/onnxrt.py
+++ b/neural_compressor/adaptor/onnxrt.py
@@ -65,6 +65,9 @@ def __init__(self, framework_specific_info):
self.recipes = framework_specific_info.get("recipes", {})
self.backend = PROVIDERS[framework_specific_info["backend"]]
self.performance_only = framework_specific_info.get("performance_only", False)
+ self.use_bf16 = framework_specific_info.get("use_bf16", False) and \
+ self.backend in ort.get_available_providers()
+ self.use_fp16 = framework_specific_info.get("use_fp16", False)
if self.backend not in ort.get_all_providers():
logger.warning("{} backend is not supported in current environment, "
@@ -93,6 +96,8 @@ def __init__(self, framework_specific_info):
config_file = 'onnxrt_trt.yaml'
elif self.backend == 'CUDAExecutionProvider':
config_file = 'onnxrt_cuda.yaml'
+ elif self.backend == 'DnnlExecutionProvider':
+ config_file = 'onnxrt_dnnl.yaml'
else: # pragma: no cover
assert False, "{} provider is not supported in current environment, " \
"supported providers: {}".format(self.backend,
@@ -207,6 +212,16 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
return model
if model.model.opset_import[0].version < 11: # pragma: no cover
logger.warning("Quantize input needs model opset 11 or newer.")
+ if self.backend == 'DnnlExecutionProvider' and \
+ any([i.domain in ['', 'ai.onnx'] and i.version < 15 for i in model.model.opset_import]):
+ from onnx import version_converter
+ try:
+ model.model = self._rename_node(version_converter.convert_version(model.model, 15))
+ except:
+ logging.warning("Fail to upgrade model opset_import to >= 15, "\
+ "please upgrate it manually to run with bf16 data type")
+ exit(0)
+
from neural_compressor.adaptor.ox_utils.util import QuantizationMode
if self.format == "qlinearops":
format = QuantizationMode.QLinearOps
@@ -590,9 +605,9 @@ def _detect_domain(self, model):
# typically, NLP models have multiple inputs,
# and the dimension of each input is usually 2 (batch_size, max_seq_len)
if not model.is_large_model:
- sess = ort.InferenceSession(model.model.SerializeToString(), providers=[self.backend])
+ sess = ort.InferenceSession(model.model.SerializeToString(), providers=ort.get_available_providers())
elif model.model_path is not None: # pragma: no cover
- sess = ort.InferenceSession(model.model_path, providers=[self.backend])
+ sess = ort.InferenceSession(model.model_path, providers=ort.get_available_providers())
else: # pragma: no cover
assert False, "Please use model path instead of onnx model object to quantize."
input_shape_lens = [len(input.shape) for input in sess.get_inputs()]
@@ -621,10 +636,10 @@ def _pre_optimize(self, model, level=1):
remove_init_from_model_input(model)
sess_options = ort.SessionOptions()
optimization_levels = {
- 'DISABLE_ALL': ort.GraphOptimizationLevel.ORT_DISABLE_ALL,
- 'ENABLE_BASIC': ort.GraphOptimizationLevel.ORT_ENABLE_BASIC,
- 'ENABLE_EXTENDED': ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED,
- 'ENABLE_ALL': ort.GraphOptimizationLevel.ORT_ENABLE_ALL}
+ 'DISABLE_ALL': ort.GraphOptimizationLevel.ORT_DISABLE_ALL,
+ 'ENABLE_BASIC': ort.GraphOptimizationLevel.ORT_ENABLE_BASIC,
+ 'ENABLE_EXTENDED': ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED,
+ 'ENABLE_ALL': ort.GraphOptimizationLevel.ORT_ENABLE_ALL}
if not isinstance(self.query_handler.get_graph_optimization(), list):
level = self.query_handler.get_graph_optimization()
elif options.onnxrt.graph_optimization.level is not None:
@@ -644,19 +659,19 @@ def _pre_optimize(self, model, level=1):
if sys.version_info < (3,10) and find_spec('onnxruntime_extensions'): # pragma: no cover
from onnxruntime_extensions import get_library_path
sess_options.register_custom_ops_library(get_library_path())
- backend = self.backend if self.backend != 'TensorrtExecutionProvider' else 'CUDAExecutionProvider'
if not model.is_large_model:
ort.InferenceSession(model.model.SerializeToString(),
- sess_options,
- providers=[backend])
+ sess_options,
+ providers=['CPUExecutionProvider'])
elif model.model_path is not None: # pragma: no cover
ort.InferenceSession(model.model_path,
- sess_options,
- providers=[backend])
+ sess_options,
+ providers=['CPUExecutionProvider'])
else: # pragma: no cover
logger.warning('Please use model path instead of onnx model object to quantize')
tmp_model = onnx.load(sess_options.optimized_model_filepath, load_external_data=False)
+
if model.is_large_model: # pragma: no cover
from onnx.external_data_helper import load_external_data_for_model
load_external_data_for_model(tmp_model, os.path.split(model.model_path)[0])
@@ -859,10 +874,12 @@ def query_fw_capability(self, model):
precisions = query.get_precisions()
for precision in precisions:
- if precision in ['fp16', 'bf16'] and (self.device == 'cpu' or self.backend != 'CUDAExecutionProvider'):
+ if precision == 'fp16' and not self.use_fp16:
continue
- elif precision == 'bf16' and 'CUDAExecutionProvider' not in ort.get_available_providers():
+ if precision == 'bf16' and \
+ (not self.use_bf16 or (not CpuInfo().bf16 and os.getenv('FORCE_BF16') != '1')):
continue
+
# get supported optype for target precision
optypes = query.get_op_types_by_precision(precision) if \
query.get_op_types_by_precision(precision) != ['*'] else \
@@ -1064,6 +1081,8 @@ def _optypewise_filter_for_qdq(self, optype_wise):
'1.11.0': ['Conv', 'Gather', 'MatMul', 'Gemm'],
'1.12.0': ['Conv', 'Gather', 'MatMul', 'Gemm']}
specific_cfg_version = self.query_handler.get_specific_cfg_version()
+ if Version(specific_cfg_version) > ONNXRT112_VERSION:
+ specific_cfg_version = '1.12.0'
for optype, caps in optype_wise.items():
if optype not in supported_perchannel_optypes[specific_cfg_version]:
for cap in caps:
diff --git a/neural_compressor/adaptor/onnxrt_dnnl.yaml b/neural_compressor/adaptor/onnxrt_dnnl.yaml
new file mode 100644
index 00000000000..2d2d718130c
--- /dev/null
+++ b/neural_compressor/adaptor/onnxrt_dnnl.yaml
@@ -0,0 +1,417 @@
+## Copyright (c) 2021 Intel Corporation
+##
+## Licensed under the Apache License, Version 2.0 (the "License");
+## you may not use this file except in compliance with the License.
+## You may obtain a copy of the License at
+##
+## http://www.apache.org/licenses/LICENSE-2.0
+##
+## Unless required by applicable law or agreed to in writing, software
+## distributed under the License is distributed on an "AS IS" BASIS,
+## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+## See the License for the specific language governing permissions and
+## limitations under the License.
+##
+#
+
+-
+ version:
+ name: '1.6.0'
+ int8: &ref_1_6 {
+ 'static': &ref_1_6_static {
+ 'Conv': {
+ 'weight': &int8_sym_perchanneltensor_minmax {
+ 'dtype': ['int8'],
+ 'scheme': ['sym'],
+ 'granularity': ['per_channel', 'per_tensor'],
+ 'algorithm': ['minmax']
+ },
+ 'activation': &uint8_asym_pertensor_minmax {
+ 'dtype': ['uint8'],
+ 'scheme': ['asym'],
+ 'granularity': ['per_tensor'],
+ 'algorithm': ['minmax']
+ },
+ 'mode': ['QDQ', 'QLinear']
+ },
+ 'FusedConv': {
+ 'weight': *int8_sym_perchanneltensor_minmax, #'QDQ': *int8_sym_pertensor_minmax
+ 'activation': *uint8_asym_pertensor_minmax,
+ 'mode': ['QDQ', 'QLinear']
+ },
+ 'Gather': {
+ 'weight': &uint8_asym_perchanneltensor_minmax {
+ 'dtype': ['uint8'],
+ 'scheme': ['asym'],
+ 'granularity': ['per_channel', 'per_tensor'],
+ 'algorithm': ['minmax']
+ },
+ 'activation': *uint8_asym_pertensor_minmax,
+ 'mode': ['QDQ', 'QLinear']
+ },
+ 'MatMul': {
+ 'weight': &int8_sym_pertensor_minmax {
+ 'dtype': ['int8'],
+ 'scheme': ['sym'],
+ 'granularity': ['per_tensor'],
+ 'algorithm': ['minmax']
+ },
+ 'activation': *uint8_asym_pertensor_minmax,
+ 'mode': ['QDQ', 'QLinear']
+ },
+ 'Attention': &default_static_qlinear_qdq {
+ 'weight': *int8_sym_pertensor_minmax,
+ 'activation': *uint8_asym_pertensor_minmax,
+ 'mode': ['QDQ', 'QLinear']
+ },
+ 'Mul': &default_static_qlinear {
+ 'weight': *int8_sym_pertensor_minmax,
+ 'activation': *uint8_asym_pertensor_minmax,
+ 'mode': ['QLinear']
+ },
+ 'Relu': *default_static_qlinear_qdq,
+ 'Clip': *default_static_qlinear_qdq,
+ 'LeakyRelu': *default_static_qlinear_qdq,
+ 'Sigmoid': *default_static_qlinear_qdq,
+ 'MaxPool': *default_static_qlinear_qdq,
+ 'EmbedLayerNormalization': *default_static_qlinear_qdq,
+ 'GlobalAveragePool': *default_static_qlinear_qdq,
+ 'Add': *default_static_qlinear,
+ },
+ 'dynamic': &ref_1_6_dynamic {
+ 'Conv': {
+ 'weight': *uint8_asym_perchanneltensor_minmax,
+ 'activation': *uint8_asym_pertensor_minmax
+ },
+ 'FusedConv': {
+ 'weight': *uint8_asym_perchanneltensor_minmax,
+ 'activation': *uint8_asym_pertensor_minmax
+ },
+ 'MatMul': &default_dynamic {
+ 'weight': *int8_sym_pertensor_minmax,
+ 'activation': *uint8_asym_pertensor_minmax
+ },
+ 'Gather': *default_dynamic,
+ 'Attention': *default_dynamic,
+ 'EmbedLayerNormalization': *default_dynamic,
+ 'LSTM': *default_dynamic,
+ }
+ }
+ recipes: &default_optimization
+ graph_optimization: # from onnxruntime graph_optimization_level
+ level: ['DISABLE_ALL', 'ENABLE_BASIC', 'ENABLE_EXTENDED', 'ENABLE_ALL']
+
+-
+ version:
+ name: '1.7.0'
+ int8: {
+ 'static': {
+ 'FusedConv': {
+ 'weight': *int8_sym_perchanneltensor_minmax, #'QDQ': *int8_sym_pertensor_minmax
+ 'activation': *uint8_asym_pertensor_minmax,
+ 'mode': ['QDQ', 'QLinear']
+ },
+ 'Conv': {
+ 'weight': *int8_sym_perchanneltensor_minmax,
+ 'activation': *uint8_asym_pertensor_minmax,
+ 'mode': ['QDQ', 'QLinear']
+ },
+ 'Gather': {
+ 'weight': *uint8_asym_perchanneltensor_minmax,
+ 'activation': *uint8_asym_pertensor_minmax,
+ 'mode': ['QDQ', 'QLinear']
+ },
+ 'MatMul': *default_static_qlinear_qdq,
+ 'Attention': *default_static_qlinear_qdq,
+ 'Mul': *default_static_qlinear,
+ 'Relu': *default_static_qlinear_qdq,
+ 'Clip': *default_static_qlinear_qdq,
+ 'LeakyRelu': *default_static_qlinear_qdq,
+ 'Sigmoid': *default_static_qlinear_qdq,
+ 'MaxPool': *default_static_qlinear_qdq,
+ 'EmbedLayerNormalization': *default_static_qlinear_qdq,
+ 'GlobalAveragePool': *default_static_qlinear_qdq,
+ 'Pad': *default_static_qlinear_qdq,
+ 'Split': *default_static_qlinear_qdq,
+ 'Add': *default_static_qlinear,
+ },
+ 'dynamic': *ref_1_6_dynamic
+ }
+ recipes:
+ <<: *default_optimization
+
+-
+ version:
+ name: '1.8.0'
+ int8: {
+ 'static': {
+ 'FusedConv': {
+ 'weight': *int8_sym_perchanneltensor_minmax, # QDQ: *int8_sym_pertensor_minmax
+ 'activation': *uint8_asym_pertensor_minmax,
+ 'mode': ['QDQ', 'QLinear']
+ },
+ 'Conv': {
+ 'weight': *int8_sym_perchanneltensor_minmax,
+ 'activation': *uint8_asym_pertensor_minmax,
+ 'mode': ['QDQ', 'QLinear']
+ },
+ 'Gather': {
+ 'weight': *uint8_asym_perchanneltensor_minmax,
+ 'activation': *uint8_asym_pertensor_minmax,
+ 'mode': ['QDQ', 'QLinear']
+ },
+ 'MatMul': {
+ 'weight': *int8_sym_perchanneltensor_minmax, # QDQ: *int8_sym_pertensor_minmax
+ 'activation': *uint8_asym_pertensor_minmax,
+ 'mode': ['QDQ', 'QLinear']
+ },
+ 'Attention': *default_static_qlinear_qdq,
+ 'Mul': *default_static_qlinear,
+ 'Relu': *default_static_qlinear_qdq,
+ 'Clip': *default_static_qlinear_qdq,
+ 'LeakyRelu': *default_static_qlinear_qdq,
+ 'Sigmoid': *default_static_qlinear_qdq,
+ 'MaxPool': *default_static_qlinear_qdq,
+ 'EmbedLayerNormalization': *default_static_qlinear_qdq,
+ 'GlobalAveragePool': *default_static_qlinear_qdq,
+ 'Pad': *default_static_qlinear_qdq,
+ 'Split': *default_static_qlinear_qdq,
+ 'Add': *default_static_qlinear,
+ 'Squeeze': *default_static_qlinear_qdq,
+ 'Reshape': *default_static_qlinear_qdq,
+ 'Concat': *default_static_qlinear_qdq,
+ 'AveragePool': *default_static_qlinear_qdq,
+ 'Unsqueeze': *default_static_qlinear_qdq,
+ 'Transpose': *default_static_qlinear_qdq,
+ 'Resize': *default_static_qlinear_qdq,
+ },
+ 'dynamic': {
+ 'Conv': {
+ 'weight': *uint8_asym_perchanneltensor_minmax,
+ 'activation': *uint8_asym_pertensor_minmax
+ },
+ 'FusedConv': {
+ 'weight': *uint8_asym_perchanneltensor_minmax,
+ 'activation': *uint8_asym_pertensor_minmax
+ },
+ 'MatMul': {
+ 'weight': *int8_sym_perchanneltensor_minmax,
+ 'activation': *uint8_asym_pertensor_minmax
+ },
+ 'Gather': *default_dynamic,
+ 'Attention': *default_dynamic,
+ 'EmbedLayerNormalization': *default_dynamic,
+ 'LSTM': *default_dynamic,
+ }
+ }
+ recipes:
+ <<: *default_optimization
+
+-
+ version:
+ name: '1.9.0'
+ int8: {
+ 'static': {
+ 'FusedConv': {
+ 'weight': *int8_sym_perchanneltensor_minmax, # QDQ: *int8_sym_pertensor_minmax
+ 'activation': *uint8_asym_pertensor_minmax,
+ 'mode': ['QDQ', 'QLinear']
+ },
+ 'Conv': {
+ 'weight': *int8_sym_perchanneltensor_minmax,
+ 'activation': *uint8_asym_pertensor_minmax,
+ 'mode': ['QDQ', 'QLinear']
+ },
+ 'Gather': {
+ 'weight': *uint8_asym_perchanneltensor_minmax,
+ 'activation': *uint8_asym_pertensor_minmax,
+ 'mode': ['QDQ', 'QLinear']
+ },
+ 'MatMul': {
+ 'weight': *int8_sym_perchanneltensor_minmax, # QDQ: *int8_sym_pertensor_minmax
+ 'activation': *uint8_asym_pertensor_minmax,
+ 'mode': ['QDQ', 'QLinear']
+ },
+ 'EmbedLayerNormalization': {
+ 'weight': *uint8_asym_pertensor_minmax, # QDQ: *int8_sym_pertensor_minmax
+ 'activation': *uint8_asym_pertensor_minmax,
+ 'mode': ['QDQ', 'QLinear']
+ },
+ 'Attention': *default_static_qlinear_qdq,
+ 'Mul': *default_static_qlinear,
+ 'Relu': *default_static_qlinear_qdq,
+ 'Clip': *default_static_qlinear_qdq,
+ 'LeakyRelu': *default_static_qlinear_qdq,
+ 'Sigmoid': *default_static_qlinear_qdq,
+ 'MaxPool': *default_static_qlinear_qdq,
+ 'GlobalAveragePool': *default_static_qlinear_qdq,
+ 'Pad': *default_static_qlinear_qdq,
+ 'Split': *default_static_qlinear_qdq,
+ 'Add': *default_static_qlinear,
+ 'Squeeze': *default_static_qlinear_qdq,
+ 'Reshape': *default_static_qlinear_qdq,
+ 'Concat': *default_static_qlinear_qdq,
+ 'AveragePool': *default_static_qlinear_qdq,
+ 'Unsqueeze': *default_static_qlinear_qdq,
+ 'Transpose': *default_static_qlinear_qdq,
+ 'Resize': *default_static_qlinear_qdq,
+ },
+ 'dynamic': &ref_1_9_dynamic {
+ 'Conv': {
+ 'weight': *uint8_asym_pertensor_minmax,
+ 'activation': *uint8_asym_pertensor_minmax
+ },
+ 'FusedConv': {
+ 'weight': *uint8_asym_pertensor_minmax,
+ 'activation': *uint8_asym_pertensor_minmax
+ },
+ 'MatMul': {
+ 'weight': *int8_sym_perchanneltensor_minmax,
+ 'activation': *uint8_asym_pertensor_minmax
+ },
+ 'EmbedLayerNormalization': {
+ 'weight': *uint8_asym_pertensor_minmax,
+ 'activation': *uint8_asym_pertensor_minmax
+ },
+ 'Gather': *default_dynamic,
+ 'Attention': *default_dynamic,
+ 'LSTM': *default_dynamic,
+ }
+ }
+ recipes:
+ <<: *default_optimization
+
+-
+ version:
+ name: '1.10.0'
+ int8: {
+ 'static': {
+ 'FusedConv': {
+ 'weight': *int8_sym_perchanneltensor_minmax, # QDQ: *int8_sym_pertensor_minmax
+ 'activation': *uint8_asym_pertensor_minmax,
+ 'mode': ['QDQ', 'QLinear']
+ },
+ 'Conv': {
+ 'weight': *int8_sym_perchanneltensor_minmax,
+ 'activation': *uint8_asym_pertensor_minmax,
+ 'mode': ['QDQ', 'QLinear']
+ },
+ 'Gather': {
+ 'weight': *uint8_asym_perchanneltensor_minmax,
+ 'activation': *uint8_asym_pertensor_minmax,
+ 'mode': ['QDQ', 'QLinear']
+ },
+ 'MatMul': {
+ 'weight': *int8_sym_perchanneltensor_minmax,
+ 'activation': *uint8_asym_pertensor_minmax,
+ 'mode': ['QDQ', 'QLinear']
+ },
+ 'EmbedLayerNormalization': {
+ 'weight': *uint8_asym_pertensor_minmax, # QDQ: *int8_sym_pertensor_minmax
+ 'activation': *uint8_asym_pertensor_minmax,
+ 'mode': ['QDQ', 'QLinear']
+ },
+ 'Attention': *default_static_qlinear_qdq,
+ 'Mul': *default_static_qlinear,
+ 'Relu': *default_static_qlinear_qdq,
+ 'Clip': *default_static_qlinear_qdq,
+ 'LeakyRelu': *default_static_qlinear_qdq,
+ 'Sigmoid': *default_static_qlinear_qdq,
+ 'MaxPool': *default_static_qlinear_qdq,
+ 'GlobalAveragePool': *default_static_qlinear_qdq,
+ 'Pad': *default_static_qlinear_qdq,
+ 'Split': *default_static_qlinear_qdq,
+ 'Add': *default_static_qlinear,
+ 'Squeeze': *default_static_qlinear_qdq,
+ 'Reshape': *default_static_qlinear_qdq,
+ 'Concat': *default_static_qlinear_qdq,
+ 'AveragePool': *default_static_qlinear_qdq,
+ 'Unsqueeze': *default_static_qlinear_qdq,
+ 'Transpose': *default_static_qlinear_qdq,
+ 'Resize': *default_static_qlinear_qdq,
+ },
+ 'dynamic': *ref_1_9_dynamic
+ }
+ recipes:
+ <<: *default_optimization
+
+-
+ version:
+ name: '1.11.0'
+ int8: &ref_1_11 {
+ 'static': {
+ 'FusedConv': {
+ 'weight': *int8_sym_perchanneltensor_minmax, # QDQ: *int8_sym_pertensor_minmax
+ 'activation': *uint8_asym_pertensor_minmax,
+ 'mode': ['QDQ', 'QLinear']
+ },
+ 'Conv': {
+ 'weight': *int8_sym_perchanneltensor_minmax,
+ 'activation': *uint8_asym_pertensor_minmax,
+ 'mode': ['QDQ', 'QLinear']
+ },
+ 'Gather': {
+ 'weight': *uint8_asym_perchanneltensor_minmax,
+ 'activation': *uint8_asym_pertensor_minmax,
+ 'mode': ['QDQ', 'QLinear']
+ },
+ 'MatMul': {
+ 'weight': *int8_sym_perchanneltensor_minmax,
+ 'activation': *uint8_asym_pertensor_minmax,
+ 'mode': ['QDQ', 'QLinear']
+ },
+ 'Gemm': {
+ 'weight': *int8_sym_perchanneltensor_minmax,
+ 'activation': *uint8_asym_pertensor_minmax,
+ 'mode': ['QDQ', 'QLinear']
+ },
+ 'EmbedLayerNormalization': {
+ 'weight': *uint8_asym_pertensor_minmax, # QDQ: *int8_sym_pertensor_minmax
+ 'activation': *uint8_asym_pertensor_minmax,
+ 'mode': ['QDQ', 'QLinear']
+ },
+ 'Attention': *default_static_qlinear_qdq,
+ 'Mul': *default_static_qlinear,
+ 'Relu': *default_static_qlinear_qdq,
+ 'Clip': *default_static_qlinear_qdq,
+ 'LeakyRelu': *default_static_qlinear_qdq,
+ 'Sigmoid': *default_static_qlinear_qdq,
+ 'MaxPool': *default_static_qlinear_qdq,
+ 'GlobalAveragePool': *default_static_qlinear_qdq,
+ 'Pad': *default_static_qlinear_qdq,
+ 'Split': *default_static_qlinear_qdq,
+ 'Add': *default_static_qlinear,
+ 'Squeeze': *default_static_qlinear_qdq,
+ 'Reshape': *default_static_qlinear_qdq,
+ 'Concat': *default_static_qlinear_qdq,
+ 'AveragePool': *default_static_qlinear_qdq,
+ 'Unsqueeze': *default_static_qlinear_qdq,
+ 'Transpose': *default_static_qlinear_qdq,
+ 'ArgMax': *default_static_qlinear,
+ 'Resize': *default_static_qlinear_qdq,
+ },
+ 'dynamic': *ref_1_9_dynamic
+ }
+ recipes:
+ <<: *default_optimization
+
+-
+ version:
+ name: '1.14.0'
+ int8: *ref_1_11
+ bf16: &common_bf16 ['MatMul', 'Gemm', 'BatchNormalization', 'Softmax', 'Sum',
+ 'Abs', 'BiasGelu', 'Exp', 'FastGelu', 'Gelu', 'Log', 'Relu', 'Round', 'Sigmoid',
+ 'Sqrt', 'Tanh', 'Add', 'Sub', 'Mul', 'Div', 'Pow', 'ReduceMean', 'Equal',
+ 'FusedMatMul', 'Greater', 'GreaterOrEqual', 'LeakyRelu', 'Less', 'LessOrEqual',
+ 'Reshape', 'Squeeze', 'Transpose', 'Unsqueeze', 'ReduceL1', 'ReduceL2', 'ReduceLogSum',
+ 'ReduceLogSumExp', 'ReduceMax', 'ReduceProd', 'ReduceSum', 'ReduceSumSquare',
+ 'LayerNormalization', 'Concat']
+ recipes:
+ <<: *default_optimization
+
+-
+ version:
+ name: 'default'
+ int8: *ref_1_6
+ recipes:
+ <<: *default_optimization
diff --git a/neural_compressor/adaptor/ox_utils/calibration.py b/neural_compressor/adaptor/ox_utils/calibration.py
index a7a7756aaa1..0bbc3acef4b 100644
--- a/neural_compressor/adaptor/ox_utils/calibration.py
+++ b/neural_compressor/adaptor/ox_utils/calibration.py
@@ -472,7 +472,7 @@ def calculate_quantization_params(self, q_config, quantization_thresholds):
if tensor_name in output_name_to_nodes:
parent = output_name_to_nodes[tensor_name]
if parent and parent.name in q_config and \
- q_config[parent.name] not in ['fp32', 'fp16']:
+ q_config[parent.name] not in ['fp32', 'fp16', 'bf16']:
scheme = q_config[parent.name]['activation']['scheme']
qType = q_config[parent.name]['activation']['dtype']
elif self.backend in ['TensorrtExecutionProvider']:
diff --git a/neural_compressor/adaptor/ox_utils/quantizer.py b/neural_compressor/adaptor/ox_utils/quantizer.py
index 9749391ebb2..c6328112724 100644
--- a/neural_compressor/adaptor/ox_utils/quantizer.py
+++ b/neural_compressor/adaptor/ox_utils/quantizer.py
@@ -246,8 +246,11 @@ def merge_dedicated_qdq_pair(self):
def should_cast(self, node):
"""Check if node should be casted."""
if node.name in self.config and self.config[node.name] != 'fp32': # pragma: no cover
- return True
- else:
+ parent = self.model.get_parent(node, 0)
+ if parent is not None and (parent.op_type != 'Cast' or parent.attribute[0].i in [1, 10, 16]):
+ return True
+ elif parent is None and node.input[0] in self.model.input():
+ return True
return False
def insert_qdq(self):
diff --git a/neural_compressor/adaptor/ox_utils/util.py b/neural_compressor/adaptor/ox_utils/util.py
index de369ca33a0..0040db9f21e 100644
--- a/neural_compressor/adaptor/ox_utils/util.py
+++ b/neural_compressor/adaptor/ox_utils/util.py
@@ -70,13 +70,15 @@
PROVIDERS = {
'default': 'CPUExecutionProvider',
'onnxrt_trt_ep': 'TensorrtExecutionProvider',
+ 'onnxrt_dnnl_ep': 'DnnlExecutionProvider',
'onnxrt_cuda_ep': 'CUDAExecutionProvider',
}
ONNXRT_BACKENDS = {
'CPUExecutionProvider': 'default',
'TensorrtExecutionProvider': 'onnxrt_trt_ep',
- 'CUDAExecutionProvider': 'onnxrt_cuda_ep'
+ 'CUDAExecutionProvider': 'onnxrt_cuda_ep',
+ 'DnnlExecutionProvider': 'onnxrt_dnnl_ep'
}
def dtype_to_name(dtype_mapping, dtype):
diff --git a/neural_compressor/config.py b/neural_compressor/config.py
index a0093899613..60771d9c7ba 100644
--- a/neural_compressor/config.py
+++ b/neural_compressor/config.py
@@ -259,7 +259,8 @@ class BenchmarkConfig:
inputs (list, optional): A list of strings containing the inputs of model. Default is an empty list.
outputs (list, optional): A list of strings containing the outputs of model. Default is an empty list.
backend (str, optional): Backend name for model execution. Supported values include: 'default', 'itex',
- 'ipex', 'onnxrt_trt_ep', 'onnxrt_cuda_ep'. Default value is 'default'.
+ 'ipex', 'onnxrt_trt_ep', 'onnxrt_cuda_ep', 'onnxrt_dnnl_ep'.
+ Default value is 'default'.
warmup (int, optional): The number of iterations to perform warmup before running performance tests.
Default value is 5.
iteration (int, optional): The number of iterations to run performance tests. Default is -1.
@@ -327,7 +328,7 @@ def backend(self):
def backend(self, backend):
"""Set backend."""
if _check_value('backend', backend, str, [
- 'default', 'itex', 'ipex', 'onnxrt_trt_ep', 'onnxrt_cuda_ep']):
+ 'default', 'itex', 'ipex', 'onnxrt_trt_ep', 'onnxrt_cuda_ep', 'onnxrt_dnnl_ep']):
self._backend = backend
@property
@@ -692,7 +693,8 @@ class _BaseQuantizationConfig:
Args:
inputs: Inputs of model, only required in tensorflow.
outputs: Outputs of model, only required in tensorflow.
- backend: Backend for model execution. Support 'default', 'itex', 'ipex', 'onnxrt_trt_ep', 'onnxrt_cuda_ep'
+ backend: Backend for model execution.
+ Support 'default', 'itex', 'ipex', 'onnxrt_trt_ep', 'onnxrt_cuda_ep', 'onnxrt_dnnl_ep'
domain: Model domain. Support 'auto', 'cv', 'object_detection', 'nlp' and 'recommendation_system'.
Adaptor will use specific quantization settings for different domains automatically, and
explicitly specified quantization settings will override the automatic setting.
@@ -1058,7 +1060,7 @@ def backend(self):
@backend.setter
def backend(self, backend):
if _check_value('backend', backend, str, [
- 'default', 'itex', 'ipex', 'onnxrt_trt_ep', 'onnxrt_cuda_ep']):
+ 'default', 'itex', 'ipex', 'onnxrt_trt_ep', 'onnxrt_cuda_ep', 'onnxrt_dnnl_ep']):
self._backend = backend
@property
@@ -1103,7 +1105,8 @@ class PostTrainingQuantConfig(_BaseQuantizationConfig):
Args:
device: Support 'cpu' and 'gpu'.
- backend: Backend for model execution. Support 'default', 'itex', 'ipex', 'onnxrt_trt_ep', 'onnxrt_cuda_ep'
+ backend: Backend for model execution.
+ Support 'default', 'itex', 'ipex', 'onnxrt_trt_ep', 'onnxrt_cuda_ep', 'onnxrt_dnnl_ep'
domain: Model domain. Support 'auto', 'cv', 'object_detection', 'nlp' and 'recommendation_system'.
Adaptor will use specific quantization settings for different domains automatically, and
explicitly specified quantization settings will override the automatic setting.
@@ -1262,7 +1265,8 @@ class QuantizationAwareTrainingConfig(_BaseQuantizationConfig):
Args:
device: Support 'cpu' and 'gpu'.
- backend: Backend for model execution. Support 'default', 'itex', 'ipex', 'onnxrt_trt_ep', 'onnxrt_cuda_ep'
+ backend: Backend for model execution.
+ Support 'default', 'itex', 'ipex', 'onnxrt_trt_ep', 'onnxrt_cuda_ep', 'onnxrt_dnnl_ep'
inputs: Inputs of model, only required in tensorflow.
outputs: Outputs of model, only required in tensorflow.
op_type_dict: Tuning constraints on optype-wise for advance user to reduce tuning space.
@@ -1704,8 +1708,8 @@ class MixedPrecisionConfig(object):
device (str, optional): Device for execution.
Support 'cpu' and 'gpu', default is 'cpu'.
backend (str, optional): Backend for model execution.
- Support 'default', 'itex', 'ipex', 'onnxrt_trt_ep', 'onnxrt_cuda_ep',
- default is 'default', 'ipex' doesn't support tune.
+ Support 'default', 'itex', 'ipex', 'onnxrt_trt_ep', 'onnxrt_cuda_ep', 'onnxrt_dnnl_ep'
+ default is 'default'.
precisions ([str, list], optional): Target precision for mix precision conversion.
Support 'bf16' and 'fp16', default is 'bf16'.
model_name (str, optional): The name of the model. Default value is empty.
@@ -1864,7 +1868,7 @@ def backend(self):
def backend(self, backend):
"""Set backend."""
if _check_value('backend', backend, str, [
- 'default', 'itex', 'ipex', 'onnxrt_trt_ep', 'onnxrt_cuda_ep']):
+ 'default', 'itex', 'ipex', 'onnxrt_trt_ep', 'onnxrt_cuda_ep', 'onnxrt_dnnl_ep']):
self._backend = backend
@property
diff --git a/neural_compressor/mix_precision.py b/neural_compressor/mix_precision.py
index 057344591ab..2444466f691 100644
--- a/neural_compressor/mix_precision.py
+++ b/neural_compressor/mix_precision.py
@@ -98,11 +98,14 @@ def fit(model,
precisions = list(set(conf.precisions) - set(conf.excluded_precisions))
if ('bf16' in precisions or 'fp16' in precisions) and conf.framework == "onnxruntime": # pragma: no cover
- if conf.device == "cpu":
- logger.warning("Mix precision exits due to device isn't gpu for onnx models.")
+ if 'fp16' in precisions and not (conf.device == "gpu" and conf.backend == "onnxrt_cuda_ep"):
+ logger.warning("Mix precision exits due to fp16 for onnx models" \
+ "needs 'gpu' device and 'onnxrt_cuda_ep' backend.")
sys.exit(0)
- elif conf.backend != "onnxrt_cuda_ep":
- logger.warning("Mix precision exits due to backend isn't onnxrt_cuda_ep for onnx models.")
+ elif 'bf16' in precisions and (not (conf.backend == "onnxrt_cuda_ep" and conf.device == "gpu") and \
+ not (conf.backend == "onnxrt_dnnl_ep" and conf.device == "cpu")):
+ logger.warning("Mix precision exits due to bf16 for onnx models needs " \
+ "'gpu' device and 'onnxrt_cuda_ep' backend, or 'cpu' device and 'onnxrt_dnnl_ep' backend.")
sys.exit(0)
elif 'bf16' in precisions and not CpuInfo().bf16 and conf.framework != "onnxruntime": # pragma: no cover
if os.getenv('FORCE_BF16') == '1':
diff --git a/neural_compressor/model/model.py b/neural_compressor/model/model.py
index e73efbb4feb..66632ece602 100644
--- a/neural_compressor/model/model.py
+++ b/neural_compressor/model/model.py
@@ -79,9 +79,9 @@ def _is_onnxruntime(model):
from onnxruntime_extensions import get_library_path
so.register_custom_ops_library(get_library_path())
if isinstance(model, str):
- ort.InferenceSession(model, so, providers=['CPUExecutionProvider'])
+ ort.InferenceSession(model, so, providers=ort.get_available_providers())
else:
- ort.InferenceSession(model.SerializeToString(), so, providers=['CPUExecutionProvider'])
+ ort.InferenceSession(model.SerializeToString(), so, providers=ort.get_available_providers())
except Exception as e: # pragma: no cover
if 'Message onnx.ModelProto exceeds maximum protobuf size of 2GB' in str(e):
logger.warning('Please use model path instead of onnx model object to quantize')
diff --git a/neural_compressor/strategy/strategy.py b/neural_compressor/strategy/strategy.py
index a947d0a3673..d7c560154c1 100644
--- a/neural_compressor/strategy/strategy.py
+++ b/neural_compressor/strategy/strategy.py
@@ -1253,6 +1253,11 @@ def _set_framework_info(self, q_dataloader, q_func=None):
framework_specific_info['backend'] == 'onnxrt_trt_ep':
framework_specific_info.update({'format': 'QDQ'})
framework = 'onnxrt_qdq'
+ if framework_specific_info['backend'] == 'onnxrt_cuda_ep' and self.config.device =='gpu':
+ framework_specific_info['use_fp16'] = True
+ framework_specific_info['use_bf16'] = True
+ if framework_specific_info['backend'] == 'onnxrt_dnnl_ep' and self.config.device == 'cpu':
+ framework_specific_info['use_bf16'] = True
if framework == 'pytorch_ipex' or framework == 'pytorch' or framework == 'pytorch_fx':
if self.config.backend == 'ipex':
framework = 'pytorch_ipex'
diff --git a/test/adaptor/onnxrt_adaptor/test_onnxrt_operators.py b/test/adaptor/onnxrt_adaptor/test_onnxrt_operators.py
index fc29326a3ca..4b99baf1e7f 100644
--- a/test/adaptor/onnxrt_adaptor/test_onnxrt_operators.py
+++ b/test/adaptor/onnxrt_adaptor/test_onnxrt_operators.py
@@ -10,6 +10,7 @@
from neural_compressor.adaptor.ox_utils.util import QuantizedInitializer, QuantizedValue, QuantizationMode
import onnxruntime as ort
from neural_compressor.config import ONNXQlinear2QDQConfig
+from neural_compressor.utils.utility import CpuInfo
def build_model():
initializers = []
@@ -1174,6 +1175,176 @@ def test_fp16(self):
session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=['CUDAExecutionProvider'])
outputs = session.run(None, input_data)
+ def get_bf16_mixed_precision_model(self, model):
+ from neural_compressor import MixedPrecisionConfig
+ from neural_compressor.mix_precision import fit
+ config = MixedPrecisionConfig(backend='onnxrt_dnnl_ep', precision='bf16')
+ converted_model = fit(model, config)
+ return converted_model
+
+ @unittest.skipIf(not CpuInfo().bf16 or 'DnnlExecutionProvider' not in ort.get_all_providers(),
+ "skip since DnnlExecutionProvider is not supported")
+ def test_bf16(self):
+ optypes = ['Sum', 'Sub', 'Div', 'Pow', 'Add']
+ for optype in optypes:
+ inps = [['input1', TensorProto.FLOAT, (1,2)]]
+ outs = [['output', TensorProto.FLOAT, (1,2)]]
+ weights = [['input2', TensorProto.FLOAT, (1,2), np.random.random((2))]]
+ node_infos = [['test', ['input1', 'input2'], ['output'], optype]]
+ model = self.build_model(inps, outs, weights, node_infos)
+ input_data = self.build_test_data(['input1'], [(1,2)], ['float32'])
+ convert_model = self.get_bf16_mixed_precision_model(model)
+ self.assertTrue('Cast' in set([i.op_type for i in convert_model.nodes()]))
+ self.assertTrue(16 in set([i.attribute[0].i for i in convert_model.nodes() if i.op_type == 'Cast']))
+ session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=['DnnlExecutionProvider'])
+ outputs = session.run(None, input_data)
+
+ optypes = ['Equal', 'Greater', 'GreaterOrEqual', 'Less', 'LessOrEqual']
+ for optype in optypes:
+ inps = [['input1', TensorProto.FLOAT, (1,2)]]
+ outs = [['output', TensorProto.BOOL, (1,2)]]
+ weights = [['input2', TensorProto.FLOAT, (1,2), np.random.random((2))]]
+ node_infos = [['test', ['input1', 'input2'], ['output'], optype]]
+ model = self.build_model(inps, outs, weights, node_infos)
+ input_data = self.build_test_data(['input1'], [(1,2)], ['float32'])
+ convert_model = self.get_bf16_mixed_precision_model(model)
+ self.assertTrue('Cast' in set([i.op_type for i in convert_model.nodes()]))
+ self.assertTrue(16 in set([i.attribute[0].i for i in convert_model.nodes() if i.op_type == 'Cast']))
+ session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=['DnnlExecutionProvider'])
+ outputs = session.run(None, input_data)
+
+ optypes = ['Abs', 'Exp', 'Log', 'Round', 'Sqrt', 'Softmax', 'Exp', 'Tanh', 'Sigmoid', 'LeakyRelu', 'Round']
+ for optype in optypes:
+ inps = [['input1', TensorProto.FLOAT, (1,2)]]
+ outs = [['output', TensorProto.FLOAT, (1,2)]]
+ node_infos = [['test', ['input1'], ['output'], optype]]
+ model = self.build_model(inps, outs, [], node_infos)
+ input_data = self.build_test_data(['input1'], [(1,2)], ['float32'])
+ convert_model = self.get_bf16_mixed_precision_model(model)
+ self.assertTrue('Cast' in set([i.op_type for i in convert_model.nodes()]))
+ self.assertTrue(16 in set([i.attribute[0].i for i in convert_model.nodes() if i.op_type == 'Cast']))
+ session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=['DnnlExecutionProvider'])
+ outputs = session.run(None, input_data)
+
+ optypes = ['ReduceMean', 'ReduceL1', 'ReduceL2', 'ReduceLogSum', 'ReduceLogSumExp', 'ReduceMax', 'ReduceProd', \
+ 'ReduceSum', 'ReduceSumSquare']
+ for optype in optypes:
+ inps = [['input1', TensorProto.FLOAT, (1,2)]]
+ outs = [['output', TensorProto.FLOAT, (1,1)]]
+ node_infos = [['test', ['input1'], ['output'], optype]]
+ model = self.build_model(inps, outs, [], node_infos)
+ input_data = self.build_test_data(['input1'], [(1,2)], ['float32'])
+ convert_model = self.get_bf16_mixed_precision_model(model)
+ self.assertTrue('Cast' in set([i.op_type for i in convert_model.nodes()]))
+ self.assertTrue(16 in set([i.attribute[0].i for i in convert_model.nodes() if i.op_type == 'Cast']))
+ session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=['DnnlExecutionProvider'])
+ outputs = session.run(None, input_data)
+
+ optypes = ['Gelu']
+ for optype in optypes:
+ inps = [['input1', TensorProto.FLOAT, (1,2)]]
+ outs = [['output', TensorProto.FLOAT, (1,2)]]
+ node_infos = [['test', ['input1'], ['output'], optype, 'com.microsoft']]
+ model = self.build_model(inps, outs, [], node_infos)
+ input_data = self.build_test_data(['input1'], [(1,2)], ['float32'])
+ convert_model = self.get_bf16_mixed_precision_model(model)
+ self.assertTrue('Cast' in set([i.op_type for i in convert_model.nodes()]))
+ self.assertTrue(16 in set([i.attribute[0].i for i in convert_model.nodes() if i.op_type == 'Cast']))
+ session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=['DnnlExecutionProvider'])
+ outputs = session.run(None, input_data)
+
+ optypes = ['BiasGelu', 'FastGelu']
+ for optype in optypes:
+ inps = [['input1', TensorProto.FLOAT, [2]]]
+ outs = [['output', TensorProto.FLOAT, [2]]]
+ weights = [['input2', TensorProto.FLOAT, [2], np.random.random((2))]]
+ node_infos = [['test', ['input1', 'input2'], ['output'], optype, 'com.microsoft']]
+ model = self.build_model(inps, outs, weights, node_infos)
+ input_data = self.build_test_data(['input1'], [(2)], ['float32'])
+ convert_model = self.get_bf16_mixed_precision_model(model)
+ self.assertTrue('Cast' in set([i.op_type for i in convert_model.nodes()]))
+ self.assertTrue(16 in set([i.attribute[0].i for i in convert_model.nodes() if i.op_type == 'Cast']))
+ session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=['DnnlExecutionProvider'])
+ outputs = session.run(None, input_data)
+
+
+ optypes = ['MatMul']
+ for optype in optypes:
+ inps = [['input1', TensorProto.FLOAT, (1,2)]]
+ outs = [['output', TensorProto.FLOAT, (1,1)]]
+ weights = [['input2', TensorProto.FLOAT, (2,1), np.random.random((2))]]
+ node_infos = [['test', ['input1', 'input2'], ['output'], optype]]
+ model = self.build_model(inps, outs, weights, node_infos)
+ input_data = self.build_test_data(['input1'], [(1,2)], ['float32'])
+ convert_model = self.get_bf16_mixed_precision_model(model)
+ self.assertTrue('Cast' in set([i.op_type for i in convert_model.nodes()]))
+ self.assertTrue(16 in set([i.attribute[0].i for i in convert_model.nodes() if i.op_type == 'Cast']))
+ session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=['DnnlExecutionProvider'])
+ outputs = session.run(None, input_data)
+
+ optypes = ['FusedMatMul']
+ for optype in optypes:
+ inps = [['input1', TensorProto.FLOAT, (1,2)]]
+ outs = [['output', TensorProto.FLOAT, (1,1)]]
+ weights = [['input2', TensorProto.FLOAT, (2,1), np.random.random((2))]]
+ node_infos = [['test', ['input1', 'input2'], ['output'], optype, 'com.microsoft']]
+ model = self.build_model(inps, outs, weights, node_infos)
+ ort.InferenceSession(model.SerializeToString())
+ input_data = self.build_test_data(['input1'], [(1,2)], ['float32'])
+ convert_model = self.get_bf16_mixed_precision_model(model)
+ self.assertTrue('Cast' in set([i.op_type for i in convert_model.nodes()]))
+ self.assertTrue(16 in set([i.attribute[0].i for i in convert_model.nodes() if i.op_type == 'Cast']))
+ session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=['DnnlExecutionProvider'])
+ outputs = session.run(None, input_data)
+
+ optypes = ['Gemm']
+ for optype in optypes:
+ inps = [['input1', TensorProto.FLOAT, (1,2)]]
+ outs = [['output', TensorProto.FLOAT, (1,2)]]
+ weights = [['input2', TensorProto.FLOAT, (2,1), np.random.random((2))],
+ ['input3', TensorProto.FLOAT, [], np.random.random((1))]]
+ node_infos = [['test', ['input1', 'input2', 'input3'], ['output'], optype]]
+ model = self.build_model(inps, outs, weights, node_infos)
+ input_data = self.build_test_data(['input1'], [(1,2)], ['float32'])
+ convert_model = self.get_bf16_mixed_precision_model(model)
+ self.assertTrue('Cast' in set([i.op_type for i in convert_model.nodes()]))
+ self.assertTrue(16 in set([i.attribute[0].i for i in convert_model.nodes() if i.op_type == 'Cast']))
+ session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=['DnnlExecutionProvider'])
+ outputs = session.run(None, input_data)
+
+ optypes = ['LayerNormalization']
+ for optype in optypes:
+ inps = [['input1', TensorProto.FLOAT, (1,2)]]
+ outs = [['output1', TensorProto.FLOAT, (1,2)], ['output2', TensorProto.FLOAT, (1,2)], ['output3', TensorProto.FLOAT, (1,2)]]
+ weights = [['input2', TensorProto.FLOAT, (2,1), np.random.random((2))],
+ ['input3', TensorProto.FLOAT, (2,1), np.random.random((2))]]
+ node_infos = [['test', ['input1', 'input2', 'input3'], ['output1', 'output2', 'output3'], optype]]
+ model = self.build_model(inps, outs, weights, node_infos)
+ input_data = self.build_test_data(['input1'], [(1,2)], ['float32'])
+ convert_model = self.get_bf16_mixed_precision_model(model)
+ self.assertTrue('Cast' in set([i.op_type for i in convert_model.nodes()]))
+ self.assertTrue(16 in set([i.attribute[0].i for i in convert_model.nodes() if i.op_type == 'Cast']))
+ session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=['DnnlExecutionProvider'])
+ outputs = session.run(None, input_data)
+
+ optypes = ['BatchNormalization']
+ for optype in optypes:
+ inps = [['input1', TensorProto.FLOAT, [1, 2]]]
+ outs = [['output1', TensorProto.FLOAT, [1, 2]]]
+ weights = [['input2', TensorProto.FLOAT, [2], np.random.random((2))],
+ ['input3', TensorProto.FLOAT, [2], np.random.random((2))],
+ ['input4', TensorProto.FLOAT, [2], np.random.random((2))],
+ ['input5', TensorProto.FLOAT, [2], np.random.random((2))],]
+ node_infos = [['test', ['input1', 'input2', 'input3', 'input4', 'input5'], ['output1'], optype]]
+ model = self.build_model(inps, outs, weights, node_infos)
+ ort.InferenceSession(model.SerializeToString())
+ input_data = self.build_test_data(['input1'], [(1,2)], ['float32'])
+ convert_model = self.get_bf16_mixed_precision_model(model)
+ self.assertTrue('Cast' in set([i.op_type for i in convert_model.nodes()]))
+ self.assertTrue(16 in set([i.attribute[0].i for i in convert_model.nodes() if i.op_type == 'Cast']))
+ session = ort.InferenceSession(convert_model.model.SerializeToString(), providers=['DnnlExecutionProvider'])
+ outputs = session.run(None, input_data)
+
if __name__ == "__main__":
unittest.main()