Skip to content

Commit

Permalink
Set sane defaults in map task templates
Browse files Browse the repository at this point in the history
Signed-off-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
eapolinario committed Feb 1, 2022
1 parent 11c59a0 commit 9684689
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 22 deletions.
2 changes: 1 addition & 1 deletion flytekit/core/map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def _raw_execute(self, **kwargs) -> Any:
return outputs


def map_task(task_function: PythonFunctionTask, concurrency: int = None, min_success_ratio: float = None, **kwargs):
def map_task(task_function: PythonFunctionTask, concurrency: int = 1, min_success_ratio: float = 1.0, **kwargs):
"""
Use a map task for parallelizable tasks that run across a list of an input type. A map task can be composed of
any individual :py:class:`flytekit.PythonFunctionTask`.
Expand Down
15 changes: 12 additions & 3 deletions flytekit/models/array_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,21 @@ def to_dict(self):
"""
:rtype: dict[T, Text]
"""
return _json_format.MessageToDict(
_array_job.ArrayJob(
array_job = None
if self.min_successes is not None:
array_job = _array_job.ArrayJob(
parallelism=self.parallelism,
size=self.size,
min_successes=self.min_successes,
)
)
elif self.min_success_ratio is not None:
array_job = _array_job.ArrayJob(
parallelism=self.parallelism,
size=self.size,
min_success_ratio=self.min_success_ratio,
)

return _json_format.MessageToDict(array_job)

@classmethod
def from_dict(cls, idl_dict):
Expand All @@ -90,4 +98,5 @@ def from_dict(cls, idl_dict):
parallelism=pb2_object.parallelism,
size=pb2_object.size,
min_successes=pb2_object.min_successes,
min_success_ratio=pb2_object.min_success_ratio,
)
51 changes: 33 additions & 18 deletions tests/flytekit/unit/core/test_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@
from flytekit.tools.translator import get_serializable


@pytest.fixture
def serialization_settings():
default_img = Image(name="default", fqn="test", tag="tag")
return context_manager.SerializationSettings(
project="project",
domain="domain",
version="version",
env=None,
image_config=ImageConfig(default_image=default_img, images=[default_img]),
)


@task
def t1(a: int) -> str:
b = a + 2
Expand Down Expand Up @@ -54,18 +66,13 @@ def test_map_task_types():
_ = map_task(t1, metadata=TaskMetadata(retries=1))(a=["invalid", "args"])


def test_serialization():
def test_serialization(serialization_settings):
maptask = map_task(t1, metadata=TaskMetadata(retries=1))
default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = context_manager.SerializationSettings(
project="project",
domain="domain",
version="version",
env=None,
image_config=ImageConfig(default_image=default_img, images=[default_img]),
)
task_spec = get_serializable(OrderedDict(), serialization_settings, maptask)

# By default all map_task tasks will have their custom fields set.
assert task_spec.template.custom["parallelism"] == "1"
assert task_spec.template.custom["minSuccessRatio"] == 1.0
assert task_spec.template.type == "container_array"
assert task_spec.template.task_type_version == 1
assert task_spec.template.container.args == [
Expand All @@ -90,7 +97,23 @@ def test_serialization():
]


def test_serialization_workflow_def():
@pytest.mark.parametrize(
"custom_fields_dict, expected_custom_fields",
[
({}, {"parallelism": "1", "minSuccessRatio": 1.0}),
({"concurrency": 99}, {"parallelism": "99", "minSuccessRatio": 1.0}),
({"min_success_ratio": 0.271828}, {"parallelism": "1", "minSuccessRatio": 0.271828}),
({"concurrency": 42, "min_success_ratio": 0.31415}, {"parallelism": "42", "minSuccessRatio": 0.31415}),
],
)
def test_serialization_of_custom_fields(custom_fields_dict, expected_custom_fields, serialization_settings):
maptask = map_task(t1, **custom_fields_dict)
task_spec = get_serializable(OrderedDict(), serialization_settings, maptask)

assert task_spec.template.custom == expected_custom_fields


def test_serialization_workflow_def(serialization_settings):
@task
def complex_task(a: int) -> str:
b = a + 2
Expand All @@ -106,14 +129,6 @@ def w1(a: typing.List[int]) -> typing.List[str]:
def w2(a: typing.List[int]) -> typing.List[str]:
return map_task(complex_task, metadata=TaskMetadata(retries=2))(a=a)

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = context_manager.SerializationSettings(
project="project",
domain="domain",
version="version",
env=None,
image_config=ImageConfig(default_image=default_img, images=[default_img]),
)
serialized_control_plane_entities = OrderedDict()
wf1_spec = get_serializable(serialized_control_plane_entities, serialization_settings, w1)
assert wf1_spec.template is not None
Expand Down

0 comments on commit 9684689

Please sign in to comment.