Skip to content

Commit

Permalink
Merge pull request #153 from Semantics-of-Sustainability/feature/search
Browse files Browse the repository at this point in the history
Add search filter.
  • Loading branch information
carschno authored Dec 18, 2024
2 parents e38709a + 70c2640 commit 5e559dc
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 20 deletions.
45 changes: 45 additions & 0 deletions tempo_embeddings/visualization/jscatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,50 @@ def handle_slider_change(change):

return selection

def _search_filter(self) -> widgets.HBox:
"""Create a search widget for filtering on a field.
Returns:
widgets.HBox: A widget containing a search box
"""

search = widgets.Text(
description="Search:",
placeholder="Enter search term",
continuous_update=True,
)
field_selector = widgets.Dropdown(
options=sorted(
[c for c in self._df.columns if self._df[c].dtype == "string"]
),
value="text",
description="In field",
layout={"width": "max-content"},
)
_widgets = [search, field_selector]

def handle_search_change(change):
for w in _widgets:
w.disabled = True

search_term = search.value.strip()
if search_term:
filtered = self._df.loc[
self._df[field_selector.value].str.contains(search_term)
]
else:
filtered = self._df

self._filter(field_selector, filtered.index)

for w in _widgets:
w.disabled = False

search.observe(handle_search_change, names="value")
field_selector.observe(handle_search_change, names="value")

return widgets.HBox(_widgets)

def _filter(self, field, index):
"""Filter the scatter plot based on the given field and index.
Expand Down Expand Up @@ -601,6 +645,7 @@ def get_widgets(
widgets.HBox(
[widget for widget in category_filters if widget is not None]
),
self._search_filter(),
self._color_by_dropdown(),
self._select_tooltips(),
self._export_button(),
Expand Down
61 changes: 41 additions & 20 deletions tests/unit/visualization/test_jscatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from contextlib import nullcontext as does_not_raise
from unittest import mock

import numpy as np
import pandas as pd
import pytest
from ipywidgets.widgets import (
Expand Down Expand Up @@ -40,6 +41,7 @@ def test_init(self, container):
HBox,
HBox,
HBox,
HBox,
Dropdown,
SelectMultiple,
HBox,
Expand Down Expand Up @@ -139,6 +141,7 @@ def test_get_widgets(
HBox,
HBox,
HBox,
HBox,
Dropdown,
SelectMultiple,
HBox,
Expand Down Expand Up @@ -220,7 +223,15 @@ def test_plot_button(self, corpus):


class TestPlotWidgets:
def test_export_button(self, corpus, tmp_path):
@pytest.fixture
def plot_widgets(self, corpus):
return JScatterVisualizer.PlotWidgets(
df=corpus.to_dataframe().convert_dtypes(),
color_by="corpus",
tooltip_fields=set(),
)

def test_export_button(self, plot_widgets, tmp_path):
expected_columns = [
"text",
"ID_DB",
Expand All @@ -235,10 +246,7 @@ def test_export_button(self, corpus, tmp_path):
"distance_to_centroid",
]

pw = JScatterVisualizer.PlotWidgets(
df=corpus.to_dataframe(), color_by="corpus", tooltip_fields=set()
)
export_button = pw._export_button()
export_button = plot_widgets._export_button()
assert isinstance(export_button, HBox)

assert [type(w) for w in export_button.children] == [Button, Text, Checkbox]
Expand All @@ -251,27 +259,40 @@ def test_export_button(self, corpus, tmp_path):

df = pd.read_csv(target_file)
assert df.columns.to_list() == expected_columns
assert len(df) == len(corpus)

def test_color_by(self, corpus):
pw = JScatterVisualizer.PlotWidgets(
df=corpus.to_dataframe(), color_by="corpus", tooltip_fields=set()
)
assert len(df) == 5

color_box = pw._color_by_dropdown()
def test_color_by(self, plot_widgets):
color_box = plot_widgets._color_by_dropdown()
assert isinstance(color_box, Dropdown)
assert color_box.value == "corpus"

color_box.value = "year"
assert pw._scatter_plot.color()["by"] == "year"

def test_select_tooltips(self, corpus):
pw = JScatterVisualizer.PlotWidgets(
df=corpus.to_dataframe(), color_by="corpus", tooltip_fields=set()
)
assert plot_widgets._scatter_plot.color()["by"] == "year"

select_tooltips = pw._select_tooltips()
def test_select_tooltips(self, plot_widgets):
select_tooltips = plot_widgets._select_tooltips()
assert isinstance(select_tooltips, SelectMultiple)

select_tooltips.value = ["provenance"]
assert pw._scatter_plot.tooltip()["properties"] == ["provenance"]
assert plot_widgets._scatter_plot.tooltip()["properties"] == ["provenance"]

@pytest.mark.parametrize(
"search_term,field,expected",
[
("test", "text", np.arange(5)),
("test", "provenance", np.arange(5)),
("invalid", "text", []),
],
)
def test_search_filter(self, plot_widgets, search_term, field, expected):
search_box = plot_widgets._search_filter()
assert isinstance(search_box, HBox)
assert [type(w) for w in search_box.children] == [Text, Dropdown]

search_box.children[0].value = search_term
search_box.children[1].value = field

np.testing.assert_equal(plot_widgets._scatter_plot.filter(), expected)

search_box.children[0].value = ""
np.testing.assert_equal(plot_widgets._scatter_plot.filter(), np.arange(5))

0 comments on commit 5e559dc

Please sign in to comment.