Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add colorbars for plot2D #2289

Merged
merged 7 commits into from
Nov 3, 2022
Merged
104 changes: 85 additions & 19 deletions python/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
"cmap": "RdBu",
"alpha": 0.8,
"post_process": np.real,
"colorbar": False,
"colorbar_label": None,
}

default_eps_parameters = {
Expand All @@ -57,6 +59,7 @@
"contour_linewidth": 1,
"frequency": None,
"resolution": None,
"colorbar": False,
}

default_boundary_parameters = {
Expand Down Expand Up @@ -447,6 +450,31 @@ def sort_points(xy):
return ax


def _add_colorbar(
ax: Axes,
cmap: str,
vmin: float,
vmax: float,
label: str,
clip_values: bool = False,
) -> None:
"""Add a colorbar to the parent Figure of 'ax' by creating an additional Axes."""
import matplotlib as mpl
from mpl_toolkits.axes_grid1 import make_axes_locatable

# Create a map between field/eps values and colors in the colormap.
# Note: cm.get_cmap() is deprecated for matplotlib>=3.6, use mpl.colormaps[cmap] instead if necessary.
sm = mpl.cm.ScalarMappable(
norm=mpl.colors.Normalize(vmin, vmax, clip_values), cmap=mpl.cm.get_cmap(cmap)
)
plt.colorbar(
mappable=sm,
cax=make_axes_locatable(ax).append_axes(position="right", size="5%", pad=0.05),
thomasdorch marked this conversation as resolved.
Show resolved Hide resolved
orientation="vertical",
label=label,
)


def plot_eps(
sim: Simulation,
ax: Axes = None,
Expand Down Expand Up @@ -549,6 +577,16 @@ def plot_eps(
ax.imshow(
eps_data, extent=extent, **filter_dict(eps_parameters, ax.imshow)
)

if eps_parameters["colorbar"]:
_add_colorbar(
ax=ax,
cmap=eps_parameters["cmap"],
vmin=np.amin(eps_data),
vmax=np.amax(eps_data),
label=r"$\epsilon_r$",
)

ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
return ax
Expand Down Expand Up @@ -735,6 +773,34 @@ def plot_fields(
output_plane: Volume = None,
field_parameters: dict = None,
) -> Union[Axes, Any]:
components = {
mp.Ex,
mp.Ey,
mp.Ez,
mp.Er,
mp.Ep,
mp.Dx,
mp.Dy,
mp.Dz,
mp.Dr,
mp.Dp,
mp.Hx,
mp.Hy,
mp.Hz,
mp.Hr,
mp.Hp,
mp.Bx,
mp.By,
mp.Bz,
mp.Br,
mp.Bp,
mp.Sx,
mp.Sy,
mp.Sz,
mp.Sr,
mp.Sp,
}

if not sim._is_initialized:
sim.init_sim()

Expand All @@ -747,19 +813,7 @@ def plot_fields(
field_parameters = dict(default_field_parameters, **field_parameters)

# user specifies a field component
if fields in [
mp.Ex,
mp.Ey,
mp.Ez,
mp.Er,
mp.Ep,
mp.Dx,
mp.Dy,
mp.Dz,
mp.Hx,
mp.Hy,
mp.Hz,
]:
if fields in components:
# Get domain measurements
sim_center, sim_size = get_2D_dimensions(sim, output_plane)

Expand All @@ -785,22 +839,34 @@ def plot_fields(
extent = [xmin, xmax, ymin, ymax]
xlabel = "X"
ylabel = "Y"
fields = sim.get_array(center=sim_center, size=sim_size, component=fields)
field_data = sim.get_array(center=sim_center, size=sim_size, component=fields)
else:
raise ValueError("Please specify a valid field component (mp.Ex, mp.Ey, ...")

fields = field_parameters["post_process"](fields)
field_data = field_parameters["post_process"](field_data)
if (sim.dimensions == mp.CYLINDRICAL) or sim.is_cylindrical:
fields = np.flipud(fields)
field_data = np.flipud(field_data)
else:
fields = np.rot90(fields)
field_data = np.rot90(field_data)

# Either plot the field, or return the array
if ax:
if mp.am_master():
ax.imshow(fields, extent=extent, **filter_dict(field_parameters, ax.imshow))
ax.imshow(
field_data, extent=extent, **filter_dict(field_parameters, ax.imshow)
)

if field_parameters["colorbar"]:

_add_colorbar(
ax=ax,
cmap=field_parameters.get("cmap", default_field_parameters["cmap"]),
vmin=np.amin(field_data),
vmax=np.amax(field_data),
label=field_parameters["colorbar_label"] or "Field Value",
)
return ax
return fields
return field_data


def plot2D(
Expand Down