From 216f4eed7f2221b25441d705493508fb74da0431 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Tue, 1 Feb 2022 11:57:46 -0800 Subject: [PATCH 1/5] Set sane defaults in map task templates Signed-off-by: Eduardo Apolinario --- flytekit/core/map_task.py | 2 +- flytekit/models/array_job.py | 15 +++++-- tests/flytekit/unit/core/test_map_task.py | 51 +++++++++++++++-------- 3 files changed, 46 insertions(+), 22 deletions(-) diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index f760be5d3c..a93d1eef89 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -209,7 +209,7 @@ def _raw_execute(self, **kwargs) -> Any: return outputs -def map_task(task_function: PythonFunctionTask, concurrency: int = None, min_success_ratio: float = None, **kwargs): +def map_task(task_function: PythonFunctionTask, concurrency: int = 1, min_success_ratio: float = 1.0, **kwargs): """ Use a map task for parallelizable tasks that run across a list of an input type. A map task can be composed of any individual :py:class:`flytekit.PythonFunctionTask`. diff --git a/flytekit/models/array_job.py b/flytekit/models/array_job.py index 4e4bf99cc7..03718c8768 100644 --- a/flytekit/models/array_job.py +++ b/flytekit/models/array_job.py @@ -70,13 +70,21 @@ def to_dict(self): """ :rtype: dict[T, Text] """ - return _json_format.MessageToDict( - _array_job.ArrayJob( + array_job = None + if self.min_successes is not None: + array_job = _array_job.ArrayJob( parallelism=self.parallelism, size=self.size, min_successes=self.min_successes, ) - ) + elif self.min_success_ratio is not None: + array_job = _array_job.ArrayJob( + parallelism=self.parallelism, + size=self.size, + min_success_ratio=self.min_success_ratio, + ) + + return _json_format.MessageToDict(array_job) @classmethod def from_dict(cls, idl_dict): @@ -90,4 +98,5 @@ def from_dict(cls, idl_dict): parallelism=pb2_object.parallelism, size=pb2_object.size, min_successes=pb2_object.min_successes, + min_success_ratio=pb2_object.min_success_ratio, ) diff --git a/tests/flytekit/unit/core/test_map_task.py b/tests/flytekit/unit/core/test_map_task.py index d1f95852c1..670088463f 100644 --- a/tests/flytekit/unit/core/test_map_task.py +++ b/tests/flytekit/unit/core/test_map_task.py @@ -12,6 +12,18 @@ from flytekit.tools.translator import get_serializable +@pytest.fixture +def serialization_settings(): + default_img = Image(name="default", fqn="test", tag="tag") + return context_manager.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + + @task def t1(a: int) -> str: b = a + 2 @@ -54,18 +66,13 @@ def test_map_task_types(): _ = map_task(t1, metadata=TaskMetadata(retries=1))(a=["invalid", "args"]) -def test_serialization(): +def test_serialization(serialization_settings): maptask = map_task(t1, metadata=TaskMetadata(retries=1)) - default_img = Image(name="default", fqn="test", tag="tag") - serialization_settings = context_manager.SerializationSettings( - project="project", - domain="domain", - version="version", - env=None, - image_config=ImageConfig(default_image=default_img, images=[default_img]), - ) task_spec = get_serializable(OrderedDict(), serialization_settings, maptask) + # By default all map_task tasks will have their custom fields set. + assert task_spec.template.custom["parallelism"] == "1" + assert task_spec.template.custom["minSuccessRatio"] == 1.0 assert task_spec.template.type == "container_array" assert task_spec.template.task_type_version == 1 assert task_spec.template.container.args == [ @@ -90,7 +97,23 @@ def test_serialization(): ] -def test_serialization_workflow_def(): +@pytest.mark.parametrize( + "custom_fields_dict, expected_custom_fields", + [ + ({}, {"parallelism": "1", "minSuccessRatio": 1.0}), + ({"concurrency": 99}, {"parallelism": "99", "minSuccessRatio": 1.0}), + ({"min_success_ratio": 0.271828}, {"parallelism": "1", "minSuccessRatio": 0.271828}), + ({"concurrency": 42, "min_success_ratio": 0.31415}, {"parallelism": "42", "minSuccessRatio": 0.31415}), + ], +) +def test_serialization_of_custom_fields(custom_fields_dict, expected_custom_fields, serialization_settings): + maptask = map_task(t1, **custom_fields_dict) + task_spec = get_serializable(OrderedDict(), serialization_settings, maptask) + + assert task_spec.template.custom == expected_custom_fields + + +def test_serialization_workflow_def(serialization_settings): @task def complex_task(a: int) -> str: b = a + 2 @@ -106,14 +129,6 @@ def w1(a: typing.List[int]) -> typing.List[str]: def w2(a: typing.List[int]) -> typing.List[str]: return map_task(complex_task, metadata=TaskMetadata(retries=2))(a=a) - default_img = Image(name="default", fqn="test", tag="tag") - serialization_settings = context_manager.SerializationSettings( - project="project", - domain="domain", - version="version", - env=None, - image_config=ImageConfig(default_image=default_img, images=[default_img]), - ) serialized_control_plane_entities = OrderedDict() wf1_spec = get_serializable(serialized_control_plane_entities, serialization_settings, w1) assert wf1_spec.template is not None From 2e1636dabcef561432bb504b8480765a673cc380 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Tue, 1 Feb 2022 12:05:32 -0800 Subject: [PATCH 2/5] Remove unused method Signed-off-by: Eduardo Apolinario --- flytekit/models/array_job.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/flytekit/models/array_job.py b/flytekit/models/array_job.py index 03718c8768..f2112a4163 100644 --- a/flytekit/models/array_job.py +++ b/flytekit/models/array_job.py @@ -85,18 +85,3 @@ def to_dict(self): ) return _json_format.MessageToDict(array_job) - - @classmethod - def from_dict(cls, idl_dict): - """ - :param dict[T, Text] idl_dict: - :rtype: ArrayJob - """ - pb2_object = _json_format.Parse(_json.dumps(idl_dict), _array_job.ArrayJob()) - - return cls( - parallelism=pb2_object.parallelism, - size=pb2_object.size, - min_successes=pb2_object.min_successes, - min_success_ratio=pb2_object.min_success_ratio, - ) From 9b043ce34460e455c2fde92939a7ca0019c8e707 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Tue, 1 Feb 2022 12:34:27 -0800 Subject: [PATCH 3/5] Put ArrayJob.from_dict back Signed-off-by: Eduardo Apolinario --- flytekit/models/array_job.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/flytekit/models/array_job.py b/flytekit/models/array_job.py index f2112a4163..2c86acdd7e 100644 --- a/flytekit/models/array_job.py +++ b/flytekit/models/array_job.py @@ -85,3 +85,24 @@ def to_dict(self): ) return _json_format.MessageToDict(array_job) + + @classmethod + def from_dict(cls, idl_dict): + """ + :param dict[T, Text] idl_dict: + :rtype: ArrayJob + """ + pb2_object = _json_format.Parse(_json.dumps(idl_dict), _array_job.ArrayJob()) + + if pb2_object.HasField("min_successes"): + return cls( + parallelism=pb2_object.parallelism, + size=pb2_object.size, + min_successes=pb2_object.min_successes, + ) + else: + return cls( + parallelism=pb2_object.parallelism, + size=pb2_object.size, + min_success_ratio=pb2_object.min_success_ratio, + ) From 629511833411fca68e6d102b8e32d61a3e031024 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Tue, 1 Feb 2022 17:49:09 -0800 Subject: [PATCH 4/5] Define parallelism=0 as unbounded Signed-off-by: Eduardo Apolinario --- flytekit/core/map_task.py | 4 ++-- flytekit/models/array_job.py | 5 ++++- tests/flytekit/unit/core/test_map_task.py | 6 +++--- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index a93d1eef89..4ec9f64a2d 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -209,7 +209,7 @@ def _raw_execute(self, **kwargs) -> Any: return outputs -def map_task(task_function: PythonFunctionTask, concurrency: int = 1, min_success_ratio: float = 1.0, **kwargs): +def map_task(task_function: PythonFunctionTask, concurrency: int = 0, min_success_ratio: float = 1.0, **kwargs): """ Use a map task for parallelizable tasks that run across a list of an input type. A map task can be composed of any individual :py:class:`flytekit.PythonFunctionTask`. @@ -231,7 +231,7 @@ def map_task(task_function: PythonFunctionTask, concurrency: int = 1, min_succes :param task_function: This argument is implicitly passed and represents the repeatable function :param concurrency: If specified, this limits the number of mapped tasks than can run in parallel to the given batch size. If the size of the input exceeds the concurrency value, then multiple batches will be run serially until - all inputs are processed. + all inputs are processed. If left unspecified, this means unbounded concurrency. :param min_success_ratio: If specified, this determines the minimum fraction of total jobs which can complete successfully before terminating this task and marking it successful. diff --git a/flytekit/models/array_job.py b/flytekit/models/array_job.py index 2c86acdd7e..b2e53e3884 100644 --- a/flytekit/models/array_job.py +++ b/flytekit/models/array_job.py @@ -84,7 +84,10 @@ def to_dict(self): min_success_ratio=self.min_success_ratio, ) - return _json_format.MessageToDict(array_job) + array_job_dict = _json_format.MessageToDict(array_job) + if self.parallelism is not None and self.parallelism == 0: + array_job_dict['parallelism'] = '0' + return array_job_dict @classmethod def from_dict(cls, idl_dict): diff --git a/tests/flytekit/unit/core/test_map_task.py b/tests/flytekit/unit/core/test_map_task.py index 670088463f..952f2ffbfd 100644 --- a/tests/flytekit/unit/core/test_map_task.py +++ b/tests/flytekit/unit/core/test_map_task.py @@ -71,7 +71,7 @@ def test_serialization(serialization_settings): task_spec = get_serializable(OrderedDict(), serialization_settings, maptask) # By default all map_task tasks will have their custom fields set. - assert task_spec.template.custom["parallelism"] == "1" + assert task_spec.template.custom["parallelism"] == "0" assert task_spec.template.custom["minSuccessRatio"] == 1.0 assert task_spec.template.type == "container_array" assert task_spec.template.task_type_version == 1 @@ -100,9 +100,9 @@ def test_serialization(serialization_settings): @pytest.mark.parametrize( "custom_fields_dict, expected_custom_fields", [ - ({}, {"parallelism": "1", "minSuccessRatio": 1.0}), + ({}, {"parallelism": "0", "minSuccessRatio": 1.0}), ({"concurrency": 99}, {"parallelism": "99", "minSuccessRatio": 1.0}), - ({"min_success_ratio": 0.271828}, {"parallelism": "1", "minSuccessRatio": 0.271828}), + ({"min_success_ratio": 0.271828}, {"parallelism": "0", "minSuccessRatio": 0.271828}), ({"concurrency": 42, "min_success_ratio": 0.31415}, {"parallelism": "42", "minSuccessRatio": 0.31415}), ], ) From bb98e76d1af255927d8b6dd8e9394f0cd286b66f Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario Date: Thu, 3 Feb 2022 19:57:44 -0800 Subject: [PATCH 5/5] Remove special case to handle 0 Signed-off-by: Eduardo Apolinario --- flytekit/models/array_job.py | 5 +---- tests/flytekit/unit/core/test_map_task.py | 5 ++--- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/flytekit/models/array_job.py b/flytekit/models/array_job.py index b2e53e3884..2c86acdd7e 100644 --- a/flytekit/models/array_job.py +++ b/flytekit/models/array_job.py @@ -84,10 +84,7 @@ def to_dict(self): min_success_ratio=self.min_success_ratio, ) - array_job_dict = _json_format.MessageToDict(array_job) - if self.parallelism is not None and self.parallelism == 0: - array_job_dict['parallelism'] = '0' - return array_job_dict + return _json_format.MessageToDict(array_job) @classmethod def from_dict(cls, idl_dict): diff --git a/tests/flytekit/unit/core/test_map_task.py b/tests/flytekit/unit/core/test_map_task.py index 952f2ffbfd..4eb44d6e76 100644 --- a/tests/flytekit/unit/core/test_map_task.py +++ b/tests/flytekit/unit/core/test_map_task.py @@ -71,7 +71,6 @@ def test_serialization(serialization_settings): task_spec = get_serializable(OrderedDict(), serialization_settings, maptask) # By default all map_task tasks will have their custom fields set. - assert task_spec.template.custom["parallelism"] == "0" assert task_spec.template.custom["minSuccessRatio"] == 1.0 assert task_spec.template.type == "container_array" assert task_spec.template.task_type_version == 1 @@ -100,9 +99,9 @@ def test_serialization(serialization_settings): @pytest.mark.parametrize( "custom_fields_dict, expected_custom_fields", [ - ({}, {"parallelism": "0", "minSuccessRatio": 1.0}), + ({}, {"minSuccessRatio": 1.0}), ({"concurrency": 99}, {"parallelism": "99", "minSuccessRatio": 1.0}), - ({"min_success_ratio": 0.271828}, {"parallelism": "0", "minSuccessRatio": 0.271828}), + ({"min_success_ratio": 0.271828}, {"minSuccessRatio": 0.271828}), ({"concurrency": 42, "min_success_ratio": 0.31415}, {"parallelism": "42", "minSuccessRatio": 0.31415}), ], )