Skip to content

Commit

Permalink
Disallow partial lists in map tasks (#1577)
Browse files Browse the repository at this point in the history
* Disallow partial lists in map tasks

Signed-off-by: eduardo apolinario <[email protected]>

* Lint

Signed-off-by: eduardo apolinario <[email protected]>

---------

Signed-off-by: eduardo apolinario <[email protected]>
Co-authored-by: eduardo apolinario <[email protected]>
  • Loading branch information
eapolinario and eapolinario authored Apr 3, 2023
1 parent 6530f12 commit 10bd38f
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 12 deletions.
4 changes: 4 additions & 0 deletions flytekit/core/map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def __init__(
"""
self._partial = None
if isinstance(python_function_task, functools.partial):
# TODO: We should be able to support partial tasks with lists as inputs
for arg in python_function_task.keywords.values():
if isinstance(arg, list):
raise ValueError("Map tasks do not support partial tasks with lists as inputs. ")
self._partial = python_function_task
actual_task = self._partial.func
else:
Expand Down
62 changes: 50 additions & 12 deletions tests/flytekit/unit/core/test_partials.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,29 @@ def wf_in(a: typing.List[int]):
assert wf_spec.template.nodes[1].inputs[2].binding.scalar is not None


def test_lists_cannot_be_used_in_partials():
@task
def t(a: int, b: typing.List[str]) -> str:
return str(a) + str(b)

with pytest.raises(ValueError):
map_task(partial(t, b=["hello", "world"]))(a=[1, 2, 3])

@task
def t_multilist(a: int, b: typing.List[float], c: typing.List[int]) -> str:
return str(a) + str(b) + str(c)

with pytest.raises(ValueError):
map_task(partial(t_multilist, b=[3.14, 12.34, 9876.5432], c=[42, 99]))(a=[1, 2, 3, 4])

@task
def t_list_of_lists(a: typing.List[typing.List[float]], b: int) -> str:
return str(a) + str(b)

with pytest.raises(ValueError):
map_task(partial(t_list_of_lists, a=[[3.14]]))(b=[1, 2, 3, 4])


def test_everything():
@task
def get_static_list() -> typing.List[float]:
Expand All @@ -140,33 +163,48 @@ def get_list_of_pd(s: int) -> typing.List[pd.DataFrame]:
def t3(a: int, b: str, c: typing.List[float], d: typing.List[float], a2: pd.DataFrame) -> str:
return str(a) + f"pdsize{len(a2)}" + b + str(c) + "&&" + str(d)

t3_bind_b1 = partial(t3, b="hello")
t3_bind_b2 = partial(t3, b="world")
t3_bind_c1 = partial(t3_bind_b1, c=[6.674, 1.618, 6.626], d=[1.0])
# TODO: partial lists are not supported yet.
# t3_bind_b1 = partial(t3, b="hello")
# t3_bind_c1 = partial(t3_bind_b1, c=[6.674, 1.618, 6.626], d=[1.0])
# mt1 = map_task(t3_bind_c1)

mt1 = map_task(t3_bind_c1)
mt1 = map_task(t3_bind_b2)

mr = MapTaskResolver()
aa = mr.loader_args(serialization_settings, mt1)
# Check bound vars
aa = aa[1].split(",")
aa.sort()
assert aa == ["b", "c", "d"]
assert aa == ["b"]

@task
def print_lists(i: typing.List[str], j: typing.List[str]) -> str:
def print_lists(i: typing.List[str], j: typing.List[str], k: typing.List[str]) -> str:
print(f"First: {i}")
print(f"Second: {j}")
return f"{i}-{j}"
print(f"Third: {k}")
return f"{i}-{j}-{k}"

@dynamic
def dt1(a: typing.List[int], a2: typing.List[pd.DataFrame], sl: typing.List[float]) -> str:
i = mt1(a=a, a2=a2)
t3_bind_c2 = partial(t3_bind_b2, c=[1.0, 2.0, 3.0], d=sl)
mt_in2 = map_task(t3_bind_c2)
i = mt1(a=a, a2=a2, c=[[1.1, 2.0, 3.0], [1.1, 2.0, 3.0]], d=[sl, sl])
mt_in2 = map_task(t3_bind_b2)
dfs = get_list_of_pd(s=3)
j = mt_in2(a=[3, 4, 5], a2=dfs)
return print_lists(i=i, j=j)
j = mt_in2(a=[3, 4, 5], a2=dfs, c=[[1.0], [2.0], [3.0]], d=[sl, sl, sl])

# Test a2 bound to a fixed dataframe
t3_bind_a2 = partial(t3_bind_b2, a2=a2[0])

mt_in3 = map_task(t3_bind_a2)

aa = mr.loader_args(serialization_settings, mt_in3)
# Check bound vars
aa = aa[1].split(",")
aa.sort()
assert aa == ["a2", "b"]

k = mt_in3(a=[3, 4, 5], c=[[1.0], [2.0], [3.0]], d=[sl, sl, sl])
return print_lists(i=i, j=j, k=k)

@workflow
def wf_dt(a: typing.List[int]) -> str:
Expand All @@ -177,5 +215,5 @@ def wf_dt(a: typing.List[int]) -> str:
print(wf_dt(a=[1, 2]))
assert (
wf_dt(a=[1, 2])
== "['1pdsize2hello[6.674, 1.618, 6.626]&&[1.0]', '2pdsize3hello[6.674, 1.618, 6.626]&&[1.0]']-['3pdsize2world[1.0, 2.0, 3.0]&&[3.14, 2.718]', '4pdsize3world[1.0, 2.0, 3.0]&&[3.14, 2.718]', '5pdsize2world[1.0, 2.0, 3.0]&&[3.14, 2.718]']"
== "['1pdsize2world[1.1, 2.0, 3.0]&&[3.14, 2.718]', '2pdsize3world[1.1, 2.0, 3.0]&&[3.14, 2.718]']-['3pdsize2world[1.0]&&[3.14, 2.718]', '4pdsize3world[2.0]&&[3.14, 2.718]', '5pdsize2world[3.0]&&[3.14, 2.718]']-['3pdsize2world[1.0]&&[3.14, 2.718]', '4pdsize2world[2.0]&&[3.14, 2.718]', '5pdsize2world[3.0]&&[3.14, 2.718]']"
)

0 comments on commit 10bd38f

Please sign in to comment.