From 4d4d241b66db711eb680cc6d5be9d5efc0c18578 Mon Sep 17 00:00:00 2001 From: Konstantin Malanchev <hombit@gmail.com> Date: Wed, 8 May 2024 14:46:21 -0400 Subject: [PATCH] Fix .nest[...] = pd.Series --- src/nested_pandas/series/accessor.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/nested_pandas/series/accessor.py b/src/nested_pandas/series/accessor.py index c4dbd5f..09d1913 100644 --- a/src/nested_pandas/series/accessor.py +++ b/src/nested_pandas/series/accessor.py @@ -92,7 +92,7 @@ def to_flat(self, fields: list[str] | None = None) -> pd.DataFrame: for field in fields: list_array = cast(pa.ListArray, struct_array.field(field)) if index is None: - index = np.repeat(self._series.index.values, np.diff(list_array.offsets)) + index = self.get_flat_index() flat_series[field] = pd.Series( list_array.flatten(), index=pd.Series(index, name=self._series.index.name), @@ -178,6 +178,13 @@ def query_flat(self, query: str) -> pd.Series: return pd.Series([], dtype=self._series.dtype) return pack_sorted_df_into_struct(flat) + def get_flat_index(self) -> pd.Index: + """Index of the flat arrays""" + return pd.Index( + np.repeat(self._series.index.values, np.diff(self._series.array.list_offsets)), + name=self._series.index.name, + ) + def get_flat_series(self, field: str) -> pd.Series: """Get the flat-array field as a Series @@ -200,7 +207,7 @@ def get_flat_series(self, field: str) -> pd.Series: return pd.Series( flat_array, dtype=pd.ArrowDtype(flat_array.type), - index=np.repeat(self._series.index.values, np.diff(self._series.array.list_offsets)), + index=self.get_flat_index(), name=field, copy=False, ) @@ -252,7 +259,9 @@ def __setitem__(self, key: str, value: ArrayLike) -> None: self.set_flat_field(key, value) return - if isinstance(value, pd.Series) and not np.array_equal(self._series.index.values, value.index.values): + if isinstance(value, pd.Series) and not self.get_flat_index().equals(value.index): + print(self._series.index) + print(value.index) raise ValueError("Cannot set field with a Series of different index") pa_array = pa.array(value, from_pandas=True)