diff --git a/bin/pyspark b/bin/pyspark index 118e6851af7a0..5142411e36974 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -87,7 +87,11 @@ export PYSPARK_SUBMIT_ARGS if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR - exec "$PYSPARK_PYTHON" $1 + if [[ -n "$PYSPARK_DOC_TEST" ]]; then + exec "$PYSPARK_PYTHON" -m doctest $1 + else + exec "$PYSPARK_PYTHON" $1 + fi exit fi diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 9aa3db7ccf1dd..ccbca67656c8d 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -256,8 +256,3 @@ def _start_update_server(): thread.daemon = True thread.start() return server - - -if __name__ == "__main__": - import doctest - doctest.testmod() diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index e666dd9800256..94bebc310bad6 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -526,8 +526,3 @@ def write_int(value, stream): def write_with_length(obj, stream): write_int(len(obj), stream) stream.write(obj) - - -if __name__ == "__main__": - import doctest - doctest.testmod() diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 698978e61ffad..09d2670cc1962 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -341,6 +341,25 @@ def func(a, b): expected = [[('a', (1, None)), ('b', (2, 3)), ('c', (None, 4))]] self._test_func(input, func, expected, True, input2) + def update_state_by_key(self): + + def updater(it): + for k, vs, s in it: + if not s: + s = vs + else: + s.extend(vs) + yield (k, s) + + input = [[('k', i)] for i in range(5)] + + def func(dstream): + return dstream.updateStateByKey(updater) + + expected = [[0], [0, 1], [0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]] + expected = [[('k', v)] for v in expected] + self._test_func(input, func, expected) + class TestWindowFunctions(PySparkStreamingTestCase): @@ -398,25 +417,6 @@ def test_reduce_by_invalid_window(self): self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 0.1, 0.1)) self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 1, 0.1)) - def update_state_by_key(self): - - def updater(it): - for k, vs, s in it: - if not s: - s = vs - else: - s.extend(vs) - yield (k, s) - - input = [[('k', i)] for i in range(5)] - - def func(dstream): - return dstream.updateStateByKey(updater) - - expected = [[0], [0, 1], [0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]] - expected = [[('k', v)] for v in expected] - self._test_func(input, func, expected) - class TestStreamingContext(PySparkStreamingTestCase): diff --git a/python/run-tests b/python/run-tests index e86e0729cf65e..c5cb580f77fd2 100755 --- a/python/run-tests +++ b/python/run-tests @@ -48,39 +48,6 @@ function run_test() { fi } -function run_core_tests() { - run_test "pyspark/conf.py" - run_test "pyspark/context.py" - run_test "pyspark/broadcast.py" - run_test "pyspark/accumulators.py" - run_test "pyspark/serializers.py" - run_test "pyspark/shuffle.py" - run_test "pyspark/rdd.py" - run_test "pyspark/tests.py" -} - -function run_sql_tests() { - run_test "pyspark/sql.py" -} - -function run_mllib_tests() { - run_test "pyspark/mllib/util.py" - run_test "pyspark/mllib/linalg.py" - run_test "pyspark/mllib/classification.py" - run_test "pyspark/mllib/clustering.py" - run_test "pyspark/mllib/random.py" - run_test "pyspark/mllib/recommendation.py" - run_test "pyspark/mllib/regression.py" - run_test "pyspark/mllib/stat.py" - run_test "pyspark/mllib/tree.py" - run_test "pyspark/mllib/tests.py" -} - -function run_streaming_tests() { - run_test "pyspark/streaming/util.py" - run_test "pyspark/streaming/tests.py" -} - echo "Running PySpark tests. Output is in python/unit-tests.log." export PYSPARK_PYTHON="python" @@ -93,10 +60,31 @@ fi echo "Testing with Python version:" $PYSPARK_PYTHON --version -run_core_tests -run_sql_tests -run_mllib_tests -run_streaming_tests +run_test "pyspark/rdd.py" +run_test "pyspark/context.py" +run_test "pyspark/conf.py" +run_test "pyspark/sql.py" +# These tests are included in the module-level docs, and so must +# be handled on a higher level rather than within the python file. +export PYSPARK_DOC_TEST=1 +run_test "pyspark/broadcast.py" +run_test "pyspark/accumulators.py" +run_test "pyspark/serializers.py" +unset PYSPARK_DOC_TEST +run_test "pyspark/shuffle.py" +run_test "pyspark/tests.py" +run_test "pyspark/mllib/classification.py" +run_test "pyspark/mllib/clustering.py" +run_test "pyspark/mllib/linalg.py" +run_test "pyspark/mllib/random.py" +run_test "pyspark/mllib/recommendation.py" +run_test "pyspark/mllib/regression.py" +run_test "pyspark/mllib/stat.py" +run_test "pyspark/mllib/tests.py" +run_test "pyspark/mllib/tree.py" +run_test "pyspark/mllib/util.py" +run_test "pyspark/streaming/util.py" +run_test "pyspark/streaming/tests.py" # Try to test with PyPy if [ $(which pypy) ]; then @@ -104,10 +92,21 @@ if [ $(which pypy) ]; then echo "Testing with PyPy version:" $PYSPARK_PYTHON --version - run_core_tests - run_sql_tests - run_mllib_tests - run_streaming_tests + run_test "pyspark/rdd.py" + run_test "pyspark/context.py" + run_test "pyspark/conf.py" + run_test "pyspark/sql.py" + # These tests are included in the module-level docs, and so must + # be handled on a higher level rather than within the python file. + export PYSPARK_DOC_TEST=1 + run_test "pyspark/broadcast.py" + run_test "pyspark/accumulators.py" + run_test "pyspark/serializers.py" + unset PYSPARK_DOC_TEST + run_test "pyspark/shuffle.py" + run_test "pyspark/tests.py" + run_test "pyspark/streaming/util.py" + run_test "pyspark/streaming/tests.py" fi if [[ $FAILED == 0 ]]; then