diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index af0114bee3f49..a9a1f6046cb82 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -17,7 +17,8 @@ package org.apache.spark.api.java -import java.util.{Comparator, List => JList} +import java.util.{Comparator, Iterator => JIterator, List => JList} +import java.lang.{Iterable => JIterable} import scala.Tuple2 import scala.collection.JavaConversions._ @@ -281,6 +282,21 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { new java.util.ArrayList(arr) } + /** + * Return a Stream that contains all of the elements in this RDD. + * + * In case of iterating it consumes memory as the biggest partition in cluster. + */ + def toLocallyIterable(): JIterable[T] = { + new JIterable[T](){ + def iterator(): JIterator[T] = { + import scala.collection.JavaConversions._ + asJavaIterator(rdd.toLocallyIterable.iterator) + } + } + } + + /** * Return an array that contains all of the elements in this RDD. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 1aa494f6a4125..d787cf110c394 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -668,11 +668,13 @@ abstract class RDD[T: ClassTag]( * * In case of iterating it consumes memory as the biggest partition in cluster. */ - def toStream(): Stream[T] = { - def collectPartition(p: Int): Array[T] = sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p), allowLocal = false).head + def toLocallyIterable: Stream[T] = { + def collectPartition(p: Int): Array[T] = { + sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p), allowLocal = false).head + } var buffer = Stream.empty[T] for (p <- 0 until this.partitions.length) { - buffer = buffer #::: { + buffer = buffer append { collectPartition(p).toStream } } diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index c7d0e2d577726..4de856f498bec 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -24,6 +24,7 @@ import scala.Tuple2; +import com.google.common.collect.Lists; import com.google.common.base.Optional; import com.google.common.base.Charsets; import com.google.common.io.Files; @@ -149,6 +150,14 @@ public void call(String s) { Assert.assertEquals(2, foreachCalls); } + @Test + public void toLocallyIterable() { + List correct = Arrays.asList(1, 2, 3, 4); + JavaRDD rdd = sc.parallelize(correct); + List result = Lists.newArrayList(rdd.toLocallyIterable()); + Assert.assertTrue(correct.equals(result)); + } + @SuppressWarnings("unchecked") @Test public void lookup() { diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index ef88f2bc467e3..3b3d9b83d85a1 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -33,7 +33,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("basic operations") { val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) assert(nums.collect().toList === List(1, 2, 3, 4)) - assert(nums.toStream().toList === List(1, 2, 3, 4)) + assert(nums.toLocallyIterable.toList === List(1, 2, 3, 4)) val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2) assert(dups.distinct().count() === 4) assert(dups.distinct.count === 4) // Can distinct and count be called without parentheses?