Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

automated DFT decimation for adjoint sources #1753

Merged
merged 9 commits into from
Oct 13, 2021
6 changes: 6 additions & 0 deletions doc/docs/Python_User_Interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -6506,6 +6506,7 @@ def __init__(self,
start_time=-1e+20,
end_time=1e+20,
center_frequency=0,
fwidth=0,
**kwargs):
```

Expand Down Expand Up @@ -6537,6 +6538,11 @@ Construct a `CustomSource`.
+ **`center_frequency` [`number`]** — Optional center frequency so that the
`CustomSource` can be used within an `EigenModeSource`. Defaults to 0.

+ **`fwidth` [`number`]** — Optional bandwidth in frequency units.
Default is 0. For bandwidth-limited sources, this parameter is used to
automatically determine the decimation factor of the time-series updates
of the DFT fields monitors (if any).

</div>

</div>
Expand Down
24 changes: 22 additions & 2 deletions python/adjoint/filter_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def __init__(
self.t = np.arange(0, dt * (self.N), dt)
self.n = np.arange(self.N)
f = self.func()

# frequency bandwidth of the Nuttall window function
fwidth = self.nuttall_bandwidth()

self.bf = [
lambda t, i=i: 0
if t > self.T else (self.nuttall(t, self.center_frequencies) /
Expand All @@ -36,7 +40,8 @@ def __init__(
CustomSource(src_func=bfi,
center_frequency=center_frequency,
is_integrated=False,
end_time=self.T) for bfi in self.bf
end_time=self.T,
fwidth=fwidth) for bfi in self.bf
]

if time_src:
Expand All @@ -58,7 +63,8 @@ def __init__(
super(FilteredSource, self).__init__(src_func=f,
center_frequency=center_frequency,
is_integrated=False,
end_time=self.T)
end_time=self.T,
fwidth=fwidth)

def cos_window_td(self, a, t, f0):
cos_sum = np.sum([(-1)**k * a[k] * np.cos(2 * np.pi * t * k / self.T)
Expand Down Expand Up @@ -103,6 +109,20 @@ def nuttall_dtft(self, f, f0):
a = [0.355768, 0.4873960, 0.144232, 0.012604]
return self.cos_window_fd(a, f, f0)

## compute the bandwidth of the DTFT of the Nuttall window function
## (magnitude) assuming it has decayed from its peak value by some
## tolerance by fitting it to an asymptotic power law of the form
## C / f^3 where C is a constant and f is the frequency
def nuttall_bandwidth(self):
tol = 1e-7
fwidth = 1/(self.N * self.dt)
frq_inf = 10000*fwidth
na_dtft = self.nuttall_dtft(frq_inf, 0)
coeff = frq_inf**3 * np.abs(na_dtft)
na_dtft_max = self.nuttall_dtft(0, 0)
bw = 2 * np.power(coeff / (tol * na_dtft_max), 1/3)
return bw.real

def dtft(self, y, f):
return np.matmul(
np.exp(1j * 2 * np.pi * f[:, np.newaxis] * np.arange(y.size) *
Expand Down
10 changes: 6 additions & 4 deletions python/meep-python.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ namespace meep {
class custom_py_src_time : public src_time {
public:
custom_py_src_time(PyObject *fun, double st = -infinity, double et = infinity,
std::complex<double> f = 0)
: func(fun), freq(f), start_time(float(st)), end_time(float(et)) {
std::complex<double> f = 0, double fw = 0)
: func(fun), freq(f), start_time(float(st)), end_time(float(et)), fwidth(fw) {
SWIG_PYTHON_THREAD_SCOPED_BLOCK;
Py_INCREF(func);
}
Expand Down Expand Up @@ -50,17 +50,19 @@ class custom_py_src_time : public src_time {
const custom_py_src_time *tp = dynamic_cast<const custom_py_src_time *>(&t);
if (tp)
return (tp->start_time == start_time && tp->end_time == end_time && tp->func == func &&
tp->freq == freq);
tp->freq == freq && tp->fwidth == fwidth);
else
return 0;
}
virtual std::complex<double> frequency() const { return freq; }
virtual void set_frequency(std::complex<double> f) { freq = f; }
virtual double get_fwidth() const { return fwidth; };
virtual void set_fwidth(double fw) { fwidth = fw; }

private:
PyObject *func;
std::complex<double> freq;
double start_time, end_time;
double start_time, end_time, fwidth;
};

} // namespace meep
11 changes: 9 additions & 2 deletions python/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ class CustomSource(SourceTime):
[`examples/chirped_pulse.py`](https://github.com/NanoComp/meep/blob/master/python/examples/chirped_pulse.py).
"""

def __init__(self, src_func, start_time=-1.0e20, end_time=1.0e20, center_frequency=0, **kwargs):
def __init__(self, src_func, start_time=-1.0e20, end_time=1.0e20, center_frequency=0, fwidth=0, **kwargs):
"""
Construct a `CustomSource`.

Expand All @@ -296,13 +296,20 @@ def __init__(self, src_func, start_time=-1.0e20, end_time=1.0e20, center_frequen

+ **`center_frequency` [`number`]** — Optional center frequency so that the
`CustomSource` can be used within an `EigenModeSource`. Defaults to 0.

+ **`fwidth` [`number`]** — Optional bandwidth in frequency units.
Default is 0. For bandwidth-limited sources, this parameter is used to
automatically determine the decimation factor of the time-series updates
of the DFT fields monitors (if any).
"""
super(CustomSource, self).__init__(**kwargs)
self.src_func = src_func
self.start_time = start_time
self.end_time = end_time
self.fwidth = fwidth
self.center_frequency = center_frequency
self.swigobj = mp.custom_py_src_time(src_func, start_time, end_time, center_frequency)
self.swigobj = mp.custom_py_src_time(src_func, start_time, end_time,
center_frequency, fwidth)
self.swigobj.is_integrated = self.is_integrated


Expand Down
35 changes: 20 additions & 15 deletions python/tests/test_adjoint_cyl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
design_region_resolution = int(2*resolution)
design_r = 4.8
design_z = 2
Nx = int(design_region_resolution*design_r)
Nz = int(design_region_resolution*design_z)
Nr = int(design_region_resolution*design_r) + 1
Nz = int(design_region_resolution*design_z) + 1

fcen = 1/1.55
width = 0.2
Expand All @@ -37,20 +37,20 @@
src = mp.GaussianSource(frequency=fcen,fwidth=fwidth)
source = [mp.Source(src,component=mp.Er,
center=source_center,
size=source_size)]
size=source_size)]

## random design region
p = np.random.rand(Nx*Nz)
p = np.random.rand(Nr*Nz)
## random epsilon perturbation for design region
deps = 1e-5
dp = deps*np.random.rand(Nx*Nz)
dp = deps*np.random.rand(Nr*Nz)


def forward_simulation(design_params):
matgrid = mp.MaterialGrid(mp.Vector3(Nx,0,Nz),
matgrid = mp.MaterialGrid(mp.Vector3(Nr,0,Nz),
SiO2,
Si,
weights=design_params.reshape(Nx,1,Nz))
weights=design_params.reshape(Nr,1,Nz))

geometry = [mp.Block(center=mp.Vector3(0.1+design_r/2,0,0),
size=mp.Vector3(design_r,0,design_z),
Expand All @@ -68,9 +68,8 @@ def forward_simulation(design_params):
far_x = [mp.Vector3(5,0,20)]
mode = sim.add_near2far(
frequencies,
mp.Near2FarRegion(center=mp.Vector3(0.1+design_r/2,0 ,(sz/2-dpml+design_z/2)/2),size=mp.Vector3(design_r,0,0), weight=+1),
decimation_factor=10
)
mp.Near2FarRegion(center=mp.Vector3(0.1+design_r/2,0,(sz/2-dpml+design_z/2)/2),
size=mp.Vector3(design_r,0,0),weight=+1))

sim.run(until_after_sources=1200)
Er = sim.get_farfield(mode, far_x[0])
Expand All @@ -81,9 +80,13 @@ def forward_simulation(design_params):

def adjoint_solver(design_params):

design_variables = mp.MaterialGrid(mp.Vector3(Nx,0,Nz),SiO2,Si)
design_region = mpa.DesignRegion(design_variables,volume=mp.Volume(center=mp.Vector3(0.1+design_r/2,0,0), size=mp.Vector3(design_r, 0,design_z)))
geometry = [mp.Block(center=design_region.center, size=design_region.size, material=design_variables)]
design_variables = mp.MaterialGrid(mp.Vector3(Nr,0,Nz),SiO2,Si)
design_region = mpa.DesignRegion(design_variables,
volume=mp.Volume(center=mp.Vector3(0.1+design_r/2,0,0),
size=mp.Vector3(design_r,0,design_z)))
geometry = [mp.Block(center=design_region.center,
size=design_region.size,
material=design_variables)]

sim = mp.Simulation(cell_size=cell_size,
boundary_layers=boundary_layers,
Expand All @@ -94,8 +97,10 @@ def adjoint_solver(design_params):
m=m)

far_x = [mp.Vector3(5,0,20)]
NearRegions = [mp.Near2FarRegion(center=mp.Vector3(0.1+design_r/2,0 ,(sz/2-dpml+design_z/2)/2),size=mp.Vector3(design_r,0,0), weight=+1)]
FarFields = mpa.Near2FarFields(sim, NearRegions ,far_x, decimation_factor=5)
NearRegions = [mp.Near2FarRegion(center=mp.Vector3(0.1+design_r/2,0,(sz/2-dpml+design_z/2)/2),
size=mp.Vector3(design_r,0,0),
weight=+1)]
FarFields = mpa.Near2FarFields(sim, NearRegions ,far_x)
ob_list = [FarFields]

def J(alpha):
Expand Down
3 changes: 1 addition & 2 deletions python/tests/test_adjoint_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,7 @@ def build_straight_wg_simulation(
mpa.EigenmodeCoefficient(simulation,
mp.Volume(center=center, size=monitor_size),
mode=1,
forward=forward,
decimation_factor=5)
forward=forward)
for center in monitor_centers for forward in [True, False]
]
return simulation, sources, monitors, design_regions, frequencies
Expand Down
17 changes: 6 additions & 11 deletions python/tests/test_adjoint_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,14 @@ def forward_simulation(design_params,mon_type, frequencies=None, use_complex=Fal
mp.ModeRegion(center=mp.Vector3(0.5*sxy-dpml-0.1),
size=mp.Vector3(0,sxy-2*dpml,0)),
yee_grid=True,
decimation_factor=10,
eig_parity=eig_parity)

elif mon_type.name == 'DFT':
mode = sim.add_dft_fields([mp.Ez],
frequencies,
center=mp.Vector3(1.25),
size=mp.Vector3(0.25,1,0),
yee_grid=False,
decimation_factor=10)
yee_grid=False)

sim.run(until_after_sources=mp.stop_when_dft_decayed())

Expand Down Expand Up @@ -145,7 +143,6 @@ def adjoint_solver(design_params, mon_type, frequencies=None, use_complex=False,
mp.Volume(center=mp.Vector3(0.5*sxy-dpml-0.1),
size=mp.Vector3(0,sxy-2*dpml,0)),
1,
decimation_factor=5,
eig_parity=eig_parity)]

def J(mode_mon):
Expand All @@ -155,8 +152,7 @@ def J(mode_mon):
obj_list = [mpa.FourierFields(sim,
mp.Volume(center=mp.Vector3(1.25),
size=mp.Vector3(0.25,1,0)),
mp.Ez,
decimation_factor=5)]
mp.Ez)]

def J(mode_mon):
return npa.power(npa.abs(mode_mon[:,4,10]),2)
Expand All @@ -166,8 +162,7 @@ def J(mode_mon):
objective_functions=J,
objective_arguments=obj_list,
design_regions=[matgrid_region],
frequencies=frequencies,
decimation_factor=10)
frequencies=frequencies)

f, dJ_du = opt([design_params])

Expand Down Expand Up @@ -213,7 +208,7 @@ def test_adjoint_solver_DFT_fields(self):
adj_scale = (dp[None,:]@adjsol_grad).flatten()
fd_grad = S12_perturbed-S12_unperturbed
print("Directional derivative -- adjoint solver: {}, FD: {}".format(adj_scale,fd_grad))
tol = 0.04 if mp.is_single_precision() else 0.005
tol = 0.0461 if mp.is_single_precision() else 0.005
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the one place where the tolerance needed to be slightly adjusted.

self.assertClose(adj_scale,fd_grad,epsilon=tol)


Expand Down Expand Up @@ -267,14 +262,14 @@ def test_gradient_backpropagation(self):
bp_adjsol_grad = tensor_jacobian_product(mapping,0)(p,filter_radius,eta,beta,adjsol_grad)

## compute unperturbed S12
S12_unperturbed = forward_simulation(mapped_p, MonitorObject.EIGENMODE,frequencies)
S12_unperturbed = forward_simulation(mapped_p,MonitorObject.EIGENMODE,frequencies)

## compare objective results
print("S12 -- adjoint solver: {}, traditional simulation: {}".format(adjsol_obj,S12_unperturbed))
self.assertClose(adjsol_obj,S12_unperturbed,epsilon=1e-6)

## compute perturbed S12
S12_perturbed = forward_simulation(mapping(p+dp,filter_radius,eta,beta), MonitorObject.EIGENMODE,frequencies)
S12_perturbed = forward_simulation(mapping(p+dp,filter_radius,eta,beta),MonitorObject.EIGENMODE,frequencies)

if bp_adjsol_grad.ndim < 2:
bp_adjsol_grad = np.expand_dims(bp_adjsol_grad,axis=1)
Expand Down
7 changes: 3 additions & 4 deletions src/dft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,12 @@ dft_chunk *fields::add_dft(component c, const volume &where, const double *freq,
data.vc = vc;

if (decimation_factor == 0) {
double tol = 1e-7;
double src_freq_max = 0;
for (src_time *s = sources; s; s = s->next) {
if (s->get_fwidth(tol) == 0)
if (s->get_fwidth() == 0)
decimation_factor = 1;
else
src_freq_max = std::max(src_freq_max, std::abs(s->frequency().real())+0.5*s->get_fwidth(tol));
src_freq_max = std::max(src_freq_max, std::abs(s->frequency().real())+0.5*s->get_fwidth());
}
double freq_max = 0;
for (size_t i = 0; i < Nfreq; ++i)
Expand Down Expand Up @@ -1376,4 +1375,4 @@ void fields::get_mode_mode_overlap(void *mode1_data, void *mode2_data, dft_flux
get_overlap(mode1_data, mode2_data, flux, 0, overlaps);
}

} // namespace meep
} // namespace meep
19 changes: 11 additions & 8 deletions src/meep.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -988,7 +988,8 @@ class src_time {
return 1;
}
virtual std::complex<double> frequency() const { return 0.0; }
virtual double get_fwidth(double tol) const { (void)tol; return 0.0; }
virtual double get_fwidth() const { return 0.0; }
virtual void set_fwidth(double fw) { (void)fw; }
virtual void set_frequency(std::complex<double> f) { (void)f; }

private:
Expand All @@ -1010,12 +1011,13 @@ class gaussian_src_time : public src_time {
virtual src_time *clone() const { return new gaussian_src_time(*this); }
virtual bool is_equal(const src_time &t) const;
virtual std::complex<double> frequency() const { return freq; }
virtual double get_fwidth(double tol) const;
virtual double get_fwidth() const { return fwidth; };
virtual void set_fwidth(double fw) { fwidth = fw; };
virtual void set_frequency(std::complex<double> f) { freq = real(f); }
std::complex<double> fourier_transform(const double f);

private:
double freq, width, peak_time, cutoff;
double freq, fwidth, width, peak_time, cutoff;
};

// Continuous (CW) source with (optional) slow turn-on and/or turn-off.
Expand All @@ -1031,7 +1033,7 @@ class continuous_src_time : public src_time {
virtual src_time *clone() const { return new continuous_src_time(*this); }
virtual bool is_equal(const src_time &t) const;
virtual std::complex<double> frequency() const { return freq; }
virtual double get_fwidth(double tol) const { (void)tol; return 0.0; };
virtual double get_fwidth() const { return 0.0; };
virtual void set_frequency(std::complex<double> f) { freq = f; }

private:
Expand All @@ -1043,8 +1045,8 @@ class continuous_src_time : public src_time {
class custom_src_time : public src_time {
public:
custom_src_time(std::complex<double> (*func)(double t, void *), void *data, double st = -infinity,
double et = infinity, std::complex<double> f = 0)
: func(func), data(data), freq(f), start_time(float(st)), end_time(float(et)) {}
double et = infinity, std::complex<double> f = 0, double fw = 0)
: func(func), data(data), freq(f), start_time(float(st)), end_time(float(et)), fwidth(fw) {}
virtual ~custom_src_time() {}

virtual std::complex<double> current(double time, double dt) const {
Expand All @@ -1064,14 +1066,15 @@ class custom_src_time : public src_time {
virtual src_time *clone() const { return new custom_src_time(*this); }
virtual bool is_equal(const src_time &t) const;
virtual std::complex<double> frequency() const { return freq; }
virtual double get_fwidth(double tol) const { (void)tol; return 0.0; };
virtual void set_frequency(std::complex<double> f) { freq = f; }
virtual double get_fwidth() const { return fwidth; };
virtual void set_fwidth(double fw) { fwidth = fw; }

private:
std::complex<double> (*func)(double t, void *);
void *data;
std::complex<double> freq;
double start_time, end_time;
double start_time, end_time, fwidth;
};

class monitor_point {
Expand Down
Loading