diff --git a/tempo_embeddings/visualization/jscatter.py b/tempo_embeddings/visualization/jscatter.py index 3965486..84f3503 100644 --- a/tempo_embeddings/visualization/jscatter.py +++ b/tempo_embeddings/visualization/jscatter.py @@ -28,12 +28,22 @@ def __init__(self, corpora: list[Corpus], **kwargs): self._visualizer = JScatterVisualizer(corpora, container=self, **kwargs) """The root visualizer.""" - self.add_tab(self._visualizer, **kwargs) + self.add_tab(self._visualizer) + + def add_tab(self, visualizer: list[Corpus], *, title: Optional[str] = None): + if title is None: + title = ( + f"Clusters {len(self._tab.children)}" + if self._tab.children + else "Full Corpus" + ) + + self._tab.children = list(self._tab.children) + [ + widgets.VBox(visualizer.get_widgets()) + ] - def add_tab(self, visualizer: list[Corpus], **kwargs): - children = list(self._tab.children) - children.append(widgets.VBox(visualizer.get_widgets())) - self._tab.children = children + self._tab.set_title(-1, title) + self._tab.selected_index = len(self._tab.children) - 1 def visualize(self): display(self._tab)