Skip to content

Commit

Permalink
Add arrow curvature to sources and monitors with bend radius (#272)
Browse files Browse the repository at this point in the history
Curved arrows are only possible through FancyArrowPatch, but those
define the arrow head size in display coordinates and end points in data
coordinates.  Furthermore, the arrow head is included in the total
length, as opposed to what happens with Axes.arrow, that was used
before.  The result is that the arrow length calculated based on the
axes limits or simulation boundaries can end up too small for the arrow
head (and for visualization as well).

The solution here, which might actually be preferable even without the
bending, is to define the arrow size in real units (inches, following
matplotlib's DPI setting) so that the arrows are always the same size.
That can only be accomplished by waiting for the transformations to be
set, at drawing time, so we use a `draw_event` callback to calculate the
arrow shape.

If the drawing window is resized, the arrow length will change, though,
because the callback is only called once (to avoid the danger of inifite
loops).

Signed-off-by: Lucas Heitzmann Gabrielli <[email protected]>
  • Loading branch information
lucas-flexcompute authored and momchil-flex committed Feb 24, 2023
1 parent d064762 commit df008ef
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 94 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed
- Saving and loading of `.hdf5` files is made orders of magnitude faster due to an internal refactor.
### Changed
- Sources and monitors with bend radii are displayed with curved arrows.

## [1.8.4] - 2023-2-13

### Fixed
Expand Down
137 changes: 68 additions & 69 deletions tidy3d/components/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@

import pydantic
import numpy as np
from matplotlib import patches
from shapely.geometry import Point, Polygon, box, MultiPolygon
from shapely.validation import make_valid

from .base import Tidy3dBaseModel, cached_property
from .types import Bound, Size, Coordinate, Axis, Coordinate2D, ArrayLike, PlanePosition
from .types import Vertices, Ax, Shapely, annotate_type
from .viz import add_ax_if_none, equal_aspect
from .viz import PLOT_BUFFER, ARROW_LENGTH_FACTOR, ARROW_WIDTH_FACTOR, MAX_ARROW_WIDTH_FACTOR
from .viz import PLOT_BUFFER, ARROW_LENGTH, arrow_style
from .viz import PlotParams, plot_params_geometry, polygon_patch
from ..log import Tidy3dKeyError, SetupError, ValidationError, log
from ..constants import MICROMETER, LARGE_NUMBER, RADIAN, fp_eps, inf
Expand Down Expand Up @@ -1383,10 +1384,9 @@ def _plot_arrow( # pylint:disable=too-many-arguments, too-many-locals
z: float = None,
color: str = None,
alpha: float = None,
length_factor: float = ARROW_LENGTH_FACTOR,
width_factor: float = ARROW_WIDTH_FACTOR,
bend_radius: float = None,
bend_axis: Axis = None,
both_dirs: bool = False,
sim_bounds: Bound = None,
ax: Ax = None,
) -> Ax:
"""Adds an arrow to the axis if with options if certain conditions met.
Expand All @@ -1405,10 +1405,10 @@ def _plot_arrow( # pylint:disable=too-many-arguments, too-many-locals
Color of the arrow.
alpha : float = None
Opacity of the arrow (0, 1)
length_factor : float = None
How long the (3D, unprojected) arrow is compared to the min(height, width) of the axes.
width_factor : float = None
How wide the (3D, unprojected) arrow is compared to the min(height, width) of the axes.
bend_radius : float = None
Radius of curvature for this arrow.
bend_axis : Axis = None
Axis of curvature of `bend_radius`.
both_dirs : bool = False
If True, plots an arrow ponting in direction and one in -direction.
Expand All @@ -1419,84 +1419,92 @@ def _plot_arrow( # pylint:disable=too-many-arguments, too-many-locals
"""

plot_axis, _ = self.parse_xyz_kwargs(x=x, y=y, z=z)

arrow_length, arrow_width = self._arrow_dims(
ax=ax,
length_factor=length_factor,
width_factor=width_factor,
sim_bounds=sim_bounds,
plot_axis=plot_axis,
)
_, (dx, dy) = self.pop_axis(direction, axis=plot_axis)

# conditions to check to determine whether to plot arrow
arrow_intersecting_plane = len(self.intersections_plane(x=x, y=y, z=z)) > 0
_, (dx, dy) = self.pop_axis(direction, axis=plot_axis)
components_in_plane = any(not np.isclose(component, 0) for component in (dx, dy))

# plot if arrow in plotting plane and some non-zero component can be displayed.
if arrow_intersecting_plane and components_in_plane:
_, (x0, y0) = self.pop_axis(self.center, axis=plot_axis)

def add_arrow(sign=1.0):
"""Add an arrow to the axes and include a sign to direction."""
ax.arrow(
x=x0,
y=y0,
dx=sign * arrow_length * dx,
dy=sign * arrow_length * dy,
width=arrow_width,
# Reasonable value for temporary arrow size. The correct size and direction
# have to be calculated after all transforms have been set. That is why we
# use a callback to do these calculations only at the drawing phase.
xmin, xmax = ax.get_xlim()
ymin, ymax = ax.get_ylim()
v_x = (xmax - xmin) / 10
v_y = (ymax - ymin) / 10

directions = (1.0, -1.0) if both_dirs else (1.0,)
for sign in directions:
arrow = patches.FancyArrowPatch(
(x0, y0),
(x0 + v_x, y0 + v_y),
arrowstyle=arrow_style,
color=color,
alpha=alpha,
zorder=np.inf,
)
# Don't draw this arrow until it's been reshaped
arrow.set_visible(False)

add_arrow(sign=1.0)
if both_dirs:
add_arrow(sign=-1.0)

return ax

def _arrow_dims( # pylint: disable=too-many-locals
self,
ax: Ax,
length_factor: float = ARROW_LENGTH_FACTOR,
width_factor: float = ARROW_WIDTH_FACTOR,
sim_bounds: Bound = None,
plot_axis: Axis = None,
) -> Tuple[float, float]:
"""Length and width of arrow based on axes size and length and width factors."""
callback = self._arrow_shape_cb(
arrow, (x0, y0), (dx, dy), sign, bend_radius if bend_axis == plot_axis else None
)
callback_id = ax.figure.canvas.mpl_connect("draw_event", callback)

if sim_bounds is not None:
# Store a reference to the callback because mpl_connect does not.
arrow.set_shape_cb = (callback_id, callback)

# use the sim_bounds to get sizes
rmin, rmax = sim_bounds
_, (xmin, ymin) = self.pop_axis(rmin, axis=plot_axis)
_, (xmax, ymax) = self.pop_axis(rmax, axis=plot_axis)
ax.add_patch(arrow)

else:
# get the sizes of the matplotlib axes
xmin, xmax = ax.get_xlim()
ymin, ymax = ax.get_ylim()
return ax

width = xmax - xmin
height = ymax - ymin
@staticmethod
def _arrow_shape_cb(arrow, pos, direction, sign, bend_radius):
def _cb(event):
# We only want to set the shape once, so we disconnect ourselves
event.canvas.mpl_disconnect(arrow.set_shape_cb[0])

transform = arrow.axes.transData.transform
scale = transform((1, 0))[0] - transform((0, 0))[0]
arrow_length = ARROW_LENGTH * event.canvas.figure.get_dpi() / scale

if bend_radius:
v_norm = (direction[0] ** 2 + direction[1] ** 2) ** 0.5
vx_norm = direction[0] / v_norm
vy_norm = direction[1] / v_norm
bend_angle = -sign * arrow_length / bend_radius
t_x = 1 - np.cos(bend_angle)
t_y = np.sin(bend_angle)
v_x = -bend_radius * (vx_norm * t_y - vy_norm * t_x)
v_y = -bend_radius * (vx_norm * t_x + vy_norm * t_y)
tangent_angle = np.arctan2(direction[1], direction[0])
arrow.set_connectionstyle(
patches.ConnectionStyle.Angle3(
angleA=180 / np.pi * tangent_angle,
angleB=180 / np.pi * (tangent_angle + bend_angle),
)
)

# apply length factor to the minimum size to get arrow length
arrow_length = length_factor * min(width, height)
else:
v_x = sign * arrow_length * direction[0]
v_y = sign * arrow_length * direction[1]

# constrain arrow width by the maximum size and the max arrow width factor
arrow_width = width_factor * arrow_length
arrow_width = min(arrow_width, MAX_ARROW_WIDTH_FACTOR * max(width, height))
arrow.set_positions(pos, (pos[0] + v_x, pos[1] + v_y))
arrow.set_visible(True)
arrow.draw(event.renderer)

return arrow_length, arrow_width
return _cb

def _volume(self, bounds: Bound) -> float:
"""Returns object's volume given bounds."""

volume = 1

for axis in range(3):

min_bound = max(self.bounds[0][axis], bounds[0][axis])
max_bound = min(self.bounds[1][axis], bounds[1][axis])

Expand All @@ -1514,7 +1522,6 @@ def _surface_area(self, bounds: Bound) -> float:
length = [0, 0, 0]

for axis in (0, 1, 2):

if min_bounds[axis] < bounds[0][axis]:
min_bounds[axis] = bounds[0][axis]
in_bounds_factor[axis] -= 1
Expand Down Expand Up @@ -1616,9 +1623,7 @@ def _volume(self, bounds: Bound) -> float:

# a very loose upper bound on how much of sphere is in bounds
for axis in range(3):

if self.center[axis] <= bounds[0][axis] or self.center[axis] >= bounds[1][axis]:

volume *= 0.5

return volume
Expand All @@ -1630,9 +1635,7 @@ def _surface_area(self, bounds: Bound) -> float:

# a very loose upper bound on how much of sphere is in bounds
for axis in range(3):

if self.center[axis] <= bounds[0][axis] or self.center[axis] >= bounds[1][axis]:

area *= 0.5

return area
Expand Down Expand Up @@ -1847,10 +1850,8 @@ def _volume(self, bounds: Bound) -> float:

# a very loose upper bound on how much of the cylinder is in bounds
for axis in range(3):

if axis != self.axis:
if self.center[axis] <= bounds[0][axis] or self.center[axis] >= bounds[1][axis]:

volume *= 0.5

return volume
Expand Down Expand Up @@ -1879,10 +1880,8 @@ def _surface_area(self, bounds: Bound) -> float:

# a very loose upper bound on how much of the cylinder is in bounds
for axis in range(3):

if axis != self.axis:
if self.center[axis] <= bounds[0][axis] or self.center[axis] >= bounds[1][axis]:

area *= 0.5

return area
Expand Down Expand Up @@ -2765,7 +2764,7 @@ def _find_intersecting_ys_angle_vertical( # pylint:disable=too-many-locals
# intersecting positions and angles
ints_y = []
ints_angle = []
for (vertices_f_local, vertices_b_local) in zip(iverts_b, iverts_f):
for vertices_f_local, vertices_b_local in zip(iverts_b, iverts_f):
x1, y1 = vertices_f_local
x2, y2 = vertices_b_local
slope = (y2 - y1) / (x2 - x1)
Expand Down Expand Up @@ -2836,7 +2835,7 @@ def _find_intersecting_ys_angle_slant( # pylint:disable=too-many-locals, too-ma
shift_val = shift_x if axis == 0 else shift_y
shift_val = shift_val[intersects_on]

for (vertices_f_local, vertices_b_local, vertices_on_local, shift_local) in zip(
for vertices_f_local, vertices_b_local, vertices_on_local, shift_local in zip(
iverts_f, iverts_b, iverts_on, shift_val
):
x_on, y_on = vertices_on_local
Expand Down
28 changes: 18 additions & 10 deletions tidy3d/components/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pydantic
import numpy as np

from .types import Ax, EMField, ArrayLike, Bound, FreqArray
from .types import Ax, EMField, ArrayLike, FreqArray
from .types import Literal, Direction, Coordinate, Axis, ObsGridArray
from .geometry import Box
from .validators import assert_plane
Expand Down Expand Up @@ -214,6 +214,11 @@ class PlanarMonitor(Monitor, ABC):

_plane_validator = assert_plane()

@cached_property
def normal_axis(self) -> Axis:
"""Axis normal to the monitor's plane."""
return self.size.index(0.0)


class AbstractModeMonitor(PlanarMonitor, FreqMonitor):
""":class:`Monitor` that records mode-related data."""
Expand All @@ -230,10 +235,8 @@ def plot( # pylint:disable=too-many-arguments
y: float = None,
z: float = None,
ax: Ax = None,
sim_bounds: Bound = None,
**patch_kwargs,
) -> Ax:

# call the monitor.plot() function first
ax = super().plot(x=x, y=y, z=z, ax=ax, **patch_kwargs)

Expand All @@ -247,10 +250,11 @@ def plot( # pylint:disable=too-many-arguments
z=z,
ax=ax,
direction=self._dir_arrow,
bend_radius=self.mode_spec.bend_radius,
bend_axis=self._bend_axis,
color=ARROW_COLOR_MONITOR,
alpha=arrow_alpha,
both_dirs=True,
sim_bounds=sim_bounds,
)
return ax

Expand All @@ -260,7 +264,16 @@ def _dir_arrow(self) -> Tuple[float, float, float]:
dx = np.cos(self.mode_spec.angle_phi) * np.sin(self.mode_spec.angle_theta)
dy = np.sin(self.mode_spec.angle_phi) * np.sin(self.mode_spec.angle_theta)
dz = np.cos(self.mode_spec.angle_theta)
return self.unpop_axis(dz, (dx, dy), axis=self.size.index(0.0))
return self.unpop_axis(dz, (dx, dy), axis=self.normal_axis)

@cached_property
def _bend_axis(self) -> Axis:
if self.mode_spec.bend_radius is None:
return None
in_plane = [0, 0]
in_plane[self.mode_spec.bend_axis] = 1
direction = self.unpop_axis(0, in_plane, axis=self.normal_axis)
return direction.index(1)


class FieldMonitor(AbstractFieldMonitor, FreqMonitor):
Expand Down Expand Up @@ -771,11 +784,6 @@ def diffraction_monitor_size(cls, val):
)
return val

@property
def normal_axis(self) -> Axis:
"""Axis normal to the monitor's plane."""
return self.size.index(0)

def storage_size(self, num_cells: int, tmesh: ArrayLike[float, 1]) -> int:
"""Size of monitor storage given the number of points after discretization."""
# assumes 1 diffraction order per frequency; actual size will be larger
Expand Down
Loading

0 comments on commit df008ef

Please sign in to comment.