Skip to content

Commit

Permalink
Better error for min_success_ratio<1 (#2724)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
pingsutw authored Sep 3, 2024
1 parent 952a17a commit a5c44cd
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 8 deletions.
25 changes: 20 additions & 5 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,20 @@ def binding_data_from_python_std(
# This handles the case where the given value is the output of another task
if isinstance(t_value, Promise):
if not t_value.is_ready:
nodes.append(t_value.ref.node) # keeps track of upstream nodes
node = t_value.ref.node
if node.flyte_entity and hasattr(node.flyte_entity, "interface"):
upstream_lt_type = node.flyte_entity.interface.outputs[t_value.ref.var].type
# if an upstream type is a list of unions, make sure the downstream type is a list of unions
# this is just a very limited test case for handling common map task type mis-matches so that we can show
# the user more information without relying on the user to register with Admin to trigger the compiler
if upstream_lt_type.collection_type and upstream_lt_type.collection_type.union_type:
if not (expected_literal_type.collection_type and expected_literal_type.collection_type.union_type):
upstream_python_type = node.flyte_entity.python_interface.outputs[t_value.ref.var]
raise AssertionError(
f"Expected type '{t_value_type}' does not match upstream type '{upstream_python_type}'"
)

nodes.append(node) # keeps track of upstream nodes
return _literals_models.BindingData(promise=t_value.ref)

elif isinstance(t_value, VoidPromise):
Expand Down Expand Up @@ -1079,8 +1092,9 @@ def create_and_link_node_from_remote(
bindings.append(b)
nodes.extend(n)
used_inputs.add(k)
except Exception as e:
raise AssertionError(f"Failed to Bind variable {k} for function {entity.name}.") from e
except Exception as exc:
exc.args = (f"Failed to Bind variable '{k}' for function '{entity.name}':\n {exc.args[0]}",)
raise

extra_inputs = used_inputs ^ set(kwargs.keys())
if len(extra_inputs) > 0:
Expand Down Expand Up @@ -1186,8 +1200,9 @@ def create_and_link_node(
bindings.append(b)
nodes.extend(n)
used_inputs.add(k)
except Exception as e:
raise AssertionError(f"Failed to Bind variable {k} for function {entity.name}.") from e
except Exception as exc:
exc.args = (f"Failed to Bind variable '{k}' for function '{entity.name}':\n {exc.args[0]}",)
raise

extra_inputs = used_inputs ^ set(kwargs.keys())
if len(extra_inputs) > 0:
Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1477,7 +1477,7 @@ def _is_union_type(t):
else:
UnionType = None

return t is typing.Union or get_origin(t) is Union or UnionType and isinstance(t, UnionType)
return t is typing.Union or get_origin(t) is typing.Union or UnionType and isinstance(t, UnionType)


class UnionTransformer(TypeTransformer[T]):
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-pandera/tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def my_wf() -> pandera.typing.DataFrame[OutSchema]:
# raise error when defining workflow using invalid data
invalid_df = pandas.DataFrame({"col1": [1, 2, 3], "col2": list("abc")})

with pytest.raises(AssertionError):
with pytest.raises(pandera.errors.SchemaError):

@workflow
def invalid_wf() -> pandera.typing.DataFrame[OutSchema]:
Expand Down
29 changes: 29 additions & 0 deletions tests/flytekit/unit/core/test_array_node_map_task.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import functools
import os
import typing
from collections import OrderedDict
from typing import List
from typing_extensions import Annotated
import tempfile

import pytest

from flytekit import dynamic, map_task, task, workflow
from flytekit.types.directory import FlyteDirectory
from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings
from flytekit.core import context_manager
from flytekit.core.array_node_map_task import ArrayNodeMapTask, ArrayNodeMapTaskResolver
Expand Down Expand Up @@ -435,3 +438,29 @@ def test_wf():

with pytest.raises(ValueError):
map_task(test_wf)


def test_mis_match():
@task
def generate_directory(word: str) -> FlyteDirectory:
temp_dir1 = tempfile.TemporaryDirectory(delete=False)
with open(os.path.join(temp_dir1.name, "file.txt"), "w") as tmp:
tmp.write(f"Hello world {word}!\n")
return FlyteDirectory(path=temp_dir1.name)

@task
def consume_directories(dirs: List[FlyteDirectory]):
for d in dirs:
print(f"Directory: {d.path} {d._remote_source}")
for path_info, other_info in d.crawl():
print(path_info)

mt = map_task(generate_directory, min_success_ratio=0.1)

@workflow
def wf():
dirs = mt(word=["one", "two", "three"])
consume_directories(dirs=dirs)

with pytest.raises(AssertionError):
wf.compile()
2 changes: 1 addition & 1 deletion tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3035,7 +3035,7 @@ def wf3() -> Base:
assert child_data.b == 12
assert isinstance(child_data, Child1)

with pytest.raises(AssertionError):
with pytest.raises(AttributeError):
wf2()

base_data = wf3()
Expand Down

0 comments on commit a5c44cd

Please sign in to comment.