Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set default values to map task template #841

Merged
merged 5 commits into from
Feb 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flytekit/core/map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def _raw_execute(self, **kwargs) -> Any:
return outputs


def map_task(task_function: PythonFunctionTask, concurrency: int = None, min_success_ratio: float = None, **kwargs):
def map_task(task_function: PythonFunctionTask, concurrency: int = 0, min_success_ratio: float = 1.0, **kwargs):
"""
Use a map task for parallelizable tasks that run across a list of an input type. A map task can be composed of
any individual :py:class:`flytekit.PythonFunctionTask`.
Expand All @@ -231,7 +231,7 @@ def map_task(task_function: PythonFunctionTask, concurrency: int = None, min_suc
:param task_function: This argument is implicitly passed and represents the repeatable function
:param concurrency: If specified, this limits the number of mapped tasks than can run in parallel to the given batch
size. If the size of the input exceeds the concurrency value, then multiple batches will be run serially until
all inputs are processed.
all inputs are processed. If left unspecified, this means unbounded concurrency.
:param min_success_ratio: If specified, this determines the minimum fraction of total jobs which can complete
successfully before terminating this task and marking it successful.

Expand Down
31 changes: 23 additions & 8 deletions flytekit/models/array_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,21 @@ def to_dict(self):
"""
:rtype: dict[T, Text]
"""
return _json_format.MessageToDict(
_array_job.ArrayJob(
array_job = None
if self.min_successes is not None:
array_job = _array_job.ArrayJob(
parallelism=self.parallelism,
size=self.size,
min_successes=self.min_successes,
eapolinario marked this conversation as resolved.
Show resolved Hide resolved
)
)
elif self.min_success_ratio is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quick thought, should this be changed to check for 1.0 instead of None? If we default to 1.0 in both flytekit and flyteplugins then there is no reason to write the custom arrayjob if it is 1.0.

Copy link
Collaborator Author

@eapolinario eapolinario Feb 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a great alternative IMO.

edit: if we go that route we are essentially saying that min_success_ratio set to 0 does not make sense, right? In protobuf we won't be able to differentiate between this case and this special case where we're going to assume that 0.0 means 1.0 when we deserialize this in the plugin.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect I don't have enough context for where this function is used. It seems to check if the provided values are different then the default to populate the array_job only if necessary. If they are the default values, then we don't write the custom array job parameters and the backend used the default.

I think the 0 actually doesn't mean 0 only applies for parallelism. Though probably unlikely, having a 0.0 min_success_ratio means the map task should succeed even if all subtasks fail, which could be valid. However, a parallelism of 0, meaning no tasks can execute, doesn't make any sense. Do we need to add another check here if parallelism != 0?

Copy link
Collaborator Author

@eapolinario eapolinario Feb 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect I don't have enough context for where this function is used. It seems to check if the provided values are different then the default to populate the array_job only if necessary. If they are the default values, then we don't write the custom array job parameters and the backend used the default.

This function translates the python values to protobuf, which ends up hitting the issue of default values, specifically how the default value for numeric values is 0. There's no way to set the default value to a different value in a numeric field.

I think the 0 actually doesn't mean 0 only applies for parallelism. Though probably unlikely, having a 0.0 min_success_ratio means the map task should succeed even if all subtasks fail, which could be valid. However, a parallelism of 0, meaning no tasks can execute, doesn't make any sense.

The key here is exactly what you said. In other words, a parallelism of 0 makes sense as a magic value (and not as a user-input). However a min_success_ratio of 0 might make sense as a user-input and since we use a numeric field to represent this in protobuf we cannot distinguish between the two cases in the case of min_success_ratio.

Do we need to add another check here if parallelism != 0?

IMO, to simplify usage we should allow users to set parallelism to 0 with the caveat tha this means unbounded concurrency. The scenario I'm thinking is someone generating these configurations programmatically, if we we don't allow for parallelism to be set to 0 they will have to special case that in their configuration (which is very annoying).

array_job = _array_job.ArrayJob(
parallelism=self.parallelism,
size=self.size,
min_success_ratio=self.min_success_ratio,
)

return _json_format.MessageToDict(array_job)

@classmethod
def from_dict(cls, idl_dict):
Expand All @@ -86,8 +94,15 @@ def from_dict(cls, idl_dict):
"""
pb2_object = _json_format.Parse(_json.dumps(idl_dict), _array_job.ArrayJob())

