Skip to content

Commit

Permalink
support updateStateByKey
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Sep 26, 2014
1 parent d357b70 commit c28f520
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 21 deletions.
30 changes: 21 additions & 9 deletions python/pyspark/streaming/dstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,8 +366,9 @@ def reduceByKeyAndWindow(self, func, invFunc,
windowDuration, slideDuration, numPartitions=None):
reduced = self.reduceByKey(func)

def reduceFunc(a, t):
return a.reduceByKey(func, numPartitions)
def reduceFunc(a, b, t):
b = b.reduceByKey(func, numPartitions)
return a.union(b).reduceByKey(func, numPartitions) if a else b

def invReduceFunc(a, b, t):
b = b.reduceByKey(func, numPartitions)
Expand All @@ -378,19 +379,30 @@ def invReduceFunc(a, b, t):
windowDuration = Seconds(windowDuration)
if not isinstance(slideDuration, Duration):
slideDuration = Seconds(slideDuration)
serializer = reduced._jrdd_deserializer
jreduceFunc = RDDFunction(self.ctx, reduceFunc, reduced._jrdd_deserializer)
jreduceFunc = RDDFunction2(self.ctx, reduceFunc, reduced._jrdd_deserializer)
jinvReduceFunc = RDDFunction2(self.ctx, invReduceFunc, reduced._jrdd_deserializer)
dstream = self.ctx._jvm.PythonReducedWindowedDStream(reduced._jdstream.dstream(),
jreduceFunc, jinvReduceFunc,
windowDuration._jduration,
slideDuration._jduration)
return DStream(dstream.asJavaDStream(), self._ssc, serializer)
return DStream(dstream.asJavaDStream(), self._ssc, self.ctx.serializer)

def updateStateByKey(self, updateFunc, numPartitions=None):
"""
:param updateFunc: [(k, vs, s)] -> [(k, s)]
"""
def reduceFunc(a, b, t):
if a is None:
g = b.groupByKey(numPartitions).map(lambda (k, vs): (k, list(vs), None))
else:
g = a.cogroup(b).map(lambda (k, (va, vb)):
(k, list(vb), list(va)[0] if len(va) else None))
return g.mapPartitions(lambda x: updateFunc(x) or [])

def updateStateByKey(self, updateFunc):
# FIXME: convert updateFunc to java JFunction2
jFunc = updateFunc
return self._jdstream.updateStateByKey(jFunc)
jreduceFunc = RDDFunction2(self.ctx, reduceFunc,
self.ctx.serializer, self._jrdd_deserializer)
dstream = self.ctx._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc)
return DStream(dstream.asJavaDStream(), self._ssc, self.ctx.serializer)


