Skip to content

Commit

Permalink
Update array node map task to support an additional plugin (#2934)
Browse files Browse the repository at this point in the history
* have get_custom return subtask and remove un-needed array job

Signed-off-by: Paul Dittamo <[email protected]>

* set is_original_sub_node_interface + FULL_STATE for python function task extensions

Signed-off-by: Paul Dittamo <[email protected]>

* lint

Signed-off-by: Paul Dittamo <[email protected]>

* unit test

Signed-off-by: Paul Dittamo <[email protected]>

* clean up

Signed-off-by: Paul Dittamo <[email protected]>

* idl update

Signed-off-by: Paul Dittamo <[email protected]>

* Revert "idl update"

This reverts commit e98bc6e.

Signed-off-by: Paul Dittamo <[email protected]>

* pass in boolvalue

Signed-off-by: Paul Dittamo <[email protected]>

* update unit test

Signed-off-by: Paul Dittamo <[email protected]>

* Set flyteidl lower bound

Signed-off-by: Eduardo Apolinario <[email protected]>

---------

Signed-off-by: Paul Dittamo <[email protected]>
Signed-off-by: Eduardo Apolinario <[email protected]>
Co-authored-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
pvditt and eapolinario authored Dec 2, 2024
1 parent 6cdc4ab commit 209fa65
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 16 deletions.
4 changes: 4 additions & 0 deletions flytekit/core/array_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,10 @@ def concurrency(self) -> Optional[int]:
def execution_mode(self) -> _core_workflow.ArrayNode.ExecutionMode:
return self._execution_mode

@property
def is_original_sub_node_interface(self) -> bool:
return True

def __call__(self, *args, **kwargs):
if not self._bindings:
ctx = FlyteContext.current_context()
Expand Down
20 changes: 18 additions & 2 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union, cast

from flyteidl.core import tasks_pb2
from flyteidl.core import workflow_pb2 as _core_workflow

from flytekit.configuration import SerializationSettings
from flytekit.core import tracker
Expand All @@ -21,7 +22,6 @@
from flytekit.core.utils import timeit
from flytekit.loggers import logger
from flytekit.models import literals as _literal_models
from flytekit.models.array_job import ArrayJob
from flytekit.models.core.workflow import NodeMetadata
from flytekit.models.interface import Variable
from flytekit.models.task import Container, K8sPod, Sql, Task
Expand Down Expand Up @@ -106,6 +106,14 @@ def __init__(
self._min_success_ratio: Optional[float] = min_success_ratio
self._collection_interface = collection_interface

self._execution_mode: _core_workflow.ArrayNode.ExecutionMode = _core_workflow.ArrayNode.FULL_STATE
if (
type(python_function_task) in {PythonFunctionTask, PythonInstanceTask}
or isinstance(python_function_task, functools.partial)
and type(python_function_task.func) in {PythonFunctionTask, PythonInstanceTask}
):
self._execution_mode = _core_workflow.ArrayNode.MINIMAL_STATE

if "metadata" not in kwargs and actual_task.metadata:
kwargs["metadata"] = actual_task.metadata
if "security_ctx" not in kwargs and actual_task.security_context:
Expand Down Expand Up @@ -154,6 +162,14 @@ def python_function_task(self) -> Union[PythonFunctionTask, PythonInstanceTask]:
def bound_inputs(self) -> Set[str]:
return self._bound_inputs

@property
def execution_mode(self) -> _core_workflow.ArrayNode.ExecutionMode:
return self._execution_mode

@property
def is_original_sub_node_interface(self) -> bool:
return False

def get_extended_resources(self, settings: SerializationSettings) -> Optional[tasks_pb2.ExtendedResources]:
return self.python_function_task.get_extended_resources(settings)

Expand All @@ -169,7 +185,7 @@ def prepare_target(self):
self.python_function_task.reset_command_fn()

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
return ArrayJob(parallelism=self._concurrency, min_success_ratio=self._min_success_ratio).to_dict()
return self._run_task.get_custom(settings) or {}

def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]]:
return self.python_function_task.get_config(settings)
Expand Down
11 changes: 10 additions & 1 deletion flytekit/models/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from flyteidl.core import tasks_pb2
from flyteidl.core import workflow_pb2 as _core_workflow
from google.protobuf.wrappers_pb2 import BoolValue

from flytekit.models import common as _common
from flytekit.models import interface as _interface
Expand Down Expand Up @@ -382,7 +383,13 @@ def from_flyte_idl(cls, pb2_object: _core_workflow.GateNode) -> "GateNode":

