From 249abdec9a964178675aab6eb91dca9adc112872 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 25 Mar 2014 17:45:13 -0700 Subject: [PATCH] org.apache.spark.rdd.PairRDDFunctionsSuite passes --- .../org/apache/spark/rdd/PairRDDFunctions.scala | 9 ++++++--- .../org/apache/spark/rdd/PairRDDFunctionsSuite.scala | 12 ++++++------ 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 35e10c14f43c0..58afa72e099bb 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -298,7 +298,8 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) */ def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = { this.cogroup(other, partitioner).flatMapValues { case (vs, ws) => - for (v <- vs; w <- ws) yield (v, w) + val wlist = ws.toList + for (v <- vs; w <- wlist.iterator) yield (v, w) } } @@ -313,7 +314,8 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) if (ws.isEmpty) { vs.map(v => (v, None)) } else { - for (v <- vs; w <- ws) yield (v, Some(w)) + val wlist = ws.toList + for (v <- vs; w <- wlist.iterator) yield (v, Some(w)) } } } @@ -330,7 +332,8 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) if (vs.isEmpty) { ws.map(w => (None, w)) } else { - for (v <- vs; w <- ws) yield (Some(v), w) + val wlist = ws.toList + for (v <- vs; w <- wlist) yield (Some(v), w) } } } diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index f9e994b13dfbc..8f3e6bd21b752 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -225,11 +225,12 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) val joined = rdd1.groupWith(rdd2).collect() assert(joined.size === 4) - assert(joined.toSet === Set( - (1, (ArrayBuffer(1, 2), ArrayBuffer('x'))), - (2, (ArrayBuffer(1), ArrayBuffer('y', 'z'))), - (3, (ArrayBuffer(1), ArrayBuffer())), - (4, (ArrayBuffer(), ArrayBuffer('w'))) + val joinedSet = joined.map(x => (x._1, (x._2._1.toList, x._2._2.toList))).toSet + assert(joinedSet === Set( + (1, (List(1, 2), List('x'))), + (2, (List(1), List('y', 'z'))), + (3, (List(1), List())), + (4, (List(), List('w'))) )) } @@ -447,4 +448,3 @@ class ConfigTestFormat() extends FakeFormat() with Configurable { super.getRecordWriter(p1) } } -