diff --git a/assembly/pom.xml b/assembly/pom.xml
index 71333f734d9a1..3f955997fd925 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r39
+ 2.4.1-kylin-r40
../pom.xml
diff --git a/common/kvstore/pom.xml b/common/kvstore/pom.xml
index 292da55179322..5828d2e5d512a 100644
--- a/common/kvstore/pom.xml
+++ b/common/kvstore/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r39
+ 2.4.1-kylin-r40
../../pom.xml
diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml
index 22ad07bf581c7..a1a884a91433c 100644
--- a/common/network-common/pom.xml
+++ b/common/network-common/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r39
+ 2.4.1-kylin-r40
../../pom.xml
diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml
index 5f7784d21e236..b1db8a14d05bd 100644
--- a/common/network-shuffle/pom.xml
+++ b/common/network-shuffle/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r39
+ 2.4.1-kylin-r40
../../pom.xml
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
index 098fa7974b87b..d24c87e035ca5 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
@@ -197,48 +197,60 @@ public Map getMetrics() {
}
}
+ private boolean isShuffleBlock(String[] blockIdParts) {
+ // length == 4: ShuffleBlockId
+ // length == 5: ContinuousShuffleBlockId
+ return (blockIdParts.length == 4 || blockIdParts.length == 5) &&
+ blockIdParts[0].equals("shuffle");
+ }
+
private class ManagedBufferIterator implements Iterator {
private int index = 0;
private final String appId;
private final String execId;
private final int shuffleId;
- // An array containing mapId and reduceId pairs.
- private final int[] mapIdAndReduceIds;
+ // An array containing mapId, reduceId and numBlocks tuple
+ private final int[] shuffleBlockIds;
ManagedBufferIterator(String appId, String execId, String[] blockIds) {
this.appId = appId;
this.execId = execId;
String[] blockId0Parts = blockIds[0].split("_");
- if (blockId0Parts.length != 4 || !blockId0Parts[0].equals("shuffle")) {
+ if (!isShuffleBlock(blockId0Parts)) {
throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[0]);
}
this.shuffleId = Integer.parseInt(blockId0Parts[1]);
- mapIdAndReduceIds = new int[2 * blockIds.length];
+ shuffleBlockIds = new int[3 * blockIds.length];
for (int i = 0; i < blockIds.length; i++) {
String[] blockIdParts = blockIds[i].split("_");
- if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) {
+ if (!isShuffleBlock(blockIdParts)) {
throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[i]);
}
if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
- ", got:" + blockIds[i]);
+ ", got:" + blockIds[i]);
+ }
+ shuffleBlockIds[3 * i] = Integer.parseInt(blockIdParts[2]);
+ shuffleBlockIds[3 * i + 1] = Integer.parseInt(blockIdParts[3]);
+ if (blockIdParts.length == 4) {
+ shuffleBlockIds[3 * i + 2] = 1;
+ } else {
+ shuffleBlockIds[3 * i + 2] = Integer.parseInt(blockIdParts[4]);
}
- mapIdAndReduceIds[2 * i] = Integer.parseInt(blockIdParts[2]);
- mapIdAndReduceIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]);
}
}
@Override
public boolean hasNext() {
- return index < mapIdAndReduceIds.length;
+ return index < shuffleBlockIds.length;
}
@Override
public ManagedBuffer next() {
final ManagedBuffer block = blockManager.getBlockData(appId, execId, shuffleId,
- mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1]);
- index += 2;
+ shuffleBlockIds[index], shuffleBlockIds[index + 1], shuffleBlockIds[index + 2]);
+ index += 3;
metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0);
return block;
}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java
index 0b7a27402369d..9589544e3ab03 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java
@@ -162,21 +162,22 @@ public void registerExecutor(
}
/**
- * Obtains a FileSegmentManagedBuffer from (shuffleId, mapId, reduceId). We make assumptions
- * about how the hash and sort based shuffles store their data.
+ * Obtains a FileSegmentManagedBuffer from (shuffleId, mapId, reduceId, numBlocks).
+ * We make assumptions about how the hash and sort based shuffles store their data.
*/
public ManagedBuffer getBlockData(
- String appId,
- String execId,
- int shuffleId,
- int mapId,
- int reduceId) {
+ String appId,
+ String execId,
+ int shuffleId,
+ int mapId,
+ int reduceId,
+ int numBlocks) {
ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId));
if (executor == null) {
throw new RuntimeException(
- String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId));
+ String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId));
}
- return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId);
+ return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId, numBlocks);
}
/**
@@ -280,19 +281,19 @@ public boolean accept(File dir, String name) {
* and the block id format is from ShuffleDataBlockId and ShuffleIndexBlockId.
*/
private ManagedBuffer getSortBasedShuffleBlockData(
- ExecutorShuffleInfo executor, int shuffleId, int mapId, int reduceId) {
+ ExecutorShuffleInfo executor, int shuffleId, int mapId, int reduceId, int numBlocks) {
File indexFile = getFile(executor.localDirs, executor.subDirsPerLocalDir,
- "shuffle_" + shuffleId + "_" + mapId + "_0.index");
+ "shuffle_" + shuffleId + "_" + mapId + "_0.index");
try {
ShuffleIndexInformation shuffleIndexInformation = shuffleIndexCache.get(indexFile);
- ShuffleIndexRecord shuffleIndexRecord = shuffleIndexInformation.getIndex(reduceId);
+ ShuffleIndexRecord shuffleIndexRecord = shuffleIndexInformation.getIndex(reduceId, numBlocks);
return new FileSegmentManagedBuffer(
- conf,
- getFile(executor.localDirs, executor.subDirsPerLocalDir,
- "shuffle_" + shuffleId + "_" + mapId + "_0.data"),
- shuffleIndexRecord.getOffset(),
- shuffleIndexRecord.getLength());
+ conf,
+ getFile(executor.localDirs, executor.subDirsPerLocalDir,
+ "shuffle_" + shuffleId + "_" + mapId + "_0.data"),
+ shuffleIndexRecord.getOffset(),
+ shuffleIndexRecord.getLength());
} catch (ExecutionException e) {
throw new RuntimeException("Failed to open file: " + indexFile, e);
}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java
index 386738ece51a6..470e1040e97e5 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java
@@ -59,9 +59,9 @@ public int getSize() {
/**
* Get index offset for a particular reducer.
*/
- public ShuffleIndexRecord getIndex(int reduceId) {
+ public ShuffleIndexRecord getIndex(int reduceId, int numBlocks) {
long offset = offsets.get(reduceId);
- long nextOffset = offsets.get(reduceId + 1);
+ long nextOffset = offsets.get(reduceId + numBlocks);
return new ShuffleIndexRecord(offset, nextOffset - offset);
}
}
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
index 7846b71d5a8b1..baa7146604ef6 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
@@ -83,8 +83,8 @@ public void testOpenShuffleBlocks() {
ManagedBuffer block0Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[3]));
ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7]));
- when(blockResolver.getBlockData("app0", "exec1", 0, 0, 0)).thenReturn(block0Marker);
- when(blockResolver.getBlockData("app0", "exec1", 0, 0, 1)).thenReturn(block1Marker);
+ when(blockResolver.getBlockData("app0", "exec1", 0, 0, 0, 1)).thenReturn(block0Marker);
+ when(blockResolver.getBlockData("app0", "exec1", 0, 0, 1, 1)).thenReturn(block1Marker);
ByteBuffer openBlocks = new OpenBlocks("app0", "exec1",
new String[] { "shuffle_0_0_0", "shuffle_0_0_1" })
.toByteBuffer();
@@ -106,8 +106,8 @@ public void testOpenShuffleBlocks() {
assertEquals(block0Marker, buffers.next());
assertEquals(block1Marker, buffers.next());
assertFalse(buffers.hasNext());
- verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 0);
- verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 1);
+ verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 0, 1);
+ verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 1, 1);
// Verify open block request latency metrics
Timer openBlockRequestLatencyMillis = (Timer) ((ExternalShuffleBlockHandler) handler)
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java
index d2072a54fa415..05ca1fab7c9e0 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java
@@ -66,7 +66,7 @@ public void testBadRequests() throws IOException {
ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null);
// Unregistered executor
try {
- resolver.getBlockData("app0", "exec1", 1, 1, 0);
+ resolver.getBlockData("app0", "exec1", 1, 1, 0, 1);
fail("Should have failed");
} catch (RuntimeException e) {
assertTrue("Bad error message: " + e, e.getMessage().contains("not registered"));
@@ -75,7 +75,7 @@ public void testBadRequests() throws IOException {
// Invalid shuffle manager
try {
resolver.registerExecutor("app0", "exec2", dataContext.createExecutorInfo("foobar"));
- resolver.getBlockData("app0", "exec2", 1, 1, 0);
+ resolver.getBlockData("app0", "exec2", 1, 1, 0, 1);
fail("Should have failed");
} catch (UnsupportedOperationException e) {
// pass
@@ -85,7 +85,7 @@ public void testBadRequests() throws IOException {
resolver.registerExecutor("app0", "exec3",
dataContext.createExecutorInfo(SORT_MANAGER));
try {
- resolver.getBlockData("app0", "exec3", 1, 1, 0);
+ resolver.getBlockData("app0", "exec3", 1, 1, 0, 1);
fail("Should have failed");
} catch (Exception e) {
// pass
@@ -99,18 +99,25 @@ public void testSortShuffleBlocks() throws IOException {
dataContext.createExecutorInfo(SORT_MANAGER));
InputStream block0Stream =
- resolver.getBlockData("app0", "exec0", 0, 0, 0).createInputStream();
+ resolver.getBlockData("app0", "exec0", 0, 0, 0, 1).createInputStream();
String block0 = CharStreams.toString(
new InputStreamReader(block0Stream, StandardCharsets.UTF_8));
block0Stream.close();
assertEquals(sortBlock0, block0);
InputStream block1Stream =
- resolver.getBlockData("app0", "exec0", 0, 0, 1).createInputStream();
+ resolver.getBlockData("app0", "exec0", 0, 0, 1, 1).createInputStream();
String block1 = CharStreams.toString(
new InputStreamReader(block1Stream, StandardCharsets.UTF_8));
block1Stream.close();
assertEquals(sortBlock1, block1);
+
+ InputStream block01Stream =
+ resolver.getBlockData("app0", "exec0", 0, 0, 0, 2).createInputStream();
+ String block01 = CharStreams.toString(
+ new InputStreamReader(block01Stream, StandardCharsets.UTF_8));
+ block01Stream.close();
+ assertEquals(sortBlock0 + sortBlock1, block01);
}
@Test
diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml
index fa0293207f8a4..612b38e34e1a1 100644
--- a/common/network-yarn/pom.xml
+++ b/common/network-yarn/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r39
+ 2.4.1-kylin-r40
../../pom.xml
diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml
index c5341af740707..a5482b6c9f766 100644
--- a/common/sketch/pom.xml
+++ b/common/sketch/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r39
+ 2.4.1-kylin-r40
../../pom.xml
diff --git a/common/tags/pom.xml b/common/tags/pom.xml
index cb9a1cb8600f6..52bfe2fd9b3e5 100644
--- a/common/tags/pom.xml
+++ b/common/tags/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r39
+ 2.4.1-kylin-r40
../../pom.xml
diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml
index 8afac700b4e78..6d44060436be6 100644
--- a/common/unsafe/pom.xml
+++ b/common/unsafe/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r39
+ 2.4.1-kylin-r40
../../pom.xml
diff --git a/core/pom.xml b/core/pom.xml
index 508d393ebc4e8..ea1748f5412b8 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.4.1-kylin-r39
+ 2.4.1-kylin-r40
../pom.xml
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index 323a5d3c52831..cf63088151be4 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -125,7 +125,7 @@ public void write(Iterator> records) throws IOException {
if (!records.hasNext()) {
partitionLengths = new long[numPartitions];
shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null);
- mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
+ mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, new long[numPartitions]);
return;
}
final SerializerInstance serInstance = serializer.newInstance();
@@ -159,15 +159,18 @@ public void write(Iterator> records) throws IOException {
File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
File tmp = Utils.tempFileWith(output);
+ MapInfo mapInfo;
try {
- partitionLengths = writePartitionedFile(tmp);
+ mapInfo = writePartitionedFile(tmp);
+ partitionLengths = mapInfo.lengths;
shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
} finally {
if (tmp.exists() && !tmp.delete()) {
logger.error("Error while deleting temp file {}", tmp.getAbsolutePath());
}
}
- mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
+ mapStatus = MapStatus$.MODULE$.apply(
+ blockManager.shuffleServerId(), mapInfo.lengths, mapInfo.records);
}
@VisibleForTesting
@@ -180,12 +183,13 @@ long[] getPartitionLengths() {
*
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker).
*/
- private long[] writePartitionedFile(File outputFile) throws IOException {
+ private MapInfo writePartitionedFile(File outputFile) throws IOException {
// Track location of the partition starts in the output file
final long[] lengths = new long[numPartitions];
+ final long[] records = new long[numPartitions];
if (partitionWriters == null) {
// We were passed an empty iterator
- return lengths;
+ return new MapInfo(lengths, records);
}
final FileOutputStream out = new FileOutputStream(outputFile, true);
@@ -194,6 +198,7 @@ private long[] writePartitionedFile(File outputFile) throws IOException {
try {
for (int i = 0; i < numPartitions; i++) {
final File file = partitionWriterSegments[i].file();
+ records[i] = partitionWriterSegments[i].record();
if (file.exists()) {
final FileInputStream in = new FileInputStream(file);
boolean copyThrewException = true;
@@ -214,7 +219,7 @@ private long[] writePartitionedFile(File outputFile) throws IOException {
writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
}
partitionWriters = null;
- return lengths;
+ return new MapInfo(lengths, records);
}
@Override
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/MapInfo.java b/core/src/main/java/org/apache/spark/shuffle/sort/MapInfo.java
new file mode 100644
index 0000000000000..87bbba8a48268
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/MapInfo.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+public final class MapInfo {
+ final long[] lengths;
+ final long[] records;
+
+ public MapInfo(long[] lengths, long[] records) {
+ this.lengths = lengths;
+ this.records = records;
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index c7d2db4217d96..59abefa1db78d 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -195,6 +195,7 @@ private void writeSortedFile(boolean isLastFile) {
if (currentPartition != -1) {
final FileSegment fileSegment = writer.commitAndGet();
spillInfo.partitionLengths[currentPartition] = fileSegment.length();
+ spillInfo.partitionRecords[currentPartition] = fileSegment.record();
}
currentPartition = partition;
}
@@ -222,6 +223,7 @@ private void writeSortedFile(boolean isLastFile) {
// writeSortedFile() in that case.
if (currentPartition != -1) {
spillInfo.partitionLengths[currentPartition] = committedSegment.length();
+ spillInfo.partitionRecords[currentPartition] = committedSegment.record();
spills.add(spillInfo);
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java
index 865def6b83c53..5f0875fe38f66 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java
@@ -26,11 +26,13 @@
*/
final class SpillInfo {
final long[] partitionLengths;
+ final long[] partitionRecords;
final File file;
final TempShuffleBlockId blockId;
SpillInfo(int numPartitions, File file, TempShuffleBlockId blockId) {
this.partitionLengths = new long[numPartitions];
+ this.partitionRecords = new long[numPartitions];
this.file = file;
this.blockId = blockId;
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index 4839d04522f10..21a2c2ad42ad9 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -229,12 +229,12 @@ void closeAndWriteOutput() throws IOException {
serOutputStream = null;
final SpillInfo[] spills = sorter.closeAndGetSpills();
sorter = null;
- final long[] partitionLengths;
+ final MapInfo mapInfo;
final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
final File tmp = Utils.tempFileWith(output);
try {
try {
- partitionLengths = mergeSpills(spills, tmp);
+ mapInfo = mergeSpills(spills, tmp);
} finally {
for (SpillInfo spill : spills) {
if (spill.file.exists() && ! spill.file.delete()) {
@@ -242,13 +242,13 @@ void closeAndWriteOutput() throws IOException {
}
}
}
- shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
+ shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, mapInfo.lengths, tmp);
} finally {
if (tmp.exists() && !tmp.delete()) {
logger.error("Error while deleting temp file {}", tmp.getAbsolutePath());
}
}
- mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
+ mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), mapInfo.lengths, mapInfo.records);
}
@VisibleForTesting
@@ -280,25 +280,25 @@ void forceSorterToSpill() throws IOException {
*
* @return the partition lengths in the merged file.
*/
- private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException {
+ private MapInfo mergeSpills(SpillInfo[] spills, File outputFile) throws IOException {
final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true);
final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf);
final boolean fastMergeEnabled =
- sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true);
+ sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true);
final boolean fastMergeIsSupported = !compressionEnabled ||
- CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec);
+ CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec);
final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled();
try {
if (spills.length == 0) {
new FileOutputStream(outputFile).close(); // Create an empty file
- return new long[partitioner.numPartitions()];
+ return new MapInfo(new long[partitioner.numPartitions()], new long[partitioner.numPartitions()]);
} else if (spills.length == 1) {
// Here, we don't need to perform any metrics updates because the bytes written to this
// output file would have already been counted as shuffle bytes written.
Files.move(spills[0].file, outputFile);
- return spills[0].partitionLengths;
+ return new MapInfo(spills[0].partitionLengths, spills[0].partitionRecords);
} else {
- final long[] partitionLengths;
+ final MapInfo mapInfo;
// There are multiple spills to merge, so none of these spill files' lengths were counted
// towards our shuffle write count or shuffle write time. If we use the slow merge path,
// then the final output file's size won't necessarily be equal to the sum of the spill
@@ -315,14 +315,14 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti
// that doesn't need to interpret the spilled bytes.
if (transferToEnabled && !encryptionEnabled) {
logger.debug("Using transferTo-based fast merge");
- partitionLengths = mergeSpillsWithTransferTo(spills, outputFile);
+ mapInfo = mergeSpillsWithTransferTo(spills, outputFile);
} else {
logger.debug("Using fileStream-based fast merge");
- partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null);
+ mapInfo = mergeSpillsWithFileStream(spills, outputFile, null);
}
} else {
logger.debug("Using slow merge");
- partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec);
+ mapInfo = mergeSpillsWithFileStream(spills, outputFile, compressionCodec);
}
// When closing an UnsafeShuffleExternalSorter that has already spilled once but also has
// in-memory records, we write out the in-memory records to a file but do not count that
@@ -331,7 +331,7 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti
// SpillInfo's bytes.
writeMetrics.decBytesWritten(spills[spills.length - 1].file.length());
writeMetrics.incBytesWritten(outputFile.length());
- return partitionLengths;
+ return mapInfo;
}
} catch (IOException e) {
if (outputFile.exists() && !outputFile.delete()) {
@@ -357,13 +357,14 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti
* @param compressionCodec the IO compression codec, or null if shuffle compression is disabled.
* @return the partition lengths in the merged file.
*/
- private long[] mergeSpillsWithFileStream(
- SpillInfo[] spills,
- File outputFile,
- @Nullable CompressionCodec compressionCodec) throws IOException {
+ private MapInfo mergeSpillsWithFileStream(
+ SpillInfo[] spills,
+ File outputFile,
+ @Nullable CompressionCodec compressionCodec) throws IOException {
assert (spills.length >= 2);
final int numPartitions = partitioner.numPartitions();
final long[] partitionLengths = new long[numPartitions];
+ final long[] partitionRecords = new long[numPartitions];
final InputStream[] spillInputStreams = new InputStream[spills.length];
final OutputStream bos = new BufferedOutputStream(
@@ -377,8 +378,8 @@ private long[] mergeSpillsWithFileStream(
try {
for (int i = 0; i < spills.length; i++) {
spillInputStreams[i] = new NioBufferedFileInputStream(
- spills[i].file,
- inputBufferSizeInBytes);
+ spills[i].file,
+ inputBufferSizeInBytes);
}
for (int partition = 0; partition < numPartitions; partition++) {
final long initialFileLength = mergedFileOutputStream.getByteCount();
@@ -386,19 +387,20 @@ private long[] mergeSpillsWithFileStream(
// the higher level streams to make sure all data is really flushed and internal state is
// cleaned.
OutputStream partitionOutput = new CloseAndFlushShieldOutputStream(
- new TimeTrackingOutputStream(writeMetrics, mergedFileOutputStream));
+ new TimeTrackingOutputStream(writeMetrics, mergedFileOutputStream));
partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput);
if (compressionCodec != null) {
partitionOutput = compressionCodec.compressedOutputStream(partitionOutput);
}
+ long records = 0;
for (int i = 0; i < spills.length; i++) {
final long partitionLengthInSpill = spills[i].partitionLengths[partition];
if (partitionLengthInSpill > 0) {
InputStream partitionInputStream = new LimitedInputStream(spillInputStreams[i],
- partitionLengthInSpill, false);
+ partitionLengthInSpill, false);
try {
partitionInputStream = blockManager.serializerManager().wrapForEncryption(
- partitionInputStream);
+ partitionInputStream);
if (compressionCodec != null) {
partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream);
}
@@ -407,10 +409,12 @@ private long[] mergeSpillsWithFileStream(
partitionInputStream.close();
}
}
+ records += spills[i].partitionRecords[partition];
}
partitionOutput.flush();
partitionOutput.close();
partitionLengths[partition] = (mergedFileOutputStream.getByteCount() - initialFileLength);
+ partitionRecords[partition] = records;
}
threwException = false;
} finally {
@@ -421,7 +425,7 @@ private long[] mergeSpillsWithFileStream(
}
Closeables.close(mergedFileOutputStream, threwException);
}
- return partitionLengths;
+ return new MapInfo(partitionLengths, partitionRecords);
}
/**
@@ -431,10 +435,11 @@ private long[] mergeSpillsWithFileStream(
*
* @return the partition lengths in the merged file.
*/
- private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException {
+ private MapInfo mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException {
assert (spills.length >= 2);
final int numPartitions = partitioner.numPartitions();
final long[] partitionLengths = new long[numPartitions];
+ final long[] partitionRecords = new long[numPartitions];
final FileChannel[] spillInputChannels = new FileChannel[spills.length];
final long[] spillInputChannelPositions = new long[spills.length];
FileChannel mergedFileOutputChannel = null;
@@ -452,17 +457,19 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th
for (int partition = 0; partition < numPartitions; partition++) {
for (int i = 0; i < spills.length; i++) {
final long partitionLengthInSpill = spills[i].partitionLengths[partition];
+ final long partitionRecordInSpill = spills[i].partitionRecords[partition];
final FileChannel spillInputChannel = spillInputChannels[i];
final long writeStartTime = System.nanoTime();
Utils.copyFileStreamNIO(
- spillInputChannel,
- mergedFileOutputChannel,
- spillInputChannelPositions[i],
- partitionLengthInSpill);
+ spillInputChannel,
+ mergedFileOutputChannel,
+ spillInputChannelPositions[i],
+ partitionLengthInSpill);
spillInputChannelPositions[i] += partitionLengthInSpill;
writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
bytesWrittenToMergedFile += partitionLengthInSpill;
partitionLengths[partition] += partitionLengthInSpill;
+ partitionRecords[partition] += partitionRecordInSpill;
}
}
// Check the position after transferTo loop to see if it is in the right position and raise an
@@ -471,11 +478,11 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th
// https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948.
if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) {
throw new IOException(
- "Current position " + mergedFileOutputChannel.position() + " does not equal expected " +
- "position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" +
- " version to see if it is 2.6.32, as there is a kernel bug which will lead to " +
- "unexpected behavior when using transferTo. You can set spark.file.transferTo=false " +
- "to disable this NIO feature."
+ "Current position " + mergedFileOutputChannel.position() + " does not equal expected " +
+ "position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" +
+ " version to see if it is 2.6.32, as there is a kernel bug which will lead to " +
+ "unexpected behavior when using transferTo. You can set spark.file.transferTo=false " +
+ "to disable this NIO feature."
);
}
threwException = false;
@@ -488,7 +495,7 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th
}
Closeables.close(mergedFileOutputChannel, threwException);
}
- return partitionLengths;
+ return new MapInfo(partitionLengths, partitionRecords);
}
@Override
diff --git a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala
index f8a6f1d0d8cbb..b20b1cf7cc26e 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala
@@ -24,4 +24,8 @@ package org.apache.spark
* @param bytesByPartitionId approximate number of output bytes for each map output partition
* (may be inexact due to use of compressed map statuses)
*/
-private[spark] class MapOutputStatistics(val shuffleId: Int, val bytesByPartitionId: Array[Long])
+private[spark] class MapOutputStatistics(
+ val shuffleId: Int,
+ val bytesByPartitionId: Array[Long],
+ val recordsByPartitionId: Array[Long] = Array[Long]())
+ extends Serializable
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 681ab27d0d368..d1dcce5807d4d 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -22,7 +22,7 @@ import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolE
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import scala.collection.JavaConverters._
-import scala.collection.mutable.{HashMap, HashSet, ListBuffer, Map}
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration
import scala.reflect.ClassTag
@@ -31,10 +31,12 @@ import scala.util.control.NonFatal
import org.apache.spark.broadcast.{Broadcast, BroadcastManager}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
+import org.apache.spark.io.CompressionCodec
import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv}
-import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.scheduler.{ExecutorCacheTaskLocation, MapStatus}
+import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.MetadataFetchFailedException
-import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
+import org.apache.spark.storage.{BlockId, BlockManagerId, ContinuousShuffleBlockId, ShuffleBlockId}
import org.apache.spark.util._
/**
@@ -280,10 +282,23 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
}
}
+ protected def supportsContinuousBlockBatchFetch(serializerRelocatable: Boolean): Boolean = {
+ if (!serializerRelocatable) {
+ false
+ } else {
+ if (!conf.getBoolean("spark.shuffle.compress", true)) {
+ true
+ } else {
+ val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf)
+ CompressionCodec.supportsConcatenationOfSerializedStreams(compressionCodec)
+ }
+ }
+ }
+
// For testing
def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int)
- : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = {
- getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)
+ : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
+ getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1, serializerRelocatable = false)
}
/**
@@ -295,8 +310,28 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
* and the second item is a sequence of (shuffle block id, shuffle block size) tuples
* describing the shuffle blocks that are stored at that block manager.
*/
- def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
- : Iterator[(BlockManagerId, Seq[(BlockId, Long)])]
+ def getMapSizesByExecutorId(
+ shuffleId: Int,
+ startPartition: Int,
+ endPartition: Int,
+ serializerRelocatable: Boolean): Seq[(BlockManagerId, Seq[(BlockId, Long)])]
+
+ /**
+ * Called from executors to get the server URIs and output sizes for each shuffle block that
+ * needs to be read from a given range of map output partitions (startPartition is included but
+ * endPartition is excluded from the range) and a given start map Id and end map Id.
+ *
+ * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
+ * and the second item is a sequence of (shuffle block id, shuffle block size) tuples
+ * describing the shuffle blocks that are stored at that block manager.
+ */
+ def getMapSizesByExecutorId(
+ shuffleId: Int,
+ startPartition: Int,
+ endPartition: Int,
+ startMapId: Int,
+ endMapId: Int,
+ serializerRelocatable: Boolean): Seq[(BlockManagerId, Seq[(BlockId, Long)])]
/**
* Deletes map output status information for the specified shuffle stage.
@@ -521,29 +556,57 @@ private[spark] class MapOutputTrackerMaster(
val parallelism = math.min(
Runtime.getRuntime.availableProcessors(),
statuses.length.toLong * totalSizes.length / parallelAggThreshold + 1).toInt
+
+ var totalRecords = new Array[Long](0)
+ val records = statuses(0).getRecordForBlock(0)
+
if (parallelism <= 1) {
for (s <- statuses) {
for (i <- 0 until totalSizes.length) {
totalSizes(i) += s.getSizeForBlock(i)
}
}
+ // records == -1 means no records number info
+ if (records != -1) {
+ totalRecords = new Array[Long](dep.partitioner.numPartitions)
+ for (s <- statuses) {
+ for (i <- totalRecords.indices) {
+ totalRecords(i) += s.getRecordForBlock(i)
+ }
+ }
+ }
} else {
+ val (sizeParallelism, recordParallelism) = if (records != -1) {
+ (parallelism / 2, parallelism - parallelism / 2)
+ } else {
+ (parallelism, 0)
+ }
val threadPool = ThreadUtils.newDaemonFixedThreadPool(parallelism, "map-output-aggregate")
try {
implicit val executionContext = ExecutionContext.fromExecutor(threadPool)
- val mapStatusSubmitTasks = equallyDivide(totalSizes.length, parallelism).map {
+ var mapStatusSubmitTasks = equallyDivide(totalSizes.length, sizeParallelism).map {
reduceIds => Future {
for (s <- statuses; i <- reduceIds) {
totalSizes(i) += s.getSizeForBlock(i)
}
}
}
+ if (records != -1) {
+ totalRecords = new Array[Long](dep.partitioner.numPartitions)
+ mapStatusSubmitTasks ++= equallyDivide(totalRecords.length, recordParallelism).map {
+ reduceIds => Future {
+ for (s <- statuses; i <- reduceIds) {
+ totalRecords(i) += s.getRecordForBlock(i)
+ }
+ }
+ }
+ }
ThreadUtils.awaitResult(Future.sequence(mapStatusSubmitTasks), Duration.Inf)
} finally {
threadPool.shutdown()
}
}
- new MapOutputStatistics(dep.shuffleId, totalSizes)
+ new MapOutputStatistics(dep.shuffleId, totalSizes, totalRecords)
}
}
@@ -624,6 +687,35 @@ private[spark] class MapOutputTrackerMaster(
None
}
+ /**
+ * Return the locations where the Mapper(s) ran. The locations each includes both a host and an
+ * executor id on that host.
+ *
+ * @param dep shuffle dependency object
+ * @param startMapId the start map id
+ * @param endMapId the end map id
+ * @return a sequence of locations that each includes both a host and an executor id on that
+ * host.
+ */
+ def getMapLocation(dep: ShuffleDependency[_, _, _], startMapId: Int, endMapId: Int): Seq[String] =
+ {
+ val shuffleStatus = shuffleStatuses.get(dep.shuffleId).orNull
+ if (shuffleStatus != null) {
+ shuffleStatus.withMapStatuses { statuses =>
+ if (startMapId >= 0 && endMapId <= statuses.length) {
+ val statusesPicked = statuses.slice(startMapId, endMapId).filter(_ != null)
+ statusesPicked.map { status =>
+ ExecutorCacheTaskLocation(status.location.host, status.location.executorId).toString
+ }
+ } else {
+ Nil
+ }
+ }
+ } else {
+ Nil
+ }
+ }
+
def incrementEpoch() {
epochLock.synchronized {
epoch += 1
@@ -638,18 +730,47 @@ private[spark] class MapOutputTrackerMaster(
}
}
- // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result.
// This method is only called in local-mode.
- def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
- : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = {
+ def getMapSizesByExecutorId(
+ shuffleId: Int,
+ startPartition: Int,
+ endPartition: Int,
+ serializerRelocatable: Boolean): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
shuffleStatuses.get(shuffleId) match {
case Some (shuffleStatus) =>
shuffleStatus.withMapStatuses { statuses =>
- MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
+ MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses,
+ supportsContinuousBlockBatchFetch(serializerRelocatable))
+ }
+ case None =>
+ Seq.empty
+ }
+ }
+
+ override def getMapSizesByExecutorId(
+ shuffleId: Int,
+ startPartition: Int,
+ endPartition: Int,
+ startMapId: Int,
+ endMapId: Int,
+ serializerRelocatable: Boolean) : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
+ logDebug(s"Fetching outputs for shuffle $shuffleId, startMapId $startMapId endMapId $endMapId" +
+ s"partitions $startPartition-$endPartition")
+ shuffleStatuses.get(shuffleId) match {
+ case Some (shuffleStatus) =>
+ shuffleStatus.withMapStatuses { statuses =>
+ MapOutputTracker.convertMapStatuses(
+ shuffleId,
+ startPartition,
+ endPartition,
+ statuses,
+ startMapId,
+ endMapId,
+ supportsContinuousBlockBatchFetch(serializerRelocatable))
}
case None =>
- Iterator.empty
+ Seq.empty
}
}
@@ -676,13 +797,37 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
/** Remembers which map output locations are currently being fetched on an executor. */
private val fetching = new HashSet[Int]
- // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result.
- override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
- : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = {
+ override def getMapSizesByExecutorId(
+ shuffleId: Int,
+ startPartition: Int,
+ endPartition: Int,
+ serializerRelocatable: Boolean): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
val statuses = getStatuses(shuffleId)
try {
- MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
+ MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses,
+ supportsContinuousBlockBatchFetch(serializerRelocatable))
+ } catch {
+ case e: MetadataFetchFailedException =>
+ // We experienced a fetch failure so our mapStatuses cache is outdated; clear it:
+ mapStatuses.clear()
+ throw e
+ }
+ }
+
+ override def getMapSizesByExecutorId(
+ shuffleId: Int,
+ startPartition: Int,
+ endPartition: Int,
+ startMapId: Int,
+ endMapId: Int,
+ serializerRelocatable: Boolean) : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
+ logDebug(s"Fetching outputs for shuffle $shuffleId, startMapId $startMapId endMapId $endMapId" +
+ s"partitions $startPartition-$endPartition")
+ val statuses = getStatuses(shuffleId)
+ try {
+ MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses,
+ startMapId, endMapId, supportsContinuousBlockBatchFetch(serializerRelocatable))
} catch {
case e: MetadataFetchFailedException =>
// We experienced a fetch failure so our mapStatuses cache is outdated; clear it:
@@ -868,28 +1013,95 @@ private[spark] object MapOutputTracker extends Logging {
* and the second item is a sequence of (shuffle block ID, shuffle block size) tuples
* describing the shuffle blocks that are stored at that block manager.
*/
+ def convertMapStatuses(
+ shuffleId: Int,
+ startPartition: Int,
+ endPartition: Int,
+ statuses: Array[MapStatus],
+ supportsContinuousBlockBatchFetch: Boolean): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
+ assert (statuses != null)
+ val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]]
+ for ((status, mapId) <- statuses.zipWithIndex) {
+ if (status == null) {
+ val errorMessage = s"Missing an output location for shuffle $shuffleId"
+ logError(errorMessage)
+ throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage)
+ } else {
+ if (endPartition - startPartition > 1 && supportsContinuousBlockBatchFetch) {
+ val totalSize: Long = (startPartition until endPartition).map(status.getSizeForBlock).sum
+ if (totalSize != 0) {
+ splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) +=
+ ((ContinuousShuffleBlockId(shuffleId, mapId,
+ startPartition, endPartition - startPartition), totalSize))
+ }
+ } else {
+ for (part <- startPartition until endPartition) {
+ val size = status.getSizeForBlock(part)
+ if (size != 0) {
+ splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) +=
+ ((ShuffleBlockId(shuffleId, mapId, part), status.getSizeForBlock(part)))
+ }
+ }
+ }
+ }
+ }
+
+ splitsByAddress.toSeq
+ }
+
+ /**
+ * Given an array of map statuses, the start map Id, end map Id and a range of map output
+ * partitions, returns a sequence that, lists the shuffle block IDs and corresponding shuffle
+ * block sizes stored at that block manager.
+ *
+ * If the status of the map is null (indicating a missing location due to a failed mapper),
+ * throws a FetchFailedException.
+ *
+ * @param shuffleId Identifier for the shuffle
+ * @param startPartition Start of map output partition ID range (included in range)
+ * @param endPartition End of map output partition ID range (excluded from range)
+ * @param statuses List of map statuses, indexed by map ID.
+ * @param startMapId Start of map Id range (included in range)
+ * @param endMapId End of map Id (excluded from range)
+ * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
+ * and the second item is a sequence of (shuffle block ID, shuffle block size) tuples
+ * describing the shuffle blocks that are stored at that block manager.
+ */
def convertMapStatuses(
shuffleId: Int,
startPartition: Int,
endPartition: Int,
- statuses: Array[MapStatus]): Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = {
- assert (statuses != null)
- val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long)]]
- for ((status, mapId) <- statuses.iterator.zipWithIndex) {
+ statuses: Array[MapStatus],
+ startMapId: Int,
+ endMapId: Int,
+ supportsContinuousBlockBatchFetch: Boolean): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
+ assert (statuses != null && statuses.length >= endMapId && startMapId >= 0)
+ val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]]
+ for (mapId <- startMapId until endMapId) {
+ val status = statuses(mapId)
if (status == null) {
val errorMessage = s"Missing an output location for shuffle $shuffleId"
logError(errorMessage)
throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage)
} else {
- for (part <- startPartition until endPartition) {
- val size = status.getSizeForBlock(part)
- if (size != 0) {
- splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) +=
- ((ShuffleBlockId(shuffleId, mapId, part), size))
+ if (endPartition - startPartition > 1 && supportsContinuousBlockBatchFetch) {
+ val totalSize: Long = (startPartition until endPartition).map(status.getSizeForBlock).sum
+ if (totalSize != 0) {
+ splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) +=
+ ((ContinuousShuffleBlockId(shuffleId, mapId,
+ startPartition, endPartition - startPartition), totalSize))
+ }
+ } else {
+ for (part <- startPartition until endPartition) {
+ val size = status.getSizeForBlock(part)
+ if (size != 0) {
+ splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) +=
+ ((ShuffleBlockId(shuffleId, mapId, part), status.getSizeForBlock(part)))
+ }
}
}
}
}
- splitsByAddress.iterator
+ splitsByAddress.toSeq
}
}
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 0e39eb874aafd..fde656985dc49 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -418,14 +418,6 @@ package object config {
.booleanConf
.createWithDefault(false)
- private[spark] val SHUFFLE_ACCURATE_BLOCK_THRESHOLD =
- ConfigBuilder("spark.shuffle.accurateBlockThreshold")
- .doc("Threshold in bytes above which the size of shuffle blocks in " +
- "HighlyCompressedMapStatus is accurately recorded. This helps to prevent OOM " +
- "by avoiding underestimating shuffle block size when fetch shuffle blocks.")
- .bytesConf(ByteUnit.BYTE)
- .createWithDefault(100 * 1024 * 1024)
-
private[spark] val SHUFFLE_REGISTRATION_TIMEOUT =
ConfigBuilder("spark.shuffle.registration.timeout")
.doc("Timeout in milliseconds for registration to the external shuffle service.")
@@ -664,4 +656,33 @@ package object config {
.stringConf
.toSequence
.createWithDefault(Nil)
+
+ private[spark] val SHUFFLE_HIGHLY_COMPRESSED_MAP_STATUS_THRESHOLD =
+ ConfigBuilder("spark.shuffle.highlyCompressedMapStatusThreshold")
+ .doc("HighlyCompressedMapStatus is used if shuffle partition number is larger than the " +
+ "threshold. Otherwise CompressedMapStatus is used.")
+ .intConf
+ .createWithDefault(2000)
+
+ private[spark] val SHUFFLE_STATISTICS_VERBOSE =
+ ConfigBuilder("spark.shuffle.statistics.verbose")
+ .doc("Collect shuffle statistics in verbose mode, including row counts etc.")
+ .booleanConf
+ .createWithDefault(false)
+
+ private[spark] val SHUFFLE_ACCURATE_BLOCK_SIZE_THRESHOLD =
+ ConfigBuilder("spark.shuffle.accurateBlockThreshold")
+ .doc("Threshold in bytes above which the size of shuffle blocks in " +
+ "HighlyCompressedMapStatus is accurately recorded. This helps to prevent OOM " +
+ "by avoiding underestimating shuffle block size when fetch shuffle blocks.")
+ .bytesConf(ByteUnit.BYTE)
+ .createWithDefault(100 * 1024 * 1024)
+
+ private[spark] val SHUFFLE_ACCURATE_BLOCK_RECORD_THRESHOLD =
+ ConfigBuilder("spark.shuffle.accurateBlockRecordThreshold")
+ .doc("When we compress the records number of shuffle blocks in HighlyCompressedMapStatus, " +
+ "we will record the number accurately if it's above this config. The record number will " +
+ "be used for data skew judgement when spark.shuffle.statistics.verbose is set true.")
+ .bytesConf(ByteUnit.BYTE)
+ .createWithDefault(2 * 1024 * 1024)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
index 0e221edf3965a..7a9a38f05c01d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
@@ -44,24 +44,31 @@ private[spark] sealed trait MapStatus {
* necessary for correctness, since block fetchers are allowed to skip zero-size blocks.
*/
def getSizeForBlock(reduceId: Int): Long
+
+ def getRecordForBlock(reduceId: Int): Long
}
private[spark] object MapStatus {
- /**
- * Min partition number to use [[HighlyCompressedMapStatus]]. A bit ugly here because in test
- * code we can't assume SparkEnv.get exists.
- */
- private lazy val minPartitionsToUseHighlyCompressMapStatus = Option(SparkEnv.get)
- .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS))
- .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get)
+ // we use Array[Long]() as uncompressedRecords's default value,
+ // main consideration is ser/deser do not accept null.
+ def apply(
+ loc: BlockManagerId,
+ uncompressedSizes: Array[Long],
+ uncompressedRecords: Array[Long] = Array[Long]()): MapStatus = {
+ val verbose = Option(SparkEnv.get)
+ .map(_.conf.get(config.SHUFFLE_STATISTICS_VERBOSE))
+ .getOrElse(config.SHUFFLE_STATISTICS_VERBOSE.defaultValue.get)
+ val threshold = Option(SparkEnv.get)
+ .map(_.conf.get(config.SHUFFLE_HIGHLY_COMPRESSED_MAP_STATUS_THRESHOLD))
+ .getOrElse(config.SHUFFLE_HIGHLY_COMPRESSED_MAP_STATUS_THRESHOLD.defaultValue.get)
- def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = {
- if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) {
- HighlyCompressedMapStatus(loc, uncompressedSizes)
+ val newRecords = if (verbose) uncompressedRecords else Array[Long]()
+ if (uncompressedSizes.length > threshold) {
+ HighlyCompressedMapStatus(loc, uncompressedSizes, newRecords)
} else {
- new CompressedMapStatus(loc, uncompressedSizes)
+ new CompressedMapStatus(loc, uncompressedSizes, newRecords)
}
}
@@ -104,13 +111,17 @@ private[spark] object MapStatus {
*/
private[spark] class CompressedMapStatus(
private[this] var loc: BlockManagerId,
- private[this] var compressedSizes: Array[Byte])
+ private[this] var compressedSizes: Array[Byte],
+ private[this] var compressedRecords: Array[Byte])
extends MapStatus with Externalizable {
- protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only
+ protected def this() = this(null, null.asInstanceOf[Array[Byte]],
+ null.asInstanceOf[Array[Byte]]) // For deserialization only
- def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) {
- this(loc, uncompressedSizes.map(MapStatus.compressSize))
+ def this(loc: BlockManagerId, uncompressedSizes: Array[Long],
+ uncompressedRecords: Array[Long] = Array[Long]()) {
+ this(loc, uncompressedSizes.map(MapStatus.compressSize),
+ uncompressedRecords.map(MapStatus.compressSize))
}
override def location: BlockManagerId = loc
@@ -119,10 +130,20 @@ private[spark] class CompressedMapStatus(
MapStatus.decompressSize(compressedSizes(reduceId))
}
+ override def getRecordForBlock(reduceId: Int): Long = {
+ if (compressedRecords.nonEmpty) {
+ MapStatus.decompressSize(compressedRecords(reduceId))
+ } else {
+ -1
+ }
+ }
+
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
loc.writeExternal(out)
out.writeInt(compressedSizes.length)
out.write(compressedSizes)
+ out.writeInt(compressedRecords.length)
+ out.write(compressedRecords)
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
@@ -130,6 +151,9 @@ private[spark] class CompressedMapStatus(
val len = in.readInt()
compressedSizes = new Array[Byte](len)
in.readFully(compressedSizes)
+ val recordsLen = in.readInt()
+ compressedRecords = new Array[Byte](recordsLen)
+ in.readFully(compressedRecords)
}
}
@@ -145,18 +169,20 @@ private[spark] class CompressedMapStatus(
* @param hugeBlockSizes sizes of huge blocks by their reduceId.
*/
private[spark] class HighlyCompressedMapStatus private (
- private[this] var loc: BlockManagerId,
- private[this] var numNonEmptyBlocks: Int,
- private[this] var emptyBlocks: RoaringBitmap,
- private[this] var avgSize: Long,
- private var hugeBlockSizes: Map[Int, Byte])
+ private[this] var loc: BlockManagerId,
+ private[this] var numNonEmptyBlocks: Int,
+ private[this] var emptyBlocks: RoaringBitmap,
+ private[this] var avgSize: Long,
+ private var hugeBlockSizes: Map[Int, Byte],
+ private[this] var avgRecord: Long,
+ private var hugeBlockRecords: Map[Int, Byte])
extends MapStatus with Externalizable {
// loc could be null when the default constructor is called during deserialization
require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0,
"Average size can only be zero for map stages that produced no output")
- protected def this() = this(null, -1, null, -1, null) // For deserialization only
+ protected def this() = this(null, -1, null, -1, null, -1, null) // For deserialization only
override def location: BlockManagerId = loc
@@ -172,6 +198,22 @@ private[spark] class HighlyCompressedMapStatus private (
}
}
+ override def getRecordForBlock(reduceId: Int): Long = {
+ assert(hugeBlockSizes != null)
+ if (avgRecord != -1) {
+ if (emptyBlocks.contains(reduceId)) {
+ 0
+ } else {
+ hugeBlockRecords.get(reduceId) match {
+ case Some(record) => MapStatus.decompressSize(record)
+ case None => avgRecord
+ }
+ }
+ } else {
+ -1
+ }
+ }
+
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
loc.writeExternal(out)
emptyBlocks.writeExternal(out)
@@ -181,6 +223,12 @@ private[spark] class HighlyCompressedMapStatus private (
out.writeInt(kv._1)
out.writeByte(kv._2)
}
+ out.writeLong(avgRecord)
+ out.writeInt(hugeBlockRecords.size)
+ hugeBlockRecords.foreach { kv =>
+ out.writeInt(kv._1)
+ out.writeByte(kv._2)
+ }
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
@@ -196,11 +244,23 @@ private[spark] class HighlyCompressedMapStatus private (
hugeBlockSizesArray += Tuple2(block, size)
}
hugeBlockSizes = hugeBlockSizesArray.toMap
+ avgRecord = in.readLong()
+ val recordCount = in.readInt()
+ val hugeBlockRecordsArray = mutable.ArrayBuffer[Tuple2[Int, Byte]]()
+ (0 until recordCount).foreach { _ =>
+ val block = in.readInt()
+ val record = in.readByte()
+ hugeBlockRecordsArray += Tuple2(block, record)
+ }
+ hugeBlockRecords = hugeBlockRecordsArray.toMap
}
}
private[spark] object HighlyCompressedMapStatus {
- def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = {
+ def apply(
+ loc: BlockManagerId,
+ uncompressedSizes: Array[Long],
+ uncompressedRecords: Array[Long] = Array[Long]()): HighlyCompressedMapStatus = {
// We must keep track of which blocks are empty so that we don't report a zero-sized
// block as being non-empty (or vice-versa) when using the average block size.
var i = 0
@@ -213,8 +273,8 @@ private[spark] object HighlyCompressedMapStatus {
val emptyBlocks = new RoaringBitmap()
val totalNumBlocks = uncompressedSizes.length
val threshold = Option(SparkEnv.get)
- .map(_.conf.get(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD))
- .getOrElse(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD.defaultValue.get)
+ .map(_.conf.get(config.SHUFFLE_ACCURATE_BLOCK_SIZE_THRESHOLD))
+ .getOrElse(config.SHUFFLE_ACCURATE_BLOCK_SIZE_THRESHOLD.defaultValue.get)
val hugeBlockSizesArray = ArrayBuffer[Tuple2[Int, Byte]]()
while (i < totalNumBlocks) {
val size = uncompressedSizes(i)
@@ -238,9 +298,38 @@ private[spark] object HighlyCompressedMapStatus {
} else {
0
}
+
+ var recordSmallBlocks: Long = 0
+ numSmallBlocks = 0
+ var avgRecord: Long = -1
+ val recordThreshold = Option(SparkEnv.get)
+ .map(_.conf.get(config.SHUFFLE_ACCURATE_BLOCK_RECORD_THRESHOLD))
+ .getOrElse(config.SHUFFLE_ACCURATE_BLOCK_RECORD_THRESHOLD.defaultValue.get)
+ val hugeBlockRecordsArray = ArrayBuffer[Tuple2[Int, Byte]]()
+ if (uncompressedRecords.nonEmpty) {
+ i = 0
+ while (i < totalNumBlocks) {
+ val record = uncompressedRecords(i)
+ if (record > 0) {
+ if (record < recordThreshold) {
+ recordSmallBlocks += record
+ numSmallBlocks += 1
+ } else {
+ hugeBlockRecordsArray += Tuple2(i, MapStatus.compressSize(uncompressedRecords(i)))
+ }
+ }
+ i += 1
+ }
+ avgRecord = if (numSmallBlocks > 0) {
+ recordSmallBlocks / numSmallBlocks
+ } else {
+ 0
+ }
+ }
+
emptyBlocks.trim()
emptyBlocks.runOptimize()
new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize,
- hugeBlockSizesArray.toMap)
+ hugeBlockSizesArray.toMap, avgRecord, hugeBlockRecordsArray.toMap)
}
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
index 1d4b05caaa143..99ee035c375ae 100644
--- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
@@ -109,6 +109,7 @@ private[spark] class SerializerManager(
private def shouldCompress(blockId: BlockId): Boolean = {
blockId match {
case _: ShuffleBlockId => compressShuffle
+ case _: ContinuousShuffleBlockId => compressShuffle
case _: BroadcastBlockId => compressBroadcast
case _: RDDBlockId => compressRdds
case _: TempLocalBlockId => compressShuffleSpill
diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
index 74b0e0b3a741a..ce89eee31f376 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
@@ -35,18 +35,37 @@ private[spark] class BlockStoreShuffleReader[K, C](
context: TaskContext,
serializerManager: SerializerManager = SparkEnv.get.serializerManager,
blockManager: BlockManager = SparkEnv.get.blockManager,
- mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
+ mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker,
+ startMapId: Option[Int] = None,
+ endMapId: Option[Int] = None)
extends ShuffleReader[K, C] with Logging {
private val dep = handle.dependency
/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
+ val blocksByAddress = (startMapId, endMapId) match {
+ case (Some(startId), Some(endId)) => mapOutputTracker.getMapSizesByExecutorId(
+ handle.shuffleId,
+ startPartition,
+ endPartition,
+ startId,
+ endId,
+ dep.serializer.supportsRelocationOfSerializedObjects)
+ case (None, None) => mapOutputTracker.getMapSizesByExecutorId(
+ handle.shuffleId,
+ startPartition,
+ endPartition,
+ dep.serializer.supportsRelocationOfSerializedObjects)
+ case (_, _) => throw new IllegalArgumentException(
+ "startMapId and endMapId should be both set or unset")
+ }
+
val wrappedStreams = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager,
- mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
+ blocksByAddress,
serializerManager.wrapStream,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
index d3f1c7ec1bbee..696662c1945a3 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
@@ -190,7 +190,7 @@ private[spark] class IndexShuffleBlockResolver(
}
}
- override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
+ override def getBlockData(blockId: ShuffleBlockIdBase): ManagedBuffer = {
// The block is actually going to be a range of a single map output file for this map, so
// find out the consolidated file, then the offset within that from our index
val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId)
@@ -205,10 +205,19 @@ private[spark] class IndexShuffleBlockResolver(
channel.position(blockId.reduceId * 8L)
val in = new DataInputStream(Channels.newInputStream(channel))
try {
+ channel.position(blockId.reduceId * 8)
val offset = in.readLong()
+ var expectedPosition = 0
+ blockId match {
+ case bid: ContinuousShuffleBlockId =>
+ val tempId = blockId.reduceId + bid.numBlocks
+ channel.position(tempId * 8)
+ expectedPosition = tempId * 8 + 8
+ case _ =>
+ expectedPosition = blockId.reduceId * 8 + 16
+ }
val nextOffset = in.readLong()
val actualPosition = channel.position()
- val expectedPosition = blockId.reduceId * 8L + 16
if (actualPosition != expectedPosition) {
throw new Exception(s"SPARK-22982: Incorrect channel position after index file reads: " +
s"expected $expectedPosition but actual position was $actualPosition.")
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala
index d1ecbc1bf0178..8b62c00bc6b9c 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala
@@ -18,7 +18,7 @@
package org.apache.spark.shuffle
import org.apache.spark.network.buffer.ManagedBuffer
-import org.apache.spark.storage.ShuffleBlockId
+import org.apache.spark.storage.ShuffleBlockIdBase
private[spark]
/**
@@ -34,7 +34,7 @@ trait ShuffleBlockResolver {
* Retrieve the data for the specified block. If the data for that block is not available,
* throws an unspecified exception.
*/
- def getBlockData(blockId: ShuffleBlockId): ManagedBuffer
+ def getBlockData(blockId: ShuffleBlockIdBase): ManagedBuffer
def stop(): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
index 4ea8a7120a9cc..f61461fede671 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
@@ -41,7 +41,8 @@ private[spark] trait ShuffleManager {
def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext): ShuffleWriter[K, V]
/**
- * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
+ * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive) to
+ * read from map output (startMapId to endMapId - 1, inclusive).
* Called on executors by reduce tasks.
*/
def getReader[K, C](
@@ -50,6 +51,19 @@ private[spark] trait ShuffleManager {
endPartition: Int,
context: TaskContext): ShuffleReader[K, C]
+ /**
+ * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive) to
+ * read from map output (startMapId to endMapId - 1, inclusive).
+ * Called on executors by reduce tasks.
+ */
+ def getReader[K, C](
+ handle: ShuffleHandle,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ startMapId: Int,
+ endMapId: Int): ShuffleReader[K, C]
+
/**
* Remove a shuffle's metadata from the ShuffleManager.
* @return true if the metadata removed successfully, otherwise false.
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index 0caf84c6050a8..cc8e615d5da28 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -119,6 +119,27 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
}
+ /**
+ * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive) to
+ * read from map output (startMapId to endMapId - 1, inclusive).
+ * Called on executors by reduce tasks.
+ */
+ override def getReader[K, C](
+ handle: ShuffleHandle,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ startMapId: Int,
+ endMapId: Int): ShuffleReader[K, C] = {
+ new BlockStoreShuffleReader(
+ handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
+ startPartition,
+ endPartition,
+ context,
+ startMapId = Some(startMapId),
+ endMapId = Some(endMapId))
+ }
+
/** Get a writer for a given partition. Called on executors by map tasks. */
override def getWriter[K, V](
handle: ShuffleHandle,
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 274399b9cc1f3..e4b9bd3caea7f 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -68,9 +68,9 @@ private[spark] class SortShuffleWriter[K, V, C](
val tmp = Utils.tempFileWith(output)
try {
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
- val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
- shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
- mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
+ val mapInfo = sorter.writePartitionedFile(blockId, tmp)
+ shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, mapInfo.lengths, tmp)
+ mapStatus = MapStatus(blockManager.shuffleServerId, mapInfo.lengths, mapInfo.records)
} finally {
if (tmp.exists() && !tmp.delete()) {
logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index 7ac2c71c18eb3..14a4df53931fc 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -38,7 +38,7 @@ sealed abstract class BlockId {
// convenience methods
def asRDDId: Option[RDDBlockId] = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None
def isRDD: Boolean = isInstanceOf[RDDBlockId]
- def isShuffle: Boolean = isInstanceOf[ShuffleBlockId]
+ def isShuffle: Boolean = isInstanceOf[ShuffleBlockIdBase]
def isBroadcast: Boolean = isInstanceOf[BroadcastBlockId]
override def toString: String = name
@@ -51,11 +51,25 @@ case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId {
// Format of the shuffle block ids (including data and index) should be kept in sync with
// org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getBlockData().
+trait ShuffleBlockIdBase extends BlockId {
+ def shuffleId: Int
+ def mapId: Int
+ def reduceId: Int
+}
+
@DeveloperApi
-case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
+case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int)
+ extends ShuffleBlockIdBase {
override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
}
+@DeveloperApi
+case class ContinuousShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int, numBlocks: Int)
+ extends ShuffleBlockIdBase {
+ override def name: String =
+ "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + "_" + numBlocks
+}
+
@DeveloperApi
case class ShuffleDataBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".data"
@@ -104,6 +118,7 @@ class UnrecognizedBlockId(name: String)
object BlockId {
val RDD = "rdd_([0-9]+)_([0-9]+)".r
val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
+ val CONTINUE_SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)_([0-9]+)".r
val SHUFFLE_DATA = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).data".r
val SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).index".r
val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r
@@ -118,6 +133,8 @@ object BlockId {
RDDBlockId(rddId.toInt, splitIndex.toInt)
case SHUFFLE(shuffleId, mapId, reduceId) =>
ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
+ case CONTINUE_SHUFFLE(shuffleId, mapId, reduceId, length) =>
+ ContinuousShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt, length.toInt)
case SHUFFLE_DATA(shuffleId, mapId, reduceId) =>
ShuffleDataBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
case SHUFFLE_INDEX(shuffleId, mapId, reduceId) =>
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 8912e971e50d0..cd2a14fec417a 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -378,7 +378,7 @@ private[spark] class BlockManager(
*/
override def getBlockData(blockId: BlockId): ManagedBuffer = {
if (blockId.isShuffle) {
- shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
+ shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockIdBase])
} else {
getLocalBytes(blockId) match {
case Some(blockData) =>
@@ -638,7 +638,7 @@ private[spark] class BlockManager(
// TODO: This should gracefully handle case where local block is not available. Currently
// downstream code will throw an exception.
val buf = new ChunkedByteBuffer(
- shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer())
+ shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockIdBase]).nioByteBuffer())
Some(new ByteBufferBlockData(buf, true))
} else {
blockInfoManager.lockForReading(blockId).map { info => doGetLocalBytes(blockId, info) }
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
index a024c83d8d8b7..d15fe396e4033 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
@@ -181,7 +181,8 @@ private[spark] class DiskBlockObjectWriter(
}
val pos = channel.position()
- val fileSegment = new FileSegment(file, committedPosition, pos - committedPosition)
+ val fileSegment =
+ new FileSegment(file, committedPosition, pos - committedPosition, numRecordsWritten)
committedPosition = pos
// In certain compression codecs, more bytes are written after streams are closed
writeMetrics.incBytesWritten(committedPosition - reportedPosition)
@@ -189,7 +190,7 @@ private[spark] class DiskBlockObjectWriter(
numRecordsWritten = 0
fileSegment
} else {
- new FileSegment(file, committedPosition, 0)
+ new FileSegment(file, committedPosition, 0, 0)
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
index 021a9facfb0b2..063c1e26a1e53 100644
--- a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
+++ b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
@@ -21,12 +21,17 @@ import java.io.File
/**
* References a particular segment of a file (potentially the entire file),
- * based off an offset and a length.
+ * based off offset, length and record number.
*/
-private[spark] class FileSegment(val file: File, val offset: Long, val length: Long) {
+private[spark] class FileSegment(
+ val file: File,
+ val offset: Long,
+ val length: Long,
+ val record: Long) {
require(offset >= 0, s"File segment offset cannot be negative (got $offset)")
require(length >= 0, s"File segment length cannot be negative (got $length)")
+ require(record >= 0, s"File segment record cannot be negative (got $record)")
override def toString: String = {
- "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length)
+ "(name=%s, offset=%d, length=%d, record=%d)".format(file.getName, offset, length, record)
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index aecc2284a9588..fcf215026673c 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -46,13 +46,11 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream
*
* @param context [[TaskContext]], used for metrics update
* @param shuffleClient [[ShuffleClient]] for fetching remote blocks
- * @param blockManager [[BlockManager]] for reading local blocks
+ * @param blockManager [[BlockManager]] for reading local blocks
* @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
* For each block we also require the size (in bytes as a long field) in
- * order to throttle the memory usage. Note that zero-sized blocks are
- * already excluded, which happened in
- * [[MapOutputTracker.convertMapStatuses]].
- * @param streamWrapper A function to wrap the returned input stream.
+ * order to throttle the memory usage.
+ * @param streamWrapper A function to wrap the returned input stream.
* @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
* @param maxReqsInFlight max number of remote requests to fetch blocks at any given point.
* @param maxBlocksInFlightPerAddress max number of shuffle blocks being fetched at any given point
@@ -65,7 +63,7 @@ final class ShuffleBlockFetcherIterator(
context: TaskContext,
shuffleClient: ShuffleClient,
blockManager: BlockManager,
- blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])],
+ blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
streamWrapper: (BlockId, InputStream) => InputStream,
maxBytesInFlight: Long,
maxReqsInFlight: Int,
@@ -550,8 +548,9 @@ final class ShuffleBlockFetcherIterator(
private def throwFetchFailedException(blockId: BlockId, address: BlockManagerId, e: Throwable) = {
blockId match {
- case ShuffleBlockId(shufId, mapId, reduceId) =>
- throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e)
+ case blockId: ShuffleBlockIdBase =>
+ throw new FetchFailedException(
+ address, blockId.shuffleId, blockId.mapId, blockId.reduceId, e)
case _ =>
throw new SparkException(
"Failed to get block " + blockId + ", which is not a shuffle block", e)
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 547a862467c88..c4fe743d3adaf 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -29,6 +29,7 @@ import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.serializer._
+import org.apache.spark.shuffle.sort.MapInfo
import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
/**
@@ -682,10 +683,11 @@ private[spark] class ExternalSorter[K, V, C](
*/
def writePartitionedFile(
blockId: BlockId,
- outputFile: File): Array[Long] = {
+ outputFile: File): MapInfo = {
// Track location of each range in the output file
val lengths = new Array[Long](numPartitions)
+ val records = new Array[Long](numPartitions)
val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
context.taskMetrics().shuffleWriteMetrics)
@@ -700,6 +702,7 @@ private[spark] class ExternalSorter[K, V, C](
}
val segment = writer.commitAndGet()
lengths(partitionId) = segment.length
+ records(partitionId) = segment.record
}
} else {
// We must perform merge-sort; get an iterator by partition and write everything directly.
@@ -710,6 +713,7 @@ private[spark] class ExternalSorter[K, V, C](
}
val segment = writer.commitAndGet()
lengths(id) = segment.length
+ records(id) = segment.record
}
}
}
@@ -719,7 +723,7 @@ private[spark] class ExternalSorter[K, V, C](
context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)
- lengths
+ new MapInfo(lengths, records)
}
def stop(): Unit = {
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index 0d5c5ea7903e9..929289dbfa26f 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -21,6 +21,7 @@
import java.nio.ByteBuffer;
import java.util.*;
+import org.apache.spark.*;
import scala.Option;
import scala.Product2;
import scala.Tuple2;
@@ -35,10 +36,6 @@
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
-import org.apache.spark.HashPartitioner;
-import org.apache.spark.ShuffleDependency;
-import org.apache.spark.SparkConf;
-import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.io.CompressionCodec$;
@@ -81,6 +78,8 @@ public class UnsafeShuffleWriterSuite {
@Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
@Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
@Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency