Skip to content

Commit

Permalink
Implement Mridul's ExternalAppendOnlyMap fixes in ExternalSorter too
Browse files Browse the repository at this point in the history
Modified ExternalSorterSuite to also set a low object stream reset and
batch size, and verified that it failed before the changes and succeeded
after.
  • Loading branch information
mateiz committed Aug 1, 2014
1 parent 0d6dad7 commit bda37bb
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,8 @@ class ExternalAppendOnlyMap[K, V, C](
* An iterator that returns (K, C) pairs in sorted order from an on-disk map
*/
private class DiskMapIterator(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long])
extends Iterator[(K, C)] {
extends Iterator[(K, C)]
{
private val batchOffsets = batchSizes.scanLeft(0L)(_ + _) // Size will be batchSize.length + 1
assert(file.length() == batchOffsets(batchOffsets.length - 1))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import scala.collection.mutable
import com.google.common.io.ByteStreams

import org.apache.spark.{Aggregator, SparkEnv, Logging, Partitioner}
import org.apache.spark.serializer.Serializer
import org.apache.spark.serializer.{DeserializationStream, Serializer}
import org.apache.spark.storage.BlockId

/**
Expand Down Expand Up @@ -273,13 +273,16 @@ private[spark] class ExternalSorter[K, V, C](
// Flush the disk writer's contents to disk, and update relevant variables.
// The writer is closed at the end of this process, and cannot be reused.
def flush() = {
writer.commitAndClose()
val bytesWritten = writer.bytesWritten
val w = writer
writer = null
w.commitAndClose()
val bytesWritten = w.bytesWritten
batchSizes.append(bytesWritten)
_diskBytesSpilled += bytesWritten
objectsWritten = 0
}

var success = false
try {
val it = collection.destructiveSortedIterator(partitionKeyComparator)
while (it.hasNext) {
Expand All @@ -299,13 +302,23 @@ private[spark] class ExternalSorter[K, V, C](
}
if (objectsWritten > 0) {
flush()
} else if (writer != null) {
val w = writer
writer = null
w.revertPartialWritesAndClose()
}
success = true
} finally {
if (!success) {
// This code path only happens if an exception was thrown above before we set success;
// close our stuff and let the exception be thrown further
if (writer != null) {
writer.revertPartialWritesAndClose()
}
if (file.exists()) {
file.delete()
}
}
writer.close()
} catch {
case e: Exception =>
writer.close()
file.delete()
throw e
}

if (usingMap) {
Expand Down Expand Up @@ -472,36 +485,58 @@ private[spark] class ExternalSorter[K, V, C](
* partitions to be requested in order.
*/
private[this] class SpillReader(spill: SpilledFile) {
val fileStream = new FileInputStream(spill.file)
val bufferedStream = new BufferedInputStream(fileStream, fileBufferSize)
// Serializer batch offsets; size will be batchSize.length + 1
val batchOffsets = spill.serializerBatchSizes.scanLeft(0L)(_ + _)

// Track which partition and which batch stream we're in. These will be the indices of
// the next element we will read. We'll also store the last partition read so that
// readNextPartition() can figure out what partition that was from.
var partitionId = 0
var indexInPartition = 0L
var batchStreamsRead = 0
var batchId = 0
var indexInBatch = 0
var lastPartitionId = 0

skipToNextPartition()

// An intermediate stream that reads from exactly one batch

// Intermediate file and deserializer streams that read from exactly one batch
// This guards against pre-fetching and other arbitrary behavior of higher level streams
var batchStream = nextBatchStream()
var compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream)
var deserStream = serInstance.deserializeStream(compressedStream)
var fileStream: FileInputStream = null
var deserializeStream = nextBatchStream() // Also sets fileStream

var nextItem: (K, C) = null
var finished = false

/** Construct a stream that only reads from the next batch */
def nextBatchStream(): InputStream = {
if (batchStreamsRead < spill.serializerBatchSizes.length) {
batchStreamsRead += 1
ByteStreams.limit(bufferedStream, spill.serializerBatchSizes(batchStreamsRead - 1))
def nextBatchStream(): DeserializationStream = {
// Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether
// we're still in a valid batch.
if (batchId < batchOffsets.length - 1) {
if (deserializeStream != null) {
deserializeStream.close()
fileStream.close()
deserializeStream = null
fileStream = null
}

val start = batchOffsets(batchId)
fileStream = new FileInputStream(spill.file)
fileStream.getChannel.position(start)
batchId += 1

val end = batchOffsets(batchId)

assert(end >= start, "start = " + start + ", end = " + end +
", batchOffsets = " + batchOffsets.mkString("[", ", ", "]"))

val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start))
val compressedStream = blockManager.wrapForCompression(spill.blockId, bufferedStream)
serInstance.deserializeStream(compressedStream)
} else {
// No more batches left; give an empty stream
bufferedStream
// No more batches left
cleanup()
null
}
}

Expand All @@ -525,27 +560,27 @@ private[spark] class ExternalSorter[K, V, C](
* If no more pairs are left, return null.
*/
private def readNextItem(): (K, C) = {
if (finished) {
if (finished || deserializeStream == null) {
return null
}
val k = deserStream.readObject().asInstanceOf[K]
val c = deserStream.readObject().asInstanceOf[C]
val k = deserializeStream.readObject().asInstanceOf[K]
val c = deserializeStream.readObject().asInstanceOf[C]
lastPartitionId = partitionId
// Start reading the next batch if we're done with this one
indexInBatch += 1
if (indexInBatch == serializerBatchSize) {
batchStream = nextBatchStream()
compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream)
deserStream = serInstance.deserializeStream(compressedStream)
indexInBatch = 0
deserializeStream = nextBatchStream()
}
// Update the partition location of the element we're reading
indexInPartition += 1
skipToNextPartition()
// If we've finished reading the last partition, remember that we're done
if (partitionId == numPartitions) {
finished = true
deserStream.close()
if (deserializeStream != null) {
deserializeStream.close()
}
}
(k, c)
}
Expand Down Expand Up @@ -578,6 +613,31 @@ private[spark] class ExternalSorter[K, V, C](
item
}
}

// Clean up our open streams and put us in a state where we can't read any more data
def cleanup() {
batchId = batchOffsets.length // Prevent reading any other batch
val ds = deserializeStream
val fs = fileStream
deserializeStream = null
fileStream = null

if (ds != null) {
try {
ds.close()
} catch {
case e: IOException =>
// Make sure we at least close the file handle
if (fs != null) {
try { fs.close() } catch { case e2: IOException => }
}
throw e
}
}

// NOTE: We don't do file.delete() here because that is done in ExternalSorter.stop().
// This should also be fixed in ExternalAppendOnlyMap.
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,17 @@ import org.apache.spark._
import org.apache.spark.SparkContext._

class ExternalSorterSuite extends FunSuite with LocalSparkContext {
private def createSparkConf(loadDefaults: Boolean): SparkConf = {
val conf = new SparkConf(loadDefaults)
// Make the Java serializer write a reset instruction (TC_RESET) after each object to test
// for a bug we had with bytes written past the last object in a batch (SPARK-2792)
conf.set("spark.serializer.objectStreamReset", "0")
conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer")
// Ensure that we actually have multiple batches per spill file
conf.set("spark.shuffle.spill.batchSize", "10")
conf
}

test("empty data stream") {
val conf = new SparkConf(false)
conf.set("spark.shuffle.memoryFraction", "0.001")
Expand Down Expand Up @@ -60,7 +71,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
}

test("few elements per partition") {
val conf = new SparkConf(false)
val conf = createSparkConf(false)
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local", "test", conf)
Expand Down Expand Up @@ -102,7 +113,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
}

test("empty partitions with spilling") {
val conf = new SparkConf(false)
val conf = createSparkConf(false)
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local", "test", conf)
Expand All @@ -127,7 +138,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
}

test("spilling in local cluster") {
val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
Expand Down Expand Up @@ -198,7 +209,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
}

test("spilling in local cluster with many reduce tasks") {
val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
Expand Down Expand Up @@ -269,7 +280,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
}

test("cleanup of intermediate files in sorter") {
val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local", "test", conf)
Expand All @@ -290,7 +301,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
}

test("cleanup of intermediate files in sorter if there are errors") {
val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local", "test", conf)
Expand All @@ -311,7 +322,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
}

test("cleanup of intermediate files in shuffle") {
val conf = new SparkConf(false)
val conf = createSparkConf(false)
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local", "test", conf)
Expand All @@ -326,7 +337,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
}

test("cleanup of intermediate files in shuffle with errors") {
val conf = new SparkConf(false)
val conf = createSparkConf(false)
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local", "test", conf)
Expand All @@ -348,7 +359,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
}

test("no partial aggregation or sorting") {
val conf = new SparkConf(false)
val conf = createSparkConf(false)
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local", "test", conf)
Expand All @@ -363,7 +374,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
}

test("partial aggregation without spill") {
val conf = new SparkConf(false)
val conf = createSparkConf(false)
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local", "test", conf)
Expand All @@ -379,7 +390,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
}

test("partial aggregation with spill, no ordering") {
val conf = new SparkConf(false)
val conf = createSparkConf(false)
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local", "test", conf)
Expand All @@ -395,7 +406,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
}

test("partial aggregation with spill, with ordering") {
val conf = new SparkConf(false)
val conf = createSparkConf(false)
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local", "test", conf)
Expand All @@ -412,7 +423,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
}

test("sorting without aggregation, no spill") {
val conf = new SparkConf(false)
val conf = createSparkConf(false)
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local", "test", conf)
Expand All @@ -429,7 +440,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
}

test("sorting without aggregation, with spill") {
val conf = new SparkConf(false)
val conf = createSparkConf(false)
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local", "test", conf)
Expand All @@ -446,7 +457,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
}

test("spilling with hash collisions") {
val conf = new SparkConf(true)
val conf = createSparkConf(true)
conf.set("spark.shuffle.memoryFraction", "0.001")
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)

Expand Down Expand Up @@ -503,7 +514,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
}

test("spilling with many hash collisions") {
val conf = new SparkConf(true)
val conf = createSparkConf(true)
conf.set("spark.shuffle.memoryFraction", "0.0001")
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)

Expand All @@ -526,7 +537,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
}

test("spilling with hash collisions using the Int.MaxValue key") {
val conf = new SparkConf(true)
val conf = createSparkConf(true)
conf.set("spark.shuffle.memoryFraction", "0.001")
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)

Expand All @@ -547,7 +558,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
}

test("spilling with null keys and values") {
val conf = new SparkConf(true)
val conf = createSparkConf(true)
conf.set("spark.shuffle.memoryFraction", "0.001")
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)

Expand Down

0 comments on commit bda37bb

Please sign in to comment.