Skip to content

Commit

Permalink
add ImageRender class for Flyte Decks (#1599)
Browse files Browse the repository at this point in the history
Signed-off-by: esad <[email protected]>
Signed-off-by: Fabio Graetz <[email protected]>
  • Loading branch information
peridotml authored and fg91 committed Apr 23, 2023
1 parent 9c9fb5d commit c7d69fa
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 5 deletions.
42 changes: 38 additions & 4 deletions plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import base64
from io import BytesIO
from typing import Union

import markdown
import pandas
import pandas as pd
import plotly.express as px
from PIL import Image
from ydata_profiling import ProfileReport

from flytekit.types.file import FlyteFile


class FrameProfilingRenderer:
"""
Expand All @@ -12,8 +19,8 @@ class FrameProfilingRenderer:
def __init__(self, title: str = "Pandas Profiling Report"):
self._title = title

def to_html(self, df: pandas.DataFrame) -> str:
assert isinstance(df, pandas.DataFrame)
def to_html(self, df: pd.DataFrame) -> str:
assert isinstance(df, pd.DataFrame)
profile = ProfileReport(df, title=self._title)
return profile.to_html()

Expand Down Expand Up @@ -45,6 +52,33 @@ class BoxRenderer:
def __init__(self, column_name):
self._column_name = column_name

def to_html(self, df: pandas.DataFrame) -> str:
def to_html(self, df: pd.DataFrame) -> str:
fig = px.box(df, y=self._column_name)
return fig.to_html()


class ImageRenderer:
"""Converts a FlyteFile or PIL.Image.Image object to an HTML string with the image data
represented as a base64-encoded string.
"""

def to_html(cls, image_src: Union[FlyteFile, Image.Image]) -> str:
img = cls._get_image_object(image_src)
return cls._image_to_html_string(img)

@staticmethod
def _get_image_object(image_src: Union[FlyteFile, Image.Image]) -> Image.Image:
if isinstance(image_src, FlyteFile):
local_path = image_src.download()
return Image.open(local_path)
elif isinstance(image_src, Image.Image):
return image_src
else:
raise ValueError("Unsupported image source type")

@staticmethod
def _image_to_html_string(img: Image.Image) -> str:
buffered = BytesIO()
img.save(buffered, format="PNG")
img_base64 = base64.b64encode(buffered.getvalue()).decode()
return f'<img src="data:image/png;base64,{img_base64}" alt="Rendered Image" />'
34 changes: 33 additions & 1 deletion plugins/flytekit-deck-standard/tests/test_renderer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import tempfile

import markdown
import pandas as pd
from flytekitplugins.deck.renderer import BoxRenderer, FrameProfilingRenderer, MarkdownRenderer
import pytest
from flytekitplugins.deck.renderer import BoxRenderer, FrameProfilingRenderer, ImageRenderer, MarkdownRenderer
from PIL import Image

from flytekit.types.file import FlyteFile, JPEGImageFile, PNGImageFile

df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [1, 22]})

Expand All @@ -19,3 +25,29 @@ def test_markdown_renderer():
def test_box_renderer():
renderer = BoxRenderer("Name")
assert "Plotlyconfig = {Mathjaxconfig: 'Local'}" in renderer.to_html(df).title()


def create_simple_image(fmt: str):
"""Create a simple PNG image using PIL"""
img = Image.new("RGB", (100, 100), color="black")
tmp = tempfile.mktemp()
img.save(tmp, fmt)
return tmp


png_image = create_simple_image(fmt="png")
jpeg_image = create_simple_image(fmt="jpeg")


@pytest.mark.parametrize(
"image_src",
[
FlyteFile(path=png_image),
JPEGImageFile(path=jpeg_image),
PNGImageFile(path=png_image),
Image.open(png_image),
],
)
def test_image_renderer(image_src):
renderer = ImageRenderer()
assert "<img" in renderer.to_html(image_src)

0 comments on commit c7d69fa

Please sign in to comment.