From bf5535ebda8001854f4e0c289560e8f838b41bc4 Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Tue, 8 Oct 2024 17:38:56 +0000 Subject: [PATCH] fix: make `explode` respect the index labels --- bigframes/core/blocks.py | 2 +- tests/system/small/test_multiindex.py | 20 +++++++++++++++++++- tests/system/small/test_series.py | 22 ++++++++++++++++++++++ 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/bigframes/core/blocks.py b/bigframes/core/blocks.py index 2b3734edd5..be3aaccd08 100644 --- a/bigframes/core/blocks.py +++ b/bigframes/core/blocks.py @@ -1390,7 +1390,7 @@ def explode( expr, column_labels=self.column_labels, index_columns=self.index_columns, - index_labels=self.column_labels.names, + index_labels=self._index_labels, ) def _standard_stats(self, column_id) -> typing.Sequence[agg_ops.UnaryAggregateOp]: diff --git a/tests/system/small/test_multiindex.py b/tests/system/small/test_multiindex.py index ab2a9c19b8..cab74f617d 100644 --- a/tests/system/small/test_multiindex.py +++ b/tests/system/small/test_multiindex.py @@ -1178,7 +1178,7 @@ def test_column_multi_index_dot_not_supported(): bf1 @ bf2 -def test_explode_w_multi_index(): +def test_explode_w_column_multi_index(): data = [[[1, 1], np.nan, [3, 3]], [[2], [5], []]] multi_level_columns = pandas.MultiIndex.from_arrays( [["col0", "col0", "col1"], ["col00", "col01", "col11"]] @@ -1197,6 +1197,24 @@ def test_explode_w_multi_index(): ) +def test_explode_w_multi_index(): + data = [[[1, 1], np.nan, [3, 3]], [[2], [5], []]] + columns = ["col00", "col01", "col11"] + multi_index = pandas.MultiIndex.from_frame( + pandas.DataFrame({"idx0": [5, 1], "idx1": ["z", "x"]}) + ) + + df = bpd.DataFrame(data, index=multi_index, columns=columns) + pd_df = df.to_pandas() + + pandas.testing.assert_frame_equal( + df.explode("col00").to_pandas(), + pd_df.explode("col00"), + check_dtype=False, + check_index_type=False, + ) + + def test_column_multi_index_w_na_stack(scalars_df_index, scalars_pandas_df_index): columns = ["int64_too", "int64_col", "rowindex_2"] level1 = pandas.Index(["b", "c", "d"]) diff --git a/tests/system/small/test_series.py b/tests/system/small/test_series.py index 624e287f8d..f1c60664a1 100644 --- a/tests/system/small/test_series.py +++ b/tests/system/small/test_series.py @@ -3852,6 +3852,28 @@ def test_series_explode(data): pytest.param([5, 1, 3, 2], False, id="ignore_unordered_index"), pytest.param(["z", "x", "a", "b"], True, id="str_index"), pytest.param(["z", "x", "a", "b"], False, id="ignore_str_index"), + pytest.param( + pd.Index(["z", "x", "a", "b"], name="idx"), True, id="str_named_index" + ), + pytest.param( + pd.Index(["z", "x", "a", "b"], name="idx"), + False, + id="ignore_str_named_index", + ), + pytest.param( + pd.MultiIndex.from_frame( + pd.DataFrame({"idx0": [5, 1, 3, 2], "idx1": ["z", "x", "a", "b"]}) + ), + True, + id="multi_index", + ), + pytest.param( + pd.MultiIndex.from_frame( + pd.DataFrame({"idx0": [5, 1, 3, 2], "idx1": ["z", "x", "a", "b"]}) + ), + False, + id="ignore_multi_index", + ), ], ) def test_series_explode_w_index(index, ignore_index):