diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 487eb19c3b98a..4a3941de8a6a6 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -900,22 +900,6 @@ def __call__(self, x): self.assertEqual(f, f_.func) self.assertEqual(return_type, f_.returnType) - def test_stopiteration_in_udf(self): - # test for SPARK-23754 - from pyspark.sql.functions import udf - from py4j.protocol import Py4JJavaError - - def foo(x): - raise StopIteration() - - with self.assertRaises(Py4JJavaError) as cm: - self.spark.range(0, 1000).withColumn('v', udf(foo)('id')).show() - - self.assertIn( - "Caught StopIteration thrown from user's code; failing the task", - cm.exception.java_exception.toString() - ) - def test_validate_column_types(self): from pyspark.sql.functions import udf, to_json from pyspark.sql.column import _to_java_column @@ -4144,6 +4128,61 @@ def foo(df): def foo(k, v, w): return k + def test_stopiteration_in_udf(self): + from pyspark.sql.functions import udf, pandas_udf, PandasUDFType + from py4j.protocol import Py4JJavaError + + def foo(x): + raise StopIteration() + + def foofoo(x, y): + raise StopIteration() + + exc_message = "Caught StopIteration thrown from user's code; failing the task" + df = self.spark.range(0, 100) + + # plain udf (test for SPARK-23754) + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.withColumn('v', udf(foo)('id')).collect + ) + + # pandas scalar udf + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.withColumn( + 'v', pandas_udf(foo, 'double', PandasUDFType.SCALAR)('id') + ).collect + ) + + # pandas grouped map + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.groupBy('id').apply( + pandas_udf(foo, df.schema, PandasUDFType.GROUPED_MAP) + ).collect + ) + + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.groupBy('id').apply( + pandas_udf(foofoo, df.schema, PandasUDFType.GROUPED_MAP) + ).collect + ) + + # pandas grouped agg + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.groupBy('id').agg( + pandas_udf(foo, 'double', PandasUDFType.GROUPED_AGG)('id') + ).collect + ) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index c8fb49d7c2b65..9dbe49b831cef 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -25,7 +25,7 @@ from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string,\ to_arrow_type, to_arrow_schema -from pyspark.util import _get_argspec, fail_on_stopiteration +from pyspark.util import _get_argspec __all__ = ["UDFRegistration"] @@ -157,17 +157,7 @@ def _create_judf(self): spark = SparkSession.builder.getOrCreate() sc = spark.sparkContext - func = fail_on_stopiteration(self.func) - - # for pandas UDFs the worker needs to know if the function takes - # one or two arguments, but the signature is lost when wrapping with - # fail_on_stopiteration, so we store it here - if self.evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, - PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): - func._argspec = _get_argspec(self.func) - - wrapped_func = _wrap_function(sc, func, self.returnType) + wrapped_func = _wrap_function(sc, self.func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( self._name, wrapped_func, jdt, self.evalType, self.deterministic) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 30723b8e15b36..18b2f251dc9fd 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1291,27 +1291,34 @@ def test_pipe_unicode(self): result = rdd.pipe('cat').collect() self.assertEqual(data, result) - def test_stopiteration_in_client_code(self): + def test_stopiteration_in_user_code(self): def stopit(*x): raise StopIteration() seq_rdd = self.sc.parallelize(range(10)) keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10)) - - self.assertRaises(Py4JJavaError, seq_rdd.map(stopit).collect) - self.assertRaises(Py4JJavaError, seq_rdd.filter(stopit).collect) - self.assertRaises(Py4JJavaError, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) - self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit) - self.assertRaises(Py4JJavaError, keyed_rdd.reduceByKeyLocally, stopit) - self.assertRaises(Py4JJavaError, seq_rdd.reduce, stopit) - self.assertRaises(Py4JJavaError, seq_rdd.fold, 0, stopit) - - # the exception raised is non-deterministic - self.assertRaises((Py4JJavaError, RuntimeError), - seq_rdd.aggregate, 0, stopit, lambda *x: 1) - self.assertRaises((Py4JJavaError, RuntimeError), - seq_rdd.aggregate, 0, lambda *x: 1, stopit) + msg = "Caught StopIteration thrown from user's code; failing the task" + + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.map(stopit).collect) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.filter(stopit).collect) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.reduce, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.fold, 0, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, + seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) + + # these methods call the user function both in the driver and in the executor + # the exception raised is different according to where the StopIteration happens + # RuntimeError is raised if in the driver + # Py4JJavaError is raised if in the executor (wraps the RuntimeError raised in the worker) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + keyed_rdd.reduceByKeyLocally, stopit) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + seq_rdd.aggregate, 0, stopit, lambda *x: 1) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + seq_rdd.aggregate, 0, lambda *x: 1, stopit) class ProfilerTests(PySparkTestCase): diff --git a/python/pyspark/util.py b/python/pyspark/util.py index e95a9b523393f..f015542c8799d 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -53,12 +53,7 @@ def _get_argspec(f): """ Get argspec of a function. Supports both Python 2 and Python 3. """ - - if hasattr(f, '_argspec'): - # only used for pandas UDF: they wrap the user function, losing its signature - # workers need this signature, so UDF saves it here - argspec = f._argspec - elif sys.version_info[0] < 3: + if sys.version_info[0] < 3: argspec = inspect.getargspec(f) else: # `getargspec` is deprecated since python3.0 (incompatible with function annotations). @@ -97,7 +92,7 @@ def majorMinorVersion(sparkVersion): def fail_on_stopiteration(f): """ Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError' - prevents silent loss of data when 'f' is used in a for loop + prevents silent loss of data when 'f' is used in a for loop in Spark code """ def wrapper(*args, **kwargs): try: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index fbcb8af8bfb24..a30d6bf523a50 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -35,7 +35,7 @@ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer from pyspark.sql.types import to_arrow_type -from pyspark.util import _get_argspec +from pyspark.util import _get_argspec, fail_on_stopiteration from pyspark import shuffle pickleSer = PickleSerializer() @@ -92,10 +92,9 @@ def verify_result_length(*a): return lambda *a: (verify_result_length(*a), arrow_return_type) -def wrap_grouped_map_pandas_udf(f, return_type): +def wrap_grouped_map_pandas_udf(f, return_type, argspec): def wrapped(key_series, value_series): import pandas as pd - argspec = _get_argspec(f) if len(argspec.args) == 1: result = f(pd.concat(value_series, axis=1)) @@ -140,15 +139,20 @@ def read_single_udf(pickleSer, infile, eval_type): else: row_func = chain(row_func, f) + # make sure StopIteration's raised in the user code are not ignored + # when they are processed in a for loop, raise them as RuntimeError's instead + func = fail_on_stopiteration(row_func) + # the last returnType will be the return type of UDF if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: - return arg_offsets, wrap_scalar_pandas_udf(row_func, return_type) + return arg_offsets, wrap_scalar_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: - return arg_offsets, wrap_grouped_map_pandas_udf(row_func, return_type) + argspec = _get_argspec(row_func) # signature was lost when wrapping it + return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: - return arg_offsets, wrap_grouped_agg_pandas_udf(row_func, return_type) + return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_BATCHED_UDF: - return arg_offsets, wrap_udf(row_func, return_type) + return arg_offsets, wrap_udf(func, return_type) else: raise ValueError("Unknown eval type: {}".format(eval_type))