Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make array_node_map_task the default map_task #2242

Merged
merged 12 commits into from
Mar 14, 2024
1 change: 1 addition & 0 deletions dev-requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ google-cloud-bigquery-storage
IPython
keyrings.alt
setuptools_scm
pytest-icdiff
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gives better diffs in case of failures: https://github.com/hjwp/pytest-icdiff


# Tensorflow is not available for python 3.12 yet: https://github.com/tensorflow/tensorflow/issues/62003
tensorflow; python_version<'3.12'
Expand Down
2 changes: 1 addition & 1 deletion flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@
from importlib.metadata import entry_points

from flytekit._version import __version__
from flytekit.core.array_node_map_task import map_task
from flytekit.core.base_sql_task import SQLTask
from flytekit.core.base_task import SecurityContext, TaskMetadata, kwtypes
from flytekit.core.checkpointer import Checkpoint
Expand All @@ -216,7 +217,6 @@
from flytekit.core.gate import approve, sleep, wait_for_input
from flytekit.core.hash import HashMethod
from flytekit.core.launch_plan import LaunchPlan, reference_launch_plan
from flytekit.core.map_task import map_task
from flytekit.core.notification import Email, PagerDuty, Slack
from flytekit.core.pod_template import PodTemplate
from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask
Expand Down
16 changes: 8 additions & 8 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from flytekit.core.checkpointer import SyncCheckpoint
from flytekit.core.context_manager import ExecutionParameters, ExecutionState, FlyteContext, FlyteContextManager
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.core.map_task import MapTaskResolver
from flytekit.core.legacy_map_task import MapTaskResolver
from flytekit.core.promise import VoidPromise
from flytekit.deck.deck import _output_deck
from flytekit.exceptions import scopes as _scoped_exceptions
Expand Down Expand Up @@ -383,7 +383,7 @@ def _execute_map_task(
prev_checkpoint: Optional[str] = None,
dynamic_addl_distro: Optional[str] = None,
dynamic_dest_dir: Optional[str] = None,
experimental: Optional[bool] = False,
legacy: Optional[bool] = False,
):
"""
This function should be called by map task and aws-batch task
Expand All @@ -409,9 +409,9 @@ def _execute_map_task(
raw_output_data_prefix, checkpoint_path, prev_checkpoint, dynamic_addl_distro, dynamic_dest_dir
) as ctx:
task_index = _compute_array_job_index()
if experimental:
mtr = ArrayNodeMapTaskResolver()
else:
mtr = ArrayNodeMapTaskResolver()
# TODO: (https://github.com/flyteorg/flyte/issues/5011) Remove legacy map task
if legacy:
mtr = MapTaskResolver()
output_prefix = os.path.join(output_prefix, str(task_index))

Expand Down Expand Up @@ -548,7 +548,7 @@ def handle_sigterm(signum, frame):
@_click.option("--resolver", required=True)
@_click.option("--checkpoint-path", required=False)
@_click.option("--prev-checkpoint", required=False)
@_click.option("--experimental", is_flag=True, default=False, required=False)
@_click.option("--legacy", is_flag=True, default=False, required=False)
@_click.argument(
"resolver-args",
type=_click.UNPROCESSED,
Expand All @@ -565,7 +565,7 @@ def map_execute_task_cmd(
resolver,
resolver_args,
prev_checkpoint,
experimental,
legacy,
checkpoint_path,
):
logger.info(get_version_message())
Expand All @@ -586,7 +586,7 @@ def map_execute_task_cmd(
resolver_args=resolver_args,
checkpoint_path=checkpoint_path,
prev_checkpoint=prev_checkpoint,
experimental=experimental,
legacy=legacy,
)


Expand Down
3 changes: 1 addition & 2 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ def get_command(self, settings: SerializationSettings) -> List[str]:
"{{.checkpointOutputPrefix}}",
"--prev-checkpoint",
"{{.prevCheckpointPrefix}}",
"--experimental",
"--resolver",
mt.name(),
"--",
Expand Down Expand Up @@ -248,7 +247,7 @@ def _outputs_interface(self) -> Dict[Any, Variable]:
"""

