Skip to content

Commit

Permalink
[SPARK-23754][PYTHON][FOLLOWUP] Move UDF stop iteration wrapping from…
Browse files Browse the repository at this point in the history
… driver to executor

SPARK-23754 was fixed in apache#21383 by changing the UDF code to wrap the user function, but this required a hack to save its argspec. This PR reverts this change and fixes the `StopIteration` bug in the worker

The root of the problem is that when an user-supplied function raises a `StopIteration`, pyspark might stop processing data, if this function is used in a for-loop. The solution is to catch `StopIteration`s exceptions and re-raise them as `RuntimeError`s, so that the execution fails and the error is reported to the user. This is done using the `fail_on_stopiteration` wrapper, in different ways depending on where the function is used:
 - In RDDs, the user function is wrapped in the driver, because this function is also called in the driver itself.
 - In SQL UDFs, the function is wrapped in the worker, since all processing happens there. Moreover, the worker needs the signature of the user function, which is lost when wrapping it, but passing this signature to the worker requires a not so nice hack.

Same tests, plus tests for pandas UDFs

Author: edorigatti <[email protected]>

Closes apache#21467 from e-dorigatti/fix_udf_hack.
  • Loading branch information
e-dorigatti committed Jun 12, 2018
1 parent e7db468 commit 217e730
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 35 deletions.
54 changes: 38 additions & 16 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,22 +853,6 @@ def __call__(self, x):
self.assertEqual(f, f_.func)
self.assertEqual(return_type, f_.returnType)

def test_stopiteration_in_udf(self):
# test for SPARK-23754
from pyspark.sql.functions import udf
from py4j.protocol import Py4JJavaError

def foo(x):
raise StopIteration()

with self.assertRaises(Py4JJavaError) as cm:
self.spark.range(0, 1000).withColumn('v', udf(foo)('id')).show()

self.assertIn(
"Caught StopIteration thrown from user's code; failing the task",
cm.exception.java_exception.toString()
)

def test_validate_column_types(self):
from pyspark.sql.functions import udf, to_json
from pyspark.sql.column import _to_java_column
Expand Down Expand Up @@ -3917,6 +3901,44 @@ def foo(df):
def foo(k, v):
return k

def test_stopiteration_in_udf(self):
from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
from py4j.protocol import Py4JJavaError

def foo(x):
raise StopIteration()

def foofoo(x, y):
raise StopIteration()

exc_message = "Caught StopIteration thrown from user's code; failing the task"
df = self.spark.range(0, 100)

# plain udf (test for SPARK-23754)
self.assertRaisesRegexp(
Py4JJavaError,
exc_message,
df.withColumn('v', udf(foo)('id')).collect
)

# pandas scalar udf
self.assertRaisesRegexp(
Py4JJavaError,
exc_message,
df.withColumn(
'v', pandas_udf(foo, 'double', PandasUDFType.SCALAR)('id')
).collect
)

# pandas grouped map
self.assertRaisesRegexp(
Py4JJavaError,
exc_message,
df.groupBy('id').apply(
pandas_udf(foo, df.schema, PandasUDFType.GROUPED_MAP)
).collect
)


@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
Expand Down
4 changes: 1 addition & 3 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from pyspark.sql.column import Column, _to_java_column, _to_seq
from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string, \
to_arrow_type, to_arrow_schema
from pyspark.util import fail_on_stopiteration

__all__ = ["UDFRegistration"]

Expand Down Expand Up @@ -155,8 +154,7 @@ def _create_judf(self):
spark = SparkSession.builder.getOrCreate()
sc = spark.sparkContext

func = fail_on_stopiteration(self.func)
wrapped_func = _wrap_function(sc, func, self.returnType)
wrapped_func = _wrap_function(sc, self.func, self.returnType)
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
self._name, wrapped_func, jdt, self.evalType, self.deterministic)
Expand Down
37 changes: 22 additions & 15 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,27 +1270,34 @@ def test_pipe_functions(self):
self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect)
self.assertEqual([], rdd.pipe('grep 4').collect())

def test_stopiteration_in_client_code(self):
def test_stopiteration_in_user_code(self):

def stopit(*x):
raise StopIteration()

seq_rdd = self.sc.parallelize(range(10))
keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10))

self.assertRaises(Py4JJavaError, seq_rdd.map(stopit).collect)
self.assertRaises(Py4JJavaError, seq_rdd.filter(stopit).collect)
self.assertRaises(Py4JJavaError, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect)
self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit)
self.assertRaises(Py4JJavaError, keyed_rdd.reduceByKeyLocally, stopit)
self.assertRaises(Py4JJavaError, seq_rdd.reduce, stopit)
self.assertRaises(Py4JJavaError, seq_rdd.fold, 0, stopit)

# the exception raised is non-deterministic
self.assertRaises((Py4JJavaError, RuntimeError),
seq_rdd.aggregate, 0, stopit, lambda *x: 1)
self.assertRaises((Py4JJavaError, RuntimeError),
seq_rdd.aggregate, 0, lambda *x: 1, stopit)
msg = "Caught StopIteration thrown from user's code; failing the task"

self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.map(stopit).collect)
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.filter(stopit).collect)
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit)
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.reduce, stopit)
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.fold, 0, stopit)
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit)
self.assertRaisesRegexp(Py4JJavaError, msg,
seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect)

# these methods call the user function both in the driver and in the executor
# the exception raised is different according to where the StopIteration happens
# RuntimeError is raised if in the driver
# Py4JJavaError is raised if in the executor (wraps the RuntimeError raised in the worker)
self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
keyed_rdd.reduceByKeyLocally, stopit)
self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
seq_rdd.aggregate, 0, stopit, lambda *x: 1)
self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
seq_rdd.aggregate, 0, lambda *x: 1, stopit)


class ProfilerTests(PySparkTestCase):
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _exception_message(excp):
def fail_on_stopiteration(f):
"""
Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError'
prevents silent loss of data when 'f' is used in a for loop
prevents silent loss of data when 'f' is used in a for loop in Spark code
"""
def wrapper(*args, **kwargs):
try:
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
BatchedSerializer, ArrowStreamPandasSerializer
from pyspark.sql.types import to_arrow_type
from pyspark.util import fail_on_stopiteration
from pyspark import shuffle

pickleSer = PickleSerializer()
Expand Down Expand Up @@ -122,6 +123,10 @@ def read_single_udf(pickleSer, infile, eval_type):
else:
row_func = chain(row_func, f)

# make sure StopIteration's raised in the user code are not ignored
# when they are processed in a for loop, raise them as RuntimeError's instead
row_func = fail_on_stopiteration(row_func)

# the last returnType will be the return type of UDF
if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
return arg_offsets, wrap_scalar_pandas_udf(row_func, return_type)
Expand Down

0 comments on commit 217e730

Please sign in to comment.