diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 9f7ebae3e9af3..b7ec85fb8a5a1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1080,6 +1080,21 @@ abstract class RDD[T: ClassTag]( var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it))) var numPartitions = partiallyAggregated.partitions.length val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) + + // Do one level of aggregation based on executorId before starting the tree + // NOTE: exclude the driver from list of executors + val numExecutors = math.max(context.getExecutorStorageStatus.length - 1, 1) + partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { case (idx, iter) => + def isAllDigits(x: String) = x forall Character.isDigit + val execId = SparkEnv.get.executorId + if (isAllDigits(execId)) { + iter.map((execId.toInt, _)) + } else { + iter.map((execId.hashCode, _)) + } + }.reduceByKey(new HashPartitioner(numExecutors), combOp).values + numPartitions = numExecutors + // If creating an extra level doesn't help reduce // the wall-clock time, we stop tree aggregation. while (numPartitions > scale + numPartitions / scale) {