Skip to content

Commit

Permalink
Add field selector for plot_by_field
Browse files Browse the repository at this point in the history
  • Loading branch information
carschno committed Dec 9, 2024
1 parent b2aa6df commit 3754559
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 25 deletions.
60 changes: 36 additions & 24 deletions tempo_embeddings/visualization/jscatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def get_widgets(self) -> list[widgets.Widget]:
_widgets.append(self._cluster_button())

_widgets.append(self._top_words_button())
_widgets.append(self._plot_by_label_button())
_widgets.append(self._plot_by_field_button())

return _widgets

Expand Down Expand Up @@ -233,9 +233,8 @@ def cluster(button): # pragma: no cover

return button

def _plot_by_label_button(self) -> widgets.Button:
def _plot_by_field_button(self) -> widgets.Button:
field = "year"
groups_field = "label"

window_size_slider = widgets.BoundedIntText(
value=5,
Expand All @@ -244,35 +243,48 @@ def _plot_by_label_button(self) -> widgets.Button:
description="Rolling Window over Years:",
layout={"width": "max-content"},
)
# TODO: update option to match selection
groups_field_selector = widgets.Dropdown(
description="Field to plot",
options=self._df.columns,
value="label",
layout={"width": "max-content"},
)

corpus_per_year = self._df[field].value_counts()

def _plot_labels(b):
for label, group in self._df.loc[self._plot_widgets.selected()].groupby(
groups_field
):
window = window_size_slider.value
if label != OUTLIERS_LABEL:
s = (
(group[field].value_counts() / corpus_per_year)
.sort_index()
.rolling(window)
.mean()
)
s.name = label
ax = s.plot(kind="line", legend=label)
ax.set_title(
f"Relative Frequency of {field} by {groups_field} (Rolling Window over {window} {field}s)"
)
ax.set_xlabel(field)
ax.set_ylabel("Relative Frequency")
def _plot_by_field(b):
_selection = self._df.loc[self._plot_widgets.selected()]
groups_field = groups_field_selector.value

if groups_field in _selection.columns:
for label, group in _selection.groupby(groups_field):
window = window_size_slider.value
if label != OUTLIERS_LABEL:
_series = (
(group[field].value_counts() / corpus_per_year)
.sort_index()
.rolling(window)
.mean()
)
_series.name = label
ax = _series.plot(kind="line", legend=label)
ax.set_title(
f"Relative Frequency by '{groups_field}' (Rolling Window over {window} {field}s)"
)
ax.set_xlabel(field)
ax.set_ylabel("Relative Frequency")
else:
# TODO: this should never happen if the dropdown is updated
raise ValueError(f"Field '{groups_field}' not found in selection.")

button = widgets.Button(
description="Plot by Corpus",
tooltip="Plot (selected) corpora frequencies over years by Corpus",
)
button.on_click(_plot_labels)
button.on_click(_plot_by_field)

return widgets.HBox((button, window_size_slider))
return widgets.HBox((button, window_size_slider, groups_field_selector))

def _top_words_button(self) -> widgets.Button:
def _show_top_words(b): # pragma: no cover
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/visualization/test_jscatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ipywidgets.widgets import (
BoundedIntText,
Button,
Dropdown,
HBox,
SelectionRangeSlider,
SelectMultiple,
Expand Down Expand Up @@ -188,7 +189,8 @@ def test_plot_button(self, corpus):
visualizer = JScatterVisualizer([corpus])
widgets = visualizer.get_widgets()

assert [type(w) for w in widgets[-1].children] == [Button, BoundedIntText]
expected_widgets = [Button, BoundedIntText, Dropdown]
assert [type(w) for w in widgets[-1].children] == expected_widgets

button = widgets[-1].children[0]

Expand Down

0 comments on commit 3754559

Please sign in to comment.