Skip to content

Commit

Permalink
Failed to load json_data to dataclass (#684)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Oct 6, 2021
1 parent 652fd22 commit 243adb7
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 27 deletions.
88 changes: 68 additions & 20 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import copy
import dataclasses
import datetime as _datetime
import enum
Expand Down Expand Up @@ -32,6 +31,7 @@
from flytekit.models.types import LiteralType, SimpleType

T = typing.TypeVar("T")
DEFINITIONS = "definitions"


class TypeTransformer(typing.Generic[T]):
Expand Down Expand Up @@ -284,6 +284,14 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
dc = cast(DataClassJsonMixin, expected_python_type).from_json(_json_format.MessageToJson(lv.scalar.generic))
return self._fix_dataclass_int(expected_python_type, dc)

def guess_python_type(self, literal_type: LiteralType) -> Type[T]:
if literal_type.simple == SimpleType.STRUCT:
if literal_type.metadata is not None and DEFINITIONS in literal_type.metadata:
schema_name = literal_type.metadata["$ref"].split("/")[-1]
return convert_json_schema_to_python_class(literal_type.metadata[DEFINITIONS], schema_name)

raise ValueError(f"Dataclass transformer cannot reverse {literal_type}")


class ProtobufTransformer(TypeTransformer[_proto_reflection.GeneratedProtocolMessageType]):
PB_FIELD_KEY = "pb_type"
Expand Down Expand Up @@ -509,6 +517,13 @@ def guess_python_type(cls, flyte_type: LiteralType) -> type:
return transformer.guess_python_type(flyte_type)
except ValueError:
logger.debug(f"Skipping transformer {transformer.name} for {flyte_type}")

# Because the dataclass transformer is handled explicity in the get_transformer code, we have to handle it
# separately here too.
try:
return cls._DATACLASS_TRANSFORMER.guess_python_type(literal_type=flyte_type)
except ValueError:
logger.debug(f"Skipping transformer {cls._DATACLASS_TRANSFORMER.name} for {flyte_type}")
raise ValueError(f"No transformers could reverse Flyte literal type {flyte_type}")


Expand Down Expand Up @@ -640,8 +655,6 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[T]:
if literal_type.simple == SimpleType.STRUCT:
if literal_type.metadata is None:
return dict
if "definitions" in literal_type.metadata:
return convert_json_schema_to_python_class(literal_type.metadata)

raise ValueError(f"Dictionary transformer cannot reverse {literal_type}")

Expand Down Expand Up @@ -734,28 +747,63 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
return expected_python_type(lv.scalar.primitive.string_value)


def convert_json_schema_to_python_class(schema):
def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[dataclasses.dataclass()]:
"""Generate a model class based on the provided JSON Schema
:param schema: dict representing valid JSON schema
:param schema_name: dataclass name of return type
"""
schema = copy.deepcopy(schema)

class Model(dict):
def __init__(self, *args, **kwargs):
self.__dict__["schema"] = schema
d = dict(*args, **kwargs)
dict.__init__(self, d)

def __setitem__(self, key, value):
dict.__setitem__(self, key, value)
attribute_list = []
for property_key, property_val in schema[schema_name]["properties"].items():
# Handle list
if property_val["type"] == "array":
attribute_list.append((property_key, typing.List[_get_element_type(property_val["items"])]))
# Handle dataclass and dict
elif property_val["type"] == "object":
if "$ref" in property_val:
name = property_val["$ref"].split("/")[-1]
attribute_list.append((property_key, convert_json_schema_to_python_class(schema, name)))
else:
attribute_list.append(
(property_key, typing.Dict[str, _get_element_type(property_val["additionalProperties"])])
)
# Handle int, float, bool or str
else:
attribute_list.append([property_key, _get_element_type(property_val)])

return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list))


def _get_element_type(element_property: typing.Dict[str, str]) -> Type[T]:
element_type = element_property["type"]
element_format = element_property["format"] if "format" in element_property else None
if element_type == "string":
return str
elif element_type == "integer":
return int
elif element_type == "boolean":
return bool
elif element_type == "number":
if element_format == "integer":
return int
else:
return float
return str


