Skip to content

Commit

Permalink
Support Series derived from a different dataframe as row_sel. (#1155)
Browse files Browse the repository at this point in the history
  • Loading branch information
ueshin authored and HyukjinKwon committed Dec 30, 2019
1 parent b3babf9 commit b3e9682
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 20 deletions.
13 changes: 12 additions & 1 deletion databricks/koalas/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ def __getitem__(self, key):
raise SparkPandasIndexingError('Too many indexers')
key = key[0]

if isinstance(key, Series) and key._kdf is not self._kdf_or_kser._kdf:
kdf = self._kdf_or_kser.to_frame()
kdf['__temp_col__'] = key
return type(self)(kdf[self._kdf_or_kser.name])[kdf['__temp_col__']]

cond, limit = self._select_rows(key)
if cond is None and limit is None:
return self._kdf_or_kser
Expand All @@ -160,6 +165,12 @@ def __getitem__(self, key):
rows_sel = key
cols_sel = None

if isinstance(rows_sel, Series) and rows_sel._kdf is not self._kdf_or_kser:
kdf = self._kdf_or_kser.copy()
kdf['__temp_col__'] = rows_sel
return type(self)(kdf)[kdf['__temp_col__'],
cols_sel][list(self._kdf_or_kser.columns)]

cond, limit = self._select_rows(rows_sel)
column_index, column_scols, returns_series = self._select_cols(cols_sel)

Expand Down Expand Up @@ -720,7 +731,7 @@ def _select_rows(self, rows_sel):
return None, rows_sel.stop
else:
ILocIndexer._raiseNotImplemented(".iloc requires numeric slice or conditional "
"boolean Index, got {}".format(rows_sel))
"boolean Index, got {}".format(type(rows_sel)))

def _select_cols(self, cols_sel):
from databricks.koalas.series import Series
Expand Down
14 changes: 1 addition & 13 deletions databricks/koalas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -4271,19 +4271,7 @@ def __len__(self):

def __getitem__(self, key):
if isinstance(key, Series) and isinstance(key.spark_type, BooleanType):
should_try_ops_on_diff_frame = key._kdf is not self._kdf

if should_try_ops_on_diff_frame:
kdf = self.to_frame()
kdf["__temp_col__"] = key
sdf = kdf._sdf.filter(F.col("__temp_col__")).drop("__temp_col__")
return _col(ks.DataFrame(_InternalFrame(
sdf=sdf,
index_map=self._internal.index_map,
column_index=self._internal.column_index,
column_index_names=self._internal.column_index_names)))
else:
return _col(DataFrame(self._internal.copy(sdf=self._kdf._sdf.filter(key._scol))))
return self.loc[key]

if not isinstance(key, tuple):
key = (key,)
Expand Down
34 changes: 28 additions & 6 deletions databricks/koalas/tests/test_ops_on_diff_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,16 +298,38 @@ def test_arithmetic_chain(self):
(pser1 + pser2 * pser3).sort_index(), almost=True)

def test_getitem_boolean_series(self):
pdf1 = pd.DataFrame({'A': [0, 1, 2, 3, 4], 'B': [100, 200, 300, 400, 500]})
pdf2 = pd.DataFrame({'A': [0, -1, -2, -3, -4], 'B': [-100, -200, -300, -400, -500]})
pdf1 = pd.DataFrame({'A': [0, 1, 2, 3, 4], 'B': [100, 200, 300, 400, 500]},
index=[20, 10, 30, 0, 50])
pdf2 = pd.DataFrame({'A': [0, -1, -2, -3, -4], 'B': [-100, -200, -300, -400, -500]},
index=[0, 30, 10, 20, 50])
kdf1 = ks.from_pandas(pdf1)
kdf2 = ks.from_pandas(pdf2)

self.assert_eq(pdf1.A[pdf2.A > 100],
kdf1.A[kdf2.A > 100].sort_index())
self.assert_eq(pdf1[pdf2.A > -3].sort_index(),
kdf1[kdf2.A > -3].sort_index())

self.assert_eq(pdf1.A[pdf2.A > -3].sort_index(),
kdf1.A[kdf2.A > -3].sort_index())

self.assert_eq((pdf1.A + 1)[pdf2.A > -3].sort_index(),
(kdf1.A + 1)[kdf2.A > -3].sort_index())

def test_loc_getitem_boolean_series(self):
pdf1 = pd.DataFrame({'A': [0, 1, 2, 3, 4], 'B': [100, 200, 300, 400, 500]},
index=[20, 10, 30, 0, 50])
pdf2 = pd.DataFrame({'A': [0, -1, -2, -3, -4], 'B': [-100, -200, -300, -400, -500]},
index=[20, 10, 30, 0, 50])
kdf1 = ks.from_pandas(pdf1)
kdf2 = ks.from_pandas(pdf2)

self.assert_eq(pdf1.loc[pdf2.A > -3].sort_index(),
kdf1.loc[kdf2.A > -3].sort_index())

self.assert_eq(pdf1.A.loc[pdf2.A > -3].sort_index(),
kdf1.A.loc[kdf2.A > -3].sort_index())

self.assert_eq((pdf1.A + 1)[pdf2.A > 100],
(kdf1.A + 1)[kdf2.A > 100].sort_index())
self.assert_eq((pdf1.A + 1).loc[pdf2.A > -3].sort_index(),
(kdf1.A + 1).loc[kdf2.A > -3].sort_index())

def test_bitwise(self):
pser1 = pd.Series([True, False, True, False, np.nan, np.nan, True, False, np.nan])
Expand Down

0 comments on commit b3e9682

Please sign in to comment.