From 0f260c2ccd6279fd46ee6118ea3af7779c611c17 Mon Sep 17 00:00:00 2001 From: Alexander Kozlov Date: Wed, 8 Nov 2023 10:17:57 +0400 Subject: [PATCH] [DOC]: Added INT4 weight compression description (#20812) * Added INT4 information into weight compression doc * Added GPTQ info. Fixed comments * Fixed list * Fixed issues. Updated Gen.AI doc * Applied comments * Added additional infor about GPTQ support * Fixed typos * Update docs/articles_en/openvino_workflow/gen_ai.md Co-authored-by: Nico Galoppo * Update docs/articles_en/openvino_workflow/gen_ai.md Co-authored-by: Nico Galoppo * Update docs/optimization_guide/nncf/code/weight_compression_openvino.py Co-authored-by: Nico Galoppo * Applied changes * Update docs/articles_en/openvino_workflow/gen_ai.md Co-authored-by: Tatiana Savina * Update docs/articles_en/openvino_workflow/gen_ai.md Co-authored-by: Tatiana Savina * Update docs/articles_en/openvino_workflow/gen_ai.md Co-authored-by: Tatiana Savina * Update docs/articles_en/openvino_workflow/model_optimization_guide/weight_compression.md Co-authored-by: Tatiana Savina * Update docs/articles_en/openvino_workflow/model_optimization_guide/weight_compression.md Co-authored-by: Tatiana Savina * Update docs/articles_en/openvino_workflow/model_optimization_guide/weight_compression.md Co-authored-by: Tatiana Savina * Update docs/articles_en/openvino_workflow/model_optimization_guide/weight_compression.md Co-authored-by: Tatiana Savina * Update docs/articles_en/openvino_workflow/model_optimization_guide/weight_compression.md Co-authored-by: Tatiana Savina * Added table with results * One more comment --------- Co-authored-by: Nico Galoppo Co-authored-by: Tatiana Savina --- docs/articles_en/openvino_workflow/gen_ai.md | 22 ++++ .../weight_compression.md | 103 +++++++++++++++++- .../nncf/code/weight_compression_openvino.py | 9 +- 3 files changed, 131 insertions(+), 3 deletions(-) diff --git a/docs/articles_en/openvino_workflow/gen_ai.md b/docs/articles_en/openvino_workflow/gen_ai.md index 4ecb55fcc2427c..40567d2daa353d 100644 --- a/docs/articles_en/openvino_workflow/gen_ai.md +++ b/docs/articles_en/openvino_workflow/gen_ai.md @@ -115,6 +115,28 @@ Optimum-Intel API also provides out-of-the-box model optimization through weight Weight compression is applied by default to models larger than one billion parameters and is also available for CLI interface as the ``--int8`` option. +.. note:: + + 8-bit weight compression is enabled by default for models larger than 1 billion parameters. + +`NNCF `__ also provides 4-bit weight compression, which is supported by OpenVINO. It can be applied to Optimum objects as follows: + +.. code-block:: python + + from nncf import compress_weights, CompressWeightsMode + + model = OVModelForCausalLM.from_pretrained(model_id, export=True, load_in_8bit=False) + model.model = compress_weights(model.model, mode=CompressWeightsMode.INT4_SYM, group_size=128, ratio=0.8) + + +The optimized model can be saved as usual with a call to ``save_pretrained()``. For more details on compression options, refer to the :doc:`weight compression guide `. + +.. note:: + + OpenVINO also supports 4-bit models from Hugging Face `Transformers `__ library optimized + with `GPTQ `__. In this case, there is no need for an additional model optimization step because model conversion will automatically preserve the INT4 optimization results, allowing model inference to benefit from it. + + Below are some examples of using Optimum-Intel for model conversion and inference: * `Stable Diffusion v2.1 using Optimum-Intel OpenVINO `__ diff --git a/docs/articles_en/openvino_workflow/model_optimization_guide/weight_compression.md b/docs/articles_en/openvino_workflow/model_optimization_guide/weight_compression.md index fb29a6d49b767f..fd9599a31f7ea7 100644 --- a/docs/articles_en/openvino_workflow/model_optimization_guide/weight_compression.md +++ b/docs/articles_en/openvino_workflow/model_optimization_guide/weight_compression.md @@ -10,12 +10,14 @@ Weight compression aims to reduce the memory footprint of a model. It can also l - enabling the inference of exceptionally large models that cannot be accommodated in the memory of the device; - improving the inference performance of the models by reducing the latency of the memory access when computing the operations with weights, for example, Linear layers. -Currently, `Neural Network Compression Framework (NNCF) `__ provides 8-bit weight quantization as a compression method primarily designed to optimize LLMs. The main difference between weights compression and full model quantization (post-training quantization) is that activations remain floating-point in the case of weights compression which leads to a better accuracy. Weight compression for LLMs provides a solid inference performance improvement which is on par with the performance of the full model quantization. In addition, weight compression is data-free and does not require a calibration dataset, making it easy to use. +Currently, `Neural Network Compression Framework (NNCF) `__ provides weight quantization to 8 and 4-bit integer data types as a compression method primarily designed to optimize LLMs. The main difference between weights compression and full model quantization is that activations remain floating-point in the case of weight compression, resulting in better accuracy. Weight compression for LLMs provides a solid inference performance improvement which is on par with the performance of the full model quantization. In addition, weight compression is data-free and does not require a calibration dataset, making it easy to use. Compress Model Weights ###################### -The code snippet below shows how to compress the weights of the model represented in OpenVINO IR using NNCF: +- **8-bit weight quantization** - this method is aimed at accurate optimization of the model, which usually leads to significant performance improvements for Transformer-based models. Models with 8-bit compressed weights are performant on the vast majority of supported CPU and GPU platforms. + +The code snippet below shows how to do 8-bit quantization of the model weights represented in OpenVINO IR using NNCF: .. tab-set:: @@ -28,6 +30,103 @@ The code snippet below shows how to compress the weights of the model represente Now, the model is ready for compilation and inference. It can be also saved into a compressed format, resulting in a smaller binary file. +- **4-bit weight quantization** - this method stands for an INT4-INT8 mixed-precision weight quantization, where INT4 is considered as the primary precision and INT8 is the backup one. It usually results in a smaller model size and lower inference latency, although the accuracy degradation could be higher, depending on the model. The method has several parameters that can provide different performance-accuracy trade-offs after optimization: + + * ``mode`` - there are two modes to choose from: ``INT4_SYM`` - stands for INT4 symmetric weight quantization and results in faster inference and smaller model size, and ``INT4_ASYM`` - INT4 asymmetric weight quantization with variable zero-point for more accurate results. + + * ``group_size`` - controls the size of the group of weights that share the same quantization parameters. Smaller model size results in a more accurate optimized model but with a larger footprint and slower inference. The following group sizes are recommended: ``128``, ``64``, ``32`` (``128`` is default value) + + * ``ratio`` - controls the ratio between INT4 and INT8 compressed layers in the model. For example, 0.8 means that 80% of layers will be compressed to INT4, while the rest will be compressed to INT8 precision. + +The example below shows 4-bit weight quantization applied on top of OpenVINO IR: + +.. tab-set:: + + .. tab-item:: OpenVINO + :sync: openvino + + .. doxygensnippet:: docs/optimization_guide/nncf/code/weight_compression_openvino.py + :language: python + :fragment: [compression_4bit] + +.. note:: + + OpenVINO also supports 4-bit models from Hugging Face `Transformers `__ library optimized + with `GPTQ `__. In this case, there is no need for an additional model optimization step because model conversion will automatically preserve the INT4 optimization results, allowing model inference to benefit from it. + + +The table below shows examples of Text Generation models with different optimization settings: + +.. list-table:: + :widths: 40 55 25 25 + :header-rows: 1 + + * - Model + - Optimization + - Perplexity + - Model Size (Gb) + * - databricks/dolly-v2-3b + - FP32 + - 5.01 + - 10.3 + * - databricks/dolly-v2-3b + - INT8 + - 5.07 + - 2.6 + * - databricks/dolly-v2-3b + - INT4_ASYM,group_size=32,ratio=0.5 + - 5.28 + - 2.2 + * - facebook/opt-6.7b + - FP32 + - 4.25 + - 24.8 + * - facebook/opt-6.7b + - INT8 + - 4.27 + - 6.2 + * - facebook/opt-6.7b + - INT4_ASYM,group_size=64,ratio=0.8 + - 4.32 + - 4.1 + * - meta-llama/Llama-2-7b-chat-hf + - FP32 + - 3.28 + - 25.1 + * - meta-llama/Llama-2-7b-chat-hf + - INT8 + - 3.29 + - 6.3 + * - meta-llama/Llama-2-7b-chat-hf + - INT4_ASYM,group_size=128,ratio=0.8 + - 3.41 + - 4.0 + * - togethercomputer/RedPajama-INCITE-7B-Instruct + - FP32 + - 4.15 + - 25.6 + * - togethercomputer/RedPajama-INCITE-7B-Instruct + - INT8 + - 4.17 + - 6.4 + * - togethercomputer/RedPajama-INCITE-7B-Instruct + - INT4_ASYM,group_size=128,ratio=1.0 + - 4.17 + - 3.6 + * - meta-llama/Llama-2-13b-chat-hf + - FP32 + - 2.92 + - 48.5 + * - meta-llama/Llama-2-13b-chat-hf + - INT8 + - 2.91 + - 12.1 + * - meta-llama/Llama-2-13b-chat-hf + - INT4_SYM,group_size=64,ratio=0.8 + - 2.98 + - 8.0 + + Additional Resources #################### diff --git a/docs/optimization_guide/nncf/code/weight_compression_openvino.py b/docs/optimization_guide/nncf/code/weight_compression_openvino.py index c9ab67efd5aa32..d66fb28f4243c0 100644 --- a/docs/optimization_guide/nncf/code/weight_compression_openvino.py +++ b/docs/optimization_guide/nncf/code/weight_compression_openvino.py @@ -3,4 +3,11 @@ ... model = compress_weights(model) # model is openvino.Model object -#! [compression_8bit] \ No newline at end of file +#! [compression_8bit] + +#! [compression_4bit] +from nncf import compress_weights, CompressWeightsMode + +... +model = compress_weights(model, mode=CompressWeightsMode.INT4_SYM, group_size=128, ratio=0.8) # model is openvino.Model object +#! [compression_4bit] \ No newline at end of file