From 10bd38f69c78d778232f307360833924fec590e8 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario <653394+eapolinario@users.noreply.github.com> Date: Mon, 3 Apr 2023 16:40:23 -0700 Subject: [PATCH] Disallow partial lists in map tasks (#1577) * Disallow partial lists in map tasks Signed-off-by: eduardo apolinario * Lint Signed-off-by: eduardo apolinario --------- Signed-off-by: eduardo apolinario Co-authored-by: eduardo apolinario --- flytekit/core/map_task.py | 4 ++ tests/flytekit/unit/core/test_partials.py | 62 ++++++++++++++++++----- 2 files changed, 54 insertions(+), 12 deletions(-) diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index 83b2542fe3..6a11c9dc50 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -54,6 +54,10 @@ def __init__( """ self._partial = None if isinstance(python_function_task, functools.partial): + # TODO: We should be able to support partial tasks with lists as inputs + for arg in python_function_task.keywords.values(): + if isinstance(arg, list): + raise ValueError("Map tasks do not support partial tasks with lists as inputs. ") self._partial = python_function_task actual_task = self._partial.func else: diff --git a/tests/flytekit/unit/core/test_partials.py b/tests/flytekit/unit/core/test_partials.py index 0a78c825f8..24e3908d1d 100644 --- a/tests/flytekit/unit/core/test_partials.py +++ b/tests/flytekit/unit/core/test_partials.py @@ -122,6 +122,29 @@ def wf_in(a: typing.List[int]): assert wf_spec.template.nodes[1].inputs[2].binding.scalar is not None +def test_lists_cannot_be_used_in_partials(): + @task + def t(a: int, b: typing.List[str]) -> str: + return str(a) + str(b) + + with pytest.raises(ValueError): + map_task(partial(t, b=["hello", "world"]))(a=[1, 2, 3]) + + @task + def t_multilist(a: int, b: typing.List[float], c: typing.List[int]) -> str: + return str(a) + str(b) + str(c) + + with pytest.raises(ValueError): + map_task(partial(t_multilist, b=[3.14, 12.34, 9876.5432], c=[42, 99]))(a=[1, 2, 3, 4]) + + @task + def t_list_of_lists(a: typing.List[typing.List[float]], b: int) -> str: + return str(a) + str(b) + + with pytest.raises(ValueError): + map_task(partial(t_list_of_lists, a=[[3.14]]))(b=[1, 2, 3, 4]) + + def test_everything(): @task def get_static_list() -> typing.List[float]: @@ -140,33 +163,48 @@ def get_list_of_pd(s: int) -> typing.List[pd.DataFrame]: def t3(a: int, b: str, c: typing.List[float], d: typing.List[float], a2: pd.DataFrame) -> str: return str(a) + f"pdsize{len(a2)}" + b + str(c) + "&&" + str(d) - t3_bind_b1 = partial(t3, b="hello") t3_bind_b2 = partial(t3, b="world") - t3_bind_c1 = partial(t3_bind_b1, c=[6.674, 1.618, 6.626], d=[1.0]) + # TODO: partial lists are not supported yet. + # t3_bind_b1 = partial(t3, b="hello") + # t3_bind_c1 = partial(t3_bind_b1, c=[6.674, 1.618, 6.626], d=[1.0]) + # mt1 = map_task(t3_bind_c1) - mt1 = map_task(t3_bind_c1) + mt1 = map_task(t3_bind_b2) mr = MapTaskResolver() aa = mr.loader_args(serialization_settings, mt1) # Check bound vars aa = aa[1].split(",") aa.sort() - assert aa == ["b", "c", "d"] + assert aa == ["b"] @task - def print_lists(i: typing.List[str], j: typing.List[str]) -> str: + def print_lists(i: typing.List[str], j: typing.List[str], k: typing.List[str]) -> str: print(f"First: {i}") print(f"Second: {j}") - return f"{i}-{j}" + print(f"Third: {k}") + return f"{i}-{j}-{k}" @dynamic def dt1(a: typing.List[int], a2: typing.List[pd.DataFrame], sl: typing.List[float]) -> str: - i = mt1(a=a, a2=a2) - t3_bind_c2 = partial(t3_bind_b2, c=[1.0, 2.0, 3.0], d=sl) - mt_in2 = map_task(t3_bind_c2) + i = mt1(a=a, a2=a2, c=[[1.1, 2.0, 3.0], [1.1, 2.0, 3.0]], d=[sl, sl]) + mt_in2 = map_task(t3_bind_b2) dfs = get_list_of_pd(s=3) - j = mt_in2(a=[3, 4, 5], a2=dfs) - return print_lists(i=i, j=j) + j = mt_in2(a=[3, 4, 5], a2=dfs, c=[[1.0], [2.0], [3.0]], d=[sl, sl, sl]) + + # Test a2 bound to a fixed dataframe + t3_bind_a2 = partial(t3_bind_b2, a2=a2[0]) + + mt_in3 = map_task(t3_bind_a2) + + aa = mr.loader_args(serialization_settings, mt_in3) + # Check bound vars + aa = aa[1].split(",") + aa.sort() + assert aa == ["a2", "b"] + + k = mt_in3(a=[3, 4, 5], c=[[1.0], [2.0], [3.0]], d=[sl, sl, sl]) + return print_lists(i=i, j=j, k=k) @workflow def wf_dt(a: typing.List[int]) -> str: @@ -177,5 +215,5 @@ def wf_dt(a: typing.List[int]) -> str: print(wf_dt(a=[1, 2])) assert ( wf_dt(a=[1, 2]) - == "['1pdsize2hello[6.674, 1.618, 6.626]&&[1.0]', '2pdsize3hello[6.674, 1.618, 6.626]&&[1.0]']-['3pdsize2world[1.0, 2.0, 3.0]&&[3.14, 2.718]', '4pdsize3world[1.0, 2.0, 3.0]&&[3.14, 2.718]', '5pdsize2world[1.0, 2.0, 3.0]&&[3.14, 2.718]']" + == "['1pdsize2world[1.1, 2.0, 3.0]&&[3.14, 2.718]', '2pdsize3world[1.1, 2.0, 3.0]&&[3.14, 2.718]']-['3pdsize2world[1.0]&&[3.14, 2.718]', '4pdsize3world[2.0]&&[3.14, 2.718]', '5pdsize2world[3.0]&&[3.14, 2.718]']-['3pdsize2world[1.0]&&[3.14, 2.718]', '4pdsize2world[2.0]&&[3.14, 2.718]', '5pdsize2world[3.0]&&[3.14, 2.718]']" )