Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pod_template and pod_template_name arguments for PythonAutoContainerTask, its downstream tasks, and @task. #1425

Merged
merged 13 commits into from
Feb 3, 2023
148 changes: 94 additions & 54 deletions dev-requirements.txt

Large diffs are not rendered by default.

272 changes: 165 additions & 107 deletions doc-requirements.txt

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
TaskMetadata - Wrapper object that allows users to specify Task
Resources - Things like CPUs/Memory, etc.
WorkflowFailurePolicy - Customizes what happens when a workflow fails.

PodTemplate - Custom PodTemplate for a task.

Dynamic and Nested Workflows
==============================
Expand Down Expand Up @@ -175,6 +175,7 @@
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.map_task import map_task
from flytekit.core.notification import Email, PagerDuty, Slack
from flytekit.core.pod_template import PodTemplate
from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask
from flytekit.core.reference import get_reference_entity
from flytekit.core.reference_entity import LaunchPlanReference, TaskReference, WorkflowReference
Expand Down
3 changes: 3 additions & 0 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class TaskMetadata(object):
timeout (Optional[Union[datetime.timedelta, int]]): the max amount of time for which one execution of this task
should be executed for. The execution will be terminated if the runtime exceeds the given timeout
(approximately)
pod_template_name (Optional[str]): the name of existing PodTemplate resource in the cluster which will be used in this task.
"""

cache: bool = False
Expand All @@ -94,6 +95,7 @@ class TaskMetadata(object):
deprecated: str = ""
retries: int = 0
timeout: Optional[Union[datetime.timedelta, int]] = None
pod_template_name: Optional[str] = None

def __post_init__(self):
if self.timeout:
Expand Down Expand Up @@ -127,6 +129,7 @@ def to_taskmetadata_model(self) -> _task_model.TaskMetadata:
discovery_version=self.cache_version,
deprecated_error_message=self.deprecated,
cache_serializable=self.cache_serialize,
pod_template_name=self.pod_template_name,
)


Expand Down
1 change: 1 addition & 0 deletions flytekit/core/container_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from flytekit.models.security import Secret, SecurityContext


# TODO: do we need pod_template here? Seems that it is a raw container not running in pods
hamersaw marked this conversation as resolved.
Show resolved Hide resolved
ByronHsu marked this conversation as resolved.
Show resolved Hide resolved
class ContainerTask(PythonTask):
"""
This is an intermediate class that represents Flyte Tasks that run a container at execution time. This is the vast
Expand Down
20 changes: 20 additions & 0 deletions flytekit/core/pod_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from dataclasses import dataclass
from typing import Dict, Optional

from kubernetes.client.models import V1PodSpec

from flytekit.exceptions import user as _user_exceptions

PRIMARY_CONTAINER_DEFAULT_NAME = "primary"


@dataclass
class PodTemplate(object):
ByronHsu marked this conversation as resolved.
Show resolved Hide resolved
ByronHsu marked this conversation as resolved.
Show resolved Hide resolved
pod_spec: V1PodSpec = V1PodSpec(containers=[])
primary_container_name: str = PRIMARY_CONTAINER_DEFAULT_NAME
hamersaw marked this conversation as resolved.
Show resolved Hide resolved
labels: Optional[Dict[str, str]] = None
annotations: Optional[Dict[str, str]] = None

def __post_init__(self):
if not self.primary_container_name:
raise _user_exceptions.FlyteValidationException("A primary container name cannot be undefined")
91 changes: 89 additions & 2 deletions flytekit/core/python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@
import re
from abc import ABC
from types import ModuleType
from typing import Callable, Dict, List, Optional, TypeVar, Union
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union

from flyteidl.core import tasks_pb2 as _core_task
from kubernetes.client import ApiClient
from kubernetes.client.models import V1Container, V1EnvVar, V1ResourceRequirements

from flytekit.configuration import ImageConfig, SerializationSettings
from flytekit.core.base_task import PythonTask, TaskResolverMixin
from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.pod_template import PodTemplate
from flytekit.core.resources import Resources, ResourceSpec
from flytekit.core.tracked_abc import FlyteTrackedABC
from flytekit.core.tracker import TrackedInstance, extract_task_module
Expand All @@ -18,6 +23,11 @@
from flytekit.models.security import Secret, SecurityContext

T = TypeVar("T")
_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"


def _sanitize_resource_name(resource: _task_model.Resources.ResourceEntry) -> str:
return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-")


