From 676914b4720c6558b6750f068a6eaf631c2efdaf Mon Sep 17 00:00:00 2001 From: Paul Dittamo <37558497+pvditt@users.noreply.github.com> Date: Wed, 31 Jul 2024 09:59:04 -0700 Subject: [PATCH] Support ArrayNode mapping over Launch Plans (#2480) * set up array node Signed-off-by: Paul Dittamo * wip array node task wrapper Signed-off-by: Paul Dittamo * support function like callability Signed-off-by: Paul Dittamo * temp check in some progress on python func wrapper Signed-off-by: Paul Dittamo * only support launch plans in new array node class for now Signed-off-by: Paul Dittamo * add map task array node implementation wrapper Signed-off-by: Paul Dittamo * ArrayNode only supports LPs for now Signed-off-by: Paul Dittamo * support local execute for new array node implementation Signed-off-by: Paul Dittamo * add local execute unit tests for array node Signed-off-by: Paul Dittamo * set exeucution version in array node spec Signed-off-by: Paul Dittamo * check input types for local execute Signed-off-by: Paul Dittamo * remove code that is un-needed for now Signed-off-by: Paul Dittamo * clean up array node class Signed-off-by: Paul Dittamo * improve naming Signed-off-by: Paul Dittamo * clean up Signed-off-by: Paul Dittamo * utilize enum execution mode to set array node execution path Signed-off-by: Paul Dittamo * default execution mode to FULL_STATE for new array node class Signed-off-by: Paul Dittamo * support min_successes for new array node Signed-off-by: Paul Dittamo * add map task wrapper unit test Signed-off-by: Paul Dittamo * set min successes for array node map task wrapper Signed-off-by: Paul Dittamo * update docstrings Signed-off-by: Paul Dittamo * Install flyteidl from master in plugins tests Signed-off-by: Eduardo Apolinario * lint Signed-off-by: Paul Dittamo * clean up min success/ratio setting Signed-off-by: Paul Dittamo * lint Signed-off-by: Paul Dittamo * make array node class callable Signed-off-by: Paul Dittamo --------- Signed-off-by: Paul Dittamo Signed-off-by: Eduardo Apolinario Co-authored-by: Eduardo Apolinario --- flytekit/core/array_node.py | 226 ++++++++++++++++++++ flytekit/core/array_node_map_task.py | 37 ++++ flytekit/models/core/workflow.py | 6 +- flytekit/remote/remote.py | 1 + flytekit/tools/translator.py | 34 ++- tests/flytekit/unit/core/test_array_node.py | 104 +++++++++ 6 files changed, 405 insertions(+), 3 deletions(-) create mode 100644 flytekit/core/array_node.py create mode 100644 tests/flytekit/unit/core/test_array_node.py diff --git a/flytekit/core/array_node.py b/flytekit/core/array_node.py new file mode 100644 index 0000000000..a7cea7ff32 --- /dev/null +++ b/flytekit/core/array_node.py @@ -0,0 +1,226 @@ +import math +from typing import Any, List, Optional, Set, Tuple, Union + +from flyteidl.core import workflow_pb2 as _core_workflow + +from flytekit.core import interface as flyte_interface +from flytekit.core.context_manager import ExecutionState, FlyteContext +from flytekit.core.interface import transform_interface_to_list_interface, transform_interface_to_typed_interface +from flytekit.core.launch_plan import LaunchPlan +from flytekit.core.node import Node +from flytekit.core.promise import ( + Promise, + VoidPromise, + flyte_entity_call_handler, + translate_inputs_to_literals, +) +from flytekit.core.task import TaskMetadata +from flytekit.loggers import logger +from flytekit.models import literals as _literal_models +from flytekit.models.core import workflow as _workflow_model +from flytekit.models.literals import Literal, LiteralCollection, Scalar + + +class ArrayNode: + def __init__( + self, + target: LaunchPlan, + execution_mode: _core_workflow.ArrayNode.ExecutionMode = _core_workflow.ArrayNode.FULL_STATE, + concurrency: Optional[int] = None, + min_successes: Optional[int] = None, + min_success_ratio: Optional[float] = None, + bound_inputs: Optional[Set[str]] = None, + metadata: Optional[Union[_workflow_model.NodeMetadata, TaskMetadata]] = None, + ): + """ + :param target: The target Flyte entity to map over + :param concurrency: If specified, this limits the number of mapped tasks than can run in parallel to the given batch + size. If the size of the input exceeds the concurrency value, then multiple batches will be run serially until + all inputs are processed. If set to 0, this means unbounded concurrency. If left unspecified, this means the + array node will inherit parallelism from the workflow + :param min_successes: The minimum number of successful executions. If set, this takes precedence over + min_success_ratio + :param min_success_ratio: The minimum ratio of successful executions. + :param bound_inputs: The set of inputs that should be bound to the map task + :param execution_mode: The execution mode for propeller to use when handling ArrayNode + :param metadata: The metadata for the underlying entity + """ + self.target = target + self._concurrency = concurrency + self._execution_mode = execution_mode + self.id = target.name + + if min_successes is not None: + self._min_successes = min_successes + self._min_success_ratio = None + else: + self._min_success_ratio = min_success_ratio if min_success_ratio is not None else 1.0 + self._min_successes = 0 + + n_outputs = len(self.target.python_interface.outputs) + if n_outputs > 1: + raise ValueError("Only tasks with a single output are supported in map tasks.") + + self._bound_inputs: Set[str] = bound_inputs or set(bound_inputs) if bound_inputs else set() + + output_as_list_of_optionals = min_success_ratio is not None and min_success_ratio != 1 and n_outputs == 1 + collection_interface = transform_interface_to_list_interface( + self.target.python_interface, self._bound_inputs, output_as_list_of_optionals + ) + self._collection_interface = collection_interface + + self.metadata = None + if isinstance(target, LaunchPlan): + if self._execution_mode != _core_workflow.ArrayNode.FULL_STATE: + raise ValueError("Only execution version 1 is supported for LaunchPlans.") + if metadata: + if isinstance(metadata, _workflow_model.NodeMetadata): + self.metadata = metadata + else: + raise Exception("Invalid metadata for LaunchPlan. Should be NodeMetadata.") + else: + raise Exception("Only LaunchPlans are supported for now.") + + def construct_node_metadata(self) -> _workflow_model.NodeMetadata: + # Part of SupportsNodeCreation interface + # TODO - include passed in metadata + return _workflow_model.NodeMetadata(name=self.target.name) + + @property + def name(self) -> str: + # Part of SupportsNodeCreation interface + return self.target.name + + @property + def python_interface(self) -> flyte_interface.Interface: + # Part of SupportsNodeCreation interface + return self._collection_interface + + @property + def bindings(self) -> List[_literal_models.Binding]: + # Required in get_serializable_node + return [] + + @property + def upstream_nodes(self) -> List[Node]: + # Required in get_serializable_node + return [] + + @property + def flyte_entity(self) -> Any: + return self.target + + def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]: + outputs_expected = True + if not self.python_interface.outputs: + outputs_expected = False + + mapped_entity_count = 0 + for k in self.python_interface.inputs.keys(): + if k not in self._bound_inputs: + v = kwargs[k] + if isinstance(v, list) and len(v) > 0 and isinstance(v[0], self.target.python_interface.inputs[k]): + mapped_entity_count = len(v) + break + else: + raise ValueError( + f"Expected a list of {self.target.python_interface.inputs[k]} but got {type(v)} instead." + ) + + failed_count = 0 + min_successes = mapped_entity_count + if self._min_successes: + min_successes = self._min_successes + elif self._min_success_ratio: + min_successes = math.ceil(min_successes * self._min_success_ratio) + + literals = [] + for i in range(mapped_entity_count): + single_instance_inputs = {} + for k in self.python_interface.inputs.keys(): + if k not in self._bound_inputs: + single_instance_inputs[k] = kwargs[k][i] + else: + single_instance_inputs[k] = kwargs[k] + + # translate Python native inputs to Flyte literals + typed_interface = transform_interface_to_typed_interface(self.target.python_interface) + literal_map = translate_inputs_to_literals( + ctx, + incoming_values=single_instance_inputs, + flyte_interface_types={} if typed_interface is None else typed_interface.inputs, + native_types=self.target.python_interface.inputs, + ) + kwargs_literals = {k1: Promise(var=k1, val=v1) for k1, v1 in literal_map.items()} + + try: + output = self.target.__call__(**kwargs_literals) + if outputs_expected: + literals.append(output.val) + except Exception as exc: + if outputs_expected: + literal_with_none = Literal(scalar=Scalar(none_type=_literal_models.Void())) + literals.append(literal_with_none) + failed_count += 1 + if mapped_entity_count - failed_count < min_successes: + logger.error("The number of successful tasks is lower than the minimum") + raise exc + + if outputs_expected: + return Promise(var="o0", val=Literal(collection=LiteralCollection(literals=literals))) + return VoidPromise(self.name) + + def local_execution_mode(self): + return ExecutionState.Mode.LOCAL_TASK_EXECUTION + + @property + def min_success_ratio(self) -> Optional[float]: + return self._min_success_ratio + + @property + def min_successes(self) -> Optional[int]: + return self._min_successes + + @property + def concurrency(self) -> Optional[int]: + return self._concurrency + + @property + def execution_mode(self) -> _core_workflow.ArrayNode.ExecutionMode: + return self._execution_mode + + def __call__(self, *args, **kwargs): + return flyte_entity_call_handler(self, *args, **kwargs) + + +def array_node( + target: Union[LaunchPlan], + concurrency: Optional[int] = None, + min_success_ratio: Optional[float] = None, + min_successes: Optional[int] = None, +): + """ + ArrayNode implementation that maps over tasks and other Flyte entities + + :param target: The target Flyte entity to map over + :param concurrency: If specified, this limits the number of mapped tasks than can run in parallel to the given batch + size. If the size of the input exceeds the concurrency value, then multiple batches will be run serially until + all inputs are processed. If set to 0, this means unbounded concurrency. If left unspecified, this means the + array node will inherit parallelism from the workflow + :param min_successes: The minimum number of successful executions. If set, this takes precedence over + min_success_ratio + :param min_success_ratio: The minimum ratio of successful executions + :return: A callable function that takes in keyword arguments and returns a Promise created by + flyte_entity_call_handler + """ + if not isinstance(target, LaunchPlan): + raise ValueError("Only LaunchPlans are supported for now.") + + node = ArrayNode( + target=target, + concurrency=concurrency, + min_successes=min_successes, + min_success_ratio=min_success_ratio, + ) + + return node diff --git a/flytekit/core/array_node_map_task.py b/flytekit/core/array_node_map_task.py index 575654b57d..337716eb08 100644 --- a/flytekit/core/array_node_map_task.py +++ b/flytekit/core/array_node_map_task.py @@ -12,9 +12,11 @@ from flytekit.configuration import SerializationSettings from flytekit.core import tracker +from flytekit.core.array_node import array_node from flytekit.core.base_task import PythonTask, TaskResolverMixin from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.interface import transform_interface_to_list_interface +from flytekit.core.launch_plan import LaunchPlan from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask from flytekit.core.type_engine import TypeEngine, is_annotated from flytekit.core.utils import timeit @@ -347,6 +349,41 @@ def _raw_execute(self, **kwargs) -> Any: def map_task( + target: Union[LaunchPlan, PythonFunctionTask], + concurrency: Optional[int] = None, + min_successes: Optional[int] = None, + min_success_ratio: float = 1.0, + **kwargs, +): + """ + Wrapper that creates a map task utilizing either the existing ArrayNodeMapTask + or the drop in replacement ArrayNode implementation + + :param target: The Flyte entity of which will be mapped over + :param concurrency: If specified, this limits the number of mapped tasks than can run in parallel to the given batch + size. If the size of the input exceeds the concurrency value, then multiple batches will be run serially until + all inputs are processed. If set to 0, this means unbounded concurrency. If left unspecified, this means the + array node will inherit parallelism from the workflow + :param min_successes: The minimum number of successful executions + :param min_success_ratio: The minimum ratio of successful executions + """ + if isinstance(target, LaunchPlan): + return array_node( + target=target, + concurrency=concurrency, + min_successes=min_successes, + min_success_ratio=min_success_ratio, + ) + return array_node_map_task( + task_function=target, + concurrency=concurrency, + min_successes=min_successes, + min_success_ratio=min_success_ratio, + **kwargs, + ) + + +def array_node_map_task( task_function: PythonFunctionTask, concurrency: Optional[int] = None, # TODO why no min_successes? diff --git a/flytekit/models/core/workflow.py b/flytekit/models/core/workflow.py index 44fe7e1f44..28a9fbc091 100644 --- a/flytekit/models/core/workflow.py +++ b/flytekit/models/core/workflow.py @@ -381,7 +381,9 @@ def from_flyte_idl(cls, pb2_object: _core_workflow.GateNode) -> "GateNode": class ArrayNode(_common.FlyteIdlEntity): - def __init__(self, node: "Node", parallelism=None, min_successes=None, min_success_ratio=None) -> None: + def __init__( + self, node: "Node", parallelism=None, min_successes=None, min_success_ratio=None, execution_mode=None + ) -> None: """ TODO: docstring """ @@ -390,6 +392,7 @@ def __init__(self, node: "Node", parallelism=None, min_successes=None, min_succe # TODO either min_successes or min_success_ratio should be set self._min_successes = min_successes self._min_success_ratio = min_success_ratio + self._execution_mode = execution_mode @property def node(self) -> "Node": @@ -401,6 +404,7 @@ def to_flyte_idl(self) -> _core_workflow.ArrayNode: parallelism=self._parallelism, min_successes=self._min_successes, min_success_ratio=self._min_success_ratio, + execution_mode=self._execution_mode, ) @classmethod diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 66b1ae54b6..1406e6a560 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -667,6 +667,7 @@ def raw_register( workflow_model.WorkflowNode, workflow_model.BranchNode, workflow_model.TaskNode, + workflow_model.ArrayNode, ), ): return None diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index a77e0a0bf5..5f34732600 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -10,6 +10,7 @@ from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core import constants as _common_constants from flytekit.core import context_manager +from flytekit.core.array_node import ArrayNode from flytekit.core.array_node_map_task import ArrayNodeMapTask from flytekit.core.base_task import PythonTask from flytekit.core.condition import BranchNode @@ -49,6 +50,7 @@ ReferenceTask, ReferenceLaunchPlan, ReferenceEntity, + ArrayNode, ] FlyteControlPlaneEntity = Union[ TaskSpec, @@ -471,15 +473,24 @@ def get_serializable_node( from flytekit.remote import FlyteLaunchPlan, FlyteTask, FlyteWorkflow - if isinstance(entity.flyte_entity, ArrayNodeMapTask): + if isinstance(entity.flyte_entity, ArrayNode): node_model = workflow_model.Node( id=_dnsify(entity.id), - metadata=entity.metadata, + metadata=entity.flyte_entity.construct_node_metadata(), inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], array_node=get_serializable_array_node(entity_mapping, settings, entity, options=options), ) + elif isinstance(entity.flyte_entity, ArrayNodeMapTask): + node_model = workflow_model.Node( + id=_dnsify(entity.id), + metadata=entity.metadata, + inputs=entity.bindings, + upstream_node_ids=[n.id for n in upstream_nodes], + output_aliases=[], + array_node=get_serializable_array_node_map_task(entity_mapping, settings, entity, options=options), + ) # TODO: do I need this? # if entity._aliases: # node_model._output_aliases = entity._aliases @@ -617,6 +628,22 @@ def get_serializable_node( def get_serializable_array_node( + entity_mapping: OrderedDict, + settings: SerializationSettings, + node: FlyteLocalEntity, + options: Optional[Options] = None, +) -> ArrayNodeModel: + array_node = node.flyte_entity + return ArrayNodeModel( + node=get_serializable_node(entity_mapping, settings, array_node, options=options), + parallelism=array_node.concurrency, + min_successes=array_node.min_successes, + min_success_ratio=array_node.min_success_ratio, + execution_mode=array_node.execution_mode, + ) + + +def get_serializable_array_node_map_task( entity_mapping: OrderedDict, settings: SerializationSettings, node: Node, @@ -790,6 +817,9 @@ def get_serializable( elif isinstance(entity, FlyteLaunchPlan): cp_entity = entity + elif isinstance(entity, ArrayNode): + cp_entity = get_serializable_array_node(entity_mapping, settings, entity, options) + else: raise Exception(f"Non serializable type found {type(entity)} Entity {entity}") diff --git a/tests/flytekit/unit/core/test_array_node.py b/tests/flytekit/unit/core/test_array_node.py new file mode 100644 index 0000000000..f7704d4afd --- /dev/null +++ b/tests/flytekit/unit/core/test_array_node.py @@ -0,0 +1,104 @@ +import typing +from collections import OrderedDict + +import pytest + +from flytekit import LaunchPlan, current_context, task, workflow +from flytekit.configuration import Image, ImageConfig, SerializationSettings +from flytekit.core.array_node import array_node +from flytekit.core.array_node_map_task import map_task +from flytekit.models.core import identifier as identifier_models +from flytekit.tools.translator import get_serializable + + +@pytest.fixture +def serialization_settings(): + default_img = Image(name="default", fqn="test", tag="tag") + return SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + + +@task +def multiply(val: int, val1: int) -> int: + return val * val1 + + +@workflow +def parent_wf(a: int, b: int) -> int: + return multiply(val=a, val1=b) + + +lp = LaunchPlan.get_default_launch_plan(current_context(), parent_wf) + + +@workflow +def grandparent_wf() -> list[int]: + return array_node(lp, concurrency=10, min_success_ratio=0.9)(a=[1, 3, 5], b=[2, 4, 6]) + + +def test_lp_serialization(serialization_settings): + + wf_spec = get_serializable(OrderedDict(), serialization_settings, grandparent_wf) + assert len(wf_spec.template.nodes) == 1 + assert wf_spec.template.nodes[0].array_node is not None + assert wf_spec.template.nodes[0].array_node.node is not None + assert wf_spec.template.nodes[0].array_node.node.workflow_node is not None + assert ( + wf_spec.template.nodes[0].array_node.node.workflow_node.launchplan_ref.resource_type + == identifier_models.ResourceType.LAUNCH_PLAN + ) + assert wf_spec.template.nodes[0].array_node.node.workflow_node.launchplan_ref.name == "tests.flytekit.unit.core.test_array_node.parent_wf" + assert wf_spec.template.nodes[0].array_node._min_success_ratio == 0.9 + assert wf_spec.template.nodes[0].array_node._parallelism == 10 + + +@pytest.mark.parametrize( + "min_successes, min_success_ratio, should_raise_error", + [ + (None, None, True), + (None, 1, True), + (None, 0.75, False), + (None, 0.5, False), + (1, None, False), + (3, None, False), + (4, None, True), + # Test min_successes takes precedence over min_success_ratio + (1, 1.0, False), + (4, 0.1, True), + ], +) +def test_local_exec_lp_min_successes(min_successes, min_success_ratio, should_raise_error): + @task + def ex_task(val: int) -> int: + if val == 1: + raise Exception("Test") + return val + + @workflow + def ex_wf(val: int) -> int: + return ex_task(val=val) + + ex_lp = LaunchPlan.get_default_launch_plan(current_context(), ex_wf) + + @workflow + def grandparent_ex_wf() -> list[typing.Optional[int]]: + return array_node(ex_lp, min_successes=min_successes, min_success_ratio=min_success_ratio)(val=[1, 2, 3, 4]) + + if should_raise_error: + with pytest.raises(Exception): + grandparent_ex_wf() + else: + assert grandparent_ex_wf() == [None, 2, 3, 4] + + +def test_map_task_wrapper(): + mapped_task = map_task(multiply)(val=[1, 3, 5], val1=[2, 4, 6]) + assert mapped_task == [2, 12, 30] + + mapped_lp = map_task(lp)(a=[1, 3, 5], b=[2, 4, 6]) + assert mapped_lp == [2, 12, 30]