diff --git a/databricks/koalas/frame.py b/databricks/koalas/frame.py index 08584fe0d3..4d63d3d0e4 100644 --- a/databricks/koalas/frame.py +++ b/databricks/koalas/frame.py @@ -9129,7 +9129,7 @@ def rename_axis( columns: Optional[Any] = None, axis: Optional[Union[int, str]] = 0, inplace: Optional[bool] = False, - ): + ) -> Optional["DataFrame"]: """ Set the name of the axis for the index or columns. @@ -9186,6 +9186,7 @@ def rename_axis( dog 4 0 cat 4 0 monkey 2 2 + >>> df = df.rename_axis("animal").sort_index() >>> df # doctest: +NORMALIZE_WHITESPACE num_legs num_arms @@ -9193,6 +9194,7 @@ def rename_axis( cat 4 0 dog 4 0 monkey 2 2 + >>> df = df.rename_axis("limbs", axis="columns").sort_index() >>> df # doctest: +NORMALIZE_WHITESPACE limbs num_legs num_arms @@ -9265,8 +9267,6 @@ def gen_names(v, curnames): index = mapper elif axis == 1: columns = mapper - else: - raise ValueError("No axis named %s for object type %s." % (axis, type(axis))) column_label_names = ( gen_names(columns, self.columns.names) @@ -9278,19 +9278,20 @@ def gen_names(v, curnames): ) index_map = OrderedDict(zip(self._internal.index_spark_column_names, index_names)) + spark_frame = self._internal.resolved_copy.spark_frame internal = InternalFrame( - self._internal.spark_frame, + spark_frame=spark_frame, index_map=index_map, column_labels=self._internal.column_labels, data_spark_columns=[ - scol_for(self._internal.spark_frame, col) - for col in self._internal.data_spark_column_names + scol_for(spark_frame, col) for col in self._internal.data_spark_column_names ], column_label_names=column_label_names, ) if inplace: self._update_internal_frame(internal) + return None else: return DataFrame(internal) diff --git a/databricks/koalas/series.py b/databricks/koalas/series.py index 3488a3d946..ef1383f9f9 100644 --- a/databricks/koalas/series.py +++ b/databricks/koalas/series.py @@ -1117,7 +1117,7 @@ def rename(self, index=None, **kwargs): def rename_axis( self, mapper: Optional[Any] = None, index: Optional[Any] = None, inplace: bool = False - ): + ) -> Optional["Series"]: """ Set the name of the axis for the index or columns. diff --git a/databricks/koalas/tests/test_series.py b/databricks/koalas/tests/test_series.py index 837b63b093..3cc48eaa12 100644 --- a/databricks/koalas/tests/test_series.py +++ b/databricks/koalas/tests/test_series.py @@ -180,6 +180,11 @@ def test_rename_axis(self): pser.rename_axis("index2").sort_index(), kser.rename_axis("index2").sort_index(), ) + self.assert_eq( + (pser + 1).rename_axis("index2").sort_index(), + (kser + 1).rename_axis("index2").sort_index(), + ) + self.assertRaises(ValueError, lambda: kser.rename_axis(["index2", "index3"])) # index/columns parameters and dict_like/functions mappers introduced in pandas 0.24.0