From 4011c184aed6929dcbd420aea368bef0c33003aa Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Mon, 16 Dec 2024 16:00:28 +0800 Subject: [PATCH 1/3] use new fused layer norm --- python/llm/src/ipex_llm/transformers/convert.py | 7 +++---- .../src/ipex_llm/transformers/models/bloom.py | 17 ----------------- .../src/ipex_llm/transformers/models/common.py | 17 ++++++++++++++++- 3 files changed, 19 insertions(+), 22 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 4f1b0d3d63f..ada7be90ffd 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1296,10 +1296,9 @@ def _optimize_post(model, lightweight_bmm=False): trans_version = transformers.__version__ # convert all nn.LayerNorm - from ipex_llm.transformers.models.bloom import bloom_layer_norm_forward - convert_forward(model, - nn.LayerNorm, - bloom_layer_norm_forward) + from ipex_llm.transformers.models.common import layer_norm_forward + convert_forward(model, nn.LayerNorm, layer_norm_forward) + from ipex_llm.transformers.models.llama import llama_rms_norm_forward from ipex_llm.transformers.models.llama import llama_mlp_forward diff --git a/python/llm/src/ipex_llm/transformers/models/bloom.py b/python/llm/src/ipex_llm/transformers/models/bloom.py index 54a6d052903..4967aa1897c 100644 --- a/python/llm/src/ipex_llm/transformers/models/bloom.py +++ b/python/llm/src/ipex_llm/transformers/models/bloom.py @@ -64,23 +64,6 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: return out -def bloom_layer_norm_forward(self, hidden_states): - if use_fused_layer_norm(hidden_states, self.training): - import xe_addons - result = xe_addons.fused_layer_norm(hidden_states, - [self.weight.size(0)], - self.weight, - self.bias, - self.eps) - # if nelement == 0, means fused norm failed, go back to python implement. - if result.nelement != 0: - return result - input_dtype = hidden_states.dtype - result = F.layer_norm(hidden_states.to(self.weight.dtype), - self.normalized_shape, self.weight, self.bias, self.eps) - return result.to(input_dtype) - - def bloom_attention_forward( self, hidden_states: torch.Tensor, diff --git a/python/llm/src/ipex_llm/transformers/models/common.py b/python/llm/src/ipex_llm/transformers/models/common.py index 5b7de52caab..c17fa3b59f0 100644 --- a/python/llm/src/ipex_llm/transformers/models/common.py +++ b/python/llm/src/ipex_llm/transformers/models/common.py @@ -14,6 +14,7 @@ # limitations under the License. +import math import torch from typing import List @@ -159,7 +160,7 @@ def rms_norm_forward(self, hidden_states: torch.Tensor): else: eps = self.epsilon - if hidden_states.device.type == 'xpu': + if hidden_states.device.type == 'xpu' and hidden_states.dtype in [torch.float, torch.half]: import xe_addons x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous() output = xe_addons.rms_norm(weight, x_2d, eps) @@ -169,3 +170,17 @@ def rms_norm_forward(self, hidden_states: torch.Tensor): variance = hidden_states.to(torch.float32).pow(2).mean(dim=-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + eps) return weight * hidden_states.to(input_dtype) + + +def layer_norm_forward(self, hidden_states: torch.Tensor): + if hidden_states.device.type == 'xpu' and hidden_states.dtype in [torch.float, torch.half]: + import xe_addons + hidden_size = math.prod(self.normalized_shape) + x_2d = hidden_states.reshape(-1, hidden_size).contiguous() + output = xe_addons.layer_norm(x_2d, self.weight, self.bias, self.eps) + return output.reshape(hidden_states.shape) + else: + return torch.nn.functional.layer_norm( + input, self.normalized_shape, + self.weight, self.bias, self.eps + ) From cf71acdfd0ace3f8059a59727cd5bd8b9261834a Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Mon, 16 Dec 2024 16:02:32 +0800 Subject: [PATCH 2/3] fix --- python/llm/src/ipex_llm/transformers/models/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/models/common.py b/python/llm/src/ipex_llm/transformers/models/common.py index c17fa3b59f0..4c2e830cbdd 100644 --- a/python/llm/src/ipex_llm/transformers/models/common.py +++ b/python/llm/src/ipex_llm/transformers/models/common.py @@ -181,6 +181,6 @@ def layer_norm_forward(self, hidden_states: torch.Tensor): return output.reshape(hidden_states.shape) else: return torch.nn.functional.layer_norm( - input, self.normalized_shape, + hidden_states, self.normalized_shape, self.weight, self.bias, self.eps ) From f359b967fe7b78a90d336cb093b46e7f790b1e3f Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 17 Dec 2024 10:31:32 +0800 Subject: [PATCH 3/3] update --- .../test_transformers_api_layernorm.py | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/python/llm/test/inference_gpu/test_transformers_api_layernorm.py b/python/llm/test/inference_gpu/test_transformers_api_layernorm.py index 68a15d8a579..b0cd8a178f7 100644 --- a/python/llm/test/inference_gpu/test_transformers_api_layernorm.py +++ b/python/llm/test/inference_gpu/test_transformers_api_layernorm.py @@ -13,39 +13,39 @@ # See the License for the specific language governing permissions and # limitations under the License. # - + import os import pytest import gc - + import torch from ipex_llm.transformers import AutoModelForCausalLM, AutoModel from transformers import LlamaTokenizer, AutoTokenizer - + device = os.environ['DEVICE'] print(f'Running on {device}') - + PROMPT = "Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun" TEST_MODEL_LIST = [ ("Falcon-7B", AutoModelForCausalLM, AutoTokenizer, os.environ.get('FALCON_7B_ORIGIN_PATH')) ] - + class Test_Optimize_Gpu_Model: def setup_method(self): self.layer_outputs = [] self.pre_layer_outputs = [] - + def run_optimize_gpu_model(self, Name, Model, Tokenizer, model_path, LayerNorm_layer, layer_before_LayerNorm, lower_bound): with torch.inference_mode(): def pre_forward_hook(module, input, output, layer_name): self.pre_layer_outputs.append(output) - + def forward_hook(module, input, output, layer_name): self.layer_outputs.append(output) - + tokenizer = Tokenizer.from_pretrained(model_path, trust_remote_code=True) input_ids = tokenizer.encode(PROMPT, return_tensors="pt").to(device) - + model = Model.from_pretrained(model_path, load_in_4bit=True, optimize_model=False, @@ -64,18 +64,18 @@ def forward_hook(module, input, output, layer_name): # the list `layer_output` has only one element. layer_tensor = self.layer_outputs.pop() model.to('cpu') - + opt_model = Model.from_pretrained(model_path, load_in_4bit=True, optimize_model=True, trust_remote_code=True) opt_model = opt_model.to(device) - - + + def replace_forward_hook(module, input, output, layer_name): output = self.pre_layer_outputs[0] return output - + for layer_name, layer_module in opt_model.named_modules(): if layer_name == layer_before_LayerNorm: layer_module.register_forward_hook( @@ -89,12 +89,12 @@ def replace_forward_hook(module, input, output, layer_name): # the list `layer_output` has only one element. opt_layer_tensor = self.layer_outputs[0] opt_model.to('cpu') - - + + LayerNorm_output_diff = [] for i, (t1, t2) in enumerate(zip(layer_tensor, opt_layer_tensor)): LayerNorm_output_diff.append(t1 - t2) - + max_diff_tensor = [torch.max(item).item() for item in LayerNorm_output_diff] print(max_diff_tensor) torch.xpu.empty_cache() @@ -102,16 +102,16 @@ def replace_forward_hook(module, input, output, layer_name): del opt_model gc.collect() assert all(max_diff <= lower_bound for max_diff in max_diff_tensor) - + @pytest.mark.parametrize('Name, Model, Tokenizer, model_path',TEST_MODEL_LIST) def test_dynamic_functions(self, Name, Model, Tokenizer, model_path): if Name == "Falcon-7B": self.Falcon_7B_gpu_model(Name, Model, Tokenizer, model_path) - + def Falcon_7B_gpu_model(self, Name, Model, Tokenizer, model_path): # currently only compare the output of the last LayerNorm layer. layer_before_LayerNorm = "transformer.h.30" LayerNorm_layer = "transformer.h.31.input_layernorm" - lower_bound = 0 + lower_bound = 1e-5 self.run_optimize_gpu_model(Name, Model, Tokenizer, model_path, LayerNorm_layer, layer_before_LayerNorm, lower_bound) \ No newline at end of file