From 243adb789cbf427dbc5c81d325a73c448df7ec0a Mon Sep 17 00:00:00 2001
From: Kevin Su <pingsutw@gmail.com>
Date: Thu, 7 Oct 2021 04:16:28 +0800
Subject: [PATCH] Failed to load json_data to dataclass (#684)

Signed-off-by: Kevin Su <pingsutw@apache.org>
---
 flytekit/core/type_engine.py                 | 88 +++++++++++++++-----
 tests/flytekit/unit/core/test_type_engine.py | 51 ++++++++++--
 2 files changed, 112 insertions(+), 27 deletions(-)

diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py
index 7b30f86cf0..d02016a465 100644
--- a/flytekit/core/type_engine.py
+++ b/flytekit/core/type_engine.py
@@ -1,6 +1,5 @@
 from __future__ import annotations
 
-import copy
 import dataclasses
 import datetime as _datetime
 import enum
@@ -32,6 +31,7 @@
 from flytekit.models.types import LiteralType, SimpleType
 
 T = typing.TypeVar("T")
+DEFINITIONS = "definitions"
 
 
 class TypeTransformer(typing.Generic[T]):
@@ -284,6 +284,14 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
         dc = cast(DataClassJsonMixin, expected_python_type).from_json(_json_format.MessageToJson(lv.scalar.generic))
         return self._fix_dataclass_int(expected_python_type, dc)
 
+    def guess_python_type(self, literal_type: LiteralType) -> Type[T]:
+        if literal_type.simple == SimpleType.STRUCT:
+            if literal_type.metadata is not None and DEFINITIONS in literal_type.metadata:
+                schema_name = literal_type.metadata["$ref"].split("/")[-1]
+                return convert_json_schema_to_python_class(literal_type.metadata[DEFINITIONS], schema_name)
+
+        raise ValueError(f"Dataclass transformer cannot reverse {literal_type}")
+
 
 class ProtobufTransformer(TypeTransformer[_proto_reflection.GeneratedProtocolMessageType]):
     PB_FIELD_KEY = "pb_type"
@@ -509,6 +517,13 @@ def guess_python_type(cls, flyte_type: LiteralType) -> type:
                 return transformer.guess_python_type(flyte_type)
             except ValueError:
                 logger.debug(f"Skipping transformer {transformer.name} for {flyte_type}")
+
+        # Because the dataclass transformer is handled explicity in the get_transformer code, we have to handle it
+        # separately here too.
+        try:
+            return cls._DATACLASS_TRANSFORMER.guess_python_type(literal_type=flyte_type)
+        except ValueError:
+            logger.debug(f"Skipping transformer {cls._DATACLASS_TRANSFORMER.name} for {flyte_type}")
         raise ValueError(f"No transformers could reverse Flyte literal type {flyte_type}")
 
 
@@ -640,8 +655,6 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[T]:
         if literal_type.simple == SimpleType.STRUCT:
             if literal_type.metadata is None:
                 return dict
-            if "definitions" in literal_type.metadata:
-                return convert_json_schema_to_python_class(literal_type.metadata)
 
         raise ValueError(f"Dictionary transformer cannot reverse {literal_type}")
 
@@ -734,28 +747,63 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
         return expected_python_type(lv.scalar.primitive.string_value)
 
 
-def convert_json_schema_to_python_class(schema):
+def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[dataclasses.dataclass()]:
     """Generate a model class based on the provided JSON Schema
     :param schema: dict representing valid JSON schema
+    :param schema_name: dataclass name of return type
     """
-    schema = copy.deepcopy(schema)
-
-    class Model(dict):
-        def __init__(self, *args, **kwargs):
-            self.__dict__["schema"] = schema
-            d = dict(*args, **kwargs)
-            dict.__init__(self, d)
-
-        def __setitem__(self, key, value):
-            dict.__setitem__(self, key, value)
+    attribute_list = []
+    for property_key, property_val in schema[schema_name]["properties"].items():
+        # Handle list
+        if property_val["type"] == "array":
+            attribute_list.append((property_key, typing.List[_get_element_type(property_val["items"])]))
+        # Handle dataclass and dict
+        elif property_val["type"] == "object":
+            if "$ref" in property_val:
+                name = property_val["$ref"].split("/")[-1]
+                attribute_list.append((property_key, convert_json_schema_to_python_class(schema, name)))
+            else:
+                attribute_list.append(
+                    (property_key, typing.Dict[str, _get_element_type(property_val["additionalProperties"])])
+                )
+        # Handle int, float, bool or str
+        else:
+            attribute_list.append([property_key, _get_element_type(property_val)])
+
+    return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list))
+
+
+def _get_element_type(element_property: typing.Dict[str, str]) -> Type[T]:
+    element_type = element_property["type"]
+    element_format = element_property["format"] if "format" in element_property else None
+    if element_type == "string":
+        return str
+    elif element_type == "integer":
+        return int
+    elif element_type == "boolean":
+        return bool
+    elif element_type == "number":
+        if element_format == "integer":
+            return int
+        else:
+            return float
+    return str
+
+
+def dataclass_from_dict(cls: type, src: typing.Dict[str, typing.Any]) -> typing.Any:
+    """
+    Utility function to construct a dataclass object from dict
+    """
+    field_types_lookup = {field.name: field.type for field in dataclasses.fields(cls)}
 
