Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Register dataframe renderers in structured dataset #1140

Merged
merged 6 commits into from
Aug 26, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions flytekit/deck/renderer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Optional

import pandas
import pyarrow
from typing_extensions import Protocol, runtime_checkable


@@ -24,3 +25,13 @@ def __init__(self, max_rows: Optional[int] = None):
def to_html(self, df: pandas.DataFrame) -> str:
assert isinstance(df, pandas.DataFrame)
return df.to_html(max_rows=self._max_rows)


class ArrowRenderer:
"""
Render a Arrow dataframe as an HTML table.
"""

def to_html(self, df: pyarrow.Table) -> str:
assert isinstance(df, pyarrow.Table)
return df.to_string()
5 changes: 5 additions & 0 deletions flytekit/types/structured/basic_dfs.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,8 @@
import pyarrow.parquet as pq

from flytekit import FlyteContext
from flytekit.deck import TopFrameRenderer
from flytekit.deck.renderer import ArrowRenderer
from flytekit.models import literals
from flytekit.models.literals import StructuredDatasetMetadata
from flytekit.models.types import StructuredDatasetType
@@ -103,3 +105,6 @@ def decode(
StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler())
StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler())
StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler())

StructuredDatasetTransformerEngine.register_renderer(pd.DataFrame, TopFrameRenderer())
StructuredDatasetTransformerEngine.register_renderer(pa.Table, ArrowRenderer())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a quick unit test to test that these are loaded into the Renderers dict by default?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did. here

31 changes: 10 additions & 21 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import collections
import importlib
import os
import types
import typing
@@ -13,20 +12,14 @@
import numpy as _np
import pandas as pd
import pyarrow as pa

from flytekit.core.data_persistence import DataPersistencePlugins, DiskPersistence

if importlib.util.find_spec("pyspark") is not None:
import pyspark
if importlib.util.find_spec("polars") is not None:
import polars as pl

from dataclasses_json import config, dataclass_json
from marshmallow import fields
from typing_extensions import Annotated, TypeAlias, get_args, get_origin

from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.data_persistence import DataPersistencePlugins, DiskPersistence
from flytekit.core.type_engine import TypeEngine, TypeTransformer
from flytekit.deck.renderer import Renderable
from flytekit.loggers import logger
from flytekit.models import literals
from flytekit.models import types as type_models
@@ -339,6 +332,7 @@ class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]):
DEFAULT_FORMATS: Dict[Type, str] = {}

Handlers = Union[StructuredDatasetEncoder, StructuredDatasetDecoder]
Renderers: Dict[Type, Renderable] = {}

@staticmethod
def _finder(handler_map, df_type: Type, protocol: str, format: str):
@@ -385,6 +379,10 @@ def __init__(self):
# Instances of StructuredDataset opt-in to the ability of being cached.
self._hash_overridable = True

@classmethod
def register_renderer(cls, python_type: Type, renderer: Renderable):
cls.Renderers[python_type] = renderer

