Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Server-Side Component Config Validation #1988

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 51 additions & 81 deletions src/zenml/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
EntityExistsError,
IllegalOperationError,
InitializationException,
StackComponentValidationError,
ValidationError,
ZenKeyError,
)
Expand Down Expand Up @@ -167,7 +166,7 @@
if TYPE_CHECKING:
from zenml.metadata.metadata_types import MetadataType, MetadataTypeEnum
from zenml.service_connectors.service_connector import ServiceConnector
from zenml.stack import Stack, StackComponentConfig
from zenml.stack import Stack
from zenml.zen_stores.base_zen_store import BaseZenStore

logger = get_logger(__name__)
Expand Down Expand Up @@ -1949,24 +1948,32 @@ def _validate_stack_configuration(
f"unregistered {component_type} with id "
f"'{component_id}'."
) from e
# Get the flavor model
flavor_model = self.get_flavor_by_name_and_type(
name=component.flavor, component_type=component.type
)

# Create and validate the configuration
from zenml.stack import Flavor

flavor = Flavor.from_model(flavor_model)
configuration = flavor.config_class(**component.configuration)
if configuration.is_local:
local_components.append(
f"{component.type.value}: {component.name}"
# Create and validate the configuration
from zenml.stack.utils import (
validate_stack_component_config,
warn_if_config_server_mismatch,
)
elif configuration.is_remote:
remote_components.append(
f"{component.type.value}: {component.name}"

configuration = validate_stack_component_config(
configuration_dict=component.configuration,
flavor_name=component.flavor,
component_type=component.type,
# Always enforce validation of custom flavors
validate_custom_flavors=True,
)
# Guaranteed to not be None by setting
# `validate_custom_flavors=True` above
assert configuration is not None
warn_if_config_server_mismatch(configuration)
if configuration.is_local:
local_components.append(
f"{component.type.value}: {component.name}"
)
elif configuration.is_remote:
remote_components.append(
f"{component.type.value}: {component.name}"
)

if local_components and remote_components:
logger.warning(
Expand Down Expand Up @@ -2142,23 +2149,22 @@ def create_stack_component(
Returns:
The model of the registered component.
"""
# Get the flavor model
flavor_model = self.get_flavor_by_name_and_type(
name=flavor,
component_type=component_type,
from zenml.stack.utils import (
validate_stack_component_config,
warn_if_config_server_mismatch,
)

# Create and validate the configuration
from zenml.stack import Flavor

flavor_class = Flavor.from_model(flavor_model)
configuration_obj = flavor_class.config_class(
warn_about_plain_text_secrets=True, **configuration
)

self._validate_stack_component_configuration(
component_type, configuration=configuration_obj
validated_config = validate_stack_component_config(
configuration_dict=configuration,
flavor_name=flavor,
component_type=component_type,
# Always enforce validation of custom flavors
validate_custom_flavors=True,
)
# Guaranteed to not be None by setting
# `validate_custom_flavors=True` above
assert validated_config is not None
warn_if_config_server_mismatch(validated_config)

create_component_model = ComponentRequestModel(
name=name,
Expand Down Expand Up @@ -2254,26 +2260,29 @@ def update_stack_component(
if configuration is not None:
existing_configuration = component.configuration
existing_configuration.update(configuration)

existing_configuration = {
k: v
for k, v in existing_configuration.items()
if v is not None
}

flavor_model = self.get_flavor_by_name_and_type(
name=component.flavor,
component_type=component.type,
from zenml.stack.utils import (
validate_stack_component_config,
warn_if_config_server_mismatch,
)

from zenml.stack import Flavor

flavor = Flavor.from_model(flavor_model)
configuration_obj = flavor.config_class(**existing_configuration)

self._validate_stack_component_configuration(
component.type, configuration=configuration_obj
validated_config = validate_stack_component_config(
configuration_dict=existing_configuration,
flavor_name=component.flavor,
component_type=component.type,
# Always enforce validation of custom flavors
validate_custom_flavors=True,
)
# Guaranteed to not be None by setting
# `validate_custom_flavors=True` above
assert validated_config is not None
warn_if_config_server_mismatch(validated_config)

update_model.configuration = existing_configuration

if labels is not None:
Expand Down Expand Up @@ -2320,45 +2329,6 @@ def delete_stack_component(
component.name,
)

def _validate_stack_component_configuration(
self,
component_type: "StackComponentType",
configuration: "StackComponentConfig",
) -> None:
"""Validates the configuration of a stack component.

Args:
component_type: The type of the component.
configuration: The component configuration to validate.

Raises:
StackComponentValidationError: in case the stack component configuration is invalid.
"""
from zenml.enums import StoreType

if configuration.is_remote and self.zen_store.is_local_store():
if self.zen_store.type != StoreType.REST:
logger.warning(
"You are configuring a stack component that is running "
"remotely while using a local ZenML server. The component "
"may not be able to reach the local ZenML server and will "
"therefore not be functional. Please consider deploying "
"and/or using a remote ZenML server instead."
)
elif configuration.is_local and not self.zen_store.is_local_store():
logger.warning(
"You are configuring a stack component that is using "
"local resources while connected to a remote ZenML server. The "
"stack component may not be usable from other hosts or by "
"other users. You should consider using a non-local stack "
"component alternative instead."
)
if not configuration.is_valid:
raise StackComponentValidationError(
f"Invalid stack component configuration. please verify "
f"the configurations set for {component_type}."
)

# .---------.
# | FLAVORS |
# '---------'
Expand Down
4 changes: 0 additions & 4 deletions src/zenml/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,6 @@ class StackValidationError(ZenMLBaseException):
"""Raised when a stack configuration is not valid."""


class StackComponentValidationError(ZenMLBaseException):
"""Raised when a stack component configuration is not valid."""


class ProvisioningError(ZenMLBaseException):
"""Raised when an error occurs when provisioning resources for a StackComponent."""

Expand Down
140 changes: 140 additions & 0 deletions src/zenml/stack/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright (c) ZenML GmbH 2023. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Util functions for handling stacks, components, and flavors."""

from typing import Any, Dict, Optional

from zenml.client import Client
from zenml.enums import StackComponentType, StoreType
from zenml.logger import get_logger
from zenml.models.flavor_models import FlavorFilterModel, FlavorResponseModel
from zenml.stack.flavor import Flavor
from zenml.stack.stack_component import StackComponentConfig
from zenml.zen_stores.base_zen_store import BaseZenStore

logger = get_logger(__name__)


def validate_stack_component_config(
configuration_dict: Dict[str, Any],
flavor_name: str,
component_type: StackComponentType,
zen_store: Optional[BaseZenStore] = None,
validate_custom_flavors: bool = True,
) -> Optional[StackComponentConfig]:
"""Validate the configuration of a stack component.

Args:
configuration_dict: The stack component configuration to validate.
flavor_name: The name of the flavor of the stack component.
component_type: The type of the stack component.
zen_store: An optional ZenStore in which to look for the flavor. If not
provided, the flavor will be fetched via the regular ZenML Client.
This is mainly useful for checks running inside the ZenML server.
validate_custom_flavors: When loading custom flavors from the local
environment, this flag decides whether the import failures are
raised or an empty value is returned.

Returns:
The validated stack component configuration or None, if the
flavor is a custom flavor that could not be imported from the local
environment and the `validate_custom_flavors` flag is set to False.

Raises:
ValueError: If the configuration is invalid.
ImportError: If the flavor class could not be imported.
ModuleNotFoundError: If the flavor class could not be imported.
"""
if zen_store:
flavor_model = get_flavor_by_name_and_type_from_zen_store(
zen_store=zen_store,
flavor_name=flavor_name,
component_type=component_type,
)
else:
flavor_model = Client().get_flavor_by_name_and_type(
name=flavor_name,
component_type=component_type,
)
try:
flavor_class = Flavor.from_model(flavor_model)
except (ImportError, ModuleNotFoundError):
# The flavor class couldn't be loaded.
if flavor_model.is_custom and not validate_custom_flavors:
return None
raise

configuration = flavor_class.config_class(**configuration_dict)
if not configuration.is_valid:
raise ValueError(
f"Invalid stack component configuration. Please verify "
f"the configurations set for {component_type}."
)
return configuration


def warn_if_config_server_mismatch(
configuration: StackComponentConfig,
) -> None:
"""Warn if a component configuration is mismatched with the ZenML server.

Args:
configuration: The component configuration to check.
"""
zen_store = Client().zen_store
if configuration.is_remote and zen_store.is_local_store():
if zen_store.type != StoreType.REST:
logger.warning(
"You are configuring a stack component that is running "
fa9r marked this conversation as resolved.
Show resolved Hide resolved
"remotely while using a local ZenML server. The component "
"may not be able to reach the local ZenML server and will "
"therefore not be functional. Please consider deploying "
"and/or using a remote ZenML server instead."
)
elif configuration.is_local and not zen_store.is_local_store():
logger.warning(
"You are configuring a stack component that is using "
"local resources while connected to a remote ZenML server. The "
"stack component may not be usable from other hosts or by "
"other users. You should consider using a non-local stack "
"component alternative instead."
)


def get_flavor_by_name_and_type_from_zen_store(
zen_store: BaseZenStore,
flavor_name: str,
component_type: StackComponentType,
) -> FlavorResponseModel:
"""Get a stack component flavor by name and type from a ZenStore.

Args:
zen_store: The ZenStore to query.
flavor_name: The name of a stack component flavor.
component_type: The type of the stack component.

Returns:
The flavor model.

Raises:
KeyError: If no flavor with the given name and type exists.
"""
flavors = zen_store.list_flavors(
FlavorFilterModel(name=flavor_name, type=component_type)
)
if not flavors:
raise KeyError(
f"No flavor with name '{flavor_name}' and type '{component_type}' exists."
)
return flavors[0]
12 changes: 12 additions & 0 deletions src/zenml/zen_server/routers/stack_components_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,18 @@ def update_stack_component(
Returns:
Updated stack component.
"""
if component_update.configuration:
from zenml.stack.utils import validate_stack_component_config

existing_component = zen_store().get_stack_component(component_id)
validate_stack_component_config(
configuration_dict=component_update.configuration,
flavor_name=existing_component.flavor,
component_type=existing_component.type,
zen_store=zen_store(),
# We allow custom flavors to fail import on the server side.
validate_custom_flavors=False,
)
return zen_store().update_stack_component(
component_id=component_id,
component_update=component_update,
Expand Down
12 changes: 10 additions & 2 deletions src/zenml/zen_server/routers/workspaces_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,16 @@ def create_stack_component(
"is not supported."
)

# TODO: [server] if possible it should validate here that the configuration
# conforms to the flavor
from zenml.stack.utils import validate_stack_component_config

validate_stack_component_config(
configuration_dict=component.configuration,
flavor_name=component.flavor,
component_type=component.type,
zen_store=zen_store(),
# We allow custom flavors to fail import on the server side.
validate_custom_flavors=False,
)

return zen_store().create_stack_component(component=component)

Expand Down
Loading