-        def __getattr__(self, key):
-            try:
-                return self.__getitem__(key)
-            except KeyError:
-                raise AttributeError(key)
+    constructor_inputs = {}
+    for field_name, value in src.items():
+        if dataclasses.is_dataclass(field_types_lookup[field_name]):
+            constructor_inputs[field_name] = dataclass_from_dict(field_types_lookup[field_name], value)
+        else:
+            constructor_inputs[field_name] = value
 
-    return dataclass_json(dataclasses.dataclass(Model))
+    return cls(**constructor_inputs)
 
 
 def _check_and_covert_float(lv: Literal) -> float:
diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py
index 5f2f7c61a8..cb9512dfee 100644
--- a/tests/flytekit/unit/core/test_type_engine.py
+++ b/tests/flytekit/unit/core/test_type_engine.py
@@ -1,7 +1,7 @@
 import datetime
 import os
 import typing
-from dataclasses import dataclass
+from dataclasses import asdict, dataclass
 from datetime import timedelta
 from enum import Enum
 
@@ -21,6 +21,7 @@
     SimpleTransformer,
     TypeEngine,
     convert_json_schema_to_python_class,
+    dataclass_from_dict,
 )
 from flytekit.models import types as model_types
 from flytekit.models.core.types import BlobType
@@ -118,12 +119,21 @@ def test_list_of_dict_getting_python_value():
 def test_list_of_dataclass_getting_python_value():
     @dataclass_json
     @dataclass()
-    class Foo(object):
-        x: int
+    class Bar(object):
+        w: typing.Optional[str]
+        x: float
         y: str
-        z: typing.Dict[int, str]
+        z: typing.Dict[str, bool]
+
+    @dataclass_json
+    @dataclass()
+    class Foo(object):
+        w: int
+        x: typing.List[int]
+        y: typing.Dict[str, str]
+        z: Bar
 
-    foo = Foo(x=1, y="boo", z={3: "10"})
+    foo = Foo(w=1, x=[1], y={"hello": "10"}, z=Bar(w=None, x=1.0, y="hello", z={"world": False}))
     generic = _json_format.Parse(typing.cast(DataClassJsonMixin, foo).to_json(), _struct.Struct())
     lv = Literal(collection=LiteralCollection(literals=[Literal(scalar=Scalar(generic=generic))]))
 
@@ -131,10 +141,18 @@ class Foo(object):
     ctx = FlyteContext.current_context()
 
     schema = JSONSchema().dump(typing.cast(DataClassJsonMixin, Foo).schema())
-    foo_class = convert_json_schema_to_python_class(schema)
+    foo_class = convert_json_schema_to_python_class(schema["definitions"], "FooSchema")
 
     pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[foo_class])
     assert isinstance(pv, list)
+    assert pv[0].w == foo.w
+    assert pv[0].x == foo.x
+    assert pv[0].y == foo.y
+    assert pv[0].z.x == foo.z.x
+    assert type(pv[0].z.x) == float
+    assert pv[0].z.y == foo.z.y
+    assert pv[0].z.z == foo.z.z
+    assert foo == dataclass_from_dict(Foo, asdict(pv[0]))
 
 
 def test_file_non_downloadable():
@@ -257,7 +275,7 @@ class Foo(object):
         y: str
 
     schema = JSONSchema().dump(typing.cast(DataClassJsonMixin, Foo).schema())
-    foo_class = convert_json_schema_to_python_class(schema)
+    foo_class = convert_json_schema_to_python_class(schema["definitions"], "FooSchema")
     foo = foo_class(x=1, y="hello")
     foo.x = 2
     assert foo.x == 2
@@ -563,6 +581,25 @@ def test_enum_type():
                 }
             ),
         ),
+        (
+            {"p1": TestStructD(s=InnerStruct(a=5, b=None, c=[1, 2, 3]), m={"a": [5]})},
+            {"p1": TestStructD},
+            LiteralMap(
+                literals={
+                    "p1": Literal(
+                        scalar=Scalar(
+                            generic=_json_format.Parse(
+                                typing.cast(
+                                    DataClassJsonMixin,
+                                    TestStructD(s=InnerStruct(a=5, b=None, c=[1, 2, 3]), m={"a": [5]}),
+                                ).to_json(),
+                                _struct.Struct(),
+                            )
+                        )
+                    )
+                }
+            ),
+        ),
     ],
 )
 def test_dict_to_literal_map(python_value, python_types, expected_literal_map):