From 81eedf37829e854ae02a466e7b294da591e01526 Mon Sep 17 00:00:00 2001
From: Kevin Su <pingsutw@apache.org>
Date: Fri, 9 Dec 2022 09:15:13 +0800
Subject: [PATCH] Set default format of structured dataset to empty (#1159)

* Set default format of structured dataset to empty

Signed-off-by: Kevin Su <pingsutw@apache.org>

* Fix tests

Signed-off-by: Kevin Su <pingsutw@apache.org>

* Fix tests

Signed-off-by: Kevin Su <pingsutw@apache.org>

* lint

Signed-off-by: Kevin Su <pingsutw@apache.org>

* last error (#1364)

Signed-off-by: Yee Hing Tong <wild-endeavor@users.noreply.github.com>

Signed-off-by: Kevin Su <pingsutw@apache.org>
Signed-off-by: Yee Hing Tong <wild-endeavor@users.noreply.github.com>
Co-authored-by: Yee Hing Tong <wild-endeavor@users.noreply.github.com>
---
 flytekit/core/type_engine.py                  |   7 +-
 flytekit/types/structured/basic_dfs.py        |   8 +-
 .../types/structured/structured_dataset.py    | 136 ++++++++++++------
 tests/flytekit/unit/core/test_flyte_pickle.py |   2 +-
 tests/flytekit/unit/core/test_imperative.py   |   5 +-
 .../unit/core/test_structured_dataset.py      |  81 +++++++++--
 6 files changed, 172 insertions(+), 67 deletions(-)

diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py
index 98969e41b3..38d78b5fcb 100644
--- a/flytekit/core/type_engine.py
+++ b/flytekit/core/type_engine.py
@@ -388,7 +388,9 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.A
             if issubclass(python_type, FlyteFile) or issubclass(python_type, FlyteDirectory):
                 return python_type(path=lv.scalar.blob.uri)
             elif issubclass(python_type, StructuredDataset):
-                return python_type(uri=lv.scalar.structured_dataset.uri)
+                sd = python_type(uri=lv.scalar.structured_dataset.uri)
+                sd.file_format = lv.scalar.structured_dataset.metadata.structured_dataset_type.format
+                return sd
             else:
                 return python_val
         else:
@@ -534,7 +536,8 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
                 f"serialized correctly"
             )
 
-        dc = cast(DataClassJsonMixin, expected_python_type).from_json(_json_format.MessageToJson(lv.scalar.generic))
+        json_str = _json_format.MessageToJson(lv.scalar.generic)
+        dc = cast(DataClassJsonMixin, expected_python_type).from_json(json_str)
         return self._fix_dataclass_int(expected_python_type, self._deserialize_flyte_type(dc, expected_python_type))
 
     # This ensures that calls with the same literal type returns the same dataclass. For example, `pyflyte run``
diff --git a/flytekit/types/structured/basic_dfs.py b/flytekit/types/structured/basic_dfs.py
index 71dff61c5e..39f8d11e24 100644
--- a/flytekit/types/structured/basic_dfs.py
+++ b/flytekit/types/structured/basic_dfs.py
@@ -101,10 +101,10 @@ def decode(
         return pq.read_table(local_dir)
 
 
-StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler())
-StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler())
-StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler())
-StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler())
+StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(), default_format_for_type=True)
+StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler(), default_format_for_type=True)
+StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler(), default_format_for_type=True)
+StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler(), default_format_for_type=True)
 
 StructuredDatasetTransformerEngine.register_renderer(pd.DataFrame, TopFrameRenderer())
 StructuredDatasetTransformerEngine.register_renderer(pa.Table, ArrowRenderer())
diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py
index 99fcb49d7b..f0fd917340 100644
--- a/flytekit/types/structured/structured_dataset.py
+++ b/flytekit/types/structured/structured_dataset.py
@@ -34,6 +34,7 @@
 
 # Storage formats
 PARQUET: StructuredDatasetFormat = "parquet"
