Skip to content

Commit

Permalink
Remove compile wrapper to simplify access to model attributes (micros…
Browse files Browse the repository at this point in the history
…oft#5581)

Having a wrapper of a compiled module brings various restrictions about
accessing attributes of the compiled model.
This PR removes the wrapper of compiled module to simplify the access to
the compiled model.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
tohtana and tjruwase authored Jun 17, 2024
1 parent 8831b57 commit 2a0c0e3
Show file tree
Hide file tree
Showing 9 changed files with 25 additions and 388 deletions.
153 changes: 1 addition & 152 deletions deepspeed/runtime/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,165 +3,14 @@

# DeepSpeed Team

from typing import Union, Callable, Dict, Any
import importlib
import torch
from ..pydantic_v1 import validator
from .config_utils import DeepSpeedConfigModel

COMPILE_CONFIG = "compile"


def is_compile_supported():
return hasattr(torch, "compiler")
return hasattr(torch, "compiler") and hasattr(torch.nn.Module, "compile")


def disable(func):
if is_compile_supported():
return torch.compiler.disable(func)
return func


def get_compile_config(param_dict):
if COMPILE_CONFIG in param_dict:
compile_config_dict = param_dict[COMPILE_CONFIG]
else:
compile_config_dict = {}
return CompileConfig(**compile_config_dict)


def get_backend_fn(backend: Union[str, Callable]) -> Union[str, Callable]:
if isinstance(backend, Callable):
return backend

elif isinstance(backend, str):
if backend in torch._dynamo.list_backends(exclude_tags=()):
return backend

# Get module name from backend name
module_name = '.'.join(backend.split('.')[:-1])
fn_name = backend.split('.')[-1]

try:
module = importlib.import_module(module_name)
backend_fn = getattr(module, fn_name)
except ImportError:
raise ValueError(
f"The backend {backend} is not in the list of available backends and could not be imported.")
return backend_fn

raise ValueError(f"backend for torch.compile must be a string or Callable: {backend}")


class CompileConfig(DeepSpeedConfigModel):
"""
[EXPERIMENTAL] This configuration enables users to activate `torch.compile` within DeepSpeed and customize its settings.
Please be aware that these features and API designs are experimental and subject to change.
"""

enabled: bool = False
"""
Enable torch.compile when True.
"""

backend: str = "inductor"
"""
Passed to `backend` argument of torch.compile.
If the given value is not in torch._dynamo.list_backends(),
DeepSpeed attempts to import and instantiate the module with the given name.
"""

kwargs: Dict[str, Any] = {}
"""
Passed to `kwargs` argument of torch.compile.
"""

@validator("enabled")
def validate_enabled(cls, field_value, values):
if field_value and not is_compile_supported():
raise ValueError("torch.compile is not supported on this version of PyTorch.")
return field_value


def CompiledModuleWrapper(mod, compile_config: Union[CompileConfig, None] = None):

class wrapper(mod.__class__):

def __init__(self, module, compile_config: Union[CompileConfig, None] = None):
self.__dict__ = {k: module.__dict__[k] for k in module.__dict__ if not k in self.__class__.__dict__}

assert is_compile_supported(), "torch.compile is not supported on this version of PyTorch."

self.__dict__['wrapped'] = module
self._is_compiled = False
self._backend = get_backend_fn(compile_config.backend)
self._compile_kwargs = compile_config.kwargs
self._compiler_fn = None

def set_backend(self, backend: Union[str, Callable]):
"""Set the backend for torch.compile.
Args:
backend (Union[str, Callable]): backend name or a function that takes a torch.nn.Module and returns a compiled module.
You can directly pass a function that works as a backend.
See also `backend` field in `CompileConfig` for more details.
"""
self._backend = get_backend_fn(backend)

def set_torch_compile_kwargs(self, kwargs: Dict[str, Union[str, Any]]) -> None:
"""Set kwargs for torch.compile. Kwargs that are set in DeepSpeed config will be overwritten.
You can also pass a backend name with "backend" key to change the backend.
Args:
kwargs (Dict[str, Union[str, Any]]): kwargs passed to torch.compile.
"""

if "backend" in kwargs:
raise ValueError("backend cannot be set as compile kwargs. Use set_backend instead.")
self._compile_kwargs.update(kwargs)

def set_compiler_fn(self, compiler_fn: Callable) -> None:
"""Set a function to be used for compiling the module.
This function should take a torch.nn.Module as input and return a compiled module.
Note that other compile options are ignored when a compiler_fn is set.
Example:
```python
def my_compiler_fn(module: torch.nn.Module):
...
return torch.compile(module, ...)
engine.set_compiler_fn(my_compiler_fn)
```
"""
self._compiler_fn = compiler_fn

def forward(self, *args, **kwargs) -> Any:
if not self.is_compiled:
if self._compiler_fn is None:
self.__dict__['wrapped'] = torch.compile(self.wrapped,
backend=self._backend,
**self._compile_kwargs)
else:
self.__dict__['wrapped'] = self._compiler_fn(self.wrapped)
self._is_compiled = True

return self.__dict__['wrapped'](*args, **kwargs)

@property
def is_compiled(self) -> bool:
return self._is_compiled

@property
def backend(self) -> Union[str, Callable]:
return self._backend

@property
def torch_compile_kwargs(self) -> Dict[str, Any]:
return self._compile_kwargs

@property
def compiler_fn(self) -> Union[Callable, None]:
return self._compiler_fn

return wrapper(mod, compile_config)
3 changes: 0 additions & 3 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from ..comm.config import DeepSpeedCommsConfig
from ..monitor.config import get_monitor_config
from ..inference.config import WeightQuantConfig
from .compiler import get_compile_config

from deepspeed import comm as dist
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
Expand Down Expand Up @@ -911,8 +910,6 @@ def _initialize_params(self, param_dict):
self.weight_quantization_config = WeightQuantConfig(
**param_dict['weight_quantization']) if 'weight_quantization' in param_dict else None

self.compile_config = get_compile_config(param_dict)

self.timers_config = get_timers_config(param_dict)

def _batch_assertion(self):
Expand Down
22 changes: 19 additions & 3 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@

from .pipe.module import PipelineModule
from .utils import get_ma_status
from .compiler import CompiledModuleWrapper
from .compiler import is_compile_supported
from ..ops.adam import FusedAdam
from ..moe.sharded_moe import TopKGate, MOELayer
from ..moe.layer import MoE
Expand Down Expand Up @@ -361,8 +361,7 @@ def __init__(self,
self.flatten = _flatten_dense_tensors
self.unflatten = _unflatten_dense_tensors

if self._config.compile_config.enabled:
self._set_client_model(CompiledModuleWrapper(self.module, self._config.compile_config))
self._is_compiled = False

def destroy(self):
if self.optimizer is not None and hasattr(self.optimizer, 'destroy'):
Expand Down Expand Up @@ -3604,3 +3603,20 @@ def empty_partition_cache(self):
self.optimizer.empty_partition_cache()
gc.collect()
get_accelerator().empty_cache()

def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwargs={}) -> None:
"""Compile the module using the specified backend and kwargs.
If a compiler_fn is set, it will be used instead of torch.compile().
"""
if not is_compile_supported():
raise RuntimeError("compile is not supported in your version of PyTorch.")

if self.is_compiled:
return

self.module.compile(backend=backend, **compile_kwargs)
self._is_compiled = True

@property
def is_compiled(self) -> bool:
return self._is_compiled
10 changes: 1 addition & 9 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,15 +933,7 @@ def __init__(
_ds_config = deepspeed.runtime.config.DeepSpeedConfig(config_dict_or_path,
mpu) if config_dict_or_path is not None else None
if _ds_config is not None:
if _ds_config.zero_config.memory_efficient_linear and _ds_config.compile_config.enabled:
# memory_efficient_linear displays numerous errors when torch.compile is enabled.
# Refer to https://github.com/pytorch/pytorch/issues/119059 for details.
# Further investigation into performance is necessary, even after resolving this issue because
# the `memory_efficient_linear` module may lead to more graph breaks compared to the original implementation.
logger.warning(f'memory_efficient_linear is disabled when torch.compile is enabled.')
mem_efficient_linear = False
else:
mem_efficient_linear = _ds_config.zero_config.memory_efficient_linear
mem_efficient_linear = _ds_config.zero_config.memory_efficient_linear

super().__init__(enabled=enabled, mem_efficient_linear=mem_efficient_linear, ds_config=_ds_config, dtype=dtype)
if not dist.is_initialized():
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,13 @@ def _launch_non_daemonic_procs(self, num_procs):
master_port = get_master_port()
skip_msg = mp.Queue() # Allows forked processes to share pytest.skip reason
processes = []
prev_start_method = mp.get_start_method()
mp.set_start_method('spawn', force=True)
for local_rank in range(num_procs):
p = mp.Process(target=self._dist_run, args=(local_rank, num_procs, master_port, skip_msg))
p.start()
processes.append(p)
mp.set_start_method(prev_start_method, force=True)

# Now loop and wait for a test to complete. The spin-wait here isn't a big
# deal because the number of processes will be O(#GPUs) << O(#CPUs).
Expand Down
85 changes: 0 additions & 85 deletions tests/unit/runtime/compile/test_compile_wrapper.py

This file was deleted.

4 changes: 0 additions & 4 deletions tests/unit/runtime/compile/test_compile_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,6 @@ def test_compile_zero(self, tmpdir, zero_stage, dtype, offload_device):
},
"zero_optimization": {
"stage": zero_stage,
},
"compile": {
"enabled": True,
"backend": get_accelerator().get_compile_backend()
}
}

Expand Down
Loading

0 comments on commit 2a0c0e3

Please sign in to comment.