Skip to content

Commit

Permalink
Make array_node_map_task the default map_task (#2242)
Browse files Browse the repository at this point in the history
* Swap arraynode map_task

Signed-off-by: Eduardo Apolinario <[email protected]>

* Fix unit tests

Signed-off-by: Eduardo Apolinario <[email protected]>

* Add pytest-icdiff to dev-requirements.in

Signed-off-by: Eduardo Apolinario <[email protected]>

* Fix test_node_creation.py tests

Signed-off-by: Eduardo Apolinario <[email protected]>

* Remove array_node_map_task from experimental module and rename legacy map task module to `legacy_map_task`

Signed-off-by: Eduardo Apolinario <[email protected]>

* Lint

Signed-off-by: Eduardo Apolinario <[email protected]>

* Fix one more mention to legacy map task code

Signed-off-by: Eduardo Apolinario <[email protected]>

* Remove `--experimental` and include `--legacy`

Signed-off-by: Eduardo Apolinario <[email protected]>

* Fix flytekit-k8s-pod/tests/test_pod.py test

Signed-off-by: Eduardo Apolinario <[email protected]>

* Remove the `--legacy` flag and rearrange how the map task resolvers are loaded

Signed-off-by: Eduardo Apolinario <[email protected]>

* Remove map task resolver imports from entrypoint

Signed-off-by: Eduardo Apolinario <[email protected]>

* Fix name in k8s-pod plugin

Signed-off-by: Eduardo Apolinario <[email protected]>

---------

Signed-off-by: Eduardo Apolinario <[email protected]>
Co-authored-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
eapolinario and eapolinario authored Mar 14, 2024
1 parent d61e79e commit 4767fd8
Show file tree
Hide file tree
Showing 13 changed files with 57 additions and 64 deletions.
1 change: 1 addition & 0 deletions dev-requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ google-cloud-bigquery-storage
IPython
keyrings.alt
setuptools_scm
pytest-icdiff

# Tensorflow is not available for python 3.12 yet: https://github.com/tensorflow/tensorflow/issues/62003
tensorflow; python_version<'3.12'
Expand Down
2 changes: 1 addition & 1 deletion flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@
from importlib.metadata import entry_points

from flytekit._version import __version__
from flytekit.core.array_node_map_task import map_task
from flytekit.core.base_sql_task import SQLTask
from flytekit.core.base_task import SecurityContext, TaskMetadata, kwtypes
from flytekit.core.checkpointer import Checkpoint
Expand All @@ -216,7 +217,6 @@
from flytekit.core.gate import approve, sleep, wait_for_input
from flytekit.core.hash import HashMethod
from flytekit.core.launch_plan import LaunchPlan, reference_launch_plan
from flytekit.core.map_task import map_task
from flytekit.core.notification import Email, PagerDuty, Slack
from flytekit.core.pod_template import PodTemplate
from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask
Expand Down
18 changes: 6 additions & 12 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
)
from flytekit.core import constants as _constants
from flytekit.core import utils
from flytekit.core.array_node_map_task import ArrayNodeMapTaskResolver
from flytekit.core.base_task import IgnoreOutputs, PythonTask
from flytekit.core.checkpointer import SyncCheckpoint
from flytekit.core.context_manager import (
Expand All @@ -33,7 +32,6 @@
OutputMetadataTracker,
)
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.core.map_task import MapTaskResolver
from flytekit.core.promise import VoidPromise
from flytekit.deck.deck import _output_deck
from flytekit.exceptions import scopes as _scoped_exceptions
Expand Down Expand Up @@ -392,7 +390,6 @@ def _execute_map_task(
prev_checkpoint: Optional[str] = None,
dynamic_addl_distro: Optional[str] = None,
dynamic_dest_dir: Optional[str] = None,
experimental: Optional[bool] = False,
):
"""
This function should be called by map task and aws-batch task
Expand All @@ -418,14 +415,14 @@ def _execute_map_task(
raw_output_data_prefix, checkpoint_path, prev_checkpoint, dynamic_addl_distro, dynamic_dest_dir
) as ctx:
task_index = _compute_array_job_index()
if experimental:
mtr = ArrayNodeMapTaskResolver()
else:
mtr = MapTaskResolver()
output_prefix = os.path.join(output_prefix, str(task_index))

mtr = load_object_from_module(resolver)()
map_task = mtr.load_task(loader_args=resolver_args, max_concurrency=max_concurrency)

# Special case for the map task resolver, we need to append the task index to the output prefix.
# TODO: (https://github.com/flyteorg/flyte/issues/5011) Remove legacy map task
if mtr.name() == "flytekit.core.legacy_map_task.MapTaskResolver":
output_prefix = os.path.join(output_prefix, str(task_index))

if test:
logger.info(
f"Test detected, returning. Inputs: {inputs} Computed task index: {task_index} "
Expand Down Expand Up @@ -557,7 +554,6 @@ def handle_sigterm(signum, frame):
@_click.option("--resolver", required=True)
@_click.option("--checkpoint-path", required=False)
@_click.option("--prev-checkpoint", required=False)
@_click.option("--experimental", is_flag=True, default=False, required=False)
@_click.argument(
"resolver-args",
type=_click.UNPROCESSED,
Expand All @@ -574,7 +570,6 @@ def map_execute_task_cmd(
resolver,
resolver_args,
prev_checkpoint,
experimental,
checkpoint_path,
):
logger.info(get_version_message())
Expand All @@ -595,7 +590,6 @@ def map_execute_task_cmd(
resolver_args=resolver_args,
checkpoint_path=checkpoint_path,
prev_checkpoint=prev_checkpoint,
experimental=experimental,
)


Expand Down
5 changes: 2 additions & 3 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ def get_command(self, settings: SerializationSettings) -> List[str]:
"{{.checkpointOutputPrefix}}",
"--prev-checkpoint",
"{{.prevCheckpointPrefix}}",
"--experimental",
"--resolver",
mt.name(),
"--",
Expand Down Expand Up @@ -248,7 +247,7 @@ def _outputs_interface(self) -> Dict[Any, Variable]:
"""