+GENERIC_FORMAT: StructuredDatasetFormat = ""
 
 
 @dataclass_json
@@ -45,9 +46,7 @@ class (that is just a model, a Python class representation of the protobuf).
     """
 
     uri: typing.Optional[os.PathLike] = field(default=None, metadata=config(mm_field=fields.String()))
-    file_format: typing.Optional[str] = field(default=PARQUET, metadata=config(mm_field=fields.String()))
-
-    DEFAULT_FILE_FORMAT = PARQUET
+    file_format: typing.Optional[str] = field(default=GENERIC_FORMAT, metadata=config(mm_field=fields.String()))
 
     @classmethod
     def columns(cls) -> typing.Dict[str, typing.Type]:
@@ -68,6 +67,8 @@ def __init__(
         # Make these fields public, so that the dataclass transformer can set a value for it
         # https://github.com/flyteorg/flytekit/blob/bcc8541bd6227b532f8462563fe8aac902242b21/flytekit/core/type_engine.py#L298
         self.uri = uri
+        # When dataclass_json runs from_json, we need to set it here, otherwise the format will be empty string
+        self.file_format = kwargs["file_format"] if "file_format" in kwargs else GENERIC_FORMAT
         # This is a special attribute that indicates if the data was either downloaded or uploaded
         self._metadata = metadata
         # This is not for users to set, the transformer will set this.
@@ -128,14 +129,14 @@ def extract_cols_and_format(
         optional str for the format,
         optional pyarrow Schema
     """
-    fmt = None
+    fmt = ""
     ordered_dict_cols = None
     pa_schema = None
     if get_origin(t) is Annotated:
         base_type, *annotate_args = get_args(t)
         for aa in annotate_args:
             if isinstance(aa, StructuredDatasetFormat):
-                if fmt is not None:
+                if fmt != "":
                     raise ValueError(f"A format was already specified {fmt}, cannot use {aa}")
                 fmt = aa
             elif isinstance(aa, collections.OrderedDict):
@@ -334,21 +335,44 @@ class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]):
     Handlers = Union[StructuredDatasetEncoder, StructuredDatasetDecoder]
     Renderers: Dict[Type, Renderable] = {}
 
-    @staticmethod
-    def _finder(handler_map, df_type: Type, protocol: str, format: str):
-        try:
-            return handler_map[df_type][protocol][format]
-        except KeyError:
+    @classmethod
+    def _finder(cls, handler_map, df_type: Type, protocol: str, format: str):
+        # If the incoming format requested is a specific format (e.g. "avro"), then look for that specific handler
+        #   if missing, see if there's a generic format handler. Error if missing.
+        # If the incoming format requested is the generic format (""), then see if it's present,
+        #   if not, look to see if there is a default format for the df_type and a handler for that format.
+        #   if still missing, look to see if there's only _one_ handler for that type, if so then use that.
+        if format != GENERIC_FORMAT:
             try:
-                hh = handler_map[df_type][protocol][""]
-                logger.info(
-                    f"Didn't find format specific handler {type(handler_map)} for protocol {protocol}"
-                    f" format {format}, using default instead."
-                )
-                return hh
+                return handler_map[df_type][protocol][format]
+            except KeyError:
+                try:
+                    return handler_map[df_type][protocol][GENERIC_FORMAT]
+                except KeyError:
+                    ...
+        else:
+            try:
+                return handler_map[df_type][protocol][GENERIC_FORMAT]
             except KeyError:
