Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-10797] RDD's coalesce should not write out the temporary key #8979

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,28 @@

package org.apache.spark.shuffle.sort;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;

import scala.Product2;
import scala.Tuple2;
import scala.collection.Iterator;

import com.google.common.io.Closeables;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.Partitioner;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.storage.*;
import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.DiskBlockObjectWriter;
import org.apache.spark.storage.TempShuffleBlockId;
import org.apache.spark.util.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Product2;
import scala.Tuple2;
import scala.collection.Iterator;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;

/**
* This class implements sort-based shuffle's hash-style shuffle fallback path. This write path
Expand Down Expand Up @@ -73,6 +74,7 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
private final Partitioner partitioner;
private final ShuffleWriteMetrics writeMetrics;
private final Serializer serializer;
private final boolean dropKeys;

/** Array of file writers, one for each partition */
private DiskBlockObjectWriter[] partitionWriters;
Expand All @@ -82,7 +84,8 @@ public BypassMergeSortShuffleWriter(
BlockManager blockManager,
Partitioner partitioner,
ShuffleWriteMetrics writeMetrics,
Serializer serializer) {
Serializer serializer,
boolean dropKeys) {
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true);
Expand All @@ -91,10 +94,11 @@ public BypassMergeSortShuffleWriter(
this.partitioner = partitioner;
this.writeMetrics = writeMetrics;
this.serializer = serializer;
this.dropKeys = dropKeys;
}

