Skip to content

Commit

Permalink
[SPARK-20685] Fix BatchPythonEvaluation bug in case of single UDF w/ …
Browse files Browse the repository at this point in the history
…repeated arg.

## What changes were proposed in this pull request?

There's a latent corner-case bug in PySpark UDF evaluation where executing a `BatchPythonEvaluation` with a single multi-argument UDF where _at least one argument value is repeated_ will crash at execution with a confusing error.

This problem was introduced in #12057: the code there has a fast path for handling a "batch UDF evaluation consisting of a single Python UDF", but that branch incorrectly assumes that a single UDF won't have repeated arguments and therefore skips the code for unpacking arguments from the input row (whose schema may not necessarily match the UDF inputs due to de-duplication of repeated arguments which occurred in the JVM before sending UDF inputs to Python).

This fix here is simply to remove this special-casing: it turns out that the code in the "multiple UDFs" branch just so happens to work for the single-UDF case because Python treats `(x)` as equivalent to `x`, not as a single-argument tuple.

## How was this patch tested?

New regression test in `pyspark.python.sql.tests` module (tested and confirmed that it fails before my fix).

Author: Josh Rosen <[email protected]>

Closes #17927 from JoshRosen/SPARK-20685.
  • Loading branch information
JoshRosen authored and gatorsmile committed May 10, 2017
1 parent af8b6cc commit 8ddbc43
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 16 deletions.
6 changes: 6 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,12 @@ def test_chained_udf(self):
[row] = self.spark.sql("SELECT double(double(1) + 1)").collect()
self.assertEqual(row[0], 6)

def test_single_udf_with_repeated_argument(self):
# regression test for SPARK-20685
self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType())
row = self.spark.sql("SELECT add(1, 1)").first()
self.assertEqual(tuple(row), (2, ))

def test_multiple_udfs(self):
self.spark.catalog.registerFunction("double", lambda x: x * 2, IntegerType())
[row] = self.spark.sql("SELECT double(1), double(2)").collect()
Expand Down
29 changes: 13 additions & 16 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,22 +87,19 @@ def read_single_udf(pickleSer, infile):

def read_udfs(pickleSer, infile):
num_udfs = read_int(infile)
if num_udfs == 1:
# fast path for single UDF
_, udf = read_single_udf(pickleSer, infile)
mapper = lambda a: udf(*a)
else:
udfs = {}
call_udf = []
for i in range(num_udfs):
arg_offsets, udf = read_single_udf(pickleSer, infile)
udfs['f%d' % i] = udf
args = ["a[%d]" % o for o in arg_offsets]
call_udf.append("f%d(%s)" % (i, ", ".join(args)))
# Create function like this:
# lambda a: (f0(a0), f1(a1, a2), f2(a3))
mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
mapper = eval(mapper_str, udfs)
udfs = {}
call_udf = []
for i in range(num_udfs):
arg_offsets, udf = read_single_udf(pickleSer, infile)
udfs['f%d' % i] = udf
args = ["a[%d]" % o for o in arg_offsets]
call_udf.append("f%d(%s)" % (i, ", ".join(args)))
# Create function like this:
# lambda a: (f0(a0), f1(a1, a2), f2(a3))
# In the special case of a single UDF this will return a single result rather
# than a tuple of results; this is the format that the JVM side expects.
mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
mapper = eval(mapper_str, udfs)

func = lambda _, it: map(mapper, it)
ser = BatchedSerializer(PickleSerializer(), 100)
Expand Down

0 comments on commit 8ddbc43

Please sign in to comment.