diff --git a/.azure-pipelines/scripts/models/env_setup.sh b/.azure-pipelines/scripts/models/env_setup.sh index 99ef12b6355..0755afa8130 100644 --- a/.azure-pipelines/scripts/models/env_setup.sh +++ b/.azure-pipelines/scripts/models/env_setup.sh @@ -78,7 +78,7 @@ if [[ "${inc_new_api}" == "false" ]]; then fi cd ${model_src_dir} -pip install ruamel_yaml +pip install ruamel.yaml==0.17.40 pip install psutil pip install protobuf==4.23.4 if [[ "${framework}" == "tensorflow" ]]; then diff --git a/docs/source/quantization_weight_only.md b/docs/source/quantization_weight_only.md index b26c5194aa3..41a02e0d460 100644 --- a/docs/source/quantization_weight_only.md +++ b/docs/source/quantization_weight_only.md @@ -129,6 +129,36 @@ torch.save(compressed_model.state_dict(), "compressed_model.pt") The saved_results folder contains two files: `best_model.pt` and `qconfig.json`, and the generated q_model is a fake quantized model. + +### **WOQ algorithms tuning** + +To find the best algorithm, users can omit specifying a particular algorithm. In comparison to setting a specific algorithm, this tuning process will traverse through a set of pre-defined WOQ configurations and identify the optimal one with the best result. For details usage, please refer to the [tuning strategy](./tuning_strategies.md#Basic). + +> **Note:** Currently, this behavior is specific to the `ONNX Runtime` backend. + +**Pre-defined configurations** + +| WOQ configurations | setting | +|:------------------:|:-------:| +|RTN_G32ASYM| {"algorithm": "RTN", "group_size": 32, "scheme": "asym"}| +|GPTQ_G32ASYM| {"algorithm": "GPTQ", "group_size": 32, "scheme": "asym"}| +|GPTQ_G32ASYM_DISABLE_LAST_MATMUL| {"algorithm": "GPTQ", "group_size": 32, "scheme": "asym"}
& disable last MatMul| +|GPTQ_G128ASYM| {"algorithm": "GPTQ", "group_size": 128, "scheme": "asym"}| +|AWQ_G32ASYM| {"algorithm": "AWQ", "group_size": 32, "scheme": "asym"}| + +**User code example** + +```python +conf = PostTrainingQuantConfig( + approach="weight_only", + quant_level="auto", # quant_level supports "auto" or 1 for woq config tuning +) +q_model = quantization.fit(model, conf, eval_func=eval_func, calib_dataloader=dataloader) +q_model.save("saved_results") +``` + +Refer to this [link](../../examples/onnxrt/nlp/huggingface_model/text_generation/llama/quantization/weight_only) for an example of WOQ algorithms tuning on ONNX Llama models. + ## Layer Wise Quantization Large language models (LLMs) have shown exceptional performance across various tasks, meanwhile, the substantial parameter size poses significant challenges for deployment. Layer-wise quantization(LWQ) can greatly reduce the memory footprint of LLMs, usually 80-90% reduction, which means that users can quantize LLMs even on single node using GPU or CPU. We can quantize the model under memory-constrained devices, therefore making the huge-sized LLM quantization possible. @@ -143,22 +173,19 @@ Large language models (LLMs) have shown exceptional performance across various t |:--------------:|:----------:| | RTN | ✔ | | AWQ | ✕ | -| GPTQ | ✕ | +| GPTQ | ✔ | | TEQ | ✕ | ### Example ```python from neural_compressor import PostTrainingQuantConfig, quantization -from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_shell +from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_empty_model -fp32_model = load_shell(model_name_or_path, AutoModelForCausalLM, torchscript=True) +fp32_model = load_empty_model(model_name_or_path, torchscript=True) conf = PostTrainingQuantConfig( approach="weight_only", recipes={ "layer_wise_quant": True, - "layer_wise_quant_args": { - "model_path": "facebook/opt-125m", - }, "rtn_args": {"enable_full_range": True}, }, ) @@ -171,6 +198,7 @@ q_model = quantization.fit( ) ouput_dir = "./saved_model" q_model.save(ouput_dir) +q_model = load(ouput_dir, fp32_model, weight_only=True, layer_wise=True) ``` ## Reference diff --git a/docs/source/tuning_strategies.md b/docs/source/tuning_strategies.md index 31062a3e319..7788149f7e3 100644 --- a/docs/source/tuning_strategies.md +++ b/docs/source/tuning_strategies.md @@ -181,6 +181,8 @@ flowchart TD > For [smooth quantization](./smooth_quant.md), users can tune the smooth quantization alpha by providing a list of scalars for the `alpha` item. The tuning process will take place at the **start stage** of the tuning procedure. For details usage, please refer to the [smooth quantization example](./smooth_quant.md#Example). +> For [weight-only quantization](./quantization_weight_only.md), users can tune the weight-only algorithms from the available [pre-defined configurations](./quantization_weight_only.md#woq-algorithms-tuning). The tuning process will take place at the **start stage** of the tuning procedure, preceding the smooth quantization alpha tuning. For details usage, please refer to the [weight-only quantization example](./quantization_weight_only.md#woq-algorithms-tuning). +*Please note that this behavior is specific to the `ONNX Runtime` backend.* **1.** Default quantization diff --git a/examples/.config/model_params_onnxrt.json b/examples/.config/model_params_onnxrt.json index 8f84567f752..d0cb53b7fd7 100644 --- a/examples/.config/model_params_onnxrt.json +++ b/examples/.config/model_params_onnxrt.json @@ -322,6 +322,13 @@ "main_script": "main.py", "batch_size": 1 }, + "beit": { + "model_src_dir": "image_recognition/beit/quantization/ptq_static", + "dataset_location": "/tf_dataset/pytorch/ImageNet/raw", + "input_model": "/tf_dataset2/models/onnx/beit/beit_base_patch16_224_pt22k_ft22kto1k.onnx", + "main_script": "main.py", + "batch_size": 1 + }, "mobilebert_squad_mlperf_qdq": { "model_src_dir": "nlp/onnx_model_zoo/mobilebert/quantization/ptq_static", "dataset_location": "/tf_dataset2/datasets/squad", diff --git a/examples/README.md b/examples/README.md index e8c9ea98b08..946947ebd40 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1133,6 +1133,12 @@ IntelĀ® Neural Compressor validated examples with multiple compression technique Post-Training Static Quantization qlinearops + + BEiT + Image Recognition + Post-Training Static Quantization + qlinearops + CodeBert Natural Language Processing diff --git a/examples/notebook/onnxruntime/Quick_Started_Notebook_of_INC_for_ONNXRuntime.ipynb b/examples/notebook/onnxruntime/Quick_Started_Notebook_of_INC_for_ONNXRuntime.ipynb index c3e0aa5f027..e5d06c2be50 100644 --- a/examples/notebook/onnxruntime/Quick_Started_Notebook_of_INC_for_ONNXRuntime.ipynb +++ b/examples/notebook/onnxruntime/Quick_Started_Notebook_of_INC_for_ONNXRuntime.ipynb @@ -47,13 +47,14 @@ "outputs": [], "source": [ "# install neural-compressor from source\n", + "import sys\n", "!git clone https://github.com/intel/neural-compressor.git\n", "%cd ./neural-compressor\n", - "!pip install -r requirements.txt\n", - "!python setup.py install\n", + "!{sys.executable} -m pip install -r requirements.txt\n", + "!{sys.executable} setup.py install\n", "%cd ..\n", "# or install stable basic version from pypi\n", - "# pip install neural-compressor" + "# pip install neural-compressor\n" ] }, { @@ -65,10 +66,8 @@ }, "outputs": [], "source": [ - "# install onnx related packages\n", - "!pip install onnx onnxruntime onnxruntime-extensions\n", - "# install other packages used in this notebook.\n", - "!pip install torch transformers accelerate coloredlogs sympy numpy sentencepiece protobuf optimum" + "# install required packages\n", + "!{sys.executable} install -r requirements.txt\n" ] }, { @@ -168,7 +167,7 @@ "source": [ "!export GLUE_DIR=./glue_data\n", "!wget https://raw.githubusercontent.com/Shimao-Zhang/Download_GLUE_Data/master/download_glue_data.py\n", - "!python download_glue_data.py --data_dir=GLUE_DIR --tasks=SST" + "!{sys.executable} download_glue_data.py --data_dir=GLUE_DIR --tasks=SST\n" ] }, { @@ -193,7 +192,7 @@ "int8_model_path = \"onnx-model/int8-model.onnx\"\n", "data_path = \"./GLUE_DIR/SST-2\"\n", "task = \"sst-2\"\n", - "batch_size = 8" + "batch_size = 8\n" ] }, { @@ -343,7 +342,7 @@ " label=label\n", " )\n", " features.append(feats)\n", - " return features" + " return features\n" ] }, { @@ -377,7 +376,7 @@ " model_name_or_path=model_name_or_path,\n", " model_type=\"distilbert\",\n", " task=task)\n", - "dataloader = DataLoader(framework=\"onnxruntime\", dataset=dataset, batch_size=batch_size)" + "dataloader = DataLoader(framework=\"onnxruntime\", dataset=dataset, batch_size=batch_size)\n" ] }, { @@ -448,7 +447,7 @@ " elif output_mode == \"regression\":\n", " processed_preds = np.squeeze(self.pred_list)\n", " result = transformers.glue_compute_metrics(self.task, processed_preds, self.label_list)\n", - " return result[self.return_key[self.task]]" + " return result[self.return_key[self.task]]\n" ] }, { @@ -486,7 +485,7 @@ " ort_inputs.update({inputs_names[i]: inputs[i]})\n", " predictions = session.run(None, ort_inputs)\n", " metric.update(predictions[0], labels)\n", - " return metric.result()" + " return metric.result()\n" ] }, { @@ -567,7 +566,7 @@ " num_heads=num_heads,\n", " hidden_size=hidden_size,\n", " optimization_options=opt_options)\n", - "model = model_optimizer.model" + "model = model_optimizer.model\n" ] }, { @@ -722,7 +721,7 @@ " config,\n", " eval_func=eval_func,\n", " calib_dataloader=dataloader)\n", - "q_model.save(int8_model_path)" + "q_model.save(int8_model_path)\n" ] }, { diff --git a/examples/notebook/onnxruntime/requirements.txt b/examples/notebook/onnxruntime/requirements.txt new file mode 100644 index 00000000000..13f6ef9a8e1 --- /dev/null +++ b/examples/notebook/onnxruntime/requirements.txt @@ -0,0 +1,12 @@ +onnx +onnxruntime +onnxruntime-extensions +torch +transformers +accelerate +coloredlogs +sympy +numpy +sentencepiece +protobuf +optimum diff --git a/examples/notebook/pytorch/Quick_Started_Notebook_of_INC_for_Pytorch.ipynb b/examples/notebook/pytorch/Quick_Started_Notebook_of_INC_for_Pytorch.ipynb index 0a7a277ee96..62d38981668 100644 --- a/examples/notebook/pytorch/Quick_Started_Notebook_of_INC_for_Pytorch.ipynb +++ b/examples/notebook/pytorch/Quick_Started_Notebook_of_INC_for_Pytorch.ipynb @@ -45,14 +45,15 @@ "outputs": [], "source": [ "# install neural-compressor from source\n", + "import sys\n", "!git clone https://github.com/intel/neural-compressor.git\n", "%cd ./neural-compressor\n", - "!pip install -r requirements.txt\n", - "!python setup.py install\n", + "!{sys.executable} -m pip install -r requirements.txt\n", + "!{sys.executable} setup.py install\n", "%cd ..\n", "\n", "# or install stable basic version from pypi\n", - "!pip install neural-compressor" + "!{sys.executable} -m pip install neural-compressor\n" ] }, { @@ -62,7 +63,7 @@ "outputs": [], "source": [ "# install other packages used in this notebook.\n", - "!pip install torch>=1.9.0 transformers>=4.16.0 accelerate sympy numpy sentencepiece!=0.1.92 protobuf<=3.20.3 datasets>=1.1.3 scipy scikit-learn Keras-Preprocessing" + "!{sys.executable} -m pip install -r requirements.txt\n" ] }, { @@ -303,10 +304,10 @@ "outputs": [], "source": [ "# fp32 benchmark\n", - "!python benchmark.py --input_model ./pytorch_model.bin 2>&1|tee fp32_benchmark.log\n", + "!{sys.executable} benchmark.py --input_model ./pytorch_model.bin 2>&1|tee fp32_benchmark.log\n", "\n", "# int8 benchmark\n", - "!python benchmark.py --input_model ./saved_results/best_model.pt 2>&1|tee int8_benchmark.log\n" + "!{sys.executable} benchmark.py --input_model ./saved_results/best_model.pt 2>&1|tee int8_benchmark.log\n" ] } ], diff --git a/examples/notebook/pytorch/requirements.txt b/examples/notebook/pytorch/requirements.txt new file mode 100644 index 00000000000..aa1af71d2b3 --- /dev/null +++ b/examples/notebook/pytorch/requirements.txt @@ -0,0 +1,11 @@ +torch>=1.9.0 +transformers>=4.16.0 +accelerate +sympy +numpy +sentencepiece!=0.1.92 +protobuf<=3.20.3 +datasets>=1.1.3 +scipy +scikit-learn +Keras-Preprocessing diff --git a/examples/notebook/tensorflow/resnet/requirements.txt b/examples/notebook/tensorflow/resnet/requirements.txt new file mode 100644 index 00000000000..a2d431e0ace --- /dev/null +++ b/examples/notebook/tensorflow/resnet/requirements.txt @@ -0,0 +1,8 @@ +numpy +neural-compressor +tensorflow +datasets +requests +urllib3 +pyOpenSSL +git+https://github.com/huggingface/huggingface_hub diff --git a/examples/notebook/tensorflow/resnet/resnet_quantization.ipynb b/examples/notebook/tensorflow/resnet/resnet_quantization.ipynb index 8f1c36ef3d8..f2b168949f8 100644 --- a/examples/notebook/tensorflow/resnet/resnet_quantization.ipynb +++ b/examples/notebook/tensorflow/resnet/resnet_quantization.ipynb @@ -29,12 +29,11 @@ "metadata": {}, "outputs": [], "source": [ + "import sys\n", "!conda install python==3.10 -y\n", - "!pip install neural-compressor\n", - "!wget -nc https://storage.googleapis.com/intel-optimized-tensorflow/models/v1_6/resnet50_fp32_pretrained_model.pb\n", - "!pip install tensorflow\n", - "!pip install datasets\n", - "!pip install git+https://github.com/huggingface/huggingface_hub" + "!{sys.executable} -m pip install -r requirements.txt \n", + "\n", + "!wget -nc https://storage.googleapis.com/intel-optimized-tensorflow/models/v1_6/resnet50_fp32_pretrained_model.pb\n" ] }, { @@ -43,9 +42,11 @@ "metadata": {}, "outputs": [], "source": [ + "print(sys.executable)\n", + "!{sys.executable} -m pip list\n", "import tensorflow as tf\n", "import numpy as np\n", - "import datasets" + "import datasets\n" ] }, { @@ -63,8 +64,8 @@ "source": [ "# login to huggingface to download the imagenet-1k dataset\n", "# you should replace this read-only token with your own by create one on (https://huggingface.co/settings/tokens)\n", - "# !huggingface-cli login --token \n", - "!huggingface-cli login --token hf_xxxxxxxxxxxxxxxxxxxxxx" + "from huggingface_hub.hf_api import HfFolder\n", + "HfFolder.save_token('hf_xxxxxxxxxxxxxxxxxxxxxx')\n" ] }, { @@ -75,8 +76,8 @@ "source": [ "from datasets import load_dataset\n", "# load dataset in streaming way will get an IterableDatset\n", - "calib_dataset = load_dataset('imagenet-1k', split='train', streaming=True, use_auth_token=True)\n", - "eval_dataset = load_dataset('imagenet-1k', split='validation', streaming=True, use_auth_token=True)" + "calib_dataset = load_dataset('imagenet-1k', split='train', streaming=True, token=True)\n", + "eval_dataset = load_dataset('imagenet-1k', split='validation', streaming=True, token=True)\n" ] }, { @@ -97,7 +98,7 @@ " return datasets.Dataset.from_dict(data)\n", "\n", "sub_calib_dataset = sample_data(calib_dataset, MAX_SAMPLE_LENGTG)\n", - "sub_eval_dataset = sample_data(eval_dataset, MAX_SAMPLE_LENGTG)" + "sub_eval_dataset = sample_data(eval_dataset, MAX_SAMPLE_LENGTG)\n" ] }, { @@ -136,7 +137,7 @@ " batch_inputs = []\n", " labels = []\n", " def __len__(self):\n", - " return self.length" + " return self.length\n" ] }, { @@ -146,7 +147,7 @@ "outputs": [], "source": [ "calib_dataloader = CustomDataloader(dataset=sub_calib_dataset, batch_size=32)\n", - "eval_dataloader = CustomDataloader(dataset=sub_eval_dataset, batch_size=32)" + "eval_dataloader = CustomDataloader(dataset=sub_eval_dataset, batch_size=32)\n" ] }, { @@ -193,7 +194,7 @@ " return acc\n", "\n", "q_model = quantization.fit(\"./resnet50_fp32_pretrained_model.pb\", conf=conf, calib_dataloader=calib_dataloader, eval_func=eval_func)\n", - "q_model.save(\"resnet50_int8.pb\")" + "q_model.save(\"resnet50_int8.pb\")\n" ] }, { @@ -221,7 +222,7 @@ "metadata": {}, "outputs": [], "source": [ - "!python resnet_benchmark.py --input_model resnet50_fp32_pretrained_model.pb 2>&1|tee fp32_benchmark.log" + "!{sys.executable} resnet_benchmark.py --input_model resnet50_fp32_pretrained_model.pb 2>&1|tee fp32_benchmark.log\n" ] }, { @@ -237,7 +238,7 @@ "metadata": {}, "outputs": [], "source": [ - "!python resnet_benchmark.py --input_model resnet50_int8.pb 2>&1|tee int8_benchmark.log" + "!{sys.executable} resnet_benchmark.py --input_model resnet50_int8.pb 2>&1|tee int8_benchmark.log\n" ] }, { diff --git a/examples/notebook/tensorflow/vgg19_ibean/requirements.txt b/examples/notebook/tensorflow/vgg19_ibean/requirements.txt new file mode 100644 index 00000000000..e866fcd37f6 --- /dev/null +++ b/examples/notebook/tensorflow/vgg19_ibean/requirements.txt @@ -0,0 +1,5 @@ +numpy +matplotlib +tensorflow +tensorflow-hub +tensorflow-datasets diff --git a/examples/onnxrt/image_recognition/beit/quantization/ptq_static/README.md b/examples/onnxrt/image_recognition/beit/quantization/ptq_static/README.md new file mode 100644 index 00000000000..f6775301e62 --- /dev/null +++ b/examples/onnxrt/image_recognition/beit/quantization/ptq_static/README.md @@ -0,0 +1,55 @@ +Step-by-Step +============ + +This example load [BERT Pre-Training of Image Transformers](https://arxiv.org/abs/2106.08254)(BEiT) model and confirm its accuracy and performance based on [ImageNet-1k dataset](http://www.image-net.org/). You need to download this dataset yourself. + +In this example, the BEiT model is pre-trained in a self-supervised fashion on ImageNet-22k - also called ImageNet-21k (14 million images, 21,841 classes) at resolution 224x224, and fine-tuned on the same dataset at resolution 224x224. It was first released in [this repository](https://github.com/microsoft/unilm/tree/master/beit). + + +# Prerequisite + +## 1. Environment +```shell +pip install neural-compressor +pip install -r requirements.txt +``` +> Note: Validated ONNX Runtime [Version](/docs/source/installation_guide.md#validated-software-environment). + +## 2. Prepare Model + +Prepare DETR R18 model for table structure recognition. + +```shell +python prepare_model.py --input_model=beit_base_patch16_224 --output_model=beit_base_patch16_224_pt22k_ft22kto1k.onnx +``` + +## 3. Prepare Dataset + +Download and extract [ImageNet-1k](http://www.image-net.org/) to dir: /path/to/imagenet. The dir include below folder: + +```bash +ls /path/to/imagenet +train val +``` + +# Run + +## 1. Quantization + +Quantize model with QLinearOps: + +```bash +bash run_quant.sh --input_model=/path/to/model \ # model path as *.onnx + --dataset_location=/path/to/imagenet \ + --output_model=/path/to/save \ + --quant_format="QOperator" +``` + +## 2. Benchmark + +```bash +bash run_benchmark.sh --input_model=/path/to/model \ # model path as *.onnx + --dataset_location=/path/to/imagenet \ + --batch_size=batch_size \ + --mode=performance # or accuracy +``` \ No newline at end of file diff --git a/examples/onnxrt/image_recognition/beit/quantization/ptq_static/beit_modeling_finetune.py b/examples/onnxrt/image_recognition/beit/quantization/ptq_static/beit_modeling_finetune.py new file mode 100644 index 00000000000..7e9cd599893 --- /dev/null +++ b/examples/onnxrt/image_recognition/beit/quantization/ptq_static/beit_modeling_finetune.py @@ -0,0 +1,420 @@ +# -------------------------------------------------------- +# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) +# Github source: https://github.com/microsoft/unilm/tree/master/beit +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# By Hangbo Bao +# Based on timm and DeiT code bases +# https://github.com/rwightman/pytorch-image-models/tree/master/timm +# https://github.com/facebookresearch/deit/ +# https://github.com/facebookresearch/dino +# --------------------------------------------------------' +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import drop_path, to_2tuple, trunc_normal_ +from timm.models.registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', + 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), + **kwargs + } + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return 'p={}'.format(self.drop_prob) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + # x = self.drop(x) + # commit this for the original BERT implement + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0., window_size=None, attn_head_dim=None): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + else: + self.q_bias = None + self.v_bias = None + + if window_size: + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = \ + torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + else: + self.window_size = None + self.relative_position_bias_table = None + self.relative_position_index = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, rel_pos_bias=None): + B, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if self.relative_position_bias_table is not None: + relative_position_bias = \ + self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if rel_pos_bias is not None: + attn = attn + rel_pos_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, + window_size=None, attn_head_dim=None): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if init_values is not None and init_values > 0: + self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) + self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) + else: + self.gamma_1, self.gamma_2 = None, None + + def forward(self, x, rel_pos_bias=None): + if self.gamma_1 is None: + x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class RelativePositionBias(nn.Module): + + def __init__(self, window_size, num_heads): + super().__init__() + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = \ + torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + + # trunc_normal_(self.relative_position_bias_table, std=.02) + + def forward(self): + relative_position_bias = \ + self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + +class VisionTransformer(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, + use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, + use_mean_pooling=True, init_scale=0.001, pretrained_cfg=None, pretrained_cfg_overlay=None): + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + if use_abs_pos_emb: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + else: + self.pos_embed = None + self.pos_drop = nn.Dropout(p=drop_rate) + + if use_shared_rel_pos_bias: + self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads) + else: + self.rel_pos_bias = None + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.use_rel_pos_bias = use_rel_pos_bias + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None) + for i in range(depth)]) + self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) + self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + # trunc_normal_(self.mask_token, std=.02) + if isinstance(self.head, nn.Linear): + trunc_normal_(self.head.weight, std=.02) + self.apply(self._init_weights) + self.fix_init_weight() + + if isinstance(self.head, nn.Linear): + self.head.weight.data.mul_(init_scale) + self.head.bias.data.mul_(init_scale) + + def fix_init_weight(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_num_layers(self): + return len(self.blocks) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + batch_size, seq_len, _ = x.size() + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for blk in self.blocks: + x = blk(x, rel_pos_bias=rel_pos_bias) + + x = self.norm(x) + if self.fc_norm is not None: + t = x[:, 1:, :] + return self.fc_norm(t.mean(1)) + else: + return x[:, 0] + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + def get_intermediate_layers(self, x): + x = self.patch_embed(x) + batch_size, seq_len, _ = x.size() + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + features = [] + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for blk in self.blocks: + x = blk(x, rel_pos_bias) + features.append(x) + + return features + + +@register_model +def beit_base_patch16_224(pretrained=False, **kwargs): + model = VisionTransformer( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + return model + + +@register_model +def beit_base_patch16_384(pretrained=False, **kwargs): + model = VisionTransformer( + img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + return model + + +@register_model +def beit_large_patch16_224(pretrained=False, **kwargs): + model = VisionTransformer( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + return model + + +@register_model +def beit_large_patch16_384(pretrained=False, **kwargs): + model = VisionTransformer( + img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + return model + + +@register_model +def beit_large_patch16_512(pretrained=False, **kwargs): + model = VisionTransformer( + img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + return model diff --git a/examples/onnxrt/image_recognition/beit/quantization/ptq_static/main.py b/examples/onnxrt/image_recognition/beit/quantization/ptq_static/main.py new file mode 100644 index 00000000000..b7d9bc0eab0 --- /dev/null +++ b/examples/onnxrt/image_recognition/beit/quantization/ptq_static/main.py @@ -0,0 +1,163 @@ +import os +import tqdm +import onnx +import torch +import logging +import argparse +import onnxruntime as ort +from timm.utils import accuracy +from torchvision import datasets, transforms +from timm.data.constants import \ + IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD + +logger = logging.getLogger(__name__) +logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt = '%m/%d/%Y %H:%M:%S', + level = logging.WARN) + +def build_eval_transform(input_size=224, imagenet_default_mean_and_std=False, crop_pct=None): + resize_im = input_size > 32 + imagenet_default_mean_and_std = imagenet_default_mean_and_std + mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN + std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD + + t = [] + if resize_im: + if crop_pct is None: + if input_size < 384: + crop_pct = 224 / 256 + else: + crop_pct = 1.0 + size = int(input_size / crop_pct) + t.append( + transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images + ) + t.append(transforms.CenterCrop(input_size)) + + t.append(transforms.ToTensor()) + t.append(transforms.Normalize(mean, std)) + return transforms.Compose(t) + +def build_val_dataset(data_path): + transform = build_eval_transform() + root = os.path.join(data_path, 'val') + dataset = datasets.ImageFolder(root, transform=transform) + return dataset + + +def evaluate_func(data_loader, model): + session = ort.InferenceSession(model.SerializeToString(), providers=["CPUExecutionProvider"]) + top1, top5 = 0, 0 + + for idx, batch in tqdm.tqdm(enumerate(data_loader), desc='eval'): + images = batch[0].cpu().detach().numpy() + target = batch[-1] + output = session.run(None, {'image': images})[0] + acc1, acc5 = accuracy(torch.from_numpy(output), target, topk=(1, 5)) + top1 += acc1.cpu().detach().numpy() + top5 += acc5.cpu().detach().numpy() + + top1 = top1 / len(data_loader) + top5 = top5 / len(data_loader) + print('* Acc@1 {:.3f} Acc@5 {:.3f}'.format(top1, top5)) + return top1 + +if __name__ == '__main__': + logger.info("Evaluating ONNXRuntime full precision accuracy and performance:") + parser = argparse.ArgumentParser( + description="BEiT quantization examples.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + '--model_path', + type=str, + help="Pre-trained model on onnx file" + ) + parser.add_argument( + '--dataset_location', + type=str, + help="Imagenet data path" + ) + parser.add_argument( + '--benchmark', + action='store_true', \ + default=False, + help="whether benchmark the model" + ) + parser.add_argument( + '--tune', + action='store_true', \ + default=False, + help="whether quantize the model" + ) + parser.add_argument( + '--output_model', + type=str, + help="output model path" + ) + parser.add_argument( + '--quant_format', + type=str, + default='default', + choices=['default', 'QDQ', 'QOperator'], + help="quantization format" + ) + parser.add_argument( + '--mode', + type=str, + help="benchmark mode of performance or accuracy" + ) + parser.add_argument( + "--batch_size", + default=64, + type=int, + ) + parser.add_argument( + "--num_workers", + default=10, + type=int, + ) + args = parser.parse_args() + + val_dataset = build_val_dataset(args.dataset_location) + val_sampler = torch.utils.data.SequentialSampler(val_dataset) + val_data_loader = torch.utils.data.DataLoader( + val_dataset, sampler=val_sampler, + batch_size=int(1.5 * args.batch_size), + num_workers=args.num_workers, + drop_last=False + ) + + def eval(model): + return evaluate_func(val_data_loader, model) + + model = onnx.load(args.model_path) + + if args.tune: + from neural_compressor import PostTrainingQuantConfig, quantization + from neural_compressor.utils.constant import FP32 + + config = PostTrainingQuantConfig(approach="static", + quant_format=args.quant_format, + op_type_dict={'Conv': FP32}, + op_name_dict={'/blocks.*/mlp/fc2/MatMul': FP32}, + recipes={'optypes_to_exclude_output_quant': ['MatMul']}, + ) + q_model = quantization.fit(model, + config, + calib_dataloader=val_data_loader, + eval_func=eval) + q_model.save(args.output_model) + + if args.benchmark: + if args.mode == 'performance': + from neural_compressor.benchmark import fit + from neural_compressor.config import BenchmarkConfig + conf = BenchmarkConfig(warmup=10, iteration=1000, cores_per_instance=4, num_of_instance=1) + fit(model, conf, b_dataloader=val_data_loader) + elif args.mode == 'accuracy': + acc_result = evaluate_func(val_data_loader, model) + print("Batch size = %d" % val_data_loader.batch_size) + print("Accuracy: %.5f" % acc_result) + + diff --git a/examples/onnxrt/image_recognition/beit/quantization/ptq_static/prepare_model.py b/examples/onnxrt/image_recognition/beit/quantization/ptq_static/prepare_model.py new file mode 100644 index 00000000000..683e115c38d --- /dev/null +++ b/examples/onnxrt/image_recognition/beit/quantization/ptq_static/prepare_model.py @@ -0,0 +1,102 @@ +import argparse +import os +import sys +import torch +from urllib import request +from timm.models import create_model +import beit_modeling_finetune + +MODEL_URLS = {"beit_base_patch16_224": "https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D",} +MODEL_FILES = {"beit_base_patch16_224": "beit_base_patch16_224_pt22k_ft22kto1k.pth"} +MAX_TIMES_RETRY_DOWNLOAD = 5 + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument("--input_model", type=str, required=False, default="beit_base_patch16_224") + parser.add_argument("--output_model", type=str, required=True) + return parser.parse_args() + + +def progressbar(cur, total=100): + percent = '{:.2%}'.format(cur / total) + sys.stdout.write("\r[%-100s] %s" % ('#' * int(cur), percent)) + sys.stdout.flush() + + +def schedule(blocknum, blocksize, totalsize): + if totalsize == 0: + percent = 0 + else: + percent = min(1.0, blocknum * blocksize / totalsize) * 100 + progressbar(percent) + + +def download_model(input_model, retry_times=5): + model_url = MODEL_URLS[input_model] + model_file = MODEL_FILES[input_model] + if os.path.isfile(model_file): + print(f"{model_file} exists, skip download") + return True + + print("download model...") + retries = 0 + while retries < retry_times: + try: + request.urlretrieve(model_url, model_file, schedule) + break + except KeyboardInterrupt: + return False + except: + retries += 1 + print(f"Download failed{', Retry downloading...' if retries < retry_times else '!'}") + return retries < retry_times + + +def export_model(input_model, output_model): + print("\nexport model...") + + model = create_model( + input_model, + pretrained=False, + num_classes=1000, + drop_rate=0.0, + drop_path_rate=0.1, + attn_drop_rate=0.0, + drop_block_rate=None, + use_mean_pooling=True, + init_scale=0.001, + use_rel_pos_bias=True, + use_abs_pos_emb=False, + init_values=0.1, + ) + + checkpoint = torch.load(MODEL_FILES[input_model], map_location='cpu') + model.load_state_dict(checkpoint['model']) + print("Resume checkpoint %s" % MODEL_FILES[input_model]) + + model.eval() + x = torch.randn(1, 3, 224, 224, requires_grad=True) + torch.onnx.export(model, + x, + output_model, + export_params=True, + opset_version=13, + do_constant_folding=True, + input_names = ["image"], + output_names = ["output"], + dynamic_axes={"image" : {0 : "batch_size"}, + "output" : {0 : "batch_size"}} + ) + assert os.path.exists(output_model), f"Export failed! {output_model} doesn't exist!" + + +def prepare_model(input_model, output_model): + is_download_successful = download_model(args.input_model, MAX_TIMES_RETRY_DOWNLOAD) + if is_download_successful: + export_model(input_model, output_model) + + +if __name__ == "__main__": + args = parse_arguments() + prepare_model(args.input_model, args.output_model) \ No newline at end of file diff --git a/examples/onnxrt/image_recognition/beit/quantization/ptq_static/requirements.txt b/examples/onnxrt/image_recognition/beit/quantization/ptq_static/requirements.txt new file mode 100644 index 00000000000..b6855b21b0a --- /dev/null +++ b/examples/onnxrt/image_recognition/beit/quantization/ptq_static/requirements.txt @@ -0,0 +1,6 @@ +torch +torchvision +timm +onnx +onnxruntime +onnxruntime-extensions; python_version < '3.11' \ No newline at end of file diff --git a/examples/onnxrt/image_recognition/beit/quantization/ptq_static/run_benchmark.sh b/examples/onnxrt/image_recognition/beit/quantization/ptq_static/run_benchmark.sh new file mode 100644 index 00000000000..41c190229d1 --- /dev/null +++ b/examples/onnxrt/image_recognition/beit/quantization/ptq_static/run_benchmark.sh @@ -0,0 +1,44 @@ +#!/bin/bash +set -x + +function main { + init_params "$@" + run_benchmark + +} + +# init params +function init_params { + for var in "$@" + do + case $var in + --input_model=*) + input_model=$(echo $var |cut -f2 -d=) + ;; + --dataset_location=*) + dataset_location=$(echo $var |cut -f2 -d=) + ;; + --mode=*) + mode=$(echo $var |cut -f2 -d=) + ;; + --batch_size=*) + batch_size=$(echo $var |cut -f2 -d=) + ;; + esac + done + +} + +# run_benchmark +function run_benchmark { + + python main.py \ + --model_path ${input_model} \ + --dataset_location ${dataset_location} \ + --mode ${mode} \ + --batch_size ${batch_size-1} \ + --benchmark + +} + +main "$@" diff --git a/examples/onnxrt/image_recognition/beit/quantization/ptq_static/run_quant.sh b/examples/onnxrt/image_recognition/beit/quantization/ptq_static/run_quant.sh new file mode 100644 index 00000000000..7f9d10fa0e7 --- /dev/null +++ b/examples/onnxrt/image_recognition/beit/quantization/ptq_static/run_quant.sh @@ -0,0 +1,43 @@ +#!/bin/bash +set -x + +function main { + init_params "$@" + run_tuning + +} + +# init params +function init_params { + + for var in "$@" + do + case $var in + --input_model=*) + input_model=$(echo $var |cut -f2 -d=) + ;; + --output_model=*) + output_model=$(echo $var |cut -f2 -d=) + ;; + --dataset_location=*) + dataset_location=$(echo $var |cut -f2 -d=) + ;; + --quant_format=*) + quant_format=$(echo $var |cut -f2 -d=) + ;; + esac + done + +} + +# run_tuning +function run_tuning { + python main.py \ + --model_path ${input_model} \ + --dataset_location ${dataset_location} \ + --output_model ${output_model} \ + --quant_format ${quant_format-default} \ + --tune +} + +main "$@" diff --git a/examples/onnxrt/nlp/huggingface_model/language_modeling/quantization/ptq_dynamic/main.py b/examples/onnxrt/nlp/huggingface_model/language_modeling/quantization/ptq_dynamic/main.py index 345f6fac183..5182ec6bbc5 100644 --- a/examples/onnxrt/nlp/huggingface_model/language_modeling/quantization/ptq_dynamic/main.py +++ b/examples/onnxrt/nlp/huggingface_model/language_modeling/quantization/ptq_dynamic/main.py @@ -197,8 +197,7 @@ def main(): tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path, use_fast=True, - cache_dir=args.cache_dir if args.cache_dir else None, - use_auth_token='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') + cache_dir=args.cache_dir if args.cache_dir else None) if args.block_size <= 0: args.block_size = tokenizer.max_len_single_sentence # Our input block size will be the max possible for the model args.block_size = min(args.block_size, tokenizer.max_len_single_sentence) diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run-gptq-llm.py b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run-gptq-llm.py index 149a030af09..a13a99d483b 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run-gptq-llm.py +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run-gptq-llm.py @@ -58,6 +58,13 @@ def __init__(self, gptq_dataloader): def __iter__(self): pass +def filter_chatglmv1(seq): + bos_token_id = 130004 + # eos_token_id = 130005 + gmask_token_id = 130001 + # return (bos_token_id in seq and eos_token_id in seq and mask_token_id in seq and gmask_token_id in seq) + return (len(seq) < 2048 and bos_token_id in seq and gmask_token_id in seq) + # INC original dataloader example class Evaluator: def __init__(self, dataset, tokenizer, batch_size=8, pad_val=1, pad_max=196, is_calib=False): @@ -217,9 +224,9 @@ def skip(*args, **kwargs): # model if re.search("chatglm", args.model_name_or_path.lower()): # chatglm requires a different way to be loaded tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True) - model = AutoModel.from_pretrained(args.model_name_or_path, trust_remote_code=True) + model = AutoModel.from_pretrained(args.model_name_or_path, trust_remote_code=True).float().cpu() else: - tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True) + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, low_cpu_mem_usage=True, trust_remote_code=True) model = model.eval() @@ -227,12 +234,20 @@ def skip(*args, **kwargs): # calib_dataset = datasets.load_from_disk('/your/local/pile-10k/') # use this if trouble with connecting to HF calib_dataset = calib_dataset.shuffle(seed=args.seed) calib_evaluator = Evaluator(calib_dataset, tokenizer, args.calib_size, is_calib=True) - calib_dataloader = DataLoader( - calib_evaluator.dataset, - batch_size=args.calib_size, - shuffle=False, - collate_fn=calib_evaluator.collate_batch, - ) + if hasattr(model.config, "_name_or_path") and "chatglm-6b" in model.config._name_or_path: + calib_dataloader = DataLoader( + calib_evaluator.dataset.filter(lambda example: filter_chatglmv1(example['input_ids'])), + batch_size=args.calib_size, + shuffle=False, + collate_fn=calib_evaluator.collate_batch, + ) + else: + calib_dataloader = DataLoader( + calib_evaluator.dataset, + batch_size=args.calib_size, + shuffle=False, + collate_fn=calib_evaluator.collate_batch, + ) if args.gpu and torch.cuda.is_available(): DEV = torch.device('cuda:0') @@ -294,7 +309,8 @@ def skip(*args, **kwargs): dataloader=calib_dataloader, nsamples = args.nsamples, use_max_length = args.use_max_length, - pad_max_length = args.pad_max_length + pad_max_length = args.pad_max_length, + device = DEV, ) results = lm_evaluate( diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index 571346b515d..04eac28aab9 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -2739,14 +2739,20 @@ def _ipex_post_quant_process(self, model, q_model, dataloader, inplace=False): q_model._model = ipex.quantization.convert(model._model, inplace=inplace) try: if isinstance(self.example_inputs, dict): - q_model._model = torch.jit.trace(q_model._model, example_kwarg_inputs=self.example_inputs) + q_model._model = torch.jit.trace( + q_model._model, + example_kwarg_inputs=self.example_inputs, + ) else: q_model._model = torch.jit.trace(q_model._model, self.example_inputs) q_model._model = torch.jit.freeze(q_model._model.eval()) except: if isinstance(self.example_inputs, dict): q_model._model = torch.jit.trace( - q_model._model, example_kwarg_inputs=self.example_inputs, strict=False + q_model._model, + example_kwarg_inputs=self.example_inputs, + strict=False, + check_trace=False, ) else: q_model._model = torch.jit.trace(q_model._model, self.example_inputs, strict=False) @@ -2763,7 +2769,7 @@ def _ipex_post_quant_process(self, model, q_model, dataloader, inplace=False): except: if isinstance(self.example_inputs, dict): q_model._model = torch.jit.trace( - q_model._model, example_kwarg_inputs=self.example_inputs, strict=False + q_model._model, example_kwarg_inputs=self.example_inputs, strict=False, check_trace=False ) else: q_model._model = torch.jit.trace(q_model._model, self.example_inputs, strict=False) @@ -3475,13 +3481,13 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None): ): from .torch_utils.layer_wise_quant import LayerWiseQuant - model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None) + # model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None) + model_path = model._model.path smooth_quant = recipe_cfgs["layer_wise_quant_args"].get("smooth_quant", False) alpha = recipe_cfgs["layer_wise_quant_args"].get("smooth_quant_alpha", 0.5) - assert ( - model_path is not None - ), "the layer_wise_quant_args should have args model_path to load the weight of model." - device = recipe_cfgs["layer_wise_quant_args"].get("decvice", "cpu") + # device = recipe_cfgs["layer_wise_quant_args"].get("decvice", "cpu") + assert model_path is not None, "The model_path should not be None." + device = self.device lw_quant = LayerWiseQuant( q_model._model, model_path, @@ -4514,14 +4520,12 @@ def rtn_quantize(self, model, tune_cfg): # for layer_wise quant mode recipe_cfgs = tune_cfg.get("recipe_cfgs", None) if recipe_cfgs.get("layer_wise_quant", False): - from neural_compressor.config import options - - from .torch_utils.layer_wise_quant.utils import _get_path, load_module + from .torch_utils.layer_wise_quant.utils import LWQ_WORKSPACE, _get_path, load_module - lwq_workspace = os.path.join(options.workspace, "lwq_tmpdir") - os.makedirs(lwq_workspace, exist_ok=True) - model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None) - assert model_path, "model_path should specify in layer_wise_quant_args." + os.makedirs(LWQ_WORKSPACE, exist_ok=True) + # model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None) + model_path = model.path + assert model_path, "model_path should not be None." model_path = _get_path(model_path) for key, config in tune_cfg["op"].items(): @@ -4557,7 +4561,7 @@ def rtn_quantize(self, model, tune_cfg): # save and clean weight from .torch_utils.layer_wise_quant.utils import clean_module_weight - torch.save(m.state_dict(), os.path.join(lwq_workspace, f"{op_name}.pt")) + torch.save(m.state_dict(), os.path.join(LWQ_WORKSPACE, f"{op_name}.pt")) clean_module_weight(m) set_module(model, op_name, m) if recipe_cfgs.get("layer_wise_quant", False): @@ -4592,6 +4596,23 @@ def gptq_quantize(self, model, tune_cfg, dataloader): ... } """ + # for layer_wise quant mode + recipe_cfgs = tune_cfg.get("recipe_cfgs", None) + model_path = None + layer_wise = False + if recipe_cfgs.get("layer_wise_quant", False): + layer_wise = True + from .torch_utils.layer_wise_quant.utils import LWQ_WORKSPACE, _get_path, register_weight_hooks + + os.makedirs(LWQ_WORKSPACE, exist_ok=True) + # model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None) + model_path = model.path + assert model_path, "model_path should not be None." + model_path = _get_path(model_path) + lwq_handles = register_weight_hooks( + model, model_path, device=self.device, clean_weight=True, saved_path=LWQ_WORKSPACE + ) + weight_config = {} for key, config in tune_cfg["op"].items(): op_name, op_type = key @@ -4616,7 +4637,15 @@ def gptq_quantize(self, model, tune_cfg, dataloader): ) # tune_cfg => weight_config model, quantization_perm = gptq_quantize( - model, weight_config, dataloader, nsamples, use_max_length, pad_max_length, self.device + model, + weight_config, + dataloader, + nsamples, + use_max_length, + pad_max_length, + self.device, + layer_wise, + model_path, ) return model, quantization_perm diff --git a/neural_compressor/adaptor/torch_utils/gptq.py b/neural_compressor/adaptor/torch_utils/gptq.py index b2710558e57..1a33addb364 100644 --- a/neural_compressor/adaptor/torch_utils/gptq.py +++ b/neural_compressor/adaptor/torch_utils/gptq.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc import math import random import re @@ -72,7 +73,7 @@ def is_leaf(module): return True if children_cnt == 0 else False -def trace_gptq_target_blocks(module, module_types=[torch.nn.ModuleList]): +def trace_gptq_target_blocks(module, module_types=[torch.nn.ModuleList, torch.nn.Sequential]): """Search transformer stacked structures, which is critical in LLMs and GPTQ execution. Args: @@ -88,21 +89,41 @@ def trace_gptq_target_blocks(module, module_types=[torch.nn.ModuleList]): "transformers": {}, Dict# TODO } """ - gptq_related_blocks = { - "embeddings": {}, - "transformers_pre": {}, # todo - "transformers_name": "", # None - "transformers": [], # None - "transformers_post": {}, # todo - } - for n, m in module.named_modules(): - if type(m) in module_types: - gptq_related_blocks["transformers_name"] = n - gptq_related_blocks["transformers"] = m - return gptq_related_blocks - else: + if type(module).__name__ == "MixFormerSequentialForCausalLM": # pragma: no cover + gptq_related_blocks = { + "embeddings": {}, + "transformers_pre": {}, # todo + "transformers_name": "", # None + "transformers": [], # None + "transformers_post": {}, # todo + } + for n, m in module.named_modules(): + if type(m) in module_types: + gptq_related_blocks["transformers_name"] = n + gptq_related_blocks["transformers"] = m + break + else: + continue + for n, m in gptq_related_blocks["transformers"][0].named_modules(): if is_leaf(m): gptq_related_blocks["embeddings"][n] = m + gptq_related_blocks["transformers"] = gptq_related_blocks["transformers"][1:-1] + else: + gptq_related_blocks = { + "embeddings": {}, + "transformers_pre": {}, # todo + "transformers_name": "", # None + "transformers": [], # None + "transformers_post": {}, # todo + } + for n, m in module.named_modules(): + if type(m) in module_types: + gptq_related_blocks["transformers_name"] = n + gptq_related_blocks["transformers"] = m + return gptq_related_blocks + else: + if is_leaf(m): + gptq_related_blocks["embeddings"][n] = m return gptq_related_blocks @@ -175,6 +196,7 @@ def __init__( use_max_length=True, pad_max_length=2048, device=None, + layer_wise=False, ): """ Args: @@ -196,7 +218,7 @@ def __init__( """ # model self.model = model - self.use_cache = self.model.config.use_cache + # self.use_cache = self.model.config.use_cache self.gptq_related_blocks = trace_gptq_target_blocks(self.model) # get the transformer block list above self.dtype = next(iter(self.model.parameters())).dtype log_quantizable_layers_per_transformer(self.gptq_related_blocks) @@ -215,9 +237,13 @@ def __init__( self.check_layer_config() # device - self.device = model.device + self.device = device + if str(self.model.device).startswith("cuda"): + self.device = self.model.device self.is_ready = False + self.layer_wise = layer_wise + # dataloader self.use_max_length = use_max_length self.pad_max_length = pad_max_length @@ -411,6 +437,12 @@ def get_layer_config(self, layer_name): pass return config + def track_hidden_states(self, data): + if isinstance(data, torch.Tensor): + return data + elif isinstance(data, tuple) or isinstance(data, list): + return data[0] + @torch.no_grad() def pre_quantization(self): """Prepare input calibration data and other attributes which are critical for gptq execution.""" @@ -438,11 +470,13 @@ def forward(layer, *args, **kwargs): raise ValueError # Step1: fetch the embeddings and other layers before the transformer stack. - for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items(): - embedding_layer = embedding_layer.to(self.device) + if not self.layer_wise: + for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items(): + embedding_layer = embedding_layer.to(self.device) # Step2: modify the first transformer block's forward function to obtain inputs for calibration - self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].to(self.device) + if not self.layer_wise: + self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].to(self.device) forward_cache = self.gptq_related_blocks["transformers"][0].forward self.gptq_related_blocks["transformers"][0].forward = partial( forward, self.gptq_related_blocks["transformers"][0] @@ -451,7 +485,8 @@ def forward(layer, *args, **kwargs): # Step3: run forward to obtain calibration datasets logger.info("Collecting calibration inputs...") for batch in tqdm(self.dataloader): - batch = move_input_to_device(batch, self.device) + if not self.layer_wise: + batch = move_input_to_device(batch, self.device) try: if isinstance(batch, tuple) or isinstance(batch, list): self.model(batch[0]) @@ -473,9 +508,10 @@ def forward(layer, *args, **kwargs): # Step 4: restore original forward function, relocate layers back to cpu. self.gptq_related_blocks["transformers"][0].forward = forward_cache - self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].cpu() - for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items(): - embedding_layer.to(self.device) + if not self.layer_wise: + self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].cpu() + for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items(): + embedding_layer.to(self.device) torch.cuda.empty_cache() # end logger.info("GPTQ quantization prepared.") @@ -501,7 +537,7 @@ def update_blockwise_hidden_states(self, outs): self.cache_positional_arguments[0] = outs[:] @torch.no_grad() - def execute_quantization(self, means=None, stds=None): + def execute_quantization(self, means=None, stds=None, model_path=None): """Run quantization.""" # Step1: prepare quantization (calibration datasets) @@ -513,7 +549,11 @@ def execute_quantization(self, means=None, stds=None): tblock_length = len(self.gptq_related_blocks["transformers"]) for block_idx in range(tblock_length): logger.info(f"Quantizing layer {block_idx + 1} / {tblock_length}..") - transformer_block = self.gptq_related_blocks["transformers"][block_idx].to(self.device) + if not self.layer_wise: + # if we do not apply layer-wise feature, we still place the entire block on the GPU + transformer_block = self.gptq_related_blocks["transformers"][block_idx].to(self.device) + else: + transformer_block = self.gptq_related_blocks["transformers"][block_idx] # .to(self.device) # Step2.1: obtain all layers (Linear, Conv2d, etc) in the block which can be quantized. sub_layers = find_layers(transformer_block) sub_layers_to_quant = {} @@ -534,8 +574,16 @@ def execute_quantization(self, means=None, stds=None): # weight_config_this_layer = self.weight_config.get( # self.get_full_layer_name(layer_name, block_idx), None # ) - weight_config_this_layer = self.get_layer_config(self.get_full_layer_name(layer_name, block_idx)) - gptq_for_this_block[layer_name] = GPTQ(sub_layers[layer_name]) + full_layer_name = self.get_full_layer_name(layer_name, block_idx) + weight_config_this_layer = self.get_layer_config(full_layer_name) + if self.layer_wise: + from ..torch_utils.layer_wise_quant.utils import load_value + + W = load_value(self.model, full_layer_name + ".weight", model_path) + else: + W = sub_layers[layer_name].weight.data.clone() + + gptq_for_this_block[layer_name] = GPTQ(sub_layers[layer_name], W, self.device) # gptq_for_this_block[layer_name].quantizer = Quantizer() gptq_for_this_block[layer_name].quantizer.configure( weight_config_this_layer["wbits"], @@ -555,11 +603,11 @@ def tmp(_, inp, out): for layer_name in sub_layers: handles.append(sub_layers[layer_name].register_forward_hook(add_batch(layer_name))) idx = self.cache_key_arguments.pop("i") - # import pdb;pdb.set_trace() for j in range(len(self.dataloader)): cache_keyword_batch = self.gather_single_batch_from_dict(self.cache_key_arguments, j) cache_positional_batch = self.gather_single_batch_from_list(self.cache_positional_arguments, j) - out = transformer_block(*cache_positional_batch, **cache_keyword_batch)[0] + out = transformer_block(*cache_positional_batch, **cache_keyword_batch) + out = self.track_hidden_states(out) self.cache_key_arguments["i"] = idx for h in handles: h.remove() @@ -570,12 +618,44 @@ def tmp(_, inp, out): # ) weight_config_this_layer = self.get_layer_config(self.get_full_layer_name(layer_name, block_idx)) logger.info(f"Quantizing layer {layer_name}") - scale, zp = gptq_for_this_block[layer_name].fasterquant( + if self.layer_wise: + from ..torch_utils.layer_wise_quant.utils import load_value + + full_layer_name = self.get_full_layer_name(layer_name, block_idx) + W = load_value(self.model, full_layer_name + ".weight", model_path) + else: + W = sub_layers[layer_name].weight.data.clone() + scale, zp, Q = gptq_for_this_block[layer_name].fasterquant( + W, blocksize=weight_config_this_layer["block_size"], percdamp=weight_config_this_layer["percdamp"], groupsize=weight_config_this_layer["group_size"], act_order=weight_config_this_layer["act_order"], ) + if self.layer_wise: + from ..torch_utils.layer_wise_quant.utils import ( + LWQ_WORKSPACE, + clean_module_weight, + load_value, + set_module_tensor_to_device, + ) + + sub_layer = sub_layers[layer_name] + full_layer_name = self.get_full_layer_name(layer_name, block_idx) + for n, p in sub_layer.named_parameters(): + param_name = full_layer_name + "." + n + if n == "weight": + set_module_tensor_to_device(self.model, param_name, self.device, Q) + else: + value = load_value(self.model, param_name, model_path) + set_module_tensor_to_device(self.model, param_name, self.device, value) + # sub_layer.weight.data = Q + torch.save(sub_layer.state_dict(), LWQ_WORKSPACE + f"/{full_layer_name}.pt") + clean_module_weight(sub_layer) + del Q + gc.collect() + else: + sub_layers[layer_name].weight.data = Q gptq_config[self.get_full_layer_name(layer_name, block_idx)] = {"scale": scale} if not weight_config_this_layer["sym"]: gptq_config[self.get_full_layer_name(layer_name, block_idx)]["zero"] = zp @@ -591,10 +671,14 @@ def tmp(_, inp, out): for j in range(len(self.dataloader)): cache_keyword_batch = self.gather_single_batch_from_dict(self.cache_key_arguments, j) cache_positional_batch = self.gather_single_batch_from_list(self.cache_positional_arguments, j) - out = transformer_block(*cache_positional_batch, **cache_keyword_batch)[0] + out = transformer_block(*cache_positional_batch, **cache_keyword_batch) + out = self.track_hidden_states(out) outs.append(out) self.cache_key_arguments["i"] = idx - self.gptq_related_blocks["transformers"][block_idx] = transformer_block.cpu() + if self.layer_wise: + self.gptq_related_blocks["transformers"][block_idx] = transformer_block + else: + self.gptq_related_blocks["transformers"][block_idx] = transformer_block.cpu() del gptq_for_this_block torch.cuda.empty_cache() # iteratively replace the input with output, thus layerwise quantization can continue. @@ -602,7 +686,7 @@ def tmp(_, inp, out): logger.info("------------------------------") logger.info("Quantization done") - self.model.config.use_cache = self.use_cache + # self.model.config.use_cache = self.use_cache # obtain model (all weight only quantization API function should return) for k, v in gptq_config.items(): @@ -617,10 +701,10 @@ class GPTQ: GPTQ: Accurate Post-training Compression for Generative Pretrained Transformers (https://arxiv.org/abs/2210.17323) """ - def __init__(self, layer): + def __init__(self, layer, W, device="cpu"): self.layer = layer - self.device = self.layer.weight.device - W = layer.weight.data.clone() + self.device = device + # W = layer.weight.data.clone() if isinstance(self.layer, nn.Conv2d) or isinstance(self.layer, nn.Conv1d): W = W.flatten(1) if isinstance(self.layer, transformers.Conv1D): @@ -661,8 +745,9 @@ def add_batch(self, inp, out): # self.H += 2 / self.nsamples * inp.matmul(inp.t()) self.H += inp.matmul(inp.t()) # H = X*X, which should be a sysm matrix - def fasterquant(self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False): - W = self.layer.weight.data.clone() + def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False): + # W = self.layer.weight.data.clone() + weight_shape, weight_dtype = W.shape, W.data.dtype if isinstance(self.layer, nn.Conv2d): W = W.flatten(1) if isinstance(self.layer, transformers.Conv1D): @@ -740,7 +825,7 @@ def fasterquant(self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=Fals # logger.info(f"{torch.sum((self.layer(self.inp1) - self.out1) ** 2)}") # logger.info(f"{torch.sum(Losses)}") - if self.device != torch.device("cpu"): + if str(self.device).startswith("cuda"): torch.cuda.synchronize() logger.info(f"time {(time.time() - tick)}") logger.info(f"error {torch.sum(Losses).item()}") @@ -751,7 +836,8 @@ def fasterquant(self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=Fals if isinstance(self.layer, transformers.Conv1D): Q = Q.t() - self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) + # self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) + Q = Q.reshape(weight_shape).to(weight_dtype) if DEBUG: logger.info(f"{torch.sum((self.layer(self.inp1) - self.out1) ** 2)}") @@ -760,7 +846,7 @@ def fasterquant(self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=Fals zero.append(self.quantizer.zero) scale = torch.cat(scale, dim=1) zero = torch.cat(zero, dim=1) - return scale, zero + return scale, zero, Q def free(self): if DEBUG: diff --git a/neural_compressor/adaptor/torch_utils/layer_wise_quant/__init__.py b/neural_compressor/adaptor/torch_utils/layer_wise_quant/__init__.py index 6f01b1288bf..f347da31bb9 100644 --- a/neural_compressor/adaptor/torch_utils/layer_wise_quant/__init__.py +++ b/neural_compressor/adaptor/torch_utils/layer_wise_quant/__init__.py @@ -15,5 +15,5 @@ # See the License for the specific language governing permissions and # limitations under the License. """Torch layer-wise quantization module.""" -from .utils import load_shell +from .utils import load_empty_model from .quantize import LayerWiseQuant diff --git a/neural_compressor/adaptor/torch_utils/layer_wise_quant/quantize.py b/neural_compressor/adaptor/torch_utils/layer_wise_quant/quantize.py index 1746ad82140..9e3f8789dad 100644 --- a/neural_compressor/adaptor/torch_utils/layer_wise_quant/quantize.py +++ b/neural_compressor/adaptor/torch_utils/layer_wise_quant/quantize.py @@ -40,7 +40,7 @@ update_module, ) -TMP_DIR = os.path.join(default_workspace, "layer_wise_quant_tmp_dir") +TMP_DIR = os.path.join(default_workspace, "lwq_tmpdir") def mk_tmp_dir(): @@ -92,7 +92,7 @@ def __init__( alpha=0.5, ): """Init LayerWiseQuant.""" - # self.q_model = load_shell(pretrained_model_name_or_path, cls) + # self.q_model = load_empty_model(pretrained_model_name_or_path, cls) self.q_model = q_model self.fp32_model = deepcopy(self.q_model) self.path = _get_path(pretrained_model_name_or_path) diff --git a/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py b/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py index e932c40480a..8bd3d32d320 100644 --- a/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py +++ b/neural_compressor/adaptor/torch_utils/layer_wise_quant/utils.py @@ -25,7 +25,7 @@ torch = LazyImport("torch") from accelerate import init_empty_weights from accelerate.utils import set_module_tensor_to_device -from transformers import AutoConfig +from transformers import AutoConfig, AutoModelForCausalLM from transformers.models.auto.auto_factory import _BaseAutoModelClass from ....config import options @@ -107,7 +107,7 @@ def dowload_hf_model(repo_id, cache_dir=None, repo_type=None, revision=None): return file_path -def load_shell(pretrained_model_name_or_path, cls, **kwargs): +def load_empty_model(pretrained_model_name_or_path, cls=AutoModelForCausalLM, **kwargs): """Load a empty model.""" is_local = os.path.isdir(pretrained_model_name_or_path) if is_local: # pragma: no cover @@ -124,6 +124,7 @@ def load_shell(pretrained_model_name_or_path, cls, **kwargs): model = cls(config) model.tie_weights() model.eval() + model.path = pretrained_model_name_or_path return model @@ -223,7 +224,10 @@ def load_module(model, module_name, path, device="cpu"): set_module_tensor_to_device(model, param_name, device, value) -def register_weight_hooks(model, path, device="cpu", clean_weight=True): +def register_weight_hooks(model, path, device="cpu", clean_weight=True, saved_path=None): + if saved_path: + os.makedirs(saved_path, exist_ok=True) + def forward_pre_hook(name): def hook(module, input): state_dict = None @@ -241,6 +245,9 @@ def hook(module, input): def forward_hook(name): def hook(module, input, output): + if saved_path: + file_path = os.path.join(saved_path, f"{name}.pt") + torch.save(module.state_dict(), file_path) clean_module_weight(module) return hook diff --git a/neural_compressor/adaptor/torch_utils/mixed_precision.py b/neural_compressor/adaptor/torch_utils/mixed_precision.py index dce5fe54b3c..6c461ea195b 100644 --- a/neural_compressor/adaptor/torch_utils/mixed_precision.py +++ b/neural_compressor/adaptor/torch_utils/mixed_precision.py @@ -38,7 +38,9 @@ def ipex_mixed_precision(model, example_inputs=None, device="cpu"): try: mp_model = torch.jit.trace(mp_model, example_kwarg_inputs=example_inputs) except: - mp_model = torch.jit.trace(mp_model, example_kwarg_inputs=example_inputs, strict=False) + mp_model = torch.jit.trace( + mp_model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False + ) else: try: mp_model = torch.jit.trace(mp_model, example_inputs) diff --git a/neural_compressor/adaptor/torch_utils/smooth_quant.py b/neural_compressor/adaptor/torch_utils/smooth_quant.py index 65804468ec4..7b59b3ce3e5 100644 --- a/neural_compressor/adaptor/torch_utils/smooth_quant.py +++ b/neural_compressor/adaptor/torch_utils/smooth_quant.py @@ -1186,7 +1186,9 @@ def trace(self, model, dummy_input): dummy_input = move_input_to_device(dummy_input, "cpu") if isinstance(dummy_input, dict) or isinstance(dummy_input, UserDict): try: - traced_model = torch.jit.trace(model, example_kwarg_inputs=dict(dummy_input), strict=False) + traced_model = torch.jit.trace( + model, example_kwarg_inputs=dict(dummy_input), strict=False, check_trace=False + ) traced_model = torch.jit.freeze(traced_model.eval(), optimize_numerics=optimize_numerics) except Exception as e: logger.warning(e) diff --git a/neural_compressor/adaptor/torch_utils/weight_only.py b/neural_compressor/adaptor/torch_utils/weight_only.py index b9376a72c7f..7ba86eaa344 100644 --- a/neural_compressor/adaptor/torch_utils/weight_only.py +++ b/neural_compressor/adaptor/torch_utils/weight_only.py @@ -470,15 +470,27 @@ def rtn_quantize( def gptq_quantize( - model, weight_config={}, dataloader=None, nsamples=128, use_max_length=True, pad_max_length=2048, device=None + model, + weight_config={}, + dataloader=None, + nsamples=128, + use_max_length=True, + pad_max_length=2048, + device=None, + layer_wise=False, + model_path=None, ): """Run weight-only quantization with.""" # TODO: unify weight_config keys, add docstring, and support default config assert isinstance(model, torch.nn.Module), "only support torch module" + if layer_wise: + assert model_path is not None, "model_path should not be None when use layer_wise mode" from .gptq import GPTQuantizer - gptq_quantizer = GPTQuantizer(model, weight_config, dataloader, nsamples, use_max_length, pad_max_length, device) - fp32_modified_model, gptq_config = gptq_quantizer.execute_quantization() + gptq_quantizer = GPTQuantizer( + model, weight_config, dataloader, nsamples, use_max_length, pad_max_length, device, layer_wise=layer_wise + ) + fp32_modified_model, gptq_config = gptq_quantizer.execute_quantization(model_path=model_path) logger.info("GPTQ quantizing done.") return fp32_modified_model, gptq_config diff --git a/neural_compressor/model/torch_model.py b/neural_compressor/model/torch_model.py index 95f273eeab5..eeada402f35 100644 --- a/neural_compressor/model/torch_model.py +++ b/neural_compressor/model/torch_model.py @@ -356,7 +356,8 @@ def save(self, root=None): if os.path.exists(os.path.join(LWQ_WORKSPACE, f"{name}.pt")): state_dict = torch.load(os.path.join(LWQ_WORKSPACE, f"{name}.pt")) model_path = _get_path( - self.q_config["recipe_cfgs"]["layer_wise_quant_args"].get("model_path") + # self.q_config["recipe_cfgs"]["layer_wise_quant_args"].get("model_path") + self._model.path ) for n, p in module.named_parameters(): param_name = name + "." + n diff --git a/test/algorithm/test_layer_wise_quant.py b/test/algorithm/test_layer_wise_quant.py index e25036bb021..2eef1e89bd9 100644 --- a/test/algorithm/test_layer_wise_quant.py +++ b/test/algorithm/test_layer_wise_quant.py @@ -8,14 +8,14 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from neural_compressor import PostTrainingQuantConfig, quantization -from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_shell +from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_empty_model from neural_compressor.utils.pytorch import load class TestLayerWise(unittest.TestCase): def test_layer_wise(self): model_name_or_path = "facebook/opt-125m" - fp32_model = load_shell(model_name_or_path, AutoModelForCausalLM, torchscript=True) + fp32_model = load_empty_model(model_name_or_path, torchscript=True) class TestDataset(Dataset): def __init__(self, size=5, shape=128): @@ -65,7 +65,7 @@ def test_util(self): ) model_name_or_path = "facebook/opt-125m" - model = load_shell(model_name_or_path, AutoModelForCausalLM, torchscript=True) + model = load_empty_model(model_name_or_path, torchscript=True) children = get_children(model) named_children = get_named_children(model) self.assertEqual(children, [v for k, v in named_children]) diff --git a/test/algorithm/test_lwq_weight_only.py b/test/algorithm/test_lwq_weight_only.py index c8840123382..3735f0b2581 100644 --- a/test/algorithm/test_lwq_weight_only.py +++ b/test/algorithm/test_lwq_weight_only.py @@ -1,21 +1,22 @@ import shutil import sys import unittest +from copy import deepcopy sys.path.insert(0, "./") import torch from torch.utils.data import DataLoader, Dataset -from transformers import AutoModelForCausalLM, AutoTokenizer from neural_compressor import PostTrainingQuantConfig, quantization -from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_shell +from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_empty_model from neural_compressor.utils.pytorch import load class TestLayerWise(unittest.TestCase): - def test_layer_wise(self): - model_name_or_path = "facebook/opt-125m" - fp32_model = load_shell(model_name_or_path, AutoModelForCausalLM, torchscript=True) + @classmethod + def setUpClass(self): + self.model_name_or_path = "facebook/opt-125m" + self.fp32_model = load_empty_model(self.model_name_or_path, torchscript=True) class TestDataset(Dataset): def __init__(self, size=5, shape=128): @@ -29,30 +30,58 @@ def __len__(self): return self.len eval_dataset = TestDataset() - eval_dataloader = DataLoader(eval_dataset, batch_size=8) + self.eval_dataloader = DataLoader(eval_dataset, batch_size=8) + @classmethod + def tearDownClass(cls): + shutil.rmtree("./saved_model", ignore_errors=True) + + def test_rtn_lwq(self): conf = PostTrainingQuantConfig( approach="weight_only", recipes={ "layer_wise_quant": True, - "layer_wise_quant_args": { - "model_path": "facebook/opt-125m", - }, + # "layer_wise_quant_args": { + # "model_path": "facebook/opt-125m", + # }, "rtn_args": {"enable_full_range": True}, }, ) q_model = quantization.fit( - fp32_model, + deepcopy(self.fp32_model), conf, - calib_dataloader=eval_dataloader, + calib_dataloader=self.eval_dataloader, eval_func=lambda x: 0.1, ) ouput_dir = "./saved_model" q_model.save(ouput_dir) - load_model = load(ouput_dir, fp32_model, weight_only=True) + load_model = load(ouput_dir, deepcopy(self.fp32_model), weight_only=True) + self.assertNotEqual(load_model.lm_head.weight.device.type, "meta") + + def test_gptq_lwq(self): + conf = PostTrainingQuantConfig( + approach="weight_only", + op_type_dict={ + ".*": { # re.match + "weight": { + "bits": 4, # 1-8 bits + "group_size": 32, + "scheme": "sym", + "algorithm": "GPTQ", + }, + }, + }, + recipes={ + "gptq_args": {"actorder": True, "mse": True, "perchannel": False}, + "layer_wise_quant": True, + }, + ) + q_model = quantization.fit(deepcopy(self.fp32_model), conf, calib_dataloader=self.eval_dataloader) + ouput_dir = "./saved_model" + q_model.save(ouput_dir) + load_model = load(ouput_dir, deepcopy(self.fp32_model), weight_only=True, layer_wise=True) self.assertNotEqual(load_model.lm_head.weight.device.type, "meta") - shutil.rmtree(ouput_dir) if __name__ == "__main__":