-                ...
-        raise ValueError(f"Failed to find a handler for {df_type}, protocol {protocol}, fmt {format}")
+                if df_type in cls.DEFAULT_FORMATS and cls.DEFAULT_FORMATS[df_type] in handler_map[df_type][protocol]:
+                    hh = handler_map[df_type][protocol][cls.DEFAULT_FORMATS[df_type]]
+                    logger.debug(
+                        f"Didn't find format specific handler {type(handler_map)} for protocol {protocol}"
+                        f" using the generic handler {hh} instead."
+                    )
+                    return hh
+                if len(handler_map[df_type][protocol]) == 1:
+                    hh = list(handler_map[df_type][protocol].values())[0]
+                    logger.debug(
+                        f"Using {hh} with format {hh.supported_format} as it's the only one available for {df_type}"
+                    )
+                    return hh
+                else:
+                    logger.warning(
+                        f"Did not automatically pick a handler for {df_type},"
+                        f" more than one detected {handler_map[df_type][protocol].keys()}"
+                    )
+        raise ValueError(f"Failed to find a handler for {df_type}, protocol {protocol}, fmt |{format}|")
 
     @classmethod
     def get_encoder(cls, df_type: Type, protocol: str, format: str):
@@ -381,7 +405,14 @@ def register_renderer(cls, python_type: Type, renderer: Renderable):
         cls.Renderers[python_type] = renderer
 
     @classmethod
-    def register(cls, h: Handlers, default_for_type: Optional[bool] = False, override: Optional[bool] = False):
+    def register(
+        cls,
+        h: Handlers,
+        default_for_type: bool = False,
+        override: bool = False,
+        default_format_for_type: bool = False,
+        default_storage_for_type: bool = False,
+    ):
         """
         Call this with any Encoder or Decoder to register it with the flytekit type system. If your handler does not
         specify a protocol (e.g. s3, gs, etc.) field, then
@@ -395,6 +426,10 @@ def register(cls, h: Handlers, default_for_type: Optional[bool] = False, overrid
           In these cases, the protocol is determined by the raw output data prefix set in the active context.
         :param override: Override any previous registrations. If default_for_type is also set, this will also override
           the default.
+        :param default_format_for_type: Unlike the default_for_type arg that will set this handler's format and storage
+          as the default, this will only set the format. Error if already set, unless override is specified.
+        :param default_storage_for_type: Same as above but only for the storage format. Error if already set,
+          unless override is specified.
         """
         if not (isinstance(h, StructuredDatasetEncoder) or isinstance(h, StructuredDatasetDecoder)):
             raise TypeError(f"We don't support this type of handler {h}")