ctx = FlyteContextManager.current_context()
if ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
if ctx.execution_state and ctx.execution_state.is_local_execution():
# In workflow execution mode we actually need to use the parent (mapper) task output interface.
return self.interface.outputs
return self.python_function_task.interface.outputs
Expand Down Expand Up @@ -367,7 +366,7 @@ def foo((i: int, j: str) -> str:
"""

def name(self) -> str:
return "ArrayNodeMapTaskResolver"
return "flytekit.core.array_node_map_task.ArrayNodeMapTaskResolver"

@timeit("Load map task")
def load_task(self, loader_args: List[str], max_concurrency: int = 0) -> ArrayNodeMapTask:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def set_command_prefix(self, cmd: typing.Optional[typing.List[str]]):
@contextmanager
def prepare_target(self):
"""
TODO: why do we do this?
Alters the underlying run_task command to modify it for map task execution and then resets it after.
"""
self._run_task.set_command_fn(self.get_command)
Expand Down Expand Up @@ -386,7 +385,7 @@ def foo((i: int, j: str) -> str:
"""

def name(self) -> str:
return "MapTaskResolver"
return "flytekit.core.legacy_map_task.MapTaskResolver"

@timeit("Load map task")
def load_task(self, loader_args: List[str], max_concurrency: int = 0) -> MapPythonTask:
Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def flyte_entity(self) -> Any:
@property
def run_entity(self) -> Any:
from flytekit.core.array_node_map_task import ArrayNodeMapTask
from flytekit.core.map_task import MapPythonTask
from flytekit.core.legacy_map_task import MapPythonTask

if isinstance(self.flyte_entity, MapPythonTask):
return self.flyte_entity.run_task
Expand Down
2 changes: 2 additions & 0 deletions flytekit/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Experimental features of flytekit."""

# TODO(eapolinario): Remove this once a new flytekit release is out and
# references are updated in the monodocs build.
from flytekit.core.array_node_map_task import map_task # noqa: F401
from flytekit.experimental.eager_function import EagerException, eager
2 changes: 1 addition & 1 deletion flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from flytekit.core.container_task import ContainerTask
from flytekit.core.gate import Gate
from flytekit.core.launch_plan import LaunchPlan, ReferenceLaunchPlan
from flytekit.core.map_task import MapPythonTask
from flytekit.core.legacy_map_task import MapPythonTask
from flytekit.core.node import Node
from flytekit.core.python_auto_container import PythonAutoContainerTask
from flytekit.core.reference_entity import ReferenceEntity, ReferenceSpec, ReferenceTemplate
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-k8s-pod/tests/test_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def simple_pod_task(i: int):
"--prev-checkpoint",
"{{.prevCheckpointPrefix}}",
"--resolver",
"MapTaskResolver",
"flytekit.core.array_node_map_task.ArrayNodeMapTaskResolver",
"--",
"vars",
"",
Expand Down
45 changes: 21 additions & 24 deletions tests/flytekit/unit/core/test_array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@

import pytest

from flytekit import task, workflow
from flytekit import map_task, task, workflow
from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings
from flytekit.core.array_node_map_task import ArrayNodeMapTask, ArrayNodeMapTaskResolver
from flytekit.core.task import TaskMetadata
from flytekit.experimental import map_task as array_node_map_task
from flytekit.tools.translator import get_serializable


Expand All @@ -32,7 +31,7 @@ def say_hello(name: str) -> str:

@workflow
def wf() -> List[str]:
return array_node_map_task(say_hello)(name=["abc", "def"])
return map_task(say_hello)(name=["abc", "def"])

res = wf()
assert res is not None
Expand All @@ -49,8 +48,8 @@ def create_input_list() -> List[str]:

@workflow
def wf() -> List[str]:
xs = array_node_map_task(say_hello)(name=create_input_list())
return array_node_map_task(say_hello)(name=xs)
xs = map_task(say_hello)(name=create_input_list())
return map_task(say_hello)(name=xs)

assert wf() == ["hello hello earth!!", "hello hello mars!!"]

Expand All @@ -60,7 +59,7 @@ def test_serialization(serialization_settings):
def t1(a: int) -> int:
return a + 1

arraynode_maptask = array_node_map_task(t1, metadata=TaskMetadata(retries=2))
arraynode_maptask = map_task(t1, metadata=TaskMetadata(retries=2))
task_spec = get_serializable(OrderedDict(), serialization_settings, arraynode_maptask)

assert task_spec.template.metadata.retries.retries == 2
Expand All @@ -79,9 +78,8 @@ def t1(a: int) -> int:
"{{.checkpointOutputPrefix}}",
"--prev-checkpoint",
"{{.prevCheckpointPrefix}}",
"--experimental",
"--resolver",
"ArrayNodeMapTaskResolver",
"flytekit.core.array_node_map_task.ArrayNodeMapTaskResolver",
"--",
"vars",
"",
Expand All @@ -101,7 +99,7 @@ def test_fast_serialization(serialization_settings):
def t1(a: int) -> int:
return a + 1

arraynode_maptask = array_node_map_task(t1, metadata=TaskMetadata(retries=2))
arraynode_maptask = map_task(t1, metadata=TaskMetadata(retries=2))
task_spec = get_serializable(OrderedDict(), serialization_settings, arraynode_maptask)

assert task_spec.template.container.args == [
Expand All @@ -122,9 +120,8 @@ def t1(a: int) -> int:
"{{.checkpointOutputPrefix}}",
"--prev-checkpoint",
"{{.prevCheckpointPrefix}}",
"--experimental",
"--resolver",
"ArrayNodeMapTaskResolver",
"flytekit.core.array_node_map_task.ArrayNodeMapTaskResolver",
"--",
"vars",
"",
Expand Down Expand Up @@ -172,8 +169,8 @@ def test_metadata_in_task_name(kwargs1, kwargs2, same):
def say_hello(name: str) -> str:
return f"hello {name}!"

t1 = array_node_map_task(say_hello, **kwargs1)
t2 = array_node_map_task(say_hello, **kwargs2)
t1 = map_task(say_hello, **kwargs1)
t2 = map_task(say_hello, **kwargs2)

assert (t1.name == t2.name) is same

Expand All @@ -183,7 +180,7 @@ def test_inputs_outputs_length():
def many_inputs(a: int, b: str, c: float) -> str:
return f"{a} - {b} - {c}"

m = array_node_map_task(many_inputs)
m = map_task(many_inputs)
assert m.python_interface.inputs == {"a": List[int], "b": List[str], "c": List[float]}
assert (
m.name
Expand All @@ -193,7 +190,7 @@ def many_inputs(a: int, b: str, c: float) -> str:
assert str(r_m.python_interface) == str(m.python_interface)

p1 = functools.partial(many_inputs, c=1.0)
m = array_node_map_task(p1)
m = map_task(p1)
assert m.python_interface.inputs == {"a": List[int], "b": List[str], "c": float}
assert (
m.name
Expand All @@ -203,7 +200,7 @@ def many_inputs(a: int, b: str, c: float) -> str:
assert str(r_m.python_interface) == str(m.python_interface)

p2 = functools.partial(p1, b="hello")
m = array_node_map_task(p2)
m = map_task(p2)
assert m.python_interface.inputs == {"a": List[int], "b": str, "c": float}
assert (
m.name
Expand All @@ -213,7 +210,7 @@ def many_inputs(a: int, b: str, c: float) -> str:
assert str(r_m.python_interface) == str(m.python_interface)

p3 = functools.partial(p2, a=1)
m = array_node_map_task(p3)
m = map_task(p3)
assert m.python_interface.inputs == {"a": int, "b": str, "c": float}
assert (
m.name
Expand All @@ -230,7 +227,7 @@ def many_outputs(a: int) -> (int, str):
return a, f"{a}"

with pytest.raises(ValueError):
_ = array_node_map_task(many_outputs)
_ = map_task(many_outputs)


def test_parameter_order():
Expand All @@ -250,9 +247,9 @@ def task3(c: str, a: int, b: float) -> str:
param_b = [0.1, 0.2, 0.3]
param_c = "c"

m1 = array_node_map_task(functools.partial(task1, c=param_c))(a=param_a, b=param_b)
m2 = array_node_map_task(functools.partial(task2, c=param_c))(a=param_a, b=param_b)
m3 = array_node_map_task(functools.partial(task3, c=param_c))(a=param_a, b=param_b)
m1 = map_task(functools.partial(task1, c=param_c))(a=param_a, b=param_b)
m2 = map_task(functools.partial(task2, c=param_c))(a=param_a, b=param_b)
m3 = map_task(functools.partial(task3, c=param_c))(a=param_a, b=param_b)

assert m1 == m2 == m3 == ["1 - 0.1 - c", "2 - 0.2 - c", "3 - 0.3 - c"]

Expand All @@ -262,7 +259,7 @@ def test_bounded_inputs_vars_order(serialization_settings):
def task1(a: int, b: float, c: str) -> str:
return f"{a} - {b} - {c}"

mt = array_node_map_task(functools.partial(task1, c=1.0, b="hello", a=1))
mt = map_task(functools.partial(task1, c=1.0, b="hello", a=1))
mtr = ArrayNodeMapTaskResolver()
args = mtr.loader_args(serialization_settings, mt)

Expand All @@ -287,7 +284,7 @@ def some_task1(inputs: int) -> int:

@workflow
def my_wf1() -> typing.List[typing.Optional[int]]:
return array_node_map_task(some_task1, min_success_ratio=min_success_ratio)(inputs=[1, 2, 3, 4])
return map_task(some_task1, min_success_ratio=min_success_ratio)(inputs=[1, 2, 3, 4])

if should_raise_error:
with pytest.raises(ValueError):
Expand All @@ -303,6 +300,6 @@ def my_mappable_task(a: int) -> typing.Optional[str]:

@workflow
def wf(x: typing.List[int]):
array_node_map_task(my_mappable_task)(a=x).with_overrides(container_image="random:image")
map_task(my_mappable_task)(a=x).with_overrides(container_image="random:image")

assert wf.nodes[0]._container_image == "random:image"
8 changes: 4 additions & 4 deletions tests/flytekit/unit/core/test_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import pytest

import flytekit.configuration
from flytekit import LaunchPlan, Resources, map_task
from flytekit import LaunchPlan, Resources
from flytekit.configuration import Image, ImageConfig
from flytekit.core.map_task import MapPythonTask, MapTaskResolver
from flytekit.core.legacy_map_task import MapPythonTask, MapTaskResolver, map_task
from flytekit.core.task import TaskMetadata, task
from flytekit.core.workflow import workflow
from flytekit.tools.translator import get_serializable
Expand Down Expand Up @@ -96,7 +96,7 @@ def test_serialization(serialization_settings):
"--prev-checkpoint",
"{{.prevCheckpointPrefix}}",
"--resolver",
"MapTaskResolver",
"flytekit.core.legacy_map_task.MapTaskResolver",
"--",
"vars",
"",
Expand Down Expand Up @@ -247,7 +247,7 @@ def test_map_task_resolver(serialization_settings):
assert mt.python_interface.inputs == {"a": typing.List[int], "b": typing.List[str], "c": typing.List[float]}
assert mt.python_interface.outputs == list_outputs
mtr = MapTaskResolver()
assert mtr.name() == "MapTaskResolver"
assert mtr.name() == "flytekit.core.legacy_map_task.MapTaskResolver"
args = mtr.loader_args(serialization_settings, mt)
t = mtr.load_task(loader_args=args)
assert t.python_interface.inputs == mt.python_interface.inputs
Expand Down
Loading

0 comments on commit 4767fd8

Please sign in to comment.