-
Notifications
You must be signed in to change notification settings - Fork 765
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add custom labels to get topic tree #2125
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -1847,11 +1847,12 @@ def get_representative_docs(self, topic: int = None) -> List[str]: | |||||||||||
else: | ||||||||||||
return self.representative_docs_ | ||||||||||||
|
||||||||||||
@staticmethod | ||||||||||||
def get_topic_tree( | ||||||||||||
self, | ||||||||||||
hier_topics: pd.DataFrame, | ||||||||||||
max_distance: float = None, | ||||||||||||
tight_layout: bool = False, | ||||||||||||
custom_labels: Union[bool, str] = False, | ||||||||||||
) -> str: | ||||||||||||
"""Extract the topic tree such that it can be printed. | ||||||||||||
|
||||||||||||
|
@@ -1862,6 +1863,11 @@ def get_topic_tree( | |||||||||||
based on the Distance column in `hier_topics`. | ||||||||||||
tight_layout: Whether to use a tight layout (narrow width) for | ||||||||||||
easier readability if you have hundreds of topics. | ||||||||||||
custom_labels: If bool, whether to use custom topic labels that were defined using | ||||||||||||
`topic_model.set_topic_labels`. | ||||||||||||
If `str`, it uses labels from other aspects, e.g., "Aspect1". | ||||||||||||
NOTE: Custom labels are only generated for the original | ||||||||||||
un-merged topics. | ||||||||||||
|
||||||||||||
Returns: | ||||||||||||
A tree that has the following structure when printed: | ||||||||||||
|
@@ -1897,9 +1903,40 @@ def get_topic_tree( | |||||||||||
|
||||||||||||
max_original_topic = hier_topics.Parent_ID.astype(int).min() - 1 | ||||||||||||
|
||||||||||||
# Prepare tree labels to print | ||||||||||||
child_left_ids = hier_topics.Child_Left_ID.astype(int) | ||||||||||||
child_right_ids = hier_topics.Child_Right_ID.astype(int) | ||||||||||||
|
||||||||||||
# Get the new parent labels generated from `hierarchical_topics` | ||||||||||||
new_left_labels = {int(row["Child_Left_ID"]): row["Child_Left_Name"] for idx, row in hier_topics.iterrows()} | ||||||||||||
new_right_labels = {int(row["Child_Right_ID"]): row["Child_Right_Name"] for idx, row in hier_topics.iterrows()} | ||||||||||||
|
||||||||||||
if custom_labels: | ||||||||||||
left_labels = {} | ||||||||||||
if isinstance(custom_labels, str): | ||||||||||||
for topic, kws_info in self.topic_aspects_[custom_labels].items(): | ||||||||||||
label = "_".join([kw[0] for kw in kws_info[:5]]) # displaying top 5 kws | ||||||||||||
left_labels[topic] = label | ||||||||||||
Comment on lines
+1916
to
+1919
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we add a check for when a user attempts to use a specific aspect but it is not found in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This makes sense, I can add a check here. One thing I noticed when looking through how custom labels were handled in the plotting functions is that they all follow a pretty similar pattern: BERTopic/bertopic/plotting/_topics.py Line 67 in 510c15e
BERTopic/bertopic/plotting/_heatmap.py Line 99 in 510c15e
BERTopic/bertopic/plotting/_datamap.py Line 123 in 510c15e
BERTopic/bertopic/plotting/_documents.py Line 139 in 510c15e
Maybe a more generic and reusable function can be created for this, which would check if the user specified aspect exists in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right, it makes no sense to do that check here but not at all of those other instances. For now, I'm alright either way. I can imagine wanting to stick with what I already use everywhere else for consistency as this might be a bit out of the scope of this particular PR. So it's up to you. |
||||||||||||
elif self.custom_labels_ is not None and custom_labels: | ||||||||||||
left_labels = {topic_id: label for topic_id, label in enumerate(self.custom_labels_, -self._outliers)} | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just wanted to mention that I am going to steal that enumerate trick. That is so much nicer coding than how I have approached it thus far. |
||||||||||||
|
||||||||||||
right_labels = left_labels.copy() | ||||||||||||
|
||||||||||||
# We want to preserve the original labels from `topic_aspects_` or `custom_labels_` | ||||||||||||
# while adding in those generated from `hierarchical_topics` | ||||||||||||
new_left_labels.update(left_labels) | ||||||||||||
new_right_labels.update(right_labels) | ||||||||||||
|
||||||||||||
child_left_names = [new_left_labels[topic] for topic in child_left_ids] | ||||||||||||
child_right_names = [new_right_labels[topic] for topic in child_right_ids] | ||||||||||||
|
||||||||||||
else: | ||||||||||||
child_left_names = hier_topics.Child_Left_Name | ||||||||||||
child_right_names = hier_topics.Child_Right_Name | ||||||||||||
|
||||||||||||
# Extract mapping from ID to name | ||||||||||||
topic_to_name = dict(zip(hier_topics.Child_Left_ID, hier_topics.Child_Left_Name)) | ||||||||||||
topic_to_name.update(dict(zip(hier_topics.Child_Right_ID, hier_topics.Child_Right_Name))) | ||||||||||||
topic_to_name = dict(zip(child_left_ids.astype(str), child_left_names)) | ||||||||||||
topic_to_name.update(dict(zip(child_right_ids.astype(str), child_right_names))) | ||||||||||||
topic_to_name = {topic: name[:100] for topic, name in topic_to_name.items()} | ||||||||||||
|
||||||||||||
# Create tree | ||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that some topic aspects might not have 5 keywords but actually a single label. In those cases, the
kws_info[0]
is likely to be filled with a label (or even a summary) and each instance inkws_info[1:]
will be an empty string. As a result, you might get the following label: "artificial intelligence____".