Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch sdpa check function in specific module attributes table #12285

Merged
merged 7 commits into from
Oct 29, 2024

Conversation

leonardozcm
Copy link
Contributor

@leonardozcm leonardozcm commented Oct 28, 2024

Description

When a module import another module, they maintain a copy of the attributes ref table from the imported module.
before this PR:

(Pdb)  print(transformers.modeling_utils.is_torch_sdpa_available)
<function is_torch_sdpa_available at 0x000002594D5A4540>
(Pdb)  print(transformers.utils.is_torch_sdpa_available)
<function patch_sdpa_available at 0x0000025951C031A0>
(Pdb)  print(transformers.utils.import_utils.is_torch_sdpa_available)
<function is_torch_sdpa_available at 0x000002594D5A4540>
(Pdb) print(patch_sdpa_available)
<function patch_sdpa_available at 0x0000025951C031A0>

after this PR

(Pdb) print(transformers.modeling_utils.is_torch_sdpa_available)
<function patch_sdpa_available at 0x000001ECD0CF31A0>
(Pdb) print(transformers.utils.is_torch_sdpa_available)
<function is_torch_sdpa_available at 0x000001ECCD65C540>
(Pdb) print(transformers.utils.import_utils.is_torch_sdpa_available)
<function is_torch_sdpa_available at 0x000001ECCD65C540>
(Pdb) print(patch_sdpa_available)
<function patch_sdpa_available at 0x000001ECD0CF31A0>

@Oscilloscope98

This comment was marked as outdated.

@Oscilloscope98
Copy link
Contributor

seems to forget to change here:
image

@Oscilloscope98

This comment was marked as outdated.

@Oscilloscope98

This comment was marked as outdated.

@@ -17,6 +17,8 @@

from typing import List
from transformers.dynamic_module_utils import get_imports
from transformers.utils import is_torch_sdpa_available
Copy link
Contributor

@Oscilloscope98 Oscilloscope98 Oct 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if for the lower transformers version without support for is_torch_sdpa_available?

Maybe add sth like

def patch_sdpa_available() -> bool:
    if IPEXImporter.is_xpu_version_installed():
        return False
    else:
        try:
            from transformers.utils import is_torch_sdpa_available
            return is_torch_sdpa_available()
        except xxx:
            return False 

@@ -17,6 +17,8 @@

from typing import List
from transformers.dynamic_module_utils import get_imports
from transformers.utils import is_torch_sdpa_available
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the import here

@leonardozcm
Copy link
Contributor Author

@leonardozcm leonardozcm merged commit 546f455 into intel-analytics:main Oct 29, 2024
1 check passed
@leonardozcm leonardozcm deleted the fix_sdpa_ref_table branch October 29, 2024 10:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants