From 81eedf37829e854ae02a466e7b294da591e01526 Mon Sep 17 00:00:00 2001 From: Kevin Su <pingsutw@apache.org> Date: Fri, 9 Dec 2022 09:15:13 +0800 Subject: [PATCH] Set default format of structured dataset to empty (#1159) * Set default format of structured dataset to empty Signed-off-by: Kevin Su <pingsutw@apache.org> * Fix tests Signed-off-by: Kevin Su <pingsutw@apache.org> * Fix tests Signed-off-by: Kevin Su <pingsutw@apache.org> * lint Signed-off-by: Kevin Su <pingsutw@apache.org> * last error (#1364) Signed-off-by: Yee Hing Tong <wild-endeavor@users.noreply.github.com> Signed-off-by: Kevin Su <pingsutw@apache.org> Signed-off-by: Yee Hing Tong <wild-endeavor@users.noreply.github.com> Co-authored-by: Yee Hing Tong <wild-endeavor@users.noreply.github.com> --- flytekit/core/type_engine.py | 7 +- flytekit/types/structured/basic_dfs.py | 8 +- .../types/structured/structured_dataset.py | 136 ++++++++++++------ tests/flytekit/unit/core/test_flyte_pickle.py | 2 +- tests/flytekit/unit/core/test_imperative.py | 5 +- .../unit/core/test_structured_dataset.py | 81 +++++++++-- 6 files changed, 172 insertions(+), 67 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 98969e41b3..38d78b5fcb 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -388,7 +388,9 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.A if issubclass(python_type, FlyteFile) or issubclass(python_type, FlyteDirectory): return python_type(path=lv.scalar.blob.uri) elif issubclass(python_type, StructuredDataset): - return python_type(uri=lv.scalar.structured_dataset.uri) + sd = python_type(uri=lv.scalar.structured_dataset.uri) + sd.file_format = lv.scalar.structured_dataset.metadata.structured_dataset_type.format + return sd else: return python_val else: @@ -534,7 +536,8 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: f"serialized correctly" ) - dc = cast(DataClassJsonMixin, expected_python_type).from_json(_json_format.MessageToJson(lv.scalar.generic)) + json_str = _json_format.MessageToJson(lv.scalar.generic) + dc = cast(DataClassJsonMixin, expected_python_type).from_json(json_str) return self._fix_dataclass_int(expected_python_type, self._deserialize_flyte_type(dc, expected_python_type)) # This ensures that calls with the same literal type returns the same dataclass. For example, `pyflyte run`` diff --git a/flytekit/types/structured/basic_dfs.py b/flytekit/types/structured/basic_dfs.py index 71dff61c5e..39f8d11e24 100644 --- a/flytekit/types/structured/basic_dfs.py +++ b/flytekit/types/structured/basic_dfs.py @@ -101,10 +101,10 @@ def decode( return pq.read_table(local_dir) -StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler()) -StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler()) -StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler()) -StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler()) +StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(), default_format_for_type=True) +StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler(), default_format_for_type=True) +StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler(), default_format_for_type=True) +StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler(), default_format_for_type=True) StructuredDatasetTransformerEngine.register_renderer(pd.DataFrame, TopFrameRenderer()) StructuredDatasetTransformerEngine.register_renderer(pa.Table, ArrowRenderer()) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 99fcb49d7b..f0fd917340 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -34,6 +34,7 @@ # Storage formats PARQUET: StructuredDatasetFormat = "parquet" +GENERIC_FORMAT: StructuredDatasetFormat = "" @dataclass_json @@ -45,9 +46,7 @@ class (that is just a model, a Python class representation of the protobuf). """ uri: typing.Optional[os.PathLike] = field(default=None, metadata=config(mm_field=fields.String())) - file_format: typing.Optional[str] = field(default=PARQUET, metadata=config(mm_field=fields.String())) - - DEFAULT_FILE_FORMAT = PARQUET + file_format: typing.Optional[str] = field(default=GENERIC_FORMAT, metadata=config(mm_field=fields.String())) @classmethod def columns(cls) -> typing.Dict[str, typing.Type]: @@ -68,6 +67,8 @@ def __init__( # Make these fields public, so that the dataclass transformer can set a value for it # https://github.com/flyteorg/flytekit/blob/bcc8541bd6227b532f8462563fe8aac902242b21/flytekit/core/type_engine.py#L298 self.uri = uri + # When dataclass_json runs from_json, we need to set it here, otherwise the format will be empty string + self.file_format = kwargs["file_format"] if "file_format" in kwargs else GENERIC_FORMAT # This is a special attribute that indicates if the data was either downloaded or uploaded self._metadata = metadata # This is not for users to set, the transformer will set this. @@ -128,14 +129,14 @@ def extract_cols_and_format( optional str for the format, optional pyarrow Schema """ - fmt = None + fmt = "" ordered_dict_cols = None pa_schema = None if get_origin(t) is Annotated: base_type, *annotate_args = get_args(t) for aa in annotate_args: if isinstance(aa, StructuredDatasetFormat): - if fmt is not None: + if fmt != "": raise ValueError(f"A format was already specified {fmt}, cannot use {aa}") fmt = aa elif isinstance(aa, collections.OrderedDict): @@ -334,21 +335,44 @@ class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]): Handlers = Union[StructuredDatasetEncoder, StructuredDatasetDecoder] Renderers: Dict[Type, Renderable] = {} - @staticmethod - def _finder(handler_map, df_type: Type, protocol: str, format: str): - try: - return handler_map[df_type][protocol][format] - except KeyError: + @classmethod + def _finder(cls, handler_map, df_type: Type, protocol: str, format: str): + # If the incoming format requested is a specific format (e.g. "avro"), then look for that specific handler + # if missing, see if there's a generic format handler. Error if missing. + # If the incoming format requested is the generic format (""), then see if it's present, + # if not, look to see if there is a default format for the df_type and a handler for that format. + # if still missing, look to see if there's only _one_ handler for that type, if so then use that. + if format != GENERIC_FORMAT: try: - hh = handler_map[df_type][protocol][""] - logger.info( - f"Didn't find format specific handler {type(handler_map)} for protocol {protocol}" - f" format {format}, using default instead." - ) - return hh + return handler_map[df_type][protocol][format] + except KeyError: + try: + return handler_map[df_type][protocol][GENERIC_FORMAT] + except KeyError: + ... + else: + try: + return handler_map[df_type][protocol][GENERIC_FORMAT] except KeyError: - ... - raise ValueError(f"Failed to find a handler for {df_type}, protocol {protocol}, fmt {format}") + if df_type in cls.DEFAULT_FORMATS and cls.DEFAULT_FORMATS[df_type] in handler_map[df_type][protocol]: + hh = handler_map[df_type][protocol][cls.DEFAULT_FORMATS[df_type]] + logger.debug( + f"Didn't find format specific handler {type(handler_map)} for protocol {protocol}" + f" using the generic handler {hh} instead." + ) + return hh + if len(handler_map[df_type][protocol]) == 1: + hh = list(handler_map[df_type][protocol].values())[0] + logger.debug( + f"Using {hh} with format {hh.supported_format} as it's the only one available for {df_type}" + ) + return hh + else: + logger.warning( + f"Did not automatically pick a handler for {df_type}," + f" more than one detected {handler_map[df_type][protocol].keys()}" + ) + raise ValueError(f"Failed to find a handler for {df_type}, protocol {protocol}, fmt |{format}|") @classmethod def get_encoder(cls, df_type: Type, protocol: str, format: str): @@ -381,7 +405,14 @@ def register_renderer(cls, python_type: Type, renderer: Renderable): cls.Renderers[python_type] = renderer @classmethod - def register(cls, h: Handlers, default_for_type: Optional[bool] = False, override: Optional[bool] = False): + def register( + cls, + h: Handlers, + default_for_type: bool = False, + override: bool = False, + default_format_for_type: bool = False, + default_storage_for_type: bool = False, + ): """ Call this with any Encoder or Decoder to register it with the flytekit type system. If your handler does not specify a protocol (e.g. s3, gs, etc.) field, then @@ -395,6 +426,10 @@ def register(cls, h: Handlers, default_for_type: Optional[bool] = False, overrid In these cases, the protocol is determined by the raw output data prefix set in the active context. :param override: Override any previous registrations. If default_for_type is also set, this will also override the default. + :param default_format_for_type: Unlike the default_for_type arg that will set this handler's format and storage + as the default, this will only set the format. Error if already set, unless override is specified. + :param default_storage_for_type: Same as above but only for the storage format. Error if already set, + unless override is specified. """ if not (isinstance(h, StructuredDatasetEncoder) or isinstance(h, StructuredDatasetDecoder)): raise TypeError(f"We don't support this type of handler {h}") @@ -409,17 +444,29 @@ def register(cls, h: Handlers, default_for_type: Optional[bool] = False, overrid stripped = DataPersistencePlugins.get_protocol(persistence_protocol) logger.debug(f"Automatically registering {persistence_protocol} as {stripped} with {h}") try: - cls.register_for_protocol(h, stripped, False, override) + cls.register_for_protocol( + h, stripped, False, override, default_format_for_type, default_storage_for_type + ) except DuplicateHandlerError: logger.debug(f"Skipping {persistence_protocol}/{stripped} for {h} because duplicate") elif h.protocol == "": raise ValueError(f"Use None instead of empty string for registering handler {h}") else: - cls.register_for_protocol(h, h.protocol, default_for_type, override) + cls.register_for_protocol( + h, h.protocol, default_for_type, override, default_format_for_type, default_storage_for_type + ) @classmethod - def register_for_protocol(cls, h: Handlers, protocol: str, default_for_type: bool, override: bool): + def register_for_protocol( + cls, + h: Handlers, + protocol: str, + default_for_type: bool, + override: bool, + default_format_for_type: bool, + default_storage_for_type: bool, + ): """ See the main register function instead. """ @@ -434,12 +481,24 @@ def register_for_protocol(cls, h: Handlers, protocol: str, default_for_type: boo lowest_level[h.supported_format] = h logger.debug(f"Registered {h} as handler for {h.python_type}, protocol {protocol}, fmt {h.supported_format}") - if default_for_type: - logger.debug( - f"Using storage {protocol} and format {h.supported_format} for dataframes of type {h.python_type} from handler {h}" - ) - cls.DEFAULT_FORMATS[h.python_type] = h.supported_format - cls.DEFAULT_PROTOCOLS[h.python_type] = protocol + if (default_format_for_type or default_for_type) and h.supported_format != GENERIC_FORMAT: + if h.python_type in cls.DEFAULT_FORMATS and not override: + logger.warning( + f"Not using handler {h} with format {h.supported_format} as default for {h.python_type}, {cls.DEFAULT_FORMATS[h.python_type]} already specified." + ) + else: + logger.debug( + f"Setting format {h.supported_format} for dataframes of type {h.python_type} from handler {h}" + ) + cls.DEFAULT_FORMATS[h.python_type] = h.supported_format + if default_storage_for_type or default_for_type: + if h.protocol in cls.DEFAULT_PROTOCOLS and not override: + logger.warning( + f"Not using handler {h} with storage protocol {h.protocol} as default for {h.python_type}, {cls.DEFAULT_PROTOCOLS[h.python_type]} already specified." + ) + else: + logger.debug(f"Using storage {protocol} for dataframes of type {h.python_type} from handler {h}") + cls.DEFAULT_PROTOCOLS[h.python_type] = protocol # Register with the type engine as well # The semantics as of now are such that it doesn't matter which order these transformers are loaded in, as @@ -461,7 +520,7 @@ def to_literal( # Check first to see if it's even an SD type. For backwards compatibility, we may be getting a FlyteSchema python_type, *attrs = extract_cols_and_format(python_type) # In case it's a FlyteSchema - sdt = StructuredDatasetType(format=self.DEFAULT_FORMATS.get(python_type, None)) + sdt = StructuredDatasetType(format=self.DEFAULT_FORMATS.get(python_type, GENERIC_FORMAT)) if expected and expected.structured_dataset_type: sdt = StructuredDatasetType( @@ -514,16 +573,12 @@ def to_literal( python_val, df_type, protocol, - sdt.format or typing.cast(StructuredDataset, python_val).DEFAULT_FILE_FORMAT, + sdt.format, sdt, ) # Otherwise assume it's a dataframe instance. Wrap it with some defaults - if python_type in self.DEFAULT_FORMATS: - fmt = self.DEFAULT_FORMATS[python_type] - else: - logger.debug(f"No default format for type {python_type}, using system default.") - fmt = StructuredDataset.DEFAULT_FILE_FORMAT + fmt = self.DEFAULT_FORMATS.get(python_type, "") protocol = self._protocol_from_type_or_prefix(ctx, python_type) meta = StructuredDatasetMetadata(structured_dataset_type=expected.structured_dataset_type if expected else None) @@ -760,18 +815,9 @@ def _get_dataset_type(self, t: typing.Union[Type[StructuredDataset], typing.Any] # Get the column information converted_cols = self._convert_ordered_dict_of_columns_to_list(column_map) - - # Get the format - default_format = ( - original_python_type.DEFAULT_FILE_FORMAT - if issubclass(original_python_type, StructuredDataset) - else self.DEFAULT_FORMATS.get(original_python_type, PARQUET) - ) - fmt = storage_format or default_format - return StructuredDatasetType( columns=converted_cols, - format=fmt, + format=storage_format, external_schema_type="arrow" if pa_schema else None, external_schema_bytes=typing.cast(pa.lib.Schema, pa_schema).to_string().encode() if pa_schema else None, ) diff --git a/tests/flytekit/unit/core/test_flyte_pickle.py b/tests/flytekit/unit/core/test_flyte_pickle.py index 318a6b76f3..7ceec809b1 100644 --- a/tests/flytekit/unit/core/test_flyte_pickle.py +++ b/tests/flytekit/unit/core/test_flyte_pickle.py @@ -95,5 +95,5 @@ def t1(data: Annotated[Union[np.ndarray, pd.DataFrame, Sequence], "some annotati task_spec = get_serializable(OrderedDict(), serialization_settings, t1) variants = task_spec.template.interface.inputs["data"].type.union_type.variants assert variants[0].blob.format == "NumpyArray" - assert variants[1].structured_dataset_type.format == "parquet" + assert variants[1].structured_dataset_type.format == "" assert variants[2].blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT diff --git a/tests/flytekit/unit/core/test_imperative.py b/tests/flytekit/unit/core/test_imperative.py index ab90d991b1..db4b32f6a9 100644 --- a/tests/flytekit/unit/core/test_imperative.py +++ b/tests/flytekit/unit/core/test_imperative.py @@ -16,7 +16,6 @@ from flytekit.tools.translator import get_serializable from flytekit.types.file import FlyteFile from flytekit.types.schema import FlyteSchema -from flytekit.types.structured.structured_dataset import StructuredDatasetType default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = flytekit.configuration.SerializationSettings( @@ -373,6 +372,4 @@ def ref_t2( assert len(wf_spec.template.interface.outputs) == 1 assert wf_spec.template.interface.outputs["output_from_t3"].type.structured_dataset_type is not None - assert wf_spec.template.interface.outputs["output_from_t3"].type.structured_dataset_type == StructuredDatasetType( - format="parquet" - ) + assert wf_spec.template.interface.outputs["output_from_t3"].type.structured_dataset_type.format == "" diff --git a/tests/flytekit/unit/core/test_structured_dataset.py b/tests/flytekit/unit/core/test_structured_dataset.py index 7793df430f..bfb41d0fef 100644 --- a/tests/flytekit/unit/core/test_structured_dataset.py +++ b/tests/flytekit/unit/core/test_structured_dataset.py @@ -7,11 +7,13 @@ from typing_extensions import Annotated import flytekit.configuration -from flytekit import kwtypes, task from flytekit.configuration import Image, ImageConfig -from flytekit.core.context_manager import FlyteContext, FlyteContextManager +from flytekit.core.base_task import kwtypes +from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.data_persistence import FileAccessProvider +from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine +from flytekit.core.workflow import workflow from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import SchemaType, SimpleType, StructuredDatasetType @@ -38,6 +40,7 @@ image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa", tag="123")), env={}, ) +df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) def test_protocol(): @@ -49,13 +52,67 @@ def generate_pandas() -> pd.DataFrame: return pd.DataFrame({"name": ["Tom", "Joseph"], "age": [20, 22]}) +def test_formats_make_sense(): + @task + def t1(a: pd.DataFrame) -> pd.DataFrame: + print(a) + return generate_pandas() + + # this should be an empty string format + assert t1.interface.outputs["o0"].type.structured_dataset_type.format == "" + assert t1.interface.inputs["a"].type.structured_dataset_type.format == "" + + ctx = FlyteContextManager.current_context() + with FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.new_execution_state().with_params(mode=ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION) + ) + ): + result = t1(a=generate_pandas()) + val = result.val.scalar.value + assert val.metadata.structured_dataset_type.format == "parquet" + + +def test_setting_of_unset_formats(): + + custom = Annotated[StructuredDataset, "parquet"] + example = custom(dataframe=df, uri="/path") + # It's okay that the annotation is not used here yet. + assert example.file_format == "" + + @task + def t2(path: str) -> StructuredDataset: + sd = StructuredDataset(dataframe=df, uri=path) + return sd + + @workflow + def wf(path: str) -> StructuredDataset: + return t2(path=path) + + res = wf(path="/tmp/somewhere") + # Now that it's passed through an encoder however, it should be set. + assert res.file_format == "parquet" + + +def test_json(): + sd = StructuredDataset(dataframe=df, uri="/some/path") + sd.file_format = "myformat" + json_str = sd.to_json() + new_sd = StructuredDataset.from_json(json_str) + assert new_sd.file_format == "myformat" + + def test_types_pandas(): pt = pd.DataFrame lt = TypeEngine.to_literal_type(pt) assert lt.structured_dataset_type is not None - assert lt.structured_dataset_type.format == PARQUET + assert lt.structured_dataset_type.format == "" assert lt.structured_dataset_type.columns == [] + pt = Annotated[pd.DataFrame, "csv"] + lt = TypeEngine.to_literal_type(pt) + assert lt.structured_dataset_type.format == "csv" + def test_annotate_extraction(): xyz = Annotated[pd.DataFrame, "myformat"] @@ -68,7 +125,7 @@ def test_annotate_extraction(): a, b, c, d = extract_cols_and_format(pd.DataFrame) assert a is pd.DataFrame assert b is None - assert c is None + assert c == "" assert d is None @@ -115,9 +172,10 @@ def test_types_sd(): def test_retrieving(): assert StructuredDatasetTransformerEngine.get_encoder(pd.DataFrame, "file", PARQUET) is not None - with pytest.raises(ValueError): - # We don't have a default "" format encoder - StructuredDatasetTransformerEngine.get_encoder(pd.DataFrame, "file", "") + # Asking for a generic means you're okay with any one registered for that type assuming there's just one. + assert StructuredDatasetTransformerEngine.get_encoder( + pd.DataFrame, "file", "" + ) is StructuredDatasetTransformerEngine.get_encoder(pd.DataFrame, "file", PARQUET) class TempEncoder(StructuredDatasetEncoder): def __init__(self, protocol): @@ -188,9 +246,10 @@ def encode( ) -> literals.StructuredDataset: return literals.StructuredDataset(uri="") - StructuredDatasetTransformerEngine.register(TempEncoder("myavro"), default_for_type=True) + default_encoder = TempEncoder("myavro") + StructuredDatasetTransformerEngine.register(default_encoder, default_for_type=True) lt = TypeEngine.to_literal_type(MyDF) - assert lt.structured_dataset_type.format == "myavro" + assert lt.structured_dataset_type.format == "" ctx = FlyteContextManager.current_context() fdt = StructuredDatasetTransformerEngine() @@ -228,7 +287,7 @@ def encode( def test_sd(): sd = StructuredDataset(dataframe="hi") sd.uri = "my uri" - assert sd.file_format == PARQUET + assert sd.file_format == "" with pytest.raises(ValueError, match="No dataframe type set"): sd.all() @@ -383,7 +442,7 @@ def encode( assert df_literal_type.structured_dataset_type.format == "avro" sd = annotated_sd_type(df) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Failed to find a handler"): TypeEngine.to_literal(ctx, sd, python_type=annotated_sd_type, expected=df_literal_type) StructuredDatasetTransformerEngine.register(TempEncoder(), default_for_type=False)