Skip to content

BUG plot_clusters #707

Merged
merged 3 commits into from
May 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
-
-
-
- Fix bug in plot_clusters ([#675](https://github.com/tinkoff-ai/etna/pull/675))
-
- Fix bugs and documentation for cross_corr_plot ([#691](https://github.com/tinkoff-ai/etna/pull/691))
-
Expand Down
13 changes: 5 additions & 8 deletions etna/analysis/plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,25 +788,22 @@ def plot_clusters(
size of the figure per subplot with one segment in inches
"""
unique_clusters = sorted(set(segment2cluster.values()))
rows_num = math.ceil(len(unique_clusters) / columns_num)
figsize = (figsize[0] * columns_num, figsize[1] * rows_num)
fig, axs = plt.subplots(rows_num, columns_num, constrained_layout=True, figsize=figsize)
_, ax = prepare_axes(num_plots=len(unique_clusters), columns_num=columns_num, figsize=figsize)
for i, cluster in enumerate(unique_clusters):
segments = [segment for segment in segment2cluster if segment2cluster[segment] == cluster]
h, w = i // columns_num, i % columns_num
for segment in segments:
segment_slice = ts[:, segment, "target"]
axs[h][w].plot(
ax[i].plot(
segment_slice.index.values,
segment_slice.values,
alpha=1 / math.sqrt(len(segments)),
c="blue",
)
axs[h][w].set_title(f"cluster={cluster}\n{len(segments)} segments in cluster")
ax[i].set_title(f"cluster={cluster}\n{len(segments)} segments in cluster")
if centroids_df is not None:
centroid = centroids_df[cluster, "target"]
axs[h][w].plot(centroid.index.values, centroid.values, c="red", label="centroid")
axs[h][w].legend()
ax[i].plot(centroid.index.values, centroid.values, c="red", label="centroid")
ax[i].legend()


def plot_time_series_with_change_points(
Expand Down