Skip to content

Commit

Permalink
Add pod_template and pod_template_name arguments for `PythonAutoC…
Browse files Browse the repository at this point in the history
…ontainerTask`, its downstream tasks, and `@task`. (#1425)

* Add `pod_template` and `pod_template_name` arguments for `PythonAutoContainerTask`, its downstream tasks, and `@task`

Signed-off-by: byhsu <[email protected]>

* clean

Signed-off-by: byhsu <[email protected]>

* fix test

Signed-off-by: byhsu <[email protected]>

* Fix taskmetadata

Signed-off-by: byhsu <[email protected]>

* add kubernetes in setup.py

Signed-off-by: byhsu <[email protected]>

* address comments

Signed-off-by: byhsu <[email protected]>

* Regenerate requirements using python 3.7

Signed-off-by: Eduardo Apolinario <[email protected]>
Signed-off-by: byhsu <[email protected]>

* keep container validation

Signed-off-by: byhsu <[email protected]>

* bump idl version

Signed-off-by: byhsu <[email protected]>

* Regenerate requirements using python 3.7

Signed-off-by: Eduardo Apolinario <[email protected]>

* Regenerate doc-requirements.txt

Signed-off-by: Eduardo Apolinario <[email protected]>

* fix

Signed-off-by: byhsu <[email protected]>

---------

Signed-off-by: byhsu <[email protected]>
Signed-off-by: Eduardo Apolinario <[email protected]>
Co-authored-by: byhsu <[email protected]>
Co-authored-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
3 people committed Feb 22, 2023
1 parent 86fe324 commit 59574e5
Show file tree
Hide file tree
Showing 20 changed files with 885 additions and 334 deletions.
156 changes: 103 additions & 53 deletions dev-requirements.txt

Large diffs are not rendered by default.

378 changes: 173 additions & 205 deletions doc-requirements.txt

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
TaskMetadata - Wrapper object that allows users to specify Task
Resources - Things like CPUs/Memory, etc.
WorkflowFailurePolicy - Customizes what happens when a workflow fails.
PodTemplate - Custom PodTemplate for a task.
Dynamic and Nested Workflows
==============================
Expand Down Expand Up @@ -176,6 +176,7 @@
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.map_task import map_task
from flytekit.core.notification import Email, PagerDuty, Slack
from flytekit.core.pod_template import PodTemplate
from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask
from flytekit.core.reference import get_reference_entity
from flytekit.core.reference_entity import LaunchPlanReference, TaskReference, WorkflowReference
Expand Down
3 changes: 3 additions & 0 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class TaskMetadata(object):
timeout (Optional[Union[datetime.timedelta, int]]): the max amount of time for which one execution of this task
should be executed for. The execution will be terminated if the runtime exceeds the given timeout
(approximately)
pod_template_name (Optional[str]): the name of existing PodTemplate resource in the cluster which will be used in this task.
"""

cache: bool = False
Expand All @@ -94,6 +95,7 @@ class TaskMetadata(object):
deprecated: str = ""
retries: int = 0
timeout: Optional[Union[datetime.timedelta, int]] = None
pod_template_name: Optional[str] = None

def __post_init__(self):
if self.timeout:
Expand Down Expand Up @@ -127,6 +129,7 @@ def to_taskmetadata_model(self) -> _task_model.TaskMetadata:
discovery_version=self.cache_version,
deprecated_error_message=self.deprecated,
cache_serializable=self.cache_serialize,
pod_template_name=self.pod_template_name,
)


Expand Down
1 change: 1 addition & 0 deletions flytekit/core/container_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from flytekit.models.security import Secret, SecurityContext


# TODO: do we need pod_template here? Seems that it is a raw container not running in pods
class ContainerTask(PythonTask):
"""
This is an intermediate class that represents Flyte Tasks that run a container at execution time. This is the vast
Expand Down
20 changes: 20 additions & 0 deletions flytekit/core/pod_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from dataclasses import dataclass
from typing import Dict, Optional

from kubernetes.client.models import V1PodSpec

from flytekit.exceptions import user as _user_exceptions

PRIMARY_CONTAINER_DEFAULT_NAME = "primary"


@dataclass
class PodTemplate(object):
pod_spec: V1PodSpec = V1PodSpec(containers=[])
primary_container_name: str = PRIMARY_CONTAINER_DEFAULT_NAME
labels: Optional[Dict[str, str]] = None
annotations: Optional[Dict[str, str]] = None

def __post_init__(self):
if not self.primary_container_name:
raise _user_exceptions.FlyteValidationException("A primary container name cannot be undefined")
90 changes: 88 additions & 2 deletions flytekit/core/python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@
import re
from abc import ABC
from types import ModuleType
from typing import Callable, Dict, List, Optional, TypeVar, Union
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union

from flyteidl.core import tasks_pb2 as _core_task
from kubernetes.client import ApiClient
from kubernetes.client.models import V1Container, V1EnvVar, V1ResourceRequirements

from flytekit.configuration import ImageConfig, SerializationSettings
from flytekit.core.base_task import PythonTask, TaskResolverMixin
from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.pod_template import PodTemplate
from flytekit.core.resources import Resources, ResourceSpec
from flytekit.core.tracked_abc import FlyteTrackedABC
from flytekit.core.tracker import TrackedInstance, extract_task_module
Expand All @@ -18,6 +23,11 @@
from flytekit.models.security import Secret, SecurityContext

T = TypeVar("T")
_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"


def _sanitize_resource_name(resource: _task_model.Resources.ResourceEntry) -> str:
return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-")


class PythonAutoContainerTask(PythonTask[T], ABC, metaclass=FlyteTrackedABC):
Expand All @@ -40,6 +50,8 @@ def __init__(
environment: Optional[Dict[str, str]] = None,
task_resolver: Optional[TaskResolverMixin] = None,
secret_requests: Optional[List[Secret]] = None,
pod_template: Optional[PodTemplate] = None,
pod_template_name: Optional[str] = None,
**kwargs,
):
"""
Expand All @@ -64,13 +76,20 @@ def __init__(
- `Confidant <https://lyft.github.io/confidant/>`__
- `Kube secrets <https://kubernetes.io/docs/concepts/configuration/secret/>`__
- `AWS Parameter store <https://docs.aws.amazon.com/systems-manager/latest/userguide/systems-manager-parameter-store.html>`__
:param pod_template: Custom PodTemplate for this task.
:param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
"""
sec_ctx = None
if secret_requests:
for s in secret_requests:
if not isinstance(s, Secret):
raise AssertionError(f"Secret {s} should be of type flytekit.Secret, received {type(s)}")
sec_ctx = SecurityContext(secrets=secret_requests)

# pod_template_name overwrites the metedata.pod_template_name
kwargs["metadata"] = kwargs["metadata"] if "metadata" in kwargs else TaskMetadata()
kwargs["metadata"].pod_template_name = pod_template_name

super().__init__(
task_type=task_type,
name=name,
Expand Down Expand Up @@ -98,6 +117,8 @@ def __init__(
self._task_resolver = task_resolver or default_task_resolver
self._get_command_fn = self.get_default_command

self.pod_template = pod_template

@property
def task_resolver(self) -> Optional[TaskResolverMixin]:
return self._task_resolver
Expand Down Expand Up @@ -157,6 +178,13 @@ def get_command(self, settings: SerializationSettings) -> List[str]:
return self._get_command_fn(settings)

def get_container(self, settings: SerializationSettings) -> _task_model.Container:
# if pod_template is not None, return None here but in get_k8s_pod, return pod_template merged with container
if self.pod_template is not None:
return None
else:
return self._get_container(settings)

def _get_container(self, settings: SerializationSettings) -> _task_model.Container:
env = {}
for elem in (settings.env, self.environment):
if elem:
Expand All @@ -179,6 +207,64 @@ def get_container(self, settings: SerializationSettings) -> _task_model.Containe
memory_limit=self.resources.limits.mem,
)

def _serialize_pod_spec(self, settings: SerializationSettings) -> Dict[str, Any]:
containers = self.pod_template.pod_spec.containers
primary_exists = False

for container in containers:
if container.name == self.pod_template.primary_container_name:
primary_exists = True
break

if not primary_exists:
# insert a placeholder primary container if it is not defined in the pod spec.
containers.append(V1Container(name=self.pod_template.primary_container_name))
final_containers = []
for container in containers:
# In the case of the primary container, we overwrite specific container attributes
# with the default values used in the regular Python task.
# The attributes include: image, command, args, resource, and env (env is unioned)
if container.name == self.pod_template.primary_container_name:
sdk_default_container = self._get_container(settings)
container.image = sdk_default_container.image
# clear existing commands
container.command = sdk_default_container.command
# also clear existing args
container.args = sdk_default_container.args
limits, requests = {}, {}
for resource in sdk_default_container.resources.limits:
limits[_sanitize_resource_name(resource)] = resource.value
for resource in sdk_default_container.resources.requests:
requests[_sanitize_resource_name(resource)] = resource.value
resource_requirements = V1ResourceRequirements(limits=limits, requests=requests)
if len(limits) > 0 or len(requests) > 0:
# Important! Only copy over resource requirements if they are non-empty.
container.resources = resource_requirements
container.env = [V1EnvVar(name=key, value=val) for key, val in sdk_default_container.env.items()] + (
container.env or []
)
final_containers.append(container)
self.pod_template.pod_spec.containers = final_containers

return ApiClient().sanitize_for_serialization(self.pod_template.pod_spec)

def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod:
if self.pod_template is None:
return None
return _task_model.K8sPod(
pod_spec=self._serialize_pod_spec(settings),
metadata=_task_model.K8sObjectMetadata(
labels=self.pod_template.labels,
annotations=self.pod_template.annotations,
),
)

# need to call super in all its children tasks
def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]]:
if self.pod_template is None:
return {}
return {_PRIMARY_CONTAINER_NAME_FIELD: self.pod_template.primary_container_name}


