Skip to content

Commit

Permalink
Additional constraints and validators for invdes plugin
Browse files Browse the repository at this point in the history
 - Constraints for design parameters to be in range [0, 1]
 - Constraints on optimizer, penalty, transformation arguments such as
   `beta`, `eta`, `learning_rate`
 - Validation for forward and gradient evaluation of both
   transformations and penalties
 - Validation for metrics included in `Design` objects
 - Renamed `Metric.freqs` --> `Metric.f`
 - Make `Metric.f` optional
  • Loading branch information
yaugenst-flex authored and momchil-flex committed Oct 7, 2024
1 parent 80026e9 commit 9fd7cf2
Show file tree
Hide file tree
Showing 12 changed files with 308 additions and 62 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `TopologyDesignRegion` is now invariant in `z` by default and supports assigning dimensions along which a design should be uniform via `TopologyDesignRegion(uniform=(bool, bool, bool))`.
- Support for arbitrary padding sizes for all padding modes in `tidy3d.plugins.autograd.functions.pad`.
- `Expression.filter(target_type, target_field)` method for extracting object instances and fields from nested expressions.
- Additional constraints and validation logic to ensure correct setup of optimization problems in `invdes` plugin.

### Changed
- Renamed `Metric.freqs` --> `Metric.f` and made frequency argument optional, in which case all frequencies from the relevant monitor will be extracted.

