From 2a0c0e3c27099531f364ff290f22153ab05cb359 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Mon, 17 Jun 2024 14:42:54 -0700 Subject: [PATCH] Remove compile wrapper to simplify access to model attributes (#5581) Having a wrapper of a compiled module brings various restrictions about accessing attributes of the compiled model. This PR removes the wrapper of compiled module to simplify the access to the compiled model. --------- Co-authored-by: Olatunji Ruwase --- deepspeed/runtime/compiler.py | 153 +----------------- deepspeed/runtime/config.py | 3 - deepspeed/runtime/engine.py | 22 ++- .../runtime/zero/partition_parameters.py | 10 +- tests/unit/common.py | 3 + .../runtime/compile/test_compile_wrapper.py | 85 ---------- .../unit/runtime/compile/test_compile_zero.py | 4 - .../unit/runtime/compile/test_load_config.py | 131 --------------- tests/unit/runtime/compile/util.py | 2 +- 9 files changed, 25 insertions(+), 388 deletions(-) delete mode 100644 tests/unit/runtime/compile/test_compile_wrapper.py delete mode 100644 tests/unit/runtime/compile/test_load_config.py diff --git a/deepspeed/runtime/compiler.py b/deepspeed/runtime/compiler.py index 4f828d978613..879c0a1a2cc9 100644 --- a/deepspeed/runtime/compiler.py +++ b/deepspeed/runtime/compiler.py @@ -3,165 +3,14 @@ # DeepSpeed Team -from typing import Union, Callable, Dict, Any -import importlib import torch -from ..pydantic_v1 import validator -from .config_utils import DeepSpeedConfigModel - -COMPILE_CONFIG = "compile" def is_compile_supported(): - return hasattr(torch, "compiler") + return hasattr(torch, "compiler") and hasattr(torch.nn.Module, "compile") def disable(func): if is_compile_supported(): return torch.compiler.disable(func) return func - - -def get_compile_config(param_dict): - if COMPILE_CONFIG in param_dict: - compile_config_dict = param_dict[COMPILE_CONFIG] - else: - compile_config_dict = {} - return CompileConfig(**compile_config_dict) - - -def get_backend_fn(backend: Union[str, Callable]) -> Union[str, Callable]: - if isinstance(backend, Callable): - return backend - - elif isinstance(backend, str): - if backend in torch._dynamo.list_backends(exclude_tags=()): - return backend - - # Get module name from backend name - module_name = '.'.join(backend.split('.')[:-1]) - fn_name = backend.split('.')[-1] - - try: - module = importlib.import_module(module_name) - backend_fn = getattr(module, fn_name) - except ImportError: - raise ValueError( - f"The backend {backend} is not in the list of available backends and could not be imported.") - return backend_fn - - raise ValueError(f"backend for torch.compile must be a string or Callable: {backend}") - - -class CompileConfig(DeepSpeedConfigModel): - """ - [EXPERIMENTAL] This configuration enables users to activate `torch.compile` within DeepSpeed and customize its settings. - Please be aware that these features and API designs are experimental and subject to change. - """ - - enabled: bool = False - """ - Enable torch.compile when True. - """ - - backend: str = "inductor" - """ - Passed to `backend` argument of torch.compile. - If the given value is not in torch._dynamo.list_backends(), - DeepSpeed attempts to import and instantiate the module with the given name. - """ - - kwargs: Dict[str, Any] = {} - """ - Passed to `kwargs` argument of torch.compile. - """ - - @validator("enabled") - def validate_enabled(cls, field_value, values): - if field_value and not is_compile_supported(): - raise ValueError("torch.compile is not supported on this version of PyTorch.") - return field_value - - -def CompiledModuleWrapper(mod, compile_config: Union[CompileConfig, None] = None): - - class wrapper(mod.__class__): - - def __init__(self, module, compile_config: Union[CompileConfig, None] = None): - self.__dict__ = {k: module.__dict__[k] for k in module.__dict__ if not k in self.__class__.__dict__} - - assert is_compile_supported(), "torch.compile is not supported on this version of PyTorch." - - self.__dict__['wrapped'] = module - self._is_compiled = False - self._backend = get_backend_fn(compile_config.backend) - self._compile_kwargs = compile_config.kwargs - self._compiler_fn = None - - def set_backend(self, backend: Union[str, Callable]): - """Set the backend for torch.compile. - - Args: - backend (Union[str, Callable]): backend name or a function that takes a torch.nn.Module and returns a compiled module. - You can directly pass a function that works as a backend. - See also `backend` field in `CompileConfig` for more details. - """ - self._backend = get_backend_fn(backend) - - def set_torch_compile_kwargs(self, kwargs: Dict[str, Union[str, Any]]) -> None: - """Set kwargs for torch.compile. Kwargs that are set in DeepSpeed config will be overwritten. - You can also pass a backend name with "backend" key to change the backend. - - Args: - kwargs (Dict[str, Union[str, Any]]): kwargs passed to torch.compile. - """ - - if "backend" in kwargs: - raise ValueError("backend cannot be set as compile kwargs. Use set_backend instead.") - self._compile_kwargs.update(kwargs) - - def set_compiler_fn(self, compiler_fn: Callable) -> None: - """Set a function to be used for compiling the module. - This function should take a torch.nn.Module as input and return a compiled module. - Note that other compile options are ignored when a compiler_fn is set. - - Example: - ```python - def my_compiler_fn(module: torch.nn.Module): - ... - return torch.compile(module, ...) - - engine.set_compiler_fn(my_compiler_fn) - ``` - """ - self._compiler_fn = compiler_fn - - def forward(self, *args, **kwargs) -> Any: - if not self.is_compiled: - if self._compiler_fn is None: - self.__dict__['wrapped'] = torch.compile(self.wrapped, - backend=self._backend, - **self._compile_kwargs) - else: - self.__dict__['wrapped'] = self._compiler_fn(self.wrapped) - self._is_compiled = True - - return self.__dict__['wrapped'](*args, **kwargs) - - @property - def is_compiled(self) -> bool: - return self._is_compiled - - @property - def backend(self) -> Union[str, Callable]: - return self._backend - - @property - def torch_compile_kwargs(self) -> Dict[str, Any]: - return self._compile_kwargs - - @property - def compiler_fn(self) -> Union[Callable, None]: - return self._compiler_fn - - return wrapper(mod, compile_config) diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 04b122963a38..b49b4a8b6086 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -31,7 +31,6 @@ from ..comm.config import DeepSpeedCommsConfig from ..monitor.config import get_monitor_config from ..inference.config import WeightQuantConfig -from .compiler import get_compile_config from deepspeed import comm as dist from deepspeed.runtime.config_utils import DeepSpeedConfigModel @@ -911,8 +910,6 @@ def _initialize_params(self, param_dict): self.weight_quantization_config = WeightQuantConfig( **param_dict['weight_quantization']) if 'weight_quantization' in param_dict else None - self.compile_config = get_compile_config(param_dict) - self.timers_config = get_timers_config(param_dict) def _batch_assertion(self): diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 08ab05d79b6a..4c418fbc532e 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -90,7 +90,7 @@ from .pipe.module import PipelineModule from .utils import get_ma_status -from .compiler import CompiledModuleWrapper +from .compiler import is_compile_supported from ..ops.adam import FusedAdam from ..moe.sharded_moe import TopKGate, MOELayer from ..moe.layer import MoE @@ -361,8 +361,7 @@ def __init__(self, self.flatten = _flatten_dense_tensors self.unflatten = _unflatten_dense_tensors - if self._config.compile_config.enabled: - self._set_client_model(CompiledModuleWrapper(self.module, self._config.compile_config)) + self._is_compiled = False def destroy(self): if self.optimizer is not None and hasattr(self.optimizer, 'destroy'): @@ -3604,3 +3603,20 @@ def empty_partition_cache(self): self.optimizer.empty_partition_cache() gc.collect() get_accelerator().empty_cache() + + def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwargs={}) -> None: + """Compile the module using the specified backend and kwargs. + If a compiler_fn is set, it will be used instead of torch.compile(). + """ + if not is_compile_supported(): + raise RuntimeError("compile is not supported in your version of PyTorch.") + + if self.is_compiled: + return + + self.module.compile(backend=backend, **compile_kwargs) + self._is_compiled = True + + @property + def is_compiled(self) -> bool: + return self._is_compiled diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index a2d2465c9666..a88c0a2d146c 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -933,15 +933,7 @@ def __init__( _ds_config = deepspeed.runtime.config.DeepSpeedConfig(config_dict_or_path, mpu) if config_dict_or_path is not None else None if _ds_config is not None: - if _ds_config.zero_config.memory_efficient_linear and _ds_config.compile_config.enabled: - # memory_efficient_linear displays numerous errors when torch.compile is enabled. - # Refer to https://github.com/pytorch/pytorch/issues/119059 for details. - # Further investigation into performance is necessary, even after resolving this issue because - # the `memory_efficient_linear` module may lead to more graph breaks compared to the original implementation. - logger.warning(f'memory_efficient_linear is disabled when torch.compile is enabled.') - mem_efficient_linear = False - else: - mem_efficient_linear = _ds_config.zero_config.memory_efficient_linear + mem_efficient_linear = _ds_config.zero_config.memory_efficient_linear super().__init__(enabled=enabled, mem_efficient_linear=mem_efficient_linear, ds_config=_ds_config, dtype=dtype) if not dist.is_initialized(): diff --git a/tests/unit/common.py b/tests/unit/common.py index 58bb26ca18b4..1774bcfae9ff 100644 --- a/tests/unit/common.py +++ b/tests/unit/common.py @@ -203,10 +203,13 @@ def _launch_non_daemonic_procs(self, num_procs): master_port = get_master_port() skip_msg = mp.Queue() # Allows forked processes to share pytest.skip reason processes = [] + prev_start_method = mp.get_start_method() + mp.set_start_method('spawn', force=True) for local_rank in range(num_procs): p = mp.Process(target=self._dist_run, args=(local_rank, num_procs, master_port, skip_msg)) p.start() processes.append(p) + mp.set_start_method(prev_start_method, force=True) # Now loop and wait for a test to complete. The spin-wait here isn't a big # deal because the number of processes will be O(#GPUs) << O(#CPUs). diff --git a/tests/unit/runtime/compile/test_compile_wrapper.py b/tests/unit/runtime/compile/test_compile_wrapper.py deleted file mode 100644 index 62af25ac3ba4..000000000000 --- a/tests/unit/runtime/compile/test_compile_wrapper.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# SPDX-License-Identifier: Apache-2.0 - -# DeepSpeed Team - -import pytest -import torch - -import deepspeed -from deepspeed.accelerator import get_accelerator -from deepspeed.utils.torch import required_torch_version - -from unit.common import DistributedTest - -pytestmark = pytest.mark.skipif(not required_torch_version(min_version=2.1), - reason="Compile tests requires Pytorch version 2.1 or above") - - -@pytest.fixture -def base_config(): - config_dict = { - "train_micro_batch_size_per_gpu": 1, - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.00015 - } - }, - "fp16": { - "enabled": True - }, - "compile": { - "enabled": True, - "backend": get_accelerator().get_compile_backend() - } - } - return config_dict - - -class SmallModelWithCustomMethod(torch.nn.Module): - - def __init__(self, hidden_dim, test_value): - super(SmallModelWithCustomMethod, self).__init__() - self.fc = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) - self.v = test_value - - def forward(self, x): - return self.fc(x) - - # Custom function that is not part of DeepSpeed engine. - def get_v(self): - return self.v - - -class TestCustomMethod(DistributedTest): - world_size = 1 - non_daemonic_procs = True - - def _init_engine(self, config, test_value): - hidden_dim = 10 - model = SmallModelWithCustomMethod(hidden_dim, test_value) - engine, _, _, _ = deepspeed.initialize(config=config, model=model, model_parameters=model.parameters()) - return engine - - def _run_model(self, engine): - train_batch_size = 1 - device = torch.device(get_accelerator().current_device_name()) - dtype = engine.module.fc.weight.dtype - hidden_dim = engine.module.fc.weight.shape[1] - x = torch.rand(train_batch_size, hidden_dim, device=device, dtype=dtype) - engine(x) - - @pytest.mark.skipif(not deepspeed.is_compile_supported(), reason="torch.compile is not supported") - def test_custom_function(self, base_config): - if get_accelerator().device_name() == "cpu": - pytest.skip("CPU accelerator does not support this test yet.") - test_value = 10 - - engine = self._init_engine(base_config, test_value) - assert engine.module.get_v() == test_value - self._run_model(engine) - - # The model is compiled after the first run. - # Thus we make sure the custom method is still available after compilation. - assert engine.module.get_v() == test_value diff --git a/tests/unit/runtime/compile/test_compile_zero.py b/tests/unit/runtime/compile/test_compile_zero.py index a0736b0f5425..ca80eef8b31e 100644 --- a/tests/unit/runtime/compile/test_compile_zero.py +++ b/tests/unit/runtime/compile/test_compile_zero.py @@ -50,10 +50,6 @@ def test_compile_zero(self, tmpdir, zero_stage, dtype, offload_device): }, "zero_optimization": { "stage": zero_stage, - }, - "compile": { - "enabled": True, - "backend": get_accelerator().get_compile_backend() } } diff --git a/tests/unit/runtime/compile/test_load_config.py b/tests/unit/runtime/compile/test_load_config.py deleted file mode 100644 index cee8d3b23f6b..000000000000 --- a/tests/unit/runtime/compile/test_load_config.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# SPDX-License-Identifier: Apache-2.0 - -# DeepSpeed Team - -import pytest -import torch - -from unit.simple_model import SimpleModel -import deepspeed -from deepspeed.accelerator import get_accelerator -from deepspeed.utils.torch import required_torch_version - -from unit.common import DistributedTest - -pytestmark = pytest.mark.skipif(not required_torch_version(min_version=2.1), - reason="Compile tests requires Pytorch version 2.1 or above") - -custom_backend_called = False -custom_compler_fn_called = False - -if deepspeed.is_compile_supported(): - # PyTorch v1 does not have torch.fx - def custom_backend(gm: torch.fx.GraphModule, example_inputs): - global custom_backend_called - custom_backend_called = True - return gm.forward - - def custom_compiler_fn(module: torch.nn.Module): - global custom_compler_fn_called - custom_compler_fn_called = True - return torch.compile(module) - - -@pytest.fixture -def base_config(): - config_dict = { - "train_micro_batch_size_per_gpu": 1, - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.00015 - } - }, - "fp16": { - "enabled": True - }, - "compile": { - "enabled": True, - "backend": get_accelerator().get_compile_backend() - } - } - - return config_dict - - -class TestConfigLoad(DistributedTest): - world_size = 1 - non_daemonic_procs = True - - def _init_engine(self, config): - hidden_dim = 10 - model = SimpleModel(hidden_dim) - engine, _, _, _ = deepspeed.initialize(config=config, model=model, model_parameters=model.parameters()) - return engine - - def _run_model(self, engine): - train_batch_size = 1 - device = torch.device(get_accelerator().current_device_name()) - dtype = engine.module.linears[0].weight.dtype - hidden_dim = engine.module.linears[0].weight.shape[1] - x = torch.rand(train_batch_size, hidden_dim, device=device, dtype=dtype) - y = torch.randn_like(x) - engine(x, y) - - @pytest.mark.skipif(not deepspeed.is_compile_supported(), reason="torch.compile is not supported") - def test_compile(self, base_config): - if get_accelerator().device_name() == "cpu": - pytest.skip("CPU accelerator does not support this test yet.") - engine = self._init_engine(base_config) - self._run_model(engine) - assert engine.is_compiled - - @pytest.mark.skipif(not deepspeed.is_compile_supported(), reason="torch.compile is not supported") - def test_custom_backend(self, base_config): - if get_accelerator().device_name() == "cpu": - pytest.skip("CPU accelerator does not support this test yet.") - global custom_backend_called - custom_backend_called = False - - engine = self._init_engine(base_config) - engine.set_backend(f"{__name__}.custom_backend") - self._run_model(engine) - assert custom_backend_called - - def test_compile_disabled(self, base_config): - if get_accelerator().device_name() == "cpu": - pytest.skip("CPU accelerator does not support this test yet.") - base_config["compile"]["enabled"] = False - engine = self._init_engine(base_config) - self._run_model(engine) - - @pytest.mark.skipif(not deepspeed.is_compile_supported(), reason="torch.compile is not supported") - def test_compile_kwargs(self, base_config): - if get_accelerator().device_name() == "cpu": - pytest.skip("CPU accelerator does not support this test yet.") - base_config["compile"]["kwargs"] = {"mode": "default"} - engine = self._init_engine(base_config) - self._run_model(engine) - assert "mode" in engine.torch_compile_kwargs - - @pytest.mark.skipif(not deepspeed.is_compile_supported(), reason="torch.compile is not supported") - def test_set_compile_kwargs(self, base_config): - if get_accelerator().device_name() == "cpu": - pytest.skip("CPU accelerator does not support this test yet.") - engine = self._init_engine(base_config) - engine.set_torch_compile_kwargs({"mode": "default"}) - self._run_model(engine) - assert "mode" in engine.torch_compile_kwargs - - @pytest.mark.skipif(not deepspeed.is_compile_supported(), reason="torch.compile is not supported") - def test_set_compiler_fn(self, base_config): - if get_accelerator().device_name() == "cpu": - pytest.skip("CPU accelerator does not support this test yet.") - global custom_compler_fn_called - custom_compler_fn_called = False - - engine = self._init_engine(base_config) - engine.set_compiler_fn(custom_compiler_fn) - self._run_model(engine) - assert custom_compler_fn_called diff --git a/tests/unit/runtime/compile/util.py b/tests/unit/runtime/compile/util.py index 86eadf3f6976..d53886a81429 100644 --- a/tests/unit/runtime/compile/util.py +++ b/tests/unit/runtime/compile/util.py @@ -84,7 +84,6 @@ def compare_loss(self, config, dtype): baseline_config = deepcopy(config) baseline_config["zero_optimization"]["stage"] = 0 baseline_config["zero_optimization"]["offload_optimizer"] = {} - baseline_config["compile"]["enabled"] = False baseline_engine, baseline_optimizer, _, _ = deepspeed.initialize(config=baseline_config, model=baseline_model, model_parameters=baseline_model.parameters()) @@ -101,6 +100,7 @@ def compare_loss(self, config, dtype): target_engine, target_optimizer, _, _ = deepspeed.initialize(config=config, model=target_model, model_parameters=target_model.parameters()) + target_engine.compile() train_batch_size = config["train_micro_batch_size_per_gpu"]