diff --git a/python/llm/src/ipex_llm/transformers/model.py b/python/llm/src/ipex_llm/transformers/model.py index f81ee840942..3e68d8ac2d2 100644 --- a/python/llm/src/ipex_llm/transformers/model.py +++ b/python/llm/src/ipex_llm/transformers/model.py @@ -114,7 +114,7 @@ class _BaseAutoModelClass: @classmethod @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import) - @patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available, create=True) + @patch("transformers.modeling_utils.is_torch_sdpa_available", patch_sdpa_available, create=True) def from_pretrained(cls, *args, **kwargs): @@ -542,7 +542,7 @@ def load_convert(cls, q_k, optimize_model, *args, **kwargs): @classmethod @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import) - @patch("transformers.utils.is_torch_sdpa_available", patch_sdpa_available, create=True) + @patch("transformers.modeling_utils.is_torch_sdpa_available", patch_sdpa_available, create=True) def load_low_bit(cls, pretrained_model_name_or_path, *model_args, diff --git a/python/llm/src/ipex_llm/transformers/patches.py b/python/llm/src/ipex_llm/transformers/patches.py index f115ffa5402..743232c5d1f 100644 --- a/python/llm/src/ipex_llm/transformers/patches.py +++ b/python/llm/src/ipex_llm/transformers/patches.py @@ -17,6 +17,7 @@ from typing import List from transformers.dynamic_module_utils import get_imports +from ipex_llm.utils.ipex_importer import IPEXImporter def patch_flash_attn_import(filename: str) -> List[str]: @@ -28,4 +29,11 @@ def patch_flash_attn_import(filename: str) -> List[str]: def patch_sdpa_available() -> bool: - return False + if IPEXImporter.is_xpu_version_installed(): + return False + else: + try: + from transformers.utils import is_torch_sdpa_available + return is_torch_sdpa_available() + except ImportError: + return False