Skip to content

Commit

Permalink
improved tests
Browse files Browse the repository at this point in the history
  • Loading branch information
e-dorigatti committed May 24, 2018
1 parent d739eea commit f0f80ed
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 20 deletions.
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,7 @@ def test_stopiteration_in_udf(self):
def foo(x):
raise StopIteration()

with self.assertRaises(Py4JJavaError) as cm:
with self.assertRaises(Py4JJavaError):
self.spark.range(0, 1000).withColumn('v', udf(foo)).show()

def test_validate_column_types(self):
Expand Down
32 changes: 13 additions & 19 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,28 +1279,22 @@ def test_pipe_unicode(self):

def test_stopiteration_in_client_code(self):

def a_rdd(keyed=False):
return self.sc.parallelize(
((x % 2, x) if keyed else x)
for x in range(10)
)

def stopit(*x):
raise StopIteration()

def do_test(action, *args, **kwargs):
with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
action(*args, **kwargs)

do_test(a_rdd().map(stopit).collect)
do_test(a_rdd().filter(stopit).collect)
do_test(a_rdd().cartesian(a_rdd()).flatMap(stopit).collect)
do_test(a_rdd().foreach, stopit)
do_test(a_rdd(keyed=True).reduceByKeyLocally, stopit)
do_test(a_rdd().reduce, stopit)
do_test(a_rdd().fold, 0, stopit)
do_test(a_rdd().aggregate, 0, stopit, lambda *x: 1)
do_test(a_rdd().aggregate, 0, lambda *x: 1, stopit)
seq_rdd = self.sc.parallelize(range(10))
keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10))
exc = Py4JJavaError, RuntimeError

self.assertRaises(exc, seq_rdd.map(stopit).collect)
self.assertRaises(exc, seq_rdd.filter(stopit).collect)
self.assertRaises(exc, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect)
self.assertRaises(exc, seq_rdd.foreach, stopit)
self.assertRaises(exc, keyed_rdd.reduceByKeyLocally, stopit)
self.assertRaises(exc, seq_rdd.reduce, stopit)
self.assertRaises(exc, seq_rdd.fold, 0, stopit)
self.assertRaises(exc, seq_rdd.aggregate, 0, stopit, lambda *x: 1)
self.assertRaises(exc, seq_rdd.aggregate, 0, lambda *x: 1, stopit)


class ProfilerTests(PySparkTestCase):
Expand Down

0 comments on commit f0f80ed

Please sign in to comment.