class TransformedDStream(DStream):
Expand Down
19 changes: 19 additions & 0 deletions python/pyspark/streaming/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,25 @@ def func(dstream):
[('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]]
self._test_func(input, func, expected)

def update_state_by_key(self):

def updater(it):
for k, vs, s in it:
if not s:
s = vs
else:
s.extend(vs)
yield (k, s)

input = [[('k', i)] for i in range(5)]

def func(dstream):
return dstream.updateStateByKey(updater)

expected = [[0], [0, 1], [0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]]
expected = [[('k', v)] for v in expected]
self._test_func(input, func, expected)


class TestStreamingContext(unittest.TestCase):
def setUp(self):
Expand Down
11 changes: 6 additions & 5 deletions python/pyspark/streaming/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,16 @@ class RDDFunction2(object):
This class is for py4j callback. This class is related with
org.apache.spark.streaming.api.python.PythonRDDFunction2.
"""
def __init__(self, ctx, func, jrdd_deserializer):
def __init__(self, ctx, func, jrdd_deserializer, jrdd_deserializer2=None):
self.ctx = ctx
self.func = func
self.deserializer = jrdd_deserializer
self.jrdd_deserializer = jrdd_deserializer
self.jrdd_deserializer2 = jrdd_deserializer2 or jrdd_deserializer

def call(self, jrdd, jrdd2, milliseconds):
try:
rdd = RDD(jrdd, self.ctx, self.deserializer) if jrdd else None
other = RDD(jrdd2, self.ctx, self.deserializer) if jrdd2 else None
rdd = RDD(jrdd, self.ctx, self.jrdd_deserializer) if jrdd else None
other = RDD(jrdd2, self.ctx, self.jrdd_deserializer2) if jrdd2 else None
r = self.func(rdd, other, milliseconds)
if r:
return r._jrdd
Expand All @@ -67,7 +68,7 @@ def call(self, jrdd, jrdd2, milliseconds):
traceback.print_exc()

def __repr__(self):
return "RDDFunction(%s, %s)" % (str(self.deserializer), str(self.func))
return "RDDFunction2(%s)" % (str(self.func))

class Java:
implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction2']
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ private[spark] class PythonTransformed2DStream (parent: DStream[_], parent2: DSt
private[spark]
class PythonReducedWindowedDStream(
parent: DStream[Array[Byte]],
reduceFunc: PythonRDDFunction,
reduceFunc: PythonRDDFunction2,
invReduceFunc: PythonRDDFunction2,
_windowDuration: Duration,
_slideDuration: Duration
Expand Down Expand Up @@ -149,10 +149,6 @@ class PythonReducedWindowedDStream(
override def parentRememberDuration: Duration = rememberDuration + windowDuration

override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
None
val reduceF = reduceFunc
val invReduceF = invReduceFunc

val currentTime = validTime
val currentWindow = new Interval(currentTime - windowDuration + parent.slideDuration,
currentTime)
Expand Down Expand Up @@ -196,7 +192,7 @@ class PythonReducedWindowedDStream(
parent.slice(previousWindow.endTime, currentWindow.endTime - parent.slideDuration)

if (newRDDs.size > 0) {
Some(reduceFunc.call(JavaRDD.fromRDD(ssc.sc.union(newRDDs).union(subbed)), validTime.milliseconds))
Some(reduceFunc.call(JavaRDD.fromRDD(subbed), JavaRDD.fromRDD(ssc.sc.union(newRDDs)), validTime.milliseconds))
} else {
Some(subbed)
}
Expand All @@ -205,7 +201,7 @@ class PythonReducedWindowedDStream(
val currentRDDs =
parent.slice(currentWindow.beginTime, currentWindow.endTime - parent.slideDuration)
if (currentRDDs.size > 0) {
Some(reduceFunc.call(JavaRDD.fromRDD(ssc.sc.union(currentRDDs)), validTime.milliseconds))
Some(reduceFunc.call(null, JavaRDD.fromRDD(ssc.sc.union(currentRDDs)), validTime.milliseconds))
} else {
None
}
Expand All @@ -216,6 +212,40 @@ class PythonReducedWindowedDStream(
}


/**
* Copied from ReducedWindowedDStream
*/
private[spark]
class PythonStateDStream(
parent: DStream[Array[Byte]],
reduceFunc: PythonRDDFunction2
) extends DStream[Array[Byte]](parent.ssc) {

super.persist(StorageLevel.MEMORY_ONLY)

override def dependencies = List(parent)

override def slideDuration: Duration = parent.slideDuration

override val mustCheckpoint = true

override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
val lastState = getOrCompute(validTime - slideDuration)
val newRDD = parent.getOrCompute(validTime)
if (newRDD.isDefined) {
if (lastState.isDefined) {
Some(reduceFunc.call(JavaRDD.fromRDD(lastState.get), JavaRDD.fromRDD(newRDD.get), validTime.milliseconds))
} else {
Some(reduceFunc.call(null, JavaRDD.fromRDD(newRDD.get), validTime.milliseconds))
}
} else {
lastState
}
}

val asJavaDStream = JavaDStream.fromDStream(this)
}

/**
* This is used for foreachRDD() in Python
*/
Expand Down

0 comments on commit c28f520

Please sign in to comment.