From 209fa65d61a232e3dbd2d06714222067faaa50bc Mon Sep 17 00:00:00 2001 From: Paul Dittamo <37558497+pvditt@users.noreply.github.com> Date: Mon, 2 Dec 2024 14:23:23 -0800 Subject: [PATCH] Update array node map task to support an additional plugin (#2934) * have get_custom return subtask and remove un-needed array job Signed-off-by: Paul Dittamo * set is_original_sub_node_interface + FULL_STATE for python function task extensions Signed-off-by: Paul Dittamo * lint Signed-off-by: Paul Dittamo * unit test Signed-off-by: Paul Dittamo * clean up Signed-off-by: Paul Dittamo * idl update Signed-off-by: Paul Dittamo * Revert "idl update" This reverts commit e98bc6e992734ade49f3e73c47f26bc495b27f3a. Signed-off-by: Paul Dittamo * pass in boolvalue Signed-off-by: Paul Dittamo * update unit test Signed-off-by: Paul Dittamo * Set flyteidl lower bound Signed-off-by: Eduardo Apolinario --------- Signed-off-by: Paul Dittamo Signed-off-by: Eduardo Apolinario Co-authored-by: Eduardo Apolinario --- flytekit/core/array_node.py | 4 ++ flytekit/core/array_node_map_task.py | 20 ++++++++- flytekit/models/core/workflow.py | 11 ++++- flytekit/tools/translator.py | 3 ++ pyproject.toml | 2 +- tests/flytekit/unit/core/test_array_node.py | 5 +++ .../unit/core/test_array_node_map_task.py | 41 +++++++++++++------ 7 files changed, 70 insertions(+), 16 deletions(-) diff --git a/flytekit/core/array_node.py b/flytekit/core/array_node.py index c83da43f87..0cb2c8d25c 100644 --- a/flytekit/core/array_node.py +++ b/flytekit/core/array_node.py @@ -217,6 +217,10 @@ def concurrency(self) -> Optional[int]: def execution_mode(self) -> _core_workflow.ArrayNode.ExecutionMode: return self._execution_mode + @property + def is_original_sub_node_interface(self) -> bool: + return True + def __call__(self, *args, **kwargs): if not self._bindings: ctx = FlyteContext.current_context() diff --git a/flytekit/core/array_node_map_task.py b/flytekit/core/array_node_map_task.py index 5fd184e7fd..1e2af4b3be 100644 --- a/flytekit/core/array_node_map_task.py +++ b/flytekit/core/array_node_map_task.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union, cast from flyteidl.core import tasks_pb2 +from flyteidl.core import workflow_pb2 as _core_workflow from flytekit.configuration import SerializationSettings from flytekit.core import tracker @@ -21,7 +22,6 @@ from flytekit.core.utils import timeit from flytekit.loggers import logger from flytekit.models import literals as _literal_models -from flytekit.models.array_job import ArrayJob from flytekit.models.core.workflow import NodeMetadata from flytekit.models.interface import Variable from flytekit.models.task import Container, K8sPod, Sql, Task @@ -106,6 +106,14 @@ def __init__( self._min_success_ratio: Optional[float] = min_success_ratio self._collection_interface = collection_interface + self._execution_mode: _core_workflow.ArrayNode.ExecutionMode = _core_workflow.ArrayNode.FULL_STATE + if ( + type(python_function_task) in {PythonFunctionTask, PythonInstanceTask} + or isinstance(python_function_task, functools.partial) + and type(python_function_task.func) in {PythonFunctionTask, PythonInstanceTask} + ): + self._execution_mode = _core_workflow.ArrayNode.MINIMAL_STATE + if "metadata" not in kwargs and actual_task.metadata: kwargs["metadata"] = actual_task.metadata if "security_ctx" not in kwargs and actual_task.security_context: @@ -154,6 +162,14 @@ def python_function_task(self) -> Union[PythonFunctionTask, PythonInstanceTask]: def bound_inputs(self) -> Set[str]: return self._bound_inputs + @property + def execution_mode(self) -> _core_workflow.ArrayNode.ExecutionMode: + return self._execution_mode + + @property + def is_original_sub_node_interface(self) -> bool: + return False + def get_extended_resources(self, settings: SerializationSettings) -> Optional[tasks_pb2.ExtendedResources]: return self.python_function_task.get_extended_resources(settings) @@ -169,7 +185,7 @@ def prepare_target(self): self.python_function_task.reset_command_fn() def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: - return ArrayJob(parallelism=self._concurrency, min_success_ratio=self._min_success_ratio).to_dict() + return self._run_task.get_custom(settings) or {} def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]]: return self.python_function_task.get_config(settings) diff --git a/flytekit/models/core/workflow.py b/flytekit/models/core/workflow.py index cadb33a434..8d8bf9c9ef 100644 --- a/flytekit/models/core/workflow.py +++ b/flytekit/models/core/workflow.py @@ -3,6 +3,7 @@ from flyteidl.core import tasks_pb2 from flyteidl.core import workflow_pb2 as _core_workflow +from google.protobuf.wrappers_pb2 import BoolValue from flytekit.models import common as _common from flytekit.models import interface as _interface @@ -382,7 +383,13 @@ def from_flyte_idl(cls, pb2_object: _core_workflow.GateNode) -> "GateNode": class ArrayNode(_common.FlyteIdlEntity): def __init__( - self, node: "Node", parallelism=None, min_successes=None, min_success_ratio=None, execution_mode=None + self, + node: "Node", + parallelism=None, + min_successes=None, + min_success_ratio=None, + execution_mode=None, + is_original_sub_node_interface=False, ) -> None: """ TODO: docstring @@ -393,6 +400,7 @@ def __init__( self._min_successes = min_successes self._min_success_ratio = min_success_ratio self._execution_mode = execution_mode + self._is_original_sub_node_interface = is_original_sub_node_interface @property def node(self) -> "Node": @@ -405,6 +413,7 @@ def to_flyte_idl(self) -> _core_workflow.ArrayNode: min_successes=self._min_successes, min_success_ratio=self._min_success_ratio, execution_mode=self._execution_mode, + is_original_sub_node_interface=BoolValue(value=self._is_original_sub_node_interface), ) @classmethod diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index 919fbed1a1..d4d8629265 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -596,6 +596,7 @@ def get_serializable_array_node( min_successes=array_node.min_successes, min_success_ratio=array_node.min_success_ratio, execution_mode=array_node.execution_mode, + is_original_sub_node_interface=array_node.is_original_sub_node_interface, ) @@ -629,6 +630,8 @@ def get_serializable_array_node_map_task( parallelism=entity.concurrency, min_successes=entity.min_successes, min_success_ratio=entity.min_success_ratio, + execution_mode=entity.execution_mode, + is_original_sub_node_interface=entity.is_original_sub_node_interface, ) diff --git a/pyproject.toml b/pyproject.toml index 44f88480b5..73d228d8af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ "diskcache>=5.2.1", "docker>=4.0.0", "docstring-parser>=0.9.0", - "flyteidl>=1.13.7", + "flyteidl>=1.13.9", "fsspec>=2023.3.0", "gcsfs>=2023.3.0", "googleapis-common-protos>=1.57", diff --git a/tests/flytekit/unit/core/test_array_node.py b/tests/flytekit/unit/core/test_array_node.py index 719e067b00..41f5e12bc9 100644 --- a/tests/flytekit/unit/core/test_array_node.py +++ b/tests/flytekit/unit/core/test_array_node.py @@ -2,6 +2,7 @@ from collections import OrderedDict import pytest +from flyteidl.core import workflow_pb2 as _core_workflow from flytekit import LaunchPlan, task, workflow from flytekit.core.context_manager import FlyteContextManager @@ -96,6 +97,10 @@ def test_lp_serialization(target, overrides_metadata, serialization_settings): parent_node = wf_spec.template.nodes[0] assert parent_node.inputs[0].var == "a" + assert parent_node.array_node._min_success_ratio == 0.9 + assert parent_node.array_node._parallelism == 10 + assert parent_node.array_node._is_original_sub_node_interface + assert parent_node.array_node._execution_mode == _core_workflow.ArrayNode.FULL_STATE assert len(parent_node.inputs[0].binding.collection.bindings) == 3 for binding in parent_node.inputs[0].binding.collection.bindings: assert binding.scalar.primitive.integer is not None diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index 7621de3076..f025929a13 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -1,30 +1,33 @@ import functools import os -import pathlib +import tempfile import typing from collections import OrderedDict from typing import List -from typing_extensions import Annotated -import tempfile import pytest +from flyteidl.core import workflow_pb2 as _core_workflow -from flytekit import dynamic, map_task, task, workflow -from flytekit.types.directory import FlyteDirectory +from flytekit import dynamic, map_task, task, workflow, PythonFunctionTask from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings from flytekit.core import context_manager from flytekit.core.array_node_map_task import ArrayNodeMapTask, ArrayNodeMapTaskResolver -from flytekit.core.python_auto_container import PICKLE_FILE_PATH from flytekit.core.task import TaskMetadata from flytekit.core.type_engine import TypeEngine -from flytekit.extras.accelerators import GPUAccelerator from flytekit.experimental.eager_function import eager +from flytekit.extras.accelerators import GPUAccelerator from flytekit.models.literals import ( Literal, LiteralMap, LiteralOffloadedMetadata, ) -from flytekit.tools.translator import get_serializable, Options +from flytekit.tools.translator import get_serializable +from flytekit.types.directory import FlyteDirectory + + +class PythonFunctionTaskExtension(PythonFunctionTask): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) @pytest.fixture @@ -117,7 +120,6 @@ def t1(a: int) -> int: task_spec = get_serializable(OrderedDict(), serialization_settings, arraynode_maptask) assert task_spec.template.metadata.retries.retries == 2 - assert task_spec.template.custom["minSuccessRatio"] == 1.0 assert task_spec.template.type == "python-task" assert task_spec.template.task_type_version == 1 assert task_spec.template.container.args == [ @@ -381,25 +383,40 @@ def wf(x: typing.List[int]): def test_serialization_metadata2(serialization_settings): @task - def t1(a: int) -> int: + def t1(a: int) -> typing.Optional[int]: return a + 1 - arraynode_maptask = map_task(t1, metadata=TaskMetadata(retries=2, interruptible=True)) + arraynode_maptask = map_task(t1, min_success_ratio=0.9, concurrency=10, metadata=TaskMetadata(retries=2, interruptible=True)) assert arraynode_maptask.metadata.interruptible @workflow def wf(x: typing.List[int]): return arraynode_maptask(a=x) + full_state_array_node_map_task = map_task(PythonFunctionTaskExtension(task_config={}, task_function=t1)) + + @workflow + def wf1(x: typing.List[int]): + return full_state_array_node_map_task(a=x) + od = OrderedDict() wf_spec = get_serializable(od, serialization_settings, wf) assert arraynode_maptask.construct_node_metadata().interruptible - assert wf_spec.template.nodes[0].metadata.interruptible + array_node = wf_spec.template.nodes[0] + assert array_node.metadata.interruptible + assert array_node.array_node._min_success_ratio == 0.9 + assert array_node.array_node._parallelism == 10 + assert not array_node.array_node._is_original_sub_node_interface + assert array_node.array_node._execution_mode == _core_workflow.ArrayNode.MINIMAL_STATE task_spec = od[arraynode_maptask] assert task_spec.template.metadata.retries.retries == 2 assert task_spec.template.metadata.interruptible + wf1_spec = get_serializable(od, serialization_settings, wf1) + array_node = wf1_spec.template.nodes[0] + assert array_node.array_node._execution_mode == _core_workflow.ArrayNode.FULL_STATE + def test_serialization_extended_resources(serialization_settings): @task(