return cls(
parallelism=pb2_object.parallelism,
size=pb2_object.size,
min_successes=pb2_object.min_successes,
)
if pb2_object.HasField("min_successes"):
return cls(
parallelism=pb2_object.parallelism,
size=pb2_object.size,
min_successes=pb2_object.min_successes,
)
else:
return cls(
parallelism=pb2_object.parallelism,
size=pb2_object.size,
min_success_ratio=pb2_object.min_success_ratio,
)
50 changes: 32 additions & 18 deletions tests/flytekit/unit/core/test_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@
from flytekit.tools.translator import get_serializable


@pytest.fixture
def serialization_settings():
default_img = Image(name="default", fqn="test", tag="tag")
return context_manager.SerializationSettings(
project="project",
domain="domain",
version="version",
env=None,
image_config=ImageConfig(default_image=default_img, images=[default_img]),
)


@task
def t1(a: int) -> str:
b = a + 2
Expand Down Expand Up @@ -54,18 +66,12 @@ def test_map_task_types():
_ = map_task(t1, metadata=TaskMetadata(retries=1))(a=["invalid", "args"])


def test_serialization():
def test_serialization(serialization_settings):
maptask = map_task(t1, metadata=TaskMetadata(retries=1))
default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = context_manager.SerializationSettings(
project="project",
domain="domain",
version="version",
env=None,
image_config=ImageConfig(default_image=default_img, images=[default_img]),
)
task_spec = get_serializable(OrderedDict(), serialization_settings, maptask)

# By default all map_task tasks will have their custom fields set.
assert task_spec.template.custom["minSuccessRatio"] == 1.0
assert task_spec.template.type == "container_array"
assert task_spec.template.task_type_version == 1
assert task_spec.template.container.args == [
Expand All @@ -90,7 +96,23 @@ def test_serialization():
]


def test_serialization_workflow_def():
@pytest.mark.parametrize(
"custom_fields_dict, expected_custom_fields",
[
({}, {"minSuccessRatio": 1.0}),
({"concurrency": 99}, {"parallelism": "99", "minSuccessRatio": 1.0}),
({"min_success_ratio": 0.271828}, {"minSuccessRatio": 0.271828}),
({"concurrency": 42, "min_success_ratio": 0.31415}, {"parallelism": "42", "minSuccessRatio": 0.31415}),
],
)
def test_serialization_of_custom_fields(custom_fields_dict, expected_custom_fields, serialization_settings):
maptask = map_task(t1, **custom_fields_dict)
task_spec = get_serializable(OrderedDict(), serialization_settings, maptask)

assert task_spec.template.custom == expected_custom_fields


def test_serialization_workflow_def(serialization_settings):
@task
def complex_task(a: int) -> str:
b = a + 2
Expand All @@ -106,14 +128,6 @@ def w1(a: typing.List[int]) -> typing.List[str]:
def w2(a: typing.List[int]) -> typing.List[str]:
return map_task(complex_task, metadata=TaskMetadata(retries=2))(a=a)

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = context_manager.SerializationSettings(
project="project",
domain="domain",
version="version",
env=None,
image_config=ImageConfig(default_image=default_img, images=[default_img]),
)
serialized_control_plane_entities = OrderedDict()
wf1_spec = get_serializable(serialized_control_plane_entities, serialization_settings, w1)
assert wf1_spec.template is not None
Expand Down