From 4b6645505d3fb076939fff5884ec77f58528b1de Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 19 Aug 2019 15:52:26 -0700 Subject: [PATCH] Fix reindex. --- databricks/koalas/frame.py | 41 +++++++++++++++-------- databricks/koalas/tests/test_dataframe.py | 11 ++++++ 2 files changed, 38 insertions(+), 14 deletions(-) diff --git a/databricks/koalas/frame.py b/databricks/koalas/frame.py index 0a9336f8d0..e470569eca 100644 --- a/databricks/koalas/frame.py +++ b/databricks/koalas/frame.py @@ -5827,22 +5827,35 @@ def _reindex_index(self, index): return internal def _reindex_columns(self, columns): - label_columns = list(columns) - null_columns = [ - F.lit(np.nan).alias(label_column) for label_column - in label_columns if label_column not in self.columns] + level = self._internal.column_index_level + if level > 1: + label_columns = list(columns) + for col in label_columns: + if not isinstance(col, tuple): + raise TypeError('Expected tuple, got {}'.format(type(col))) + else: + label_columns = [(col,) for col in columns] + for col in label_columns: + if len(col) != level: + raise ValueError("shape (1,{}) doesn't match the shape (1,{})" + .format(len(col), level)) + index_to_column = dict(zip(self._internal.column_index, self._internal.data_columns)) + scols, columns, idx = [], [], [] + null_columns = False + for label in label_columns: + if index_to_column.get(label, None) is not None: + scols.append(self._internal.scol_for(index_to_column[label])) + columns.append(index_to_column[label]) + else: + scols.append(F.lit(np.nan).alias(str(label))) + columns.append(str(label)) + null_columns = True + idx.append(label) - # Concatenate all fields - sdf = self._sdf.select( - self._internal.index_scols + - list(map(self._internal.scol_for, self.columns)) + - null_columns) + if null_columns: + sdf = self._sdf.select(self._internal.index_scols + list(scols)) - # Only select label_columns (with index columns) - sdf = sdf.select(self._internal.index_scols + [scol_for(sdf, col) for col in label_columns]) - return self._internal.copy( - sdf=sdf, - data_columns=label_columns) + return self._internal.copy(sdf=sdf, data_columns=columns, column_index=idx) def melt(self, id_vars=None, value_vars=None, var_name='variable', value_name='value'): diff --git a/databricks/koalas/tests/test_dataframe.py b/databricks/koalas/tests/test_dataframe.py index d3066e581e..1dc315236d 100644 --- a/databricks/koalas/tests/test_dataframe.py +++ b/databricks/koalas/tests/test_dataframe.py @@ -1245,6 +1245,17 @@ def test_reindex(self): self.assertRaises(TypeError, lambda: kdf.reindex(index=['A', 'B', 'C'], axis=1)) self.assertRaises(TypeError, lambda: kdf.reindex(index=123)) + columns = pd.MultiIndex.from_tuples([('X', 'numbers')]) + pdf.columns = columns + kdf.columns = columns + + self.assert_eq( + pdf.reindex(columns=[('X', 'numbers'), ('Y', '2'), ('Y', '3')]).sort_index(), + kdf.reindex(columns=[('X', 'numbers'), ('Y', '2'), ('Y', '3')]).sort_index()) + + self.assertRaises(TypeError, lambda: kdf.reindex(columns=['X'])) + self.assertRaises(ValueError, lambda: kdf.reindex(columns=[('X',)])) + def test_rank(self): pdf = pd.DataFrame(data={'col1': [1, 2, 3, 1], 'col2': [3, 4, 3, 1]}, columns=['col1', 'col2'])