-
Notifications
You must be signed in to change notification settings - Fork 634
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: revamped adjoint filters #1625
base: master
Are you sure you want to change the base?
Conversation
I have been trying to use the new code. I don't know if I missed something, but currently conic_filter(x,radius) gives me nan. Specifically, meep/python/adjoint/filters.py Lines 236 to 239 in 8903b50
Thus, |
No, the The input array |
Thanks! So for |
No, that function remains unchanged. So the user must manually multiply the The idea behind the new filtering API is to keep things as simple as possible. |
Thanks! |
@@ -3,14 +3,15 @@ | |||
""" | |||
|
|||
import numpy as np | |||
from autograd import numpy as npa | |||
import jax | |||
from jax import numpy as npj |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just FYI- convention is import jax.numpy as jnp
rather than npj
: https://jax.readthedocs.io/en/latest/notebooks/quickstart.html
Can we avoid that? As much as possible, I'd like distances to be specified in "real" units. I guess the problem with filters is that the same design variables (material grid) can be applied to multiple objects, and so effectively have different resolutions in the same simulation. |
I agree, but it comes at the cost of a more complicated API with lots more parameters to pass (and generally less flexibility).
Exactly. I think it's better for the user to handle things on their side so that any required bookkeeping doesn't restrict any applications. |
axis=1) | ||
|
||
return out | ||
def atleast_3d(ary): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there are already two other versions of a similar function for this (one that I added in utils and one that was already in another module). It might be worth consolidating these into a single function?
In some places it will be used for standard numpy arrays and here you're using it for jax arrays. I assume you'll want to track gradients through this, so I think to support both numpy and jax arrays, the var.reshape(...)
approach could be used. This could work with a slight modification to:
Lines 15 to 17 in 4d27231
def _make_at_least_nd(x: onp.ndarray, dims: int = 3) -> onp.ndarray: | |
"""Makes an array have at least the specified number of dimensions.""" | |
return onp.reshape(x, x.shape + onp.maximum(dims - x.ndim, 0) * (1, )) |
It seems that the Near2FarFields now fails. Specifically, the adjoint source amplitudes ( meep/python/adjoint/objective.py Lines 333 to 338 in 74c9752
However, in optimization_problem.py , if I import jacobian from autograd instead of jax , it would work as usual. The values of the jacobian dJ are the same, and I tried to convert the jax jacobian to a numpy array with jnp.asarray , but it doesn't make a difference.
|
Looks like this is a result of earlier PRs that transitioned the solver to jax. (Note the only actual additions offered by this PR are in What I typically do is use autograd for the main adjoint solver, but use jax for the filter functions used here. It's not ideal, but easier than fiddling with jax. Although more recently I had to abandon jax altogether to get it to work on multi-node systems. In an ideal world, we would have a package consisting of jax's autodiff function library, without losing autograd's simplicity. |
Thanks Alec! |
Any update on this? |
I bet this issue is due to the off-by-one bug we've been fixing in various PRs (#1769, #1760). Also, all our filters are separable, which means we can actually just perform a simple 1D filter with the corresponding 1D kernel in each dimension. This not only speeds things up significantly, but should clean up the code quite a bit too (we can keep the full |
Updates all the adjoint filters/projections to:
jax
rather thanautograd
.TODO
@mochen4 to pull into an existing dev branch, make sure you rebase on top of
master
, as described here.Note: it looks that while rebasing, not all of the formatting changes were correctly committed (hence some of the extraneous diffs). I'll clean those up too.