Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Leverage warm pools for AWS Sagemaker #3027

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion docs/book/component-guide/orchestrators/sagemaker.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,44 @@ They can then be applied to a step as follows:
@step(settings={"orchestrator": sagemaker_orchestrator_settings})
```

For example, if your ZenML component is configured to use `ml.c5.xlarge` with 400GB additional storage by default, all steps will use it except for the step above, which will use `ml.t3.medium` with 30GB additional storage.
For example, if your ZenML component is configured to use `ml.c5.xlarge` with 400GB additional storage by default, all steps will use it except for the step above, which will use `ml.t3.medium` (for Processing Steps) or `ml.m5.xlarge` (for Training Steps) with 30GB additional storage. See the next section for details on how ZenML decides which Sagemaker Step type to use.

Check out [this docs page](../../how-to/use-configuration-files/runtime-configuration.md) for more information on how to specify settings in general.

For more information and a full list of configurable attributes of the Sagemaker orchestrator, check out the [SDK Docs](https://sdkdocs.zenml.io/latest/integration\_code\_docs/integrations-aws/#zenml.integrations.aws.orchestrators.sagemaker\_orchestrator.SagemakerOrchestrator) .

### Using Warm Pools for your pipelines

[Warm Pools in SageMaker](https://docs.aws.amazon.com/sagemaker/latest/dg/train-warm-pools.html) can significantly reduce the startup time of your pipeline steps, leading to faster iterations and improved development efficiency. This feature keeps compute instances in a "warm" state, ready to quickly start new jobs.

To enable Warm Pools, use the `SagemakerOrchestratorSettings` class:

```python
sagemaker_orchestrator_settings = SagemakerOrchestratorSettings(
keep_alive_period_in_seconds = 300, # 5 minutes, default value
)
```

This configuration keeps instances warm for 5 minutes after each job completes, allowing subsequent jobs to start faster if initiated within this timeframe. The reduced startup time can be particularly beneficial for iterative development processes or frequently run pipelines.

If you prefer not to use Warm Pools, you can explicitly disable them:

```python
sagemaker_orchestrator_settings = SagemakerOrchestratorSettings(
keep_alive_period_in_seconds = None,
)
```

By default, the SageMaker orchestrator uses Training Steps where possible, which can offer performance benefits and better integration with SageMaker's training capabilities. To disable this behavior:

```python
sagemaker_orchestrator_settings = SagemakerOrchestratorSettings(
use_training_step = False
)
```

These settings allow you to fine-tune your SageMaker orchestrator configuration, balancing between faster startup times with Warm Pools and more control over resource usage. By optimizing these settings, you can potentially reduce overall pipeline runtime and improve your development workflow efficiency.

#### S3 data access in ZenML steps

In Sagemaker jobs, it is possible to [access data that is located in S3](https://docs.aws.amazon.com/sagemaker/latest/dg/model-access-training-data.html). Similarly, it is possible to write data from a job to a bucket. The ZenML Sagemaker orchestrator supports this via the `SagemakerOrchestratorSettings` and hence at component, pipeline, and step levels.
Expand Down Expand Up @@ -266,6 +298,10 @@ sagemaker_orchestrator_settings = SagemakerOrchestratorSettings(
)
```

{% hint style="warning" %}
Using multichannel output or output mode except `EndOfJob` will make it impossible to use TrainingStep and also Warm Pools. See corresponding section of this document for details.
{% endhint %}

### Enabling CUDA for GPU-backed hardware

Note that if you wish to use this orchestrator to run steps on a GPU, you will need to follow [the instructions on this page](../../how-to/training-with-gpus/training-with-gpus.md) to ensure that it works. It requires adding some extra settings customization and is essential to enable CUDA for the GPU to give its full acceleration.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union

from pydantic import Field
from pydantic import Field, model_validator

