From f89736f7b24cb4005fb06316ace0203733af231c Mon Sep 17 00:00:00 2001 From: rnwang04 Date: Tue, 26 Mar 2024 12:34:30 +0800 Subject: [PATCH 1/4] add esimd sdp for pvc --- .../src/ipex_llm/transformers/models/llama.py | 2 +- .../src/ipex_llm/transformers/models/utils.py | 27 +++++++++++-------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 45d944c5e4d..68391499d0d 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -622,7 +622,7 @@ def llama_attention_forward_4_31_original( is_causal=True) attn_weights = None elif not self.training and not hidden_states.requires_grad and \ - use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states): + use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states, attention_mask): import linear_fp16_esimd attn_output = linear_fp16_esimd.sdp_forward(query_states, key_states, diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 1a4e1f0b94d..a5442c8254d 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -299,7 +299,7 @@ def use_flash_attention(query, key, attention_mask=None): return True -def use_esimd_sdp(q_len, k_len, head_dim, query_states): +def use_esimd_sdp(q_len, k_len, head_dim, query_states, attention_mask=None): if head_dim != 128: # esimd_sdp only support head_dim = 128 now return False @@ -315,17 +315,22 @@ def use_esimd_sdp(q_len, k_len, head_dim, query_states): elif query_states.dtype != torch.float16: # esimd_sdp only has optimization for FP16 now return False - else: - device_name = torch.xpu.get_device_name(query_states.device.index) - if device_name.startswith("Intel(R) Arc(TM) A") or \ - device_name.startswith("Intel(R) Data Center GPU Flex"): - import linear_fp16_esimd - if hasattr(linear_fp16_esimd, "sdp_forward"): - return True - else: - return False - else: + elif query_states.shape[0] > 1 and attention_mask is not None: + # for batched input, can't accept attention_mask + if len(torch.nonzero(attention_mask, as_tuple=True)) > 1: + return False + + device_name = torch.xpu.get_device_name(query_states.device.index) + if device_name.startswith("Intel(R) Arc(TM) A") or \ + device_name.startswith("Intel(R) Data Center GPU Flex") or \ + device_name.startswith("Intel(R) Data Center GPU Max"): + import linear_fp16_esimd + if not hasattr(linear_fp16_esimd, "sdp_forward"): return False + else: + return False + + return True def mlp_fusion_check(x, qtype, training): From 8318c3f12af21ecae229b07e6a94916eb6fe4ccc Mon Sep 17 00:00:00 2001 From: rnwang04 Date: Tue, 26 Mar 2024 12:36:08 +0800 Subject: [PATCH 2/4] update --- python/llm/dev/benchmark/all-in-one/run.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/llm/dev/benchmark/all-in-one/run.py b/python/llm/dev/benchmark/all-in-one/run.py index f5c0ecfacad..04f4791c1ff 100644 --- a/python/llm/dev/benchmark/all-in-one/run.py +++ b/python/llm/dev/benchmark/all-in-one/run.py @@ -121,7 +121,7 @@ def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1, low_bit, cpu_embedding if 'win' in test_api else 'N/A', round(result[in_out_pair][-1][5], 2), - result[in_out_pair][-1][6] if any(keyword in test_api for keyword in ['int4_gpu', 'int4_fp16_gpu_win', 'int4_loadlowbit_gpu' ]) else 'N/A', + result[in_out_pair][-1][6] if any(keyword in test_api for keyword in ['int4_gpu', 'int4_fp16_gpu_win', 'int4_loadlowbit_gpu', 'fp16_gpu']) else 'N/A', streaming if 'win' in test_api else 'N/A'], ) @@ -408,7 +408,7 @@ def run_transformer_int4_gpu(repo_id, model = model.to('xpu') elif origin_repo_id in LLAMA_IDS: model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit=low_bit, trust_remote_code=True, - use_cache=True).eval() + use_cache=True, torch_dtype=torch.float16).eval() tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True) model = model.to('xpu') else: @@ -716,7 +716,7 @@ def run_bigdl_fp16_gpu(repo_id, print(output[0]) if i >= warm_up: result[in_out].append([model.first_cost, model.rest_cost_mean, model.encoder_time, - actual_in_len, actual_out_len, load_time]) + actual_in_len, actual_out_len, load_time, model.peak_memory]) del model torch.xpu.empty_cache() return result From 028228f8edc8858a2b79f4ea35048953fa0c8909 Mon Sep 17 00:00:00 2001 From: rnwang04 Date: Tue, 26 Mar 2024 15:01:33 +0800 Subject: [PATCH 3/4] fix --- python/llm/dev/benchmark/all-in-one/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/llm/dev/benchmark/all-in-one/run.py b/python/llm/dev/benchmark/all-in-one/run.py index 04f4791c1ff..ddd130d2ac4 100644 --- a/python/llm/dev/benchmark/all-in-one/run.py +++ b/python/llm/dev/benchmark/all-in-one/run.py @@ -408,7 +408,7 @@ def run_transformer_int4_gpu(repo_id, model = model.to('xpu') elif origin_repo_id in LLAMA_IDS: model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit=low_bit, trust_remote_code=True, - use_cache=True, torch_dtype=torch.float16).eval() + use_cache=True).eval() tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True) model = model.to('xpu') else: From 5e267ca4f2df93bbc2860f3f259c1a7727259a5e Mon Sep 17 00:00:00 2001 From: rnwang04 Date: Tue, 26 Mar 2024 18:58:24 +0800 Subject: [PATCH 4/4] fix batch --- python/llm/src/ipex_llm/transformers/models/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index a5442c8254d..79c1ffecf71 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -317,7 +317,8 @@ def use_esimd_sdp(q_len, k_len, head_dim, query_states, attention_mask=None): return False elif query_states.shape[0] > 1 and attention_mask is not None: # for batched input, can't accept attention_mask - if len(torch.nonzero(attention_mask, as_tuple=True)) > 1: + # TODO: this check needs some time + if not torch.all(attention_mask.eq(0)): return False device_name = torch.xpu.get_device_name(query_states.device.index)