Skip to content

Commit

Permalink
WIP added test case
Browse files Browse the repository at this point in the history
  • Loading branch information
giwa committed Sep 20, 2014
1 parent 9ad6855 commit ce2acd2
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,6 @@ private[spark] object PythonRDD extends Logging {
} catch {
case eof: EOFException => {}
}
println("RDDDD ==================")
println(objs)
JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
}

Expand Down
25 changes: 16 additions & 9 deletions examples/src/main/python/streaming/test_oprations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,22 @@
conf = SparkConf()
conf.setAppName("PythonStreamingNetworkWordCount")
ssc = StreamingContext(conf=conf, duration=Seconds(1))
ssc.checkpoint("/tmp/spark_ckp")

test_input = ssc._testInputStream([[1],[1],[1]])
# ssc.checkpoint("/tmp/spark_ckp")
fm_test = test_input.flatMap(lambda x: x.split(" "))
mapped_test = fm_test.map(lambda x: (x, 1))
test_input = ssc._testInputStream([1,2,3])
class buff:
pass

fm_test = test_input.map(lambda x: (x, 1))
fm_test.test_output(buff)


mapped_test.print_()
ssc.start()
# ssc.awaitTermination()
# ssc.stop()
while True:
ssc.awaitTermination(50)
try:
buff.result
break
except AttributeError:
pass

ssc.stop()
print buff.result
16 changes: 7 additions & 9 deletions python/pyspark/streaming/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ def awaitTermination(self, timeout=None):
"""
Wait for the execution to stop.
"""
if timeout:
self._jssc.awaitTermination(timeout)
else:
if timeout is None:
self._jssc.awaitTermination()
else:
self._jssc.awaitTermination(timeout)

# start from simple one. storageLevel is not passed for now.
def socketTextStream(self, hostname, port):
Expand Down Expand Up @@ -137,6 +137,7 @@ def stop(self, stopSparkContext=True):

def checkpoint(self, directory):
"""
Not tested
"""
self._jssc.checkpoint(directory)

Expand All @@ -147,8 +148,7 @@ def _testInputStream(self, test_input, numSlices=None):
# because it sends O(n) Py4J commands. As an alternative, serialized
# objects are written to a file and loaded through textFile().

#tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
tempFile = open("/tmp/spark_rdd", "wb")
tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)

# Make sure we distribute data evenly if it's smaller than self.batchSize
if "__len__" not in dir(test_input):
Expand All @@ -160,10 +160,8 @@ def _testInputStream(self, test_input, numSlices=None):
else:
serializer = self._sc._unbatched_serializer
serializer.dump_stream(test_input, tempFile)
tempFile.flush()
tempFile.close()
print tempFile.name

jinput_stream = self._jvm.PythonTestInputStream(self._jssc,
tempFile.name,
numSlices).asJavaDStream()
return DStream(jinput_stream, self, UTF8Deserializer())
return DStream(jinput_stream, self, PickleSerializer())
22 changes: 18 additions & 4 deletions python/pyspark/streaming/dstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _sum(self):
"""
return self._mapPartitions(lambda x: [sum(x)]).reduce(operator.add)

def print_(self):
def print_(self, label=None):
"""
Since print is reserved name for python, we cannot define a print method function.
This function prints serialized data in RDD in DStream because Scala and Java cannot
Expand All @@ -56,7 +56,7 @@ def print_(self):
Call DStream.print().
"""
# a hack to call print function in DStream
getattr(self._jdstream, "print")()
getattr(self._jdstream, "print")(label)

def filter(self, f):
"""
Expand Down Expand Up @@ -230,6 +230,7 @@ def pyprint(self):
"""
def takeAndPrint(rdd, time):
print "take and print ==================="
taken = rdd.take(11)
print "-------------------------------------------"
print "Time: %s" % (str(time))
Expand All @@ -242,11 +243,24 @@ def takeAndPrint(rdd, time):

self.foreachRDD(takeAndPrint)

#def transform(self, func):
#def transform(self, func): - TD
# from utils import RDDFunction
# wrapped_func = RDDFunction(self.ctx, self._jrdd_deserializer, func)
# jdstream = self.ctx._jvm.PythonTransformedDStream(self._jdstream.dstream(), wrapped_func).toJavaDStream
# return DStream(jdstream, self._ssc, ...) ## DO NOT KNOW HOW
# return DStream(jdstream, self._ssc, ...) ## DO NOT KNOW HOW

def _test_output(self, buff):
"""
This function is only for testcase.
Store data in dstream to buffer to valify the result in tesecase
"""
def get_output(rdd, time):
taken = rdd.take(11)
buff.result = taken
self.foreachRDD(get_output)

def output(self):
self._jdstream.outputToFile()


class PipelinedDStream(DStream):
Expand Down
62 changes: 57 additions & 5 deletions python/pyspark/streaming_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@
Unit tests for PySpark; additional tests are implemented as doctests in
individual modules.
This file will merged to tests.py. But for now, this file is separated to
focus to streaming test case
This file will merged to tests.py. But for now, this file is separated due
to focusing to streaming test case
"""
from fileinput import input
from glob import glob
from itertools import chain
import os
import re
import shutil
Expand All @@ -41,18 +42,69 @@

