From 1441ac15c273db4bb34f4a16ed21d46b7e9f3e9f Mon Sep 17 00:00:00 2001 From: Anderson Banihirwe Date: Fri, 22 Mar 2024 16:50:00 -0700 Subject: [PATCH 1/2] break the main __getitem__ method down into smaller methods --- xarray/core/indexing.py | 33 ++++++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 82ee4ccb0e4..c35d17c2270 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1686,6 +1686,28 @@ def _oindex_get(self, indexer: OuterIndexer): def _vindex_get(self, indexer: VectorizedIndexer): return self.__getitem__(indexer) + def _prepare_key(self, key: tuple[Any, ...]) -> tuple[Any, ...]: + if isinstance(key, tuple) and len(key) == 1: + # unpack key so it can index a pandas.Index object (pandas.Index + # objects don't like tuples) + (key,) = key + + return key + + def _handle_result( + self, result: Any + ) -> ( + PandasIndexingAdapter + | NumpyIndexingAdapter + | np.ndarray + | np.datetime64 + | np.timedelta64 + ): + if isinstance(result, pd.Index): + return type(self)(result, dtype=self.dtype) + else: + return self._convert_scalar(result) + def __getitem__( self, indexer: ExplicitIndexer ) -> ( @@ -1695,11 +1717,7 @@ def __getitem__( | np.datetime64 | np.timedelta64 ): - key = indexer.tuple - if isinstance(key, tuple) and len(key) == 1: - # unpack key so it can index a pandas.Index object (pandas.Index - # objects don't like tuples) - (key,) = key + key = self._prepare_key(indexer.tuple) if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional indexable = NumpyIndexingAdapter(np.asarray(self)) @@ -1707,10 +1725,7 @@ def __getitem__( result = self.array[key] - if isinstance(result, pd.Index): - return type(self)(result, dtype=self.dtype) - else: - return self._convert_scalar(result) + return self._handle_result(result) def transpose(self, order) -> pd.Index: return self.array # self.array should be always one-dimensional From 626f41563419e75af7e93d63425b52e12acf16b7 Mon Sep 17 00:00:00 2001 From: Anderson Banihirwe Date: Fri, 22 Mar 2024 17:07:31 -0700 Subject: [PATCH 2/2] handle the _oindex_get, _vindex_get, and __getitem__ methods for the PandasMultiIndexingAdapter --- xarray/core/indexing.py | 69 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 62 insertions(+), 7 deletions(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index c35d17c2270..92ace8573b8 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1680,12 +1680,6 @@ def _convert_scalar(self, item): # a NumPy array. return to_0d_array(item) - def _oindex_get(self, indexer: OuterIndexer): - return self.__getitem__(indexer) - - def _vindex_get(self, indexer: VectorizedIndexer): - return self.__getitem__(indexer) - def _prepare_key(self, key: tuple[Any, ...]) -> tuple[Any, ...]: if isinstance(key, tuple) and len(key) == 1: # unpack key so it can index a pandas.Index object (pandas.Index @@ -1708,7 +1702,7 @@ def _handle_result( else: return self._convert_scalar(result) - def __getitem__( + def _get_item( self, indexer: ExplicitIndexer ) -> ( PandasIndexingAdapter @@ -1727,6 +1721,39 @@ def __getitem__( return self._handle_result(result) + def _oindex_get( + self, indexer: OuterIndexer + ) -> ( + PandasIndexingAdapter + | NumpyIndexingAdapter + | np.ndarray + | np.datetime64 + | np.timedelta64 + ): + return self._get_item(indexer) + + def _vindex_get( + self, indexer: VectorizedIndexer + ) -> ( + PandasIndexingAdapter + | NumpyIndexingAdapter + | np.ndarray + | np.datetime64 + | np.timedelta64 + ): + return self._get_item(indexer) + + def __getitem__( + self, indexer: ExplicitIndexer + ) -> ( + PandasIndexingAdapter + | NumpyIndexingAdapter + | np.ndarray + | np.datetime64 + | np.timedelta64 + ): + return self._get_item(indexer) + def transpose(self, order) -> pd.Index: return self.array # self.array should be always one-dimensional @@ -1781,6 +1808,34 @@ def _convert_scalar(self, item): item = item[idx] return super()._convert_scalar(item) + def _oindex_get( + self, indexer: OuterIndexer + ) -> ( + PandasIndexingAdapter + | NumpyIndexingAdapter + | np.ndarray + | np.datetime64 + | np.timedelta64 + ): + result = super()._oindex_get(indexer) + if isinstance(result, type(self)): + result.level = self.level + return result + + def _vindex_get( + self, indexer: VectorizedIndexer + ) -> ( + PandasIndexingAdapter + | NumpyIndexingAdapter + | np.ndarray + | np.datetime64 + | np.timedelta64 + ): + result = super()._vindex_get(indexer) + if isinstance(result, type(self)): + result.level = self.level + return result + def __getitem__(self, indexer: ExplicitIndexer): result = super().__getitem__(indexer) if isinstance(result, type(self)):