From a02e4ff65e0ed4d785dbadbb2bbabde5b5fb3f91 Mon Sep 17 00:00:00 2001 From: zero323 Date: Thu, 2 Feb 2017 01:16:05 +0100 Subject: [PATCH 1/5] DataFrame.replace improvements --- python/pyspark/sql/dataframe.py | 73 ++++++++++++++++++++++----------- python/pyspark/sql/tests.py | 61 +++++++++++++++++++++++++++ 2 files changed, 109 insertions(+), 25 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 10e42d0f9d322..b5c50844b404d 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1268,7 +1268,7 @@ def fillna(self, value, subset=None): return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) @since(1.4) - def replace(self, to_replace, value, subset=None): + def replace(self, to_replace, value=None, subset=None): """Returns a new :class:`DataFrame` replacing a value with another value. :func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are aliases of each other. @@ -1307,43 +1307,66 @@ def replace(self, to_replace, value, subset=None): |null| null|null| +----+------+----+ """ - if not isinstance(to_replace, (float, int, long, basestring, list, tuple, dict)): + # Helper functions + def all_of(types): + def all_of_(xs): + return all(isinstance(x, types) for x in xs) + return all_of_ + + all_of_bool = all_of(bool) + all_of_str = all_of(basestring) + all_of_numeric = all_of((float, int, long)) + + # Validate input types + valid_types = (bool, float, int, long, basestring, list, tuple) + if not isinstance(to_replace, valid_types + (dict, )): raise ValueError( - "to_replace should be a float, int, long, string, list, tuple, or dict") + "to_replace should be a float, int, long, string, list, tuple, or dict. " + "Got {0}".format(type(to_replace))) - if not isinstance(value, (float, int, long, basestring, list, tuple)): - raise ValueError("value should be a float, int, long, string, list, or tuple") + if (not isinstance(value, valid_types) and + not isinstance(to_replace, dict)): + raise ValueError("If to_replace is not a dict, value should be " + "a float, int, long, string, list, or tuple. " + "Got {0}".format(type(value))) + + if isinstance(to_replace, (list, tuple)) and isinstance(value, (list, tuple)): + if len(to_replace) != len(value): + raise ValueError("to_replace and value lists should be of the same length. " + "Got {0} and {1}".format(len(to_replace), len(value))) - rep_dict = dict() + if not (subset is None or isinstance(subset, (list, tuple, basestring))): + raise ValueError("subset should be a list or tuple of column names, " + "column name or None. Got {0}".format(type(subset))) + # Reshape input arguments if necessary if isinstance(to_replace, (float, int, long, basestring)): to_replace = [to_replace] - if isinstance(to_replace, tuple): - to_replace = list(to_replace) + if isinstance(value, (float, int, long, basestring)): + value = [value for _ in range(len(to_replace))] - if isinstance(value, tuple): - value = list(value) - - if isinstance(to_replace, list) and isinstance(value, list): - if len(to_replace) != len(value): - raise ValueError("to_replace and value lists should be of the same length") - rep_dict = dict(zip(to_replace, value)) - elif isinstance(to_replace, list) and isinstance(value, (float, int, long, basestring)): - rep_dict = dict([(tr, value) for tr in to_replace]) - elif isinstance(to_replace, dict): + if isinstance(to_replace, dict): rep_dict = to_replace + if value is not None: + warnings.warn("to_replace is a dict, but value is not None. " + "value will be ignored.") + else: + rep_dict = dict(zip(to_replace, value)) - if subset is None: - return DataFrame(self._jdf.na().replace('*', rep_dict), self.sql_ctx) - elif isinstance(subset, basestring): + if isinstance(subset, basestring): subset = [subset] - if not isinstance(subset, (list, tuple)): - raise ValueError("subset should be a list or tuple of column names") + # Check if we won't pass mixed type generics + if not any(all_of_type(rep_dict.keys()) and all_of_type(rep_dict.values()) + for all_of_type in [all_of_bool, all_of_str, all_of_numeric]): + raise ValueError("Mixed type replacements are not supported") - return DataFrame( - self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx) + if subset is None: + return DataFrame(self._jdf.na().replace('*', rep_dict), self.sql_ctx) + else: + return DataFrame( + self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx) @since(2.0) def approxQuantile(self, col, probabilities, relativeError): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2fea4ac41f0d3..eb1ea503b6e69 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1591,6 +1591,67 @@ def test_replace(self): self.assertEqual(row.age, 10) self.assertEqual(row.height, None) + # replace with lists + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace([u'Alice'], [u'Ann']).first() + self.assertTupleEqual(row, (u'Ann', 10, 80.1)) + + # replace with dict + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace({10: 11}).first() + self.assertTupleEqual(row, (u'Alice', 11, 80.1)) + + # replace with tuples + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace((u'Alice', ), (u'Bob', )).first() + self.assertTupleEqual(row, (u'Bob', 10, 80.1)) + + # replace multiple columns + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace((10, 80.0), (20, 90)).first() + self.assertTupleEqual(row, (u'Alice', 20, 90.0)) + + # test for mixed numerics + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace((10, 80), (20, 90.5)).first() + self.assertTupleEqual(row, (u'Alice', 20, 90.5)) + + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace({10: 20, 80: 90.5}).first() + self.assertTupleEqual(row, (u'Alice', 20, 90.5)) + + # replace with boolean + row = (self + .spark.createDataFrame([(u'Alice', 10, 80.0)], schema) + .selectExpr("name = 'Bob'", 'age <= 15') + .replace(False, True).first()) + self.assertTupleEqual(row, (True, True)) + + # should fail if subset is not list, tuple or None + with self.assertRaises(ValueError): + self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace({10: 11}, subset=1).first() + + # should fail if to_replace and value have different length + with self.assertRaises(ValueError): + self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace(["Alice", "Bob"], ["Eve"]).first() + + # should fail if when received unexpected type + with self.assertRaises(ValueError): + from datetime import datetime + self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace(datetime.now(), datetime.now()).first() + + # should fail if provided mixed type replacements + with self.assertRaises(ValueError): + self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace(["Alice", 10], ["Eve", 20]).first() + + with self.assertRaises(ValueError): + self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace({u"Alice": u"Bob", 10: 20}).first() + def test_capture_analysis_exception(self): self.assertRaises(AnalysisException, lambda: self.spark.sql("select abc")) self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) From db8f4c924d579bea186ac30ca6b0d06044a4be0a Mon Sep 17 00:00:00 2001 From: zero323 Date: Mon, 27 Feb 2017 04:29:26 +0100 Subject: [PATCH 2/5] Improve internal docs for DataFrame.replace --- python/pyspark/sql/dataframe.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index b5c50844b404d..9168ad20dbbcc 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1309,6 +1309,15 @@ def replace(self, to_replace, value=None, subset=None): """ # Helper functions def all_of(types): + """Given a type or tuple of types + and sequence of xs check if each x + is instance of type(s) + + >>> all_of(bool)([True, False]) + True + >>> all_of(basestring)(["a", 1]) + False + """ def all_of_(xs): return all(isinstance(x, types) for x in xs) return all_of_ @@ -1357,7 +1366,7 @@ def all_of_(xs): if isinstance(subset, basestring): subset = [subset] - # Check if we won't pass mixed type generics + # Verify we were not passed in mixed type generics." if not any(all_of_type(rep_dict.keys()) and all_of_type(rep_dict.values()) for all_of_type in [all_of_bool, all_of_str, all_of_numeric]): raise ValueError("Mixed type replacements are not supported") From e014867271099b8450369fd591fd765c530b083d Mon Sep 17 00:00:00 2001 From: zero323 Date: Mon, 27 Feb 2017 04:41:17 +0100 Subject: [PATCH 3/5] Add more tests for replace with dict --- python/pyspark/sql/tests.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index eb1ea503b6e69..276a84a4ca1a8 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1601,6 +1601,17 @@ def test_replace(self): [(u'Alice', 10, 80.1)], schema).replace({10: 11}).first() self.assertTupleEqual(row, (u'Alice', 11, 80.1)) + # test backward compatibility with dummy value + dummy_value = 1 + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace({'Alice': 'Bob'}, dummy_value).first() + self.assertTupleEqual(row, (u'Bob', 10, 80.1)) + + # test dict with mixed numerics + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace({10: -10, 80.1: 90.5}).first() + self.assertTupleEqual(row, (u'Alice', -10, 90.5)) + # replace with tuples row = self.spark.createDataFrame( [(u'Alice', 10, 80.1)], schema).replace((u'Alice', ), (u'Bob', )).first() From 17e68205ef639893902c65c0394c8aa4406191be Mon Sep 17 00:00:00 2001 From: zero323 Date: Mon, 27 Feb 2017 18:32:13 +0100 Subject: [PATCH 4/5] Add missing warnings import --- python/pyspark/sql/dataframe.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 9168ad20dbbcc..da75193aba30d 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -26,6 +26,8 @@ else: from itertools import imap as map +import warnings + from pyspark import copy_func, since from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer From 03303dfba528f78f7c9118e8a98cca49371993f7 Mon Sep 17 00:00:00 2001 From: zero323 Date: Wed, 8 Mar 2017 00:45:25 +0100 Subject: [PATCH 5/5] Adjust formatting --- python/pyspark/sql/dataframe.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index da75193aba30d..807fe38bda4f8 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1311,9 +1311,8 @@ def replace(self, to_replace, value=None, subset=None): """ # Helper functions def all_of(types): - """Given a type or tuple of types - and sequence of xs check if each x - is instance of type(s) + """Given a type or tuple of types and a sequence of xs + check if each x is instance of type(s) >>> all_of(bool)([True, False]) True @@ -1335,8 +1334,7 @@ def all_of_(xs): "to_replace should be a float, int, long, string, list, tuple, or dict. " "Got {0}".format(type(to_replace))) - if (not isinstance(value, valid_types) and - not isinstance(to_replace, dict)): + if not isinstance(value, valid_types) and not isinstance(to_replace, dict): raise ValueError("If to_replace is not a dict, value should be " "a float, int, long, string, list, or tuple. " "Got {0}".format(type(value))) @@ -1360,8 +1358,7 @@ def all_of_(xs): if isinstance(to_replace, dict): rep_dict = to_replace if value is not None: - warnings.warn("to_replace is a dict, but value is not None. " - "value will be ignored.") + warnings.warn("to_replace is a dict and value is not None. value will be ignored.") else: rep_dict = dict(zip(to_replace, value))