Skip to content

Commit

Permalink
add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
mawc2019 committed Jun 17, 2022
1 parent bc70d31 commit a816bc9
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions python/tests/test_adjoint_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def forward_simulation_complex_fields(design_params, frequencies=None):
material=matgrid)]

sim = mp.Simulation(resolution=resolution,
cell_size=cell_size,
cell_size=cell_size,default_material=silicon,
k_point=k_point,
boundary_layers=pml_x,
sources=pt_source,
Expand All @@ -213,23 +213,23 @@ def forward_simulation_complex_fields(design_params, frequencies=None):
if not frequencies:
frequencies = [fcen]

mode = sim.add_dft_fields([mp.Ez],
mode = sim.add_dft_fields([mp.Dz],
frequencies,
center=mp.Vector3(0.9),
size=mp.Vector3(0.2,0.5),
yee_grid=False)

sim.run(until_after_sources=mp.stop_when_dft_decayed())

Ez2 = []
Dz2 = []
for f in range(len(frequencies)):
Ez_dft = sim.get_dft_array(mode, mp.Ez, f)
Ez2.append(np.power(np.abs(Ez_dft[3,9]),2))
Ez2 = np.array(Ez2)
Dz_dft = sim.get_dft_array(mode, mp.Dz, f)
Dz2.append(np.power(np.abs(Dz_dft[3,9]),2))
Dz2 = np.array(Dz2)

sim.reset_meep()

return Ez2
return Dz2


def adjoint_solver_complex_fields(design_params, frequencies=None):
Expand All @@ -249,7 +249,7 @@ def adjoint_solver_complex_fields(design_params, frequencies=None):
material=matgrid)]

sim = mp.Simulation(resolution=resolution,
cell_size=cell_size,
cell_size=cell_size,default_material=silicon,
k_point=k_point,
boundary_layers=pml_x,
sources=pt_source,
Expand All @@ -261,7 +261,7 @@ def adjoint_solver_complex_fields(design_params, frequencies=None):
obj_list = [mpa.FourierFields(sim,
mp.Volume(center=mp.Vector3(0.9),
size=mp.Vector3(0.2,0.5)),
mp.Ez)]
mp.Dz)]

def J(dft_mon):
return npa.power(npa.abs(dft_mon[:,3,9]),2)
Expand Down Expand Up @@ -505,21 +505,21 @@ def test_complex_fields(self):
## compute gradient using adjoint solver
adjsol_obj, adjsol_grad = adjoint_solver_complex_fields(p, frequencies)

## compute unperturbed |Ez|^2
Ez2_unperturbed = forward_simulation_complex_fields(p, frequencies)
## compute unperturbed |Dz|^2
Dz2_unperturbed = forward_simulation_complex_fields(p, frequencies)

## compare objective results
print("Ez2 -- adjoint solver: {}, traditional simulation: {}".format(adjsol_obj,Ez2_unperturbed))
self.assertClose(adjsol_obj,Ez2_unperturbed,epsilon=1e-6)
print("Dz2 -- adjoint solver: {}, traditional simulation: {}".format(adjsol_obj,Dz2_unperturbed))
self.assertClose(adjsol_obj,Dz2_unperturbed,epsilon=1e-6)

## compute perturbed |Ez|^2
Ez2_perturbed = forward_simulation_complex_fields(p+dp, frequencies)
## compute perturbed |Dz|^2
Dz2_perturbed = forward_simulation_complex_fields(p+dp, frequencies)

## compare gradients
if adjsol_grad.ndim < 2:
adjsol_grad = np.expand_dims(adjsol_grad,axis=1)
adj_scale = (dp[None,:]@adjsol_grad).flatten()
fd_grad = Ez2_perturbed-Ez2_unperturbed
fd_grad = Dz2_perturbed-Dz2_unperturbed
print("Directional derivative -- adjoint solver: {}, FD: {}".format(adj_scale,fd_grad))
tol = 0.018 if mp.is_single_precision() else 0.002
self.assertClose(adj_scale,fd_grad,epsilon=tol)
Expand Down

0 comments on commit a816bc9

Please sign in to comment.