Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make run_with_accelerate a pythonic decorator #2943

Merged
merged 10 commits into from
Aug 27, 2024
192 changes: 97 additions & 95 deletions src/zenml/integrations/huggingface/steps/accelerate_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
"""Step function to run any ZenML step using Accelerate."""

import functools
import inspect
from typing import Any, Callable, Dict, TypeVar, cast
from typing import Any, Callable, Dict, Optional, TypeVar, Union, cast

import cloudpickle as pickle
from accelerate.commands.launch import ( # type: ignore[import-untyped]
launch_command,
launch_command_parser,
)

from zenml import get_pipeline_context
from zenml.logger import get_logger
from zenml.steps import BaseStep
from zenml.utils.function_utils import _cli_arg_name, create_cli_wrapped_script
Expand All @@ -35,28 +35,31 @@


def run_with_accelerate(
step_function: BaseStep,
step_function_top_level: Optional[BaseStep] = None,
**accelerate_launch_kwargs: Any,
) -> BaseStep:
) -> Union[Callable[[BaseStep], BaseStep], BaseStep]:
"""Run a function with accelerate.

Accelerate package: https://huggingface.co/docs/accelerate/en/index
Example:
```python
from zenml import step, pipeline
from zenml.integrations.hugginface.steps import run_with_accelerate

@run_with_accelerate(num_processes=4, multi_gpu=True)
@step
def training_step(some_param: int, ...):
# your training code is below
...

@pipeline
def training_pipeline(some_param: int, ...):
run_with_accelerate(training_step, num_processes=4)(some_param, ...)
training_step(some_param, ...)
```

Args:
step_function: The step function to run.
step_function_top_level: The step function to run with accelerate [optional].
Used in functional calls like `run_with_accelerate(some_func,foo=bar)()`.
accelerate_launch_kwargs: A dictionary of arguments to pass along to the
`accelerate launch` command, including hardware selection, resource
allocation, and training paradigm options. Visit
Expand All @@ -65,100 +68,99 @@ def training_pipeline(some_param: int, ...):

Returns:
The accelerate-enabled version of the step.

Raises:
RuntimeError: If the decorator is misused.

"""

def _decorator(
entrypoint: F, accelerate_launch_kwargs: Dict[str, Any]
) -> F:
@functools.wraps(entrypoint)
def inner(*args: Any, **kwargs: Any) -> Any:
if args:
raise ValueError(
"Accelerated steps do not support positional arguments."
)

with create_cli_wrapped_script(
entrypoint, flavor="accelerate"
) as (
script_path,
output_path,
):
commands = [str(script_path.absolute())]
for k, v in kwargs.items():
k = _cli_arg_name(k)
if isinstance(v, bool):
if v:
commands.append(f"--{k}")
elif type(v) in (list, tuple, set):
for each in v:
commands += [f"--{k}", f"{each}"]
else:
commands += [f"--{k}", f"{v}"]
logger.debug(commands)

