Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revisit StructuredDatasetDecoder interface #865

Merged
merged 9 commits into from
Mar 1, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flytekit/models/literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ def from_flyte_idl(cls, pb2_object):


class StructuredDatasetMetadata(_common.FlyteIdlEntity):
def __init__(self, structured_dataset_type: StructuredDatasetType = None):
wild-endeavor marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, structured_dataset_type: StructuredDatasetType):
self._structured_dataset_type = structured_dataset_type

@property
Expand Down
14 changes: 6 additions & 8 deletions flytekit/types/structured/basic_dfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,13 @@ def decode(
self,
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
current_task_metadata: typing.Optional[StructuredDatasetMetadata] = None,
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
) -> pd.DataFrame:
path = flyte_value.uri
local_dir = ctx.file_access.get_random_local_directory()
ctx.file_access.get_data(path, local_dir, is_multipart=True)
if flyte_value.metadata.structured_dataset_type.columns:
columns = []
for c in flyte_value.metadata.structured_dataset_type.columns:
columns.append(c.name)
if current_task_metadata and current_task_metadata.structured_dataset_type.columns:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, do we still need "if current_task_metadata"

columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
return pd.read_parquet(local_dir, columns=columns)
return pd.read_parquet(local_dir)

Expand Down Expand Up @@ -94,14 +93,13 @@ def decode(
self,
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
current_task_metadata: typing.Optional[StructuredDatasetMetadata] = None,
) -> pa.Table:
path = flyte_value.uri
local_dir = ctx.file_access.get_random_local_directory()
ctx.file_access.get_data(path, local_dir, is_multipart=True)
if flyte_value.metadata.structured_dataset_type.columns:
columns = []
for c in flyte_value.metadata.structured_dataset_type.columns:
columns.append(c.name)
if current_task_metadata and current_task_metadata.structured_dataset_type.columns:
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
return pq.read_table(local_dir, columns=columns)
return pq.read_table(local_dir)

Expand Down
19 changes: 10 additions & 9 deletions flytekit/types/structured/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from flytekit.models.types import StructuredDatasetType
from flytekit.types.structured.structured_dataset import (
BIGQUERY,
DF,
StructuredDataset,
StructuredDatasetDecoder,
StructuredDatasetEncoder,
Expand All @@ -29,18 +28,18 @@ def _write_to_bq(structured_dataset: StructuredDataset):
client.load_table_from_dataframe(df, table_id)


def _read_from_bq(flyte_value: literals.StructuredDataset) -> pd.DataFrame:
def _read_from_bq(
flyte_value: literals.StructuredDataset, current_task_metadata: typing.Optional[StructuredDatasetMetadata] = None
) -> pd.DataFrame:
path = flyte_value.uri
_, project_id, dataset_id, table_id = re.split("\\.|://|:", path)
client = bigquery_storage.BigQueryReadClient()
table = f"projects/{project_id}/datasets/{dataset_id}/tables/{table_id}"
parent = "projects/{}".format(project_id)

read_options = None
if flyte_value.metadata.structured_dataset_type.columns:
columns = []
for c in flyte_value.metadata.structured_dataset_type.columns:
columns.append(c.name)
if current_task_metadata and current_task_metadata.structured_dataset_type.columns:
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
read_options = types.ReadSession.TableReadOptions(selected_fields=columns)

requested_session = types.ReadSession(table=table, data_format=types.DataFormat.ARROW, read_options=read_options)
Expand Down Expand Up @@ -78,8 +77,9 @@ def decode(
self,
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
) -> typing.Union[DF, typing.Generator[DF, None, None]]:
return _read_from_bq(flyte_value)
current_task_metadata: typing.Optional[StructuredDatasetMetadata] = None,
) -> pd.DataFrame:
return _read_from_bq(flyte_value, current_task_metadata)


class ArrowToBQEncodingHandlers(StructuredDatasetEncoder):
Expand All @@ -106,7 +106,8 @@ def decode(
self,
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
) -> typing.Union[DF, typing.Generator[DF, None, None]]:
current_task_metadata: typing.Optional[StructuredDatasetMetadata] = None,
) -> pa.Table:
return pa.Table.from_pandas(_read_from_bq(flyte_value))


