Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update upstream #375

Merged
merged 1 commit into from
Jun 11, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 55 additions & 16 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,22 +900,6 @@ def __call__(self, x):
self.assertEqual(f, f_.func)
self.assertEqual(return_type, f_.returnType)

def test_stopiteration_in_udf(self):
# test for SPARK-23754
from pyspark.sql.functions import udf
from py4j.protocol import Py4JJavaError

def foo(x):
raise StopIteration()

with self.assertRaises(Py4JJavaError) as cm:
self.spark.range(0, 1000).withColumn('v', udf(foo)('id')).show()

self.assertIn(
"Caught StopIteration thrown from user's code; failing the task",
cm.exception.java_exception.toString()
)

def test_validate_column_types(self):
from pyspark.sql.functions import udf, to_json
from pyspark.sql.column import _to_java_column
Expand Down Expand Up @@ -4144,6 +4128,61 @@ def foo(df):
def foo(k, v, w):
return k

def test_stopiteration_in_udf(self):
from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
from py4j.protocol import Py4JJavaError

def foo(x):
raise StopIteration()

def foofoo(x, y):
raise StopIteration()

exc_message = "Caught StopIteration thrown from user's code; failing the task"
df = self.spark.range(0, 100)

# plain udf (test for SPARK-23754)
self.assertRaisesRegexp(
Py4JJavaError,
exc_message,
df.withColumn('v', udf(foo)('id')).collect
)

# pandas scalar udf
self.assertRaisesRegexp(
Py4JJavaError,
exc_message,
df.withColumn(
'v', pandas_udf(foo, 'double', PandasUDFType.SCALAR)('id')
).collect
)

# pandas grouped map
self.assertRaisesRegexp(
Py4JJavaError,
exc_message,
df.groupBy('id').apply(
pandas_udf(foo, df.schema, PandasUDFType.GROUPED_MAP)
).collect
)

self.assertRaisesRegexp(
Py4JJavaError,
exc_message,
df.groupBy('id').apply(
pandas_udf(foofoo, df.schema, PandasUDFType.GROUPED_MAP)
).collect
)

# pandas grouped agg
self.assertRaisesRegexp(
Py4JJavaError,
exc_message,
df.groupBy('id').agg(
pandas_udf(foo, 'double', PandasUDFType.GROUPED_AGG)('id')
).collect
)


@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
Expand Down
14 changes: 2 additions & 12 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pyspark.sql.column import Column, _to_java_column, _to_seq
from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string,\
to_arrow_type, to_arrow_schema
from pyspark.util import _get_argspec, fail_on_stopiteration
from pyspark.util import _get_argspec

__all__ = ["UDFRegistration"]

Expand Down Expand Up @@ -157,17 +157,7 @@ def _create_judf(self):
spark = SparkSession.builder.getOrCreate()
sc = spark.sparkContext

func = fail_on_stopiteration(self.func)

# for pandas UDFs the worker needs to know if the function takes
# one or two arguments, but the signature is lost when wrapping with
# fail_on_stopiteration, so we store it here
if self.evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF):
func._argspec = _get_argspec(self.func)

wrapped_func = _wrap_function(sc, func, self.returnType)
wrapped_func = _wrap_function(sc, self.func, self.returnType)
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
self._name, wrapped_func, jdt, self.evalType, self.deterministic)
Expand Down
37 changes: 22 additions & 15 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,27 +1291,34 @@ def test_pipe_unicode(self):
result = rdd.pipe('cat').collect()
self.assertEqual(data, result)

def test_stopiteration_in_client_code(self):
def test_stopiteration_in_user_code(self):

def stopit(*x):
raise StopIteration()

seq_rdd = self.sc.parallelize(range(10))
keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10))

self.assertRaises(Py4JJavaError, seq_rdd.map(stopit).collect)
self.assertRaises(Py4JJavaError, seq_rdd.filter(stopit).collect)
self.assertRaises(Py4JJavaError, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect)
self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit)
self.assertRaises(Py4JJavaError, keyed_rdd.reduceByKeyLocally, stopit)
self.assertRaises(Py4JJavaError, seq_rdd.reduce, stopit)
self.assertRaises(Py4JJavaError, seq_rdd.fold, 0, stopit)