### Fixed
- Some validation fixes for design region.
Expand Down
28 changes: 25 additions & 3 deletions tests/test_plugins/test_invdes.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ def test_invdes_with_metric_objective(use_emulated_run, use_emulated_to_sim_data
"""Test using a metric as an objective function in InverseDesign."""

# Create a metric as the objective function
metric = 2 * ModePower(monitor_name=MNT_NAME2, freqs=[FREQ0]) ** 2
metric = 2 * ModePower(monitor_name=MNT_NAME2, f=[FREQ0]) ** 2

invdes = tdi.InverseDesign(
simulation=simulation,
Expand All @@ -529,7 +529,7 @@ def test_invdes_with_metric_objective(use_emulated_run, use_emulated_to_sim_data
[
(RandomInitializationSpec, {"min_value": 0.0, "max_value": 1.0}, (3, 3)),
(UniformInitializationSpec, {"value": 0.5}, (2, 2)),
(CustomInitializationSpec, {"params": np.array([[1, 2], [3, 4]])}, (2, 2)),
(CustomInitializationSpec, {"params": np.zeros((3, 3, 3))}, (3, 3, 3)),
],
)
def test_parameter_spec(spec_class, spec_kwargs, expected_shape):
Expand All @@ -542,7 +542,7 @@ def test_parameter_spec(spec_class, spec_kwargs, expected_shape):
def test_parameter_spec_with_inverse_design(use_emulated_run, use_emulated_to_sim_data): # noqa: F811
"""Test InitializationSpec with InverseDesign class."""

metric = 2 * ModePower(monitor_name=MNT_NAME2, freqs=[FREQ0]) ** 2
metric = 2 * ModePower(monitor_name=MNT_NAME2, f=[FREQ0]) ** 2

initialization_spec = RandomInitializationSpec()
design_region = make_design_region()
Expand Down Expand Up @@ -584,3 +584,25 @@ def test_initial_simulation_multi():
assert sim.structures[-1] == invdes_multi.design_region.to_structure(
invdes_multi.design_region.initial_parameters
)


def test_validate_invdes_metric():
"""Test the _validate_metric_monitor_name validator."""
invdes = make_invdes()
metric = ModePower(monitor_name="invalid_monitor", f=[FREQ0])
with pytest.raises(ValueError, match="monitors"):
invdes.updated_copy(metric=metric)

metric = ModePower(monitor_name=MNT_NAME2, mode_index=10, f=[FREQ0])
with pytest.raises(ValueError, match="mode index"):
invdes.updated_copy(metric=metric)

metric = ModePower(monitor_name=MNT_NAME2, mode_index=0, f=[FREQ0 / 2])
with pytest.raises(ValueError, match="frequencies"):
invdes.updated_copy(metric=metric)

metric = ModePower(monitor_name=MNT_NAME2, mode_index=0)
monitor = mnt2.updated_copy(freqs=[FREQ0, FREQ0 / 2])
invdes = invdes.updated_copy(simulation=simulation.updated_copy(monitors=[monitor]))
with pytest.raises(ValueError, match="single frequency"):
invdes.updated_copy(metric=metric)
6 changes: 5 additions & 1 deletion tidy3d/plugins/expressions/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ class Function(Expression):
Base class for mathematical functions in expressions.
"""

operand: NumberOrExpression
operand: NumberOrExpression = pd.Field(
...,
title="Operand",
description="The operand for the function.",
)

_format: str = "{func}({operand})"

Expand Down
70 changes: 36 additions & 34 deletions tidy3d/plugins/expressions/metrics.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Any
from typing import Any, Optional, Union

import autograd.numpy as anp
import autograd.numpy as np
import pydantic.v1 as pd

from tidy3d.components.monitor import ModeMonitor
from tidy3d.components.types import Direction, FreqArray
from tidy3d.exceptions import ValidationError

from .types import NumberType
from .variables import Variable
Expand All @@ -27,48 +26,51 @@ class ModeAmp(Metric):
"""
Metric for calculating the mode coefficient from a ModeMonitor.
Parameters
----------
monitor_name : str
The name of the mode monitor.
freqs : FreqArray
The frequency array.
direction : Direction = "+"
The direction of the mode.
mode_index : pd.NonNegativeInt = 0
The index of the mode.
Examples
--------
>>> monitor = ModeMonitor(name="monitor1", freqs=[1.0])
>>> monitor = ModeMonitor(name="monitor1", f=[1.0])
>>> mode_coeff = ModeAmp.from_mode_monitor(monitor)
>>> data = SimulationData() # Assume this is a valid SimulationData object
>>> result = mode_coeff.evaluate(data)
"""

monitor_name: str
freqs: FreqArray
direction: Direction = "+"
mode_index: pd.NonNegativeInt = 0

@pd.validator("freqs", always=True)
def _single_frequency(cls, val: FreqArray) -> FreqArray:
if len(val) != 1:
raise ValidationError("Only a single frequency is supported at the moment.")
return val
monitor_name: str = pd.Field(
...,
title="Monitor Name",
description="The name of the mode monitor. This needs to match the name of the monitor in the simulation.",
)
f: Optional[Union[float, FreqArray]] = pd.Field( # type: ignore
None,
title="Frequency Array",
description="The frequency array. If None, all frequencies in the monitor will be used.",
)
direction: Direction = pd.Field(
"+",
title="Direction",
description="The direction of propagation of the mode.",
)
mode_index: pd.NonNegativeInt = pd.Field(
0,
title="Mode Index",
description="The index of the mode.",
)

@classmethod
def from_mode_monitor(cls, monitor: ModeMonitor):
return cls(monitor_name=monitor.name, freqs=monitor.freqs, mode_index=0)
def from_mode_monitor(
cls, monitor: ModeMonitor, mode_index: int = 0, direction: Direction = "+"
):
return cls(
monitor_name=monitor.name, f=monitor.freqs, mode_index=mode_index, direction=direction
)

def evaluate(self, *args: Any, **kwargs: Any) -> NumberType:
data = super().evaluate(*args, **kwargs)
amps = (
data[self.monitor_name]
.amps.sel(direction=self.direction, mode_index=self.mode_index)
.isel(f=0)
amps = data[self.monitor_name].amps.sel(
direction=self.direction, mode_index=self.mode_index
)
return anp.squeeze(amps.values.tolist())
if self.f is not None:
amps = amps.sel(f=list(self.f), method="nearest")
return np.squeeze(amps.values.tolist())


class ModePower(ModeAmp):
Expand All @@ -77,12 +79,12 @@ class ModePower(ModeAmp):
Examples
--------
>>> monitor = ModeMonitor(name="monitor1", freqs=[1.0])
>>> monitor = ModeMonitor(name="monitor1", f=[1.0])
>>> mode_power = ModePower.from_mode_monitor(monitor)
>>> data = SimulationData() # Assume this is a valid SimulationData object
>>> result = mode_power.evaluate(data)
"""

def evaluate(self, *args: Any, **kwargs: Any) -> NumberType:
amps = super().evaluate(*args, **kwargs)
return abs(amps) ** 2
return np.abs(amps) ** 2
18 changes: 15 additions & 3 deletions tidy3d/plugins/expressions/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ class UnaryOperator(Expression):
Subclasses should implement the evaluate method to define the specific operation.
"""

operand: NumberOrExpression
operand: NumberOrExpression = pd.Field(
...,
title="Operand",
description="The operand for the unary operator.",
)

_symbol: str
_format: str = "({symbol}{operand})"
Expand All @@ -37,8 +41,16 @@ class BinaryOperator(Expression):
Subclasses should implement the evaluate method to define the specific operation.
"""

left: NumberOrExpression
right: NumberOrExpression
left: NumberOrExpression = pd.Field(
...,
title="Left",
description="The left operand for the binary operator.",
)
right: NumberOrExpression = pd.Field(
...,
title="Right",
description="The right operand for the binary operator.",
)

_symbol: str
_format: str = "({left} {symbol} {right})"
Expand Down
14 changes: 12 additions & 2 deletions tidy3d/plugins/expressions/variables.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any, Optional

import pydantic.v1 as pd

from .base import Expression
from .types import NumberType

Expand Down Expand Up @@ -34,7 +36,11 @@ class Variable(Expression):
>>> expr(5, 3) # Raises ValueError
"""

name: Optional[str] = None
name: Optional[str] = pd.Field(
None,
title="Name",
description="The name of the variable used for lookup during evaluation.",
)

def evaluate(self, *args: Any, **kwargs: Any) -> NumberType:
if self.name:
Expand Down Expand Up @@ -72,7 +78,11 @@ class Constant(Variable):
>>> c.evaluate() # Returns 5
"""

value: NumberType
value: NumberType = pd.Field(
...,
title="Value",
description="The fixed value of the constant.",
)

def __init__(self, value: NumberType, **kwargs: dict[str, Any]) -> None:
super().__init__(value=value, **kwargs)
Expand Down
59 changes: 59 additions & 0 deletions tidy3d/plugins/invdes/design.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
import typing

import autograd.numpy as anp
import numpy as np
import pydantic.v1 as pd

import tidy3d as td
import tidy3d.web as web
from tidy3d.components.autograd import get_static
from tidy3d.exceptions import ValidationError
from tidy3d.plugins.expressions.metrics import Metric
from tidy3d.plugins.expressions.types import ExpressionType

from .base import InvdesBaseModel
Expand Down Expand Up @@ -119,6 +122,62 @@ class InverseDesign(AbstractInverseDesign):

_check_sim_pixel_size = check_pixel_size("simulation")

@pd.root_validator(pre=False)
def _validate_model(cls, values: dict) -> dict:
cls._validate_metric(values)
return values

@staticmethod
def _validate_metric(values: dict) -> dict:
metric_expr = values.get("metric")
if not metric_expr:
return values
simulation = values.get("simulation")
for metric in metric_expr.filter(Metric):
InverseDesign._validate_metric_monitor_name(metric, simulation)
InverseDesign._validate_metric_mode_index(metric, simulation)
InverseDesign._validate_metric_f(metric, simulation)
return values

@staticmethod
def _validate_metric_monitor_name(metric: Metric, simulation: td.Simulation) -> None:
"""Validate that the monitor name of the metric exists in the simulation."""
monitor = next((m for m in simulation.monitors if m.name == metric.monitor_name), None)
if monitor is None:
raise ValidationError(
f"Monitor named '{metric.monitor_name}' associated with the metric not found in the simulation monitors."
)

@staticmethod
def _validate_metric_mode_index(metric: Metric, simulation: td.Simulation) -> None:
"""Validate that the mode index of the metric is within the bounds of the monitor's ``ModeSpec.num_modes``."""
monitor = next((m for m in simulation.monitors if m.name == metric.monitor_name), None)
if metric.mode_index >= monitor.mode_spec.num_modes:
raise ValidationError(
f"Mode index '{metric.mode_index}' for metric associated with monitor "
f"'{metric.monitor_name}' is out of bounds. "
f"Maximum allowed mode index is '{monitor.mode_spec.num_modes - 1}'."
)

@staticmethod
def _validate_metric_f(metric: Metric, simulation: td.Simulation) -> None:
"""Validate that the frequencies of the metric are present in the monitor."""
monitor = next((m for m in simulation.monitors if m.name == metric.monitor_name), None)
if metric.f is not None:
if len(metric.f) != 1:
raise ValidationError("Only a single frequency is supported for the metric.")
for freq in metric.f:
if not any(np.isclose(freq, monitor.freqs, atol=1.0)):
raise ValidationError(
f"Frequency '{freq}' for metric associated with monitor "
f"'{metric.monitor_name}' not found in monitor frequencies."
)
else:
if len(monitor.freqs) != 1:
raise ValidationError(
f"Monitor '{metric.monitor_name}' must contain only a single frequency when metric.f is None."
)

def is_output_monitor(self, monitor: td.Monitor) -> bool:
"""Whether a monitor is added to the ``JaxSimulation`` as an ``output_monitor``."""

Expand Down
Loading

0 comments on commit 9fd7cf2

Please sign in to comment.