Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-19454][PYTHON][SQL] DataFrame.replace improvements #16793

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 56 additions & 25 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
else:
from itertools import imap as map

import warnings

from pyspark import copy_func, since
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
Expand Down Expand Up @@ -1268,7 +1270,7 @@ def fillna(self, value, subset=None):
return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx)

@since(1.4)
def replace(self, to_replace, value, subset=None):
def replace(self, to_replace, value=None, subset=None):
"""Returns a new :class:`DataFrame` replacing a value with another value.
:func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are
aliases of each other.
Expand Down Expand Up @@ -1307,43 +1309,72 @@ def replace(self, to_replace, value, subset=None):
|null| null|null|
+----+------+----+
"""
if not isinstance(to_replace, (float, int, long, basestring, list, tuple, dict)):
# Helper functions
def all_of(types):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe give this a doc-string to clarify what all_of does even though its not user facing better to have a docstring than not.

"""Given a type or tuple of types and a sequence of xs
check if each x is instance of type(s)

>>> all_of(bool)([True, False])
True
>>> all_of(basestring)(["a", 1])
False
"""
def all_of_(xs):
return all(isinstance(x, types) for x in xs)
return all_of_

all_of_bool = all_of(bool)
all_of_str = all_of(basestring)
all_of_numeric = all_of((float, int, long))

# Validate input types
valid_types = (bool, float, int, long, basestring, list, tuple)
if not isinstance(to_replace, valid_types + (dict, )):
raise ValueError(
"to_replace should be a float, int, long, string, list, tuple, or dict")
"to_replace should be a float, int, long, string, list, tuple, or dict. "
"Got {0}".format(type(to_replace)))

if not isinstance(value, (float, int, long, basestring, list, tuple)):
raise ValueError("value should be a float, int, long, string, list, or tuple")
if not isinstance(value, valid_types) and not isinstance(to_replace, dict):
raise ValueError("If to_replace is not a dict, value should be "
"a float, int, long, string, list, or tuple. "
"Got {0}".format(type(value)))

if isinstance(to_replace, (list, tuple)) and isinstance(value, (list, tuple)):
if len(to_replace) != len(value):
raise ValueError("to_replace and value lists should be of the same length. "
"Got {0} and {1}".format(len(to_replace), len(value)))

rep_dict = dict()
if not (subset is None or isinstance(subset, (list, tuple, basestring))):
raise ValueError("subset should be a list or tuple of column names, "
"column name or None. Got {0}".format(type(subset)))

# Reshape input arguments if necessary
if isinstance(to_replace, (float, int, long, basestring)):
to_replace = [to_replace]

if isinstance(to_replace, tuple):
to_replace = list(to_replace)
if isinstance(value, (float, int, long, basestring)):
value = [value for _ in range(len(to_replace))]

if isinstance(value, tuple):
value = list(value)

if isinstance(to_replace, list) and isinstance(value, list):
if len(to_replace) != len(value):
raise ValueError("to_replace and value lists should be of the same length")
rep_dict = dict(zip(to_replace, value))
elif isinstance(to_replace, list) and isinstance(value, (float, int, long, basestring)):
rep_dict = dict([(tr, value) for tr in to_replace])
elif isinstance(to_replace, dict):
if isinstance(to_replace, dict):
rep_dict = to_replace
if value is not None:
warnings.warn("to_replace is a dict and value is not None. value will be ignored.")
else:
rep_dict = dict(zip(to_replace, value))

if subset is None:
return DataFrame(self._jdf.na().replace('*', rep_dict), self.sql_ctx)
elif isinstance(subset, basestring):
if isinstance(subset, basestring):
subset = [subset]

if not isinstance(subset, (list, tuple)):
raise ValueError("subset should be a list or tuple of column names")
# Verify we were not passed in mixed type generics."
if not any(all_of_type(rep_dict.keys()) and all_of_type(rep_dict.values())
for all_of_type in [all_of_bool, all_of_str, all_of_numeric]):
raise ValueError("Mixed type replacements are not supported")

return DataFrame(
self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx)
if subset is None:
return DataFrame(self._jdf.na().replace('*', rep_dict), self.sql_ctx)
else:
return DataFrame(
self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx)

