Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spark 1162 Implemented takeOrdered in pyspark. #97

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 102 additions & 5 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from tempfile import NamedTemporaryFile
from threading import Thread
import warnings
from heapq import heappush, heappop, heappushpop
import heapq

from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long
Expand All @@ -41,9 +41,9 @@

from py4j.java_collections import ListConverter, MapConverter


__all__ = ["RDD"]


def _extract_concise_traceback():
"""
This function returns the traceback info for a callsite, returns a dict
Expand Down Expand Up @@ -91,6 +91,73 @@ def __exit__(self, type, value, tb):
if _spark_stack_depth == 0:
self._context._jsc.setCallSite(None)

class MaxHeapQ(object):
"""
An implementation of MaxHeap.
>>> import pyspark.rdd
>>> heap = pyspark.rdd.MaxHeapQ(5)
>>> [heap.insert(i) for i in range(10)]
[None, None, None, None, None, None, None, None, None, None]
>>> sorted(heap.getElements())
[0, 1, 2, 3, 4]
>>> heap = pyspark.rdd.MaxHeapQ(5)
>>> [heap.insert(i) for i in range(9, -1, -1)]
[None, None, None, None, None, None, None, None, None, None]
>>> sorted(heap.getElements())
[0, 1, 2, 3, 4]
>>> heap = pyspark.rdd.MaxHeapQ(1)
>>> [heap.insert(i) for i in range(9, -1, -1)]
[None, None, None, None, None, None, None, None, None, None]
>>> heap.getElements()
[0]
"""

def __init__(self, maxsize):
# we start from q[1], this makes calculating children as trivial as 2 * k
self.q = [0]
self.maxsize = maxsize

def _swim(self, k):
while (k > 1) and (self.q[k/2] < self.q[k]):
self._swap(k, k/2)
k = k/2

def _swap(self, i, j):
t = self.q[i]
self.q[i] = self.q[j]
self.q[j] = t

def _sink(self, k):
N = self.size()
while 2 * k <= N:
j = 2 * k
# Here we test if both children are greater than parent
# if not swap with larger one.
if j < N and self.q[j] < self.q[j + 1]:
j = j + 1
if(self.q[k] > self.q[j]):
break
self._swap(k, j)
k = j

def size(self):
return len(self.q) - 1

def insert(self, value):
if (self.size()) < self.maxsize:
self.q.append(value)
self._swim(self.size())
else:
self._replaceRoot(value)

def getElements(self):
return self.q[1:]

def _replaceRoot(self, value):
if(self.q[1] > value):
self.q[1] = value
self._sink(1)

class RDD(object):
"""
A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
Expand Down Expand Up @@ -696,23 +763,53 @@ def top(self, num):
Note: It returns the list sorted in descending order.
>>> sc.parallelize([10, 4, 2, 12, 3]).top(1)
[12]
>>> sc.parallelize([2, 3, 4, 5, 6]).cache().top(2)
>>> sc.parallelize([2, 3, 4, 5, 6], 2).cache().top(2)
[6, 5]
"""
def topIterator(iterator):
q = []
for k in iterator:
if len(q) < num:
heappush(q, k)
heapq.heappush(q, k)
else:
heappushpop(q, k)
heapq.heappushpop(q, k)
yield q

def merge(a, b):
return next(topIterator(a + b))

return sorted(self.mapPartitions(topIterator).reduce(merge), reverse=True)

def takeOrdered(self, num, key=None):
"""
Get the N elements from a RDD ordered in ascending order or as specified
by the optional key function.

>>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7]).takeOrdered(6)
[1, 2, 3, 4, 5, 6]
>>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7], 2).takeOrdered(6, key=lambda x: -x)
[10, 9, 7, 6, 5, 4]
"""

def topNKeyedElems(iterator, key_=None):
q = MaxHeapQ(num)
for k in iterator:
if key_ != None:
k = (key_(k), k)
q.insert(k)
yield q.getElements()

def unKey(x, key_=None):
if key_ != None:
x = [i[1] for i in x]
return x

def merge(a, b):
return next(topNKeyedElems(a + b))
result = self.mapPartitions(lambda i: topNKeyedElems(i, key)).reduce(merge)
return sorted(unKey(result, key), key=key)


def take(self, num):
"""
Take the first num elements of the RDD.
Expand Down