class DefaultTaskResolver(TrackedInstance, TaskResolverMixin):
"""
Expand Down
7 changes: 7 additions & 0 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from flytekit.core.base_task import TaskMetadata, TaskResolverMixin
from flytekit.core.interface import transform_function_to_interface
from flytekit.core.pod_template import PodTemplate
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.core.reference_entity import ReferenceEntity, TaskReference
from flytekit.core.resources import Resources
Expand Down Expand Up @@ -92,6 +93,8 @@ def task(
task_resolver: Optional[TaskResolverMixin] = None,
docs: Optional[Documentation] = None,
disable_deck: bool = True,
pod_template: Optional[PodTemplate] = None,
pod_template_name: Optional[str] = None,
) -> Union[Callable, PythonFunctionTask]:
"""
This is the core decorator to use for any task type in flytekit.
Expand Down Expand Up @@ -182,6 +185,8 @@ def foo2():
:param task_resolver: Provide a custom task resolver.
:param disable_deck: If true, this task will not output deck html file
:param docs: Documentation about this task
:param pod_template: Custom PodTemplate for this task.
:param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
"""

def wrapper(fn) -> PythonFunctionTask:
Expand All @@ -208,6 +213,8 @@ def wrapper(fn) -> PythonFunctionTask:
task_resolver=task_resolver,
disable_deck=disable_deck,
docs=docs,
pod_template=pod_template,
pod_template_name=pod_template_name,
)
update_wrapper(task_instance, fn)
return task_instance
Expand Down
13 changes: 13 additions & 0 deletions flytekit/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def __init__(
discovery_version,
deprecated_error_message,
cache_serializable,
pod_template_name,
):
"""
Information needed at runtime to determine behavior such as whether or not outputs are discoverable, timeouts,
Expand All @@ -196,6 +197,7 @@ def __init__(
receive deprecation warnings.
:param bool cache_serializable: Whether or not caching operations are executed in serial. This means only a
single instance over identical inputs is executed, other concurrent executions wait for the cached results.
:param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
"""
self._discoverable = discoverable
self._runtime = runtime
Expand All @@ -205,6 +207,7 @@ def __init__(
self._discovery_version = discovery_version
self._deprecated_error_message = deprecated_error_message
self._cache_serializable = cache_serializable
self._pod_template_name = pod_template_name

@property
def discoverable(self):
Expand Down Expand Up @@ -274,6 +277,14 @@ def cache_serializable(self):
"""
return self._cache_serializable

@property
def pod_template_name(self):
"""
The name of the existing PodTemplate resource which will be used in this task.
:rtype: Text
"""
return self._pod_template_name

def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.task_pb2.TaskMetadata
Expand All @@ -286,6 +297,7 @@ def to_flyte_idl(self):
discovery_version=self.discovery_version,
deprecated_error_message=self.deprecated_error_message,
cache_serializable=self.cache_serializable,
pod_template_name=self.pod_template_name,
)
if self.timeout:
tm.timeout.FromTimedelta(self.timeout)
Expand All @@ -306,6 +318,7 @@ def from_flyte_idl(cls, pb2_object):
discovery_version=pb2_object.discovery_version,
deprecated_error_message=pb2_object.deprecated_error_message,
cache_serializable=pb2_object.cache_serializable,
pod_template_name=pb2_object.pod_template_name,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
def get_config(self, settings: SerializationSettings) -> Dict[str, str]:
# Parameters in taskTemplate config will be used to create aws job definition.
# More detail about job definition: https://docs.aws.amazon.com/batch/latest/userguide/job_definition_parameters.html
return {"platformCapabilities": self._task_config.platformCapabilities}
return {**super().get_config(settings), "platformCapabilities": self._task_config.platformCapabilities}

def get_command(self, settings: SerializationSettings) -> List[str]:
container_args = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def fn(settings: SerializationSettings) -> typing.List[str]:
return self._config_task_instance.get_k8s_pod(settings)

def get_config(self, settings: SerializationSettings) -> typing.Dict[str, str]:
return self._config_task_instance.get_config(settings)
return {**super().get_config(settings), **self._config_task_instance.get_config(settings)}

def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
return self._config_task_instance.pre_execute(user_params)
Expand Down
Loading

0 comments on commit 59574e5

Please sign in to comment.