class PythonAutoContainerTask(PythonTask[T], ABC, metaclass=FlyteTrackedABC):
Expand All @@ -40,6 +50,8 @@ def __init__(
environment: Optional[Dict[str, str]] = None,
task_resolver: Optional[TaskResolverMixin] = None,
secret_requests: Optional[List[Secret]] = None,
pod_template: Optional[PodTemplate] = None,
pod_template_name: Optional[str] = None,
Comment on lines +53 to +54
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we briefly explore name changes here? pod_template and pod_template_name are not very descriptive and certainly don't capture a client-side PodTemplate application vs server-side k8s resource PodTemplate configuration. Thoughts?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it is very confusing now, but honestly I don't have a better idea on my mind. I also feel they should be one of. If users can set both, it will be more confusing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still needs discussion!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We sync'd offline and decided to go with these names.

**kwargs,
):
"""
Expand All @@ -64,13 +76,20 @@ def __init__(
- `Confidant <https://lyft.github.io/confidant/>`__
- `Kube secrets <https://kubernetes.io/docs/concepts/configuration/secret/>`__
- `AWS Parameter store <https://docs.aws.amazon.com/systems-manager/latest/userguide/systems-manager-parameter-store.html>`__
:param pod_template: custom PodTemplate.
:param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
"""
sec_ctx = None
if secret_requests:
for s in secret_requests:
if not isinstance(s, Secret):
raise AssertionError(f"Secret {s} should be of type flytekit.Secret, received {type(s)}")
sec_ctx = SecurityContext(secrets=secret_requests)

# pod_template_name overwrites the metedata.pod_template_name
kwargs["metadata"] = kwargs["metadata"] if "metadata" in kwargs else TaskMetadata()
kwargs["metadata"].pod_template_name = pod_template_name

super().__init__(
task_type=task_type,
name=name,
Expand Down Expand Up @@ -98,6 +117,8 @@ def __init__(
self._task_resolver = task_resolver or default_task_resolver
self._get_command_fn = self.get_default_command

self.pod_template = pod_template

@property
def task_resolver(self) -> Optional[TaskResolverMixin]:
return self._task_resolver
Expand Down Expand Up @@ -157,6 +178,14 @@ def get_command(self, settings: SerializationSettings) -> List[str]:
return self._get_command_fn(settings)

def get_container(self, settings: SerializationSettings) -> _task_model.Container:
# if pod_template exists, return None
# but in get_k8s_pod, return pod_template merged with container
if self.pod_template is not None:
return None
else:
return self._get_container(settings)

def _get_container(self, settings: SerializationSettings) -> _task_model.Container:
env = {}
for elem in (settings.env, self.environment):
if elem:
Expand All @@ -179,6 +208,64 @@ def get_container(self, settings: SerializationSettings) -> _task_model.Containe
memory_limit=self.resources.limits.mem,
)

def _serialize_pod_spec(self, settings: SerializationSettings) -> Dict[str, Any]:
containers = self.pod_template.pod_spec.containers
primary_exists = False

for container in containers:
if container.name == self.pod_template.primary_container_name:
primary_exists = True
break

if not primary_exists:
# insert a placeholder primary container if it is not defined in the pod spec.
containers.append(V1Container(name=self.pod_template.primary_container_name))
final_containers = []
for container in containers:
# In the case of the primary container, we overwrite specific container attributes
# with the default values used in the regular Python task.
# The attributes include: image, command, args, resource, and env (env is unioned)
if container.name == self.pod_template.primary_container_name:
sdk_default_container = self._get_container(settings)
container.image = sdk_default_container.image
# clear existing commands
container.command = sdk_default_container.command
# also clear existing args
container.args = sdk_default_container.args
limits, requests = {}, {}
for resource in sdk_default_container.resources.limits:
limits[_sanitize_resource_name(resource)] = resource.value
for resource in sdk_default_container.resources.requests:
requests[_sanitize_resource_name(resource)] = resource.value
resource_requirements = V1ResourceRequirements(limits=limits, requests=requests)
if len(limits) > 0 or len(requests) > 0:
# Important! Only copy over resource requirements if they are non-empty.
container.resources = resource_requirements
container.env = [V1EnvVar(name=key, value=val) for key, val in sdk_default_container.env.items()] + (
container.env or []
)
final_containers.append(container)
self.pod_template.pod_spec.containers = final_containers

return ApiClient().sanitize_for_serialization(self.pod_template.pod_spec)

def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod:
if self.pod_template is None:
return None
return _task_model.K8sPod(
pod_spec=self._serialize_pod_spec(settings),
metadata=_task_model.K8sObjectMetadata(
labels=self.pod_template.labels,
annotations=self.pod_template.annotations,
),
)

# need to call super in all its children tasks
def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]]:
if self.pod_template is None:
return {}
return {_PRIMARY_CONTAINER_NAME_FIELD: self.pod_template.primary_container_name}


