diff --git a/databricks/koalas/indexing.py b/databricks/koalas/indexing.py index 5ff219e6fb..00f3a4ec04 100644 --- a/databricks/koalas/indexing.py +++ b/databricks/koalas/indexing.py @@ -142,6 +142,11 @@ def __getitem__(self, key): raise SparkPandasIndexingError('Too many indexers') key = key[0] + if isinstance(key, Series) and key._kdf is not self._kdf_or_kser._kdf: + kdf = self._kdf_or_kser.to_frame() + kdf['__temp_col__'] = key + return type(self)(kdf[self._kdf_or_kser.name])[kdf['__temp_col__']] + cond, limit = self._select_rows(key) if cond is None and limit is None: return self._kdf_or_kser @@ -159,6 +164,12 @@ def __getitem__(self, key): rows_sel = key cols_sel = None + if isinstance(rows_sel, Series) and rows_sel._kdf is not self._kdf_or_kser: + kdf = self._kdf_or_kser.copy() + kdf['__temp_col__'] = rows_sel + return type(self)(kdf)[kdf['__temp_col__'], + cols_sel][list(self._kdf_or_kser.columns)] + cond, limit = self._select_rows(rows_sel) column_index, column_scols, returns_series = self._select_cols(cols_sel) @@ -719,7 +730,7 @@ def _select_rows(self, rows_sel): return None, rows_sel.stop else: ILocIndexer._raiseNotImplemented(".iloc requires numeric slice or conditional " - "boolean Index, got {}".format(rows_sel)) + "boolean Index, got {}".format(type(rows_sel))) def _select_cols(self, cols_sel): from databricks.koalas.series import Series diff --git a/databricks/koalas/series.py b/databricks/koalas/series.py index b38ac97207..2e05e20027 100644 --- a/databricks/koalas/series.py +++ b/databricks/koalas/series.py @@ -4278,19 +4278,7 @@ def __len__(self): def __getitem__(self, key): if isinstance(key, Series) and isinstance(key.spark_type, BooleanType): - should_try_ops_on_diff_frame = key._kdf is not self._kdf - - if should_try_ops_on_diff_frame: - kdf = self.to_frame() - kdf["__temp_col__"] = key - sdf = kdf._sdf.filter(F.col("__temp_col__")).drop("__temp_col__") - return _col(ks.DataFrame(_InternalFrame( - sdf=sdf, - index_map=self._internal.index_map, - column_index=self._internal.column_index, - column_index_names=self._internal.column_index_names))) - else: - return _col(DataFrame(self._internal.copy(sdf=self._kdf._sdf.filter(key._scol)))) + return self.loc[key] if not isinstance(key, tuple): key = (key,) diff --git a/databricks/koalas/tests/test_ops_on_diff_frames.py b/databricks/koalas/tests/test_ops_on_diff_frames.py index e83666458d..5d9ba418b2 100644 --- a/databricks/koalas/tests/test_ops_on_diff_frames.py +++ b/databricks/koalas/tests/test_ops_on_diff_frames.py @@ -298,16 +298,38 @@ def test_arithmetic_chain(self): (pser1 + pser2 * pser3).sort_index(), almost=True) def test_getitem_boolean_series(self): - pdf1 = pd.DataFrame({'A': [0, 1, 2, 3, 4], 'B': [100, 200, 300, 400, 500]}) - pdf2 = pd.DataFrame({'A': [0, -1, -2, -3, -4], 'B': [-100, -200, -300, -400, -500]}) + pdf1 = pd.DataFrame({'A': [0, 1, 2, 3, 4], 'B': [100, 200, 300, 400, 500]}, + index=[20, 10, 30, 0, 50]) + pdf2 = pd.DataFrame({'A': [0, -1, -2, -3, -4], 'B': [-100, -200, -300, -400, -500]}, + index=[0, 30, 10, 20, 50]) kdf1 = ks.from_pandas(pdf1) kdf2 = ks.from_pandas(pdf2) - self.assert_eq(pdf1.A[pdf2.A > 100], - kdf1.A[kdf2.A > 100].sort_index()) + self.assert_eq(pdf1[pdf2.A > -3].sort_index(), + kdf1[kdf2.A > -3].sort_index()) + + self.assert_eq(pdf1.A[pdf2.A > -3].sort_index(), + kdf1.A[kdf2.A > -3].sort_index()) + + self.assert_eq((pdf1.A + 1)[pdf2.A > -3].sort_index(), + (kdf1.A + 1)[kdf2.A > -3].sort_index()) + + def test_loc_getitem_boolean_series(self): + pdf1 = pd.DataFrame({'A': [0, 1, 2, 3, 4], 'B': [100, 200, 300, 400, 500]}, + index=[20, 10, 30, 0, 50]) + pdf2 = pd.DataFrame({'A': [0, -1, -2, -3, -4], 'B': [-100, -200, -300, -400, -500]}, + index=[20, 10, 30, 0, 50]) + kdf1 = ks.from_pandas(pdf1) + kdf2 = ks.from_pandas(pdf2) + + self.assert_eq(pdf1.loc[pdf2.A > -3].sort_index(), + kdf1.loc[kdf2.A > -3].sort_index()) + + self.assert_eq(pdf1.A.loc[pdf2.A > -3].sort_index(), + kdf1.A.loc[kdf2.A > -3].sort_index()) - self.assert_eq((pdf1.A + 1)[pdf2.A > 100], - (kdf1.A + 1)[kdf2.A > 100].sort_index()) + self.assert_eq((pdf1.A + 1).loc[pdf2.A > -3].sort_index(), + (kdf1.A + 1).loc[kdf2.A > -3].sort_index()) def test_bitwise(self): pser1 = pd.Series([True, False, True, False, np.nan, np.nan, True, False, np.nan])