Expand Down
22 changes: 10 additions & 12 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,14 +261,17 @@ def decode(
self,
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
current_task_metadata: Optional[StructuredDatasetMetadata] = None,
) -> Union[DF, Generator[DF, None, None]]:
"""
This is code that will be called by the dataset transformer engine to ultimately translate from a Flyte Literal
value into a Python instance.

:param ctx:
:param ctx: A FlyteContext, useful in accessing the filesystem and other attributes
:param flyte_value: This will be a Flyte IDL StructuredDataset Literal - do not confuse this with the
StructuredDataset class defined also in this module.
:param current_task_metadata: Metadata contains column name and type, and decoder will use it to
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
read specific column of parquet file or database table .It might be different from the metadata in the incoming literal.
:return: This function can either return an instance of the dataframe that this decoder handles, or an iterator
of those dataframes.
"""
Expand Down Expand Up @@ -596,8 +599,7 @@ def t2(in_a: Annotated[StructuredDataset, kwtypes(col_b=float)]): ...
if column_dict is None or len(column_dict) == 0:
# but if it does, then we just copy it over
if incoming_columns is not None and incoming_columns != []:
for c in incoming_columns:
final_dataset_columns.append(c)
final_dataset_columns = incoming_columns.copy()
# If the current running task's input does have columns defined
else:
final_dataset_columns = self._convert_ordered_dict_of_columns_to_list(column_dict)
Expand Down Expand Up @@ -631,22 +633,18 @@ def open_as(
ctx: FlyteContext,
sd: literals.StructuredDataset,
df_type: Type[DF],
updated_metadata: Optional[StructuredDatasetMetadata] = None,
metadata: Optional[StructuredDatasetMetadata] = None,
) -> DF:
"""

:param ctx:
:param ctx: A FlyteContext, useful in accessing the filesystem and other attributes
:param sd:
:param df_type:
:param meta: New metadata type, since it might be different from the metadata in the literal.
:return:
:param metadata: New metadata type, since it might be different from the metadata in the literal.
:return: dataframe. It could be pandas dataframe or arrow table, etc.
"""
protocol = protocol_prefix(sd.uri)
decoder = self.get_decoder(df_type, protocol, sd.metadata.structured_dataset_type.format)
# todo: revisit this, we probably should add a new field to the decoder interface
if updated_metadata:
sd._metadata = updated_metadata
result = decoder.decode(ctx, sd)
result = decoder.decode(ctx, sd, metadata)
if isinstance(result, types.GeneratorType):
raise ValueError(f"Decoder {decoder} returned iterator {result} but whole value requested from {sd}")
return result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,16 @@ def decode(
self,
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
current_task_metadata: typing.Optional[StructuredDatasetMetadata] = None,
) -> pa.Table:
uri = flyte_value.uri
if not ctx.file_access.is_remote(uri):
Path(uri).parent.mkdir(parents=True, exist_ok=True)
_, path = split_protocol(uri)

columns = None
if flyte_value.metadata.structured_dataset_type.columns:
columns = []
for c in flyte_value.metadata.structured_dataset_type.columns:
columns.append(c.name)
if current_task_metadata and current_task_metadata.structured_dataset_type.columns:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is no longer optional right? Can we remove the first if check?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, updated it to

if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns

columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
try:
fs = FSSpecPersistence.get_filesystem(uri)
return pq.read_table(path, filesystem=fs, columns=columns)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,13 @@ def decode(
self,
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
current_task_metadata: typing.Optional[StructuredDatasetMetadata] = None,
) -> pd.DataFrame:
uri = flyte_value.uri
columns = None
kwargs = get_storage_options(uri)
if flyte_value.metadata.structured_dataset_type.columns:
columns = []
for c in flyte_value.metadata.structured_dataset_type.columns:
columns.append(c.name)
if current_task_metadata and current_task_metadata.structured_dataset_type.columns:
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
try:
return pd.read_parquet(uri, columns=columns, storage_options=kwargs)
except NoCredentialsError:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,12 @@ def decode(
self,
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
current_task_metadata: typing.Optional[StructuredDatasetMetadata] = None,
) -> DataFrame:
user_ctx = FlyteContext.current_context().user_space_params
if current_task_metadata and current_task_metadata.structured_dataset_type.columns:
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
return user_ctx.spark_session.read.parquet(flyte_value.uri).select(*columns)
return user_ctx.spark_session.read.parquet(flyte_value.uri)


