-
Notifications
You must be signed in to change notification settings - Fork 765
/
_topics.py
198 lines (174 loc) · 6.63 KB
/
_topics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import numpy as np
import pandas as pd
from umap import UMAP
from typing import List, Union
from sklearn.preprocessing import MinMaxScaler
from bertopic._utils import select_topic_representation
import plotly.express as px
import plotly.graph_objects as go
def visualize_topics(
topic_model,
topics: List[int] = None,
top_n_topics: int = None,
use_ctfidf: bool = False,
custom_labels: Union[bool, str] = False,
title: str = "<b>Intertopic Distance Map</b>",
width: int = 650,
height: int = 650,
) -> go.Figure:
"""Visualize topics, their sizes, and their corresponding words.
This visualization is highly inspired by LDAvis, a great visualization
technique typically reserved for LDA.
Arguments:
topic_model: A fitted BERTopic instance.
topics: A selection of topics to visualize
top_n_topics: Only select the top n most frequent topics
use_ctfidf: Whether to use c-TF-IDF representations instead of the embeddings from the embedding model.
custom_labels: If bool, whether to use custom topic labels that were defined using
`topic_model.set_topic_labels`.
If `str`, it uses labels from other aspects, e.g., "Aspect1".
title: Title of the plot.
width: The width of the figure.
height: The height of the figure.
Examples:
To visualize the topics simply run:
```python
topic_model.visualize_topics()
```
Or if you want to save the resulting figure:
```python
fig = topic_model.visualize_topics()
fig.write_html("path/to/file.html")
```
<iframe src="../../getting_started/visualization/viz.html"
style="width:1000px; height: 680px; border: 0px;""></iframe>
"""
# Select topics based on top_n and topics args
freq_df = topic_model.get_topic_freq()
freq_df = freq_df.loc[freq_df.Topic != -1, :]
if topics is not None:
topics = list(topics)
elif top_n_topics is not None:
topics = sorted(freq_df.Topic.to_list()[:top_n_topics])
else:
topics = sorted(freq_df.Topic.to_list())
# Extract topic words and their frequencies
topic_list = sorted(topics)
frequencies = [topic_model.topic_sizes_[topic] for topic in topic_list]
if isinstance(custom_labels, str):
words = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in topic_list]
words = ["_".join([label[0] for label in labels[:4]]) for labels in words]
words = [label if len(label) < 30 else label[:27] + "..." for label in words]
elif custom_labels and topic_model.custom_labels_ is not None:
words = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in topic_list]
else:
words = [" | ".join([word[0] for word in topic_model.get_topic(topic)[:5]]) for topic in topic_list]
# Embed c-TF-IDF into 2D
all_topics = sorted(list(topic_model.get_topics().keys()))
indices = np.array([all_topics.index(topic) for topic in topics])
embeddings, c_tfidf_used = select_topic_representation(
topic_model.c_tf_idf_,
topic_model.topic_embeddings_,
use_ctfidf=use_ctfidf,
output_ndarray=True,
)
embeddings = embeddings[indices]
if c_tfidf_used:
embeddings = MinMaxScaler().fit_transform(embeddings)
embeddings = UMAP(n_neighbors=2, n_components=2, metric="hellinger", random_state=42).fit_transform(embeddings)
else:
embeddings = UMAP(n_neighbors=2, n_components=2, metric="cosine", random_state=42).fit_transform(embeddings)
# Visualize with plotly
df = pd.DataFrame(
{
"x": embeddings[:, 0],
"y": embeddings[:, 1],
"Topic": topic_list,
"Words": words,
"Size": frequencies,
}
)
return _plotly_topic_visualization(df, topic_list, title, width, height)
def _plotly_topic_visualization(df: pd.DataFrame, topic_list: List[str], title: str, width: int, height: int):
"""Create plotly-based visualization of topics with a slider for topic selection."""
def get_color(topic_selected):
if topic_selected == -1:
marker_color = ["#B0BEC5" for _ in topic_list]
else:
marker_color = ["red" if topic == topic_selected else "#B0BEC5" for topic in topic_list]
return [{"marker.color": [marker_color]}]
# Prepare figure range
x_range = (
df.x.min() - abs((df.x.min()) * 0.15),
df.x.max() + abs((df.x.max()) * 0.15),
)
y_range = (
df.y.min() - abs((df.y.min()) * 0.15),
df.y.max() + abs((df.y.max()) * 0.15),
)
# Plot topics
fig = px.scatter(
df,
x="x",
y="y",
size="Size",
size_max=40,
template="simple_white",
labels={"x": "", "y": ""},
hover_data={"Topic": True, "Words": True, "Size": True, "x": False, "y": False},
)
fig.update_traces(marker=dict(color="#B0BEC5", line=dict(width=2, color="DarkSlateGrey")))
# Update hover order
fig.update_traces(
hovertemplate="<br>".join(
[
"<b>Topic %{customdata[0]}</b>",
"%{customdata[1]}",
"Size: %{customdata[2]}",
]
)
)
# Create a slider for topic selection
steps = [dict(label=f"Topic {topic}", method="update", args=get_color(topic)) for topic in topic_list]
sliders = [dict(active=0, pad={"t": 50}, steps=steps)]
# Stylize layout
fig.update_layout(
title={
"text": f"{title}",
"y": 0.95,
"x": 0.5,
"xanchor": "center",
"yanchor": "top",
"font": dict(size=22, color="Black"),
},
width=width,
height=height,
hoverlabel=dict(bgcolor="white", font_size=16, font_family="Rockwell"),
xaxis={"visible": False},
yaxis={"visible": False},
sliders=sliders,
)
# Update axes ranges
fig.update_xaxes(range=x_range)
fig.update_yaxes(range=y_range)
# Add grid in a 'plus' shape
fig.add_shape(
type="line",
x0=sum(x_range) / 2,
y0=y_range[0],
x1=sum(x_range) / 2,
y1=y_range[1],
line=dict(color="#CFD8DC", width=2),
)
fig.add_shape(
type="line",
x0=x_range[0],
y0=sum(y_range) / 2,
x1=x_range[1],
y1=sum(y_range) / 2,
line=dict(color="#9E9E9E", width=2),
)
fig.add_annotation(x=x_range[0], y=sum(y_range) / 2, text="D1", showarrow=False, yshift=10)
fig.add_annotation(y=y_range[1], x=sum(x_range) / 2, text="D2", showarrow=False, xshift=10)
fig.data = fig.data[::-1]
return fig