From a9df7da60927d387eaebbb0f630ee0ec90c75f74 Mon Sep 17 00:00:00 2001 From: yan tomsinsky Date: Sun, 19 May 2024 16:39:09 +0300 Subject: [PATCH 01/51] [SW-184941] INC CI, CD and Promotion Change-Id: I60c420f9776e1bdab7bb9e02e5bcbdb6891bfe52 --- requirements_pt.txt | 1 - setup.py | 8 +++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/requirements_pt.txt b/requirements_pt.txt index c3891a27b99..018b1b9dbf6 100644 --- a/requirements_pt.txt +++ b/requirements_pt.txt @@ -3,6 +3,5 @@ numpy < 2.0 peft prettytable psutil -py-cpuinfo pydantic tbb diff --git a/setup.py b/setup.py index bb23ac7866a..f4706563d00 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,5 @@ import os import re -import subprocess import sys from io import open @@ -135,11 +134,10 @@ def get_build_version(): description="Repository of Intel® Neural Compressor", long_description=open("README.md", "r", encoding="utf-8").read(), long_description_content_type="text/markdown", - keywords="quantization,auto-tuning,post-training static quantization," - "post-training dynamic quantization,quantization-aware training", license="Apache 2.0", - url="https://github.com/intel/neural-compressor", - packages=include_packages, + keywords="quantization", + url="", + packages=find_packages(include=['neural_compressor', 'neural_compressor.*']), include_package_data=True, package_data=package_data, install_requires=install_requires, From 14f031e516262e3a953febb71d8e2b6b2e0bec18 Mon Sep 17 00:00:00 2001 From: Ron Ben Moshe Date: Thu, 6 Jun 2024 10:58:15 +0300 Subject: [PATCH 02/51] [SW-183320]updated setup.py Change-Id: I592af89486cb1d9e0b5197521c428920197a9103 --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index f4706563d00..a2392358572 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,6 @@ import os import re +import subprocess import sys from io import open From ee7e5c8f72db7d20df377cf19b26edfaa71892c5 Mon Sep 17 00:00:00 2001 From: Zhou Yuwen Date: Wed, 22 May 2024 07:39:06 +0000 Subject: [PATCH 03/51] [SW-177474] add HQT FP8 porting code Change-Id: I4676f13a5ed43c444f2ec68675cc41335e7234dd Signed-off-by: Zhou Yuwen --- .../quantization/habana_fp8/README.md | 24 - .../models/configuration_chatglm.py | 61 - .../habana_fp8/models/modeling_chatglm.py | 1294 ----------------- .../habana_fp8/models/modeling_llama.py | 1263 ---------------- .../models/tokenization_baichuan.py | 255 ---- .../quantization/habana_fp8/requirement.txt | 7 - .../quantization/habana_fp8/run_llm.py | 222 --- .../quantization/habana_fp8/utils.py | 255 ---- examples/fp8_sample/README.md | 96 ++ examples/fp8_sample/maxabs_measure.json | 7 + examples/fp8_sample/maxabs_quant.json | 8 + examples/fp8_sample/quant_config.json | 8 + examples/fp8_sample/sample_one_step.py | 57 + examples/fp8_sample/sample_two_steps.py | 50 + .../{habana_fp8 => fp8_quant}/__init__.py | 9 +- .../torch/algorithms/fp8_quant/common.py | 98 ++ .../torch/algorithms/fp8_quant/fp8_quant.py | 61 + .../algorithms/fp8_quant/helper_modules.py | 118 ++ .../torch/algorithms/habana_fp8/fp8_quant.py | 220 --- .../torch/algorithms/habana_fp8/modules.py | 487 ------- .../torch/algorithms/habana_fp8/observer.py | 440 ------ .../torch/algorithms/habana_fp8/save_load.py | 105 -- .../torch/algorithms/habana_fp8/scale.py | 59 - .../algorithms/habana_fp8/tensor/__init__.py | 13 - .../algorithms/habana_fp8/tensor/convert.cpp | 63 - neural_compressor/torch/amp/__init__.py | 15 - neural_compressor/torch/amp/autocast.py | 95 -- neural_compressor/torch/amp/fp8/__init__.py | 13 - neural_compressor/torch/amp/fp8/functions.py | 134 -- .../torch/quantization/__init__.py | 2 +- .../torch/quantization/algorithm_entry.py | 30 +- .../torch/quantization/config.py | 161 +- .../torch/quantization/quantize.py | 28 +- setup.py | 3 +- test/3x/torch/amp/test_fp8_amp.py | 75 - .../torch/quantization/habana_fp8/test_fp8.py | 189 --- 36 files changed, 660 insertions(+), 5365 deletions(-) delete mode 100644 examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/README.md delete mode 100644 examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/configuration_chatglm.py delete mode 100644 examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/modeling_chatglm.py delete mode 100644 examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/modeling_llama.py delete mode 100644 examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/tokenization_baichuan.py delete mode 100644 examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/requirement.txt delete mode 100644 examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py delete mode 100644 examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/utils.py create mode 100644 examples/fp8_sample/README.md create mode 100644 examples/fp8_sample/maxabs_measure.json create mode 100644 examples/fp8_sample/maxabs_quant.json create mode 100644 examples/fp8_sample/quant_config.json create mode 100644 examples/fp8_sample/sample_one_step.py create mode 100644 examples/fp8_sample/sample_two_steps.py rename neural_compressor/torch/algorithms/{habana_fp8 => fp8_quant}/__init__.py (70%) create mode 100644 neural_compressor/torch/algorithms/fp8_quant/common.py create mode 100644 neural_compressor/torch/algorithms/fp8_quant/fp8_quant.py create mode 100644 neural_compressor/torch/algorithms/fp8_quant/helper_modules.py delete mode 100644 neural_compressor/torch/algorithms/habana_fp8/fp8_quant.py delete mode 100644 neural_compressor/torch/algorithms/habana_fp8/modules.py delete mode 100644 neural_compressor/torch/algorithms/habana_fp8/observer.py delete mode 100644 neural_compressor/torch/algorithms/habana_fp8/save_load.py delete mode 100644 neural_compressor/torch/algorithms/habana_fp8/scale.py delete mode 100644 neural_compressor/torch/algorithms/habana_fp8/tensor/__init__.py delete mode 100644 neural_compressor/torch/algorithms/habana_fp8/tensor/convert.cpp delete mode 100644 neural_compressor/torch/amp/__init__.py delete mode 100644 neural_compressor/torch/amp/autocast.py delete mode 100644 neural_compressor/torch/amp/fp8/__init__.py delete mode 100644 neural_compressor/torch/amp/fp8/functions.py delete mode 100644 test/3x/torch/amp/test_fp8_amp.py delete mode 100644 test/3x/torch/quantization/habana_fp8/test_fp8.py diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/README.md b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/README.md deleted file mode 100644 index eb39321b173..00000000000 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/README.md +++ /dev/null @@ -1,24 +0,0 @@ -# Run - -## Run FP32 model -``` python -python run_llm.py --model [model_name_or_path] --to_graph [--performance]|[--accuracy --tasks lambada_openai --batch_size 8]|[--generate --max_new_tokens 10] -``` - -## Run BF16/FP16 model -``` python -python run_llm.py --model [model_name_or_path] --approach cast --precision [bf16|fp16] --to_graph [--performance]|[--accuracy --tasks lambada_openai --batch_size 8]|[--generate --max_new_tokens 10] -``` - -## Run FP8 model -``` python -python run_llm.py --model [model_name_or_path] --approach [dynamic|static|cast] --precision [fp8_e4m3|fp8_e5m2] --to_graph [--performance]|[--accuracy --tasks lambada_openai --batch_size 8]|[--generate --max_new_tokens 10] -``` - -# Multi-card Inference -With deepspeed we can leverage multi-cards inference with a prefix in command, below it's a demonstration of 4 card inference. - -```python -deepspeed --num_gpus=4 run_llm.py --model [model_name_or_path] --approach [dynamic|static|cast] --precision [fp8_e4m3|fp8_e5m2] --to_graph [--performance]|[--accuracy --tasks lambada_openai --batch_size 8]|[--generate --max_new_tokens 10] -``` -deepspeed --num_gpus=4 run_llm.py --model facebook/opt-125m --approach static --precision fp8_e4m3 --to_graph --accuracy --tasks lambada_openai --batch_size 8 \ No newline at end of file diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/configuration_chatglm.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/configuration_chatglm.py deleted file mode 100644 index 35600185f5a..00000000000 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/configuration_chatglm.py +++ /dev/null @@ -1,61 +0,0 @@ -from transformers import PretrainedConfig - - -class ChatGLMConfig(PretrainedConfig): - model_type = "chatglm" - def __init__( - self, - num_layers=28, - padded_vocab_size=65024, - hidden_size=4096, - ffn_hidden_size=13696, - kv_channels=128, - num_attention_heads=32, - seq_length=2048, - hidden_dropout=0.0, - classifier_dropout=None, - attention_dropout=0.0, - layernorm_epsilon=1e-5, - rmsnorm=True, - apply_residual_connection_post_layernorm=False, - post_layer_norm=True, - add_bias_linear=False, - add_qkv_bias=False, - bias_dropout_fusion=True, - multi_query_attention=False, - multi_query_group_num=1, - apply_query_key_layer_scaling=True, - attention_softmax_in_fp32=True, - fp32_residual_connection=False, - quantization_bit=0, - pre_seq_len=None, - prefix_projection=False, - **kwargs - ): - self.num_layers = num_layers - self.vocab_size = padded_vocab_size - self.padded_vocab_size = padded_vocab_size - self.hidden_size = hidden_size - self.ffn_hidden_size = ffn_hidden_size - self.kv_channels = kv_channels - self.num_attention_heads = num_attention_heads - self.seq_length = seq_length - self.hidden_dropout = hidden_dropout - self.classifier_dropout = classifier_dropout - self.attention_dropout = attention_dropout - self.layernorm_epsilon = layernorm_epsilon - self.rmsnorm = rmsnorm - self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm - self.post_layer_norm = post_layer_norm - self.add_bias_linear = add_bias_linear - self.add_qkv_bias = add_qkv_bias - self.bias_dropout_fusion = bias_dropout_fusion - self.multi_query_attention = multi_query_attention - self.multi_query_group_num = multi_query_group_num - self.apply_query_key_layer_scaling = apply_query_key_layer_scaling - self.attention_softmax_in_fp32 = attention_softmax_in_fp32 - self.fp32_residual_connection = fp32_residual_connection - self.quantization_bit = quantization_bit - self.pre_seq_len = pre_seq_len - self.prefix_projection = prefix_projection - super().__init__(**kwargs) \ No newline at end of file diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/modeling_chatglm.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/modeling_chatglm.py deleted file mode 100644 index be1cd520af5..00000000000 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/modeling_chatglm.py +++ /dev/null @@ -1,1294 +0,0 @@ -""" PyTorch ChatGLM model. """ - -import math -import copy -import warnings -import re -import sys - -import torch -import torch.utils.checkpoint -import torch.nn.functional as F -from torch import nn -from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss -from torch.nn.utils import skip_init -from typing import Optional, Tuple, Union, List, Callable, Dict, Any -from copy import deepcopy - -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - SequenceClassifierOutputWithPast, -) -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import logging -from transformers.generation.logits_process import LogitsProcessor -from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput - -from .configuration_chatglm import ChatGLMConfig - -# flags required to enable jit fusion kernels - -if sys.platform != 'darwin': - torch._C._jit_set_profiling_mode(False) - torch._C._jit_set_profiling_executor(False) - torch._C._jit_override_can_fuse_on_cpu(True) - torch._C._jit_override_can_fuse_on_gpu(True) - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM" -_CONFIG_FOR_DOC = "ChatGLMConfig" - -CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "THUDM/chatglm3-6b", - # See all ChatGLM models at https://huggingface.co/models?filter=chatglm -] - - -def default_init(cls, *args, **kwargs): - return cls(*args, **kwargs) - - -class InvalidScoreLogitsProcessor(LogitsProcessor): - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - if torch.isnan(scores).any() or torch.isinf(scores).any(): - scores.zero_() - scores[..., 5] = 5e4 - return scores - - -class PrefixEncoder(torch.nn.Module): - """ - The torch.nn model to encode the prefix - Input shape: (batch-size, prefix-length) - Output shape: (batch-size, prefix-length, 2*layers*hidden) - """ - - def __init__(self, config: ChatGLMConfig): - super().__init__() - self.prefix_projection = config.prefix_projection - if self.prefix_projection: - # Use a two-layer MLP to encode the prefix - kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2 - self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size) - self.trans = torch.nn.Sequential( - torch.nn.Linear(kv_size, config.hidden_size), - torch.nn.Tanh(), - torch.nn.Linear(config.hidden_size, kv_size) - ) - else: - self.embedding = torch.nn.Embedding(config.pre_seq_len, - config.num_layers * config.kv_channels * config.multi_query_group_num * 2) - - def forward(self, prefix: torch.Tensor): - if self.prefix_projection: - prefix_tokens = self.embedding(prefix) - past_key_values = self.trans(prefix_tokens) - else: - past_key_values = self.embedding(prefix) - return past_key_values - - -def split_tensor_along_last_dim( - tensor: torch.Tensor, - num_partitions: int, - contiguous_split_chunks: bool = False, -) -> List[torch.Tensor]: - """Split a tensor along its last dimension. - - Arguments: - tensor: input tensor. - num_partitions: number of partitions to split the tensor - contiguous_split_chunks: If True, make each chunk contiguous - in memory. - - Returns: - A list of Tensors - """ - # Get the size and dimension. - last_dim = tensor.dim() - 1 - last_dim_size = tensor.size()[last_dim] // num_partitions - # Split. - tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) - # Note: torch.split does not create contiguous tensors by default. - if contiguous_split_chunks: - return tuple(chunk.contiguous() for chunk in tensor_list) - - return tensor_list - - -class RotaryEmbedding(nn.Module): - def __init__(self, dim, original_impl=False, device=None, dtype=None): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) - self.register_buffer("inv_freq", inv_freq) - self.dim = dim - self.original_impl = original_impl - - def forward_impl( - self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000 - ): - """Enhanced Transformer with Rotary Position Embedding. - - Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ - transformers/rope/__init__.py. MIT License: - https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. - """ - # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ - theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem)) - - # Create position indexes `[0, 1, ..., seq_len - 1]` - seq_idx = torch.arange(seq_len, dtype=torch.float, device=device) - - # Calculate the product of position index and $\theta_i$ - idx_theta = torch.outer(seq_idx, theta).float() - - cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) - - # this is to mimic the behaviour of complex32, else we will get different results - if dtype in (torch.float16, torch.bfloat16, torch.int8): - cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() - return cache - - def forward(self, max_seq_len, offset=0): - return self.forward_impl( - max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device - ) - -### INC change ### -# @torch.jit.script - -def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: - # x: [sq, b, np, hn] - sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) - rot_dim = rope_cache.shape[-2] * 2 - x, x_pass = x[..., :rot_dim], x[..., rot_dim:] - # truncate to support variable sizes - rope_cache = rope_cache[:sq] - xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) - rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) - x_out2 = torch.stack( - [ - xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], - xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], - ], - -1, - ) - x_out2 = x_out2.flatten(3) - return torch.cat((x_out2, x_pass), dim=-1) - - -class RMSNorm(torch.nn.Module): - def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): - super().__init__() - self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) - self.eps = eps - - def forward(self, hidden_states: torch.Tensor): - input_dtype = hidden_states.dtype - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.eps) - - return (self.weight * hidden_states).to(input_dtype) - - -class CoreAttention(torch.nn.Module): - def __init__(self, config: ChatGLMConfig, layer_number): - super(CoreAttention, self).__init__() - - self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling - self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 - if self.apply_query_key_layer_scaling: - self.attention_softmax_in_fp32 = True - self.layer_number = max(1, layer_number) - - projection_size = config.kv_channels * config.num_attention_heads - - # Per attention head and per partition values. - self.hidden_size_per_partition = projection_size - self.hidden_size_per_attention_head = projection_size // config.num_attention_heads - self.num_attention_heads_per_partition = config.num_attention_heads - - coeff = None - self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) - if self.apply_query_key_layer_scaling: - coeff = self.layer_number - self.norm_factor *= coeff - self.coeff = coeff - - self.attention_dropout = torch.nn.Dropout(config.attention_dropout) - - def forward(self, query_layer, key_layer, value_layer, attention_mask): - pytorch_major_version = int(torch.__version__.split('.')[0]) - if pytorch_major_version >= 2: - query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] - if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, - is_causal=True) - else: - if attention_mask is not None: - attention_mask = ~attention_mask - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, - attention_mask) - context_layer = context_layer.permute(2, 0, 1, 3) - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) - context_layer = context_layer.reshape(*new_context_layer_shape) - else: - # Raw attention scores - - # [b, np, sq, sk] - output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) - - # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) - # [sk, b, np, hn] -> [sk, b * np, hn] - key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) - - # preallocting input tensor: [b * np, sq, sk] - matmul_input_buffer = torch.empty( - output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, - device=query_layer.device - ) - - # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.baddbmm( - matmul_input_buffer, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=(1.0 / self.norm_factor), - ) - - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - - # =========================== - # Attention probs and dropout - # =========================== - - # attention scores and attention mask [b, np, sq, sk] - if self.attention_softmax_in_fp32: - attention_scores = attention_scores.float() - if self.coeff is not None: - attention_scores = attention_scores * self.coeff - if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: - attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3], - device=attention_scores.device, dtype=torch.bool) - attention_mask.tril_() - attention_mask = ~attention_mask - if attention_mask is not None: - attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = attention_probs.type_as(value_layer) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.attention_dropout(attention_probs) - # ========================= - # Context layer. [sq, b, hp] - # ========================= - - # value_layer -> context layer. - # [sk, b, np, hn] --> [b, np, sq, hn] - - # context layer shape: [b, np, sq, hn] - output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) - # change view [sk, b * np, hn] - value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) - # change view [b * np, sq, sk] - attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) - # matmul: [b * np, sq, hn] - context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) - # change view [b, np, sq, hn] - context_layer = context_layer.view(*output_size) - # [b, np, sq, hn] --> [sq, b, np, hn] - context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) - context_layer = context_layer.view(*new_context_layer_shape) - - return context_layer - - -class SelfAttention(torch.nn.Module): - """Parallel self-attention layer abstract class. - - Self-attention layer takes input with size [s, b, h] - and returns output of the same size. - """ - - def __init__(self, config: ChatGLMConfig, layer_number, device=None): - super(SelfAttention, self).__init__() - self.layer_number = max(1, layer_number) - - self.projection_size = config.kv_channels * config.num_attention_heads - - # Per attention head and per partition values. - self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads - self.num_attention_heads_per_partition = config.num_attention_heads - - self.multi_query_attention = config.multi_query_attention - self.qkv_hidden_size = 3 * self.projection_size - if self.multi_query_attention: - self.num_multi_query_groups_per_partition = config.multi_query_group_num - self.qkv_hidden_size = ( - self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num - ) - self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size, - bias=config.add_bias_linear or config.add_qkv_bias, - device=device, **_config_to_kwargs(config) - ) - - self.core_attention = CoreAttention(config, self.layer_number) - - # Output. - self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear, - device=device, **_config_to_kwargs(config) - ) - - def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): - if self.multi_query_attention: - num_attention_heads = self.num_multi_query_groups_per_partition - else: - num_attention_heads = self.num_attention_heads_per_partition - return torch.empty( - inference_max_sequence_len, - batch_size, - num_attention_heads, - self.hidden_size_per_attention_head, - dtype=dtype, - device=device, - ) - - def forward( - self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True - ): - # hidden_states: [sq, b, h] - - # ================================================= - # Pre-allocate memory for key-values for inference. - # ================================================= - # ===================== - # Query, Key, and Value - # ===================== - - # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] - mixed_x_layer = self.query_key_value(hidden_states) - - if self.multi_query_attention: - (query_layer, key_layer, value_layer) = mixed_x_layer.split( - [ - self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, - ], - dim=-1, - ) - query_layer = query_layer.view( - query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) - ) - key_layer = key_layer.view( - key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) - ) - value_layer = value_layer.view( - value_layer.size()[:-1] - + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) - ) - else: - new_tensor_shape = mixed_x_layer.size()[:-1] + \ - (self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head) - mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) - - # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] - (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) - - # apply relative positional encoding (rotary embedding) - if rotary_pos_emb is not None: - query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) - key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) - - # adjust key and value for inference - if kv_cache is not None: - cache_k, cache_v = kv_cache - key_layer = torch.cat((cache_k, key_layer), dim=0) - value_layer = torch.cat((cache_v, value_layer), dim=0) - if use_cache: - kv_cache = (key_layer, value_layer) - else: - kv_cache = None - - if self.multi_query_attention: - key_layer = key_layer.unsqueeze(-2) - key_layer = key_layer.expand( - -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 - ) - key_layer = key_layer.contiguous().view( - key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) - ) - value_layer = value_layer.unsqueeze(-2) - value_layer = value_layer.expand( - -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 - ) - value_layer = value_layer.contiguous().view( - value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) - ) - - # ================================== - # core attention computation - # ================================== - - context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) - - # ================= - # Output. [sq, b, h] - # ================= - - output = self.dense(context_layer) - - return output, kv_cache - - -def _config_to_kwargs(args): - common_kwargs = { - "dtype": args.torch_dtype, - } - return common_kwargs - - -class MLP(torch.nn.Module): - """MLP. - - MLP will take the input with h hidden state, project it to 4*h - hidden dimension, perform nonlinear transformation, and project the - state back into h hidden dimension. - """ - - def __init__(self, config: ChatGLMConfig, device=None): - super(MLP, self).__init__() - - self.add_bias = config.add_bias_linear - - # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf - self.dense_h_to_4h = nn.Linear( - config.hidden_size, - config.ffn_hidden_size * 2, - bias=self.add_bias, - device=device, - **_config_to_kwargs(config) - ) - - def swiglu(x): - x = torch.chunk(x, 2, dim=-1) - return F.silu(x[0]) * x[1] - - self.activation_func = swiglu - - # Project back to h. - self.dense_4h_to_h = nn.Linear( - config.ffn_hidden_size, - config.hidden_size, - bias=self.add_bias, - device=device, - **_config_to_kwargs(config) - ) - - def forward(self, hidden_states): - # [s, b, 4hp] - intermediate_parallel = self.dense_h_to_4h(hidden_states) - intermediate_parallel = self.activation_func(intermediate_parallel) - # [s, b, h] - output = self.dense_4h_to_h(intermediate_parallel) - return output - - -class GLMBlock(torch.nn.Module): - """A single transformer layer. - - Transformer layer takes input with size [s, b, h] and returns an - output of the same size. - """ - - def __init__(self, config: ChatGLMConfig, layer_number, device=None): - super(GLMBlock, self).__init__() - self.layer_number = layer_number - - self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm - - self.fp32_residual_connection = config.fp32_residual_connection - - LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm - # Layernorm on the input data. - self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, - dtype=config.torch_dtype) - - # Self attention. - self.self_attention = SelfAttention(config, layer_number, device=device) - self.hidden_dropout = config.hidden_dropout - - # Layernorm on the attention output - self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, - dtype=config.torch_dtype) - - # MLP - self.mlp = MLP(config, device=device) - - def forward( - self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True, - ): - # hidden_states: [s, b, h] - - # Layer norm at the beginning of the transformer layer. - layernorm_output = self.input_layernorm(hidden_states) - # Self attention. - attention_output, kv_cache = self.self_attention( - layernorm_output, - attention_mask, - rotary_pos_emb, - kv_cache=kv_cache, - use_cache=use_cache - ) - - # Residual connection. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = hidden_states - - layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) - layernorm_input = residual + layernorm_input - - # Layer norm post the self attention. - layernorm_output = self.post_attention_layernorm(layernorm_input) - - # MLP. - mlp_output = self.mlp(layernorm_output) - - # Second residual connection. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = layernorm_input - - output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) - output = residual + output - - return output, kv_cache - - -class GLMTransformer(torch.nn.Module): - """Transformer class.""" - - def __init__(self, config: ChatGLMConfig, device=None): - super(GLMTransformer, self).__init__() - - self.fp32_residual_connection = config.fp32_residual_connection - self.post_layer_norm = config.post_layer_norm - - # Number of layers. - self.num_layers = config.num_layers - - # Transformer layers. - def build_layer(layer_number): - return GLMBlock(config, layer_number, device=device) - - self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)]) - - if self.post_layer_norm: - LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm - # Final layer norm before output. - self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, - dtype=config.torch_dtype) - - self.gradient_checkpointing = False - - def _get_layer(self, layer_number): - return self.layers[layer_number] - - def forward( - self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None, - use_cache: Optional[bool] = True, - output_hidden_states: Optional[bool] = False, - ): - if not kv_caches: - kv_caches = [None for _ in range(self.num_layers)] - presents = () if use_cache else None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - all_self_attentions = None - all_hidden_states = () if output_hidden_states else None - for index in range(self.num_layers): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer = self._get_layer(index) - if self.gradient_checkpointing and self.training: - layer_ret = torch.utils.checkpoint.checkpoint( - layer, - hidden_states, - attention_mask, - rotary_pos_emb, - kv_caches[index], - use_cache - ) - else: - layer_ret = layer( - hidden_states, - attention_mask, - rotary_pos_emb, - kv_cache=kv_caches[index], - use_cache=use_cache - ) - hidden_states, kv_cache = layer_ret - if use_cache: - presents = presents + (kv_cache,) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - # Final layer norm. - if self.post_layer_norm: - hidden_states = self.final_layernorm(hidden_states) - - return hidden_states, presents, all_hidden_states, all_self_attentions - - -class ChatGLMPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and - a simple interface for downloading and loading pretrained models. - """ - - is_parallelizable = False - supports_gradient_checkpointing = True - config_class = ChatGLMConfig - base_model_prefix = "transformer" - _no_split_modules = ["GLMBlock"] - - def _init_weights(self, module: nn.Module): - """Initialize the weights.""" - return - - def get_masks(self, input_ids, past_key_values, padding_mask=None): - batch_size, seq_length = input_ids.shape - full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) - full_attention_mask.tril_() - past_length = 0 - if past_key_values: - past_length = past_key_values[0][0].shape[0] - if past_length: - full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length, - device=input_ids.device), full_attention_mask), dim=-1) - if padding_mask is not None: - full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) - if not past_length and padding_mask is not None: - full_attention_mask -= padding_mask.unsqueeze(-1) - 1 - full_attention_mask = (full_attention_mask < 0.5).bool() - full_attention_mask.unsqueeze_(1) - return full_attention_mask - - def get_position_ids(self, input_ids, device): - batch_size, seq_length = input_ids.shape - position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) - return position_ids - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, GLMTransformer): - module.gradient_checkpointing = value - - -class Embedding(torch.nn.Module): - """Language model embeddings.""" - - def __init__(self, config: ChatGLMConfig, device=None): - super(Embedding, self).__init__() - - self.hidden_size = config.hidden_size - # Word embeddings (parallel). - self.word_embeddings = nn.Embedding( - config.padded_vocab_size, - self.hidden_size, - dtype=config.torch_dtype, - device=device - ) - self.fp32_residual_connection = config.fp32_residual_connection - - def forward(self, input_ids): - # Embeddings. - words_embeddings = self.word_embeddings(input_ids) - embeddings = words_embeddings - # Data format change to avoid explicit transposes : [b s h] --> [s b h]. - embeddings = embeddings.transpose(0, 1).contiguous() - # If the input flag for fp32 residual connection is set, convert for float. - if self.fp32_residual_connection: - embeddings = embeddings.float() - return embeddings - - -class ChatGLMModel(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): - super().__init__(config) - if empty_init: - init_method = skip_init - else: - init_method = default_init - init_kwargs = {} - if device is not None: - init_kwargs["device"] = device - self.embedding = init_method(Embedding, config, **init_kwargs) - self.num_layers = config.num_layers - self.multi_query_group_num = config.multi_query_group_num - self.kv_channels = config.kv_channels - - # Rotary positional embeddings - self.seq_length = config.seq_length - rotary_dim = ( - config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels - ) - - self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device, - dtype=config.torch_dtype) - self.encoder = init_method(GLMTransformer, config, **init_kwargs) - self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False, - dtype=config.torch_dtype, **init_kwargs) - self.pre_seq_len = config.pre_seq_len - self.prefix_projection = config.prefix_projection - if self.pre_seq_len is not None: - for param in self.parameters(): - param.requires_grad = False - self.prefix_tokens = torch.arange(self.pre_seq_len).long() - self.prefix_encoder = PrefixEncoder(config) - self.dropout = torch.nn.Dropout(0.1) - - def get_input_embeddings(self): - return self.embedding.word_embeddings - - def get_prompt(self, batch_size, device, dtype=torch.half): - prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) - past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) - past_key_values = past_key_values.view( - batch_size, - self.pre_seq_len, - self.num_layers * 2, - self.multi_query_group_num, - self.kv_channels - ) - # seq_len, b, nh, hidden_size - past_key_values = self.dropout(past_key_values) - past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) - return past_key_values - - def forward( - self, - input_ids, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.BoolTensor] = None, - full_attention_mask: Optional[torch.BoolTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - batch_size, seq_length = input_ids.shape - - if inputs_embeds is None: - inputs_embeds = self.embedding(input_ids) - - if self.pre_seq_len is not None: - if past_key_values is None: - past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device, - dtype=inputs_embeds.dtype) - if attention_mask is not None: - attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)), - attention_mask], dim=-1) - - if full_attention_mask is None: - if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): - full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) - - # Rotary positional embeddings - rotary_pos_emb = self.rotary_pos_emb(self.seq_length) - if position_ids is not None: - rotary_pos_emb = rotary_pos_emb[position_ids] - else: - rotary_pos_emb = rotary_pos_emb[None, :seq_length] - rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() - - # Run encoder. - hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( - inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb, - kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states - ) - - if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - def quantize(self, weight_bit_width: int): - from .quantization import quantize - quantize(self.encoder, weight_bit_width) - return self - - -class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): - super().__init__(config) - - self.max_sequence_length = config.max_length - self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) - self.config = config - self.quantized = False - - if self.config.quantization_bit: - self.quantize(self.config.quantization_bit, empty_init=True) - - def _update_model_kwargs_for_generation( - self, - outputs: ModelOutput, - model_kwargs: Dict[str, Any], - is_encoder_decoder: bool = False, - standardize_cache_format: bool = False, - ) -> Dict[str, Any]: - # update past_key_values - model_kwargs["past_key_values"] = self._extract_past_from_model_output( - outputs, standardize_cache_format=standardize_cache_format - ) - - # update attention mask - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - model_kwargs["attention_mask"] = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 - ) - - # update position ids - if "position_ids" in model_kwargs: - position_ids = model_kwargs["position_ids"] - new_position_id = position_ids[..., -1:].clone() - new_position_id += 1 - model_kwargs["position_ids"] = torch.cat( - [position_ids, new_position_id], dim=-1 - ) - - model_kwargs["is_first_forward"] = False - return model_kwargs - - def prepare_inputs_for_generation( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - is_first_forward: bool = True, - **kwargs - ) -> dict: - # only last token for input_ids if past is not None - if position_ids is None: - position_ids = self.get_position_ids(input_ids, device=input_ids.device) - if not is_first_forward: - if past_key_values is not None: - position_ids = position_ids[..., -1:] - input_ids = input_ids[:, -1:] - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "position_ids": position_ids, - "attention_mask": attention_mask, - "return_last_logit": True, - "use_cache": use_cache - } - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - return_last_logit: Optional[bool] = False, - ): - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = transformer_outputs[0] - if return_last_logit: - hidden_states = hidden_states[-1:] - lm_logits = self.transformer.output_layer(hidden_states) - lm_logits = lm_logits.transpose(0, 1).contiguous() - - loss = None - if labels is not None: - lm_logits = lm_logits.to(torch.float32) - - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss(ignore_index=-100) - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - lm_logits = lm_logits.to(hidden_states.dtype) - loss = loss.to(hidden_states.dtype) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - @staticmethod - def _reorder_cache( - past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor - ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - - Output shares the same memory storage as `past`. - """ - return tuple( - ( - layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), - layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), - ) - for layer_past in past - ) - - def process_response(self, output, history): - content = "" - history = deepcopy(history) - for response in output.split("<|assistant|>"): - metadata, content = response.split("\n", maxsplit=1) - if not metadata.strip(): - content = content.strip() - history.append({"role": "assistant", "metadata": metadata, "content": content}) - content = content.replace("[[训练时间]]", "2023年") - else: - history.append({"role": "assistant", "metadata": metadata, "content": content}) - if history[0]["role"] == "system" and "tools" in history[0]: - content = "\n".join(content.split("\n")[1:-1]) - def tool_call(**kwargs): - return kwargs - parameters = eval(content) - content = {"name": metadata.strip(), "parameters": parameters} - else: - content = {"name": metadata.strip(), "content": content} - return content, history - - @torch.inference_mode() - def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user", - max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, - **kwargs): - if history is None: - history = [] - if logits_processor is None: - logits_processor = LogitsProcessorList() - logits_processor.append(InvalidScoreLogitsProcessor()) - gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, - "temperature": temperature, "logits_processor": logits_processor, **kwargs} - inputs = tokenizer.build_chat_input(query, history=history, role=role) - inputs = inputs.to(self.device) - eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), - tokenizer.get_command("<|observation|>")] - outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id) - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1] - response = tokenizer.decode(outputs) - history.append({"role": role, "content": query}) - response, history = self.process_response(response, history) - return response, history - - @torch.inference_mode() - def stream_chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user", - past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, - logits_processor=None, return_past_key_values=False, **kwargs): - if history is None: - history = [] - if logits_processor is None: - logits_processor = LogitsProcessorList() - logits_processor.append(InvalidScoreLogitsProcessor()) - eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), - tokenizer.get_command("<|observation|>")] - gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, - "temperature": temperature, "logits_processor": logits_processor, **kwargs} - if past_key_values is None: - inputs = tokenizer.build_chat_input(query, history=history, role=role) - else: - inputs = tokenizer.build_chat_input(query, role=role) - inputs = inputs.to(self.device) - if past_key_values is not None: - past_length = past_key_values[0][0].shape[0] - if self.transformer.pre_seq_len is not None: - past_length -= self.transformer.pre_seq_len - inputs.position_ids += past_length - attention_mask = inputs.attention_mask - attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1) - inputs['attention_mask'] = attention_mask - history.append({"role": role, "content": query}) - for outputs in self.stream_generate(**inputs, past_key_values=past_key_values, - eos_token_id=eos_token_id, return_past_key_values=return_past_key_values, - **gen_kwargs): - if return_past_key_values: - outputs, past_key_values = outputs - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1] - response = tokenizer.decode(outputs) - if response and response[-1] != "�": - response, new_history = self.process_response(response, history) - if return_past_key_values: - yield response, new_history, past_key_values - else: - yield response, new_history - - @torch.inference_mode() - def stream_generate( - self, - input_ids, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, - return_past_key_values=False, - **kwargs, - ): - batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] - - if generation_config is None: - generation_config = self.generation_config - generation_config = copy.deepcopy(generation_config) - model_kwargs = generation_config.update(**kwargs) - model_kwargs["use_cache"] = generation_config.use_cache - bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id - - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None - - has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None - if has_default_max_length and generation_config.max_new_tokens is None: - warnings.warn( - f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " - "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" - " recommend using `max_new_tokens` to control the maximum length of the generation.", - UserWarning, - ) - elif generation_config.max_new_tokens is not None: - generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length - if not has_default_max_length: - logger.warn( - f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" - f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " - "Please refer to the documentation for more information. " - "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", - UserWarning, - ) - - if input_ids_seq_length >= generation_config.max_length: - input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" - logger.warning( - f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" - f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" - " increasing `max_new_tokens`." - ) - - # 2. Set generation parameters if not already defined - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - - logits_processor = self._get_logits_processor( - generation_config=generation_config, - input_ids_seq_length=input_ids_seq_length, - encoder_input_ids=input_ids, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - logits_processor=logits_processor, - ) - - stopping_criteria = self._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria - ) - logits_warper = self._get_logits_warper(generation_config) - - unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) - scores = None - while True: - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - # forward pass to get next token - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - ) - - next_token_logits = outputs.logits[:, -1, :] - - # pre-process distribution - next_token_scores = logits_processor(input_ids, next_token_logits) - next_token_scores = logits_warper(input_ids, next_token_scores) - - # sample - probs = nn.functional.softmax(next_token_scores, dim=-1) - if generation_config.do_sample: - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - else: - next_tokens = torch.argmax(probs, dim=-1) - # update generated ids, model inputs, and length for next step - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - unfinished_sequences = unfinished_sequences.mul( - next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) - ) - if return_past_key_values: - yield input_ids, outputs.past_key_values - else: - yield input_ids - # stop when each sentence is finished, or if we exceed the maximum length - if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): - break - - def quantize(self, bits: int, empty_init=False, device=None, **kwargs): - if bits == 0: - return - - from .quantization import quantize - - if self.quantized: - logger.info("Already quantized.") - return self - - self.quantized = True - - self.config.quantization_bit = bits - - self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device, - **kwargs) - return self - - -class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): - super().__init__(config) - - self.num_labels = config.num_labels - self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) - - self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half) - if config.classifier_dropout is not None: - self.dropout = nn.Dropout(config.classifier_dropout) - else: - self.dropout = None - self.config = config - - if self.config.quantization_bit: - self.quantize(self.config.quantization_bit, empty_init=True) - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - full_attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - inputs_embeds: Optional[torch.LongTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - full_attention_mask=full_attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = transformer_outputs[0] - pooled_hidden_states = hidden_states[-1] - if self.dropout is not None: - pooled_hidden_states = self.dropout(pooled_hidden_states) - logits = self.classifier_head(pooled_hidden_states) - - loss = None - if labels is not None: - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(logits.squeeze().float(), labels.squeeze()) - else: - loss = loss_fct(logits.float(), labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(logits.float(), labels.view(-1, self.num_labels)) - - if not return_dict: - output = (logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/modeling_llama.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/modeling_llama.py deleted file mode 100644 index 4cd1b6e18e8..00000000000 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/modeling_llama.py +++ /dev/null @@ -1,1263 +0,0 @@ -# coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch LLaMA model.""" -import math -import warnings -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - -from transformers.activations import ACT2FN -from transformers.modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast -from transformers.modeling_utils import PreTrainedModel -from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS -from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - logging, - replace_return_docstrings, -) -from transformers.utils.import_utils import is_torch_fx_available -from transformers.models.llama.configuration_llama import LlamaConfig - - -if is_flash_attn_2_available(): - from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa -### INC code ### -from neural_compressor.torch.quantization.modules import Matmul, BatchMatmul, Autocast - -# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. -# It means that the function will not be traced through and simply appear as a node in the graph. -if is_torch_fx_available(): - _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "LlamaConfig" - - -def _get_unpad_data(attention_mask): - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - warnings.warn( - "Calling `transformers.models.llama.modeling_llama._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils.AttentionMaskConverter._prepare_4d_attention_mask" - ) - return AttentionMaskConverter._prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) - - -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - warnings.warn( - "Calling `transformers.models.llama.modeling_llama._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.llama.modeling_llama.AttentionMaskConverter._make_causal_mask" - ) - return AttentionMaskConverter._make_causal_mask( - input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length - ) - - -class LlamaRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) - - -class LlamaRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - - -class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - t = t / self.scaling_factor - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - -class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class LlamaMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - if self.config.pretraining_tp > 1: - slice = self.intermediate_size // self.config.pretraining_tp - gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) - up_proj_slices = self.up_proj.weight.split(slice, dim=0) - down_proj_slices = self.down_proj.weight.split(slice, dim=1) - - gate_proj = torch.cat( - [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 - ) - up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) - - intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) - down_proj = [ - F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) - ] - down_proj = sum(down_proj) - else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - return down_proj - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class LlamaAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: LlamaConfig): - super().__init__() - self.config = config - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - self._init_rope() - ### INC code ### - self.matmul1 = Matmul() - self.matmul2 = Matmul() - self.cast1 = Autocast() - self.cast2 = Autocast() - - def _init_rope(self): - if self.config.rope_scaling is None: - self.rotary_emb = LlamaRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - else: - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = LlamaLinearScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - ) - elif scaling_type == "dynamic": - self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - ) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - - bsz, q_len, _ = hidden_states.size() - - if self.config.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - ### INC code ### - key_states = self.cast1(key_states) - value_states = self.cast2(value_states) - # import habana_frameworks.torch.core as htcore - # htcore.mark_step() - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - ### INC code ### - attn_weights = self.matmul1(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - - ### INC code ### - attn_output = self.matmul2(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - if self.config.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) - else: - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class LlamaFlashAttention2(LlamaAttention): - """ - Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # LlamaFlashAttention2 attention does not support output_attentions - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - - # overwrite attention_mask with padding_mask - attention_mask = kwargs.pop("padding_mask") - - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - # Handle the case where the model is quantized - if hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = self._flash_attention_forward( - query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate - ) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - def _flash_attention_forward( - self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None - ): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`int`, *optional*): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - """ - # Contains at least one padding token in the sequence - if attention_mask is not None: - batch_size = query_states.shape[0] - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length - ) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=self.is_causal, - ) - - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) - else: - attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal - ) - - return attn_output - - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - - key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) - value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) - if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k - ) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -class LlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = ( - LlamaAttention(config=config) - if not getattr(config, "_flash_attn_2_enabled", False) - else LlamaFlashAttention2(config=config) - ) - self.mlp = LlamaMLP(config) - self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - **kwargs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -LLAMA_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`LlamaConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class LlamaPreTrainedModel(PreTrainedModel): - config_class = LlamaConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["LlamaDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - -LLAMA_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class LlamaModel(LlamaPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] - - Args: - config: LlamaConfig - """ - - def __init__(self, config: LlamaConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape[:2] - elif inputs_embeds is not None: - batch_size, seq_length = inputs_embeds.shape[:2] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - past_key_values_length = 0 - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0) - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if getattr(self.config, "_flash_attn_2_enabled", False): - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - - # embed positions - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - past_key_value, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class LlamaForCausalLM(LlamaPreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.model = LlamaModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - if self.config.pretraining_tp > 1: - lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) - logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] - logits = torch.cat(logits, dim=-1) - else: - logits = self.lm_head(hidden_states) - logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - - -@add_start_docstrings( - """ - The LLaMa Model transformer with a sequence classification head on top (linear layer). - - [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - LLAMA_START_DOCSTRING, -) -class LlamaForSequenceClassification(LlamaPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = LlamaModel(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to( - logits.device - ) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/tokenization_baichuan.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/tokenization_baichuan.py deleted file mode 100644 index 5b7054d3227..00000000000 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/tokenization_baichuan.py +++ /dev/null @@ -1,255 +0,0 @@ -# Copyright 2023 Baichuan Inc. All Rights Reserved. - -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from shutil import copyfile -from typing import Any, Dict, List, Optional, Tuple - -import sentencepiece as spm - -from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer -from transformers.utils import logging - - -logger = logging.get_logger(__name__) - -VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} - -PRETRAINED_VOCAB_FILES_MAP = { - "vocab_file": {}, - "tokenizer_file": {}, -} -PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} - - -class BaichuanTokenizer(PreTrainedTokenizer): - """ - Construct a Baichuan tokenizer. Based on byte-level Byte-Pair-Encoding. - - Args: - vocab_file (`str`): - Path to the vocabulary file. - """ - - vocab_files_names = VOCAB_FILES_NAMES - pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP - max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - model_input_names = ["input_ids", "attention_mask"] - - def __init__( - self, - vocab_file, - unk_token="", - bos_token="", - eos_token="", - pad_token=None, - sp_model_kwargs: Optional[Dict[str, Any]] = None, - add_bos_token=True, - add_eos_token=False, - clean_up_tokenization_spaces=False, - **kwargs, - ): - self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs - bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token - eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token - unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token - pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token - ### INC code ### - self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) - self.sp_model.Load(vocab_file) - - super().__init__( - bos_token=bos_token, - eos_token=eos_token, - unk_token=unk_token, - pad_token=pad_token, - add_bos_token=add_bos_token, - add_eos_token=add_eos_token, - sp_model_kwargs=self.sp_model_kwargs, - clean_up_tokenization_spaces=clean_up_tokenization_spaces, - **kwargs, - ) - self.vocab_file = vocab_file - self.add_bos_token = add_bos_token - self.add_eos_token = add_eos_token - #self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) - #self.sp_model.Load(vocab_file) - - def __getstate__(self): - state = self.__dict__.copy() - state["sp_model"] = None - return state - - def __setstate__(self, d): - self.__dict__ = d - self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) - self.sp_model.Load(self.vocab_file) - - @property - def vocab_size(self): - """Returns vocab size""" - return self.sp_model.get_piece_size() - - def get_vocab(self): - """Returns vocab as a dict""" - vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} - vocab.update(self.added_tokens_encoder) - return vocab - - def _tokenize(self, text): - """Returns a tokenized string.""" - return self.sp_model.encode(text, out_type=str) - - def _convert_token_to_id(self, token): - """Converts a token (str) in an id using the vocab.""" - return self.sp_model.piece_to_id(token) - - def _convert_id_to_token(self, index): - """Converts an index (integer) in a token (str) using the vocab.""" - token = self.sp_model.IdToPiece(index) - return token - - def convert_tokens_to_string(self, tokens): - """Converts a sequence of tokens (string) in a single string.""" - current_sub_tokens = [] - out_string = "" - prev_is_special = False - for i, token in enumerate(tokens): - # make sure that special tokens are not decoded using sentencepiece model - if token in self.all_special_tokens: - if not prev_is_special and i != 0: - out_string += " " - out_string += self.sp_model.decode(current_sub_tokens) + token - prev_is_special = True - current_sub_tokens = [] - else: - current_sub_tokens.append(token) - prev_is_special = False - out_string += self.sp_model.decode(current_sub_tokens) - return out_string - - def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: - """ - Save the vocabulary and special tokens file to a directory. - - Args: - save_directory (`str`): - The directory in which to save the vocabulary. - - Returns: - `Tuple(str)`: Paths to the files saved. - """ - if not os.path.isdir(save_directory): - logger.error(f"Vocabulary path ({save_directory}) should be a directory") - return - out_vocab_file = os.path.join( - save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] - ) - - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): - copyfile(self.vocab_file, out_vocab_file) - elif not os.path.isfile(self.vocab_file): - with open(out_vocab_file, "wb") as fi: - content_spiece_model = self.sp_model.serialized_model_proto() - fi.write(content_spiece_model) - - return (out_vocab_file,) - - def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): - bos_token_id = [self.bos_token_id] if self.add_bos_token else [] - eos_token_id = [self.eos_token_id] if self.add_eos_token else [] - - output = bos_token_id + token_ids_0 + eos_token_id - - if token_ids_1 is not None: - output = output + bos_token_id + token_ids_1 + eos_token_id - - return output - - def get_special_tokens_mask( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False - ) -> List[int]: - """ - Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding - special tokens using the tokenizer `prepare_for_model` method. - - Args: - token_ids_0 (`List[int]`): - List of IDs. - token_ids_1 (`List[int]`, *optional*): - Optional second list of IDs for sequence pairs. - already_has_special_tokens (`bool`, *optional*, defaults to `False`): - Whether or not the token list is already formatted with special tokens for the model. - - Returns: - `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. - """ - if already_has_special_tokens: - return super().get_special_tokens_mask( - token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True - ) - - bos_token_id = [1] if self.add_bos_token else [] - eos_token_id = [1] if self.add_eos_token else [] - - if token_ids_1 is None: - return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id - return ( - bos_token_id - + ([0] * len(token_ids_0)) - + eos_token_id - + bos_token_id - + ([0] * len(token_ids_1)) - + eos_token_id - ) - - def create_token_type_ids_from_sequences( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None - ) -> List[int]: - """ - Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT - sequence pair mask has the following format: - - ``` - 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 - | first sequence | second sequence | - ``` - - if token_ids_1 is None, only returns the first portion of the mask (0s). - - Args: - token_ids_0 (`List[int]`): - List of ids. - token_ids_1 (`List[int]`, *optional*): - Optional second list of IDs for sequence pairs. - - Returns: - `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). - """ - bos_token_id = [self.bos_token_id] if self.add_bos_token else [] - eos_token_id = [self.eos_token_id] if self.add_eos_token else [] - - output = [0] * len(bos_token_id + token_ids_0 + eos_token_id) - - if token_ids_1 is not None: - output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) - - return output diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/requirement.txt b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/requirement.txt deleted file mode 100644 index d3655acd742..00000000000 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/requirement.txt +++ /dev/null @@ -1,7 +0,0 @@ -transformers -datasets -accelerate -SentencePiece -lm_eval==0.3.0 -openpyxl -einops diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py deleted file mode 100644 index e77ef2c6a33..00000000000 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py +++ /dev/null @@ -1,222 +0,0 @@ -import os -os.environ["EXPERIMENTAL_WEIGHT_SHARING"] = "False" - -### USE_GAUDI2_SCALE requires PT_USE_FP8_AMAX for torch.mm/bmm, or got failure -# os.environ["USE_GAUDI2_SCALE"] = "True" -# os.environ["PT_USE_FP8_AMAX"] = "True" - -### graphs will dump to .graph_dumps folder -# os.environ["GRAPH_VISUALIZATION"] = "True" -# import shutil -# shutil.rmtree(".graph_dumps", ignore_errors=True) - -import argparse -import time -import json -import re -import torch -import habana_frameworks.torch.hpex -import torch.nn.functional as F -import deepspeed -import transformers -from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig -import habana_frameworks.torch.core as htcore - -from utils import show_msg, eval_func, init_empty_model, init_model, init_tokenizer - - -torch.set_grad_enabled(False) -htcore.hpu_set_env() -torch.device('hpu') - - -parser = argparse.ArgumentParser() -parser.add_argument( - "--model", nargs="?", default="facebook/opt-125m" -) -parser.add_argument( - "--trust_remote_code", default=True, - help="Transformers parameter: use the external repo") -parser.add_argument( - "--revision", default=None, - help="Transformers parameter: set the model hub commit number") -parser.add_argument("--dataset", nargs="?", default="NeelNanda/pile-10k", const="NeelNanda/pile-10k") -parser.add_argument("--output_dir", nargs="?", default="./saved_results") -parser.add_argument("--to_graph", action="store_true") -parser.add_argument("--approach", type=str, default=None, - help="Select from ['dynamic', 'static' 'cast']") -parser.add_argument("--precision", type=str, default='fp32', - help="Select from ['fp8_e4m3', 'fp8_e5m2', 'bf16', 'fp16', 'fp32'], \ - ['bf16', 'fp16'] only work with cast approach") -parser.add_argument("--autotune", action="store_true") -parser.add_argument("--accuracy", action="store_true") -parser.add_argument("--performance", action="store_true") -parser.add_argument("--generate", action="store_true") -parser.add_argument("--skip_fp8_mm", action="store_true") -parser.add_argument("--dump_to_excel", action="store_true") -parser.add_argument("--save", action="store_true") -parser.add_argument("--load", action="store_true") -parser.add_argument("--batch_size", default=1, type=int, - help="For accuracy measurement only.") -parser.add_argument("--pad_max_length", default=512, type=int, - help="Pad input ids to max length.") -parser.add_argument("--calib_iters", default=100, type=int, - help="calibration iters.") -parser.add_argument("--tasks", nargs='+', default=["lambada_openai"], \ - type=str, choices=["hellaswag", "lambada_openai", "piqa", "winogrande", "copa", - "rte", "openbookqa", "lambada_standard", "wikitext"], - help="tasks list for accuracy validation") -parser.add_argument("--limit", default=None, type=int, - help="the sample num of evaluation.") -parser.add_argument("--max_new_tokens", default=100, type=int, - help="calibration iters.") -parser.add_argument('--buckets', type=int, nargs='+', \ - help="Input length buckets to use with static_shapes", default=[256, 512]) -parser.add_argument("--local_rank", - type=int, - default=-1, - help="local_rank for distributed training on gpus") -parser.add_argument("--skip_lm_head", action="store_true") -args = parser.parse_args() - - -world_size = int(os.getenv('WORLD_SIZE', '1')) -local_rank = int(os.getenv('LOCAL_RANK', '-1')) - - -if args.load: - user_model = init_empty_model(args.model) -else: - user_model = init_model(args) -user_model.eval() - - -tokenizer = init_tokenizer(args) - - -### dynamic & static quantization ### -if args.approach in ["dynamic", "static"] and not args.load: - print("device:", next(user_model.parameters()).device) - from neural_compressor.torch.quantization import ( - quantize, autotune, FP8Config, get_default_fp8_config, TuningConfig, get_default_fp8_config_set - ) - dtype = args.precision - if args.approach == "dynamic": - from neural_compressor.torch.algorithms.habana_fp8 import quantize_dynamic - user_model = quantize_dynamic(user_model, dtype, inplace=True) - elif args.approach == "static": - qconfig = FP8Config(w_dtype=dtype, act_dtype=dtype, approach="static") - if args.skip_lm_head: - fp32_config = FP8Config(w_dtype="fp32", act_dtype="fp32") - qconfig.set_local("lm_head", fp32_config) - # dataset - from datasets import load_dataset - calib_dataset = load_dataset(args.dataset, split="train").select(range(100)) - calib_dataset = calib_dataset.shuffle(seed=42) - calib_data = [] - for examples in calib_dataset: - calib_data.append( - tokenizer( - examples["text"], - return_tensors="pt", - max_length=64, - padding="max_length", - truncation=True - ) - ) - - def calib_func(model): - for i, calib_input in enumerate(calib_data): - if i >= args.calib_iters: - break - model( - input_ids=calib_input["input_ids"].to('hpu'), - attention_mask=calib_input["attention_mask"].to('hpu'), - ) - - user_model = quantize(user_model, qconfig, calib_func, inplace=True) - # saving - print(user_model) - if args.save and local_rank in [-1, 0]: - user_model.save("saved_results") - - -if args.load: - from neural_compressor.torch.quantization import load - user_model = load("saved_results", user_model) - - -if args.approach in ["dynamic", "static"] or args.load: - # It enables weights constant folding - from habana_frameworks.torch.core.quantization import _check_params_as_const, _mark_params_as_const - _mark_params_as_const(user_model) # can reduce memory allocated and speed up - _check_params_as_const(user_model) - - - -# If torch.matmul and torch.bmm are not replaced by INC module, -# Below codes can make torch.matmul and torch.bmm run on fp8 by injection. -if not args.skip_fp8_mm and args.precision in ['fp8_e4m3', 'fp8_e5m2']: - def replace_torch_mm_bmm(): - from neural_compressor.torch.amp.fp8.functions import fp8_matmul - torch.matmul = fp8_matmul - torch.bmm = fp8_matmul - - replace_torch_mm_bmm() - - -# inference optimization -if args.to_graph: - import habana_frameworks.torch.hpu.graphs as htgraphs - user_model = htgraphs.wrap_in_hpu_graph(user_model) - - -# dump message of HPU after quantization or reloading -show_msg() - - -### generation, performance and accuracy validation ### -if args.generate: - input_prompt = "Here is my prompt" - print("Prompt sentence:", input_prompt) - generation_config = { - "min_new_tokens": args.max_new_tokens, "max_new_tokens": args.max_new_tokens, - # "do_sample": False, "temperature": 0.9, "num_beams": 4, - } - input_tokens = tokenizer(input_prompt, return_tensors="pt").to('hpu') - eval_start = time.perf_counter() - if args.approach == "cast": - from neural_compressor.torch.amp import autocast - if args.precision == "fp8_e4m3": - dtype = torch.float8_e4m3fn - elif args.precision == "fp8_e5m2": - dtype = torch.float8_e5m2 - elif args.precision == "fp16": - dtype = torch.float16 - elif args.precision == "bf16": - dtype = torch.bfloat16 - with autocast('hpu', dtype=dtype): - outputs = user_model.generate(**input_tokens, **generation_config) - else: - outputs = user_model.generate(**input_tokens, **generation_config) - - output_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True) - eval_end = time.perf_counter() - print("Generated sentence:", output_sentence) - print("Duration:", eval_end - eval_start) - - -if args.performance: - eval_start = time.perf_counter() - input_prompt = "Intel is a company which" - input_tokens = torch.ones((1, 128), dtype=torch.long).to('hpu') - generation_config = {"min_new_tokens": 100, "max_new_tokens": 100} - outputs = user_model.generate(input_tokens, **generation_config) - print("Duration of generating 100 tokens :", time.perf_counter() - eval_start) - - -if args.accuracy: - eval_func(user_model, tokenizer=tokenizer, args=args) - -# dump final message of HPU -show_msg() diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/utils.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/utils.py deleted file mode 100644 index 843287cddfa..00000000000 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/utils.py +++ /dev/null @@ -1,255 +0,0 @@ -import os -import re -import torch -from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer - - -world_size = int(os.getenv('WORLD_SIZE', '1')) -local_rank = int(os.getenv('LOCAL_RANK', '-1')) - - -def init_model(args): - import deepspeed - model_dtype = torch.float32 - if re.search("llama", args.model.lower()) or re.search("bloom", args.model.lower()): - if world_size > 1: - config = AutoConfig.from_pretrained(args.model) - model_dtype = torch.bfloat16 # RuntimeErrorCastToFp8V2 input must be of float or bfloat16 dtype - deepspeed.init_distributed(dist_backend="hccl") - with deepspeed.OnDevice(dtype=model_dtype, device="meta"): - user_model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype) - import tempfile - checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w") - from optimum.habana.checkpoint_utils import write_checkpoints_json # in optimum-habana - write_checkpoints_json( - args.model, - local_rank, - checkpoints_json, - token=None, - ) - else: - user_model = AutoModelForCausalLM.from_pretrained( - args.model, - device_map='hpu', - torch_dtype=model_dtype, - ) - elif re.search("chatglm", args.model.lower()): - from models.modeling_chatglm import ChatGLMForConditionalGeneration - user_model = ChatGLMForConditionalGeneration.from_pretrained( - args.model, - revision=args.revision, - device_map='hpu', - torch_dtype=model_dtype, - ) - # print(user_model.transformer.output_layer.weight.dtype) # always fp16 - user_model.float() # static fp8 need float32 for graph compiler - else: - user_model = AutoModelForCausalLM.from_pretrained( - args.model, - trust_remote_code=args.trust_remote_code, - revision=args.revision, - device_map='hpu', - torch_dtype=model_dtype, - ) - # load weight for multi-cards - if world_size > 1: - if re.search("llama", args.model.lower()) or re.search("bloom", args.model.lower()): - ds_inference_kwargs = {"dtype": model_dtype} - ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size} - ds_inference_kwargs["enable_cuda_graph"] = False - from transformers.models.llama.modeling_llama import LlamaDecoderLayer - ds_inference_kwargs["injection_policy"] = {LlamaDecoderLayer: ("self_attn.o_proj", "mlp.down_proj")} - ds_inference_kwargs["checkpoint"] = checkpoints_json.name - ds_model = deepspeed.init_inference(user_model, **ds_inference_kwargs) - else: - ds_model = deepspeed.init_inference(user_model, - mp_size=world_size, - replace_with_kernel_inject=False) - user_model = ds_model.module - return user_model - - -def init_empty_model(model_name): - from accelerate import init_empty_weights - model_dtype = torch.float32 - config = AutoConfig.from_pretrained(model_name) - with init_empty_weights(): - model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype) - return model - - -def init_tokenizer(args): - # tokenizer - if re.search("baichuan", args.model.lower()): - from models.tokenization_baichuan import BaichuanTokenizer - tokenizer = BaichuanTokenizer.from_pretrained( - args.model, - trust_remote_code=args.trust_remote_code - ) - else: - tokenizer = AutoTokenizer.from_pretrained( - args.model, - trust_remote_code=args.trust_remote_code - ) - tokenizer.pad_token = tokenizer.eos_token - return tokenizer - - -def show_msg(): - import numpy as np - import glob - from habana_frameworks.torch.hpu import memory_stats - print("Number of HPU graphs:", len(glob.glob(".graph_dumps/*PreGraph*"))) - mem_stats = memory_stats() - mem_dict = { - "memory_allocated (GB)": np.round(mem_stats["InUse"] / 1024**3, 2), - "max_memory_allocated (GB)": np.round(mem_stats["MaxInUse"] / 1024**3, 2), - "total_memory_available (GB)": np.round(mem_stats["Limit"] / 1024**3, 2), - } - for k, v in mem_dict.items(): - print("{:35} = {} GB".format(k[:-5].replace("_", " ").capitalize(), v)) - - -def itrex_bootstrap_stderr(f, xs, iters): - from lm_eval.metrics import _bootstrap_internal, sample_stddev - res = [] - chunk_size = min(1000, iters) - it = _bootstrap_internal(f, chunk_size) - for i in range(iters // chunk_size): - bootstrap = it((i, xs)) - res.extend(bootstrap) - return sample_stddev(res) - - -def save_to_excel(dict): - import pandas as pd - df_new = pd.DataFrame(dict) - try: - df_existing = pd.read_excel('output.xlsx') - except FileNotFoundError: - df_existing = pd.DataFrame() - df_combined = pd.concat([df_existing, df_new], axis=0, ignore_index=True) - df_combined.to_excel('output.xlsx', index=False, engine='openpyxl', header=True) - - -def eval_func(user_model, tokenizer, args): - import os - import re - import time - import json - import torch - import habana_frameworks.torch.hpex - import torch.nn.functional as F - import lm_eval - import lm_eval.tasks - import lm_eval.evaluator - - # to avoid out-of-memory caused by Popen for large language models. - lm_eval.metrics.bootstrap_stderr = itrex_bootstrap_stderr - - class HabanaModelAdapter(lm_eval.base.BaseLM): - def __init__(self, tokenizer, model, args, options): - super().__init__() - self.tokenizer = tokenizer - self.model = model.eval() - self._batch_size = args.batch_size - self.buckets = list(sorted(args.buckets)) - self.options = options - self._device = "hpu" - torch.set_grad_enabled(False) - - @property - def eot_token_id(self): - return self.model.config.eos_token_id - - @property - def max_length(self): - return self.buckets[-1] - - @property - def max_gen_toks(self): - raise NotImplementedError() - - @property - def batch_size(self): - return self._batch_size - - @property - def device(self): - # We need to do padding ourselves, otherwise we'll end up with recompilations - # Returning 'cpu' to keep tensors on CPU in lm_eval code - return 'cpu' # 'hpu' - - def tok_encode(self, string): - if ( - re.search("chatglm3", args.model.lower()) or - re.search("llama", args.model.lower()) or - re.search("mistral", args.model.lower()) - ): - string = string.lstrip() - return self.tokenizer.encode(string, add_special_tokens=False) - - def tok_decode(self, tokens): - return self.tokenizer.decode(tokens, skip_special_tokens=True) - - def _model_generate(self, context, max_length, eos_token_id): - raise NotImplementedError() - - def find_bucket(self, length): - return [b for b in self.buckets if b >= length][0] - - def _model_call(self, inputs): - seq_length = inputs.shape[-1] - padding_length = 0 - bucket_length = self.find_bucket(seq_length) - padding_length = bucket_length - seq_length - inputs = F.pad(inputs, (0, padding_length), value=self.model.config.pad_token_id) - logits = self.model(inputs.to(self._device))["logits"].cpu() - - if padding_length > 0: - logits = logits[:, :-padding_length, :] - logits = logits.to(torch.float32) - return logits - - lm_tasks = lm_eval.tasks.get_task_dict(args.tasks) - options = None - lm = HabanaModelAdapter(tokenizer, user_model, args, options) - - eval_start = time.perf_counter() - if args.approach == "cast": - from neural_compressor.torch.amp import autocast - if args.precision == "fp8_e4m3": - dtype = torch.float8_e4m3fn - elif args.precision == "fp8_e5m2": - dtype = torch.float8_e5m2 - elif args.precision == "fp16": - dtype = torch.float16 - elif args.precision == "bf16": - dtype = torch.bfloat16 - with autocast('hpu', dtype=dtype): - results = lm_eval.evaluator.evaluate(lm, lm_tasks, limit=args.limit) - else: - results = lm_eval.evaluator.evaluate(lm, lm_tasks, limit=args.limit) - print(lm_eval.evaluator.make_table(results)) - eval_end = time.perf_counter() - print("Duration:", eval_end - eval_start) - results['args'] = vars(args) - results['duration'] = eval_end - eval_start - - # make sure that result is dumped only once during multi-cards evaluation - local_rank = int(os.getenv('LOCAL_RANK', '-1')) - if local_rank in [-1, 0]: - dumped = json.dumps(results, indent=2) - accu_dict = {} - case_name = str(args.approach) + "-" + args.precision - for task_name in args.tasks: - if task_name == "wikitext": - print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["word_perplexity"]), flush=True) - accu_dict[task_name] = [args.model, case_name, results["results"][task_name]["word_perplexity"]] - else: - print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["acc"]), flush=True) - accu_dict[task_name] = [args.model, case_name, results["results"][task_name]["acc"]] - accu_dict["duration"] = [args.model, case_name, results["duration"]] - if args.dump_to_excel: - save_to_excel(accu_dict) - return results["results"][task_name]["acc"] diff --git a/examples/fp8_sample/README.md b/examples/fp8_sample/README.md new file mode 100644 index 00000000000..b758768ef0f --- /dev/null +++ b/examples/fp8_sample/README.md @@ -0,0 +1,96 @@ +### Usage demo: + +#### two steps to get quantized model + +```diff +import torch ++ from neural_compressor.torch.quantization import FP8Config, convert, prepare, finalize_calibration +import habana_frameworks.torch.core as htcore + +class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(10, 5) + self.fc2 = torch.nn.Linear(5, 10) + + def forward(self, inp): + x1 = self.fc1(inp) + x2 = self.fc2(x1) + return x2 + +model = M().eval() + ++ config = FP8Config.from_json_file(args.quant_config) # args.quant_config is the path of json file + ++ if config.measure: ++ model = prepare(model, config) + ++ if config.quantize: ++ htcore.hpu_initialize() ++ model = convert(model, config) + +# user code run +with torch.no_grad(): + model.to("hpu") + output = model(torch.randn(1, 10).to("hpu")) + print(output) + ++ if config.measure: ++ finalize_calibration(model) +``` + + +Whole script and config refer to [sample_two_steps.py](./sample_two_steps.py), [maxabs_measure.json](./maxabs_measure.json) and [maxabs_quant.json](./maxabs_quant.json). + +First, measure the tensor quantization statistic: +```shell +python sample_two_steps.py --quant_config=maxabs_measure.json +``` + +Then quantize the model based on previous measurements: +```shell +python sample_two_steps.py --quant_config=maxabs_quant.json +``` + +#### one step to get quantized model + +```diff +import torch ++ from neural_compressor.torch.quantization import FP8Config, convert, prepare, finalize_calibration +import habana_frameworks.torch.core as htcore + +class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(10, 5) + self.fc2 = torch.nn.Linear(5, 10) + + def forward(self, inp): + x1 = self.fc1(inp) + x2 = self.fc2(x1) + return x2 + +model = M().to("hpu") + ++ config = FP8Config.from_json_file(args.quant_config) # args.quant_config is the path of json file ++ model = prepare(model, config) + +# user code run to do calibration +with torch.no_grad(): + output = model(torch.randn(1, 10).to("hpu")) + print(output) + ++ finalize_calibration(model) ++ model = convert(model) + +# user code to run benchmark for quantized model +with torch.no_grad(): + output = model(torch.randn(1, 10).to("hpu")) + print(output) +``` + +Whole script and config refer to [sample_one_step.py](./sample_one_step.py). + +```shell +python sample_one_step.py --quant_config=quant_config.json +``` diff --git a/examples/fp8_sample/maxabs_measure.json b/examples/fp8_sample/maxabs_measure.json new file mode 100644 index 00000000000..8d55f33e57a --- /dev/null +++ b/examples/fp8_sample/maxabs_measure.json @@ -0,0 +1,7 @@ +{ + "mode": "MEASURE", + "observer": "maxabs", + "allowlist": {"types": [], "names": []}, + "blocklist": {"types": [], "names": []}, + "dump_stats_path": "./hqt_output/measure" +} diff --git a/examples/fp8_sample/maxabs_quant.json b/examples/fp8_sample/maxabs_quant.json new file mode 100644 index 00000000000..d1f76f8f630 --- /dev/null +++ b/examples/fp8_sample/maxabs_quant.json @@ -0,0 +1,8 @@ +{ + "mode": "QUANTIZE", + "observer": "maxabs", + "scale_method": "maxabs_hw", + "allowlist": {"types": [], "names": []}, + "blocklist": {"types": [], "names": []}, + "dump_stats_path": "./hqt_output/measure" +} diff --git a/examples/fp8_sample/quant_config.json b/examples/fp8_sample/quant_config.json new file mode 100644 index 00000000000..c139d13bbea --- /dev/null +++ b/examples/fp8_sample/quant_config.json @@ -0,0 +1,8 @@ +{ + "mode": "AUTO", + "observer": "maxabs", + "scale_method": "maxabs_hw", + "allowlist": {"types": [], "names": []}, + "blocklist": {"types": [], "names": []}, + "dump_stats_path": "./hqt_output/measure" +} diff --git a/examples/fp8_sample/sample_one_step.py b/examples/fp8_sample/sample_one_step.py new file mode 100644 index 00000000000..18eb7bfba4c --- /dev/null +++ b/examples/fp8_sample/sample_one_step.py @@ -0,0 +1,57 @@ +import argparse +import torch +import habana_frameworks.torch.core as htcore +htcore.hpu_set_env() + +from neural_compressor.torch.quantization import FP8Config, convert, finalize_calibration, prepare + +torch.manual_seed(1) + + +# 1. python sample_one_step.py --quant_config=quant_config.json + + +class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(10, 5) + self.fc2 = torch.nn.Linear(5, 10) + + def forward(self, inp): + x1 = self.fc1(inp) + x2 = self.fc2(x1) + return x2 + + +def eval_func(model): + # user's eval func + input = torch.randn(1, 10) + model(input.to("hpu")) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Habana FP8 sample code.", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("--quant_config", type=str, help="json file of quantization config") + args = parser.parse_args() + + model = M().eval().to("hpu") + htcore.hpu_initialize() + + config = FP8Config.from_json_file(args.quant_config) + model = prepare(model, config) + + # for calibration + with torch.no_grad(): + # model.to("hpu") + output = model(torch.randn(1, 10).to("hpu")) + + finalize_calibration(model) + model = convert(model) + print(model) + + # for benchmark + with torch.no_grad(): + output = model(torch.randn(1, 10).to("hpu")) + print(output) diff --git a/examples/fp8_sample/sample_two_steps.py b/examples/fp8_sample/sample_two_steps.py new file mode 100644 index 00000000000..9e17748b9b0 --- /dev/null +++ b/examples/fp8_sample/sample_two_steps.py @@ -0,0 +1,50 @@ +import argparse +import torch +import habana_frameworks.torch.core as htcore +htcore.hpu_set_env() + +from neural_compressor.torch.quantization import FP8Config, convert, finalize_calibration, prepare + +torch.manual_seed(1) + +# 1. python sample_two_steps.py --quant_config=maxabs_measure.json +# 2. python sample_two_steps.py --quant_config=maxabs_quant.json + + +class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(10, 5) + self.fc2 = torch.nn.Linear(5, 10) + + def forward(self, inp): + x1 = self.fc1(inp) + x2 = self.fc2(x1) + return x2 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Habana FP8 sample code.", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("--quant_config", type=str, help="json file of quantization config") + args = parser.parse_args() + + model = M().eval() + config = FP8Config.from_json_file(args.quant_config) + + if config.measure: + model = prepare(model, config) + + if config.quantize: + htcore.hpu_initialize() + model = convert(model, config) + print(model) + + with torch.no_grad(): + model.to("hpu") + output = model(torch.randn(1, 10).to("hpu")) + print(output) + + if config.measure: + finalize_calibration(model) diff --git a/neural_compressor/torch/algorithms/habana_fp8/__init__.py b/neural_compressor/torch/algorithms/fp8_quant/__init__.py similarity index 70% rename from neural_compressor/torch/algorithms/habana_fp8/__init__.py rename to neural_compressor/torch/algorithms/fp8_quant/__init__.py index fe3a05d7d0b..d16760b5e81 100644 --- a/neural_compressor/torch/algorithms/habana_fp8/__init__.py +++ b/neural_compressor/torch/algorithms/fp8_quant/__init__.py @@ -12,5 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .fp8_quant import quantize_dynamic, quantize, white_list -from .save_load import save, load +from neural_compressor.torch.algorithms.fp8_quant.common import ( + update_mode, + save_calib_result, + restore_patched_module, + with_patched_module, +) +from neural_compressor.torch.algorithms.fp8_quant.fp8_quant import FP8Quantizer diff --git a/neural_compressor/torch/algorithms/fp8_quant/common.py b/neural_compressor/torch/algorithms/fp8_quant/common.py new file mode 100644 index 00000000000..b038a367a78 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/common.py @@ -0,0 +1,98 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import tempfile +from collections import namedtuple +from pathlib import Path +from typing import Union + +import torch + + +def save_calib_result(model): + import habana_quantization_toolkit as hqt + hqt.finish_measurements(model) + + +def update_mode(config_path, measure_step=False, quant_step=False): + with open(config_path, 'r') as file: + config = json.load(file) + + if (measure_step and config.get("mode") == "MEASURE") or (quant_step and config.get("mode") == "QUANTIZE"): + return config_path + else: + if measure_step: + config["mode"] = "MEASURE" + if quant_step: + config["mode"] = "QUANTIZE" + + temp_file = tempfile.NamedTemporaryFile(suffix=".json", delete=False) + temp_file_path = temp_file.name + + with open(temp_file_path, 'w') as temp_file: + json.dump(config, temp_file) + + return temp_file_path + + +def generate_model_info(model): + mod_inst_info = namedtuple("ModInstInfo", ["name", "parent"]) + parent_child_mod_dict = {} + + def create_mod_info_recursion(parent): + for name, mod in parent.named_children(): + parent_child_mod_dict[mod] = mod_inst_info(name=name, parent=parent) + create_mod_info_recursion(mod) + + create_mod_info_recursion(model) + return parent_child_mod_dict + + +def get_patched_mod_list(): + from habana_quantization_toolkit._core.common import mod_default_dict + + patched_mod_list = [] + for patched_mod in mod_default_dict.values(): + patched_mod_list.append(patched_mod.patched_module.__name__) + return patched_mod_list + + +def restore_patched_module(patched_model): + from neural_compressor.torch.algorithms.fp8_quant.helper_modules import helper_mods + patched_mod_list = get_patched_mod_list() + + parent_child_mod_dict = generate_model_info(patched_model) + with torch.no_grad(): + for name, patched_mod in patched_model.named_modules(): + patched_mod_type_str = patched_mod.__class__.__name__ + if patched_mod_type_str in patched_mod_list: + parent = parent_child_mod_dict[patched_mod].parent + name = parent_child_mod_dict[patched_mod].name + class_name_org = getattr(patched_mod, "class_name_org", None) or \ + patched_mod.__class__.__name__.split("Patched")[-1] + origin_mod = helper_mods[class_name_org](patched_mod) + origin_mod.forward = patched_mod.forward_orig + setattr(parent, name, origin_mod) + + +def with_patched_module(model): + patched_mod_list = get_patched_mod_list() + + for name, mod in model.named_modules(): + mod_type = mod.__class__.__name__ + if mod_type in patched_mod_list: + return True + return False diff --git a/neural_compressor/torch/algorithms/fp8_quant/fp8_quant.py b/neural_compressor/torch/algorithms/fp8_quant/fp8_quant.py new file mode 100644 index 00000000000..f9ce9145569 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/fp8_quant.py @@ -0,0 +1,61 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from neural_compressor.common.utils import FP8_QUANT +from neural_compressor.torch.algorithms import Quantizer +from neural_compressor.torch.algorithms.fp8_quant import ( + restore_patched_module, + update_mode, + with_patched_module, +) + + +class FP8Quantizer(Quantizer): + def __init__(self, quant_config): + super().__init__(quant_config) + if isinstance(quant_config, dict): + json_file = [cfg.json_file for cfg in quant_config.values()] + assert len(json_file) > 0, "Cannot get json file from config." + self.quant_config = json_file[0] + + def prepare(self, model): + _prepare(model, self.quant_config) + return model + + def convert(self, model): + if with_patched_module(model): + # for INC flow, it calls `prepare` and then `convert` user-facing API in one run + restore_patched_module(model) + _convert(model, self.quant_config) + return model + + +def _convert(model, config_path): + import habana_quantization_toolkit as hqt + + # update mode to QUANTIZE + config_path = update_mode(config_path, quant_step=True) + + return hqt.prep_model(model, config_path) + + +def _prepare(model, config_path): + import habana_quantization_toolkit as hqt + + # update mode to MEASURE + config_path = update_mode(config_path, measure_step=True) + + return hqt.prep_model(model, config_path) diff --git a/neural_compressor/torch/algorithms/fp8_quant/helper_modules.py b/neural_compressor/torch/algorithms/fp8_quant/helper_modules.py new file mode 100644 index 00000000000..6c7154328d7 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/helper_modules.py @@ -0,0 +1,118 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +# For mapping revert patched module to origin module + +helper_mods = {} + +def helper_mod_register(name): + def decorator(mod): + helper_mods[name] = mod + return mod + return decorator + +@helper_mod_register(name="Matmul") +class Matmul(torch.nn.Module): + def __init__(self, patched_mod, *args, **kwargs): + super().__init__() + self.__dict__.update(patched_mod.__dict__) + self.extra_repr = patched_mod.extra_repr_org + +@helper_mod_register(name="Linear") +class Linear(torch.nn.Module): + def __init__(self, patched_mod, *args, **kwargs): + super().__init__() + self.__dict__.update(patched_mod.__dict__) + self.extra_repr = patched_mod.extra_repr_org + +@helper_mod_register(name="FalconLinear") +class FalconLinear(torch.nn.Module): + def __init__(self, patched_mod, *args, **kwargs): + super().__init__() + self.__dict__.update(patched_mod.__dict__) + self.extra_repr = patched_mod.extra_repr_org + +@helper_mod_register(name="KVCache") +class KVCache(torch.nn.Module): + def __init__(self, patched_mod, *args, **kwargs): + super().__init__() + self.allocate = patched_mod.org_allocate + self.get_shape = patched_mod.get_shape + self.forward = patched_mod.forward + self.update = patched_mod.update + +@helper_mod_register(name="Conv2d") +class Conv2d(torch.nn.Module): + def __init__(self, patched_mod, *args, **kwargs): + super().__init__() + self.__dict__.update(patched_mod.__dict__) + self.extra_repr = patched_mod.extra_repr_org + +@helper_mod_register(name="LoRACompatibleLinear") +class LoRACompatibleLinear(torch.nn.Module): + def __init__(self, patched_mod, *args, **kwargs): + super().__init__() + self.__dict__.update(patched_mod.__dict__) + self.extra_repr = patched_mod.extra_repr_org + +@helper_mod_register(name="LoRACompatibleConv") +class LoRACompatibleConv(torch.nn.Module): + def __init__(self, patched_mod, *args, **kwargs): + super().__init__() + self.__dict__.update(patched_mod.__dict__) + self.extra_repr = patched_mod.extra_repr_org + +@helper_mod_register(name="Softmax") +class Softmax(torch.nn.Module): + def __init__(self, patched_mod, *args, **kwargs): + super().__init__() + self.__dict__.update(patched_mod.__dict__) + self.extra_repr = patched_mod.extra_repr_org + +@helper_mod_register(name="LinearLayer") +class LinearLayer(torch.nn.Module): + def __init__(self, patched_mod, *args, **kwargs): + super().__init__() + self.__dict__.update(patched_mod.__dict__) + self.extra_repr = patched_mod.extra_repr_org + +@helper_mod_register(name="LinearAllreduce") +class LinearAllreduce(torch.nn.Module): + def __init__(self, patched_mod, *args, **kwargs): + super().__init__() + self.__dict__.update(patched_mod.__dict__) + self.extra_repr = patched_mod.extra_repr_org + +@helper_mod_register(name="ScopedLinearAllReduce") +class ScopedLinearAllReduce(torch.nn.Module): + def __init__(self, patched_mod, *args, **kwargs): + super().__init__() + self.__dict__.update(patched_mod.__dict__) + self.extra_repr = patched_mod.extra_repr_org + +@helper_mod_register(name="LmHeadLinearAllreduce") +class LmHeadLinearAllreduce(torch.nn.Module): + def __init__(self, patched_mod, *args, **kwargs): + super().__init__() + self.__dict__.update(patched_mod.__dict__) + self.extra_repr = patched_mod.extra_repr_org + +@helper_mod_register(name="ModuleFusedSDPA") +class ModuleFusedSDPA(torch.nn.Module): + def __init__(self, patched_mod, *args, **kwargs): + super().__init__() + self.__dict__.update(patched_mod.__dict__) + self.extra_repr = patched_mod.extra_repr_org diff --git a/neural_compressor/torch/algorithms/habana_fp8/fp8_quant.py b/neural_compressor/torch/algorithms/habana_fp8/fp8_quant.py deleted file mode 100644 index c80cc443531..00000000000 --- a/neural_compressor/torch/algorithms/habana_fp8/fp8_quant.py +++ /dev/null @@ -1,220 +0,0 @@ -# Copyright (c) 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# pylint:disable=import-error - -import copy -import os - -import habana_frameworks.torch.core as htcore -import torch -from deepspeed.module_inject import LinearAllreduce, LinearLayer -from deepspeed.module_inject.layers import LmHeadLinearAllreduce -from habana_frameworks.torch.core.quantization import _check_params_as_const, _mark_params_as_const - -from neural_compressor.torch.utils import fetch_module, logger, set_module - -from .modules import ( # fp32; dynamic modules; static modules; dtype amax - Autocast, - BatchMatmul, - FP8BatchMatmul, - FP8Cast, - FP8DynamicBatchMatmul, - FP8DynamicLinear, - FP8DynamicMatmul, - FP8Linear, - FP8LinearAllreduce, - FP8LinearLayer, - FP8LmHeadLinearAllreduce, - FP8Matmul, - Matmul, -) -from .observer import observer_mapping - -quantization_mapping = { - LinearAllreduce: FP8LinearAllreduce, - LinearLayer: FP8LinearLayer, - LmHeadLinearAllreduce: FP8LmHeadLinearAllreduce, - torch.nn.Linear: FP8Linear, - BatchMatmul: FP8BatchMatmul, - Matmul: FP8Matmul, - Autocast: FP8Cast, - # torch.matmul: fp8_matmul -} -white_list = tuple(quantization_mapping.keys()) - - -FP8_DTYPE = [torch.float8_e5m2, torch.float8_e4m3fn, "fp8_e5m2", "fp8_e4m3"] -dtype_mapping = {"fp8_e5m2": torch.float8_e5m2, "fp8_e4m3": torch.float8_e4m3fn} -# enable inference optimizations -htcore.hpu_initialize() - - -def _replace_module(module, qconfig): - assert qconfig.w_dtype == qconfig.act_dtype, "weight and activation should be the same dtype." - dtype = dtype_mapping[qconfig.w_dtype] - # only modules that have weight should use this observer - if hasattr(module, "weight"): - observer_cls = observer_mapping[qconfig.w_observer] - observer_obj = observer_cls(dtype=dtype) - if qconfig.approach == "static": - if isinstance(module, white_list): - QModule = quantization_mapping[type(module)] - qmodule = QModule(module, dtype) - elif qconfig.approach == "dynamic": - if isinstance(module, torch.nn.Linear): - # need module for initialization - qmodule = FP8DynamicLinear(module, dtype) - elif isinstance(module, Matmul): - qmodule = FP8DynamicMatmul(dtype) - elif isinstance(module, BatchMatmul): - qmodule = FP8DynamicBatchMatmul(dtype) - elif isinstance(module, Autocast): - qmodule = FP8Cast(dtype=dtype) - # only modules that have weight should use this API - if hasattr(qmodule, "from_float"): - qmodule.from_float(module, observer_obj) - return qmodule - - -def quantize_dynamic(model, dtype=torch.float8_e4m3fn, inplace=True): - torch.set_grad_enabled(False) - q_model = model if inplace else copy.deepcopy(model) - if isinstance(dtype, str): - dtype = dtype_mapping[dtype] - for n, m in q_model.named_modules(): - if isinstance(m, torch.nn.Linear): - observer_cls = observer_mapping["minmax_per_channel"] - observer_obj = observer_cls(dtype=dtype) - new_m = FP8DynamicLinear(m, dtype) # need m for init - new_m.from_float(m, observer_obj) - set_module(q_model, n, new_m) - elif isinstance(m, Matmul): - new_m = FP8DynamicMatmul(dtype) - set_module(q_model, n, new_m) - elif isinstance(m, BatchMatmul): - new_m = FP8DynamicBatchMatmul(dtype) - set_module(q_model, n, new_m) - elif isinstance(m, Autocast): - new_m = FP8Cast(dtype=dtype) - set_module(q_model, n, new_m) - htcore.mark_step() - _mark_params_as_const(q_model) - _check_params_as_const(q_model) - return q_model - - -def _add_observer(module, qconfig): - act_observer = qconfig.act_observer - - def input_observer_forward_pre_hook(self, input): - try: - if isinstance(input[0], torch.Tensor): - self.input_activation_post_process(input[0]) - if hasattr(self, "input_activation_post_process1") and isinstance(input[1], torch.Tensor): - self.input_activation_post_process1(input[1]) - return input - except Exception as e: - # The KL act_observer may encounter a overflow error on EltwiseAdd. - pass - - ### Insert input observer into model, only for fp8_e4m3 static quantization ### - observer_cls = observer_mapping[act_observer] - - if isinstance(module, white_list): - observer_obj = observer_cls(dtype=dtype_mapping[qconfig.act_dtype]) - module.add_module("input_activation_post_process", observer_obj) - if isinstance(module, (BatchMatmul, Matmul)): - observer_obj = observer_cls(dtype=dtype_mapping[qconfig.act_dtype]) - module.add_module("input_activation_post_process1", observer_obj) - module.register_forward_pre_hook(input_observer_forward_pre_hook) - - -def _remove_observer(module): - import deepspeed.comm as dist - from torch.distributed import ReduceOp - - if hasattr(module, "input_activation_post_process"): - scale = module.input_activation_post_process.calculate_qparams() - if dist.is_initialized(): - scale = scale.to("hpu") - dist.all_reduce(scale, op=ReduceOp.MAX) - if hasattr(module, "input_activation_post_process1"): - module.register_parameter("scale1", torch.nn.Parameter(scale)) - else: - module.register_parameter("scale", torch.nn.Parameter(scale)) - delattr(module, "input_activation_post_process") - if hasattr(module, "input_activation_post_process1"): - scale = module.input_activation_post_process1.calculate_qparams() - if dist.is_initialized(): - scale = scale.to("hpu") - dist.all_reduce(scale, op=ReduceOp.MAX) - module.register_parameter("scale2", torch.nn.Parameter(scale)) - delattr(module, "input_activation_post_process1") - - # remove observer hooks - hook_map = module._forward_pre_hooks - handle_ids_to_remove = set() - for handle_id, hook_fn in hook_map.items(): - if hasattr(hook_fn, "__name__") and hook_fn.__name__ == "input_observer_forward_pre_hook": - handle_ids_to_remove.add(handle_id) - for handle_id in handle_ids_to_remove: - hook_map.pop(handle_id) - - -def prepare(model, qconfig_mapping): - model.qconfig = qconfig_mapping - for (op_name, op_type), qconfig in qconfig_mapping.items(): - if qconfig.approach == "dynamic": - continue - if qconfig.w_dtype not in FP8_DTYPE: - continue - module = fetch_module(model, op_name) - if module is None: - logger.info(f"{op_name} is not found in model.") - continue - _add_observer(module, qconfig) - set_module(model, op_name, module) - return model - - -def convert(model): - for (op_name, op_type), qconfig in model.qconfig.items(): - if qconfig.w_dtype not in FP8_DTYPE: - continue - module = fetch_module(model, op_name) - if module is None: - logger.info(f"{op_name} is not found in model.") - continue - if qconfig.approach != "dynamic": - _remove_observer(module) - module = _replace_module(module, qconfig) - set_module(model, op_name, module) - htcore.mark_step() - return model - - -def quantize(model, qconfig_mapping, run_fn=None, run_args=None, inplace=True): - torch.set_grad_enabled(False) - q_model = model if inplace else copy.deepcopy(model) - q_model = prepare(q_model, qconfig_mapping) - if run_fn is not None: - if run_args is not None: - run_fn(q_model, *run_args) - else: - run_fn(q_model) - q_model = convert(q_model) - _mark_params_as_const(q_model) - _check_params_as_const(q_model) - return q_model diff --git a/neural_compressor/torch/algorithms/habana_fp8/modules.py b/neural_compressor/torch/algorithms/habana_fp8/modules.py deleted file mode 100644 index 99b9faf1f72..00000000000 --- a/neural_compressor/torch/algorithms/habana_fp8/modules.py +++ /dev/null @@ -1,487 +0,0 @@ -# Copyright (c) 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# pylint:disable=import-error - -import os - -import habana_frameworks.torch.core as htcore -import habana_frameworks.torch.hpex -import torch -import torch.nn as nn -from torch.nn import functional as F - -from neural_compressor.common import logger - -from .observer import calculate_qparams - - -##################### FP32 modules ####################### -class Matmul(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - return torch.matmul(x, y) - - -class BatchMatmul(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - return torch.bmm(x, y) - - -class Autocast(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x - - -##################### FP8 modules ####################### -class FP8DynamicLinear(torch.nn.Module): - def __init__(self, org_module, dtype=torch.float8_e4m3fn) -> None: - super().__init__() - # attributes - self.use_amax = True - self.dtype = dtype - self.in_features = org_module.in_features - self.out_features = org_module.out_features - self.weight_dtype = self.dtype - self.out_dtype = org_module.weight.dtype - # register weight, bias - self.register_buffer( - "weight", - torch.empty( - self.in_features, - self.out_features, - device="hpu", - dtype=self.weight_dtype, - ), - ) - if org_module.bias is not None: - self.register_buffer( - "bias", - torch.empty( - self.out_features, - device="hpu", - dtype=self.out_dtype, - ), - ) - else: - self.bias = None - - def from_float(self, org_module, w_observer): - # register scale - if not org_module.weight.device.type == "meta": - w_observer(org_module.weight) - weight_scale = w_observer.calculate_qparams() - else: - weight_scale = torch.tensor([1.0]) - self.register_buffer( - "weight_scale", - torch.tensor( - weight_scale, - device="hpu", - dtype=torch.float32, - ), - ) - self.register_buffer( - "weight_scale_inv", - torch.tensor( - torch.reciprocal(weight_scale), - device="hpu", - dtype=torch.float32, - ), - ) - # copy weight and bias - if not org_module.weight.device.type == "meta": - org_module.to("hpu") - self.weight.data.copy_( - torch.ops.hpu.cast_to_fp8_v2(org_module.weight.T, self.weight_scale_inv, False, False, self.dtype)[0] - ) - if org_module.bias is not None: - self.bias.data.copy_(org_module.bias.data.type(self.out_dtype)) - - def forward(self, inp): - assert inp.shape[-1] == self.in_features, "GEMM not possible" - org_middle_shape = inp.shape[1:-1] - inp = inp.view(-1, self.in_features) - if inp.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: - if self.use_amax: - input_scale = calculate_qparams(inp.min(), inp.max(), self.dtype) - input_scale_inv = torch.reciprocal(input_scale) - else: - input_scale, input_scale_inv = None, None - inp = torch.ops.hpu.cast_to_fp8_v2(inp, input_scale_inv, False, False, self.dtype)[0] - else: - input_scale, input_scale_inv = None, None - out = torch.ops.hpu.fp8_gemm_v2( - inp, - False, - self.weight, - False, - None, - self.out_dtype, - input_scale, # inv is used for recover scale - self.weight_scale, - self.bias, - False, - ) - out = out.view(-1, *org_middle_shape, out.shape[-1]) - return out - - def extra_repr(self) -> str: - return "in_features={}, out_features={}, bias={}, format={}".format( - self.in_features, - self.out_features, - self.bias is not None, - self.dtype, - ) - - -class FP8DynamicMatmul(torch.nn.Module): - def __init__(self, dtype) -> None: - super().__init__() - self.dtype = dtype - self.use_amax = True - self.out_dtype = torch.float32 - - def forward(self, input1, input2): - dim1 = input1.shape[-1] - dim2 = input2.shape[-2] - assert dim1 == dim2, "GEMM not possible" - - # process input1 - if input1.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: - self.out_dtype = input1.dtype - if self.use_amax: - input1_scale = calculate_qparams(input1.min(), input1.max(), self.dtype) - input1_scale_inv = torch.reciprocal(input1_scale) - else: - input1_scale, input1_scale_inv = None, None - input1 = torch.ops.hpu.cast_to_fp8_v2(input1, input1_scale_inv, False, False, self.dtype)[0] - else: - # skip cast for input1 - input1_scale, input1_scale_inv = None, None - # process input2 - if input2.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: - self.out_dtype = input2.dtype - if self.use_amax: - input2_scale = calculate_qparams(input2.min(), input2.max(), self.dtype) - input2_scale_inv = torch.reciprocal(input2_scale) - else: - input2_scale, input2_scale_inv = None, None - input2 = torch.ops.hpu.cast_to_fp8_v2(input2, input2_scale_inv, False, False, self.dtype)[0] - else: - # skip cast for input2 - input2_scale, input2_scale_inv = None, None - # calculate - out = torch.ops.hpu.fp8_gemm_v2( - input1, - False, - input2, - False, - None, - self.out_dtype, - input1_scale, # inv is used for recover scale - input2_scale, - None, - False, - ) - return out - - def extra_repr(self) -> str: - return "format={}".format(self.dtype) - - -class FP8DynamicBatchMatmul(FP8DynamicMatmul): - pass - - -class FP8Linear(torch.nn.Module): - def __init__(self, org_module, dtype) -> None: - super().__init__() - # attributes - self.in_features = org_module.in_features - self.out_features = org_module.out_features - self.dtype = dtype - self.weight_dtype = self.dtype - self.out_dtype = org_module.weight.dtype - self.register_buffer( - "weight", - torch.empty( - self.in_features, - self.out_features, - device="hpu", - dtype=self.weight_dtype, - ), - ) - if org_module.bias is not None: - self.register_buffer( - "bias", - torch.empty( - self.out_features, - device="hpu", - dtype=self.out_dtype, - ), - ) - else: - self.bias = None - - def from_float(self, org_module, w_observer): - # register scale - if not org_module.weight.device.type == "meta": - w_observer(org_module.weight) - weight_scale = w_observer.calculate_qparams() - else: - weight_scale = torch.tensor([1.0]) - self.register_buffer( - "weight_scale", - torch.tensor( - weight_scale, - device="hpu", - dtype=torch.float32, - ), - ) - self.register_buffer( - "weight_scale_inv", - torch.tensor( - torch.reciprocal(weight_scale), - device="hpu", - dtype=torch.float32, - ), - ) - # copy weight and bias - if not org_module.weight.device.type == "meta": - org_module.to("hpu") - self.weight.data.copy_( - torch.ops.hpu.cast_to_fp8_v2(org_module.weight.T, self.weight_scale_inv, False, False, self.dtype)[0] - ) - if org_module.bias is not None: - self.bias.data.copy_(org_module.bias.data.type(self.out_dtype)) - # register input scale - input_scale = org_module.scale if hasattr(org_module, "scale") else torch.tensor([1.0]) - self.register_buffer( - "input_scale", - torch.tensor( - input_scale, - device="hpu", - dtype=torch.float32, - ), - ) - self.register_buffer( - "input_scale_inv", - torch.tensor( - torch.reciprocal(input_scale), - device="hpu", - dtype=torch.float32, - ), - ) - - def forward(self, inp): - assert inp.shape[-1] == self.in_features, "GEMM not possible" - org_middle_shape = inp.shape[1:-1] - inp = inp.view(-1, self.in_features) - inp = torch.ops.hpu.cast_to_fp8_v2(inp, self.input_scale_inv, False, False, self.dtype)[0] - out = torch.ops.hpu.fp8_gemm_v2( - inp, - False, - self.weight, - False, - None, - self.out_dtype, - self.input_scale, # inv is used for recover scale - self.weight_scale, - self.bias, - False, - ) - out = out.view(-1, *org_middle_shape, out.shape[-1]) - return out - - def extra_repr(self) -> str: - return "in_features={}, out_features={}, bias={}, scale={}, format={}".format( - self.in_features, - self.out_features, - self.bias is not None, - self.input_scale.tolist() if hasattr(self, "input_scale") else None, - self.dtype, - ) - - -class FP8Matmul(torch.nn.Module): - def __init__(self, org_module, dtype) -> None: - super().__init__() - org_module.to("hpu") - self.dtype = dtype - self.out_dtype = torch.float32 - scale1 = org_module.scale1 if hasattr(org_module, "scale1") else 1.0 - scale2 = org_module.scale2 if hasattr(org_module, "scale2") else 1.0 - self.register_buffer( - "scale1", - torch.tensor( - scale1, - device="hpu", - dtype=self.out_dtype, - ), - ) - self.register_buffer( - "scale2", - torch.tensor( - scale2, - device="hpu", - dtype=self.out_dtype, - ), - ) - - def forward(self, input1, input2): - dim1 = input1.shape[-1] - dim2 = input2.shape[-2] - assert dim1 == dim2, "GEMM not possible" - - if input1.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: - self.out_dtype = input1.dtype - self.scale1_inv = torch.reciprocal(self.scale1) - input1 = torch.ops.hpu.cast_to_fp8_v2(input1, self.scale1_inv, False, False, self.dtype)[0] - else: - self.scale1_inv = None - if input2.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: - self.out_dtype = input2.dtype - self.scale2_inv = torch.reciprocal(self.scale2) - input2 = torch.ops.hpu.cast_to_fp8_v2(input2, self.scale2_inv, False, False, self.dtype)[0] - else: - self.scale2_inv = None - out = torch.ops.hpu.fp8_gemm_v2( - input1, - False, - input2, - False, - None, - self.out_dtype, - self.scale1, # inv is used for recover scale - self.scale2, - None, - False, - ) - return out - - def extra_repr(self) -> str: - return "scales={}, format={}".format( - (self.scale1.tolist(), self.scale2.tolist()), - self.dtype, - ) - - -class FP8BatchMatmul(FP8Matmul): - pass - - -class FP8Cast(torch.nn.Module): - def __init__(self, org_module=None, dtype=torch.float8_e4m3fn) -> None: - super().__init__() - self.dtype = dtype - if org_module is not None: - org_module.to("hpu") - scale = org_module.scale if hasattr(org_module, "scale") else 1.0 - self.register_buffer( - "scale", - torch.tensor( - scale, - device="hpu", - dtype=torch.float32, - ), - ) - self.scale, self.scale_inv = None, None # due to next matmul doesn't know this scale - else: - self.scale, self.scale_inv = None, None - - def forward(self, input): - if input.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: - out = torch.ops.hpu.cast_to_fp8_v2(input, self.scale_inv, False, False, self.dtype)[0] - else: - out = input - return out - - def extra_repr(self) -> str: - return "scales={}, format={}".format( - self.scale, - self.dtype, - ) - - -FP8LinearLayer = FP8Linear - - -class FP8LinearAllreduce(FP8Linear): - def forward(self, inp): - assert inp.shape[-1] == self.in_features, "GEMM not possible" - inputmat = inp.view(-1, self.in_features) - inputmat = torch.ops.hpu.cast_to_fp8_v2(inputmat, self.input_scale_inv, False, False, self.dtype)[0] - out = torch.ops.hpu.fp8_gemm_v2( - inputmat, - False, - self.weight, - False, - None, - self.out_dtype, - self.input_scale, - self.weight_scale, - None, - False, - ) - from deepspeed import comm as dist - - if self.mp_group is not None: - dist.inference_all_reduce(out, group=self.mp_group) - if self.bias is not None: - out += self.bias - return out.view(-1, *inp.shape[1:-1], out.shape[-1]) - - -class FP8LmHeadLinearAllreduce(FP8Linear): - def forward(self, inp): - # from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list - # input_shard_size = get_shard_size(inp.shape[-1], self.world_size) - # input_shard_offset = sum(get_shard_size_list(inp.shape[-1], self.world_size)[0:self.rank]) - - # inputmat = inp[:, :, input_shard_offset:input_shard_offset + input_shard_size] - assert ( - inp.shape[-1] % self.world_size == 0 - ), "Please ensure that self.world_size is divisible by input.shape[-1]" - input_shard = inp.shape[-1] // self.world_size - inp_part = inp[:, :, self.rank * input_shard : (self.rank + 1) * input_shard] - inputmat = inp_part.view(-1, input_shard) # dim=2 will help kernel speed - inputmat = torch.ops.hpu.cast_to_fp8_v2(inputmat, self.input_scale_inv, False, False, self.dtype)[0] - out = torch.ops.hpu.fp8_gemm_v2( - inputmat, - False, - self.weight, - False, - None, - self.out_dtype, - self.input_scale, - self.weight_scale, - None, - False, - ) - from deepspeed import comm as dist - - if self.mp_group is not None: - dist.inference_all_reduce(out, group=self.mp_group) - if self.bias is not None: - out += self.bias - return out.view(-1, *inp.shape[1:-1], out.shape[-1]) diff --git a/neural_compressor/torch/algorithms/habana_fp8/observer.py b/neural_compressor/torch/algorithms/habana_fp8/observer.py deleted file mode 100644 index fd29892ddb7..00000000000 --- a/neural_compressor/torch/algorithms/habana_fp8/observer.py +++ /dev/null @@ -1,440 +0,0 @@ -# Copyright (c) 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# pylint:disable=import-error - -import os -from typing import Tuple - -import habana_frameworks.torch.core as htcore -import torch -from torch.ao.quantization.observer import * - -E4M3_AMAX = torch.tensor(240, dtype=torch.float).to("cpu") -E5M2_AMAX = torch.tensor(57344, dtype=torch.float).to("cpu") -USE_HW_SCALE = bool(os.getenv("USE_HW_SCALE", False)) -USE_POW2_SCALE = bool(os.getenv("USE_POW2_SCALE", False)) -observer_mapping = {} - - -def observer_registry(name): - def new_observer(observer_cls): - global observer_mapping - observer_mapping[name] = observer_cls - return observer_cls - - return new_observer - - -def _map_gaudi_scale(scale): - if USE_HW_SCALE: - scale_list = torch.tensor([16, 1, 1 / 16, 1 / 256]) - return torch.clip( - 2 ** (torch.ceil(torch.log2(scale) / 4) * 4), - torch.tensor(scale_list[-1], dtype=scale.dtype, device=scale.device), - torch.tensor(scale_list[0], dtype=scale.dtype, device=scale.device), - ) - elif USE_POW2_SCALE: - return 2 ** torch.ceil(torch.log2(scale)) - else: - return scale - - -def calculate_qparams(min_val, max_val, dtype): - amax = torch.max(torch.abs(min_val), torch.abs(max_val)) - dtype_amax = E4M3_AMAX if dtype == torch.float8_e4m3fn else E5M2_AMAX - scale = amax / dtype_amax - scale = scale.reshape(-1) - return _map_gaudi_scale(scale) - - -@observer_registry(name="minmax") -class FP8MinMaxObserver(ObserverBase): - def __init__( - self, - dtype: torch.dtype = torch.float8_e4m3fn, - ) -> None: - # bins: The number of bins used for histogram calculation. - super().__init__(dtype=dtype) - assert isinstance(dtype, torch.dtype), "Please make sure the dtype of observer is torch.dtype." - factory_kwargs = {"device": "cpu", "dtype": torch.float32} - self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) - self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) - - def forward(self, x_orig): - r"""Records the running minimum and maximum of ``x``.""" - if x_orig.numel() == 0: - return x_orig - x = x_orig.detach() - x = x.to(self.min_val.dtype) - min_val_cur, max_val_cur = torch.aminmax(x) - min_val = torch.min(min_val_cur, self.min_val) - max_val = torch.max(max_val_cur, self.max_val) - self.min_val.copy_(min_val) - self.max_val.copy_(max_val) - return x_orig - - def calculate_qparams(self): - r"""Calculates the quantization parameters.""" - scale = calculate_qparams(self.min_val, self.max_val, self.dtype) - return scale - - def extra_repr(self): - return f"min_val={self.min_val}, max_val={self.max_val}" - - def reset_min_max_vals(self): - """Resets the min/max values.""" - self.min_val.copy_(torch.tensor(float("inf"))) - self.max_val.copy_(torch.tensor(float("-inf"))) - - -@observer_registry(name="minmax_per_channel") -class FP8PerChannelMinMaxObserver(ObserverBase): - def __init__( - self, - dtype: torch.dtype = torch.float8_e4m3fn, - ch_axis=0, # weight_shape = (out_features, in_features) - ) -> None: - # bins: The number of bins used for histogram calculation. - super().__init__(dtype=dtype) - assert isinstance(dtype, torch.dtype), "Please make sure the dtype of observer is torch.dtype." - self.ch_axis = ch_axis - factory_kwargs = {"device": "cpu", "dtype": torch.float32} - self.register_buffer("min_val", torch.tensor([], **factory_kwargs)) - self.register_buffer("max_val", torch.tensor([], **factory_kwargs)) - - def forward(self, x_orig): - if x_orig.numel() == 0: - return x_orig - x = x_orig.detach() - min_val = self.min_val - max_val = self.max_val - x_dim = x.size() - - new_axis_list = [i for i in range(len(x_dim))] - new_axis_list[self.ch_axis] = 0 - new_axis_list[0] = self.ch_axis - y = x.permute(new_axis_list) - # Need to match dtype of min/max because the updates to buffers - # are done in place and types need to match for comparisons - y = y.to(self.min_val.dtype) - y = torch.flatten(y, start_dim=1) - if min_val.numel() == 0 or max_val.numel() == 0: - min_val, max_val = torch.aminmax(y, dim=1) - else: - min_val_cur, max_val_cur = torch.aminmax(y, dim=1) - min_val = torch.min(min_val_cur, min_val) - max_val = torch.max(max_val_cur, max_val) - self.min_val.resize_(min_val.shape) - self.max_val.resize_(max_val.shape) - self.min_val.copy_(min_val) - self.max_val.copy_(max_val) - return x_orig - - def calculate_qparams(self): - r"""Calculates the quantization parameters.""" - scale = calculate_qparams(self.min_val, self.max_val, self.dtype) - return scale - - def extra_repr(self): - return f"min_val={self.min_val}, max_val={self.max_val}" - - def reset_min_max_vals(self): - """Resets the min/max values.""" - self.min_val.copy_(torch.tensor(float("inf"))) - self.max_val.copy_(torch.tensor(float("-inf"))) - - -@observer_registry(name="kl") -class FP8HistogramObserver(ObserverBase): - def __init__( - self, - dtype: torch.dtype = torch.float8_e4m3fn, - bins: int = 2048, - upsample_rate: int = 128, - qscheme=torch.per_tensor_affine, - eps=torch.finfo(torch.float32).eps, - ) -> None: - # bins: The number of bins used for histogram calculation. - super().__init__(dtype=dtype) - assert isinstance(dtype, torch.dtype), "Please make sure the dtype of observer is torch.dtype." - self.bins = bins - factory_kwargs = {"device": "cpu", "dtype": torch.float32} - self.register_buffer("histogram", torch.zeros(self.bins, **factory_kwargs)) - self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) - self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) - self.dst_nbins = 2 ** torch.finfo(self.dtype).bits - self.upsample_rate = upsample_rate - - def calculate_qparams(self, **kwargs): - new_min, new_max = self._non_linear_param_search() - scale = calculate_qparams(new_min, new_max, self.dtype) - return scale - - def _get_norm(self, delta_begin: torch.Tensor, delta_end: torch.Tensor, density: torch.Tensor) -> torch.Tensor: - r"""Compute the norm of the values uniformaly distributed between - delta_begin and delta_end. - Currently only L2 norm is supported. - - norm = density * (integral_{begin, end} x^2) - = density * (end^3 - begin^3) / 3 - """ - norm = (delta_end * delta_end * delta_end - delta_begin * delta_begin * delta_begin) / 3 - return density * norm - - def _get_dst_bin(self, src_bin_begin, src_bin_end, dst_bin_max): - # get dst bin value - FP8_amax = E4M3_AMAX if self.dtype == torch.float8_e4m3fn else E5M2_AMAX - scale = FP8_amax / dst_bin_max - if torch.isinf(torch.tensor(scale)): - scale = torch.tensor(3.4e38) - tmp = torch.ops.hpu.cast_to_fp8_v2(src_bin_begin.to("hpu"), scale.to("hpu"), False, False, self.dtype)[0] - dst_bin_begin = torch.ops.hpu.cast_from_fp8(tmp, None, torch.float32).to("cpu") - tmp = torch.ops.hpu.cast_to_fp8_v2(src_bin_end.to("hpu"), scale.to("hpu"), False, False, self.dtype)[0] - dst_bin_end = torch.ops.hpu.cast_from_fp8(tmp, None, torch.float32).to("cpu") - # get bin width of dst bin value, dst_bin_begin must contain 0 and the max qvalue. - dst_bin = list(set(dst_bin_begin.detach().cpu().numpy())) - dst_bin.sort() - width_dict = {} - bin_of_dst_dict = {} - for i, bin in enumerate(dst_bin): - bin_of_dst_dict[bin] = i - if bin == 0: - width_dict[bin] = {"left": 0, "right": dst_bin[i + 1]} - elif i == len(dst_bin) - 1: - width_dict[bin] = {"left": dst_bin[i] - dst_bin[i - 1], "right": dst_bin[i] - dst_bin[i - 1]} - else: - width_dict[bin] = {"left": dst_bin[i] - dst_bin[i - 1], "right": dst_bin[i + 1] - dst_bin[i]} - dst_bin_of_begin = [bin_of_dst_dict[float(i)] for i in dst_bin_begin] - dst_bin_of_end = [bin_of_dst_dict[float(i)] for i in dst_bin_end] - left_dst_bin_end_width = [width_dict[float(i)]["left"] for i in dst_bin_end] - right_dst_bin_begin_width = [width_dict[float(i)]["right"] for i in dst_bin_begin] - return ( - dst_bin_begin, - dst_bin_end, - torch.tensor(dst_bin_of_begin), - torch.tensor(dst_bin_of_end), - torch.tensor(left_dst_bin_end_width), - torch.tensor(right_dst_bin_begin_width), - ) - - def _compute_quantization_error(self, next_start_bin: int, next_end_bin: int): - r"""Compute the quantization error if we use start_bin to end_bin as the - min and max to do the quantization.""" - bin_width = (self.max_val.item() - self.min_val.item()) / self.bins - dst_bin_max = bin_width * (next_end_bin - next_start_bin + 1) - - src_bin = torch.arange(self.bins, device=self.histogram.device) - src_bin_begin = src_bin * bin_width - src_bin_end = src_bin_begin + bin_width - ( - dst_bin_begin, - dst_bin_end, - dst_bin_of_begin, - dst_bin_of_end, - left_dst_bin_end_width, - right_dst_bin_begin_width, - ) = self._get_dst_bin(src_bin_begin, src_bin_end, dst_bin_max) - - dst_bin_of_begin_center = dst_bin_begin + right_dst_bin_begin_width - dst_bin_of_end_center = dst_bin_end + left_dst_bin_end_width - - density = self.histogram / bin_width - - norm = torch.zeros(self.bins, device=self.histogram.device) - - delta_begin = src_bin_begin - dst_bin_of_begin_center - delta_end = right_dst_bin_begin_width - - norm += self._get_norm(delta_begin, delta_end, density) - - norm += (dst_bin_of_end - dst_bin_of_begin - 1) * self._get_norm( - torch.tensor(-left_dst_bin_end_width), torch.tensor(right_dst_bin_begin_width), density - ) - - delta_begin = -left_dst_bin_end_width - delta_end = src_bin_end - dst_bin_of_end_center - norm += self._get_norm(delta_begin, delta_end, density) - - return norm.sum().item() - - def _non_linear_param_search(self) -> Tuple[torch.Tensor, torch.Tensor]: - r"""Non-linear parameter search. - - An approximation for L2 error minimization for selecting min/max. - By selecting new min/max, we filter out outliers in input distribution. - This follows the implementation of NormMinimization::NonlinearQuantizationParamsSearch in - caffe2/quantization/server/norm_minimization.cc - """ - assert self.histogram.size()[0] == self.bins, "bins mismatch" - bin_width = (self.max_val - self.min_val) / self.bins - - # cumulative sum - total = torch.sum(self.histogram).item() - cSum = torch.cumsum(self.histogram, dim=0) - - stepsize = 1e-5 # granularity - alpha = 0.0 # lower bound - beta = 1.0 # upper bound - start_bin = 0 - end_bin = self.bins - 1 - norm_min = float("inf") - - while alpha < beta: - # Find the next step - next_alpha = alpha - next_beta = beta - stepsize - - # find the right bins between the quantile bounds - # keep the left bins at zero due to fp8 symmetry - l = 0 - r = end_bin - while r > start_bin and cSum[r] > next_beta * total: - r = r - 1 - - # decide the next move - next_start_bin = start_bin - next_end_bin = end_bin - if (l - start_bin) <= (end_bin - r): - # move the end bin - next_end_bin = r - beta = next_beta - - if next_start_bin == start_bin and next_end_bin == end_bin: - continue - - # calculate the quantization error using next_start_bin and next_end_bin - norm = self._compute_quantization_error(next_start_bin, next_end_bin) - - if norm > norm_min: - break - norm_min = norm - start_bin = next_start_bin - end_bin = next_end_bin - - new_min = self.min_val + bin_width * start_bin - new_max = self.min_val + bin_width * (end_bin + 1) - return new_min, new_max - - def _adjust_min_max( - self, combined_min: torch.Tensor, combined_max: torch.Tensor, upsample_rate: int - ) -> Tuple[torch.Tensor, torch.Tensor, int, int]: - # We ensure that: - # (combined_max - combined_min)/(downsample_rate*Nbins) = (max - min)/(upsample_rate*Nbins) - # This allows us to have a common grid of resolution s, where we can align - # the input histogram - # start_idx maps min_val to the histogram bin index. - - # Compute the width of histogram bins is a straightforward solution, where - # hist_bin_width = (self.max_val - self.min_val) / (self.bins * upsample_rate) - # Underflow happens if the numerator is close to the smallest positive subnormal number of FP32 - # Therefore, we avoid such division operation. - downsample_rate = int( - torch.ceil((combined_max - combined_min) * upsample_rate / (self.max_val - self.min_val)).item() - ) - e = downsample_rate * (self.max_val - self.min_val) / upsample_rate - (combined_max - combined_min) - start_idx = int( - torch.round( - (self.min_val - combined_min) * self.bins * upsample_rate / (self.max_val - self.min_val) - ).item() - ) - combined_max = combined_max + e - combined_min = combined_min - return combined_min, combined_max, downsample_rate, start_idx - - def _combine_histograms( - self, - orig_hist: torch.Tensor, - new_hist: torch.Tensor, - upsample_rate: int, - downsample_rate: int, - start_idx: int, - Nbins: int, - ) -> torch.Tensor: - # First up-sample the histogram with new data by a factor of L - # This creates an approximate probability density that's piecewise constant - upsampled_histogram = new_hist.repeat_interleave(upsample_rate) - # Now insert the upsampled histogram into the output - # histogram, which is initialized with zeros. - # The offset at which the histogram is introduced is determined - # by the start index as the output histogram can cover a wider range - histogram_with_output_range = torch.zeros((Nbins * downsample_rate), device=orig_hist.device) - histogram_with_output_range[start_idx : Nbins * upsample_rate + start_idx] = upsampled_histogram - # Compute integral histogram, double precision is needed to ensure - # that there are no overflows - integral_histogram = torch.cumsum(histogram_with_output_range, 0, dtype=torch.double)[ - downsample_rate - 1 :: downsample_rate - ] - # Finally perform interpolation - shifted_integral_histogram = torch.zeros((Nbins), device=orig_hist.device) - shifted_integral_histogram[1:Nbins] = integral_histogram[0:-1] - interpolated_histogram = (integral_histogram - shifted_integral_histogram) / upsample_rate - orig_hist = orig_hist + interpolated_histogram.to(torch.float) - return orig_hist - - def forward(self, x_orig: torch.Tensor) -> torch.Tensor: - if x_orig.numel() == 0: - return x_orig - x = x_orig.detach() - # use abs due to fp8 symmetry - x = torch.abs(x) - min_val = self.min_val - max_val = self.max_val - same_values = min_val.item() == max_val.item() - is_uninitialized = min_val == float("inf") and max_val == float("-inf") - if is_uninitialized or same_values: - min_val, max_val = torch.aminmax(x) - self.min_val.resize_(min_val.shape) - self.min_val.copy_(min_val) - self.max_val.resize_(max_val.shape) - self.max_val.copy_(max_val) - assert min_val.numel() == 1 and max_val.numel() == 1, "histogram min/max values must be scalar." - torch.histc(x, self.bins, min=int(min_val), max=int(max_val), out=self.histogram) - else: - new_min, new_max = torch.aminmax(x) - combined_min = torch.min(new_min, min_val) - combined_max = torch.max(new_max, max_val) - # combine the existing histogram and new histogram into 1 histogram - # We do this by first upsampling the histogram to a dense grid - # and then downsampling the histogram efficiently - ( - combined_min, - combined_max, - downsample_rate, - start_idx, - ) = self._adjust_min_max(combined_min, combined_max, self.upsample_rate) - assert combined_min.numel() == 1 and combined_max.numel() == 1, "histogram min/max values must be scalar." - combined_histogram = torch.histc(x, self.bins, min=int(combined_min), max=int(combined_max)) - if combined_min == min_val and combined_max == max_val: - combined_histogram += self.histogram - else: - combined_histogram = self._combine_histograms( - combined_histogram, - self.histogram, - self.upsample_rate, - downsample_rate, - start_idx, - self.bins, - ) - - self.histogram.detach_().resize_(combined_histogram.shape) - self.histogram.copy_(combined_histogram) - self.min_val.detach_().resize_(combined_min.shape) - self.min_val.copy_(combined_min) - self.max_val.detach_().resize_(combined_max.shape) - self.max_val.copy_(combined_max) - return x_orig - - def extra_repr(self): - return f"min_val={self.min_val}, max_val={self.max_val}" diff --git a/neural_compressor/torch/algorithms/habana_fp8/save_load.py b/neural_compressor/torch/algorithms/habana_fp8/save_load.py deleted file mode 100644 index 8079a130625..00000000000 --- a/neural_compressor/torch/algorithms/habana_fp8/save_load.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright (c) 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# pylint:disable=import-error - -import json -import os - -import habana_frameworks.torch.core as htcore -import torch - -from neural_compressor.common.utils import load_config_mapping, save_config_mapping -from neural_compressor.torch.utils import QCONFIG_NAME, WEIGHT_NAME, logger - -from .fp8_quant import FP8_DTYPE, dtype_mapping -from .modules import ( # fp32; dynamic modules - Autocast, - BatchMatmul, - FP8Cast, - FP8DynamicBatchMatmul, - FP8DynamicLinear, - FP8DynamicMatmul, - Matmul, -) -from .observer import observer_mapping - - -def save(model, output_dir="./saved_results"): - if not os.path.exists(output_dir): - os.mkdir(output_dir) - qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME) - qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), QCONFIG_NAME) - # saving process - save_config_mapping(model.qconfig, qconfig_file_path) - - import fp8_convert - - stat_dict = {} - for k, v in model.state_dict().items(): - if v.dtype in FP8_DTYPE: - v = fp8_convert.to_u8(v.to("cpu")) - stat_dict[k] = v.to("cpu") - torch.save(stat_dict, qmodel_file_path) - - logger.info("Save state_dict of quantized model to {}.".format(qmodel_file_path)) - logger.info("Save configuration of quantized model to {}.".format(qconfig_file_path)) - - -def load(model, output_dir="./saved_results"): - from neural_compressor.torch.utils import fetch_module, set_module - - from .fp8_quant import quantization_mapping, white_list - - qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME) - stat_dict = torch.load(qmodel_file_path) - import fp8_convert - - for (op_name, op_type), op_qconfig in model.qconfig.items(): - dtype = dtype_mapping[op_qconfig.w_dtype] - # only modules that have weight should use this observer - observer_cls = observer_mapping[op_qconfig.w_observer] - observer_obj = observer_cls(dtype=dtype) - choice = 1 if dtype == torch.float8_e4m3fn else 0 - if op_name + ".weight" in stat_dict: - stat_dict[op_name + ".weight"] = fp8_convert.from_u8(stat_dict[op_name + ".weight"], choice) - if dtype not in FP8_DTYPE: - continue - module = fetch_module(model, op_name) - # replace module - if op_qconfig.approach == "static": - if isinstance(module, white_list): - QModule = quantization_mapping[type(module)] - qmodule = QModule(module, dtype) - else: - if isinstance(module, torch.nn.Linear): - # need module for initialization - qmodule = FP8DynamicLinear(module, dtype) - elif isinstance(module, Matmul): - qmodule = FP8DynamicMatmul(dtype) - elif isinstance(module, BatchMatmul): - qmodule = FP8DynamicBatchMatmul(dtype) - elif isinstance(module, Autocast): - qmodule = FP8Cast(dtype=dtype) - # only modules that have weight should use this API - if hasattr(qmodule, "from_float"): - qmodule.from_float(module, observer_obj) - # replace module with qmodule - set_module(model, op_name, qmodule) - htcore.mark_step() - model.load_state_dict(stat_dict, assign=True) - model.to("hpu") - htcore.mark_step() - logger.info("Quantized model loading successful.") - return model diff --git a/neural_compressor/torch/algorithms/habana_fp8/scale.py b/neural_compressor/torch/algorithms/habana_fp8/scale.py deleted file mode 100644 index 1dfaee24502..00000000000 --- a/neural_compressor/torch/algorithms/habana_fp8/scale.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright (c) 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# pylint:disable=import-error - -import habana_frameworks.torch.core as htcore -import torch - -scale_method_mapping = {} - - -def scale_method_registry(name): - def new_scale_method(scale_method_cls): - global scale_method_mapping - scale_method_mapping[name] = scale_method_cls - return scale_method_cls - - return new_scale_method - - -@scale_method_registry("hw") -def hardware_scale_method(scale): - scale_list = torch.tensor([16, 1, 1 / 16, 1 / 256]) - return torch.clip( - 2 ** (torch.ceil(torch.log2(scale) / 4) * 4), - torch.tensor(scale_list[-1], dtype=scale.dtype, device=scale.device), - torch.tensor(scale_list[0], dtype=scale.dtype, device=scale.device), - ) - - -@scale_method_registry("pow2") -def pow2_scale_method(scale): - return 2 ** torch.ceil(torch.log2(scale)) - - -@scale_method_registry("unit") -def unit_scale_method(scale): - return torch.tensor(1.0) - - -@scale_method_registry("self") -def self_scale_method(scale): - return scale - - -def map_gaudi_scale(scale, method): - scale_method = scale_method_mapping[method] - return scale_method(scale) diff --git a/neural_compressor/torch/algorithms/habana_fp8/tensor/__init__.py b/neural_compressor/torch/algorithms/habana_fp8/tensor/__init__.py deleted file mode 100644 index 28f108cb636..00000000000 --- a/neural_compressor/torch/algorithms/habana_fp8/tensor/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/neural_compressor/torch/algorithms/habana_fp8/tensor/convert.cpp b/neural_compressor/torch/algorithms/habana_fp8/tensor/convert.cpp deleted file mode 100644 index f22c5c82c89..00000000000 --- a/neural_compressor/torch/algorithms/habana_fp8/tensor/convert.cpp +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (c) 2024 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Temporary implementation of fp8 tensor saving and loading -// Will remove after Habana torch applies below patch: -// https://github.com/pytorch/pytorch/pull/114662 - - -#include - - -// function prototype declaration -torch::Tensor to_u8(torch::Tensor tensor); -torch::Tensor from_u8(torch::Tensor tensor, int choice=1); - - -torch::Tensor to_u8(torch::Tensor tensor) { - auto p = tensor.data_ptr(); - // RuntimeError: HPU device type not enabled. - auto options = torch::TensorOptions().device(torch::kCPU).dtype(torch::kUInt8); - auto tmp = torch::from_blob(p, tensor.sizes(), options); - // copy to avoid memory leak. - torch::Tensor tensor_uint8 = torch::empty_like(tensor, torch::kUInt8).copy_(tmp); - return tensor_uint8; -}; - - -/* -choice=1 means torch.float8_e4m3fn; -others means torch.float8_e5m2; -*/ -torch::Tensor from_u8(torch::Tensor tensor, int choice) { - auto p = tensor.data_ptr(); - torch::ScalarType dtype; - if (choice == 1) { - dtype = torch::kFloat8_e4m3fn; - } - else { - dtype = torch::kFloat8_e5m2; - } - auto options = torch::TensorOptions().device(torch::kCPU).dtype(dtype); - auto tmp = torch::from_blob(p, tensor.sizes(), options); - // copy to avoid memory leak. - torch::Tensor tensor_fp8 = torch::empty_like(tensor, dtype).copy_(tmp); - return tensor_fp8; -}; - - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("to_u8", &to_u8, "Convert tensor to u8 for saving."); - m.def("from_u8", &from_u8, "Recover tensor from u8 for loading."); -}; diff --git a/neural_compressor/torch/amp/__init__.py b/neural_compressor/torch/amp/__init__.py deleted file mode 100644 index 87a0c8287d0..00000000000 --- a/neural_compressor/torch/amp/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .autocast import autocast diff --git a/neural_compressor/torch/amp/autocast.py b/neural_compressor/torch/amp/autocast.py deleted file mode 100644 index 7375b80c0f5..00000000000 --- a/neural_compressor/torch/amp/autocast.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright (c) 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from typing import Any, Optional - -import torch -from torch.types import _dtype - - -class autocast: - r"""Instances of :class:`autocast` serve as context managers or decorators that - allow regions of your script to run in mixed precision. - - In these regions, ops run in an op-specific dtype chosen by autocast - to improve performance while maintaining accuracy. - - When entering an autocast-enabled region, Tensors may be any type. - You should not call ``half()`` or ``bfloat16()`` on your model(s) or inputs when using autocasting. - - :class:`autocast` should wrap only the forward pass(es) of your network, including the loss - computation(s). Backward passes under autocast are not recommended. - Backward ops run in the same type that autocast used for corresponding forward ops. - - # Enables autocasting for the inference pass - with torch.autocast(device_type="hpu", dtype=torch.float8_e4m3fn): - output = model(input) - - :class:`autocast` can also be used as a decorator, e.g., on the ``forward`` method of your model:: - - class AutocastModel(nn.Module): - ... - @torch.autocast(device_type="cuda") - def forward(self, input): - ... - - The autocast state is thread-local. If you want it enabled in a new thread, the context manager or decorator - must be invoked in that thread. This affects :class:`torch.nn.DataParallel` and - :class:`torch.nn.parallel.DistributedDataParallel` when used with more than one GPU per process - (see :ref:`Working with Multiple GPUs`). - - Args: - device_type(str, required): Device type to use. Possible values are: 'cuda', 'cpu', 'xpu' and 'hpu'. - The type is the same as the `type` attribute of a :class:`torch.device`. - Thus, you may obtain the device type of a tensor using `Tensor.device.type`. - enabled(bool, optional): Whether autocasting should be enabled in the region. - Default: ``True`` - dtype(torch_dtype, optional): Whether to use torch.float16 or torch.bfloat16. - cache_enabled(bool, optional): Whether the weight cache inside autocast should be enabled. - Default: ``True`` - """ - - def __init__( - self, - device_type: str, - dtype: Optional[_dtype] = None, - enabled: bool = True, - cache_enabled: Optional[bool] = None, - ): - self.device = device_type - if dtype is not None: - self.fast_dtype = dtype - if cache_enabled is not None: - self._cache_enabled = cache_enabled - if not (device_type == "hpu" and dtype in [torch.float8_e4m3fn, torch.float8_e5m2]): - self._autocast = torch.autocast(device_type, dtype, enabled, cache_enabled) - - def __enter__(self) -> None: - if self.device == "hpu" and self.fast_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - from neural_compressor.torch.amp.fp8.functions import replace_func - - # This function will replace F.linear and torch.matmul with the fp8 one - replace_func(self.fast_dtype) - else: - self._autocast.__enter__() - - def __exit__(self, exc_type, exc_value, traceback) -> None: - if self.device == "hpu" and self.fast_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - from neural_compressor.torch.amp.fp8.functions import recover_func - - # This function will recover F.linear and torch.matmul with the original one - recover_func() - else: - self._autocast.__exit__(exc_type, exc_value, traceback) diff --git a/neural_compressor/torch/amp/fp8/__init__.py b/neural_compressor/torch/amp/fp8/__init__.py deleted file mode 100644 index 28f108cb636..00000000000 --- a/neural_compressor/torch/amp/fp8/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/neural_compressor/torch/amp/fp8/functions.py b/neural_compressor/torch/amp/fp8/functions.py deleted file mode 100644 index f8f19a64b17..00000000000 --- a/neural_compressor/torch/amp/fp8/functions.py +++ /dev/null @@ -1,134 +0,0 @@ -# Copyright (c) 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# pylint:disable=import-error - -import os - -import habana_frameworks.torch.core as htcore -import habana_frameworks.torch.hpex -import torch -from torch.nn import functional as F - -from neural_compressor.torch.algorithms.habana_fp8.observer import calculate_qparams -from neural_compressor.torch.utils import logger - -_F_linear = F.linear -_torch_matmul = torch.matmul -_torch_bmm = torch.bmm - - -DATA_TYPE = torch.float8_e4m3fn -USE_AMAX = bool(os.getenv("PT_USE_FP8_AMAX", False)) - - -def fp8_linear_forward(input, weight, bias=None): - out_dtype = torch.float32 - org_middle_shape = input.shape[1:-1] - input = input.view((-1, weight.shape[-1])) - # process input - if input.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: - out_dtype = input.dtype - if USE_AMAX: - input_scale = calculate_qparams(input.min(), input.max(), DATA_TYPE) - input_scale_inv = torch.reciprocal(input_scale) - else: - input_scale, input_scale_inv = None, None - input = torch.ops.hpu.cast_to_fp8_v2(input, input_scale_inv, False, False, DATA_TYPE)[0] - else: - # skip cast for input - input_scale, input_scale_inv = None, None - # process weight - if weight.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: - out_dtype = weight.dtype - if USE_AMAX: - weight_scale = calculate_qparams(weight.min(), weight.max(), DATA_TYPE) - weight_scale_inv = torch.reciprocal(weight_scale) - else: - weight_scale, weight_scale_inv = None, None - weight = torch.ops.hpu.cast_to_fp8_v2(weight, weight_scale_inv, False, False, DATA_TYPE)[0] - else: - # skip cast for weight - weight_scale, weight_scale_inv = None, None - out = torch.ops.hpu.fp8_gemm_v2( - input, - False, - weight, - True, - None, - out_dtype, - input_scale, - weight_scale, - bias, - False, - ) - out = out.view(-1, *org_middle_shape, out.shape[-1]) - return out - - -def fp8_matmul(input1, input2): - out_dtype = torch.float32 - # process input1 - if input1.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: - out_dtype = input1.dtype - if USE_AMAX: - input1_scale = calculate_qparams(input1.min(), input1.max(), DATA_TYPE) - input1_scale_inv = torch.reciprocal(input1_scale) - else: - input1_scale, input1_scale_inv = None, None - input1 = torch.ops.hpu.cast_to_fp8_v2(input1, input1_scale_inv, False, False, DATA_TYPE)[0] - else: - # skip cast for input1 - input1_scale, input1_scale_inv = None, None - # process input2 - if input2.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: - out_dtype = input2.dtype - if USE_AMAX: - input2_scale = calculate_qparams(input2.min(), input2.max(), DATA_TYPE) - input2_scale_inv = torch.reciprocal(input2_scale) - else: - input2_scale, input2_scale_inv = None, None - input2 = torch.ops.hpu.cast_to_fp8_v2(input2, input2_scale_inv, False, False, DATA_TYPE)[0] - else: - # skip cast for input2 - input2_scale, input2_scale_inv = None, None - # calculate - out = torch.ops.hpu.fp8_gemm_v2( - input1, - False, - input2, - False, - None, - out_dtype, - input1_scale, - input2_scale, - None, - False, - ) - return out - - -def replace_func(dtype): - global DATA_TYPE - DATA_TYPE = dtype - F.linear = fp8_linear_forward - torch.matmul = fp8_matmul - torch.bmm = fp8_matmul - logger.debug("F.linear and torch.matmul are replaced with the fp8 one") - - -def recover_func(): - F.linear = _F_linear - torch.matmul = _torch_matmul - torch.bmm = _torch_bmm - logger.debug("F.linear and torch.matmul are recovered") diff --git a/neural_compressor/torch/quantization/__init__.py b/neural_compressor/torch/quantization/__init__.py index f6a015eb89f..64d7816ad81 100644 --- a/neural_compressor/torch/quantization/__init__.py +++ b/neural_compressor/torch/quantization/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. """Intel Neural Compressor Pytorch quantization API.""" -from neural_compressor.torch.quantization.quantize import quantize, prepare, convert +from neural_compressor.torch.quantization.quantize import quantize, prepare, convert, finalize_calibration from neural_compressor.torch.quantization.config import ( RTNConfig, get_default_rtn_config, diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index c86c604152c..82edd0c610d 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -678,20 +678,22 @@ def hqq_entry( ###################### Habana FP8 Algo Entry ################################## -from neural_compressor.torch.utils import is_hpex_available - -if is_hpex_available(): - from neural_compressor.torch.algorithms.habana_fp8 import quantize, save - - @register_algo(FP8_QUANT) - def fp8_quant_entry( - model: torch.nn.Module, configs_mapping: Dict[Tuple[str], FP8Config], *args, **kwargs - ) -> torch.nn.Module: - kwargs.pop("example_inputs") - model = quantize(model, configs_mapping, *args, **kwargs) - model.qconfig = configs_mapping - model.save = MethodType(save, model) - return model +@register_algo(FP8_QUANT) +@torch.no_grad() +def fp8_entry( + model: torch.nn.Module, + configs_mapping: Dict[Tuple[str], FP8Config], + mode: Mode = Mode.QUANTIZE, + *args, + **kwargs, +) -> torch.nn.Module: + """The main entry to apply fp8 quantization.""" + from neural_compressor.torch.algorithms.fp8_quant import FP8Quantizer + + quantizer = get_quantizer(model, quantizer_cls=FP8Quantizer, quant_config=configs_mapping) + model = quantizer.execute(model, mode=mode) + postprocess_model(model, mode, quantizer) + return model ###################### MX Quant Algo Entry ################################## diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 29f944b93e3..ecc18848e52 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -18,6 +18,8 @@ """Intel Neural Compressor Pytorch quantization config API.""" +import json +import importlib from collections import OrderedDict from typing import Callable, Dict, List, NamedTuple, Optional from typing import OrderedDict as OrderedDictType @@ -1606,81 +1608,142 @@ def get_default_hqq_config() -> HQQConfig: return HQQConfig() -######################## FP8 Config ############################### +######################## FP8 Quant Config ############################### +# refer to habana_quantization_toolkit/_core/common.py +FP8_WHITE_LIST = ( + "Matmul", "Linear", "FalconLinear", "KVCache", "Conv2d", + "LoRACompatibleLinear", "LoRACompatibleConv", "Softmax", "ModuleFusedSDPA") +if importlib.util.find_spec("deepspeed"): + FP8_WHITE_LIST.append( + "LinearLayer", "LinearAllreduce","ScopedLinearAllReduce", "LmHeadLinearAllreduce") + @register_config(framework_name=FRAMEWORK_NAME, algo_name=FP8_QUANT) class FP8Config(TorchBaseConfig): """Config class for FP8 quantization.""" name = FP8_QUANT - supported_configs: List[OperatorConfig] = [] + + # tunable params params_list = [ - "w_dtype", - "w_observer", - "act_dtype", - "act_observer", - "approach", - "device", + "fp8_config", + "scale_method", + "observer", + "measure_exclude", ] def __init__( self, - w_dtype: str = "fp8_e4m3", - w_observer: Union[str, List[str]] = "minmax_per_channel", - act_dtype: str = "fp8_e4m3", - act_observer: Union[str, List[str]] = "minmax", - approach: Union[str, List[str]] = "static", - device: Union[str, List[str]] = "hpu", - white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, + dump_stats_path: str = "./hqt_output/measure", + fp8_config: str = "E4M3", + hp_dtype: torch.dtype = torch.bfloat16, + blocklist: dict = {'names': [], 'types': ()}, + allowlist: dict = {'names': [], 'types': FP8_WHITE_LIST}, + mode: str = "AUTO", + scale_method: str = "maxabs_hw", + scale_params: dict = {}, + observer: str = "maxabs", + mod_dict: dict = {}, + measure_exclude: str = "OUTPUT", + **kwargs, ): - """Init FP8 config. + """Init FP8 config.""" + super().__init__() + self.dump_stats_path =dump_stats_path + self.fp8_config = fp8_config + self.hp_dtype = hp_dtype + self.blocklist = blocklist + self.allowlist = allowlist + self.mode = mode + self.scale_method = scale_method + self.scale_params = scale_params + self.observer = observer + self.mod_dict = mod_dict + self._json_file = None + + @property + def measure(self): + return self.mode == "MEASURE" + + @property + def quantize(self): + return self.mode == "QUANTIZE" + + @property + def json_file(self): + if self._json_file is None: + import tempfile + from pathlib import Path + + json_file_tmp = tempfile.NamedTemporaryFile(suffix=".json") + self.to_json_file(json_file_tmp.name) + self.json_file(json_file_tmp.name) + return self._json_file + + @json_file.setter + def json_file(self, json_file): + self._json_file = json_file - Args: - """ - super().__init__(white_list=white_list) - self.w_dtype = w_dtype - self.w_observer = w_observer - self.act_dtype = act_dtype - self.act_observer = act_observer - self.approach = approach - self.device = device - self._post_init() + @classmethod + def from_json_file(cls, filename): + with open(filename, "r", encoding="utf-8") as file: + config_dict = json.load(file) + config = cls.from_dict(config_dict) + config.json_file = filename + return config @classmethod - def register_supported_configs(cls) -> List[OperatorConfig]: + def get_config_set_for_tuning(cls) -> Union[None, "FP8Config", List["FP8Config"]]: + # just a simple example here + # usually write parameter combinations that are more suitable to tune based on experience. + return FP8Config( + fp8_config=["E4M3", "E5M2"], + scale_method=["without_scale", "maxabs_hw"], + measure_exclude=["NONE", "OUTPUT"]) + + @classmethod + def register_supported_configs(cls): + """Add all supported configs.""" supported_configs = [] - fp8_config = FP8Config( - w_dtype=["fp8_e5m2", "fp8_e4m3"], - w_observer=["minmax", "minmax_per_channel"], - act_dtype=["fp8_e5m2", "fp8_e4m3"], - act_observer=["minmax", "kl"], - approach=["static", "dynamic"], - device=["hpu"], + linear_rtn_config = FP8Config( + mode=["AUTO", "MEASURE", "QUANTIZE"], + fp8_config=["E4M3", "E5M2"], + scale_method=["without_scale", "unit_scale", "max", "maxabs_hw", + "maxabs_pow2", "maxabs_hw_opt_weight", "maxabs_pow2_opt_weight", + "smoothquant_weights_output_channel_maxabs_pow2", + "weaksmoothquant_weights_output_channel_maxabs_pow2", + "act_maxabs_hw_weights_pcs_maxabs_pow2", + "act_maxabs_hw_weights_pcs_opt_pow2", + "act_maxabs_pow2_weights_pcs_maxabs_pow2", + "act_maxabs_pow2_weights_pcs_opt_pow2", + "smoothquant_opt"], + observer=["shape", "maxabs", "maxabs_per_channel", "save"], + measure_exclude=["NONE", "OUTPUT", "INPUT", "ALL"], ) - if is_hpex_available(): - from neural_compressor.torch.algorithms.habana_fp8 import white_list - - operators = white_list - else: - operators = () - supported_configs.append(OperatorConfig(config=fp8_config, operators=operators)) + operators = list(FP8_WHITE_LIST) + supported_configs.append(OperatorConfig(config=linear_rtn_config, operators=operators)) cls.supported_configs = supported_configs @staticmethod def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: - from neural_compressor.torch.algorithms.habana_fp8 import white_list - filter_result = [] for op_name, module in model.named_modules(): - if isinstance(module, white_list): - pair = (op_name, type(module).__name__) + if module.__class__.__name__ in FP8_WHITE_LIST or \ + module.__class__.__name__.split("Patched")[-1] in FP8_WHITE_LIST: + pair = (op_name, module.__class__.__name__) filter_result.append(pair) logger.debug(f"Get model info: {filter_result}") return filter_result - @classmethod - def get_config_set_for_tuning(cls) -> Union[None, "FP8Config", List["FP8Config"]]: - # TODO fwk owner needs to update it. - return FP8Config(act_observer=["minmax", "kl"]) + def to_config_mapping( + self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None + ): + config_mapping = OrderedDict() + if config_list is None: + config_list = [self] + for config in config_list: + for op_name, op_type in model_info: + config_mapping[(op_name, op_type)] = self + return config_mapping def get_default_fp8_config() -> FP8Config: diff --git a/neural_compressor/torch/quantization/quantize.py b/neural_compressor/torch/quantization/quantize.py index 85e73d47078..08e8d7c889d 100644 --- a/neural_compressor/torch/quantization/quantize.py +++ b/neural_compressor/torch/quantization/quantize.py @@ -20,7 +20,7 @@ from neural_compressor.common.base_config import BaseConfig, ComposableConfig, config_registry from neural_compressor.common.utils import Mode, call_counter, log_process -from neural_compressor.torch.quantization.config import SmoothQuantConfig, StaticQuantConfig +from neural_compressor.torch.quantization.config import SmoothQuantConfig, StaticQuantConfig, FP8Config from neural_compressor.torch.utils import is_ipex_available, logger from neural_compressor.torch.utils.utility import WHITE_MODULE_LIST, algos_mapping, get_model_info @@ -62,8 +62,8 @@ def quantize( assert isinstance( quant_config, BaseConfig ), f"Please pass a dict or config instance as the quantization configuration, but got {type(quant_config)}." - logger.info("Quantize model with config:") - logger.info(quant_config.to_dict()) + logger.debug("Quantize model with config:") + logger.debug(quant_config.to_dict()) # select quantization algo according to config if is_ipex_available and ( @@ -132,8 +132,8 @@ def prepare( assert isinstance( quant_config, BaseConfig ), f"Please pass a dict or config instance as the quantization configuration, but got {type(quant_config)}." - logger.info("Prepare model with config:") - logger.info(quant_config.to_dict()) + logger.debug("Prepare model with config:") + logger.debug(quant_config.to_dict()) # select quantization algo according to config if is_ipex_available and ( @@ -179,8 +179,9 @@ def convert( """ q_model = model if inplace else copy.deepcopy(model) - # TODO: Optimize the check for prepared flag after adding HQT FP8 Quant - assert getattr(model, "is_prepared", False), "Please run prepare function before convert." + assert ( + getattr(model, "is_prepared", False) or quant_config is not None + ), "Please pass quant_config to convert function." if getattr(model, "is_prepared", False): if quant_config is None: @@ -195,8 +196,8 @@ def convert( assert isinstance( quant_config, BaseConfig ), f"Please pass a dict or config instance as the quantization configuration, but got {type(quant_config)}." - logger.info("Convert model with config:") - logger.info(quant_config.to_dict()) + logger.debug("Convert model with config:") + logger.debug(quant_config.to_dict()) # select quantization algo according to config if is_ipex_available and ( @@ -220,3 +221,12 @@ def convert( ) setattr(q_model, "is_quantized", True) return q_model + + +def finalize_calibration(model): + if hasattr(model, "quant_config") and isinstance(model.quant_config, FP8Config): # FP8 + from neural_compressor.torch.algorithms.fp8_quant import save_calib_result + + save_calib_result(model) + else: + raise NotImplementedError("`finalize_calibration` only supports FP8 measurement now.") diff --git a/setup.py b/setup.py index a2392358572..ebabaa97b78 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,8 @@ def get_build_version(): return __version__ try: result = subprocess.run(["git", "describe", "--tags"], capture_output=True, text=True, check=True) - _, distance, commit = result.stdout.strip().split("-") + distance = result.stdout.strip().split("-")[-2] + commit = result.stdout.strip().split("-")[-1] return f"{__version__}.dev{distance}+{commit}" except subprocess.CalledProcessError: return __version__ diff --git a/test/3x/torch/amp/test_fp8_amp.py b/test/3x/torch/amp/test_fp8_amp.py deleted file mode 100644 index a5212467723..00000000000 --- a/test/3x/torch/amp/test_fp8_amp.py +++ /dev/null @@ -1,75 +0,0 @@ -import copy -import os -import shutil -import unittest - -import torch - -from neural_compressor.torch.amp import autocast -from neural_compressor.torch.utils import is_hpex_available - -# if not is_hpex_available(): -# exit() - - -class M(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.fc1 = torch.nn.Linear(10, 5) - self.fc2 = torch.nn.Linear(5, 10) - - def forward(self, inp): - x1 = self.fc1(inp) - x2 = self.fc2(x1) - x3 = torch.matmul(inp.T, x2) - x3 = x3.unsqueeze(0) - x3 = torch.bmm(x3, x3) - return x3 - - -@unittest.skipIf(not is_hpex_available(), "HPEX is required for HPU inference") -class TestPytorchFP8Adaptor(unittest.TestCase): - @classmethod - def setUpClass(self): - self.model = M().to("hpu") - self.inp = torch.randn(1, 10).to("hpu") - - @classmethod - def tearDownClass(self): - shutil.rmtree("./saved", ignore_errors=True) - shutil.rmtree("./.graph_dumps", ignore_errors=True) - shutil.rmtree("runs", ignore_errors=True) - - def test_autocast(self): - m = copy.deepcopy(self.model) - inp = self.inp - fp32_out = m(inp) - with autocast("hpu", dtype=torch.bfloat16) and torch.no_grad(): - bf16_out = m(inp) - print("BF16 MSE:", (bf16_out - fp32_out).pow(2).sum()) - - with autocast("hpu", dtype=torch.float8_e5m2) and torch.no_grad(): - e5m2_out = m(inp) - print("FP8_E5M2 MSE:", (e5m2_out - fp32_out).pow(2).sum()) - - with autocast("hpu", dtype=torch.float8_e4m3fn) and torch.no_grad(): - e4m3_out = m(inp) - print("FP8_E4M3 MSE:", (e4m3_out - fp32_out).pow(2).sum()) - - def test_autocast_use_amax(self): - os.environ["PT_USE_FP8_AMAX"] = str(1) - m = copy.deepcopy(self.model) - inp = self.inp - fp32_out = m(inp) - with autocast("hpu", dtype=torch.float8_e5m2) and torch.no_grad(): - e5m2_out = m(inp) - print("FP8_E5M2 using amax MSE:", (e5m2_out - fp32_out).pow(2).sum()) - - with autocast("hpu", dtype=torch.float8_e4m3fn) and torch.no_grad(): - e4m3_out = m(inp) - print("FP8_E4M3 using amax MSE:", (e4m3_out - fp32_out).pow(2).sum()) - os.environ.pop("PT_USE_FP8_AMAX", None) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/3x/torch/quantization/habana_fp8/test_fp8.py b/test/3x/torch/quantization/habana_fp8/test_fp8.py deleted file mode 100644 index 8fafc302f65..00000000000 --- a/test/3x/torch/quantization/habana_fp8/test_fp8.py +++ /dev/null @@ -1,189 +0,0 @@ -import copy -import shutil - -import pytest -import torch - -from neural_compressor.torch.utils import is_hpex_available - -if is_hpex_available(): - from neural_compressor.torch.algorithms.habana_fp8 import quantize_dynamic - from neural_compressor.torch.algorithms.habana_fp8.modules import ( - BatchMatmul, - FP8BatchMatmul, - FP8DynamicBatchMatmul, - FP8DynamicLinear, - FP8DynamicMatmul, - FP8Linear, - FP8Matmul, - Matmul, - ) - from neural_compressor.torch.quantization import ( - FP8Config, - TuningConfig, - autotune, - get_default_fp8_config, - get_default_fp8_config_set, - quantize, - ) - - torch.set_grad_enabled(False) - - -class M(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.fc1 = torch.nn.Linear(10, 5) - self.fc2 = torch.nn.Linear(5, 10) - self.mm = Matmul() - self.bmm = BatchMatmul() - - def forward(self, inp): - x1 = self.fc1(inp) - x2 = self.fc2(x1) - x3 = self.mm(inp.T, x2) - x3 = x3.unsqueeze(0) - x4 = self.mm(inp.T, x2) - x4 = x4.unsqueeze(0) + 1 ## SW-178838 - x5 = self.bmm(x3, x4) - x6 = self.bmm(x3, x4) - out = x5 + x6 - return out - - -@pytest.mark.skipif(not is_hpex_available(), reason="no hpex in environment here.") -class TestPytorchFP8Adaptor: - def setup_class(self): - self.model = M().to("hpu") - self.inp = torch.randn(1, 10).to("hpu") - self.fp32_out = self.model(self.inp) - - def teardown_class(self): - shutil.rmtree("./saved", ignore_errors=True) - shutil.rmtree("./.graph_dumps", ignore_errors=True) - shutil.rmtree("runs", ignore_errors=True) - - def test_dynamic_accu(self): - m = copy.deepcopy(self.model) - inp = self.inp - fp32_out = m(inp) - m = quantize_dynamic(m, dtype="fp8_e5m2", inplace=True) - assert isinstance(m.fc1, FP8DynamicLinear), "Unexpected result. Please double check." - assert isinstance(m.mm, FP8DynamicMatmul), "Unexpected result. Please double check." - assert isinstance(m.bmm, FP8DynamicBatchMatmul), "Unexpected result. Please double check." - print(m) - fp8_out = m(inp) - print("Dynamic quantization FP8_E5M2 MSE:", (fp32_out - fp8_out).pow(2).sum()) - - m = copy.deepcopy(self.model) - inp = self.inp - fp32_out = m(inp) - m = quantize_dynamic(m, dtype="fp8_e4m3", inplace=True) - assert isinstance(m.fc1, FP8DynamicLinear), "Unexpected result. Please double check." - assert isinstance(m.mm, FP8DynamicMatmul), "Unexpected result. Please double check." - assert isinstance(m.bmm, FP8DynamicBatchMatmul), "Unexpected result. Please double check." - print(m) - fp8_out = m(inp) - print("Dynamic quantization FP8_E4M3 MSE:", (fp32_out - fp8_out).pow(2).sum()) - - m = copy.deepcopy(self.model) - inp = self.inp - fp32_out = m(inp) - qconfig = FP8Config(approach="dynamic") - m = quantize(m, qconfig, inplace=True) - assert isinstance(m.fc1, FP8DynamicLinear), "Unexpected result. Please double check." - assert isinstance(m.mm, FP8DynamicMatmul), "Unexpected result. Please double check." - assert isinstance(m.bmm, FP8DynamicBatchMatmul), "Unexpected result. Please double check." - print(m) - fp8_out = m(inp) - print("Dynamic quantization FP8_E4M3 MSE:", (fp32_out - fp8_out).pow(2).sum()) - - @pytest.mark.parametrize("dtype", ["fp8_e5m2", "fp8_e4m3"]) - @pytest.mark.parametrize("w_observer", ["minmax", "minmax_per_channel"]) - @pytest.mark.parametrize("act_observer", ["minmax", "kl"]) - def test_static_accu(self, dtype, w_observer, act_observer): - m = copy.deepcopy(self.model) - inp = self.inp - qconfig = FP8Config( - w_dtype=dtype, w_observer=w_observer, act_dtype=dtype, act_observer=act_observer, approach="static" - ) - - def calib_func(model): - model(inp) - - m = quantize(m, qconfig, run_fn=calib_func, inplace=True) - assert isinstance(m.fc1, FP8Linear), "Unexpected result. Please double check." - assert isinstance(m.mm, FP8Matmul), "Unexpected result. Please double check." - assert isinstance(m.bmm, FP8BatchMatmul), "Unexpected result. Please double check." - fp8_out = m(inp) - print("Static quantization config:", dtype, w_observer, act_observer) - print("Static quantization MSE:", (self.fp32_out - fp8_out).pow(2).sum()) - - def test_convert(self): - # Temporary implementation of fp8 tensor saving and loading - # Will remove after Habana torch applies below patch: - # https://github.com/pytorch/pytorch/pull/114662 - # e4m3 - fp8_inp = torch.ops.hpu.cast_to_fp8_v2(self.inp, 500, dtype=torch.float8_e4m3fn)[0].to("cpu") - import fp8_convert - - int8_inp = fp8_convert.to_u8(fp8_inp) - torch.save(int8_inp, "tmp.pt") - saved_int8_inp = torch.load("tmp.pt") - recovered_inp = fp8_convert.from_u8(saved_int8_inp, 1) - assert (fp8_inp == recovered_inp).all(), "Unexpected result. Please double check." - # e5m2 - fp8_inp = torch.ops.hpu.cast_to_fp8_v2(self.inp, 500, dtype=torch.float8_e5m2)[0].to("cpu") - int8_inp = fp8_convert.to_u8(fp8_inp) - recovered_inp = fp8_convert.from_u8(int8_inp, 0) - assert (fp8_inp == recovered_inp).all(), "Unexpected result. Please double check." - - def test_save_load(self): - m = copy.deepcopy(self.model) - inp = self.inp - qconfig = get_default_fp8_config() - - def calib_func(model): - model(inp) - - m = quantize(m, qconfig, run_fn=calib_func, inplace=True) - fp8_out = m(inp) - m.save("saved_results") - - from neural_compressor.torch.quantization import load - - m = copy.deepcopy(self.model) - m = load("saved_results", m) - recovered_out = m(inp) - assert (recovered_out == fp8_out).all(), "Unexpected result. Please double check." - assert isinstance(m.fc1, FP8Linear), "Unexpected result. Please double check." - assert isinstance(m.mm, FP8Matmul), "Unexpected result. Please double check." - assert isinstance(m.bmm, FP8BatchMatmul), "Unexpected result. Please double check." - - def test_autotune(self): - m = copy.deepcopy(self.model) - inp = self.inp - fp32_out = m(inp) - - def calib_func(model): - model(inp) - - accu_list = [1.0, 0.9, 0.99] - - def eval_func(model): - nonlocal accu_list - return accu_list.pop() - - tune_config = TuningConfig( - config_set=get_default_fp8_config_set(), - tolerable_loss=0.01, - ) - best_model = autotune( - model=m, - tune_config=tune_config, - run_fn=calib_func, - eval_fns=eval_func, - ) - assert isinstance(best_model.fc1, FP8Linear), "Unexpected result. Please double check." - assert isinstance(best_model.mm, FP8Matmul), "Unexpected result. Please double check." - assert isinstance(best_model.bmm, FP8BatchMatmul), "Unexpected result. Please double check." From ca1444bd80c5a0b2ee16185cf6962e4b6a3f8c93 Mon Sep 17 00:00:00 2001 From: Uri Livne Date: Wed, 19 Jun 2024 15:05:12 +0300 Subject: [PATCH 04/51] [SW-189361] Fix white list extend Change-Id: Ic2021c248798fce37710d28014a6d59259c868a3 --- neural_compressor/torch/quantization/config.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index ecc18848e52..d8aefe1f3ff 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -1610,12 +1610,12 @@ def get_default_hqq_config() -> HQQConfig: ######################## FP8 Quant Config ############################### # refer to habana_quantization_toolkit/_core/common.py -FP8_WHITE_LIST = ( +FP8_WHITE_LIST = [ "Matmul", "Linear", "FalconLinear", "KVCache", "Conv2d", - "LoRACompatibleLinear", "LoRACompatibleConv", "Softmax", "ModuleFusedSDPA") + "LoRACompatibleLinear", "LoRACompatibleConv", "Softmax", "ModuleFusedSDPA"] if importlib.util.find_spec("deepspeed"): - FP8_WHITE_LIST.append( - "LinearLayer", "LinearAllreduce","ScopedLinearAllReduce", "LmHeadLinearAllreduce") + FP8_WHITE_LIST.extend( + ["LinearLayer", "LinearAllreduce","ScopedLinearAllReduce", "LmHeadLinearAllreduce"]) @register_config(framework_name=FRAMEWORK_NAME, algo_name=FP8_QUANT) class FP8Config(TorchBaseConfig): From dfec104431916e5d0f6e6de0cf411c45d3f6e7b7 Mon Sep 17 00:00:00 2001 From: Uri Livne Date: Wed, 3 Jul 2024 17:22:02 +0300 Subject: [PATCH 05/51] [SW-191317] Raise exception according to hqt config object Change-Id: I06ba8fa912c811c88912987c11e5c12ef328348a --- neural_compressor/torch/algorithms/fp8_quant/common.py | 8 +++++++- neural_compressor/torch/quantization/quantize.py | 4 ---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/neural_compressor/torch/algorithms/fp8_quant/common.py b/neural_compressor/torch/algorithms/fp8_quant/common.py index b038a367a78..4a603c677ac 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/common.py +++ b/neural_compressor/torch/algorithms/fp8_quant/common.py @@ -24,7 +24,13 @@ def save_calib_result(model): import habana_quantization_toolkit as hqt - hqt.finish_measurements(model) + if (hasattr(model, "__hqt_config__") and + isinstance(model.__hqt_config__, hqt._quant_common.quant_config.Fp8cfg)): + # TODO SW-184714 modify hqt notation to inc notation once code is ported + hqt.finish_measurements(model) + else: + raise NotImplementedError("Saving calibration results currently supported only in HPU.") + def update_mode(config_path, measure_step=False, quant_step=False): diff --git a/neural_compressor/torch/quantization/quantize.py b/neural_compressor/torch/quantization/quantize.py index 08e8d7c889d..5c161e5bb8b 100644 --- a/neural_compressor/torch/quantization/quantize.py +++ b/neural_compressor/torch/quantization/quantize.py @@ -224,9 +224,5 @@ def convert( def finalize_calibration(model): - if hasattr(model, "quant_config") and isinstance(model.quant_config, FP8Config): # FP8 from neural_compressor.torch.algorithms.fp8_quant import save_calib_result - save_calib_result(model) - else: - raise NotImplementedError("`finalize_calibration` only supports FP8 measurement now.") From 216d94b278cbb62b425fadc37161ecacd5f09449 Mon Sep 17 00:00:00 2001 From: Uri Livne Date: Sat, 6 Jul 2024 20:06:08 +0300 Subject: [PATCH 06/51] [SW-184714] Port HQT code into INC HQT lib content was copied as is under fp8_quant Tests were copied to 3.x torch location Change-Id: Iec6e1fa7ac4bf1df1c95b429524c40e32bc13ac9 --- .../torch/algorithms/fp8_quant/__init__.py | 1 + .../algorithms/fp8_quant/_core/__init__.py | 0 .../algorithms/fp8_quant/_core/common.py | 255 ++++++ .../algorithms/fp8_quant/_core/fp_utils.py | 172 ++++ .../algorithms/fp8_quant/_core/measure.py | 419 +++++++++ .../fp8_quant/_core/quant_dequant.py | 55 ++ .../algorithms/fp8_quant/_core/quantize.py | 96 +++ .../torch/algorithms/fp8_quant/_core/scale.py | 438 ++++++++++ .../fp8_quant/_core/scale_methods/__init__.py | 3 + .../fp8_quant/_core/scale_methods/max_abs.py | 397 +++++++++ .../_core/scale_methods/smooth_quant.py | 118 +++ .../_core/scale_methods/unit_scale.py | 52 ++ .../torch/algorithms/fp8_quant/_core/utils.py | 49 ++ .../fp8_quant/_quant_common/__init__.py | 0 .../fp8_quant/_quant_common/helper_modules.py | 812 ++++++++++++++++++ .../fp8_quant/_quant_common/quant_config.py | 250 ++++++ .../torch/algorithms/fp8_quant/common.py | 8 +- .../custom_config/custom_example.json | 5 + .../custom_config/llama_measure.json | 14 + .../fp8_quant/custom_config/llama_quant.json | 17 + .../custom_config/measure_config.json | 12 + .../fp8_quant/custom_config/quant_config.json | 13 + .../torch/algorithms/fp8_quant/fp8_quant.py | 7 +- .../fp8_quant/prepare_quant/__init__.py | 0 .../fp8_quant/prepare_quant/prepare_model.py | 36 + .../algorithms/fp8_quant/scripts/__init__.py | 0 .../scripts/regression_detection/__init__.py | 0 .../regression_detection/golden_metrics.json | 74 ++ .../regression_detection.py | 117 +++ .../algorithms/fp8_quant/utils/__init__.py | 0 .../algorithms/fp8_quant/utils/logger.py | 240 ++++++ .../3x/torch/algorithms/fp8_quant/__init__.py | 6 + .../3x/torch/algorithms/fp8_quant/conftest.py | 12 + .../torch/algorithms/fp8_quant/fp8_tests.py | 174 ++++ test/3x/torch/algorithms/fp8_quant/pytest.ini | 3 + .../fp8_quant/test_jsons/test_hw_quant.json | 16 + ...st_hw_quant_ignored_unmeasured_models.json | 17 + .../fp8_quant/test_jsons/test_measure.json | 13 + .../fp8_quant/test_jsons/test_pow2_quant.json | 16 + .../fp8_quant/test_jsons/test_unit_quant.json | 16 + test/3x/torch/algorithms/fp8_quant/tester.py | 218 +++++ .../fp8_quant/unit_tests/__init__.py | 6 + .../fp8_quant/unit_tests/test_deepspeed.py | 86 ++ .../test_functions/test_config_json.py | 29 + .../test_functions/test_matmul_fp8.py | 71 ++ .../unit_tests/test_layers/test_conv2d.py | 40 + .../unit_tests/test_layers/test_linear.py | 33 + .../unit_tests/test_layers/test_matmul.py | 56 ++ 48 files changed, 4465 insertions(+), 7 deletions(-) create mode 100644 neural_compressor/torch/algorithms/fp8_quant/_core/__init__.py create mode 100644 neural_compressor/torch/algorithms/fp8_quant/_core/common.py create mode 100644 neural_compressor/torch/algorithms/fp8_quant/_core/fp_utils.py create mode 100644 neural_compressor/torch/algorithms/fp8_quant/_core/measure.py create mode 100644 neural_compressor/torch/algorithms/fp8_quant/_core/quant_dequant.py create mode 100644 neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py create mode 100644 neural_compressor/torch/algorithms/fp8_quant/_core/scale.py create mode 100644 neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/__init__.py create mode 100644 neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/max_abs.py create mode 100644 neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/smooth_quant.py create mode 100644 neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/unit_scale.py create mode 100644 neural_compressor/torch/algorithms/fp8_quant/_core/utils.py create mode 100644 neural_compressor/torch/algorithms/fp8_quant/_quant_common/__init__.py create mode 100644 neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py create mode 100644 neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py create mode 100644 neural_compressor/torch/algorithms/fp8_quant/custom_config/custom_example.json create mode 100644 neural_compressor/torch/algorithms/fp8_quant/custom_config/llama_measure.json create mode 100644 neural_compressor/torch/algorithms/fp8_quant/custom_config/llama_quant.json create mode 100755 neural_compressor/torch/algorithms/fp8_quant/custom_config/measure_config.json create mode 100755 neural_compressor/torch/algorithms/fp8_quant/custom_config/quant_config.json create mode 100644 neural_compressor/torch/algorithms/fp8_quant/prepare_quant/__init__.py create mode 100644 neural_compressor/torch/algorithms/fp8_quant/prepare_quant/prepare_model.py create mode 100644 neural_compressor/torch/algorithms/fp8_quant/scripts/__init__.py create mode 100644 neural_compressor/torch/algorithms/fp8_quant/scripts/regression_detection/__init__.py create mode 100644 neural_compressor/torch/algorithms/fp8_quant/scripts/regression_detection/golden_metrics.json create mode 100644 neural_compressor/torch/algorithms/fp8_quant/scripts/regression_detection/regression_detection.py create mode 100644 neural_compressor/torch/algorithms/fp8_quant/utils/__init__.py create mode 100644 neural_compressor/torch/algorithms/fp8_quant/utils/logger.py create mode 100644 test/3x/torch/algorithms/fp8_quant/__init__.py create mode 100644 test/3x/torch/algorithms/fp8_quant/conftest.py create mode 100644 test/3x/torch/algorithms/fp8_quant/fp8_tests.py create mode 100644 test/3x/torch/algorithms/fp8_quant/pytest.ini create mode 100644 test/3x/torch/algorithms/fp8_quant/test_jsons/test_hw_quant.json create mode 100644 test/3x/torch/algorithms/fp8_quant/test_jsons/test_hw_quant_ignored_unmeasured_models.json create mode 100644 test/3x/torch/algorithms/fp8_quant/test_jsons/test_measure.json create mode 100644 test/3x/torch/algorithms/fp8_quant/test_jsons/test_pow2_quant.json create mode 100644 test/3x/torch/algorithms/fp8_quant/test_jsons/test_unit_quant.json create mode 100644 test/3x/torch/algorithms/fp8_quant/tester.py create mode 100644 test/3x/torch/algorithms/fp8_quant/unit_tests/__init__.py create mode 100644 test/3x/torch/algorithms/fp8_quant/unit_tests/test_deepspeed.py create mode 100644 test/3x/torch/algorithms/fp8_quant/unit_tests/test_functions/test_config_json.py create mode 100644 test/3x/torch/algorithms/fp8_quant/unit_tests/test_functions/test_matmul_fp8.py create mode 100644 test/3x/torch/algorithms/fp8_quant/unit_tests/test_layers/test_conv2d.py create mode 100644 test/3x/torch/algorithms/fp8_quant/unit_tests/test_layers/test_linear.py create mode 100644 test/3x/torch/algorithms/fp8_quant/unit_tests/test_layers/test_matmul.py diff --git a/neural_compressor/torch/algorithms/fp8_quant/__init__.py b/neural_compressor/torch/algorithms/fp8_quant/__init__.py index d16760b5e81..bea97db811c 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/__init__.py +++ b/neural_compressor/torch/algorithms/fp8_quant/__init__.py @@ -18,4 +18,5 @@ restore_patched_module, with_patched_module, ) +from neural_compressor.torch.algorithms.fp8_quant.prepare_quant.prepare_model import finish_measurements, prep_model from neural_compressor.torch.algorithms.fp8_quant.fp8_quant import FP8Quantizer diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/__init__.py b/neural_compressor/torch/algorithms/fp8_quant/_core/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/common.py b/neural_compressor/torch/algorithms/fp8_quant/_core/common.py new file mode 100644 index 00000000000..c155146dcc6 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/common.py @@ -0,0 +1,255 @@ +import os +import torch +import json +import numpy as np +import functools +import importlib.util + +from .._quant_common.helper_modules import * +from .._quant_common.quant_config import get_hqt_config +from ..utils.logger import logger + +deepspeed_exists = False +if importlib.util.find_spec("deepspeed"): # check if deepspeed is installed + deepspeed_exists = True + +UNMEASURED_MODELS = "UnmeasuredModels" + + +class ModuleInfo: + def __init__(self, type, patched_module): + self.type = type + self.patched_module = patched_module + + +class ModuleConfig: + def __init__(self, inputs=(None,), outputs=(None,), params=None): + self.inputs = inputs + self.outputs = outputs + self.params = params if params is not None else {} + + +class ModuleExtraConfig: + def __init__(self, inputs=(None,), outputs=(None,), params=None, scale=None, config_params=None): + self.inputs = inputs + self.outputs = outputs + self.params = params if params is not None else {} + self.scale = scale + self.config_params = config_params if config_params is not None else {} + + +class ModuleType: + def __init__(self, num_inputs, param_names, num_outputs, required_output): + self.num_inputs = num_inputs + self.param_names = param_names + self.num_outputs = num_outputs + self.required_output = required_output + + +mod_types = { + "linear": ModuleType(1, ["weight"], 1, False), + "matmul": ModuleType(2, [], 1, False), + "kv_cache": ModuleType(1, [], 1, False), + "softmax": ModuleType(1, [], 1, True), + "fused_sdpa": ModuleType(3, [], 2, True), +} +descale_fcn = lambda x, scale: torch.mul(x, scale) +scale_fcn = lambda x, scale: torch.div(x, scale) +mat_scale_fcn = lambda x, scale_col, scale_row: torch.div(torch.div(x, scale_col), scale_row) +cast_fcn = lambda x, dtype: x.to(dtype=dtype) +cast_to_fp8_fcn = lambda x, dtype, scale_inv=None: torch.ops.hpu.cast_to_fp8_v2(x, scale_inv, False, False, dtype)[0] +cast_from_fp8_fcn = lambda x, dtype, scale=None: torch.ops.hpu.cast_from_fp8(x, scale, dtype) + + +class ShapeList: + data = None + + +def rec_fn(x, fn): + if isinstance(x, dict): + return {k: rec_fn(x[k], fn) for k in x} + elif isinstance(x, list): + return [rec_fn(k, fn) for k in x] + elif isinstance(x, tuple): + return tuple([rec_fn(k, fn) for k in x]) + else: + return fn(x) + + +def np_to_pt(x): + return rec_fn(x, lambda x: torch.tensor(x) if isinstance(x, np.ndarray) else x) + + +def pt_to_np(x): + return rec_fn( + x, + lambda x: (x.detach().cpu().float().numpy() if isinstance(x, torch.Tensor) else x), + ) + + +def np_to_list(x): + return rec_fn(x, lambda x: x.tolist() if isinstance(x, np.ndarray) else x) + + +def list_to_np(x): + return rec_fn(x, lambda x: np.array(x) if isinstance(x, list) else x) + + +def save_json(d, fname): + with open(fname, "w") as f: + json.dump(d, f, indent=4) + + +def load_json(fname): + with open(fname, "r") as f: + d = json.load(f) + return d + + +def save_npz(d, fname): + np.savez(fname, d) + + +def load_npz(fname): + d = np.load(fname, allow_pickle=True) + return d["arr_0"].item() + + +def save_file(model, d, source_format, fname, mode): + config = get_hqt_config(model) + logger.debug("Saving %s file: %s", mode, fname) + ext = os.path.splitext(fname)[1] + target_format = file_functions[ext][0] + dc = rec_fn(d, format_functions[(source_format, target_format)]) + df = { + "GlobalRank": config.cfg["global_rank"], + "LocalRank": config.cfg["local_rank"], + "Mode": mode, + "Nodes": dc, + } + try: + file_functions[ext][1](df, fname) + except: + pass + + +# convert module config data to other format +def module_convert(m, fcn): + mt = ModuleConfig( + tuple([fcn(x) for x in m.inputs]), + tuple([fcn(m.outputs)],) if type(m.outputs) == np.ndarray else tuple([fcn(y) for y in m.outputs]), + {k: fcn(m.params[k]) for k in m.params}, + ) + return mt + + +def fix_fields(d): + if "input" in d: + d["inputs"] = d.pop("input") + if "output" in d: + d["outputs"] = d.pop("output") + return d + + +def load_file(fname, target_format, fail_on_file_not_exist): + logger.debug("Loading file: %s", fname) + ext = os.path.splitext(fname)[1] + source_format = file_functions[ext][0] + d = {} + if os.path.isfile(fname): + d = file_functions[ext][2](fname) + elif fail_on_file_not_exist: + raise FileNotFoundError(f"Failed to load file {fname}") + if "Nodes" in d: + dc = {k: ModuleConfig(**fix_fields(d["Nodes"][k])) for k in d["Nodes"]} + dc = {k: module_convert(dc[k], format_functions[(source_format, target_format)]) for k in dc} + else: + dc = {} + return dc + + +def save_scales(model, d, source_format, fname): + dc = {k: d[k].__dict__ for k in d} + save_file(model, dc, source_format, fname, "Scale") + + +def load_scales(fname, target_format): + logger.debug("Loading scales file %s", fname) + d = load_file(fname, target_format, False) + return d + + +def convert_scales_to_tensors_dict(scales_obj, scales_file_format, hp_dtype): + scales_temp = {k: scales_obj[k].__dict__ for k in scales_obj} + scales_temp = format_functions_rec((scales_file_format, torch.Tensor))(scales_temp) + scales_temp = rec_fn(scales_temp, lambda x: x.to(dtype=hp_dtype, device="hpu")) + scales = {k: ModuleConfig(**scales_temp[k]) for k in scales_temp} + return scales + + +file_functions = { + ".json": (list, save_json, load_json), + ".npz": (np.ndarray, save_npz, load_npz), +} + +format_functions = { + (torch.Tensor, torch.Tensor): lambda x: x, + (np.ndarray, np.ndarray): lambda x: x, + (list, list): lambda x: x, + (torch.Tensor, np.ndarray): lambda x: x.detach().cpu().float().numpy(), + (torch.Tensor, list): lambda x: x.detach().cpu().float().numpy().tolist(), + (np.ndarray, torch.Tensor): torch.tensor, + (np.ndarray, list): lambda x: x.tolist(), + (list, torch.Tensor): torch.tensor, + (list, np.ndarray): lambda x: np.array(x), + (list, ShapeList): lambda x: [int(s) for s in x[0]], +} + + +format_functions_rec = lambda k: functools.partial(rec_fn, fn=format_functions[k]) + +mod_default_dict = { + "Matmul": ModuleInfo("matmul", PatchedMatmul), + "Linear": ModuleInfo("linear", PatchedLinear), + "RowParallelLinear": ModuleInfo("linear", PatchedRowParallelLinear), + "ColumnParallelLinear": ModuleInfo("linear", PatchedColumnParallelLinear), + "MergedColumnParallelLinear": ModuleInfo("linear", PatchedColumnParallelLinear), + "QKVParallelLinear": ModuleInfo("linear", PatchedColumnParallelLinear), + "FalconLinear": ModuleInfo("linear", PatchedLinear), + "KVCache": ModuleInfo("kv_cache", PatchedKVCache), + "VLLMKVCache": ModuleInfo("kv_cache", PatchedVLLMKVCache), + "Conv2d": ModuleInfo("linear", PatchedConv2d), + "LoRACompatibleLinear": ModuleInfo("linear", PatchedLoRACompatibleLinear), + "LoRACompatibleConv": ModuleInfo("linear", PatchedLoRACompatibleConv), + "Softmax": ModuleInfo("softmax", PatchedSoftmax), + "ModuleFusedSDPA": ModuleInfo("fused_sdpa", PatchedModuleFusedSDPA), +} + + +if deepspeed_exists: + mod_default_dict.update( + { + "LinearLayer": ModuleInfo("linear", PatchedLinear), + "LinearAllreduce": ModuleInfo("linear", PatchedLinearAllReduce), + "ScopedLinearAllReduce": ModuleInfo("linear", PatchedLinearAllReduce), + "LmHeadLinearAllreduce": ModuleInfo("linear", PatchedLmHeadLinearAllreduce), + } + ) + + +class ModInstInfo: + def __init__(self, name, parent): + self.name = name + self.parent = parent + + +parent_child_mod_dict = {} + + +def generate_model_info(model): + def create_mod_info_recursion(parent): + for name, mod in parent.named_children(): + parent_child_mod_dict[mod] = ModInstInfo(name, parent) + create_mod_info_recursion(mod) + + create_mod_info_recursion(model) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/fp_utils.py b/neural_compressor/torch/algorithms/fp8_quant/_core/fp_utils.py new file mode 100644 index 00000000000..14f54d4eaa8 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/fp_utils.py @@ -0,0 +1,172 @@ +import torch +import habana_frameworks.torch.core as htcore +import habana_frameworks.torch.utils.experimental as htexp +from .common import * + +GAUDI2 = htexp.synDeviceType.synDeviceGaudi2 +GAUDI3 = htexp.synDeviceType.synDeviceGaudi3 + +EXP_WIDTH = { + torch.float32: 8, + torch.bfloat16: 8, + torch.float8_e4m3fn: 4, + torch.float8_e5m2: 5, +} + + +def get_default_exp_bias(dtype): + exp_width = EXP_WIDTH[dtype] + return 2 ** (exp_width - 1) - 1 + + +EXP_BIAS_SETS = { + (GAUDI2, torch.float8_e4m3fn): [3, 7, 11, 15], + (GAUDI2, torch.float8_e5m2): [15], + (GAUDI3, torch.float8_e4m3fn): range(0, 63), + (GAUDI3, torch.float8_e5m2): range(0, 63), +} + +MAX_RANGE = { + torch.float32: 2 ** ((2**8 - 2 - get_default_exp_bias(torch.float32))) * (2 - 2 ** -(23)), + torch.bfloat16: 2 ** ((2**8 - 2 - get_default_exp_bias(torch.bfloat16))) * (2 - 2 ** -(7)), + torch.float8_e4m3fn: 2 ** ((2**4 - 2 - get_default_exp_bias(torch.float8_e4m3fn))) * (2 - 2 ** -(8 - 1 - 4)), + torch.float8_e5m2: 2 ** ((2**5 - 2 - get_default_exp_bias(torch.float8_e5m2))) * (2 - 2 ** -(8 - 1 - 5)), +} + + +def get_fullscale(dtype, exp_bias=None): + default_exp_bias = get_default_exp_bias(dtype) + fullscale = MAX_RANGE[dtype] + exp_bias = default_exp_bias if exp_bias == None else exp_bias + fullscale = fullscale * (2 ** (default_exp_bias - exp_bias)) + return fullscale + + +def get_fullscales_by_expbias_set(dtype, expbias_set): + return [get_fullscale(dtype, exp_bias=eb) for eb in expbias_set] + + +def get_fp8_hw_alligned_scales(dtype, device): + exp_bias_set = EXP_BIAS_SETS.get((device, dtype), None) + return ( + None + if exp_bias_set == None + else [x / MAX_RANGE[dtype] for x in get_fullscales_by_expbias_set(dtype, exp_bias_set)] + ) + + +DEVICES_SCALE_FACTORS = { + htexp.synDeviceType.synDeviceGaudi2: 4, + htexp.synDeviceType.synDeviceGaudi3: 1, +} +FP8_143_SCALES = { + device: get_fp8_hw_alligned_scales(torch.float8_e4m3fn, device) for device in DEVICES_SCALE_FACTORS.keys() +} +FP8_143_SCALES_TRAITS = { + device: ( + min(FP8_143_SCALES[device]), + max(FP8_143_SCALES[device]), + DEVICES_SCALE_FACTORS[device], + ) + for device in DEVICES_SCALE_FACTORS.keys() +} + + +def calc_maxabs_scale(xmaxabs, fullscale, backoff=1): + scale = xmaxabs / (fullscale * backoff) + return scale + + +def scale_to_pow2(scale): + scale_pow2 = 2 ** torch.ceil(torch.log2(scale)) + return scale_pow2 + + +# Considering range of hw alligned scales: 2^a, 2^a+1,..., 2^b (a=2^b then s=2^b, therefor min(_, 2^b) +# if m<=2^a then s=2^a, therefor max(_, 2^a) --> 2^a <= min(max(_,2^a),2^b) <=2^b +# if s^a 0: + sd[mname]["params"] = dict() + sdl[mname]["params"] = dict() + for param_name in mcd[mname].params: + if mcd[mname].params[param_name].state is not None: + sd[mname]["params"][param_name] = ( + mcd[mname].params[param_name].state.detach().cpu().float().numpy() + ) + sdl[mname]["params"][param_name] = ( + mcd[mname].params[param_name].state.detach().cpu().float().numpy().tolist() + ) + return sd, sdl + + +def save_measurements(model, fname=None): + config = get_hqt_config(model).cfg + if config["mode"] in [QuantMode.MEASURE, QuantMode.SHAPE]: + if fname is None: + if ("measure_file" in config) and (config["measure_file"] is not None): + fname_base = config["measure_file"] + measure_type = "DynamicRange" + elif ("shape_file" in config) and (config["shape_file"] is not None) and (config["observer"] == "shape"): + fname_base = config["shape_file"] + measure_type = "Shape" + fname_np = fname_base + ".npz" + fname_list = fname_base + ".json" + else: + logger.warning("'fname' is not None - Measurements/Shapes will not be saved") + return + mcd = get_mod_extra_config_dict(model) + sd, sdl = measure_control_to_state_dict(mcd) + + logger.info("Dumping measurements") + save_file(model, sd, np.ndarray, fname_np, measure_type) + save_file(model, sdl, list, fname_list, measure_type) + save_json(gmod_list, fname_base + "_mod_list.json") + + +def load_measurements(model, fname): + config = get_hqt_config(model).cfg + source_fname = fname if fname is not None else config["measure_file"] + fname_np = source_fname + ".npz" + d = load_file( + fname_np, + np.ndarray, + fail_on_file_not_exist=(config["scale_method"] != ScaleMethod.UNIT_SCALE), + ) + from collections import defaultdict + + d = defaultdict(lambda: None, d) + + return d + + +def get_default_config(mod_list): + config = {k: "default" for k in mod_list} + return config + + +def save_json(d, fname): + with open(fname, "w") as f: + json.dump(d, f, indent=4) + + +def load_json(fname): + with open(fname, "r") as f: + d = json.load(f) + return d + + +class MaxAbsObserver: + def __init__(self, name, mod, d_shape=None, params=None): + self.name = name + self.mod = mod + self.first = True + self.used = False + self.state = self.init_state_from_shape(d_shape) + + def init_state(self, x): + device = x.device + state = torch.zeros((1, 1), device=device, dtype=torch.float32) + self.shape = list(x.shape) + return state + + def init_state_from_shape(self, x_shape, device="hpu"): + state = torch.zeros((1, 1), device=device, dtype=torch.float32) + self.first = False + return state + + def update_state(self, x): + # TODO: [SW-189690] Find better way to update self.state in MaxAbsObserver class in HQT + self.state = torch.maximum(torch.max(torch.abs(x)), self.state) + + def measure(self, x): + if self.first: + self.state = self.init_state(x) + self.first = False + self.update_state(x) + self.used = True + + def is_used(self): + return self.used + + +class MaxAbsPerChannelObserver: + def __init__(self, name, mod, d_shape=None, params=None): + self.name = name + self.mod = mod + self.first = True + self.state = None + self.used = False + self.dim = params["dim"] if (params is not None) and ("dim" in params) else -1 + if d_shape is not None: + p = list(range(len(d_shape))) + self.dim = self.dim if self.dim >= 0 else len(d_shape) + self.dim + p[-1] = self.dim + p[self.dim] = len(d_shape) - 1 + self.p = p + self.state = self.init_state_from_shape(d_shape) + + def init_state(self, x): + device = x.device + Nch = x.shape[self.dim] + self.Nch = Nch + state = torch.zeros((Nch, 1), device=device, dtype=torch.float32) + self.shape = list(x.shape) + return state + + def init_state_from_shape(self, x_shape, device="hpu"): + device = device + Nch = x_shape[self.dim] + self.Nch = Nch + state = torch.zeros((Nch, 1), device=device, dtype=torch.float32) + self.first = False + return state + + def update_state(self, x): + self.state.copy_( + torch.maximum( + torch.max( + torch.abs(x.permute(self.p).reshape([-1, self.Nch])), + dim=0, + keepdim=True, + )[0].t(), + self.state, + ) + ) + + def measure(self, x): + if self.first: + self.state = self.init_state(x) + self.first = False + self.update_state(x) + self.used = True + + def is_used(self): + return self.used + + +def save_module(mod): + folder_name = os.path.join(mod.config["dump_stats_base_path"], "tensors") + os.makedirs(folder_name, exist_ok=True) + file_base_name = os.path.join(folder_name, imod_dict[mod] + "_module.pt") + torch.save(mod.state_dict(), file_base_name) + + +class SaveObserver: + def __init__(self, name, mod, d_shape=None, params=None): + self.name = name + self.mod = mod + self.first = True + self.cnt = -1 + self.folder_name = os.path.join(config["dump_stats_base_path"], "tensors") + os.makedirs(self.folder_name, exist_ok=True) + self.file_base_name = os.path.join(self.folder_name, imod_dict[mod] + "_" + name + "_iter") + self.state = self.init_state_from_shape(d_shape) + self.used = False + + def init_state(self, x): + device = x.device + state = torch.zeros((1, 1), device=device, dtype=torch.float32) + self.shape = list(x.shape) + return state + + def init_state_from_shape(self, x_shape, device="hpu"): + state = torch.zeros((1, 1), device=device, dtype=torch.float32) + self.first = False + return state + + def update_state(self, x): + self.cnt += 1 + torch.save(x, self.file_base_name + str(self.cnt) + ".pt") + + def measure(self, x): + self.update_state(x) + self.used = True + + def is_used(self): + return self.used + + +class ShapeObserver: + def __init__(self, name, mod, d_shape=None, params=None): + self.name = name + self.mod = mod + self.state = None + + def init_state(self, x): + device = x.device + Ndim = len(x.shape) + self.Ndim = Ndim + state = torch.tensor(x.shape, device=device, dtype=torch.int32).reshape((1, Ndim)) + return state + + def init_state_from_shape(self, x_shape, device="hpu"): + logger.info("ShapeObserver doesn't support init_state_from_shape") + return + + def update_state(self, x): + logger.info("ShapeObserver doesn't support update_state") + return + + def measure(self, x): + self.state = self.init_state(x) + + def is_used(self): + return self.state is not None + + +observer_types = { + "shape": ShapeObserver, + "maxabs": MaxAbsObserver, + "maxabs_per_channel": MaxAbsPerChannelObserver, + "save": SaveObserver, +} + +observer_params = { + "maxabs_per_channel": { + "linear": ModuleConfig(({"dim": -1},), ({"dim": -1},), {"weight": {"dim": 0}}), + "matmul": ModuleConfig(({"dim": -1}, {"dim": -2},), ({"dim": -1},), None), + } +} diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/quant_dequant.py b/neural_compressor/torch/algorithms/fp8_quant/_core/quant_dequant.py new file mode 100644 index 00000000000..50b604b7d89 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/quant_dequant.py @@ -0,0 +1,55 @@ +import torch.nn as nn +from abc import abstractmethod +from .common import * + + +class QuantDequantBase(nn.Module): + def __init__(self, lp_dtype, hp_dtype="", *args, **kwargs): + super(QuantDequantBase, self).__init__(*args, **kwargs) + self.lp_dtype = lp_dtype + self.hp_dtype = hp_dtype + + @abstractmethod + def forward(self, *args, **kwargs): + pass + + def extra_repr(self) -> str: + return f"lp_dtype={self.lp_dtype}, hp_dtype={self.hp_dtype}" + + +class QuantDequantNone(QuantDequantBase): + def __init__(self, lp_dtype, hp_dtype, *args, **kwargs): + super(QuantDequantNone, self).__init__(lp_dtype, hp_dtype, *args, **kwargs) + + def forward(self, *args, **kwargs): + return args[0] + + def extra_repr(self) -> str: + repr = super(QuantDequantNone, self).extra_repr() + return f"{repr}, doesn't quantize nor dequantize" + + +class QuantInput(QuantDequantBase): + def __init__(self, scale_inv, lp_dtype, hp_dtype, *args, **kwargs): + super(QuantInput, self).__init__(lp_dtype, hp_dtype, *args, **kwargs) + self.scale_inv = nn.Parameter(scale_inv) + + def forward(self, x): + return cast_to_fp8_fcn(x, self.lp_dtype, self.scale_inv) + + def extra_repr(self) -> str: + repr = super(QuantInput, self).extra_repr() + return f"{repr}, scale_inv dtype={self.scale_inv.dtype}" + + +class DequantOutput(QuantDequantBase): + def __init__(self, scale, lp_dtype, hp_dtype, *args, **kwargs): + super(DequantOutput, self).__init__(lp_dtype, hp_dtype, *args, **kwargs) + self.scale = nn.Parameter(scale) + + def forward(self, x): + return cast_from_fp8_fcn(x, self.hp_dtype, self.scale) + + def extra_repr(self) -> str: + repr = super(DequantOutput, self).extra_repr() + return f"{repr}, scale dtype={self.scale.dtype}" diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py b/neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py new file mode 100644 index 00000000000..76ee0a1d635 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py @@ -0,0 +1,96 @@ +import torch +import torch.nn as nn +import habana_frameworks.torch.core as htcore +from .._quant_common.quant_config import get_hqt_config +from .._quant_common.helper_modules import PatchedUnmeasuredModule +from .measure import load_measurements +from .scale import scale_method_mapping, get_config, scaling_methods +from .common import ( + mod_default_dict, + generate_model_info, + parent_child_mod_dict, + UNMEASURED_MODELS, +) +from ..utils.logger import logger + + +def patch_module(mod, qconfig, mod_dict, patched_mod=None): + parent = parent_child_mod_dict[mod].parent + name = parent_child_mod_dict[mod].name + if patched_mod is None: + patched_mod = mod_dict[mod.__class__.__name__].patched_module(mod, qconfig) + setattr(parent, name, patched_mod) + + +def apply_hf_hook(module): + if hasattr(module, "_hf_hook"): + module._hf_hook.pre_forward(module) + module._hf_hook.detach_hook(module) + delattr(module, "_hf_hook") + if hasattr(module, "_old_forward"): + module.forward = module._old_forward + delattr(module, "_old_forward") + + +def quantize_params(mod, mod_extra_config): + for param_name in mod_extra_config.params: + quantizer = mod_extra_config.params[param_name] + param = getattr(mod, param_name) + quantized_param = quantizer(param.to("hpu")) + delattr(mod, param_name) + setattr(mod, param_name, nn.Parameter(quantized_param)) + quantized_param = getattr(mod, param_name) + quantized_param.requires_grad_(False) + htcore.mark_step() + + +def prepare_model(model, qconfig, mod_list, hp_dtype=torch.float): + config = get_hqt_config(model) + patched_modules = [] + patched_module_types = set() + with torch.no_grad(): + for name, mod in model.named_modules(): + if name in qconfig[UNMEASURED_MODELS]: + if not config.cfg["ignore_modules_wo_measures"]: + patch_module(mod, None, None, PatchedUnmeasuredModule(name)) + else: + logger.debug("Module %s was not quantized.", name) + continue + # When offloading weight to disk, need to transfer the weight from disk to cpu using hf_hook + apply_hf_hook(mod) + if name in mod_list: + mod_extra_config = qconfig[name] + quantize_params(mod, mod_extra_config) + patch_module(mod, mod_extra_config, mod_default_dict) + patched_modules.append(name) + patched_module_types.add(type(mod)) + logger.debug("Patched module types: %s", patched_module_types) + logger.debug("Patched modules: %s", patched_modules) + logger.debug("Total patched modules: %d", len(patched_modules)) + model = model.to("hpu") + htcore.mark_step() + + +def quantize(model, mod_list): + config = get_hqt_config(model) + generate_model_info(model) + hp_dtype = config.cfg["hp_dtype"] + lp_dtype = config.cfg["fp8_config"] + measurement = load_measurements(model, config.cfg["measure_file"]) + # FIXME make sure this takes unit_scale or measured scale, from Configs + scaling_method_name = scale_method_mapping[(config.cfg["scale_method"], config.cfg["observer"])] + scaling_method = scaling_methods[scaling_method_name] + params = config.cfg["scale_params"] + params["hp_dtype"] = hp_dtype + params["lp_dtype"] = lp_dtype + qconfig = get_config( + model, + measurement, + mod_default_dict, + scaling_method, + params, + config.cfg["scale_file"], + False, + mod_list, + ) + prepare_model(model, qconfig, mod_list, hp_dtype=hp_dtype) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/scale.py b/neural_compressor/torch/algorithms/fp8_quant/_core/scale.py new file mode 100644 index 00000000000..a85c79b660b --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/scale.py @@ -0,0 +1,438 @@ +import torch +import numpy as np + +from .._quant_common.quant_config import ScaleMethod +from .scale_methods import * +from .quant_dequant import * + +from .fp_utils import * +from .common import * +from ..utils.logger import logger + + +def matmul_scales_to_mod_config(mod, scales, params): + scales_inv = invert_scales(scales) + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + input_config = [QuantInput(s_inv, lp_dtype, hp_dtype) for s_inv in scales_inv.inputs] + # outputs as bf16, and descaled in gemm under PatchedMatmul, so no need to work here + output_config = [QuantDequantNone(lp_dtype, hp_dtype)] + config = ModuleConfig(input_config, output_config, {}) + return config + + +def fsdpa_scales_to_mod_config(mod, scales, params): + scales_inv = invert_scales(scales) + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + input_config = [QuantInput(s_inv, lp_dtype, hp_dtype) for s_inv in scales_inv.inputs] + output_config = [DequantOutput(scales.outputs[0], lp_dtype, hp_dtype)] + config = ModuleConfig(input_config, output_config, {}) + return config + + +def linear_scales_to_mod_config(mod, scales, params): + scales_inv = invert_scales(scales) + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + input_config = [QuantInput(scales_inv.inputs[0], lp_dtype, hp_dtype)] + # outputs as bf16, and descaled in gemm under PatchedLinear, so no need to work here + output_config = [QuantDequantNone(lp_dtype, hp_dtype)] + if isinstance(scales_inv.params["weight"], (torch.Tensor, float)): + weight_config = QuantInput(scales_inv.params["weight"], lp_dtype, hp_dtype) + elif isinstance(scales_inv.params["weight"], dict): + weight_scale_inv_out_ch = scales_inv.params["weight"][0] + weight_scale_inv_in_ch = scales_inv.params["weight"][1] + if isinstance(weight_scale_inv_out_ch, torch.Tensor): + scale_inv = torch.mul( + weight_scale_inv_in_ch.reshape([1, -1]), + weight_scale_inv_out_ch.reshape([-1, 1]), + ) + else: + # TODO SW-169781: Handle here scalar weight for PCQ + raise TypeError(f"Unknown weight scales type: {type(weight_scale_inv_out_ch)}.") + weight_config = QuantInput(scale_inv, lp_dtype, hp_dtype) + else: + logger.error("Unknown weight scales format.") + params_config = {"weight": weight_config} + if hasattr(mod, "bias") and (getattr(mod, "bias") is not None): + # In PatchedLinear the bias is added to the output of gemm. + # The output is expected to be descaled and in bf16, so we don't need to touch the bias. + bias_config = QuantDequantNone(lp_dtype, hp_dtype) + params_config.update({"bias": bias_config}) + config = ModuleConfig(input_config, output_config, params_config) + return config + + +def kv_cache_scales_to_mod_config(mod, scales, params): + # how quant/dequant will be applied on layer tensors + scales_inv = invert_scales(scales) + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + input_config = [QuantInput(scales_inv.inputs[0], lp_dtype, hp_dtype)] + output_config = [DequantOutput(scales.outputs[0], lp_dtype, hp_dtype)] + config = ModuleConfig(input_config, output_config) + return config + + +def softmax_scales_to_mod_config(mod, scales, params): + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + output_config = [DequantOutput(scales.outputs[0], lp_dtype, hp_dtype)] + return ModuleConfig(None, output_config) + + +def get_config( + model, + measurement, + mod_dict, + method, + params, + scales_file=None, + recalc_scales=False, + mod_list=None, +): + with torch.no_grad(): + top_level_config = get_hqt_config(model) + qconfig = {UNMEASURED_MODELS: []} + scales_file_format = np.ndarray # file_functions[os.path.splitext(scales_file)[1]][0] + scales_obj = ( + load_scales(scales_file + ".npz", scales_file_format) + if (scales_file is not None) and not recalc_scales + else {} + ) + scales = convert_scales_to_tensors_dict(scales_obj, scales_file_format, params["hp_dtype"]) + model_dict = dict(model.named_modules()) + for mname in mod_list: + mod = model_dict[mname] + set_hqt_config(mod, top_level_config) # set config in the module, as it consumed by the patched module + mod_type_str = mod.__class__.__name__ + layer_type = mod_dict[mod_type_str].type + if mname not in scales: + logger.debug("Calcuating scales for layer %s", mname) + if mname not in measurement: + qconfig[UNMEASURED_MODELS].append(mname) + logger.debug( + "Layer '%s' has no measurements therefore it can't be quantized.", + mname, + ) + continue + layer_measure = measurement[mname] # ModuleConfig() of measurements + scales[mname] = method[layer_type][0](mod, layer_measure, params) # ModuleConfig() of scales + if scales_file is not None: + scales_obj[mname] = ModuleConfig( + **format_functions_rec((torch.Tensor, scales_file_format))(scales[mname].__dict__) + ) + + logger.debug( + "Preparing quantization functions for layer %s layer_type=%s", + mname, + layer_type, + ) + mod_config = method[layer_type][1](mod, scales[mname], params) # ModuleConfig() of QuantDequant + mod_extra_config = ModuleExtraConfig( + mod_config.inputs, + mod_config.outputs, + mod_config.params, + scales[mname], + params, + ) + qconfig[mname] = mod_extra_config + if scales_file is not None: + save_scales(model, scales_obj, scales_file_format, scales_file + ".npz") + save_scales(model, scales_obj, scales_file_format, scales_file + ".json") + return qconfig + + +scaling_methods = { + "unit_scale": { + "linear": (linear_unit_scale_scales, linear_scales_to_mod_config), + "matmul": (matmul_unit_scale_scales, matmul_scales_to_mod_config), + "softmax": (softmax_unit_scale_scales, softmax_scales_to_mod_config), + "kv_cache": (kv_cache_unit_scale_scales, kv_cache_scales_to_mod_config), + "fused_sdpa": (fsdpa_unit_scale_scales, fsdpa_scales_to_mod_config), + }, + "act_maxabs_pts_weight_maxabs_pts_pow2_hw": { + "linear": ( + linear_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + linear_scales_to_mod_config, + ), + "matmul": ( + matmul_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + matmul_scales_to_mod_config, + ), + "kv_cache": ( + kv_cache_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + kv_cache_scales_to_mod_config, + ), + "softmax": ( + softmax_input_unit_output_maxabs_pts_hw_scales, + softmax_scales_to_mod_config, + ), + "fused_sdpa": ( + fsdpa_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + fsdpa_scales_to_mod_config, + ), + }, + "act_maxabs_pts_weight_maxabs_pts_pow2": { + "linear": ( + linear_act_maxabs_pts_weight_maxabs_pts_pow2_scales, + linear_scales_to_mod_config, + ), + "matmul": ( + matmul_act_maxabs_pts_weight_maxabs_pts_pow2_scales, + matmul_scales_to_mod_config, + ), + }, + "act_maxabs_pts_pow2_hw_weights_maxabs_pcs_pow2": { + "linear": ( + linear_act_maxabs_pts_pow2_hw_weights_maxabs_pcs_pow2_scales, + linear_scales_to_mod_config, + ), + "matmul": ( + matmul_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + matmul_scales_to_mod_config, + ), + # kv_cache is pts as op in hw doesn't work in pcs + "kv_cache": ( + kv_cache_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + kv_cache_scales_to_mod_config, + ), + "fused_sdpa": ( + fsdpa_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + fsdpa_scales_to_mod_config, + ), + }, + "act_maxabs_pts_weight_opt_pts_pow2": { + "linear": ( + linear_act_maxabs_pts_weight_opt_pts_pow2_scales, + linear_scales_to_mod_config, + ), + "matmul": ( + matmul_act_maxabs_pts_weight_maxabs_pts_pow2_scales, + matmul_scales_to_mod_config, + ), + }, + "act_maxabs_pts_weight_opt_pts_hw": { + "linear": ( + linear_act_maxabs_pts_weight_opt_pts_hw_scales, + linear_scales_to_mod_config, + ), + "matmul": ( + matmul_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + matmul_scales_to_mod_config, + ), + "softmax": ( + softmax_input_unit_output_maxabs_pts_hw_scales, + softmax_scales_to_mod_config, + ), + "fused_sdpa": ( + fsdpa_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + fsdpa_scales_to_mod_config, + ), + }, + "act_maxabs_pts_pow2_hw_weights_opt_pcs_pow2": { + "linear": ( + linear_act_maxabs_pts_pow2_hw_weights_opt_pcs_pow2_scales, + linear_scales_to_mod_config, + ), + "matmul": ( + matmul_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + matmul_scales_to_mod_config, + ), + # kv_cache is pts as op in hw doesn't work in pcs + "kv_cache": ( + kv_cache_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + kv_cache_scales_to_mod_config, + ), + "fused_sdpa": ( + fsdpa_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + fsdpa_scales_to_mod_config, + ), + }, + "act_maxabs_pts_pow2_weights_maxabs_pcs_pow2": { + "linear": ( + linear_act_maxabs_pts_pow2_weights_maxabs_pcs_pow2_scales, + linear_scales_to_mod_config, + ), + "matmul": ( + matmul_act_maxabs_pts_weight_maxabs_pts_pow2_scales, + matmul_scales_to_mod_config, + ), + # kv_cache is pts as op in hw doesn't work in pcs + "kv_cache": ( + kv_cache_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + kv_cache_scales_to_mod_config, + ), + "fused_sdpa": ( + fsdpa_act_maxabs_pts_weight_maxabs_pts_pow2_scales, + fsdpa_scales_to_mod_config, + ), + }, + "act_maxabs_pts_pow2_weights_opt_pcs_pow2": { + "linear": ( + linear_act_maxabs_pts_pow2_weights_opt_pcs_pow2_scales, + linear_scales_to_mod_config, + ), + "matmul": ( + matmul_act_maxabs_pts_weight_maxabs_pts_pow2_scales, + matmul_scales_to_mod_config, + ), + # kv_cache is pts as op in hw doesn't work in pcs + "kv_cache": ( + kv_cache_act_maxabs_pts_pow2_weight_opt_pcs_pow2_scales, + kv_cache_scales_to_mod_config, + ), + "fused_sdpa": ( + fsdpa_act_maxabs_pts_weight_maxabs_pts_pow2_scales, + fsdpa_scales_to_mod_config, + ), + }, + "smoothquant_weights_opt_pow2": { + "linear": ( + linear_smoothquant_weights_opt_pow2_scales, + linear_scales_to_mod_config, + ), + "matmul": ( + matmul_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + matmul_scales_to_mod_config, + ), + }, + "smoothquant_weights_maxabs_pow2": { + "linear": ( + linear_smoothquant_weights_maxabs_pow2_scales, + linear_scales_to_mod_config, + ), + "matmul": ( + matmul_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + matmul_scales_to_mod_config, + ), + }, + "weaksmoothquant_weights_maxabs_pow2": { + "linear": ( + linear_weaksmoothquant_weights_maxabs_pow2_scales, + linear_scales_to_mod_config, + ), + "matmul": ( + matmul_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + matmul_scales_to_mod_config, + ), + }, +} + +scale_method_mapping = { + (ScaleMethod.UNIT_SCALE, "maxabs"): "unit_scale", + (ScaleMethod.UNIT_SCALE, "maxabs_per_channel"): "unit_scale", + (ScaleMethod.MAXABS_HW, "maxabs"): "act_maxabs_pts_weight_maxabs_pts_pow2_hw", + (ScaleMethod.MAXABS_POW2, "maxabs"): "act_maxabs_pts_weight_maxabs_pts_pow2", + (ScaleMethod.MAXABS_HW_OPT_WEIGHT, "maxabs"): "act_maxabs_pts_weight_opt_pts_hw", + ( + ScaleMethod.MAXABS_POW2_OPT_WEIGHT, + "maxabs", + ): "act_maxabs_pts_weight_opt_pts_pow2", + ( + ScaleMethod.ACT_MAXABS_HW_WEIGHTS_PCS_MAXABS_POW2, + "maxabs", + ): "act_maxabs_pts_pow2_hw_weights_maxabs_pcs_pow2", + ( + ScaleMethod.ACT_MAXABS_HW_WEIGHTS_PCS_MAXABS_POW2, + "maxabs_per_channel", + ): "act_maxabs_pts_pow2_hw_weights_maxabs_pcs_pow2", + ( + ScaleMethod.SMOOTHQUANT_WEIGHTS_OUTPUT_CHANNEL_MAXABS_POW2, + "maxabs_per_channel", + ): "smoothquant_weights_maxabs_pow2", + ( + ScaleMethod.WEAKSMOOTHQUANT_WEIGHTS_OUTPUT_CHANNEL_MAXABS_POW2, + "maxabs_per_channel", + ): "weaksmoothquant_weights_maxabs_pow2", + ( + ScaleMethod.ACT_MAXABS_HW_WEIGHTS_PCS_OPT_POW2, + "maxabs", + ): "act_maxabs_pts_pow2_hw_weights_opt_pcs_pow2", + ( + ScaleMethod.ACT_MAXABS_HW_WEIGHTS_PCS_OPT_POW2, + "maxabs_per_channel", + ): "act_maxabs_pts_pow2_hw_weights_opt_pcs_pow2", + ( + ScaleMethod.ACT_MAXABS_POW2_WEIGHTS_PCS_MAXABS_POW2, + "maxabs", + ): "act_maxabs_pts_pow2_weights_maxabs_pcs_pow2", + ( + ScaleMethod.ACT_MAXABS_POW2_WEIGHTS_PCS_MAXABS_POW2, + "maxabs_per_channel", + ): "act_maxabs_pts_pow2_weights_maxabs_pcs_pow2", + ( + ScaleMethod.ACT_MAXABS_POW2_WEIGHTS_PCS_OPT_POW2, + "maxabs", + ): "act_maxabs_pts_pow2_weights_opt_pcs_pow2", + ( + ScaleMethod.ACT_MAXABS_POW2_WEIGHTS_PCS_OPT_POW2, + "maxabs_per_channel", + ): "act_maxabs_pts_pow2_weights_opt_pcs_pow2", + ( + ScaleMethod.WEAKSMOOTHQUANT_WEIGHTS_OUTPUT_CHANNEL_MAXABS_POW2, + "maxabs_per_channel", + ): "weaksmoothquant_weights_maxabs_pow2", + ( + ScaleMethod.SMOOTHQUANT_WEIGHTS_OUTPUT_CHANNEL_MAXABS_POW2, + "maxabs_per_channel", + ): "smoothquant_weights_maxabs_pow2", + (ScaleMethod.SMOOTHQUANT_OPT, "maxabs_per_channel"): "smoothquant_weights_opt_pow2", +} + +scaling_params = { + "unit_scale": {}, + "act_maxabs_pts_weight_maxabs_pts_pow2_hw": { + "input_backoff": 0.25, + "weight_backoff": 0.5, + }, + "act_maxabs_pts_weight_maxabs_pts_pow2": { + "input_backoff": 0.25, + "weight_backoff": 0.5, + }, + "act_maxabs_pts_weight_opt_pts_pow2": { + "input_backoff": 0.25, + "weight_backoff": 0.5, + "weight_scales": [2.0**s for s in range(-10, 10)], + }, + "act_maxabs_pts_weight_opt_pts_hw": { + "input_backoff": 0.25, + "weight_backoff": 0.5, + "weight_scales": [2.0**s for s in [4, 0, -4, -8]], + }, + "smoothquant_weights_maxabs_pow2": { + "input_backoff": 0.25, + "weight_backoff": 0.5, + "alpha": 0.5, + }, + "weaksmoothquant_weights_maxabs_pow2": { + "input_backoff": 0.25, + "weight_backoff": 0.5, + "alpha": 0.5, + }, + "act_maxabs_pts_pow2_hw_weights_maxabs_pcs_pow2": { + "input_backoff": 0.25, + "weight_backoff": 0.5, + }, + "act_maxabs_pts_pow2_hw_weights_opt_pcs_pow2": { + "input_backoff": 0.25, + "weight_backoff": 0.5, + "weight_scales": [2.0**s for s in range(-3, 5)], + }, + "act_maxabs_pts_pow2_weights_maxabs_pcs_pow2": { + "input_backoff": 0.25, + "weight_backoff": 0.5, + }, + "act_maxabs_pts_pow2_weights_opt_pcs_pow2": { + "input_backoff": 0.25, + "weight_backoff": 0.5, + "weight_scales": [2.0**s for s in range(-3, 5)], + }, + "smoothquant_weights_opt_pow2": { + "input_backoff": 0.25, + "weight_backoff": 0.5, + "alpha": 0.5, + "transformed_weight_scales": [2.0**s for s in range(-3, 5)], + }, +} diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/__init__.py b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/__init__.py new file mode 100644 index 00000000000..1c0b11e3c99 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/__init__.py @@ -0,0 +1,3 @@ +from .max_abs import * +from .unit_scale import * +from .smooth_quant import * diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/max_abs.py b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/max_abs.py new file mode 100644 index 00000000000..d991a68aca2 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/max_abs.py @@ -0,0 +1,397 @@ +import torch + +from ..fp_utils import * +from ..common import * + + +def linear_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales(mod, measurement, params): + config = get_hqt_config(mod).cfg + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + weight_backoff = params["weight_backoff"] + input_scale = calc_maxabs_scale( + torch.tensor(measurement.inputs[0], dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + weight_scale = calc_maxabs_scale( + torch.max(torch.abs(mod.weight.detach())).to(dtype=hp_dtype, device=device), + fullscale, + weight_backoff, + ) + input_scale = scale_to_pow2_hw(input_scale, device_type=config["device_type"]) + weight_scale = scale_to_pow2_hw(weight_scale, device_type=config["device_type"]) + output_scale = input_scale * weight_scale + return ModuleConfig((input_scale,), (output_scale,), {"weight": weight_scale}) + + +def linear_act_maxabs_pts_weight_maxabs_pts_pow2_scales(mod, measurement, params): + config = get_hqt_config(mod).cfg + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + weight_backoff = params["weight_backoff"] + input_scale = calc_maxabs_scale( + torch.tensor(measurement.inputs[0], dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + weight_scale = calc_maxabs_scale( + torch.max(torch.abs(mod.weight.detach())).to(dtype=hp_dtype, device=device), + fullscale, + weight_backoff, + ) + input_scale = scale_to_pow2(input_scale) + weight_scale = scale_to_pow2(weight_scale) + output_scale = input_scale * weight_scale + return ModuleConfig((input_scale,), (output_scale,), {"weight": weight_scale}) + + +def matmul_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales(mod, measurement, params): + config = get_hqt_config(mod).cfg + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + input_scale = [ + calc_maxabs_scale( + torch.tensor(x, dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + for x in measurement.inputs + ] + input_scale = [scale_to_pow2_hw(x, device_type=config["device_type"]) for x in input_scale] + output_scale = [input_scale[0] * input_scale[1]] + return ModuleConfig(input_scale, output_scale, {}) + + +def matmul_act_maxabs_pts_weight_maxabs_pts_pow2_scales(mod, measurement, params): + config = get_hqt_config(mod).cfg + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + input_scale = [ + calc_maxabs_scale( + torch.tensor(x, dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + for x in measurement.inputs + ] + input_scale = [scale_to_pow2(x) for x in input_scale] + output_scale = [input_scale[0] * input_scale[1]] + return ModuleConfig(input_scale, output_scale, {}) + + +def fsdpa_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales(mod, measurement, params): + config = get_hqt_config(mod).cfg + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + input_scale = [ + calc_maxabs_scale( + torch.tensor(x, dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + for x in measurement.inputs + ] + # add amax scale to input scales + input_scale.append( + calc_maxabs_scale( + torch.tensor(measurement.outputs[1], dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + ) + input_scale = [scale_to_pow2_hw(x, device_type=config["device_type"]) for x in input_scale] + output_scale = calc_maxabs_scale( + torch.tensor(measurement.outputs[0], dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + output_scale = [scale_to_pow2_hw(output_scale, device_type=config["device_type"])] + return ModuleConfig(input_scale, output_scale, {}) + + +def fsdpa_act_maxabs_pts_weight_maxabs_pts_pow2_scales(mod, measurement, params): + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + input_scale = [ + calc_maxabs_scale( + torch.tensor(x, dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + for x in measurement.inputs + ] + # fsdpa is combined out of - BMM1(Q,K) -> Softmax -> BMM2(AMAX,V) + # during measure we recieve the amax value from the cguid and apply it during quant as input + input_scale.append( + calc_maxabs_scale( + torch.tensor(measurement.outputs[1], dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + ) + input_scale = [scale_to_pow2(x) for x in input_scale] + output_scale = calc_maxabs_scale( + torch.tensor(measurement.outputs[0], dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + output_scale = [scale_to_pow2(output_scale)] + return ModuleConfig(input_scale, output_scale, {}) + + +def linear_act_maxabs_pts_weight_opt_pts_pow2_scales(mod, measurement, params): + config = get_hqt_config(mod).cfg + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + scales = params["weight_scales"] + input_backoff = params["input_backoff"] + weight_backoff = params["weight_backoff"] + input_scale = calc_maxabs_scale( + torch.tensor(measurement.inputs[0], dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + weight_scale = mmse_scale(mod.weight, scales, lp_dtype, hp_dtype) + input_scale = scale_to_pow2(input_scale) + weight_scale = scale_to_pow2(weight_scale) + output_scale = input_scale * weight_scale + return ModuleConfig((input_scale,), (output_scale,), {"weight": weight_scale}) + + +def linear_act_maxabs_pts_weight_opt_pts_hw_scales(mod, measurement, params): + config = get_hqt_config(mod).cfg + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + scales = params["weight_scales"] + input_backoff = params["input_backoff"] + weight_backoff = params["weight_backoff"] + input_scale = calc_maxabs_scale( + torch.tensor(measurement.inputs[0], dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + weight_scale = mmse_scale(mod.weight, scales, lp_dtype, hp_dtype) + input_scale = scale_to_pow2_hw(input_scale, device_type=config["device_type"]) + weight_scale = scale_to_pow2_hw(weight_scale, device_type=config["device_type"]) + output_scale = input_scale * weight_scale + return ModuleConfig((input_scale,), (output_scale,), {"weight": weight_scale}) + + +def kv_cache_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales(mod, measurement, params): + config = get_hqt_config(mod).cfg + # calc the scale per layer tensor + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + input_scale = calc_maxabs_scale( + torch.tensor(measurement.inputs[0], dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + input_scale_list = [scale_to_pow2_hw(input_scale, device_type=config["device_type"])] + output_scale = [input_scale_list[0]] # output scale is same as the first input (current data) since range is same + return ModuleConfig(input_scale_list, output_scale, {}) + + +def kv_cache_act_maxabs_pts_pow2_weight_opt_pcs_pow2_scales(mod, measurement, params): + # calc the scale per layer tensor + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + input_scale = calc_maxabs_scale( + torch.tensor(measurement.inputs[0], dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + input_scale_list = [scale_to_pow2(input_scale)] + output_scale = [input_scale_list[0]] # output scale is same as the first input (current data) since range is same + return ModuleConfig(input_scale_list, output_scale, {}) + + +def softmax_input_unit_output_maxabs_pts_hw_scales(mod, measurement, params): + config = get_hqt_config(mod).cfg + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + output_scale = calc_maxabs_scale( + torch.tensor(measurement.outputs[0], dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + output_scale = [scale_to_pow2_hw(output_scale, device_type=config["device_type"])] + return ModuleConfig((), output_scale, {}) + + +def linear_act_maxabs_pts_pow2_hw_weights_maxabs_pcs_pow2_scales(mod, measurement, params): + config = get_hqt_config(mod).cfg + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + weight_backoff = params["weight_backoff"] + input_scale = calc_maxabs_scale( + torch.tensor(measurement.inputs[0], dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + input_scale = scale_to_pow2_hw(input_scale, device_type=config["device_type"]) + weight_scale_in_ch = torch.ones([mod.weight.shape[1], 1], dtype=hp_dtype, device=device) + + weight_range_out_ch = torch.max(torch.abs(mod.weight), dim=1)[0].reshape([-1, 1]) + weight_maxabs_scale_out_ch = calc_maxabs_scale(weight_range_out_ch, fullscale, weight_backoff) + weight_maxabs_scale_out_ch = scale_to_pow2(weight_maxabs_scale_out_ch) + output_scale = weight_maxabs_scale_out_ch * input_scale + return ModuleConfig( + (input_scale.flatten(),), + (output_scale.flatten(),), + { + "weight": { + 0: weight_maxabs_scale_out_ch.flatten(), + 1: weight_scale_in_ch.flatten(), + } + }, + ) + + +def linear_act_maxabs_pts_pow2_weights_maxabs_pcs_pow2_scales(mod, measurement, params): + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + weight_backoff = params["weight_backoff"] + input_scale = calc_maxabs_scale( + torch.tensor(measurement.inputs[0], dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + input_scale = scale_to_pow2(input_scale) + weight_scale_in_ch = torch.ones([mod.weight.shape[1], 1], dtype=hp_dtype, device=device) + + weight_range_out_ch = torch.max(torch.abs(mod.weight), dim=1)[0].reshape([-1, 1]) + weight_maxabs_scale_out_ch = calc_maxabs_scale(weight_range_out_ch, fullscale, weight_backoff) + weight_maxabs_scale_out_ch = scale_to_pow2(weight_maxabs_scale_out_ch) + output_scale = weight_maxabs_scale_out_ch * input_scale + return ModuleConfig( + (input_scale.flatten(),), + (output_scale.flatten(),), + { + "weight": { + 0: weight_maxabs_scale_out_ch.flatten(), + 1: weight_scale_in_ch.flatten(), + } + }, + ) + + +def linear_act_maxabs_pts_pow2_hw_weights_opt_pcs_pow2_scales(mod, measurement, params): + config = get_hqt_config(mod).cfg + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + weight_backoff = params["weight_backoff"] + weight_scales = params["weight_scales"] + input_scale = calc_maxabs_scale( + torch.tensor(measurement.inputs[0], dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + input_scale = scale_to_pow2_hw(input_scale, device_type=config["device_type"]) + weight_scale_in_ch = torch.ones([mod.weight.shape[1], 1], dtype=hp_dtype, device=device) + + weight_range_out_ch = torch.max(torch.abs(mod.weight), dim=1)[0].reshape([-1, 1]) + weight_maxabs_scale_out_ch = calc_maxabs_scale(weight_range_out_ch, fullscale, weight_backoff) + weight_maxabs_scale_out_ch = scale_to_pow2(weight_maxabs_scale_out_ch) + weight_opt_scale_out_ch = mmse_scale_multi( + torch.transpose(mod.weight, 0, 1), + weight_maxabs_scale_out_ch.squeeze(), + weight_scales, + lp_dtype, + hp_dtype, + ).unsqueeze(1) + weight_maxabs_scale_out_ch = weight_opt_scale_out_ch + weight_maxabs_scale_out_ch = scale_to_pow2(weight_maxabs_scale_out_ch) # should be power of 2, just making sure + output_scale = weight_maxabs_scale_out_ch * input_scale + return ModuleConfig( + (input_scale.flatten(),), + (output_scale.flatten(),), + { + "weight": { + 0: weight_maxabs_scale_out_ch.flatten(), + 1: weight_scale_in_ch.flatten(), + } + }, + ) + + +def linear_act_maxabs_pts_pow2_weights_opt_pcs_pow2_scales(mod, measurement, params): + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + weight_backoff = params["weight_backoff"] + weight_scales = params["weight_scales"] + input_scale = calc_maxabs_scale( + torch.tensor(measurement.inputs[0], dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + input_scale = scale_to_pow2(input_scale) + weight_scale_in_ch = torch.ones([mod.weight.shape[1], 1], dtype=hp_dtype, device=device) + + weight_range_out_ch = torch.max(torch.abs(mod.weight), dim=1)[0].reshape([-1, 1]) + weight_maxabs_scale_out_ch = calc_maxabs_scale(weight_range_out_ch, fullscale, weight_backoff) + weight_maxabs_scale_out_ch = scale_to_pow2(weight_maxabs_scale_out_ch) + weight_opt_scale_out_ch = mmse_scale_multi( + torch.transpose(mod.weight, 0, 1), + weight_maxabs_scale_out_ch.squeeze(), + weight_scales, + lp_dtype, + hp_dtype, + ).unsqueeze(1) + weight_maxabs_scale_out_ch = weight_opt_scale_out_ch + weight_maxabs_scale_out_ch = scale_to_pow2(weight_maxabs_scale_out_ch) # should be power of 2, just making sure + output_scale = weight_maxabs_scale_out_ch * input_scale + return ModuleConfig( + (input_scale.flatten(),), + (output_scale.flatten(),), + { + "weight": { + 0: weight_maxabs_scale_out_ch.flatten(), + 1: weight_scale_in_ch.flatten(), + } + }, + ) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/smooth_quant.py b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/smooth_quant.py new file mode 100644 index 00000000000..3a216e6ef15 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/smooth_quant.py @@ -0,0 +1,118 @@ +import torch +from tqdm import tqdm + +from ..fp_utils import * +from ..common import * + + +def linear_smoothquant_weights_opt_pow2_scales(mod, measurement, params): + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + weight_backoff = params["weight_backoff"] + alpha = params["alpha"] + transformed_weight_scales = params["transformed_weight_scales"] + input_range = torch.tensor(measurement.inputs[0], dtype=hp_dtype, device=device) + weight_range_in_ch = torch.max(torch.abs(mod.weight), dim=0)[0].reshape([-1, 1]) + input_scale = calc_maxabs_scale(input_range, fullscale, input_backoff) + weight_scale_in_ch = calc_maxabs_scale(weight_range_in_ch, fullscale, weight_backoff) + input_scale = (input_scale**alpha) / (weight_scale_in_ch ** (1 - alpha)) + input_scale = scale_to_pow2(input_scale) + weight_scale_in_ch = 1 / input_scale + trans_weight = scale_fcn(mod.weight, weight_scale_in_ch.reshape([1, -1])) + trans_weight_range_out_ch = torch.max(torch.abs(trans_weight), dim=1)[0].reshape([-1, 1]) + trans_weight_maxabs_scale_out_ch = calc_maxabs_scale(trans_weight_range_out_ch, fullscale, weight_backoff) + trans_weight_maxabs_scale_out_ch = scale_to_pow2(trans_weight_maxabs_scale_out_ch) + trans_weight_scale_out_ch = torch.zeros(mod.weight.shape[0]) + for k in tqdm(range(trans_weight_scale_out_ch.shape[0])): + trans_weight_scale_out_ch[k] = mmse_scale( + trans_weight[k, :], + [s * trans_weight_maxabs_scale_out_ch[k] for s in transformed_weight_scales], + lp_dtype, + hp_dtype, + ) + weight_scale_out_ch = scale_to_pow2(trans_weight_scale_out_ch) + output_scale = torch.tensor(weight_scale_out_ch, dtype=hp_dtype, device=device) + return ModuleConfig( + (input_scale.flatten(),), + (output_scale.flatten(),), + {"weight": {0: weight_scale_out_ch.flatten(), 1: weight_scale_in_ch.flatten()}}, + ) + + +def linear_smoothquant_weights_maxabs_pow2_scales(mod, measurement, params): + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + weight_backoff = params["weight_backoff"] + alpha = params["alpha"] + input_range = torch.tensor(measurement.inputs[0], dtype=hp_dtype, device=device) + weight_range_in_ch = torch.max(torch.abs(mod.weight), dim=0)[0].reshape([-1, 1]) + input_scale = calc_maxabs_scale(input_range, 1.0, 1.0) + weight_scale_in_ch = calc_maxabs_scale(weight_range_in_ch, 1.0, 1.0) + input_scale = (input_scale**alpha) / (weight_scale_in_ch ** (1 - alpha)) + input_scale = scale_to_pow2(input_scale) + input_range_post = input_range / input_scale + input_scale_post = calc_maxabs_scale(input_range_post.max(), fullscale, input_backoff) + input_scale_post = scale_to_pow2(input_scale_post) + input_scale = input_scale * input_scale_post + weight_scale_in_ch = 1 / input_scale + trans_weight = scale_fcn(mod.weight, weight_scale_in_ch.reshape([1, -1])) + trans_weight_range_out_ch = torch.max(torch.abs(trans_weight), dim=1)[0].reshape([-1, 1]) + trans_weight_maxabs_scale_out_ch = calc_maxabs_scale(trans_weight_range_out_ch, fullscale, weight_backoff) + trans_weight_maxabs_scale_out_ch = scale_to_pow2(trans_weight_maxabs_scale_out_ch) + weight_scale_out_ch = scale_to_pow2(trans_weight_maxabs_scale_out_ch) + output_scale = torch.tensor(weight_scale_out_ch, dtype=hp_dtype, device=device) + return ModuleConfig( + (input_scale.flatten(),), + (output_scale.flatten(),), + {"weight": {0: weight_scale_out_ch.flatten(), 1: weight_scale_in_ch.flatten()}}, + ) + + +def linear_weaksmoothquant_weights_maxabs_pow2_scales(mod, measurement, params): + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + weight_backoff = params["weight_backoff"] + alpha = params["alpha"] + input_range = torch.tensor(measurement.inputs[0], dtype=hp_dtype, device=device).max().clamp(min=1e-5) + input_range_mid = input_range.max() / torch.sqrt(input_range.max() / input_range.min().clamp(min=1e-5)) + input_scale_pcs = calc_maxabs_scale(input_range.clamp(min=1e-5), input_range_mid, 1.0).clamp(min=1e-5) + weight_range_in_ch = torch.max(torch.abs(mod.weight), dim=0)[0].reshape([-1, 1]).clamp(min=1e-5) + weight_range_in_ch_mid = weight_range_in_ch.max() / torch.sqrt( + weight_range_in_ch.max() / weight_range_in_ch.min().clamp(min=1e-5) + ).clamp(min=1e-5) + weight_scale_pcs = calc_maxabs_scale(weight_range_in_ch.clamp(min=1e-5), weight_range_in_ch_mid, 1.0).clamp( + min=1e-5 + ) + + input_scale = ((input_scale_pcs**alpha) / (weight_scale_pcs ** (1 - alpha))).clamp(min=1e-5) + input_scale = scale_to_pow2(input_scale) + input_scale_post = calc_maxabs_scale((input_range / input_scale).max(), fullscale, input_backoff) + input_scale_post = scale_to_pow2(input_scale_post) + + weight_scale_in_ch = torch.ones([mod.weight.shape[1], 1], dtype=hp_dtype, device=device) * (1 / input_scale) + + trans_weight = scale_fcn(mod.weight, weight_scale_in_ch.reshape([1, -1])) + weight_range_out_ch = torch.max(torch.abs(trans_weight), dim=1)[0].reshape([-1, 1]) + + weight_maxabs_scale_out_ch = calc_maxabs_scale(weight_range_out_ch, fullscale, weight_backoff) + weight_maxabs_scale_out_ch = scale_to_pow2(weight_maxabs_scale_out_ch) + output_scale = weight_maxabs_scale_out_ch * input_scale_post + return ModuleConfig( + (input_scale.flatten() * input_scale_post,), + (output_scale.flatten(),), + { + "weight": { + 0: weight_maxabs_scale_out_ch.flatten(), + 1: weight_scale_in_ch.flatten(), + } + }, + ) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/unit_scale.py b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/unit_scale.py new file mode 100644 index 00000000000..6be7673aace --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/unit_scale.py @@ -0,0 +1,52 @@ +import torch + +from ..fp_utils import * +from ..common import * + + +def linear_unit_scale_scales(mod, measurement, params): + device = torch.device("hpu") + hp_dtype = params["hp_dtype"] + input_scale = torch.tensor(1.0, dtype=hp_dtype, device=device) + weight_scale = torch.tensor(1.0, dtype=hp_dtype, device=device) + output_scale = torch.tensor(1.0, dtype=hp_dtype, device=device) + return ModuleConfig((input_scale,), (output_scale,), {"weight": weight_scale}) + + +def fsdpa_unit_scale_scales(mod, measurement, params): + device = torch.device("hpu") + hp_dtype = torch.float32 # params["hp_dtype"] + q_scale = torch.tensor(1.0, dtype=hp_dtype, device=device) + k_scale = torch.tensor(1.0, dtype=hp_dtype, device=device) + v_scale = torch.tensor(1.0, dtype=hp_dtype, device=device) + softmax_scale = torch.tensor(1.0, dtype=hp_dtype, device=device) + input_scale = (q_scale, k_scale, v_scale, softmax_scale) + output_scale = (torch.tensor(1.0, dtype=hp_dtype, device=device),) + return ModuleConfig(input_scale, output_scale, {}) + + +def matmul_unit_scale_scales(mod, measurement, params): + device = torch.device("hpu") + hp_dtype = params["hp_dtype"] + input_scale = ( + torch.tensor(1.0, dtype=hp_dtype, device=device), + torch.tensor(1.0, dtype=hp_dtype, device=device), + ) + output_scale = (torch.tensor(1.0, dtype=hp_dtype, device=device),) + return ModuleConfig(input_scale, output_scale, {}) + + +def softmax_unit_scale_scales(mod, measurement, params): + device = torch.device("hpu") + hp_dtype = params["hp_dtype"] + input_scale = (torch.tensor(1.0, dtype=hp_dtype, device=device),) + output_scale = (torch.tensor(1.0, dtype=hp_dtype, device=device),) + return ModuleConfig(input_scale, output_scale) + + +def kv_cache_unit_scale_scales(mod, measurement, params): + device = torch.device("hpu") + hp_dtype = params["hp_dtype"] + input_scale = (torch.tensor(1.0, dtype=hp_dtype, device=device),) + output_scale = (torch.tensor(1.0, dtype=hp_dtype, device=device),) + return ModuleConfig(input_scale, output_scale) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/utils.py b/neural_compressor/torch/algorithms/fp8_quant/_core/utils.py new file mode 100644 index 00000000000..a4652bd1755 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/utils.py @@ -0,0 +1,49 @@ +from .measure import prepare_model as prepare_model_for_measure +from .quantize import quantize +from .scale import scaling_params, scale_method_mapping +from .._quant_common.quant_config import QuantMode, get_hqt_config + +from .._quant_common.helper_modules import * +from ..utils.logger import logger +from .common import mod_default_dict + +def update_mod_dict(config): + assert len(config.cfg['mod_dict']) == 0, f"Custom modules are not supported: {config.cfg['mod_dict'].keys()}. Please add it in the code." + config.cfg['mod_dict'].update({k: mod_default_dict[k].type for k in mod_default_dict}) + +def print_init_info(config): + import importlib.metadata + versionStr = importlib.metadata.version('habana_quantization_toolkit') + locationStr = versionStr.find('git') + 3 + logger.info("HQT Git revision = %s", versionStr[locationStr:]) + logger.info("HQT Configuration = %s", config) + +def is_substr(substr_list, target): + return any([x in target for x in substr_list]) + +def prepare_model(model): + config = get_hqt_config(model) + update_mod_dict(config) + allowlist=set(config.cfg['mod_dict'].keys()) + blocklist=set() + for type_st in config.cfg['blocklist']['types']: + blocklist.add(type_st) + allowlist.difference_update(blocklist) + allowlist_tuple=tuple(allowlist) + mod_list=[] + for name, mod in model.named_modules(): + mod_type=mod.__class__.__name__ + if (mod_type in allowlist_tuple) and (is_substr(config.cfg['allowlist']['names'], name) or len(config.cfg['allowlist']['names'])==0) and (not is_substr(config.cfg['blocklist']['names'], name)): + mod_list.append(name) + + print_init_info(config) + + logger.debug("Module list: %s", mod_list) + logger.info("Total modules : %d", len(mod_list)) + if (config.cfg['mode']==QuantMode.MEASURE) or (config.cfg['mode']==QuantMode.SHAPE): + return prepare_model_for_measure(model, mod_list) + elif config.cfg['mode']==QuantMode.QUANTIZE: + scaling_method_name = scale_method_mapping[(config.cfg['scale_method'], config.cfg['observer'])] + scaling_params[scaling_method_name].update(config.cfg['scale_params']) + config.cfg['scale_params'] = scaling_params[scaling_method_name] + return quantize(model, mod_list) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/__init__.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py new file mode 100644 index 00000000000..61d26f081ff --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py @@ -0,0 +1,812 @@ +import torch.nn as nn +import torch + +from .quant_config import QuantMode, get_hqt_config, set_hqt_config + +try: # backwards compatibility for 1.16 + from habana_frameworks.torch.hpex.kernels import fp8_fused_sdpa +except ImportError: + pass + + +class BMM(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.bmm(x, y) + + +class Matmul(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, *args, **kwargs): + return torch.matmul(*args, **kwargs) + + +class Identity(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.clone() + + +class Softmax(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, dim=None): + return torch.softmax(x, dim) + + +def matmul_fp8( + input, + other, + out=None, + out_dtype=torch.bfloat16, + scale_input_inv=None, + scale_other_inv=None, +): + res = torch.ops.hpu.fp8_gemm_v2( + input, + False, + other, + False, + out, + out_dtype, + scale_input_inv, + scale_other_inv, + None, + False, + ) + return res + + +def measure_input(input, observer): + for i in range(len(observer)): + observer[i].measure(input[i]) + + +def measure_output(output, observer): + if observer: + for i in range(len(observer)): + observer[i].measure(output[i]) + + +def conv2d_fp8( + input, + other, + bias, + stride, + padding, + dilation, + groups, + out_dtype=torch.bfloat16, + scale_input_inv=None, + scale_other_inv=None, +): + return torch.ops.hpu.conv2d_fp8( + input=input, + weight=other, + bias=bias, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + out_dtype=out_dtype, + scale_input=scale_input_inv, + scale_weight=scale_other_inv, + ) + + +def set_attrs_from_orig_model(cls_instance, mod, mod_extra_config, *func_names): + cls_instance.__dict__.update(mod.__dict__) + config = get_hqt_config(cls_instance) + cls_instance.extra_repr_org = mod.extra_repr + cls_instance.class_name_org = mod.__class__.__name__ + cls_instance._mod_extra_config = mod_extra_config + cls_instance.quantization_mode = config.cfg["mode"] + cls_instance.forward_orig = mod.forward + if func_names is not None: + for func in func_names: + setattr(cls_instance, func, getattr(mod, func)) + + +def get_current_repr(cls_instance, *member_names): + curr_repr = "" + if cls_instance.quantization_mode == QuantMode.QUANTIZE: + first_name = True + for name in member_names: + if not first_name: + curr_repr += ", " + curr_repr += f"{name} dtype={getattr(cls_instance, name).dtype}" + first_name = False + return curr_repr + + +def extra_representation(org_repr, org_name, curr_repr): + repr = f"original={org_name}," + (" " + org_repr + "," if org_repr != "" else "") + return f"{repr} {curr_repr}" + + +def _raise_lora_layer_error(layer_class): + raise RuntimeError( + f"{layer_class} quantization is not supported in case of lora_layer member is not None." + f" Can add {layer_class} to 'blocklist' field in quantization config file" + ) + + +class PatchedMatmul(nn.Module): + def __init__(self, mod, mod_extra_config, *args, **kwargs): + super().__init__() + set_attrs_from_orig_model(self, mod, mod_extra_config) + if self.quantization_mode == QuantMode.QUANTIZE: + self.quant_input_0 = self._mod_extra_config.inputs[0] + self.quant_input_1 = self._mod_extra_config.inputs[1] + self.scale_input = nn.Parameter(mod_extra_config.scale.inputs[0]) + self.scale_other = nn.Parameter(mod_extra_config.scale.inputs[1]) + elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): + self.forward = self.forward_measure + + def forward(self, input, other): + qinput = self.quant_input_0(input) + qother = self.quant_input_1(other) + output = matmul_fp8( + qinput, + qother, + out_dtype=self._mod_extra_config.config_params["hp_dtype"], + scale_input_inv=self.scale_input, + scale_other_inv=self.scale_other, + ) + return output + + def forward_measure(self, input, other): + measure_input((input, other), observer=self._mod_extra_config.inputs) + output = self.forward_orig(input, other) + measure_output((output,), self._mod_extra_config.outputs) + return output + + def extra_repr(self) -> str: + return extra_representation( + self.extra_repr_org(), + self.class_name_org, + get_current_repr(self, "scale_input", "scale_other"), + ) + + +class PatchedLinear(nn.Module): + def __init__(self, mod, mod_extra_config, *args, **kwargs): + super().__init__() + set_attrs_from_orig_model(self, mod, mod_extra_config) + if self.quantization_mode == QuantMode.QUANTIZE: + # When offloading weights to disk using device_map, the module forward is overridden. + # __dict__.update call again overrides the PatchedLinear forward with the forward that device_map planted. + # So need to set PatchedLinear forawrd to be the right forward. + self.forward = self.forward_quant + self.quant_input = self._mod_extra_config.inputs[0] + self.weight = nn.Parameter(self.weight.t().contiguous()) + self.scale_input = nn.Parameter(mod_extra_config.scale.inputs[0]) + if isinstance(mod_extra_config.scale.params["weight"], (torch.Tensor, float)): + self.scale_weight = nn.Parameter(mod_extra_config.scale.params["weight"]) + elif isinstance(mod_extra_config.scale.params["weight"], dict): + # PCQ weight is calculated with actual weight [0] and ones [1] + self.scale_weight = nn.Parameter(mod_extra_config.scale.params["weight"][0]) + elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): + self.forward = self.forward_measure + + def forward_quant(self, input): + qinput = self.quant_input(input) + y = matmul_fp8( + qinput, + self.weight, + out_dtype=self._mod_extra_config.config_params["hp_dtype"], + scale_input_inv=self.scale_input, + scale_other_inv=self.scale_weight, + ) + output = y + self.bias if (self.bias is not None) else y + return output + + def forward_measure(self, input): + measure_input((input,), observer=self._mod_extra_config.inputs) + output = self.forward_orig(input) + measure_output((output,), self._mod_extra_config.outputs) + return output + + def extra_repr(self) -> str: + return extra_representation( + self.extra_repr_org(), + self.class_name_org, + get_current_repr(self, "scale_input", "scale_weight"), + ) + + +class PatchedLinearAllReduce(nn.Module): + def __init__(self, mod, mod_extra_config, *args, **kwargs): + super().__init__() + set_attrs_from_orig_model(self, mod, mod_extra_config) + self.scoped_version = mod.__class__.__name__ == "ScopedLinearAllReduce" + if self.quantization_mode == QuantMode.QUANTIZE: + self.quant_input = self._mod_extra_config.inputs[0] + self.quant_output = self._mod_extra_config.outputs[0] + self.weight = nn.Parameter(self.weight.t().contiguous()) + self.scale_input = nn.Parameter(mod_extra_config.scale.inputs[0]) + if isinstance(mod_extra_config.scale.params["weight"], (torch.Tensor, float)): + self.scale_weight = nn.Parameter(mod_extra_config.scale.params["weight"]) + elif isinstance(mod_extra_config.scale.params["weight"], dict): + # PCQ weight is calculated with actual weight [0] and ones [1] + self.scale_weight = nn.Parameter(mod_extra_config.scale.params["weight"][0]) + elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): + self.forward = self.forward_measure + + def forward(self, input): + # pre_all_reduce + qinput = self.quant_input(input) + output = matmul_fp8( + qinput, + self.weight, + out_dtype=self._mod_extra_config.config_params["hp_dtype"], + scale_input_inv=self.scale_input, + scale_other_inv=self.scale_weight, + ) + dqoutput = self.quant_output(output) + if not self.scoped_version: + self.all_reduce(dqoutput) + dqoutput = self.post_all_reduce(dqoutput) + return dqoutput + + def forward_measure(self, input): + measure_input((input,), observer=self._mod_extra_config.inputs) + output = torch.matmul(input, self.weight.transpose(-1, -2)) + measure_output((output,), self._mod_extra_config.outputs) + # in scoped version all reduce is being called outside of the layer + if not self.scoped_version: + self.all_reduce(output) + output = self.post_all_reduce(output) + return output + + def all_reduce(self, input): + if self.mp_group is not None: + from deepspeed import comm as dist + + dist.inference_all_reduce(input, group=self.mp_group) + + def post_all_reduce(self, input): + output = input + self.bias if (self.bias is not None) else input + return output + + def extra_repr(self) -> str: + return extra_representation( + self.extra_repr_org(), + self.class_name_org, + get_current_repr(self, "scale_input", "scale_weight"), + ) + + +class PatchedRowParallelLinear(nn.Module): + def __init__(self, mod, mod_extra_config, *args, **kwargs): + super().__init__() + set_attrs_from_orig_model(self, mod, mod_extra_config, "resolve_input") + if self.quantization_mode == QuantMode.QUANTIZE: + self.quant_input = self._mod_extra_config.inputs[0] + self.quant_output = self._mod_extra_config.outputs[0] + self.weight = nn.Parameter(self.weight.t().contiguous()) + self.scale_input = nn.Parameter(mod_extra_config.scale.inputs[0]) + if isinstance(mod_extra_config.scale.params["weight"], (torch.Tensor, float)): + self.scale_weight = nn.Parameter(mod_extra_config.scale.params["weight"]) + elif isinstance(mod_extra_config.scale.params["weight"], dict): + # PCQ weight is calculated with actual weight [0] and ones [1] + self.scale_weight = nn.Parameter(mod_extra_config.scale.params["weight"][0]) + elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): + self.forward = self.forward_measure + + def forward(self, input): + resolved_input = self.resolve_input(input) + qinput = self.quant_input(resolved_input) + output = matmul_fp8( + qinput, + self.weight, + out_dtype=self._mod_extra_config.config_params["hp_dtype"], + scale_input_inv=self.scale_input, + scale_other_inv=self.scale_weight, + ) + dqoutput = self.quant_output(output) + if self.reduce_results: + dqoutput = self.collective_func(dqoutput) + return self.post_all_reduce(dqoutput) + + def forward_measure(self, input): + resolved_input = self.resolve_input(input) + measure_input((resolved_input,), observer=self._mod_extra_config.inputs) + output = torch.matmul(resolved_input, self.weight.transpose(-1, -2)) + measure_output((output,), self._mod_extra_config.outputs) + if self.reduce_results: + output = self.collective_func(output) + return self.post_all_reduce(output) + + def post_all_reduce(self, output): + assert ( + self.reduce_results or (not self.bias) or self.skip_bias_add + ), "When not reduce the results, adding bias to the results can lead to incorrect results" + if not self.skip_bias_add: + output = output + self.bias if self.bias is not None else output + output_bias = None + else: + output_bias = self.bias + return output, output_bias + + def extra_repr(self) -> str: + return extra_representation( + self.extra_repr_org(), + self.class_name_org, + get_current_repr(self, "scale_input", "scale_weight"), + ) + + +class PatchedColumnParallelLinear(nn.Module): + def __init__(self, mod, mod_extra_config, *args, **kwargs): + super().__init__() + set_attrs_from_orig_model(self, mod, mod_extra_config) + if self.quantization_mode == QuantMode.QUANTIZE: + self.quant_input = self._mod_extra_config.inputs[0] + self.quant_output = self._mod_extra_config.outputs[0] + self.weight = nn.Parameter(self.weight.t().contiguous()) + self.scale_input = nn.Parameter(mod_extra_config.scale.inputs[0]) + if isinstance(mod_extra_config.scale.params["weight"], (torch.Tensor, float)): + self.scale_weight = nn.Parameter(mod_extra_config.scale.params["weight"]) + elif isinstance(mod_extra_config.scale.params["weight"], dict): + # PCQ weight is calculated with actual weight [0] and ones [1] + self.scale_weight = nn.Parameter(mod_extra_config.scale.params["weight"][0]) + elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): + self.forward = self.forward_measure + + def forward(self, input): + qinput = self.quant_input(input) + output = matmul_fp8( + qinput, + self.weight, + out_dtype=self._mod_extra_config.config_params["hp_dtype"], + scale_input_inv=self.scale_input, + scale_other_inv=self.scale_weight, + ) + dqoutput = self.quant_output(output) + if self.gather_output: + dqoutput = self.collective_func(dqoutput) + return self.post_all_reduce(dqoutput) + + def forward_measure(self, input): + measure_input((input,), observer=self._mod_extra_config.inputs) + output = torch.matmul(input, self.weight.transpose(-1, -2)) + measure_output((output,), self._mod_extra_config.outputs) + if self.gather_output: + output = self.collective_func(output) + return self.post_all_reduce(output) + + def post_all_reduce(self, output): + if not self.skip_bias_add: + output = output + self.bias if self.bias is not None else output + output_bias = None + else: + output_bias = self.bias + return output, output_bias + + def extra_repr(self) -> str: + return extra_representation( + self.extra_repr_org(), + self.class_name_org, + get_current_repr(self, "scale_input", "scale_weight"), + ) + + +class PatchedLmHeadLinearAllreduce(nn.Module): + def __init__(self, mod, mod_extra_config, *args, **kwargs): + super().__init__() + set_attrs_from_orig_model(self, mod, mod_extra_config) + if self.quantization_mode == QuantMode.QUANTIZE: + self.quant_input = self._mod_extra_config.inputs[0] + self.quant_output = self._mod_extra_config.outputs[0] + self.weight = nn.Parameter(self.weight.t().contiguous()) + self.scale_input = nn.Parameter(mod_extra_config.scale.inputs[0]) + if isinstance(mod_extra_config.scale.params["weight"], (torch.Tensor, float)): + self.scale_weight = nn.Parameter(mod_extra_config.scale.params["weight"]) + elif isinstance(mod_extra_config.scale.params["weight"], dict): + # PCQ weight is calculated with actual weight [0] and ones [1] + self.scale_weight = nn.Parameter(mod_extra_config.scale.params["weight"][0]) + elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): + self.forward = self.forward_measure + + def forward(self, input): + assert ( + input.shape[-1] % self.world_size == 0 + ), "Please ensure that self.world_size is divisible by input.shape[-1]" + input_shard = input.shape[-1] // self.world_size + splittedInput = input[:, :, self.rank * input_shard : (self.rank + 1) * input_shard] + qinput = self.quant_input(splittedInput) + output = matmul_fp8( + qinput, + self.weight, + out_dtype=self._mod_extra_config.config_params["hp_dtype"], + scale_input_inv=self.scale_input, + scale_other_inv=self.scale_weight, + ) + dqoutput = self.quant_output(output) + + if self.mp_group is not None: + from deepspeed import comm as dist + + dist.inference_all_reduce(dqoutput, group=self.mp_group) + if self.bias is not None: + dqoutput += self.bias + return dqoutput + + def forward_measure(self, input): + assert ( + input.shape[-1] % self.world_size == 0 + ), "Please ensure that self.world_size is divisible by input.shape[-1]" + input_shard = input.shape[-1] // self.world_size + splittedInput = input[:, :, self.rank * input_shard : (self.rank + 1) * input_shard] + measure_input((splittedInput,), observer=self._mod_extra_config.inputs) + output = torch.matmul(splittedInput, self.weight.t()) + measure_output((output,), self._mod_extra_config.outputs) + + if self.mp_group is not None: + from deepspeed import comm as dist + + dist.inference_all_reduce(output, group=self.mp_group) + if self.bias is not None: + output += self.bias + return output + + def extra_repr(self) -> str: + return extra_representation( + self.extra_repr_org(), + self.class_name_org, + get_current_repr(self, "scale_input", "scale_weight"), + ) + + +class PatchedKVCache(nn.Module): + # Module to patch KVCache module from llama model + def __init__(self, mod, mod_extra_config, *args, **kwargs): + super().__init__() + set_attrs_from_orig_model(self, mod, mod_extra_config, "forward", "get_shape") + self.org_allocate = mod.allocate + self.org_update = mod.update + if self.quantization_mode == QuantMode.QUANTIZE: + mod.update = self.update + self.quant_input = self._mod_extra_config.inputs[0] + self.quant_output = self._mod_extra_config.outputs[0] + elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): + self.update = self.update_measure + mod.update = self.update_measure + + # overwrite allocate function of original module to force allocation in fp8 + def allocate(self, inp_seq_len, dtype, device, shape): + dtype = torch.float8_e4m3fn if (self.quantization_mode == QuantMode.QUANTIZE) else dtype + return self.org_allocate(inp_seq_len, dtype, device, shape) + + # overwrite update function of original module to force quant and dequant of cache input and output + def update(self, prev, cur, dim, idx, inp_seq_len): + qinput = self.quant_input(cur) + output = self.org_update(prev, qinput, dim, idx, inp_seq_len) + if output.dtype == torch.float8_e4m3fn: + return self.quant_output(output) + else: + return output + + # overwrite update function of original module to force quant and dequant of cache input and output + def update_measure(self, prev, cur, dim, idx, inp_seq_len): + measure_input((cur,), self._mod_extra_config.inputs) + output = self.org_update(prev, cur, dim, idx, inp_seq_len) + measure_output((output,), self._mod_extra_config.outputs) + return output + + +class PatchedVLLMKVCache(nn.Module): + # Module to patch VLLMKVCache module from llama model + def __init__(self, mod, mod_extra_config, *args, **kwargs): + super().__init__() + set_attrs_from_orig_model(self, mod, mod_extra_config) + if self.quantization_mode == QuantMode.QUANTIZE: + self.quant_input = self._mod_extra_config.inputs[0] + self.quant_output = self._mod_extra_config.outputs[0] + self.orig_fetch_from_cache = mod.fetch_from_cache + elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): + self.fetch_from_cache = mod.fetch_from_cache + self.forward = self.forward_measure + + def forward(self, input, cache, block_indices, block_offset): + qinput = self.quant_input(input) + output_cache = self.forward_orig(qinput, cache, block_indices, block_offset) + return self.quant_output(output_cache) + + def forward_measure(self, input, cache, block_indices, block_offset): + measure_input((input), self._mod_extra_config.inputs) + output_cache = self.forward_orig(input, cache, block_indices, block_offset) + measure_output((output_cache), self._mod_extra_config.outputs) + return output_cache + + def fetch_from_cache(self, cache, blocks, permutations): + quant_cache = self.quant_input(cache) + output_cache = self.orig_fetch_from_cache(quant_cache, blocks, permutations) + for i in range(len(output_cache)): + output_cache[i]=self.quant_output(output_cache[i]) + return output_cache + + +class PatchedConv2d(nn.Conv2d): + def __init__(self, mod, mod_extra_config, *args, **kwargs): + set_attrs_from_orig_model(self, mod, mod_extra_config) + if self.quantization_mode == QuantMode.QUANTIZE: + self.quant_input = self._mod_extra_config.inputs[0] + self.scale_input = nn.Parameter(mod_extra_config.scale.inputs[0]) + self.scale_weight = nn.Parameter(mod_extra_config.scale.params["weight"]) + elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): + self.forward = self.forward_measure + + def forward(self, input): + qinput = self.quant_input(input) + output = conv2d_fp8( + qinput, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + out_dtype=self._mod_extra_config.config_params["hp_dtype"], + scale_input_inv=self.scale_input, + scale_other_inv=self.scale_weight, + ) + return output + + def forward_measure(self, input): + measure_input((input,), observer=self._mod_extra_config.inputs) + output = self.forward_orig(input) + measure_output((output,), self._mod_extra_config.outputs) + return output + + def extra_repr(self) -> str: + return extra_representation( + self.extra_repr_org(), + self.class_name_org, + get_current_repr(self, "scale_input", "scale_weight"), + ) + + +class PatchedSoftmax(nn.Module): + def __init__(self, mod, mod_extra_config, *args, **kwargs): + super().__init__() + set_attrs_from_orig_model(self, mod, mod_extra_config) + if self.quantization_mode == QuantMode.QUANTIZE: + self.quant_output = self._mod_extra_config.outputs[0] + # input scale is 1 assuming the input to SM is descaled because we are using HW supported scales + self.scale_input = nn.Parameter(torch.Tensor([1.0])) + self.scale_output = nn.Parameter(torch.Tensor([1 / mod_extra_config.scale.outputs[0]])) + elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): + self.forward = self.forward_measure + + def forward(self, x, dim=None, invAttnHead=None): + output = torch.ops.hpu.softmax_fp8(x, dim, self.scale_input, self.scale_output, invAttnHead) + return self.quant_output(output) + + def forward_measure(self, x, dim=None, invAttnHead=None): + measure_input((x,), observer=self._mod_extra_config.inputs) + output = self.forward_orig(x, dim, invAttnHead) + measure_output((output,), self._mod_extra_config.outputs) + return output + + def extra_repr(self) -> str: + return extra_representation( + self.extra_repr_org(), + self.class_name_org, + get_current_repr(self, "scale_input", "scale_output"), + ) + + +class PatchedLoRACompatibleLinear(nn.Linear): + def __init__(self, mod, mod_extra_config, *args, **kwargs): + set_attrs_from_orig_model(self, mod, mod_extra_config) + if self.quantization_mode == QuantMode.QUANTIZE: + self.quant_input = self._mod_extra_config.inputs[0] + self.weight = nn.Parameter(self.weight.t().contiguous()) + self.scale_input = nn.Parameter(mod_extra_config.scale.inputs[0]) + self.scale_weight = nn.Parameter(mod_extra_config.scale.params["weight"]) + elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): + self.forward = self.forward_measure + + def forward(self, input, scale: float = 1.0): + qinput = self.quant_input(input) + y = matmul_fp8( + qinput, + self.weight, + out_dtype=self._mod_extra_config.config_params["hp_dtype"], + scale_input_inv=self.scale_input, + scale_other_inv=self.scale_weight, + ) + output = y + self.bias if (self.bias is not None) else y + if self.lora_layer is not None: + # TODO SW-174899 support lora layer quantization + _raise_lora_layer_error(self.class_name_org) + # output = output + (scale * self.lora_layer(input)) + return output + + def forward_measure(self, input, scale: float = 1.0): + measure_input((input,), observer=self._mod_extra_config.inputs) + output = self.forward_orig(input, scale) + measure_output((output,), self._mod_extra_config.outputs) + return output + + def extra_repr(self) -> str: + return extra_representation( + self.extra_repr_org(), + self.class_name_org, + get_current_repr(self, "scale_input", "scale_weight"), + ) + + +class PatchedLoRACompatibleConv(nn.Conv2d): + def __init__(self, mod, mod_extra_config, *args, **kwargs): + set_attrs_from_orig_model(self, mod, mod_extra_config) + if self.quantization_mode == QuantMode.QUANTIZE: + self.quant_input = self._mod_extra_config.inputs[0] + self.scale_input = nn.Parameter(mod_extra_config.scale.inputs[0]) + self.scale_weight = nn.Parameter(mod_extra_config.scale.params["weight"]) + elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): + self.forward = self.forward_measure + + def forward(self, input, scale: float = 1.0): + qinput = self.quant_input(input) + if self.lora_layer is not None: + # TODO SW-174899 support lora layer quantization + _raise_lora_layer_error(self.class_name_org) + # output = conv2d_fp8(qinput, self.weight, None, self.stride, self.padding, self.dilation, self.groups, \ + # out_dtype=self._mod_extra_config.config_params["hp_dtype"], scale_input_inv=self.scale_input, scale_other_inv=self.scale_weight) + # output = output + (scale * self.lora_layer(input)) + # output = output+torch.unsqueeze(torch.unsqueeze(self.bias,1), 1) if (self.bias is not None) else output + else: + output = conv2d_fp8( + qinput, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + out_dtype=self._mod_extra_config.config_params["hp_dtype"], + scale_input_inv=self.scale_input, + scale_other_inv=self.scale_weight, + ) + return output + + def forward_measure(self, input, scale: float = 1.0): + measure_input((input,), observer=self._mod_extra_config.inputs) + output = self.forward_orig(input, scale) + measure_output((output,), self._mod_extra_config.outputs) + return output + + def extra_repr(self) -> str: + return extra_representation( + self.extra_repr_org(), + self.class_name_org, + get_current_repr(self, "scale_input", "scale_weight"), + ) + + +class PatchedModuleFusedSDPA(nn.Module): + def __init__(self, mod, mod_extra_config, *args, **kwargs): + # fsdpa is combined out of - BMM1(Q,K) -> Softmax -> BMM2(AMAX,V) + # during measure we recieve the amax value from the cguid and apply it during quant as input + super().__init__() + set_attrs_from_orig_model(self, mod, mod_extra_config) + if self.quantization_mode == QuantMode.QUANTIZE: + self.quant_q = self._mod_extra_config.inputs[0] + self.quant_k = self._mod_extra_config.inputs[1] + self.quant_v = self._mod_extra_config.inputs[2] + self.dequant_output = self._mod_extra_config.outputs[0] + self.scale_q = nn.Parameter(mod_extra_config.scale.inputs[0].type(torch.float32)) + self.scale_k = nn.Parameter(mod_extra_config.scale.inputs[1].type(torch.float32)) + self.scale_v = nn.Parameter(mod_extra_config.scale.inputs[2].type(torch.float32)) + self.descale_amax = nn.Parameter(mod_extra_config.scale.inputs[3].type(torch.float32)) + self.scale_output = nn.Parameter(1 / mod_extra_config.scale.outputs[0].type(torch.float32)) + self.scale_amax = nn.Parameter(1 / self.descale_amax) + elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): + self.forward = self.forward_measure + + def forward( + self, + q, + k, + v, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, + softmax_mode="None", + ): + qinput = self.quant_q(q).detach() + kinput = self.quant_k(k).detach() + vinput = self.quant_v(v).detach() + results = fp8_fused_sdpa( + qinput, + kinput, + vinput, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + # fp8_fused_sdpa in fp8 mode supports only FastSoftmax + softmax_mode="None", + d_scale_q=self.scale_q, + d_scale_k=self.scale_k, + d_scale_v=self.scale_v, + q_scale_s=self.scale_amax, + q_scale_o=self.scale_output, + d_scale_s=self.descale_amax, + is_amax_s=False, + ) + output = results[0] + d_out = self.dequant_output(output) + return d_out + + def forward_measure( + self, + q, + k, + v, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, + softmax_mode="fast", + ): + dq = q.detach() + dk = k.detach() + dv = v.detach() + measure_input((dq, dk, dv), observer=self._mod_extra_config.inputs) + results = fp8_fused_sdpa( + dq, + dk, + dv, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + # fp8_fused_sdpa in bf16 can use either FastSoftmax or regular + softmax_mode="fast", + is_amax_s=True, + ) + output = results[0] + amax = results[1] + measure_output((output, amax), self._mod_extra_config.outputs) + return output + + def extra_repr(self) -> str: + return extra_representation( + self.extra_repr_org(), + self.class_name_org, + get_current_repr( + self, + "scale_q", + "scale_k", + "scale_v", + "descale_amax", + "scale_amax", + "scale_output", + ), + ) + + +class PatchedUnmeasuredModule(nn.Module): + def __init__(self, name, *args, **kwargs): + super().__init__() + self.name = name + + def forward(self, *args, **kwargs): + raise Exception( + "Error - Layer '{}' was called but was not quantized because no measures were supplied.".format(self.name) + ) + + def extra_repr(self) -> str: + return f"Dummy patch of {self.name} to raise excption as there are no measurements provided." diff --git a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py new file mode 100644 index 00000000000..10c94dea640 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py @@ -0,0 +1,250 @@ +from __future__ import annotations + +import json +import os +import torch +from enum import Enum, Flag, auto +from dataclasses import dataclass +from json.decoder import JSONDecodeError +from typing import Any, Mapping +import habana_frameworks.torch.utils.experimental as htexp + +from ..utils.logger import logger + +local_rank = int(os.getenv("LOCAL_RANK", "-1")) +world_size = int(os.getenv("WORLD_SIZE", "-1")) +global_rank = int(os.getenv("RANK", "-1")) + + +class QuantMode(Enum): + NONE = 0 + QUANTIZE = 1 + MEASURE = 2 + SHAPE = 3 + + +class MeasureExclude(Flag): + NONE = auto() + INPUT = auto() + OUTPUT = auto() + PARAMS = auto() + ALL = auto() + + +class ScaleMethod(Enum): + MAX = 1 + UNIT_SCALE = 2 + MAXABS_HW = 3 + MAXABS_POW2 = 4 + SMOOTHQUANT_WEIGHTS_OUTPUT_CHANNEL_MAXABS_POW2 = 5 + WEAKSMOOTHQUANT_WEIGHTS_OUTPUT_CHANNEL_MAXABS_POW2 = 6 + ACT_MAXABS_HW_WEIGHTS_PCS_MAXABS_POW2 = 7 + ACT_MAXABS_HW_WEIGHTS_PCS_OPT_POW2 = 8 + ACT_MAXABS_POW2_WEIGHTS_PCS_MAXABS_POW2 = 9 + ACT_MAXABS_POW2_WEIGHTS_PCS_OPT_POW2 = 10 + SMOOTHQUANT_OPT = 11 + MAXABS_HW_OPT_WEIGHT = 12 + MAXABS_POW2_OPT_WEIGHT = 13 + + +def get_hqt_config(mod) -> Fp8cfg: + return mod.__hqt_config__ + + +def set_hqt_config(mod, config): + mod.__hqt_config__ = config + + +@dataclass +class Fp8cfg: + cfg: Mapping[str, Any] + + def parse(custom_config: Mapping[str, str]) -> Fp8cfg: + measured_global_config = { + "dump_stats_path": "stats", + "fp8_config": torch.float8_e4m3fn, # The parameters of the chosen Quantization methed + "hp_dtype": torch.bfloat16, # The parameters of the chosen Quantization methed + "blocklist": { + "names": [], + "types": (), + }, # types and names to not be quantized + "allowlist": { + "names": [], + "types": ("torch.nn.Linear", "torch.nn.Conv2d", "BMM"), + }, # types and names to be quantized. Allowlist by names is not yet implemented + "mode": QuantMode.QUANTIZE, # Quantize or Measure + "scale_method": ScaleMethod.UNIT_SCALE, # Method to quantize with + "scale_params": {}, # scaling parameters that are different then the default ones + "observer": "maxabs", # Supported ['shape', 'maxabs', 'maxabs_per_channel', 'save'] + "mod_dict": {}, + "ignore_modules_wo_measures": False, # Determines whether to fail quantization on modules without existing measures or not to quantize them + "local_rank": local_rank if local_rank >= 0 else None, + "global_rank": None, + "world_size": world_size if world_size >= 0 else None, + "seperate_measure_files": True, # Determines whether to expect one or several measure files when using more than one gaudi + "device_type": htexp._get_device_type(), # Determines device type: Gaudi2, Gaudi3... + "measure_exclude": MeasureExclude.OUTPUT, + } + # assert measured_global_config['allowlist']['names'] == [''], "Allowlist names not yet implemented" + + # go over all user-defined keys from json, handle various cases + for keys in custom_config: + if keys == "mode": + if custom_config[keys] == "NONE": + custom_config[keys] = QuantMode.NONE + elif custom_config[keys] == "QUANTIZE": + custom_config[keys] = QuantMode.QUANTIZE + elif custom_config[keys] == "MEASURE": + custom_config[keys] = QuantMode.MEASURE + elif custom_config[keys] == "SHAPE": + custom_config[keys] = QuantMode.SHAPE + else: + raise ValueError("invalid mode in custom config. Enter Quantize or Measure") + + if keys == "measure_exclude": + if custom_config[keys] == "NONE": + custom_config[keys] = MeasureExclude.NONE + elif custom_config[keys] == "OUTPUT": + custom_config[keys] = MeasureExclude.OUTPUT + elif custom_config[keys] == "INPUT": + custom_config[keys] = MeasureExclude.INPUT + elif custom_config[keys] == "ALL": + custom_config[keys] = MeasureExclude.ALL + else: + raise ValueError("invalid measure exclude value in custom config. Enter OUTPUT or NONE") + + if keys == "fp8_config": + if custom_config[keys].lower() == "e4m3": + custom_config[keys] = torch.float8_e4m3fn + + elif custom_config[keys].lower() == "e5m2": + custom_config[keys] = torch.float8_e5m2 + else: + raise ValueError("invalid fp8_config in custom config. Enter E4M3 or E5M2") + + if keys == "scale_method": + if custom_config[keys].lower() == "unit_scale": + custom_config[keys] = ScaleMethod.UNIT_SCALE + elif custom_config[keys].lower() == "max": + custom_config[keys] = ScaleMethod.MAX + elif custom_config[keys].lower() == "maxabs_hw": + custom_config[keys] = ScaleMethod.MAXABS_HW + elif custom_config[keys].lower() == "maxabs_pow2": + custom_config[keys] = ScaleMethod.MAXABS_POW2 + elif custom_config[keys].lower() == "maxabs_hw_opt_weight": + custom_config[keys] = ScaleMethod.MAXABS_HW_OPT_WEIGHT + elif custom_config[keys].lower() == "maxabs_pow2_opt_weight": + custom_config[keys] = ScaleMethod.MAXABS_POW2_OPT_WEIGHT + elif custom_config[keys].lower() == "smoothquant_weights_output_channel_maxabs_pow2": + custom_config[keys] = ScaleMethod.SMOOTHQUANT_WEIGHTS_OUTPUT_CHANNEL_MAXABS_POW2 + elif custom_config[keys].lower() == "weaksmoothquant_weights_output_channel_maxabs_pow2": + custom_config[keys] = ScaleMethod.WEAKSMOOTHQUANT_WEIGHTS_OUTPUT_CHANNEL_MAXABS_POW2 + elif custom_config[keys].lower() == "act_maxabs_hw_weights_pcs_maxabs_pow2": + custom_config[keys] = ScaleMethod.ACT_MAXABS_HW_WEIGHTS_PCS_MAXABS_POW2 + elif custom_config[keys].lower() == "act_maxabs_hw_weights_pcs_opt_pow2": + custom_config[keys] = ScaleMethod.ACT_MAXABS_HW_WEIGHTS_PCS_OPT_POW2 + elif custom_config[keys].lower() == "act_maxabs_pow2_weights_pcs_maxabs_pow2": + custom_config[keys] = ScaleMethod.ACT_MAXABS_POW2_WEIGHTS_PCS_MAXABS_POW2 + elif custom_config[keys].lower() == "act_maxabs_pow2_weights_pcs_opt_pow2": + custom_config[keys] = ScaleMethod.ACT_MAXABS_POW2_WEIGHTS_PCS_OPT_POW2 + elif custom_config[keys].lower() == "smoothquant_opt": + custom_config[keys] = ScaleMethod.SMOOTHQUANT_OPT + else: + raise ValueError( + f'Invalid fp8_config in custom config ({custom_config[keys]}). should be in ["max", "unit_scale", "maxabs_hw", "maxabs_pow2", "maxabs_per_channel_pow2", "smoothquant_opt"]' + ) + + if keys == "ignore_modules_wo_measures": + custom_config[keys] = custom_config[keys].lower() == "true" + + # TODO [SW-175936] - remove checking for old key names whitelist and blacklist. + if isinstance(custom_config[keys], dict): + for keys_2 in custom_config[keys]: + if keys == "whitelist": + measured_global_config["allowlist"][keys_2] = custom_config[keys][keys_2] + elif keys == "blacklist": + measured_global_config["blocklist"][keys_2] = custom_config[keys][keys_2] + else: + measured_global_config[keys][keys_2] = custom_config[keys][keys_2] + else: + if keys == "whitelist": + measured_global_config["allowlist"] = custom_config[keys] + elif keys == "blacklist": + measured_global_config["blocklist"] = custom_config[keys] + else: + measured_global_config[keys] = custom_config[keys] + + # If seperate_measure_files is True (default value), then it is assumed that there are multiple distinct measure and scale files + # and they are stored in / loaded from paths with the correct index as a suffix. Else, only one is searched for. + measured_global_config["local_rank"] = ( + local_rank if local_rank >= 0 and (custom_config.get("seperate_measure_files", True) == True) else None + ) + + base_name = measured_global_config["dump_stats_path"].split("/")[-1] + folder_name = measured_global_config["dump_stats_path"][: -(len(base_name))] + measured_global_config["dump_stats_base_path"] = folder_name + os.makedirs(folder_name, exist_ok=True) + worker_st = ( + "" + if measured_global_config["local_rank"] == None + else "_" + str(measured_global_config["local_rank"]) + "_" + str(measured_global_config["world_size"]) + ) + measured_global_config["shape_file"] = measured_global_config["dump_stats_path"] + "_hooks_shape" + worker_st + measured_global_config["scale_file"] = ( + measured_global_config["dump_stats_path"] + + "_hooks_" + + measured_global_config["observer"] + + "_" + + measured_global_config["scale_method"].name + + worker_st + ) + if (measured_global_config["mode"] == QuantMode.MEASURE) or ( + measured_global_config["mode"] == QuantMode.QUANTIZE + ): + measured_global_config["measure_file"] = ( + measured_global_config["dump_stats_path"] + "_hooks_" + measured_global_config["observer"] + worker_st + ) + # measured_global_config['dump_stats_path'] += '_hooks_.json' + + logger.debug("HQT Paths:") + logger.debug("base_name='%s'", base_name) + logger.debug("folder_name='%s'", folder_name) + logger.debug( + "measured_global_config['shape_file']='%s'", + measured_global_config["shape_file"], + ) + logger.debug( + "measured_global_config['scale_file']='%s'", + measured_global_config["scale_file"], + ) + if "measure_file" in measured_global_config.keys(): + logger.debug( + "measured_global_config['measure_file']='%s'", + measured_global_config["measure_file"], + ) + logger.debug( + "measured_global_config['dump_stats_path']='%s'", + measured_global_config["dump_stats_path"], + ) + + return Fp8cfg(cfg=measured_global_config) + + +def _read_config_from_file(config_path: str) -> Mapping[str, str]: + logger.debug("QUANT PACKAGE: using %s config", config_path) + + module_directory = os.path.dirname(os.path.abspath(__file__)) + + # if file in absolute path doesn't exist, try looking in cfg directory + if not os.path.isfile(config_path): + config_path = os.path.join(module_directory, "..", f"custom_config/{config_path}.json") + try: + logger.info("QUANT PACKAGE: Loading %s", config_path) + with open(config_path) as config_json: + config = json.load(config_json) + except FileNotFoundError as e: + raise Exception(f"Got exception: {e}. QUANT PACKAGE: Can't open {config_path}!") + except JSONDecodeError as e: + config_json.close() + raise Exception(f"Got exception: {e}. QUANT PACKAGE: Can't load {config_path} json!") + return config diff --git a/neural_compressor/torch/algorithms/fp8_quant/common.py b/neural_compressor/torch/algorithms/fp8_quant/common.py index 4a603c677ac..ff1dc90a43f 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/common.py +++ b/neural_compressor/torch/algorithms/fp8_quant/common.py @@ -21,13 +21,15 @@ import torch +from neural_compressor.torch.algorithms.fp8_quant.prepare_quant.prepare_model import finish_measurements +from neural_compressor.torch.algorithms.fp8_quant._quant_common.quant_config import Fp8cfg + def save_calib_result(model): - import habana_quantization_toolkit as hqt if (hasattr(model, "__hqt_config__") and - isinstance(model.__hqt_config__, hqt._quant_common.quant_config.Fp8cfg)): + isinstance(model.__hqt_config__, Fp8cfg)): # TODO SW-184714 modify hqt notation to inc notation once code is ported - hqt.finish_measurements(model) + finish_measurements(model) else: raise NotImplementedError("Saving calibration results currently supported only in HPU.") diff --git a/neural_compressor/torch/algorithms/fp8_quant/custom_config/custom_example.json b/neural_compressor/torch/algorithms/fp8_quant/custom_config/custom_example.json new file mode 100644 index 00000000000..26b8af220a7 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/custom_config/custom_example.json @@ -0,0 +1,5 @@ +{ + "mode": "MEASURE", + "scale_method": "MAX", + "fp8_config": "E4M3" +} \ No newline at end of file diff --git a/neural_compressor/torch/algorithms/fp8_quant/custom_config/llama_measure.json b/neural_compressor/torch/algorithms/fp8_quant/custom_config/llama_measure.json new file mode 100644 index 00000000000..fc675067c22 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/custom_config/llama_measure.json @@ -0,0 +1,14 @@ +{ + "mode": "MEASURE", + "observer": "maxabs", + "allowlist": { + "types": [], + "names": [] + }, + "blocklist": { + "types": [], + "names": [] + }, + "quantize_weight": false, + "dump_stats_path": "./llama_output/7b_measure" +} \ No newline at end of file diff --git a/neural_compressor/torch/algorithms/fp8_quant/custom_config/llama_quant.json b/neural_compressor/torch/algorithms/fp8_quant/custom_config/llama_quant.json new file mode 100644 index 00000000000..f341964187a --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/custom_config/llama_quant.json @@ -0,0 +1,17 @@ +{ + "mode": "QUANTIZE", + "observer": "maxabs", + "scale_method": "maxabs_hw", + "allowlist": { + "types": [], + "names": [] + }, + "blocklist": { + "types": [], + "names": [ + "lm_head" + ] + }, + "quantize_weight": false, + "dump_stats_path": "./llama_output/7b_measure" +} \ No newline at end of file diff --git a/neural_compressor/torch/algorithms/fp8_quant/custom_config/measure_config.json b/neural_compressor/torch/algorithms/fp8_quant/custom_config/measure_config.json new file mode 100755 index 00000000000..b8c4d29b781 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/custom_config/measure_config.json @@ -0,0 +1,12 @@ +{ + "mode": "MEASURE", + "scale_method": "MAX", + "quantize_weight": true, + "dump_stats_path": "./run_outputs/fp8/stats", + "allowlist": { + "types": [ + "torch.nn.Linear", + "torch.nn.Conv2d" + ] + } +} \ No newline at end of file diff --git a/neural_compressor/torch/algorithms/fp8_quant/custom_config/quant_config.json b/neural_compressor/torch/algorithms/fp8_quant/custom_config/quant_config.json new file mode 100755 index 00000000000..286a1632257 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/custom_config/quant_config.json @@ -0,0 +1,13 @@ +{ + "mode": "QUANTIZE", + "observer": "maxabs", + "scale_method": "maxabs_hw", + "fp8_config": "E4M3", + "allowlist": { + "types": [ + "torch.nn.Linear", + "torch.nn.Conv2d" + ] + }, + "dump_stats_path": "./run_outputs/fp8/stats" +} \ No newline at end of file diff --git a/neural_compressor/torch/algorithms/fp8_quant/fp8_quant.py b/neural_compressor/torch/algorithms/fp8_quant/fp8_quant.py index f9ce9145569..bbde53fb417 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/fp8_quant.py +++ b/neural_compressor/torch/algorithms/fp8_quant/fp8_quant.py @@ -20,6 +20,7 @@ restore_patched_module, update_mode, with_patched_module, + prep_model, ) @@ -44,12 +45,10 @@ def convert(self, model): def _convert(model, config_path): - import habana_quantization_toolkit as hqt - # update mode to QUANTIZE config_path = update_mode(config_path, quant_step=True) - return hqt.prep_model(model, config_path) + return prep_model(model, config_path) def _prepare(model, config_path): @@ -58,4 +57,4 @@ def _prepare(model, config_path): # update mode to MEASURE config_path = update_mode(config_path, measure_step=True) - return hqt.prep_model(model, config_path) + return prep_model(model, config_path) diff --git a/neural_compressor/torch/algorithms/fp8_quant/prepare_quant/__init__.py b/neural_compressor/torch/algorithms/fp8_quant/prepare_quant/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/neural_compressor/torch/algorithms/fp8_quant/prepare_quant/prepare_model.py b/neural_compressor/torch/algorithms/fp8_quant/prepare_quant/prepare_model.py new file mode 100644 index 00000000000..8a38f79388b --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/prepare_quant/prepare_model.py @@ -0,0 +1,36 @@ +import os +from typing import Optional +from .._quant_common.quant_config import Fp8cfg +from .._core.measure import save_measurements +from .._core.utils import prepare_model +from .._quant_common.quant_config import ( + _read_config_from_file, + Fp8cfg, + set_hqt_config, +) + +def _prep_model_with_predefined_config(model, *, config: Fp8cfg): + set_hqt_config(model, config) + prepare_model(model) + + +def prep_model(model, config_path: Optional[str] = None): + """ + Prepare this model with the given (absolute or relative) path of the json file containing the configuration. + If `config_path` is not given or `None`, + instead perform the legacy behavior of checking for env variable `QUANT_CONFIG`. + """ + if config_path is None: + config_path = os.getenv("QUANT_CONFIG") + if config_path is None: + raise EnvironmentError( + "Either pass config_path parameter explicitly (recommended), or set environment variable QUANT_CONFIG" + ) + + config = _read_config_from_file(config_path=config_path) + config = Fp8cfg.parse(config) + return _prep_model_with_predefined_config(model, config=config) + + +def finish_measurements(model): + save_measurements(model) diff --git a/neural_compressor/torch/algorithms/fp8_quant/scripts/__init__.py b/neural_compressor/torch/algorithms/fp8_quant/scripts/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/neural_compressor/torch/algorithms/fp8_quant/scripts/regression_detection/__init__.py b/neural_compressor/torch/algorithms/fp8_quant/scripts/regression_detection/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/neural_compressor/torch/algorithms/fp8_quant/scripts/regression_detection/golden_metrics.json b/neural_compressor/torch/algorithms/fp8_quant/scripts/regression_detection/golden_metrics.json new file mode 100644 index 00000000000..8409f7ffb47 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/scripts/regression_detection/golden_metrics.json @@ -0,0 +1,74 @@ +{ + "bf16": { + "winogrande": { + "mean": 0.7995, + "sem": 0.0112 + }, + "hellaswag": { + "mean": 0.6529, + "sem": 0.0048 + }, + "piqa": { + "mean": 0.8166, + "sem": 0.0090 + }, + "lambada_openai": { + "mean": 0.7900, + "sem": 0.0057 + } + }, + "fp8": { + "ptq": { + "winogrande": { + "mean": 0.7948, + "sem": 0.0113, + "mean_diff": -0.0047, + "sem_diff": 0.0058 + }, + "hellaswag": { + "mean": 0.6473, + "sem": 0.0048, + "mean_diff": 0.0056, + "sem_diff": 0.0014 + }, + "piqa": { + "mean": 0.8134, + "sem": 0.0091, + "mean_diff": -0.0033, + "sem_diff": 0.0034 + }, + "lambada_openai": { + "mean": 0.7900, + "sem": 0.0057, + "mean_diff": 0.0000, + "sem_diff": 0.0021 + } + }, + "pcq": { + "winogrande": { + "mean": 0.8003, + "sem": 0.0112, + "mean_diff": 0.0008, + "sem_diff": 0.0060 + }, + "hellaswag": { + "mean": 0.6512, + "sem": 0.0048, + "mean_diff": -0.0017, + "sem_diff": 0.0010 + }, + "piqa": { + "mean": 0.8150, + "sem": 0.0091, + "mean_diff": -0.0016, + "sem_diff": 0.0031 + }, + "lambada_openai": { + "mean": 0.7920, + "sem": 0.0057, + "mean_diff": 0.0019, + "sem_diff": 0.0021 + } + } + } +} \ No newline at end of file diff --git a/neural_compressor/torch/algorithms/fp8_quant/scripts/regression_detection/regression_detection.py b/neural_compressor/torch/algorithms/fp8_quant/scripts/regression_detection/regression_detection.py new file mode 100644 index 00000000000..59d609a48dd --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/scripts/regression_detection/regression_detection.py @@ -0,0 +1,117 @@ +import argparse +import numpy as np +import scipy +import json + +tasks = ["winogrande", "hellaswag", "piqa", "lambada_openai"] + + +def ztest(ref_mean=0.0, ref_stderr=1.0, test_mean=0.0, test_stderr=0.0): + z_score = (test_mean - ref_mean) / np.sqrt(ref_stderr**2 + test_stderr**2) + p_value = 1.0 + scipy.special.erf(-np.abs(z_score) / np.sqrt(2)) + return p_value + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description="Regression detection using Z-Test. We assume we have mean and SEM of golden run arranged in json and test results json and we compare the results to see if degregation occurred.", + ) + parser.add_argument( + "--hp_dtype", + type=str, + help="Data type of the high precision test", + default=None, + ) + parser.add_argument( + "--lp_dtype", + type=str, + help="Data type of the low precision test", + default=None, + ) + parser.add_argument( + "--golden_metrics", + type=str, + help="Path to json that includes mean, SEM and diff golden metrics of bf16 and fp8 precision.", + default=None, + ) + parser.add_argument( + "--test_metrics_lp", + type=str, + help="Path to json that includes mean, SEM and diff test metrics of lp precision.", + default=None, + ) + parser.add_argument( + "--test_metrics_hp", + type=str, + help="Path to json that includes mean, SEM and diff test metrics of high precision.", + default=None, + ) + parser.add_argument("--quantization_mode", type=str, help="quantization mode", default=None) + args = parser.parse_args() + mode = args.quantization_mode + hp_dtype = args.hp_dtype + lp_dtype = args.lp_dtype + golden_metrics_path = args.golden_metrics + test_metrics_lp_path = args.test_metrics_lp + test_metrics_hp_path = args.test_metrics_hp + if golden_metrics_path is None or test_metrics_hp_path is None or test_metrics_lp_path is None: + print("Please provide golden_metrics, test_metrics_hp_path and test_metrics_lp_path json paths") + exit(1) + + with open(golden_metrics_path, "r") as f: + golden_metrics_json = json.load(f) + + with open(test_metrics_lp_path, "r") as f: + test_metrics_lp_json = json.load(f) + test_metrics_lp_json = test_metrics_lp_json["results"] + + with open(test_metrics_hp_path, "r") as f: + test_metrics_hp_json = json.load(f) + test_metrics_hp_json = test_metrics_hp_json["results"] + + regressions = [] + for task in tasks: + # The two-sample z-test comparing the golden and under-test high-precision configuration + ref_mean_hp = golden_metrics_json[hp_dtype][task]["mean"] + ref_stderr_hp = golden_metrics_json[hp_dtype][task]["sem"] + test_mean_hp = test_metrics_hp_json[task]["acc"] + test_stderr_hp = test_metrics_hp_json[task]["acc_stderr"] + p_hp_value = ztest(ref_mean_hp, ref_stderr_hp, test_mean_hp, test_stderr_hp) + print(f"Z-Test high precision p-value={p_hp_value*100:.2f}% in {task} task") + if p_hp_value < 0.05: + regressions.append(f"Z-Test high precision p-value is less than 0.05 in {task} task.") + + # The two-sample z-test comparing the golden and under-test low-precision configuration + if mode != None: + ref_mean_lp = golden_metrics_json[lp_dtype][mode][task]["mean"] + ref_stderr_lp = golden_metrics_json[lp_dtype][mode][task]["sem"] + else: + ref_mean_lp = golden_metrics_json[lp_dtype][task]["mean"] + ref_stderr_lp = golden_metrics_json[lp_dtype][task]["sem"] + test_mean_lp = test_metrics_lp_json[task]["acc"] + test_stderr_lp = test_metrics_lp_json[task]["acc_stderr"] + p_lp_value = ztest(ref_mean_lp, ref_stderr_lp, test_mean_lp, test_stderr_lp) + print(f"Z-Test low precision p-value={p_lp_value*100:.2f}% in {task} task") + if p_lp_value < 0.05: + regressions.append(f"Z-Test low precision p-value is less than 0.05 in {task} task.") + + # The single-sample z-test comparing the golden and under-test degradation of low-precision configuration + if mode != None: + ref_mean_diff = golden_metrics_json[lp_dtype][mode][task]["mean_diff"] + ref_stderr_diff = golden_metrics_json[lp_dtype][mode][task]["sem_diff"] + else: + ref_mean_diff = golden_metrics_json[lp_dtype][task]["mean_diff"] + ref_stderr_diff = golden_metrics_json[lp_dtype][task]["sem_diff"] + test_mean_diff = test_mean_hp - test_mean_lp + p_diff_value = ztest(ref_mean_diff, ref_stderr_diff, test_mean_diff) + print(f"Z-Test low precision diff p-value={p_diff_value*100:.2f}% in {task} task") + if p_diff_value < 0.05: + regressions.append(f"Z-Test low precision diff p-value is less than 0.05 in {task} task.") + + if len(regressions) == 0: + print("No regressions were detected!") + else: + print("Regressions were detected!") + for regression in regressions: + print(regression) diff --git a/neural_compressor/torch/algorithms/fp8_quant/utils/__init__.py b/neural_compressor/torch/algorithms/fp8_quant/utils/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/neural_compressor/torch/algorithms/fp8_quant/utils/logger.py b/neural_compressor/torch/algorithms/fp8_quant/utils/logger.py new file mode 100644 index 00000000000..b4724fe31eb --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/utils/logger.py @@ -0,0 +1,240 @@ +# Taken and adjusted from neural-compressor-fork/neural_compressor/common/utils/logger.py +# Should be merged with INC logger once HQT code is inserted into INC +# TODO: SW-185347 merge INC logger with HQT logger +"""Logger: handles logging functionalities.""" + + +import logging +from logging.handlers import RotatingFileHandler +import os + +__all__ = ["logger"] + +# Define color escape codes +RESET = "\033[0m" +BOLD = "\033[1m" +UNDERLINE = "\033[4m" +WHITE = "\033[37m" +BG_RED = "\033[41m" +RED = "\033[91m" +GREEN = "\033[92m" +YELLOW = "\033[93m" +BLUE = "\033[94m" +PURPLE = "\033[95m" +CYAN = "\033[96m" + + +def _pretty_dict(value, indent=0): + """Make the logger dict pretty.""" + prefix = "\n" + " " * (indent + 4) + if isinstance(value, dict): + items = [prefix + repr(key) + ": " + _pretty_dict(value[key], indent + 4) for key in value] + return "{%s}" % (",".join(items) + "\n" + " " * indent) + elif isinstance(value, list): + items = [prefix + _pretty_dict(item, indent + 4) for item in value] + return "[%s]" % (",".join(items) + "\n" + " " * indent) + elif isinstance(value, tuple): + items = [prefix + _pretty_dict(item, indent + 4) for item in value] + return "(%s)" % (",".join(items) + "\n" + " " * indent) + else: + return repr(value) + + +logging.TRACE = 5 # There is no 'trace' level for python logger. + + +def trace(self, msg, *args, **kwargs): + """ + Log 'msg % args' with severity 'TRACE'. + + To pass exception information, use the keyword argument exc_info with + a true value, e.g. + + logger.trace("Houston, we have a %s", "thorny problem", exc_info=1) + """ + if self.isEnabledFor(logging.TRACE): + self._log(logging.TRACE, msg, args, **kwargs) + + +logging.Logger.trace = trace +logging.IGNORE = 60 +logging.addLevelName(logging.TRACE, "TRACE") +logging.__all__ += ["TRACE", "trace"] + +log_levels = { + "0": logging.TRACE, # = 5 + "1": logging.DEBUG, # = 10 + "2": logging.INFO, # = 20 + "3": logging.WARNING, # = 30 + "4": logging.ERROR, # = 40 + "5": logging.CRITICAL, # = 50 + "6": logging.IGNORE, # = 60 (Disabling logger) +} +MAX_LOG_LEVEL_NAME_LEN = 8 + +DEFAULT_LOG_FILE_SIZE = 1024 * 1024 * 10 +DEFAULT_LOG_FILE_AMOUNT = 5 + + +class _Logger(object): + """_Logger class.""" + + __instance = None + + def __new__(cls): + """Create a singleton _Logger instance.""" + if _Logger.__instance is None: + _Logger.__instance = object.__new__(cls) + _Logger.__instance._init_log() + return _Logger.__instance + + def get_enable_console_val(self): + enableConsole = os.environ.get("ENABLE_CONSOLE", "False").upper() + if enableConsole not in ["TRUE", "FALSE"]: + raise Exception(f"Env var 'ENABLE_CONSOLE' has to be true or false.") + return enableConsole == "TRUE" + + def get_log_level(self): + log_level_str = os.environ.get("LOG_LEVEL_HQT", os.environ.get("LOG_LEVEL_ALL")) + if log_level_str is None: + return logging.INFO + if log_level_str not in log_levels: + raise Exception(f"Wrong Log Level value: '{log_level_str}'. Must be an integer 0-6.") + return log_levels[log_level_str] + + def prepare_logger_format(self): + # Time printing is added to format according to the value of PRINT_TIME env var. + print_time = os.environ.get("PRINT_TIME", "True") + time_format = "" if print_time.upper() in ["0", "FALSE"] else "%(asctime)s.%(msecs)06d" + return f"[{time_format}][%(name)s][%(levelname)s] %(message)s" + + # Create a formatter with lower case level name + @staticmethod + class LowercaseLevelNameFormatter(logging.Formatter): + def format(self, record): + level_name = record.levelname + record.levelname = record.levelname.lower().ljust(MAX_LOG_LEVEL_NAME_LEN) + message = super().format(record) + record.levelname = level_name + return message + + # Create a formatter with color for the console output + @staticmethod + class ColoredFormatter(logging.Formatter): + def format(self, record): + message = super().format(record) + # if record.levelname == 'TRACE': + # stays black + if record.levelname == "DEBUG": + style = CYAN + elif record.levelname == "INFO": + style = GREEN + elif record.levelname == "WARNING": + style = f"{BOLD}{YELLOW}" + elif record.levelname == "ERROR": + style = f"{BOLD}{RED}" + elif record.levelname == "CRITICAL": + style = f"{BG_RED}{BOLD}{WHITE}" + else: + return message + return message.replace( + record.levelname, + f"{style}{record.levelname.lower().ljust(MAX_LOG_LEVEL_NAME_LEN)}{RESET}", + 1, + ) + + def _init_log(self): + """Setup the logger format and handler.""" + enableConsole = self.get_enable_console_val() + self._logger = logging.getLogger("HQT") + log_level = self.get_log_level() + if log_level == logging.IGNORE: + self._logger.disabled = True + else: + # according to: swtools_sdk/hl_logger/src/hllog_core.cpp + self._logger.handlers.clear() + self._logger.setLevel(log_level) + logging_format = self.prepare_logger_format() + hls_id = int(os.getenv("HLS_ID", "-1")) + local_rank_id = int(os.getenv("ID", os.getenv("OMPI_COMM_WORLD_RANK", "-1"))) + habana_logs_path = os.getenv("HABANA_LOGS") + if habana_logs_path is None: + habana_logs_path = ( + "/tmp/.habana_logs" if os.getenv("HOME") is None else os.getenv("HOME") + "/.habana_logs" + ) + log_folder = f"{habana_logs_path}{''if hls_id < 0 else '/{}'.format(hls_id)}" + log_folder = f"{log_folder}{''if local_rank_id < 0 else '/{}'.format(local_rank_id)}" + try: + os.makedirs(log_folder, exist_ok=True) + except OSError as error: + print( + f"Warning: Directory '{log_folder}' can not be created for HQT logs: {error.strerror}. Logger is disabled." + ) + self._logger.disabled = True + pass + file_path = log_folder + "/hqt_log.txt" + log_file_size = int(os.getenv("HQT_LOG_FILE_SIZE", DEFAULT_LOG_FILE_SIZE)) + if log_file_size < 0: + print( + f"Warning: Log file size value is not valid [{log_file_size}]. Using default value [{DEFAULT_LOG_FILE_SIZE}]" + ) + log_file_size = DEFAULT_LOG_FILE_SIZE + log_file_amount = int(os.getenv("HQT_LOG_FILE_AMOUNT", DEFAULT_LOG_FILE_AMOUNT)) + if log_file_amount < 0: + print( + f"Warning: Log file amount value is not valid [{log_file_amount}]. Using default value [{DEFAULT_LOG_FILE_AMOUNT}]" + ) + log_file_amount = DEFAULT_LOG_FILE_AMOUNT + fileHandler = RotatingFileHandler( + file_path, backupCount=log_file_amount, maxBytes=log_file_size + ) # default mode = append ("a") + formatter = _Logger.LowercaseLevelNameFormatter(logging_format, "%Y-%m-%d %H:%M:%S") + fileHandler.setFormatter(formatter) + self._logger.addHandler(fileHandler) + if enableConsole: + import sys + + streamHandler = logging.StreamHandler(sys.stdout) + if sys.stdout.isatty(): + streamHandler.setFormatter(_Logger.ColoredFormatter(logging_format, "%Y-%m-%d %H:%M:%S")) + else: + streamHandler.setFormatter(formatter) + self._logger.addHandler(streamHandler) + self._logger.propagate = False + + def log(self, func, msg, *args, **kwargs): + kwargs.setdefault("stacklevel", 3) + if isinstance(msg, dict): + for _, line in enumerate(_pretty_dict(msg).split("\n")): + func(line, *args, **kwargs) + else: + func(msg, *args, **kwargs) + + def trace(self, msg, *args, **kwargs): + """Output log with the trace level.""" + self.log(self._logger.trace, msg, *args, **kwargs) + + def debug(self, msg, *args, **kwargs): + """Output log with the debug level.""" + self.log(self._logger.debug, msg, *args, **kwargs) + + def info(self, msg, *args, **kwargs): + """Output log with the info level.""" + self.log(self._logger.info, msg, *args, **kwargs) + + def warning(self, msg, *args, **kwargs): + """Output log with the warning level (Alias of the method warn).""" + self.log(self._logger.warning, msg, *args, **kwargs) + + def error(self, msg, *args, **kwargs): + """Output log with the error level.""" + self.log(self._logger.error, msg, *args, **kwargs) + + def critical(self, msg, *args, **kwargs): + """Output log with the critical level.""" + self.log(self._logger.critical, msg, *args, **kwargs) + + fatal = critical + + +logger = _Logger() diff --git a/test/3x/torch/algorithms/fp8_quant/__init__.py b/test/3x/torch/algorithms/fp8_quant/__init__.py new file mode 100644 index 00000000000..7fec54b4191 --- /dev/null +++ b/test/3x/torch/algorithms/fp8_quant/__init__.py @@ -0,0 +1,6 @@ +from .tester import run_accuracy_test, TestVector + +__all__ = [ + "run_accuracy_test", + "TestVector", +] diff --git a/test/3x/torch/algorithms/fp8_quant/conftest.py b/test/3x/torch/algorithms/fp8_quant/conftest.py new file mode 100644 index 00000000000..3497af8b3f4 --- /dev/null +++ b/test/3x/torch/algorithms/fp8_quant/conftest.py @@ -0,0 +1,12 @@ +# Called once at the beginning of the test session +def pytest_sessionstart(): + import habana_frameworks.torch.core as htcore + import torch + + htcore.hpu_set_env() + + # Use reproducible results + torch.use_deterministic_algorithms(True) + + # Fix the seed - just in case + torch.manual_seed(0) diff --git a/test/3x/torch/algorithms/fp8_quant/fp8_tests.py b/test/3x/torch/algorithms/fp8_quant/fp8_tests.py new file mode 100644 index 00000000000..adb9e426409 --- /dev/null +++ b/test/3x/torch/algorithms/fp8_quant/fp8_tests.py @@ -0,0 +1,174 @@ +import torch + +import habana_quantization_toolkit +import habana_frameworks.torch.core as htcore + +# This file is for small tests run for debug flow and accuracy. (Not for CI) + + +class TinyBlock(torch.nn.Module): + + def __init__(self): + super(TinyBlock, self).__init__() + self.pre_linear = torch.nn.Linear(2, 1, bias=False) + self.pre_linear.weight = torch.nn.Parameter(torch.ones([1, 2])) + + def forward(self, x): + x = self.pre_linear(x) + return x + + +class TinyModel(torch.nn.Module): + + def __init__(self): + super(TinyModel, self).__init__() + self.block = TinyBlock() + + def forward(self, x): + x = self.block(x) + return x + + +class TinyBlock2(torch.nn.Module): + + def __init__(self): + super(TinyBlock2, self).__init__() + self.pre_linear = torch.nn.Linear(2, 1, bias=False) + self.pre_linear.weight = torch.nn.Parameter(torch.ones([1, 2])) + self.pre_linear2 = torch.nn.Linear(1, 1, bias=False) + self.pre_linear2.weight = torch.nn.Parameter(torch.ones([1, 1])) + + def forward(self, x): + x = self.pre_linear(x) + x = self.pre_linear2(x) + return x + + +class TinyModel2(torch.nn.Module): + + def __init__(self): + super(TinyModel2, self).__init__() + self.block = TinyBlock2() + + def forward(self, x): + x = self.block(x) + return x + + +class TinyModel3(torch.nn.Module): + + def __init__(self): + super(TinyModel3, self).__init__() + self.block = TinyBlock() + self.block2 = TinyBlock2() + + def forward(self, x, b): + if b: + x = self.block(x) + else: + x = self.block2(x) + return x + + +model = TinyModel() +model.eval() +model = model.to("hpu").to(torch.bfloat16) +htcore.hpu_initialize() +habana_quantization_toolkit.prep_model(model) # fp8 additions + + +with torch.no_grad(): + + # >>> new_fp8converted_input = (torch.tensor(MaxAbs(input), dtype=torch.bfloat16) / torch.tensor(InputScale, dtype=torch.bfloat16)).to(torch.float8_e4m3fn) + # >>> new_fp8converted_weight = (torch.tensor(MaxAbs(weight), dtype=torch.bfloat16) / torch.tensor(WeightScale, dtype=torch.bfloat16)).to(torch.float8_e4m3fn) + # >>> mul_result = new_fp8converted_weight.to(torch.bfloat16) * new_fp8converted_input.to(torch.bfloat16) + # >>> result = mul_result * torch.tensor(InputScale, dtype=torch.bfloat16) * torch.tensor(WeightScale, dtype=torch.bfloat16) + + # If the results of the first 2 lines > 240 (or nan), assume they are equal to 240. (In G2 or G3 with specific fp8 representation settings) + + # Run simulator: + # Gaudi2: run_coral_sim --chip-type gaudi2 -r -D 32 + # Gaudi3: run_coral_sim --chip-type gaudi3 -r -D 32 + # cd .../quantization_toolkit/habana_quantization_toolkit/tests/ + + # Test1: (Disable (comment) all other tests, delete all files from the test_outputs folder) + # Run: + # QUANT_CONFIG=test_jsons/test_measure.json python3 fp8_tests.py + # QUANT_CONFIG=test_jsons/test_hw_quant.json python3 fp8_tests.py + # QUANT_CONFIG=test_jsons/test_pow2_quant.json python3 fp8_tests.py + # QUANT_CONFIG=test_jsons/test_unit_quant.json python3 fp8_tests.py + + out_arange = model((torch.tensor([[232, 0]], dtype=torch.bfloat16)).to("hpu")) + print(out_arange) + + out_arange = model((torch.tensor([[240, 0]], dtype=torch.bfloat16)).to("hpu")) + print(out_arange) + + out_arange = model((torch.tensor([[248, 0]], dtype=torch.bfloat16)).to("hpu")) + print(out_arange) + + # Result (Same for Gaudi2 and Gaudi3): + # for HW/POW2: + # tensor([[224.]], device='hpu:0', dtype=torch.bfloat16) + # tensor([[240.]], device='hpu:0', dtype=torch.bfloat16) + # tensor([[256.]], device='hpu:0', dtype=torch.bfloat16) + # for Unit: + # tensor([[224.]], device='hpu:0', dtype=torch.bfloat16) + # tensor([[240.]], device='hpu:0', dtype=torch.bfloat16) + # tensor([[240.]], device='hpu:0', dtype=torch.bfloat16) + + # Test2: (Disable (comment) all other tests, delete all files from the test_outputs folder) + # Run: + # QUANT_CONFIG=test_jsons/test_measure.json python3 fp8_tests.py + # QUANT_CONFIG=test_jsons/test_hw_quant.json python3 fp8_tests.py + # QUANT_CONFIG=test_jsons/test_pow2_quant.json python3 fp8_tests.py + # QUANT_CONFIG=test_jsons/test_unit_quant.json python3 fp8_tests.py + + out_arange = model((torch.tensor([[3720, 0]], dtype=torch.bfloat16)).to("hpu")) + print(out_arange) + + out_arange = model((torch.tensor([[3721, 0]], dtype=torch.bfloat16)).to("hpu")) + print(out_arange) + + out_arange = model((torch.tensor([[13721, 0]], dtype=torch.bfloat16)).to("hpu")) + print(out_arange) + + # Result: + # for HW (Gaudi2): + # tensor([[3584.]], device='hpu:0', dtype=torch.bfloat16) + # tensor([[3840.]], device='hpu:0', dtype=torch.bfloat16) + # tensor([[3840.]], device='hpu:0', dtype=torch.bfloat16) + # for HW (Gaudi3) and Pow2: + # tensor([[3584.]], device='hpu:0', dtype=torch.bfloat16) + # tensor([[3840.]], device='hpu:0', dtype=torch.bfloat16) + # tensor([[13312.]], device='hpu:0', dtype=torch.bfloat16) + # for Unit: + # tensor([[240.]], device='hpu:0', dtype=torch.bfloat16) + # tensor([[240.]], device='hpu:0', dtype=torch.bfloat16) + # tensor([[240.]], device='hpu:0', dtype=torch.bfloat16) + + # Test3: (Disable (comment) all other tests, delete all files from the test_outputs folder) + # (Change Line 73 above to: model = TinyModel3()) + # Run: (add LOG_LEVEL_HQT=0/1 for additional logs) + # (Uncomment lines 164+165) + # 1) QUANT_CONFIG=test_jsons/test_measure.json python3 fp8_tests.py + # 2) QUANT_CONFIG=test_jsons/test_hw_quant.json python3 fp8_tests.py + # 3) QUANT_CONFIG=test_jsons/test_hw_quant_ignored_unmeasured_models.json python3 fp8_tests.py + # (Comment lines 164+165, Uncomment lines 166+167) + # 4) QUANT_CONFIG=test_jsons/test_hw_quant.json python3 fp8_tests.py + # 5) QUANT_CONFIG=test_jsons/test_hw_quant_ignored_unmeasured_models.json python3 fp8_tests.py + + # out_arange = model((torch.tensor([[232, 0]], dtype=torch.bfloat16)).to('hpu'), True) + # print(out_arange) + # out_arange = model((torch.tensor([[232, 0]], dtype=torch.bfloat16)).to('hpu'), False) + # print(out_arange) + + # Result: + # 1) tensor([[232.]], device='hpu:0', dtype=torch.bfloat16) + # 2) tensor([[224.]], device='hpu:0', dtype=torch.bfloat16) + # 3) tensor([[224.]], device='hpu:0', dtype=torch.bfloat16) + # 4) Exception: Error - Layer 'block2.pre_linear' was called but was not quantized because no measures were supplied. + # 5) tensor([[232.]], device='hpu:0', dtype=torch.bfloat16) + + # fp8 additions + habana_quantization_toolkit.finish_measurements(model) diff --git a/test/3x/torch/algorithms/fp8_quant/pytest.ini b/test/3x/torch/algorithms/fp8_quant/pytest.ini new file mode 100644 index 00000000000..e081c3c20c8 --- /dev/null +++ b/test/3x/torch/algorithms/fp8_quant/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +markers = + deepspeed: marks tests as deepspeed (deselect with '-m "not deepspeed"') diff --git a/test/3x/torch/algorithms/fp8_quant/test_jsons/test_hw_quant.json b/test/3x/torch/algorithms/fp8_quant/test_jsons/test_hw_quant.json new file mode 100644 index 00000000000..eb4f8e8208e --- /dev/null +++ b/test/3x/torch/algorithms/fp8_quant/test_jsons/test_hw_quant.json @@ -0,0 +1,16 @@ +{ + "mode": "QUANTIZE", + "observer": "maxabs", + "scale_method": "maxabs_hw", + "allowlist": { + "types": [], + "names": [] + }, + "blocklist": { + "types": [], + "names": [ + "lm_head" + ] + }, + "dump_stats_path": "./test_outputs/unit_test" +} \ No newline at end of file diff --git a/test/3x/torch/algorithms/fp8_quant/test_jsons/test_hw_quant_ignored_unmeasured_models.json b/test/3x/torch/algorithms/fp8_quant/test_jsons/test_hw_quant_ignored_unmeasured_models.json new file mode 100644 index 00000000000..54a779cee7e --- /dev/null +++ b/test/3x/torch/algorithms/fp8_quant/test_jsons/test_hw_quant_ignored_unmeasured_models.json @@ -0,0 +1,17 @@ +{ + "mode": "QUANTIZE", + "observer": "maxabs", + "scale_method": "maxabs_hw", + "allowlist": { + "types": [], + "names": [] + }, + "blocklist": { + "types": [], + "names": [ + "lm_head" + ] + }, + "dump_stats_path": "./test_outputs/unit_test", + "ignore_modules_wo_measures": "true" +} \ No newline at end of file diff --git a/test/3x/torch/algorithms/fp8_quant/test_jsons/test_measure.json b/test/3x/torch/algorithms/fp8_quant/test_jsons/test_measure.json new file mode 100644 index 00000000000..e2743faafa7 --- /dev/null +++ b/test/3x/torch/algorithms/fp8_quant/test_jsons/test_measure.json @@ -0,0 +1,13 @@ +{ + "mode": "MEASURE", + "observer": "maxabs", + "allowlist": { + "types": [], + "names": [] + }, + "blocklist": { + "types": [], + "names": [] + }, + "dump_stats_path": "./test_outputs/unit_test" +} \ No newline at end of file diff --git a/test/3x/torch/algorithms/fp8_quant/test_jsons/test_pow2_quant.json b/test/3x/torch/algorithms/fp8_quant/test_jsons/test_pow2_quant.json new file mode 100644 index 00000000000..7f44824fa9d --- /dev/null +++ b/test/3x/torch/algorithms/fp8_quant/test_jsons/test_pow2_quant.json @@ -0,0 +1,16 @@ +{ + "mode": "QUANTIZE", + "observer": "maxabs", + "scale_method": "maxabs_pow2", + "allowlist": { + "types": [], + "names": [] + }, + "blocklist": { + "types": [], + "names": [ + "lm_head" + ] + }, + "dump_stats_path": "./test_outputs/unit_test" +} \ No newline at end of file diff --git a/test/3x/torch/algorithms/fp8_quant/test_jsons/test_unit_quant.json b/test/3x/torch/algorithms/fp8_quant/test_jsons/test_unit_quant.json new file mode 100644 index 00000000000..60127bbad20 --- /dev/null +++ b/test/3x/torch/algorithms/fp8_quant/test_jsons/test_unit_quant.json @@ -0,0 +1,16 @@ +{ + "mode": "QUANTIZE", + "observer": "maxabs", + "scale_method": "unit_scale", + "allowlist": { + "types": [], + "names": [] + }, + "blocklist": { + "types": [], + "names": [ + "lm_head" + ] + }, + "dump_stats_path": "./test_outputs/unit_test" +} \ No newline at end of file diff --git a/test/3x/torch/algorithms/fp8_quant/tester.py b/test/3x/torch/algorithms/fp8_quant/tester.py new file mode 100644 index 00000000000..374c9ada590 --- /dev/null +++ b/test/3x/torch/algorithms/fp8_quant/tester.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +import itertools +import logging +import os.path +import random +import typing +from dataclasses import dataclass + +import torch + +import habana_frameworks as htcore + +from habana_quantization_toolkit._core.common import mod_default_dict + +from habana_quantization_toolkit._quant_common.quant_config import ( + Fp8cfg, + QuantMode, + ScaleMethod, +) + + +@dataclass +class TestVector: + # Mark to pytest that it is not a tester class + __test__ = False + + inputs: typing.Sequence[torch.Tensor] + atol: typing.Optional[float] = None + rtol: typing.Optional[float] = None + + +M = typing.TypeVar("M", bound=torch.nn.Module) + + +def _assert_quantized_correctly(*, reference_model: WrapModel, quantized_model: WrapModel): + """ + In quantized mode, assert the reference model is not quantized, and the quantized model is. + Otherwise, assert that both are not quantized. + """ + for reference_name in mod_default_dict.keys(): + quantized_name = mod_default_dict[reference_name].patched_module.__name__ + + assert not reference_model.has_name(quantized_name) + assert not quantized_model.has_name(reference_name), f"{reference_name=} should not be in the quantized model" + + if reference_model.has_name(reference_name): + assert quantized_model.has_name(quantized_name), f"{quantized_name=} should be in the quantized model" + + +def run_accuracy_test( + *, + module_class: typing.Type[M], + module_args: typing.Sequence = (), + module_kwargs: typing.Mapping = {}, + lp_dtype: torch.dtype, + scale_method: ScaleMethod, + measure_vectors: typing.Optional[typing.Iterable[TestVector]] = None, + test_vectors: typing.Iterable[TestVector], + seed: typing.Optional[int] = None, +): + """ + Run both the reference and the quantized versions of this module, + and compare the outputs on every test vector. + + First the measure vectors are used for measurements. + + This test also makes asserts the quantization actually happened. + This may be moved to another tests in the future. + + You can use the generate_test_vectors.py script to generate input test vectors. + + Args: + module_class: The reference module class to test. + This should be the direct module to test, e.g. Matmul, Linear, etc. + module_args: The positional arguments to pass to the module constructor. Default is empty. + module_kwargs: The keyword arguments to pass to the module constructor. Default is empty. + lp_dtype: The dtype to quantize to. + scale_method: The scaling method to use. + measure_vectors: An iterable of vectors, each contains a sequence of inputs. + If not given, `itertools.tee()` for `test_vectors` will be used. + That is, all the test vectors will be used for the measurements. + test_vectors: An iterable of test vectors, each contains a sequence of inputs and tolerance + seed: The random seed to use. If not given, will use a default seed derived from the module name. + """ + + # If no measure vectors given - use the same dataset as for the test vectors + # Use `help(itertools.tee)` for more info + if measure_vectors is None: + measure_vectors, test_vectors = itertools.tee(test_vectors) + + for mode in [QuantMode.MEASURE, QuantMode.QUANTIZE]: + import habana_quantization_toolkit.prepare_quant.prepare_model as hqt + + reference_model = WrapModel(module_class, seed, *module_args, **module_kwargs) + quantized_model = WrapModel(module_class, seed, *module_args, **module_kwargs) + + config = _get_test_only_config( + mode=mode, + lp_dtype=lp_dtype, + scale_method=scale_method, + ) + hqt._prep_model_with_predefined_config(quantized_model, config=config) + + _assert_quantized_correctly(reference_model=reference_model, quantized_model=quantized_model) + + vectors = { + QuantMode.MEASURE: measure_vectors, + QuantMode.QUANTIZE: test_vectors, + }[mode] + + for vector in vectors: + reference_output = reference_model(*(input.clone() for input in vector.inputs)).to(float) + quantized_output = quantized_model(*(input.clone() for input in vector.inputs)).to(float) + + # Override tolerance values given by the caller + tolerance = { + key: getattr(vector, key) for key in ["atol", "rtol"] if getattr(vector, key, None) is not None + } + + # Accuracy check against the reference module + assert torch.allclose(reference_output, quantized_output, **tolerance), ( + f"Test vector fails in accuracy test: " + f"\n inputs={vector.inputs}" + f"\n {reference_output=}" + f"\n {quantized_output=}" + f"\n {lp_dtype=}" + f"\n {scale_method.name=}" + ) + + hqt.finish_measurements(quantized_model) + + +def _set_optional_seed(*, module_class: typing.Type[M], seed: typing.Optional[int]): + """ + Set random seed to a unique reproducible value derived from the module. + + Args: + module_class: The module class to test. + This should be the direct module to test, e.g. Matmul, Linear, etc. + seed: The random seed to use. If not given, will use a default seed derived from the module name. + """ + if seed is None: + import hashlib + + # We use sha256 to ensure a deterministic has, as opposed to `builtins.hash`, which sadly is not so. + seed = int.from_bytes( + bytes=hashlib.sha256(module_class.__name__.encode("utf-8")).digest()[:4], + byteorder="big", + ) + + logging.info(f"Using {seed=}") + + random.seed(seed) + torch.manual_seed(seed) + + +class WrapModel(torch.nn.Module): + """ + Wrap an inner module. + If we do not wrap the inner module, it will not be quantized properly. + + Maybe we can change this behavior in the future. + """ + + def __init__( + self, + module_class: typing.Type[M], + seed: typing.Optional[int], + /, + *args, + **kwargs, + ): + super().__init__() + _set_optional_seed(module_class=module_class, seed=seed) + self.inner = module_class(*args, **kwargs) + + def forward(self, *args, **kwargs): + return self.inner(*args, **kwargs) + + def has_name(self, module_name: str) -> bool: + return any(module._get_name() == module_name for module in self.modules()) + + +TEST_ONLY_OUTPUT_DIRECTORY = f"habana_quantization_toolkit/tests/output/" + + +def get_test_unique_dump_path(): + # This is a unique id of the test including the parameters, thanks to pytest. + # TODO: make sure this globally-ever unique (probably add global init timestamp) + unique_test_id = os.environ.get("PYTEST_CURRENT_TEST") + return os.path.join(TEST_ONLY_OUTPUT_DIRECTORY, unique_test_id) + + +def _get_test_only_config( + *, + mode: QuantMode, + scale_method: ScaleMethod, + lp_dtype: torch.dtype, +) -> Fp8cfg: + """ + Should NOT be used externally. + + Return a new config used only for the tests. + """ + + # TODO: replace this with a version that does not use strings but direct values. + # It is currently needed because of how Fp8cfg.parse() works. + return Fp8cfg.parse( + { + "method": "HOOKS", + "mode": mode.name, + "observer": "maxabs", + "fp8_config": str(lp_dtype).replace("torch.float8_", "")[:4], + "scale_method": scale_method.name, + "dump_stats_path": get_test_unique_dump_path(), + } + ) diff --git a/test/3x/torch/algorithms/fp8_quant/unit_tests/__init__.py b/test/3x/torch/algorithms/fp8_quant/unit_tests/__init__.py new file mode 100644 index 00000000000..2516c4e1ef6 --- /dev/null +++ b/test/3x/torch/algorithms/fp8_quant/unit_tests/__init__.py @@ -0,0 +1,6 @@ +""" +The unit_test package contains a `test_.py` file for every module supported +in the habana quantization toolkit. + +To run use `pytest`. +""" diff --git a/test/3x/torch/algorithms/fp8_quant/unit_tests/test_deepspeed.py b/test/3x/torch/algorithms/fp8_quant/unit_tests/test_deepspeed.py new file mode 100644 index 00000000000..f0fe3ffcfff --- /dev/null +++ b/test/3x/torch/algorithms/fp8_quant/unit_tests/test_deepspeed.py @@ -0,0 +1,86 @@ +import os +import typing + +import pytest +import torch + +from habana_quantization_toolkit._quant_common.quant_config import ScaleMethod + +from habana_quantization_toolkit.tests import run_accuracy_test, TestVector + + +class LinearBlock(torch.nn.Module): + def __init__(self): + super(LinearBlock, self).__init__() + self.linear_ = torch.nn.Linear(2, 2, bias=True) + self.linear_.weight = torch.nn.Parameter(torch.arange(0.0, 4.0).reshape(2, 2)) + self.linear_.bias = torch.nn.Parameter(torch.zeros(2)) + + def forward(self, x): + return self.linear_(x) + + +class TinyBlock(torch.nn.Module): + def __init__(self): + super(TinyBlock, self).__init__() + self.pre_linear = torch.nn.Linear(2, 2, bias=False) + self.pre_linear.weight = torch.nn.Parameter(torch.ones((2, 2)) / 4) + + self.linear1 = LinearBlock() + self.post_linear = torch.nn.Linear(2, 2, bias=False) + self.post_linear.weight = torch.nn.Parameter(torch.ones((2, 2)) / 4) + self.linear2 = LinearBlock() + + def forward(self, x): + x = self.pre_linear(x) + x = self.linear1(x) + x = self.post_linear(x) + x = self.linear2(x) + x = x.sum() + return x + + +class TinyModel(torch.nn.Module): + def __init__(self, **kwargs): + super().__init__() + + block = TinyBlock() + + # no kernel inject - currently only works on Habana's DeepSpeed fork! + # these layers will be switched to LinearAllReduce. + injection_policy = {TinyBlock: ("linear1.linear_", "linear2.linear_")} + + # Initialize deepspeed on model creation + import deepspeed + block = deepspeed.init_inference( + block, + injection_policy=injection_policy, + **kwargs, + ) + self.block = block.module + + def forward(self, x): + return self.block(x) + + +def get_test_vectors(dtype: torch.dtype) -> typing.Iterable[TestVector]: + yield TestVector( + inputs=[torch.ones(1, 2).to(device="hpu", dtype=dtype)], + ) + + +# @pytest.mark.parametrize("hp_dtype", [torch.bfloat16, torch.float32]) +# TODO: float32 doesn't work - WHY? +# TODO: add ticket +@pytest.mark.deepspeed +@pytest.mark.parametrize("hp_dtype", [torch.bfloat16]) +@pytest.mark.parametrize("lp_dtype", [torch.float8_e4m3fn]) +def test_deepspeed_accuracy(hp_dtype: torch.dtype, lp_dtype: torch.dtype): + world_size = 1 + run_accuracy_test( + module_class=TinyModel, + test_vectors=get_test_vectors(dtype=hp_dtype), + lp_dtype=lp_dtype, + scale_method=ScaleMethod.MAXABS_HW, + module_kwargs={"dtype": hp_dtype, "mp_size": world_size}, + ) diff --git a/test/3x/torch/algorithms/fp8_quant/unit_tests/test_functions/test_config_json.py b/test/3x/torch/algorithms/fp8_quant/unit_tests/test_functions/test_config_json.py new file mode 100644 index 00000000000..502aaeb457d --- /dev/null +++ b/test/3x/torch/algorithms/fp8_quant/unit_tests/test_functions/test_config_json.py @@ -0,0 +1,29 @@ +""" +Use this module as an example of how to write new unit tests for layers. +""" + +import torch + +import habana_quantization_toolkit as hqt + +from habana_quantization_toolkit._quant_common.quant_config import QuantMode +from habana_quantization_toolkit._quant_common.helper_modules import Matmul + + +class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.inner = Matmul() + + +def test_config_json(): + model = Model() + + for mode in [QuantMode.MEASURE, QuantMode.QUANTIZE]: + name = { + QuantMode.MEASURE: "measure", + QuantMode.QUANTIZE: "quant", + }[mode] + config_path = f"llama_{name}" + hqt.prep_model(model, config_path=config_path) + hqt.finish_measurements(model) diff --git a/test/3x/torch/algorithms/fp8_quant/unit_tests/test_functions/test_matmul_fp8.py b/test/3x/torch/algorithms/fp8_quant/unit_tests/test_functions/test_matmul_fp8.py new file mode 100644 index 00000000000..97151d54cd8 --- /dev/null +++ b/test/3x/torch/algorithms/fp8_quant/unit_tests/test_functions/test_matmul_fp8.py @@ -0,0 +1,71 @@ +import itertools +import pytest +import torch + +from typing import Iterable, Tuple +from habana_quantization_toolkit._core.fp_utils import FP8_143_SCALES +from habana_quantization_toolkit._quant_common.helper_modules import matmul_fp8 +import habana_frameworks.torch.utils.experimental as htexp + + +def run_test_matmul_fp8( + *, + hp_dtype: torch.dtype, + lp_dtype: torch.dtype, + scales: Tuple[float, float], +): + torch.manual_seed(0) + x = torch.randn(2, 2, dtype=float).clone() + y = torch.randn(2, 2, dtype=float).clone() + + x_scale, y_scale = scales + expected_result = (torch.matmul(x, y) / x_scale / y_scale).to(dtype=hp_dtype) + + result = matmul_fp8( + input=x.to(device="hpu").to(dtype=lp_dtype), + other=y.to(device="hpu").to(dtype=lp_dtype), + out_dtype=hp_dtype, + scale_input_inv=1 / x_scale, + scale_other_inv=1 / y_scale, + ) + + assert torch.allclose(expected_result, result, rtol=0.1), f"Matmul failed for {x_scale=} {y_scale=}" + + +def get_fp8_143_scales(): + device_type = htexp._get_device_type() + return FP8_143_SCALES[device_type] + + +def get_scales_pairs_not_both_hw_aligned() -> Iterable[Tuple[float, float]]: + not_hw_aligned_scales = [0.25] + + return itertools.chain( + zip(not_hw_aligned_scales, not_hw_aligned_scales), + zip(not_hw_aligned_scales, get_fp8_143_scales()), + zip(get_fp8_143_scales(), not_hw_aligned_scales), + ) + + +def get_scales_pairs_both_hw_aligned() -> Iterable[Tuple[float, float]]: + return zip(get_fp8_143_scales(), get_fp8_143_scales()) + + +@pytest.mark.parametrize("hp_dtype", [torch.bfloat16]) +@pytest.mark.parametrize("lp_dtype", [torch.float8_e4m3fn]) +def test_matmul_fp8_not_both_hw_aligned( + hp_dtype: torch.dtype, + lp_dtype: torch.dtype, +): + for scales in get_scales_pairs_not_both_hw_aligned(): + run_test_matmul_fp8(hp_dtype=hp_dtype, lp_dtype=lp_dtype, scales=scales) + + +@pytest.mark.parametrize("hp_dtype", [torch.bfloat16]) +@pytest.mark.parametrize("lp_dtype", [torch.float8_e4m3fn]) +def test_matmul_fp8_both_hw_aligned( + hp_dtype: torch.dtype, + lp_dtype: torch.dtype, +): + for scales in get_scales_pairs_both_hw_aligned(): + run_test_matmul_fp8(hp_dtype=hp_dtype, lp_dtype=lp_dtype, scales=scales) diff --git a/test/3x/torch/algorithms/fp8_quant/unit_tests/test_layers/test_conv2d.py b/test/3x/torch/algorithms/fp8_quant/unit_tests/test_layers/test_conv2d.py new file mode 100644 index 00000000000..6994bf437ca --- /dev/null +++ b/test/3x/torch/algorithms/fp8_quant/unit_tests/test_layers/test_conv2d.py @@ -0,0 +1,40 @@ +import typing + +import pytest +import torch +from habana_quantization_toolkit._quant_common.quant_config import ScaleMethod + +from habana_quantization_toolkit.tests import run_accuracy_test, TestVector + + +def get_test_vectors(*, dtype: torch.dtype, C_in: int, H: int, W: int) -> typing.Iterable[TestVector]: + yield TestVector( + inputs=[torch.ones(1, C_in, H, W, dtype=dtype, device="hpu")], + atol=0.2, + ) + + +@pytest.mark.parametrize("hp_dtype", [torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("lp_dtype", [torch.float8_e4m3fn]) +def test_conv2d_accuracy(hp_dtype: torch.dtype, lp_dtype: torch.dtype): + C_in = 1 + C_out = 1 + K = 3 + + H = W = 8 + + run_accuracy_test( + module_class=torch.nn.Conv2d, + module_kwargs={ + "in_channels": C_in, + "out_channels": C_out, + "kernel_size": K, + "padding": 1, + "bias": False, + "device": "hpu", + "dtype": hp_dtype, + }, + lp_dtype=lp_dtype, + scale_method=ScaleMethod.MAXABS_HW, + test_vectors=get_test_vectors(dtype=hp_dtype, C_in=C_in, H=H, W=W), + ) diff --git a/test/3x/torch/algorithms/fp8_quant/unit_tests/test_layers/test_linear.py b/test/3x/torch/algorithms/fp8_quant/unit_tests/test_layers/test_linear.py new file mode 100644 index 00000000000..528b5d9358d --- /dev/null +++ b/test/3x/torch/algorithms/fp8_quant/unit_tests/test_layers/test_linear.py @@ -0,0 +1,33 @@ +import typing + +import pytest +import torch +from habana_quantization_toolkit._quant_common.quant_config import ScaleMethod + +from habana_quantization_toolkit.tests import run_accuracy_test, TestVector + + +def get_test_vectors(*, dtype: torch.dtype, N: int, D_in: int) -> typing.Iterable[TestVector]: + yield TestVector( + inputs=[torch.ones(N, D_in, dtype=dtype, device="hpu")], + atol=0.02, + ) + + +@pytest.mark.parametrize("hp_dtype", [torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("lp_dtype", [torch.float8_e4m3fn]) +def test_linear_accuracy(hp_dtype: torch.dtype, lp_dtype: torch.dtype): + N = 1 + D_in = 8 + H = 5 + run_accuracy_test( + module_class=torch.nn.Linear, + module_kwargs={ + "in_features": D_in, + "out_features": H, + "bias": False, + }, + lp_dtype=lp_dtype, + scale_method=ScaleMethod.MAXABS_HW, + test_vectors=get_test_vectors(dtype=hp_dtype, N=N, D_in=D_in), + ) diff --git a/test/3x/torch/algorithms/fp8_quant/unit_tests/test_layers/test_matmul.py b/test/3x/torch/algorithms/fp8_quant/unit_tests/test_layers/test_matmul.py new file mode 100644 index 00000000000..86ae332b311 --- /dev/null +++ b/test/3x/torch/algorithms/fp8_quant/unit_tests/test_layers/test_matmul.py @@ -0,0 +1,56 @@ +import typing + +import pytest +import torch +from habana_quantization_toolkit._quant_common.quant_config import ScaleMethod + +from habana_quantization_toolkit.tests import run_accuracy_test, TestVector + + +def get_test_vectors(*, dtype: torch.dtype) -> typing.Iterable[TestVector]: + yield TestVector( + inputs=[ + torch.eye(2, dtype=dtype, device="hpu"), + torch.eye(2, dtype=dtype, device="hpu"), + ], + atol=0.2, + ) + yield TestVector( + inputs=[ + torch.randn((2, 2), dtype=dtype, device="hpu"), + torch.randn((2, 2), dtype=dtype, device="hpu"), + ], + atol=0.2, + ) + yield TestVector( + inputs=[ + torch.eye(2, dtype=dtype, device="hpu"), + torch.randn((2, 2), dtype=dtype, device="hpu"), + ], + atol=0.2, + ) + + +class Matmul(torch.nn.Module): + """ + This is a mimic of other implementations of `Matmul`. + It is here to not create a dependency on optimum-habana (which is logically needed). + It should not be used directly in user code. + """ + + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.matmul(x, y) + + +@pytest.mark.parametrize("hp_dtype", [torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("lp_dtype", [torch.float8_e4m3fn]) +def test_matmul_accuracy(hp_dtype: torch.dtype, lp_dtype: torch.dtype): + run_accuracy_test( + module_class=Matmul, + lp_dtype=lp_dtype, + scale_method=ScaleMethod.MAXABS_HW, + test_vectors=get_test_vectors(dtype=hp_dtype), + ) From 96bffd97a94412b43566105dfe0b8007525021dd Mon Sep 17 00:00:00 2001 From: Uri Livne Date: Sun, 7 Jul 2024 18:23:30 +0300 Subject: [PATCH 07/51] [SW-184714] Add internal folder to fp8 quant This is a folder used for experiments, not to be used by users Change-Id: I9e221ae582794e304e95392c0f37638f7bce69bc --- .../internal/diffusion_evaluation/README | 32 + .../SR_evaluation/README.md | 37 + .../SR_evaluation/create_SR_dataset.py | 87 ++ .../imagenet1000_clsidx_to_labels.txt | 1000 +++++++++++++++++ .../SR_evaluation/super_res_eval.py | 70 ++ .../diffusion_evaluation/create_dataset.py | 90 ++ .../diffusion_evaluation/evaluator.py | 102 ++ .../imagenet_quant.py | 75 ++ .../inference_quant_examples/run_example.sh | 5 + 9 files changed, 1498 insertions(+) create mode 100644 neural_compressor/torch/algorithms/fp8_quant/internal/diffusion_evaluation/README create mode 100644 neural_compressor/torch/algorithms/fp8_quant/internal/diffusion_evaluation/SR_evaluation/README.md create mode 100644 neural_compressor/torch/algorithms/fp8_quant/internal/diffusion_evaluation/SR_evaluation/create_SR_dataset.py create mode 100644 neural_compressor/torch/algorithms/fp8_quant/internal/diffusion_evaluation/SR_evaluation/imagenet1000_clsidx_to_labels.txt create mode 100644 neural_compressor/torch/algorithms/fp8_quant/internal/diffusion_evaluation/SR_evaluation/super_res_eval.py create mode 100644 neural_compressor/torch/algorithms/fp8_quant/internal/diffusion_evaluation/create_dataset.py create mode 100644 neural_compressor/torch/algorithms/fp8_quant/internal/diffusion_evaluation/evaluator.py create mode 100644 neural_compressor/torch/algorithms/fp8_quant/internal/inference_quant_examples/imagenet_quant.py create mode 100755 neural_compressor/torch/algorithms/fp8_quant/internal/inference_quant_examples/run_example.sh diff --git a/neural_compressor/torch/algorithms/fp8_quant/internal/diffusion_evaluation/README b/neural_compressor/torch/algorithms/fp8_quant/internal/diffusion_evaluation/README new file mode 100644 index 00000000000..3a71ca76a5e --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/internal/diffusion_evaluation/README @@ -0,0 +1,32 @@ +How to calculate FID and clip score: + +We will use the MS-COCO database. We use this for two things: +- Generating a large amount of prompts which we can use to create diffusion images +- Once we have diffusion images, we need a "ground truth" dataset to calculate the FID. + +1) Run a python script which does the following things: + - Takes a subset of MSCOCO + - Create a CSV with prompts which can then be inserted into the diffusion model. These prompts are taken from captions of the images in the subset + - Create a new folder with the images from the subset + - The standard number of images for this evaluation is 30K or 10K + +run the following: + +python create_dataset.py /datasets/coco2014 + +Now, create the generated images from the csv file + +IMPORTANT!! - the script that does the actual evaluation (explained below) expects to get an image where the prompt is the title of the image. For example, if the prompt is "a monster playing the guitar" then the name of the file that is created using diffusion should be "/a monster playing the guitar.png" (or jpg or whatever) + +IMPORTANT!! #2 - from my experience, stable diffusion inference returns an error for prompts with the character '/' in them. There are very few, around one in a thousand. My recomendation, if you want to evaluate N images, create a subset of the size N+30 and delete prompts with '/' in them. After creating the CSV I just deleted these prompts manually (takes 10 seconds to do). +(Perhaps automating this should be a future commit). + +2) Now, run the evaluation script. This does the following: +- Calculates the CLIP score – takes the CLIP embedding of each generated image and the embedding of the caption that created it (in this case each image and its file name). Then, calculates the cosine distance between them. +- Calculates the FID - takes the real and generated images, and calculates according to the FID distance metric. +- insert the number of images to evaluate with - could be the number of images in the subset created above or less + +To do this, run: + +python evaluator.py --device hpu --real_images_path /datasets/coco2014/val2014 --diff_images_path --num_of_images + diff --git a/neural_compressor/torch/algorithms/fp8_quant/internal/diffusion_evaluation/SR_evaluation/README.md b/neural_compressor/torch/algorithms/fp8_quant/internal/diffusion_evaluation/SR_evaluation/README.md new file mode 100644 index 00000000000..208a9b8d81c --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/internal/diffusion_evaluation/SR_evaluation/README.md @@ -0,0 +1,37 @@ +How to calculate PSNR and SSIM for Super Resolution +We will use the Imagenet validation dataset. + +The evaluation is done by the following steps: +1) We take the Imagenet validation set which has 50,000 images (We can also take a subset) +2) Crop these Images to be 256*256 (center cropped), and save these images as the "ground truth" dataset. The name of +the saved image is its label. +3) Downsample the images to be 64*64 (using bicubic interpolation) and then restore them using Super Resolution. +4) Calculate PSNR and SSIM between each ground truth image and restored image, and print the mean. + +Steps 1,2 and 4 are inluded here, while step 3 (downsampling and restoring) should be done seperately, using the +desired Super Resolution method. Keep in mind that this script assumes that the images are stored in a specific format, +(detailed later). Later, the restored images path should be given as an input to step 4. + +You can skip step 1+2 and use the images at /datasets/imagenet/val_cropped_labeled +You can also run a python script which does the following to the imagenet validation dataset: + - Crops images to 256*256 (this can also be changed using the argument --resize, 256*256 is the default) + - Saves the images with the convention /