# the exception raised is non-deterministic
self.assertRaises((Py4JJavaError, RuntimeError),
seq_rdd.aggregate, 0, stopit, lambda *x: 1)
self.assertRaises((Py4JJavaError, RuntimeError),
seq_rdd.aggregate, 0, lambda *x: 1, stopit)
msg = "Caught StopIteration thrown from user's code; failing the task"

self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.map(stopit).collect)
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.filter(stopit).collect)
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit)
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.reduce, stopit)
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.fold, 0, stopit)
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit)
self.assertRaisesRegexp(Py4JJavaError, msg,
seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect)

# these methods call the user function both in the driver and in the executor
# the exception raised is different according to where the StopIteration happens
# RuntimeError is raised if in the driver
# Py4JJavaError is raised if in the executor (wraps the RuntimeError raised in the worker)
self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
keyed_rdd.reduceByKeyLocally, stopit)
self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
seq_rdd.aggregate, 0, stopit, lambda *x: 1)
self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
seq_rdd.aggregate, 0, lambda *x: 1, stopit)


class ProfilerTests(PySparkTestCase):
Expand Down
9 changes: 2 additions & 7 deletions python/pyspark/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,7 @@ def _get_argspec(f):
"""
Get argspec of a function. Supports both Python 2 and Python 3.
"""

if hasattr(f, '_argspec'):
# only used for pandas UDF: they wrap the user function, losing its signature
# workers need this signature, so UDF saves it here
argspec = f._argspec
elif sys.version_info[0] < 3:
if sys.version_info[0] < 3:
argspec = inspect.getargspec(f)
else:
# `getargspec` is deprecated since python3.0 (incompatible with function annotations).
Expand Down Expand Up @@ -97,7 +92,7 @@ def majorMinorVersion(sparkVersion):
def fail_on_stopiteration(f):
"""
Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError'
prevents silent loss of data when 'f' is used in a for loop
prevents silent loss of data when 'f' is used in a for loop in Spark code
"""
def wrapper(*args, **kwargs):
try:
Expand Down
18 changes: 11 additions & 7 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
BatchedSerializer, ArrowStreamPandasSerializer
from pyspark.sql.types import to_arrow_type
from pyspark.util import _get_argspec
from pyspark.util import _get_argspec, fail_on_stopiteration
from pyspark import shuffle

pickleSer = PickleSerializer()
Expand Down Expand Up @@ -92,10 +92,9 @@ def verify_result_length(*a):
return lambda *a: (verify_result_length(*a), arrow_return_type)


def wrap_grouped_map_pandas_udf(f, return_type):
def wrap_grouped_map_pandas_udf(f, return_type, argspec):
def wrapped(key_series, value_series):
import pandas as pd
argspec = _get_argspec(f)

if len(argspec.args) == 1:
result = f(pd.concat(value_series, axis=1))
Expand Down Expand Up @@ -140,15 +139,20 @@ def read_single_udf(pickleSer, infile, eval_type):
else:
row_func = chain(row_func, f)

# make sure StopIteration's raised in the user code are not ignored
# when they are processed in a for loop, raise them as RuntimeError's instead
func = fail_on_stopiteration(row_func)

# the last returnType will be the return type of UDF
if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
return arg_offsets, wrap_scalar_pandas_udf(row_func, return_type)
return arg_offsets, wrap_scalar_pandas_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
return arg_offsets, wrap_grouped_map_pandas_udf(row_func, return_type)
argspec = _get_argspec(row_func) # signature was lost when wrapping it
return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec)
elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
return arg_offsets, wrap_grouped_agg_pandas_udf(row_func, return_type)
return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_BATCHED_UDF:
return arg_offsets, wrap_udf(row_func, return_type)
return arg_offsets, wrap_udf(func, return_type)
else:
raise ValueError("Unknown eval type: {}".format(eval_type))

Expand Down