Skip to content

Commit

Permalink
Fix taskmetadata
Browse files Browse the repository at this point in the history
Signed-off-by: byhsu <[email protected]>
  • Loading branch information
ByronHsu committed Jan 26, 2023
1 parent f0b243a commit 32ea24b
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 4 deletions.
6 changes: 3 additions & 3 deletions flytekit/core/python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from kubernetes.client.models import V1EnvVar, V1ResourceRequirements

from flytekit.configuration import ImageConfig, SerializationSettings
from flytekit.core.base_task import PythonTask, TaskResolverMixin
from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.pod_template import PodTemplate
from flytekit.core.resources import Resources, ResourceSpec
Expand Down Expand Up @@ -86,8 +86,8 @@ def __init__(
sec_ctx = SecurityContext(secrets=secret_requests)

# pod_template_name overwrites the metedata.pod_template_name
if "metadata" in kwargs:
kwargs["metadata"].pod_template_name = pod_template_name
kwargs["metadata"] = kwargs["metadata"] if "metadata" in kwargs else TaskMetadata()
kwargs["metadata"].pod_template_name = pod_template_name

super().__init__(
task_type=task_type,
Expand Down
7 changes: 7 additions & 0 deletions tests/flytekit/unit/core/test_python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def test_pod_template(default_serialization_settings):
]
),
),
pod_template_name="A",
)


Expand Down Expand Up @@ -277,10 +278,16 @@ def test_minimum_pod_template(default_serialization_settings):
"task_with_minimum_pod_template",
]

#################
# Test pod_teamplte_name
#################
assert task_with_minimum_pod_template.metadata.pod_template_name == "A"

#################
# Test Serialization
#################
ts = get_serializable_task(default_serialization_settings, task_with_minimum_pod_template)
assert ts.template.container is None
# k8s_pod content is already verified above, so only check the existence here
assert ts.template.k8s_pod is not None
assert ts.template.metadata.pod_template_name == "A"
8 changes: 7 additions & 1 deletion tests/flytekit/unit/core/test_python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def test_pod_template():
]
),
),
pod_template_name="podTemplateA",
pod_template_name="A",
)
def func_with_pod_template(i: str):
print(i + 3)
Expand Down Expand Up @@ -192,10 +192,16 @@ def func_with_pod_template(i: str):
"func_with_pod_template",
]

#################
# Test pod_teamplte_name
#################
assert func_with_pod_template.metadata.pod_template_name == "A"

#################
# Test Serialization
#################
ts = get_serializable_task(default_serialization_settings, func_with_pod_template)
assert ts.template.container is None
# k8s_pod content is already verified above, so only check the existence here
assert ts.template.k8s_pod is not None
assert ts.template.metadata.pod_template_name == "A"

0 comments on commit 32ea24b

Please sign in to comment.