Skip to content

Commit

Permalink
Hide partitions from users (#677)
Browse files Browse the repository at this point in the history
PR that hides the partitions from the users. Users can now scroll
through the dataset through a next/previous buttons.

In addition, a small changes was made to the statistics pages.
Previously it had a mix of both local and global statistics which made
it confusing. Now it's only restricted to global statistics.


https://github.com/ml6team/fondant/assets/47530815/05a32c54-7fbe-45c3-b7ba-d13e976674e2
  • Loading branch information
PhilippeMoussalli authored Nov 28, 2023
1 parent 3eeb815 commit 9d1abed
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 114 deletions.
3 changes: 3 additions & 0 deletions data_explorer/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@
"selected_component_path",
"partition",
]

ROWS_TO_RETURN = 20
ROWS_PER_PAGE = 10
167 changes: 136 additions & 31 deletions data_explorer/app/interfaces/dataset_interface.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Dataset interface for the data explorer app."""

import os
import typing as t

import dask.dataframe as dd
import pandas as pd
import streamlit as st
from config import ROWS_TO_RETURN
from fondant.core.manifest import Manifest
from interfaces.common_interface import MainInterface
from interfaces.utils import get_default_index
Expand Down Expand Up @@ -84,51 +85,155 @@ def _get_subset_path(self, manifest, subset):

return subset_path

@staticmethod
@st.cache_data
def _load_dask_dataframe(subset_path, fields):
return dd.read_parquet(subset_path, columns=list(fields.keys()))
def load_dask_dataframe(self):
"""Loads a Dask DataFrame from a subset of the dataset."""
manifest, subset, fields = self._get_manifest_fields_and_subset()
subset_path = self._get_subset_path(manifest, subset)
dask_df = dd.read_parquet(subset_path, columns=list(fields.keys())).reset_index(
drop=False,
)

return dask_df, fields

# TODO: change later to accept range of partitions
@staticmethod
def _get_partition_to_load(dask_df: dd.DataFrame) -> t.Union[int, None]:
"""Get the partition of the dataframe to load from a slider."""
partition = None
def get_pandas_from_dask(
dask_df: dd.DataFrame,
rows_to_return: int,
partition_index: int,
rows_from_last_partition: int,
):
"""
Converts a Dask DataFrame into a Pandas DataFrame with specified number of rows.
if dask_df.npartitions > 1:
if st.session_state["partition"] is None:
starting_value = 0
Args:
dask_df: Input Dask DataFrame.
rows_to_return: Number of rows needed in the resulting Pandas DataFrame.
partition_index: Index of the partition to start from.
rows_from_last_partition: Number of rows to take from the last partition.
Returns:
result_df: Pandas DataFrame with the specified number of rows.
last_partition_index: Index of the last used partition.
rows_from_last_partition: Number of rows taken from the last partition.
"""
rows_returned = 0
data_to_return = []

dask_df = dask_df.partitions[partition_index:]

for partition_index, partition in enumerate(
dask_df.partitions,
start=partition_index,
):
# Materialize partition as a pandas DataFrame
partition_df = partition.compute()
partition_df = partition_df[rows_from_last_partition:]

# Check if adding this partition exceeds the required rows
if rows_returned + len(partition_df) <= rows_to_return:
data_to_return.append(partition_df)
rows_returned += len(partition_df)
else:
starting_value = st.session_state["partition"]
# Calculate how many rows to take from this partition
rows_from_last_partition = rows_to_return - rows_returned
partition_df = partition_df.head(rows_from_last_partition)
data_to_return.append(partition_df)
break

partition = st.slider("partition", 1, dask_df.npartitions, starting_value)
st.session_state["partition"] = partition
# Concatenate the selected partitions into a single pandas DataFrame
df = pd.concat(data_to_return)

return partition
return df, partition_index, rows_from_last_partition

@staticmethod
def _get_dataframe_partition(
dask_df: dd.DataFrame,
partition: t.Union[int, None],
) -> dd.DataFrame:
"""Get the partition of the dataframe to load."""
if partition is not None:
return dask_df.get_partition(partition)
def _initialize_page_view_dict():
return st.session_state.get(
"page_view_dict",
{
0: {
"start_index": 0,
"start_partition": 0,
},
},
)

return dask_df
@staticmethod
def _update_page_view_dict(
page_view_dict,
page_index,
start_index,
start_partition,
):
page_view_dict[page_index] = {
"start_index": start_index,
"start_partition": start_partition,
}
st.session_state["page_view_dict"] = page_view_dict
return page_view_dict

def create_loader_widget(self):
def load_pandas_dataframe(self):
"""
Provides common widgets for loading a dataframe and selecting a partition to load. Uses
Cached dataframes to avoid reloading the dataframe when changing the partition.
Returns:
Dataframe and fields
"""
manifest, subset, fields = self._get_manifest_fields_and_subset()
subset_path = self._get_subset_path(manifest, subset)
df = self._load_dask_dataframe(subset_path, fields)
partition = self._get_partition_to_load(df)
dataframe = self._get_dataframe_partition(df, partition)
previous_button_disabled = True
next_button_disabled = False

# Get the manifest, subset, and fields
dask_df, fields = self.load_dask_dataframe()

# Initialize page view dict if it doesn't exist
page_index = st.session_state.get("page_index", 0)
page_view_dict = self._initialize_page_view_dict()

# Get the starting index and partition for the current page
start_index = page_view_dict[page_index]["start_index"]
start_partition = page_view_dict[page_index]["start_partition"]

pandas_df, next_partition, next_index = self.get_pandas_from_dask(
dask_df,
ROWS_TO_RETURN,
start_index,
start_partition,
)
self._update_page_view_dict(
page_view_dict,
page_index + 1,
next_index,
next_partition,
)

st.info(
f"Showing {len(pandas_df)} rows. Click on the 'next' and 'previous' "
f"buttons to navigate through the dataset.",
)
previous_col, _, next_col = st.columns([0.2, 0.6, 0.2])

if page_index != 0:
previous_button_disabled = False

if len(pandas_df) < ROWS_TO_RETURN:
next_button_disabled = True

if previous_col.button(
"⏮️ Previous",
use_container_width=True,
disabled=previous_button_disabled,
):
st.session_state["page_index"] = page_index - 1
st.rerun()

if next_col.button(
"Next ⏭️",
use_container_width=True,
disabled=next_button_disabled,
):
st.session_state["page_index"] = page_index + 1
st.rerun()

st.markdown(f"Page {page_index}")

return dataframe, fields
return pandas_df, fields
31 changes: 6 additions & 25 deletions data_explorer/app/pages/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import typing as t

import pandas as pd
import streamlit as st
import streamlit.components.v1 as components
from bs4 import BeautifulSoup
Expand Down Expand Up @@ -43,31 +44,17 @@ def create_pdf_from_text(raw_text):

class DatasetExplorerApp(DatasetLoaderApp):
@staticmethod
def setup_app_page(dataframe, fields) -> AgGridReturn:
def setup_app_page(dataframe: pd.DataFrame, fields) -> AgGridReturn:
"""Build the dataframe explorer table."""
image_fields = get_image_fields(fields)

# get the first rows of the dataframe
cols = st.columns(2)
with cols[0]:
rows = st.slider(
"Dataframe rows to load",
1,
len(dataframe),
min(len(dataframe), 20),
)
with cols[1]:
rows_per_page = st.slider("Amount of rows per page", 5, 50, 10)

dataframe_explorer = dataframe.head(rows).reset_index(drop=False)

for field in image_fields:
dataframe_explorer = convert_image_column(dataframe_explorer, field)
dataframe = convert_image_column(dataframe, field)

# TODO: add formatting for other datatypes?

# configure builder
options_builder = GridOptionsBuilder.from_dataframe(dataframe_explorer)
options_builder = GridOptionsBuilder.from_dataframe(dataframe)

# Add tooltip hover for all fields
for field in fields:
Expand All @@ -91,12 +78,6 @@ def setup_app_page(dataframe, fields) -> AgGridReturn:
for field in image_fields:
configure_image_builder(options_builder, field)

# configure pagination and sidebar
options_builder.configure_pagination(
enabled=True,
paginationPageSize=rows_per_page,
paginationAutoPageSize=False,
)
options_builder.configure_default_column(
editable=False,
groupable=True,
Expand All @@ -113,7 +94,7 @@ def setup_app_page(dataframe, fields) -> AgGridReturn:

# display the Ag Grid table
return AgGrid(
dataframe_explorer,
dataframe,
gridOptions=options_builder.build(),
allow_unsafe_jscode=True,
columns_auto_size_mode=ColumnsAutoSizeMode.FIT_ALL_COLUMNS_TO_VIEW,
Expand Down Expand Up @@ -162,6 +143,6 @@ def setup_viewer_widget(self, grid_dict: AgGridReturn, fields: t.Dict[str, t.Any

app = DatasetExplorerApp()
app.create_common_interface()
df, df_fields = app.create_loader_widget()
df, df_fields = app.load_pandas_dataframe()
grid_data_dict = app.setup_app_page(df, df_fields)
app.setup_viewer_widget(grid_data_dict, df_fields)
8 changes: 3 additions & 5 deletions data_explorer/app/pages/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,16 @@ def setup_app_page(dataframe, fields):

image_field = st.selectbox("Image field", image_fields)

images = dataframe[image_field].compute()
images = dataframe[image_field]
images = [Image.open(io.BytesIO(x)).resize((256, 256)) for x in images]

image_slider = st.slider("image range", 0, len(images), (0, 10))

# show images in a gallery
cols = st.columns(5)
for i, image in enumerate(images[image_slider[0] : image_slider[1]]):
for i, image in enumerate(images):
cols[i % 5].image(image, use_column_width=True)


app = ImageGalleryApp()
app.create_common_interface()
df, df_fields = app.create_loader_widget()
df, df_fields = app.load_pandas_dataframe()
app.setup_app_page(df, df_fields)
61 changes: 8 additions & 53 deletions data_explorer/app/pages/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def make_numeric_statistics_table(
"""
# make a new dataframe with statistics
# for each numeric field
dataframe[numeric_fields] = dataframe[numeric_fields].fillna(0)
statistics = dataframe[numeric_fields].describe().compute()
statistics = statistics.transpose()
# add a column with the field name
Expand All @@ -47,43 +48,16 @@ def make_numeric_statistics_table(
statistics = statistics[column_order]
return statistics

@staticmethod
def make_numeric_plot(
def build_numeric_analysis_table(
self,
dataframe: dd.DataFrame,
numeric_field: t.List[str],
plot_type: str,
):
"""Plots a numeric dataframe column with streamlit."""
if plot_type == "histogram":
data = dataframe[numeric_field].compute() # .hist(bins=30)
st.plotly_chart(data.hist(), use_container_width=True)

elif plot_type == "violin":
data = dataframe[numeric_field].compute()
st.plotly_chart(data.plot(kind="violin"), use_container_width=True)

elif plot_type == "density":
data = dataframe[numeric_field].compute()
st.plotly_chart(data.plot(kind="density_heatmap"), use_container_width=True)

elif plot_type == "strip":
data = dataframe[numeric_field].compute()
st.plotly_chart(data.plot(kind="strip"), use_container_width=True)

elif plot_type == "categorical":
data = dataframe[numeric_field].value_counts()
st.bar_chart(data.compute())

else:
msg = "Aggregation type not supported"
raise ValueError(msg)

def build_numeric_analysis_table(self, dataframe, numeric_fields) -> None:
numeric_fields: t.List[str],
) -> None:
"""Build the numeric analysis table."""
# check if there are numeric fields
if len(numeric_fields) == 0:
st.warning("There are no numeric fields in this subset")

else:
# make numeric statistics table
aggregation_dataframe = self.make_numeric_statistics_table(
dataframe,
Expand All @@ -106,31 +80,12 @@ def build_numeric_analysis_table(self, dataframe, numeric_fields) -> None:
columns_auto_size_mode=ColumnsAutoSizeMode.FIT_ALL_COLUMNS_TO_VIEW,
)

def build_numeric_analysis_plots(self, dataframe, numeric_fields):
"""Build the numeric analysis plots."""
# check if there are numeric fields
if len(numeric_fields) == 0:
st.warning("There are no numeric fields in this subset")

# choose a numeric field in dropdown
cols = st.columns(2)
with cols[0]:
numeric_field = st.selectbox("Field", numeric_fields)
with cols[1]:
plot_type = st.selectbox(
"Plot type",
["histogram", "violin", "density", "categorical"],
)

self.make_numeric_plot(dataframe, numeric_field, plot_type)

def setup_app_page(self, dataframe, fields):
def setup_app_page(self, dataframe: dd.DataFrame, fields):
numeric_fields = get_numeric_fields(fields)
self.build_numeric_analysis_table(dataframe, numeric_fields)
self.build_numeric_analysis_plots(dataframe, numeric_fields)


app = NumericAnalysisApp()
app.create_common_interface()
df, df_fields = app.create_loader_widget()
df, df_fields = app.load_dask_dataframe()
app.setup_app_page(df, df_fields)

0 comments on commit 9d1abed

Please sign in to comment.