From eb32d8822cb60302fb1616824fafba151d5133a3 Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Mon, 30 Oct 2023 08:43:42 +0100 Subject: [PATCH] Fix sagemaker orchestrator and step operator env vars and other minor bugs (#1993) * Support union for flavor schema properties in the CLI * Fix AWS container registry check for existing repos * Split large env vars into chunks for the sagemaker orchestrator * Fix docstring * Move env var splitting to its own utils module * Split env vars passed to the sagemaker step operator * Fix errors * Add unit tests for env splitting utilities * Actually adding the unit tests --- src/zenml/cli/utils.py | 11 +- .../aws_container_registry.py | 4 +- .../orchestrators/sagemaker_orchestrator.py | 24 ++++- ...agemaker_orchestrator_entrypoint_config.py | 39 +++++++ .../step_operators/sagemaker_step_operator.py | 29 +++++ ...gemaker_step_operator_entrypoint_config.py | 39 +++++++ src/zenml/orchestrators/step_launcher.py | 9 +- .../step_operators/base_step_operator.py | 18 ++++ src/zenml/utils/env_utils.py | 100 ++++++++++++++++++ tests/unit/utils/test_env_utils.py | 72 +++++++++++++ 10 files changed, 332 insertions(+), 13 deletions(-) create mode 100644 src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator_entrypoint_config.py create mode 100644 src/zenml/integrations/aws/step_operators/sagemaker_step_operator_entrypoint_config.py create mode 100644 src/zenml/utils/env_utils.py create mode 100644 tests/unit/utils/test_env_utils.py diff --git a/src/zenml/cli/utils.py b/src/zenml/cli/utils.py index 05bfda802a7..963fc93cddd 100644 --- a/src/zenml/cli/utils.py +++ b/src/zenml/cli/utils.py @@ -1324,8 +1324,17 @@ def describe_pydantic_object(schema_json: Dict[str, Any]) -> None: warning("Properties", bold=True) for prop, prop_schema in properties.items(): if "$ref" not in prop_schema.keys(): + if "type" in prop_schema.keys(): + prop_type = prop_schema["type"] + elif "anyOf" in prop_schema.keys(): + prop_type = ", ".join( + [p["type"] for p in prop_schema["anyOf"]] + ) + prop_type = f"one of: {prop_type}" + else: + prop_type = "object" warning( - f"{prop}, {prop_schema['type']}" + f"{prop}, {prop_type}" f"{', REQUIRED' if prop in required else ''}" ) diff --git a/src/zenml/integrations/aws/container_registries/aws_container_registry.py b/src/zenml/integrations/aws/container_registries/aws_container_registry.py index b4042979c04..0867570794a 100644 --- a/src/zenml/integrations/aws/container_registries/aws_container_registry.py +++ b/src/zenml/integrations/aws/container_registries/aws_container_registry.py @@ -17,7 +17,7 @@ from typing import List, Optional, cast import boto3 -from botocore.exceptions import ClientError, NoCredentialsError +from botocore.exceptions import BotoCoreError, ClientError from zenml.container_registries.base_container_registry import ( BaseContainerRegistry, @@ -80,7 +80,7 @@ def prepare_image_push(self, image_name: str) -> None: response = boto3.client( "ecr", region_name=self._get_region() ).describe_repositories() - except NoCredentialsError: + except (BotoCoreError, ClientError): logger.warning( "Amazon ECR requires you to create a repository before you can " f"push an image to it. ZenML is trying to push the image " diff --git a/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py b/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py index febe48895ad..481be263a8f 100644 --- a/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +++ b/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py @@ -30,17 +30,21 @@ from zenml.constants import ( METADATA_ORCHESTRATOR_URL, ) -from zenml.entrypoints import StepEntrypointConfiguration from zenml.enums import StackComponentType from zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor import ( SagemakerOrchestratorConfig, SagemakerOrchestratorSettings, ) +from zenml.integrations.aws.orchestrators.sagemaker_orchestrator_entrypoint_config import ( + SAGEMAKER_PROCESSOR_STEP_ENV_VAR_SIZE_LIMIT, + SagemakerEntrypointConfiguration, +) from zenml.logger import get_logger from zenml.metadata.metadata_types import MetadataType, Uri from zenml.orchestrators import ContainerizedOrchestrator from zenml.orchestrators.utils import get_orchestrator_run_name from zenml.stack import StackValidator +from zenml.utils.env_utils import split_environment_variables if TYPE_CHECKING: from zenml.models.pipeline_deployment_models import ( @@ -206,12 +210,24 @@ def prepare_or_run_pipeline( boto_session=boto_session, default_bucket=self.config.bucket ) + # Sagemaker does not allow environment variables longer than 256 + # characters to be passed to Processor steps. If an environment variable + # is longer than 256 characters, we split it into multiple environment + # variables (chunks) and re-construct it on the other side using the + # custom entrypoint configuration. + split_environment_variables( + size_limit=SAGEMAKER_PROCESSOR_STEP_ENV_VAR_SIZE_LIMIT, + env=environment, + ) + sagemaker_steps = [] for step_name, step in deployment.step_configurations.items(): image = self.get_image(deployment=deployment, step_name=step_name) - command = StepEntrypointConfiguration.get_entrypoint_command() - arguments = StepEntrypointConfiguration.get_entrypoint_arguments( - step_name=step_name, deployment_id=deployment.id + command = SagemakerEntrypointConfiguration.get_entrypoint_command() + arguments = ( + SagemakerEntrypointConfiguration.get_entrypoint_arguments( + step_name=step_name, deployment_id=deployment.id + ) ) entrypoint = command + arguments diff --git a/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator_entrypoint_config.py b/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator_entrypoint_config.py new file mode 100644 index 00000000000..83c3d3ee94d --- /dev/null +++ b/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator_entrypoint_config.py @@ -0,0 +1,39 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Entrypoint configuration for ZenML Sagemaker pipeline steps.""" + +from zenml.entrypoints.step_entrypoint_configuration import ( + StepEntrypointConfiguration, +) +from zenml.utils.env_utils import reconstruct_environment_variables + +SAGEMAKER_PROCESSOR_STEP_ENV_VAR_SIZE_LIMIT = 256 + + +class SagemakerEntrypointConfiguration(StepEntrypointConfiguration): + """Entrypoint configuration for ZenML Sagemaker pipeline steps. + + The only purpose of this entrypoint configuration is to reconstruct the + environment variables that exceed the maximum length of 256 characters + allowed for Sagemaker Processor steps from their individual components. + """ + + def run(self) -> None: + """Runs the step.""" + # Reconstruct the environment variables that exceed the maximum length + # of 256 characters from their individual chunks + reconstruct_environment_variables() + + # Run the step + super().run() diff --git a/src/zenml/integrations/aws/step_operators/sagemaker_step_operator.py b/src/zenml/integrations/aws/step_operators/sagemaker_step_operator.py index 3e3355184b2..3c16889cb9c 100644 --- a/src/zenml/integrations/aws/step_operators/sagemaker_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/sagemaker_step_operator.py @@ -24,9 +24,17 @@ SagemakerStepOperatorConfig, SagemakerStepOperatorSettings, ) +from zenml.integrations.aws.step_operators.sagemaker_step_operator_entrypoint_config import ( + SAGEMAKER_ESTIMATOR_STEP_ENV_VAR_SIZE_LIMIT, + SagemakerEntrypointConfiguration, +) from zenml.logger import get_logger from zenml.stack import Stack, StackValidator from zenml.step_operators import BaseStepOperator +from zenml.step_operators.step_operator_entrypoint_configuration import ( + StepOperatorEntrypointConfiguration, +) +from zenml.utils.env_utils import split_environment_variables from zenml.utils.string_utils import random_str if TYPE_CHECKING: @@ -67,6 +75,17 @@ def settings_class(self) -> Optional[Type["BaseSettings"]]: """ return SagemakerStepOperatorSettings + @property + def entrypoint_config_class( + self, + ) -> Type[StepOperatorEntrypointConfiguration]: + """Returns the entrypoint configuration class for this step operator. + + Returns: + The entrypoint configuration class for this step operator. + """ + return SagemakerEntrypointConfiguration + @property def validator(self) -> Optional[StackValidator]: """Validates the stack. @@ -159,6 +178,16 @@ def launch( self.name, ) + # Sagemaker does not allow environment variables longer than 512 + # characters to be passed to Estimator steps. If an environment variable + # is longer than 512 characters, we split it into multiple environment + # variables (chunks) and re-construct it on the other side using the + # custom entrypoint configuration. + split_environment_variables( + env=environment, + size_limit=SAGEMAKER_ESTIMATOR_STEP_ENV_VAR_SIZE_LIMIT, + ) + image_name = info.get_image(key=SAGEMAKER_DOCKER_IMAGE_KEY) environment[_ENTRYPOINT_ENV_VARIABLE] = " ".join(entrypoint_command) diff --git a/src/zenml/integrations/aws/step_operators/sagemaker_step_operator_entrypoint_config.py b/src/zenml/integrations/aws/step_operators/sagemaker_step_operator_entrypoint_config.py new file mode 100644 index 00000000000..a06260708e4 --- /dev/null +++ b/src/zenml/integrations/aws/step_operators/sagemaker_step_operator_entrypoint_config.py @@ -0,0 +1,39 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Entrypoint configuration for ZenML Sagemaker step operator.""" + +from zenml.step_operators.step_operator_entrypoint_configuration import ( + StepOperatorEntrypointConfiguration, +) +from zenml.utils.env_utils import reconstruct_environment_variables + +SAGEMAKER_ESTIMATOR_STEP_ENV_VAR_SIZE_LIMIT = 512 + + +class SagemakerEntrypointConfiguration(StepOperatorEntrypointConfiguration): + """Entrypoint configuration for ZenML Sagemaker step operator. + + The only purpose of this entrypoint configuration is to reconstruct the + environment variables that exceed the maximum length of 512 characters + allowed for Sagemaker Estimator steps from their individual components. + """ + + def run(self) -> None: + """Runs the step.""" + # Reconstruct the environment variables that exceed the maximum length + # of 512 characters from their individual chunks + reconstruct_environment_variables() + + # Run the step + super().run() diff --git a/src/zenml/orchestrators/step_launcher.py b/src/zenml/orchestrators/step_launcher.py index e5eae10f17e..2a23d4f284b 100644 --- a/src/zenml/orchestrators/step_launcher.py +++ b/src/zenml/orchestrators/step_launcher.py @@ -508,17 +508,14 @@ def _run_step_with_step_operator( step_operator_name: The name of the step operator to use. step_run_info: Additional information needed to run the step. """ - from zenml.step_operators.step_operator_entrypoint_configuration import ( - StepOperatorEntrypointConfiguration, - ) - step_operator = _get_step_operator( stack=self._stack, step_operator_name=step_operator_name, ) + entrypoint_cfg_class = step_operator.entrypoint_config_class entrypoint_command = ( - StepOperatorEntrypointConfiguration.get_entrypoint_command() - + StepOperatorEntrypointConfiguration.get_entrypoint_arguments( + entrypoint_cfg_class.get_entrypoint_command() + + entrypoint_cfg_class.get_entrypoint_arguments( step_name=self._step_name, deployment_id=self._deployment.id, step_run_id=str(step_run_info.step_run_id), diff --git a/src/zenml/step_operators/base_step_operator.py b/src/zenml/step_operators/base_step_operator.py index 421443bf02b..5c7bfbe8e7a 100644 --- a/src/zenml/step_operators/base_step_operator.py +++ b/src/zenml/step_operators/base_step_operator.py @@ -20,6 +20,9 @@ from zenml.logger import get_logger from zenml.stack import Flavor, StackComponent from zenml.stack.stack_component import StackComponentConfig +from zenml.step_operators.step_operator_entrypoint_configuration import ( + StepOperatorEntrypointConfiguration, +) if TYPE_CHECKING: from zenml.config.step_run_info import StepRunInfo @@ -43,6 +46,21 @@ def config(self) -> BaseStepOperatorConfig: """ return cast(BaseStepOperatorConfig, self._config) + @property + def entrypoint_config_class( + self, + ) -> Type[StepOperatorEntrypointConfiguration]: + """Returns the entrypoint configuration class for this step operator. + + Concrete step operator implementations may override this property + to return a custom entrypoint configuration class if they need to + customize the entrypoint configuration. + + Returns: + The entrypoint configuration class for this step operator. + """ + return StepOperatorEntrypointConfiguration + @abstractmethod def launch( self, diff --git a/src/zenml/utils/env_utils.py b/src/zenml/utils/env_utils.py new file mode 100644 index 00000000000..4eaca70361b --- /dev/null +++ b/src/zenml/utils/env_utils.py @@ -0,0 +1,100 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Utility functions for handling environment variables.""" +import os +from typing import Dict, List, Optional, cast + +ENV_VAR_CHUNK_SUFFIX = "_CHUNK_" + + +def split_environment_variables( + size_limit: int, + env: Optional[Dict[str, str]] = None, +) -> None: + """Split long environment variables into chunks. + + Splits the input environment variables with values that exceed the supplied + maximum length into individual components. The input environment variables + are modified in-place. + + Args: + size_limit: Maximum length of an environment variable value. + env: Input environment variables dictionary. If not supplied, the + OS environment variables are used. + + Raises: + RuntimeError: If an environment variable value is too large and requires + more than 10 chunks. + """ + if env is None: + env = cast(Dict[str, str], os.environ) + + for key, value in env.copy().items(): + if len(value) <= size_limit: + continue + + # We keep the number of chunks to a maximum of 10 to avoid generating + # too many environment variables chunks and also to make the + # reconstruction easier to implement + if len(value) > size_limit * 10: + raise RuntimeError( + f"Environment variable {key} exceeds the maximum length of " + f"{size_limit * 10} characters." + ) + + env.pop(key) + + # Split the environment variable into chunks + chunks = [ + value[i : i + size_limit] for i in range(0, len(value), size_limit) + ] + for i, chunk in enumerate(chunks): + env[f"{key}{ENV_VAR_CHUNK_SUFFIX}{i}"] = chunk + + +def reconstruct_environment_variables( + env: Optional[Dict[str, str]] = None +) -> None: + """Reconstruct environment variables that were split into chunks. + + Reconstructs the environment variables with values that were split into + individual chunks because they were too large. The input environment + variables are modified in-place. + + Args: + env: Input environment variables dictionary. If not supplied, the OS + environment variables are used. + """ + if env is None: + env = cast(Dict[str, str], os.environ) + + chunks: Dict[str, List[str]] = {} + for key in env.keys(): + if not key[:-1].endswith(ENV_VAR_CHUNK_SUFFIX): + continue + + # Collect all chunks of the same environment variable + original_key = key[: -(len(ENV_VAR_CHUNK_SUFFIX) + 1)] + chunks.setdefault(original_key, []) + chunks[original_key].append(key) + + # Reconstruct the environment variables from their chunks + for key, chunk_keys in chunks.items(): + chunk_keys.sort() + value = "".join([env[key] for key in chunk_keys]) + env[key] = value + + # Remove the chunk environment variables + for key in chunk_keys: + env.pop(key) diff --git a/tests/unit/utils/test_env_utils.py b/tests/unit/utils/test_env_utils.py new file mode 100644 index 00000000000..89fc5a552d0 --- /dev/null +++ b/tests/unit/utils/test_env_utils.py @@ -0,0 +1,72 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from contextlib import ExitStack as does_not_raise + +import pytest + +from zenml.utils.env_utils import ( + reconstruct_environment_variables, + split_environment_variables, +) + + +def test_split_reconstruct_large_env_vars(): + """Test that splitting and reconstructing large environment variables works.""" + env = { + "ARIA_TEST_ENV_VAR": "aria", + "AXL_TEST_ENV_VAR": "axl is gray and puffy", + "BLUPUS_TEST_ENV_VAR": "blupus", + } + + split_environment_variables(env=env, size_limit=4) + + assert env == { + "ARIA_TEST_ENV_VAR": "aria", + "AXL_TEST_ENV_VAR_CHUNK_0": "axl ", + "AXL_TEST_ENV_VAR_CHUNK_1": "is g", + "AXL_TEST_ENV_VAR_CHUNK_2": "ray ", + "AXL_TEST_ENV_VAR_CHUNK_3": "and ", + "AXL_TEST_ENV_VAR_CHUNK_4": "puff", + "AXL_TEST_ENV_VAR_CHUNK_5": "y", + "BLUPUS_TEST_ENV_VAR_CHUNK_0": "blup", + "BLUPUS_TEST_ENV_VAR_CHUNK_1": "us", + } + + reconstruct_environment_variables(env=env) + + assert env == { + "ARIA_TEST_ENV_VAR": "aria", + "AXL_TEST_ENV_VAR": "axl is gray and puffy", + "BLUPUS_TEST_ENV_VAR": "blupus", + } + + +def test_split_too_large_env_var_fails(): + """Test that splitting and reconstructing too large an environment variable fails.""" + env = { + "ARIA_TEST_ENV_VAR": "aria", + "AXL_TEST_ENV_VAR": "axl is gray and puffy and wonderful", + } + + with does_not_raise(): + split_environment_variables(env=env, size_limit=4) + + env = { + "ARIA_TEST_ENV_VAR": "aria", + "AXL_TEST_ENV_VAR": "axl is gray and puffy and wonderful and otherworldly", + } + + with pytest.raises(RuntimeError): + split_environment_variables(env=env, size_limit=4)