From 443d00779acac739c3a185f384b78236eaac9643 Mon Sep 17 00:00:00 2001 From: xinhe Date: Fri, 13 Sep 2024 21:35:32 +0800 Subject: [PATCH] add INC_FORCE_DEVICE introduction (#1988) * add INC_FORCE_DEVICE introduction Signed-off-by: xin3he * Update PyTorch.md * Update PyTorch.md * Update docs/source/3x/PyTorch.md Co-authored-by: Yi Liu * rename to INC_TARGET_DEVICE Signed-off-by: xin3he --------- Signed-off-by: xin3he Co-authored-by: Yi Liu --- docs/source/3x/PyTorch.md | 14 +++++++++- .../torch/utils/auto_accelerator.py | 26 +++++++++---------- .../quantization/weight_only/test_hqq.py | 4 +-- test/3x/torch/utils/test_auto_accelerator.py | 16 +++++++----- 4 files changed, 37 insertions(+), 23 deletions(-) diff --git a/docs/source/3x/PyTorch.md b/docs/source/3x/PyTorch.md index 2c2111d4d69..608ccc57bd3 100644 --- a/docs/source/3x/PyTorch.md +++ b/docs/source/3x/PyTorch.md @@ -245,7 +245,7 @@ Deep Learning 2. How to set different configuration for specific op_name or op_type? - > INC extends a `set_local` method based on the global configuration object to set custom configuration. + > Neural Compressor extends a `set_local` method based on the global configuration object to set custom configuration. ```python def set_local(self, operator_name_or_list: Union[List, str, Callable], config: BaseConfig) -> BaseConfig: @@ -264,3 +264,15 @@ Deep Learning quant_config.set_local(".*mlp.*", RTNConfig(bits=8)) # For layers with "mlp" in their names, set bits=8 quant_config.set_local("Conv1d", RTNConfig(dtype="fp32")) # For Conv1d layers, do not quantize them. ``` + +3. How to specify an accelerator? + + > Neural Compressor provides automatic accelerator detection, including HPU, XPU, CUDA, and CPU. + + > The automatically detected accelerator may not be suitable for some special cases, such as poor performance, memory limitations. In such situations, users can override the detected accelerator by setting the environment variable `INC_TARGET_DEVICE`. + + > Usage: + + ```bash + export INC_TARGET_DEVICE=cpu + ``` diff --git a/neural_compressor/torch/utils/auto_accelerator.py b/neural_compressor/torch/utils/auto_accelerator.py index b72cffff5ad..57c7ed9a9b0 100644 --- a/neural_compressor/torch/utils/auto_accelerator.py +++ b/neural_compressor/torch/utils/auto_accelerator.py @@ -395,19 +395,19 @@ def mark_step(self): def auto_detect_accelerator(device_name="auto") -> Auto_Accelerator: """Automatically detects and selects the appropriate accelerator. - Force use the cpu on node has both cpu and gpu: `FORCE_DEVICE=cpu` python main.py ... - The `FORCE_DEVICE` is case insensitive. - The environment variable `FORCE_DEVICE` has higher priority than the `device_name`. + Force use the cpu on node has both cpu and gpu: `INC_TARGET_DEVICE=cpu` python main.py ... + The `INC_TARGET_DEVICE` is case insensitive. + The environment variable `INC_TARGET_DEVICE` has higher priority than the `device_name`. TODO: refine the docs and logic later """ - # 1. Get the device setting from environment variable `FORCE_DEVICE`. - FORCE_DEVICE = os.environ.get("FORCE_DEVICE", None) - if FORCE_DEVICE: - FORCE_DEVICE = FORCE_DEVICE.lower() - # 2. If the `FORCE_DEVICE` is set and the accelerator is available, use it. - if FORCE_DEVICE and accelerator_registry.get_accelerator_cls_by_name(FORCE_DEVICE) is not None: - logger.warning("Force use %s accelerator.", FORCE_DEVICE) - return accelerator_registry.get_accelerator_cls_by_name(FORCE_DEVICE)() + # 1. Get the device setting from environment variable `INC_TARGET_DEVICE`. + INC_TARGET_DEVICE = os.environ.get("INC_TARGET_DEVICE", None) + if INC_TARGET_DEVICE: + INC_TARGET_DEVICE = INC_TARGET_DEVICE.lower() + # 2. If the `INC_TARGET_DEVICE` is set and the accelerator is available, use it. + if INC_TARGET_DEVICE and accelerator_registry.get_accelerator_cls_by_name(INC_TARGET_DEVICE) is not None: + logger.warning("Force use %s accelerator.", INC_TARGET_DEVICE) + return accelerator_registry.get_accelerator_cls_by_name(INC_TARGET_DEVICE)() # 3. If the `device_name` is set and the accelerator is available, use it. if device_name != "auto": if accelerator_registry.get_accelerator_cls_by_name(device_name) is not None: @@ -425,8 +425,8 @@ def auto_detect_accelerator(device_name="auto") -> Auto_Accelerator: # Force use cpu accelerator even if cuda is available. -# FORCE_DEVICE = "cpu" python ... +# INC_TARGET_DEVICE = "cpu" python ... # or -# FORCE_DEVICE = "CPU" python ... +# INC_TARGET_DEVICE = "CPU" python ... # or # CUDA_VISIBLE_DEVICES="" python ... diff --git a/test/3x/torch/quantization/weight_only/test_hqq.py b/test/3x/torch/quantization/weight_only/test_hqq.py index d6e0352c312..05c2de75433 100644 --- a/test/3x/torch/quantization/weight_only/test_hqq.py +++ b/test/3x/torch/quantization/weight_only/test_hqq.py @@ -68,7 +68,7 @@ def setup_class(cls): @pytest.fixture def force_use_cpu(self, monkeypatch): # Force use CPU - monkeypatch.setenv("FORCE_DEVICE", "cpu") + monkeypatch.setenv("INC_TARGET_DEVICE", "cpu") @pytest.fixture def force_not_half(self, monkeypatch): @@ -194,7 +194,7 @@ def test_hqq_module( if device_name == "cuda" and not torch.cuda.is_available(): pytest.skip("Skipping CUDA test because cuda is not available") if device_name == "cpu": - os.environ["FORCE_DEVICE"] = "cpu" + os.environ["INC_TARGET_DEVICE"] = "cpu" hqq_global_option.use_half = False _common_hqq_test( diff --git a/test/3x/torch/utils/test_auto_accelerator.py b/test/3x/torch/utils/test_auto_accelerator.py index dea9cdce918..06dcdf8c722 100644 --- a/test/3x/torch/utils/test_auto_accelerator.py +++ b/test/3x/torch/utils/test_auto_accelerator.py @@ -17,7 +17,9 @@ @pytest.mark.skipif(not HPU_Accelerator.is_available(), reason="HPEX is not available") class TestHPUAccelerator: def test_cuda_accelerator(self): - assert os.environ.get("FORCE_DEVICE", None) is None, "FORCE_DEVICE shouldn't be set. HPU is the first priority." + assert ( + os.environ.get("INC_TARGET_DEVICE", None) is None + ), "INC_TARGET_DEVICE shouldn't be set. HPU is the first priority." accelerator = auto_detect_accelerator() assert accelerator.current_device() == 0, f"{accelerator.current_device()}" assert accelerator.current_device_name() == "hpu:0" @@ -47,10 +49,10 @@ class TestXPUAccelerator: @pytest.fixture def force_use_xpu(self, monkeypatch): # Force use xpu - monkeypatch.setenv("FORCE_DEVICE", "xpu") + monkeypatch.setenv("INC_TARGET_DEVICE", "xpu") def test_xpu_accelerator(self, force_use_xpu): - print(f"FORCE_DEVICE: {os.environ.get('FORCE_DEVICE', None)}") + print(f"INC_TARGET_DEVICE: {os.environ.get('INC_TARGET_DEVICE', None)}") accelerator = auto_detect_accelerator() assert accelerator.current_device() == 0, f"{accelerator.current_device()}" assert accelerator.current_device_name() == "xpu:0" @@ -79,10 +81,10 @@ class TestCPUAccelerator: @pytest.fixture def force_use_cpu(self, monkeypatch): # Force use CPU - monkeypatch.setenv("FORCE_DEVICE", "cpu") + monkeypatch.setenv("INC_TARGET_DEVICE", "cpu") def test_cpu_accelerator(self, force_use_cpu): - print(f"FORCE_DEVICE: {os.environ.get('FORCE_DEVICE', None)}") + print(f"INC_TARGET_DEVICE: {os.environ.get('INC_TARGET_DEVICE', None)}") accelerator = auto_detect_accelerator() assert accelerator.current_device() == "cpu", f"{accelerator.current_device()}" assert accelerator.current_device_name() == "cpu" @@ -99,10 +101,10 @@ class TestCUDAAccelerator: @pytest.fixture def force_use_cuda(self, monkeypatch): # Force use CUDA - monkeypatch.setenv("FORCE_DEVICE", "cuda") + monkeypatch.setenv("INC_TARGET_DEVICE", "cuda") def test_cuda_accelerator(self, force_use_cuda): - print(f"FORCE_DEVICE: {os.environ.get('FORCE_DEVICE', None)}") + print(f"INC_TARGET_DEVICE: {os.environ.get('INC_TARGET_DEVICE', None)}") accelerator = auto_detect_accelerator() assert accelerator.current_device() == 0, f"{accelerator.current_device()}" assert accelerator.current_device_name() == "cuda:0"