parser = launch_command_parser()
args = parser.parse_args(commands)
for k, v in accelerate_launch_kwargs.items():
if k in args:
setattr(args, k, v)
else:
logger.warning(
f"You passed in `{k}` as an `accelerate launch` argument, but it was not accepted. "
"Please check https://huggingface.co/docs/accelerate/en/package_reference/cli#accelerate-launch "
"to find out more about supported arguments and retry."
)
try:
launch_command(args)
except Exception as e:
logger.error(
"Accelerate training job failed... See error message for details."
def _decorator(step_function: BaseStep) -> BaseStep:
def _wrapper(
entrypoint: F, accelerate_launch_kwargs: Dict[str, Any]
) -> F:
@functools.wraps(entrypoint)
def inner(*args: Any, **kwargs: Any) -> Any:
if args:
raise ValueError(
"Accelerated steps do not support positional arguments."
)
raise RuntimeError(
"Accelerate training job failed."
) from e
else:
logger.info(
"Accelerate training job finished successfully."
)
return pickle.load(open(output_path, "rb"))

return cast(F, inner)

import __main__

if __main__.__file__ == inspect.getsourcefile(step_function.entrypoint):
raise RuntimeError(
f"`{run_with_accelerate.__name__}` decorator cannot be used "
"with steps defined inside the entrypoint script, please move "
f"your step `{step_function.name}` code to another file and retry."
with create_cli_wrapped_script(
entrypoint, flavor="accelerate"
) as (
script_path,
output_path,
):
commands = [str(script_path.absolute())]
for k, v in kwargs.items():
k = _cli_arg_name(k)
if isinstance(v, bool):
if v:
commands.append(f"--{k}")
elif type(v) in (list, tuple, set):
for each in v:
commands += [f"--{k}", f"{each}"]
else:
commands += [f"--{k}", f"{v}"]
logger.debug(commands)

parser = launch_command_parser()
args = parser.parse_args(commands)
for k, v in accelerate_launch_kwargs.items():
if k in args:
setattr(args, k, v)
else:
logger.warning(
f"You passed in `{k}` as an `accelerate launch` argument, but it was not accepted. "
"Please check https://huggingface.co/docs/accelerate/en/package_reference/cli#accelerate-launch "
"to find out more about supported arguments and retry."
)
try:
launch_command(args)
except Exception as e:
logger.error(
"Accelerate training job failed... See error message for details."
)
raise RuntimeError(
"Accelerate training job failed."
) from e
else:
logger.info(
"Accelerate training job finished successfully."
)
return pickle.load(open(output_path, "rb"))

return cast(F, inner)

try:
get_pipeline_context()
except RuntimeError:
pass
else:
raise RuntimeError(
f"`{run_with_accelerate.__name__}` decorator cannot be used "
"in a functional way with steps, please apply decoration "
"directly to a step instead. This behavior will be also "
"allowed in future, but now it faces technical limitations.\n"
"Example (allowed):\n"
f"@{run_with_accelerate.__name__}(...)\n"
f"def {step_function.name}(...):\n"
" ...\n"
"Example (not allowed):\n"
"def my_pipeline(...):\n"
f" run_with_accelerate({step_function.name},...)(...)\n"
avishniakov marked this conversation as resolved.
Show resolved Hide resolved
)

setattr(
step_function, "unwrapped_entrypoint", step_function.entrypoint
)
if f"@{run_with_accelerate.__name__}" in inspect.getsource(
step_function.entrypoint
):
raise RuntimeError(
f"`{run_with_accelerate.__name__}` decorator cannot be used "
"directly on steps using '@' syntax, please use a functional "
"decoration in your pipeline script instead.\n"
"Example (not allowed):\n"
f"@{run_with_accelerate.__name__}\n"
f"def {step_function.name}(...):\n"
" ...\n"
"Example (allowed):\n"
"def my_pipeline(...):\n"
f" run_with_accelerate({step_function.name})(...)\n"
setattr(
step_function,
"entrypoint",
_wrapper(
step_function.entrypoint,
accelerate_launch_kwargs=accelerate_launch_kwargs,
),
)

setattr(
step_function,
"entrypoint",
_decorator(
step_function.entrypoint,
accelerate_launch_kwargs=accelerate_launch_kwargs,
),
)
return step_function

return step_function
if step_function_top_level:
return _decorator(step_function_top_level)
return _decorator
8 changes: 4 additions & 4 deletions src/zenml/utils/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@
import sys
sys.path.append(r"{func_path}")

from {func_module} import {func_name} as func_to_wrap
from {func_module} import {func_name} as step_function

if entrypoint:=getattr(func_to_wrap, "entrypoint", None):
func = _cli_wrapped_function(entrypoint)
if unwrapped_entrypoint:=getattr(step_function, "unwrapped_entrypoint", None):
func = _cli_wrapped_function(unwrapped_entrypoint)
else:
func = _cli_wrapped_function(func_to_wrap)
func = _cli_wrapped_function(step_function.entrypoint)
"""
_CLI_WRAPPED_MAINS = {
"accelerate": """
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import shutil
from pathlib import Path

import pytest
import transformers
from accelerate import Accelerator
from datasets import load_from_disk
Expand Down Expand Up @@ -74,17 +75,20 @@ def get_full_path(folder: str):
return str(ft_model_dir)


@pipeline(enable_cache=False)
def train_pipe():
model_dir = run_with_accelerate(train, num_processes=2, use_cpu=True)()
# if it is StepArtifact, we are still composing the pipeline
if not isinstance(model_dir, StepArtifact):
assert isinstance(model_dir, str)
assert model_dir == "model_dir"
train_accelerated = run_with_accelerate(train, num_processes=2, use_cpu=True)


def test_accelerate_runner_on_cpu_with_toy_model(clean_client):
"""Tests whether the run_with_accelerate wrapper works as expected."""

@pipeline(enable_cache=False)
def train_pipe():
model_dir = train_accelerated()
# if it is StepArtifact, we are still composing the pipeline
if not isinstance(model_dir, StepArtifact):
assert isinstance(model_dir, str)
assert model_dir == "model_dir"

try:
prev_files = os.listdir()
response = train_pipe()
Expand All @@ -93,3 +97,14 @@ def test_accelerate_runner_on_cpu_with_toy_model(clean_client):
cur_files = os.listdir()
for each in set(cur_files) - set(prev_files):
shutil.rmtree(each)


def test_accelerate_runner_fails_on_functional_use(clean_client):
"""Tests whether the run_with_accelerate wrapper works as expected."""

@pipeline(enable_cache=False)
def train_pipe():
_ = run_with_accelerate(train, num_processes=2, use_cpu=True)

with pytest.raises(RuntimeError):
train_pipe()
Loading