From 243adb789cbf427dbc5c81d325a73c448df7ec0a Mon Sep 17 00:00:00 2001 From: Kevin Su <pingsutw@gmail.com> Date: Thu, 7 Oct 2021 04:16:28 +0800 Subject: [PATCH] Failed to load json_data to dataclass (#684) Signed-off-by: Kevin Su <pingsutw@apache.org> --- flytekit/core/type_engine.py | 88 +++++++++++++++----- tests/flytekit/unit/core/test_type_engine.py | 51 ++++++++++-- 2 files changed, 112 insertions(+), 27 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 7b30f86cf0..d02016a465 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1,6 +1,5 @@ from __future__ import annotations -import copy import dataclasses import datetime as _datetime import enum @@ -32,6 +31,7 @@ from flytekit.models.types import LiteralType, SimpleType T = typing.TypeVar("T") +DEFINITIONS = "definitions" class TypeTransformer(typing.Generic[T]): @@ -284,6 +284,14 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: dc = cast(DataClassJsonMixin, expected_python_type).from_json(_json_format.MessageToJson(lv.scalar.generic)) return self._fix_dataclass_int(expected_python_type, dc) + def guess_python_type(self, literal_type: LiteralType) -> Type[T]: + if literal_type.simple == SimpleType.STRUCT: + if literal_type.metadata is not None and DEFINITIONS in literal_type.metadata: + schema_name = literal_type.metadata["$ref"].split("/")[-1] + return convert_json_schema_to_python_class(literal_type.metadata[DEFINITIONS], schema_name) + + raise ValueError(f"Dataclass transformer cannot reverse {literal_type}") + class ProtobufTransformer(TypeTransformer[_proto_reflection.GeneratedProtocolMessageType]): PB_FIELD_KEY = "pb_type" @@ -509,6 +517,13 @@ def guess_python_type(cls, flyte_type: LiteralType) -> type: return transformer.guess_python_type(flyte_type) except ValueError: logger.debug(f"Skipping transformer {transformer.name} for {flyte_type}") + + # Because the dataclass transformer is handled explicity in the get_transformer code, we have to handle it + # separately here too. + try: + return cls._DATACLASS_TRANSFORMER.guess_python_type(literal_type=flyte_type) + except ValueError: + logger.debug(f"Skipping transformer {cls._DATACLASS_TRANSFORMER.name} for {flyte_type}") raise ValueError(f"No transformers could reverse Flyte literal type {flyte_type}") @@ -640,8 +655,6 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[T]: if literal_type.simple == SimpleType.STRUCT: if literal_type.metadata is None: return dict - if "definitions" in literal_type.metadata: - return convert_json_schema_to_python_class(literal_type.metadata) raise ValueError(f"Dictionary transformer cannot reverse {literal_type}") @@ -734,28 +747,63 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: return expected_python_type(lv.scalar.primitive.string_value) -def convert_json_schema_to_python_class(schema): +def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[dataclasses.dataclass()]: """Generate a model class based on the provided JSON Schema :param schema: dict representing valid JSON schema + :param schema_name: dataclass name of return type """ - schema = copy.deepcopy(schema) - - class Model(dict): - def __init__(self, *args, **kwargs): - self.__dict__["schema"] = schema - d = dict(*args, **kwargs) - dict.__init__(self, d) - - def __setitem__(self, key, value): - dict.__setitem__(self, key, value) + attribute_list = [] + for property_key, property_val in schema[schema_name]["properties"].items(): + # Handle list + if property_val["type"] == "array": + attribute_list.append((property_key, typing.List[_get_element_type(property_val["items"])])) + # Handle dataclass and dict + elif property_val["type"] == "object": + if "$ref" in property_val: + name = property_val["$ref"].split("/")[-1] + attribute_list.append((property_key, convert_json_schema_to_python_class(schema, name))) + else: + attribute_list.append( + (property_key, typing.Dict[str, _get_element_type(property_val["additionalProperties"])]) + ) + # Handle int, float, bool or str + else: + attribute_list.append([property_key, _get_element_type(property_val)]) + + return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list)) + + +def _get_element_type(element_property: typing.Dict[str, str]) -> Type[T]: + element_type = element_property["type"] + element_format = element_property["format"] if "format" in element_property else None + if element_type == "string": + return str + elif element_type == "integer": + return int + elif element_type == "boolean": + return bool + elif element_type == "number": + if element_format == "integer": + return int + else: + return float + return str + + +def dataclass_from_dict(cls: type, src: typing.Dict[str, typing.Any]) -> typing.Any: + """ + Utility function to construct a dataclass object from dict + """ + field_types_lookup = {field.name: field.type for field in dataclasses.fields(cls)} - def __getattr__(self, key): - try: - return self.__getitem__(key) - except KeyError: - raise AttributeError(key) + constructor_inputs = {} + for field_name, value in src.items(): + if dataclasses.is_dataclass(field_types_lookup[field_name]): + constructor_inputs[field_name] = dataclass_from_dict(field_types_lookup[field_name], value) + else: + constructor_inputs[field_name] = value - return dataclass_json(dataclasses.dataclass(Model)) + return cls(**constructor_inputs) def _check_and_covert_float(lv: Literal) -> float: diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 5f2f7c61a8..cb9512dfee 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -1,7 +1,7 @@ import datetime import os import typing -from dataclasses import dataclass +from dataclasses import asdict, dataclass from datetime import timedelta from enum import Enum @@ -21,6 +21,7 @@ SimpleTransformer, TypeEngine, convert_json_schema_to_python_class, + dataclass_from_dict, ) from flytekit.models import types as model_types from flytekit.models.core.types import BlobType @@ -118,12 +119,21 @@ def test_list_of_dict_getting_python_value(): def test_list_of_dataclass_getting_python_value(): @dataclass_json @dataclass() - class Foo(object): - x: int + class Bar(object): + w: typing.Optional[str] + x: float y: str - z: typing.Dict[int, str] + z: typing.Dict[str, bool] + + @dataclass_json + @dataclass() + class Foo(object): + w: int + x: typing.List[int] + y: typing.Dict[str, str] + z: Bar - foo = Foo(x=1, y="boo", z={3: "10"}) + foo = Foo(w=1, x=[1], y={"hello": "10"}, z=Bar(w=None, x=1.0, y="hello", z={"world": False})) generic = _json_format.Parse(typing.cast(DataClassJsonMixin, foo).to_json(), _struct.Struct()) lv = Literal(collection=LiteralCollection(literals=[Literal(scalar=Scalar(generic=generic))])) @@ -131,10 +141,18 @@ class Foo(object): ctx = FlyteContext.current_context() schema = JSONSchema().dump(typing.cast(DataClassJsonMixin, Foo).schema()) - foo_class = convert_json_schema_to_python_class(schema) + foo_class = convert_json_schema_to_python_class(schema["definitions"], "FooSchema") pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[foo_class]) assert isinstance(pv, list) + assert pv[0].w == foo.w + assert pv[0].x == foo.x + assert pv[0].y == foo.y + assert pv[0].z.x == foo.z.x + assert type(pv[0].z.x) == float + assert pv[0].z.y == foo.z.y + assert pv[0].z.z == foo.z.z + assert foo == dataclass_from_dict(Foo, asdict(pv[0])) def test_file_non_downloadable(): @@ -257,7 +275,7 @@ class Foo(object): y: str schema = JSONSchema().dump(typing.cast(DataClassJsonMixin, Foo).schema()) - foo_class = convert_json_schema_to_python_class(schema) + foo_class = convert_json_schema_to_python_class(schema["definitions"], "FooSchema") foo = foo_class(x=1, y="hello") foo.x = 2 assert foo.x == 2 @@ -563,6 +581,25 @@ def test_enum_type(): } ), ), + ( + {"p1": TestStructD(s=InnerStruct(a=5, b=None, c=[1, 2, 3]), m={"a": [5]})}, + {"p1": TestStructD}, + LiteralMap( + literals={ + "p1": Literal( + scalar=Scalar( + generic=_json_format.Parse( + typing.cast( + DataClassJsonMixin, + TestStructD(s=InnerStruct(a=5, b=None, c=[1, 2, 3]), m={"a": [5]}), + ).to_json(), + _struct.Struct(), + ) + ) + ) + } + ), + ), ], ) def test_dict_to_literal_map(python_value, python_types, expected_literal_map):