Skip to content

Commit

Permalink
fix: fix broken multiindex loc cases (#467)
Browse files Browse the repository at this point in the history
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly:
- [ ] Make sure to open an issue as a [bug/issue](https://togithub.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code!  That way we can discuss the change, evaluate designs, and agree on the general idea
- [ ] Ensure the tests and linter pass
- [ ] Code coverage does not decrease (if any source code was changed)
- [ ] Appropriate docs were updated (if necessary)

Fixes #<issue_number_goes_here> 🦕
  • Loading branch information
TrevorBergeron committed Mar 20, 2024
1 parent 3189bf1 commit 7bd6820
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 105 deletions.
157 changes: 57 additions & 100 deletions bigframes/core/indexers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

import typing
from typing import List, Tuple, Union
from typing import Tuple, Union

import ibis
import pandas as pd
Expand Down Expand Up @@ -147,19 +147,22 @@ def __getitem__(
...

def __getitem__(self, key):
# TODO(swast): If the DataFrame has a MultiIndex, we'll need to
# disambiguate this from a single row selection.
# TODO(tbergeron): Pandas will try both splitting 2-tuple into row, index or as 2-part
# row key. We must choose one, so bias towards treating as multi-part row label
if isinstance(key, tuple) and len(key) == 2:
df = typing.cast(
bigframes.dataframe.DataFrame,
_loc_getitem_series_or_dataframe(self._dataframe, key[0]),
)
is_row_multi_index = self._dataframe.index.nlevels > 1
is_first_item_tuple = isinstance(key[0], tuple)
if not is_row_multi_index or is_first_item_tuple:
df = typing.cast(
bigframes.dataframe.DataFrame,
_loc_getitem_series_or_dataframe(self._dataframe, key[0]),
)

columns = key[1]
if isinstance(columns, pd.Series) and columns.dtype == "bool":
columns = df.columns[columns]
columns = key[1]
if isinstance(columns, pd.Series) and columns.dtype == "bool":
columns = df.columns[columns]

return df[columns]
return df[columns]

return typing.cast(
bigframes.dataframe.DataFrame,
Expand Down Expand Up @@ -283,94 +286,40 @@ def _loc_getitem_series_or_dataframe(
pd.Series,
bigframes.core.scalar.Scalar,
]:
if isinstance(key, bigframes.series.Series) and key.dtype == "boolean":
return series_or_dataframe[key]
elif isinstance(key, bigframes.series.Series):
temp_name = guid.generate_guid(prefix="temp_series_name_")
if len(series_or_dataframe.index.names) > 1:
temp_name = series_or_dataframe.index.names[0]
key = key.rename(temp_name)
keys_df = key.to_frame()
keys_df = keys_df.set_index(temp_name, drop=True)
return _perform_loc_list_join(series_or_dataframe, keys_df)
elif isinstance(key, bigframes.core.indexes.Index):
block = key._block
block = block.select_columns(())
keys_df = bigframes.dataframe.DataFrame(block)
return _perform_loc_list_join(series_or_dataframe, keys_df)
elif pd.api.types.is_list_like(key):
key = typing.cast(List, key)
if len(key) == 0:
return typing.cast(
Union[bigframes.dataframe.DataFrame, bigframes.series.Series],
series_or_dataframe.iloc[0:0],
)
if pd.api.types.is_list_like(key[0]):
original_index_names = series_or_dataframe.index.names
num_index_cols = len(original_index_names)

entry_col_count_correct = [len(entry) == num_index_cols for entry in key]
if not all(entry_col_count_correct):
# pandas usually throws TypeError in these cases- tuple causes IndexError, but that
# seems like unintended behavior
raise TypeError(
"All entries must be of equal length when indexing by list of listlikes"
)
temporary_index_names = [
guid.generate_guid(prefix="temp_loc_index_")
for _ in range(len(original_index_names))
]
index_cols_dict = {}
for i in range(num_index_cols):
index_name = temporary_index_names[i]
values = [entry[i] for entry in key]
index_cols_dict[index_name] = values
keys_df = bigframes.dataframe.DataFrame(
index_cols_dict, session=series_or_dataframe._get_block().expr.session
)
keys_df = keys_df.set_index(temporary_index_names, drop=True)
keys_df = keys_df.rename_axis(original_index_names)
else:
# We can't upload a DataFrame with None as the column name, so set it
# an arbitrary string.
index_name = series_or_dataframe.index.name
index_name_is_none = index_name is None
if index_name_is_none:
index_name = "unnamed_col"
keys_df = bigframes.dataframe.DataFrame(
{index_name: key},
session=series_or_dataframe._get_block().expr.session,
)
keys_df = keys_df.set_index(index_name, drop=True)
if index_name_is_none:
keys_df.index.name = None
return _perform_loc_list_join(series_or_dataframe, keys_df)
elif isinstance(key, slice):
if isinstance(key, slice):
if (key.start is None) and (key.stop is None) and (key.step is None):
return series_or_dataframe.copy()
raise NotImplementedError(
f"loc does not yet support indexing with a slice. {constants.FEEDBACK_LINK}"
)
elif callable(key):
if callable(key):
raise NotImplementedError(
f"loc does not yet support indexing with a callable. {constants.FEEDBACK_LINK}"
)
elif pd.api.types.is_scalar(key):
index_name = "unnamed_col"
keys_df = bigframes.dataframe.DataFrame(
{index_name: [key]}, session=series_or_dataframe._get_block().expr.session
)
keys_df = keys_df.set_index(index_name, drop=True)
keys_df.index.name = None
result = _perform_loc_list_join(series_or_dataframe, keys_df)
pandas_result = result.to_pandas()
# although loc[scalar_key] returns multiple results when scalar_key
# is not unique, we download the results here and return the computed
# individual result (as a scalar or pandas series) when the key is unique,
# since we expect unique index keys to be more common. loc[[scalar_key]]
# can be used to retrieve one-item DataFrames or Series.
if len(pandas_result) == 1:
return pandas_result.iloc[0]
elif isinstance(key, bigframes.series.Series) and key.dtype == "boolean":
return series_or_dataframe[key]
elif (
isinstance(key, bigframes.series.Series)
or isinstance(key, indexes.Index)
or (pd.api.types.is_list_like(key) and not isinstance(key, tuple))
):
index = indexes.Index(key, session=series_or_dataframe._session)
index.names = series_or_dataframe.index.names[: index.nlevels]
return _perform_loc_list_join(series_or_dataframe, index)
elif pd.api.types.is_scalar(key) or isinstance(key, tuple):
index = indexes.Index([key], session=series_or_dataframe._session)
index.names = series_or_dataframe.index.names[: index.nlevels]
result = _perform_loc_list_join(series_or_dataframe, index, drop_levels=True)

if index.nlevels == series_or_dataframe.index.nlevels:
pandas_result = result.to_pandas()
# although loc[scalar_key] returns multiple results when scalar_key
# is not unique, we download the results here and return the computed
# individual result (as a scalar or pandas series) when the key is unique,
# since we expect unique index keys to be more common. loc[[scalar_key]]
# can be used to retrieve one-item DataFrames or Series.
if len(pandas_result) == 1:
return pandas_result.iloc[0]
# when the key is not unique, we return a bigframes data type
# as usual for methods that return dataframes/series
return result
Expand All @@ -385,39 +334,47 @@ def _loc_getitem_series_or_dataframe(
@typing.overload
def _perform_loc_list_join(
series_or_dataframe: bigframes.series.Series,
keys_df: bigframes.dataframe.DataFrame,
keys_index: indexes.Index,
drop_levels: bool = False,
) -> bigframes.series.Series:
...


@typing.overload
def _perform_loc_list_join(
series_or_dataframe: bigframes.dataframe.DataFrame,
keys_df: bigframes.dataframe.DataFrame,
keys_index: indexes.Index,
drop_levels: bool = False,
) -> bigframes.dataframe.DataFrame:
...


def _perform_loc_list_join(
series_or_dataframe: Union[bigframes.dataframe.DataFrame, bigframes.series.Series],
keys_df: bigframes.dataframe.DataFrame,
keys_index: indexes.Index,
drop_levels: bool = False,
) -> Union[bigframes.series.Series, bigframes.dataframe.DataFrame]:
# right join based on the old index so that the matching rows from the user's
# original dataframe will be duplicated and reordered appropriately
original_index_names = series_or_dataframe.index.names
if isinstance(series_or_dataframe, bigframes.series.Series):
original_name = series_or_dataframe.name
name = series_or_dataframe.name if series_or_dataframe.name is not None else "0"
result = typing.cast(
bigframes.series.Series,
series_or_dataframe.to_frame()._perform_join_by_index(keys_df, how="right")[
name
],
series_or_dataframe.to_frame()._perform_join_by_index(
keys_index, how="right"
)[name],
)
result = result.rename(original_name)
else:
result = series_or_dataframe._perform_join_by_index(keys_df, how="right") # type: ignore
result = result.rename_axis(original_index_names)
result = series_or_dataframe._perform_join_by_index(keys_index, how="right") # type: ignore

if drop_levels and series_or_dataframe.index.nlevels > keys_index.nlevels:
# drop common levels
levels_to_drop = [
name for name in series_or_dataframe.index.names if name in keys_index.names
]
result = result.droplevel(levels_to_drop) # type: ignore
return result


Expand Down
3 changes: 2 additions & 1 deletion bigframes/core/indexes/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
dtype=None,
*,
name=None,
session=None,
):
import bigframes.dataframe as df
import bigframes.series as series
Expand All @@ -75,7 +76,7 @@ def __init__(
else:
pd_index = pandas.Index(data=data, dtype=dtype, name=name)
pd_df = pandas.DataFrame(index=pd_index)
block = df.DataFrame(pd_df)._block
block = df.DataFrame(pd_df, session=session)._block
self._query_job = None
self._block: blocks.Block = block

Expand Down
4 changes: 3 additions & 1 deletion bigframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2310,7 +2310,9 @@ def join(

return left._perform_join_by_index(right, how=how)

def _perform_join_by_index(self, other: DataFrame, *, how: str = "left"):
def _perform_join_by_index(
self, other: Union[DataFrame, indexes.Index], *, how: str = "left"
):
block, _ = self._block.join(other._block, how=how, block_identity_join=True)
return DataFrame(block)

Expand Down
4 changes: 4 additions & 0 deletions bigframes/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ def T(self) -> Series:
def _info_axis(self) -> indexes.Index:
return self.index

@property
def _session(self) -> bigframes.Session:
return self._get_block().expr.session

def transpose(self) -> Series:
return self

Expand Down
25 changes: 22 additions & 3 deletions tests/system/small/test_multiindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,15 +169,34 @@ def test_concat_multi_indices_ignore_index(scalars_df_index, scalars_pandas_df_i
pandas.testing.assert_frame_equal(bf_result.to_pandas(), pd_result)


def test_multi_index_loc(scalars_df_index, scalars_pandas_df_index):
@pytest.mark.parametrize(
("key"),
[
(2),
([2, 0]),
([(2, "capitalize, This "), (-2345, "Hello, World!")]),
],
)
def test_multi_index_loc_multi_row(scalars_df_index, scalars_pandas_df_index, key):
bf_result = (
scalars_df_index.set_index(["int64_too", "bool_col"]).loc[[2, 0]].to_pandas()
scalars_df_index.set_index(["int64_too", "string_col"]).loc[key].to_pandas()
)
pd_result = scalars_pandas_df_index.set_index(["int64_too", "bool_col"]).loc[[2, 0]]
pd_result = scalars_pandas_df_index.set_index(["int64_too", "string_col"]).loc[key]

pandas.testing.assert_frame_equal(bf_result, pd_result)


def test_multi_index_loc_single_row(scalars_df_index, scalars_pandas_df_index):
bf_result = scalars_df_index.set_index(["int64_too", "string_col"]).loc[
(2, "capitalize, This ")
]
pd_result = scalars_pandas_df_index.set_index(["int64_too", "string_col"]).loc[
(2, "capitalize, This ")
]

pandas.testing.assert_series_equal(bf_result, pd_result)


def test_multi_index_getitem_bool(scalars_df_index, scalars_pandas_df_index):
bf_frame = scalars_df_index.set_index(["int64_too", "bool_col"])
pd_frame = scalars_pandas_df_index.set_index(["int64_too", "bool_col"])
Expand Down

0 comments on commit 7bd6820

Please sign in to comment.