Skip to content

Commit

Permalink
create a duplicate of the dict for the colorbar parameters of plot2D (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
oskooi authored Jan 24, 2023
1 parent f448ebd commit 85e1077
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions python/visualization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import namedtuple
import warnings
import copy

from time import sleep

Expand Down Expand Up @@ -472,7 +473,7 @@ def _add_colorbar(
from mpl_toolkits.axes_grid1 import make_axes_locatable

if colorbar_parameters is None:
colorbar_parameters = default_colorbar_parameters
colorbar_parameters = copy.deepcopy(default_colorbar_parameters)
else:
colorbar_parameters = dict(default_colorbar_parameters, **colorbar_parameters)

Expand All @@ -486,13 +487,17 @@ def _add_colorbar(
norm=mpl.colors.Normalize(vmin, vmax),
cmap=mpl.cm.get_cmap(cmap),
)

# Pop specific values out of colorbar params so user can add any kwargs to plt.colorbar
cax = make_axes_locatable(ax).append_axes(
# ref: https://matplotlib.org/stable/gallery/axes_grid1/demo_colorbar_with_axes_divider.html#colorbar-with-axesdivider
ax_divider = make_axes_locatable(ax)
cax = ax_divider.append_axes(
pad=colorbar_parameters.pop("pad"),
size=colorbar_parameters.pop("size"),
position=colorbar_parameters.pop("position"),
)
plt.colorbar(mappable=sm, cax=cax, **colorbar_parameters)
fig = ax.get_figure()
fig.colorbar(mappable=sm, cax=cax, **colorbar_parameters)


def plot_eps(
Expand Down Expand Up @@ -881,7 +886,6 @@ def plot_fields(
ax.imshow(field_data, extent=extent, **filter_dict(field_parameters, ax.imshow))

if field_parameters["colorbar"]:

_add_colorbar(
ax=ax,
cmap=field_parameters["cmap"],
Expand Down

0 comments on commit 85e1077

Please sign in to comment.