SPARK_HOME = os.environ["SPARK_HOME"]

class buff:
"""
Buffer for store the output from stream
"""
result = None

class PySparkStreamingTestCase(unittest.TestCase):

def setUp(self):
self._old_sys_path = list(sys.path)
print "set up"
class_name = self.__class__.__name__
self.ssc = StreamingContext(appName=class_name, duration=Seconds(1))

def tearDown(self):
print "tear donw"
self.ssc.stop()
sys.path = self._old_sys_path
time.sleep(10)

class TestBasicOperationsSuite(PySparkStreamingTestCase):
def setUp(self):
PySparkStreamingTestCase.setUp(self)
buff.result = None
self.timeout = 10 # seconds

def tearDown(self):
PySparkStreamingTestCase.tearDown(self)

def test_map(self):
test_input = [range(1,5), range(5,9), range(9, 13)]
def test_func(dstream):
return dstream.map(lambda x: str(x))
expected = map(str, test_input)
output = self.run_stream(test_input, test_func)
self.assertEqual(output, expected)

def test_flatMap(self):
test_input = [range(1,5), range(5,9), range(9, 13)]
def test_func(dstream):
return dstream.flatMap(lambda x: (x, x * 2))
# Maybe there be good way to create flatmap
excepted = map(lambda x: list(chain.from_iterable((map(lambda y:[y, y*2], x)))),
test_input)
output = self.run_stream(test_input, test_func)

def run_stream(self, test_input, test_func):
# Generate input stream with user-defined input
test_input_stream = self.ssc._testInputStream(test_input)
# Applyed test function to stream
test_stream = test_func(test_input_stream)
# Add job to get outpuf from stream
test_stream._test_output(buff)
self.ssc.start()

start_time = time.time()
while True:
current_time = time.time()
# check time out
if (current_time - start_time) > self.timeout:
self.ssc.stop()
break
self.ssc.awaitTermination(50)
if buff.result is not None:
break
return buff.result

if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
dstream.print()
}

def print(label: String = null): Unit = {
dstream.print(label)
}

def outputToFile(): Unit = {
dstream.outputToFile()
}


/**
* Return a new DStream in which each RDD has a single element generated by counting each RDD
* of this DStream.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@

package org.apache.spark.streaming.api.python

import java.io._
import java.io.{ObjectInputStream, IOException}
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
import scala.collection.JavaConversions._


import org.apache.spark._
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -51,6 +56,8 @@ class PythonDStream[T: ClassTag](
override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
parent.getOrCompute(validTime) match{
case Some(rdd) =>
logInfo("RDD ID in python DStream ===========")
logInfo("RDD id " + rdd.id)
val pythonRDD = new PythonRDD(rdd, command, envVars, pythonIncludes, preservePartitoning, pythonExec, broadcastVars, accumulator)
Some(pythonRDD.asJavaRDD.rdd)
case None => None
Expand Down Expand Up @@ -152,7 +159,7 @@ DStream[Array[Byte]](prev.ssc){
val pairwiseRDD = new PairwiseRDD(rdd)
/*
* Since python operation is executed by Scala after StreamingContext.start.
* What PairwiseDStream does is equivalent to following python code in pySpark.
* What PythonPairwiseDStream does is equivalent to python code in pySpark.
*
* with _JavaStackTrace(self.context) as st:
* pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,23 @@ abstract class DStream[T: ClassTag] (
new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register()
}


def print(label: String = null) {
def foreachFunc = (rdd: RDD[T], time: Time) => {
val first11 = rdd.take(11)
println ("-------------------------------------------")
println ("Time: " + time)
println ("-------------------------------------------")
if(label != null){
println (label)
}
first11.take(10).foreach(println)
if (first11.size > 10) println("...")
println()
}
new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register()
}

/**
* Return a new DStream in which each RDD contains all the elements in seen in a
* sliding window of time over this DStream. The new DStream generates RDDs with
Expand Down

0 comments on commit ce2acd2

Please sign in to comment.