@since(2.0)
def approxQuantile(self, col, probabilities, relativeError):
Expand Down
72 changes: 72 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1591,6 +1591,78 @@ def test_replace(self):
self.assertEqual(row.age, 10)
self.assertEqual(row.height, None)

# replace with lists
row = self.spark.createDataFrame(
[(u'Alice', 10, 80.1)], schema).replace([u'Alice'], [u'Ann']).first()
self.assertTupleEqual(row, (u'Ann', 10, 80.1))

# replace with dict
row = self.spark.createDataFrame(
[(u'Alice', 10, 80.1)], schema).replace({10: 11}).first()
self.assertTupleEqual(row, (u'Alice', 11, 80.1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only test of "new" functionality (excluding error cases), correct?

Copy link
Member Author

@zero323 zero323 Feb 12, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests are mostly a side effect of discussions related to #16792 Right now test coverage is low and we depend on a certain behavior of Py4j and Scala counterpart. Also I wanted to be sure that all the expected types are still accepted after the changes I've made.

So maybe not necessary, but I will argue it is a good idea to have these.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think (and I could be wrong) that @nchammas was suggesting it might make sense to have some more tests with dict, not that the other additional new tests are bad.


# test backward compatibility with dummy value
dummy_value = 1
row = self.spark.createDataFrame(
[(u'Alice', 10, 80.1)], schema).replace({'Alice': 'Bob'}, dummy_value).first()
self.assertTupleEqual(row, (u'Bob', 10, 80.1))

# test dict with mixed numerics
row = self.spark.createDataFrame(
[(u'Alice', 10, 80.1)], schema).replace({10: -10, 80.1: 90.5}).first()
self.assertTupleEqual(row, (u'Alice', -10, 90.5))

# replace with tuples
row = self.spark.createDataFrame(
[(u'Alice', 10, 80.1)], schema).replace((u'Alice', ), (u'Bob', )).first()
self.assertTupleEqual(row, (u'Bob', 10, 80.1))

# replace multiple columns
row = self.spark.createDataFrame(
[(u'Alice', 10, 80.0)], schema).replace((10, 80.0), (20, 90)).first()
self.assertTupleEqual(row, (u'Alice', 20, 90.0))

# test for mixed numerics
row = self.spark.createDataFrame(
[(u'Alice', 10, 80.0)], schema).replace((10, 80), (20, 90.5)).first()
self.assertTupleEqual(row, (u'Alice', 20, 90.5))

row = self.spark.createDataFrame(
[(u'Alice', 10, 80.0)], schema).replace({10: 20, 80: 90.5}).first()
self.assertTupleEqual(row, (u'Alice', 20, 90.5))

# replace with boolean
row = (self
.spark.createDataFrame([(u'Alice', 10, 80.0)], schema)
.selectExpr("name = 'Bob'", 'age <= 15')
.replace(False, True).first())
self.assertTupleEqual(row, (True, True))

# should fail if subset is not list, tuple or None
with self.assertRaises(ValueError):
self.spark.createDataFrame(
[(u'Alice', 10, 80.1)], schema).replace({10: 11}, subset=1).first()

# should fail if to_replace and value have different length
with self.assertRaises(ValueError):
self.spark.createDataFrame(
[(u'Alice', 10, 80.1)], schema).replace(["Alice", "Bob"], ["Eve"]).first()

# should fail if when received unexpected type
with self.assertRaises(ValueError):
from datetime import datetime
self.spark.createDataFrame(
[(u'Alice', 10, 80.1)], schema).replace(datetime.now(), datetime.now()).first()

# should fail if provided mixed type replacements
with self.assertRaises(ValueError):
self.spark.createDataFrame(
[(u'Alice', 10, 80.1)], schema).replace(["Alice", 10], ["Eve", 20]).first()

with self.assertRaises(ValueError):
self.spark.createDataFrame(
[(u'Alice', 10, 80.1)], schema).replace({u"Alice": u"Bob", 10: 20}).first()

def test_capture_analysis_exception(self):
self.assertRaises(AnalysisException, lambda: self.spark.sql("select abc"))
self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b"))
Expand Down