from zenml.config.base_settings import BaseSettings
from zenml.integrations.aws import (
Expand All @@ -25,23 +25,38 @@
from zenml.models import ServiceConnectorRequirements
from zenml.orchestrators import BaseOrchestratorConfig
from zenml.orchestrators.base_orchestrator import BaseOrchestratorFlavor
from zenml.utils import deprecation_utils
from zenml.utils.secret_utils import SecretField

if TYPE_CHECKING:
from zenml.integrations.aws.orchestrators import SagemakerOrchestrator

DEFAULT_TRAINING_INSTANCE_TYPE = "ml.m5.xlarge"
avishniakov marked this conversation as resolved.
Show resolved Hide resolved
DEFAULT_PROCESSING_INSTANCE_TYPE = "ml.t3.medium"
DEFAULT_OUTPUT_DATA_S3_MODE = "EndOfJob"


class SagemakerOrchestratorSettings(BaseSettings):
"""Settings for the Sagemaker orchestrator.

Attributes:
instance_type: The instance type to use for the processing job.
processor_role: The IAM role to use for the step execution on a Processor.
execution_role: The IAM role to use for the step execution.
processor_role: DEPRECATED: use `execution_role` instead.
volume_size_in_gb: The size of the EBS volume to use for the processing
job.
max_runtime_in_seconds: The maximum runtime in seconds for the
processing job.
processor_tags: Tags to apply to the Processor assigned to the step.
tags: Tags to apply to the Processor/Estimator assigned to the step.
processor_tags: DEPRECATED: use `tags` instead.
keep_alive_period_in_seconds: The time in seconds after which the
provisioned instance will be terminated if not used. This is only
applicable for TrainingStep type and it is not possible to use
TrainingStep type if the `output_data_s3_uri` is set to Dict[str, str].
use_training_step: Whether to use the TrainingStep type.
It is not possible to use TrainingStep type
if the `output_data_s3_uri` is set to Dict[str, str] or if the
`output_data_s3_mode` != "EndOfJob".
processor_args: Arguments that are directly passed to the SageMaker
Processor for a specific step, allowing for overriding the default
settings provided when configuring the component. See
Expand All @@ -50,6 +65,13 @@ class SagemakerOrchestratorSettings(BaseSettings):
For processor_args.instance_type, check
https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html
for a list of available instance types.
estimator_args: Arguments that are directly passed to the SageMaker
Estimator for a specific step, allowing for overriding the default
settings provided when configuring the component. See
https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator
for a full list of arguments.
For a list of available instance types, check
https://docs.aws.amazon.com/sagemaker/latest/dg/cmn-info-instance-types.html.
input_data_s3_mode: How data is made available to the container.
Two possible input modes: File, Pipe.
input_data_s3_uri: S3 URI where data is located if not locally,
Expand All @@ -74,23 +96,70 @@ class SagemakerOrchestratorSettings(BaseSettings):
Data must be available locally in /opt/ml/processing/output/data/<ChannelName>.
"""

instance_type: str = "ml.t3.medium"
processor_role: Optional[str] = None
instance_type: Optional[str] = None
execution_role: Optional[str] = None
volume_size_in_gb: int = 30
max_runtime_in_seconds: int = 86400
processor_tags: Dict[str, str] = {}
tags: Dict[str, str] = {}
keep_alive_period_in_seconds: Optional[int] = 300 # 5 minutes
use_training_step: Optional[bool] = None

processor_args: Dict[str, Any] = {}
estimator_args: Dict[str, Any] = {}

input_data_s3_mode: str = "File"
input_data_s3_uri: Optional[Union[str, Dict[str, str]]] = Field(
default=None, union_mode="left_to_right"
)

output_data_s3_mode: str = "EndOfJob"
output_data_s3_mode: str = DEFAULT_OUTPUT_DATA_S3_MODE
output_data_s3_uri: Optional[Union[str, Dict[str, str]]] = Field(
default=None, union_mode="left_to_right"
)

processor_role: Optional[str] = None
processor_tags: Dict[str, str] = {}
_deprecation_validator = deprecation_utils.deprecate_pydantic_attributes(
("processor_role", "execution_role"), ("processor_tags", "tags")
)

@model_validator(mode="before")
def validate_model(cls, data: Dict[str, Any]) -> Dict[str, Any]:
"""Check if model is configured correctly.

Args:
data: The model data.

Returns:
The validated model data.

Raises:
ValueError: If the model is configured incorrectly.
"""
use_training_step = data.get("use_training_step", True)
output_data_s3_uri = data.get("output_data_s3_uri", None)
output_data_s3_mode = data.get(
"output_data_s3_mode", DEFAULT_OUTPUT_DATA_S3_MODE
)
if use_training_step and (
isinstance(output_data_s3_uri, dict)
or (
isinstance(output_data_s3_uri, str)
and (output_data_s3_mode != DEFAULT_OUTPUT_DATA_S3_MODE)
)
):
raise ValueError(
"`use_training_step=True` is not supported when `output_data_s3_uri` is a dict or "
f"when `output_data_s3_mode` is not '{DEFAULT_OUTPUT_DATA_S3_MODE}'."
)
instance_type = data.get("instance_type", None)
if instance_type is None:
if use_training_step:
data["instance_type"] = DEFAULT_TRAINING_INSTANCE_TYPE
else:
data["instance_type"] = DEFAULT_PROCESSING_INSTANCE_TYPE
return data


class SagemakerOrchestratorConfig(
BaseOrchestratorConfig, SagemakerOrchestratorSettings
Expand Down
124 changes: 81 additions & 43 deletions src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from sagemaker.processing import ProcessingInput, ProcessingOutput
from sagemaker.workflow.execution_variables import ExecutionVariables
from sagemaker.workflow.pipeline import Pipeline
from sagemaker.workflow.steps import ProcessingStep
from sagemaker.workflow.steps import ProcessingStep, TrainingStep

from zenml.config.base_settings import BaseSettings
from zenml.constants import (
Expand Down Expand Up @@ -238,47 +238,61 @@ def prepare_or_run_pipeline(
ExecutionVariables.PIPELINE_EXECUTION_ARN
)

# Retrieve Processor arguments provided in the Step settings.
processor_args_for_step = step_settings.processor_args or {}
use_training_step = (
step_settings.use_training_step
if step_settings.use_training_step is not None
else (
self.config.use_training_step
if self.config.use_training_step is not None
else True
)
)

# Retrieve Executor arguments provided in the Step settings.
if use_training_step:
args_for_step_executor = step_settings.estimator_args or {}
else:
args_for_step_executor = step_settings.processor_args or {}

# Set default values from configured orchestrator Component to arguments
# to be used when they are not present in processor_args.
processor_args_for_step.setdefault(
"instance_type", step_settings.instance_type
)
processor_args_for_step.setdefault(
args_for_step_executor.setdefault(
"role",
step_settings.processor_role or self.config.execution_role,
step_settings.execution_role or self.config.execution_role,
)
processor_args_for_step.setdefault(
args_for_step_executor.setdefault(
"volume_size_in_gb", step_settings.volume_size_in_gb
)
processor_args_for_step.setdefault(
args_for_step_executor.setdefault(
"max_runtime_in_seconds", step_settings.max_runtime_in_seconds
)
processor_args_for_step.setdefault(
tags = step_settings.tags
args_for_step_executor.setdefault(
"tags",
[
{"Key": key, "Value": value}
for key, value in step_settings.processor_tags.items()
]
if step_settings.processor_tags
else None,
(
[
{"Key": key, "Value": value}
for key, value in tags.items()
]
if tags
else None
),
)
args_for_step_executor.setdefault(
"instance_type", step_settings.instance_type
)

# Set values that cannot be overwritten
processor_args_for_step["image_uri"] = image
processor_args_for_step["instance_count"] = 1
processor_args_for_step["sagemaker_session"] = session
processor_args_for_step["entrypoint"] = entrypoint
processor_args_for_step["base_job_name"] = orchestrator_run_name
processor_args_for_step["env"] = environment
args_for_step_executor["image_uri"] = image
args_for_step_executor["instance_count"] = 1
args_for_step_executor["sagemaker_session"] = session
args_for_step_executor["base_job_name"] = orchestrator_run_name

# Convert network_config to sagemaker.network.NetworkConfig if present
network_config = processor_args_for_step.get("network_config")
network_config = args_for_step_executor.get("network_config")
if network_config and isinstance(network_config, dict):
try:
processor_args_for_step["network_config"] = NetworkConfig(
args_for_step_executor["network_config"] = NetworkConfig(
**network_config
)
except TypeError:
Expand Down Expand Up @@ -317,17 +331,21 @@ def prepare_or_run_pipeline(

# Construct S3 outputs from container for step
outputs = None
output_path = None
avishniakov marked this conversation as resolved.
Show resolved Hide resolved

if step_settings.output_data_s3_uri is None:
pass
elif isinstance(step_settings.output_data_s3_uri, str):
outputs = [
ProcessingOutput(
source="/opt/ml/processing/output/data",
destination=step_settings.output_data_s3_uri,
s3_upload_mode=step_settings.output_data_s3_mode,
)
]
if use_training_step:
output_path = step_settings.output_data_s3_uri
else:
outputs = [
ProcessingOutput(
source="/opt/ml/processing/output/data",
destination=step_settings.output_data_s3_uri,
s3_upload_mode=step_settings.output_data_s3_mode,
)
]
elif isinstance(step_settings.output_data_s3_uri, dict):
outputs = []
for (
Expand All @@ -342,17 +360,37 @@ def prepare_or_run_pipeline(
)
)

# Create Processor and ProcessingStep
processor = sagemaker.processing.Processor(
**processor_args_for_step
)
sagemaker_step = ProcessingStep(
name=step_name,
processor=processor,
depends_on=step.spec.upstream_steps,
inputs=inputs,
outputs=outputs,
)
if use_training_step:
# Create Estimator and TrainingStep
estimator = sagemaker.estimator.Estimator(
keep_alive_period_in_seconds=step_settings.keep_alive_period_in_seconds,
output_path=output_path,
environment=environment,
container_entry_point=entrypoint,
**args_for_step_executor,
)
sagemaker_step = TrainingStep(
name=step_name,
depends_on=step.spec.upstream_steps,
inputs=inputs,
estimator=estimator,
)
else:
# Create Processor and ProcessingStep
processor = sagemaker.processing.Processor(
entrypoint=entrypoint,
env=environment,
**args_for_step_executor,
)

sagemaker_step = ProcessingStep(
name=step_name,
processor=processor,
depends_on=step.spec.upstream_steps,
inputs=inputs,
outputs=outputs,
)

sagemaker_steps.append(sagemaker_step)

# construct the pipeline from the sagemaker_steps
Expand Down
Loading