From cc3a6582a0299804519af7083f3852b768f84faa Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 8 May 2024 13:14:02 -0700 Subject: [PATCH] [CI/Test] fix swap test for multi gpu (#4689) --- tests/kernels/test_cache.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 8a27d51bb78d5..4cae15c79c489 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -222,11 +222,12 @@ def test_reshape_and_cache_flash( random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) + torch.set_default_device(device) # Create a random slot mapping. num_slots = block_size * num_blocks slot_mapping = random.sample(range(num_slots), num_tokens) - slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda') + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=device) qkv = torch.randn(num_tokens, 3, @@ -245,6 +246,7 @@ def test_reshape_and_cache_flash( head_size, kv_cache_dtype, dtype, + device=device, ) key_cache, value_cache = key_caches[0], value_caches[0]