From f3548541b9e47f54237b207000cf68c693546159 Mon Sep 17 00:00:00 2001 From: Carsten Schnober Date: Mon, 18 Nov 2024 16:23:40 +0100 Subject: [PATCH] Fix left-over comments/returns. --- tempo_embeddings/visualization/jscatter.py | 41 ++++++++++------------ 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/tempo_embeddings/visualization/jscatter.py b/tempo_embeddings/visualization/jscatter.py index 54f7cd8..acec8cd 100644 --- a/tempo_embeddings/visualization/jscatter.py +++ b/tempo_embeddings/visualization/jscatter.py @@ -14,8 +14,6 @@ class JScatterVisualizer: """A class for creating interactive scatter plots with Jupyter widgets.""" - # DEFAULT_DERIVED_FIELDS: set[str] = {"label", "top words", "collection"} - # """Fields that are not part of the corpus metadata, but should still be shown in the tooltip.""" DEFAULT_CONTINUOUS_FIELDS: set[str] = {"year"} def __init__( @@ -279,31 +277,30 @@ def _continuous_field_filter( Returns: widgets.VBox: A widget containing a RangeSlider widget and the output widget """ - if field not in self._df.columns: - logging.warning(f"Categorical field '{field}' not found, ignoring") - return - - min_value = self._df[field].min() - max_value = self._df[field].max() - - selection = widgets.SelectionRangeSlider( - options=[str(i) for i in range(min_value, max_value + 1)], - index=(0, max_value - min_value), - description=field, - continuous_update=True, - ) + if field in self._df.columns: + min_value = self._df[field].min() + max_value = self._df[field].max() + + selection = widgets.SelectionRangeSlider( + options=[str(i) for i in range(min_value, max_value + 1)], + index=(0, max_value - min_value), + description=field, + continuous_update=True, + ) - selection_output = widgets.Output() + selection_output = widgets.Output() - def handle_slider_change(change): - start = int(change.new[0]) # noqa: F841 - end = int(change.new[1]) # noqa: F841 + def handle_slider_change(change): + start = int(change.new[0]) # noqa: F841 + end = int(change.new[1]) # noqa: F841 - self._filter(field, self._df.query("year > @start & year < @end").index) + self._filter(field, self._df.query("year > @start & year < @end").index) - selection.observe(handle_slider_change, names="value") + selection.observe(handle_slider_change, names="value") - return selection, selection_output + return selection, selection_output + else: + logging.warning(f"Categorical field '{field}' not found, ignoring") def _filter(self, field, index): """Filter the scatter plot based on the given field and index.