diff --git a/flytekit/deck/renderer.py b/flytekit/deck/renderer.py index 8617ae4d12..0cf781d3da 100644 --- a/flytekit/deck/renderer.py +++ b/flytekit/deck/renderer.py @@ -1,6 +1,7 @@ from typing import Any, Optional import pandas +import pyarrow from typing_extensions import Protocol, runtime_checkable @@ -24,3 +25,13 @@ def __init__(self, max_rows: Optional[int] = None): def to_html(self, df: pandas.DataFrame) -> str: assert isinstance(df, pandas.DataFrame) return df.to_html(max_rows=self._max_rows) + + +class ArrowRenderer: + """ + Render a Arrow dataframe as an HTML table. + """ + + def to_html(self, df: pyarrow.Table) -> str: + assert isinstance(df, pyarrow.Table) + return df.to_string() diff --git a/flytekit/types/structured/basic_dfs.py b/flytekit/types/structured/basic_dfs.py index 97964d0b63..71dff61c5e 100644 --- a/flytekit/types/structured/basic_dfs.py +++ b/flytekit/types/structured/basic_dfs.py @@ -7,6 +7,8 @@ import pyarrow.parquet as pq from flytekit import FlyteContext +from flytekit.deck import TopFrameRenderer +from flytekit.deck.renderer import ArrowRenderer from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import StructuredDatasetType @@ -103,3 +105,6 @@ def decode( StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler()) StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler()) StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler()) + +StructuredDatasetTransformerEngine.register_renderer(pd.DataFrame, TopFrameRenderer()) +StructuredDatasetTransformerEngine.register_renderer(pa.Table, ArrowRenderer()) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index bfbc494bbb..bdad752b16 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -1,7 +1,6 @@ from __future__ import annotations import collections -import importlib import os import types import typing @@ -13,20 +12,14 @@ import numpy as _np import pandas as pd import pyarrow as pa - -from flytekit.core.data_persistence import DataPersistencePlugins, DiskPersistence - -if importlib.util.find_spec("pyspark") is not None: - import pyspark -if importlib.util.find_spec("polars") is not None: - import polars as pl - from dataclasses_json import config, dataclass_json from marshmallow import fields from typing_extensions import Annotated, TypeAlias, get_args, get_origin from flytekit.core.context_manager import FlyteContext, FlyteContextManager +from flytekit.core.data_persistence import DataPersistencePlugins, DiskPersistence from flytekit.core.type_engine import TypeEngine, TypeTransformer +from flytekit.deck.renderer import Renderable from flytekit.loggers import logger from flytekit.models import literals from flytekit.models import types as type_models @@ -339,6 +332,7 @@ class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]): DEFAULT_FORMATS: Dict[Type, str] = {} Handlers = Union[StructuredDatasetEncoder, StructuredDatasetDecoder] + Renderers: Dict[Type, Renderable] = {} @staticmethod def _finder(handler_map, df_type: Type, protocol: str, format: str): @@ -385,6 +379,10 @@ def __init__(self): # Instances of StructuredDataset opt-in to the ability of being cached. self._hash_overridable = True + @classmethod + def register_renderer(cls, python_type: Type, renderer: Renderable): + cls.Renderers[python_type] = renderer + @classmethod def register(cls, h: Handlers, default_for_type: Optional[bool] = False, override: Optional[bool] = False): """ @@ -698,19 +696,10 @@ def to_html(self, ctx: FlyteContext, python_val: typing.Any, expected_python_typ else: df = python_val - if isinstance(df, pd.DataFrame): - return df.describe().to_html() - elif isinstance(df, pa.Table): - return df.to_string() - elif isinstance(df, _np.ndarray): - return pd.DataFrame(df).describe().to_html() - elif importlib.util.find_spec("pyspark") is not None and isinstance(df, pyspark.sql.DataFrame): - return pd.DataFrame(df.schema, columns=["StructField"]).to_html() - elif importlib.util.find_spec("polars") is not None and isinstance(df, pl.DataFrame): - describe_df = df.describe() - return pd.DataFrame(describe_df.transpose(), columns=describe_df.columns).to_html(index=False) + if type(df) in self.Renderers: + return self.Renderers[type(df)].to_html(df) else: - raise NotImplementedError("Conversion to html string should be implemented") + raise NotImplementedError(f"Could not find a renderer for {type(df)} in {self.Renderers}") def open_as( self, diff --git a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py index 6388bc4c9e..0dfd0c6516 100644 --- a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py +++ b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py @@ -1,5 +1,6 @@ import typing +import pandas as pd import polars as pl from flytekit import FlyteContext @@ -15,6 +16,17 @@ ) +class PolarsDataFrameRenderer: + """ + The Polars DataFrame summary statistics are rendered as an HTML table. + """ + + def to_html(self, df: pl.DataFrame) -> str: + assert isinstance(df, pl.DataFrame) + describe_df = df.describe() + return pd.DataFrame(describe_df.transpose(), columns=describe_df.columns).to_html(index=False) + + class PolarsDataFrameToParquetEncodingHandler(StructuredDatasetEncoder): def __init__(self): super().__init__(pl.DataFrame, None, PARQUET) @@ -61,3 +73,4 @@ def decode( StructuredDatasetTransformerEngine.register(PolarsDataFrameToParquetEncodingHandler()) StructuredDatasetTransformerEngine.register(ParquetToPolarsDataFrameDecodingHandler()) +StructuredDatasetTransformerEngine.register_renderer(pl.DataFrame, PolarsDataFrameRenderer()) diff --git a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py index 94eb0bc735..b991cd5d13 100644 --- a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py +++ b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py @@ -1,4 +1,6 @@ +import pandas as pd import polars as pl +from flytekitplugins.polars.sd_transformers import PolarsDataFrameRenderer from typing_extensions import Annotated from flytekit import kwtypes, task, workflow @@ -57,3 +59,10 @@ def wf() -> full_schema: result = wf() assert result is not None + + +def test_polars_renderer(): + df = pl.DataFrame({"col1": [1, 3, 2], "col2": list("abc")}) + assert PolarsDataFrameRenderer().to_html(df) == pd.DataFrame( + df.describe().transpose(), columns=df.describe().columns + ).to_html(index=False) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py index 1a89b7b331..46079f40dd 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py @@ -1,5 +1,6 @@ import typing +import pandas as pd from pyspark.sql.dataframe import DataFrame from flytekit import FlyteContext @@ -15,6 +16,16 @@ ) +class SparkDataFrameRenderer: + """ + Render a Spark dataframe schema as an HTML table. + """ + + def to_html(self, df: DataFrame) -> str: + assert isinstance(df, DataFrame) + return pd.DataFrame(df.schema, columns=["StructField"]).to_html() + + class SparkToParquetEncodingHandler(StructuredDatasetEncoder): def __init__(self): super().__init__(DataFrame, None, PARQUET) @@ -50,3 +61,4 @@ def decode( StructuredDatasetTransformerEngine.register(SparkToParquetEncodingHandler()) StructuredDatasetTransformerEngine.register(ParquetToSparkDecodingHandler()) +StructuredDatasetTransformerEngine.register_renderer(DataFrame, SparkDataFrameRenderer()) diff --git a/tests/flytekit/unit/core/test_structured_dataset.py b/tests/flytekit/unit/core/test_structured_dataset.py index d78d129309..7793df430f 100644 --- a/tests/flytekit/unit/core/test_structured_dataset.py +++ b/tests/flytekit/unit/core/test_structured_dataset.py @@ -414,3 +414,18 @@ def test_protocol_detection(): protocol = e._protocol_from_type_or_prefix(ctx2, pd.DataFrame, "bq://foo") assert protocol == "bq" + + +def test_register_renderers(): + class DummyRenderer: + def to_html(self, input: str) -> str: + return "hello " + input + + renderers = StructuredDatasetTransformerEngine.Renderers + StructuredDatasetTransformerEngine.register_renderer(str, DummyRenderer()) + assert renderers[str].to_html("flyte") == "hello flyte" + assert pd.DataFrame in renderers + assert pa.Table in renderers + + with pytest.raises(NotImplementedError, match="Could not find a renderer for in"): + StructuredDatasetTransformerEngine().to_html(FlyteContextManager.current_context(), 3, int) diff --git a/tests/flytekit/unit/deck/test_renderer.py b/tests/flytekit/unit/deck/test_renderer.py index f1ebbcd873..3f597af416 100644 --- a/tests/flytekit/unit/deck/test_renderer.py +++ b/tests/flytekit/unit/deck/test_renderer.py @@ -1,9 +1,12 @@ import pandas as pd +import pyarrow as pa -from flytekit.deck.renderer import TopFrameRenderer +from flytekit.deck.renderer import ArrowRenderer, TopFrameRenderer -def test_frame_profiling_renderer(): +def test_renderer(): df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [1, 22]}) - renderer = TopFrameRenderer() - assert renderer.to_html(df) == df.to_html() + pa_df = pa.Table.from_pandas(df) + + assert TopFrameRenderer().to_html(df) == df.to_html() + assert ArrowRenderer().to_html(pa_df) == pa_df.to_string() diff --git a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py index c2394d7a7a..c849995b96 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py +++ b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py @@ -68,6 +68,15 @@ def decode( StructuredDatasetTransformerEngine.register(MockBQDecodingHandlers(), False, True) +class NumpyRenderer: + """ + The Polars DataFrame summary statistics are rendered as an HTML table. + """ + + def to_html(self, array: np.ndarray) -> str: + return pd.DataFrame(array).describe().to_html() + + @pytest.fixture(autouse=True) def numpy_type(): class NumpyEncodingHandlers(StructuredDatasetEncoder): @@ -101,9 +110,9 @@ def decode( table = pq.read_table(local_dir) return table.to_pandas().to_numpy() - for protocol in ["/", "s3"]: - StructuredDatasetTransformerEngine.register(NumpyEncodingHandlers(np.ndarray, protocol, PARQUET)) - StructuredDatasetTransformerEngine.register(NumpyDecodingHandlers(np.ndarray, protocol, PARQUET)) + StructuredDatasetTransformerEngine.register(NumpyEncodingHandlers(np.ndarray)) + StructuredDatasetTransformerEngine.register(NumpyDecodingHandlers(np.ndarray)) + StructuredDatasetTransformerEngine.register_renderer(np.ndarray, NumpyRenderer()) @task