Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non-materialized dft fields for adjoint calculations #1855

Merged
merged 29 commits into from
Apr 14, 2022
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
bfd8633
init
smartalecH Dec 13, 2021
96f30c7
checkpoint
smartalecH Dec 28, 2021
f507d56
work with no smoothing, small errors
smartalecH Jan 3, 2022
39672b3
no smoothing, all diag now works completely
smartalecH Jan 4, 2022
5d914f9
everything is in place, just need to debug the subvolume routine
smartalecH Jan 5, 2022
07d4ff2
rebase
smartalecH Jan 5, 2022
b67db7e
consistent with complex types
smartalecH Jan 5, 2022
6c7a1e4
center issue
smartalecH Jan 5, 2022
b00279a
Merge branch 'master' of https://github.com/NanoComp/meep into dist_dft
smartalecH Jan 5, 2022
33f338d
minor fixes
smartalecH Jan 6, 2022
9308a4f
pass tests locally
smartalecH Jan 17, 2022
3b61a98
restore geom.py averaging default and rebase
smartalecH Jan 17, 2022
8946ac4
get single proc working for broadband by fixing checkpointing
smartalecH Jan 19, 2022
546b0d8
update visualization to work with new pyplot
smartalecH Feb 10, 2022
85387af
rebase
smartalecH Mar 17, 2022
9c732e6
Merge branch 'dist_dft' of https://github.com/smartalecH/meep into di…
smartalecH Mar 21, 2022
d4f6abb
Merge branch 'master' of https://github.com/NanoComp/meep into dist_dft
smartalecH Mar 21, 2022
da5fc4f
make final fixes, like formatting and remove expand
smartalecH Mar 21, 2022
8bc8bc6
change dft array to vector and fix bounds errors
smartalecH Apr 6, 2022
933a10b
remove vector dft fields
smartalecH Apr 6, 2022
e94d855
finish reverting
smartalecH Apr 7, 2022
d194012
raise tolerance and refactor using macros
smartalecH Apr 7, 2022
039a2cb
whoops
smartalecH Apr 7, 2022
d4bd08f
Merge branch 'master' of https://github.com/NanoComp/meep into dist_dft
smartalecH Apr 9, 2022
2d5d807
increase tolerance
smartalecH Apr 9, 2022
9aa8296
Merge branch 'dist_dft' of https://github.com/smartalecH/meep into di…
smartalecH Apr 9, 2022
5ee390e
add filtering to adjoint test
smartalecH Apr 12, 2022
12ff030
minor refactoring and test source change
smartalecH Apr 12, 2022
2671aaa
relax cylindrical adjoint test
smartalecH Apr 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/adjoint/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def place_adjoint_source(self, dJ):
*(np.eye(3)[self._monitor.normal_direction] *
np.abs(center_frequency)))
eig_kpoint = -1 * direction if self.forward else direction

