Skip to content

Commit

Permalink
Fix .nest.get_*_series
Browse files Browse the repository at this point in the history
  • Loading branch information
hombit committed Sep 6, 2024
1 parent 0d8c2f3 commit 8fca944
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 10 deletions.
29 changes: 19 additions & 10 deletions src/nested_pandas/series/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,14 +251,18 @@ def get_flat_series(self, field: str) -> pd.Series:
The flat-array field.
"""

# TODO: we should make proper missed values handling here
flat_chunks = []
for nested_chunk in self._series.array._chunked_array.iterchunks():
struct_array = cast(pa.StructArray, nested_chunk)
list_array = cast(pa.ListArray, struct_array.field(field))
flat_array = list_array.flatten()
flat_chunks.append(flat_array)

flat_chunked_array = pa.chunked_array(flat_chunks)

struct_array = cast(pa.StructArray, pa.array(self._series))
list_array = cast(pa.ListArray, struct_array.field(field))
flat_array = list_array.flatten()
return pd.Series(
flat_array,
dtype=pd.ArrowDtype(flat_array.type),
flat_chunked_array,
dtype=pd.ArrowDtype(flat_chunked_array.type),
index=self.get_flat_index(),
name=field,
copy=False,
Expand All @@ -277,11 +281,16 @@ def get_list_series(self, field: str) -> pd.Series:
pd.Series
The list-array field.
"""
struct_array = cast(pa.StructArray, pa.array(self._series))
list_array = struct_array.field(field)

list_chunks = []
for nested_chunk in self._series.array._chunked_array.iterchunks():
struct_array = cast(pa.StructArray, nested_chunk)
list_array = struct_array.field(field)
list_chunks.append(list_array)
list_chunked_array = pa.chunked_array(list_chunks)
return pd.Series(
list_array,
dtype=pd.ArrowDtype(list_array.type),
list_chunked_array,
dtype=pd.ArrowDtype(list_chunked_array.type),
index=self._series.index,
name=field,
copy=False,
Expand Down
53 changes: 53 additions & 0 deletions tests/nested_pandas/series/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,32 @@ def test_get_list_series():
)


def test_get_list_series_multiple_chunks():
"""Test that .nest.get_list_series() works when underlying array is chunked"""
struct_array = pa.StructArray.from_arrays(
arrays=[
[np.array([1, 2, 3]), np.array([4, 5, 6])],
[np.array([6, 4, 2]), np.array([1, 2, 3])],
],
names=["a", "b"],
)
chunked_array = pa.chunked_array([struct_array] * 3)
series = pd.Series(chunked_array, dtype=NestedDtype(chunked_array.type), index=[5, 7, 9, 11, 13, 15])
assert series.array.num_chunks == 3

lists = series.nest.get_list_series("a")

assert_series_equal(
lists,
pd.Series(
data=[np.array([1, 2, 3]), np.array([4, 5, 6])] * 3,
dtype=pd.ArrowDtype(pa.list_(pa.int64())),
index=[5, 7, 9, 11, 13, 15],
name="a",
),
)


def test_get():
"""Test .nest.get() which is implemented by the base class"""
series = pack_seq(
Expand Down Expand Up @@ -588,6 +614,33 @@ def test___getitem___single_field():
)


def test___getitem___single_field_multiple_chunks():
"""Reproduces issue 142
https://github.com/lincc-frameworks/nested-pandas/issues/142
"""
struct_array = pa.StructArray.from_arrays(
arrays=[
[np.array([1.0, 2.0, 3.0]), np.array([1.0, 2.0, 1.0])],
[np.array([4.0, 5.0, 6.0]), np.array([3.0, 4.0, 5.0])],
],
names=["a", "b"],
)
chunked_array = pa.chunked_array([struct_array] * 3)
series = pd.Series(chunked_array, dtype=NestedDtype(chunked_array.type), index=[0, 1, 2, 3, 4, 5])
assert series.array.num_chunks == 3

assert_series_equal(
series.nest["a"],
pd.Series(
np.array([1.0, 2.0, 3.0, 1.0, 2.0, 1.0] * 3),
dtype=pd.ArrowDtype(pa.float64()),
index=[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5],
name="a",
),
)


def test___getitem___multiple_fields():
"""Test that the .nest[["b", "a"]] works for multiple fields."""
arrays = [
Expand Down

0 comments on commit 8fca944

Please sign in to comment.