-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: indexing with broadcasting #1473
Conversation
xarray/core/variable.py
Outdated
along self.dims. | ||
""" | ||
if not utils.is_dict_like(key): | ||
key = {self.dims[0]: key} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that if key
is a tuple, it should be paired with multiple dimensions.
xarray/core/variable.py
Outdated
example_v = v | ||
indexes[k] = v | ||
|
||
# When all the keys are array or integer, slice |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we actually need two totally different code paths for basic vs. advanced indexing:
- If all indexers are integers/slices, we can use NumPy's "basic indexing". There's no need to futz with broadcasting. We can use exactly a simpler version of the existing logic, dropping _
indexable_data
. (We need this path because "basic indexing" is much faster than "advanced indexing" in NumPy, as it avoids copying data.) - If any indexers are arrays or Variable objects, we need to do "advanced indexing":
- Normalize every
(key, value)
pair intoxarray.Variable
objects:xarray.Variable
value -> keep them unmodified- Integer values ->
xarray.Variable((), value)
- Slice values ->
xarray.Variable((key,), np.arange(*value.indexes(size)))
wheresize
is the size of the dimension corresponding tokey
. - 1D list/numpy arrays ->
xarray.Variable((key,), value)
- N-D arrays -> IndexError
- Make all indexers broadcast compatible with a single call to
_broadcast_compat_variables()
. All variables are now guaranteed to have the samedims
, so takedims
from the first variable to label the result of indexing. - Create a tuple using
data
from each of these variables, and use it to index this variable's data.
- Normalize every
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's an alternative version (basically what I described above) that I think will work, at least if the variable's data is a NumPy array (we may need to jump through a few more hoops for dask arrays):
def _broadcast_indexes(self, key):
key = self._item_key_to_tuple(key) # key is a tuple
key = indexing.expanded_indexer(key, self.ndim) # key is a tuple of full size
basic_indexing_types = integer_types + (slice,)
if all(isinstance(k, basic_indexing_types) for k in key):
return self._broadcast_indexes_basic(key)
else:
return self._broadcast_indexes_advanced(key)
def _broadcast_indexes_basic(self, key):
dims = tuple(dim for k, dim in zip(key, self.dims)
if not isinstance(k, integer_types))
return dims, key
def _broadcast_indexes_advanced(self, key):
variables = []
for dim, value in zip(self.dims, key):
if isinstance(value, slice):
value = np.arange(*value.indices(self.sizes[dim])
# NOTE: this is close to but not quite correct, since we want to
# handle tuples differently than as_variable and want a different
# error message (not referencing tuples)
variable = as_variable(value, name=dim)
variables.append(variable)
variables = _broadcast_compat_variables(*variables)
dims = variables[0].dims # all variables have the same dims
key = tuple(variable.data for variable in variables)
return dims, key
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need this path because "basic indexing" is much faster than "advanced indexing" in NumPy, as it avoids copying data.
Thanks for the suggestion!
And also thanks for the reference implementation.
(What a simple logic!!)
Please let me make sure one thing.
What would be expected by the following?
v = Variable(['x', 'y'], [[0, 1, 2], [3, 4, 5]])
ind_x = Variable(['a'], [0, 1])
v.getitem2(dict(x=ind_x, y=[1, 0])))
It should be understood as y=Variable(['y'], [1, 0])
?
I thought y=[0, 1]
would be y=Variable(['a'], [1, 0])
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, y=[0, 1]
-> Variable(['y'], [1, 0])
.
I think the result here would look like Variable(['a', 'y'], [[3, 4], [0, 1]])
:
- After making them broadcast compatible, the indexer variables have dimensions
['a', 'y']
and shapes(2, 1)
/(1, 2)
. - For the
x
coordinate, each coordinate value is taken in order. - For the
y
coordinate, the first two coordinates are taken in reverse order. - The result is labeled by the dimensions of the indexers (i.e.,
('a', 'y')
).
Thanks @shoyer As you pointed out, we may need better Error message in here. |
xarray/core/variable.py
Outdated
try: # TODO we need our own Exception. | ||
variable = as_variable(value, name=dim) | ||
except ValueError as e: | ||
if "cannot set variable" in str(e): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My current implementation for this exception handling is rather bad.
I want to change this to something like
except DimensionMismatchError:
raise IndexError('...')
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. Can you switch this for as_variable
?
Inherit from ValueError
for backwards compatibility. Something like MissingDimensionsError
should be more broadly useful for cases like this where we can't safely guess a dimension name.
I just realized that dask's indexing is limited, e.g. it does not support nd-array indexing. |
|
xarray/core/variable.py
Outdated
else: | ||
raise e | ||
if variable._isbool_type(): # boolean indexing case | ||
variables.extend(list(variable.nonzero())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a little hesitate to allow multi-dimensional boolean indexers here. The problem with using extend()
here is that method calls like a.isel(x=b)
get mapped into something like a[b]
, so if b
is multi-dimensional, the second dimension of b
gets matched up with the second dimension of a
in an unpredictable way. We would need some way to specify the mapping to multiple dimensions, something like a.isel((x,y)=b)
(obviously not valid syntax).
Instead, I would error for boolean indexers with more than one dimension, and then convert with nonzero()
, as you've done here.
Multi-dimensional boolean indexers are useful, but I think the main use-case is indexing with a single argument like x[y > 0]
, so we don't need fancy remapping between dimensions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. I add the sanity check for the boolean array.
xarray/core/variable.py
Outdated
"cannot be used for indexing.") | ||
else: | ||
raise e | ||
if variable._isbool_type(): # boolean indexing case |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make this just variable.dtype.kind == 'b'
?
xarray/core/variable.py
Outdated
|
||
for dim, value in zip(self.dims, key): | ||
if isinstance(value, slice): | ||
value = np.arange(self.sizes[dim])[value] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the slice.indices
method here, to construct the desired array without indexing:
In [14]: x = slice(0, 10, 2)
In [15]: np.arange(*x.indices(5))
Out[15]: array([0, 2, 4])
xarray/core/variable.py
Outdated
elif isinstance(self._data, dask_array_type): | ||
# TODO we should replace dask's native nonzero | ||
# after https://github.com/dask/dask/issues/1076 is implemented. | ||
nonzeros = np.nonzero(self.load()._data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would just do nonzeros = np.nonzero(self.data)
instead of all these isinstance
checks. You can leave the TODO, but I don't think we actually need it for indexing since currently we already load indexers into memory.
xarray/core/variable.py
Outdated
try: # TODO we need our own Exception. | ||
variable = as_variable(value, name=dim) | ||
except ValueError as e: | ||
if "cannot set variable" in str(e): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. Can you switch this for as_variable
?
Inherit from ValueError
for backwards compatibility. Something like MissingDimensionsError
should be more broadly useful for cases like this where we can't safely guess a dimension name.
xarray/core/variable.py
Outdated
@@ -412,15 +473,15 @@ def __setitem__(self, key, value): | |||
|
|||
See __getitem__ for more details. | |||
""" | |||
key = self._item_key_to_tuple(key) | |||
dims, index_tuple = self._broadcast_indexes(key) | |||
if isinstance(self._data, dask_array_type): | |||
raise TypeError("this variable's data is stored in a dask array, " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is actually no longer true -- dask.array
supports assignment in recent versions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xarray/core/variable.py
Outdated
data = orthogonally_indexable(self._data) | ||
data[key] = value | ||
data = broadcasted_indexable(self._data) | ||
data[index_tuple] = value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
value
should be broadcast against the indexing key
if possible.
If it's an xarray.Variable
, I think we can just call value.set_dims(dims)
on it. If it's a NumPy/Dask array, I think we can safely ignore it and let NumPy/Dask handle broadcasting.
xarray/core/indexing.py
Outdated
""" | ||
def __init__(self, array): | ||
self.array = array | ||
|
||
def __getitem__(self, key): | ||
key = expanded_indexer(key, self.ndim) | ||
if any(not isinstance(k, integer_types + (slice,)) for k in key): | ||
""" key: tuple of Variable, slice, integer """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should only be a NumPy or dask array here, not a Variable.
xarray/core/indexing.py
Outdated
if any(not isinstance(k, integer_types + (slice,)) for k in key): | ||
""" key: tuple of Variable, slice, integer """ | ||
# basic or orthogonal indexing | ||
if all(isinstance(k, (integer_types, slice)) or k.squeeze().ndim <= 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For simple cases where everything is an integers or slice, we want to just use a single indexing call like self.array[key]
rather than this loop.
In the hard case when some arguments are arrays, we should try self.array.vindex[key]
. If it doesn't work in some cases, we can either add work-arounds or try to fix it upstream in dask.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After looking at dask.array in a little more detail, I think we need to keep a work-around for "orthogonal" indexing in dask. It looks like vindex
only works when each indexer is 1D.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See fujiisoup#1
Refactor DaskArrayAdapter
Let's not worry about supporting every indexing type with dask. I think that with my patch we can do everything we currently do. We'll want I think we'll also want to make an "vectorized to orthogonal" indexing adapter that we can use |
# Conflicts: # xarray/core/variable.py
@shoyer Thanks for your help.
Yes. Thanks to your patch, dask-based variable is now indexed fine. Some replies to your comments to the outdated codes.
I will try to fit the other array wrappers, |
I think it would be better to update |
@shoyer and @fujiisoup - is this ready for a final review? |
I think it's ready. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow! This PR is a behemoth. I think I made it through the whole thing. I'm really excited about all these changes so nice work.
doc/indexing.rst
Outdated
Note that using ``sel`` it is not possible to mix a dimension | ||
indexer with level indexers for that dimension | ||
(e.g., ``mda.sel(x={'one': 'a'}, two=0)`` will raise a ``ValueError``). | ||
In briefly, similar to how NumPy's `advanced indexing`_ works, vectorized |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In briefly
doc/indexing.rst
Outdated
dimensions or use the ellipsis in the ``loc`` specifier, e.g. in the example | ||
above, ``mda.loc[{'one': 'a', 'two': 0}, :]`` or ``mda.loc[('a', 0), ...]``. | ||
ind = xr.DataArray([['a', 'b'], ['b', 'a']], dims=['a', 'b']) | ||
da.loc[:, ind] # same to da.sel(y=ind) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as
doc/indexing.rst
Outdated
|
||
Multi-dimensional indexing | ||
-------------------------- | ||
and also for ``Dataset`` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These methods may also be applied to Dataset
objects
doc/indexing.rst
Outdated
using numpy's broadcasting rules to vectorize indexers. This means you can do | ||
indexing like this, which would require slightly more awkward syntax with | ||
numpy arrays: | ||
As like numpy ndarray, value assignment sometimes works differently from what one may expect. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like numpy.ndarray
, ...
doc/indexing.rst
Outdated
|
||
|
||
.. note:: | ||
Dask backend does not yet support value assignment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dask arrays do not support value assignment
xarray/core/variable.py
Outdated
@@ -99,7 +111,7 @@ def as_variable(obj, name=None): | |||
if name is not None and name in obj.dims: | |||
# convert the Variable into an Index | |||
if obj.ndim != 1: | |||
raise ValueError( | |||
raise MissingDimensionsError( | |||
'%r has more than 1-dimension and the same name as one of its ' | |||
'dimensions %r. xarray disallows such variables because they ' | |||
'conflict with the coordinates used to label dimensions.' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this may be missing % (name, data.ndim)
# after https://github.com/dask/dask/issues/1076 is implemented. | ||
nonzeros = np.nonzero(self.data) | ||
return tuple(Variable((dim), nz) for nz, dim | ||
in zip(nonzeros, self.dims)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this is in dask now. Can we add a conditional import of the dask version here?
Side note, I think a nonzero method just like this would fit in xarray's public API (not part of this PR).
xarray/core/variable.py
Outdated
data = self._indexable_data[index_tuple] | ||
if new_order: | ||
data = np.moveaxis(data, range(len(new_order)), new_order) | ||
assert getattr(data, 'ndim', 0) == len(dims), (data.ndim, len(dims)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we raise a ValueError
here instead and provide more informative error message?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I don't think of the case this assertion fails.
Maybe I will remove this line.
xarray/tests/test_dataset.py
Outdated
assert all(["Indexer has dimensions " not in | ||
str(w.message) for w in ws]) | ||
warnings.warn('dummy', FutureWarning, stacklevel=3) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 things here:
- I'm not sure why we need the dummy warning if it should not warn. Can we assert that a warning was not raised?
- move the
ind = ...
line out of the context manager so it is clear which line should be raising a warning
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not know a good way to make sure that the code does NOT warn.
If the dummy warning line is removed, the test fails because the code does not warn.
Does anyone know a better workaround?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this SO answer should work.
ind = xr.DataArray([0.0, 1.0], dims=['dim2'], name='ind')
with pytest.warns(None) as ws:
# Should not warn
data.reindex(dim2=ind)
assert len(ws) == 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. It now works without dummy warning.
xarray/tests/test_dataset.py
Outdated
if pd.__version__ >= '0.17': | ||
with self.assertRaises(KeyError): | ||
data.sel_points(x=[2.5], y=[2.0], method='pad', tolerance=1e-3) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move to its own test and use pytest.importorskip('pandas', minversion='0.17')
in place of your if statement
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still need to look carefully at the logic for coordinate conflicts.
if current != previous + 1: | ||
return False | ||
previous = current | ||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, all()
is defined in the right way for length 0 lists. So this can just be:
def _is_contiguous(positions):
return (np.diff(positions) == 1).all()
This is arguably an improvement for readability (at least for the numpy familiar devs who work on xarray) but there's no advantage to using this because positions
only has one element per indexer, so it almost always a list of only a few items.
@jhamman Thanks for the review (and sorry for my late reply). @shoyer Limitations of the current implementation are
mda.sel(x=xr.DataArray(mda.indexes['x'][:3], dims='x')) works as expected, but mda.sel(x=xr.DataArray(mda.indexes['x'][:3], dims='z')) will attach coordinate |
@fujiisoup Thanks again for all your hard work on this and for my slow response. I've made another PR with tweaks to your logic for conflicting coordinates: fujiisoup#5 Mostly, my PR is about simplifying the logic by removing the special case work arounds you added that check object identity (things like |
I will take a look at the multi-index issues. I suspect that many of these will be hard to resolve until we complete the refactor making |
Simplify rules for indexing conflicts
@shoyer, thanks for your review.
OK. It makes sense also for me. Merging your PR.
Actually, the vectorized label-indexing currently does not work almost entirely with In [1]: import xarray as xr
...: import pandas as pd
...:
...: midx = pd.MultiIndex.from_tuples(
...: [(1, 'a'), (2, 'b'), (3, 'c')],
...: names=['x0', 'x1'])
...: da = xr.DataArray([0, 1, 2], dims=['x'],
...: coords={'x': midx})
...: da
...:
Out[1]:
<xarray.DataArray (x: 3)>
array([0, 1, 2])
Coordinates:
* x (x) MultiIndex
- x0 (x) int64 1 2 3
- x1 (x) object 'a' 'b' 'c'
works as expected,
fail without appropriate error messages
destructs the MultiIndex structure silently. I will add better Exceptions later today. |
I think this is ready to go in. @jhamman @fujiisoup any reason to wait? |
I'm happy with this :) |
LGTM. |
@fujiisoup Can you open a new pull request with this branch? I'd like to give you credit on GitHub for this (since you did most of the work), but I think if I merge this with "Squash and Merge" everything will get credited to me. You can also try doing your own rebase to clean-up history into fewer commits if you like (or I could "squash and merge" locally in git), but I think the new PR would do a better job of preserving history anyone who wants to look at this later. |
I closed this intentionally since I think there is a good chance GitHub won't let you open a new PR otherwise. |
xref #974 (comment)