diff --git a/avae/vis.py b/avae/vis.py index 2d975f5..17180d5 100644 --- a/avae/vis.py +++ b/avae/vis.py @@ -5,6 +5,7 @@ import typing import altair +import matplotlib import matplotlib.gridspec as gridspec import matplotlib.pyplot as plt import numpy as np @@ -1886,7 +1887,9 @@ def plot_affinity_matrix( ax = plt.subplot(gs[0]) ax.set_title("Affinity Matrix", fontsize=16) - im = ax.imshow(lookup, vmin=-1, vmax=1, cmap=plt.get_cmap("RdBu")) + im = ax.imshow( + lookup, vmin=-1, vmax=1, cmap=matplotlib.colormaps.get_cmap("RdBu") + ) ax.set_xticks(np.arange(0, len(all_classes))) ax.set_xticklabels(all_classes)