@classmethod
def register(cls, h: Handlers, default_for_type: Optional[bool] = False, override: Optional[bool] = False):
"""
@@ -698,19 +696,10 @@ def to_html(self, ctx: FlyteContext, python_val: typing.Any, expected_python_typ
else:
df = python_val

if isinstance(df, pd.DataFrame):
return df.describe().to_html()
elif isinstance(df, pa.Table):
return df.to_string()
elif isinstance(df, _np.ndarray):
return pd.DataFrame(df).describe().to_html()
elif importlib.util.find_spec("pyspark") is not None and isinstance(df, pyspark.sql.DataFrame):
return pd.DataFrame(df.schema, columns=["StructField"]).to_html()
elif importlib.util.find_spec("polars") is not None and isinstance(df, pl.DataFrame):
describe_df = df.describe()
return pd.DataFrame(describe_df.transpose(), columns=describe_df.columns).to_html(index=False)
if type(df) in self.Renderers:
return self.Renderers[type(df)].to_html(df)
else:
raise NotImplementedError("Conversion to html string should be implemented")
raise NotImplementedError(f"Could not find a renderer for {type(df)} in {self.Renderers}")

def open_as(
self,
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing

import pandas as pd
import polars as pl

from flytekit import FlyteContext
@@ -15,6 +16,17 @@
)


class PolarsDataFrameRenderer:
"""
The Polars DataFrame summary statistics are rendered as an HTML table.
"""

def to_html(self, df: pl.DataFrame) -> str:
assert isinstance(df, pl.DataFrame)
describe_df = df.describe()
return pd.DataFrame(describe_df.transpose(), columns=describe_df.columns).to_html(index=False)


class PolarsDataFrameToParquetEncodingHandler(StructuredDatasetEncoder):
def __init__(self):
super().__init__(pl.DataFrame, None, PARQUET)
@@ -61,3 +73,4 @@ def decode(

StructuredDatasetTransformerEngine.register(PolarsDataFrameToParquetEncodingHandler())
StructuredDatasetTransformerEngine.register(ParquetToPolarsDataFrameDecodingHandler())
StructuredDatasetTransformerEngine.register_renderer(pl.DataFrame, PolarsDataFrameRenderer())
9 changes: 9 additions & 0 deletions plugins/flytekit-polars/tests/test_polars_plugin_sd.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import pandas as pd
import polars as pl
from flytekitplugins.polars.sd_transformers import PolarsDataFrameRenderer
from typing_extensions import Annotated

from flytekit import kwtypes, task, workflow
@@ -57,3 +59,10 @@ def wf() -> full_schema:

result = wf()
assert result is not None


def test_polars_renderer():
df = pl.DataFrame({"col1": [1, 3, 2], "col2": list("abc")})
assert PolarsDataFrameRenderer().to_html(df) == pd.DataFrame(
df.describe().transpose(), columns=df.describe().columns
).to_html(index=False)
12 changes: 12 additions & 0 deletions plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing

import pandas as pd
from pyspark.sql.dataframe import DataFrame

from flytekit import FlyteContext
@@ -15,6 +16,16 @@
)


class SparkDataFrameRenderer:
"""
Render a Spark dataframe schema as an HTML table.
"""

def to_html(self, df: DataFrame) -> str:
assert isinstance(df, DataFrame)
return pd.DataFrame(df.schema, columns=["StructField"]).to_html()


class SparkToParquetEncodingHandler(StructuredDatasetEncoder):
def __init__(self):
super().__init__(DataFrame, None, PARQUET)
@@ -50,3 +61,4 @@ def decode(

StructuredDatasetTransformerEngine.register(SparkToParquetEncodingHandler())
StructuredDatasetTransformerEngine.register(ParquetToSparkDecodingHandler())
StructuredDatasetTransformerEngine.register_renderer(DataFrame, SparkDataFrameRenderer())
15 changes: 15 additions & 0 deletions tests/flytekit/unit/core/test_structured_dataset.py
Original file line number Diff line number Diff line change
@@ -414,3 +414,18 @@ def test_protocol_detection():

protocol = e._protocol_from_type_or_prefix(ctx2, pd.DataFrame, "bq://foo")
assert protocol == "bq"


def test_register_renderers():
class DummyRenderer:
def to_html(self, input: str) -> str:
return "hello " + input

renderers = StructuredDatasetTransformerEngine.Renderers
StructuredDatasetTransformerEngine.register_renderer(str, DummyRenderer())
assert renderers[str].to_html("flyte") == "hello flyte"
assert pd.DataFrame in renderers
assert pa.Table in renderers

with pytest.raises(NotImplementedError, match="Could not find a renderer for <class 'int'> in"):
StructuredDatasetTransformerEngine().to_html(FlyteContextManager.current_context(), 3, int)
11 changes: 7 additions & 4 deletions tests/flytekit/unit/deck/test_renderer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import pandas as pd
import pyarrow as pa

from flytekit.deck.renderer import TopFrameRenderer
from flytekit.deck.renderer import ArrowRenderer, TopFrameRenderer


def test_frame_profiling_renderer():
def test_renderer():
df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [1, 22]})
renderer = TopFrameRenderer()
assert renderer.to_html(df) == df.to_html()
pa_df = pa.Table.from_pandas(df)

assert TopFrameRenderer().to_html(df) == df.to_html()
assert ArrowRenderer().to_html(pa_df) == pa_df.to_string()
Original file line number Diff line number Diff line change
@@ -68,6 +68,15 @@ def decode(
StructuredDatasetTransformerEngine.register(MockBQDecodingHandlers(), False, True)


class NumpyRenderer:
"""
The Polars DataFrame summary statistics are rendered as an HTML table.
"""

def to_html(self, array: np.ndarray) -> str:
return pd.DataFrame(array).describe().to_html()


@pytest.fixture(autouse=True)
def numpy_type():
class NumpyEncodingHandlers(StructuredDatasetEncoder):
@@ -101,9 +110,9 @@ def decode(
table = pq.read_table(local_dir)
return table.to_pandas().to_numpy()

for protocol in ["/", "s3"]:
StructuredDatasetTransformerEngine.register(NumpyEncodingHandlers(np.ndarray, protocol, PARQUET))
StructuredDatasetTransformerEngine.register(NumpyDecodingHandlers(np.ndarray, protocol, PARQUET))
StructuredDatasetTransformerEngine.register(NumpyEncodingHandlers(np.ndarray))
StructuredDatasetTransformerEngine.register(NumpyDecodingHandlers(np.ndarray))
StructuredDatasetTransformerEngine.register_renderer(np.ndarray, NumpyRenderer())


@task