Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tests] enable test_vera_dtypes on XPU #2017

Merged
merged 2 commits into from
Aug 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions tests/test_vera.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torch import nn

from peft import PeftModel, VeraConfig, get_peft_model
from peft.utils import infer_device


class MLP(nn.Module):
Expand Down Expand Up @@ -284,9 +285,12 @@ def test_vera_different_shapes(self, mlp):

@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_vera_dtypes(self, dtype):
# 1872
if (dtype == torch.bfloat16) and not (torch.cuda.is_available() and torch.cuda.is_bf16_supported()):
pytest.skip("bfloat16 not supported on this system, skipping the test")
if dtype == torch.bfloat16:
# skip if bf16 is not supported on hardware, see #1872
is_xpu = infer_device() == "xpu"
is_cuda_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
if not (is_xpu or is_cuda_bf16):
pytest.skip("bfloat16 not supported on this system, skipping the test")

model = MLP().to(dtype)
config = VeraConfig(target_modules=["lin1", "lin2"], init_weights=False)
Expand Down
Loading