Expand Down
43 changes: 21 additions & 22 deletions plugins/flytekit-spark/tests/test_wf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
from flytekit import kwtypes, task, workflow
from flytekit.types.schema import FlyteSchema

try:
from typing import Annotated
except ImportError:
from typing_extensions import Annotated


def test_wf1_with_spark():
@task(task_config=Spark())
Expand Down Expand Up @@ -53,27 +58,6 @@ def my_wf() -> my_schema:
assert df2 is not None


def test_ddwf1_with_spark():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this test bad for some reason?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have the same test here.

def test_wf1_with_spark():
@task(task_config=Spark())
def my_spark(a: int) -> (int, str):
session = flytekit.current_context().spark_session
assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local"
return a + 2, "world"
@task
def t2(a: str, b: str) -> str:
return b + a
@workflow
def my_wf(a: int, b: str) -> (int, str):
x, y = my_spark(a=a)
d = t2(a=y, b=b)
return x, d
x = my_wf(a=5, b="hello ")
assert x == (7, "hello world")
****

@task(task_config=Spark())
def my_spark(a: int) -> (int, str):
session = flytekit.current_context().spark_session
assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local"
return a + 2, "world"

@task
def t2(a: str, b: str) -> str:
return b + a

@workflow
def my_wf(a: int, b: str) -> (int, str):
x, y = my_spark(a=a)
d = t2(a=y, b=b)
return x, d

x = my_wf(a=5, b="hello ")
assert x == (7, "hello world")


def test_fs_sd_compatibility():
my_schema = FlyteSchema[kwtypes(name=str, age=int)]

Expand Down Expand Up @@ -108,7 +92,6 @@ def test_spark_dataframe_return():
def my_spark(a: int) -> my_schema:
session = flytekit.current_context().spark_session
df = session.createDataFrame([("Alice", a)], my_schema.column_names())
print(type(df))
return df

@workflow
Expand All @@ -120,3 +103,19 @@ def my_wf(a: int) -> my_schema:
df2 = reader.all()
result_df = df2.reset_index(drop=True) == pd.DataFrame(data={"name": ["Alice"], "age": [5]}).reset_index(drop=True)
assert result_df.all().all()


def test_read_spark_subset_columns():
@task
def t1() -> pd.DataFrame:
return pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})

@task(task_config=Spark())
def t2(df: Annotated[pyspark.sql.DataFrame, kwtypes(Name=str)]) -> int:
return len(df.columns)

@workflow()
def wf() -> int:
return t2(df=t1())

assert wf() == 1
4 changes: 3 additions & 1 deletion tests/flytekit/unit/core/test_structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def decode(
self,
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
current_task_metadata: typing.Optional[StructuredDatasetMetadata] = None,
) -> typing.Union[typing.Generator[pd.DataFrame, None, None]]:
yield pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})

Expand All @@ -241,8 +242,9 @@ def decode(
self,
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
current_task_metadata: typing.Optional[StructuredDatasetMetadata] = None,
) -> pd.DataFrame:
pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})
return pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})

StructuredDatasetTransformerEngine.register(
MockPandasDecodingHandlers(pd.DataFrame, "tmpfs"), default_for_type=False, override=True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def decode(
self,
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
current_task_metadata: typing.Optional[StructuredDatasetMetadata] = None,
) -> pd.DataFrame:
return pd_df

Expand Down Expand Up @@ -86,6 +87,7 @@ def decode(
self,
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
current_task_metadata: typing.Optional[StructuredDatasetMetadata] = None,
) -> typing.Union[DF, typing.Generator[DF, None, None]]:
path = flyte_value.uri
local_dir = ctx.file_access.get_random_local_directory()
Expand Down