Skip to content

Commit

Permalink
Revisit StructuredDatasetDecoder interface
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw committed Feb 24, 2022
1 parent 9477e1f commit eb1c75e
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 52 deletions.
2 changes: 1 addition & 1 deletion flytekit/models/literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ def from_flyte_idl(cls, pb2_object):


class StructuredDatasetMetadata(_common.FlyteIdlEntity):
def __init__(self, structured_dataset_type: StructuredDatasetType = None):
def __init__(self, structured_dataset_type: StructuredDatasetType):
self._structured_dataset_type = structured_dataset_type

@property
Expand Down
14 changes: 6 additions & 8 deletions flytekit/types/structured/basic_dfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,13 @@ def decode(
self,
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
current_task_metadata: typing.Optional[StructuredDatasetMetadata] = None,
) -> pd.DataFrame:
path = flyte_value.uri
local_dir = ctx.file_access.get_random_local_directory()
ctx.file_access.get_data(path, local_dir, is_multipart=True)
if flyte_value.metadata.structured_dataset_type.columns:
columns = []
for c in flyte_value.metadata.structured_dataset_type.columns:
columns.append(c.name)
if current_task_metadata and current_task_metadata.structured_dataset_type.columns:
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
return pd.read_parquet(local_dir, columns=columns)
return pd.read_parquet(local_dir)

Expand Down Expand Up @@ -94,14 +93,13 @@ def decode(
self,
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
current_task_metadata: typing.Optional[StructuredDatasetMetadata] = None,
) -> pa.Table:
path = flyte_value.uri
local_dir = ctx.file_access.get_random_local_directory()
ctx.file_access.get_data(path, local_dir, is_multipart=True)
if flyte_value.metadata.structured_dataset_type.columns:
columns = []
for c in flyte_value.metadata.structured_dataset_type.columns:
columns.append(c.name)
if current_task_metadata and current_task_metadata.structured_dataset_type.columns:
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
return pq.read_table(local_dir, columns=columns)
return pq.read_table(local_dir)

Expand Down
19 changes: 10 additions & 9 deletions flytekit/types/structured/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from flytekit.models.types import StructuredDatasetType
from flytekit.types.structured.structured_dataset import (
BIGQUERY,
DF,
StructuredDataset,
StructuredDatasetDecoder,
StructuredDatasetEncoder,
Expand All @@ -29,18 +28,18 @@ def _write_to_bq(structured_dataset: StructuredDataset):
client.load_table_from_dataframe(df, table_id)


def _read_from_bq(flyte_value: literals.StructuredDataset) -> pd.DataFrame:
def _read_from_bq(
flyte_value: literals.StructuredDataset, current_task_metadata: typing.Optional[StructuredDatasetMetadata] = None
) -> pd.DataFrame:
path = flyte_value.uri
_, project_id, dataset_id, table_id = re.split("\\.|://|:", path)
client = bigquery_storage.BigQueryReadClient()
table = f"projects/{project_id}/datasets/{dataset_id}/tables/{table_id}"
parent = "projects/{}".format(project_id)

read_options = None
if flyte_value.metadata.structured_dataset_type.columns:
columns = []
for c in flyte_value.metadata.structured_dataset_type.columns:
columns.append(c.name)
if current_task_metadata and current_task_metadata.structured_dataset_type.columns:
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
read_options = types.ReadSession.TableReadOptions(selected_fields=columns)

requested_session = types.ReadSession(table=table, data_format=types.DataFormat.ARROW, read_options=read_options)
Expand Down Expand Up @@ -78,8 +77,9 @@ def decode(
self,
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
) -> typing.Union[DF, typing.Generator[DF, None, None]]:
return _read_from_bq(flyte_value)
current_task_metadata: typing.Optional[StructuredDatasetMetadata] = None,
) -> pd.DataFrame:
return _read_from_bq(flyte_value, current_task_metadata)


class ArrowToBQEncodingHandlers(StructuredDatasetEncoder):
Expand All @@ -106,7 +106,8 @@ def decode(
self,
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
) -> typing.Union[DF, typing.Generator[DF, None, None]]:
current_task_metadata: typing.Optional[StructuredDatasetMetadata] = None,
) -> pa.Table:
return pa.Table.from_pandas(_read_from_bq(flyte_value))


