From 3f0049ffc51e4c709256cf174c435f741370148d Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 9 Oct 2019 19:01:29 +0100 Subject: [PATCH] Speed up isel and __getitem__ (#3375) * Variable.isel cleanup/speedup * Dataset.isel code cleanup * Speed up isel * What's New * Better error checks * Speedup * type annotations * Update doc/whats-new.rst Co-Authored-By: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> * What's New * What's New * Always shallow-copy variables --- doc/whats-new.rst | 10 +++- xarray/core/dataset.py | 108 ++++++++++++++++++++++------------------ xarray/core/indexes.py | 3 +- xarray/core/variable.py | 35 +++++++++---- 4 files changed, 93 insertions(+), 63 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 81206cc5cc1..a3cdcbdc7f5 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -42,6 +42,9 @@ Breaking changes (:issue:`3222`, :issue:`3293`, :issue:`3340`, :issue:`3346`, :issue:`3358`). By `Guido Imperiale `_. +- Dropped the 'drop=False' optional parameter from :meth:`Variable.isel`. + It was unused and doesn't make sense for a Variable. + (:pull:`3375`) by `Guido Imperiale `_. New functions/methods ~~~~~~~~~~~~~~~~~~~~~ @@ -49,14 +52,17 @@ New functions/methods Enhancements ~~~~~~~~~~~~ -- Add a repr for :py:class:`~xarray.core.GroupBy` objects (:issue:`3344`). +- Add a repr for :py:class:`~xarray.core.GroupBy` objects. Example:: >>> da.groupby("time.season") DataArrayGroupBy, grouped over 'season' 4 groups with labels 'DJF', 'JJA', 'MAM', 'SON' - By `Deepak Cherian `_. + (:issue:`3344`) by `Deepak Cherian `_. +- Speed up :meth:`Dataset.isel` up to 33% and :meth:`DataArray.isel` up to 25% for small + arrays (:issue:`2799`, :pull:`3375`) by + `Guido Imperiale `_. Bug fixes ~~~~~~~~~ diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 1d9ef6f7a72..7b4c7b441bd 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1745,8 +1745,8 @@ def maybe_chunk(name, var, chunks): return self._replace(variables) def _validate_indexers( - self, indexers: Mapping - ) -> List[Tuple[Any, Union[slice, Variable]]]: + self, indexers: Mapping[Hashable, Any] + ) -> Iterator[Tuple[Hashable, Union[int, slice, np.ndarray, Variable]]]: """ Here we make sure + indexer has a valid keys + indexer is in a valid data type @@ -1755,50 +1755,61 @@ def _validate_indexers( """ from .dataarray import DataArray - invalid = [k for k in indexers if k not in self.dims] + invalid = indexers.keys() - self.dims.keys() if invalid: raise ValueError("dimensions %r do not exist" % invalid) # all indexers should be int, slice, np.ndarrays, or Variable - indexers_list: List[Tuple[Any, Union[slice, Variable]]] = [] for k, v in indexers.items(): - if isinstance(v, slice): - indexers_list.append((k, v)) - continue - - if isinstance(v, Variable): - pass + if isinstance(v, (int, slice, Variable)): + yield k, v elif isinstance(v, DataArray): - v = v.variable + yield k, v.variable elif isinstance(v, tuple): - v = as_variable(v) + yield k, as_variable(v) elif isinstance(v, Dataset): raise TypeError("cannot use a Dataset as an indexer") elif isinstance(v, Sequence) and len(v) == 0: - v = Variable((k,), np.zeros((0,), dtype="int64")) + yield k, np.empty((0,), dtype="int64") else: v = np.asarray(v) - if v.dtype.kind == "U" or v.dtype.kind == "S": + if v.dtype.kind in "US": index = self.indexes[k] if isinstance(index, pd.DatetimeIndex): v = v.astype("datetime64[ns]") elif isinstance(index, xr.CFTimeIndex): v = _parse_array_of_cftime_strings(v, index.date_type) - if v.ndim == 0: - v = Variable((), v) - elif v.ndim == 1: - v = Variable((k,), v) - else: + if v.ndim > 1: raise IndexError( "Unlabeled multi-dimensional array cannot be " "used for indexing: {}".format(k) ) + yield k, v - indexers_list.append((k, v)) - - return indexers_list + def _validate_interp_indexers( + self, indexers: Mapping[Hashable, Any] + ) -> Iterator[Tuple[Hashable, Variable]]: + """Variant of _validate_indexers to be used for interpolation + """ + for k, v in self._validate_indexers(indexers): + if isinstance(v, Variable): + if v.ndim == 1: + yield k, v.to_index_variable() + else: + yield k, v + elif isinstance(v, int): + yield k, Variable((), v) + elif isinstance(v, np.ndarray): + if v.ndim == 0: + yield k, Variable((), v) + elif v.ndim == 1: + yield k, IndexVariable((k,), v) + else: + raise AssertionError() # Already tested by _validate_indexers + else: + raise TypeError(type(v)) def _get_indexers_coords_and_indexes(self, indexers): """Extract coordinates and indexes from indexers. @@ -1885,10 +1896,10 @@ def isel( Dataset.sel DataArray.isel """ - indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel") - - indexers_list = self._validate_indexers(indexers) + # Note: we need to preserve the original indexers variable in order to merge the + # coords below + indexers_list = list(self._validate_indexers(indexers)) variables = OrderedDict() # type: OrderedDict[Hashable, Variable] indexes = OrderedDict() # type: OrderedDict[Hashable, pd.Index] @@ -1904,19 +1915,21 @@ def isel( ) if new_index is not None: indexes[name] = new_index - else: + elif var_indexers: new_var = var.isel(indexers=var_indexers) + else: + new_var = var.copy(deep=False) variables[name] = new_var - coord_names = set(variables).intersection(self._coord_names) + coord_names = self._coord_names & variables.keys() selected = self._replace_with_new_dims(variables, coord_names, indexes) # Extract coordinates from indexers coord_vars, new_indexes = selected._get_indexers_coords_and_indexes(indexers) variables.update(coord_vars) indexes.update(new_indexes) - coord_names = set(variables).intersection(self._coord_names).union(coord_vars) + coord_names = self._coord_names & variables.keys() | coord_vars.keys() return self._replace_with_new_dims(variables, coord_names, indexes=indexes) def sel( @@ -2478,11 +2491,9 @@ def interp( if kwargs is None: kwargs = {} + coords = either_dict_or_kwargs(coords, coords_kwargs, "interp") - indexers = OrderedDict( - (k, v.to_index_variable() if isinstance(v, Variable) and v.ndim == 1 else v) - for k, v in self._validate_indexers(coords) - ) + indexers = OrderedDict(self._validate_interp_indexers(coords)) obj = self if assume_sorted else self.sortby([k for k in coords]) @@ -2507,26 +2518,25 @@ def _validate_interp_indexer(x, new_x): "strings or datetimes. " "Instead got\n{}".format(new_x) ) - else: - return (x, new_x) + return x, new_x variables = OrderedDict() # type: OrderedDict[Hashable, Variable] for name, var in obj._variables.items(): - if name not in indexers: - if var.dtype.kind in "uifc": - var_indexers = { - k: _validate_interp_indexer(maybe_variable(obj, k), v) - for k, v in indexers.items() - if k in var.dims - } - variables[name] = missing.interp( - var, var_indexers, method, **kwargs - ) - elif all(d not in indexers for d in var.dims): - # keep unrelated object array - variables[name] = var + if name in indexers: + continue + + if var.dtype.kind in "uifc": + var_indexers = { + k: _validate_interp_indexer(maybe_variable(obj, k), v) + for k, v in indexers.items() + if k in var.dims + } + variables[name] = missing.interp(var, var_indexers, method, **kwargs) + elif all(d not in indexers for d in var.dims): + # keep unrelated object array + variables[name] = var - coord_names = set(variables).intersection(obj._coord_names) + coord_names = obj._coord_names & variables.keys() indexes = OrderedDict( (k, v) for k, v in obj.indexes.items() if k not in indexers ) @@ -2546,7 +2556,7 @@ def _validate_interp_indexer(x, new_x): variables.update(coord_vars) indexes.update(new_indexes) - coord_names = set(variables).intersection(obj._coord_names).union(coord_vars) + coord_names = obj._coord_names & variables.keys() | coord_vars.keys() return self._replace_with_new_dims(variables, coord_names, indexes=indexes) def interp_like( diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 94188fabc92..a9f0d802da6 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -2,6 +2,7 @@ from collections import OrderedDict from typing import Any, Hashable, Iterable, Mapping, Optional, Tuple, Union +import numpy as np import pandas as pd from . import formatting @@ -63,7 +64,7 @@ def isel_variable_and_index( name: Hashable, variable: Variable, index: pd.Index, - indexers: Mapping[Any, Union[slice, Variable]], + indexers: Mapping[Hashable, Union[int, slice, np.ndarray, Variable]], ) -> Tuple[Variable, Optional[pd.Index]]: """Index a Variable and pandas.Index together.""" if not indexers: diff --git a/xarray/core/variable.py b/xarray/core/variable.py index b4b01f7ee49..6d7a07c6791 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -3,7 +3,7 @@ from collections import OrderedDict, defaultdict from datetime import timedelta from distutils.version import LooseVersion -from typing import Any, Hashable, Mapping, Union +from typing import Any, Hashable, Mapping, Union, TypeVar import numpy as np import pandas as pd @@ -41,6 +41,18 @@ # https://github.com/python/mypy/issues/224 BASIC_INDEXING_TYPES = integer_types + (slice,) # type: ignore +VariableType = TypeVar("VariableType", bound="Variable") +"""Type annotation to be used when methods of Variable return self or a copy of self. +When called from an instance of a subclass, e.g. IndexVariable, mypy identifies the +output as an instance of the subclass. + +Usage:: + + class Variable: + def f(self: VariableType, ...) -> VariableType: + ... +""" + class MissingDimensionsError(ValueError): """Error class used when we can't safely guess a dimension name. @@ -663,8 +675,8 @@ def _broadcast_indexes_vectorized(self, key): return out_dims, VectorizedIndexer(tuple(out_key)), new_order - def __getitem__(self, key): - """Return a new Array object whose contents are consistent with + def __getitem__(self: VariableType, key) -> VariableType: + """Return a new Variable object whose contents are consistent with getting the provided key from the underlying data. NB. __getitem__ and __setitem__ implement xarray-style indexing, @@ -682,7 +694,7 @@ def __getitem__(self, key): data = duck_array_ops.moveaxis(data, range(len(new_order)), new_order) return self._finalize_indexing_result(dims, data) - def _finalize_indexing_result(self, dims, data): + def _finalize_indexing_result(self: VariableType, dims, data) -> VariableType: """Used by IndexVariable to return IndexVariable objects when possible. """ return type(self)(dims, data, self._attrs, self._encoding, fastpath=True) @@ -957,7 +969,11 @@ def chunk(self, chunks=None, name=None, lock=False): return type(self)(self.dims, data, self._attrs, self._encoding, fastpath=True) - def isel(self, indexers=None, drop=False, **indexers_kwargs): + def isel( + self: VariableType, + indexers: Mapping[Hashable, Any] = None, + **indexers_kwargs: Any + ) -> VariableType: """Return a new array indexed along the specified dimension(s). Parameters @@ -976,15 +992,12 @@ def isel(self, indexers=None, drop=False, **indexers_kwargs): """ indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel") - invalid = [k for k in indexers if k not in self.dims] + invalid = indexers.keys() - set(self.dims) if invalid: raise ValueError("dimensions %r do not exist" % invalid) - key = [slice(None)] * self.ndim - for i, dim in enumerate(self.dims): - if dim in indexers: - key[i] = indexers[dim] - return self[tuple(key)] + key = tuple(indexers.get(dim, slice(None)) for dim in self.dims) + return self[key] def squeeze(self, dim=None): """Return a new object with squeezed data.