@@ -409,17 +444,29 @@ def register(cls, h: Handlers, default_for_type: Optional[bool] = False, overrid
                 stripped = DataPersistencePlugins.get_protocol(persistence_protocol)
                 logger.debug(f"Automatically registering {persistence_protocol} as {stripped} with {h}")
                 try:
-                    cls.register_for_protocol(h, stripped, False, override)
+                    cls.register_for_protocol(
+                        h, stripped, False, override, default_format_for_type, default_storage_for_type
+                    )
                 except DuplicateHandlerError:
                     logger.debug(f"Skipping {persistence_protocol}/{stripped} for {h} because duplicate")
 
         elif h.protocol == "":
             raise ValueError(f"Use None instead of empty string for registering handler {h}")
         else:
-            cls.register_for_protocol(h, h.protocol, default_for_type, override)
+            cls.register_for_protocol(
+                h, h.protocol, default_for_type, override, default_format_for_type, default_storage_for_type
+            )
 
     @classmethod
-    def register_for_protocol(cls, h: Handlers, protocol: str, default_for_type: bool, override: bool):
+    def register_for_protocol(
+        cls,
+        h: Handlers,
+        protocol: str,
+        default_for_type: bool,
+        override: bool,
+        default_format_for_type: bool,
+        default_storage_for_type: bool,
+    ):
         """
         See the main register function instead.
         """
@@ -434,12 +481,24 @@ def register_for_protocol(cls, h: Handlers, protocol: str, default_for_type: boo
         lowest_level[h.supported_format] = h
         logger.debug(f"Registered {h} as handler for {h.python_type}, protocol {protocol}, fmt {h.supported_format}")
 
-        if default_for_type:
-            logger.debug(
-                f"Using storage {protocol} and format {h.supported_format} for dataframes of type {h.python_type} from handler {h}"
-            )
-            cls.DEFAULT_FORMATS[h.python_type] = h.supported_format
-            cls.DEFAULT_PROTOCOLS[h.python_type] = protocol
+        if (default_format_for_type or default_for_type) and h.supported_format != GENERIC_FORMAT:
+            if h.python_type in cls.DEFAULT_FORMATS and not override:
+                logger.warning(
+                    f"Not using handler {h} with format {h.supported_format} as default for {h.python_type}, {cls.DEFAULT_FORMATS[h.python_type]} already specified."
+                )
+            else:
+                logger.debug(
+                    f"Setting format {h.supported_format} for dataframes of type {h.python_type} from handler {h}"
+                )
+                cls.DEFAULT_FORMATS[h.python_type] = h.supported_format
+        if default_storage_for_type or default_for_type:
+            if h.protocol in cls.DEFAULT_PROTOCOLS and not override:
+                logger.warning(
+                    f"Not using handler {h} with storage protocol {h.protocol} as default for {h.python_type}, {cls.DEFAULT_PROTOCOLS[h.python_type]} already specified."
+                )
+            else:
+                logger.debug(f"Using storage {protocol} for dataframes of type {h.python_type} from handler {h}")
+                cls.DEFAULT_PROTOCOLS[h.python_type] = protocol
 
         # Register with the type engine as well
         # The semantics as of now are such that it doesn't matter which order these transformers are loaded in, as
@@ -461,7 +520,7 @@ def to_literal(
         # Check first to see if it's even an SD type. For backwards compatibility, we may be getting a FlyteSchema
         python_type, *attrs = extract_cols_and_format(python_type)
         # In case it's a FlyteSchema
-        sdt = StructuredDatasetType(format=self.DEFAULT_FORMATS.get(python_type, None))
+        sdt = StructuredDatasetType(format=self.DEFAULT_FORMATS.get(python_type, GENERIC_FORMAT))
 
         if expected and expected.structured_dataset_type:
             sdt = StructuredDatasetType(
@@ -514,16 +573,12 @@ def to_literal(
                 python_val,
                 df_type,
                 protocol,
-                sdt.format or typing.cast(StructuredDataset, python_val).DEFAULT_FILE_FORMAT,
+                sdt.format,
                 sdt,
             )
 
         # Otherwise assume it's a dataframe instance. Wrap it with some defaults
-        if python_type in self.DEFAULT_FORMATS:
-            fmt = self.DEFAULT_FORMATS[python_type]
-        else:
-            logger.debug(f"No default format for type {python_type}, using system default.")
-            fmt = StructuredDataset.DEFAULT_FILE_FORMAT
+        fmt = self.DEFAULT_FORMATS.get(python_type, "")
         protocol = self._protocol_from_type_or_prefix(ctx, python_type)
         meta = StructuredDatasetMetadata(structured_dataset_type=expected.structured_dataset_type if expected else None)
 
@@ -760,18 +815,9 @@ def _get_dataset_type(self, t: typing.Union[Type[StructuredDataset], typing.Any]
 
         # Get the column information
         converted_cols = self._convert_ordered_dict_of_columns_to_list(column_map)
-
-        # Get the format
-        default_format = (
-            original_python_type.DEFAULT_FILE_FORMAT
-            if issubclass(original_python_type, StructuredDataset)
-            else self.DEFAULT_FORMATS.get(original_python_type, PARQUET)
-        )
-        fmt = storage_format or default_format
-
         return StructuredDatasetType(
             columns=converted_cols,
-            format=fmt,
+            format=storage_format,
             external_schema_type="arrow" if pa_schema else None,
             external_schema_bytes=typing.cast(pa.lib.Schema, pa_schema).to_string().encode() if pa_schema else None,
         )
diff --git a/tests/flytekit/unit/core/test_flyte_pickle.py b/tests/flytekit/unit/core/test_flyte_pickle.py
index 318a6b76f3..7ceec809b1 100644
--- a/tests/flytekit/unit/core/test_flyte_pickle.py
+++ b/tests/flytekit/unit/core/test_flyte_pickle.py
@@ -95,5 +95,5 @@ def t1(data: Annotated[Union[np.ndarray, pd.DataFrame, Sequence], "some annotati
     task_spec = get_serializable(OrderedDict(), serialization_settings, t1)
     variants = task_spec.template.interface.inputs["data"].type.union_type.variants
     assert variants[0].blob.format == "NumpyArray"
-    assert variants[1].structured_dataset_type.format == "parquet"
+    assert variants[1].structured_dataset_type.format == ""
     assert variants[2].blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT
diff --git a/tests/flytekit/unit/core/test_imperative.py b/tests/flytekit/unit/core/test_imperative.py
index ab90d991b1..db4b32f6a9 100644
--- a/tests/flytekit/unit/core/test_imperative.py
+++ b/tests/flytekit/unit/core/test_imperative.py
@@ -16,7 +16,6 @@
 from flytekit.tools.translator import get_serializable
 from flytekit.types.file import FlyteFile
 from flytekit.types.schema import FlyteSchema
-from flytekit.types.structured.structured_dataset import StructuredDatasetType
 
 default_img = Image(name="default", fqn="test", tag="tag")
 serialization_settings = flytekit.configuration.SerializationSettings(
@@ -373,6 +372,4 @@ def ref_t2(
 
     assert len(wf_spec.template.interface.outputs) == 1
     assert wf_spec.template.interface.outputs["output_from_t3"].type.structured_dataset_type is not None
-    assert wf_spec.template.interface.outputs["output_from_t3"].type.structured_dataset_type == StructuredDatasetType(
-        format="parquet"
-    )
+    assert wf_spec.template.interface.outputs["output_from_t3"].type.structured_dataset_type.format == ""
diff --git a/tests/flytekit/unit/core/test_structured_dataset.py b/tests/flytekit/unit/core/test_structured_dataset.py
index 7793df430f..bfb41d0fef 100644
--- a/tests/flytekit/unit/core/test_structured_dataset.py
+++ b/tests/flytekit/unit/core/test_structured_dataset.py
@@ -7,11 +7,13 @@
 from typing_extensions import Annotated
 
 import flytekit.configuration
-from flytekit import kwtypes, task
 from flytekit.configuration import Image, ImageConfig
-from flytekit.core.context_manager import FlyteContext, FlyteContextManager
+from flytekit.core.base_task import kwtypes
+from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager
 from flytekit.core.data_persistence import FileAccessProvider
+from flytekit.core.task import task
 from flytekit.core.type_engine import TypeEngine
+from flytekit.core.workflow import workflow
 from flytekit.models import literals
 from flytekit.models.literals import StructuredDatasetMetadata
 from flytekit.models.types import SchemaType, SimpleType, StructuredDatasetType
@@ -38,6 +40,7 @@
     image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa", tag="123")),
     env={},
 )
+df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})
 
 
 def test_protocol():
@@ -49,13 +52,67 @@ def generate_pandas() -> pd.DataFrame:
     return pd.DataFrame({"name": ["Tom", "Joseph"], "age": [20, 22]})
 
 
+def test_formats_make_sense():
+    @task
+    def t1(a: pd.DataFrame) -> pd.DataFrame:
+        print(a)
+        return generate_pandas()
+
+    # this should be an empty string format
+    assert t1.interface.outputs["o0"].type.structured_dataset_type.format == ""
+    assert t1.interface.inputs["a"].type.structured_dataset_type.format == ""
+
+    ctx = FlyteContextManager.current_context()
+    with FlyteContextManager.with_context(
+        ctx.with_execution_state(
+            ctx.new_execution_state().with_params(mode=ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION)
+        )
+    ):
+        result = t1(a=generate_pandas())
+        val = result.val.scalar.value
+        assert val.metadata.structured_dataset_type.format == "parquet"
+
+
+def test_setting_of_unset_formats():
+
+    custom = Annotated[StructuredDataset, "parquet"]
+    example = custom(dataframe=df, uri="/path")
+    # It's okay that the annotation is not used here yet.
+    assert example.file_format == ""
+
+    @task
+    def t2(path: str) -> StructuredDataset:
+        sd = StructuredDataset(dataframe=df, uri=path)
+        return sd
+
+    @workflow
+    def wf(path: str) -> StructuredDataset:
+        return t2(path=path)
+
+    res = wf(path="/tmp/somewhere")
+    # Now that it's passed through an encoder however, it should be set.
+    assert res.file_format == "parquet"
+
+
+def test_json():
+    sd = StructuredDataset(dataframe=df, uri="/some/path")
+    sd.file_format = "myformat"
+    json_str = sd.to_json()
+    new_sd = StructuredDataset.from_json(json_str)
+    assert new_sd.file_format == "myformat"
+
+
 def test_types_pandas():
     pt = pd.DataFrame
     lt = TypeEngine.to_literal_type(pt)
     assert lt.structured_dataset_type is not None
-    assert lt.structured_dataset_type.format == PARQUET
+    assert lt.structured_dataset_type.format == ""
     assert lt.structured_dataset_type.columns == []
 
+    pt = Annotated[pd.DataFrame, "csv"]
+    lt = TypeEngine.to_literal_type(pt)
+    assert lt.structured_dataset_type.format == "csv"
+
 
 def test_annotate_extraction():
     xyz = Annotated[pd.DataFrame, "myformat"]
@@ -68,7 +125,7 @@ def test_annotate_extraction():
     a, b, c, d = extract_cols_and_format(pd.DataFrame)
     assert a is pd.DataFrame
     assert b is None
-    assert c is None
+    assert c == ""
     assert d is None
 
 
@@ -115,9 +172,10 @@ def test_types_sd():
 
 def test_retrieving():
     assert StructuredDatasetTransformerEngine.get_encoder(pd.DataFrame, "file", PARQUET) is not None
-    with pytest.raises(ValueError):
-        # We don't have a default "" format encoder
-        StructuredDatasetTransformerEngine.get_encoder(pd.DataFrame, "file", "")
+    # Asking for a generic means you're okay with any one registered for that type assuming there's just one.
+    assert StructuredDatasetTransformerEngine.get_encoder(
+        pd.DataFrame, "file", ""
+    ) is StructuredDatasetTransformerEngine.get_encoder(pd.DataFrame, "file", PARQUET)
 
     class TempEncoder(StructuredDatasetEncoder):
         def __init__(self, protocol):
@@ -188,9 +246,10 @@ def encode(
         ) -> literals.StructuredDataset:
             return literals.StructuredDataset(uri="")
 
-    StructuredDatasetTransformerEngine.register(TempEncoder("myavro"), default_for_type=True)
+    default_encoder = TempEncoder("myavro")
+    StructuredDatasetTransformerEngine.register(default_encoder, default_for_type=True)
     lt = TypeEngine.to_literal_type(MyDF)
-    assert lt.structured_dataset_type.format == "myavro"
+    assert lt.structured_dataset_type.format == ""
 
     ctx = FlyteContextManager.current_context()
     fdt = StructuredDatasetTransformerEngine()
@@ -228,7 +287,7 @@ def encode(
 def test_sd():
     sd = StructuredDataset(dataframe="hi")
     sd.uri = "my uri"
-    assert sd.file_format == PARQUET
+    assert sd.file_format == ""
 
     with pytest.raises(ValueError, match="No dataframe type set"):
         sd.all()
@@ -383,7 +442,7 @@ def encode(
     assert df_literal_type.structured_dataset_type.format == "avro"
 
     sd = annotated_sd_type(df)
-    with pytest.raises(ValueError):
+    with pytest.raises(ValueError, match="Failed to find a handler"):
         TypeEngine.to_literal(ctx, sd, python_type=annotated_sd_type, expected=df_literal_type)
 
     StructuredDatasetTransformerEngine.register(TempEncoder(), default_for_type=False)