From e81a2dd901dd1b93291555722c6d96901940be06 Mon Sep 17 00:00:00 2001 From: Yi30 <106061964+yiliu30@users.noreply.github.com> Date: Tue, 19 Mar 2024 09:30:17 +0800 Subject: [PATCH] Fixed the auto device detection (#1674) * update the device name to current device Signed-off-by: yiliu30 --- neural_compressor/torch/utils/environ.py | 2 +- .../weight_only => }/test_auto_accelerator.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) rename test/3x/torch/{quantization/weight_only => }/test_auto_accelerator.py (83%) diff --git a/neural_compressor/torch/utils/environ.py b/neural_compressor/torch/utils/environ.py index cab60b40416..699d03047cd 100644 --- a/neural_compressor/torch/utils/environ.py +++ b/neural_compressor/torch/utils/environ.py @@ -66,5 +66,5 @@ def get_device(device_name="auto"): from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator runtime_accelerator = auto_detect_accelerator(device_name) - device = runtime_accelerator.name() + device = runtime_accelerator.current_device_name() return device diff --git a/test/3x/torch/quantization/weight_only/test_auto_accelerator.py b/test/3x/torch/test_auto_accelerator.py similarity index 83% rename from test/3x/torch/quantization/weight_only/test_auto_accelerator.py rename to test/3x/torch/test_auto_accelerator.py index c1137b18fdd..42656c8bcd5 100644 --- a/test/3x/torch/quantization/weight_only/test_auto_accelerator.py +++ b/test/3x/torch/test_auto_accelerator.py @@ -3,6 +3,7 @@ import pytest import torch +from neural_compressor.torch.utils import get_device from neural_compressor.torch.utils.auto_accelerator import accelerator_registry, auto_detect_accelerator @@ -52,6 +53,16 @@ def test_cuda_accelerator(self, force_use_cuda): assert accelerator.synchronize() is None assert accelerator.empty_cache() is None + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Only one GPU is available") + def test_get_device(self): + accelerator = auto_detect_accelerator() + assert accelerator.set_device(1) is None + assert accelerator.current_device_name() == "cuda:1" + cur_device = get_device() + assert cur_device == "cuda:1" + tmp_tensor = torch.tensor([1, 2], device=cur_device) + assert "cuda:1" == str(tmp_tensor.device) + class TestAutoAccelerator: