diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 7021a339e879b..658e8c8b89318 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -29,15 +29,16 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap, CompactBuffer} import org.apache.spark.util.Utils import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.ShuffleHandle - -private[spark] sealed trait CoGroupSplitDep extends Serializable +/** The references to rdd and splitIndex are transient because redundant information is stored + * in the CoGroupedRDD object. Because CoGroupedRDD is serialized separately from + * CoGroupPartition, if rdd and splitIndex aren't transient, they'll be included twice in the + * task closure. */ private[spark] case class NarrowCoGroupSplitDep( - rdd: RDD[_], - splitIndex: Int, + @transient rdd: RDD[_], + @transient splitIndex: Int, var split: Partition - ) extends CoGroupSplitDep { + ) extends Serializable { @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { @@ -47,9 +48,16 @@ private[spark] case class NarrowCoGroupSplitDep( } } -private[spark] case class ShuffleCoGroupSplitDep(handle: ShuffleHandle) extends CoGroupSplitDep - -private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]) +/** + * Stores information about the narrow dependencies used by a CoGroupedRdd. + * + * @param narrowDeps maps to the dependencies variable in the parent RDD: for each one to one + * dependency in dependencies, narrowDeps has a NarrowCoGroupSplitDep (describing + * the partition for that dependency) at the corresponding index. The size of + * narrowDeps should always be equal to the number of parents. + */ +private[spark] class CoGroupPartition( + idx: Int, val narrowDeps: Array[Option[NarrowCoGroupSplitDep]]) extends Partition with Serializable { override val index: Int = idx override def hashCode(): Int = idx @@ -105,9 +113,9 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: // Assume each RDD contributed a single dependency, and get it dependencies(j) match { case s: ShuffleDependency[_, _, _] => - new ShuffleCoGroupSplitDep(s.shuffleHandle) + None case _ => - new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)) + Some(new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))) } }.toArray) } @@ -120,20 +128,21 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: val sparkConf = SparkEnv.get.conf val externalSorting = sparkConf.getBoolean("spark.shuffle.spill", true) val split = s.asInstanceOf[CoGroupPartition] - val numRdds = split.deps.length + val numRdds = dependencies.length // A list of (rdd iterator, dependency number) pairs val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)] - for ((dep, depNum) <- split.deps.zipWithIndex) dep match { - case NarrowCoGroupSplitDep(rdd, _, itsSplit) => + for ((dep, depNum) <- dependencies.zipWithIndex) dep match { + case oneToOneDependency: OneToOneDependency[Product2[K, Any]] => + val dependencyPartition = split.narrowDeps(depNum).get.split // Read them from the parent - val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]] + val it = oneToOneDependency.rdd.iterator(dependencyPartition, context) rddIterators += ((it, depNum)) - case ShuffleCoGroupSplitDep(handle) => + case shuffleDependency: ShuffleDependency[_, _, _] => // Read map outputs of shuffle val it = SparkEnv.get.shuffleManager - .getReader(handle, split.index, split.index + 1, context) + .getReader(shuffleDependency.shuffleHandle, split.index, split.index + 1, context) .read() rddIterators += ((it, depNum)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index e9d745588ee9a..633aeba3bbae6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -81,9 +81,9 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) => dependencies(j) match { case s: ShuffleDependency[_, _, _] => - new ShuffleCoGroupSplitDep(s.shuffleHandle) + None case _ => - new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)) + Some(new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))) } }.toArray) } @@ -105,20 +105,26 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( seq } } - def integrate(dep: CoGroupSplitDep, op: Product2[K, V] => Unit): Unit = dep match { - case NarrowCoGroupSplitDep(rdd, _, itsSplit) => - rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op) + def integrate(depNum: Int, op: Product2[K, V] => Unit) = { + dependencies(depNum) match { + case oneToOneDependency: OneToOneDependency[_] => + val dependencyPartition = partition.narrowDeps(depNum).get.split + oneToOneDependency.rdd.iterator(dependencyPartition, context) + .asInstanceOf[Iterator[Product2[K, V]]].foreach(op) - case ShuffleCoGroupSplitDep(handle) => - val iter = SparkEnv.get.shuffleManager - .getReader(handle, partition.index, partition.index + 1, context) - .read() - iter.foreach(op) + case shuffleDependency: ShuffleDependency[_, _, _] => + val iter = SparkEnv.get.shuffleManager + .getReader( + shuffleDependency.shuffleHandle, partition.index, partition.index + 1, context) + .read() + iter.foreach(op) + } } + // the first dep is rdd1; add all values to the map - integrate(partition.deps(0), t => getSeq(t._1) += t._2) + integrate(0, t => getSeq(t._1) += t._2) // the second dep is rdd2; remove all of its keys - integrate(partition.deps(1), t => map.remove(t._1)) + integrate(1, t => map.remove(t._1)) map.iterator.map { t => t._2.iterator.map { (t._1, _) } }.flatten }