Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiable n-dimensional nearest and linear interpolation for DataArray #1769

Merged
merged 1 commit into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Support for differentiation with respect to `GeometryGroup.geometries` elements.
- Users can now export `SimulationData` to MATLAB `.mat` files with the `to_mat_file` method.
- `ModeSolver` methods to plot the mode plane simulation components, including `.plot()`, `.plot_eps()`, `.plot_structures_eps()`, `.plot_grid()`, and `.plot_pml()`.
- Support for differentiation with respect to monitor attributes that require interpolation, such as flux and intensity.

### Changed

Expand Down
31 changes: 30 additions & 1 deletion tests/test_components/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def field_vol_postprocess_fn(sim_data, mnt_data):
value += abs(anp.sum(val.values))
intensity = anp.nan_to_num(anp.sum(sim_data.get_intensity(mnt_data.monitor.name).values))
value += intensity
# value += anp.sum(mnt_data.flux.values) # not yet supported
value += anp.sum(mnt_data.flux.values)
return value
yaugenst-flex marked this conversation as resolved.
Show resolved Hide resolved

field_point = td.FieldMonitor(
Expand Down Expand Up @@ -747,6 +747,35 @@ def objective(*args):
ag.grad(objective)(params0)


@pytest.mark.parametrize("colocate", [True, False])
@pytest.mark.parametrize("objtype", ["flux", "intensity"])
def test_interp_objectives(use_emulated_run, colocate, objtype):
monitor = td.FieldMonitor(
center=(0, 0, 0),
size=(td.inf, td.inf, 0),
freqs=[FREQ0],
name="monitor",
colocate=colocate,
)

def objective(args):
structures_traced_dict = make_structures(args)
structures = list(SIM_BASE.structures)
for structure_key in structure_keys_:
structures.append(structures_traced_dict[structure_key])

sim = SIM_BASE.updated_copy(monitors=[monitor], structures=structures)
data = run(sim, task_name="autograd_test", verbose=False)

if objtype == "flux":
return anp.sum(data[monitor.name].flux.values)
elif objtype == "intensity":
return anp.sum(data.get_intensity(monitor.name).values)

tylerflex marked this conversation as resolved.
Show resolved Hide resolved
grads = ag.grad(objective)(params0)
assert np.any(grads > 0)


def test_autograd_deepcopy():
"""make sure deepcopy works as expected in autograd."""

Expand Down
37 changes: 37 additions & 0 deletions tests/test_plugins/autograd/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import numpy.testing as npt
import pytest
import scipy.interpolate
import scipy.ndimage
from autograd.test_util import check_grads
from scipy.signal import convolve as convolve_sp
Expand All @@ -10,6 +11,7 @@
grey_dilation,
grey_erosion,
grey_opening,
interpn,
morphological_gradient,
morphological_gradient_external,
morphological_gradient_internal,
Expand Down Expand Up @@ -300,3 +302,38 @@ def test_threshold_exceptions(array, vmin, vmax, level, expected_message):
"""Test threshold function for expected exceptions."""
with pytest.raises(ValueError, match=expected_message):
threshold(array, vmin, vmax, level)


@pytest.mark.parametrize("dim", [1, 2, 3, 4])
@pytest.mark.parametrize("method", ["linear", "nearest"])
class TestInterpn:
tylerflex marked this conversation as resolved.
Show resolved Hide resolved
@staticmethod
def generate_points_values_xi(rng, dim):
points = tuple(np.linspace(0, 1, 10) for _ in range(dim))
values = rng.random([p.size for p in points])
xi = tuple(np.linspace(0, 1, 5) for _ in range(dim))
return points, values, xi

def test_interpn_val(self, rng, dim, method):
points, values, xi = self.generate_points_values_xi(rng, dim)
xi_grid = np.meshgrid(*xi, indexing="ij")

result_custom = interpn(points, values, xi, method=method)
result_scipy = scipy.interpolate.interpn(points, values, tuple(xi_grid), method=method)
npt.assert_allclose(result_custom, result_scipy)

@pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("mode", ["fwd", "rev"])
def test_interpn_values_grad(self, rng, dim, method, order, mode):
points, values, xi = self.generate_points_values_xi(rng, dim)
check_grads(lambda v: interpn(points, v, xi, method=method), modes=[mode], order=order)(
values
)


class TestInterpnExceptions:
def test_invalid_method(self, rng):
"""Test that an exception is raised for an invalid interpolation method."""
points, values, xi = TestInterpn.generate_points_values_xi(rng, 2)
with pytest.raises(ValueError, match="interpolation method"):
interpn(points, values, xi, method="invalid_method")
18 changes: 4 additions & 14 deletions tests/test_plugins/test_microwave.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
from tidy3d.plugins.microwave.models import coupled_microstrip, microstrip

from ..utils import run_emulated
from ..utils import get_spatial_coords_dict, run_emulated

# Using similar code as "test_data/test_data_arrays.py"
MON_SIZE = (2, 1, 0)
Expand All @@ -28,9 +28,7 @@
F0 = (FSTART + FSTOP) / 2
FWIDTH = FSTOP - FSTART
FS = np.linspace(FSTART, FSTOP, 3)
FIELD_MONITOR = td.FieldMonitor(
size=MON_SIZE, fields=FIELDS, name="strip_field", freqs=FS, colocate=False
)
FIELD_MONITOR = td.FieldMonitor(size=MON_SIZE, fields=FIELDS, name="strip_field", freqs=FS)
STRIP_WIDTH = 1.5
STRIP_HEIGHT = 0.5

Expand All @@ -39,7 +37,7 @@
grid_spec=td.GridSpec.uniform(dl=0.04),
monitors=[
FIELD_MONITOR,
td.FieldMonitor(center=(0, 0, 0), size=(1, 1, 1), freqs=FS, name="field"),
td.FieldMonitor(center=(0, 0, 0), size=(1, 1, 1), freqs=FS, name="field", colocate=False),
td.FieldMonitor(
center=(0, 0, 0), size=(1, 1, 1), freqs=FS, fields=["Ex", "Hx"], name="ExHx"
),
Expand Down Expand Up @@ -67,17 +65,9 @@
""" Generate the data arrays for testing path integral computations """


def get_xyz(
monitor: td.components.monitor.MonitorType, grid_key: str
) -> tuple[list[float], list[float], list[float]]:
grid = SIM_Z.discretize_monitor(monitor)
x, y, z = grid[grid_key].to_list
return x, y, z


def make_stripline_scalar_field_data_array(grid_key: str):
"""Populate FIELD_MONITOR with a idealized stripline mode, where fringing fields are assumed 0."""
XS, YS, ZS = get_xyz(FIELD_MONITOR, grid_key)
XS, YS, ZS = get_spatial_coords_dict(SIM_Z, FIELD_MONITOR, grid_key).values()
XGRID, YGRID = np.meshgrid(XS, YS, indexing="ij")
XGRID = XGRID.reshape((len(XS), len(YS), 1, 1))
YGRID = YGRID.reshape((len(XS), len(YS), 1, 1))
Expand Down
52 changes: 24 additions & 28 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def make_custom_data(lims, unstructured):
# Make a few autograd ArrayBoxes for testing
start_node = VJPNode.new_root()
tracer = new_box(1.0, 0, start_node)
tracer_arr = new_box([[[1.0]]], 0, start_node)
tracer_arr = new_box(np.array([[[1.0]]]), 0, start_node)

SIM_FULL = td.Simulation(
size=(8.0, 8.0, 8.0),
Expand Down Expand Up @@ -852,6 +852,24 @@ def make_custom_data(lims, unstructured):
)


def get_spatial_coords_dict(simulation: td.Simulation, monitor: td.Monitor, field_name: str):
"""Returns MonitorData coordinates associated with a Monitor object"""
grid = simulation.discretize_monitor(monitor)
spatial_coords = grid.boundaries if monitor.colocate else grid[field_name]
spatial_coords_dict = spatial_coords.dict()

coords = {}
for axis, dim in enumerate("xyz"):
if monitor.size[axis] == 0:
coords[dim] = [monitor.center[axis]]
elif monitor.colocate:
coords[dim] = spatial_coords_dict[dim][:-1]
else:
coords[dim] = spatial_coords_dict[dim]

return coords


def run_emulated(simulation: td.Simulation, path=None, **kwargs) -> td.SimulationData:
"""Emulates a simulation run."""
from scipy.ndimage.filters import gaussian_filter
Expand All @@ -872,19 +890,11 @@ def make_data(
def make_field_data(monitor: td.FieldMonitor) -> td.FieldData:
"""make a random FieldData from a FieldMonitor."""
field_cmps = {}
coords = {}
grid = simulation.discretize_monitor(monitor)

for field_name in monitor.fields:
spatial_coords_dict = grid[field_name].dict()

for axis, dim in enumerate("xyz"):
if monitor.size[axis] == 0:
yaugenst-flex marked this conversation as resolved.
Show resolved Hide resolved
coords[dim] = [monitor.center[axis]]
else:
coords[dim] = np.array(spatial_coords_dict[dim])

coords = get_spatial_coords_dict(simulation, monitor, field_name)
coords["f"] = list(monitor.freqs)

field_cmps[field_name] = make_data(
coords=coords, data_array_type=td.ScalarFieldDataArray, is_complex=True
)
Expand All @@ -900,17 +910,10 @@ def make_field_data(monitor: td.FieldMonitor) -> td.FieldData:
def make_field_time_data(monitor: td.FieldTimeMonitor) -> td.FieldTimeData:
"""make a random FieldTimeData from a FieldTimeMonitor."""
field_cmps = {}
coords = {}
grid = simulation.discretize_monitor(monitor)
tmesh = simulation.tmesh
for field_name in monitor.fields:
spatial_coords_dict = grid[field_name].dict()

for axis, dim in enumerate("xyz"):
if monitor.size[axis] == 0:
coords[dim] = [monitor.center[axis]]
else:
coords[dim] = np.array(spatial_coords_dict[dim])
coords = get_spatial_coords_dict(simulation, monitor, field_name)

(idx_begin, idx_end) = monitor.time_inds(tmesh)
tcoords = tmesh[idx_begin:idx_end]
Expand All @@ -930,7 +933,6 @@ def make_field_time_data(monitor: td.FieldTimeMonitor) -> td.FieldTimeData:
def make_mode_solver_data(monitor: td.ModeSolverMonitor) -> td.ModeSolverData:
"""make a random ModeSolverData from a ModeSolverMonitor."""
field_cmps = {}
coords = {}
grid = simulation.discretize_monitor(monitor)
index_coords = {}
index_coords["f"] = list(monitor.freqs)
Expand All @@ -940,16 +942,10 @@ def make_mode_solver_data(monitor: td.ModeSolverMonitor) -> td.ModeSolverData:
(1 + 1j) * np.random.random(index_data_shape), coords=index_coords
)
for field_name in ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"]:
spatial_coords_dict = grid[field_name].dict()

for axis, dim in enumerate("xyz"):
if monitor.size[axis] == 0:
coords[dim] = [monitor.center[axis]]
else:
coords[dim] = np.array(spatial_coords_dict[dim])

coords = get_spatial_coords_dict(simulation, monitor, field_name)
coords["f"] = list(monitor.freqs)
coords["mode_index"] = index_coords["mode_index"]

field_cmps[field_name] = make_data(
coords=coords, data_array_type=td.ScalarModeFieldDataArray, is_complex=True
)
Expand Down
4 changes: 2 additions & 2 deletions tidy3d/components/autograd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .functions import interpn
from .types import (
AutogradFieldMap,
AutogradTraced,
Expand All @@ -18,6 +19,5 @@
"AutogradTraced",
"AutogradFieldMap",
"get_static",
"integrate_within_bounds",
"DerivativeInfo",
"interpn",
]
1 change: 0 additions & 1 deletion tidy3d/components/autograd/derivative_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def integrate_within_bounds(arr: xr.DataArray, dims: list[str], bounds: Bound) -

# uses trapezoidal rule
# https://docs.xarray.dev/en/stable/generated/xarray.DataArray.integrate.html

dims_integrate = [dim for dim in dims if len(_arr.coords[dim]) > 1]
return _arr.integrate(coord=dims_integrate)

Expand Down
Loading
Loading