Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add layer wise quantization doc and ONNXRT example #1434

Merged
merged 25 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
abb6312
update onnxrt woq example
yuwenzho Nov 28, 2023
0655119
add layer-wise quantization example
yuwenzho Nov 30, 2023
d4e39e8
Merge branch 'master' into yuwenzho/woq_example_doc
yuwenzho Nov 30, 2023
ca67739
fix docstring
yuwenzho Nov 30, 2023
54179d0
update README.md
yuwenzho Nov 30, 2023
ea2078e
add layer-wise quantization doc
yuwenzho Nov 30, 2023
1acdc9c
update onnxrt lwq figure
yuwenzho Nov 30, 2023
a6a656b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 30, 2023
7035187
update quantization_layer_wise.md
yuwenzho Nov 30, 2023
66be652
update ox_utils
yuwenzho Nov 30, 2023
785f6a6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 30, 2023
ae13425
Merge branch 'master' into yuwenzho/lwq_example_doc
yuwenzho Dec 1, 2023
de59987
fix import bug
yuwenzho Dec 1, 2023
595007b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 1, 2023
966aa9b
Update run_quant.sh
yuwenzho Dec 1, 2023
1f10a92
fix import bug
yuwenzho Dec 1, 2023
472ebbe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 1, 2023
23ce23b
update requirement
yuwenzho Dec 5, 2023
da05ba4
update README and doc
yuwenzho Dec 5, 2023
e3116d9
update onnx_model.py
yuwenzho Dec 5, 2023
cf126ce
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 5, 2023
0ccb151
Merge branch 'master' into yuwenzho/lwq_example_doc
yuwenzho Dec 5, 2023
449114f
update onnx_model.py
yuwenzho Dec 5, 2023
ed3f844
update main.py
yuwenzho Dec 6, 2023
95ed5fd
update main.py
yuwenzho Dec 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/source/imgs/lwq_ort.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
98 changes: 98 additions & 0 deletions docs/source/quantization_layer_wise.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
Layer Wise Quantization (LWQ)
=====

