Skip to content

Commit

Permalink
Fix sagemaker orchestrator and step operator env vars and other minor…
Browse files Browse the repository at this point in the history
… bugs (#1993)

* Support union for flavor schema properties in the CLI

* Fix AWS container registry check for existing repos

* Split large env vars into chunks for the sagemaker orchestrator

* Fix docstring

* Move env var splitting to its own utils module

* Split env vars passed to the sagemaker step operator

* Fix errors

* Add unit tests for env splitting utilities

* Actually adding the unit tests
  • Loading branch information
stefannica authored Oct 30, 2023
1 parent 80935c8 commit eb32d88
Show file tree
Hide file tree
Showing 10 changed files with 332 additions and 13 deletions.
11 changes: 10 additions & 1 deletion src/zenml/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,8 +1324,17 @@ def describe_pydantic_object(schema_json: Dict[str, Any]) -> None:
warning("Properties", bold=True)
for prop, prop_schema in properties.items():
if "$ref" not in prop_schema.keys():
if "type" in prop_schema.keys():
prop_type = prop_schema["type"]
elif "anyOf" in prop_schema.keys():
prop_type = ", ".join(
[p["type"] for p in prop_schema["anyOf"]]
)
prop_type = f"one of: {prop_type}"
else:
prop_type = "object"
warning(
f"{prop}, {prop_schema['type']}"
f"{prop}, {prop_type}"
f"{', REQUIRED' if prop in required else ''}"
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import List, Optional, cast

import boto3
from botocore.exceptions import ClientError, NoCredentialsError
from botocore.exceptions import BotoCoreError, ClientError

from zenml.container_registries.base_container_registry import (
BaseContainerRegistry,
Expand Down Expand Up @@ -80,7 +80,7 @@ def prepare_image_push(self, image_name: str) -> None:
response = boto3.client(
"ecr", region_name=self._get_region()
).describe_repositories()
except NoCredentialsError:
except (BotoCoreError, ClientError):
logger.warning(
"Amazon ECR requires you to create a repository before you can "
f"push an image to it. ZenML is trying to push the image "
Expand Down
24 changes: 20 additions & 4 deletions src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,21 @@
from zenml.constants import (
METADATA_ORCHESTRATOR_URL,
)
from zenml.entrypoints import StepEntrypointConfiguration
from zenml.enums import StackComponentType
from zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor import (
SagemakerOrchestratorConfig,
SagemakerOrchestratorSettings,
)
from zenml.integrations.aws.orchestrators.sagemaker_orchestrator_entrypoint_config import (
SAGEMAKER_PROCESSOR_STEP_ENV_VAR_SIZE_LIMIT,
SagemakerEntrypointConfiguration,
)
from zenml.logger import get_logger
from zenml.metadata.metadata_types import MetadataType, Uri
from zenml.orchestrators import ContainerizedOrchestrator
from zenml.orchestrators.utils import get_orchestrator_run_name
from zenml.stack import StackValidator
from zenml.utils.env_utils import split_environment_variables

if TYPE_CHECKING:
from zenml.models.pipeline_deployment_models import (
Expand Down Expand Up @@ -206,12 +210,24 @@ def prepare_or_run_pipeline(
boto_session=boto_session, default_bucket=self.config.bucket
)

# Sagemaker does not allow environment variables longer than 256
# characters to be passed to Processor steps. If an environment variable
# is longer than 256 characters, we split it into multiple environment
# variables (chunks) and re-construct it on the other side using the
# custom entrypoint configuration.
split_environment_variables(
size_limit=SAGEMAKER_PROCESSOR_STEP_ENV_VAR_SIZE_LIMIT,
env=environment,
)

sagemaker_steps = []
for step_name, step in deployment.step_configurations.items():
image = self.get_image(deployment=deployment, step_name=step_name)
command = StepEntrypointConfiguration.get_entrypoint_command()
arguments = StepEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name, deployment_id=deployment.id
command = SagemakerEntrypointConfiguration.get_entrypoint_command()
arguments = (
SagemakerEntrypointConfiguration.get_entrypoint_arguments(
step_name=step_name, deployment_id=deployment.id
)
)
entrypoint = command + arguments

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) ZenML GmbH 2023. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Entrypoint configuration for ZenML Sagemaker pipeline steps."""

from zenml.entrypoints.step_entrypoint_configuration import (
StepEntrypointConfiguration,
)
from zenml.utils.env_utils import reconstruct_environment_variables

SAGEMAKER_PROCESSOR_STEP_ENV_VAR_SIZE_LIMIT = 256


class SagemakerEntrypointConfiguration(StepEntrypointConfiguration):
"""Entrypoint configuration for ZenML Sagemaker pipeline steps.
The only purpose of this entrypoint configuration is to reconstruct the
environment variables that exceed the maximum length of 256 characters
allowed for Sagemaker Processor steps from their individual components.
"""

def run(self) -> None:
"""Runs the step."""
# Reconstruct the environment variables that exceed the maximum length
# of 256 characters from their individual chunks
reconstruct_environment_variables()

# Run the step
super().run()
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,17 @@
SagemakerStepOperatorConfig,
SagemakerStepOperatorSettings,
)
from zenml.integrations.aws.step_operators.sagemaker_step_operator_entrypoint_config import (
SAGEMAKER_ESTIMATOR_STEP_ENV_VAR_SIZE_LIMIT,
SagemakerEntrypointConfiguration,
)
from zenml.logger import get_logger
from zenml.stack import Stack, StackValidator
from zenml.step_operators import BaseStepOperator
from zenml.step_operators.step_operator_entrypoint_configuration import (
StepOperatorEntrypointConfiguration,
)
from zenml.utils.env_utils import split_environment_variables
from zenml.utils.string_utils import random_str

if TYPE_CHECKING:
Expand Down Expand Up @@ -67,6 +75,17 @@ def settings_class(self) -> Optional[Type["BaseSettings"]]:
"""
return SagemakerStepOperatorSettings

@property
def entrypoint_config_class(
self,
) -> Type[StepOperatorEntrypointConfiguration]:
"""Returns the entrypoint configuration class for this step operator.
Returns:
The entrypoint configuration class for this step operator.
"""
return SagemakerEntrypointConfiguration

@property
def validator(self) -> Optional[StackValidator]:
"""Validates the stack.
Expand Down Expand Up @@ -159,6 +178,16 @@ def launch(
self.name,
)

# Sagemaker does not allow environment variables longer than 512
# characters to be passed to Estimator steps. If an environment variable
# is longer than 512 characters, we split it into multiple environment
# variables (chunks) and re-construct it on the other side using the
# custom entrypoint configuration.
split_environment_variables(
env=environment,
size_limit=SAGEMAKER_ESTIMATOR_STEP_ENV_VAR_SIZE_LIMIT,
)

image_name = info.get_image(key=SAGEMAKER_DOCKER_IMAGE_KEY)
environment[_ENTRYPOINT_ENV_VARIABLE] = " ".join(entrypoint_command)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) ZenML GmbH 2023. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Entrypoint configuration for ZenML Sagemaker step operator."""

from zenml.step_operators.step_operator_entrypoint_configuration import (
StepOperatorEntrypointConfiguration,
)
from zenml.utils.env_utils import reconstruct_environment_variables

SAGEMAKER_ESTIMATOR_STEP_ENV_VAR_SIZE_LIMIT = 512


class SagemakerEntrypointConfiguration(StepOperatorEntrypointConfiguration):
"""Entrypoint configuration for ZenML Sagemaker step operator.
The only purpose of this entrypoint configuration is to reconstruct the
environment variables that exceed the maximum length of 512 characters
allowed for Sagemaker Estimator steps from their individual components.
"""

def run(self) -> None:
"""Runs the step."""
# Reconstruct the environment variables that exceed the maximum length
# of 512 characters from their individual chunks
reconstruct_environment_variables()

# Run the step
super().run()
9 changes: 3 additions & 6 deletions src/zenml/orchestrators/step_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,17 +508,14 @@ def _run_step_with_step_operator(
step_operator_name: The name of the step operator to use.
step_run_info: Additional information needed to run the step.
"""
from zenml.step_operators.step_operator_entrypoint_configuration import (
StepOperatorEntrypointConfiguration,
)

step_operator = _get_step_operator(
stack=self._stack,
step_operator_name=step_operator_name,
)
entrypoint_cfg_class = step_operator.entrypoint_config_class
entrypoint_command = (
StepOperatorEntrypointConfiguration.get_entrypoint_command()
+ StepOperatorEntrypointConfiguration.get_entrypoint_arguments(
entrypoint_cfg_class.get_entrypoint_command()
+ entrypoint_cfg_class.get_entrypoint_arguments(
step_name=self._step_name,
deployment_id=self._deployment.id,
step_run_id=str(step_run_info.step_run_id),
Expand Down
18 changes: 18 additions & 0 deletions src/zenml/step_operators/base_step_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from zenml.logger import get_logger
from zenml.stack import Flavor, StackComponent
from zenml.stack.stack_component import StackComponentConfig
from zenml.step_operators.step_operator_entrypoint_configuration import (
StepOperatorEntrypointConfiguration,
)

if TYPE_CHECKING:
from zenml.config.step_run_info import StepRunInfo
Expand All @@ -43,6 +46,21 @@ def config(self) -> BaseStepOperatorConfig:
"""
return cast(BaseStepOperatorConfig, self._config)

@property
def entrypoint_config_class(
self,
) -> Type[StepOperatorEntrypointConfiguration]:
"""Returns the entrypoint configuration class for this step operator.
Concrete step operator implementations may override this property
to return a custom entrypoint configuration class if they need to
customize the entrypoint configuration.
Returns:
The entrypoint configuration class for this step operator.
"""
return StepOperatorEntrypointConfiguration

@abstractmethod
def launch(
self,
Expand Down
100 changes: 100 additions & 0 deletions src/zenml/utils/env_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) ZenML GmbH 2023. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Utility functions for handling environment variables."""
import os
from typing import Dict, List, Optional, cast

ENV_VAR_CHUNK_SUFFIX = "_CHUNK_"


def split_environment_variables(
size_limit: int,
env: Optional[Dict[str, str]] = None,
) -> None:
"""Split long environment variables into chunks.
Splits the input environment variables with values that exceed the supplied
maximum length into individual components. The input environment variables
are modified in-place.
Args:
size_limit: Maximum length of an environment variable value.
env: Input environment variables dictionary. If not supplied, the
OS environment variables are used.
Raises:
RuntimeError: If an environment variable value is too large and requires
more than 10 chunks.
"""
if env is None:
env = cast(Dict[str, str], os.environ)

for key, value in env.copy().items():
if len(value) <= size_limit:
continue

# We keep the number of chunks to a maximum of 10 to avoid generating
# too many environment variables chunks and also to make the
# reconstruction easier to implement
if len(value) > size_limit * 10:
raise RuntimeError(
f"Environment variable {key} exceeds the maximum length of "
f"{size_limit * 10} characters."
)

env.pop(key)

# Split the environment variable into chunks
chunks = [
value[i : i + size_limit] for i in range(0, len(value), size_limit)
]
for i, chunk in enumerate(chunks):
env[f"{key}{ENV_VAR_CHUNK_SUFFIX}{i}"] = chunk


def reconstruct_environment_variables(
env: Optional[Dict[str, str]] = None
) -> None:
"""Reconstruct environment variables that were split into chunks.
Reconstructs the environment variables with values that were split into
individual chunks because they were too large. The input environment
variables are modified in-place.
Args:
env: Input environment variables dictionary. If not supplied, the OS
environment variables are used.
"""
if env is None:
env = cast(Dict[str, str], os.environ)

chunks: Dict[str, List[str]] = {}
for key in env.keys():
if not key[:-1].endswith(ENV_VAR_CHUNK_SUFFIX):
continue

# Collect all chunks of the same environment variable
original_key = key[: -(len(ENV_VAR_CHUNK_SUFFIX) + 1)]
chunks.setdefault(original_key, [])
chunks[original_key].append(key)

# Reconstruct the environment variables from their chunks
for key, chunk_keys in chunks.items():
chunk_keys.sort()
value = "".join([env[key] for key in chunk_keys])
env[key] = value

# Remove the chunk environment variables
for key in chunk_keys:
env.pop(key)
Loading

0 comments on commit eb32d88

Please sign in to comment.