Skip to content

Commit

Permalink
Adjoint typehints (#2202)
Browse files Browse the repository at this point in the history
* precommit fixes

* add more typehints

* fix precommit errors

* fix last few typos

* fix one last typo

* remove unused imports

* fix other typos

Co-authored-by: Alec Hammond <[email protected]>
  • Loading branch information
smartalecH and Alec Hammond authored Sep 22, 2022
1 parent 0a5aaee commit 7f0469f
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 34 deletions.
42 changes: 32 additions & 10 deletions python/adjoint/objective.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Handling of objective functions and objective quantities."""
import abc
from collections import namedtuple
from typing import Callable, List, Optional

import numpy as np
from meep.simulation import py_v3_to_vec
Expand Down Expand Up @@ -161,15 +162,18 @@ class EigenmodeCoefficient(ObjectiveQuantity):

def __init__(
self,
sim,
volume,
mode,
forward=True,
kpoint_func=None,
kpoint_func_overlap_idx=0,
decimation_factor=0,
sim: mp.Simulation,
volume: mp.Volume,
mode: int,
forward: Optional[bool] = True,
kpoint_func: Optional[Callable] = None,
kpoint_func_overlap_idx: Optional[int] = 0,
decimation_factor: Optional[int] = 0,
**kwargs
):
"""
+ **`sim` [ `Simulation` ]** —
"""
super().__init__(sim)
if kpoint_func_overlap_idx not in [0, 1]:
raise ValueError(
Expand Down Expand Up @@ -268,7 +272,15 @@ def __call__(self):


class FourierFields(ObjectiveQuantity):
def __init__(self, sim, volume, component, yee_grid=False, decimation_factor=0):
def __init__(
self,
sim: mp.Simulation,
volume: mp.Volume,
component: List[int],
yee_grid: Optional[bool] = False,
decimation_factor: Optional[int] = 0,
):
""" """
super().__init__(sim)
self.volume = sim._fit_volume_to_simulation(volume)
self.component = component
Expand Down Expand Up @@ -354,7 +366,14 @@ def __call__(self):


class Near2FarFields(ObjectiveQuantity):
def __init__(self, sim, Near2FarRegions, far_pts, decimation_factor=0):
def __init__(
self,
sim: mp.Simulation,
Near2FarRegions: List[mp.Near2FarRegion],
far_pts: List[mp.Vector3],
decimation_factor: Optional[int] = 0,
):
""" """
super().__init__(sim)
self.Near2FarRegions = Near2FarRegions
self.far_pts = far_pts # list of far pts
Expand Down Expand Up @@ -420,7 +439,10 @@ def __call__(self):


class LDOS(ObjectiveQuantity):
def __init__(self, sim, decimation_factor=0, **kwargs):
def __init__(
self, sim: mp.Simulation, decimation_factor: Optional[int] = 0, **kwargs
):
""" """
super().__init__(sim)
self.decimation_factor = decimation_factor
self.srckwarg = kwargs
Expand Down
62 changes: 38 additions & 24 deletions python/adjoint/optimization_problem.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from collections import namedtuple
from typing import Callable, List, Union, Optional, Tuple

import numpy as np
from autograd import grad, jacobian

import meep as mp

from . import LDOS, DesignRegion, utils
from . import LDOS, DesignRegion, utils, ObjectiveQuantity


class OptimizationProblem:
Expand All @@ -24,21 +25,28 @@ class OptimizationProblem:

def __init__(
self,
simulation,
objective_functions,
objective_arguments,
design_regions,
frequencies=None,
fcen=None,
df=None,
nf=None,
decay_by=1e-11,
decimation_factor=0,
minimum_run_time=0,
maximum_run_time=None,
finite_difference_step=utils.FD_DEFAULT,
simulation: mp.Simulation,
objective_functions: List[Callable],
objective_arguments: List[ObjectiveQuantity],
design_regions: List[DesignRegion],
frequencies: Optional[Union[float, List[float]]] = None,
fcen: Optional[float] = None,
df: Optional[float] = None,
nf: Optional[int] = None,
decay_by: Optional[float] = 1e-11,
decimation_factor: Optional[int] = 0,
minimum_run_time: Optional[float] = 0,
maximum_run_time: Optional[float] = None,
finite_difference_step: Optional[float] = utils.FD_DEFAULT,
step_funcs: list = None,
):
"""
+ **`simulation` [ `Simulation` ]** — The corresponding Meep
`Simulation` object that describes the problem (e.g. sources,
geometry)
+ **`objective_functions` [ `list of ` ]** —
"""
self.step_funcs = step_funcs if step_funcs is not None else []
self.sim = simulation

Expand Down Expand Up @@ -108,7 +116,13 @@ def __init__(

self.gradient = []

def __call__(self, rho_vector=None, need_value=True, need_gradient=True, beta=None):
def __call__(
self,
rho_vector: List[List[float]] = None,
need_value: bool = True,
need_gradient: bool = True,
beta: float = None,
) -> Tuple[List[float], List[float]]:
"""Evaluate value and/or gradient of objective function."""
if rho_vector:
self.update_design(rho_vector=rho_vector, beta=beta)
Expand Down Expand Up @@ -140,7 +154,7 @@ def __call__(self, rho_vector=None, need_value=True, need_gradient=True, beta=No

return self.f0, self.gradient

def get_fdf_funcs(self):
def get_fdf_funcs(self) -> Tuple[Callable, Callable]:
"""construct callable functions for objective function value and gradient
Returns
Expand Down Expand Up @@ -306,11 +320,11 @@ def calculate_gradient(self):

def calculate_fd_gradient(
self,
num_gradients=1,
db=1e-4,
design_variables_idx=0,
filter=None,
):
num_gradients: int = 1,
db: float = 1e-4,
design_variables_idx: int = 0,
filter: Callable = None,
) -> List[float]:
"""
Estimate central difference gradients.
Expand Down Expand Up @@ -448,7 +462,7 @@ def calculate_fd_gradient(

return fd_gradient, fd_gradient_idx

def update_design(self, rho_vector, beta=None):
def update_design(self, rho_vector: List[float], beta: float = None) -> None:
"""Update the design permittivity function.
rho_vector ....... a list of numpy arrays that maps to each design region
Expand All @@ -465,11 +479,11 @@ def update_design(self, rho_vector, beta=None):
self.sim.reset_meep()
self.current_state = "INIT"

def get_objective_arguments(self):
def get_objective_arguments(self) -> List[float]:
"""Return list of evaluated objective arguments."""
return [m.get_evaluation() for m in self.objective_arguments]

def plot2D(self, init_opt=False, **kwargs):
def plot2D(self, init_opt=False, **kwargs) -> None:
"""Produce a graphical visualization of the geometry and/or fields,
as appropriately autodetermined based on the current state of
progress.
Expand Down

0 comments on commit 7f0469f

Please sign in to comment.