1. [Introduction](#introduction)

2. [Supported Framework Model Matrix](#supported-framework-model-matrix)

3. [Examples](#examples)

## Introduction

Large language models (LLMs) have shown exceptional performance across various tasks, meanwhile, the substantial parameter size poses significant challenges for deployment. Layer-wise quantization(LWQ) can greatly reduce the memory footprint of LLMs, usually 80-90% reduction, which means that users can quantize LLMs even on single node using GPU or CPU. We can quantize the model under memory-constrained devices, therefore making the huge-sized LLM quantization possible.

<img src="./imgs/lwq.png" width=780 height=429>

*Figure 1: The process of layer-wise quantization for PyTorch model. The color grey means empty parameters and the color blue represents parameters need to be quantized. Every rectangle inside model represents one layer.*

<img src="./imgs/lwq_ort.png" width=900 height=400>

*Figure 2: The process of layer-wise quantization for ONNX model. The graph of LLM is split into several parts, and each subgraph is quantized in turn.*

## Supported Framework Model Matrix


<table class="tg">
<thead>
<tr>
<th colspan="2" style="text-align:center;vertical-align:middle">Types/Framework</th>
<th style="text-align:center;vertical-align:middle">PyTorch</th>
<th style="text-align:center;vertical-align:middle">ONNX Runtime</th>
</tr>
</thead>
<tbody>
<tr>
<td style="text-align:center;vertical-align:middle" colspan="2">W8A8 Post Training Static Quantization</td>
<td style="text-align:center;vertical-align:middle">&#10004;</td>
<td style="text-align:center;vertical-align:middle">&#10004;</td>
</tr>
<tr>
<td style="text-align:center;vertical-align:middle" rowspan="4">Weight-only Quantization</td>
<td style="text-align:center;vertical-align:middle">RTN</td>
<td style="text-align:center;vertical-align:middle">&#10004;</td>
<td style="text-align:center;vertical-align:middle" rowspan="4">&#10005;</td>
</tr>
<tr>
<td style="text-align:center;vertical-align:middle">AWQ</td>
<td style="text-align:center;vertical-align:middle">&#10005;</td>
</tr>
<tr>
<td style="text-align:center;vertical-align:middle">GPTQ</td>
<td style="text-align:center;vertical-align:middle">&#10004;</td>
</tr>
<tr>
<td style="text-align:center;vertical-align:middle">TEQ</td>
<td style="text-align:center;vertical-align:middle">&#10005;</td>
</tr>
</tbody>
</table>

## Examples

#### PyTorch framework example

```python
from neural_compressor import PostTrainingQuantConfig, quantization
from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_empty_model

fp32_model = load_empty_model(model_name_or_path, torchscript=True)
conf = PostTrainingQuantConfig(
approach="weight_only",
recipes={
"layer_wise_quant": True,
"rtn_args": {"enable_full_range": True},
},
)

q_model = quantization.fit(
fp32_model,
conf,
calib_dataloader=eval_dataloader,
eval_func=lambda x: 0.1,
)
ouput_dir = "./saved_model"
q_model.save(ouput_dir)
q_model = load(ouput_dir, fp32_model, weight_only=True, layer_wise=True)
```

#### ONNX Runtime framework example

```python
from neural_compressor import quantization, PostTrainingQuantConfig

conf = PostTrainingQuantConfig(recipes={"layer_wise_quant": True})
q_model = quantization.fit(fp32_model_path, conf, calib_dataloader=dataloader)
q_model.save(int8_model_path)
```

Refer to [ONNX Runtime llama-2 LWQ example](../../examples/onnxrt/nlp/huggingface_model/text_generation/llama/quantization/weight_only)
7 changes: 7 additions & 0 deletions examples/.config/model_params_onnxrt.json
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,13 @@
"main_script": "main.py",
"batch_size": 1
},
"llama-2-7b-lwq": {
"model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/ptq_static",
"dataset_location": "",
"input_model": "/tf_dataset2/models/onnx/llama-2-7b",
"main_script": "main.py",
"batch_size": 1
},
"llama-2-7b-rtn": {
"model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/weight_only",
"dataset_location": "",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ optimum-cli export onnx --model meta-llama/Llama-2-7b-hf --task text-generation-

## 1. Quantization

Run SmoothQuant

```bash
bash run_quant.sh --input_model=/path/to/model \ # folder path of onnx model
--output_model=/path/to/model_tune \ # folder path to save onnx model
Expand All @@ -44,6 +46,19 @@ bash run_quant.sh --input_model=/path/to/model \ # folder path of onnx model
--quant_format="QOperator" # or QDQ, optional
```

Additionally set `--layer-wise=True` to use layer-wise quantization to save your memory. Please note that layer-wise quantization for ONNX models is still under development and only support W8A8 quantization now. More details please refer to [layer wise quantiation](https://github.com/intel/neural-compressor/blob/master/docs/source/quantization_layer_wise.md).

chensuyue marked this conversation as resolved.
Show resolved Hide resolved
```bash
bash run_quant.sh --input_model=/path/to/model \ # folder path of onnx model
--output_model=/path/to/model_tune \ # folder path to save onnx model
--batch_size=batch_size # optional \
--dataset NeelNanda/pile-10k \
--tokenizer=meta-llama/Llama-2-7b-hf \ # model name or folder path containing all relevant files for model's tokenizer
--quant_format="QOperator" \ # or QDQ, optional
--layer_wise=True
```


## 2. Benchmark

Accuracy:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@
type=int,
default=4
)
parser.add_argument(
'--layer_wise',
action='store_true', \
default=False,
)
args = parser.parse_args()

# load model
Expand Down Expand Up @@ -258,15 +263,36 @@ def __iter__(self):

if args.tune:
from neural_compressor import quantization, PostTrainingQuantConfig
config = PostTrainingQuantConfig(
calibration_sampling_size=[8],
recipes={'optypes_to_exclude_output_quant': ['MatMul'],
'smooth_quant': True,
'smooth_quant_args': {'alpha': args.smooth_quant_alpha}},
op_type_dict={'^((?!(MatMul|Gather|Conv)).)*$': {'weight': {'dtype': ['fp32']}, 'activation': {'dtype': ['fp32']}}})
for model in ['decoder_model.onnx', 'decoder_with_past_model.onnx']:
q_model = quantization.fit(
os.path.join(args.model_path, model),
config,
calib_dataloader=KVDataloader(os.path.join(args.model_path, model), pad_max=args.pad_max, batch_size=1))
q_model.save(os.path.join(args.output_model, model))
if args.layer_wise:
# layer-wise quantization for ONNX models is still under development and only support W8A8 quantization now
config = PostTrainingQuantConfig(
calibration_sampling_size=[8],
recipes={'optypes_to_exclude_output_quant': ['MatMul'],
'layer_wise_quant': True},
op_type_dict={'^((?!(MatMul|Gather|Conv)).)*$': {'weight': {'dtype': ['fp32']}, 'activation': {'dtype': ['fp32']}}})
for model in ['decoder_model.onnx']:
# only test decoder_model
q_model = quantization.fit(
os.path.join(args.model_path, model),
config,
calib_dataloader=KVDataloader(os.path.join(args.model_path, model), pad_max=args.pad_max, batch_size=1))
q_model.save(os.path.join(args.output_model, model))

tokenizer.save_pretrained(args.output_model)

else:
config = PostTrainingQuantConfig(
calibration_sampling_size=[8],
recipes={'optypes_to_exclude_output_quant': ['MatMul'],
'smooth_quant': True,
'smooth_quant_args': {'alpha': args.smooth_quant_alpha},
},
op_type_dict={'^((?!(MatMul|Gather|Conv)).)*$': {'weight': {'dtype': ['fp32']}, 'activation': {'dtype': ['fp32']}}})
for model in ['decoder_model.onnx', 'decoder_with_past_model.onnx']:
q_model = quantization.fit(
os.path.join(args.model_path, model),
config,
calib_dataloader=KVDataloader(os.path.join(args.model_path, model), pad_max=args.pad_max, batch_size=1))
q_model.save(os.path.join(args.output_model, model))

tokenizer.save_pretrained(args.output_model)
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ function init_params {
--tokenizer=*)
tokenizer=$(echo $var |cut -f2 -d=)
;;
--layer_wise=*)
layer_wise=$(echo $var |cut -f2 -d=)
;;
esac
done

Expand Down Expand Up @@ -59,6 +62,11 @@ function run_tuning {
echo "Created directory $output_model"
fi

# check if layer_wise option is set to true (case insensitive)
if [ "${layer_wise,,}" = "true" ]; then
extra_cmd="--layer_wise"
fi

python main.py \
--quant_format ${quant_format-QOperator} \
--model_path ${input_model} \
Expand All @@ -67,7 +75,8 @@ function run_tuning {
--batch_size ${batch_size-1} \
--smooth_quant_alpha ${alpha-0.6} \
--dataset ${dataset-NeelNanda/pile-10k} \
--tune
--tune \
${extra_cmd}
}

main "$@"
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/adaptor/onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,7 @@ def _pre_optimize(self, model, level=1):
from onnx.external_data_helper import load_external_data_for_model

load_external_data_for_model(tmp_model, os.path.split(model.model_path)[0])
model.model_path = sess_options.optimized_model_filepath
model.model_path = sess_options.optimized_model_filepath
else:
model.model_path = sess_options.optimized_model_filepath

Expand Down
4 changes: 2 additions & 2 deletions neural_compressor/adaptor/ox_utils/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,9 @@ def get_intermediate_outputs(self, q_config=None):
for output in session.get_outputs()
]
augment_model_wrapper = (
ONNXModel(self.augmented_model)
ONNXModel(self.augmented_model, load_external_data=False)
if not self.model_wrapper.is_large_model
else ONNXModel(self.model_wrapper.model_path + "_augment.onnx")
else ONNXModel(self.model_wrapper.model_path + "_augment.onnx", load_external_data=False)
)
input_name_to_nodes = augment_model_wrapper.input_name_to_nodes
output_name_to_node = augment_model_wrapper.output_name_to_node
Expand Down
38 changes: 38 additions & 0 deletions neural_compressor/adaptor/ox_utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
numpy_helper = LazyImport("onnx.numpy_helper")
onnx_proto = LazyImport("onnx.onnx_pb")
torch = LazyImport("torch")
symbolic_shape_infer = LazyImport("onnxruntime.tools.symbolic_shape_infer")
onnx = LazyImport("onnx")


__producer__ = "onnx.quantize"
__version__ = "0.1.0"
Expand Down Expand Up @@ -594,3 +597,38 @@ def to_numpy(data):
)
else:
return data


def infer_shapes(in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0, base_dir=""):
"""Symbolic shape inference."""

class SymbolicShapeInference(symbolic_shape_infer.SymbolicShapeInference):
def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix="", base_dir=""):
super().__init__(int_max, auto_merge, guess_output_rank, verbose, prefix)
self.base_dir = base_dir

