Skip to content

Commit

Permalink
Support enum in dataclass (flyteorg#753)
Browse files Browse the repository at this point in the history
* Add support enum in dataclass

Signed-off-by: Kevin Su <[email protected]>

* Update test

Signed-off-by: Kevin Su <[email protected]>

* Fixed lint

Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: Robert Everson <[email protected]>
  • Loading branch information
pingsutw authored and Robert Everson committed May 27, 2022
1 parent 5aaf5e3 commit 389afe9
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 1 deletion.
9 changes: 8 additions & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from google.protobuf.json_format import MessageToDict as _MessageToDict
from google.protobuf.json_format import ParseDict as _ParseDict
from google.protobuf.struct_pb2 import Struct
from marshmallow_enum import EnumField, LoadDumpOptions
from marshmallow_jsonschema import JSONSchema

from flytekit.common.exceptions import user as user_exceptions
Expand Down Expand Up @@ -226,7 +227,13 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:
)
schema = None
try:
schema = JSONSchema().dump(cast(DataClassJsonMixin, t).schema())
s = cast(DataClassJsonMixin, t).schema()
for _, v in s.fields.items():
# marshmallow-jsonschema only supports enums loaded by name.
# https://github.com/fuhrysteve/marshmallow-jsonschema/blob/81eada1a0c42ff67de216923968af0a6b54e5dcb/marshmallow_jsonschema/base.py#L228
if isinstance(v, EnumField):
v.load_by = LoadDumpOptions.name
schema = JSONSchema().dump(s)
except Exception as e:
logger.warn("failed to extract schema for object %s, (will run schemaless) error: %s", str(t), e)

Expand Down
23 changes: 23 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from flyteidl.core import errors_pb2
from google.protobuf import json_format as _json_format
from google.protobuf import struct_pb2 as _struct
from marshmallow_enum import LoadDumpOptions
from marshmallow_jsonschema import JSONSchema

from flytekit.common.exceptions import user as user_exceptions
Expand Down Expand Up @@ -547,6 +548,28 @@ def test_enum_type():
TypeEngine.to_literal_type(UnsupportedEnumValues)


def test_enum_in_dataclass():
@dataclass_json
@dataclass
class Datum(object):
x: int
y: Color

lt = TypeEngine.to_literal_type(Datum)
schema = Datum.schema()
schema.fields["y"].load_by = LoadDumpOptions.name
assert lt.metadata == JSONSchema().dump(schema)

transformer = DataclassTransformer()
ctx = FlyteContext.current_context()
datum = Datum(5, Color.RED)
lv = transformer.to_literal(ctx, datum, Datum, lt)
gt = transformer.guess_python_type(lt)
pv = transformer.to_python_value(ctx, lv, expected_python_type=gt)
assert datum.x == pv.x
assert datum.y.value == pv.y


@pytest.mark.parametrize(
"python_value,python_types,expected_literal_map",
[
Expand Down
24 changes: 24 additions & 0 deletions tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import typing
from collections import OrderedDict
from dataclasses import dataclass
from enum import Enum

import pandas
import pytest
Expand Down Expand Up @@ -1063,6 +1064,29 @@ def wf(x: int, y: int) -> Datum:
wf(x=10, y=20)


def test_enum_in_dataclass():
class Color(Enum):
RED = "red"
GREEN = "green"
BLUE = "blue"

@dataclass_json
@dataclass
class Datum(object):
x: int
y: Color

@task
def t1(x: int) -> Datum:
return Datum(x=x, y=Color.RED)

@workflow
def wf(x: int) -> Datum:
return t1(x=x)

assert wf(x=10) == Datum(10, Color.RED)


def test_environment():
@task(environment={"FOO": "foofoo", "BAZ": "baz"})
def t1(a: int) -> str:
Expand Down

0 comments on commit 389afe9

Please sign in to comment.