From de1c0b2d6838220b57dbd8eacdb092681f953e8f Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 24 Jan 2022 09:40:22 -0800 Subject: [PATCH] [pr into #822] (#827) Signed-off-by: Yee Hing Tong Signed-off-by: Kevin Su --- flytekit/core/base_task.py | 4 +- .../types/structured/structured_dataset.py | 253 ++++++++++++------ setup.py | 1 + .../unit/core/test_structured_dataset.py | 26 +- tests/flytekit/unit/core/test_type_engine.py | 14 +- 5 files changed, 213 insertions(+), 85 deletions(-) diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index f590ba2033..688b212144 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -21,7 +21,7 @@ import datetime from abc import abstractmethod from dataclasses import dataclass -from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Dict, Generic, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union from flytekit.core.context_manager import ( ExecutionParameters, @@ -53,7 +53,7 @@ from flytekit.models.security import SecurityContext -def kwtypes(**kwargs) -> Dict[str, Type]: +def kwtypes(**kwargs) -> OrderedDict[str, Type]: """ This is a small helper function to convert the keyword arguments to an OrderedDict of types. diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index a21c71d9f9..53e541e6da 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -9,13 +9,14 @@ from dataclasses import dataclass, field from typing import Dict, Generator, Optional, Type, Union +import pyarrow from dataclasses_json import config, dataclass_json from marshmallow import fields try: - from typing import Annotated, get_args, get_origin + from typing import Annotated, TypeAlias, get_args, get_origin except ImportError: - from typing_extensions import Annotated, get_origin, get_args + from typing_extensions import Annotated, get_origin, get_args, TypeAlias import _datetime import numpy as _np @@ -38,20 +39,24 @@ S3 = "s3" LOCAL = "/" +# For specifying the storage formats of StructuredDatasets. It's just a string, nothing fancy. +StructuredDatasetFormat: TypeAlias = str + # Storage formats -PARQUET = "parquet" +PARQUET: StructuredDatasetFormat = "parquet" @dataclass_json @dataclass class StructuredDataset(object): - uri: typing.Optional[os.PathLike] = field(default=None, metadata=config(mm_field=fields.String())) - file_format: typing.Optional[str] = field(default=PARQUET, metadata=config(mm_field=fields.String())) """ This is the user facing StructuredDataset class. Please don't confuse it with the literals.StructuredDataset class (that is just a model, a Python class representation of the protobuf). """ + uri: typing.Optional[os.PathLike] = field(default=None, metadata=config(mm_field=fields.String())) + file_format: typing.Optional[str] = field(default=PARQUET, metadata=config(mm_field=fields.String())) + DEFAULT_FILE_FORMAT = PARQUET @classmethod @@ -109,6 +114,53 @@ def iter(self) -> Generator[DF, None, None]: return FLYTE_DATASET_TRANSFORMER.iter_as(ctx, self.literal, self._dataframe_type) +def extract_cols_and_format( + t: typing.Any, +) -> typing.Tuple[Type[T], Optional[typing.OrderedDict[str, Type]], Optional[str], Optional[pa.lib.Schema]]: + """ + Helper function, just used to iterate through Annotations and extract out the following information: + - base type, if not Annotated, it will just be the type that was passed in. + - column information, as a collections.OrderedDict, + - the storage format, as a ``StructuredDatasetFormat`` (str), + - pa.lib.Schema + + If more than one of any type of thing is found, an error will be raised. + If no instances of a given type are found, then None will be returned. + + If we add more things, we should put all the returned items in a dataclass instead of just a tuple. + + :param t: The incoming type which may or may not be Annotated + :return: Tuple representing + the original type, + optional OrderedDict of columns, + optional str for the format, + optional pyarrow Schema + """ + fmt = None + ordered_dict_cols = None + pa_schema = None + if get_origin(t) is Annotated: + base_type, *annotate_args = get_args(t) + for aa in annotate_args: + if isinstance(aa, StructuredDatasetFormat): + if fmt is not None: + raise ValueError(f"A format was already specified {fmt}, cannot use {aa}") + fmt = aa + elif isinstance(aa, collections.OrderedDict): + if ordered_dict_cols is not None: + raise ValueError(f"Column information was already found {ordered_dict_cols}, cannot use {aa}") + ordered_dict_cols = aa + elif isinstance(aa, pyarrow.Schema): + if pa_schema is not None: + raise ValueError(f"Arrow schema was already found {pa_schema}, cannot use {aa}") + pa_schema = aa + return base_type, ordered_dict_cols, fmt, pa_schema + + # We return None as the format instead of parquet or something because the transformer engine may find + # a better default for the given dataframe type. + return t, ordered_dict_cols, fmt, pa_schema + + class StructuredDatasetEncoder(ABC): def __init__(self, python_type: Type[T], protocol: str, supported_format: Optional[str] = None): """ @@ -227,8 +279,8 @@ def protocol_prefix(uri: str) -> str: def convert_schema_type_to_structured_dataset_type( - column_type: SchemaType.SchemaColumn.SchemaColumnType, -) -> type_models.SimpleType: + column_type: int, +) -> int: if column_type == SchemaType.SchemaColumn.SchemaColumnType.INTEGER: return type_models.SimpleType.INTEGER if column_type == SchemaType.SchemaColumn.SchemaColumnType.FLOAT: @@ -352,9 +404,10 @@ def to_literal( expected: LiteralType, ) -> Literal: # Make a copy in case we need to hand off to encoders, since we can't be sure of mutations. - # Check first to see if it's even an SD type. For backwards compatibility, we may be getting a + # Check first to see if it's even an SD type. For backwards compatibility, we may be getting a FlyteSchema if get_origin(python_type) is Annotated: python_type = get_args(python_type)[0] + # In case it's a FlyteSchema sdt = StructuredDatasetType(format=self.DEFAULT_FORMATS.get(python_type, None)) if expected and expected.structured_dataset_type: @@ -449,58 +502,106 @@ def encode( return Literal(scalar=Scalar(structured_dataset=sd_model)) def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: + """ + The only tricky thing with converting a Literal (say the output of an earlier task), to a Python value at + the start of a task execution, is the column subsetting behavior. For example, if you have, + + def t1() -> Annotated[StructuredDataset, kwtypes(col_a=int, col_b=float)]: ... + def t2(in_a: Annotated[StructuredDataset, kwtypes(col_b=float)]): ... + + where t2(in_a=t1()), when t2 does in_a.open(pd.DataFrame).all(), it should get a DataFrame + with only one column. + + +-----------------------------+-----------------------------------------+--------------------------------------+ + | | StructuredDatasetType of the incoming Literal | + +-----------------------------+-----------------------------------------+--------------------------------------+ + | StructuredDatasetType | Has columns defined | [] columns or None | + | of currently running task | | | + +=============================+=========================================+======================================+ + | Has columns | The StructuredDatasetType passed to the decoder will have the columns | + | defined | as defined by the type annotation of the currently running task. | + | | | + | | Decoders **should** then subset the incoming data to the columns requested. | + | | | + +-----------------------------+-----------------------------------------+--------------------------------------+ + | [] columns or None | StructuredDatasetType passed to decoder | StructuredDatasetType passed to the | + | | will have the columns from the incoming | decoder will have an empty list of | + | | Literal. This is the scenario where | columns. | + | | the Literal returned by the running | | + | | task will have more information than | | + | | the running task's signature. | | + +-----------------------------+-----------------------------------------+--------------------------------------+ + """ + # Detect annotations and extract out all the relevant information that the user might supply + expected_python_type, column_dict, storage_fmt, pa_schema = extract_cols_and_format(expected_python_type) + # The literal that we get in might be an old FlyteSchema. - # We'll continue to support this for the time being. - subset_columns = {} - if get_origin(expected_python_type) is Annotated: - subset_columns = get_args(expected_python_type)[1] - expected_python_type = get_args(expected_python_type)[0] + # We'll continue to support this for the time being. There is some duplicated logic here but let's + # keep it copy/pasted for clarity if lv.scalar.schema is not None: - sd = StructuredDataset() + schema_columns = lv.scalar.schema.type.columns + + # See the repeated logic below for comments + if column_dict is None or len(column_dict) == 0: + final_dataset_columns = [] + if schema_columns is not None and schema_columns != []: + for c in schema_columns: + final_dataset_columns.append( + StructuredDatasetType.DatasetColumn( + name=c.name, + literal_type=LiteralType( + simple=convert_schema_type_to_structured_dataset_type(c.type), + ), + ) + ) + # Dataframe will always be serialized to parquet file by FlyteSchema transformer + new_sdt = StructuredDatasetType(columns=final_dataset_columns, format=PARQUET) + else: + final_dataset_columns = self._convert_ordered_dict_of_columns_to_list(column_dict) + # Dataframe will always be serialized to parquet file by FlyteSchema transformer + new_sdt = StructuredDatasetType(columns=final_dataset_columns, format=PARQUET) + + metad = literals.StructuredDatasetMetadata(structured_dataset_type=new_sdt) sd_literal = literals.StructuredDataset( uri=lv.scalar.schema.uri, - metadata=literals.StructuredDatasetMetadata( - # Dataframe will always be serialized to parquet file by FlyteSchema transformer - structured_dataset_type=StructuredDatasetType(format=PARQUET) - ), + metadata=metad, ) - sd._literal_sd = sd_literal + if issubclass(expected_python_type, StructuredDataset): - schema_columns = lv.scalar.schema.type.columns - if schema_columns is not None: - subset_dataset_columns = [] - for c in schema_columns: - if c.name in subset_columns: - subset_dataset_columns.append( - StructuredDatasetType.DatasetColumn( - name=c.name, - literal_type=LiteralType( - simple=convert_schema_type_to_structured_dataset_type(c.type) - ), - ) - ) - sd._literal_sd.metadata.structured_dataset_type.columns = subset_dataset_columns + sd = StructuredDataset(dataframe=None, metadata=metad) + sd._literal_sd = sd_literal return sd else: return self.open_as(ctx, sd_literal, df_type=expected_python_type) - # Either a StructuredDataset type or some dataframe type. + # A StructuredDataset type, for example + # t1(input_a: StructuredDataset) # or + # t1(input_a: Annotated[StructuredDataset, my_cols]) if issubclass(expected_python_type, StructuredDataset): - # Just save the literal for now. If in the future we find that we need the StructuredDataset type hint - # type also, we can add it. - dataset_columns = lv.scalar.structured_dataset.metadata.structured_dataset_type.columns - if dataset_columns is not None: - subset_dataset_columns = [] - for c in dataset_columns: - if c.name in subset_columns: - subset_dataset_columns.append(c) - lv.scalar.structured_dataset.metadata.structured_dataset_type.columns = subset_dataset_columns + incoming_columns = lv.scalar.structured_dataset.metadata.structured_dataset_type.columns + + # If the incoming literal, also doesn't have columns, then we just have an empty list, so initialize here + final_dataset_columns = [] + # If the current running task's input does not have columns defined, or has an empty list of columns + if column_dict is None or len(column_dict) == 0: + # but if it does, then we just copy it over + if incoming_columns is not None and incoming_columns != []: + for c in incoming_columns: + final_dataset_columns.append(c) + # If the current running task's input does have columns defined + else: + final_dataset_columns = self._convert_ordered_dict_of_columns_to_list(column_dict) + new_sdt = StructuredDatasetType( + columns=final_dataset_columns, + format=lv.scalar.structured_dataset.metadata.structured_dataset_type.format, + external_schema_type=lv.scalar.structured_dataset.metadata.structured_dataset_type.external_schema_type, + external_schema_bytes=lv.scalar.structured_dataset.metadata.structured_dataset_type.external_schema_bytes, + ) sd = expected_python_type( dataframe=None, - # Specifying these two are just done for completeness. Kind of waste since - # we're saving the whole incoming literal to _literal_sd. - metadata=lv.scalar.structured_dataset.metadata, + # Note here that the type being passed in + metadata=StructuredDatasetMetadata(structured_dataset_type=new_sdt), ) sd._literal_sd = lv.scalar.structured_dataset return sd @@ -527,7 +628,7 @@ def iter_as( raise ValueError(f"Decoder {decoder} didn't return iterator {result} but should have from {sd}") return result - def _get_dataset_column_literal_type(self, t: Type): + def _get_dataset_column_literal_type(self, t: Type) -> type_models.LiteralType: if t in self._SUPPORTED_TYPES: return self._SUPPORTED_TYPES[t] if hasattr(t, "__origin__") and t.__origin__ == list: @@ -536,37 +637,37 @@ def _get_dataset_column_literal_type(self, t: Type): return type_models.LiteralType(map_value_type=self._get_dataset_column_literal_type(t.__args__[1])) raise AssertionError(f"type {t} is currently not supported by StructuredDataset") - def _get_dataset_type(self, t: typing.Union[Type[StructuredDataset], typing.Any]) -> StructuredDatasetType: + def _convert_ordered_dict_of_columns_to_list( + self, column_map: typing.OrderedDict[str, Type] + ) -> typing.List[StructuredDatasetType.DatasetColumn]: converted_cols: typing.List[StructuredDatasetType.DatasetColumn] = [] - # Handle different kinds of annotation - # my_cols = kwtypes(x=int, y=str) - # 1. Fill in format correctly by checking for typing.annotated. For example, Annotated[pd.Dataframe, my_cols] - if get_origin(t) is Annotated: - _, *hint_args = get_args(t) - if type(hint_args[0]) in (collections.OrderedDict, dict): - for k, v in hint_args[0].items(): - lt = self._get_dataset_column_literal_type(v) - converted_cols.append(StructuredDatasetType.DatasetColumn(name=k, literal_type=lt)) - return StructuredDatasetType( - columns=converted_cols, format=PARQUET if len(hint_args) == 1 else hint_args[1] - ) - # 3. Fill in external schema type and bytes by checking for typing.annotated metadata. - # For example, Annotated[pd.Dataframe, pa.schema([("col1", pa.int32()), ("col2", pa.string())])] - elif type(hint_args[0]) is pa.lib.Schema: - return StructuredDatasetType( - format=PARQUET, - external_schema_type="arrow", - external_schema_bytes=typing.cast(pa.lib.Schema, hint_args[0]).to_string().encode(), - ) - raise ValueError(f"Unrecognized Annotated type for StructuredDataset {t}") + if column_map is None or len(column_map) == 0: + return converted_cols + for k, v in column_map.items(): + lt = self._get_dataset_column_literal_type(v) + converted_cols.append(StructuredDatasetType.DatasetColumn(name=k, literal_type=lt)) + return converted_cols - elif issubclass(t, StructuredDataset): - return StructuredDatasetType(columns=None, format=t.DEFAULT_FILE_FORMAT) - - # 3. pd.Dataframe - else: - fmt = self.DEFAULT_FORMATS.get(t, PARQUET) - return StructuredDatasetType(columns=converted_cols, format=fmt) + def _get_dataset_type(self, t: typing.Union[Type[StructuredDataset], typing.Any]) -> StructuredDatasetType: + original_python_type, column_map, storage_format, pa_schema = extract_cols_and_format(t) + + # Get the column information + converted_cols = self._convert_ordered_dict_of_columns_to_list(column_map) + + # Get the format + default_format = ( + original_python_type.DEFAULT_FILE_FORMAT + if issubclass(original_python_type, StructuredDataset) + else self.DEFAULT_FORMATS.get(original_python_type, PARQUET) + ) + fmt = storage_format or default_format + + return StructuredDatasetType( + columns=converted_cols, + format=fmt, + external_schema_type="arrow" if pa_schema else None, + external_schema_bytes=typing.cast(pa.lib.Schema, pa_schema).to_string().encode() if pa_schema else None, + ) def get_literal_type(self, t: typing.Union[Type[StructuredDataset], typing.Any]) -> LiteralType: """ diff --git a/setup.py b/setup.py index ae028038bb..740a1d81e5 100644 --- a/setup.py +++ b/setup.py @@ -61,6 +61,7 @@ "natsort>=7.0.1", "docker-image-py>=0.1.10", "singledispatchmethod; python_version < '3.8.0'", + "typing_extensions", "docstring-parser>=0.9.0", "diskcache>=5.2.1", "checksumdir>=1.2.0", diff --git a/tests/flytekit/unit/core/test_structured_dataset.py b/tests/flytekit/unit/core/test_structured_dataset.py index 59dcdaf96f..59b20cc3b6 100644 --- a/tests/flytekit/unit/core/test_structured_dataset.py +++ b/tests/flytekit/unit/core/test_structured_dataset.py @@ -25,6 +25,7 @@ StructuredDatasetDecoder, StructuredDatasetEncoder, convert_schema_type_to_structured_dataset_type, + extract_cols_and_format, protocol_prefix, ) @@ -59,6 +60,21 @@ def test_types_pandas(): assert lt.structured_dataset_type.columns == [] +def test_annotate_extraction(): + xyz = Annotated[pd.DataFrame, "myformat"] + a, b, c, d = extract_cols_and_format(xyz) + assert a is pd.DataFrame + assert b is None + assert c == "myformat" + assert d is None + + a, b, c, d = extract_cols_and_format(pd.DataFrame) + assert a is pd.DataFrame + assert b is None + assert c is None + assert d is None + + def test_types_annotated(): pt = Annotated[pd.DataFrame, my_cols] lt = TypeEngine.to_literal_type(pt) @@ -70,7 +86,7 @@ def test_types_annotated(): assert lt.structured_dataset_type.columns[2].literal_type.simple == SimpleType.INTEGER assert lt.structured_dataset_type.columns[3].literal_type.simple == SimpleType.STRING - pt = Annotated[pd.DataFrame, arrow_schema] + pt = Annotated[pd.DataFrame, PARQUET, arrow_schema] lt = TypeEngine.to_literal_type(pt) assert lt.structured_dataset_type.external_schema_type == "arrow" assert "some_string" in str(lt.structured_dataset_type.external_schema_bytes) @@ -79,10 +95,6 @@ def test_types_annotated(): with pytest.raises(AssertionError, match="type None is currently not supported by StructuredDataset"): TypeEngine.to_literal_type(pt) - pt = Annotated[pd.DataFrame, None] - with pytest.raises(ValueError, match="Unrecognized Annotated type for StructuredDataset"): - TypeEngine.to_literal_type(pt) - def test_types_sd(): pt = StructuredDataset @@ -251,3 +263,7 @@ def test_convert_schema_type_to_structured_dataset_type(): assert convert_schema_type_to_structured_dataset_type(schema_ct.BOOLEAN) == SimpleType.BOOLEAN with pytest.raises(AssertionError, match="Unrecognized SchemaColumnType"): convert_schema_type_to_structured_dataset_type(int) + + +def test_to_python_value(): + ... diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 27f5425686..d738d12ea5 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -639,12 +639,14 @@ def test_structured_dataset_type(): name = "Name" age = "Age" data = {name: ["Tom", "Joseph"], age: [20, 22]} + superset_cols = kwtypes(Name=str, Age=int) + subset_cols = kwtypes(Name=str) df = pd.DataFrame(data) from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetTransformerEngine tf = StructuredDatasetTransformerEngine() - lt = tf.get_literal_type(Annotated[StructuredDataset, {name: str, age: int}, "parquet"]) + lt = tf.get_literal_type(Annotated[StructuredDataset, superset_cols, "parquet"]) assert lt.structured_dataset_type is not None ctx = FlyteContextManager.current_context() @@ -657,7 +659,7 @@ def test_structured_dataset_type(): assert_frame_equal(df, v1) assert_frame_equal(df, v2.to_pandas()) - subset_lt = tf.get_literal_type(Annotated[StructuredDataset, {name: str}, "parquet"]) + subset_lt = tf.get_literal_type(Annotated[StructuredDataset, subset_cols, "parquet"]) assert subset_lt.structured_dataset_type is not None subset_lv = tf.to_literal(ctx, df, pd.DataFrame, subset_lt) @@ -668,6 +670,14 @@ def test_structured_dataset_type(): assert_frame_equal(subset_data, v1) assert_frame_equal(subset_data, v2.to_pandas()) + empty_lt = tf.get_literal_type(Annotated[StructuredDataset, "parquet"]) + assert empty_lt.structured_dataset_type is not None + empty_lv = tf.to_literal(ctx, df, pd.DataFrame, empty_lt) + v1 = tf.to_python_value(ctx, empty_lv, pd.DataFrame) + v2 = tf.to_python_value(ctx, empty_lv, pa.Table) + assert_frame_equal(df, v1) + assert_frame_equal(df, v2.to_pandas()) + def test_enum_type(): t = TypeEngine.to_literal_type(Color)