diff --git a/tempo_embeddings/visualization/jscatter.py b/tempo_embeddings/visualization/jscatter.py index 9bdf2ca..0e3c81c 100644 --- a/tempo_embeddings/visualization/jscatter.py +++ b/tempo_embeddings/visualization/jscatter.py @@ -379,6 +379,28 @@ def export(change): return widgets.HBox((button, csv_file, overwrite)) + def _color_by_dropdown(self): + current = self._scatter_plot.color()["by"] + columns = [ + c + for c in self._df.columns + if not self._df[c].hasnans and 1 < self._df[c].unique().size <= 50 + ] + if current not in columns: + columns.append(current) + # TODO: update columns by selection + + def handle_change(change): + self._scatter_plot.color(by=change["new"], map="auto") + + color_by_dropdown = widgets.Dropdown( + options=columns, value=current, description="Color by:", disabled=False + ) + + color_by_dropdown.observe(handle_change, names="value") + + return color_by_dropdown + def _category_field_filter( self, field: str ) -> Optional[widgets.SelectMultiple]: @@ -525,5 +547,6 @@ def get_widgets( widgets.HBox( [widget for widget in category_filters if widget is not None] ), + self._color_by_dropdown(), self._export_button(), ] diff --git a/tests/unit/visualization/test_jscatter.py b/tests/unit/visualization/test_jscatter.py index 946fe47..760ff18 100644 --- a/tests/unit/visualization/test_jscatter.py +++ b/tests/unit/visualization/test_jscatter.py @@ -233,3 +233,15 @@ def test_export_button(self, corpus, tmp_path): df = pd.read_csv(target_file) assert df.columns.to_list() == expected_columns assert len(df) == len(corpus) + + def test_color_by(self, corpus): + pw = JScatterVisualizer.PlotWidgets( + df=corpus.to_dataframe(), color_by="corpus", tooltip_fields=set() + ) + + color_box = pw._color_by_dropdown() + assert isinstance(color_box, Dropdown) + assert color_box.value == "corpus" + + color_box.value = "year" + assert pw._scatter_plot.color()["by"] == "year"