Skip to content

Commit

Permalink
refactor of queueStream()
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Sep 28, 2014
1 parent 26ea396 commit 7001b51
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 47 deletions.
11 changes: 7 additions & 4 deletions python/pyspark/streaming/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _check_serialzers(self, rdds):
# reset them to sc.serializer
rdds[i] = rdds[i].map(lambda x: x, preservesPartitioning=True)

def queueStream(self, queue, oneAtATime=False, default=None):
def queueStream(self, queue, oneAtATime=True, default=None):
"""
Create an input stream from an queue of RDDs or list. In each batch,
it will process either one or all of the RDDs returned by the queue.
Expand All @@ -200,9 +200,12 @@ def queueStream(self, queue, oneAtATime=False, default=None):
self._check_serialzers(rdds)
jrdds = ListConverter().convert([r._jrdd for r in rdds],
SparkContext._gateway._gateway_client)
jdstream = self._jvm.PythonDataInputStream(self._jssc, jrdds, oneAtATime,
default and default._jrdd)
return DStream(jdstream.asJavaDStream(), self, rdds[0]._jrdd_deserializer)
queue = self._jvm.PythonDStream.toRDDQueue(jrdds)
if default:
jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd)
else:
jdstream = self._jssc.queueStream(queue, oneAtATime)
return DStream(jdstream, self, rdds[0]._jrdd_deserializer)

def transform(self, dstreams, transformFunc):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.streaming.api.python

import java.util.{ArrayList => JArrayList}
import scala.collection.JavaConversions._

import org.apache.spark.rdd.RDD
import org.apache.spark.api.java._
Expand Down Expand Up @@ -65,6 +66,16 @@ abstract class PythonDStream(parent: DStream[_]) extends DStream[Array[Byte]] (p
val asJavaDStream = JavaDStream.fromDStream(this)
}

object PythonDStream {

// convert list of RDD into queue of RDDs, for ssc.queueStream()
def toRDDQueue(rdds: JArrayList[JavaRDD[Array[Byte]]]): java.util.Queue[JavaRDD[Array[Byte]]] = {
val queue = new java.util.LinkedList[JavaRDD[Array[Byte]]]
rdds.forall(queue.add(_))
queue
}
}

/**
* Transformed DStream in Python.
*
Expand Down Expand Up @@ -243,46 +254,4 @@ class PythonForeachDStream(
) {

this.register()
}


/**
* similar to QueueInputStream
*/
class PythonDataInputStream(
ssc_ : JavaStreamingContext,
inputRDDs: JArrayList[JavaRDD[Array[Byte]]],
oneAtAtime: Boolean,
defaultRDD: JavaRDD[Array[Byte]]
) extends InputDStream[Array[Byte]](JavaStreamingContext.toStreamingContext(ssc_)) {

val emptyRDD = if (defaultRDD != null) {
Some(defaultRDD.rdd)
} else {
Some(ssc.sparkContext.emptyRDD[Array[Byte]])
}

def start() {}

def stop() {}

def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
val index = ((validTime - zeroTime) / slideDuration - 1).toInt
if (oneAtAtime) {
if (index == 0) {
val rdds = inputRDDs.toArray.map(_.asInstanceOf[JavaRDD[Array[Byte]]].rdd).toSeq
Some(ssc.sparkContext.union(rdds))
} else {
emptyRDD
}
} else {
if (index < inputRDDs.size()) {
Some(inputRDDs.get(index).rdd)
} else {
emptyRDD
}
}
}

val asJavaDStream = JavaDStream.fromDStream(this)
}
}

0 comments on commit 7001b51

Please sign in to comment.