diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index e6df70dbdb9f2..9a60d94af0cc5 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -50,42 +50,39 @@ private[spark] class HashShuffleReader[K, C]( val serializerInstance = ser.newInstance() // Create a key/value iterator for each stream - val recordIterator = wrappedStreams.flatMap { wrappedStream => + val recordIter = wrappedStreams.flatMap { wrappedStream => serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator } + // Update the context task metrics for each record read. val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() - // Update read metrics for each record materialized - val metricIter = new InterruptibleIterator[(Any, Any)](context, recordIterator) { - override def next(): (Any, Any) = { + val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( + recordIter.map(record => { readMetrics.incRecordsRead(1) - delegate.next() - } - } + record + }), + context.taskMetrics().updateShuffleReadMetrics()) - val iter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](metricIter, { - context.taskMetrics().updateShuffleReadMetrics() - }) + // An interruptible iterator must be used here in order to support task cancellation + val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { // We are reading values that are already combined - val combinedKeyValuesIterator = iter.asInstanceOf[Iterator[(K, C)]] - new InterruptibleIterator(context, - dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)) + val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] + dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) } else { // We don't know the value type, but also don't care -- the dependency *should* // have made sure its compatible w/ this aggregator, which will convert the value // type to the combined type C - val keyValuesIterator = iter.asInstanceOf[Iterator[(K, Nothing)]] - new InterruptibleIterator(context, - dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)) + val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] + dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) } } else { require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!") // Convert the Product2s to pairs since this is what downstream RDDs currently expect - iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2)) + interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2)) } // Sort the output if there is a sort ordering defined. diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index f7dc651e6d5d0..89f6713946b4e 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -28,10 +28,10 @@ import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer +import org.apache.spark.{SparkFunSuite, TaskContextImpl} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.BlockFetchingListener -import org.apache.spark.{SparkFunSuite, TaskContextImpl} class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { @@ -61,11 +61,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Create a mock managed buffer for testing def createMockManagedBuffer(): ManagedBuffer = { val mockManagedBuffer = mock(classOf[ManagedBuffer]) - when(mockManagedBuffer.createInputStream()).thenAnswer(new Answer[InputStream] { - override def answer(invocation: InvocationOnMock): InputStream = { - mock(classOf[InputStream]) - } - }) + when(mockManagedBuffer.createInputStream()).thenReturn(mock(classOf[InputStream])) mockManagedBuffer } @@ -76,9 +72,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Make sure blockManager.getBlockData would return the blocks val localBlocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer])) + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()) localBlocks.foreach { case (blockId, buf) => doReturn(buf).when(blockManager).getBlockData(meq(blockId)) } @@ -86,9 +82,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val remoteBlocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 3, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 4, 0) -> mock(classOf[ManagedBuffer]) - ) + ShuffleBlockId(0, 3, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 4, 0) -> createMockManagedBuffer()) val transfer = createMockTransfer(remoteBlocks) @@ -109,13 +104,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { for (i <- 0 until 5) { assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") - val (blockId, subIterator) = iterator.next() - assert(subIterator.isSuccess, + val (blockId, inputStream) = iterator.next() + assert(inputStream.isSuccess, s"iterator should have 5 elements defined but actually has $i elements") // Make sure we release buffers when a wrapped input stream is closed. val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId)) - val wrappedInputStream = new BufferReleasingInputStream(mock(classOf[InputStream]), iterator) + val wrappedInputStream = new BufferReleasingInputStream(inputStream.get, iterator) verify(mockBuf, times(0)).release() wrappedInputStream.close() verify(mockBuf, times(1)).release()