diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 925f79c770..c95d6a1576 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -17,6 +17,7 @@ from google.protobuf.json_format import MessageToDict as _MessageToDict from google.protobuf.json_format import ParseDict as _ParseDict from google.protobuf.struct_pb2 import Struct +from marshmallow_enum import EnumField, LoadDumpOptions from marshmallow_jsonschema import JSONSchema from flytekit.common.exceptions import user as user_exceptions @@ -226,7 +227,13 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: ) schema = None try: - schema = JSONSchema().dump(cast(DataClassJsonMixin, t).schema()) + s = cast(DataClassJsonMixin, t).schema() + for _, v in s.fields.items(): + # marshmallow-jsonschema only supports enums loaded by name. + # https://github.com/fuhrysteve/marshmallow-jsonschema/blob/81eada1a0c42ff67de216923968af0a6b54e5dcb/marshmallow_jsonschema/base.py#L228 + if isinstance(v, EnumField): + v.load_by = LoadDumpOptions.name + schema = JSONSchema().dump(s) except Exception as e: logger.warn("failed to extract schema for object %s, (will run schemaless) error: %s", str(t), e) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 6c660b7106..3633cb194a 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -10,6 +10,7 @@ from flyteidl.core import errors_pb2 from google.protobuf import json_format as _json_format from google.protobuf import struct_pb2 as _struct +from marshmallow_enum import LoadDumpOptions from marshmallow_jsonschema import JSONSchema from flytekit.common.exceptions import user as user_exceptions @@ -547,6 +548,28 @@ def test_enum_type(): TypeEngine.to_literal_type(UnsupportedEnumValues) +def test_enum_in_dataclass(): + @dataclass_json + @dataclass + class Datum(object): + x: int + y: Color + + lt = TypeEngine.to_literal_type(Datum) + schema = Datum.schema() + schema.fields["y"].load_by = LoadDumpOptions.name + assert lt.metadata == JSONSchema().dump(schema) + + transformer = DataclassTransformer() + ctx = FlyteContext.current_context() + datum = Datum(5, Color.RED) + lv = transformer.to_literal(ctx, datum, Datum, lt) + gt = transformer.guess_python_type(lt) + pv = transformer.to_python_value(ctx, lv, expected_python_type=gt) + assert datum.x == pv.x + assert datum.y.value == pv.y + + @pytest.mark.parametrize( "python_value,python_types,expected_literal_map", [ diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index 17b46844c2..ef4bca3981 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -6,6 +6,7 @@ import typing from collections import OrderedDict from dataclasses import dataclass +from enum import Enum import pandas import pytest @@ -1063,6 +1064,29 @@ def wf(x: int, y: int) -> Datum: wf(x=10, y=20) +def test_enum_in_dataclass(): + class Color(Enum): + RED = "red" + GREEN = "green" + BLUE = "blue" + + @dataclass_json + @dataclass + class Datum(object): + x: int + y: Color + + @task + def t1(x: int) -> Datum: + return Datum(x=x, y=Color.RED) + + @workflow + def wf(x: int) -> Datum: + return t1(x=x) + + assert wf(x=10) == Datum(10, Color.RED) + + def test_environment(): @task(environment={"FOO": "foofoo", "BAZ": "baz"}) def t1(a: int) -> str: