From a5c44cd1344a0f4d1c1209c2defb48b289f3532a Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 3 Sep 2024 15:42:28 -0700 Subject: [PATCH] Better error for min_success_ratio<1 (#2724) Signed-off-by: Yee Hing Tong --- flytekit/core/promise.py | 25 ++++++++++++---- flytekit/core/type_engine.py | 2 +- plugins/flytekit-pandera/tests/test_plugin.py | 2 +- .../unit/core/test_array_node_map_task.py | 29 +++++++++++++++++++ tests/flytekit/unit/core/test_type_engine.py | 2 +- 5 files changed, 52 insertions(+), 8 deletions(-) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 44195be6f3..9a8a853981 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -765,7 +765,20 @@ def binding_data_from_python_std( # This handles the case where the given value is the output of another task if isinstance(t_value, Promise): if not t_value.is_ready: - nodes.append(t_value.ref.node) # keeps track of upstream nodes + node = t_value.ref.node + if node.flyte_entity and hasattr(node.flyte_entity, "interface"): + upstream_lt_type = node.flyte_entity.interface.outputs[t_value.ref.var].type + # if an upstream type is a list of unions, make sure the downstream type is a list of unions + # this is just a very limited test case for handling common map task type mis-matches so that we can show + # the user more information without relying on the user to register with Admin to trigger the compiler + if upstream_lt_type.collection_type and upstream_lt_type.collection_type.union_type: + if not (expected_literal_type.collection_type and expected_literal_type.collection_type.union_type): + upstream_python_type = node.flyte_entity.python_interface.outputs[t_value.ref.var] + raise AssertionError( + f"Expected type '{t_value_type}' does not match upstream type '{upstream_python_type}'" + ) + + nodes.append(node) # keeps track of upstream nodes return _literals_models.BindingData(promise=t_value.ref) elif isinstance(t_value, VoidPromise): @@ -1079,8 +1092,9 @@ def create_and_link_node_from_remote( bindings.append(b) nodes.extend(n) used_inputs.add(k) - except Exception as e: - raise AssertionError(f"Failed to Bind variable {k} for function {entity.name}.") from e + except Exception as exc: + exc.args = (f"Failed to Bind variable '{k}' for function '{entity.name}':\n {exc.args[0]}",) + raise extra_inputs = used_inputs ^ set(kwargs.keys()) if len(extra_inputs) > 0: @@ -1186,8 +1200,9 @@ def create_and_link_node( bindings.append(b) nodes.extend(n) used_inputs.add(k) - except Exception as e: - raise AssertionError(f"Failed to Bind variable {k} for function {entity.name}.") from e + except Exception as exc: + exc.args = (f"Failed to Bind variable '{k}' for function '{entity.name}':\n {exc.args[0]}",) + raise extra_inputs = used_inputs ^ set(kwargs.keys()) if len(extra_inputs) > 0: diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 2218ed430a..5948c0beef 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1477,7 +1477,7 @@ def _is_union_type(t): else: UnionType = None - return t is typing.Union or get_origin(t) is Union or UnionType and isinstance(t, UnionType) + return t is typing.Union or get_origin(t) is typing.Union or UnionType and isinstance(t, UnionType) class UnionTransformer(TypeTransformer[T]): diff --git a/plugins/flytekit-pandera/tests/test_plugin.py b/plugins/flytekit-pandera/tests/test_plugin.py index e29a28157d..3c7a5107d4 100644 --- a/plugins/flytekit-pandera/tests/test_plugin.py +++ b/plugins/flytekit-pandera/tests/test_plugin.py @@ -44,7 +44,7 @@ def my_wf() -> pandera.typing.DataFrame[OutSchema]: # raise error when defining workflow using invalid data invalid_df = pandas.DataFrame({"col1": [1, 2, 3], "col2": list("abc")}) - with pytest.raises(AssertionError): + with pytest.raises(pandera.errors.SchemaError): @workflow def invalid_wf() -> pandera.typing.DataFrame[OutSchema]: diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index 74f1868eb4..fa964a71ef 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -1,12 +1,15 @@ import functools +import os import typing from collections import OrderedDict from typing import List from typing_extensions import Annotated +import tempfile import pytest from flytekit import dynamic, map_task, task, workflow +from flytekit.types.directory import FlyteDirectory from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings from flytekit.core import context_manager from flytekit.core.array_node_map_task import ArrayNodeMapTask, ArrayNodeMapTaskResolver @@ -435,3 +438,29 @@ def test_wf(): with pytest.raises(ValueError): map_task(test_wf) + + +def test_mis_match(): + @task + def generate_directory(word: str) -> FlyteDirectory: + temp_dir1 = tempfile.TemporaryDirectory(delete=False) + with open(os.path.join(temp_dir1.name, "file.txt"), "w") as tmp: + tmp.write(f"Hello world {word}!\n") + return FlyteDirectory(path=temp_dir1.name) + + @task + def consume_directories(dirs: List[FlyteDirectory]): + for d in dirs: + print(f"Directory: {d.path} {d._remote_source}") + for path_info, other_info in d.crawl(): + print(path_info) + + mt = map_task(generate_directory, min_success_ratio=0.1) + + @workflow + def wf(): + dirs = mt(word=["one", "two", "three"]) + consume_directories(dirs=dirs) + + with pytest.raises(AssertionError): + wf.compile() diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 8370f96e94..63cfb68e21 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -3035,7 +3035,7 @@ def wf3() -> Base: assert child_data.b == 12 assert isinstance(child_data, Child1) - with pytest.raises(AssertionError): + with pytest.raises(AttributeError): wf2() base_data = wf3()