diff --git a/src/nested_pandas/series/accessor.py b/src/nested_pandas/series/accessor.py index b7936bb..cf12b99 100644 --- a/src/nested_pandas/series/accessor.py +++ b/src/nested_pandas/series/accessor.py @@ -251,14 +251,18 @@ def get_flat_series(self, field: str) -> pd.Series: The flat-array field. """ - # TODO: we should make proper missed values handling here + flat_chunks = [] + for nested_chunk in self._series.array._chunked_array.iterchunks(): + struct_array = cast(pa.StructArray, nested_chunk) + list_array = cast(pa.ListArray, struct_array.field(field)) + flat_array = list_array.flatten() + flat_chunks.append(flat_array) + + flat_chunked_array = pa.chunked_array(flat_chunks) - struct_array = cast(pa.StructArray, pa.array(self._series)) - list_array = cast(pa.ListArray, struct_array.field(field)) - flat_array = list_array.flatten() return pd.Series( - flat_array, - dtype=pd.ArrowDtype(flat_array.type), + flat_chunked_array, + dtype=pd.ArrowDtype(flat_chunked_array.type), index=self.get_flat_index(), name=field, copy=False, @@ -277,11 +281,16 @@ def get_list_series(self, field: str) -> pd.Series: pd.Series The list-array field. """ - struct_array = cast(pa.StructArray, pa.array(self._series)) - list_array = struct_array.field(field) + + list_chunks = [] + for nested_chunk in self._series.array._chunked_array.iterchunks(): + struct_array = cast(pa.StructArray, nested_chunk) + list_array = struct_array.field(field) + list_chunks.append(list_array) + list_chunked_array = pa.chunked_array(list_chunks) return pd.Series( - list_array, - dtype=pd.ArrowDtype(list_array.type), + list_chunked_array, + dtype=pd.ArrowDtype(list_chunked_array.type), index=self._series.index, name=field, copy=False, diff --git a/tests/nested_pandas/series/test_accessor.py b/tests/nested_pandas/series/test_accessor.py index b63c879..9510ec1 100644 --- a/tests/nested_pandas/series/test_accessor.py +++ b/tests/nested_pandas/series/test_accessor.py @@ -543,6 +543,32 @@ def test_get_list_series(): ) +def test_get_list_series_multiple_chunks(): + """Test that .nest.get_list_series() works when underlying array is chunked""" + struct_array = pa.StructArray.from_arrays( + arrays=[ + [np.array([1, 2, 3]), np.array([4, 5, 6])], + [np.array([6, 4, 2]), np.array([1, 2, 3])], + ], + names=["a", "b"], + ) + chunked_array = pa.chunked_array([struct_array] * 3) + series = pd.Series(chunked_array, dtype=NestedDtype(chunked_array.type), index=[5, 7, 9, 11, 13, 15]) + assert series.array.num_chunks == 3 + + lists = series.nest.get_list_series("a") + + assert_series_equal( + lists, + pd.Series( + data=[np.array([1, 2, 3]), np.array([4, 5, 6])] * 3, + dtype=pd.ArrowDtype(pa.list_(pa.int64())), + index=[5, 7, 9, 11, 13, 15], + name="a", + ), + ) + + def test_get(): """Test .nest.get() which is implemented by the base class""" series = pack_seq( @@ -588,6 +614,33 @@ def test___getitem___single_field(): ) +def test___getitem___single_field_multiple_chunks(): + """Reproduces issue 142 + + https://github.com/lincc-frameworks/nested-pandas/issues/142 + """ + struct_array = pa.StructArray.from_arrays( + arrays=[ + [np.array([1.0, 2.0, 3.0]), np.array([1.0, 2.0, 1.0])], + [np.array([4.0, 5.0, 6.0]), np.array([3.0, 4.0, 5.0])], + ], + names=["a", "b"], + ) + chunked_array = pa.chunked_array([struct_array] * 3) + series = pd.Series(chunked_array, dtype=NestedDtype(chunked_array.type), index=[0, 1, 2, 3, 4, 5]) + assert series.array.num_chunks == 3 + + assert_series_equal( + series.nest["a"], + pd.Series( + np.array([1.0, 2.0, 3.0, 1.0, 2.0, 1.0] * 3), + dtype=pd.ArrowDtype(pa.float64()), + index=[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5], + name="a", + ), + ) + + def test___getitem___multiple_fields(): """Test that the .nest[["b", "a"]] works for multiple fields.""" arrays = [