Skip to content

Commit

Permalink
Fix left-over comments/returns.
Browse files Browse the repository at this point in the history
  • Loading branch information
carschno committed Nov 18, 2024
1 parent 1eb7618 commit f354854
Showing 1 changed file with 19 additions and 22 deletions.
41 changes: 19 additions & 22 deletions tempo_embeddings/visualization/jscatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
class JScatterVisualizer:
"""A class for creating interactive scatter plots with Jupyter widgets."""

# DEFAULT_DERIVED_FIELDS: set[str] = {"label", "top words", "collection"}
# """Fields that are not part of the corpus metadata, but should still be shown in the tooltip."""
DEFAULT_CONTINUOUS_FIELDS: set[str] = {"year"}

def __init__(
Expand Down Expand Up @@ -279,31 +277,30 @@ def _continuous_field_filter(
Returns:
widgets.VBox: A widget containing a RangeSlider widget and the output widget
"""
if field not in self._df.columns:
logging.warning(f"Categorical field '{field}' not found, ignoring")
return

min_value = self._df[field].min()
max_value = self._df[field].max()

selection = widgets.SelectionRangeSlider(
options=[str(i) for i in range(min_value, max_value + 1)],
index=(0, max_value - min_value),
description=field,
continuous_update=True,
)
if field in self._df.columns:
min_value = self._df[field].min()
max_value = self._df[field].max()

selection = widgets.SelectionRangeSlider(
options=[str(i) for i in range(min_value, max_value + 1)],
index=(0, max_value - min_value),
description=field,
continuous_update=True,
)

selection_output = widgets.Output()
selection_output = widgets.Output()

def handle_slider_change(change):
start = int(change.new[0]) # noqa: F841
end = int(change.new[1]) # noqa: F841
def handle_slider_change(change):
start = int(change.new[0]) # noqa: F841
end = int(change.new[1]) # noqa: F841

self._filter(field, self._df.query("year > @start & year < @end").index)
self._filter(field, self._df.query("year > @start & year < @end").index)

selection.observe(handle_slider_change, names="value")
selection.observe(handle_slider_change, names="value")

return selection, selection_output
return selection, selection_output
else:
logging.warning(f"Categorical field '{field}' not found, ignoring")

def _filter(self, field, index):
"""Filter the scatter plot based on the given field and index.
Expand Down

0 comments on commit f354854

Please sign in to comment.