-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Support selection by nearest value in JaxDataArray
#1671
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just some preliminary comments, looking good though
@@ -312,28 +313,37 @@ def isel(self, **isel_kwargs) -> JaxDataArray: | |||
|
|||
return self_sel | |||
|
|||
def sel(self, indexers: dict = None, method: str = "nearest", **sel_kwargs) -> JaxDataArray: | |||
def sel(self, indexers: dict = None, method: str | None = None, **sel_kwargs) -> JaxDataArray: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this |
fail in earlier versions of python? we actually recently stopped supporting 3.8 so maybe it's ok now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, maybe we want it to be something like method: typing.Literal[None, "nearest"]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this | fail in earlier versions of python? we actually recently stopped supporting 3.8 so maybe it's ok now
Yes absolutely! Just checked, it was introduced in 3.10, so I'll have to change that. I'd really like to use this & match
statements though...
Also, maybe we want it to be something like method: typing.Literal[None, "nearest"]?
Yes I was thinking about the Literal
too, I'll add it. Not sure why I didn't go with it in the end. Probably it seemed redundant given the if/else
check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my main concern was the docs, but actually I realized methods get documented using their docstrings (not the type annotations). We probably want to just add a proper Parameters
and Returns
format in the docstring, mentioning method
. Could you add that while you're at it?
@@ -312,28 +313,37 @@ def isel(self, **isel_kwargs) -> JaxDataArray: | |||
|
|||
return self_sel | |||
|
|||
def sel(self, indexers: dict = None, method: str = "nearest", **sel_kwargs) -> JaxDataArray: | |||
def sel(self, indexers: dict = None, method: str | None = None, **sel_kwargs) -> JaxDataArray: | |||
"""Select a value from the :class:`.JaxDataArray` by indexing into coordinate values.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you add a short mention of method
option?
EDIT: actually let's just add the full Parameters
and Returns
here. I should have done that earlier..
else: | ||
raise NotImplementedError(f"Unkown selection method: {method}.") | ||
|
||
def _sel_exact(self, sel_kwargs: dict) -> JaxDataArray: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you add a short docstring (and to _sel_nearest
)
coord_list = np.asarray(self.get_coord_list(coord_name)) | ||
vals = np.atleast_1d(vals) | ||
dist = np.abs(coord_list[None] - vals[:, None]) | ||
indices = np.where(np.isclose(dist, 0))[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just so I understand, the only difference between _sel_exact and _sel_nearest seems to be how the indices are computed given the coord_list
and vals
# exact
vals = np.atleast_1d(vals)
dist = np.abs(coord_list[None] - vals[:, None])
indices = np.where(np.isclose(dist, 0))[1]
if indices.size == 0:
raise DataError(f"Could not select '{coord_name}={vals}', some values were not found.")
# nearest
vals = np.asarray(vals)
dist = np.abs(coord_list[:, None] - vals[None])
indices = np.argmin(dist, axis=0)
Is my understanding correct?
Furthermore, it could probably be written to just compute indices given some dist
array.
Would it make sense then to refactor a bit to either just stick these in the same method (with an if / else) or just localize the _get_indices(dist) -> incides
logic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes the calculation of the distance matrix should probably be refactored into its own method, you're right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well, after looking at it (see comments below) probably we can just keep the same distance method calculation, but we just need to refactor so that we compute the indices differently given the distance and some method
?
for coord_name, vals in sel_kwargs.items(): | ||
coord_list = np.asarray(self.get_coord_list(coord_name)) | ||
vals = np.asarray(vals) | ||
dist = np.abs(coord_list[:, None] - vals[None]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you explain the indexing here? so it seems to me that we add a dimension to vals
so it would be shaped (N,) -> (1, N)
and then coord_list
would be (M,) -> (M, 1)
. And then we have a distance (M, N)
. We take argmin with respect to the coord_list
axis, giving (N,)
values corresponding to the smallest distance between each value to all of the coords?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes exactly, which is then an index array of shape (M,)
that can be used to index into coord_list
. I came across a bug though, in that this does not handle non-numeric types such as coords of the form ["+", "-"]
. Have to come up with a fix for that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah.. maybe in that case we should just revert to the exact selection? not sure how xarray handles this.
for coord_name, vals in sel_kwargs.items(): | ||
coord_list = np.asarray(self.get_coord_list(coord_name)) | ||
vals = np.atleast_1d(vals) | ||
dist = np.abs(coord_list[None] - vals[:, None]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comment below but trying to understand the idea here, so coord_list[None]
gives an array shaped (1, M)
and then vals[:, None]
gives (1, N)
, we take the difference and get an array shaped (N, M)
.
The next line, by applying np.where(...)[1]
, we select out all of the N
indices into coord_list
?
What's a little confusing is that in the nearest method, it seems the dist
array is transposed compared to this one, it might be better to either keep them consistent, or probably just to refactor so that we compute this dist
method for both nearest
and exact
and then have a separate logic for computing indices
given the dist
(for nearest and exact)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the reason for doing this was that it preserves the ordering of the coordinates this way. I'll have another look if I can do that without swapping the indices.
tests/test_plugins/test_adjoint.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is great 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's add a test for the direction
coords too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking almost ready, just a couple comments.
f"Could not select '{coord_name}={vals}', some values were not found." | ||
) | ||
|
||
isel_kwargs[coord_name] = jnp.squeeze(indices) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just want to make sure: this squeeze
doesn't have any unexpected effects if say my values array has a dimension of size 1? I want to make sure we dont "squeeze" out any dimensions in the values that still have a coordinate in coords
. Can we add a test for this edge case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh nice yea good catch, there is a difference between .sel(x=0)
and .sel(x=[0])
. Fixed with tests, matches xarray behavior.
4b8f0d7
to
14acc0b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good, just a comment on the docstring and then after this is addressed we can squash and merge.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @yaugenst I'll squash and merge this once the tests pass!
2cc3afc
to
5f7523f
Compare
Closes #1617
Also took the liberty to rework the mechanics of the old selection by vectorizing everything using numpy.
Note that there is a slight difference between xarray's exact selection and ours now, because I base selections off of
np.isclose
instead of doing exact float comparisions (x == y
). I think this is more intuitive, as otherwise something likex == 0.2
fails ifx
is a numpy float (0.20000000000000018
).I can change this back to match xarray's behavior exactly if that's what we want to do.