class ArrayNode(_common.FlyteIdlEntity):
def __init__(
self, node: "Node", parallelism=None, min_successes=None, min_success_ratio=None, execution_mode=None
self,
node: "Node",
parallelism=None,
min_successes=None,
min_success_ratio=None,
execution_mode=None,
is_original_sub_node_interface=False,
) -> None:
"""
TODO: docstring
Expand All @@ -393,6 +400,7 @@ def __init__(
self._min_successes = min_successes
self._min_success_ratio = min_success_ratio
self._execution_mode = execution_mode
self._is_original_sub_node_interface = is_original_sub_node_interface

@property
def node(self) -> "Node":
Expand All @@ -405,6 +413,7 @@ def to_flyte_idl(self) -> _core_workflow.ArrayNode:
min_successes=self._min_successes,
min_success_ratio=self._min_success_ratio,
execution_mode=self._execution_mode,
is_original_sub_node_interface=BoolValue(value=self._is_original_sub_node_interface),
)

@classmethod
Expand Down
3 changes: 3 additions & 0 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,7 @@ def get_serializable_array_node(
min_successes=array_node.min_successes,
min_success_ratio=array_node.min_success_ratio,
execution_mode=array_node.execution_mode,
is_original_sub_node_interface=array_node.is_original_sub_node_interface,
)


Expand Down Expand Up @@ -629,6 +630,8 @@ def get_serializable_array_node_map_task(
parallelism=entity.concurrency,
min_successes=entity.min_successes,
min_success_ratio=entity.min_success_ratio,
execution_mode=entity.execution_mode,
is_original_sub_node_interface=entity.is_original_sub_node_interface,
)


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies = [
"diskcache>=5.2.1",
"docker>=4.0.0",
"docstring-parser>=0.9.0",
"flyteidl>=1.13.7",
"flyteidl>=1.13.9",
"fsspec>=2023.3.0",
"gcsfs>=2023.3.0",
"googleapis-common-protos>=1.57",
Expand Down
5 changes: 5 additions & 0 deletions tests/flytekit/unit/core/test_array_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import OrderedDict

import pytest
from flyteidl.core import workflow_pb2 as _core_workflow

from flytekit import LaunchPlan, task, workflow
from flytekit.core.context_manager import FlyteContextManager
Expand Down Expand Up @@ -96,6 +97,10 @@ def test_lp_serialization(target, overrides_metadata, serialization_settings):

parent_node = wf_spec.template.nodes[0]
assert parent_node.inputs[0].var == "a"
assert parent_node.array_node._min_success_ratio == 0.9
assert parent_node.array_node._parallelism == 10
assert parent_node.array_node._is_original_sub_node_interface
assert parent_node.array_node._execution_mode == _core_workflow.ArrayNode.FULL_STATE
assert len(parent_node.inputs[0].binding.collection.bindings) == 3
for binding in parent_node.inputs[0].binding.collection.bindings:
assert binding.scalar.primitive.integer is not None
Expand Down
41 changes: 29 additions & 12 deletions tests/flytekit/unit/core/test_array_node_map_task.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,33 @@
import functools
import os
import pathlib
import tempfile
import typing
from collections import OrderedDict
from typing import List
from typing_extensions import Annotated
import tempfile

import pytest
from flyteidl.core import workflow_pb2 as _core_workflow

from flytekit import dynamic, map_task, task, workflow
from flytekit.types.directory import FlyteDirectory
from flytekit import dynamic, map_task, task, workflow, PythonFunctionTask
from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings
from flytekit.core import context_manager
from flytekit.core.array_node_map_task import ArrayNodeMapTask, ArrayNodeMapTaskResolver
from flytekit.core.python_auto_container import PICKLE_FILE_PATH
from flytekit.core.task import TaskMetadata
from flytekit.core.type_engine import TypeEngine
from flytekit.extras.accelerators import GPUAccelerator
from flytekit.experimental.eager_function import eager
from flytekit.extras.accelerators import GPUAccelerator
from flytekit.models.literals import (
Literal,
LiteralMap,
LiteralOffloadedMetadata,
)
from flytekit.tools.translator import get_serializable, Options
from flytekit.tools.translator import get_serializable
from flytekit.types.directory import FlyteDirectory


class PythonFunctionTaskExtension(PythonFunctionTask):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)


@pytest.fixture
Expand Down Expand Up @@ -117,7 +120,6 @@ def t1(a: int) -> int:
task_spec = get_serializable(OrderedDict(), serialization_settings, arraynode_maptask)

assert task_spec.template.metadata.retries.retries == 2
assert task_spec.template.custom["minSuccessRatio"] == 1.0
assert task_spec.template.type == "python-task"
assert task_spec.template.task_type_version == 1
assert task_spec.template.container.args == [
Expand Down Expand Up @@ -381,25 +383,40 @@ def wf(x: typing.List[int]):

def test_serialization_metadata2(serialization_settings):
@task
def t1(a: int) -> int:
def t1(a: int) -> typing.Optional[int]:
return a + 1

arraynode_maptask = map_task(t1, metadata=TaskMetadata(retries=2, interruptible=True))
arraynode_maptask = map_task(t1, min_success_ratio=0.9, concurrency=10, metadata=TaskMetadata(retries=2, interruptible=True))
assert arraynode_maptask.metadata.interruptible

@workflow
def wf(x: typing.List[int]):
return arraynode_maptask(a=x)

full_state_array_node_map_task = map_task(PythonFunctionTaskExtension(task_config={}, task_function=t1))

@workflow
def wf1(x: typing.List[int]):
return full_state_array_node_map_task(a=x)

od = OrderedDict()
wf_spec = get_serializable(od, serialization_settings, wf)

assert arraynode_maptask.construct_node_metadata().interruptible
assert wf_spec.template.nodes[0].metadata.interruptible
array_node = wf_spec.template.nodes[0]
assert array_node.metadata.interruptible
assert array_node.array_node._min_success_ratio == 0.9
assert array_node.array_node._parallelism == 10
assert not array_node.array_node._is_original_sub_node_interface
assert array_node.array_node._execution_mode == _core_workflow.ArrayNode.MINIMAL_STATE
task_spec = od[arraynode_maptask]
assert task_spec.template.metadata.retries.retries == 2
assert task_spec.template.metadata.interruptible

wf1_spec = get_serializable(od, serialization_settings, wf1)
array_node = wf1_spec.template.nodes[0]
assert array_node.array_node._execution_mode == _core_workflow.ArrayNode.FULL_STATE


def test_serialization_extended_resources(serialization_settings):
@task(
Expand Down

0 comments on commit 209fa65

Please sign in to comment.