diff --git a/tempo_embeddings/visualization/jscatter.py b/tempo_embeddings/visualization/jscatter.py index 7a12f7c..70a991e 100644 --- a/tempo_embeddings/visualization/jscatter.py +++ b/tempo_embeddings/visualization/jscatter.py @@ -175,7 +175,7 @@ def get_widgets(self) -> list[widgets.Widget]: _widgets.append(self._cluster_button()) _widgets.append(self._top_words_button()) - _widgets.append(self._plot_by_label_button()) + _widgets.append(self._plot_by_field_button()) return _widgets @@ -233,9 +233,8 @@ def cluster(button): # pragma: no cover return button - def _plot_by_label_button(self) -> widgets.Button: + def _plot_by_field_button(self) -> widgets.Button: field = "year" - groups_field = "label" window_size_slider = widgets.BoundedIntText( value=5, @@ -244,35 +243,48 @@ def _plot_by_label_button(self) -> widgets.Button: description="Rolling Window over Years:", layout={"width": "max-content"}, ) + # TODO: update option to match selection + groups_field_selector = widgets.Dropdown( + description="Field to plot", + options=self._df.columns, + value="label", + layout={"width": "max-content"}, + ) + corpus_per_year = self._df[field].value_counts() - def _plot_labels(b): - for label, group in self._df.loc[self._plot_widgets.selected()].groupby( - groups_field - ): - window = window_size_slider.value - if label != OUTLIERS_LABEL: - s = ( - (group[field].value_counts() / corpus_per_year) - .sort_index() - .rolling(window) - .mean() - ) - s.name = label - ax = s.plot(kind="line", legend=label) - ax.set_title( - f"Relative Frequency of {field} by {groups_field} (Rolling Window over {window} {field}s)" - ) - ax.set_xlabel(field) - ax.set_ylabel("Relative Frequency") + def _plot_by_field(b): + _selection = self._df.loc[self._plot_widgets.selected()] + groups_field = groups_field_selector.value + + if groups_field in _selection.columns: + for label, group in _selection.groupby(groups_field): + window = window_size_slider.value + if label != OUTLIERS_LABEL: + _series = ( + (group[field].value_counts() / corpus_per_year) + .sort_index() + .rolling(window) + .mean() + ) + _series.name = label + ax = _series.plot(kind="line", legend=label) + ax.set_title( + f"Relative Frequency by '{groups_field}' (Rolling Window over {window} {field}s)" + ) + ax.set_xlabel(field) + ax.set_ylabel("Relative Frequency") + else: + # TODO: this should never happen if the dropdown is updated + raise ValueError(f"Field '{groups_field}' not found in selection.") button = widgets.Button( description="Plot by Corpus", tooltip="Plot (selected) corpora frequencies over years by Corpus", ) - button.on_click(_plot_labels) + button.on_click(_plot_by_field) - return widgets.HBox((button, window_size_slider)) + return widgets.HBox((button, window_size_slider, groups_field_selector)) def _top_words_button(self) -> widgets.Button: def _show_top_words(b): # pragma: no cover diff --git a/tests/unit/visualization/test_jscatter.py b/tests/unit/visualization/test_jscatter.py index 0bfd07e..1138448 100644 --- a/tests/unit/visualization/test_jscatter.py +++ b/tests/unit/visualization/test_jscatter.py @@ -7,6 +7,7 @@ from ipywidgets.widgets import ( BoundedIntText, Button, + Dropdown, HBox, SelectionRangeSlider, SelectMultiple, @@ -188,7 +189,8 @@ def test_plot_button(self, corpus): visualizer = JScatterVisualizer([corpus]) widgets = visualizer.get_widgets() - assert [type(w) for w in widgets[-1].children] == [Button, BoundedIntText] + expected_widgets = [Button, BoundedIntText, Dropdown] + assert [type(w) for w in widgets[-1].children] == expected_widgets button = widgets[-1].children[0]