Skip to content

Commit

Permalink
FIX Import error in BOFT half precision test (#1995)
Browse files Browse the repository at this point in the history
  • Loading branch information
BenjaminBossan authored Aug 8, 2024
1 parent 9988cb9 commit 41c274e
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -3093,14 +3093,14 @@ class TestBOFT:
def test_boft_half_linear(self):
# Check that we can use BoFT with model loaded in half precision
layer = torch.nn.Linear(160, 160).cuda()
layer = boft.Linear(layer, "layer", boft_n_butterfly_factor=2).to(dtype=torch.bfloat16)
layer = boft.layer.Linear(layer, "layer", boft_n_butterfly_factor=2).to(dtype=torch.bfloat16)
x = torch.randn(160, 160, device="cuda", dtype=torch.bfloat16)
layer(x) # does not raise

@require_torch_gpu
@pytest.mark.single_gpu_tests
def test_boft_half_conv(self):
conv = torch.nn.Conv2d(1, 1, 4).cuda()
conv = boft.Conv2d(conv, "conv", boft_n_butterfly_factor=2).to(dtype=torch.bfloat16)
conv = boft.layer.Conv2d(conv, "conv", boft_n_butterfly_factor=2).to(dtype=torch.bfloat16)
x = torch.randn(1, 160, 160, device="cuda", dtype=torch.bfloat16)
conv(x) # does not raise

0 comments on commit 41c274e

Please sign in to comment.