@Override
public void insertAll(Iterator<Product2<K, V>> records) throws IOException {
public void insertAll(Iterator<Product2<K, V>> records, boolean dropKeys) throws IOException {
assert (partitionWriters == null);
if (!records.hasNext()) {
return;
Expand All @@ -118,7 +122,12 @@ public void insertAll(Iterator<Product2<K, V>> records) throws IOException {
while (records.hasNext()) {
final Product2<K, V> record = records.next();
final K key = record._1();
partitionWriters[partitioner.getPartition(key)].write(key, record._2());

if(dropKeys) {
partitionWriters[partitioner.getPartition(key)].write(record._2());
} else {
partitionWriters[partitioner.getPartition(key)].write(key, record._2());
}
}

for (DiskBlockObjectWriter writer : partitionWriters) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,22 @@

package org.apache.spark.shuffle.sort;

import java.io.File;
import java.io.IOException;

import org.apache.spark.TaskContext;
import org.apache.spark.annotation.Private;
import org.apache.spark.storage.BlockId;
import scala.Product2;
import scala.collection.Iterator;

import org.apache.spark.annotation.Private;
import org.apache.spark.TaskContext;
import org.apache.spark.storage.BlockId;
import java.io.File;
import java.io.IOException;

/**
* Interface for objects that {@link SortShuffleWriter} uses to write its output files.
*/
@Private
public interface SortShuffleFileWriter<K, V> {

void insertAll(Iterator<Product2<K, V>> records) throws IOException;
void insertAll(Iterator<Product2<K, V>> records, boolean dropKeys) throws IOException;

/**
* Write all the data added into this shuffle sorter into a file in the disk store. This is
Expand Down
5 changes: 3 additions & 2 deletions core/src/main/scala/org/apache/spark/Dependency.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
val serializer: Option[Serializer] = None,
val keyOrdering: Option[Ordering[K]] = None,
val aggregator: Option[Aggregator[K, V, C]] = None,
val mapSideCombine: Boolean = false)
val mapSideCombine: Boolean = false,
val dropKeys: Boolean = false)
extends Dependency[Product2[K, V]] {

override def rdd: RDD[Product2[K, V]] = _rdd.asInstanceOf[RDD[Product2[K, V]]]
Expand All @@ -88,7 +89,7 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
val shuffleId: Int = _rdd.context.newShuffleId()

val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle(
shuffleId, _rdd.partitions.size, this)
shuffleId, _rdd.partitions.size, this).setDropKeys(dropKeys)

_rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
}
Expand Down
8 changes: 5 additions & 3 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -398,9 +398,11 @@ abstract class RDD[T: ClassTag](

// include a shuffle step so that our upstream tasks are still distributed
new CoalescedRDD(
new ShuffledRDD[Int, T, T](mapPartitionsWithIndex(distributePartition),
new HashPartitioner(numPartitions)),
numPartitions).values
new ShuffledRDD[Int, T, T](
mapPartitionsWithIndex(distributePartition),
new HashPartitioner(numPartitions)
).setDropKeys(true).mapPartitions(_.asInstanceOf[Iterator[T]]),
numPartitions)
} else {
new CoalescedRDD(this, numPartitions)
}
Expand Down
11 changes: 10 additions & 1 deletion core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag](

private var mapSideCombine: Boolean = false

private var dropKeys: Boolean = false

/** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C] = {
this.serializer = Option(serializer)
Expand All @@ -76,8 +78,15 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag](
this
}

/** Set dropKeys flag for RDD's shuffle. */
def setDropKeys(dropKeys: Boolean): ShuffledRDD[K, V, C] = {
this.dropKeys = dropKeys
this
}

override def getDependencies: Seq[Dependency[_]] = {
List(new ShuffleDependency(prev, part, serializer, keyOrdering, aggregator, mapSideCombine))
List(new ShuffleDependency(prev, part, serializer, keyOrdering, aggregator,
mapSideCombine, dropKeys))
}

override val partitioner = Some(part)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,24 @@ private[spark] class BlockStoreShuffleReader[K, C](
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
if (!handle.dropKeys) {
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
} else {
serializerInstance.deserializeStream(wrappedStream).asIterator
}
}

// Update the context task metrics for each record read.
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
val metricIter = CompletionIterator[Any, Iterator[Any]](
recordIter.map(record => {
readMetrics.incRecordsRead(1)
record
}),
context.taskMetrics().updateShuffleReadMetrics())

// An interruptible iterator must be used here in order to support task cancellation
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
val interruptibleIter = new InterruptibleIterator[Any](context, metricIter)

val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
Expand All @@ -99,7 +103,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
// Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
// the ExternalSorter won't spill to disk.
val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
sorter.insertAll(aggregatedIter)
sorter.insertAll(aggregatedIter, handle.dropKeys)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.internalMetricsToAccumulators(
Expand Down
13 changes: 12 additions & 1 deletion core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,15 @@ import org.apache.spark.annotation.DeveloperApi
* @param shuffleId ID of the shuffle
*/
@DeveloperApi
abstract class ShuffleHandle(val shuffleId: Int) extends Serializable {}
abstract class ShuffleHandle(val shuffleId: Int) extends Serializable {
private var _dropKeys: Boolean = false

def setDropKeys(dropKeys: Boolean): ShuffleHandle = {
this._dropKeys = dropKeys
this
}

def dropKeys: Boolean = {
_dropKeys
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ private[spark] class SortShuffleWriter[K, V, C](
override def write(records: Iterator[Product2[K, V]]): Unit = {
sorter = if (dep.mapSideCombine) {
require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
new ExternalSorter[K, V, C](
dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
new ExternalSorter[K, V, C](dep.aggregator, Some(dep.partitioner),
dep.keyOrdering, dep.serializer, dep.shuffleHandle.dropKeys)
} else if (SortShuffleWriter.shouldBypassMergeSort(
SparkEnv.get.conf, dep.partitioner.numPartitions, aggregator = None, keyOrdering = None)) {
// If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
Expand All @@ -62,15 +62,15 @@ private[spark] class SortShuffleWriter[K, V, C](
// together the spilled files, which would happen with the normal code path. The downside is
// having multiple files open at a time and thus more memory allocated to buffers.
new BypassMergeSortShuffleWriter[K, V](SparkEnv.get.conf, blockManager, dep.partitioner,
writeMetrics, Serializer.getSerializer(dep.serializer))
writeMetrics, Serializer.getSerializer(dep.serializer), dep.shuffleHandle.dropKeys)
} else {
// In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
// care whether the keys get sorted in each partition; that will be done on the reduce side
// if the operation being run is sortByKey.
new ExternalSorter[K, V, V](
aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
new ExternalSorter[K, V, V](aggregator = None, Some(dep.partitioner),
ordering = None, dep.serializer, dep.shuffleHandle.dropKeys)
}
sorter.insertAll(records)
sorter.insertAll(records, handle.dropKeys)

// Don't bother including the time to open the merged output file in the shuffle write time,
// because it just opens a single file, so is typically too fast to measure accurately
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,18 @@ private[spark] class DiskBlockObjectWriter(
recordWritten()
}

/**
* Writes an object.
*/
def write(obj: Any) {
if (!initialized) {
open()
}

objOut.writeObject(obj)
recordWritten()
}

override def write(b: Int): Unit = throw new UnsupportedOperationException()

override def write(kvBytes: Array[Byte], offs: Int, len: Int): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ private[spark] class ExternalSorter[K, V, C](
aggregator: Option[Aggregator[K, V, C]] = None,
partitioner: Option[Partitioner] = None,
ordering: Option[Ordering[K]] = None,
serializer: Option[Serializer] = None)
serializer: Option[Serializer] = None,
dropKeys: Boolean = false)
extends Logging
with Spillable[WritablePartitionedPairCollection[K, C]]
with SortShuffleFileWriter[K, V] {
Expand Down Expand Up @@ -192,7 +193,7 @@ private[spark] class ExternalSorter[K, V, C](
*/
private[spark] def numSpills: Int = spills.size

override def insertAll(records: Iterator[Product2[K, V]]): Unit = {
override def insertAll(records: Iterator[Product2[K, V]], dropKeys: Boolean): Unit = {
// TODO: stop combining if we find that the reduction factor isn't high
val shouldCombine = aggregator.isDefined

Expand Down Expand Up @@ -670,7 +671,7 @@ private[spark] class ExternalSorter[K, V, C](
if (spills.isEmpty) {
// Case where we only have in-memory data
val collection = if (aggregator.isDefined) map else buffer
val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
val it = collection.destructiveSortedWritablePartitionedIterator(comparator, dropKeys)
while (it.hasNext) {
val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
context.taskMetrics.shuffleWriteMetrics.get)
Expand All @@ -689,7 +690,11 @@ private[spark] class ExternalSorter[K, V, C](
val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
context.taskMetrics.shuffleWriteMetrics.get)
for (elem <- elements) {
writer.write(elem._1, elem._2)
if (dropKeys) {
writer.write(elem._2)
} else {
writer.write(elem._1, elem._2)
}
}
writer.commitAndClose()
val segment = writer.fileSegment()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ private[spark] class PartitionedSerializedPairBuffer[K, V](

override def estimateSize: Long = metaBuffer.capacity * 4L + kvBuffer.capacity

override def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
override def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]],
dropKeys: Boolean = false)
: WritablePartitionedIterator = {
sort(keyComparator)
new WritablePartitionedIterator {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,19 @@ private[spark] trait WritablePartitionedPairCollection[K, V] {
* returned in order of their partition ID and then the given comparator.
* This may destroy the underlying collection.
*/
def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]],
dropKeys: Boolean = false)
: WritablePartitionedIterator = {
val it = partitionedDestructiveSortedIterator(keyComparator)
new WritablePartitionedIterator {
private[this] var cur = if (it.hasNext) it.next() else null

def writeNext(writer: DiskBlockObjectWriter): Unit = {
writer.write(cur._1._2, cur._2)
if (dropKeys) {
writer.write(cur._2)
} else {
writer.write(cur._1._2, cur._2)
}
cur = if (it.hasNext) it.next() else null
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,10 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
blockManager,
new HashPartitioner(7),
shuffleWriteMetrics,
serializer
serializer,
false
)
writer.insertAll(Iterator.empty)
writer.insertAll(Iterator.empty, false)
val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile)
assert(partitionLengths.sum === 0)
assert(outputFile.exists())
Expand All @@ -133,9 +134,10 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
blockManager,
new HashPartitioner(7),
shuffleWriteMetrics,
serializer
serializer,
false
)
writer.insertAll(records)
writer.insertAll(records, false)
assert(temporaryFilesCreated.nonEmpty)
val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile)
assert(partitionLengths.sum === outputFile.length())
Expand All @@ -152,15 +154,16 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
blockManager,
new HashPartitioner(7),
shuffleWriteMetrics,
serializer
serializer,
false
)
intercept[SparkException] {
writer.insertAll((0 until 100000).iterator.map(i => {
if (i == 99990) {
throw new SparkException("Intentional failure")
}
(i, i)
}))
}), false)
}
assert(temporaryFilesCreated.nonEmpty)
writer.stop()
Expand Down
Loading