Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Server-Side Component Config Validation #1988

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 25 additions & 82 deletions src/zenml/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
EntityExistsError,
IllegalOperationError,
InitializationException,
StackComponentValidationError,
ValidationError,
ZenKeyError,
)
Expand Down Expand Up @@ -167,7 +166,7 @@
if TYPE_CHECKING:
from zenml.metadata.metadata_types import MetadataType, MetadataTypeEnum
from zenml.service_connectors.service_connector import ServiceConnector
from zenml.stack import Stack, StackComponentConfig
from zenml.stack import Stack
from zenml.zen_stores.base_zen_store import BaseZenStore

logger = get_logger(__name__)
Expand Down Expand Up @@ -1949,24 +1948,23 @@ def _validate_stack_configuration(
f"unregistered {component_type} with id "
f"'{component_id}'."
) from e
# Get the flavor model
flavor_model = self.get_flavor_by_name_and_type(
name=component.flavor, component_type=component.type
)

# Create and validate the configuration
from zenml.stack import Flavor
# Create and validate the configuration
from zenml.stack.utils import validate_stack_component_config

flavor = Flavor.from_model(flavor_model)
configuration = flavor.config_class(**component.configuration)
if configuration.is_local:
local_components.append(
f"{component.type.value}: {component.name}"
)
elif configuration.is_remote:
remote_components.append(
f"{component.type.value}: {component.name}"
configuration = validate_stack_component_config(
configuration_dict=component.configuration,
flavor_name=component.flavor,
component_type=component.type,
)
if configuration.is_local:
local_components.append(
f"{component.type.value}: {component.name}"
)
elif configuration.is_remote:
remote_components.append(
f"{component.type.value}: {component.name}"
)

if local_components and remote_components:
logger.warning(
Expand Down Expand Up @@ -2142,22 +2140,12 @@ def create_stack_component(
Returns:
The model of the registered component.
"""
# Get the flavor model
flavor_model = self.get_flavor_by_name_and_type(
name=flavor,
component_type=component_type,
)
from zenml.stack.utils import validate_stack_component_config

# Create and validate the configuration
from zenml.stack import Flavor

flavor_class = Flavor.from_model(flavor_model)
configuration_obj = flavor_class.config_class(
warn_about_plain_text_secrets=True, **configuration
)

self._validate_stack_component_configuration(
component_type, configuration=configuration_obj
validate_stack_component_config(
configuration_dict=configuration,
flavor_name=flavor,
component_type=component_type,
)

create_component_model = ComponentRequestModel(
Expand Down Expand Up @@ -2254,26 +2242,20 @@ def update_stack_component(
if configuration is not None:
existing_configuration = component.configuration
existing_configuration.update(configuration)

existing_configuration = {
k: v
for k, v in existing_configuration.items()
if v is not None
}

flavor_model = self.get_flavor_by_name_and_type(
name=component.flavor,
from zenml.stack.utils import validate_stack_component_config

validate_stack_component_config(
configuration_dict=existing_configuration,
flavor_name=component.flavor,
component_type=component.type,
)

from zenml.stack import Flavor

flavor = Flavor.from_model(flavor_model)
configuration_obj = flavor.config_class(**existing_configuration)

self._validate_stack_component_configuration(
component.type, configuration=configuration_obj
)
update_model.configuration = existing_configuration

if labels is not None:
Expand Down Expand Up @@ -2320,45 +2302,6 @@ def delete_stack_component(
component.name,
)

def _validate_stack_component_configuration(
self,
component_type: "StackComponentType",
configuration: "StackComponentConfig",
) -> None:
"""Validates the configuration of a stack component.

Args:
component_type: The type of the component.
configuration: The component configuration to validate.

Raises:
StackComponentValidationError: in case the stack component configuration is invalid.
"""
from zenml.enums import StoreType

if configuration.is_remote and self.zen_store.is_local_store():
if self.zen_store.type != StoreType.REST:
logger.warning(
"You are configuring a stack component that is running "
"remotely while using a local ZenML server. The component "
"may not be able to reach the local ZenML server and will "
"therefore not be functional. Please consider deploying "
"and/or using a remote ZenML server instead."
)
elif configuration.is_local and not self.zen_store.is_local_store():
logger.warning(
"You are configuring a stack component that is using "
"local resources while connected to a remote ZenML server. The "
"stack component may not be usable from other hosts or by "
"other users. You should consider using a non-local stack "
"component alternative instead."
)
if not configuration.is_valid:
raise StackComponentValidationError(
f"Invalid stack component configuration. please verify "
f"the configurations set for {component_type}."
)

# .---------.
# | FLAVORS |
# '---------'
Expand Down
4 changes: 0 additions & 4 deletions src/zenml/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,6 @@ class StackValidationError(ZenMLBaseException):
"""Raised when a stack configuration is not valid."""


class StackComponentValidationError(ZenMLBaseException):
"""Raised when a stack component configuration is not valid."""


class ProvisioningError(ZenMLBaseException):
"""Raised when an error occurs when provisioning resources for a StackComponent."""

Expand Down
148 changes: 148 additions & 0 deletions src/zenml/stack/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright (c) ZenML GmbH 2023. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Util functions for handling stacks, components, and flavors."""

from typing import Any, Dict, Optional

from zenml.client import Client
from zenml.enums import StackComponentType, StoreType
from zenml.logger import get_logger
from zenml.models.flavor_models import FlavorFilterModel, FlavorResponseModel
from zenml.stack.flavor import Flavor
from zenml.stack.stack_component import StackComponentConfig
from zenml.zen_stores.base_zen_store import BaseZenStore

logger = get_logger(__name__)


def validate_stack_component_config(
configuration_dict: Dict[str, Any],
flavor_name: str,
component_type: StackComponentType,
zen_store: Optional[BaseZenStore] = None,
) -> StackComponentConfig:
"""Validate the configuration of a stack component.

Args:
configuration_dict: The stack component configuration to validate.
flavor_name: The name of the flavor of the stack component.
component_type: The type of the stack component.
zen_store: An optional ZenStore in which to look for the flavor. If not
provided, the flavor will be fetched via the regular ZenML Client.
This is mainly useful for checks running inside the ZenML server.

Returns:
The validated stack component configuration.

Raises:
ValueError: If the configuration is invalid.
"""
flavor_class = get_stack_component_flavor_class(
flavor_name=flavor_name,
component_type=component_type,
zen_store=zen_store,
)
configuration = flavor_class.config_class(**configuration_dict)
if not configuration.is_valid:
raise ValueError(
f"Invalid stack component configuration. Please verify "
f"the configurations set for {component_type}."
)
_warn_if_config_server_mismatch(
configuration, zen_store=zen_store or Client().zen_store
)
return configuration


def _warn_if_config_server_mismatch(
configuration: StackComponentConfig, zen_store: BaseZenStore
) -> None:
"""Warn if the configuration is mismatched with the ZenML server."""
if configuration.is_remote and zen_store.is_local_store():
if zen_store.type != StoreType.REST:
logger.warning(
"You are configuring a stack component that is running "
fa9r marked this conversation as resolved.
Show resolved Hide resolved
"remotely while using a local ZenML server. The component "
"may not be able to reach the local ZenML server and will "
"therefore not be functional. Please consider deploying "
"and/or using a remote ZenML server instead."
)
elif configuration.is_local and not zen_store.is_local_store():
logger.warning(
"You are configuring a stack component that is using "
"local resources while connected to a remote ZenML server. The "
"stack component may not be usable from other hosts or by "
"other users. You should consider using a non-local stack "
"component alternative instead."
)


def get_stack_component_flavor_class(
flavor_name: str,
component_type: StackComponentType,
zen_store: Optional[BaseZenStore] = None,
) -> Flavor:
"""Get the flavor class of a stack component.

Args:
flavor_name: The name of a stack component flavor.
component_type: The type of the stack component.
zen_store: An optional ZenStore in which to look for the flavor. If not
provided, the flavor will be fetched via the regular ZenML Client.
This is mainly useful for checks running inside the ZenML server.

Returns:
The flavor class of the stack component.
"""
if zen_store:
flavor_model = get_flavor_by_name_and_type_from_zen_store(
zen_store=zen_store,
flavor_name=flavor_name,
component_type=component_type,
)
else:
flavor_model = Client().get_flavor_by_name_and_type(
name=flavor_name,
component_type=component_type,
)
flavor_class = Flavor.from_model(flavor_model)
return flavor_class


def get_flavor_by_name_and_type_from_zen_store(
zen_store: BaseZenStore,
flavor_name: str,
component_type: StackComponentType,
) -> FlavorResponseModel:
"""Get a stack component flavor by name and type from a ZenStore.

Args:
zen_store: The ZenStore to query.
flavor_name: The name of a stack component flavor.
component_type: The type of the stack component.

Returns:
The flavor model.

Raises:
KeyError: If no flavor with the given name and type exists.
"""
flavors = zen_store.list_flavors(
FlavorFilterModel(name=flavor_name, type=component_type)
)
if not flavors:
raise KeyError(
f"No flavor with name '{flavor_name}' and type '{component_type}' exists."
)
return flavors[0]
10 changes: 10 additions & 0 deletions src/zenml/zen_server/routers/stack_components_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,16 @@ def update_stack_component(
Returns:
Updated stack component.
"""
if component_update.configuration:
from zenml.stack.utils import validate_stack_component_config

existing_component = zen_store().get_stack_component(component_id)
validate_stack_component_config(
configuration_dict=component_update.configuration,
flavor_name=existing_component.flavor,
component_type=existing_component.type,
zen_store=zen_store(),
)
return zen_store().update_stack_component(
component_id=component_id,
component_update=component_update,
Expand Down
10 changes: 8 additions & 2 deletions src/zenml/zen_server/routers/workspaces_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,14 @@ def create_stack_component(
"is not supported."
)

# TODO: [server] if possible it should validate here that the configuration
# conforms to the flavor
from zenml.stack.utils import validate_stack_component_config

validate_stack_component_config(
configuration_dict=component.configuration,
flavor_name=component.flavor,
component_type=component.type,
zen_store=zen_store(),
)

return zen_store().create_stack_component(component=component)

Expand Down
Loading