class DefaultTaskResolver(TrackedInstance, TaskResolverMixin):
"""
Expand Down
7 changes: 7 additions & 0 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from flytekit.core.base_task import TaskMetadata, TaskResolverMixin
from flytekit.core.interface import transform_function_to_interface
from flytekit.core.pod_template import PodTemplate
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.core.reference_entity import ReferenceEntity, TaskReference
from flytekit.core.resources import Resources
Expand Down Expand Up @@ -92,6 +93,8 @@ def task(
task_resolver: Optional[TaskResolverMixin] = None,
docs: Optional[Documentation] = None,
disable_deck: bool = True,
pod_template: Optional[PodTemplate] = None,
pod_template_name: Optional[str] = None,
) -> Union[Callable, PythonFunctionTask]:
"""
This is the core decorator to use for any task type in flytekit.
Expand Down Expand Up @@ -182,6 +185,8 @@ def foo2():
:param task_resolver: Provide a custom task resolver.
:param disable_deck: If true, this task will not output deck html file
:param docs: Documentation about this task
:param pod_template: custom PodTemplate.
:param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
"""

def wrapper(fn) -> PythonFunctionTask:
Expand All @@ -208,6 +213,8 @@ def wrapper(fn) -> PythonFunctionTask:
task_resolver=task_resolver,
disable_deck=disable_deck,
docs=docs,
pod_template=pod_template,
pod_template_name=pod_template_name,
)
update_wrapper(task_instance, fn)
return task_instance
Expand Down
14 changes: 14 additions & 0 deletions flytekit/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def __init__(
discovery_version,
deprecated_error_message,
cache_serializable,
pod_template_name,
):
"""
Information needed at runtime to determine behavior such as whether or not outputs are discoverable, timeouts,
Expand All @@ -196,6 +197,7 @@ def __init__(
receive deprecation warnings.
:param bool cache_serializable: Whether or not caching operations are executed in serial. This means only a
single instance over identical inputs is executed, other concurrent executions wait for the cached results.
:param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
"""
self._discoverable = discoverable
self._runtime = runtime
Expand All @@ -205,6 +207,7 @@ def __init__(
self._discovery_version = discovery_version
self._deprecated_error_message = deprecated_error_message
self._cache_serializable = cache_serializable
self._pod_template_name = pod_template_name

@property
def discoverable(self):
Expand Down Expand Up @@ -274,6 +277,15 @@ def cache_serializable(self):
"""
return self._cache_serializable

@property
def pod_template_name(self):
"""
The name of the existing PodTemplate resource which will be used in this task.
:rtype: Text
"""
# TODO: comment
ByronHsu marked this conversation as resolved.
Show resolved Hide resolved
return self._pod_template_name

def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.task_pb2.TaskMetadata
Expand All @@ -286,6 +298,7 @@ def to_flyte_idl(self):
discovery_version=self.discovery_version,
deprecated_error_message=self.deprecated_error_message,
cache_serializable=self.cache_serializable,
pod_template_name=self.pod_template_name,
)
if self.timeout:
tm.timeout.FromTimedelta(self.timeout)
Expand All @@ -306,6 +319,7 @@ def from_flyte_idl(cls, pb2_object):
discovery_version=pb2_object.discovery_version,
deprecated_error_message=pb2_object.deprecated_error_message,
cache_serializable=pb2_object.cache_serializable,
pod_template_name=pb2_object.pod_template_name,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
def get_config(self, settings: SerializationSettings) -> Dict[str, str]:
# Parameters in taskTemplate config will be used to create aws job definition.
# More detail about job definition: https://docs.aws.amazon.com/batch/latest/userguide/job_definition_parameters.html
return {"platformCapabilities": self._task_config.platformCapabilities}
return {**super().get_config(settings), "platformCapabilities": self._task_config.platformCapabilities}

def get_command(self, settings: SerializationSettings) -> List[str]:
container_args = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def fn(settings: SerializationSettings) -> typing.List[str]:
return self._config_task_instance.get_k8s_pod(settings)

def get_config(self, settings: SerializationSettings) -> typing.Dict[str, str]:
return self._config_task_instance.get_config(settings)
return {**super().get_config(settings), **self._config_task_instance.get_config(settings)}

def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
return self._config_task_instance.pre_execute(user_params)
Expand Down
Loading