Skip to content

Commit

Permalink
SPARK-1259 Make RDD locally iterable
Browse files Browse the repository at this point in the history
  • Loading branch information
epahomov committed Mar 18, 2014
1 parent 33ecb17 commit 8be3dcf
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 5 deletions.
18 changes: 17 additions & 1 deletion core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

package org.apache.spark.api.java

import java.util.{Comparator, List => JList}
import java.util.{Comparator, Iterator => JIterator, List => JList}
import java.lang.{Iterable => JIterable}

import scala.Tuple2
import scala.collection.JavaConversions._
Expand Down Expand Up @@ -281,6 +282,21 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
new java.util.ArrayList(arr)
}

/**
* Return a Stream that contains all of the elements in this RDD.
*
* In case of iterating it consumes memory as the biggest partition in cluster.
*/
def toLocallyIterable(): JIterable[T] = {
new JIterable[T](){
def iterator(): JIterator[T] = {
import scala.collection.JavaConversions._
asJavaIterator(rdd.toLocallyIterable.iterator)
}
}
}


/**
* Return an array that contains all of the elements in this RDD.
*/
Expand Down
8 changes: 5 additions & 3 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -668,11 +668,13 @@ abstract class RDD[T: ClassTag](
*
* In case of iterating it consumes memory as the biggest partition in cluster.
*/
def toStream(): Stream[T] = {
def collectPartition(p: Int): Array[T] = sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p), allowLocal = false).head
def toLocallyIterable: Stream[T] = {
def collectPartition(p: Int): Array[T] = {
sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p), allowLocal = false).head
}
var buffer = Stream.empty[T]
for (p <- 0 until this.partitions.length) {
buffer = buffer #::: {
buffer = buffer append {
collectPartition(p).toStream
}
}
Expand Down
9 changes: 9 additions & 0 deletions core/src/test/java/org/apache/spark/JavaAPISuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import scala.Tuple2;

import com.google.common.collect.Lists;
import com.google.common.base.Optional;
import com.google.common.base.Charsets;
import com.google.common.io.Files;
Expand Down Expand Up @@ -149,6 +150,14 @@ public void call(String s) {
Assert.assertEquals(2, foreachCalls);
}

@Test
public void toLocallyIterable() {
List<Integer> correct = Arrays.asList(1, 2, 3, 4);
JavaRDD<Integer> rdd = sc.parallelize(correct);
List<Integer> result = Lists.newArrayList(rdd.toLocallyIterable());
Assert.assertTrue(correct.equals(result));
}

@SuppressWarnings("unchecked")
@Test
public void lookup() {
Expand Down
2 changes: 1 addition & 1 deletion core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class RDDSuite extends FunSuite with SharedSparkContext {
test("basic operations") {
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
assert(nums.collect().toList === List(1, 2, 3, 4))
assert(nums.toStream().toList === List(1, 2, 3, 4))
assert(nums.toLocallyIterable.toList === List(1, 2, 3, 4))
val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2)
assert(dups.distinct().count() === 4)
assert(dups.distinct.count === 4) // Can distinct and count be called without parentheses?
Expand Down

0 comments on commit 8be3dcf

Please sign in to comment.