Skip to content

Commit

Permalink
Merge branch 'master' into optional-flytefile-dataclass
Browse files Browse the repository at this point in the history
  • Loading branch information
eapolinario authored Dec 24, 2022
2 parents 7c20187 + 425f488 commit 92396b5
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 33 deletions.
2 changes: 1 addition & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ tomli==2.0.1
# coverage
# mypy
# pytest
torch==1.12.1
torch==1.13.1
# via -r dev-requirements.in
traitlets==5.6.0
# via
Expand Down
2 changes: 1 addition & 1 deletion flytekit/clis/sdk_in_container/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
help="Directory to write the output zip file containing the protobuf definitions",
)
@click.option(
"-d",
"-D",
"--destination-dir",
required=False,
type=str,
Expand Down
42 changes: 11 additions & 31 deletions plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Tuple, Union

from flyteidl.core import tasks_pb2 as _core_task
Expand All @@ -18,6 +19,7 @@ def _sanitize_resource_name(resource: _task_models.Resources.ResourceEntry) -> s
return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-")


@dataclass
class Pod(object):
"""
Pod is a platform-wide configuration that uses pod templates. By default, every task is launched as a container in a pod.
Expand All @@ -29,39 +31,17 @@ class Pod(object):
:param Optional[Dict[str, str]] annotations: Annotations are key/value pairs that are attached to arbitrary non-identifying metadata to pod spec.
"""

def __init__(
self,
pod_spec: V1PodSpec,
primary_container_name: str,
labels: Optional[Dict[str, str]] = None,
annotations: Optional[Dict[str, str]] = None,
):
if not pod_spec:
pod_spec: V1PodSpec
primary_container_name: str = _PRIMARY_CONTAINER_NAME_FIELD
labels: Optional[Dict[str, str]] = None
annotations: Optional[Dict[str, str]] = None

def __post_init_(self):
if not self.pod_spec:
raise _user_exceptions.FlyteValidationException("A pod spec cannot be undefined")
if not primary_container_name:
if not self.primary_container_name:
raise _user_exceptions.FlyteValidationException("A primary container name cannot be undefined")

self._pod_spec = pod_spec
self._primary_container_name = primary_container_name
self._labels = labels
self._annotations = annotations

@property
def pod_spec(self) -> V1PodSpec:
return self._pod_spec

@property
def primary_container_name(self) -> str:
return self._primary_container_name

@property
def labels(self) -> Optional[Dict[str, str]]:
return self._labels

@property
def annotations(self) -> Optional[Dict[str, str]]:
return self._annotations


class PodFunctionTask(PythonFunctionTask[Pod]):
def __init__(self, task_config: Pod, task_function: Callable, **kwargs):
Expand Down Expand Up @@ -114,7 +94,7 @@ def _serialize_pod_spec(self, settings: SerializationSettings) -> Dict[str, Any]

final_containers.append(container)

self.task_config._pod_spec.containers = final_containers
self.task_config.pod_spec.containers = final_containers

return ApiClient().sanitize_for_serialization(self.task_config.pod_spec)

Expand Down

0 comments on commit 92396b5

Please sign in to comment.