def dataclass_from_dict(cls: type, src: typing.Dict[str, typing.Any]) -> typing.Any:
"""
Utility function to construct a dataclass object from dict
"""
field_types_lookup = {field.name: field.type for field in dataclasses.fields(cls)}

def __getattr__(self, key):
try:
return self.__getitem__(key)
except KeyError:
raise AttributeError(key)
constructor_inputs = {}
for field_name, value in src.items():
if dataclasses.is_dataclass(field_types_lookup[field_name]):
constructor_inputs[field_name] = dataclass_from_dict(field_types_lookup[field_name], value)
else:
constructor_inputs[field_name] = value

return dataclass_json(dataclasses.dataclass(Model))
return cls(**constructor_inputs)


def _check_and_covert_float(lv: Literal) -> float:
Expand Down
51 changes: 44 additions & 7 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime
import os
import typing
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from datetime import timedelta
from enum import Enum

Expand All @@ -21,6 +21,7 @@
SimpleTransformer,
TypeEngine,
convert_json_schema_to_python_class,
dataclass_from_dict,
)
from flytekit.models import types as model_types
from flytekit.models.core.types import BlobType
Expand Down Expand Up @@ -118,23 +119,40 @@ def test_list_of_dict_getting_python_value():
def test_list_of_dataclass_getting_python_value():
@dataclass_json
@dataclass()
class Foo(object):
x: int
class Bar(object):
w: typing.Optional[str]
x: float
y: str
z: typing.Dict[int, str]
z: typing.Dict[str, bool]

@dataclass_json
@dataclass()
class Foo(object):
w: int
x: typing.List[int]
y: typing.Dict[str, str]
z: Bar

foo = Foo(x=1, y="boo", z={3: "10"})
foo = Foo(w=1, x=[1], y={"hello": "10"}, z=Bar(w=None, x=1.0, y="hello", z={"world": False}))
generic = _json_format.Parse(typing.cast(DataClassJsonMixin, foo).to_json(), _struct.Struct())
lv = Literal(collection=LiteralCollection(literals=[Literal(scalar=Scalar(generic=generic))]))

transformer = TypeEngine.get_transformer(typing.List)
ctx = FlyteContext.current_context()

schema = JSONSchema().dump(typing.cast(DataClassJsonMixin, Foo).schema())
foo_class = convert_json_schema_to_python_class(schema)
foo_class = convert_json_schema_to_python_class(schema["definitions"], "FooSchema")

pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[foo_class])
assert isinstance(pv, list)
assert pv[0].w == foo.w
assert pv[0].x == foo.x
assert pv[0].y == foo.y
assert pv[0].z.x == foo.z.x
assert type(pv[0].z.x) == float
assert pv[0].z.y == foo.z.y
assert pv[0].z.z == foo.z.z
assert foo == dataclass_from_dict(Foo, asdict(pv[0]))


def test_file_non_downloadable():
Expand Down Expand Up @@ -257,7 +275,7 @@ class Foo(object):
y: str

schema = JSONSchema().dump(typing.cast(DataClassJsonMixin, Foo).schema())
foo_class = convert_json_schema_to_python_class(schema)
foo_class = convert_json_schema_to_python_class(schema["definitions"], "FooSchema")
foo = foo_class(x=1, y="hello")
foo.x = 2
assert foo.x == 2
Expand Down Expand Up @@ -563,6 +581,25 @@ def test_enum_type():
}
),
),
(
{"p1": TestStructD(s=InnerStruct(a=5, b=None, c=[1, 2, 3]), m={"a": [5]})},
{"p1": TestStructD},
LiteralMap(
literals={
"p1": Literal(
scalar=Scalar(
generic=_json_format.Parse(
typing.cast(
DataClassJsonMixin,
TestStructD(s=InnerStruct(a=5, b=None, c=[1, 2, 3]), m={"a": [5]}),
).to_json(),
_struct.Struct(),
)
)
)
}
),
),
],
)
def test_dict_to_literal_map(python_value, python_types, expected_literal_map):
Expand Down

0 comments on commit 243adb7

Please sign in to comment.