Expand Down
22 changes: 10 additions & 12 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,14 +261,17 @@ def decode(
self,
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
current_task_metadata: Optional[StructuredDatasetMetadata] = None,
) -> Union[DF, Generator[DF, None, None]]:
"""
This is code that will be called by the dataset transformer engine to ultimately translate from a Flyte Literal
value into a Python instance.
:param ctx:
:param ctx: A FlyteContext, useful in accessing the filesystem and other attributes
:param flyte_value: This will be a Flyte IDL StructuredDataset Literal - do not confuse this with the
StructuredDataset class defined also in this module.
:param current_task_metadata: Metadata contains column name and type, and decoder will use it to
read specific column of parquet file or database table .It might be different from the metadata in the incoming literal.
:return: This function can either return an instance of the dataframe that this decoder handles, or an iterator
of those dataframes.
"""
Expand Down Expand Up @@ -596,8 +599,7 @@ def t2(in_a: Annotated[StructuredDataset, kwtypes(col_b=float)]): ...
if column_dict is None or len(column_dict) == 0:
# but if it does, then we just copy it over
if incoming_columns is not None and incoming_columns != []:
for c in incoming_columns:
final_dataset_columns.append(c)
final_dataset_columns = incoming_columns.copy()
# If the current running task's input does have columns defined
else:
final_dataset_columns = self._convert_ordered_dict_of_columns_to_list(column_dict)
Expand All @@ -614,7 +616,7 @@ def t2(in_a: Annotated[StructuredDataset, kwtypes(col_b=float)]): ...
# t1(input_a: StructuredDataset) # or
# t1(input_a: Annotated[StructuredDataset, my_cols])
if issubclass(expected_python_type, StructuredDataset):
sd = expected_python_type(
sd = StructuredDataset(
dataframe=None,
# Note here that the type being passed in
metadata=metad,
Expand All @@ -634,19 +636,15 @@ def open_as(
updated_metadata: Optional[StructuredDatasetMetadata] = None,
) -> DF:
"""
:param ctx:
:param ctx: A FlyteContext, useful in accessing the filesystem and other attributes
:param sd:
:param df_type:
:param meta: New metadata type, since it might be different from the metadata in the literal.
:return:
:param updated_metadata: New metadata type, since it might be different from the metadata in the literal.
:return: dataframe. It could be pandas dataframe or arrow table, etc.
"""
protocol = protocol_prefix(sd.uri)
decoder = self.get_decoder(df_type, protocol, sd.metadata.structured_dataset_type.format)
# todo: revisit this, we probably should add a new field to the decoder interface
if updated_metadata:
sd._metadata = updated_metadata
result = decoder.decode(ctx, sd)
result = decoder.decode(ctx, sd, updated_metadata)
if isinstance(result, types.GeneratorType):
raise ValueError(f"Decoder {decoder} returned iterator {result} but whole value requested from {sd}")
return result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,12 @@ def decode(
self,
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
current_task_metadata: typing.Optional[StructuredDatasetMetadata] = None,
) -> DataFrame:
user_ctx = FlyteContext.current_context().user_space_params
if current_task_metadata and current_task_metadata.structured_dataset_type.columns:
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
return user_ctx.spark_session.read.parquet(flyte_value.uri).select(*columns)
return user_ctx.spark_session.read.parquet(flyte_value.uri)


Expand Down
43 changes: 21 additions & 22 deletions plugins/flytekit-spark/tests/test_wf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
from flytekit import kwtypes, task, workflow
from flytekit.types.schema import FlyteSchema

try:
from typing import Annotated
except ImportError:
from typing_extensions import Annotated


def test_wf1_with_spark():
@task(task_config=Spark())
Expand Down Expand Up @@ -53,27 +58,6 @@ def my_wf() -> my_schema:
assert df2 is not None


def test_ddwf1_with_spark():
@task(task_config=Spark())
def my_spark(a: int) -> (int, str):
session = flytekit.current_context().spark_session
assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local"
return a + 2, "world"

@task
def t2(a: str, b: str) -> str:
return b + a

@workflow
def my_wf(a: int, b: str) -> (int, str):
x, y = my_spark(a=a)
d = t2(a=y, b=b)
return x, d

x = my_wf(a=5, b="hello ")
assert x == (7, "hello world")


def test_fs_sd_compatibility():
my_schema = FlyteSchema[kwtypes(name=str, age=int)]

Expand Down Expand Up @@ -108,7 +92,6 @@ def test_spark_dataframe_return():
def my_spark(a: int) -> my_schema:
session = flytekit.current_context().spark_session
df = session.createDataFrame([("Alice", a)], my_schema.column_names())
print(type(df))
return df

@workflow
Expand All @@ -120,3 +103,19 @@ def my_wf(a: int) -> my_schema:
df2 = reader.all()
result_df = df2.reset_index(drop=True) == pd.DataFrame(data={"name": ["Alice"], "age": [5]}).reset_index(drop=True)
assert result_df.all().all()


def test_read_spark_subset_columns():
@task
def t1() -> pd.DataFrame:
return pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})

@task(task_config=Spark())
def t2(df: Annotated[pyspark.sql.DataFrame, kwtypes(Name=str)]) -> int:
return len(df.columns)

@workflow()
def wf() -> int:
return t2(df=t1())

assert wf() == 1

0 comments on commit eb1c75e

Please sign in to comment.