Skip to content

Commit

Permalink
Explorer new dataset format (#682)
Browse files Browse the repository at this point in the history
PR that changes the explorer to match the new dataset format. 


![image](https://github.com/ml6team/fondant/assets/47530815/233170df-3e75-4949-b26b-f3bfdc5227a9)

There is some weird refersh error that is occurring which might get
resolved/ will be easier to debug after removing the partitions.
Let's focus on first getting this
[PR](#677) merged since it might
fix the issue
  • Loading branch information
PhilippeMoussalli authored Nov 28, 2023
1 parent 9d1abed commit 6a84677
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 117 deletions.
3 changes: 2 additions & 1 deletion data_explorer/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
"partition",
]

DEFAULT_INDEX_NAME = "id"

ROWS_TO_RETURN = 20
ROWS_PER_PAGE = 10
22 changes: 13 additions & 9 deletions data_explorer/app/df_helpers/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,29 @@

import typing as t

from fondant.core.schema import Field


def get_fields_by_types(
fields: t.Dict[str, str],
fields: t.Dict[str, Field],
field_types: t.List[str],
) -> t.List[str]:
return [
field
for field, f_type in fields.items()
if any(ftype in f_type for ftype in field_types)
]
filtered_fields = []

for field, f_type in fields.items():
if any(ftype in f_type.type.to_json()["type"] for ftype in field_types):
filtered_fields.append(field)

return filtered_fields


def get_string_fields(fields: t.Dict[str, str]) -> t.List[str]:
def get_string_fields(fields: t.Dict[str, Field]) -> t.List[str]:
return get_fields_by_types(fields, ["string", "utf8"])


def get_image_fields(fields: t.Dict[str, str]) -> t.List[str]:
def get_image_fields(fields: t.Dict[str, Field]) -> t.List[str]:
return get_fields_by_types(fields, ["binary"])


def get_numeric_fields(fields: t.Dict[str, str]) -> t.List[str]:
def get_numeric_fields(fields: t.Dict[str, Field]) -> t.List[str]:
return get_fields_by_types(fields, ["int", "float"])
3 changes: 2 additions & 1 deletion data_explorer/app/interfaces/common_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class MainInterface:
def __init__(self):
app_interface = AppStateInterface()
app_interface.initialize()
st.set_page_config(layout="wide")
self.fs, _ = fsspec.core.url_to_fs(st.session_state["base_path"])

def _display_base_info(self):
Expand Down Expand Up @@ -116,7 +117,7 @@ def create_common_interface(self):
add_logo("content/fondant_logo.png")

with st.sidebar:
# Increase the width of the sidebar to accomodate logo
# Increase the width of the sidebar to accommodate logo
st.markdown(
"""
<style>
Expand Down
212 changes: 141 additions & 71 deletions data_explorer/app/interfaces/dataset_interface.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""Dataset interface for the data explorer app."""

import os
import typing as t
from collections import defaultdict

import dask.dataframe as dd
import pandas as pd
import streamlit as st
from config import ROWS_TO_RETURN
from config import DEFAULT_INDEX_NAME, ROWS_TO_RETURN
from fondant.core.manifest import Manifest
from fondant.core.schema import Field
from interfaces.common_interface import MainInterface
from interfaces.utils import get_default_index

Expand Down Expand Up @@ -36,80 +39,111 @@ def _select_component():

return selected_component_path

def _get_manifest_fields_and_subset(self):
"""Get fields and subset from manifest and store them in session state."""
cols = st.columns(3)
def _get_manifest_fields(self) -> t.Tuple[Manifest, t.Dict[str, Field]]:
"""Get fields from manifest and store them in session state."""
cols = st.columns(2)

with cols[0]:
selected_component_path = self._select_component()

manifest_path = os.path.join(selected_component_path, "manifest.json")
manifest = Manifest.from_file(manifest_path)
subsets = manifest.subsets.keys()
fields = manifest.fields
field_list = list(fields.keys())

with cols[1]:
subset = st.selectbox("Select subset", subsets)

fields = manifest.subsets[subset].fields

with cols[2]:
fields = st.multiselect("Fields", fields, default=fields)
selected_field_names = st.multiselect(
"Fields",
field_list,
default=field_list,
)

field_types = {
f"{field}": manifest.subsets[subset].fields[field].type.name
for field in fields
selected_fields = {
field_name: fields[field_name] for field_name in selected_field_names
}
return manifest, subset, field_types

def _get_subset_path(self, manifest, subset):
return manifest, selected_fields

def _get_field_location(self, manifest: Manifest, field_name: str) -> str:
"""
Get path to subset from manifest. If the base path is not mounted, the subset path is
assumed to be relative to the base path. If the base path is mounted, the subset path is
Get path to the fields from manifest. If the base path is not mounted, the fields path are
assumed to be relative to the base path. If the base path is mounted, the fields path are
assumed to be absolute.
"""
base_path = st.session_state["base_path"]
subset = manifest.subsets[subset]
field_location = manifest.get_field_location(field_name)

if (
os.path.ismount(base_path) is False
and self.fs.__class__.__name__ == "LocalFileSystem"
):
# Used for local development when running the app locally
subset_path = os.path.join(
field_location = os.path.join(
os.path.dirname(base_path),
subset.location.lstrip("/"),
field_location.lstrip("/"),
)

return field_location

def get_fields_mapping(self):
field_mapping = defaultdict(list)

manifest, selected_fields = self._get_manifest_fields()
# Add index field to field mapping to guarantee start reading with the index dataframe
index_location = self._get_field_location(manifest, DEFAULT_INDEX_NAME)
field_mapping[index_location].append(DEFAULT_INDEX_NAME)

for field_name, field in selected_fields.items():
field_location = self._get_field_location(manifest, field_name)
field_mapping[field_location].append(field_name)

return field_mapping, selected_fields

@staticmethod
@st.cache_data
def load_dask_dataframe(field_mapping):
dataframe = None
for location, fields in field_mapping.items():
if DEFAULT_INDEX_NAME in fields:
fields.remove(DEFAULT_INDEX_NAME)

partial_df = dd.read_parquet(
location,
columns=fields,
index=DEFAULT_INDEX_NAME,
calculate_divisions=True,
)
else:
# Used for mounted data (production)
subset_path = subset.location

return subset_path

def load_dask_dataframe(self):
"""Loads a Dask DataFrame from a subset of the dataset."""
manifest, subset, fields = self._get_manifest_fields_and_subset()
subset_path = self._get_subset_path(manifest, subset)
dask_df = dd.read_parquet(subset_path, columns=list(fields.keys())).reset_index(
drop=False,
)

return dask_df, fields
if dataframe is None:
# ensure that the index is set correctly and divisions are known.
dataframe = partial_df
else:
dataframe = dataframe.merge(
partial_df,
how="left",
left_index=True,
right_index=True,
)
return dataframe

@staticmethod
@st.cache_data
def get_pandas_from_dask(
dask_df: dd.DataFrame,
field_mapping,
_dask_df: dd.DataFrame,
rows_to_return: int,
partition_index: int,
rows_from_last_partition: int,
last_partition_index: int,
):
"""
Converts a Dask DataFrame into a Pandas DataFrame with specified number of rows.
Args:
dask_df: Input Dask DataFrame.
field_mapping: Mapping of fields to their location. Used as a unique cache key.
_dask_df: Input Dask DataFrame.
rows_to_return: Number of rows needed in the resulting Pandas DataFrame.
partition_index: Index of the partition to start from.
rows_from_last_partition: Number of rows to take from the last partition.
last_partition_index: Index from the last partition.
Returns:
result_df: Pandas DataFrame with the specified number of rows.
Expand All @@ -119,56 +153,80 @@ def get_pandas_from_dask(
rows_returned = 0
data_to_return = []

dask_df = dask_df.partitions[partition_index:]
dask_df = _dask_df.partitions[partition_index:]

for partition_index, partition in enumerate(
dask_df.partitions,
start=partition_index,
):
# Materialize partition as a pandas DataFrame
partition_df = partition.compute()
partition_df = partition_df[rows_from_last_partition:]
partition_df = partition.compute().reset_index(drop=False)
partition_length = len(partition_df)
partition_df = partition_df[last_partition_index:]

# Check if adding this partition exceeds the required rows
if rows_returned + len(partition_df) <= rows_to_return:
if rows_returned + partition_length <= rows_to_return:
data_to_return.append(partition_df)
rows_returned += len(partition_df)
rows_returned += partition_length
else:
# Calculate how many rows to take from this partition
rows_from_last_partition = rows_to_return - rows_returned
partition_df = partition_df.head(rows_from_last_partition)
data_to_return.append(partition_df)
break
rows_to_take = rows_to_return - rows_returned
sliced_partition_df = partition_df.head(rows_to_take)

# Check if the partition is empty
if len(sliced_partition_df) == 0:
last_partition_index = 0
continue

data_to_return.append(sliced_partition_df)
rows_returned += len(sliced_partition_df)

# Check if we have reached the required number of rows
if rows_returned >= rows_to_return or len(sliced_partition_df) == 0:
last_partition_index = last_partition_index + len(
sliced_partition_df,
)
break

# Check if the last row of the sliced partition is the same as the last row of
# the original partition. If so, we have reached the end of the dataframe.
if partition_df.tail(1).equals(sliced_partition_df.tail(1)):
last_partition_index = 0

# Concatenate the selected partitions into a single pandas DataFrame
df = pd.concat(data_to_return)

return df, partition_index, rows_from_last_partition
return df, partition_index, last_partition_index

@staticmethod
def _initialize_page_view_dict():
return st.session_state.get(
"page_view_dict",
{
def _initialize_page_view_dict(component):
page_view_dict = st.session_state.get("page_view_dict", {})

if component not in page_view_dict:
page_view_dict[component] = {
0: {
"start_index": 0,
"start_partition": 0,
},
},
)
}
st.session_state["page_index"] = 0

return page_view_dict

@staticmethod
def _update_page_view_dict(
page_view_dict,
page_index,
start_index,
start_partition,
component,
):
page_view_dict[page_index] = {
page_view_dict[component][page_index] = {
"start_index": start_index,
"start_partition": start_partition,
}
st.session_state["page_view_dict"] = page_view_dict

return page_view_dict

def load_pandas_dataframe(self):
Expand All @@ -182,28 +240,35 @@ def load_pandas_dataframe(self):
previous_button_disabled = True
next_button_disabled = False

# Get field mapping from manifest and selected fields
field_mapping, selected_fields = self.get_fields_mapping()

# Get the manifest, subset, and fields
dask_df, fields = self.load_dask_dataframe()
dask_df = self.load_dask_dataframe(field_mapping)

# Initialize page view dict if it doesn't exist
component = st.session_state["component"]
page_view_dict = self._initialize_page_view_dict(component)
page_index = st.session_state.get("page_index", 0)
page_view_dict = self._initialize_page_view_dict()

# Get the starting index and partition for the current page
start_index = page_view_dict[page_index]["start_index"]
start_partition = page_view_dict[page_index]["start_partition"]
start_index = page_view_dict[component][page_index]["start_index"]
start_partition = page_view_dict[component][page_index]["start_partition"]

pandas_df, next_partition, next_index = self.get_pandas_from_dask(
dask_df,
ROWS_TO_RETURN,
start_index,
start_partition,
field_mapping=field_mapping,
_dask_df=dask_df,
rows_to_return=ROWS_TO_RETURN,
partition_index=start_partition,
last_partition_index=start_index,
)

self._update_page_view_dict(
page_view_dict,
page_index + 1,
next_index,
next_partition,
page_view_dict=page_view_dict,
page_index=page_index + 1,
start_index=next_index,
start_partition=next_partition,
component=component,
)

st.info(
Expand All @@ -215,8 +280,13 @@ def load_pandas_dataframe(self):
if page_index != 0:
previous_button_disabled = False

if len(pandas_df) < ROWS_TO_RETURN:
next_button_disabled = True
if next_partition == dask_df.npartitions - 1:
# Check if the last row of the last partition is the same as the last row of the
# dataframe. If so, we have reached the end of the dataframe.
partition_index = dask_df.tail(1).index[0]
dask_df_index = pandas_df.tail(1)[DEFAULT_INDEX_NAME].iloc[0]
if partition_index == dask_df_index:
next_button_disabled = True

if previous_col.button(
"⏮️ Previous",
Expand All @@ -236,4 +306,4 @@ def load_pandas_dataframe(self):

st.markdown(f"Page {page_index}")

return pandas_df, fields
return pandas_df, selected_fields
Loading

0 comments on commit 6a84677

Please sign in to comment.