diff --git a/.azure-pipelines/scripts/models/env_setup.sh b/.azure-pipelines/scripts/models/env_setup.sh
index 99ef12b6355..0755afa8130 100644
--- a/.azure-pipelines/scripts/models/env_setup.sh
+++ b/.azure-pipelines/scripts/models/env_setup.sh
@@ -78,7 +78,7 @@ if [[ "${inc_new_api}" == "false" ]]; then
fi
cd ${model_src_dir}
-pip install ruamel_yaml
+pip install ruamel.yaml==0.17.40
pip install psutil
pip install protobuf==4.23.4
if [[ "${framework}" == "tensorflow" ]]; then
diff --git a/docs/source/quantization_weight_only.md b/docs/source/quantization_weight_only.md
index b26c5194aa3..41a02e0d460 100644
--- a/docs/source/quantization_weight_only.md
+++ b/docs/source/quantization_weight_only.md
@@ -129,6 +129,36 @@ torch.save(compressed_model.state_dict(), "compressed_model.pt")
The saved_results folder contains two files: `best_model.pt` and `qconfig.json`, and the generated q_model is a fake quantized model.
+
+### **WOQ algorithms tuning**
+
+To find the best algorithm, users can omit specifying a particular algorithm. In comparison to setting a specific algorithm, this tuning process will traverse through a set of pre-defined WOQ configurations and identify the optimal one with the best result. For details usage, please refer to the [tuning strategy](./tuning_strategies.md#Basic).
+
+> **Note:** Currently, this behavior is specific to the `ONNX Runtime` backend.
+
+**Pre-defined configurations**
+
+| WOQ configurations | setting |
+|:------------------:|:-------:|
+|RTN_G32ASYM| {"algorithm": "RTN", "group_size": 32, "scheme": "asym"}|
+|GPTQ_G32ASYM| {"algorithm": "GPTQ", "group_size": 32, "scheme": "asym"}|
+|GPTQ_G32ASYM_DISABLE_LAST_MATMUL| {"algorithm": "GPTQ", "group_size": 32, "scheme": "asym"}
& disable last MatMul|
+|GPTQ_G128ASYM| {"algorithm": "GPTQ", "group_size": 128, "scheme": "asym"}|
+|AWQ_G32ASYM| {"algorithm": "AWQ", "group_size": 32, "scheme": "asym"}|
+
+**User code example**
+
+```python
+conf = PostTrainingQuantConfig(
+ approach="weight_only",
+ quant_level="auto", # quant_level supports "auto" or 1 for woq config tuning
+)
+q_model = quantization.fit(model, conf, eval_func=eval_func, calib_dataloader=dataloader)
+q_model.save("saved_results")
+```
+
+Refer to this [link](../../examples/onnxrt/nlp/huggingface_model/text_generation/llama/quantization/weight_only) for an example of WOQ algorithms tuning on ONNX Llama models.
+
## Layer Wise Quantization
Large language models (LLMs) have shown exceptional performance across various tasks, meanwhile, the substantial parameter size poses significant challenges for deployment. Layer-wise quantization(LWQ) can greatly reduce the memory footprint of LLMs, usually 80-90% reduction, which means that users can quantize LLMs even on single node using GPU or CPU. We can quantize the model under memory-constrained devices, therefore making the huge-sized LLM quantization possible.
@@ -143,22 +173,19 @@ Large language models (LLMs) have shown exceptional performance across various t
|:--------------:|:----------:|
| RTN | ✔ |
| AWQ | ✕ |
-| GPTQ | ✕ |
+| GPTQ | ✔ |
| TEQ | ✕ |
### Example
```python
from neural_compressor import PostTrainingQuantConfig, quantization
-from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_shell
+from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_empty_model
-fp32_model = load_shell(model_name_or_path, AutoModelForCausalLM, torchscript=True)
+fp32_model = load_empty_model(model_name_or_path, torchscript=True)
conf = PostTrainingQuantConfig(
approach="weight_only",
recipes={
"layer_wise_quant": True,
- "layer_wise_quant_args": {
- "model_path": "facebook/opt-125m",
- },
"rtn_args": {"enable_full_range": True},
},
)
@@ -171,6 +198,7 @@ q_model = quantization.fit(
)
ouput_dir = "./saved_model"
q_model.save(ouput_dir)
+q_model = load(ouput_dir, fp32_model, weight_only=True, layer_wise=True)
```
## Reference
diff --git a/docs/source/tuning_strategies.md b/docs/source/tuning_strategies.md
index 31062a3e319..7788149f7e3 100644
--- a/docs/source/tuning_strategies.md
+++ b/docs/source/tuning_strategies.md
@@ -181,6 +181,8 @@ flowchart TD
> For [smooth quantization](./smooth_quant.md), users can tune the smooth quantization alpha by providing a list of scalars for the `alpha` item. The tuning process will take place at the **start stage** of the tuning procedure. For details usage, please refer to the [smooth quantization example](./smooth_quant.md#Example).
+> For [weight-only quantization](./quantization_weight_only.md), users can tune the weight-only algorithms from the available [pre-defined configurations](./quantization_weight_only.md#woq-algorithms-tuning). The tuning process will take place at the **start stage** of the tuning procedure, preceding the smooth quantization alpha tuning. For details usage, please refer to the [weight-only quantization example](./quantization_weight_only.md#woq-algorithms-tuning).
+*Please note that this behavior is specific to the `ONNX Runtime` backend.*
**1.** Default quantization
diff --git a/examples/.config/model_params_onnxrt.json b/examples/.config/model_params_onnxrt.json
index 8f84567f752..d0cb53b7fd7 100644
--- a/examples/.config/model_params_onnxrt.json
+++ b/examples/.config/model_params_onnxrt.json
@@ -322,6 +322,13 @@
"main_script": "main.py",
"batch_size": 1
},
+ "beit": {
+ "model_src_dir": "image_recognition/beit/quantization/ptq_static",
+ "dataset_location": "/tf_dataset/pytorch/ImageNet/raw",
+ "input_model": "/tf_dataset2/models/onnx/beit/beit_base_patch16_224_pt22k_ft22kto1k.onnx",
+ "main_script": "main.py",
+ "batch_size": 1
+ },
"mobilebert_squad_mlperf_qdq": {
"model_src_dir": "nlp/onnx_model_zoo/mobilebert/quantization/ptq_static",
"dataset_location": "/tf_dataset2/datasets/squad",
diff --git a/examples/README.md b/examples/README.md
index e8c9ea98b08..946947ebd40 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -1133,6 +1133,12 @@ IntelĀ® Neural Compressor validated examples with multiple compression technique
Post-Training Static Quantization |
qlinearops |
+
+ BEiT |
+ Image Recognition |
+ Post-Training Static Quantization |
+ qlinearops |
+
CodeBert |
Natural Language Processing |
diff --git a/examples/notebook/onnxruntime/Quick_Started_Notebook_of_INC_for_ONNXRuntime.ipynb b/examples/notebook/onnxruntime/Quick_Started_Notebook_of_INC_for_ONNXRuntime.ipynb
index c3e0aa5f027..e5d06c2be50 100644
--- a/examples/notebook/onnxruntime/Quick_Started_Notebook_of_INC_for_ONNXRuntime.ipynb
+++ b/examples/notebook/onnxruntime/Quick_Started_Notebook_of_INC_for_ONNXRuntime.ipynb
@@ -47,13 +47,14 @@
"outputs": [],
"source": [
"# install neural-compressor from source\n",
+ "import sys\n",
"!git clone https://github.com/intel/neural-compressor.git\n",
"%cd ./neural-compressor\n",
- "!pip install -r requirements.txt\n",
- "!python setup.py install\n",
+ "!{sys.executable} -m pip install -r requirements.txt\n",
+ "!{sys.executable} setup.py install\n",
"%cd ..\n",
"# or install stable basic version from pypi\n",
- "# pip install neural-compressor"
+ "# pip install neural-compressor\n"
]
},
{
@@ -65,10 +66,8 @@
},
"outputs": [],
"source": [
- "# install onnx related packages\n",
- "!pip install onnx onnxruntime onnxruntime-extensions\n",
- "# install other packages used in this notebook.\n",
- "!pip install torch transformers accelerate coloredlogs sympy numpy sentencepiece protobuf optimum"
+ "# install required packages\n",
+ "!{sys.executable} install -r requirements.txt\n"
]
},
{
@@ -168,7 +167,7 @@
"source": [
"!export GLUE_DIR=./glue_data\n",
"!wget https://raw.githubusercontent.com/Shimao-Zhang/Download_GLUE_Data/master/download_glue_data.py\n",
- "!python download_glue_data.py --data_dir=GLUE_DIR --tasks=SST"
+ "!{sys.executable} download_glue_data.py --data_dir=GLUE_DIR --tasks=SST\n"
]
},
{
@@ -193,7 +192,7 @@
"int8_model_path = \"onnx-model/int8-model.onnx\"\n",
"data_path = \"./GLUE_DIR/SST-2\"\n",
"task = \"sst-2\"\n",
- "batch_size = 8"
+ "batch_size = 8\n"
]
},
{
@@ -343,7 +342,7 @@
" label=label\n",
" )\n",
" features.append(feats)\n",
- " return features"
+ " return features\n"
]
},
{
@@ -377,7 +376,7 @@
" model_name_or_path=model_name_or_path,\n",
" model_type=\"distilbert\",\n",
" task=task)\n",
- "dataloader = DataLoader(framework=\"onnxruntime\", dataset=dataset, batch_size=batch_size)"
+ "dataloader = DataLoader(framework=\"onnxruntime\", dataset=dataset, batch_size=batch_size)\n"
]
},
{
@@ -448,7 +447,7 @@
" elif output_mode == \"regression\":\n",
" processed_preds = np.squeeze(self.pred_list)\n",
" result = transformers.glue_compute_metrics(self.task, processed_preds, self.label_list)\n",
- " return result[self.return_key[self.task]]"
+ " return result[self.return_key[self.task]]\n"
]
},
{
@@ -486,7 +485,7 @@
" ort_inputs.update({inputs_names[i]: inputs[i]})\n",
" predictions = session.run(None, ort_inputs)\n",
" metric.update(predictions[0], labels)\n",
- " return metric.result()"
+ " return metric.result()\n"
]
},
{
@@ -567,7 +566,7 @@
" num_heads=num_heads,\n",
" hidden_size=hidden_size,\n",
" optimization_options=opt_options)\n",
- "model = model_optimizer.model"
+ "model = model_optimizer.model\n"
]
},
{
@@ -722,7 +721,7 @@
" config,\n",
" eval_func=eval_func,\n",
" calib_dataloader=dataloader)\n",
- "q_model.save(int8_model_path)"
+ "q_model.save(int8_model_path)\n"
]
},
{
diff --git a/examples/notebook/onnxruntime/requirements.txt b/examples/notebook/onnxruntime/requirements.txt
new file mode 100644
index 00000000000..13f6ef9a8e1
--- /dev/null
+++ b/examples/notebook/onnxruntime/requirements.txt
@@ -0,0 +1,12 @@
+onnx
+onnxruntime
+onnxruntime-extensions
+torch
+transformers
+accelerate
+coloredlogs
+sympy
+numpy
+sentencepiece
+protobuf
+optimum
diff --git a/examples/notebook/pytorch/Quick_Started_Notebook_of_INC_for_Pytorch.ipynb b/examples/notebook/pytorch/Quick_Started_Notebook_of_INC_for_Pytorch.ipynb
index 0a7a277ee96..62d38981668 100644
--- a/examples/notebook/pytorch/Quick_Started_Notebook_of_INC_for_Pytorch.ipynb
+++ b/examples/notebook/pytorch/Quick_Started_Notebook_of_INC_for_Pytorch.ipynb
@@ -45,14 +45,15 @@
"outputs": [],
"source": [
"# install neural-compressor from source\n",
+ "import sys\n",
"!git clone https://github.com/intel/neural-compressor.git\n",
"%cd ./neural-compressor\n",
- "!pip install -r requirements.txt\n",
- "!python setup.py install\n",
+ "!{sys.executable} -m pip install -r requirements.txt\n",
+ "!{sys.executable} setup.py install\n",
"%cd ..\n",
"\n",
"# or install stable basic version from pypi\n",
- "!pip install neural-compressor"
+ "!{sys.executable} -m pip install neural-compressor\n"
]
},
{
@@ -62,7 +63,7 @@
"outputs": [],
"source": [
"# install other packages used in this notebook.\n",
- "!pip install torch>=1.9.0 transformers>=4.16.0 accelerate sympy numpy sentencepiece!=0.1.92 protobuf<=3.20.3 datasets>=1.1.3 scipy scikit-learn Keras-Preprocessing"
+ "!{sys.executable} -m pip install -r requirements.txt\n"
]
},
{
@@ -303,10 +304,10 @@
"outputs": [],
"source": [
"# fp32 benchmark\n",
- "!python benchmark.py --input_model ./pytorch_model.bin 2>&1|tee fp32_benchmark.log\n",
+ "!{sys.executable} benchmark.py --input_model ./pytorch_model.bin 2>&1|tee fp32_benchmark.log\n",
"\n",
"# int8 benchmark\n",
- "!python benchmark.py --input_model ./saved_results/best_model.pt 2>&1|tee int8_benchmark.log\n"
+ "!{sys.executable} benchmark.py --input_model ./saved_results/best_model.pt 2>&1|tee int8_benchmark.log\n"
]
}
],
diff --git a/examples/notebook/pytorch/requirements.txt b/examples/notebook/pytorch/requirements.txt
new file mode 100644
index 00000000000..aa1af71d2b3
--- /dev/null
+++ b/examples/notebook/pytorch/requirements.txt
@@ -0,0 +1,11 @@
+torch>=1.9.0
+transformers>=4.16.0
+accelerate
+sympy
+numpy
+sentencepiece!=0.1.92
+protobuf<=3.20.3
+datasets>=1.1.3
+scipy
+scikit-learn
+Keras-Preprocessing
diff --git a/examples/notebook/tensorflow/resnet/requirements.txt b/examples/notebook/tensorflow/resnet/requirements.txt
new file mode 100644
index 00000000000..a2d431e0ace
--- /dev/null
+++ b/examples/notebook/tensorflow/resnet/requirements.txt
@@ -0,0 +1,8 @@
+numpy
+neural-compressor
+tensorflow
+datasets
+requests
+urllib3
+pyOpenSSL
+git+https://github.com/huggingface/huggingface_hub
diff --git a/examples/notebook/tensorflow/resnet/resnet_quantization.ipynb b/examples/notebook/tensorflow/resnet/resnet_quantization.ipynb
index 8f1c36ef3d8..f2b168949f8 100644
--- a/examples/notebook/tensorflow/resnet/resnet_quantization.ipynb
+++ b/examples/notebook/tensorflow/resnet/resnet_quantization.ipynb
@@ -29,12 +29,11 @@
"metadata": {},
"outputs": [],
"source": [
+ "import sys\n",
"!conda install python==3.10 -y\n",
- "!pip install neural-compressor\n",
- "!wget -nc https://storage.googleapis.com/intel-optimized-tensorflow/models/v1_6/resnet50_fp32_pretrained_model.pb\n",
- "!pip install tensorflow\n",
- "!pip install datasets\n",
- "!pip install git+https://github.com/huggingface/huggingface_hub"
+ "!{sys.executable} -m pip install -r requirements.txt \n",
+ "\n",
+ "!wget -nc https://storage.googleapis.com/intel-optimized-tensorflow/models/v1_6/resnet50_fp32_pretrained_model.pb\n"
]
},
{
@@ -43,9 +42,11 @@
"metadata": {},
"outputs": [],
"source": [
+ "print(sys.executable)\n",
+ "!{sys.executable} -m pip list\n",
"import tensorflow as tf\n",
"import numpy as np\n",
- "import datasets"
+ "import datasets\n"
]
},
{
@@ -63,8 +64,8 @@
"source": [
"# login to huggingface to download the imagenet-1k dataset\n",
"# you should replace this read-only token with your own by create one on (https://huggingface.co/settings/tokens)\n",
- "# !huggingface-cli login --token \n",
- "!huggingface-cli login --token hf_xxxxxxxxxxxxxxxxxxxxxx"
+ "from huggingface_hub.hf_api import HfFolder\n",
+ "HfFolder.save_token('hf_xxxxxxxxxxxxxxxxxxxxxx')\n"
]
},
{
@@ -75,8 +76,8 @@
"source": [
"from datasets import load_dataset\n",
"# load dataset in streaming way will get an IterableDatset\n",
- "calib_dataset = load_dataset('imagenet-1k', split='train', streaming=True, use_auth_token=True)\n",
- "eval_dataset = load_dataset('imagenet-1k', split='validation', streaming=True, use_auth_token=True)"
+ "calib_dataset = load_dataset('imagenet-1k', split='train', streaming=True, token=True)\n",
+ "eval_dataset = load_dataset('imagenet-1k', split='validation', streaming=True, token=True)\n"
]
},
{
@@ -97,7 +98,7 @@
" return datasets.Dataset.from_dict(data)\n",
"\n",
"sub_calib_dataset = sample_data(calib_dataset, MAX_SAMPLE_LENGTG)\n",
- "sub_eval_dataset = sample_data(eval_dataset, MAX_SAMPLE_LENGTG)"
+ "sub_eval_dataset = sample_data(eval_dataset, MAX_SAMPLE_LENGTG)\n"
]
},
{
@@ -136,7 +137,7 @@
" batch_inputs = []\n",
" labels = []\n",
" def __len__(self):\n",
- " return self.length"
+ " return self.length\n"
]
},
{
@@ -146,7 +147,7 @@
"outputs": [],
"source": [
"calib_dataloader = CustomDataloader(dataset=sub_calib_dataset, batch_size=32)\n",
- "eval_dataloader = CustomDataloader(dataset=sub_eval_dataset, batch_size=32)"
+ "eval_dataloader = CustomDataloader(dataset=sub_eval_dataset, batch_size=32)\n"
]
},
{
@@ -193,7 +194,7 @@
" return acc\n",
"\n",
"q_model = quantization.fit(\"./resnet50_fp32_pretrained_model.pb\", conf=conf, calib_dataloader=calib_dataloader, eval_func=eval_func)\n",
- "q_model.save(\"resnet50_int8.pb\")"
+ "q_model.save(\"resnet50_int8.pb\")\n"
]
},
{
@@ -221,7 +222,7 @@
"metadata": {},
"outputs": [],
"source": [
- "!python resnet_benchmark.py --input_model resnet50_fp32_pretrained_model.pb 2>&1|tee fp32_benchmark.log"
+ "!{sys.executable} resnet_benchmark.py --input_model resnet50_fp32_pretrained_model.pb 2>&1|tee fp32_benchmark.log\n"
]
},
{
@@ -237,7 +238,7 @@
"metadata": {},
"outputs": [],
"source": [
- "!python resnet_benchmark.py --input_model resnet50_int8.pb 2>&1|tee int8_benchmark.log"
+ "!{sys.executable} resnet_benchmark.py --input_model resnet50_int8.pb 2>&1|tee int8_benchmark.log\n"
]
},
{
diff --git a/examples/notebook/tensorflow/vgg19_ibean/requirements.txt b/examples/notebook/tensorflow/vgg19_ibean/requirements.txt
new file mode 100644
index 00000000000..e866fcd37f6
--- /dev/null
+++ b/examples/notebook/tensorflow/vgg19_ibean/requirements.txt
@@ -0,0 +1,5 @@
+numpy
+matplotlib
+tensorflow
+tensorflow-hub
+tensorflow-datasets
diff --git a/examples/onnxrt/image_recognition/beit/quantization/ptq_static/README.md b/examples/onnxrt/image_recognition/beit/quantization/ptq_static/README.md
new file mode 100644
index 00000000000..f6775301e62
--- /dev/null
+++ b/examples/onnxrt/image_recognition/beit/quantization/ptq_static/README.md
@@ -0,0 +1,55 @@
+Step-by-Step
+============
+
+This example load [BERT Pre-Training of Image Transformers](https://arxiv.org/abs/2106.08254)(BEiT) model and confirm its accuracy and performance based on [ImageNet-1k dataset](http://www.image-net.org/). You need to download this dataset yourself.
+
+In this example, the BEiT model is pre-trained in a self-supervised fashion on ImageNet-22k - also called ImageNet-21k (14 million images, 21,841 classes) at resolution 224x224, and fine-tuned on the same dataset at resolution 224x224. It was first released in [this repository](https://github.com/microsoft/unilm/tree/master/beit).
+
+
+# Prerequisite
+
+## 1. Environment
+```shell
+pip install neural-compressor
+pip install -r requirements.txt
+```
+> Note: Validated ONNX Runtime [Version](/docs/source/installation_guide.md#validated-software-environment).
+
+## 2. Prepare Model
+
+Prepare DETR R18 model for table structure recognition.
+
+```shell
+python prepare_model.py --input_model=beit_base_patch16_224 --output_model=beit_base_patch16_224_pt22k_ft22kto1k.onnx
+```
+
+## 3. Prepare Dataset
+
+Download and extract [ImageNet-1k](http://www.image-net.org/) to dir: /path/to/imagenet. The dir include below folder:
+
+```bash
+ls /path/to/imagenet
+train val
+```
+
+# Run
+
+## 1. Quantization
+
+Quantize model with QLinearOps:
+
+```bash
+bash run_quant.sh --input_model=/path/to/model \ # model path as *.onnx
+ --dataset_location=/path/to/imagenet \
+ --output_model=/path/to/save \
+ --quant_format="QOperator"
+```
+
+## 2. Benchmark
+
+```bash
+bash run_benchmark.sh --input_model=/path/to/model \ # model path as *.onnx
+ --dataset_location=/path/to/imagenet \
+ --batch_size=batch_size \
+ --mode=performance # or accuracy
+```
\ No newline at end of file
diff --git a/examples/onnxrt/image_recognition/beit/quantization/ptq_static/beit_modeling_finetune.py b/examples/onnxrt/image_recognition/beit/quantization/ptq_static/beit_modeling_finetune.py
new file mode 100644
index 00000000000..7e9cd599893
--- /dev/null
+++ b/examples/onnxrt/image_recognition/beit/quantization/ptq_static/beit_modeling_finetune.py
@@ -0,0 +1,420 @@
+# --------------------------------------------------------
+# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
+# Github source: https://github.com/microsoft/unilm/tree/master/beit
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# By Hangbo Bao
+# Based on timm and DeiT code bases
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm
+# https://github.com/facebookresearch/deit/
+# https://github.com/facebookresearch/dino
+# --------------------------------------------------------'
+import math
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from timm.models.layers import drop_path, to_2tuple, trunc_normal_
+from timm.models.registry import register_model
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': .9, 'interpolation': 'bicubic',
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
+ **kwargs
+ }
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return 'p={}'.format(self.drop_prob)
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ # x = self.drop(x)
+ # commit this for the original BERT implement
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
+ proj_drop=0., window_size=None, attn_head_dim=None):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ if attn_head_dim is not None:
+ head_dim = attn_head_dim
+ all_head_dim = head_dim * self.num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
+ else:
+ self.q_bias = None
+ self.v_bias = None
+
+ if window_size:
+ self.window_size = window_size
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+ # cls to token & token 2 cls & cls to cls
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(window_size[0])
+ coords_w = torch.arange(window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+ relative_position_index = \
+ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
+ relative_position_index[0, 0] = self.num_relative_distance - 1
+
+ self.register_buffer("relative_position_index", relative_position_index)
+ else:
+ self.window_size = None
+ self.relative_position_bias_table = None
+ self.relative_position_index = None
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(all_head_dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, rel_pos_bias=None):
+ B, N, C = x.shape
+ qkv_bias = None
+ if self.q_bias is not None:
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ if self.relative_position_bias_table is not None:
+ relative_position_bias = \
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1] + 1,
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if rel_pos_bias is not None:
+ attn = attn + rel_pos_bias
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
+ window_size=None, attn_head_dim=None):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ if init_values is not None and init_values > 0:
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
+ else:
+ self.gamma_1, self.gamma_2 = None, None
+
+ def forward(self, x, rel_pos_bias=None):
+ if self.gamma_1 is None:
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ else:
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, x, **kwargs):
+ B, C, H, W = x.shape
+ # FIXME look at relaxing size constraints
+ assert H == self.img_size[0] and W == self.img_size[1], \
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x).flatten(2).transpose(1, 2)
+ return x
+
+
+class RelativePositionBias(nn.Module):
+
+ def __init__(self, window_size, num_heads):
+ super().__init__()
+ self.window_size = window_size
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+ # cls to token & token 2 cls & cls to cls
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(window_size[0])
+ coords_w = torch.arange(window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+ relative_position_index = \
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
+ relative_position_index[0, 0] = self.num_relative_distance - 1
+
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ # trunc_normal_(self.relative_position_bias_table, std=.02)
+
+ def forward(self):
+ relative_position_bias = \
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1] + 1,
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+
+
+class VisionTransformer(nn.Module):
+ """ Vision Transformer with support for patch or hybrid CNN input stage
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
+ use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
+ use_mean_pooling=True, init_scale=0.001, pretrained_cfg=None, pretrained_cfg_overlay=None):
+ super().__init__()
+ self.num_classes = num_classes
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ if use_abs_pos_emb:
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+ else:
+ self.pos_embed = None
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ if use_shared_rel_pos_bias:
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
+ else:
+ self.rel_pos_bias = None
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ self.use_rel_pos_bias = use_rel_pos_bias
+ self.blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+ init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
+ for i in range(depth)])
+ self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
+ self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ if self.pos_embed is not None:
+ trunc_normal_(self.pos_embed, std=.02)
+ trunc_normal_(self.cls_token, std=.02)
+ # trunc_normal_(self.mask_token, std=.02)
+ if isinstance(self.head, nn.Linear):
+ trunc_normal_(self.head.weight, std=.02)
+ self.apply(self._init_weights)
+ self.fix_init_weight()
+
+ if isinstance(self.head, nn.Linear):
+ self.head.weight.data.mul_(init_scale)
+ self.head.bias.data.mul_(init_scale)
+
+ def fix_init_weight(self):
+ def rescale(param, layer_id):
+ param.div_(math.sqrt(2.0 * layer_id))
+
+ for layer_id, layer in enumerate(self.blocks):
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def get_num_layers(self):
+ return len(self.blocks)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed', 'cls_token'}
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ x = self.patch_embed(x)
+ batch_size, seq_len, _ = x.size()
+
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+ if self.pos_embed is not None:
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
+ for blk in self.blocks:
+ x = blk(x, rel_pos_bias=rel_pos_bias)
+
+ x = self.norm(x)
+ if self.fc_norm is not None:
+ t = x[:, 1:, :]
+ return self.fc_norm(t.mean(1))
+ else:
+ return x[:, 0]
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+ def get_intermediate_layers(self, x):
+ x = self.patch_embed(x)
+ batch_size, seq_len, _ = x.size()
+
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+ if self.pos_embed is not None:
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+
+ features = []
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
+ for blk in self.blocks:
+ x = blk(x, rel_pos_bias)
+ features.append(x)
+
+ return features
+
+
+@register_model
+def beit_base_patch16_224(pretrained=False, **kwargs):
+ model = VisionTransformer(
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ model.default_cfg = _cfg()
+ return model
+
+
+@register_model
+def beit_base_patch16_384(pretrained=False, **kwargs):
+ model = VisionTransformer(
+ img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ model.default_cfg = _cfg()
+ return model
+
+
+@register_model
+def beit_large_patch16_224(pretrained=False, **kwargs):
+ model = VisionTransformer(
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ model.default_cfg = _cfg()
+ return model
+
+
+@register_model
+def beit_large_patch16_384(pretrained=False, **kwargs):
+ model = VisionTransformer(
+ img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ model.default_cfg = _cfg()
+ return model
+
+
+@register_model
+def beit_large_patch16_512(pretrained=False, **kwargs):
+ model = VisionTransformer(
+ img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ model.default_cfg = _cfg()
+ return model
diff --git a/examples/onnxrt/image_recognition/beit/quantization/ptq_static/main.py b/examples/onnxrt/image_recognition/beit/quantization/ptq_static/main.py
new file mode 100644
index 00000000000..b7d9bc0eab0
--- /dev/null
+++ b/examples/onnxrt/image_recognition/beit/quantization/ptq_static/main.py
@@ -0,0 +1,163 @@
+import os
+import tqdm
+import onnx
+import torch
+import logging
+import argparse
+import onnxruntime as ort
+from timm.utils import accuracy
+from torchvision import datasets, transforms
+from timm.data.constants import \
+ IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
+ datefmt = '%m/%d/%Y %H:%M:%S',
+ level = logging.WARN)
+
+def build_eval_transform(input_size=224, imagenet_default_mean_and_std=False, crop_pct=None):
+ resize_im = input_size > 32
+ imagenet_default_mean_and_std = imagenet_default_mean_and_std
+ mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
+ std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD
+
+ t = []
+ if resize_im:
+ if crop_pct is None:
+ if input_size < 384:
+ crop_pct = 224 / 256
+ else:
+ crop_pct = 1.0
+ size = int(input_size / crop_pct)
+ t.append(
+ transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images
+ )
+ t.append(transforms.CenterCrop(input_size))
+
+ t.append(transforms.ToTensor())
+ t.append(transforms.Normalize(mean, std))
+ return transforms.Compose(t)
+
+def build_val_dataset(data_path):
+ transform = build_eval_transform()
+ root = os.path.join(data_path, 'val')
+ dataset = datasets.ImageFolder(root, transform=transform)
+ return dataset
+
+
+def evaluate_func(data_loader, model):
+ session = ort.InferenceSession(model.SerializeToString(), providers=["CPUExecutionProvider"])
+ top1, top5 = 0, 0
+
+ for idx, batch in tqdm.tqdm(enumerate(data_loader), desc='eval'):
+ images = batch[0].cpu().detach().numpy()
+ target = batch[-1]
+ output = session.run(None, {'image': images})[0]
+ acc1, acc5 = accuracy(torch.from_numpy(output), target, topk=(1, 5))
+ top1 += acc1.cpu().detach().numpy()
+ top5 += acc5.cpu().detach().numpy()
+
+ top1 = top1 / len(data_loader)
+ top5 = top5 / len(data_loader)
+ print('* Acc@1 {:.3f} Acc@5 {:.3f}'.format(top1, top5))
+ return top1
+
+if __name__ == '__main__':
+ logger.info("Evaluating ONNXRuntime full precision accuracy and performance:")
+ parser = argparse.ArgumentParser(
+ description="BEiT quantization examples.",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+ parser.add_argument(
+ '--model_path',
+ type=str,
+ help="Pre-trained model on onnx file"
+ )
+ parser.add_argument(
+ '--dataset_location',
+ type=str,
+ help="Imagenet data path"
+ )
+ parser.add_argument(
+ '--benchmark',
+ action='store_true', \
+ default=False,
+ help="whether benchmark the model"
+ )
+ parser.add_argument(
+ '--tune',
+ action='store_true', \
+ default=False,
+ help="whether quantize the model"
+ )
+ parser.add_argument(
+ '--output_model',
+ type=str,
+ help="output model path"
+ )
+ parser.add_argument(
+ '--quant_format',
+ type=str,
+ default='default',
+ choices=['default', 'QDQ', 'QOperator'],
+ help="quantization format"
+ )
+ parser.add_argument(
+ '--mode',
+ type=str,
+ help="benchmark mode of performance or accuracy"
+ )
+ parser.add_argument(
+ "--batch_size",
+ default=64,
+ type=int,
+ )
+ parser.add_argument(
+ "--num_workers",
+ default=10,
+ type=int,
+ )
+ args = parser.parse_args()
+
+ val_dataset = build_val_dataset(args.dataset_location)
+ val_sampler = torch.utils.data.SequentialSampler(val_dataset)
+ val_data_loader = torch.utils.data.DataLoader(
+ val_dataset, sampler=val_sampler,
+ batch_size=int(1.5 * args.batch_size),
+ num_workers=args.num_workers,
+ drop_last=False
+ )
+
+ def eval(model):
+ return evaluate_func(val_data_loader, model)
+
+ model = onnx.load(args.model_path)
+
+ if args.tune:
+ from neural_compressor import PostTrainingQuantConfig, quantization
+ from neural_compressor.utils.constant import FP32
+
+ config = PostTrainingQuantConfig(approach="static",
+ quant_format=args.quant_format,
+ op_type_dict={'Conv': FP32},
+ op_name_dict={'/blocks.*/mlp/fc2/MatMul': FP32},
+ recipes={'optypes_to_exclude_output_quant': ['MatMul']},
+ )
+ q_model = quantization.fit(model,
+ config,
+ calib_dataloader=val_data_loader,
+ eval_func=eval)
+ q_model.save(args.output_model)
+
+ if args.benchmark:
+ if args.mode == 'performance':
+ from neural_compressor.benchmark import fit
+ from neural_compressor.config import BenchmarkConfig
+ conf = BenchmarkConfig(warmup=10, iteration=1000, cores_per_instance=4, num_of_instance=1)
+ fit(model, conf, b_dataloader=val_data_loader)
+ elif args.mode == 'accuracy':
+ acc_result = evaluate_func(val_data_loader, model)
+ print("Batch size = %d" % val_data_loader.batch_size)
+ print("Accuracy: %.5f" % acc_result)
+
+
diff --git a/examples/onnxrt/image_recognition/beit/quantization/ptq_static/prepare_model.py b/examples/onnxrt/image_recognition/beit/quantization/ptq_static/prepare_model.py
new file mode 100644
index 00000000000..683e115c38d
--- /dev/null
+++ b/examples/onnxrt/image_recognition/beit/quantization/ptq_static/prepare_model.py
@@ -0,0 +1,102 @@
+import argparse
+import os
+import sys
+import torch
+from urllib import request
+from timm.models import create_model
+import beit_modeling_finetune
+
+MODEL_URLS = {"beit_base_patch16_224": "https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D",}
+MODEL_FILES = {"beit_base_patch16_224": "beit_base_patch16_224_pt22k_ft22kto1k.pth"}
+MAX_TIMES_RETRY_DOWNLOAD = 5
+
+
+def parse_arguments():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--input_model", type=str, required=False, default="beit_base_patch16_224")
+ parser.add_argument("--output_model", type=str, required=True)
+ return parser.parse_args()
+
+
+def progressbar(cur, total=100):
+ percent = '{:.2%}'.format(cur / total)
+ sys.stdout.write("\r[%-100s] %s" % ('#' * int(cur), percent))
+ sys.stdout.flush()
+
+
+def schedule(blocknum, blocksize, totalsize):
+ if totalsize == 0:
+ percent = 0
+ else:
+ percent = min(1.0, blocknum * blocksize / totalsize) * 100
+ progressbar(percent)
+
+
+def download_model(input_model, retry_times=5):
+ model_url = MODEL_URLS[input_model]
+ model_file = MODEL_FILES[input_model]
+ if os.path.isfile(model_file):
+ print(f"{model_file} exists, skip download")
+ return True
+
+ print("download model...")
+ retries = 0
+ while retries < retry_times:
+ try:
+ request.urlretrieve(model_url, model_file, schedule)
+ break
+ except KeyboardInterrupt:
+ return False
+ except:
+ retries += 1
+ print(f"Download failed{', Retry downloading...' if retries < retry_times else '!'}")
+ return retries < retry_times
+
+
+def export_model(input_model, output_model):
+ print("\nexport model...")
+
+ model = create_model(
+ input_model,
+ pretrained=False,
+ num_classes=1000,
+ drop_rate=0.0,
+ drop_path_rate=0.1,
+ attn_drop_rate=0.0,
+ drop_block_rate=None,
+ use_mean_pooling=True,
+ init_scale=0.001,
+ use_rel_pos_bias=True,
+ use_abs_pos_emb=False,
+ init_values=0.1,
+ )
+
+ checkpoint = torch.load(MODEL_FILES[input_model], map_location='cpu')
+ model.load_state_dict(checkpoint['model'])
+ print("Resume checkpoint %s" % MODEL_FILES[input_model])
+
+ model.eval()
+ x = torch.randn(1, 3, 224, 224, requires_grad=True)
+ torch.onnx.export(model,
+ x,
+ output_model,
+ export_params=True,
+ opset_version=13,
+ do_constant_folding=True,
+ input_names = ["image"],
+ output_names = ["output"],
+ dynamic_axes={"image" : {0 : "batch_size"},
+ "output" : {0 : "batch_size"}}
+ )
+ assert os.path.exists(output_model), f"Export failed! {output_model} doesn't exist!"
+
+
+def prepare_model(input_model, output_model):
+ is_download_successful = download_model(args.input_model, MAX_TIMES_RETRY_DOWNLOAD)
+ if is_download_successful:
+ export_model(input_model, output_model)
+
+
+if __name__ == "__main__":
+ args = parse_arguments()
+ prepare_model(args.input_model, args.output_model)
\ No newline at end of file
diff --git a/examples/onnxrt/image_recognition/beit/quantization/ptq_static/requirements.txt b/examples/onnxrt/image_recognition/beit/quantization/ptq_static/requirements.txt
new file mode 100644
index 00000000000..b6855b21b0a
--- /dev/null
+++ b/examples/onnxrt/image_recognition/beit/quantization/ptq_static/requirements.txt
@@ -0,0 +1,6 @@
+torch
+torchvision
+timm
+onnx
+onnxruntime
+onnxruntime-extensions; python_version < '3.11'
\ No newline at end of file
diff --git a/examples/onnxrt/image_recognition/beit/quantization/ptq_static/run_benchmark.sh b/examples/onnxrt/image_recognition/beit/quantization/ptq_static/run_benchmark.sh
new file mode 100644
index 00000000000..41c190229d1
--- /dev/null
+++ b/examples/onnxrt/image_recognition/beit/quantization/ptq_static/run_benchmark.sh
@@ -0,0 +1,44 @@
+#!/bin/bash
+set -x
+
+function main {
+ init_params "$@"
+ run_benchmark
+
+}
+
+# init params
+function init_params {
+ for var in "$@"
+ do
+ case $var in
+ --input_model=*)
+ input_model=$(echo $var |cut -f2 -d=)
+ ;;
+ --dataset_location=*)
+ dataset_location=$(echo $var |cut -f2 -d=)
+ ;;
+ --mode=*)
+ mode=$(echo $var |cut -f2 -d=)
+ ;;
+ --batch_size=*)
+ batch_size=$(echo $var |cut -f2 -d=)
+ ;;
+ esac
+ done
+
+}
+
+# run_benchmark
+function run_benchmark {
+
+ python main.py \
+ --model_path ${input_model} \
+ --dataset_location ${dataset_location} \
+ --mode ${mode} \
+ --batch_size ${batch_size-1} \
+ --benchmark
+
+}
+
+main "$@"
diff --git a/examples/onnxrt/image_recognition/beit/quantization/ptq_static/run_quant.sh b/examples/onnxrt/image_recognition/beit/quantization/ptq_static/run_quant.sh
new file mode 100644
index 00000000000..7f9d10fa0e7
--- /dev/null
+++ b/examples/onnxrt/image_recognition/beit/quantization/ptq_static/run_quant.sh
@@ -0,0 +1,43 @@
+#!/bin/bash
+set -x
+
+function main {
+ init_params "$@"
+ run_tuning
+
+}
+
+# init params
+function init_params {
+
+ for var in "$@"
+ do
+ case $var in
+ --input_model=*)
+ input_model=$(echo $var |cut -f2 -d=)
+ ;;
+ --output_model=*)
+ output_model=$(echo $var |cut -f2 -d=)
+ ;;
+ --dataset_location=*)
+ dataset_location=$(echo $var |cut -f2 -d=)
+ ;;
+ --quant_format=*)
+ quant_format=$(echo $var |cut -f2 -d=)
+ ;;
+ esac
+ done
+
+}
+
+# run_tuning
+function run_tuning {
+ python main.py \
+ --model_path ${input_model} \
+ --dataset_location ${dataset_location} \
+ --output_model ${output_model} \
+ --quant_format ${quant_format-default} \
+ --tune
+}
+
+main "$@"
diff --git a/examples/onnxrt/nlp/huggingface_model/language_modeling/quantization/ptq_dynamic/main.py b/examples/onnxrt/nlp/huggingface_model/language_modeling/quantization/ptq_dynamic/main.py
index 345f6fac183..5182ec6bbc5 100644
--- a/examples/onnxrt/nlp/huggingface_model/language_modeling/quantization/ptq_dynamic/main.py
+++ b/examples/onnxrt/nlp/huggingface_model/language_modeling/quantization/ptq_dynamic/main.py
@@ -197,8 +197,7 @@ def main():
tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path,
use_fast=True,
- cache_dir=args.cache_dir if args.cache_dir else None,
- use_auth_token='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')
+ cache_dir=args.cache_dir if args.cache_dir else None)
if args.block_size <= 0:
args.block_size = tokenizer.max_len_single_sentence # Our input block size will be the max possible for the model
args.block_size = min(args.block_size, tokenizer.max_len_single_sentence)
diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run-gptq-llm.py b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run-gptq-llm.py
index 149a030af09..a13a99d483b 100644
--- a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run-gptq-llm.py
+++ b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run-gptq-llm.py
@@ -58,6 +58,13 @@ def __init__(self, gptq_dataloader):
def __iter__(self):
pass
+def filter_chatglmv1(seq):
+ bos_token_id = 130004
+ # eos_token_id = 130005
+ gmask_token_id = 130001
+ # return (bos_token_id in seq and eos_token_id in seq and mask_token_id in seq and gmask_token_id in seq)
+ return (len(seq) < 2048 and bos_token_id in seq and gmask_token_id in seq)
+
# INC original dataloader example
class Evaluator:
def __init__(self, dataset, tokenizer, batch_size=8, pad_val=1, pad_max=196, is_calib=False):
@@ -217,9 +224,9 @@ def skip(*args, **kwargs):
# model
if re.search("chatglm", args.model_name_or_path.lower()): # chatglm requires a different way to be loaded
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
- model = AutoModel.from_pretrained(args.model_name_or_path, trust_remote_code=True)
+ model = AutoModel.from_pretrained(args.model_name_or_path, trust_remote_code=True).float().cpu()
else:
- tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, low_cpu_mem_usage=True, trust_remote_code=True)
model = model.eval()
@@ -227,12 +234,20 @@ def skip(*args, **kwargs):
# calib_dataset = datasets.load_from_disk('/your/local/pile-10k/') # use this if trouble with connecting to HF
calib_dataset = calib_dataset.shuffle(seed=args.seed)
calib_evaluator = Evaluator(calib_dataset, tokenizer, args.calib_size, is_calib=True)
- calib_dataloader = DataLoader(
- calib_evaluator.dataset,
- batch_size=args.calib_size,
- shuffle=False,
- collate_fn=calib_evaluator.collate_batch,
- )
+ if hasattr(model.config, "_name_or_path") and "chatglm-6b" in model.config._name_or_path:
+ calib_dataloader = DataLoader(
+ calib_evaluator.dataset.filter(lambda example: filter_chatglmv1(example['input_ids'])),
+ batch_size=args.calib_size,
+ shuffle=False,
+ collate_fn=calib_evaluator.collate_batch,
+ )
+ else:
+ calib_dataloader = DataLoader(
+ calib_evaluator.dataset,
+ batch_size=args.calib_size,
+ shuffle=False,
+ collate_fn=calib_evaluator.collate_batch,
+ )
if args.gpu and torch.cuda.is_available():
DEV = torch.device('cuda:0')
@@ -294,7 +309,8 @@ def skip(*args, **kwargs):
dataloader=calib_dataloader,
nsamples = args.nsamples,
use_max_length = args.use_max_length,
- pad_max_length = args.pad_max_length
+ pad_max_length = args.pad_max_length,
+ device = DEV,
)
results = lm_evaluate(
diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py
index 571346b515d..04eac28aab9 100644
--- a/neural_compressor/adaptor/pytorch.py
+++ b/neural_compressor/adaptor/pytorch.py
@@ -2739,14 +2739,20 @@ def _ipex_post_quant_process(self, model, q_model, dataloader, inplace=False):
q_model._model = ipex.quantization.convert(model._model, inplace=inplace)
try:
if isinstance(self.example_inputs, dict):
- q_model._model = torch.jit.trace(q_model._model, example_kwarg_inputs=self.example_inputs)
+ q_model._model = torch.jit.trace(
+ q_model._model,
+ example_kwarg_inputs=self.example_inputs,
+ )
else:
q_model._model = torch.jit.trace(q_model._model, self.example_inputs)
q_model._model = torch.jit.freeze(q_model._model.eval())
except:
if isinstance(self.example_inputs, dict):
q_model._model = torch.jit.trace(
- q_model._model, example_kwarg_inputs=self.example_inputs, strict=False
+ q_model._model,
+ example_kwarg_inputs=self.example_inputs,
+ strict=False,
+ check_trace=False,
)
else:
q_model._model = torch.jit.trace(q_model._model, self.example_inputs, strict=False)
@@ -2763,7 +2769,7 @@ def _ipex_post_quant_process(self, model, q_model, dataloader, inplace=False):
except:
if isinstance(self.example_inputs, dict):
q_model._model = torch.jit.trace(
- q_model._model, example_kwarg_inputs=self.example_inputs, strict=False
+ q_model._model, example_kwarg_inputs=self.example_inputs, strict=False, check_trace=False
)
else:
q_model._model = torch.jit.trace(q_model._model, self.example_inputs, strict=False)
@@ -3475,13 +3481,13 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
):
from .torch_utils.layer_wise_quant import LayerWiseQuant
- model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None)
+ # model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None)
+ model_path = model._model.path
smooth_quant = recipe_cfgs["layer_wise_quant_args"].get("smooth_quant", False)
alpha = recipe_cfgs["layer_wise_quant_args"].get("smooth_quant_alpha", 0.5)
- assert (
- model_path is not None
- ), "the layer_wise_quant_args should have args model_path to load the weight of model."
- device = recipe_cfgs["layer_wise_quant_args"].get("decvice", "cpu")
+ # device = recipe_cfgs["layer_wise_quant_args"].get("decvice", "cpu")
+ assert model_path is not None, "The model_path should not be None."
+ device = self.device
lw_quant = LayerWiseQuant(
q_model._model,
model_path,
@@ -4514,14 +4520,12 @@ def rtn_quantize(self, model, tune_cfg):
# for layer_wise quant mode
recipe_cfgs = tune_cfg.get("recipe_cfgs", None)
if recipe_cfgs.get("layer_wise_quant", False):
- from neural_compressor.config import options
-
- from .torch_utils.layer_wise_quant.utils import _get_path, load_module
+ from .torch_utils.layer_wise_quant.utils import LWQ_WORKSPACE, _get_path, load_module
- lwq_workspace = os.path.join(options.workspace, "lwq_tmpdir")
- os.makedirs(lwq_workspace, exist_ok=True)
- model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None)
- assert model_path, "model_path should specify in layer_wise_quant_args."
+ os.makedirs(LWQ_WORKSPACE, exist_ok=True)
+ # model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None)
+ model_path = model.path
+ assert model_path, "model_path should not be None."
model_path = _get_path(model_path)
for key, config in tune_cfg["op"].items():
@@ -4557,7 +4561,7 @@ def rtn_quantize(self, model, tune_cfg):
# save and clean weight
from .torch_utils.layer_wise_quant.utils import clean_module_weight
- torch.save(m.state_dict(), os.path.join(lwq_workspace, f"{op_name}.pt"))
+ torch.save(m.state_dict(), os.path.join(LWQ_WORKSPACE, f"{op_name}.pt"))
clean_module_weight(m)
set_module(model, op_name, m)
if recipe_cfgs.get("layer_wise_quant", False):
@@ -4592,6 +4596,23 @@ def gptq_quantize(self, model, tune_cfg, dataloader):
...
}
"""
+ # for layer_wise quant mode
+ recipe_cfgs = tune_cfg.get("recipe_cfgs", None)
+ model_path = None
+ layer_wise = False
+ if recipe_cfgs.get("layer_wise_quant", False):
+ layer_wise = True
+ from .torch_utils.layer_wise_quant.utils import LWQ_WORKSPACE, _get_path, register_weight_hooks
+
+ os.makedirs(LWQ_WORKSPACE, exist_ok=True)
+ # model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None)
+ model_path = model.path
+ assert model_path, "model_path should not be None."
+ model_path = _get_path(model_path)
+ lwq_handles = register_weight_hooks(
+ model, model_path, device=self.device, clean_weight=True, saved_path=LWQ_WORKSPACE
+ )
+
weight_config = {}
for key, config in tune_cfg["op"].items():
op_name, op_type = key
@@ -4616,7 +4637,15 @@ def gptq_quantize(self, model, tune_cfg, dataloader):
)
# tune_cfg => weight_config
model, quantization_perm = gptq_quantize(
- model, weight_config, dataloader, nsamples, use_max_length, pad_max_length, self.device
+ model,
+ weight_config,
+ dataloader,
+ nsamples,
+ use_max_length,
+ pad_max_length,
+ self.device,
+ layer_wise,
+ model_path,
)
return model, quantization_perm
diff --git a/neural_compressor/adaptor/torch_utils/gptq.py b/neural_compressor/adaptor/torch_utils/gptq.py
index b2710558e57..1a33addb364 100644
--- a/neural_compressor/adaptor/torch_utils/gptq.py
+++ b/neural_compressor/adaptor/torch_utils/gptq.py
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import gc
import math
import random
import re
@@ -72,7 +73,7 @@ def is_leaf(module):
return True if children_cnt == 0 else False
-def trace_gptq_target_blocks(module, module_types=[torch.nn.ModuleList]):
+def trace_gptq_target_blocks(module, module_types=[torch.nn.ModuleList, torch.nn.Sequential]):
"""Search transformer stacked structures, which is critical in LLMs and GPTQ execution.
Args:
@@ -88,21 +89,41 @@ def trace_gptq_target_blocks(module, module_types=[torch.nn.ModuleList]):
"transformers": {}, Dict# TODO
}
"""
- gptq_related_blocks = {
- "embeddings": {},
- "transformers_pre": {}, # todo
- "transformers_name": "", # None
- "transformers": [], # None
- "transformers_post": {}, # todo
- }
- for n, m in module.named_modules():
- if type(m) in module_types:
- gptq_related_blocks["transformers_name"] = n
- gptq_related_blocks["transformers"] = m
- return gptq_related_blocks
- else:
+ if type(module).__name__ == "MixFormerSequentialForCausalLM": # pragma: no cover
+ gptq_related_blocks = {
+ "embeddings": {},
+ "transformers_pre": {}, # todo
+ "transformers_name": "", # None
+ "transformers": [], # None
+ "transformers_post": {}, # todo
+ }
+ for n, m in module.named_modules():
+ if type(m) in module_types:
+ gptq_related_blocks["transformers_name"] = n
+ gptq_related_blocks["transformers"] = m
+ break
+ else:
+ continue
+ for n, m in gptq_related_blocks["transformers"][0].named_modules():
if is_leaf(m):
gptq_related_blocks["embeddings"][n] = m
+ gptq_related_blocks["transformers"] = gptq_related_blocks["transformers"][1:-1]
+ else:
+ gptq_related_blocks = {
+ "embeddings": {},
+ "transformers_pre": {}, # todo
+ "transformers_name": "", # None
+ "transformers": [], # None
+ "transformers_post": {}, # todo
+ }
+ for n, m in module.named_modules():
+ if type(m) in module_types:
+ gptq_related_blocks["transformers_name"] = n
+ gptq_related_blocks["transformers"] = m
+ return gptq_related_blocks
+ else:
+ if is_leaf(m):
+ gptq_related_blocks["embeddings"][n] = m
return gptq_related_blocks
@@ -175,6 +196,7 @@ def __init__(
use_max_length=True,
pad_max_length=2048,
device=None,
+ layer_wise=False,
):
"""
Args:
@@ -196,7 +218,7 @@ def __init__(
"""
# model
self.model = model
- self.use_cache = self.model.config.use_cache
+ # self.use_cache = self.model.config.use_cache
self.gptq_related_blocks = trace_gptq_target_blocks(self.model) # get the transformer block list above
self.dtype = next(iter(self.model.parameters())).dtype
log_quantizable_layers_per_transformer(self.gptq_related_blocks)
@@ -215,9 +237,13 @@ def __init__(
self.check_layer_config()
# device
- self.device = model.device
+ self.device = device
+ if str(self.model.device).startswith("cuda"):
+ self.device = self.model.device
self.is_ready = False
+ self.layer_wise = layer_wise
+
# dataloader
self.use_max_length = use_max_length
self.pad_max_length = pad_max_length
@@ -411,6 +437,12 @@ def get_layer_config(self, layer_name):
pass
return config
+ def track_hidden_states(self, data):
+ if isinstance(data, torch.Tensor):
+ return data
+ elif isinstance(data, tuple) or isinstance(data, list):
+ return data[0]
+
@torch.no_grad()
def pre_quantization(self):
"""Prepare input calibration data and other attributes which are critical for gptq execution."""
@@ -438,11 +470,13 @@ def forward(layer, *args, **kwargs):
raise ValueError
# Step1: fetch the embeddings and other layers before the transformer stack.
- for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items():
- embedding_layer = embedding_layer.to(self.device)
+ if not self.layer_wise:
+ for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items():
+ embedding_layer = embedding_layer.to(self.device)
# Step2: modify the first transformer block's forward function to obtain inputs for calibration
- self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].to(self.device)
+ if not self.layer_wise:
+ self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].to(self.device)
forward_cache = self.gptq_related_blocks["transformers"][0].forward
self.gptq_related_blocks["transformers"][0].forward = partial(
forward, self.gptq_related_blocks["transformers"][0]
@@ -451,7 +485,8 @@ def forward(layer, *args, **kwargs):
# Step3: run forward to obtain calibration datasets
logger.info("Collecting calibration inputs...")
for batch in tqdm(self.dataloader):
- batch = move_input_to_device(batch, self.device)
+ if not self.layer_wise:
+ batch = move_input_to_device(batch, self.device)
try:
if isinstance(batch, tuple) or isinstance(batch, list):
self.model(batch[0])
@@ -473,9 +508,10 @@ def forward(layer, *args, **kwargs):
# Step 4: restore original forward function, relocate layers back to cpu.
self.gptq_related_blocks["transformers"][0].forward = forward_cache
- self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].cpu()
- for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items():
- embedding_layer.to(self.device)
+ if not self.layer_wise:
+ self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].cpu()
+ for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items():
+ embedding_layer.to(self.device)
torch.cuda.empty_cache()
# end
logger.info("GPTQ quantization prepared.")
@@ -501,7 +537,7 @@ def update_blockwise_hidden_states(self, outs):
self.cache_positional_arguments[0] = outs[:]
@torch.no_grad()
- def execute_quantization(self, means=None, stds=None):
+ def execute_quantization(self, means=None, stds=None, model_path=None):
"""Run quantization."""
# Step1: prepare quantization (calibration datasets)
@@ -513,7 +549,11 @@ def execute_quantization(self, means=None, stds=None):
tblock_length = len(self.gptq_related_blocks["transformers"])
for block_idx in range(tblock_length):
logger.info(f"Quantizing layer {block_idx + 1} / {tblock_length}..")
- transformer_block = self.gptq_related_blocks["transformers"][block_idx].to(self.device)
+ if not self.layer_wise:
+ # if we do not apply layer-wise feature, we still place the entire block on the GPU
+ transformer_block = self.gptq_related_blocks["transformers"][block_idx].to(self.device)
+ else:
+ transformer_block = self.gptq_related_blocks["transformers"][block_idx] # .to(self.device)
# Step2.1: obtain all layers (Linear, Conv2d, etc) in the block which can be quantized.
sub_layers = find_layers(transformer_block)
sub_layers_to_quant = {}
@@ -534,8 +574,16 @@ def execute_quantization(self, means=None, stds=None):
# weight_config_this_layer = self.weight_config.get(
# self.get_full_layer_name(layer_name, block_idx), None
# )
- weight_config_this_layer = self.get_layer_config(self.get_full_layer_name(layer_name, block_idx))
- gptq_for_this_block[layer_name] = GPTQ(sub_layers[layer_name])
+ full_layer_name = self.get_full_layer_name(layer_name, block_idx)
+ weight_config_this_layer = self.get_layer_config(full_layer_name)
+ if self.layer_wise:
+ from ..torch_utils.layer_wise_quant.utils import load_value
+
+ W = load_value(self.model, full_layer_name + ".weight", model_path)
+ else:
+ W = sub_layers[layer_name].weight.data.clone()
+
+ gptq_for_this_block[layer_name] = GPTQ(sub_layers[layer_name], W, self.device)
# gptq_for_this_block[layer_name].quantizer = Quantizer()
gptq_for_this_block[layer_name].quantizer.configure(
weight_config_this_layer["wbits"],
@@ -555,11 +603,11 @@ def tmp(_, inp, out):
for layer_name in sub_layers:
handles.append(sub_layers[layer_name].register_forward_hook(add_batch(layer_name)))
idx = self.cache_key_arguments.pop("i")
- # import pdb;pdb.set_trace()
for j in range(len(self.dataloader)):
cache_keyword_batch = self.gather_single_batch_from_dict(self.cache_key_arguments, j)
cache_positional_batch = self.gather_single_batch_from_list(self.cache_positional_arguments, j)
- out = transformer_block(*cache_positional_batch, **cache_keyword_batch)[0]
+ out = transformer_block(*cache_positional_batch, **cache_keyword_batch)
+ out = self.track_hidden_states(out)
self.cache_key_arguments["i"] = idx
for h in handles:
h.remove()
@@ -570,12 +618,44 @@ def tmp(_, inp, out):
# )
weight_config_this_layer = self.get_layer_config(self.get_full_layer_name(layer_name, block_idx))
logger.info(f"Quantizing layer {layer_name}")
- scale, zp = gptq_for_this_block[layer_name].fasterquant(
+ if self.layer_wise:
+ from ..torch_utils.layer_wise_quant.utils import load_value
+
+ full_layer_name = self.get_full_layer_name(layer_name, block_idx)
+ W = load_value(self.model, full_layer_name + ".weight", model_path)
+ else:
+ W = sub_layers[layer_name].weight.data.clone()
+ scale, zp, Q = gptq_for_this_block[layer_name].fasterquant(
+ W,
blocksize=weight_config_this_layer["block_size"],
percdamp=weight_config_this_layer["percdamp"],
groupsize=weight_config_this_layer["group_size"],
act_order=weight_config_this_layer["act_order"],
)
+ if self.layer_wise:
+ from ..torch_utils.layer_wise_quant.utils import (
+ LWQ_WORKSPACE,
+ clean_module_weight,
+ load_value,
+ set_module_tensor_to_device,
+ )
+
+ sub_layer = sub_layers[layer_name]
+ full_layer_name = self.get_full_layer_name(layer_name, block_idx)
+ for n, p in sub_layer.named_parameters():
+ param_name = full_layer_name + "." + n
+ if n == "weight":
+ set_module_tensor_to_device(self.model, param_name, self.device, Q)
+ else:
+ value = load_value(self.model, param_name, model_path)
+ set_module_tensor_to_device(self.model, param_name, self.device, value)
+ # sub_layer.weight.data = Q
+ torch.save(sub_layer.state_dict(), LWQ_WORKSPACE + f"/{full_layer_name}.pt")
+ clean_module_weight(sub_layer)
+ del Q
+ gc.collect()
+ else:
+ sub_layers[layer_name].weight.data = Q
gptq_config[self.get_full_layer_name(layer_name, block_idx)] = {"scale": scale}
if not weight_config_this_layer["sym"]:
gptq_config[self.get_full_layer_name(layer_name, block_idx)]["zero"] = zp
@@ -591,10 +671,14 @@ def tmp(_, inp, out):
for j in range(len(self.dataloader)):
cache_keyword_batch = self.gather_single_batch_from_dict(self.cache_key_arguments, j)
cache_positional_batch = self.gather_single_batch_from_list(self.cache_positional_arguments, j)
- out = transformer_block(*cache_positional_batch, **cache_keyword_batch)[0]
+ out = transformer_block(*cache_positional_batch, **cache_keyword_batch)
+ out = self.track_hidden_states(out)
outs.append(out)
self.cache_key_arguments["i"] = idx
- self.gptq_related_blocks["transformers"][block_idx] = transformer_block.cpu()
+ if self.layer_wise:
+ self.gptq_related_blocks["transformers"][block_idx] = transformer_block
+ else:
+ self.gptq_related_blocks["transformers"][block_idx] = transformer_block.cpu()
del gptq_for_this_block
torch.cuda.empty_cache()
# iteratively replace the input with output, thus layerwise quantization can continue.
@@ -602,7 +686,7 @@ def tmp(_, inp, out):
logger.info("------------------------------")
logger.info("Quantization done")
- self.model.config.use_cache = self.use_cache
+ # self.model.config.use_cache = self.use_cache
# obtain model (all weight only quantization API function should return)
for k, v in gptq_config.items():
@@ -617,10 +701,10 @@ class GPTQ:
GPTQ: Accurate Post-training Compression for Generative Pretrained Transformers (https://arxiv.org/abs/2210.17323)
"""
- def __init__(self, layer):
+ def __init__(self, layer, W, device="cpu"):
self.layer = layer
- self.device = self.layer.weight.device
- W = layer.weight.data.clone()
+ self.device = device
+ # W = layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d) or isinstance(self.layer, nn.Conv1d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
@@ -661,8 +745,9 @@ def add_batch(self, inp, out):
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
self.H += inp.matmul(inp.t()) # H = X*X, which should be a sysm matrix
- def fasterquant(self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False):
- W = self.layer.weight.data.clone()
+ def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False):
+ # W = self.layer.weight.data.clone()
+ weight_shape, weight_dtype = W.shape, W.data.dtype
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
@@ -740,7 +825,7 @@ def fasterquant(self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=Fals
# logger.info(f"{torch.sum((self.layer(self.inp1) - self.out1) ** 2)}")
# logger.info(f"{torch.sum(Losses)}")
- if self.device != torch.device("cpu"):
+ if str(self.device).startswith("cuda"):
torch.cuda.synchronize()
logger.info(f"time {(time.time() - tick)}")
logger.info(f"error {torch.sum(Losses).item()}")
@@ -751,7 +836,8 @@ def fasterquant(self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=Fals
if isinstance(self.layer, transformers.Conv1D):
Q = Q.t()
- self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
+ # self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
+ Q = Q.reshape(weight_shape).to(weight_dtype)
if DEBUG:
logger.info(f"{torch.sum((self.layer(self.inp1) - self.out1) ** 2)}")
@@ -760,7 +846,7 @@ def fasterquant(self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=Fals
zero.append(self.quantizer.zero)
scale = torch.cat(scale, dim=1)
zero = torch.cat(zero, dim=1)
- return scale, zero
+ return scale, zero, Q
def free(self):
if DEBUG:
diff --git a/neural_compressor/adaptor/torch_utils/layer_wise_quant/__init__.py b/neural_compressor/adaptor/torch_utils/layer_wise_quant/__init__.py
index 6f01b1288bf..f347da31bb9 100644
--- a/neural_compressor/adaptor/torch_utils/layer_wise_quant/__init__.py
+++ b/neural_compressor/adaptor/torch_utils/layer_wise_quant/__init__.py
@@ -15,5 +15,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Torch layer-wise quantization module."""
-from .utils import load_shell
+from .utils import load_empty_model
from .quantize import LayerWiseQuant
diff --git a/neural_compressor/adaptor/torch_utils/layer_wise_quant/quantize.py b/neural_compressor/adaptor/torch_utils/layer_wise_quant/quantize.py
index 1746ad82140..9e3f8789dad 100644
--- a/neural_compressor/adaptor/torch_utils/layer_wise_quant/quantize.py
+++ b/neural_compressor/adaptor/torch_utils/layer_wise_quant/quantize.py
@@ -40,7 +40,7 @@
update_module,
)
-TMP_DIR = os.path.join(default_workspace, "layer_wise_quant_tmp_dir")
+TMP_DIR = os.path.join(default_workspace, "lwq_tmpdir")
def mk_tmp_dir():
@@ -92,7 +92,7 @@ def __init__(
alpha=0.5,
):
"""Init LayerWiseQuant."""
- # self.q_model = load_shell(pretrained_model_name_or_path, cls)
+ # self.q_model = load_empty_model(pretrained_model_name_or_path, cls)
self.q_model = q_model
self.fp32_model = deepcopy(self.q_model)
self.path = _get_path(pretrained_model_name_or_path)
diff --git a/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py b/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py
index e932c40480a..8bd3d32d320 100644
--- a/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py
+++ b/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py
@@ -25,7 +25,7 @@
torch = LazyImport("torch")
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
-from transformers import AutoConfig
+from transformers import AutoConfig, AutoModelForCausalLM
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from ....config import options
@@ -107,7 +107,7 @@ def dowload_hf_model(repo_id, cache_dir=None, repo_type=None, revision=None):
return file_path
-def load_shell(pretrained_model_name_or_path, cls, **kwargs):
+def load_empty_model(pretrained_model_name_or_path, cls=AutoModelForCausalLM, **kwargs):
"""Load a empty model."""
is_local = os.path.isdir(pretrained_model_name_or_path)
if is_local: # pragma: no cover
@@ -124,6 +124,7 @@ def load_shell(pretrained_model_name_or_path, cls, **kwargs):
model = cls(config)
model.tie_weights()
model.eval()
+ model.path = pretrained_model_name_or_path
return model
@@ -223,7 +224,10 @@ def load_module(model, module_name, path, device="cpu"):
set_module_tensor_to_device(model, param_name, device, value)
-def register_weight_hooks(model, path, device="cpu", clean_weight=True):
+def register_weight_hooks(model, path, device="cpu", clean_weight=True, saved_path=None):
+ if saved_path:
+ os.makedirs(saved_path, exist_ok=True)
+
def forward_pre_hook(name):
def hook(module, input):
state_dict = None
@@ -241,6 +245,9 @@ def hook(module, input):
def forward_hook(name):
def hook(module, input, output):
+ if saved_path:
+ file_path = os.path.join(saved_path, f"{name}.pt")
+ torch.save(module.state_dict(), file_path)
clean_module_weight(module)
return hook
diff --git a/neural_compressor/adaptor/torch_utils/mixed_precision.py b/neural_compressor/adaptor/torch_utils/mixed_precision.py
index dce5fe54b3c..6c461ea195b 100644
--- a/neural_compressor/adaptor/torch_utils/mixed_precision.py
+++ b/neural_compressor/adaptor/torch_utils/mixed_precision.py
@@ -38,7 +38,9 @@ def ipex_mixed_precision(model, example_inputs=None, device="cpu"):
try:
mp_model = torch.jit.trace(mp_model, example_kwarg_inputs=example_inputs)
except:
- mp_model = torch.jit.trace(mp_model, example_kwarg_inputs=example_inputs, strict=False)
+ mp_model = torch.jit.trace(
+ mp_model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False
+ )
else:
try:
mp_model = torch.jit.trace(mp_model, example_inputs)
diff --git a/neural_compressor/adaptor/torch_utils/smooth_quant.py b/neural_compressor/adaptor/torch_utils/smooth_quant.py
index 65804468ec4..7b59b3ce3e5 100644
--- a/neural_compressor/adaptor/torch_utils/smooth_quant.py
+++ b/neural_compressor/adaptor/torch_utils/smooth_quant.py
@@ -1186,7 +1186,9 @@ def trace(self, model, dummy_input):
dummy_input = move_input_to_device(dummy_input, "cpu")
if isinstance(dummy_input, dict) or isinstance(dummy_input, UserDict):
try:
- traced_model = torch.jit.trace(model, example_kwarg_inputs=dict(dummy_input), strict=False)
+ traced_model = torch.jit.trace(
+ model, example_kwarg_inputs=dict(dummy_input), strict=False, check_trace=False
+ )
traced_model = torch.jit.freeze(traced_model.eval(), optimize_numerics=optimize_numerics)
except Exception as e:
logger.warning(e)
diff --git a/neural_compressor/adaptor/torch_utils/weight_only.py b/neural_compressor/adaptor/torch_utils/weight_only.py
index b9376a72c7f..7ba86eaa344 100644
--- a/neural_compressor/adaptor/torch_utils/weight_only.py
+++ b/neural_compressor/adaptor/torch_utils/weight_only.py
@@ -470,15 +470,27 @@ def rtn_quantize(
def gptq_quantize(
- model, weight_config={}, dataloader=None, nsamples=128, use_max_length=True, pad_max_length=2048, device=None
+ model,
+ weight_config={},
+ dataloader=None,
+ nsamples=128,
+ use_max_length=True,
+ pad_max_length=2048,
+ device=None,
+ layer_wise=False,
+ model_path=None,
):
"""Run weight-only quantization with."""
# TODO: unify weight_config keys, add docstring, and support default config
assert isinstance(model, torch.nn.Module), "only support torch module"
+ if layer_wise:
+ assert model_path is not None, "model_path should not be None when use layer_wise mode"
from .gptq import GPTQuantizer
- gptq_quantizer = GPTQuantizer(model, weight_config, dataloader, nsamples, use_max_length, pad_max_length, device)
- fp32_modified_model, gptq_config = gptq_quantizer.execute_quantization()
+ gptq_quantizer = GPTQuantizer(
+ model, weight_config, dataloader, nsamples, use_max_length, pad_max_length, device, layer_wise=layer_wise
+ )
+ fp32_modified_model, gptq_config = gptq_quantizer.execute_quantization(model_path=model_path)
logger.info("GPTQ quantizing done.")
return fp32_modified_model, gptq_config
diff --git a/neural_compressor/model/torch_model.py b/neural_compressor/model/torch_model.py
index 95f273eeab5..eeada402f35 100644
--- a/neural_compressor/model/torch_model.py
+++ b/neural_compressor/model/torch_model.py
@@ -356,7 +356,8 @@ def save(self, root=None):
if os.path.exists(os.path.join(LWQ_WORKSPACE, f"{name}.pt")):
state_dict = torch.load(os.path.join(LWQ_WORKSPACE, f"{name}.pt"))
model_path = _get_path(
- self.q_config["recipe_cfgs"]["layer_wise_quant_args"].get("model_path")
+ # self.q_config["recipe_cfgs"]["layer_wise_quant_args"].get("model_path")
+ self._model.path
)
for n, p in module.named_parameters():
param_name = name + "." + n
diff --git a/test/algorithm/test_layer_wise_quant.py b/test/algorithm/test_layer_wise_quant.py
index e25036bb021..2eef1e89bd9 100644
--- a/test/algorithm/test_layer_wise_quant.py
+++ b/test/algorithm/test_layer_wise_quant.py
@@ -8,14 +8,14 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
from neural_compressor import PostTrainingQuantConfig, quantization
-from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_shell
+from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_empty_model
from neural_compressor.utils.pytorch import load
class TestLayerWise(unittest.TestCase):
def test_layer_wise(self):
model_name_or_path = "facebook/opt-125m"
- fp32_model = load_shell(model_name_or_path, AutoModelForCausalLM, torchscript=True)
+ fp32_model = load_empty_model(model_name_or_path, torchscript=True)
class TestDataset(Dataset):
def __init__(self, size=5, shape=128):
@@ -65,7 +65,7 @@ def test_util(self):
)
model_name_or_path = "facebook/opt-125m"
- model = load_shell(model_name_or_path, AutoModelForCausalLM, torchscript=True)
+ model = load_empty_model(model_name_or_path, torchscript=True)
children = get_children(model)
named_children = get_named_children(model)
self.assertEqual(children, [v for k, v in named_children])
diff --git a/test/algorithm/test_lwq_weight_only.py b/test/algorithm/test_lwq_weight_only.py
index c8840123382..3735f0b2581 100644
--- a/test/algorithm/test_lwq_weight_only.py
+++ b/test/algorithm/test_lwq_weight_only.py
@@ -1,21 +1,22 @@
import shutil
import sys
import unittest
+from copy import deepcopy
sys.path.insert(0, "./")
import torch
from torch.utils.data import DataLoader, Dataset
-from transformers import AutoModelForCausalLM, AutoTokenizer
from neural_compressor import PostTrainingQuantConfig, quantization
-from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_shell
+from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_empty_model
from neural_compressor.utils.pytorch import load
class TestLayerWise(unittest.TestCase):
- def test_layer_wise(self):
- model_name_or_path = "facebook/opt-125m"
- fp32_model = load_shell(model_name_or_path, AutoModelForCausalLM, torchscript=True)
+ @classmethod
+ def setUpClass(self):
+ self.model_name_or_path = "facebook/opt-125m"
+ self.fp32_model = load_empty_model(self.model_name_or_path, torchscript=True)
class TestDataset(Dataset):
def __init__(self, size=5, shape=128):
@@ -29,30 +30,58 @@ def __len__(self):
return self.len
eval_dataset = TestDataset()
- eval_dataloader = DataLoader(eval_dataset, batch_size=8)
+ self.eval_dataloader = DataLoader(eval_dataset, batch_size=8)
+ @classmethod
+ def tearDownClass(cls):
+ shutil.rmtree("./saved_model", ignore_errors=True)
+
+ def test_rtn_lwq(self):
conf = PostTrainingQuantConfig(
approach="weight_only",
recipes={
"layer_wise_quant": True,
- "layer_wise_quant_args": {
- "model_path": "facebook/opt-125m",
- },
+ # "layer_wise_quant_args": {
+ # "model_path": "facebook/opt-125m",
+ # },
"rtn_args": {"enable_full_range": True},
},
)
q_model = quantization.fit(
- fp32_model,
+ deepcopy(self.fp32_model),
conf,
- calib_dataloader=eval_dataloader,
+ calib_dataloader=self.eval_dataloader,
eval_func=lambda x: 0.1,
)
ouput_dir = "./saved_model"
q_model.save(ouput_dir)
- load_model = load(ouput_dir, fp32_model, weight_only=True)
+ load_model = load(ouput_dir, deepcopy(self.fp32_model), weight_only=True)
+ self.assertNotEqual(load_model.lm_head.weight.device.type, "meta")
+
+ def test_gptq_lwq(self):
+ conf = PostTrainingQuantConfig(
+ approach="weight_only",
+ op_type_dict={
+ ".*": { # re.match
+ "weight": {
+ "bits": 4, # 1-8 bits
+ "group_size": 32,
+ "scheme": "sym",
+ "algorithm": "GPTQ",
+ },
+ },
+ },
+ recipes={
+ "gptq_args": {"actorder": True, "mse": True, "perchannel": False},
+ "layer_wise_quant": True,
+ },
+ )
+ q_model = quantization.fit(deepcopy(self.fp32_model), conf, calib_dataloader=self.eval_dataloader)
+ ouput_dir = "./saved_model"
+ q_model.save(ouput_dir)
+ load_model = load(ouput_dir, deepcopy(self.fp32_model), weight_only=True, layer_wise=True)
self.assertNotEqual(load_model.lm_head.weight.device.type, "meta")
- shutil.rmtree(ouput_dir)
if __name__ == "__main__":