Skip to content

Commit

Permalink
WIP add top words button.
Browse files Browse the repository at this point in the history
  • Loading branch information
carschno committed Nov 22, 2024
1 parent f55be00 commit 3c3bf6e
Showing 1 changed file with 29 additions and 9 deletions.
38 changes: 29 additions & 9 deletions tempo_embeddings/visualization/jscatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,12 +215,8 @@ def _init_scatter(self) -> jscatter.Scatter:
def _selected(self) -> list[int]:
"""Return the indices of currently selected/filtered/all rows.
1. If there are selected points, return their indices, OR
2. If there are filtered points, return their indices, OR
3. Return all indices.
Returns:
list[int]: The indices of the selected/filtered/all rows.
The indices of the selected/filtered/all rows.
"""
if self._scatter.selection().size > 0:
index = self._scatter.selection()
Expand All @@ -229,16 +225,39 @@ def _selected(self) -> list[int]:
# filter() raises error if it has not been set yet
filter_indices = self._scatter.filter()
except AttributeError:
logging.debug("No filter indices found")
logging.debug("No filter set.")
index = self._df.index
else:
index = filter_indices if len(filter_indices) > 0 else self._df.index
return self._df.iloc[index].to_csv(index=False, quoting=csv.QUOTE_ALL)
index = filter_indices if filter_indices.size > 0 else self._df.index
return index

def _export_button(self) -> DownloadButton:
def selected_rows():
return self._df.iloc[self._selected()].to_csv(
index=False, quoting=csv.QUOTE_ALL
)

return DownloadButton(
filename="scatter_plot.csv", contents=self._selected, description="Export"
filename="scatter_plot.csv", contents=selected_rows, description="Export"
)

def _top_words_button(self) -> widgets.Button:
def _show_top_words(b):
# TODO: create a link between self._df and the corpora
# TODO: keep/unify text column names
# Corpus.from_csv_stream(self._df.iloc[self._scatter.selection()].to_csv())
corpus = Corpus.from_dataframe(self._df[self._selected()])
top_words = self._keyword_extractor.top_words(corpus)
print(top_words)

button = widgets.Button(
description="Top words",
disabled=False,
button_style="", # 'success', 'info', 'warning', 'danger' or ''
tooltip="Show top words",
)
button.on_click(_show_top_words)
return button

def _init_widgets(self) -> tuple[jscatter.Scatter, widgets.HBox, widgets.HBox]:
"""Create the widgets for filtering the scatter plot."""
Expand All @@ -261,6 +280,7 @@ def _init_widgets(self) -> tuple[jscatter.Scatter, widgets.HBox, widgets.HBox]:
widgets.HBox(continuous_filters),
widgets.HBox(category_filters),
self._export_button(),
self._top_words_button(),
]

return self._widgets
Expand Down

0 comments on commit 3c3bf6e

Please sign in to comment.