From 442b921bc3a8e81000774ec348dd7b9709faf57c Mon Sep 17 00:00:00 2001 From: Thomas Dorch <53101766+thomasdorch@users.noreply.github.com> Date: Fri, 29 Jul 2022 17:21:39 -0700 Subject: [PATCH] Type hinting (#2176) * add type hinting to visualization * import visualization on-the-fly in Simulation --- python/simulation.py | 9 +- python/visualization.py | 290 +++++++++++++++++++++++----------------- 2 files changed, 178 insertions(+), 121 deletions(-) diff --git a/python/simulation.py b/python/simulation.py index 0a4453c0a..6a58cd601 100644 --- a/python/simulation.py +++ b/python/simulation.py @@ -15,7 +15,6 @@ except ImportError: from collections.abc import Sequence -import meep.visualization as vis import numpy as np from meep.geom import GeometricObject, Medium, Vector3, init_do_averaging from meep.source import ( @@ -4680,6 +4679,8 @@ def plot2D( - `post_process=np.real`: post processing function to apply to fields (must be a function object) """ + import meep.visualization as vis + return vis.plot2D( self, ax=ax, @@ -4700,6 +4701,8 @@ def plot2D( ) def plot_fields(self, **kwargs): + import meep.visualization as vis + return vis.plot_fields(self, **kwargs) def plot3D(self): @@ -4707,6 +4710,8 @@ def plot3D(self): Uses Mayavi to render a 3D simulation domain. The simulation object must be 3D. Can also be embedded in Jupyter notebooks. """ + import meep.visualization as vis + return vis.plot3D(self) def visualize_chunks(self): @@ -4715,6 +4720,8 @@ def visualize_chunks(self): rectangular region is a chunk, and each color represents a different processor. Requires [matplotlib](https://matplotlib.org). """ + import meep.visualization as vis + vis.visualize_chunks(self) diff --git a/python/visualization.py b/python/visualization.py index 4f4879e8e..1fe46f0a6 100644 --- a/python/visualization.py +++ b/python/visualization.py @@ -1,16 +1,21 @@ -import warnings from collections import namedtuple +import warnings import numpy as np + +import meep as mp from meep.geom import Vector3, init_do_averaging from meep.source import EigenModeSource, check_positive +from meep.simulation import Simulation, Volume -import meep as mp +## Typing imports +from matplotlib.axes import Axes +from typing import Callable, Union, Any, Iterable # ------------------------------------------------------- # # Visualization # ------------------------------------------------------- # -# Contains all necesarry visualation routines for use with +# Contains all necessary visualization routines for use with # pymeep and pympb. # ------------------------------------------------------- # @@ -74,7 +79,7 @@ # don't correspond to the keyword arguments of a particular # function (func_with_kwargs.) # Adapted from https://stackoverflow.com/questions/26515595/how-does-one-ignore-unexpected-keyword-arguments-passed-to-a-function/44052550 -def filter_dict(dict_to_filter, func_with_kwargs): +def filter_dict(dict_to_filter: dict, func_with_kwargs: Callable) -> dict: import inspect filter_keys = [] @@ -86,18 +91,21 @@ def filter_dict(dict_to_filter, func_with_kwargs): # Python2 ... filter_keys = inspect.getargspec(func_with_kwargs)[0] - return { + filtered_dict = { filter_key: dict_to_filter[filter_key] for filter_key in filter_keys if filter_key in dict_to_filter } + return filtered_dict # ------------------------------------------------------- # # Routines to add legends to plot -def place_label(ax, label_text, x, y, centerx, centery, label_parameters=None): +def place_label( + ax: Axes, label_text: str, x, y, centerx, centery, label_parameters: dict = None +) -> Axes: if label_parameters is None: label_parameters = default_label_parameters @@ -108,8 +116,15 @@ def place_label(ax, label_text, x, y, centerx, centery, label_parameters=None): alpha = label_parameters["label_alpha"] color = label_parameters["label_color"] - xtext = -offset if x > centerx else offset - ytext = -offset if y > centery else offset + if x > centerx: + xtext = -offset + else: + xtext = offset + if y > centery: + ytext = -offset + else: + ytext = offset + ax.annotate( label_text, xy=(x, y), @@ -129,7 +144,7 @@ def place_label(ax, label_text, x, y, centerx, centery, label_parameters=None): # Returns the intersection points of two Volumes. # Volumes must be a line, plane, or rectangular prism # (since they are volume objects) -def intersect_volume_volume(volume1, volume2): +def intersect_volume_volume(volume1: Volume, volume2: Volume) -> list: # volume1 ............... [volume] # volume2 ............... [volume] @@ -161,11 +176,11 @@ def intersect_volume_volume(volume1, volume2): L = np.max([L1, L2], axis=0) # For single points we have to check manually - if np.all(U - L == 0) and ( - (not volume1.pt_in_volume(Vector3(*U))) - or (not volume2.pt_in_volume(Vector3(*U))) - ): - return [] + if np.all(U - L == 0): + if (not volume1.pt_in_volume(Vector3(*U))) or ( + not volume2.pt_in_volume(Vector3(*U)) + ): + return [] # Check for two volumes that don't intersect if np.any(U - L < 0): @@ -175,7 +190,9 @@ def intersect_volume_volume(volume1, volume2): vertices = [] for x_vals in [L[0], U[0]]: for y_vals in [L[1], U[1]]: - vertices.extend(Vector3(x_vals, y_vals, z_vals) for z_vals in [L[2], U[2]]) + for z_vals in [L[2], U[2]]: + vertices.append(Vector3(x_vals, y_vals, z_vals)) + # Remove any duplicate points caused by coplanar lines vertices = [ vertices[i] for i, x in enumerate(vertices) if x not in vertices[i + 1 :] @@ -193,28 +210,27 @@ def intersect_volume_volume(volume1, volume2): # Not only do we need to check for all of these possibilities, but we also need # to check if the user accidentally specifies a plane that stretches beyond the # simulation domain. -def get_2D_dimensions(sim, output_plane): - from meep.simulation import Volume - +def get_2D_dimensions(sim: Simulation, output_plane: Volume) -> (Vector3, Vector3): # Pull correct plane from user if output_plane: plane_center, plane_size = (output_plane.center, output_plane.size) elif sim.output_volume: plane_center, plane_size = mp.get_center_and_size(sim.output_volume) - elif (sim.dimensions == mp.CYLINDRICAL) or sim.is_cylindrical: - plane_center, plane_size = ( - sim.geometry_center + mp.Vector3(sim.cell_size.x / 2), - sim.cell_size, - ) else: - plane_center, plane_size = (sim.geometry_center, sim.cell_size) + if (sim.dimensions == mp.CYLINDRICAL) or sim.is_cylindrical: + plane_center, plane_size = ( + sim.geometry_center + Vector3(sim.cell_size.x / 2), + sim.cell_size, + ) + else: + plane_center, plane_size = (sim.geometry_center, sim.cell_size) plane_volume = Volume(center=plane_center, size=plane_size) if plane_size.x != 0 and plane_size.y != 0 and plane_size.z != 0: raise ValueError("Plane volume must be 2D (a plane).") if (sim.dimensions == mp.CYLINDRICAL) or sim.is_cylindrical: - center = sim.geometry_center + mp.Vector3(sim.cell_size.x / 2) - check_volume = mp.Volume(center=center, size=sim.cell_size) + center = sim.geometry_center + Vector3(sim.cell_size.x / 2) + check_volume = Volume(center=center, size=sim.cell_size) else: check_volume = Volume(center=sim.geometry_center, size=sim.cell_size) vertices = intersect_volume_volume(check_volume, plane_volume) @@ -237,7 +253,9 @@ def get_2D_dimensions(sim, output_plane): return sim_center, sim_size -def box_vertices(box_center, box_size, is_cylindrical=False): +def box_vertices( + box_center: Vector3, box_size: Vector3, is_cylindrical: bool = False +) -> (float, float, float, float, float, float): # in cylindrical coordinates, radial (R) axis # is in the range (0,R) rather than (-R/2,+R/2) # as in Cartesian coordinates. @@ -257,14 +275,16 @@ def box_vertices(box_center, box_size, is_cylindrical=False): # ------------------------------------------------------- # # actual plotting routines - - def plot_volume( - sim, ax, volume, output_plane=None, plotting_parameters=None, label=None -): + sim: Simulation, + ax: Axes, + volume: Volume, + output_plane: Volume = None, + plotting_parameters: dict = None, + label: str = None, +) -> Axes: import matplotlib.patches as patches from matplotlib import pyplot as plt - from meep.simulation import Volume # Set up the plotting parameters if plotting_parameters is None: @@ -357,7 +377,7 @@ def sort_points(xy): ax.plot( [a.y for a in intersection], [a.z for a in intersection], - **line_args, + **line_args ) return ax # Plot XZ @@ -365,7 +385,7 @@ def sort_points(xy): ax.plot( [a.x for a in intersection], [a.z for a in intersection], - **line_args, + **line_args ) return ax # Plot XY @@ -373,7 +393,7 @@ def sort_points(xy): ax.plot( [a.x for a in intersection], [a.y for a in intersection], - **line_args, + **line_args ) return ax else: @@ -417,7 +437,13 @@ def sort_points(xy): return ax -def plot_eps(sim, ax, output_plane=None, eps_parameters=None, frequency=None): +def plot_eps( + sim: Simulation, + ax: Axes, + output_plane: Volume = None, + eps_parameters: dict = None, + frequency: float = None, +) -> Axes: # consolidate plotting parameters if eps_parameters is None: eps_parameters = default_eps_parameters @@ -454,7 +480,11 @@ def plot_eps(sim, ax, output_plane=None, eps_parameters=None, frequency=None): sim_center, sim_size, sim.is_cylindrical ) - grid_resolution = eps_parameters["resolution"] or sim.resolution + if eps_parameters["resolution"]: + grid_resolution = eps_parameters["resolution"] + else: + grid_resolution = sim.resolution + Nx = int((xmax - xmin) * grid_resolution + 1) Ny = int((ymax - ymin) * grid_resolution + 1) Nz = int((zmax - zmin) * grid_resolution + 1) @@ -470,10 +500,10 @@ def plot_eps(sim, ax, output_plane=None, eps_parameters=None, frequency=None): elif sim_size.y == 0: # Plot x on x axis, z on y axis (XZ plane) extent = [xmin, xmax, zmin, zmax] - xlabel = ( - "R" if (sim.dimensions == mp.CYLINDRICAL) or sim.is_cylindrical else "X" - ) - + if (sim.dimensions == mp.CYLINDRICAL) or sim.is_cylindrical: + xlabel = "R" + else: + xlabel = "X" ylabel = "Z" xtics = np.linspace(xmin, xmax, Nx) ytics = np.array([sim_center.y]) @@ -512,16 +542,19 @@ def plot_eps(sim, ax, output_plane=None, eps_parameters=None, frequency=None): return ax -def plot_boundaries(sim, ax, output_plane=None, boundary_parameters=None): +def plot_boundaries( + sim: Simulation, + ax: Axes, + output_plane: Volume = None, + boundary_parameters: dict = None, +) -> Axes: # consolidate plotting parameters if boundary_parameters is None: boundary_parameters = default_boundary_parameters else: boundary_parameters = dict(default_boundary_parameters, **boundary_parameters) - def get_boundary_volumes(thickness, direction, side): - from meep.simulation import Volume - + def get_boundary_volumes(thickness: float, direction: float, side) -> Volume: thickness = boundary.thickness xmin, xmax, ymin, ymax, zmin, zmax = box_vertices( @@ -625,8 +658,13 @@ def get_boundary_volumes(thickness, direction, side): return ax -def plot_sources(sim, ax, output_plane=None, labels=False, source_parameters=None): - from meep.simulation import Volume +def plot_sources( + sim: Simulation, + ax: Axes, + output_plane: Volume = None, + labels: bool = False, + source_parameters: dict = None, +): # consolidate plotting parameters if source_parameters is None: @@ -649,14 +687,18 @@ def plot_sources(sim, ax, output_plane=None, labels=False, source_parameters=Non return ax -def plot_monitors(sim, ax, output_plane=None, labels=False, monitor_parameters=None): - from meep.simulation import Volume - +def plot_monitors( + sim: Simulation, + ax: Axes, + output_plane: Volume = None, + labels: bool = False, + monitor_parameters: dict = None, +) -> Axes: # consolidate plotting parameters if monitor_parameters is None: monitor_parameters = default_monitor_parameters else: - monitor_parametesr = dict(default_monitor_parameters, **monitor_parameters) + monitor_parameters = dict(default_monitor_parameters, **monitor_parameters) label = "monitor" if labels else None @@ -674,7 +716,13 @@ def plot_monitors(sim, ax, output_plane=None, labels=False, monitor_parameters=N return ax -def plot_fields(sim, ax=None, fields=None, output_plane=None, field_parameters=None): +def plot_fields( + sim: Simulation, + ax: Axes = None, + fields=None, + output_plane: Volume = None, + field_parameters: dict = None, +) -> Union[Axes, Any]: if not sim._is_initialized: sim.init_sim() @@ -686,7 +734,8 @@ def plot_fields(sim, ax=None, fields=None, output_plane=None, field_parameters=N else: field_parameters = dict(default_field_parameters, **field_parameters) - if fields not in [ + # user specifies a field component + if fields in [ mp.Ex, mp.Ey, mp.Ez, @@ -699,64 +748,68 @@ def plot_fields(sim, ax=None, fields=None, output_plane=None, field_parameters=N mp.Hy, mp.Hz, ]: - raise ValueError("Please specify a valid field component (mp.Ex, mp.Ey, ...") - - # Get domain measurements - sim_center, sim_size = get_2D_dimensions(sim, output_plane) + # Get domain measurements + sim_center, sim_size = get_2D_dimensions(sim, output_plane) - xmin, xmax, ymin, ymax, zmin, zmax = box_vertices( - sim_center, sim_size, sim.is_cylindrical - ) - - if sim_size.x == 0: - # Plot y on x axis, z on y axis (YZ plane) - extent = [ymin, ymax, zmin, zmax] - xlabel = "Y" - ylabel = "Z" - elif sim_size.y == 0: - # Plot x on x axis, z on y axis (XZ plane) - extent = [xmin, xmax, zmin, zmax] - xlabel = ( - "R" if (sim.dimensions == mp.CYLINDRICAL) or sim.is_cylindrical else "X" + xmin, xmax, ymin, ymax, zmin, zmax = box_vertices( + sim_center, sim_size, sim.is_cylindrical ) - ylabel = "Z" - elif sim_size.z == 0: - # Plot x on x axis, y on y axis (XY plane) - extent = [xmin, xmax, ymin, ymax] - xlabel = "X" - ylabel = "Y" - fields = sim.get_array(center=sim_center, size=sim_size, component=fields) + if sim_size.x == 0: + # Plot y on x axis, z on y axis (YZ plane) + extent = [ymin, ymax, zmin, zmax] + xlabel = "Y" + ylabel = "Z" + elif sim_size.y == 0: + # Plot x on x axis, z on y axis (XZ plane) + extent = [xmin, xmax, zmin, zmax] + if (sim.dimensions == mp.CYLINDRICAL) or sim.is_cylindrical: + xlabel = "R" + else: + xlabel = "X" + ylabel = "Z" + elif sim_size.z == 0: + # Plot x on x axis, y on y axis (XY plane) + extent = [xmin, xmax, ymin, ymax] + xlabel = "X" + ylabel = "Y" + fields = sim.get_array(center=sim_center, size=sim_size, component=fields) + else: + raise ValueError("Please specify a valid field component (mp.Ex, mp.Ey, ...") + fields = field_parameters["post_process"](fields) if (sim.dimensions == mp.CYLINDRICAL) or sim.is_cylindrical: fields = np.flipud(fields) else: fields = np.rot90(fields) - if not ax: + # Either plot the field, or return the array + if ax: + if mp.am_master(): + ax.imshow(fields, extent=extent, **filter_dict(field_parameters, ax.imshow)) + return ax + else: return fields - if mp.am_master(): - ax.imshow(fields, extent=extent, **filter_dict(field_parameters, ax.imshow)) return ax def plot2D( - sim, - ax=None, - output_plane=None, + sim: Simulation, + ax: Axes = None, + output_plane: Volume = None, fields=None, - labels=False, - eps_parameters=None, - boundary_parameters=None, - source_parameters=None, - monitor_parameters=None, - field_parameters=None, - frequency=None, - plot_eps_flag=True, - plot_sources_flag=True, - plot_monitors_flag=True, - plot_boundaries_flag=True, -): + labels: bool = False, + eps_parameters: dict = None, + boundary_parameters: dict = None, + source_parameters: dict = None, + monitor_parameters: dict = None, + field_parameters: dict = None, + frequency: float = None, + plot_eps_flag: bool = True, + plot_sources_flag: bool = True, + plot_monitors_flag: bool = True, + plot_boundaries_flag: bool = True, +) -> Axes: # Ensure a figure axis exists if ax is None and mp.am_master(): @@ -765,8 +818,6 @@ def plot2D( ax = plt.gca() # validate the output plane to ensure proper 2D coordinates - from meep.simulation import Volume - sim_center, sim_size = get_2D_dimensions(sim, output_plane) output_plane = Volume(center=sim_center, size=sim_size) @@ -819,7 +870,7 @@ def plot2D( return ax -def plot3D(sim): +def plot3D(sim: Simulation): from mayavi import mlab if sim.dimensions < 3: @@ -838,16 +889,17 @@ def plot3D(sim): ztics = np.linspace(zmin, zmax, Nz) eps_data = sim.get_epsilon_grid(xtics, ytics, ztics) - return mlab.contour3d(eps_data, colormap="YlGnBu") + s = mlab.contour3d(eps_data, colormap="YlGnBu") + return s -def visualize_chunks(sim): +def visualize_chunks(sim: Simulation): if sim.structure is None: sim.init_sim() + import matplotlib.pyplot as plt import matplotlib.cm import matplotlib.colors - import matplotlib.pyplot as plt if sim.structure.gv.dim == 2: from mpl_toolkits.mplot3d import Axes3D @@ -860,13 +912,13 @@ def visualize_chunks(sim): def plot_box(box, proc, fig, ax): if sim.structure.gv.dim == 2: - low = mp.Vector3(box.low.x, box.low.y, box.low.z) - high = mp.Vector3(box.high.x, box.high.y, box.high.z) + low = Vector3(box.low.x, box.low.y, box.low.z) + high = Vector3(box.high.x, box.high.y, box.high.z) points = [low, high] - x_len = mp.Vector3(high.x) - mp.Vector3(low.x) - y_len = mp.Vector3(y=high.y) - mp.Vector3(y=low.y) - xy_len = mp.Vector3(high.x, high.y) - mp.Vector3(low.x, low.y) + x_len = Vector3(high.x) - Vector3(low.x) + y_len = Vector3(y=high.y) - Vector3(y=low.y) + xy_len = Vector3(high.x, high.y) - Vector3(low.x, low.y) points += [low + x_len] points += [low + y_len] @@ -893,12 +945,12 @@ def plot_box(box, proc, fig, ax): # Plot the points themselves to force the scaling of the axes ax.scatter(points[:, 0], points[:, 1], points[:, 2], s=0) else: - low = mp.Vector3(box.low.x, box.low.y) - high = mp.Vector3(box.high.x, box.high.y) + low = Vector3(box.low.x, box.low.y) + high = Vector3(box.high.x, box.high.y) points = [low, high] - x_len = mp.Vector3(high.x) - mp.Vector3(low.x) - y_len = mp.Vector3(y=high.y) - mp.Vector3(y=low.y) + x_len = Vector3(high.x) - Vector3(low.x) + y_len = Vector3(y=high.y) - Vector3(y=low.y) points += [low + x_len] points += [low + y_len] @@ -1003,7 +1055,7 @@ def __init__( realtime=False, normalize=False, plot_modifiers=None, - **customization_args, + **customization_args ): """ Construct an `Animate2D` object. @@ -1175,7 +1227,6 @@ def to_jshtml(self, fps): # Only works with Python3 and matplotlib > 3.1.0 from distutils.version import LooseVersion - import matplotlib if LooseVersion(matplotlib.__version__) < LooseVersion("3.1.0"): @@ -1187,7 +1238,6 @@ def to_jshtml(self, fps): return if mp.am_master(): from uuid import uuid4 - from matplotlib._animation_data import ( DISPLAY_TEMPLATE, INCLUDED_FRAMES, @@ -1199,7 +1249,7 @@ def to_jshtml(self, fps): fill_frames = self._embedded_frames(self._saved_frames, self.frame_format) Nframes = len(self._saved_frames) mode_dict = dict(once_checked="", loop_checked="", reflect_checked="") - mode_dict[f"{self.default_mode}_checked"] = "checked" + mode_dict[self.default_mode + "_checked"] = "checked" interval = 1000 // fps @@ -1211,7 +1261,7 @@ def to_jshtml(self, fps): Nframes=Nframes, fill_frames=fill_frames, interval=interval, - **mode_dict, + **mode_dict ) return JS_Animation(html_string) @@ -1227,8 +1277,8 @@ def to_gif(self, fps, filename): # requires ffmpeg to be installed # modified from the matplotlib library if mp.am_master(): - from io import BytesIO, TextIOWrapper - from subprocess import PIPE, Popen + from subprocess import Popen, PIPE + from io import TextIOWrapper, BytesIO FFMPEG_BIN = "ffmpeg" command = [ @@ -1273,8 +1323,8 @@ def to_mp4(self, fps, filename): # requires ffmpeg to be installed # modified from the matplotlib library if mp.am_master(): - from io import BytesIO, TextIOWrapper - from subprocess import PIPE, Popen + from subprocess import Popen, PIPE + from io import TextIOWrapper, BytesIO FFMPEG_BIN = "ffmpeg" command = [