Skip to content

Commit

Permalink
[Error Message] Dataclasses Mismatched Type (#2650)
Browse files Browse the repository at this point in the history
* Show different of types in dataclass when transforming error

Signed-off-by: Future-Outlier <[email protected]>

* add tests for dataclass

Signed-off-by: Future-Outlier <[email protected]>

* fix tests

Signed-off-by: Future-Outlier <[email protected]>

---------

Signed-off-by: Future-Outlier <[email protected]>
  • Loading branch information
Future-Outlier authored Aug 6, 2024
1 parent 243e1be commit d802c7e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 4 deletions.
6 changes: 5 additions & 1 deletion flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,11 @@ def _output_to_literal_map(self, native_outputs: Dict[int, Any], ctx: FlyteConte
except Exception as e:
# only show the name of output key if it's user-defined (by default Flyte names these as "o<n>")
key = k if k != f"o{i}" else i
msg = f"Failed to convert outputs of task '{self.name}' at position {key}:\n {e}"
msg = (
f"Failed to convert outputs of task '{self.name}' at position {key}.\n"
f"Failed to convert type {type(native_outputs_as_map[expected_output_names[i]])} to type {py_type}.\n"
f"Error Message: {e}."
)
logger.error(msg)
raise TypeError(msg) from e
# Now check if there is any output metadata associated with this output variable and attach it to the
Expand Down
30 changes: 27 additions & 3 deletions tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,6 +1568,17 @@ def t2() -> Bar:


def test_error_messages():
@dataclass
class DC1:
a: int
b: str

@dataclass
class DC2:
a: int
b: str
c: int

@task
def foo(a: int, b: str) -> typing.Tuple[int, str]:
return 10, "hello"
Expand All @@ -1580,6 +1591,10 @@ def foo2(a: int, b: str) -> typing.Tuple[int, str]:
def foo3(a: typing.Dict) -> typing.Dict:
return a

@task
def foo4(input: DC1=DC1(1, 'a')) -> DC2:
return input # type: ignore

# pytest-xdist uses `__channelexec__` as the top-level module
running_xdist = os.environ.get("PYTEST_XDIST_WORKER") is not None
prefix = "__channelexec__." if running_xdist else ""
Expand All @@ -1596,9 +1611,9 @@ def foo3(a: typing.Dict) -> typing.Dict:
with pytest.raises(
TypeError,
match=(
f"Failed to convert outputs of task '{prefix}tests.flytekit.unit.core.test_type_hints.foo2' "
"at position 0:\n"
" Expected value of type <class 'int'> but got 'hello' of type <class 'str'>"
f"Failed to convert outputs of task '{prefix}tests.flytekit.unit.core.test_type_hints.foo2' at position 0.\n"
f"Failed to convert type <class 'str'> to type <class 'int'>.\n"
"Error Message: Expected value of type <class 'int'> but got 'hello' of type <class 'str'>."
),
):
foo2(a=10, b="hello")
Expand All @@ -1610,6 +1625,15 @@ def foo3(a: typing.Dict) -> typing.Dict:
):
foo3(a=[{"hello": 2}])

with pytest.raises(
TypeError,
match=(
f"Failed to convert outputs of task '{prefix}tests.flytekit.unit.core.test_type_hints.foo4' at position 0.\n"
f"Failed to convert type <class 'tests.flytekit.unit.core.test_type_hints.test_error_messages.<locals>.DC1'> to type <class 'tests.flytekit.unit.core.test_type_hints.test_error_messages.<locals>.DC2'>.\n"
"Error Message: 'DC1' object has no attribute 'c'."
),
):
foo4()

def test_failure_node():
@task
Expand Down

0 comments on commit d802c7e

Please sign in to comment.