-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-23754][Python] Re-raising StopIteration in client code #21383
Changes from all commits
ec7854a
fddd031
ee54924
d739eea
f0f80ed
d59f0d5
b0af18e
167a75b
90b064d
75316af
026ecdd
f7b53c2
8fac2a8
5b5570b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -900,6 +900,22 @@ def __call__(self, x): | |
self.assertEqual(f, f_.func) | ||
self.assertEqual(return_type, f_.returnType) | ||
|
||
def test_stopiteration_in_udf(self): | ||
# test for SPARK-23754 | ||
from pyspark.sql.functions import udf | ||
from py4j.protocol import Py4JJavaError | ||
|
||
def foo(x): | ||
raise StopIteration() | ||
|
||
with self.assertRaises(Py4JJavaError) as cm: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto for |
||
self.spark.range(0, 1000).withColumn('v', udf(foo)('id')).show() | ||
|
||
self.assertIn( | ||
"Caught StopIteration thrown from user's code; failing the task", | ||
cm.exception.java_exception.toString() | ||
) | ||
|
||
def test_validate_column_types(self): | ||
from pyspark.sql.functions import udf, to_json | ||
from pyspark.sql.column import _to_java_column | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,7 +25,7 @@ | |
from pyspark.sql.column import Column, _to_java_column, _to_seq | ||
from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string,\ | ||
to_arrow_type, to_arrow_schema | ||
from pyspark.util import _get_argspec | ||
from pyspark.util import _get_argspec, fail_on_stopiteration | ||
|
||
__all__ = ["UDFRegistration"] | ||
|
||
|
@@ -157,7 +157,17 @@ def _create_judf(self): | |
spark = SparkSession.builder.getOrCreate() | ||
sc = spark.sparkContext | ||
|
||
wrapped_func = _wrap_function(sc, self.func, self.returnType) | ||
func = fail_on_stopiteration(self.func) | ||
|
||
# for pandas UDFs the worker needs to know if the function takes | ||
# one or two arguments, but the signature is lost when wrapping with | ||
# fail_on_stopiteration, so we store it here | ||
if self.evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, | ||
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, | ||
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): | ||
func._argspec = _get_argspec(self.func) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it make sense for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure how to do that, though, can you suggest a way? Originally this hack was in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. I saw @HyukjinKwon comment here: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That seems to be difficult in particular in Python 2. I'm not aware of a clean and straightforward way without a hack to copy the signature in Python 2 too. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. I remember we still support pandas udf with Python 2? Does the resolution here not work with Pandas UDF and Python 2? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup, we definitely support. The current approach probably wouldn't change anything we supported before. I believe the builtin functions in Python 2 don't already with with Pandas UDFs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. Thanks for the clarification. |
||
|
||
wrapped_func = _wrap_function(sc, func, self.returnType) | ||
jdt = spark._jsparkSession.parseDataType(self.returnType.json()) | ||
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( | ||
self._name, wrapped_func, jdt, self.evalType, self.deterministic) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -161,6 +161,37 @@ def gen_gs(N, step=1): | |
self.assertEqual(k, len(vs)) | ||
self.assertEqual(list(range(k)), list(vs)) | ||
|
||
def test_stopiteration_is_raised(self): | ||
|
||
def stopit(*args, **kwargs): | ||
raise StopIteration() | ||
|
||
def legit_create_combiner(x): | ||
return [x] | ||
|
||
def legit_merge_value(x, y): | ||
return x.append(y) or x | ||
|
||
def legit_merge_combiners(x, y): | ||
return x.extend(y) or x | ||
|
||
data = [(x % 2, x) for x in range(100)] | ||
|
||
# wrong create combiner | ||
m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20) | ||
with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's pick up one explicit exception here too. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Without the change in this test, you still can get a |
||
m.mergeValues(data) | ||
|
||
# wrong merge value | ||
m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20) | ||
with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: | ||
m.mergeValues(data) | ||
|
||
# wrong merge combiners | ||
m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20) | ||
with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: | ||
m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data)) | ||
|
||
|
||
class SorterTests(unittest.TestCase): | ||
def test_in_memory_sort(self): | ||
|
@@ -1246,6 +1277,28 @@ def test_pipe_unicode(self): | |
result = rdd.pipe('cat').collect() | ||
self.assertEqual(data, result) | ||
|
||
def test_stopiteration_in_client_code(self): | ||
|
||
def stopit(*x): | ||
raise StopIteration() | ||
|
||
seq_rdd = self.sc.parallelize(range(10)) | ||
keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10)) | ||
|
||
self.assertRaises(Py4JJavaError, seq_rdd.map(stopit).collect) | ||
self.assertRaises(Py4JJavaError, seq_rdd.filter(stopit).collect) | ||
self.assertRaises(Py4JJavaError, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) | ||
self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit) | ||
self.assertRaises(Py4JJavaError, keyed_rdd.reduceByKeyLocally, stopit) | ||
self.assertRaises(Py4JJavaError, seq_rdd.reduce, stopit) | ||
self.assertRaises(Py4JJavaError, seq_rdd.fold, 0, stopit) | ||
|
||
# the exception raised is non-deterministic | ||
self.assertRaises((Py4JJavaError, RuntimeError), | ||
seq_rdd.aggregate, 0, stopit, lambda *x: 1) | ||
self.assertRaises((Py4JJavaError, RuntimeError), | ||
seq_rdd.aggregate, 0, lambda *x: 1, stopit) | ||
|
||
|
||
class ProfilerTests(PySparkTestCase): | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also add tests for pandas_udf?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can also merge this PR first. I can follow up with the pandas udf tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes please, I am not really familiar with UDFs in general