From 4d4d241b66db711eb680cc6d5be9d5efc0c18578 Mon Sep 17 00:00:00 2001
From: Konstantin Malanchev <hombit@gmail.com>
Date: Wed, 8 May 2024 14:46:21 -0400
Subject: [PATCH] Fix .nest[...] = pd.Series

---
 src/nested_pandas/series/accessor.py | 15 ++++++++++++---
 1 file changed, 12 insertions(+), 3 deletions(-)

diff --git a/src/nested_pandas/series/accessor.py b/src/nested_pandas/series/accessor.py
index c4dbd5f..09d1913 100644
--- a/src/nested_pandas/series/accessor.py
+++ b/src/nested_pandas/series/accessor.py
@@ -92,7 +92,7 @@ def to_flat(self, fields: list[str] | None = None) -> pd.DataFrame:
         for field in fields:
             list_array = cast(pa.ListArray, struct_array.field(field))
             if index is None:
-                index = np.repeat(self._series.index.values, np.diff(list_array.offsets))
+                index = self.get_flat_index()
             flat_series[field] = pd.Series(
                 list_array.flatten(),
                 index=pd.Series(index, name=self._series.index.name),
@@ -178,6 +178,13 @@ def query_flat(self, query: str) -> pd.Series:
             return pd.Series([], dtype=self._series.dtype)
         return pack_sorted_df_into_struct(flat)
 
+    def get_flat_index(self) -> pd.Index:
+        """Index of the flat arrays"""
+        return pd.Index(
+            np.repeat(self._series.index.values, np.diff(self._series.array.list_offsets)),
+            name=self._series.index.name,
+        )
+
     def get_flat_series(self, field: str) -> pd.Series:
         """Get the flat-array field as a Series
 
@@ -200,7 +207,7 @@ def get_flat_series(self, field: str) -> pd.Series:
         return pd.Series(
             flat_array,
             dtype=pd.ArrowDtype(flat_array.type),
-            index=np.repeat(self._series.index.values, np.diff(self._series.array.list_offsets)),
+            index=self.get_flat_index(),
             name=field,
             copy=False,
         )
@@ -252,7 +259,9 @@ def __setitem__(self, key: str, value: ArrayLike) -> None:
             self.set_flat_field(key, value)
             return
 
-        if isinstance(value, pd.Series) and not np.array_equal(self._series.index.values, value.index.values):
+        if isinstance(value, pd.Series) and not self.get_flat_index().equals(value.index):
+            print(self._series.index)
+            print(value.index)
             raise ValueError("Cannot set field with a Series of different index")
 
         pa_array = pa.array(value, from_pandas=True)