Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support selection by nearest value in JaxDataArray #1671

Merged
merged 1 commit into from
May 3, 2024

Conversation

yaugenst-flex
Copy link
Contributor

Closes #1617

Also took the liberty to rework the mechanics of the old selection by vectorizing everything using numpy.
Note that there is a slight difference between xarray's exact selection and ours now, because I base selections off of np.isclose instead of doing exact float comparisions (x == y). I think this is more intuitive, as otherwise something like x == 0.2 fails if x is a numpy float (0.20000000000000018).
I can change this back to match xarray's behavior exactly if that's what we want to do.

@yaugenst-flex yaugenst-flex marked this pull request as draft May 3, 2024 10:43
@yaugenst-flex yaugenst-flex removed the request for review from tylerflex May 3, 2024 10:44
Copy link
Collaborator

@tylerflex tylerflex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just some preliminary comments, looking good though

@@ -312,28 +313,37 @@ def isel(self, **isel_kwargs) -> JaxDataArray:

return self_sel

def sel(self, indexers: dict = None, method: str = "nearest", **sel_kwargs) -> JaxDataArray:
def sel(self, indexers: dict = None, method: str | None = None, **sel_kwargs) -> JaxDataArray:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this | fail in earlier versions of python? we actually recently stopped supporting 3.8 so maybe it's ok now

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, maybe we want it to be something like method: typing.Literal[None, "nearest"]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this | fail in earlier versions of python? we actually recently stopped supporting 3.8 so maybe it's ok now

Yes absolutely! Just checked, it was introduced in 3.10, so I'll have to change that. I'd really like to use this & match statements though...

Also, maybe we want it to be something like method: typing.Literal[None, "nearest"]?

Yes I was thinking about the Literal too, I'll add it. Not sure why I didn't go with it in the end. Probably it seemed redundant given the if/else check.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my main concern was the docs, but actually I realized methods get documented using their docstrings (not the type annotations). We probably want to just add a proper Parameters and Returns format in the docstring, mentioning method. Could you add that while you're at it?

@@ -312,28 +313,37 @@ def isel(self, **isel_kwargs) -> JaxDataArray:

return self_sel

def sel(self, indexers: dict = None, method: str = "nearest", **sel_kwargs) -> JaxDataArray:
def sel(self, indexers: dict = None, method: str | None = None, **sel_kwargs) -> JaxDataArray:
"""Select a value from the :class:`.JaxDataArray` by indexing into coordinate values."""
Copy link
Collaborator

@tylerflex tylerflex May 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a short mention of method option?

EDIT: actually let's just add the full Parameters and Returns here. I should have done that earlier..

else:
raise NotImplementedError(f"Unkown selection method: {method}.")

def _sel_exact(self, sel_kwargs: dict) -> JaxDataArray:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a short docstring (and to _sel_nearest)

coord_list = np.asarray(self.get_coord_list(coord_name))
vals = np.atleast_1d(vals)
dist = np.abs(coord_list[None] - vals[:, None])
indices = np.where(np.isclose(dist, 0))[1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just so I understand, the only difference between _sel_exact and _sel_nearest seems to be how the indices are computed given the coord_list and vals

# exact
vals = np.atleast_1d(vals)
dist = np.abs(coord_list[None] - vals[:, None])
indices = np.where(np.isclose(dist, 0))[1]
if indices.size == 0:
    raise DataError(f"Could not select '{coord_name}={vals}', some values were not found.")
# nearest
vals = np.asarray(vals)
dist = np.abs(coord_list[:, None] - vals[None])
indices = np.argmin(dist, axis=0)

Is my understanding correct?

Furthermore, it could probably be written to just compute indices given some dist array.

Would it make sense then to refactor a bit to either just stick these in the same method (with an if / else) or just localize the _get_indices(dist) -> incides logic?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes the calculation of the distance matrix should probably be refactored into its own method, you're right.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, after looking at it (see comments below) probably we can just keep the same distance method calculation, but we just need to refactor so that we compute the indices differently given the distance and some method?

for coord_name, vals in sel_kwargs.items():
coord_list = np.asarray(self.get_coord_list(coord_name))
vals = np.asarray(vals)
dist = np.abs(coord_list[:, None] - vals[None])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain the indexing here? so it seems to me that we add a dimension to vals so it would be shaped (N,) -> (1, N) and then coord_list would be (M,) -> (M, 1). And then we have a distance (M, N). We take argmin with respect to the coord_list axis, giving (N,) values corresponding to the smallest distance between each value to all of the coords?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly, which is then an index array of shape (M,) that can be used to index into coord_list. I came across a bug though, in that this does not handle non-numeric types such as coords of the form ["+", "-"]. Have to come up with a fix for that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah.. maybe in that case we should just revert to the exact selection? not sure how xarray handles this.

for coord_name, vals in sel_kwargs.items():
coord_list = np.asarray(self.get_coord_list(coord_name))
vals = np.atleast_1d(vals)
dist = np.abs(coord_list[None] - vals[:, None])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment below but trying to understand the idea here, so coord_list[None] gives an array shaped (1, M) and then vals[:, None] gives (1, N), we take the difference and get an array shaped (N, M).

The next line, by applying np.where(...)[1], we select out all of the N indices into coord_list?

What's a little confusing is that in the nearest method, it seems the dist array is transposed compared to this one, it might be better to either keep them consistent, or probably just to refactor so that we compute this dist method for both nearest and exact and then have a separate logic for computing indices given the dist (for nearest and exact)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the reason for doing this was that it preserves the ordering of the coordinates this way. I'll have another look if I can do that without swapping the indices.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is great 👍

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add a test for the direction coords too

@yaugenst-flex yaugenst-flex marked this pull request as ready for review May 3, 2024 15:47
Copy link
Collaborator

@tylerflex tylerflex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking almost ready, just a couple comments.

tidy3d/plugins/adjoint/components/data/data_array.py Outdated Show resolved Hide resolved
f"Could not select '{coord_name}={vals}', some values were not found."
)

isel_kwargs[coord_name] = jnp.squeeze(indices)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just want to make sure: this squeeze doesn't have any unexpected effects if say my values array has a dimension of size 1? I want to make sure we dont "squeeze" out any dimensions in the values that still have a coordinate in coords. Can we add a test for this edge case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nice yea good catch, there is a difference between .sel(x=0) and .sel(x=[0]). Fixed with tests, matches xarray behavior.

Copy link
Collaborator

@tylerflex tylerflex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, just a comment on the docstring and then after this is addressed we can squash and merge.

Copy link
Collaborator

@tylerflex tylerflex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @yaugenst I'll squash and merge this once the tests pass!

@tylerflex tylerflex merged commit b1fef3f into pre/2.7 May 3, 2024
16 checks passed
@tylerflex tylerflex deleted the yaugenst-flex/issue1617 branch May 3, 2024 18:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants