diff --git a/python/visualization.py b/python/visualization.py index 58d13e350..60d586f47 100644 --- a/python/visualization.py +++ b/python/visualization.py @@ -1,5 +1,6 @@ from collections import namedtuple import warnings +import copy from time import sleep @@ -472,7 +473,7 @@ def _add_colorbar( from mpl_toolkits.axes_grid1 import make_axes_locatable if colorbar_parameters is None: - colorbar_parameters = default_colorbar_parameters + colorbar_parameters = copy.deepcopy(default_colorbar_parameters) else: colorbar_parameters = dict(default_colorbar_parameters, **colorbar_parameters) @@ -486,13 +487,17 @@ def _add_colorbar( norm=mpl.colors.Normalize(vmin, vmax), cmap=mpl.cm.get_cmap(cmap), ) + # Pop specific values out of colorbar params so user can add any kwargs to plt.colorbar - cax = make_axes_locatable(ax).append_axes( + # ref: https://matplotlib.org/stable/gallery/axes_grid1/demo_colorbar_with_axes_divider.html#colorbar-with-axesdivider + ax_divider = make_axes_locatable(ax) + cax = ax_divider.append_axes( pad=colorbar_parameters.pop("pad"), size=colorbar_parameters.pop("size"), position=colorbar_parameters.pop("position"), ) - plt.colorbar(mappable=sm, cax=cax, **colorbar_parameters) + fig = ax.get_figure() + fig.colorbar(mappable=sm, cax=cax, **colorbar_parameters) def plot_eps( @@ -881,7 +886,6 @@ def plot_fields( ax.imshow(field_data, extent=extent, **filter_dict(field_parameters, ax.imshow)) if field_parameters["colorbar"]: - _add_colorbar( ax=ax, cmap=field_parameters["cmap"],