diff --git a/flytekit/models/literals.py b/flytekit/models/literals.py index bc398ab7a7..a7bdf43153 100644 --- a/flytekit/models/literals.py +++ b/flytekit/models/literals.py @@ -549,7 +549,7 @@ def from_flyte_idl(cls, pb2_object): class StructuredDatasetMetadata(_common.FlyteIdlEntity): - def __init__(self, structured_dataset_type: StructuredDatasetType = None): + def __init__(self, structured_dataset_type: StructuredDatasetType): self._structured_dataset_type = structured_dataset_type @property diff --git a/flytekit/types/structured/basic_dfs.py b/flytekit/types/structured/basic_dfs.py index 49b2f13ed9..7e5a3d5e89 100644 --- a/flytekit/types/structured/basic_dfs.py +++ b/flytekit/types/structured/basic_dfs.py @@ -55,14 +55,13 @@ def decode( self, ctx: FlyteContext, flyte_value: literals.StructuredDataset, + current_task_metadata: typing.Optional[StructuredDatasetMetadata] = None, ) -> pd.DataFrame: path = flyte_value.uri local_dir = ctx.file_access.get_random_local_directory() ctx.file_access.get_data(path, local_dir, is_multipart=True) - if flyte_value.metadata.structured_dataset_type.columns: - columns = [] - for c in flyte_value.metadata.structured_dataset_type.columns: - columns.append(c.name) + if current_task_metadata and current_task_metadata.structured_dataset_type.columns: + columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] return pd.read_parquet(local_dir, columns=columns) return pd.read_parquet(local_dir) @@ -94,14 +93,13 @@ def decode( self, ctx: FlyteContext, flyte_value: literals.StructuredDataset, + current_task_metadata: typing.Optional[StructuredDatasetMetadata] = None, ) -> pa.Table: path = flyte_value.uri local_dir = ctx.file_access.get_random_local_directory() ctx.file_access.get_data(path, local_dir, is_multipart=True) - if flyte_value.metadata.structured_dataset_type.columns: - columns = [] - for c in flyte_value.metadata.structured_dataset_type.columns: - columns.append(c.name) + if current_task_metadata and current_task_metadata.structured_dataset_type.columns: + columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] return pq.read_table(local_dir, columns=columns) return pq.read_table(local_dir) diff --git a/flytekit/types/structured/bigquery.py b/flytekit/types/structured/bigquery.py index 923ea06e9e..280210ead2 100644 --- a/flytekit/types/structured/bigquery.py +++ b/flytekit/types/structured/bigquery.py @@ -11,7 +11,6 @@ from flytekit.models.types import StructuredDatasetType from flytekit.types.structured.structured_dataset import ( BIGQUERY, - DF, StructuredDataset, StructuredDatasetDecoder, StructuredDatasetEncoder, @@ -29,7 +28,9 @@ def _write_to_bq(structured_dataset: StructuredDataset): client.load_table_from_dataframe(df, table_id) -def _read_from_bq(flyte_value: literals.StructuredDataset) -> pd.DataFrame: +def _read_from_bq( + flyte_value: literals.StructuredDataset, current_task_metadata: typing.Optional[StructuredDatasetMetadata] = None +) -> pd.DataFrame: path = flyte_value.uri _, project_id, dataset_id, table_id = re.split("\\.|://|:", path) client = bigquery_storage.BigQueryReadClient() @@ -37,10 +38,8 @@ def _read_from_bq(flyte_value: literals.StructuredDataset) -> pd.DataFrame: parent = "projects/{}".format(project_id) read_options = None - if flyte_value.metadata.structured_dataset_type.columns: - columns = [] - for c in flyte_value.metadata.structured_dataset_type.columns: - columns.append(c.name) + if current_task_metadata and current_task_metadata.structured_dataset_type.columns: + columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] read_options = types.ReadSession.TableReadOptions(selected_fields=columns) requested_session = types.ReadSession(table=table, data_format=types.DataFormat.ARROW, read_options=read_options) @@ -78,8 +77,9 @@ def decode( self, ctx: FlyteContext, flyte_value: literals.StructuredDataset, - ) -> typing.Union[DF, typing.Generator[DF, None, None]]: - return _read_from_bq(flyte_value) + current_task_metadata: typing.Optional[StructuredDatasetMetadata] = None, + ) -> pd.DataFrame: + return _read_from_bq(flyte_value, current_task_metadata) class ArrowToBQEncodingHandlers(StructuredDatasetEncoder): @@ -106,7 +106,8 @@ def decode( self, ctx: FlyteContext, flyte_value: literals.StructuredDataset, - ) -> typing.Union[DF, typing.Generator[DF, None, None]]: + current_task_metadata: typing.Optional[StructuredDatasetMetadata] = None, + ) -> pa.Table: return pa.Table.from_pandas(_read_from_bq(flyte_value)) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 819bc012cc..11e816187a 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -261,14 +261,17 @@ def decode( self, ctx: FlyteContext, flyte_value: literals.StructuredDataset, + current_task_metadata: Optional[StructuredDatasetMetadata] = None, ) -> Union[DF, Generator[DF, None, None]]: """ This is code that will be called by the dataset transformer engine to ultimately translate from a Flyte Literal value into a Python instance. - :param ctx: + :param ctx: A FlyteContext, useful in accessing the filesystem and other attributes :param flyte_value: This will be a Flyte IDL StructuredDataset Literal - do not confuse this with the StructuredDataset class defined also in this module. + :param current_task_metadata: Metadata contains column name and type, and decoder will use it to + read specific column of parquet file or database table .It might be different from the metadata in the incoming literal. :return: This function can either return an instance of the dataframe that this decoder handles, or an iterator of those dataframes. """ @@ -596,8 +599,7 @@ def t2(in_a: Annotated[StructuredDataset, kwtypes(col_b=float)]): ... if column_dict is None or len(column_dict) == 0: # but if it does, then we just copy it over if incoming_columns is not None and incoming_columns != []: - for c in incoming_columns: - final_dataset_columns.append(c) + final_dataset_columns = incoming_columns.copy() # If the current running task's input does have columns defined else: final_dataset_columns = self._convert_ordered_dict_of_columns_to_list(column_dict) @@ -614,7 +616,7 @@ def t2(in_a: Annotated[StructuredDataset, kwtypes(col_b=float)]): ... # t1(input_a: StructuredDataset) # or # t1(input_a: Annotated[StructuredDataset, my_cols]) if issubclass(expected_python_type, StructuredDataset): - sd = expected_python_type( + sd = StructuredDataset( dataframe=None, # Note here that the type being passed in metadata=metad, @@ -634,19 +636,15 @@ def open_as( updated_metadata: Optional[StructuredDatasetMetadata] = None, ) -> DF: """ - - :param ctx: + :param ctx: A FlyteContext, useful in accessing the filesystem and other attributes :param sd: :param df_type: - :param meta: New metadata type, since it might be different from the metadata in the literal. - :return: + :param updated_metadata: New metadata type, since it might be different from the metadata in the literal. + :return: dataframe. It could be pandas dataframe or arrow table, etc. """ protocol = protocol_prefix(sd.uri) decoder = self.get_decoder(df_type, protocol, sd.metadata.structured_dataset_type.format) - # todo: revisit this, we probably should add a new field to the decoder interface - if updated_metadata: - sd._metadata = updated_metadata - result = decoder.decode(ctx, sd) + result = decoder.decode(ctx, sd, updated_metadata) if isinstance(result, types.GeneratorType): raise ValueError(f"Decoder {decoder} returned iterator {result} but whole value requested from {sd}") return result diff --git a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py index 2466d3fc13..7c9a2af0f5 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py @@ -39,8 +39,12 @@ def decode( self, ctx: FlyteContext, flyte_value: literals.StructuredDataset, + current_task_metadata: typing.Optional[StructuredDatasetMetadata] = None, ) -> DataFrame: user_ctx = FlyteContext.current_context().user_space_params + if current_task_metadata and current_task_metadata.structured_dataset_type.columns: + columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] + return user_ctx.spark_session.read.parquet(flyte_value.uri).select(*columns) return user_ctx.spark_session.read.parquet(flyte_value.uri) diff --git a/plugins/flytekit-spark/tests/test_wf.py b/plugins/flytekit-spark/tests/test_wf.py index a0a624fec7..8c42a6162f 100644 --- a/plugins/flytekit-spark/tests/test_wf.py +++ b/plugins/flytekit-spark/tests/test_wf.py @@ -6,6 +6,11 @@ from flytekit import kwtypes, task, workflow from flytekit.types.schema import FlyteSchema +try: + from typing import Annotated +except ImportError: + from typing_extensions import Annotated + def test_wf1_with_spark(): @task(task_config=Spark()) @@ -53,27 +58,6 @@ def my_wf() -> my_schema: assert df2 is not None -def test_ddwf1_with_spark(): - @task(task_config=Spark()) - def my_spark(a: int) -> (int, str): - session = flytekit.current_context().spark_session - assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local" - return a + 2, "world" - - @task - def t2(a: str, b: str) -> str: - return b + a - - @workflow - def my_wf(a: int, b: str) -> (int, str): - x, y = my_spark(a=a) - d = t2(a=y, b=b) - return x, d - - x = my_wf(a=5, b="hello ") - assert x == (7, "hello world") - - def test_fs_sd_compatibility(): my_schema = FlyteSchema[kwtypes(name=str, age=int)] @@ -108,7 +92,6 @@ def test_spark_dataframe_return(): def my_spark(a: int) -> my_schema: session = flytekit.current_context().spark_session df = session.createDataFrame([("Alice", a)], my_schema.column_names()) - print(type(df)) return df @workflow @@ -120,3 +103,19 @@ def my_wf(a: int) -> my_schema: df2 = reader.all() result_df = df2.reset_index(drop=True) == pd.DataFrame(data={"name": ["Alice"], "age": [5]}).reset_index(drop=True) assert result_df.all().all() + + +def test_read_spark_subset_columns(): + @task + def t1() -> pd.DataFrame: + return pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + + @task(task_config=Spark()) + def t2(df: Annotated[pyspark.sql.DataFrame, kwtypes(Name=str)]) -> int: + return len(df.columns) + + @workflow() + def wf() -> int: + return t2(df=t1()) + + assert wf() == 1