ctx = FlyteContextManager.current_context()
if ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
if ctx.execution_state and ctx.execution_state.is_local_execution():
# In workflow execution mode we actually need to use the parent (mapper) task output interface.
return self.interface.outputs
return self.python_function_task.interface.outputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def get_command(self, settings: SerializationSettings) -> List[str]:
"{{.checkpointOutputPrefix}}",
"--prev-checkpoint",
"{{.prevCheckpointPrefix}}",
"--legacy",
"--resolver",
mt.name(),
"--",
Expand All @@ -150,7 +151,6 @@ def set_command_prefix(self, cmd: typing.Optional[typing.List[str]]):
@contextmanager
def prepare_target(self):
"""
TODO: why do we do this?
Alters the underlying run_task command to modify it for map task execution and then resets it after.
"""
self._run_task.set_command_fn(self.get_command)
Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def flyte_entity(self) -> Any:
@property
def run_entity(self) -> Any:
from flytekit.core.array_node_map_task import ArrayNodeMapTask
from flytekit.core.map_task import MapPythonTask
from flytekit.core.legacy_map_task import MapPythonTask

if isinstance(self.flyte_entity, MapPythonTask):
return self.flyte_entity.run_task
Expand Down
2 changes: 2 additions & 0 deletions flytekit/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Experimental features of flytekit."""

# TODO(eapolinario): Remove this once a new flytekit release is out and
# references are updated in the monodocs build.
from flytekit.core.array_node_map_task import map_task # noqa: F401
from flytekit.experimental.eager_function import EagerException, eager
2 changes: 1 addition & 1 deletion flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from flytekit.core.container_task import ContainerTask
from flytekit.core.gate import Gate
from flytekit.core.launch_plan import LaunchPlan, ReferenceLaunchPlan
from flytekit.core.map_task import MapPythonTask
from flytekit.core.legacy_map_task import MapPythonTask
from flytekit.core.node import Node
from flytekit.core.python_auto_container import PythonAutoContainerTask
from flytekit.core.reference_entity import ReferenceEntity, ReferenceSpec, ReferenceTemplate
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-k8s-pod/tests/test_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def simple_pod_task(i: int):
"--prev-checkpoint",
"{{.prevCheckpointPrefix}}",
"--resolver",
"MapTaskResolver",
"ArrayNodeMapTaskResolver",
"--",
"vars",
"",
Expand Down
41 changes: 19 additions & 22 deletions tests/flytekit/unit/core/test_array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@

import pytest

from flytekit import task, workflow
from flytekit import map_task, task, workflow
from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings
from flytekit.core.array_node_map_task import ArrayNodeMapTask, ArrayNodeMapTaskResolver
from flytekit.core.task import TaskMetadata
from flytekit.experimental import map_task as array_node_map_task
from flytekit.tools.translator import get_serializable


Expand All @@ -32,7 +31,7 @@ def say_hello(name: str) -> str:

@workflow
def wf() -> List[str]:
return array_node_map_task(say_hello)(name=["abc", "def"])
return map_task(say_hello)(name=["abc", "def"])

res = wf()
assert res is not None
Expand All @@ -49,8 +48,8 @@ def create_input_list() -> List[str]:

@workflow
def wf() -> List[str]:
xs = array_node_map_task(say_hello)(name=create_input_list())
return array_node_map_task(say_hello)(name=xs)
xs = map_task(say_hello)(name=create_input_list())
return map_task(say_hello)(name=xs)

assert wf() == ["hello hello earth!!", "hello hello mars!!"]

Expand All @@ -60,7 +59,7 @@ def test_serialization(serialization_settings):
def t1(a: int) -> int:
return a + 1

arraynode_maptask = array_node_map_task(t1, metadata=TaskMetadata(retries=2))
arraynode_maptask = map_task(t1, metadata=TaskMetadata(retries=2))
task_spec = get_serializable(OrderedDict(), serialization_settings, arraynode_maptask)

assert task_spec.template.metadata.retries.retries == 2
Expand All @@ -79,7 +78,6 @@ def t1(a: int) -> int:
"{{.checkpointOutputPrefix}}",
"--prev-checkpoint",
"{{.prevCheckpointPrefix}}",
"--experimental",
"--resolver",
"ArrayNodeMapTaskResolver",
"--",
Expand All @@ -101,7 +99,7 @@ def test_fast_serialization(serialization_settings):
def t1(a: int) -> int:
return a + 1

arraynode_maptask = array_node_map_task(t1, metadata=TaskMetadata(retries=2))
arraynode_maptask = map_task(t1, metadata=TaskMetadata(retries=2))
task_spec = get_serializable(OrderedDict(), serialization_settings, arraynode_maptask)

assert task_spec.template.container.args == [
Expand All @@ -122,7 +120,6 @@ def t1(a: int) -> int:
"{{.checkpointOutputPrefix}}",
"--prev-checkpoint",
"{{.prevCheckpointPrefix}}",
"--experimental",
"--resolver",
"ArrayNodeMapTaskResolver",
"--",
Expand Down Expand Up @@ -172,8 +169,8 @@ def test_metadata_in_task_name(kwargs1, kwargs2, same):
def say_hello(name: str) -> str:
return f"hello {name}!"

t1 = array_node_map_task(say_hello, **kwargs1)
t2 = array_node_map_task(say_hello, **kwargs2)
t1 = map_task(say_hello, **kwargs1)
t2 = map_task(say_hello, **kwargs2)

assert (t1.name == t2.name) is same

Expand All @@ -183,7 +180,7 @@ def test_inputs_outputs_length():
def many_inputs(a: int, b: str, c: float) -> str:
return f"{a} - {b} - {c}"

m = array_node_map_task(many_inputs)
m = map_task(many_inputs)
assert m.python_interface.inputs == {"a": List[int], "b": List[str], "c": List[float]}
assert (
m.name
Expand All @@ -193,7 +190,7 @@ def many_inputs(a: int, b: str, c: float) -> str:
assert str(r_m.python_interface) == str(m.python_interface)

p1 = functools.partial(many_inputs, c=1.0)
m = array_node_map_task(p1)
m = map_task(p1)
assert m.python_interface.inputs == {"a": List[int], "b": List[str], "c": float}
assert (
m.name
Expand All @@ -203,7 +200,7 @@ def many_inputs(a: int, b: str, c: float) -> str:
assert str(r_m.python_interface) == str(m.python_interface)

p2 = functools.partial(p1, b="hello")
m = array_node_map_task(p2)
m = map_task(p2)
assert m.python_interface.inputs == {"a": List[int], "b": str, "c": float}
assert (
m.name
Expand All @@ -213,7 +210,7 @@ def many_inputs(a: int, b: str, c: float) -> str:
assert str(r_m.python_interface) == str(m.python_interface)

p3 = functools.partial(p2, a=1)
m = array_node_map_task(p3)
m = map_task(p3)
assert m.python_interface.inputs == {"a": int, "b": str, "c": float}
assert (
m.name
Expand All @@ -230,7 +227,7 @@ def many_outputs(a: int) -> (int, str):
return a, f"{a}"

with pytest.raises(ValueError):
_ = array_node_map_task(many_outputs)
_ = map_task(many_outputs)


def test_parameter_order():
Expand All @@ -250,9 +247,9 @@ def task3(c: str, a: int, b: float) -> str:
param_b = [0.1, 0.2, 0.3]
param_c = "c"

m1 = array_node_map_task(functools.partial(task1, c=param_c))(a=param_a, b=param_b)
m2 = array_node_map_task(functools.partial(task2, c=param_c))(a=param_a, b=param_b)
m3 = array_node_map_task(functools.partial(task3, c=param_c))(a=param_a, b=param_b)
m1 = map_task(functools.partial(task1, c=param_c))(a=param_a, b=param_b)
m2 = map_task(functools.partial(task2, c=param_c))(a=param_a, b=param_b)
m3 = map_task(functools.partial(task3, c=param_c))(a=param_a, b=param_b)

assert m1 == m2 == m3 == ["1 - 0.1 - c", "2 - 0.2 - c", "3 - 0.3 - c"]

Expand All @@ -262,7 +259,7 @@ def test_bounded_inputs_vars_order(serialization_settings):
def task1(a: int, b: float, c: str) -> str:
return f"{a} - {b} - {c}"

mt = array_node_map_task(functools.partial(task1, c=1.0, b="hello", a=1))
mt = map_task(functools.partial(task1, c=1.0, b="hello", a=1))
mtr = ArrayNodeMapTaskResolver()
args = mtr.loader_args(serialization_settings, mt)

Expand All @@ -287,7 +284,7 @@ def some_task1(inputs: int) -> int:

@workflow
def my_wf1() -> typing.List[typing.Optional[int]]:
return array_node_map_task(some_task1, min_success_ratio=min_success_ratio)(inputs=[1, 2, 3, 4])
return map_task(some_task1, min_success_ratio=min_success_ratio)(inputs=[1, 2, 3, 4])

if should_raise_error:
with pytest.raises(ValueError):
Expand All @@ -303,6 +300,6 @@ def my_mappable_task(a: int) -> typing.Optional[str]:

@workflow
def wf(x: typing.List[int]):
array_node_map_task(my_mappable_task)(a=x).with_overrides(container_image="random:image")
map_task(my_mappable_task)(a=x).with_overrides(container_image="random:image")

assert wf.nodes[0]._container_image == "random:image"
5 changes: 3 additions & 2 deletions tests/flytekit/unit/core/test_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import pytest

import flytekit.configuration
from flytekit import LaunchPlan, Resources, map_task
from flytekit import LaunchPlan, Resources
from flytekit.configuration import Image, ImageConfig
from flytekit.core.map_task import MapPythonTask, MapTaskResolver
from flytekit.core.legacy_map_task import MapPythonTask, MapTaskResolver, map_task
from flytekit.core.task import TaskMetadata, task
from flytekit.core.workflow import workflow
from flytekit.tools.translator import get_serializable
Expand Down Expand Up @@ -95,6 +95,7 @@ def test_serialization(serialization_settings):
"{{.checkpointOutputPrefix}}",
"--prev-checkpoint",
"{{.prevCheckpointPrefix}}",
"--legacy",
"--resolver",
"MapTaskResolver",
"--",
Expand Down
16 changes: 8 additions & 8 deletions tests/flytekit/unit/core/test_node_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,13 @@ def my_wf(a: typing.List[str]) -> typing.List[str]:
)
wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
assert len(wf_spec.template.nodes) == 1
assert wf_spec.template.nodes[0].task_node.overrides is not None
assert wf_spec.template.nodes[0].task_node.overrides.resources.requests == [
assert wf_spec.template.nodes[0].array_node.node.task_node.overrides is not None
assert wf_spec.template.nodes[0].array_node.node.task_node.overrides.resources.requests == [
_resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "1"),
_resources_models.ResourceEntry(_resources_models.ResourceName.MEMORY, "100"),
_resources_models.ResourceEntry(_resources_models.ResourceName.EPHEMERAL_STORAGE, "500Mi"),
]
assert wf_spec.template.nodes[0].task_node.overrides.resources.limits == []
assert wf_spec.template.nodes[0].array_node.node.task_node.overrides.resources.limits == []


def test_resource_limits_override():
Expand All @@ -254,8 +254,8 @@ def my_wf(a: typing.List[str]) -> typing.List[str]:
)
wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
assert len(wf_spec.template.nodes) == 1
assert wf_spec.template.nodes[0].task_node.overrides.resources.requests == []
assert wf_spec.template.nodes[0].task_node.overrides.resources.limits == [
assert wf_spec.template.nodes[0].array_node.node.task_node.overrides.resources.requests == []
assert wf_spec.template.nodes[0].array_node.node.task_node.overrides.resources.limits == [
_resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "2"),
_resources_models.ResourceEntry(_resources_models.ResourceName.MEMORY, "200"),
_resources_models.ResourceEntry(_resources_models.ResourceName.EPHEMERAL_STORAGE, "1Gi"),
Expand Down Expand Up @@ -285,14 +285,14 @@ def my_wf(a: typing.List[str]) -> typing.List[str]:
)
wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
assert len(wf_spec.template.nodes) == 1
assert wf_spec.template.nodes[0].task_node.overrides is not None
assert wf_spec.template.nodes[0].task_node.overrides.resources.requests == [
assert wf_spec.template.nodes[0].array_node.node.task_node.overrides is not None
assert wf_spec.template.nodes[0].array_node.node.task_node.overrides.resources.requests == [
_resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "1"),
_resources_models.ResourceEntry(_resources_models.ResourceName.MEMORY, "100"),
_resources_models.ResourceEntry(_resources_models.ResourceName.EPHEMERAL_STORAGE, "500Mi"),
]

assert wf_spec.template.nodes[0].task_node.overrides.resources.limits == [
assert wf_spec.template.nodes[0].array_node.node.task_node.overrides.resources.limits == [
_resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "2"),
_resources_models.ResourceEntry(_resources_models.ResourceName.MEMORY, "200"),
_resources_models.ResourceEntry(_resources_models.ResourceName.EPHEMERAL_STORAGE, "1Gi"),
Expand Down
Loading
Loading