diff --git a/databricks/koalas/frame.py b/databricks/koalas/frame.py index a911ecf714..3d7f807d7a 100644 --- a/databricks/koalas/frame.py +++ b/databricks/koalas/frame.py @@ -5261,37 +5261,24 @@ def replace( if len(value) != len(to_replace): raise ValueError("Length of to_replace and value must be same") - subset = self._internal.data_spark_column_names - - sdf = self._internal.resolved_copy.spark_frame - if ( - isinstance(to_replace, dict) - and value is None - and (not any(isinstance(i, dict) for i in to_replace.values())) + if isinstance(to_replace, dict) and ( + value is not None or all(isinstance(i, dict) for i in to_replace.values()) ): - sdf = sdf.replace(to_replace, value, subset) - elif isinstance(to_replace, dict): - for name, replacement in to_replace.items(): - if isinstance(name, str): - name = (name,) - df_column = self._internal.spark_column_name_for(name) - if isinstance(replacement, dict): - sdf = sdf.replace(replacement, subset=df_column) + + def op(kser): + if kser.name in to_replace: + return kser.replace(to_replace=to_replace[kser.name], value=value, regex=regex) else: - sdf = sdf.withColumn( - df_column, - F.when(scol_for(sdf, df_column) == replacement, value).otherwise( - scol_for(sdf, df_column) - ), - ) + return kser + else: - sdf = sdf.replace(to_replace, value, subset) + op = lambda kser: kser.replace(to_replace=to_replace, value=value, regex=regex) - internal = self._internal.with_new_sdf(sdf) + kdf = self._apply_series_op(op) if inplace: - self._internal = internal + self._internal = kdf._internal else: - return DataFrame(internal) + return kdf def clip(self, lower: Union[float, int] = None, upper: Union[float, int] = None) -> "DataFrame": """ diff --git a/databricks/koalas/tests/test_dataframe.py b/databricks/koalas/tests/test_dataframe.py index a12788f872..63b852833b 100644 --- a/databricks/koalas/tests/test_dataframe.py +++ b/databricks/koalas/tests/test_dataframe.py @@ -1577,6 +1577,7 @@ def test_replace(self): self.assert_eq(kdf.replace({"A": 0, "B": 5}, 100), pdf.replace({"A": 0, "B": 5}, 100)) self.assert_eq(kdf.replace({"A": {0: 100, 4: 400}}), pdf.replace({"A": {0: 100, 4: 400}})) + self.assert_eq(kdf.replace({"X": {0: 100, 4: 400}}), pdf.replace({"X": {0: 100, 4: 400}})) # multi-index columns columns = pd.MultiIndex.from_tuples([("X", "A"), ("X", "B"), ("Y", "C")]) @@ -1600,6 +1601,9 @@ def test_replace(self): self.assert_eq( kdf.replace({("X", "A"): {0: 100, 4: 400}}), pdf.replace({("X", "A"): {0: 100, 4: 400}}) ) + self.assert_eq( + kdf.replace({("X", "B"): {0: 100, 4: 400}}), pdf.replace({("X", "B"): {0: 100, 4: 400}}) + ) def test_update(self): # check base function