if self._frequencies.size == 1:
amp = da_dE * dJ * scale
src = time_src
Expand Down
29 changes: 12 additions & 17 deletions python/adjoint/optimization_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,11 @@ def __call__(self, rho_vector=None, need_value=True, need_gradient=True, beta=No
print("Starting forward run...")
self.forward_run()
print("Starting adjoint run...")
self.D_a = []
self.adjoint_run()
print("Calculating gradient...")
self.calculate_gradient()
elif self.current_state == "FWD":
print("Starting adjoint run...")
self.D_a = []
self.adjoint_run()
print("Calculating gradient...")
self.calculate_gradient()
Expand Down Expand Up @@ -168,7 +166,7 @@ def prepare_forward_run(self):
self.forward_monitors.append(m.register_monitors(self.frequencies))

# register design region
self.design_region_monitors = utils.install_design_region_monitors(
self.forward_design_region_monitors = utils.install_design_region_monitors(
self.sim, self.design_regions, self.frequencies, self.decimation_factor
)

Expand All @@ -192,10 +190,7 @@ def forward_run(self):
self.f0 = [fi(*self.results_list) for fi in self.objective_functions]
if len(self.f0) == 1:
self.f0 = self.f0[0]

# Store forward fields for each set of design variables in array
self.D_f = utils.gather_design_region_fields(self.sim,self.design_region_monitors,self.frequencies)


# store objective function evaluation in memory
self.f_bank.append(self.f0)

Expand Down Expand Up @@ -225,19 +220,22 @@ def adjoint_run(self):

# flip the k point
if self.sim.k_point:
self.sim.k_point *= -1
self.sim.change_k_point(-1*self.sim.k_point)

self.adjoint_design_region_monitors = []
for ar in range(len(self.objective_functions)):
# Reset the fields
self.sim.reset_meep()
self.sim.restart_fields()
self.sim.clear_dft_monitors()

# Update the sources
self.sim.change_sources(self.adjoint_sources[ar])

# register design flux
self.design_region_monitors = utils.install_design_region_monitors(
# register design dft fields
self.adjoint_design_region_monitors.append(utils.install_design_region_monitors(
self.sim, self.design_regions, self.frequencies, self.decimation_factor
)
))
self.sim._evaluate_dft_objects()

# Adjoint run
self.sim.run(until_after_sources=mp.stop_when_dft_decayed(
Expand All @@ -246,9 +244,6 @@ def adjoint_run(self):
self.maximum_run_time
))

# Store adjoint fields for each design set of design variables
self.D_a.append(utils.gather_design_region_fields(self.sim,self.design_region_monitors,self.frequencies))

# reset the m number
if utils._check_if_cylindrical(self.sim):
self.sim.m = -self.sim.m
Expand All @@ -264,8 +259,8 @@ def calculate_gradient(self):
self.gradient = [[
dr.get_gradient(
self.sim,
self.D_a[ar][dri],
self.D_f[dri],
self.adjoint_design_region_monitors[ar][dri],
self.forward_design_region_monitors[dri],
self.frequencies,
self.finite_difference_step
) for dri, dr in enumerate(self.design_regions)
Expand Down
57 changes: 12 additions & 45 deletions python/adjoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,53 +43,19 @@ def update_beta(self,beta):

def get_gradient(self, sim, fields_a, fields_f, frequencies, finite_difference_step):
num_freqs = onp.array(frequencies).size
shapes = []
'''We have the option to linearly scale the gradients up front
using the scalegrad parameter (leftover from MPB API). Not
currently needed for any existing feature (but available for
future use)'''
scalegrad = 1
for component_idx, component in enumerate(_compute_components(sim)):
'''We need to correct for the rare cases that get_dft_array
returns a singleton element for the forward and adjoint fields.
This only occurs when we are in 2D and only working with a particular
polarization (as the other fields are never stored). For example, the
2D in-plane polarization consists of a single scalar Ez field
(technically, meep doesn't store anything for these cases, but get_dft_array
still returns 0).

Our get_gradient algorithm, however, requires we pass an array of
zeros with the proper shape as the design_region.'''
spatial_shape = sim.get_array_slice_dimensions(component, vol=self.volume)[0]
if (fields_a[component_idx][0,...].size == 1):
fields_a[component_idx] = onp.zeros(onp.insert(spatial_shape,0,num_freqs),
dtype=onp.complex64 if mp.is_single_precision() else onp.complex128)
fields_f[component_idx] = onp.zeros(onp.insert(spatial_shape,0,num_freqs),
dtype=onp.complex64 if mp.is_single_precision() else onp.complex128)
if _check_if_cylindrical(sim):
'''For some reason, get_dft_array returns the field
components in a different order than the convention used
throughout meep. So, we reorder them here so we can use
the same field macros later in our get_gradient function.
'''
fields_a[component_idx] = onp.transpose(fields_a[component_idx],(_FREQ_AXIS,_RHO_AXIS,_PHI_AXIS,_Z_AXIS))
fields_f[component_idx] = onp.transpose(fields_f[component_idx],(_FREQ_AXIS,_RHO_AXIS,_PHI_AXIS,_Z_AXIS))
shapes.append(fields_a[component_idx].shape)
fields_a[component_idx] = fields_a[component_idx].flatten(order='C')
fields_f[component_idx] = fields_f[component_idx].flatten(order='C')
shapes = onp.asarray(shapes).flatten(order='C')
fields_a = onp.concatenate(fields_a)
fields_f = onp.concatenate(fields_f)

grad = onp.zeros((num_freqs, self.num_design_params)) # preallocate
geom_list = sim.geometry
f = sim.fields
vol = sim._fit_volume_to_simulation(self.volume)

# compute the gradient
mp._get_gradient(grad,scalegrad,fields_a,fields_f,
sim.gv,vol.swigobj,onp.array(frequencies),
sim.geps,shapes,finite_difference_step)
mp._get_gradient(grad,scalegrad,
fields_a[0].swigobj,fields_a[1].swigobj,fields_a[2].swigobj,
fields_f[0].swigobj,fields_f[1].swigobj,fields_f[2].swigobj,
sim.gv,vol.swigobj,onp.array(frequencies),
sim.geps,finite_difference_step)
return onp.squeeze(grad).T

def _check_if_cylindrical(sim):
Expand Down Expand Up @@ -149,18 +115,19 @@ def install_design_region_monitors(
simulation: mp.Simulation,
design_regions: List[DesignRegion],
frequencies: List[float],
decimation_factor: int = 0
decimation_factor: int = 0,
) -> List[mp.DftFields]:
"""Installs DFT field monitors at the design regions of the simulation."""
design_region_monitors = [
design_region_monitors = [[
simulation.add_dft_fields(
_compute_components(simulation),
[comp],
frequencies,
where=design_region.volume,
yee_grid=True,
decimation_factor=decimation_factor
) for design_region in design_regions
]
decimation_factor=decimation_factor,
persist=True
) for comp in _compute_components(simulation)
] for design_region in design_regions ]
return design_region_monitors


Expand Down
36 changes: 13 additions & 23 deletions python/adjoint/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _run_fwd_simulation(self, design_variables):
self.simulation.reset_meep()
self.simulation.change_sources(self.sources)
utils.register_monitors(self.monitors, self.frequencies)
design_region_monitors = utils.install_design_region_monitors(
fwd_design_region_monitors = utils.install_design_region_monitors(
self.simulation,
self.design_regions,
self.frequencies,
Expand All @@ -156,13 +156,8 @@ def _run_fwd_simulation(self, design_variables):
self.simulation.run(**sim_run_args)

monitor_values = utils.gather_monitor_values(self.monitors)
fwd_fields = utils.gather_design_region_fields(
self.simulation,
design_region_monitors,
self.frequencies,
)
return (jnp.asarray(monitor_values),
jax.tree_map(jnp.asarray, fwd_fields))
fwd_design_region_monitors)
smartalecH marked this conversation as resolved.
Show resolved Hide resolved

def _run_adjoint_simulation(self, monitor_values_grad):
"""Runs adjoint simulation, returning design region fields."""
Expand All @@ -172,9 +167,12 @@ def _run_adjoint_simulation(self, monitor_values_grad):
'regions are present.')
adjoint_sources = utils.create_adjoint_sources(self.monitors,
monitor_values_grad)
self.simulation.reset_meep()
# TODO refactor with optimization_problem.py #
self.simulation.restart_fields()
self.simulation.clear_dft_monitors()
self.simulation.change_sources(adjoint_sources)
design_region_monitors = utils.install_design_region_monitors(
# #
smartalecH marked this conversation as resolved.
Show resolved Hide resolved
adj_design_region_monitors = utils.install_design_region_monitors(
self.simulation,
self.design_regions,
self.frequencies,
Expand All @@ -186,11 +184,7 @@ def _run_adjoint_simulation(self, monitor_values_grad):
}
self.simulation.run(**sim_run_args)

return utils.gather_design_region_fields(
self.simulation,
design_region_monitors,
self.frequencies,
)
return adj_design_region_monitors

def _calculate_vjps(
self,
Expand Down Expand Up @@ -221,20 +215,16 @@ def simulate(design_variables: List[jnp.ndarray]) -> jnp.ndarray:

def _simulate_fwd(design_variables):
"""Runs forward simulation, returning monitor values and fields."""
monitor_values, fwd_fields = self._run_fwd_simulation(
monitor_values, self.fwd_design_region_monitors = self._run_fwd_simulation(
design_variables)
design_variable_shapes = [x.shape for x in design_variables]
return monitor_values, (fwd_fields, design_variable_shapes)
return monitor_values, (design_variable_shapes)

def _simulate_rev(res, monitor_values_grad):
"""Runs adjoint simulation, returning VJP of design wrt monitor values."""
fwd_fields = jax.tree_map(
lambda x: onp.asarray(x,
dtype=onp.complex64 if mp.is_single_precision() else onp.complex128),
res[0])
design_variable_shapes = res[1]
adj_fields = self._run_adjoint_simulation(monitor_values_grad)
vjps = self._calculate_vjps(fwd_fields, adj_fields,
design_variable_shapes = res
self.adj_design_region_monitors = self._run_adjoint_simulation(monitor_values_grad)
vjps = self._calculate_vjps(self.fwd_design_region_monitors, self.adj_design_region_monitors,
design_variable_shapes)
return ([jnp.asarray(vjp) for vjp in vjps], )

Expand Down
2 changes: 1 addition & 1 deletion python/geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def __init__(self,
medium2,
weights=None,
grid_type="U_DEFAULT",
do_averaging=False,
do_averaging=True,
smartalecH marked this conversation as resolved.
Show resolved Hide resolved
beta=0,
eta=0.5,
damping=0):
Expand Down
34 changes: 12 additions & 22 deletions python/meep.i
Original file line number Diff line number Diff line change
Expand Up @@ -843,9 +843,12 @@ meep::volume_list *make_volume_list(const meep::volume &v, int c,
//--------------------------------------------------

%inline %{
void _get_gradient(PyObject *grad, double scalegrad, PyObject *fields_a, PyObject *fields_f,
void _get_gradient(PyObject *grad, double scalegrad,
meep::dft_fields *fields_a_0, meep::dft_fields *fields_a_1, meep::dft_fields *fields_a_2,
meep::dft_fields *fields_f_0, meep::dft_fields *fields_f_1, meep::dft_fields *fields_f_2,
meep::grid_volume *grid_volume, meep::volume *where, PyObject *frequencies,
meep_geom::geom_epsilon *geps, PyObject *fields_shapes, double fd_step) {
meep_geom::geom_epsilon *geps, double fd_step) {

// clean the gradient array
PyArrayObject *pao_grad = (PyArrayObject *)grad;
if (!PyArray_Check(pao_grad)) meep::abort("grad parameter must be numpy array.");
Expand All @@ -854,25 +857,11 @@ void _get_gradient(PyObject *grad, double scalegrad, PyObject *fields_a, PyObjec
double *grad_c = (double *)PyArray_DATA(pao_grad);
npy_intp ng = PyArray_DIMS(pao_grad)[1]; // number of design parameters

// clean the adjoint fields array
PyArrayObject *pao_fields_a = (PyArrayObject *)fields_a;
if (!PyArray_Check(pao_fields_a)) meep::abort("adjoint fields parameter must be numpy array.");
if (!PyArray_ISCARRAY(pao_fields_a)) meep::abort("Numpy adjoint fields array must be C-style contiguous.");
if (PyArray_NDIM(pao_fields_a) !=1) {meep::abort("Numpy adjoint fields array must have 1 dimension.");}
std::complex<meep::realnum> *fields_a_c = (std::complex<meep::realnum> *)PyArray_DATA(pao_fields_a);

// clean the forward fields array
PyArrayObject *pao_fields_f = (PyArrayObject *)fields_f;
if (!PyArray_Check(pao_fields_f)) meep::abort("forward fields parameter must be numpy array.");
if (!PyArray_ISCARRAY(pao_fields_f)) meep::abort("Numpy forward fields array must be C-style contiguous.");
if (PyArray_NDIM(pao_fields_f) !=1) {meep::abort("Numpy forward fields array must have 1 dimension.");}
std::complex<meep::realnum> *fields_f_c = (std::complex<meep::realnum> *)PyArray_DATA(pao_fields_f);

// clean shapes array
PyArrayObject *pao_fields_shapes = (PyArrayObject *)fields_shapes;
if (!PyArray_Check(pao_fields_shapes)) meep::abort("fields shape parameter must be numpy array.");
if (!PyArray_ISCARRAY(pao_fields_shapes)) meep::abort("Numpy fields shape array must be C-style contiguous.");
size_t *fields_shapes_c = (size_t *)PyArray_DATA(pao_fields_shapes);
// clean the adjoint fields object
std::vector<meep::dft_fields *> adjoint_fields = {fields_a_0,fields_a_1,fields_a_2};

// clean the forward fields object
std::vector<meep::dft_fields *> forward_fields = {fields_f_0,fields_f_1,fields_f_2};

// clean the frequencies array
PyArrayObject *pao_freqs = (PyArrayObject *)frequencies;
Expand All @@ -883,7 +872,8 @@ void _get_gradient(PyObject *grad, double scalegrad, PyObject *fields_a, PyObjec
if (PyArray_DIMS(pao_grad)[0] != nf) meep::abort("Numpy grad array is allocated for %td frequencies; it should be allocated for %td.",PyArray_DIMS(pao_grad)[0],nf);

// calculate the gradient
meep_geom::material_grids_addgradient(grad_c,ng,fields_a_c,fields_f_c,fields_shapes_c,frequencies_c,scalegrad,*grid_volume,*where,geps,fd_step);
meep_geom::material_grids_addgradient(grad_c,ng,nf,adjoint_fields,forward_fields,frequencies_c,scalegrad,*grid_volume,*where,geps,fd_step);

}
%}

Expand Down
21 changes: 17 additions & 4 deletions python/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2543,9 +2543,10 @@ def _evaluate_dft_objects(self):
if dft.swigobj is None:
dft.swigobj = dft.func(*dft.args)


def add_dft_fields(self, *args, **kwargs):
"""
`add_dft_fields(cs, fcen, df, nfreq, freq, where=None, center=None, size=None, yee_grid=False, decimation_factor=0)` ##sig
`add_dft_fields(cs, fcen, df, nfreq, freq, where=None, center=None, size=None, yee_grid=False, decimation_factor=0, persist=False)` ##sig

Given a list of field components `cs`, compute the Fourier transform of these
fields for `nfreq` equally spaced frequencies covering the frequency range
Expand All @@ -2572,26 +2573,27 @@ def add_dft_fields(self, *args, **kwargs):
size = kwargs.get('size', None)
yee_grid = kwargs.get('yee_grid', False)
decimation_factor = kwargs.get('decimation_factor', 0)
persist = kwargs.get('persist', False)
center_v3 = Vector3(*center) if center is not None else None
size_v3 = Vector3(*size) if size is not None else None
use_centered_grid = not yee_grid
dftf = DftFields(self._add_dft_fields, [
components, where, center_v3, size_v3, freq, use_centered_grid,
decimation_factor
decimation_factor,persist
])
self.dft_objects.append(dftf)
return dftf

def _add_dft_fields(self, components, where, center, size, freq,
use_centered_grid, decimation_factor):
use_centered_grid, decimation_factor, persist):
if self.fields is None:
self.init_sim()
try:
where = self._volume_from_kwargs(where, center, size)
except ValueError:
where = self.fields.total_volume()
return self.fields.add_dft_fields(components, where, freq,
use_centered_grid, decimation_factor)
use_centered_grid, decimation_factor, persist)

def output_dft(self, dft_fields, fname):
"""
Expand Down Expand Up @@ -3828,6 +3830,17 @@ def restart_fields(self):
else:
self._is_initialized = False
self.init_sim()

def clear_dft_monitors(self):
"""
Remove all of the dft monitors from the simulation.
"""
for m in self.dft_objects:
if not (isinstance(m,DftFields) and (m.chunks) and (m.chunks.persist)):
m.remove()
self.fields.clear_dft_monitors()

self.dft_objects = []

def run(self, *step_funcs, **kwargs):
"""
Expand Down
Loading