def _get_value(self, node, idx):
name = node.input[idx]
assert name in self.sympy_data_ or name in self.initializers_
return (
self.sympy_data_[name]
if name in self.sympy_data_
else numpy_helper.to_array(self.initializers_[name], base_dir=self.base_dir)
)

onnx_opset = symbolic_shape_infer.get_opset(in_mp)
if (not onnx_opset) or onnx_opset < 7:
logger.warning("Only support models of onnx opset 7 and above.")
return None
symbolic_shape_inference = SymbolicShapeInference(
int_max, auto_merge, guess_output_rank, verbose, base_dir=base_dir
)
all_shapes_inferred = False
symbolic_shape_inference._preprocess(in_mp)
while symbolic_shape_inference.run_:
all_shapes_inferred = symbolic_shape_inference._infer_impl()
symbolic_shape_inference._update_output_from_vi()
if not all_shapes_inferred:
onnx.save_model(symbolic_shape_inference.out_mp_, "sym_shape_infer_temp.onnx", save_as_external_data=True)
raise Exception("Incomplete symbolic shape inference")
return symbolic_shape_inference.out_mp_
13 changes: 10 additions & 3 deletions neural_compressor/model/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,21 @@ def __init__(self, model, **kwargs):

Args:
model (str or ModelProto): path to onnx model or loaded ModelProto model object.
ignore_warning (bool): ignore large model warning. Default is False.
load_external_data (bool): load external data for large model. Default is True.
"""
self._model = model if not isinstance(model, str) else onnx.load(model)
self._model = model if not isinstance(model, str) else onnx.load(model, load_external_data=False)
self._model_path = None if not isinstance(model, str) else model

self.check_is_large_model()
if self._is_large_model and self._model_path is None and not kwargs.get("ignore_warning", False):
logger.warning("Model size > 2GB. Please use model path instead of onnx model object to quantize")

if self._is_large_model and isinstance(model, str) and kwargs.get("load_external_data", True):
from onnx.external_data_helper import load_external_data_for_model

load_external_data_for_model(self._model, os.path.dirname(self._model_path))

self._config = None
if isinstance(model, str) and os.path.exists(Path(model).parent.joinpath("config.json").as_posix()):
from transformers import PretrainedConfig
Expand Down Expand Up @@ -1038,9 +1045,9 @@ def split_model_with_node(
if shape_infer:
try:
# need ort.GraphOptimizationLevel <= ORT_ENABLE_BASIC
import onnxruntime.tools.symbolic_shape_infer as symbolic_shape_infer
from neural_compressor.adaptor.ox_utils.util import infer_shapes

self._model = symbolic_shape_infer.SymbolicShapeInference.infer_shapes(self._model, auto_merge=True)
self._model = infer_shapes(self._model, auto_merge=True, base_dir=os.path.dirname(self._model_path))
except Exception as e: # pragma: no cover
logger.error("Shape infer fails for layer-wise quantization")
if "Incomplete symbolic shape inference" in str(e):
chensuyue marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
Loading