diff --git a/assembly/pom.xml b/assembly/pom.xml index 71333f734d9a1..3f955997fd925 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../pom.xml diff --git a/common/kvstore/pom.xml b/common/kvstore/pom.xml index 292da55179322..5828d2e5d512a 100644 --- a/common/kvstore/pom.xml +++ b/common/kvstore/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../../pom.xml diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index 22ad07bf581c7..a1a884a91433c 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../../pom.xml diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index 5f7784d21e236..b1db8a14d05bd 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../../pom.xml diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index 098fa7974b87b..d24c87e035ca5 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -197,48 +197,60 @@ public Map getMetrics() { } } + private boolean isShuffleBlock(String[] blockIdParts) { + // length == 4: ShuffleBlockId + // length == 5: ContinuousShuffleBlockId + return (blockIdParts.length == 4 || blockIdParts.length == 5) && + blockIdParts[0].equals("shuffle"); + } + private class ManagedBufferIterator implements Iterator { private int index = 0; private final String appId; private final String execId; private final int shuffleId; - // An array containing mapId and reduceId pairs. - private final int[] mapIdAndReduceIds; + // An array containing mapId, reduceId and numBlocks tuple + private final int[] shuffleBlockIds; ManagedBufferIterator(String appId, String execId, String[] blockIds) { this.appId = appId; this.execId = execId; String[] blockId0Parts = blockIds[0].split("_"); - if (blockId0Parts.length != 4 || !blockId0Parts[0].equals("shuffle")) { + if (!isShuffleBlock(blockId0Parts)) { throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[0]); } this.shuffleId = Integer.parseInt(blockId0Parts[1]); - mapIdAndReduceIds = new int[2 * blockIds.length]; + shuffleBlockIds = new int[3 * blockIds.length]; for (int i = 0; i < blockIds.length; i++) { String[] blockIdParts = blockIds[i].split("_"); - if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) { + if (!isShuffleBlock(blockIdParts)) { throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[i]); } if (Integer.parseInt(blockIdParts[1]) != shuffleId) { throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + - ", got:" + blockIds[i]); + ", got:" + blockIds[i]); + } + shuffleBlockIds[3 * i] = Integer.parseInt(blockIdParts[2]); + shuffleBlockIds[3 * i + 1] = Integer.parseInt(blockIdParts[3]); + if (blockIdParts.length == 4) { + shuffleBlockIds[3 * i + 2] = 1; + } else { + shuffleBlockIds[3 * i + 2] = Integer.parseInt(blockIdParts[4]); } - mapIdAndReduceIds[2 * i] = Integer.parseInt(blockIdParts[2]); - mapIdAndReduceIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]); } } @Override public boolean hasNext() { - return index < mapIdAndReduceIds.length; + return index < shuffleBlockIds.length; } @Override public ManagedBuffer next() { final ManagedBuffer block = blockManager.getBlockData(appId, execId, shuffleId, - mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1]); - index += 2; + shuffleBlockIds[index], shuffleBlockIds[index + 1], shuffleBlockIds[index + 2]); + index += 3; metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0); return block; } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 0b7a27402369d..9589544e3ab03 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -162,21 +162,22 @@ public void registerExecutor( } /** - * Obtains a FileSegmentManagedBuffer from (shuffleId, mapId, reduceId). We make assumptions - * about how the hash and sort based shuffles store their data. + * Obtains a FileSegmentManagedBuffer from (shuffleId, mapId, reduceId, numBlocks). + * We make assumptions about how the hash and sort based shuffles store their data. */ public ManagedBuffer getBlockData( - String appId, - String execId, - int shuffleId, - int mapId, - int reduceId) { + String appId, + String execId, + int shuffleId, + int mapId, + int reduceId, + int numBlocks) { ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); if (executor == null) { throw new RuntimeException( - String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId)); + String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId)); } - return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId); + return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId, numBlocks); } /** @@ -280,19 +281,19 @@ public boolean accept(File dir, String name) { * and the block id format is from ShuffleDataBlockId and ShuffleIndexBlockId. */ private ManagedBuffer getSortBasedShuffleBlockData( - ExecutorShuffleInfo executor, int shuffleId, int mapId, int reduceId) { + ExecutorShuffleInfo executor, int shuffleId, int mapId, int reduceId, int numBlocks) { File indexFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, - "shuffle_" + shuffleId + "_" + mapId + "_0.index"); + "shuffle_" + shuffleId + "_" + mapId + "_0.index"); try { ShuffleIndexInformation shuffleIndexInformation = shuffleIndexCache.get(indexFile); - ShuffleIndexRecord shuffleIndexRecord = shuffleIndexInformation.getIndex(reduceId); + ShuffleIndexRecord shuffleIndexRecord = shuffleIndexInformation.getIndex(reduceId, numBlocks); return new FileSegmentManagedBuffer( - conf, - getFile(executor.localDirs, executor.subDirsPerLocalDir, - "shuffle_" + shuffleId + "_" + mapId + "_0.data"), - shuffleIndexRecord.getOffset(), - shuffleIndexRecord.getLength()); + conf, + getFile(executor.localDirs, executor.subDirsPerLocalDir, + "shuffle_" + shuffleId + "_" + mapId + "_0.data"), + shuffleIndexRecord.getOffset(), + shuffleIndexRecord.getLength()); } catch (ExecutionException e) { throw new RuntimeException("Failed to open file: " + indexFile, e); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java index 386738ece51a6..470e1040e97e5 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java @@ -59,9 +59,9 @@ public int getSize() { /** * Get index offset for a particular reducer. */ - public ShuffleIndexRecord getIndex(int reduceId) { + public ShuffleIndexRecord getIndex(int reduceId, int numBlocks) { long offset = offsets.get(reduceId); - long nextOffset = offsets.get(reduceId + 1); + long nextOffset = offsets.get(reduceId + numBlocks); return new ShuffleIndexRecord(offset, nextOffset - offset); } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index 7846b71d5a8b1..baa7146604ef6 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -83,8 +83,8 @@ public void testOpenShuffleBlocks() { ManagedBuffer block0Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[3])); ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); - when(blockResolver.getBlockData("app0", "exec1", 0, 0, 0)).thenReturn(block0Marker); - when(blockResolver.getBlockData("app0", "exec1", 0, 0, 1)).thenReturn(block1Marker); + when(blockResolver.getBlockData("app0", "exec1", 0, 0, 0, 1)).thenReturn(block0Marker); + when(blockResolver.getBlockData("app0", "exec1", 0, 0, 1, 1)).thenReturn(block1Marker); ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", new String[] { "shuffle_0_0_0", "shuffle_0_0_1" }) .toByteBuffer(); @@ -106,8 +106,8 @@ public void testOpenShuffleBlocks() { assertEquals(block0Marker, buffers.next()); assertEquals(block1Marker, buffers.next()); assertFalse(buffers.hasNext()); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 0); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 1); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 0, 1); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 1, 1); // Verify open block request latency metrics Timer openBlockRequestLatencyMillis = (Timer) ((ExternalShuffleBlockHandler) handler) diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index d2072a54fa415..05ca1fab7c9e0 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -66,7 +66,7 @@ public void testBadRequests() throws IOException { ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); // Unregistered executor try { - resolver.getBlockData("app0", "exec1", 1, 1, 0); + resolver.getBlockData("app0", "exec1", 1, 1, 0, 1); fail("Should have failed"); } catch (RuntimeException e) { assertTrue("Bad error message: " + e, e.getMessage().contains("not registered")); @@ -75,7 +75,7 @@ public void testBadRequests() throws IOException { // Invalid shuffle manager try { resolver.registerExecutor("app0", "exec2", dataContext.createExecutorInfo("foobar")); - resolver.getBlockData("app0", "exec2", 1, 1, 0); + resolver.getBlockData("app0", "exec2", 1, 1, 0, 1); fail("Should have failed"); } catch (UnsupportedOperationException e) { // pass @@ -85,7 +85,7 @@ public void testBadRequests() throws IOException { resolver.registerExecutor("app0", "exec3", dataContext.createExecutorInfo(SORT_MANAGER)); try { - resolver.getBlockData("app0", "exec3", 1, 1, 0); + resolver.getBlockData("app0", "exec3", 1, 1, 0, 1); fail("Should have failed"); } catch (Exception e) { // pass @@ -99,18 +99,25 @@ public void testSortShuffleBlocks() throws IOException { dataContext.createExecutorInfo(SORT_MANAGER)); InputStream block0Stream = - resolver.getBlockData("app0", "exec0", 0, 0, 0).createInputStream(); + resolver.getBlockData("app0", "exec0", 0, 0, 0, 1).createInputStream(); String block0 = CharStreams.toString( new InputStreamReader(block0Stream, StandardCharsets.UTF_8)); block0Stream.close(); assertEquals(sortBlock0, block0); InputStream block1Stream = - resolver.getBlockData("app0", "exec0", 0, 0, 1).createInputStream(); + resolver.getBlockData("app0", "exec0", 0, 0, 1, 1).createInputStream(); String block1 = CharStreams.toString( new InputStreamReader(block1Stream, StandardCharsets.UTF_8)); block1Stream.close(); assertEquals(sortBlock1, block1); + + InputStream block01Stream = + resolver.getBlockData("app0", "exec0", 0, 0, 0, 2).createInputStream(); + String block01 = CharStreams.toString( + new InputStreamReader(block01Stream, StandardCharsets.UTF_8)); + block01Stream.close(); + assertEquals(sortBlock0 + sortBlock1, block01); } @Test diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index fa0293207f8a4..612b38e34e1a1 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../../pom.xml diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index c5341af740707..a5482b6c9f766 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../../pom.xml diff --git a/common/tags/pom.xml b/common/tags/pom.xml index cb9a1cb8600f6..52bfe2fd9b3e5 100644 --- a/common/tags/pom.xml +++ b/common/tags/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../../pom.xml diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index 8afac700b4e78..6d44060436be6 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../../pom.xml diff --git a/core/pom.xml b/core/pom.xml index 508d393ebc4e8..ea1748f5412b8 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../pom.xml diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 323a5d3c52831..cf63088151be4 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -125,7 +125,7 @@ public void write(Iterator> records) throws IOException { if (!records.hasNext()) { partitionLengths = new long[numPartitions]; shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null); - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, new long[numPartitions]); return; } final SerializerInstance serInstance = serializer.newInstance(); @@ -159,15 +159,18 @@ public void write(Iterator> records) throws IOException { File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); File tmp = Utils.tempFileWith(output); + MapInfo mapInfo; try { - partitionLengths = writePartitionedFile(tmp); + mapInfo = writePartitionedFile(tmp); + partitionLengths = mapInfo.lengths; shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); } finally { if (tmp.exists() && !tmp.delete()) { logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); } } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), mapInfo.lengths, mapInfo.records); } @VisibleForTesting @@ -180,12 +183,13 @@ long[] getPartitionLengths() { * * @return array of lengths, in bytes, of each partition of the file (used by map output tracker). */ - private long[] writePartitionedFile(File outputFile) throws IOException { + private MapInfo writePartitionedFile(File outputFile) throws IOException { // Track location of the partition starts in the output file final long[] lengths = new long[numPartitions]; + final long[] records = new long[numPartitions]; if (partitionWriters == null) { // We were passed an empty iterator - return lengths; + return new MapInfo(lengths, records); } final FileOutputStream out = new FileOutputStream(outputFile, true); @@ -194,6 +198,7 @@ private long[] writePartitionedFile(File outputFile) throws IOException { try { for (int i = 0; i < numPartitions; i++) { final File file = partitionWriterSegments[i].file(); + records[i] = partitionWriterSegments[i].record(); if (file.exists()) { final FileInputStream in = new FileInputStream(file); boolean copyThrewException = true; @@ -214,7 +219,7 @@ private long[] writePartitionedFile(File outputFile) throws IOException { writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); } partitionWriters = null; - return lengths; + return new MapInfo(lengths, records); } @Override diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/MapInfo.java b/core/src/main/java/org/apache/spark/shuffle/sort/MapInfo.java new file mode 100644 index 0000000000000..87bbba8a48268 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/MapInfo.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +public final class MapInfo { + final long[] lengths; + final long[] records; + + public MapInfo(long[] lengths, long[] records) { + this.lengths = lengths; + this.records = records; + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index c7d2db4217d96..59abefa1db78d 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -195,6 +195,7 @@ private void writeSortedFile(boolean isLastFile) { if (currentPartition != -1) { final FileSegment fileSegment = writer.commitAndGet(); spillInfo.partitionLengths[currentPartition] = fileSegment.length(); + spillInfo.partitionRecords[currentPartition] = fileSegment.record(); } currentPartition = partition; } @@ -222,6 +223,7 @@ private void writeSortedFile(boolean isLastFile) { // writeSortedFile() in that case. if (currentPartition != -1) { spillInfo.partitionLengths[currentPartition] = committedSegment.length(); + spillInfo.partitionRecords[currentPartition] = committedSegment.record(); spills.add(spillInfo); } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java index 865def6b83c53..5f0875fe38f66 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java @@ -26,11 +26,13 @@ */ final class SpillInfo { final long[] partitionLengths; + final long[] partitionRecords; final File file; final TempShuffleBlockId blockId; SpillInfo(int numPartitions, File file, TempShuffleBlockId blockId) { this.partitionLengths = new long[numPartitions]; + this.partitionRecords = new long[numPartitions]; this.file = file; this.blockId = blockId; } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 4839d04522f10..21a2c2ad42ad9 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -229,12 +229,12 @@ void closeAndWriteOutput() throws IOException { serOutputStream = null; final SpillInfo[] spills = sorter.closeAndGetSpills(); sorter = null; - final long[] partitionLengths; + final MapInfo mapInfo; final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); final File tmp = Utils.tempFileWith(output); try { try { - partitionLengths = mergeSpills(spills, tmp); + mapInfo = mergeSpills(spills, tmp); } finally { for (SpillInfo spill : spills) { if (spill.file.exists() && ! spill.file.delete()) { @@ -242,13 +242,13 @@ void closeAndWriteOutput() throws IOException { } } } - shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); + shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, mapInfo.lengths, tmp); } finally { if (tmp.exists() && !tmp.delete()) { logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); } } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), mapInfo.lengths, mapInfo.records); } @VisibleForTesting @@ -280,25 +280,25 @@ void forceSorterToSpill() throws IOException { * * @return the partition lengths in the merged file. */ - private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException { + private MapInfo mergeSpills(SpillInfo[] spills, File outputFile) throws IOException { final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true); final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf); final boolean fastMergeEnabled = - sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true); + sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true); final boolean fastMergeIsSupported = !compressionEnabled || - CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); + CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled(); try { if (spills.length == 0) { new FileOutputStream(outputFile).close(); // Create an empty file - return new long[partitioner.numPartitions()]; + return new MapInfo(new long[partitioner.numPartitions()], new long[partitioner.numPartitions()]); } else if (spills.length == 1) { // Here, we don't need to perform any metrics updates because the bytes written to this // output file would have already been counted as shuffle bytes written. Files.move(spills[0].file, outputFile); - return spills[0].partitionLengths; + return new MapInfo(spills[0].partitionLengths, spills[0].partitionRecords); } else { - final long[] partitionLengths; + final MapInfo mapInfo; // There are multiple spills to merge, so none of these spill files' lengths were counted // towards our shuffle write count or shuffle write time. If we use the slow merge path, // then the final output file's size won't necessarily be equal to the sum of the spill @@ -315,14 +315,14 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti // that doesn't need to interpret the spilled bytes. if (transferToEnabled && !encryptionEnabled) { logger.debug("Using transferTo-based fast merge"); - partitionLengths = mergeSpillsWithTransferTo(spills, outputFile); + mapInfo = mergeSpillsWithTransferTo(spills, outputFile); } else { logger.debug("Using fileStream-based fast merge"); - partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null); + mapInfo = mergeSpillsWithFileStream(spills, outputFile, null); } } else { logger.debug("Using slow merge"); - partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec); + mapInfo = mergeSpillsWithFileStream(spills, outputFile, compressionCodec); } // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has // in-memory records, we write out the in-memory records to a file but do not count that @@ -331,7 +331,7 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti // SpillInfo's bytes. writeMetrics.decBytesWritten(spills[spills.length - 1].file.length()); writeMetrics.incBytesWritten(outputFile.length()); - return partitionLengths; + return mapInfo; } } catch (IOException e) { if (outputFile.exists() && !outputFile.delete()) { @@ -357,13 +357,14 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled. * @return the partition lengths in the merged file. */ - private long[] mergeSpillsWithFileStream( - SpillInfo[] spills, - File outputFile, - @Nullable CompressionCodec compressionCodec) throws IOException { + private MapInfo mergeSpillsWithFileStream( + SpillInfo[] spills, + File outputFile, + @Nullable CompressionCodec compressionCodec) throws IOException { assert (spills.length >= 2); final int numPartitions = partitioner.numPartitions(); final long[] partitionLengths = new long[numPartitions]; + final long[] partitionRecords = new long[numPartitions]; final InputStream[] spillInputStreams = new InputStream[spills.length]; final OutputStream bos = new BufferedOutputStream( @@ -377,8 +378,8 @@ private long[] mergeSpillsWithFileStream( try { for (int i = 0; i < spills.length; i++) { spillInputStreams[i] = new NioBufferedFileInputStream( - spills[i].file, - inputBufferSizeInBytes); + spills[i].file, + inputBufferSizeInBytes); } for (int partition = 0; partition < numPartitions; partition++) { final long initialFileLength = mergedFileOutputStream.getByteCount(); @@ -386,19 +387,20 @@ private long[] mergeSpillsWithFileStream( // the higher level streams to make sure all data is really flushed and internal state is // cleaned. OutputStream partitionOutput = new CloseAndFlushShieldOutputStream( - new TimeTrackingOutputStream(writeMetrics, mergedFileOutputStream)); + new TimeTrackingOutputStream(writeMetrics, mergedFileOutputStream)); partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); if (compressionCodec != null) { partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); } + long records = 0; for (int i = 0; i < spills.length; i++) { final long partitionLengthInSpill = spills[i].partitionLengths[partition]; if (partitionLengthInSpill > 0) { InputStream partitionInputStream = new LimitedInputStream(spillInputStreams[i], - partitionLengthInSpill, false); + partitionLengthInSpill, false); try { partitionInputStream = blockManager.serializerManager().wrapForEncryption( - partitionInputStream); + partitionInputStream); if (compressionCodec != null) { partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); } @@ -407,10 +409,12 @@ private long[] mergeSpillsWithFileStream( partitionInputStream.close(); } } + records += spills[i].partitionRecords[partition]; } partitionOutput.flush(); partitionOutput.close(); partitionLengths[partition] = (mergedFileOutputStream.getByteCount() - initialFileLength); + partitionRecords[partition] = records; } threwException = false; } finally { @@ -421,7 +425,7 @@ private long[] mergeSpillsWithFileStream( } Closeables.close(mergedFileOutputStream, threwException); } - return partitionLengths; + return new MapInfo(partitionLengths, partitionRecords); } /** @@ -431,10 +435,11 @@ private long[] mergeSpillsWithFileStream( * * @return the partition lengths in the merged file. */ - private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException { + private MapInfo mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException { assert (spills.length >= 2); final int numPartitions = partitioner.numPartitions(); final long[] partitionLengths = new long[numPartitions]; + final long[] partitionRecords = new long[numPartitions]; final FileChannel[] spillInputChannels = new FileChannel[spills.length]; final long[] spillInputChannelPositions = new long[spills.length]; FileChannel mergedFileOutputChannel = null; @@ -452,17 +457,19 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th for (int partition = 0; partition < numPartitions; partition++) { for (int i = 0; i < spills.length; i++) { final long partitionLengthInSpill = spills[i].partitionLengths[partition]; + final long partitionRecordInSpill = spills[i].partitionRecords[partition]; final FileChannel spillInputChannel = spillInputChannels[i]; final long writeStartTime = System.nanoTime(); Utils.copyFileStreamNIO( - spillInputChannel, - mergedFileOutputChannel, - spillInputChannelPositions[i], - partitionLengthInSpill); + spillInputChannel, + mergedFileOutputChannel, + spillInputChannelPositions[i], + partitionLengthInSpill); spillInputChannelPositions[i] += partitionLengthInSpill; writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); bytesWrittenToMergedFile += partitionLengthInSpill; partitionLengths[partition] += partitionLengthInSpill; + partitionRecords[partition] += partitionRecordInSpill; } } // Check the position after transferTo loop to see if it is in the right position and raise an @@ -471,11 +478,11 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948. if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) { throw new IOException( - "Current position " + mergedFileOutputChannel.position() + " does not equal expected " + - "position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" + - " version to see if it is 2.6.32, as there is a kernel bug which will lead to " + - "unexpected behavior when using transferTo. You can set spark.file.transferTo=false " + - "to disable this NIO feature." + "Current position " + mergedFileOutputChannel.position() + " does not equal expected " + + "position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" + + " version to see if it is 2.6.32, as there is a kernel bug which will lead to " + + "unexpected behavior when using transferTo. You can set spark.file.transferTo=false " + + "to disable this NIO feature." ); } threwException = false; @@ -488,7 +495,7 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th } Closeables.close(mergedFileOutputChannel, threwException); } - return partitionLengths; + return new MapInfo(partitionLengths, partitionRecords); } @Override diff --git a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala index f8a6f1d0d8cbb..b20b1cf7cc26e 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala @@ -24,4 +24,8 @@ package org.apache.spark * @param bytesByPartitionId approximate number of output bytes for each map output partition * (may be inexact due to use of compressed map statuses) */ -private[spark] class MapOutputStatistics(val shuffleId: Int, val bytesByPartitionId: Array[Long]) +private[spark] class MapOutputStatistics( + val shuffleId: Int, + val bytesByPartitionId: Array[Long], + val recordsByPartitionId: Array[Long] = Array[Long]()) + extends Serializable diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 681ab27d0d368..d1dcce5807d4d 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -22,7 +22,7 @@ import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolE import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.JavaConverters._ -import scala.collection.mutable.{HashMap, HashSet, ListBuffer, Map} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration import scala.reflect.ClassTag @@ -31,10 +31,12 @@ import scala.util.control.NonFatal import org.apache.spark.broadcast.{Broadcast, BroadcastManager} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.io.CompressionCodec import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} -import org.apache.spark.scheduler.MapStatus +import org.apache.spark.scheduler.{ExecutorCacheTaskLocation, MapStatus} +import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.MetadataFetchFailedException -import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} +import org.apache.spark.storage.{BlockId, BlockManagerId, ContinuousShuffleBlockId, ShuffleBlockId} import org.apache.spark.util._ /** @@ -280,10 +282,23 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } + protected def supportsContinuousBlockBatchFetch(serializerRelocatable: Boolean): Boolean = { + if (!serializerRelocatable) { + false + } else { + if (!conf.getBoolean("spark.shuffle.compress", true)) { + true + } else { + val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf) + CompressionCodec.supportsConcatenationOfSerializedStreams(compressionCodec) + } + } + } + // For testing def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) - : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { - getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1) + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1, serializerRelocatable = false) } /** @@ -295,8 +310,28 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * and the second item is a sequence of (shuffle block id, shuffle block size) tuples * describing the shuffle blocks that are stored at that block manager. */ - def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] + def getMapSizesByExecutorId( + shuffleId: Int, + startPartition: Int, + endPartition: Int, + serializerRelocatable: Boolean): Seq[(BlockManagerId, Seq[(BlockId, Long)])] + + /** + * Called from executors to get the server URIs and output sizes for each shuffle block that + * needs to be read from a given range of map output partitions (startPartition is included but + * endPartition is excluded from the range) and a given start map Id and end map Id. + * + * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, + * and the second item is a sequence of (shuffle block id, shuffle block size) tuples + * describing the shuffle blocks that are stored at that block manager. + */ + def getMapSizesByExecutorId( + shuffleId: Int, + startPartition: Int, + endPartition: Int, + startMapId: Int, + endMapId: Int, + serializerRelocatable: Boolean): Seq[(BlockManagerId, Seq[(BlockId, Long)])] /** * Deletes map output status information for the specified shuffle stage. @@ -521,29 +556,57 @@ private[spark] class MapOutputTrackerMaster( val parallelism = math.min( Runtime.getRuntime.availableProcessors(), statuses.length.toLong * totalSizes.length / parallelAggThreshold + 1).toInt + + var totalRecords = new Array[Long](0) + val records = statuses(0).getRecordForBlock(0) + if (parallelism <= 1) { for (s <- statuses) { for (i <- 0 until totalSizes.length) { totalSizes(i) += s.getSizeForBlock(i) } } + // records == -1 means no records number info + if (records != -1) { + totalRecords = new Array[Long](dep.partitioner.numPartitions) + for (s <- statuses) { + for (i <- totalRecords.indices) { + totalRecords(i) += s.getRecordForBlock(i) + } + } + } } else { + val (sizeParallelism, recordParallelism) = if (records != -1) { + (parallelism / 2, parallelism - parallelism / 2) + } else { + (parallelism, 0) + } val threadPool = ThreadUtils.newDaemonFixedThreadPool(parallelism, "map-output-aggregate") try { implicit val executionContext = ExecutionContext.fromExecutor(threadPool) - val mapStatusSubmitTasks = equallyDivide(totalSizes.length, parallelism).map { + var mapStatusSubmitTasks = equallyDivide(totalSizes.length, sizeParallelism).map { reduceIds => Future { for (s <- statuses; i <- reduceIds) { totalSizes(i) += s.getSizeForBlock(i) } } } + if (records != -1) { + totalRecords = new Array[Long](dep.partitioner.numPartitions) + mapStatusSubmitTasks ++= equallyDivide(totalRecords.length, recordParallelism).map { + reduceIds => Future { + for (s <- statuses; i <- reduceIds) { + totalRecords(i) += s.getRecordForBlock(i) + } + } + } + } ThreadUtils.awaitResult(Future.sequence(mapStatusSubmitTasks), Duration.Inf) } finally { threadPool.shutdown() } } - new MapOutputStatistics(dep.shuffleId, totalSizes) + new MapOutputStatistics(dep.shuffleId, totalSizes, totalRecords) } } @@ -624,6 +687,35 @@ private[spark] class MapOutputTrackerMaster( None } + /** + * Return the locations where the Mapper(s) ran. The locations each includes both a host and an + * executor id on that host. + * + * @param dep shuffle dependency object + * @param startMapId the start map id + * @param endMapId the end map id + * @return a sequence of locations that each includes both a host and an executor id on that + * host. + */ + def getMapLocation(dep: ShuffleDependency[_, _, _], startMapId: Int, endMapId: Int): Seq[String] = + { + val shuffleStatus = shuffleStatuses.get(dep.shuffleId).orNull + if (shuffleStatus != null) { + shuffleStatus.withMapStatuses { statuses => + if (startMapId >= 0 && endMapId <= statuses.length) { + val statusesPicked = statuses.slice(startMapId, endMapId).filter(_ != null) + statusesPicked.map { status => + ExecutorCacheTaskLocation(status.location.host, status.location.executorId).toString + } + } else { + Nil + } + } + } else { + Nil + } + } + def incrementEpoch() { epochLock.synchronized { epoch += 1 @@ -638,18 +730,47 @@ private[spark] class MapOutputTrackerMaster( } } - // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. // This method is only called in local-mode. - def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { + def getMapSizesByExecutorId( + shuffleId: Int, + startPartition: Int, + endPartition: Int, + serializerRelocatable: Boolean): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") shuffleStatuses.get(shuffleId) match { case Some (shuffleStatus) => shuffleStatus.withMapStatuses { statuses => - MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) + MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses, + supportsContinuousBlockBatchFetch(serializerRelocatable)) + } + case None => + Seq.empty + } + } + + override def getMapSizesByExecutorId( + shuffleId: Int, + startPartition: Int, + endPartition: Int, + startMapId: Int, + endMapId: Int, + serializerRelocatable: Boolean) : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + logDebug(s"Fetching outputs for shuffle $shuffleId, startMapId $startMapId endMapId $endMapId" + + s"partitions $startPartition-$endPartition") + shuffleStatuses.get(shuffleId) match { + case Some (shuffleStatus) => + shuffleStatus.withMapStatuses { statuses => + MapOutputTracker.convertMapStatuses( + shuffleId, + startPartition, + endPartition, + statuses, + startMapId, + endMapId, + supportsContinuousBlockBatchFetch(serializerRelocatable)) } case None => - Iterator.empty + Seq.empty } } @@ -676,13 +797,37 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr /** Remembers which map output locations are currently being fetched on an executor. */ private val fetching = new HashSet[Int] - // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. - override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { + override def getMapSizesByExecutorId( + shuffleId: Int, + startPartition: Int, + endPartition: Int, + serializerRelocatable: Boolean): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") val statuses = getStatuses(shuffleId) try { - MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) + MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses, + supportsContinuousBlockBatchFetch(serializerRelocatable)) + } catch { + case e: MetadataFetchFailedException => + // We experienced a fetch failure so our mapStatuses cache is outdated; clear it: + mapStatuses.clear() + throw e + } + } + + override def getMapSizesByExecutorId( + shuffleId: Int, + startPartition: Int, + endPartition: Int, + startMapId: Int, + endMapId: Int, + serializerRelocatable: Boolean) : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + logDebug(s"Fetching outputs for shuffle $shuffleId, startMapId $startMapId endMapId $endMapId" + + s"partitions $startPartition-$endPartition") + val statuses = getStatuses(shuffleId) + try { + MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses, + startMapId, endMapId, supportsContinuousBlockBatchFetch(serializerRelocatable)) } catch { case e: MetadataFetchFailedException => // We experienced a fetch failure so our mapStatuses cache is outdated; clear it: @@ -868,28 +1013,95 @@ private[spark] object MapOutputTracker extends Logging { * and the second item is a sequence of (shuffle block ID, shuffle block size) tuples * describing the shuffle blocks that are stored at that block manager. */ + def convertMapStatuses( + shuffleId: Int, + startPartition: Int, + endPartition: Int, + statuses: Array[MapStatus], + supportsContinuousBlockBatchFetch: Boolean): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + assert (statuses != null) + val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]] + for ((status, mapId) <- statuses.zipWithIndex) { + if (status == null) { + val errorMessage = s"Missing an output location for shuffle $shuffleId" + logError(errorMessage) + throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage) + } else { + if (endPartition - startPartition > 1 && supportsContinuousBlockBatchFetch) { + val totalSize: Long = (startPartition until endPartition).map(status.getSizeForBlock).sum + if (totalSize != 0) { + splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) += + ((ContinuousShuffleBlockId(shuffleId, mapId, + startPartition, endPartition - startPartition), totalSize)) + } + } else { + for (part <- startPartition until endPartition) { + val size = status.getSizeForBlock(part) + if (size != 0) { + splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) += + ((ShuffleBlockId(shuffleId, mapId, part), status.getSizeForBlock(part))) + } + } + } + } + } + + splitsByAddress.toSeq + } + + /** + * Given an array of map statuses, the start map Id, end map Id and a range of map output + * partitions, returns a sequence that, lists the shuffle block IDs and corresponding shuffle + * block sizes stored at that block manager. + * + * If the status of the map is null (indicating a missing location due to a failed mapper), + * throws a FetchFailedException. + * + * @param shuffleId Identifier for the shuffle + * @param startPartition Start of map output partition ID range (included in range) + * @param endPartition End of map output partition ID range (excluded from range) + * @param statuses List of map statuses, indexed by map ID. + * @param startMapId Start of map Id range (included in range) + * @param endMapId End of map Id (excluded from range) + * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, + * and the second item is a sequence of (shuffle block ID, shuffle block size) tuples + * describing the shuffle blocks that are stored at that block manager. + */ def convertMapStatuses( shuffleId: Int, startPartition: Int, endPartition: Int, - statuses: Array[MapStatus]): Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { - assert (statuses != null) - val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long)]] - for ((status, mapId) <- statuses.iterator.zipWithIndex) { + statuses: Array[MapStatus], + startMapId: Int, + endMapId: Int, + supportsContinuousBlockBatchFetch: Boolean): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + assert (statuses != null && statuses.length >= endMapId && startMapId >= 0) + val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]] + for (mapId <- startMapId until endMapId) { + val status = statuses(mapId) if (status == null) { val errorMessage = s"Missing an output location for shuffle $shuffleId" logError(errorMessage) throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage) } else { - for (part <- startPartition until endPartition) { - val size = status.getSizeForBlock(part) - if (size != 0) { - splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) += - ((ShuffleBlockId(shuffleId, mapId, part), size)) + if (endPartition - startPartition > 1 && supportsContinuousBlockBatchFetch) { + val totalSize: Long = (startPartition until endPartition).map(status.getSizeForBlock).sum + if (totalSize != 0) { + splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) += + ((ContinuousShuffleBlockId(shuffleId, mapId, + startPartition, endPartition - startPartition), totalSize)) + } + } else { + for (part <- startPartition until endPartition) { + val size = status.getSizeForBlock(part) + if (size != 0) { + splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) += + ((ShuffleBlockId(shuffleId, mapId, part), status.getSizeForBlock(part))) + } } } } } - splitsByAddress.iterator + splitsByAddress.toSeq } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 0e39eb874aafd..fde656985dc49 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -418,14 +418,6 @@ package object config { .booleanConf .createWithDefault(false) - private[spark] val SHUFFLE_ACCURATE_BLOCK_THRESHOLD = - ConfigBuilder("spark.shuffle.accurateBlockThreshold") - .doc("Threshold in bytes above which the size of shuffle blocks in " + - "HighlyCompressedMapStatus is accurately recorded. This helps to prevent OOM " + - "by avoiding underestimating shuffle block size when fetch shuffle blocks.") - .bytesConf(ByteUnit.BYTE) - .createWithDefault(100 * 1024 * 1024) - private[spark] val SHUFFLE_REGISTRATION_TIMEOUT = ConfigBuilder("spark.shuffle.registration.timeout") .doc("Timeout in milliseconds for registration to the external shuffle service.") @@ -664,4 +656,33 @@ package object config { .stringConf .toSequence .createWithDefault(Nil) + + private[spark] val SHUFFLE_HIGHLY_COMPRESSED_MAP_STATUS_THRESHOLD = + ConfigBuilder("spark.shuffle.highlyCompressedMapStatusThreshold") + .doc("HighlyCompressedMapStatus is used if shuffle partition number is larger than the " + + "threshold. Otherwise CompressedMapStatus is used.") + .intConf + .createWithDefault(2000) + + private[spark] val SHUFFLE_STATISTICS_VERBOSE = + ConfigBuilder("spark.shuffle.statistics.verbose") + .doc("Collect shuffle statistics in verbose mode, including row counts etc.") + .booleanConf + .createWithDefault(false) + + private[spark] val SHUFFLE_ACCURATE_BLOCK_SIZE_THRESHOLD = + ConfigBuilder("spark.shuffle.accurateBlockThreshold") + .doc("Threshold in bytes above which the size of shuffle blocks in " + + "HighlyCompressedMapStatus is accurately recorded. This helps to prevent OOM " + + "by avoiding underestimating shuffle block size when fetch shuffle blocks.") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(100 * 1024 * 1024) + + private[spark] val SHUFFLE_ACCURATE_BLOCK_RECORD_THRESHOLD = + ConfigBuilder("spark.shuffle.accurateBlockRecordThreshold") + .doc("When we compress the records number of shuffle blocks in HighlyCompressedMapStatus, " + + "we will record the number accurately if it's above this config. The record number will " + + "be used for data skew judgement when spark.shuffle.statistics.verbose is set true.") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(2 * 1024 * 1024) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 0e221edf3965a..7a9a38f05c01d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -44,24 +44,31 @@ private[spark] sealed trait MapStatus { * necessary for correctness, since block fetchers are allowed to skip zero-size blocks. */ def getSizeForBlock(reduceId: Int): Long + + def getRecordForBlock(reduceId: Int): Long } private[spark] object MapStatus { - /** - * Min partition number to use [[HighlyCompressedMapStatus]]. A bit ugly here because in test - * code we can't assume SparkEnv.get exists. - */ - private lazy val minPartitionsToUseHighlyCompressMapStatus = Option(SparkEnv.get) - .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS)) - .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get) + // we use Array[Long]() as uncompressedRecords's default value, + // main consideration is ser/deser do not accept null. + def apply( + loc: BlockManagerId, + uncompressedSizes: Array[Long], + uncompressedRecords: Array[Long] = Array[Long]()): MapStatus = { + val verbose = Option(SparkEnv.get) + .map(_.conf.get(config.SHUFFLE_STATISTICS_VERBOSE)) + .getOrElse(config.SHUFFLE_STATISTICS_VERBOSE.defaultValue.get) + val threshold = Option(SparkEnv.get) + .map(_.conf.get(config.SHUFFLE_HIGHLY_COMPRESSED_MAP_STATUS_THRESHOLD)) + .getOrElse(config.SHUFFLE_HIGHLY_COMPRESSED_MAP_STATUS_THRESHOLD.defaultValue.get) - def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { - if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) { - HighlyCompressedMapStatus(loc, uncompressedSizes) + val newRecords = if (verbose) uncompressedRecords else Array[Long]() + if (uncompressedSizes.length > threshold) { + HighlyCompressedMapStatus(loc, uncompressedSizes, newRecords) } else { - new CompressedMapStatus(loc, uncompressedSizes) + new CompressedMapStatus(loc, uncompressedSizes, newRecords) } } @@ -104,13 +111,17 @@ private[spark] object MapStatus { */ private[spark] class CompressedMapStatus( private[this] var loc: BlockManagerId, - private[this] var compressedSizes: Array[Byte]) + private[this] var compressedSizes: Array[Byte], + private[this] var compressedRecords: Array[Byte]) extends MapStatus with Externalizable { - protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only + protected def this() = this(null, null.asInstanceOf[Array[Byte]], + null.asInstanceOf[Array[Byte]]) // For deserialization only - def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) { - this(loc, uncompressedSizes.map(MapStatus.compressSize)) + def this(loc: BlockManagerId, uncompressedSizes: Array[Long], + uncompressedRecords: Array[Long] = Array[Long]()) { + this(loc, uncompressedSizes.map(MapStatus.compressSize), + uncompressedRecords.map(MapStatus.compressSize)) } override def location: BlockManagerId = loc @@ -119,10 +130,20 @@ private[spark] class CompressedMapStatus( MapStatus.decompressSize(compressedSizes(reduceId)) } + override def getRecordForBlock(reduceId: Int): Long = { + if (compressedRecords.nonEmpty) { + MapStatus.decompressSize(compressedRecords(reduceId)) + } else { + -1 + } + } + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) out.writeInt(compressedSizes.length) out.write(compressedSizes) + out.writeInt(compressedRecords.length) + out.write(compressedRecords) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { @@ -130,6 +151,9 @@ private[spark] class CompressedMapStatus( val len = in.readInt() compressedSizes = new Array[Byte](len) in.readFully(compressedSizes) + val recordsLen = in.readInt() + compressedRecords = new Array[Byte](recordsLen) + in.readFully(compressedRecords) } } @@ -145,18 +169,20 @@ private[spark] class CompressedMapStatus( * @param hugeBlockSizes sizes of huge blocks by their reduceId. */ private[spark] class HighlyCompressedMapStatus private ( - private[this] var loc: BlockManagerId, - private[this] var numNonEmptyBlocks: Int, - private[this] var emptyBlocks: RoaringBitmap, - private[this] var avgSize: Long, - private var hugeBlockSizes: Map[Int, Byte]) + private[this] var loc: BlockManagerId, + private[this] var numNonEmptyBlocks: Int, + private[this] var emptyBlocks: RoaringBitmap, + private[this] var avgSize: Long, + private var hugeBlockSizes: Map[Int, Byte], + private[this] var avgRecord: Long, + private var hugeBlockRecords: Map[Int, Byte]) extends MapStatus with Externalizable { // loc could be null when the default constructor is called during deserialization require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0, "Average size can only be zero for map stages that produced no output") - protected def this() = this(null, -1, null, -1, null) // For deserialization only + protected def this() = this(null, -1, null, -1, null, -1, null) // For deserialization only override def location: BlockManagerId = loc @@ -172,6 +198,22 @@ private[spark] class HighlyCompressedMapStatus private ( } } + override def getRecordForBlock(reduceId: Int): Long = { + assert(hugeBlockSizes != null) + if (avgRecord != -1) { + if (emptyBlocks.contains(reduceId)) { + 0 + } else { + hugeBlockRecords.get(reduceId) match { + case Some(record) => MapStatus.decompressSize(record) + case None => avgRecord + } + } + } else { + -1 + } + } + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) emptyBlocks.writeExternal(out) @@ -181,6 +223,12 @@ private[spark] class HighlyCompressedMapStatus private ( out.writeInt(kv._1) out.writeByte(kv._2) } + out.writeLong(avgRecord) + out.writeInt(hugeBlockRecords.size) + hugeBlockRecords.foreach { kv => + out.writeInt(kv._1) + out.writeByte(kv._2) + } } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { @@ -196,11 +244,23 @@ private[spark] class HighlyCompressedMapStatus private ( hugeBlockSizesArray += Tuple2(block, size) } hugeBlockSizes = hugeBlockSizesArray.toMap + avgRecord = in.readLong() + val recordCount = in.readInt() + val hugeBlockRecordsArray = mutable.ArrayBuffer[Tuple2[Int, Byte]]() + (0 until recordCount).foreach { _ => + val block = in.readInt() + val record = in.readByte() + hugeBlockRecordsArray += Tuple2(block, record) + } + hugeBlockRecords = hugeBlockRecordsArray.toMap } } private[spark] object HighlyCompressedMapStatus { - def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = { + def apply( + loc: BlockManagerId, + uncompressedSizes: Array[Long], + uncompressedRecords: Array[Long] = Array[Long]()): HighlyCompressedMapStatus = { // We must keep track of which blocks are empty so that we don't report a zero-sized // block as being non-empty (or vice-versa) when using the average block size. var i = 0 @@ -213,8 +273,8 @@ private[spark] object HighlyCompressedMapStatus { val emptyBlocks = new RoaringBitmap() val totalNumBlocks = uncompressedSizes.length val threshold = Option(SparkEnv.get) - .map(_.conf.get(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD)) - .getOrElse(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD.defaultValue.get) + .map(_.conf.get(config.SHUFFLE_ACCURATE_BLOCK_SIZE_THRESHOLD)) + .getOrElse(config.SHUFFLE_ACCURATE_BLOCK_SIZE_THRESHOLD.defaultValue.get) val hugeBlockSizesArray = ArrayBuffer[Tuple2[Int, Byte]]() while (i < totalNumBlocks) { val size = uncompressedSizes(i) @@ -238,9 +298,38 @@ private[spark] object HighlyCompressedMapStatus { } else { 0 } + + var recordSmallBlocks: Long = 0 + numSmallBlocks = 0 + var avgRecord: Long = -1 + val recordThreshold = Option(SparkEnv.get) + .map(_.conf.get(config.SHUFFLE_ACCURATE_BLOCK_RECORD_THRESHOLD)) + .getOrElse(config.SHUFFLE_ACCURATE_BLOCK_RECORD_THRESHOLD.defaultValue.get) + val hugeBlockRecordsArray = ArrayBuffer[Tuple2[Int, Byte]]() + if (uncompressedRecords.nonEmpty) { + i = 0 + while (i < totalNumBlocks) { + val record = uncompressedRecords(i) + if (record > 0) { + if (record < recordThreshold) { + recordSmallBlocks += record + numSmallBlocks += 1 + } else { + hugeBlockRecordsArray += Tuple2(i, MapStatus.compressSize(uncompressedRecords(i))) + } + } + i += 1 + } + avgRecord = if (numSmallBlocks > 0) { + recordSmallBlocks / numSmallBlocks + } else { + 0 + } + } + emptyBlocks.trim() emptyBlocks.runOptimize() new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize, - hugeBlockSizesArray.toMap) + hugeBlockSizesArray.toMap, avgRecord, hugeBlockRecordsArray.toMap) } } diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 1d4b05caaa143..99ee035c375ae 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -109,6 +109,7 @@ private[spark] class SerializerManager( private def shouldCompress(blockId: BlockId): Boolean = { blockId match { case _: ShuffleBlockId => compressShuffle + case _: ContinuousShuffleBlockId => compressShuffle case _: BroadcastBlockId => compressBroadcast case _: RDDBlockId => compressRdds case _: TempLocalBlockId => compressShuffleSpill diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 74b0e0b3a741a..ce89eee31f376 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -35,18 +35,37 @@ private[spark] class BlockStoreShuffleReader[K, C]( context: TaskContext, serializerManager: SerializerManager = SparkEnv.get.serializerManager, blockManager: BlockManager = SparkEnv.get.blockManager, - mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker, + startMapId: Option[Int] = None, + endMapId: Option[Int] = None) extends ShuffleReader[K, C] with Logging { private val dep = handle.dependency /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { + val blocksByAddress = (startMapId, endMapId) match { + case (Some(startId), Some(endId)) => mapOutputTracker.getMapSizesByExecutorId( + handle.shuffleId, + startPartition, + endPartition, + startId, + endId, + dep.serializer.supportsRelocationOfSerializedObjects) + case (None, None) => mapOutputTracker.getMapSizesByExecutorId( + handle.shuffleId, + startPartition, + endPartition, + dep.serializer.supportsRelocationOfSerializedObjects) + case (_, _) => throw new IllegalArgumentException( + "startMapId and endMapId should be both set or unset") + } + val wrappedStreams = new ShuffleBlockFetcherIterator( context, blockManager.shuffleClient, blockManager, - mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition), + blocksByAddress, serializerManager.wrapStream, // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index d3f1c7ec1bbee..696662c1945a3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -190,7 +190,7 @@ private[spark] class IndexShuffleBlockResolver( } } - override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { + override def getBlockData(blockId: ShuffleBlockIdBase): ManagedBuffer = { // The block is actually going to be a range of a single map output file for this map, so // find out the consolidated file, then the offset within that from our index val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId) @@ -205,10 +205,19 @@ private[spark] class IndexShuffleBlockResolver( channel.position(blockId.reduceId * 8L) val in = new DataInputStream(Channels.newInputStream(channel)) try { + channel.position(blockId.reduceId * 8) val offset = in.readLong() + var expectedPosition = 0 + blockId match { + case bid: ContinuousShuffleBlockId => + val tempId = blockId.reduceId + bid.numBlocks + channel.position(tempId * 8) + expectedPosition = tempId * 8 + 8 + case _ => + expectedPosition = blockId.reduceId * 8 + 16 + } val nextOffset = in.readLong() val actualPosition = channel.position() - val expectedPosition = blockId.reduceId * 8L + 16 if (actualPosition != expectedPosition) { throw new Exception(s"SPARK-22982: Incorrect channel position after index file reads: " + s"expected $expectedPosition but actual position was $actualPosition.") diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala index d1ecbc1bf0178..8b62c00bc6b9c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala @@ -18,7 +18,7 @@ package org.apache.spark.shuffle import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.storage.ShuffleBlockIdBase private[spark] /** @@ -34,7 +34,7 @@ trait ShuffleBlockResolver { * Retrieve the data for the specified block. If the data for that block is not available, * throws an unspecified exception. */ - def getBlockData(blockId: ShuffleBlockId): ManagedBuffer + def getBlockData(blockId: ShuffleBlockIdBase): ManagedBuffer def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala index 4ea8a7120a9cc..f61461fede671 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -41,7 +41,8 @@ private[spark] trait ShuffleManager { def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext): ShuffleWriter[K, V] /** - * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive) to + * read from map output (startMapId to endMapId - 1, inclusive). * Called on executors by reduce tasks. */ def getReader[K, C]( @@ -50,6 +51,19 @@ private[spark] trait ShuffleManager { endPartition: Int, context: TaskContext): ShuffleReader[K, C] + /** + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive) to + * read from map output (startMapId to endMapId - 1, inclusive). + * Called on executors by reduce tasks. + */ + def getReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext, + startMapId: Int, + endMapId: Int): ShuffleReader[K, C] + /** * Remove a shuffle's metadata from the ShuffleManager. * @return true if the metadata removed successfully, otherwise false. diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 0caf84c6050a8..cc8e615d5da28 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -119,6 +119,27 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) } + /** + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive) to + * read from map output (startMapId to endMapId - 1, inclusive). + * Called on executors by reduce tasks. + */ + override def getReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext, + startMapId: Int, + endMapId: Int): ShuffleReader[K, C] = { + new BlockStoreShuffleReader( + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], + startPartition, + endPartition, + context, + startMapId = Some(startMapId), + endMapId = Some(endMapId)) + } + /** Get a writer for a given partition. Called on executors by map tasks. */ override def getWriter[K, V]( handle: ShuffleHandle, diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 274399b9cc1f3..e4b9bd3caea7f 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -68,9 +68,9 @@ private[spark] class SortShuffleWriter[K, V, C]( val tmp = Utils.tempFileWith(output) try { val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) - val partitionLengths = sorter.writePartitionedFile(blockId, tmp) - shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) + val mapInfo = sorter.writePartitionedFile(blockId, tmp) + shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, mapInfo.lengths, tmp) + mapStatus = MapStatus(blockManager.shuffleServerId, mapInfo.lengths, mapInfo.records) } finally { if (tmp.exists() && !tmp.delete()) { logError(s"Error while deleting temp file ${tmp.getAbsolutePath}") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 7ac2c71c18eb3..14a4df53931fc 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -38,7 +38,7 @@ sealed abstract class BlockId { // convenience methods def asRDDId: Option[RDDBlockId] = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None def isRDD: Boolean = isInstanceOf[RDDBlockId] - def isShuffle: Boolean = isInstanceOf[ShuffleBlockId] + def isShuffle: Boolean = isInstanceOf[ShuffleBlockIdBase] def isBroadcast: Boolean = isInstanceOf[BroadcastBlockId] override def toString: String = name @@ -51,11 +51,25 @@ case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId { // Format of the shuffle block ids (including data and index) should be kept in sync with // org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getBlockData(). +trait ShuffleBlockIdBase extends BlockId { + def shuffleId: Int + def mapId: Int + def reduceId: Int +} + @DeveloperApi -case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { +case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) + extends ShuffleBlockIdBase { override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId } +@DeveloperApi +case class ContinuousShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int, numBlocks: Int) + extends ShuffleBlockIdBase { + override def name: String = + "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + "_" + numBlocks +} + @DeveloperApi case class ShuffleDataBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".data" @@ -104,6 +118,7 @@ class UnrecognizedBlockId(name: String) object BlockId { val RDD = "rdd_([0-9]+)_([0-9]+)".r val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r + val CONTINUE_SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)_([0-9]+)".r val SHUFFLE_DATA = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).data".r val SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).index".r val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r @@ -118,6 +133,8 @@ object BlockId { RDDBlockId(rddId.toInt, splitIndex.toInt) case SHUFFLE(shuffleId, mapId, reduceId) => ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) + case CONTINUE_SHUFFLE(shuffleId, mapId, reduceId, length) => + ContinuousShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt, length.toInt) case SHUFFLE_DATA(shuffleId, mapId, reduceId) => ShuffleDataBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) case SHUFFLE_INDEX(shuffleId, mapId, reduceId) => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 8912e971e50d0..cd2a14fec417a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -378,7 +378,7 @@ private[spark] class BlockManager( */ override def getBlockData(blockId: BlockId): ManagedBuffer = { if (blockId.isShuffle) { - shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) + shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockIdBase]) } else { getLocalBytes(blockId) match { case Some(blockData) => @@ -638,7 +638,7 @@ private[spark] class BlockManager( // TODO: This should gracefully handle case where local block is not available. Currently // downstream code will throw an exception. val buf = new ChunkedByteBuffer( - shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer()) + shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockIdBase]).nioByteBuffer()) Some(new ByteBufferBlockData(buf, true)) } else { blockInfoManager.lockForReading(blockId).map { info => doGetLocalBytes(blockId, info) } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index a024c83d8d8b7..d15fe396e4033 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -181,7 +181,8 @@ private[spark] class DiskBlockObjectWriter( } val pos = channel.position() - val fileSegment = new FileSegment(file, committedPosition, pos - committedPosition) + val fileSegment = + new FileSegment(file, committedPosition, pos - committedPosition, numRecordsWritten) committedPosition = pos // In certain compression codecs, more bytes are written after streams are closed writeMetrics.incBytesWritten(committedPosition - reportedPosition) @@ -189,7 +190,7 @@ private[spark] class DiskBlockObjectWriter( numRecordsWritten = 0 fileSegment } else { - new FileSegment(file, committedPosition, 0) + new FileSegment(file, committedPosition, 0, 0) } } diff --git a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala index 021a9facfb0b2..063c1e26a1e53 100644 --- a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala +++ b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala @@ -21,12 +21,17 @@ import java.io.File /** * References a particular segment of a file (potentially the entire file), - * based off an offset and a length. + * based off offset, length and record number. */ -private[spark] class FileSegment(val file: File, val offset: Long, val length: Long) { +private[spark] class FileSegment( + val file: File, + val offset: Long, + val length: Long, + val record: Long) { require(offset >= 0, s"File segment offset cannot be negative (got $offset)") require(length >= 0, s"File segment length cannot be negative (got $length)") + require(record >= 0, s"File segment record cannot be negative (got $record)") override def toString: String = { - "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length) + "(name=%s, offset=%d, length=%d, record=%d)".format(file.getName, offset, length, record) } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index aecc2284a9588..fcf215026673c 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -46,13 +46,11 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream * * @param context [[TaskContext]], used for metrics update * @param shuffleClient [[ShuffleClient]] for fetching remote blocks - * @param blockManager [[BlockManager]] for reading local blocks + * @param blockManager [[BlockManager]] for reading local blocks * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. * For each block we also require the size (in bytes as a long field) in - * order to throttle the memory usage. Note that zero-sized blocks are - * already excluded, which happened in - * [[MapOutputTracker.convertMapStatuses]]. - * @param streamWrapper A function to wrap the returned input stream. + * order to throttle the memory usage. + * @param streamWrapper A function to wrap the returned input stream. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. * @param maxBlocksInFlightPerAddress max number of shuffle blocks being fetched at any given point @@ -65,7 +63,7 @@ final class ShuffleBlockFetcherIterator( context: TaskContext, shuffleClient: ShuffleClient, blockManager: BlockManager, - blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])], + blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], streamWrapper: (BlockId, InputStream) => InputStream, maxBytesInFlight: Long, maxReqsInFlight: Int, @@ -550,8 +548,9 @@ final class ShuffleBlockFetcherIterator( private def throwFetchFailedException(blockId: BlockId, address: BlockManagerId, e: Throwable) = { blockId match { - case ShuffleBlockId(shufId, mapId, reduceId) => - throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e) + case blockId: ShuffleBlockIdBase => + throw new FetchFailedException( + address, blockId.shuffleId, blockId.mapId, blockId.reduceId, e) case _ => throw new SparkException( "Failed to get block " + blockId + ", which is not a shuffle block", e) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 547a862467c88..c4fe743d3adaf 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -29,6 +29,7 @@ import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.Logging import org.apache.spark.serializer._ +import org.apache.spark.shuffle.sort.MapInfo import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} /** @@ -682,10 +683,11 @@ private[spark] class ExternalSorter[K, V, C]( */ def writePartitionedFile( blockId: BlockId, - outputFile: File): Array[Long] = { + outputFile: File): MapInfo = { // Track location of each range in the output file val lengths = new Array[Long](numPartitions) + val records = new Array[Long](numPartitions) val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize, context.taskMetrics().shuffleWriteMetrics) @@ -700,6 +702,7 @@ private[spark] class ExternalSorter[K, V, C]( } val segment = writer.commitAndGet() lengths(partitionId) = segment.length + records(partitionId) = segment.record } } else { // We must perform merge-sort; get an iterator by partition and write everything directly. @@ -710,6 +713,7 @@ private[spark] class ExternalSorter[K, V, C]( } val segment = writer.commitAndGet() lengths(id) = segment.length + records(id) = segment.record } } } @@ -719,7 +723,7 @@ private[spark] class ExternalSorter[K, V, C]( context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled) context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes) - lengths + new MapInfo(lengths, records) } def stop(): Unit = { diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 0d5c5ea7903e9..929289dbfa26f 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -21,6 +21,7 @@ import java.nio.ByteBuffer; import java.util.*; +import org.apache.spark.*; import scala.Option; import scala.Product2; import scala.Tuple2; @@ -35,10 +36,6 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.apache.spark.HashPartitioner; -import org.apache.spark.ShuffleDependency; -import org.apache.spark.SparkConf; -import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.executor.TaskMetrics; import org.apache.spark.io.CompressionCodec$; @@ -81,6 +78,8 @@ public class UnsafeShuffleWriterSuite { @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency shuffleDep; + @Mock(answer = RETURNS_SMART_NULLS) SparkEnv env; + @After public void tearDown() { @@ -95,17 +94,20 @@ public void tearDown() { @SuppressWarnings("unchecked") public void setUp() throws IOException { MockitoAnnotations.initMocks(this); - tempDir = Utils.createTempDir(null, "test"); + tempDir = Utils.createTempDir("test", "test"); mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir); partitionSizesInMergedFile = null; spillFilesCreated.clear(); conf = new SparkConf() - .set("spark.buffer.pageSize", "1m") - .set("spark.memory.offHeap.enabled", "false"); + .set("spark.buffer.pageSize", "1m") + .set("spark.memory.offHeap.enabled", "false"); taskMetrics = new TaskMetrics(); memoryManager = new TestMemoryManager(conf); taskMemoryManager = new TaskMemoryManager(memoryManager, 0); + when(env.conf()).thenReturn(conf); + SparkEnv.set(env); + // Some tests will override this manager because they change the configuration. This is a // default for tests that don't need a specific one. SerializerManager manager = new SerializerManager(serializer, conf); @@ -113,22 +115,22 @@ public void setUp() throws IOException { when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); when(blockManager.getDiskWriter( - any(BlockId.class), - any(File.class), - any(SerializerInstance.class), - anyInt(), - any(ShuffleWriteMetrics.class))).thenAnswer(invocationOnMock -> { - Object[] args = invocationOnMock.getArguments(); - return new DiskBlockObjectWriter( - (File) args[1], - blockManager.serializerManager(), - (SerializerInstance) args[2], - (Integer) args[3], - false, - (ShuffleWriteMetrics) args[4], - (BlockId) args[0] - ); - }); + any(BlockId.class), + any(File.class), + any(SerializerInstance.class), + anyInt(), + any(ShuffleWriteMetrics.class))).thenAnswer(invocationOnMock -> { + Object[] args = invocationOnMock.getArguments(); + return new DiskBlockObjectWriter( + (File) args[1], + blockManager.serializerManager(), + (SerializerInstance) args[2], + (Integer) args[3], + false, + (ShuffleWriteMetrics) args[4], + (BlockId) args[0] + ); + }); when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); doAnswer(invocationOnMock -> { @@ -138,7 +140,7 @@ public void setUp() throws IOException { tmp.renameTo(mergedOutputFile); return null; }).when(shuffleBlockResolver) - .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), any(File.class)); + .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), any(File.class)); when(diskBlockManager.createTempShuffleBlock()).thenAnswer(invocationOnMock -> { TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID()); @@ -153,23 +155,23 @@ public void setUp() throws IOException { } private UnsafeShuffleWriter createWriter( - boolean transferToEnabled) throws IOException { + boolean transferToEnabled) throws IOException { conf.set("spark.file.transferTo", String.valueOf(transferToEnabled)); return new UnsafeShuffleWriter<>( - blockManager, - shuffleBlockResolver, - taskMemoryManager, - new SerializedShuffleHandle<>(0, 1, shuffleDep), - 0, // map id - taskContext, - conf + blockManager, + shuffleBlockResolver, + taskMemoryManager, + new SerializedShuffleHandle<>(0, 1, shuffleDep), + 0, // map id + taskContext, + conf ); } private void assertSpillFilesWereCleanedUp() { for (File spillFile : spillFilesCreated) { assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up", - spillFile.exists()); + spillFile.exists()); } } @@ -262,8 +264,8 @@ public void writeWithoutSpilling() throws Exception { } assertEquals(mergedOutputFile.length(), sumOfPartitionSizes); assertEquals( - HashMultiset.create(dataToWrite), - HashMultiset.create(readRecordsFromFile())); + HashMultiset.create(dataToWrite), + HashMultiset.create(readRecordsFromFile())); assertSpillFilesWereCleanedUp(); ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics(); assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten()); @@ -273,9 +275,9 @@ public void writeWithoutSpilling() throws Exception { } private void testMergingSpills( - final boolean transferToEnabled, - String compressionCodecName, - boolean encrypt) throws Exception { + final boolean transferToEnabled, + String compressionCodecName, + boolean encrypt) throws Exception { if (compressionCodecName != null) { conf.set("spark.shuffle.compress", "true"); conf.set("spark.io.compression.codec", compressionCodecName); @@ -287,7 +289,7 @@ private void testMergingSpills( SerializerManager manager; if (encrypt) { manager = new SerializerManager(serializer, conf, - Option.apply(CryptoStreamUtils.createKey(conf))); + Option.apply(CryptoStreamUtils.createKey(conf))); } else { manager = new SerializerManager(serializer, conf); } @@ -297,8 +299,8 @@ private void testMergingSpills( } private void testMergingSpills( - boolean transferToEnabled, - boolean encrypted) throws IOException { + boolean transferToEnabled, + boolean encrypted) throws IOException { final UnsafeShuffleWriter writer = createWriter(transferToEnabled); final ArrayList> dataToWrite = new ArrayList<>(); for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) { @@ -404,6 +406,51 @@ public void mergeSpillsWithFileStreamAndEncryptionAndNoCompression() throws Exce testMergingSpills(false, null, true); } + private void testMergingSpillsRecordStatistics( + boolean transferToEnabled) throws IOException { + conf.set("spark.shuffle.statistics.verbose", "true"); + + final UnsafeShuffleWriter writer = createWriter(transferToEnabled); + final ArrayList> dataToWrite = new ArrayList<>(); + for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) { + dataToWrite.add(new Tuple2<>(i, i)); + } + writer.insertRecordIntoSorter(dataToWrite.get(0)); + writer.insertRecordIntoSorter(dataToWrite.get(1)); + writer.insertRecordIntoSorter(dataToWrite.get(2)); + writer.insertRecordIntoSorter(dataToWrite.get(3)); + writer.forceSorterToSpill(); + writer.insertRecordIntoSorter(dataToWrite.get(4)); + writer.insertRecordIntoSorter(dataToWrite.get(5)); + writer.closeAndWriteOutput(); + final Option mapStatus = writer.stop(true); + assertTrue(mapStatus.isDefined()); + assertTrue(mergedOutputFile.exists()); + assertEquals(2, spillFilesCreated.size()); + + long sumOfPartitionSizes = 0; + for (long size: partitionSizesInMergedFile) { + sumOfPartitionSizes += size; + } + assertEquals(sumOfPartitionSizes, mergedOutputFile.length()); + + long sumOfPartitionRows = 0; + for (int i = 0; i < NUM_PARTITITONS; i++) { + sumOfPartitionRows += mapStatus.get().getRecordForBlock(i); + } + assertEquals(sumOfPartitionRows, 6); + } + + @Test + public void mergeSpillsRecordStatisticsWithTransferTo() throws Exception { + testMergingSpillsRecordStatistics(true); + } + + @Test + public void mergeSpillsRecordStatisticsWithFileStream() throws Exception { + testMergingSpillsRecordStatistics(false); + } + @Test public void writeEnoughDataToTriggerSpill() throws Exception { memoryManager.limit(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES); @@ -469,8 +516,8 @@ public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception writer.write(dataToWrite.iterator()); writer.stop(true); assertEquals( - HashMultiset.create(dataToWrite), - HashMultiset.create(readRecordsFromFile())); + HashMultiset.create(dataToWrite), + HashMultiset.create(readRecordsFromFile())); assertSpillFilesWereCleanedUp(); } @@ -490,8 +537,8 @@ public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception { writer.write(dataToWrite.iterator()); writer.stop(true); assertEquals( - HashMultiset.create(dataToWrite), - HashMultiset.create(readRecordsFromFile())); + HashMultiset.create(dataToWrite), + HashMultiset.create(readRecordsFromFile())); assertSpillFilesWereCleanedUp(); } @@ -514,14 +561,14 @@ public void testPeakMemoryUsed() throws Exception { taskMemoryManager = spy(taskMemoryManager); when(taskMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes); final UnsafeShuffleWriter writer = - new UnsafeShuffleWriter<>( - blockManager, - shuffleBlockResolver, - taskMemoryManager, - new SerializedShuffleHandle<>(0, 1, shuffleDep), - 0, // map id - taskContext, - conf); + new UnsafeShuffleWriter<>( + blockManager, + shuffleBlockResolver, + taskMemoryManager, + new SerializedShuffleHandle<>(0, 1, shuffleDep), + 0, // map id + taskContext, + conf); // Peak memory should be monotonically increasing. More specifically, every time // we allocate a new page it should increase by exactly the size of the page. diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index de36aefa410cd..06e498fe75512 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -26,8 +26,9 @@ import org.apache.spark.LocalSparkContext._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} +import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} +import org.apache.spark.storage.{BlockManagerId, ContinuousShuffleBlockId, ShuffleBlockId} class MapOutputTrackerSuite extends SparkFunSuite { private val conf = new SparkConf @@ -304,7 +305,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(size10000, size0, size1000, size0))) assert(tracker.containsShuffle(10)) - assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq === + assert(tracker.getMapSizesByExecutorId(10, 0, 4, false).toSeq === Seq( (BlockManagerId("a", "hostA", 1000), Seq((ShuffleBlockId(10, 0, 1), size1000), (ShuffleBlockId(10, 0, 3), size10000))), @@ -318,4 +319,38 @@ class MapOutputTrackerSuite extends SparkFunSuite { rpcEnv.shutdown() } + test("fetch contiguous partitions") { + val rpcEnv = createRpcEnv("test") + val serializer = new KryoSerializer(conf) + val tracker = newTrackerMaster() + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + tracker.registerShuffle(10, 2) + assert(tracker.containsShuffle(10)) + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + val size2000 = MapStatus.decompressSize(MapStatus.compressSize(2000L)) + val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) + tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), + Array(1000L, 10000L, 2000L))) + tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), + Array(10000L, 2000L, 1000L))) + val statuses1 = tracker.getMapSizesByExecutorId(10, 0, 2, serializerRelocatable = true) + assert(statuses1.toSet === + Seq((BlockManagerId("a", "hostA", 1000), + ArrayBuffer((ContinuousShuffleBlockId(10, 0, 0, 2), size1000 + size10000))), + (BlockManagerId("b", "hostB", 1000), + ArrayBuffer((ContinuousShuffleBlockId(10, 1, 0, 2), size10000 + size2000)))) + .toSet) + val statuses2 = tracker.getMapSizesByExecutorId(10, 2, 3, serializerRelocatable = true) + assert(statuses2.toSet === + Seq((BlockManagerId("a", "hostA", 1000), + ArrayBuffer((ShuffleBlockId(10, 0, 2), size2000))), + (BlockManagerId("b", "hostB", 1000), + ArrayBuffer((ShuffleBlockId(10, 1, 2), size1000)))) + .toSet) + assert(0 == tracker.getNumCachedSerializedBroadcast) + tracker.stop() + rpcEnv.shutdown() + } + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index 2155a0f2b6c21..4f52d617ec346 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -100,7 +100,7 @@ class MapStatusSuite extends SparkFunSuite { test("SPARK-22540: ensure HighlyCompressedMapStatus calculates correct avgSize") { val threshold = 1000 - val conf = new SparkConf().set(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD.key, threshold.toString) + val conf = new SparkConf().set(config.SHUFFLE_ACCURATE_BLOCK_SIZE_THRESHOLD.key, threshold.toString) val env = mock(classOf[SparkEnv]) doReturn(conf).when(env).conf SparkEnv.set(env) @@ -158,7 +158,7 @@ class MapStatusSuite extends SparkFunSuite { test("Blocks which are bigger than SHUFFLE_ACCURATE_BLOCK_THRESHOLD should not be " + "underestimated.") { - val conf = new SparkConf().set(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD.key, "1000") + val conf = new SparkConf().set(config.SHUFFLE_ACCURATE_BLOCK_SIZE_THRESHOLD.key, "1000") val env = mock(classOf[SparkEnv]) doReturn(conf).when(env).conf SparkEnv.set(env) diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 2d8a83c6fabed..3a567a1aed364 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -101,14 +101,16 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // Make a mocked MapOutputTracker for the shuffle reader to use to determine what // shuffle data to read. val mapOutputTracker = mock(classOf[MapOutputTracker]) - when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)).thenReturn { + when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1, + serializer.supportsRelocationOfSerializedObjects)) + .thenReturn { // Test a scenario where all data is local, to avoid creating a bunch of additional mocks // for the code to read data over the network. val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) (shuffleBlockId, byteOutputStream.size().toLong) } - Seq((localBlockManagerId, shuffleBlockIdsAndSizes)).toIterator + Seq((localBlockManagerId, shuffleBlockIdsAndSizes)) } // Create a mocked shuffle handle to pass into HashShuffleReader. diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala index ff4755833a916..4ff4c3862fb61 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala @@ -150,6 +150,20 @@ class BlockIdSuite extends SparkFunSuite { assertSame(id, BlockId(id.toString)) } + test("continuous shuffle block") { + val id = ContinuousShuffleBlockId(1, 2, 3, 4) + assertSame(id, ContinuousShuffleBlockId(1, 2, 3, 4)) + assertDifferent(id, ContinuousShuffleBlockId(3, 2, 3, 4)) + assert(id.name === "shuffle_1_2_3_4") + assert(id.asRDDId === None) + assert(id.shuffleId === 1) + assert(id.mapId === 2) + assert(id.reduceId === 3) + assert(id.numBlocks === 4) + assert(id.isShuffle) + assertSame(id, BlockId(id.toString)) + } + test("test") { val id = TestBlockId("abc") assertSame(id, TestBlockId("abc")) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index b268195e09a5b..9242afb6ac24e 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -100,7 +100,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (localBmId, localBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq), (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq) - ).toIterator + ) val iterator = new ShuffleBlockFetcherIterator( TaskContext.empty(), @@ -177,7 +177,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -245,7 +245,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -315,7 +315,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -379,7 +379,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (localBmId, localBlockLengths), (remoteBmId, remoteBlockLengths) - ).toIterator + ) val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -433,7 +433,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -492,7 +492,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) def fetchShuffleBlock( - blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])]): Unit = { + blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])]): Unit = { // Set `maxBytesInFlight` and `maxReqsInFlight` to `Int.MaxValue`, so that during the // construction of `ShuffleBlockFetcherIterator`, all requests to fetch remote shuffle blocks // are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here. @@ -510,14 +510,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq)).toIterator + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq)) fetchShuffleBlock(blocksByAddress1) // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch // shuffle block to disk. assert(tempFileManager == null) val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)).toIterator + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)) fetchShuffleBlock(blocksByAddress2) // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch // shuffle block to disk. @@ -546,7 +546,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT taskContext, transfer, blockManager, - blocksByAddress.toIterator, + blocksByAddress, (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, diff --git a/examples/pom.xml b/examples/pom.xml index 143dbcb8d4678..1196d58967cf1 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../pom.xml diff --git a/external/avro/pom.xml b/external/avro/pom.xml index 23243ffa61ddf..863cd648ccd71 100644 --- a/external/avro/pom.xml +++ b/external/avro/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../../pom.xml diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index 8ced663b2b4d1..2683d7fcc14b5 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../../pom.xml diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml index 1a0e6e308d09a..35fcd42cd0c16 100644 --- a/external/flume-assembly/pom.xml +++ b/external/flume-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../../pom.xml diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index bb62e264671c0..88e14e18404bb 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index d544ab0c21bb8..1488734fea2ac 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../../pom.xml diff --git a/external/kafka-0-10-assembly/pom.xml b/external/kafka-0-10-assembly/pom.xml index 05b2872d36c7b..ef0e0158bea7c 100644 --- a/external/kafka-0-10-assembly/pom.xml +++ b/external/kafka-0-10-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../../pom.xml diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index 6ec3f14e2eefa..d59640e8b225f 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../../pom.xml diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index 5c586aa0fa825..c2b1fd5098ce9 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../../pom.xml diff --git a/external/kafka-0-8-assembly/pom.xml b/external/kafka-0-8-assembly/pom.xml index 9684ab24eeffc..ba5fb659d174f 100644 --- a/external/kafka-0-8-assembly/pom.xml +++ b/external/kafka-0-8-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../../pom.xml diff --git a/external/kafka-0-8/pom.xml b/external/kafka-0-8/pom.xml index 51f1c62f418eb..94df9413bc5c6 100644 --- a/external/kafka-0-8/pom.xml +++ b/external/kafka-0-8/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../../pom.xml diff --git a/external/kinesis-asl-assembly/pom.xml b/external/kinesis-asl-assembly/pom.xml index fcd63bcd5f38b..bd7d400a5ec81 100644 --- a/external/kinesis-asl-assembly/pom.xml +++ b/external/kinesis-asl-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../../pom.xml diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml index c4a2aafe5b001..2e2cd11d168bd 100644 --- a/external/kinesis-asl/pom.xml +++ b/external/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../../pom.xml diff --git a/external/spark-ganglia-lgpl/pom.xml b/external/spark-ganglia-lgpl/pom.xml index 87c6eda24bc39..850f0bd16a66f 100644 --- a/external/spark-ganglia-lgpl/pom.xml +++ b/external/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index efc8619203ee1..ce3ce8c7e5aea 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../pom.xml diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml index 466f2c3023da7..ad5fdc96d3a74 100644 --- a/hadoop-cloud/pom.xml +++ b/hadoop-cloud/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../pom.xml diff --git a/launcher/pom.xml b/launcher/pom.xml index e47835714f837..86465b98b013d 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../pom.xml diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index 2e8ada927ef36..d60815aafadee 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../pom.xml diff --git a/mllib/pom.xml b/mllib/pom.xml index fc799ea4bfd89..bfc30a52619dc 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../pom.xml diff --git a/pom.xml b/pom.xml index 5be677db7333f..e2b7026cc91af 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 pom Spark Project Parent POM http://spark.apache.org/ diff --git a/repl/pom.xml b/repl/pom.xml index 4f3fee556984b..c964f51c39674 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../pom.xml diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index 212ceaf093ce0..868dab29e5f56 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../../../pom.xml diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index 924008dc1e6ca..13c18b099eff3 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../../../pom.xml diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml index 2b1d428a1c6a4..41f523d9d0bd8 100644 --- a/resource-managers/mesos/pom.xml +++ b/resource-managers/mesos/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../../pom.xml diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index 24780b454569a..819821d437564 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../../pom.xml diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 1154c583f0eee..ffbac54fd0df1 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../../pom.xml diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9bd3f20a1496c..773e071b05e0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -271,6 +271,71 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ADAPTIVE_EXECUTION_JOIN_ENABLED = buildConf("spark.sql.adaptive.join.enabled") + .doc("When true and adaptive execution is enabled, a better join strategy is determined at " + + "runtime.") + .booleanConf + .createWithDefault(true) + + val ADAPTIVE_BROADCASTJOIN_THRESHOLD = buildConf("spark.sql.adaptiveBroadcastJoinThreshold") + .doc("Configures the maximum size in bytes for a table that will be broadcast to all worker " + + "nodes when performing a join in adaptive exeuction mode. If not set, it equals to " + + "spark.sql.autoBroadcastJoinThreshold.") + .longConf + .createOptional + + val ADAPTIVE_EXECUTION_ALLOW_ADDITIONAL_SHUFFLE = + buildConf("spark.sql.adaptive.allowAdditionalShuffle") + .doc("When true, additional shuffles are allowed during plan optimizations in adaptive " + + "execution") + .booleanConf + .createWithDefault(false) + + val ADAPTIVE_EXECUTION_SKEWED_JOIN_ENABLED = buildConf("spark.sql.adaptive.skewedJoin.enabled") + .doc("When true and adaptive execution is enabled, a skewed join is automatically handled at " + + "runtime.") + .booleanConf + .createWithDefault(false) + + val ADAPTIVE_EXECUTION_SKEWED_PARTITION_FACTOR = + buildConf("spark.sql.adaptive.skewedPartitionFactor") + .doc("A partition is considered as a skewed partition " + + "if its size is larger than this factor " + + "multiple the median partition size and also larger than " + + "spark.sql.adaptive.skewedPartitionSizeThreshold, " + + "or if its row count is larger than this " + + "factor multiple the median row count and also larger than " + + "spark.sql.adaptive.skewedPartitionRowCountThreshold.") + .intConf + .createWithDefault(10) + + val ADAPTIVE_EXECUTION_SKEWED_PARTITION_SIZE_THRESHOLD = + buildConf("spark.sql.adaptive.skewedPartitionSizeThreshold") + .doc("Configures the minimum size in bytes for a partition that is considered as a skewed " + + "partition in adaptive skewed join.") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(64 * 1024 * 1024) + + val ADAPTIVE_EXECUTION_SKEWED_PARTITION_ROW_COUNT_THRESHOLD = + buildConf("spark.sql.adaptive.skewedPartitionRowCountThreshold") + .doc("Configures the minimum row count for a partition that is considered as a skewed " + + "partition in adaptive skewed join.") + .longConf + .createWithDefault(10L * 1000 * 1000) + + val ADAPTIVE_EXECUTION_SKEWED_PARTITION_MAX_SPLITS = + buildConf("spark.sql.adaptive.skewedPartitionMaxSplits") + .doc("Configures the maximum number of task to handle a skewed partition in adaptive skewed" + + "join.") + .intConf + .createWithDefault(5) + + val ADAPTIVE_EXECUTION_TARGET_POSTSHUFFLE_ROW_COUNT = + buildConf("spark.sql.adaptive.shuffle.targetPostShuffleRowCount") + .doc("The target post-shuffle row count of a task.") + .longConf + .createWithDefault(20L * 1000 * 1000) + val SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE = buildConf("spark.sql.adaptive.shuffle.targetPostShuffleInputSize") .doc("The target post-shuffle input size in bytes of a task.") @@ -293,6 +358,14 @@ object SQLConf { .intConf .createWithDefault(-1) + val SHUFFLE_MAX_NUM_POSTSHUFFLE_PARTITIONS = + buildConf("spark.sql.adaptive.maxNumPostShufflePartitions") + .doc("The advisory maximum number of post-shuffle partitions used in adaptive execution.") + .intConf + .checkValue(numPartitions => numPartitions > 0, "The maximum shuffle partition number " + + "must be a positive integer.") + .createWithDefault(500) + val SUBEXPRESSION_ELIMINATION_ENABLED = buildConf("spark.sql.subexpressionElimination.enabled") .internal() @@ -1688,11 +1761,35 @@ class SQLConf extends Serializable with Logging { def allowAEwhenRepartition: Boolean = getConf(ALLOW_ADAPTIVE_WHEN_REPARTITION) + def adaptiveJoinEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_JOIN_ENABLED) + + def adaptiveBroadcastJoinThreshold: Long = + getConf(ADAPTIVE_BROADCASTJOIN_THRESHOLD).getOrElse(autoBroadcastJoinThreshold) + + def adaptiveAllowAdditionShuffle: Boolean = getConf(ADAPTIVE_EXECUTION_ALLOW_ADDITIONAL_SHUFFLE) + def adaptiveExecutionEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_ENABLED) def minNumPostShufflePartitions: Int = getConf(SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS) + def maxNumPostShufflePartitions: Int = getConf(SHUFFLE_MAX_NUM_POSTSHUFFLE_PARTITIONS) + + def adaptiveSkewedJoinEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_SKEWED_JOIN_ENABLED) + + def adaptiveSkewedFactor: Int = getConf(ADAPTIVE_EXECUTION_SKEWED_PARTITION_FACTOR) + + def adaptiveSkewedSizeThreshold: Long = + getConf(ADAPTIVE_EXECUTION_SKEWED_PARTITION_SIZE_THRESHOLD) + + def adaptiveSkewedRowCountThreshold: Long = + getConf(ADAPTIVE_EXECUTION_SKEWED_PARTITION_ROW_COUNT_THRESHOLD) + + def adaptiveSkewedMaxSplits: Int = getConf(ADAPTIVE_EXECUTION_SKEWED_PARTITION_MAX_SPLITS) + + def adaptiveTargetPostShuffleRowCount: Long = + getConf(ADAPTIVE_EXECUTION_TARGET_POSTSHUFFLE_ROW_COUNT) + def minBatchesToRetain: Int = getConf(MIN_BATCHES_TO_RETAIN) def maxBatchesToRetainInMemory: Int = getConf(MAX_BATCHES_TO_RETAIN_IN_MEMORY) diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 6e95bc300f86b..9c93dc5e6a62c 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../../pom.xml diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 129a12d14adf0..963a9810b9819 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partition import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.statsEstimation.Statistics import org.apache.spark.sql.sources.{BaseRelation, Filter} import org.apache.spark.sql.types.StructType import org.apache.spark.util.{TaskCompletionListener, Utils} @@ -74,6 +75,10 @@ trait DataSourceScanExec extends LeafExecNode with CodegenSupport { private def redact(text: String): String = { Utils.redact(sqlContext.sessionState.conf.stringRedactionPattern, text) } + + override def computeStats(): Statistics = { + Statistics(sizeInBytes = relation.sizeInBytes) + } } /** Physical plan node for scanning data from a relation. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 2962becb64e88..3f391d572ca15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -23,8 +23,10 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.{Statistics => LogicalPlanStatistics} import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.statsEstimation.{Statistics => PhsicalPlanStatistics} import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils @@ -88,7 +90,7 @@ case class ExternalRDD[T]( override protected def stringArgs: Iterator[Any] = Iterator(output) - override def computeStats(): Statistics = Statistics( + override def computeStats(): LogicalPlanStatistics = LogicalPlanStatistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) @@ -122,6 +124,13 @@ case class ExternalRDDScanExec[T]( override def simpleString: String = { s"$nodeName${output.mkString("[", ",", "]")}" } + + override def computeStats(): PhsicalPlanStatistics = + PhsicalPlanStatistics( + // TODO: Instead of returning a default value here, find a way to return a meaningful size + // estimate for RDDs. See PR 1238 for more discussions. + sizeInBytes = BigInt(conf.defaultSizeInBytes) + ) } /** Logical plan node for scanning data from an RDD of InternalRow. */ @@ -162,7 +171,7 @@ case class LogicalRDD( override protected def stringArgs: Iterator[Any] = Iterator(output, isStreaming) - override def computeStats(): Statistics = Statistics( + override def computeStats(): LogicalPlanStatistics = LogicalPlanStatistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) @@ -199,4 +208,11 @@ case class RDDScanExec( override def simpleString: String = { s"$nodeName${Utils.truncatedString(output, "[", ",", "]")}" } + + override def computeStats(): PhsicalPlanStatistics = + PhsicalPlanStatistics( + // TODO: Instead of returning a default value here, find a way to return a meaningful size + // estimate for RDDs. See PR 1238 for more discussions. + sizeInBytes = BigInt(conf.defaultSizeInBytes) + ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala index 448eb703eacde..2e8f7cb9603ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala @@ -21,6 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.statsEstimation.Statistics /** @@ -76,4 +77,10 @@ case class LocalTableScanExec( longMetric("numOutputRows").add(taken.size) taken } + + override def computeStats(): Statistics = { + val rowSize = 8 + output.map(_.dataType.defaultSize).sum + val rowCount = rows.size + Statistics(rowSize * rowCount) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 64f49e2d0d4e6..c957285b2a315 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.adaptive.PlanQueryStage import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, TimestampType, _} @@ -84,7 +85,11 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { * row format conversions as needed. */ protected def prepareForExecution(plan: SparkPlan): SparkPlan = { - preparations.foldLeft(plan) { case (sp, rule) => rule.apply(sp) } + if (sparkSession.sessionState.conf.adaptiveExecutionEnabled) { + adaptivePreparations.foldLeft(plan) { case (sp, rule) => rule.apply(sp)} + } else { + preparations.foldLeft(plan) { case (sp, rule) => rule.apply(sp)} + } } /** A sequence of rules that will be applied in order to the physical plan before execution. */ @@ -95,6 +100,15 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { ReuseExchange(sparkSession.sessionState.conf), ReuseSubquery(sparkSession.sessionState.conf)) + protected def adaptivePreparations: Seq[Rule[SparkPlan]] = Seq( + PlanSubqueries(sparkSession), + EnsureRequirements(sparkSession.sessionState.conf), + ReuseSubquery(sparkSession.sessionState.conf), + // PlanQueryStage needs to be the last rule because it divides the plan into multiple sub-trees + // by inserting leaf node QueryStageInput. Transforming the plan after applying this rule will + // only transform node in a sub-tree. + PlanQueryStage(sparkSession.sessionState.conf)) + protected def stringOrError[A](f: => A): String = try f.toString catch { case e: AnalysisException => e.toString } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 6cb3de4f32c87..6d990af1d4c3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicLong -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.ui.{PostQueryExecutionForKylin, SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart} @@ -139,4 +139,22 @@ object SQLExecution extends Logging{ } } } + + + def withExecutionIdAndJobDesc[T]( + sc: SparkContext, + executionId: String, + jobDesc: String)(body: => T): T = { + val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + val oldJobDesc = sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) + + try { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) + sc.setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, jobDesc) + body + } finally { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) + sc.setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, oldJobDesc) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala index 862ee05392f37..acc80bbbad577 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala @@ -29,9 +29,9 @@ import org.apache.spark.sql.catalyst.InternalRow * (`startPreShufflePartitionIndex` to `endPreShufflePartitionIndex - 1`, inclusive). */ private final class ShuffledRowRDDPartition( - val postShufflePartitionIndex: Int, - val startPreShufflePartitionIndex: Int, - val endPreShufflePartitionIndex: Int) extends Partition { + val postShufflePartitionIndex: Int, + val startPreShufflePartitionIndex: Int, + val endPreShufflePartitionIndex: Int) extends Partition { override val index: Int = postShufflePartitionIndex } @@ -112,7 +112,8 @@ class CoalescedPartitioner(val parent: Partitioner, val partitionStartIndices: A */ class ShuffledRowRDD( var dependency: ShuffleDependency[Int, InternalRow, InternalRow], - specifiedPartitionStartIndices: Option[Array[Int]] = None) + specifiedPartitionStartIndices: Option[Array[Int]] = None, + specifiedPartitionEndIndices: Option[Array[Int]] = None) extends RDD[InternalRow](dependency.rdd.context, Nil) { private[this] val numPreShufflePartitions = dependency.partitioner.numPartitions @@ -125,23 +126,24 @@ class ShuffledRowRDD( (0 until numPreShufflePartitions).toArray } - private[this] val part: Partitioner = - new CoalescedPartitioner(dependency.partitioner, partitionStartIndices) - override def getDependencies: Seq[Dependency[_]] = List(dependency) - override val partitioner: Option[Partitioner] = Some(part) + override val partitioner: Option[Partitioner] = specifiedPartitionEndIndices match { + case Some(indices) => None + case None => Some(new CoalescedPartitioner(dependency.partitioner, partitionStartIndices)) + } override def getPartitions: Array[Partition] = { - assert(partitionStartIndices.length == part.numPartitions) Array.tabulate[Partition](partitionStartIndices.length) { i => val startIndex = partitionStartIndices(i) - val endIndex = - if (i < partitionStartIndices.length - 1) { + val endIndex = specifiedPartitionEndIndices match { + case Some(indices) => indices(i) + case None => if (i < partitionStartIndices.length - 1) { partitionStartIndices(i + 1) } else { numPreShufflePartitions } + } new ShuffledRowRDDPartition(i, startIndex, endIndex) } } @@ -157,11 +159,11 @@ class ShuffledRowRDD( // The range of pre-shuffle partitions that we are fetching at here is // [startPreShufflePartitionIndex, endPreShufflePartitionIndex - 1]. val reader = - SparkEnv.get.shuffleManager.getReader( - dependency.shuffleHandle, - shuffledRowPartition.startPreShufflePartitionIndex, - shuffledRowPartition.endPreShufflePartitionIndex, - context) + SparkEnv.get.shuffleManager.getReader( + dependency.shuffleHandle, + shuffledRowPartition.startPreShufflePartitionIndex, + shuffledRowPartition.endPreShufflePartitionIndex, + context) reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 89d1164d7b96d..6864beec8326b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => GenPredic import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.statsEstimation.{SparkPlanStats, Statistics} import org.apache.spark.sql.types.DataType import org.apache.spark.util.ThreadUtils @@ -44,7 +45,11 @@ import org.apache.spark.util.ThreadUtils * * The naming convention is that physical operators end with "Exec" suffix, e.g. [[ProjectExec]]. */ -abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializable { +abstract class SparkPlan + extends QueryPlan[SparkPlan] + with SparkPlanStats + with Logging + with Serializable { /** * A handle to the SQL Context that was used to create this plan. Since many operators need @@ -452,6 +457,10 @@ object SparkPlan { trait LeafExecNode extends SparkPlan { override final def children: Seq[SparkPlan] = Nil override def producedAttributes: AttributeSet = outputSet + + + /** LeafExec nodes that can survive analysis must define their own statistics. */ + def computeStats(): Statistics = throw new UnsupportedOperationException } object UnaryExecNode { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index 59ffd16381116..863babdbb080a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.execution.adaptive.QueryStageInput import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.metric.SQLMetricInfo @@ -51,6 +52,7 @@ private[execution] object SparkPlanInfo { def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = { val children = plan match { case ReusedExchangeExec(_, child) => child :: Nil + case i: QueryStageInput => i.childStage :: Nil case _ => plan.children ++ plan.subqueries } val metrics = plan.metrics.toSeq.map { case (key, metric) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanVisitor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanVisitor.scala new file mode 100644 index 0000000000000..2ae7604261c97 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanVisitor.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.execution.adaptive.ShuffleQueryStage +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.joins.{HashJoin, SortMergeJoinExec} + +/** + * A visitor pattern for traversing a [[SparkPlan]] tree and compute some properties. + */ +trait SparkPlanVisitor[T] { + + def visit(p: SparkPlan): T = p match { + case p: FilterExec => visitFilterExec(p) + case p: HashAggregateExec => visitHashAggregateExec(p) + case p: HashJoin => visitHashJoin(p) + case p: ProjectExec => visitProjectExec(p) + case p: ShuffleExchangeExec => visitShuffleExchangeExec(p) + case p: SortAggregateExec => visitSortAggregateExec(p) + case p: SortMergeJoinExec => visitSortMergeJoinExec(p) + case p: ShuffleQueryStage => visitShuffleQueryStage(p) + case p: SparkPlan => default(p) + } + + def default(p: SparkPlan): T + + def visitFilterExec(p: FilterExec): T + + def visitHashAggregateExec(p: HashAggregateExec): T + + def visitHashJoin(p: HashJoin): T + + def visitProjectExec(p: ProjectExec): T + + def visitShuffleExchangeExec(p: ShuffleExchangeExec): T + + def visitSortAggregateExec(p: SortAggregateExec): T + + def visitSortMergeJoinExec(p: SortMergeJoinExec): T + + def visitShuffleQueryStage(p: ShuffleQueryStage): T +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4b6e54d7c4780..6674ac549ad00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -360,7 +360,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION) - aggregate.AggUtils.planStreamingAggregation( + org.apache.spark.sql.execution.aggregate.AggUtils.planStreamingAggregation( namedGroupingExpressions, aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), rewrittenResultExpressions, @@ -445,13 +445,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val aggregateOperator = if (functionsWithDistinct.isEmpty) { - aggregate.AggUtils.planAggregateWithoutDistinct( + org.apache.spark.sql.execution.aggregate.AggUtils.planAggregateWithoutDistinct( groupingExpressions, aggregateExpressions, resultExpressions, planLater(child)) } else { - aggregate.AggUtils.planAggregateWithOneDistinct( + org.apache.spark.sql.execution.aggregate.AggUtils.planAggregateWithOneDistinct( groupingExpressions, functionsWithDistinct, functionsWithoutDistinct, @@ -662,7 +662,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case r: logical.Range => execution.RangeExec(r) :: Nil case r: logical.RepartitionByExpression => - exchange.ShuffleExchangeExec(r.partitioning, planLater(r.child), None, Some(true)) :: Nil + exchange.ShuffleExchangeExec(r.partitioning, planLater(r.child), Some(true)) :: Nil case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil case r: LogicalRDD => RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveShuffledRowRDD.scala new file mode 100644 index 0000000000000..8c6e4b7780cd5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveShuffledRowRDD.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow + +/** + * The [[Partition]] used by [[AdaptiveShuffledRowRDD]]. A post-shuffle partition + * (identified by `postShufflePartitionIndex`) contains a range of pre-shuffle partitions + * (`preShufflePartitionIndex` from `startMapId` to `endMapId - 1`, inclusive). + */ +private final class AdaptiveShuffledRowRDDPartition( + val postShufflePartitionIndex: Int, + val preShufflePartitionIndex: Int, + val startMapId: Int, + val endMapId: Int) extends Partition { + override val index: Int = postShufflePartitionIndex +} + +/** + * This is a specialized version of [[org.apache.spark.sql.execution.ShuffledRowRDD]]. This is used + * in Spark SQL adaptive execution to solve data skew issues. This RDD includes rearranged + * partitions from mappers. + * + * This RDD takes a [[ShuffleDependency]] (`dependency`), a partitionIndex + * and an array of map Id start indices as input arguments + * (`specifiedMapIdStartIndices`). + * + */ +class AdaptiveShuffledRowRDD( + var dependency: ShuffleDependency[Int, InternalRow, InternalRow], + partitionIndex: Int, + specifiedMapIdStartIndices: Option[Array[Int]] = None, + specifiedMapIdEndIndices: Option[Array[Int]] = None) + extends RDD[InternalRow](dependency.rdd.context, Nil) { + + private[this] val numPostShufflePartitions = dependency.rdd.partitions.length + + private[this] val mapIdStartIndices: Array[Int] = specifiedMapIdStartIndices match { + case Some(indices) => indices + case None => (0 until numPostShufflePartitions).toArray + } + + override def getDependencies: Seq[Dependency[_]] = List(dependency) + + override def getPartitions: Array[Partition] = { + Array.tabulate[Partition](mapIdStartIndices.length) { i => + val startIndex = mapIdStartIndices(i) + val endIndex = specifiedMapIdEndIndices match { + case Some(indices) => indices(i) + case None => + if (i < mapIdStartIndices.length - 1) { + mapIdStartIndices(i + 1) + } else { + numPostShufflePartitions + } + } + new AdaptiveShuffledRowRDDPartition(i, partitionIndex, startIndex, endIndex) + } + } + + override def getPreferredLocations(partition: Partition): Seq[String] = { + val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + val shuffledRowRDDPartition = partition.asInstanceOf[AdaptiveShuffledRowRDDPartition] + val dep = dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + tracker.getMapLocation( + dep, shuffledRowRDDPartition.startMapId, shuffledRowRDDPartition.endMapId) + } + + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + val shuffledRowPartition = split.asInstanceOf[AdaptiveShuffledRowRDDPartition] + val index = shuffledRowPartition.preShufflePartitionIndex + val reader = SparkEnv.get.shuffleManager.getReader( + dependency.shuffleHandle, + index, + index + 1, + context, + shuffledRowPartition.startMapId, + shuffledRowPartition.endMapId) + reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2) + } + + override def clearDependencies() { + super.clearDependencies() + dependency = null + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/HandleSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/HandleSkewedJoin.scala new file mode 100644 index 0000000000000..f0d0bbb1823fa --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/HandleSkewedJoin.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import scala.collection.immutable.Nil +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{SortExec, SparkPlan, UnionExec} +import org.apache.spark.sql.execution.joins.SortMergeJoinExec +import org.apache.spark.sql.execution.statsEstimation.PartitionStatistics +import org.apache.spark.sql.internal.SQLConf + +case class HandleSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { + + private val supportedJoinTypes = + Inner :: Cross :: LeftSemi :: LeftAnti :: LeftOuter :: RightOuter :: Nil + + private def isSizeSkewed(size: Long, medianSize: Long): Boolean = { + size > medianSize * conf.adaptiveSkewedFactor && + size > conf.adaptiveSkewedSizeThreshold + } + + private def isRowCountSkewed(rowCount: Long, medianRowCount: Long): Boolean = { + rowCount > medianRowCount * conf.adaptiveSkewedFactor && + rowCount > conf.adaptiveSkewedRowCountThreshold + } + + /** + * A partition is considered as a skewed partition if its size is larger than the median + * partition size * spark.sql.adaptive.skewedPartitionFactor and also larger than + * spark.sql.adaptive.skewedPartitionSizeThreshold, or if its row count is larger than + * the median row count * spark.sql.adaptive.skewedPartitionFactor and also larger than + * spark.sql.adaptive.skewedPartitionRowCountThreshold. + */ + private def isSkewed( + stats: PartitionStatistics, + partitionId: Int, + medianSize: Long, + medianRowCount: Long): Boolean = { + isSizeSkewed(stats.bytesByPartitionId(partitionId), medianSize) || + isRowCountSkewed(stats.recordsByPartitionId(partitionId), medianRowCount) + } + + private def medianSizeAndRowCount(stats: PartitionStatistics): (Long, Long) = { + val bytesLen = stats.bytesByPartitionId.length + val rowCountsLen = stats.recordsByPartitionId.length + val bytes = stats.bytesByPartitionId.sorted + val rowCounts = stats.recordsByPartitionId.sorted + val medSize = if (bytes(bytesLen / 2) > 0) bytes(bytesLen / 2) else 1 + val medRowCount = if (rowCounts(rowCountsLen / 2) > 0) rowCounts(rowCountsLen / 2) else 1 + (medSize, medRowCount) + } + + /** + * To equally divide n elements into m buckets, basically each bucket should have n/m elements, + * for the remaining n%m elements, add one more element to the first n%m buckets each. Returns + * a sequence with length numBuckets and each value represents the start index of each bucket. + */ + def equallyDivide(numElements: Int, numBuckets: Int): Seq[Int] = { + val elementsPerBucket = numElements / numBuckets + val remaining = numElements % numBuckets + val splitPoint = (elementsPerBucket + 1) * remaining + (0 until remaining).map(_ * (elementsPerBucket + 1)) ++ + (remaining until numBuckets).map(i => splitPoint + (i - remaining) * elementsPerBucket) + } + + /** + * We split the partition into several splits. Each split reads the data from several map outputs + * ranging from startMapId to endMapId(exclusive). This method calculates the split number and + * the startMapId for all splits. + */ + private def estimateMapIdStartIndices( + queryStageInput: ShuffleQueryStageInput, + partitionId: Int, + medianSize: Long, + medianRowCount: Long): Array[Int] = { + val stats = queryStageInput.childStage.stats + val size = stats.bytesByPartitionId.get(partitionId) + val rowCount = stats.recordStatistics.get.recordsByPartitionId(partitionId) + val factor = Math.max(size / medianSize, rowCount / medianRowCount) + val numSplits = Math.min(conf.adaptiveSkewedMaxSplits, + Math.min(factor.toInt, queryStageInput.numMapper)) + equallyDivide(queryStageInput.numMapper, numSplits).toArray + } + + /** + * Base optimization support check: the join type is supported and plan statistics is available. + * Note that for some join types(like left outer), whether a certain partition can be optimized + * also depends on the filed isSkewAndSupportsSplit. + */ + private def supportOptimization( + joinType: JoinType, + left: QueryStageInput, + right: QueryStageInput): Boolean = { + supportedJoinTypes.contains(joinType) && + left.childStage.stats.getPartitionStatistics.isDefined && + right.childStage.stats.getPartitionStatistics.isDefined + } + + private def supportSplitOnLeftPartition(joinType: JoinType) = joinType != RightOuter + + private def supportSplitOnRightPartition(joinType: JoinType) = { + joinType != LeftOuter && joinType != LeftSemi && joinType != LeftAnti + } + + private def handleSkewedJoin( + operator: SparkPlan, + queryStage: QueryStage): SparkPlan = operator.transformUp { + case smj @ SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, + SortExec(_, _, left: ShuffleQueryStageInput, _), + SortExec(_, _, right: ShuffleQueryStageInput, _)) + if supportOptimization(joinType, left, right) => + + val leftStats = left.childStage.stats.getPartitionStatistics.get + val rightStats = right.childStage.stats.getPartitionStatistics.get + val numPartitions = leftStats.bytesByPartitionId.length + val (leftMedSize, leftMedRowCount) = medianSizeAndRowCount(leftStats) + val (rightMedSize, rightMedRowCount) = medianSizeAndRowCount(rightStats) + logInfo(s"HandlingSkewedJoin left medSize/rowCounts: ($leftMedSize, $leftMedRowCount)" + + s" right medSize/rowCounts ($rightMedSize, $rightMedRowCount)") + + logInfo(s"left bytes Max : ${leftStats.bytesByPartitionId.max}") + logInfo(s"left row counts Max : ${leftStats.recordsByPartitionId.max}") + logInfo(s"right bytes Max : ${rightStats.bytesByPartitionId.max}") + logInfo(s"right row counts Max : ${rightStats.recordsByPartitionId.max}") + + val skewedPartitions = mutable.HashSet[Int]() + val subJoins = mutable.ArrayBuffer[SparkPlan](smj) + for (partitionId <- 0 until numPartitions) { + val isLeftSkew = isSkewed(leftStats, partitionId, leftMedSize, leftMedRowCount) + val isRightSkew = isSkewed(rightStats, partitionId, rightMedSize, rightMedRowCount) + val isSkewAndSupportsSplit = + (isLeftSkew && supportSplitOnLeftPartition(joinType)) || + (isRightSkew && supportSplitOnRightPartition(joinType)) + + if (isSkewAndSupportsSplit) { + skewedPartitions += partitionId + val leftMapIdStartIndices = if (isLeftSkew && supportSplitOnLeftPartition(joinType)) { + estimateMapIdStartIndices(left, partitionId, leftMedSize, leftMedRowCount) + } else { + Array(0) + } + val rightMapIdStartIndices = if (isRightSkew && supportSplitOnRightPartition(joinType)) { + estimateMapIdStartIndices(right, partitionId, rightMedSize, rightMedRowCount) + } else { + Array(0) + } + + for (i <- 0 until leftMapIdStartIndices.length; + j <- 0 until rightMapIdStartIndices.length) { + val leftEndMapId = if (i == leftMapIdStartIndices.length - 1) { + left.numMapper + } else { + leftMapIdStartIndices(i + 1) + } + val rightEndMapId = if (j == rightMapIdStartIndices.length - 1) { + right.numMapper + } else { + rightMapIdStartIndices(j + 1) + } + + val leftInput = + SkewedShuffleQueryStageInput( + left.childStage, left.output, partitionId, leftMapIdStartIndices(i), leftEndMapId) + val rightInput = + SkewedShuffleQueryStageInput( + right.childStage, right.output, partitionId, + rightMapIdStartIndices(j), rightEndMapId) + + subJoins += + SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, leftInput, rightInput) + } + } + } + logInfo(s"skewed partition number is ${skewedPartitions.size}") + if (skewedPartitions.size > 0) { + left.skewedPartitions = Some(skewedPartitions) + right.skewedPartitions = Some(skewedPartitions) + UnionExec(subJoins.toList) + } else { + smj + } + } + + def apply(plan: SparkPlan): SparkPlan = { + if (!conf.adaptiveSkewedJoinEnabled) { + plan + } else { + plan match { + case queryStage: QueryStage => + val queryStageInputs: Seq[ShuffleQueryStageInput] = queryStage.collect { + case input: ShuffleQueryStageInput => input + } + if (queryStageInputs.length == 2) { + // Currently we only support handling skewed join for 2 table join. + val optimizedPlan = handleSkewedJoin(queryStage.child, queryStage) + queryStage.child = optimizedPlan + queryStage + } else { + queryStage + } + case _ => plan + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LocalShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LocalShuffledRowRDD.scala new file mode 100644 index 0000000000000..7a7ac31875a09 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LocalShuffledRowRDD.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark._ +import org.apache.spark.rdd.{RDD, ShuffledRDDPartition} +import org.apache.spark.sql.catalyst.InternalRow + +/** + * This is a specialized version of [[org.apache.spark.sql.execution.ShuffledRowRDD]]. This is used + * in Spark SQL adaptive execution when a shuffle join is converted to broadcast join at runtime + * because the map output of one input table is small enough for broadcast. This RDD represents the + * data of another input table of the join that reads from shuffle. Each partition of the RDD reads + * the whole data from just one mapper output locally. So actually there is no data transferred + * from the network. + * + * This RDD takes a [[ShuffleDependency]] (`dependency`). + * + * The `dependency` has the parent RDD of this RDD, which represents the dataset before shuffle + * (i.e. map output). Elements of this RDD are (partitionId, Row) pairs. + * Partition ids should be in the range [0, numPartitions - 1]. + * `dependency.partitioner.numPartitions` is the number of pre-shuffle partitions. (i.e. the number + * of partitions of the map output). The post-shuffle partition number is the same to the parent + * RDD's partition number. + */ +class LocalShuffledRowRDD( + var dependency: ShuffleDependency[Int, InternalRow, InternalRow], + specifiedPartitionStartIndices: Option[Array[Int]] = None, + specifiedPartitionEndIndices: Option[Array[Int]] = None) + extends RDD[InternalRow](dependency.rdd.context, Nil) { + + private[this] val numPreShufflePartitions = dependency.partitioner.numPartitions + private[this] val numPostShufflePartitions = dependency.rdd.partitions.length + + private[this] val partitionStartIndices: Array[Int] = specifiedPartitionStartIndices match { + case Some(indices) => indices + case None => Array(0) + } + + private[this] val partitionEndIndices: Array[Int] = specifiedPartitionEndIndices match { + case Some(indices) => indices + case None if specifiedPartitionStartIndices.isEmpty => Array(numPreShufflePartitions) + case _ => specifiedPartitionStartIndices.get.drop(1) :+ numPreShufflePartitions + } + + override def getDependencies: Seq[Dependency[_]] = List(dependency) + + override def getPartitions: Array[Partition] = { + assert(partitionStartIndices.length == partitionEndIndices.length) + Array.tabulate[Partition](numPostShufflePartitions) { i => + new ShuffledRDDPartition(i) + } + } + + override def getPreferredLocations(partition: Partition): Seq[String] = { + val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + val dep = dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + tracker.getMapLocation(dep, partition.index, partition.index + 1) + } + + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + val shuffledRowPartition = split.asInstanceOf[ShuffledRDDPartition] + val mapId = shuffledRowPartition.index + // Connect the the InternalRows read by each ShuffleReader + new Iterator[InternalRow] { + val readers = partitionStartIndices.zip(partitionEndIndices).map { case (start, end) => + SparkEnv.get.shuffleManager.getReader( + dependency.shuffleHandle, + start, + end, + context, + mapId, + mapId + 1) + } + + var i = 0 + var iter = readers(i).read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2) + + override def hasNext = { + while (iter.hasNext == false && i + 1 <= readers.length - 1) { + i += 1 + iter = readers(i).read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2) + } + iter.hasNext + } + + override def next() = iter.next() + } + } + + override def clearDependencies() { + super.clearDependencies() + dependency = null + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeJoin.scala new file mode 100644 index 0000000000000..65b00653d42d0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeJoin.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{SortExec, SparkPlan} +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft, BuildRight, SortMergeJoinExec} +import org.apache.spark.sql.internal.SQLConf + +case class OptimizeJoin(conf: SQLConf) extends Rule[SparkPlan] { + + private def canBuildRight(joinType: JoinType): Boolean = joinType match { + case _: InnerLike | LeftOuter | LeftSemi | LeftAnti => true + case j: ExistenceJoin => true + case _ => false + } + + private def canBuildLeft(joinType: JoinType): Boolean = joinType match { + case _: InnerLike | RightOuter => true + case _ => false + } + + private def canBroadcast(plan: SparkPlan): Boolean = { + plan.stats.sizeInBytes >= 0 && plan.stats.sizeInBytes <= conf.adaptiveBroadcastJoinThreshold + } + + private def removeSort(plan: SparkPlan): SparkPlan = { + plan match { + case s: SortExec => s.child + case p: SparkPlan => p + } + } + + private[adaptive] def calculatePartitionStartEndIndices( + bytesByPartitionId: Array[Long]): (Array[Int], Array[Int]) = { + val partitionStartIndices = ArrayBuffer[Int]() + val partitionEndIndices = ArrayBuffer[Int]() + var continuousZeroFlag = false + var i = 0 + for (bytes <- bytesByPartitionId) { + if (bytes != 0 && !continuousZeroFlag) { + partitionStartIndices += i + continuousZeroFlag = true + } else if (bytes == 0 && continuousZeroFlag) { + partitionEndIndices += i + continuousZeroFlag = false + } + i += 1 + } + if (continuousZeroFlag) { + partitionEndIndices += i + } + if (partitionStartIndices.length == 0) { + (Array(0), Array(0)) + } else { + (partitionStartIndices.toArray, partitionEndIndices.toArray) + } + } + + // After transforming to BroadcastJoin from SortMergeJoin, local shuffle read should be used and + // there's opportunity to read less partitions based on previous shuffle write results. + private def optimizeForLocalShuffleReadLessPartitions( + broadcastSidePlan: SparkPlan, + childrenPlans: Seq[SparkPlan]) = { + // If there's shuffle write on broadcast side, then find the partitions with 0 size and ignore + // reading them in local shuffle read. + broadcastSidePlan match { + case broadcast: ShuffleQueryStageInput + if broadcast.childStage.stats.bytesByPartitionId.isDefined => + val (startIndicies, endIndicies) = calculatePartitionStartEndIndices(broadcast.childStage + .stats.bytesByPartitionId.get) + childrenPlans.foreach { + case input: ShuffleQueryStageInput => + input.partitionStartIndices = Some(startIndicies) + input.partitionEndIndices = Some(endIndicies) + case _ => + } + case _ => + } + } + + // While the changes in optimizeForLocalShuffleReadLessPartitions has additional exchanges, + // we need to revert this changes. + private def revertShuffleReadChanges( + childrenPlans: Seq[SparkPlan]) = { + childrenPlans.foreach { + case input: ShuffleQueryStageInput => + input.isLocalShuffle = false + input.partitionEndIndices = None + input.partitionStartIndices = None + case _ => + } + } + + private def optimizeSortMergeJoin( + smj: SortMergeJoinExec, + queryStage: QueryStage): SparkPlan = { + smj match { + case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right) => + val broadcastSide = if (canBuildRight(joinType) && canBroadcast(right)) { + Some(BuildRight) + } else if (canBuildLeft(joinType) && canBroadcast(left)) { + Some(BuildLeft) + } else { + None + } + broadcastSide.map { buildSide => + val broadcastJoin = BroadcastHashJoinExec( + leftKeys, + rightKeys, + joinType, + buildSide, + condition, + removeSort(left), + removeSort(right)) + // All shuffle read should be local instead of remote + broadcastJoin.children.foreach { + case input: ShuffleQueryStageInput => + input.isLocalShuffle = true + case _ => + } + + val newChild = queryStage.child.transformDown { + case s: SortMergeJoinExec if s.fastEquals(smj) => broadcastJoin + } + + val broadcastSidePlan = buildSide match { + case BuildLeft => removeSort(left) + case BuildRight => removeSort(right) + } + // Local shuffle read less partitions based on broadcastSide's row statistics + joinType match { + case _: InnerLike => + optimizeForLocalShuffleReadLessPartitions(broadcastSidePlan, broadcastJoin.children) + case _ => + } + + // Apply EnsureRequirement rule to check if any new Exchange will be added. If the added + // Exchange number less than spark.sql.adaptive.maxAdditionalShuffleNum, we convert the + // sortMergeJoin to BroadcastHashJoin. Otherwise we don't convert it because it causes + // additional Shuffle. + val afterEnsureRequirements = EnsureRequirements(conf).apply(newChild) + val numExchanges = afterEnsureRequirements.collect { + case e: ShuffleExchangeExec => e + }.length + + val topShuffleCheck = queryStage match { + case _: ShuffleQueryStage => afterEnsureRequirements.isInstanceOf[ShuffleExchangeExec] + case _ => true + } + val allowAdditionalShuffle = conf.adaptiveAllowAdditionShuffle + val noAdditionalShuffle = (numExchanges == 0) || + (queryStage.isInstanceOf[ShuffleQueryStage] && numExchanges <= 1) + if (topShuffleCheck && (allowAdditionalShuffle || noAdditionalShuffle)) { + // Update the plan in queryStage + queryStage.child = newChild + broadcastJoin + } else { + logWarning("Join optimization is not applied due to additional shuffles will be " + + "introduced. Enable spark.sql.adaptive.allowAdditionalShuffle to allow it.") + revertShuffleReadChanges(broadcastJoin.children) + smj + } + }.getOrElse(smj) + } + } + + private def optimizeJoin( + operator: SparkPlan, + queryStage: QueryStage): SparkPlan = { + operator match { + case smj: SortMergeJoinExec => + val op = optimizeSortMergeJoin(smj, queryStage) + val optimizedChildren = op.children.map(optimizeJoin(_, queryStage)) + op.withNewChildren(optimizedChildren) + case op => + val optimizedChildren = op.children.map(optimizeJoin(_, queryStage)) + op.withNewChildren(optimizedChildren) + } + } + + def apply(plan: SparkPlan): SparkPlan = { + if (!conf.adaptiveJoinEnabled) { + plan + } else { + plan match { + case queryStage: QueryStage => + val optimizedPlan = optimizeJoin(queryStage.child, queryStage) + queryStage.child = optimizedPlan + queryStage + case _ => plan + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanQueryStage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanQueryStage.scala new file mode 100644 index 0000000000000..9aac60b8697fa --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanQueryStage.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.command.ExecutedCommandExec +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ShuffleExchangeExec} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType + +/** + * Divide the spark plan into multiple QueryStages. For each Exchange in the plan, it adds a + * QueryStage and a QueryStageInput. If reusing Exchange is enabled, it finds duplicated exchanges + * and uses the same QueryStage for all the references. + */ +case class PlanQueryStage(conf: SQLConf) extends Rule[SparkPlan] { + + def apply(plan: SparkPlan): SparkPlan = { + + val newPlan = if (!conf.exchangeReuseEnabled) { + plan.transformUp { + case e: ShuffleExchangeExec => + ShuffleQueryStageInput(ShuffleQueryStage(e), e.output, isRepartition = false) + case e: BroadcastExchangeExec => + BroadcastQueryStageInput(BroadcastQueryStage(e), e.output) + } + } else { + // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls. + val stages = mutable.HashMap[StructType, ArrayBuffer[QueryStage]]() + + plan.transformUp { + case exchange: Exchange => + val sameSchema = stages.getOrElseUpdate(exchange.schema, ArrayBuffer[QueryStage]()) + val samePlan = sameSchema.find { s => + exchange.sameResult(s.child) + } + if (samePlan.isDefined) { + // Keep the output of this exchange, the following plans require that to resolve + // attributes. + exchange match { + case e: ShuffleExchangeExec => + ShuffleQueryStageInput(samePlan.get.asInstanceOf[ShuffleQueryStage], + exchange.output, + isRepartition = e.isRepartition.get) + case e: BroadcastExchangeExec => BroadcastQueryStageInput( + samePlan.get.asInstanceOf[BroadcastQueryStage], exchange.output) + } + } else { + val queryStageInput = exchange match { + case e: ShuffleExchangeExec => + ShuffleQueryStageInput(ShuffleQueryStage(e), + e.output, + isRepartition = e.isRepartition.get) + case e: BroadcastExchangeExec => + BroadcastQueryStageInput(BroadcastQueryStage(e), e.output) + } + sameSchema += queryStageInput.childStage + queryStageInput + } + } + } + ResultQueryStage(newPlan) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStage.scala new file mode 100644 index 0000000000000..4b9ab49b4458b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStage.scala @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration.Duration + +import org.apache.spark.{broadcast, MapOutputStatistics, SparkContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, PartitioningCollection} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.exchange._ +import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate +import org.apache.spark.util.ThreadUtils + +/** + * In adaptive execution mode, an execution plan is divided into multiple QueryStages. Each + * QueryStage is a sub-tree that runs in a single stage. + */ +abstract class QueryStage extends UnaryExecNode { + + var child: SparkPlan + + // Ignore this wrapper for canonicalizing. + override def doCanonicalize(): SparkPlan = child.canonicalized + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + /** + * Execute childStages and wait until all stages are completed. Use a thread pool to avoid + * blocking on one child stage. + */ + def executeChildStages(): Unit = { + val executionId = sqlContext.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + val jobDesc = sqlContext.sparkContext.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) + + // Handle broadcast stages + val broadcastQueryStages: Seq[BroadcastQueryStage] = child.collect { + case bqs: BroadcastQueryStageInput => bqs.childStage + } + val broadcastFutures = broadcastQueryStages.map { queryStage => + Future { + SQLExecution.withExecutionIdAndJobDesc(sqlContext.sparkContext, executionId, jobDesc) { + queryStage.prepareBroadcast() + } + }(QueryStage.executionContext) + } + + // Submit shuffle stages + val shuffleQueryStages: Seq[ShuffleQueryStage] = child.collect { + case sqs: ShuffleQueryStageInput => sqs.childStage + } + val shuffleStageFutures = shuffleQueryStages.map { queryStage => + Future { + SQLExecution.withExecutionIdAndJobDesc(sqlContext.sparkContext, executionId, jobDesc) { + queryStage.execute() + } + }(QueryStage.executionContext) + } + + ThreadUtils.awaitResult( + Future.sequence(broadcastFutures)(implicitly, QueryStage.executionContext), Duration.Inf) + ThreadUtils.awaitResult( + Future.sequence(shuffleStageFutures)(implicitly, QueryStage.executionContext), Duration.Inf) + } + + private var prepared = false + + /** + * Before executing the plan in this query stage, we execute all child stages, optimize the plan + * in this stage and determine the reducer number based on the child stages' statistics. Finally + * we do a codegen for this query stage and update the UI with the new plan. + */ + def prepareExecuteStage(): Unit = synchronized { + // Ensure the prepareExecuteStage method only be executed once. + if (prepared) { + return + } + // 1. Execute childStages + executeChildStages() + + // It is possible to optimize this stage's plan here based on the child stages' statistics. + val oldChild = child + OptimizeJoin(conf).apply(this) + HandleSkewedJoin(conf).apply(this) + // If the Joins are changed, we need apply EnsureRequirements rule to add BroadcastExchange. + if (!oldChild.fastEquals(child)) { + child = EnsureRequirements(conf).apply(child) + } + + // 2. Determine reducer number + val queryStageInputs: Seq[ShuffleQueryStageInput] = child.collect { + case input: ShuffleQueryStageInput if !input.isLocalShuffle => input + } + val childMapOutputStatistics = queryStageInputs.map(_.childStage.mapOutputStatistics) + .filter(_ != null).toArray + // Right now, Adaptive execution only support HashPartitionings and the same number of + // Check partitionings + val partitioningsCheck = queryStageInputs.forall { + _.outputPartitioning match { + case hash: HashPartitioning => true + case collection: PartitioningCollection => + collection.partitionings.forall(_.isInstanceOf[HashPartitioning]) + case _ => false + } + } + + val isRepartition = queryStageInputs.forall { + _.isRepartition + } + + // Check pre-shuffle partitions num + val numPreShufflePartitionsCheck = + childMapOutputStatistics.map(stats => stats.bytesByPartitionId.length).distinct.length == 1 + + if (childMapOutputStatistics.length > 0 + && partitioningsCheck + && !isRepartition + && numPreShufflePartitionsCheck) { + val exchangeCoordinator = new ExchangeCoordinator( + conf.targetPostShuffleInputSize, + conf.adaptiveTargetPostShuffleRowCount, + conf.minNumPostShufflePartitions) + + if (queryStageInputs.length == 2 && queryStageInputs.forall(_.skewedPartitions.isDefined)) { + // If a skewed join is detected and optimized, we will omit the skewed partitions when + // estimate the partition start and end indices. + val (partitionStartIndices, partitionEndIndices) = + exchangeCoordinator.estimatePartitionStartEndIndices( + childMapOutputStatistics, queryStageInputs(0).skewedPartitions.get) + queryStageInputs.foreach { i => + i.partitionStartIndices = Some(partitionStartIndices) + i.partitionEndIndices = Some(partitionEndIndices) + } + } else { + val partitionStartIndices = + exchangeCoordinator.estimatePartitionStartIndices(childMapOutputStatistics) + queryStageInputs.foreach(_.partitionStartIndices = Some(partitionStartIndices)) + } + } + + // 3. Codegen and update the UI + child = CollapseCodegenStages(sqlContext.conf).apply(child) + val executionId = sqlContext.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + if (executionId != null && executionId.nonEmpty) { + val queryExecution = SQLExecution.getQueryExecution(executionId.toLong) + sparkContext.listenerBus.post(SparkListenerSQLAdaptiveExecutionUpdate( + executionId.toLong, + queryExecution.toString, + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan))) + } + prepared = true + } + + // Caches the created ShuffleRowRDD so we can reuse that. + private var cachedRDD: RDD[InternalRow] = null + + def executeStage(): RDD[InternalRow] = child.execute() + + /** + * A QueryStage can be reused like Exchange. It is possible that multiple threads try to submit + * the same QueryStage. Use synchronized to make sure it is executed only once. + */ + override def doExecute(): RDD[InternalRow] = synchronized { + if (cachedRDD == null) { + prepareExecuteStage() + cachedRDD = executeStage() + } + cachedRDD + } + + override def executeCollect(): Array[InternalRow] = { + prepareExecuteStage() + child.executeCollect() + } + + override def executeToIterator(): Iterator[InternalRow] = { + prepareExecuteStage() + child.executeToIterator() + } + + override def executeTake(n: Int): Array[InternalRow] = { + prepareExecuteStage() + child.executeTake(n) + } + + override def generateTreeString( + depth: Int, + lastChildren: Seq[Boolean], + builder: StringBuilder, + verbose: Boolean, + prefix: String = "", + addSuffix: Boolean = false): StringBuilder = { + child.generateTreeString(depth, lastChildren, builder, verbose, "*") + } +} + +/** + * The last QueryStage of an execution plan. + */ +case class ResultQueryStage(var child: SparkPlan) extends QueryStage + +/** + * A shuffle QueryStage whose child is a ShuffleExchange. + */ +case class ShuffleQueryStage(var child: SparkPlan) extends QueryStage { + + protected var _mapOutputStatistics: MapOutputStatistics = null + + def mapOutputStatistics: MapOutputStatistics = _mapOutputStatistics + + override def executeStage(): RDD[InternalRow] = { + child match { + case e: ShuffleExchangeExec => + val result = e.eagerExecute() + _mapOutputStatistics = e.mapOutputStatistics + result + case _ => throw new IllegalArgumentException( + "The child of ShuffleQueryStage must be a ShuffleExchange.") + } + } +} + +/** + * A broadcast QueryStage whose child is a BroadcastExchangeExec. + */ +case class BroadcastQueryStage(var child: SparkPlan) extends QueryStage { + override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { + child.executeBroadcast() + } + + private var prepared = false + + def prepareBroadcast(): Unit = synchronized { + if (!prepared) { + executeChildStages() + child = CollapseCodegenStages(sqlContext.conf).apply(child) + // After child stages are completed, prepare() triggers the broadcast. + prepare() + prepared = true + } + } + + override def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException( + "BroadcastExchange does not support the execute() code path.") + } +} + +object QueryStage { + private[execution] val executionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("adaptive-query-stage")) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageInput.scala new file mode 100644 index 0000000000000..938bcba1abafe --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageInput.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import scala.collection.mutable + +import org.apache.spark.TaskContext +import org.apache.spark.broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SortOrder, UnsafeProjection} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.statsEstimation.Statistics +import org.apache.spark.util.TaskCompletionListener + +/** + * QueryStageInput is the leaf node of a QueryStage and is used to hide its child stage. It gets + * the result of its child stage and serves it as the input of the QueryStage. A QueryStage knows + * its child stages by collecting all the QueryStageInputs. + */ +abstract class QueryStageInput extends LeafExecNode { + + def childStage: QueryStage + + // Ignore this wrapper for canonicalizing. + override def doCanonicalize(): SparkPlan = childStage.canonicalized + + // Similar to ReusedExchangeExec, two QueryStageInputs can reference to the same childStage. + // QueryStageInput can have distinct set of output attribute ids from its childStage, we need + // to update the attribute ids in outputPartitioning and outputOrdering. + private lazy val updateAttr: Expression => Expression = { + val originalAttrToNewAttr = AttributeMap(childStage.output.zip(output)) + e => + e.transform { + case attr: Attribute => originalAttrToNewAttr.getOrElse(attr, attr) + } + } + + override def outputPartitioning: Partitioning = childStage.outputPartitioning match { + case h: HashPartitioning => h.copy(expressions = h.expressions.map(updateAttr)) + case other => other + } + + override def outputOrdering: Seq[SortOrder] = { + childStage.outputOrdering.map(updateAttr(_).asInstanceOf[SortOrder]) + } + + override def computeStats(): Statistics = { + childStage.stats + } + + override def generateTreeString( + depth: Int, + lastChildren: Seq[Boolean], + builder: StringBuilder, + verbose: Boolean, + prefix: String = "", + addSuffix: Boolean = false): StringBuilder = { + childStage.generateTreeString(depth, lastChildren, builder, verbose, "*") + } +} + +/** + * A QueryStageInput whose child stage is a ShuffleQueryStage. It returns a new ShuffledRowRDD + * based on the the child stage's result RDD and the specified partitionStartIndices. If the + * child stage is reused by another ShuffleQueryStageInput, they can return RDDs with different + * partitionStartIndices. + */ +case class ShuffleQueryStageInput(childStage: ShuffleQueryStage, + override val output: Seq[Attribute], + var isLocalShuffle: Boolean = false, + var skewedPartitions: Option[mutable.HashSet[Int]] = None, + var partitionStartIndices: Option[Array[Int]] = None, + var partitionEndIndices: Option[Array[Int]] = None, + isRepartition: Boolean) + extends QueryStageInput { + + override def outputPartitioning: Partitioning = partitionStartIndices.map { + indices => UnknownPartitioning(indices.length) + }.getOrElse(super.outputPartitioning) + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "readBytes" -> SQLMetrics.createMetric(sparkContext, "number of read bytes") + ) + + override def doExecute(): RDD[InternalRow] = { + val childRDD = childStage.execute().asInstanceOf[ShuffledRowRDD] + val shuffleRDD: RDD[InternalRow] = if (isLocalShuffle) { + new LocalShuffledRowRDD(childRDD.dependency, partitionStartIndices, partitionEndIndices) + } else { + new ShuffledRowRDD(childRDD.dependency, partitionStartIndices, partitionEndIndices) + } + + val numOutputRows = longMetric("numOutputRows") + // Avoid to serialize MetastoreRelation because schema is lazy. (see SPARK-15649) + val outputSchema = schema + shuffleRDD.mapPartitionsWithIndexInternal { (index, iter) => + TaskContext.get().addTaskCompletionListener(new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = { + longMetric("readBytes").add(context.taskMetrics().inputMetrics.bytesRead) + } + }) + val proj = UnsafeProjection.create(outputSchema) + proj.initialize(index) + iter.map { r => + numOutputRows += 1 + proj(r) + } + } + } + + def numMapper(): Int = { + val childRDD = childStage.execute().asInstanceOf[ShuffledRowRDD] + childRDD.dependency.rdd.partitions.length + } +} + +/** + * A QueryStageInput that reads part of a single partition.The partition is divided into several + * splits and it only reads one of the splits ranging from startMapId to endMapId (exclusive). + */ +case class SkewedShuffleQueryStageInput( + childStage: ShuffleQueryStage, + override val output: Seq[Attribute], + partitionId: Int, + startMapId: Int, + endMapId: Int) + extends QueryStageInput { + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "readBytes" -> SQLMetrics.createMetric(sparkContext, "number of read bytes") + ) + + override def doExecute(): RDD[InternalRow] = { + val childRDD = childStage.execute().asInstanceOf[ShuffledRowRDD] + val shuffleRDD: RDD[InternalRow] = new AdaptiveShuffledRowRDD( + childRDD.dependency, + partitionId, + Some(Array(startMapId)), + Some(Array(endMapId))) + + val numOutputRows = longMetric("numOutputRows") + // Avoid to serialize MetastoreRelation because schema is lazy. (see SPARK-15649) + val outputSchema = schema + shuffleRDD.mapPartitionsWithIndexInternal { (index, iter) => + TaskContext.get().addTaskCompletionListener(new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = { + longMetric("readBytes").add(context.taskMetrics().inputMetrics.bytesRead) + } + }) + val proj = UnsafeProjection.create(outputSchema) + proj.initialize(index) + iter.map { r => + numOutputRows += 1 + proj(r) + } + } + } +} + +/** A QueryStageInput whose child stage is a BroadcastQueryStage. */ +case class BroadcastQueryStageInput( + childStage: BroadcastQueryStage, + override val output: Seq[Attribute]) + extends QueryStageInput { + + override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { + childStage.executeBroadcast() + } + + override def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException( + "BroadcastStageInput does not support the execute() code path.") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 0a3729b384c00..ee296494d3f20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.statsEstimation.Statistics import org.apache.spark.sql.types.LongType import org.apache.spark.util.ThreadUtils import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} @@ -567,6 +568,10 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) } override def simpleString: String = s"Range ($start, $end, step=$step, splits=$numSlices)" + + override def computeStats: Statistics = { + Statistics(LongType.defaultSize * numElements) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 196d057c2de1b..69012bc74497e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -23,13 +23,13 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.execution.statsEstimation.Statistics import org.apache.spark.sql.execution.vectorized._ import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} - case class InMemoryTableScanExec( attributes: Seq[Attribute], predicates: Seq[Expression], @@ -180,8 +180,8 @@ case class InMemoryTableScanExec( relation.cachedPlan.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) // Keeps relation's partition statistics because we don't serialize relation. - private val stats = relation.partitionStatistics - private def statsFor(a: Attribute) = stats.forAttribute(a) + private val partitionStats = relation.partitionStatistics + private def statsFor(a: Attribute) = partitionStats.forAttribute(a) // Currently, only use statistics from atomic types except binary type only. private object ExtractableLiteral { @@ -248,7 +248,7 @@ case class InMemoryTableScanExec( filter.map( BindReferences.bindReference( _, - stats.schema, + partitionStats.schema, allowFailures = true)) boundFilter.foreach(_ => @@ -271,7 +271,7 @@ case class InMemoryTableScanExec( private def filteredCachedBatches(): RDD[CachedBatch] = { // Using these variables here to avoid serialization of entire objects (if referenced directly) // within the map Partitions closure. - val schema = stats.schema + val schema = partitionStats.schema val schemaIndex = schema.zipWithIndex val buffers = relation.cacheBuilder.cachedColumnBuffers @@ -310,4 +310,9 @@ case class InMemoryTableScanExec( inputRDD } } + + override def computeStats(): Statistics = { + val stats = relation.computeStats() + Statistics(stats.sizeInBytes) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 9652f4881327a..3e6af24b84c3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -36,115 +36,12 @@ import org.apache.spark.sql.internal.SQLConf * the input partition ordering requirements are met. */ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { - private def defaultNumPreShufflePartitions: Int = conf.numShufflePartitions - - private def targetPostShuffleInputSize: Long = conf.targetPostShuffleInputSize - - private def maxTargetPostShuffleInputSize: Long = conf.maxTargetPostShuffleInputSize - - private def adaptiveExecutionEnabled: Boolean = conf.adaptiveExecutionEnabled - - private def minNumPostShufflePartitions: Option[Int] = { - val minNumPostShufflePartitions = conf.minNumPostShufflePartitions - if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None - } - - /** - * Adds [[ExchangeCoordinator]] to [[ShuffleExchangeExec]]s if adaptive query execution is enabled - * and partitioning schemes of these [[ShuffleExchangeExec]]s support [[ExchangeCoordinator]]. - */ - private def withExchangeCoordinator( - children: Seq[SparkPlan], - requiredChildDistributions: Seq[Distribution]): Seq[SparkPlan] = { - val supportsCoordinator = - if (children.exists(_.isInstanceOf[ShuffleExchangeExec])) { - // Right now, ExchangeCoordinator only support HashPartitionings. - children.forall { - case e @ ShuffleExchangeExec(hash: HashPartitioning, _, _, _) => - if (e.isRepartition.get && !conf.allowAEwhenRepartition) { - false - } else { - true - } - case child => - child.outputPartitioning match { - case hash: HashPartitioning => true - case collection: PartitioningCollection => - collection.partitionings.forall(_.isInstanceOf[HashPartitioning]) - case _ => false - } - } - } else { - // In this case, although we do not have Exchange operators, we may still need to - // shuffle data when we have more than one children because data generated by - // these children may not be partitioned in the same way. - // Please see the comment in withCoordinator for more details. - val supportsDistribution = requiredChildDistributions.forall { dist => - dist.isInstanceOf[ClusteredDistribution] || dist.isInstanceOf[HashClusteredDistribution] - } - children.length > 1 && supportsDistribution - } - - val withCoordinator = - if (adaptiveExecutionEnabled && supportsCoordinator) { - val coordinator = - new ExchangeCoordinator( - targetPostShuffleInputSize, - minNumPostShufflePartitions, - maxTargetPostShuffleInputSize) - children.zip(requiredChildDistributions).map { - case (e @ ShuffleExchangeExec(_, _, _, _), _) => - // This child is an Exchange, we need to add the coordinator. - e.copy(coordinator = Some(coordinator), isRepartition = e.isRepartition) - case (child, distribution) => - // If this child is not an Exchange, we need to add an Exchange for now. - // Ideally, we can try to avoid this Exchange. However, when we reach here, - // there are at least two children operators (because if there is a single child - // and we can avoid Exchange, supportsCoordinator will be false and we - // will not reach here.). Although we can make two children have the same number of - // post-shuffle partitions. Their numbers of pre-shuffle partitions may be different. - // For example, let's say we have the following plan - // Join - // / \ - // Agg Exchange - // / \ - // Exchange t2 - // / - // t1 - // In this case, because a post-shuffle partition can include multiple pre-shuffle - // partitions, a HashPartitioning will not be strictly partitioned by the hashcodes - // after shuffle. So, even we can use the child Exchange operator of the Join to - // have a number of post-shuffle partitions that matches the number of partitions of - // Agg, we cannot say these two children are partitioned in the same way. - // Here is another case - // Join - // / \ - // Agg1 Agg2 - // / \ - // Exchange1 Exchange2 - // / \ - // t1 t2 - // In this case, two Aggs shuffle data with the same column of the join condition. - // After we use ExchangeCoordinator, these two Aggs may not be partitioned in the same - // way. Let's say that Agg1 and Agg2 both have 5 pre-shuffle partitions and 2 - // post-shuffle partitions. It is possible that Agg1 fetches those pre-shuffle - // partitions by using a partitionStartIndices [0, 3]. However, Agg2 may fetch its - // pre-shuffle partitions by using another partitionStartIndices [0, 4]. - // So, Agg1 and Agg2 are actually not co-partitioned. - // - // It will be great to introduce a new Partitioning to represent the post-shuffle - // partitions when one post-shuffle partition includes multiple pre-shuffle partitions. - val targetPartitioning = distribution.createPartitioning(defaultNumPreShufflePartitions) - assert(targetPartitioning.isInstanceOf[HashPartitioning]) - ShuffleExchangeExec(targetPartitioning, child, Some(coordinator), Some(false)) - } - } else { - // If we do not need ExchangeCoordinator, the original children are returned. - children - } - - withCoordinator - } + private def defaultNumPreShufflePartitions: Int = + if (conf.adaptiveExecutionEnabled) { + conf.maxNumPostShufflePartitions + } else { + conf.numShufflePartitions + } private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = { val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution @@ -153,9 +50,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { assert(requiredChildDistributions.length == children.length) assert(requiredChildOrderings.length == children.length) - logInfo(s"EnsureRequirements in operator ${operator.getClass.getSimpleName}" + - s", requiredChildDistributions:$requiredChildDistributions" + - s", requiredChildOrderings:$requiredChildOrderings.") // Ensure that the operator's children satisfy their output distribution requirements. children = children.zip(requiredChildDistributions).map { case (child, distribution) if child.outputPartitioning.satisfies(distribution) => @@ -200,8 +94,10 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { val defaultPartitioning = distribution.createPartitioning(targetNumPartitions) child match { // If child is an exchange, we replace it with a new one having defaultPartitioning. - case ShuffleExchangeExec(_, c, _, Some(false)) => + case ShuffleExchangeExec(_, c, Some(false)) => ShuffleExchangeExec(defaultPartitioning, c) + case ShuffleExchangeExec(_, c, Some(true)) => + ShuffleExchangeExec(defaultPartitioning, c, Some(true)) case _ => ShuffleExchangeExec(defaultPartitioning, child) } } @@ -210,15 +106,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { } } - // Now, we need to add ExchangeCoordinator if necessary. - // Actually, it is not a good idea to add ExchangeCoordinators while we are adding Exchanges. - // However, with the way that we plan the query, we do not have a place where we have a - // global picture of all shuffle dependencies of a post-shuffle stage. So, we add coordinator - // at here for now. - // Once we finish https://issues.apache.org/jira/browse/SPARK-10665, - // we can first add Exchanges and then add coordinator once we have a DAG of query fragments. - children = withExchangeCoordinator(children, requiredChildDistributions) - // Now that we've performed any necessary shuffles, add sorts to guarantee output orderings: children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) => // If child.outputOrdering already satisfies the requiredOrdering, we do not need to sort. @@ -233,10 +120,10 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { } private def reorder( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - expectedOrderOfKeys: Seq[Expression], - currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + expectedOrderOfKeys: Seq[Expression], + currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { val leftKeysBuffer = ArrayBuffer[Expression]() val rightKeysBuffer = ArrayBuffer[Expression]() val pickedIndexes = mutable.Set[Int]() @@ -256,10 +143,10 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { } private def reorderJoinKeys( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - leftPartitioning: Partitioning, - rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = { + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + leftPartitioning: Partitioning, + rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = { if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) { leftPartitioning match { case HashPartitioning(leftExpressions, _) @@ -290,6 +177,13 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { */ private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = { plan match { + case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, + right) => + val (reorderedLeftKeys, reorderedRightKeys) = + reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) + BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, + left, right) + case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => val (reorderedLeftKeys, reorderedRightKeys) = reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) @@ -307,7 +201,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = plan.transformUp { // TODO: remove this after we create a physical operator for `RepartitionByExpression`. - case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, _, _) => + case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, _) => child.outputPartitioning match { case lower: HashPartitioning if upper.semanticEquals(lower) => child case _ => operator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala index 1a5b7599bb7d9..d46d17ed43d37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expre import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.statsEstimation.Statistics import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType @@ -77,6 +78,10 @@ case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchan override def outputOrdering: Seq[SortOrder] = { child.outputOrdering.map(updateAttr(_).asInstanceOf[SortOrder]) } + + override def computeStats(): Statistics = { + child.stats + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala index 61a521dabdd4d..d5f2f6a6325d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala @@ -17,65 +17,40 @@ package org.apache.spark.sql.execution.exchange -import java.util.{HashMap => JHashMap, Map => JMap} -import javax.annotation.concurrent.GuardedBy - +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{MapOutputStatistics, ShuffleDependency, SimpleFutureAction} +import org.apache.spark.MapOutputStatistics import org.apache.spark.internal.Logging -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan} /** * A coordinator used to determines how we shuffle data between stages generated by Spark SQL. * Right now, the work of this coordinator is to determine the number of post-shuffle partitions * for a stage that needs to fetch shuffle data from one or multiple stages. * - * A coordinator is constructed with three parameters, `numExchanges`, - * `targetPostShuffleInputSize`, and `minNumPostShufflePartitions`. - * - `numExchanges` is used to indicated that how many [[ShuffleExchangeExec]]s that will be - * registered to this coordinator. So, when we start to do any actual work, we have a way to - * make sure that we have got expected number of [[ShuffleExchangeExec]]s. + * A coordinator is constructed with three parameters, `targetPostShuffleInputSize`, + * `targetPostShuffleRowCount` and `minNumPostShufflePartitions`. * - `targetPostShuffleInputSize` is the targeted size of a post-shuffle partition's * input data size. With this parameter, we can estimate the number of post-shuffle partitions. * This parameter is configured through * `spark.sql.adaptive.shuffle.targetPostShuffleInputSize`. - * - `minNumPostShufflePartitions` is an optional parameter. If it is defined, this coordinator - * will try to make sure that there are at least `minNumPostShufflePartitions` post-shuffle - * partitions. - * - * The workflow of this coordinator is described as follows: - * - Before the execution of a [[SparkPlan]], for a [[ShuffleExchangeExec]] operator, - * if an [[ExchangeCoordinator]] is assigned to it, it registers itself to this coordinator. - * This happens in the `doPrepare` method. - * - Once we start to execute a physical plan, a [[ShuffleExchangeExec]] registered to this - * coordinator will call `postShuffleRDD` to get its corresponding post-shuffle - * [[ShuffledRowRDD]]. - * If this coordinator has made the decision on how to shuffle data, this [[ShuffleExchangeExec]] - * will immediately get its corresponding post-shuffle [[ShuffledRowRDD]]. - * - If this coordinator has not made the decision on how to shuffle data, it will ask those - * registered [[ShuffleExchangeExec]]s to submit their pre-shuffle stages. Then, based on the - * size statistics of pre-shuffle partitions, this coordinator will determine the number of - * post-shuffle partitions and pack multiple pre-shuffle partitions with continuous indices - * to a single post-shuffle partition whenever necessary. - * - Finally, this coordinator will create post-shuffle [[ShuffledRowRDD]]s for all registered - * [[ShuffleExchangeExec]]s. So, when a [[ShuffleExchangeExec]] calls `postShuffleRDD`, this - * coordinator can lookup the corresponding [[RDD]]. + * - `targetPostShuffleRowCount` is the targeted row count of a post-shuffle partition's + * input row count. This is set through + * `spark.sql.adaptive.shuffle.adaptiveTargetPostShuffleRowCount`. + * - `minNumPostShufflePartitions` is used to make sure that there are at least + * `minNumPostShufflePartitions` post-shuffle partitions. * * The strategy used to determine the number of post-shuffle partitions is described as follows. - * To determine the number of post-shuffle partitions, we have a target input size for a - * post-shuffle partition. Once we have size statistics of pre-shuffle partitions from stages - * corresponding to the registered [[ShuffleExchangeExec]]s, we will do a pass of those statistics - * and pack pre-shuffle partitions with continuous indices to a single post-shuffle partition until - * adding another pre-shuffle partition would cause the size of a post-shuffle partition to be - * greater than the target size. + * To determine the number of post-shuffle partitions, we have a target input size and row count + * for a post-shuffle partition. Once we have size and row count statistics of all pre-shuffle + * partitions, we will do a pass of those statistics and pack pre-shuffle partitions with + * continuous indices to a single post-shuffle partition until adding another pre-shuffle partition + * would cause the size or row count of a post-shuffle partition to be greater than the target. * * For example, we have two stages with the following pre-shuffle partition size statistics: * stage 1: [100 MB, 20 MB, 100 MB, 10MB, 30 MB] * stage 2: [10 MB, 10 MB, 70 MB, 5 MB, 5 MB] - * assuming the target input size is 128 MB, we will have four post-shuffle partitions, + * assuming the target input size is 128 MB, we will have three post-shuffle partitions, * which are: * - post-shuffle partition 0: pre-shuffle partition 0 (size 110 MB) * - post-shuffle partition 1: pre-shuffle partition 1 (size 30 MB) @@ -83,67 +58,49 @@ import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan} * - post-shuffle partition 3: pre-shuffle partition 3 and 4 (size 50 MB) */ class ExchangeCoordinator( - advisoryTargetPostShuffleInputSize: Long, - minNumPostShufflePartitions: Option[Int] = None, - maxTargetPostShuffleInputSize: Long = -1L, - encounterRepartition: Boolean = false) + advisoryTargetPostShuffleInputSize: Long, + targetPostShuffleRowCount: Long, + minNumPostShufflePartitions: Int = 1, + maxTargetPostShuffleInputSize: Long = -1L, + encounterRepartition: Boolean = false) extends Logging { - // The registered Exchange operators. - private[this] val exchanges = ArrayBuffer[ShuffleExchangeExec]() - - // `lazy val` is used here so that we could notice the wrong use of this class, e.g., all the - // exchanges should be registered before `postShuffleRDD` called first time. If a new exchange is - // registered after the `postShuffleRDD` call, `assert(exchanges.length == numExchanges)` fails - // in `doEstimationIfNecessary`. - private[this] lazy val numExchanges = exchanges.size - - // This map is used to lookup the post-shuffle ShuffledRowRDD for an Exchange operator. - private[this] lazy val postShuffleRDDs: JMap[ShuffleExchangeExec, ShuffledRowRDD] = - new JHashMap[ShuffleExchangeExec, ShuffledRowRDD](numExchanges) - - // A boolean that indicates if this coordinator has made decision on how to shuffle data. - // This variable will only be updated by doEstimationIfNecessary, which is protected by - // synchronized. - @volatile private[this] var estimated: Boolean = false - /** - * Registers a [[ShuffleExchangeExec]] operator to this coordinator. This method is only allowed - * to be called in the `doPrepare` method of a [[ShuffleExchangeExec]] operator. + * Estimates partition start indices for post-shuffle partitions based on + * mapOutputStatistics provided by all pre-shuffle stages. */ - @GuardedBy("this") - def registerExchange(exchange: ShuffleExchangeExec): Unit = synchronized { - exchanges += exchange + def estimatePartitionStartIndices(mapOutputStatistics: Array[MapOutputStatistics]): Array[Int] = { + estimatePartitionStartEndIndices(mapOutputStatistics, mutable.HashSet.empty)._1 } - def isEstimated: Boolean = estimated - /** * Estimates partition start indices for post-shuffle partitions based on - * mapOutputStatistics provided by all pre-shuffle stages. + * mapOutputStatistics provided by all pre-shuffle stages and omitted skewed partitions which have + * been taken care of in HandleSkewedJoin. */ - def estimatePartitionStartIndices( - mapOutputStatistics: Array[MapOutputStatistics]): Array[Int] = { + def estimatePartitionStartEndIndices( + mapOutputStatistics: Array[MapOutputStatistics], + omittedPartitions: mutable.HashSet[Int]): (Array[Int], Array[Int]) = { + + assert(omittedPartitions.size < mapOutputStatistics(0).bytesByPartitionId.length, + "All partitions are skewed.") + // If minNumPostShufflePartitions is defined, it is possible that we need to use a // value less than advisoryTargetPostShuffleInputSize as the target input size of // a post shuffle task. - val targetPostShuffleInputSize = minNumPostShufflePartitions match { - case Some(numPartitions) => - val totalPostShuffleInputSize = mapOutputStatistics.map(_.bytesByPartitionId.sum).sum - // The max at here is to make sure that when we have an empty table, we - // only have a single post-shuffle partition. - // There is no particular reason that we pick 16. We just need a number to - // prevent maxPostShuffleInputSize from being set to 0. - val maxPostShuffleInputSize = - math.max(math.ceil(totalPostShuffleInputSize / numPartitions.toDouble).toLong, 16) - math.min(maxPostShuffleInputSize, advisoryTargetPostShuffleInputSize) - - case None => advisoryTargetPostShuffleInputSize - } + val totalPostShuffleInputSize = mapOutputStatistics.map(_.bytesByPartitionId.sum).sum + // The max at here is to make sure that when we have an empty table, we + // only have a single post-shuffle partition. + // There is no particular reason that we pick 16. We just need a number to + // prevent maxPostShuffleInputSize from being set to 0. + val maxPostShuffleInputSize = math.max( + math.ceil(totalPostShuffleInputSize / minNumPostShufflePartitions.toDouble).toLong, 16) + val targetPostShuffleInputSize = + math.min(maxPostShuffleInputSize, advisoryTargetPostShuffleInputSize) logInfo( s"advisoryTargetPostShuffleInputSize: $advisoryTargetPostShuffleInputSize, " + - s"targetPostShuffleInputSize $targetPostShuffleInputSize.") + s"targetPostShuffleInputSize $targetPostShuffleInputSize. ") // Make sure we do get the same number of pre-shuffle partitions for those stages. val distinctNumPreShufflePartitions = @@ -161,130 +118,64 @@ class ExchangeCoordinator( val numPreShufflePartitions = distinctNumPreShufflePartitions.head val partitionStartIndices = ArrayBuffer[Int]() - // The first element of partitionStartIndices is always 0. - partitionStartIndices += 0 + val partitionEndIndices = ArrayBuffer[Int]() - var postShuffleInputSize = 0L + def nextStartIndex(i: Int): Int = { + var index = i + while (index < numPreShufflePartitions && omittedPartitions.contains(index)) { + index = index + 1 + } + index + } - var i = 0 - while (i < numPreShufflePartitions) { - // We calculate the total size of ith pre-shuffle partitions from all pre-shuffle stages. - // Then, we add the total size to postShuffleInputSize. - var nextShuffleInputSize = 0L + def partitionSizeAndRowCount(partitionId: Int): (Long, Long) = { + var size = 0L + var rowCount = 0L var j = 0 while (j < mapOutputStatistics.length) { - nextShuffleInputSize += mapOutputStatistics(j).bytesByPartitionId(i) + val statistics = mapOutputStatistics(j) + size += statistics.bytesByPartitionId(partitionId) + if (statistics.recordsByPartitionId.nonEmpty) { + rowCount += statistics.recordsByPartitionId(partitionId) + } j += 1 } - - // If including the nextShuffleInputSize would exceed the target partition size, then start a - // new partition. - if (i > 0 && postShuffleInputSize + nextShuffleInputSize > targetPostShuffleInputSize) { - partitionStartIndices += i - // reset postShuffleInputSize. - postShuffleInputSize = nextShuffleInputSize - } else postShuffleInputSize += nextShuffleInputSize - - i += 1 + (size, rowCount) } - partitionStartIndices.toArray - } - - @GuardedBy("this") - private def doEstimationIfNecessary(): Unit = synchronized { - // It is unlikely that this method will be called from multiple threads - // (when multiple threads trigger the execution of THIS physical) - // because in common use cases, we will create new physical plan after - // users apply operations (e.g. projection) to an existing DataFrame. - // However, if it happens, we have synchronized to make sure only one - // thread will trigger the job submission. - if (!estimated) { - // Make sure we have the expected number of registered Exchange operators. - assert(exchanges.length == numExchanges) - - val newPostShuffleRDDs = new JHashMap[ShuffleExchangeExec, ShuffledRowRDD](numExchanges) - - // Submit all map stages - val shuffleDependencies = ArrayBuffer[ShuffleDependency[Int, InternalRow, InternalRow]]() - val submittedStageFutures = ArrayBuffer[SimpleFutureAction[MapOutputStatistics]]() - var i = 0 - while (i < numExchanges) { - val exchange = exchanges(i) - val shuffleDependency = exchange.prepareShuffleDependency() - shuffleDependencies += shuffleDependency - if (shuffleDependency.rdd.partitions.length != 0) { - // submitMapStage does not accept RDD with 0 partition. - // So, we will not submit this dependency. - submittedStageFutures += - exchange.sqlContext.sparkContext.submitMapStage(shuffleDependency) - } + val firstStartIndex = nextStartIndex(0) + partitionStartIndices += firstStartIndex + var (postShuffleInputSize, postShuffleInputRowCount) = partitionSizeAndRowCount(firstStartIndex) + + var i = firstStartIndex + var nextIndex = nextStartIndex(i + 1) + while (nextIndex < numPreShufflePartitions) { + val (nextShuffleInputSize, nextShuffleInputRowCount) = partitionSizeAndRowCount(nextIndex) + // If the next partition is omitted, or including the nextShuffleInputSize would exceed the + // target partition size, then start a new partition. + if (nextIndex != i + 1 + || postShuffleInputSize + nextShuffleInputSize > targetPostShuffleInputSize + || postShuffleInputRowCount + nextShuffleInputRowCount > targetPostShuffleRowCount) { + partitionEndIndices += i + 1 + partitionStartIndices += nextIndex + postShuffleInputSize = nextShuffleInputSize + postShuffleInputRowCount = nextShuffleInputRowCount + i = nextIndex + } else { + postShuffleInputSize += nextShuffleInputSize + postShuffleInputRowCount += nextShuffleInputRowCount i += 1 } - - // Wait for the finishes of those submitted map stages. - val mapOutputStatistics = new Array[MapOutputStatistics](submittedStageFutures.length) - var j = 0 - while (j < submittedStageFutures.length) { - // This call is a blocking call. If the stage has not finished, we will wait at here. - mapOutputStatistics(j) = submittedStageFutures(j).get() - j += 1 - } - - // If we have mapOutputStatistics.length < numExchange, it is because we do not submit - // a stage when the number of partitions of this dependency is 0. - assert(mapOutputStatistics.length <= numExchanges) - - // Now, we estimate partitionStartIndices. partitionStartIndices.length will be the - // number of post-shuffle partitions. - val partitionStartIndices = - if (mapOutputStatistics.length == 0) { - Array(0) - } else { - estimatePartitionStartIndices(mapOutputStatistics) - } - if (maxTargetPostShuffleInputSize > 0 && exchanges.forall(_.isRepartition.get)) { - val moreThan = mapOutputStatistics.apply(0) - .bytesByPartitionId - .zipWithIndex - .filter(_._1 > maxTargetPostShuffleInputSize) - .map(_._2) - - if (moreThan.nonEmpty) { - throw new DetectDataSkewException(moreThan) - } - } - var k = 0 - while (k < numExchanges) { - val exchange = exchanges(k) - val rdd = - exchange.preparePostShuffleRDD(shuffleDependencies(k), Some(partitionStartIndices)) - newPostShuffleRDDs.put(exchange, rdd) - - k += 1 - } - - // Finally, we set postShuffleRDDs and estimated. - assert(postShuffleRDDs.isEmpty) - assert(newPostShuffleRDDs.size() == numExchanges) - postShuffleRDDs.putAll(newPostShuffleRDDs) - estimated = true - } - } - - def postShuffleRDD(exchange: ShuffleExchangeExec): ShuffledRowRDD = { - doEstimationIfNecessary() - - if (!postShuffleRDDs.containsKey(exchange)) { - throw new IllegalStateException( - s"The given $exchange is not registered in this coordinator.") + nextIndex = nextStartIndex(nextIndex + 1) } + partitionEndIndices += i + 1 - postShuffleRDDs.get(exchange) + (partitionStartIndices.toArray, partitionEndIndices.toArray) } override def toString: String = { - s"coordinator[target post-shuffle partition size: $advisoryTargetPostShuffleInputSize]" + s"coordinator[target post-shuffle partition size: $advisoryTargetPostShuffleInputSize]" + + s"coordinator[target post-shuffle row count: $targetPostShuffleRowCount]" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 57e7c80f2c352..fb438c7773914 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -41,7 +41,6 @@ import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordCo */ case class ShuffleExchangeExec(var newPartitioning: Partitioning, child: SparkPlan, - @transient coordinator: Option[ExchangeCoordinator], isRepartition: Option[Boolean] = Some(false)) extends Exchange { // NOTE: coordinator can be null after serialization/deserialization, @@ -51,14 +50,7 @@ case class ShuffleExchangeExec(var newPartitioning: Partitioning, "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size")) override def nodeName: String = { - val extraInfo = coordinator match { - case Some(exchangeCoordinator) => - s"(coordinator id: ${System.identityHashCode(exchangeCoordinator)})" - case _ => "" - } - - val simpleNodeName = "Exchange" - s"$simpleNodeName$extraInfo" + "Exchange" } override def outputPartitioning: Partitioning = newPartitioning @@ -66,21 +58,6 @@ case class ShuffleExchangeExec(var newPartitioning: Partitioning, private val serializer: Serializer = new UnsafeRowSerializer(child.output.size, longMetric("dataSize")) - override protected def doPrepare(): Unit = { - // If an ExchangeCoordinator is needed, we register this Exchange operator - // to the coordinator when we do prepare. It is important to make sure - // we register this operator right before the execution instead of register it - // in the constructor because it is possible that we create new instances of - // Exchange operators when we transform the physical plan - // (then the ExchangeCoordinator will hold references of unneeded Exchanges). - // So, we should only call registerExchange just before we start to execute - // the plan. - coordinator match { - case Some(exchangeCoordinator) => exchangeCoordinator.registerExchange(this) - case _ => - } - } - /** * Returns a [[ShuffleDependency]] that will partition rows of its child based on * the partitioning scheme defined in `newPartitioning`. Those partitions of @@ -119,15 +96,26 @@ case class ShuffleExchangeExec(var newPartitioning: Partitioning, protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { // Returns the same ShuffleRowRDD if this plan is used by multiple plans. if (cachedShuffleRDD == null) { - cachedShuffleRDD = coordinator match { - case Some(exchangeCoordinator) => - val shuffleRDD = exchangeCoordinator.postShuffleRDD(this) - assert(shuffleRDD.partitions.length == newPartitioning.numPartitions) - shuffleRDD - case _ => - val shuffleDependency = prepareShuffleDependency() - preparePostShuffleRDD(shuffleDependency) + val shuffleDependency = prepareShuffleDependency() + cachedShuffleRDD = preparePostShuffleRDD(shuffleDependency) + } + cachedShuffleRDD + } + + private var _mapOutputStatistics: MapOutputStatistics = null + + def mapOutputStatistics: MapOutputStatistics = _mapOutputStatistics + + def eagerExecute(): RDD[InternalRow] = { + if (cachedShuffleRDD == null) { + val shuffleDependency = prepareShuffleDependency() + if (shuffleDependency.rdd.partitions.length != 0) { + // submitMapStage does not accept RDD with 0 partition. + // So, we will not submit this dependency. + val submittedStageFuture = sqlContext.sparkContext.submitMapStage(shuffleDependency) + _mapOutputStatistics = submittedStageFuture.get() } + cachedShuffleRDD = preparePostShuffleRDD(shuffleDependency) } cachedShuffleRDD } @@ -137,7 +125,6 @@ object ShuffleExchangeExec { def apply(newPartitioning: Partitioning, child: SparkPlan): ShuffleExchangeExec = { ShuffleExchangeExec(newPartitioning, child, - coordinator = Option.empty[ExchangeCoordinator], isRepartition = Some(false)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala new file mode 100644 index 0000000000000..88d8e97a9da30 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.statsEstimation + +import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive.ShuffleQueryStage +import org.apache.spark.sql.execution.aggregate._ +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.joins.{HashJoin, SortMergeJoinExec} + +object SizeInBytesOnlyStatsPlanVisitor extends SparkPlanVisitor[Statistics] { + + private def visitUnaryExecNode(p: UnaryExecNode): Statistics = { + // There should be some overhead in Row object, the size should not be zero when there is + // no columns, this help to prevent divide-by-zero error. + val childRowSize = p.child.output.map(_.dataType.defaultSize).sum + 8 + val outputRowSize = p.output.map(_.dataType.defaultSize).sum + 8 + // Assume there will be the same number of rows as child has. + var sizeInBytes = (p.child.stats.sizeInBytes * outputRowSize) / childRowSize + if (sizeInBytes == 0) { + // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero + // (product of children). + sizeInBytes = 1 + } + + // Don't propagate rowCount and attributeStats, since they are not estimated here. + Statistics(sizeInBytes = sizeInBytes) + } + + override def default(p: SparkPlan): Statistics = p match { + case p: LeafExecNode => p.computeStats() + case _: SparkPlan => Statistics(sizeInBytes = p.children.map(_.stats.sizeInBytes).product) + } + + override def visitFilterExec(p: FilterExec): Statistics = visitUnaryExecNode(p) + + override def visitProjectExec(p: ProjectExec): Statistics = visitUnaryExecNode(p) + + override def visitHashAggregateExec(p: HashAggregateExec): Statistics = { + if (p.groupingExpressions.isEmpty) { + val sizeInBytes = 8 + p.output.map(_.dataType.defaultSize).sum + Statistics(sizeInBytes) + } else { + visitUnaryExecNode(p) + } + } + + override def visitHashJoin(p: HashJoin): Statistics = { + p.joinType match { + case LeftAnti | LeftSemi => + // LeftSemi and LeftAnti won't ever be bigger than left + p.left.stats + case _ => + Statistics(sizeInBytes = p.left.stats.sizeInBytes * p.right.stats.sizeInBytes) + } + } + + override def visitShuffleExchangeExec(p: ShuffleExchangeExec): Statistics = { + if (p.mapOutputStatistics != null) { + val sizeInBytes = p.mapOutputStatistics.bytesByPartitionId.sum + val bytesByPartitionId = p.mapOutputStatistics.bytesByPartitionId + if (p.mapOutputStatistics.recordsByPartitionId.nonEmpty) { + val record = p.mapOutputStatistics.recordsByPartitionId.sum + val recordsByPartitionId = p.mapOutputStatistics.recordsByPartitionId + Statistics(sizeInBytes = sizeInBytes, + bytesByPartitionId = Some(bytesByPartitionId), + recordStatistics = Some(RecordStatistics(record, recordsByPartitionId))) + } else { + Statistics(sizeInBytes = sizeInBytes, bytesByPartitionId = Some(bytesByPartitionId)) + } + } else { + visitUnaryExecNode(p) + } + } + + override def visitSortAggregateExec(p: SortAggregateExec): Statistics = { + if (p.groupingExpressions.isEmpty) { + val sizeInBytes = 8 + p.output.map(_.dataType.defaultSize).sum + Statistics(sizeInBytes) + } else { + visitUnaryExecNode(p) + } + } + + override def visitSortMergeJoinExec(p: SortMergeJoinExec): Statistics = { + p.joinType match { + case LeftAnti | LeftSemi => + // LeftSemi and LeftAnti won't ever be bigger than left + p.left.stats + case _ => + default(p) + } + } + + override def visitShuffleQueryStage(p: ShuffleQueryStage): Statistics = { + if (p.mapOutputStatistics != null) { + val childDataSize = p.child.metrics.get("dataSize").map(_.value).getOrElse(0L) + val sizeInBytes = p.mapOutputStatistics.bytesByPartitionId.sum.max(childDataSize) + val bytesByPartitionId = p.mapOutputStatistics.bytesByPartitionId + if (p.mapOutputStatistics.recordsByPartitionId.nonEmpty) { + val record = p.mapOutputStatistics.recordsByPartitionId.sum + val recordsByPartitionId = p.mapOutputStatistics.recordsByPartitionId + Statistics(sizeInBytes = sizeInBytes, + bytesByPartitionId = Some(bytesByPartitionId), + recordStatistics = Some(RecordStatistics(record, recordsByPartitionId))) + } else { + Statistics(sizeInBytes = sizeInBytes, bytesByPartitionId = Some(bytesByPartitionId)) + } + } else { + visitUnaryExecNode(p) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/statsEstimation/SparkPlanStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/statsEstimation/SparkPlanStats.scala new file mode 100644 index 0000000000000..4ae78fc900efa --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/statsEstimation/SparkPlanStats.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.statsEstimation + +import org.apache.spark.sql.execution.SparkPlan + +/** + * A trait to add statistics propagation to [[SparkPlan]]. + */ +trait SparkPlanStats { self: SparkPlan => + + /** + * Returns the estimated statistics for the current logical plan node. Under the hood, this + * method caches the return value, which is computed based on the configuration passed in the + * first time. If the configuration changes, the cache can be invalidated by calling + * [[invalidateStatsCache()]]. + */ + def stats: Statistics = statsCache.getOrElse { + statsCache = Option(SizeInBytesOnlyStatsPlanVisitor.visit(self)) + statsCache.get + } + + /** A cache for the estimated statistics, such that it will only be computed once. */ + protected var statsCache: Option[Statistics] = None + + /** Invalidates the stats cache. See [[stats]] for more information. */ + final def invalidateStatsCache(): Unit = { + statsCache = None + children.foreach(_.invalidateStatsCache()) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/statsEstimation/Statistics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/statsEstimation/Statistics.scala new file mode 100644 index 0000000000000..45c2fac484159 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/statsEstimation/Statistics.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.statsEstimation + +import java.math.{MathContext, RoundingMode} + +import org.apache.spark.util.Utils + + +/** + * Estimates of various statistics. The default estimation logic simply lazily multiplies the + * corresponding statistic produced by the children. To override this behavior, override + * `statistics` and assign it an overridden version of `Statistics`. + * + * '''NOTE''': concrete and/or overridden versions of statistics fields should pay attention to the + * performance of the implementations. The reason is that estimations might get triggered in + * performance-critical processes, such as query plan planning. + * + * Note that we are using a BigInt here since it is easy to overflow a 64-bit integer in + * cardinality estimation (e.g. cartesian joins). + * + * @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise it + * defaults to the product of children's `sizeInBytes`. + * @param rowCount Estimated number of rows. + */ + +case class PartitionStatistics( + bytesByPartitionId: Array[Long], + recordsByPartitionId: Array[Long]) + +case class RecordStatistics( + record: BigInt, + recordsByPartitionId: Array[Long]) + +case class Statistics( + sizeInBytes: BigInt, + bytesByPartitionId: Option[Array[Long]] = None, + recordStatistics: Option[RecordStatistics] = None) { + + override def toString: String = "Statistics(" + simpleString + ")" + + /** Readable string representation for the Statistics. */ + def simpleString: String = { + Seq(s"sizeInBytes=${Utils.bytesToString(sizeInBytes)}", + if (recordStatistics.isDefined) { + // Show row count in scientific notation. + s"record=${BigDecimal(recordStatistics.get.record, + new MathContext(3, RoundingMode.HALF_UP)).toString()}" + } else { + "" + } + ).filter(_.nonEmpty).mkString(", ") + } + + def getPartitionStatistics : Option[PartitionStatistics] = { + if (bytesByPartitionId.isDefined && recordStatistics.isDefined) { + Some(PartitionStatistics(bytesByPartitionId.get, recordStatistics.get.recordsByPartitionId)) + } else { + None + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index f13d512b2e558..53dd961554424 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -275,25 +275,25 @@ class SQLAppStatusListener( } } + private def toStoredNodes(nodes: Seq[SparkPlanGraphNode]): Seq[SparkPlanGraphNodeWrapper] = { + nodes.map { + case cluster: SparkPlanGraphCluster => + val storedCluster = new SparkPlanGraphClusterWrapper( + cluster.id, + cluster.name, + cluster.desc, + toStoredNodes(cluster.nodes), + cluster.metrics) + new SparkPlanGraphNodeWrapper(null, storedCluster) + + case node => + new SparkPlanGraphNodeWrapper(node, null) + } + } + private def onExecutionStart(event: SparkListenerSQLExecutionStart): Unit = { val SparkListenerSQLExecutionStart(executionId, description, details, - physicalPlanDescription, sparkPlanInfo, time) = event - - def toStoredNodes(nodes: Seq[SparkPlanGraphNode]): Seq[SparkPlanGraphNodeWrapper] = { - nodes.map { - case cluster: SparkPlanGraphCluster => - val storedCluster = new SparkPlanGraphClusterWrapper( - cluster.id, - cluster.name, - cluster.desc, - toStoredNodes(cluster.nodes), - cluster.metrics) - new SparkPlanGraphNodeWrapper(null, storedCluster) - - case node => - new SparkPlanGraphNodeWrapper(node, null) - } - } + physicalPlanDescription, sparkPlanInfo, time) = event val planGraph = SparkPlanGraph(sparkPlanInfo) val sqlPlanMetrics = planGraph.allNodes.flatMap { node => @@ -321,6 +321,27 @@ class SQLAppStatusListener( } } + private def onAdaptiveExecutionUpdate(event: SparkListenerSQLAdaptiveExecutionUpdate): Unit = { + val SparkListenerSQLAdaptiveExecutionUpdate(executionId, + physicalPlanDescription, sparkPlanInfo) = event + + val planGraph = SparkPlanGraph(sparkPlanInfo) + val sqlPlanMetrics = planGraph.allNodes.flatMap { node => + node.metrics.map { metric => (metric.accumulatorId, metric) } + }.toMap.values.toList + + val graphToStore = new SparkPlanGraphWrapper( + executionId, + toStoredNodes(planGraph.nodes), + planGraph.edges) + kvstore.write(graphToStore) + + val exec = getOrCreateExecution(executionId) + exec.physicalPlanDescription = physicalPlanDescription + exec.metrics = sqlPlanMetrics + update(exec) + } + private def onExecutionEnd(event: SparkListenerSQLExecutionEnd): Unit = { val SparkListenerSQLExecutionEnd(executionId, time) = event Option(liveExecutions.get(executionId)).foreach { exec => @@ -349,6 +370,7 @@ class SQLAppStatusListener( override def onOtherEvent(event: SparkListenerEvent): Unit = event match { case e: SparkListenerSQLExecutionStart => onExecutionStart(e) + case e: SparkListenerSQLAdaptiveExecutionUpdate => onAdaptiveExecutionUpdate(e) case e: SparkListenerSQLExecutionEnd => onExecutionEnd(e) case e: SparkListenerDriverAccumUpdates => onDriverAccumUpdates(e) case _ => // Ignore diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index dd0b4b6002f19..99568bcad2989 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -38,6 +38,13 @@ case class SparkListenerSQLExecutionStart( time: Long) extends SparkListenerEvent +@DeveloperApi +case class SparkListenerSQLAdaptiveExecutionUpdate( + executionId: Long, + physicalPlanDescription: String, + sparkPlanInfo: SparkPlanInfo) + extends SparkListenerEvent + @DeveloperApi case class PostQueryExecutionForKylin( localProperties: Properties, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index e57d080dadf78..15b4acfb662b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -96,6 +96,18 @@ object SparkPlanGraph { case "InputAdapter" => buildSparkPlanGraphNode( planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges) + case "QueryStage" | "BroadcastQueryStage" | "ResultQueryStage" | "ShuffleQueryStage" => + if (exchanges.contains(planInfo.children.head)) { + // Point to the re-used exchange + val node = exchanges(planInfo.children.head) + edges += SparkPlanGraphEdge(node.id, parent.id) + } else { + buildSparkPlanGraphNode( + planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges) + } + case "QueryStageInput" | "ShuffleQueryStageInput" | "BroadcastQueryStageInput" => + buildSparkPlanGraphNode( + planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges) case "Subquery" if subgraph != null => // Subquery should not be included in WholeStageCodegen buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, parent, null, exchanges) diff --git a/sql/core/src/test/resources/log4j.properties b/sql/core/src/test/resources/log4j.properties index 2e5cac12952db..e1c3076855331 100644 --- a/sql/core/src/test/resources/log4j.properties +++ b/sql/core/src/test/resources/log4j.properties @@ -22,7 +22,7 @@ log4j.rootLogger=INFO, CA, FA log4j.appender.CA=org.apache.log4j.ConsoleAppender log4j.appender.CA.layout=org.apache.log4j.PatternLayout log4j.appender.CA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %p %c: %m%n -log4j.appender.CA.Threshold = WARN +log4j.appender.CA.Threshold = INFO log4j.appender.CA.follow = true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index d40f756f04753..4e593ff046a53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1270,7 +1270,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val agg = cp.groupBy('id % 2).agg(count('id)) agg.queryExecution.executedPlan.collectFirst { - case ShuffleExchangeExec(_, _: RDDScanExec, _, _) => + case ShuffleExchangeExec(_, _: RDDScanExec, _) => case BroadcastExchangeExec(_, _: RDDScanExec) => }.foreach { _ => fail( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index b656c9b0f4f8e..d18e72226462f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.execution +import scala.collection.mutable + import org.scalatest.BeforeAndAfterAll import org.apache.spark.{MapOutputStatistics, SparkConf, SparkFunSuite} import org.apache.spark.sql._ +import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageInput import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -45,9 +48,9 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } private def checkEstimation( - coordinator: ExchangeCoordinator, - bytesByPartitionIdArray: Array[Array[Long]], - expectedPartitionStartIndices: Array[Int]): Unit = { + coordinator: ExchangeCoordinator, + bytesByPartitionIdArray: Array[Array[Long]], + expectedPartitionStartIndices: Array[Int]): Unit = { val mapOutputStatistics = bytesByPartitionIdArray.zipWithIndex.map { case (bytesByPartitionId, index) => new MapOutputStatistics(index, bytesByPartitionId) @@ -57,8 +60,39 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { assert(estimatedPartitionStartIndices === expectedPartitionStartIndices) } + private def checkEstimation( + coordinator: ExchangeCoordinator, + bytesByPartitionIdArray: Array[Array[Long]], + rowCountsByPartitionIdArray: Array[Array[Long]], + expectedPartitionStartIndices: Array[Int]): Unit = { + val mapOutputStatistics = bytesByPartitionIdArray.zip(rowCountsByPartitionIdArray).zipWithIndex + .map { + case ((bytesByPartitionId, rowCountByPartitionId), index) => + new MapOutputStatistics(index, bytesByPartitionId, rowCountByPartitionId) + } + val estimatedPartitionStartIndices = + coordinator.estimatePartitionStartIndices(mapOutputStatistics) + assert(estimatedPartitionStartIndices === expectedPartitionStartIndices) + } + + private def checkStartEndEstimation( + coordinator: ExchangeCoordinator, + bytesByPartitionIdArray: Array[Array[Long]], + omittedPartitions: mutable.HashSet[Int], + expectedPartitionStartIndices: Array[Int], + expectedPartitionEndIndices: Array[Int]): Unit = { + val mapOutputStatistics = bytesByPartitionIdArray.zipWithIndex.map { + case (bytesByPartitionId, index) => + new MapOutputStatistics(index, bytesByPartitionId) + } + val (estimatedPartitionStartIndices, estimatedPartitionEndIndices) = + coordinator.estimatePartitionStartEndIndices(mapOutputStatistics, omittedPartitions) + assert(estimatedPartitionStartIndices === expectedPartitionStartIndices) + assert(estimatedPartitionEndIndices === expectedPartitionEndIndices) + } + test("test estimatePartitionStartIndices - 1 Exchange") { - val coordinator = new ExchangeCoordinator(100L) + val coordinator = new ExchangeCoordinator(100L, 100L) { // All bytes per partition are 0. @@ -105,7 +139,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } test("test estimatePartitionStartIndices - 2 Exchanges") { - val coordinator = new ExchangeCoordinator(100L) + val coordinator = new ExchangeCoordinator(100L, 100L) { // If there are multiple values of the number of pre-shuffle partitions, @@ -199,7 +233,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } test("test estimatePartitionStartIndices and enforce minimal number of reducers") { - val coordinator = new ExchangeCoordinator(100L, Some(2)) + val coordinator = new ExchangeCoordinator(100L, 100L, 2) { // The minimal number of post-shuffle partitions is not enforced because @@ -236,6 +270,103 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } } + test("test estimatePartitionStartIndices and let row count exceed the threshold") { + val coordinator = new ExchangeCoordinator(100L, 100L) + + val rowCountsByPartitionIdArray = Array(Array(120L, 20, 90, 1, 20)) + + { + // Total bytes is less than the target size, but the sum of row count will exceed the + // threshold. + // 3 post-shuffle partition is needed. + val bytesByPartitionId = Array[Long](1, 1, 1, 1, 1) + val expectedPartitionStartIndices = Array[Int](0, 1, 2, 4) + checkEstimation(coordinator, + Array(bytesByPartitionId), + rowCountsByPartitionIdArray, + expectedPartitionStartIndices) + } + } + + test("test estimatePartitionStartEndIndices") { + val coordinator = new ExchangeCoordinator(100L, 100L) + + { + // All bytes per partition are 0. + val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0) + val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0) + val omittedPartitions = mutable.HashSet[Int](0, 4) + val expectedPartitionStartIndices = Array[Int](1) + val expectedPartitionEndIndices = Array[Int](4) + checkStartEndEstimation( + coordinator, + Array(bytesByPartitionId1, bytesByPartitionId2), + omittedPartitions, + expectedPartitionStartIndices, + expectedPartitionEndIndices) + } + + { + // 1 post-shuffle partition is needed. + val bytesByPartitionId1 = Array[Long](0, 30, 0, 20, 0) + val bytesByPartitionId2 = Array[Long](30, 0, 20, 0, 20) + val omittedPartitions = mutable.HashSet[Int](0, 1) + val expectedPartitionStartIndices = Array[Int](2) + val expectedPartitionEndIndices = Array[Int](5) + checkStartEndEstimation( + coordinator, + Array(bytesByPartitionId1, bytesByPartitionId2), + omittedPartitions, + expectedPartitionStartIndices, + expectedPartitionEndIndices) + } + + { + // 3 post-shuffle partition are needed. + val bytesByPartitionId1 = Array[Long](0, 10, 0, 20, 0) + val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30) + val omittedPartitions = mutable.HashSet[Int](3) + val expectedPartitionStartIndices = Array[Int](0, 2, 4) + val expectedPartitionEndIndices = Array[Int](2, 3, 5) + checkStartEndEstimation( + coordinator, + Array(bytesByPartitionId1, bytesByPartitionId2), + omittedPartitions, + expectedPartitionStartIndices, + expectedPartitionEndIndices) + } + + { + // 2 post-shuffle partition are needed. + val bytesByPartitionId1 = Array[Long](0, 100, 0, 30, 0) + val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30) + val omittedPartitions = mutable.HashSet[Int](1, 2, 3) + val expectedPartitionStartIndices = Array[Int](0, 4) + val expectedPartitionEndIndices = Array[Int](1, 5) + checkStartEndEstimation( + coordinator, + Array(bytesByPartitionId1, bytesByPartitionId2), + omittedPartitions, + expectedPartitionStartIndices, + expectedPartitionEndIndices) + } + + { + // There are a few large pre-shuffle partitions. + val bytesByPartitionId1 = Array[Long](0, 120, 40, 30, 0) + val bytesByPartitionId2 = Array[Long](30, 0, 60, 0, 110) + val omittedPartitions = mutable.HashSet[Int](1, 4) + val expectedPartitionStartIndices = Array[Int](0, 2, 3) + val expectedPartitionEndIndices = Array[Int](1, 3, 4) + checkStartEndEstimation( + coordinator, + Array(bytesByPartitionId1, bytesByPartitionId2), + omittedPartitions, + expectedPartitionStartIndices, + expectedPartitionEndIndices) + } + } + /////////////////////////////////////////////////////////////////////////// // Query tests /////////////////////////////////////////////////////////////////////////// @@ -250,16 +381,16 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } def withSparkSession( - f: SparkSession => Unit, - targetNumPostShufflePartitions: Int, - minNumPostShufflePartitions: Option[Int]): Unit = { + f: SparkSession => Unit, + targetNumPostShufflePartitions: Int, + minNumPostShufflePartitions: Option[Int]): Unit = { val sparkConf = new SparkConf(false) .setMaster("local[*]") .setAppName("test") - .set("spark.ui.enabled", "false") + .set("spark.ui.enabled", "true") .set("spark.driver.allowMultipleContexts", "true") - .set(SQLConf.SHUFFLE_PARTITIONS.key, "5") + .set(SQLConf.SHUFFLE_MAX_NUM_POSTSHUFFLE_PARTITIONS.key, "5") .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true") .set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1") .set( @@ -269,7 +400,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { case Some(numPartitions) => sparkConf.set(SQLConf.SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS.key, numPartitions.toString) case None => - sparkConf.set(SQLConf.SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS.key, "-1") + sparkConf.set(SQLConf.SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS.key, "1") } val spark = SparkSession.builder() @@ -299,25 +430,21 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. - val exchanges = agg.queryExecution.executedPlan.collect { - case e: ShuffleExchangeExec => e + val queryStageInputs = agg.queryExecution.executedPlan.collect { + case q: ShuffleQueryStageInput => q } - assert(exchanges.length === 1) + assert(queryStageInputs.length === 1) minNumPostShufflePartitions match { case Some(numPartitions) => - exchanges.foreach { - case e: ShuffleExchangeExec => - assert(e.coordinator.isDefined) - assert(e.outputPartitioning.numPartitions === 5) - case o => + queryStageInputs.foreach { q => + assert(q.partitionStartIndices.isDefined) + assert(q.outputPartitioning.numPartitions === 5) } case None => - exchanges.foreach { - case e: ShuffleExchangeExec => - assert(e.coordinator.isDefined) - assert(e.outputPartitioning.numPartitions === 3) - case o => + queryStageInputs.foreach { q => + assert(q.partitionStartIndices.isDefined) + assert(q.outputPartitioning.numPartitions === 3) } } } @@ -350,25 +477,21 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. - val exchanges = join.queryExecution.executedPlan.collect { - case e: ShuffleExchangeExec => e + val queryStageInputs = join.queryExecution.executedPlan.collect { + case q: ShuffleQueryStageInput => q } - assert(exchanges.length === 2) + assert(queryStageInputs.length === 2) minNumPostShufflePartitions match { case Some(numPartitions) => - exchanges.foreach { - case e: ShuffleExchangeExec => - assert(e.coordinator.isDefined) - assert(e.outputPartitioning.numPartitions === 5) - case o => + queryStageInputs.foreach { q => + assert(q.partitionStartIndices.isDefined) + assert(q.outputPartitioning.numPartitions === 5) } case None => - exchanges.foreach { - case e: ShuffleExchangeExec => - assert(e.coordinator.isDefined) - assert(e.outputPartitioning.numPartitions === 2) - case o => + queryStageInputs.foreach { q => + assert(q.partitionStartIndices.isDefined) + assert(q.outputPartitioning.numPartitions === 2) } } } @@ -406,26 +529,26 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. - val exchanges = join.queryExecution.executedPlan.collect { - case e: ShuffleExchangeExec => e + val queryStageInputs = join.queryExecution.executedPlan.collect { + case q: ShuffleQueryStageInput => q } - assert(exchanges.length === 4) + assert(queryStageInputs.length === 2) minNumPostShufflePartitions match { case Some(numPartitions) => - exchanges.foreach { - case e: ShuffleExchangeExec => - assert(e.coordinator.isDefined) - assert(e.outputPartitioning.numPartitions === 5) - case o => + queryStageInputs.foreach { q => + assert(q.partitionStartIndices.isDefined) + assert(q.outputPartitioning.numPartitions === 5) } case None => - assert(exchanges.forall(_.coordinator.isDefined)) - assert(exchanges.map(_.outputPartitioning.numPartitions).toSet === Set(2, 3)) + queryStageInputs.foreach { q => + assert(q.partitionStartIndices.isDefined) + assert(q.outputPartitioning.numPartitions === 2) + } } } - withSparkSession(test, 6644, minNumPostShufflePartitions) + withSparkSession(test, 16384, minNumPostShufflePartitions) } test(s"determining the number of reducers: complex query 2$testNameNote") { @@ -458,84 +581,26 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. - val exchanges = join.queryExecution.executedPlan.collect { - case e: ShuffleExchangeExec => e + val queryStageInputs = join.queryExecution.executedPlan.collect { + case q: ShuffleQueryStageInput => q } - assert(exchanges.length === 3) + assert(queryStageInputs.length === 2) minNumPostShufflePartitions match { case Some(numPartitions) => - exchanges.foreach { - case e: ShuffleExchangeExec => - assert(e.coordinator.isDefined) - assert(e.outputPartitioning.numPartitions === 5) - case o => + queryStageInputs.foreach { q => + assert(q.partitionStartIndices.isDefined) + assert(q.outputPartitioning.numPartitions === 5) } case None => - assert(exchanges.forall(_.coordinator.isDefined)) - assert(exchanges.map(_.outputPartitioning.numPartitions).toSet === Set(5, 3)) + queryStageInputs.foreach { q => + assert(q.partitionStartIndices.isDefined) + assert(q.outputPartitioning.numPartitions === 3) + } } } - withSparkSession(test, 6144, minNumPostShufflePartitions) - } - } - - test("SPARK-28231 adaptive execution should ignore RepartitionByExpression") { - val test = { spark: SparkSession => - val df = - spark - .range(0, 1000, 1, numInputPartitions) - .repartition(20, col("id")) - .selectExpr("id % 20 as key", "id as value") - val agg = df.groupBy("key").count() - - // Check the answer first. - checkAnswer( - agg, - spark.range(0, 20).selectExpr("id", "50 as cnt").collect()) - - // Then, let's look at the number of post-shuffle partitions estimated - // by the ExchangeCoordinator. - val exchanges = agg.queryExecution.executedPlan.collect { - case e: ShuffleExchangeExec => e - } - assert(exchanges.length === 2) - exchanges.foreach { - case e @ ShuffleExchangeExec(_, _, _, Some(false)) => - assert(e.coordinator.isDefined) - assert(e.outputPartitioning.numPartitions === 5) - case e @ ShuffleExchangeExec(_, _, _, Some(true)) => - assert(e.coordinator.isEmpty) - assert(e.outputPartitioning.numPartitions === 20) - case o => - } - } - withSparkSession(test, 4, None) - } - - test("SPARK-24705 adaptive query execution works correctly when exchange reuse enabled") { - val test = { spark: SparkSession => - spark.sql("SET spark.sql.exchange.reuse=true") - val df = spark.range(1).selectExpr("id AS key", "id AS value") - val resultDf = df.join(df, "key").join(df, "key") - val sparkPlan = resultDf.queryExecution.executedPlan - assert(sparkPlan.collect { case p: ReusedExchangeExec => p }.length == 1) - assert(sparkPlan.collect { - case p @ ShuffleExchangeExec(_, _, Some(c), _) => p }.length == 3) - checkAnswer(resultDf, Row(0, 0, 0, 0) :: Nil) - } - withSparkSession(test, 4, None) - } - - test("SPARK-29284 adaptive query execution works correctly " + - "when first stage partitions size is 0") { - val test = { spark: SparkSession => - spark.sql("SET spark.sql.adaptive.enabled=true") - spark.sql("SET spark.sql.shuffle.partitions=1") - val resultDf = spark.range(0).distinct().groupBy().count() - checkAnswer(resultDf, Row(0) :: Nil) + withSparkSession(test, 12000, minNumPostShufflePartitions) } - withSparkSession(test, 4, None) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index fa617d5eda847..e4e224df7607f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -412,7 +412,6 @@ class PlannerSuite extends SharedSQLContext { val inputPlan = ShuffleExchangeExec( partitioning, DummySparkPlan(outputPartitioning = partitioning), - None, None) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) @@ -429,7 +428,6 @@ class PlannerSuite extends SharedSQLContext { val inputPlan = ShuffleExchangeExec( partitioning, DummySparkPlan(outputPartitioning = partitioning), - None, None) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/QueryStageSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/QueryStageSuite.scala new file mode 100644 index 0000000000000..8b0339d2013da --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/QueryStageSuite.scala @@ -0,0 +1,880 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkFunSuite +import org.apache.spark.internal.config +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf + +class QueryStageSuite extends SparkFunSuite with BeforeAndAfterAll { + + private var originalActiveSparkSession: Option[SparkSession] = _ + private var originalInstantiatedSparkSession: Option[SparkSession] = _ + + override protected def beforeAll(): Unit = { + originalActiveSparkSession = SparkSession.getActiveSession + originalInstantiatedSparkSession = SparkSession.getDefaultSession + + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } + + override protected def afterAll(): Unit = { + // Set these states back. + originalActiveSparkSession.foreach(ctx => SparkSession.setActiveSession(ctx)) + originalInstantiatedSparkSession.foreach(ctx => SparkSession.setDefaultSession(ctx)) + } + + def defaultSparkSession(): SparkSession = { + val spark = SparkSession.builder() + .master("local[*]") + .appName("test") + .config("spark.ui.enabled", "true") + .config("spark.driver.allowMultipleContexts", "true") + .config(SQLConf.SHUFFLE_MAX_NUM_POSTSHUFFLE_PARTITIONS.key, "5") + .config(config.SHUFFLE_STATISTICS_VERBOSE.key, "true") + .config(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true") + .config(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1") + .config(SQLConf.ADAPTIVE_BROADCASTJOIN_THRESHOLD.key, "16000") + .getOrCreate() + spark + } + + def withSparkSession(spark: SparkSession)(f: SparkSession => Unit): Unit = { + try f(spark) finally spark.stop() + } + + val numInputPartitions: Int = 10 + + def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { + QueryTest.checkAnswer(actual, expectedAnswer) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + + def checkJoin(join: DataFrame, spark: SparkSession): Unit = { + // Before Execution, there is one SortMergeJoin + val smjBeforeExecution = join.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjBeforeExecution.length === 1) + + // Check the answer. + val expectedAnswer = + spark + .range(0, 1000) + .selectExpr("id % 500 as key", "id as value") + .union(spark.range(0, 1000).selectExpr("id % 500 as key", "id as value")) + checkAnswer( + join, + expectedAnswer.collect()) + + // During execution, the SortMergeJoin is changed to BroadcastHashJoinExec + val smjAfterExecution = join.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjAfterExecution.length === 0) + + val numBhjAfterExecution = join.queryExecution.executedPlan.collect { + case smj: BroadcastHashJoinExec => smj + }.length + assert(numBhjAfterExecution === 1) + + // Both shuffle should be local shuffle + val queryStageInputs = join.queryExecution.executedPlan.collect { + case q: ShuffleQueryStageInput => q + } + assert(queryStageInputs.length === 2) + assert(queryStageInputs.forall(_.isLocalShuffle) === true) + } + + test("1 sort merge join to broadcast join") { + withSparkSession(defaultSparkSession) { spark: SparkSession => + val df1 = + spark + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 500 as key1", "id as value1") + val df2 = + spark + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 500 as key2", "id as value2") + + val innerJoin = df1.join(df2, col("key1") === col("key2")).select(col("key1"), col("value2")) + checkJoin(innerJoin, spark) + + val leftJoin = + df1.join(df2, col("key1") === col("key2"), "left").select(col("key1"), col("value1")) + checkJoin(leftJoin, spark) + } + } + + test("2 sort merge joins to broadcast joins") { + // t1 and t3 are smaller than the spark.sql.adaptiveBroadcastJoinThreshold + // t2 is greater than spark.sql.adaptiveBroadcastJoinThreshold + // Both Join1 and Join2 are changed to broadcast join. + // + // Join2 + // / \ + // Join1 Ex (Exchange) + // / \ \ + // Ex Ex t3 + // / \ + // t1 t2 + withSparkSession(defaultSparkSession) { spark: SparkSession => + val df1 = + spark + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 500 as key1", "id as value1") + val df2 = + spark + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 500 as key2", "id as value2") + val df3 = + spark + .range(0, 500, 1, numInputPartitions) + .selectExpr("id % 500 as key3", "id as value3") + + val join = + df1 + .join(df2, col("key1") === col("key2")) + .join(df3, col("key2") === col("key3")) + .select(col("key3"), col("value1")) + + // Before Execution, there is two SortMergeJoins + val smjBeforeExecution = join.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjBeforeExecution.length === 2) + + // Check the answer. + val expectedAnswer = + spark + .range(0, 1000) + .selectExpr("id % 500 as key", "id as value") + .union(spark.range(0, 1000).selectExpr("id % 500 as key", "id as value")) + checkAnswer( + join, + expectedAnswer.collect()) + + // During execution, 2 SortMergeJoin are changed to BroadcastHashJoin + val smjAfterExecution = join.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjAfterExecution.length === 0) + + val numBhjAfterExecution = join.queryExecution.executedPlan.collect { + case smj: BroadcastHashJoinExec => smj + }.length + assert(numBhjAfterExecution === 2) + + val queryStageInputs = join.queryExecution.executedPlan.collect { + case q: QueryStageInput => q + } + assert(queryStageInputs.length === 3) + } + } + + test("Do not change sort merge join if it adds additional Exchanges") { + // t1 is smaller than spark.sql.adaptiveBroadcastJoinThreshold + // t2 and t3 are greater than spark.sql.adaptiveBroadcastJoinThreshold + // Both Join1 and Join2 are not changed to broadcast join. + // + // Join2 + // / \ + // Join1 Ex (Exchange) + // / \ \ + // Ex Ex t3 + // / \ + // t1 t2 + withSparkSession(defaultSparkSession) { spark: SparkSession => + val df1 = + spark + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 500 as key1", "id as value1") + val df2 = + spark + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 500 as key2", "id as value2") + val df3 = + spark + .range(0, 1500, 1, numInputPartitions) + .selectExpr("id % 500 as key3", "id as value3") + + val join = + df1 + .join(df2, col("key1") === col("key2")) + .join(df3, col("key2") === col("key3")) + .select(col("key3"), col("value1")) + + // Before Execution, there is two SortMergeJoins + val smjBeforeExecution = join.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjBeforeExecution.length === 2) + + // Check the answer. + val partResult = + spark + .range(0, 1000) + .selectExpr("id % 500 as key", "id as value") + .union(spark.range(0, 1000).selectExpr("id % 500 as key", "id as value")) + val expectedAnswer = partResult.union(partResult).union(partResult) + checkAnswer( + join, + expectedAnswer.collect()) + + // During execution, no SortMergeJoin is changed to BroadcastHashJoin + val smjAfterExecution = join.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjAfterExecution.length === 2) + + val numBhjAfterExecution = join.queryExecution.executedPlan.collect { + case smj: BroadcastHashJoinExec => smj + }.length + assert(numBhjAfterExecution === 0) + + val queryStageInputs = join.queryExecution.executedPlan.collect { + case q: QueryStageInput => q + } + assert(queryStageInputs.length === 3) + } + } + + test("Reuse QueryStage in adaptive execution") { + withSparkSession(defaultSparkSession) { spark: SparkSession => + val df = spark.range(0, 1000, 1, numInputPartitions).toDF() + val join = df.join(df, "id") + + // Before Execution, there is one SortMergeJoin + val smjBeforeExecution = join.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjBeforeExecution.length === 1) + + checkAnswer(join, df.collect()) + + // During execution, the SortMergeJoin is changed to BroadcastHashJoinExec + val smjAfterExecution = join.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjAfterExecution.length === 0) + + val numBhjAfterExecution = join.queryExecution.executedPlan.collect { + case smj: BroadcastHashJoinExec => smj + }.length + assert(numBhjAfterExecution === 1) + + val queryStageInputs = join.queryExecution.executedPlan.collect { + case q: QueryStageInput => q + } + assert(queryStageInputs.length === 2) + + assert(queryStageInputs(0).childStage === queryStageInputs(1).childStage) + } + } + + test("adaptive skewed join") { + val spark = defaultSparkSession + spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_JOIN_ENABLED.key, "false") + spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_SKEWED_JOIN_ENABLED.key, "true") + spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_ROW_COUNT_THRESHOLD.key, 10) + withSparkSession(spark) { spark: SparkSession => + val df1 = + spark + .range(0, 10, 1, 2) + .selectExpr("id % 5 as key1", "id as value1") + val df2 = + spark + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 1 as key2", "id as value2") + + val join = df1.join(df2, col("key1") === col("key2")).select(col("key1"), col("value2")) + + // Before Execution, there is one SortMergeJoin + val smjBeforeExecution = join.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjBeforeExecution.length === 1) + + // Check the answer. + val expectedAnswer = + spark + .range(0, 1000) + .selectExpr("0 as key", "id as value") + .union(spark.range(0, 1000).selectExpr("0 as key", "id as value")) + checkAnswer( + join, + expectedAnswer.collect()) + + // During execution, the SMJ is changed to Union of SMJ + 5 SMJ of the skewed partition. + val smjAfterExecution = join.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjAfterExecution.length === 6) + + val queryStageInputs = join.queryExecution.executedPlan.collect { + case q: ShuffleQueryStageInput => q + } + assert(queryStageInputs.length === 2) + assert(queryStageInputs(0).skewedPartitions === queryStageInputs(1).skewedPartitions) + assert(queryStageInputs(0).skewedPartitions === Some(Set(0))) + } + } + + test("adaptive skewed join: left/right outer join and skewed on right side") { + val spark = defaultSparkSession + spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_JOIN_ENABLED.key, "false") + spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_SKEWED_JOIN_ENABLED.key, "true") + spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_ROW_COUNT_THRESHOLD.key, 10) + withSparkSession(spark) { spark: SparkSession => + val df1 = + spark + .range(0, 10, 1, 2) + .selectExpr("id % 5 as key1", "id as value1") + val df2 = + spark + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 1 as key2", "id as value2") + + val leftOuterJoin = + df1.join(df2, col("key1") === col("key2"), "left").select(col("key1"), col("value2")) + val rightOuterJoin = + df1.join(df2, col("key1") === col("key2"), "right").select(col("key1"), col("value2")) + + // Before Execution, there is one SortMergeJoin + val smjBeforeExecutionForLeftOuter = leftOuterJoin.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjBeforeExecutionForLeftOuter.length === 1) + + val smjBeforeExecutionForRightOuter = leftOuterJoin.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjBeforeExecutionForRightOuter.length === 1) + + // Check the answer. + val expectedAnswerForLeftOuter = + spark + .range(0, 1000) + .selectExpr("0 as key", "id as value") + .union(spark.range(0, 1000).selectExpr("0 as key", "id as value")) + .union(spark.range(0, 10, 1).filter(_ % 5 != 0).selectExpr("id % 5 as key1", "null")) + checkAnswer( + leftOuterJoin, + expectedAnswerForLeftOuter.collect()) + + val expectedAnswerForRightOuter = + spark + .range(0, 1000) + .selectExpr("0 as key", "id as value") + .union(spark.range(0, 1000).selectExpr("0 as key", "id as value")) + checkAnswer( + rightOuterJoin, + expectedAnswerForRightOuter.collect()) + + // For the left outer join case: during execution, the SMJ can not be translated to any sub + // joins due to the skewed side is on the right but the join type is left outer + // (not correspond with each other) + val smjAfterExecutionForLeftOuter = leftOuterJoin.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjAfterExecutionForLeftOuter.length === 1) + + // For the right outer join case: during execution, the SMJ is changed to Union of SMJ + 5 SMJ + // joins due to the skewed side is on the right and the join type is right + // outer (correspond with each other) + val smjAfterExecutionForRightOuter = rightOuterJoin.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + + assert(smjAfterExecutionForRightOuter.length === 6) + val queryStageInputs = rightOuterJoin.queryExecution.executedPlan.collect { + case q: ShuffleQueryStageInput => q + } + assert(queryStageInputs.length === 2) + assert(queryStageInputs(0).skewedPartitions === queryStageInputs(1).skewedPartitions) + assert(queryStageInputs(0).skewedPartitions === Some(Set(0))) + + } + } + + test("adaptive skewed join: left/right outer join and skewed on left side") { + val spark = defaultSparkSession + spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_JOIN_ENABLED.key, "false") + spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_SKEWED_JOIN_ENABLED.key, "true") + spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_ROW_COUNT_THRESHOLD.key, 10) + withSparkSession(spark) { spark: SparkSession => + val df1 = + spark + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 1 as key1", "id as value1") + val df2 = + spark + .range(0, 10, 1, 2) + .selectExpr("id % 5 as key2", "id as value2") + + val leftOuterJoin = + df1.join(df2, col("key1") === col("key2"), "left").select(col("key1"), col("value1")) + val rightOuterJoin = + df1.join(df2, col("key1") === col("key2"), "right").select(col("key1"), col("value1")) + + // Before Execution, there is one SortMergeJoin + val smjBeforeExecutionForLeftOuter = leftOuterJoin.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjBeforeExecutionForLeftOuter.length === 1) + + val smjBeforeExecutionForRightOuter = leftOuterJoin.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjBeforeExecutionForRightOuter.length === 1) + + // Check the answer. + val expectedAnswerForLeftOuter = + spark + .range(0, 1000) + .selectExpr("0 as key", "id as value") + .union(spark.range(0, 1000).selectExpr("0 as key", "id as value")) + checkAnswer( + leftOuterJoin, + expectedAnswerForLeftOuter.collect()) + + val expectedAnswerForRightOuter = + spark + .range(0, 1000) + .selectExpr("0 as key", "id as value") + .union(spark.range(0, 1000).selectExpr("0 as key", "id as value")) + .union(spark.range(0, 10, 1).filter(_ % 5 != 0).selectExpr("null", "null")) + + checkAnswer( + rightOuterJoin, + expectedAnswerForRightOuter.collect()) + + // For the left outer join case: during execution, the SMJ is changed to Union of SMJ + 5 SMJ + // joins due to the skewed side is on the left and the join type is left outer + // (correspond with each other) + val smjAfterExecutionForLeftOuter = leftOuterJoin.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjAfterExecutionForLeftOuter.length === 6) + + // For the right outer join case: during execution, the SMJ can not be translated to any sub + // joins due to the skewed side is on the left but the join type is right outer + // (not correspond with each other) + val smjAfterExecutionForRightOuter = rightOuterJoin.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + + assert(smjAfterExecutionForRightOuter.length === 1) + val queryStageInputs = leftOuterJoin.queryExecution.executedPlan.collect { + case q: ShuffleQueryStageInput => q + } + assert(queryStageInputs.length === 2) + assert(queryStageInputs(0).skewedPartitions === queryStageInputs(1).skewedPartitions) + assert(queryStageInputs(0).skewedPartitions === Some(Set(0))) + + } + } + + test("adaptive skewed join: left/right outer join and skewed on both sides") { + val spark = defaultSparkSession + spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_JOIN_ENABLED.key, "false") + spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_SKEWED_JOIN_ENABLED.key, "true") + spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_ROW_COUNT_THRESHOLD.key, 10) + withSparkSession(spark) { spark: SparkSession => + import spark.implicits._ + val df1 = + spark + .range(0, 100, 1, numInputPartitions) + .selectExpr("id % 1 as key1", "id as value1") + val df2 = + spark + .range(0, 100, 1, numInputPartitions) + .selectExpr("id % 1 as key2", "id as value2") + + val leftOuterJoin = + df1.join(df2, col("key1") === col("key2"), "left").select(col("key1"), col("value2")) + val rightOuterJoin = + df1.join(df2, col("key1") === col("key2"), "right").select(col("key1"), col("value2")) + + // Before Execution, there is one SortMergeJoin + val smjBeforeExecutionForLeftOuter = leftOuterJoin.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjBeforeExecutionForLeftOuter.length === 1) + + val smjBeforeExecutionForRightOuter = leftOuterJoin.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjBeforeExecutionForRightOuter.length === 1) + + // Check the answer. + val expectedAnswerForLeftOuter = + spark + .range(0, 100) + .flatMap(i => Seq.fill(100)(i)) + .selectExpr("0 as key", "value") + + checkAnswer( + leftOuterJoin, + expectedAnswerForLeftOuter.collect()) + + val expectedAnswerForRightOuter = + spark + .range(0, 100) + .flatMap(i => Seq.fill(100)(i)) + .selectExpr("0 as key", "value") + checkAnswer( + rightOuterJoin, + expectedAnswerForRightOuter.collect()) + + // For the left outer join case: during execution, although the skewed sides include the + // right, the SMJ is still changed to Union of SMJ + 5 SMJ joins due to the skewed sides + // also include the left, so we split the left skewed partition + // (correspondence exists) + val smjAfterExecutionForLeftOuter = leftOuterJoin.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjAfterExecutionForLeftOuter.length === 6) + + // For the right outer join case: during execution, although the skewed sides include the + // left, the SMJ is still changed to Union of SMJ + 5 SMJ joins due to the skewed sides + // also include the right, so we split the right skewed partition + // (correspondence exists) + val smjAfterExecutionForRightOuter = rightOuterJoin.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + + assert(smjAfterExecutionForRightOuter.length === 6) + val queryStageInputs = rightOuterJoin.queryExecution.executedPlan.collect { + case q: ShuffleQueryStageInput => q + } + assert(queryStageInputs.length === 2) + assert(queryStageInputs(0).skewedPartitions === queryStageInputs(1).skewedPartitions) + assert(queryStageInputs(0).skewedPartitions === Some(Set(0))) + + } + } + + test("adaptive skewed join: left semi/anti join and skewed on right side") { + val spark = defaultSparkSession + spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_JOIN_ENABLED.key, "false") + spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_SKEWED_JOIN_ENABLED.key, "true") + spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_ROW_COUNT_THRESHOLD.key, 10) + withSparkSession(spark) { spark: SparkSession => + val df1 = + spark + .range(0, 10, 1, 2) + .selectExpr("id % 5 as key1", "id as value1") + val df2 = + spark + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 1 as key2", "id as value2") + + val leftSemiJoin = + df1.join(df2, col("key1") === col("key2"), "left_semi").select(col("key1"), col("value1")) + val leftAntiJoin = + df1.join(df2, col("key1") === col("key2"), "left_anti").select(col("key1"), col("value1")) + + // Before Execution, there is one SortMergeJoin + val smjBeforeExecutionForLeftSemi = leftSemiJoin.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjBeforeExecutionForLeftSemi.length === 1) + + val smjBeforeExecutionForLeftAnti = leftSemiJoin.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjBeforeExecutionForLeftAnti.length === 1) + + // Check the answer. + val expectedAnswerForLeftSemi = + spark + .range(0, 10) + .filter(_ % 5 == 0) + .selectExpr("id % 5 as key", "id as value") + checkAnswer( + leftSemiJoin, + expectedAnswerForLeftSemi.collect()) + + val expectedAnswerForLeftAnti = + spark + .range(0, 10) + .filter(_ % 5 != 0) + .selectExpr("id % 5 as key", "id as value") + checkAnswer( + leftAntiJoin, + expectedAnswerForLeftAnti.collect()) + + // For the left outer join case: during execution, the SMJ can not be translated to any sub + // joins due to the skewed side is on the right but the join type is left semi + // (not correspond with each other) + val smjAfterExecutionForLeftSemi = leftSemiJoin.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjAfterExecutionForLeftSemi.length === 1) + + // For the right outer join case: during execution, the SMJ can not be translated to any sub + // joins due to the skewed side is on the right but the join type is left anti + // (not correspond with each other) + val smjAfterExecutionForLeftAnti = leftAntiJoin.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjAfterExecutionForLeftAnti.length === 1) + + } + } + + test("adaptive skewed join: left semi/anti join and skewed on left side") { + val spark = defaultSparkSession + spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_JOIN_ENABLED.key, "false") + spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_SKEWED_JOIN_ENABLED.key, "true") + spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_ROW_COUNT_THRESHOLD.key, 10) + val MAX_SPLIT = 5 + spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_MAX_SPLITS.key, MAX_SPLIT) + withSparkSession(spark) { spark: SparkSession => + val df1 = + spark + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 1 as key1", "id as value1") + val df2 = + spark + .range(0, 10, 1, 2) + .selectExpr("id % 5 as key2", "id as value2") + + val leftSemiJoin = + df1.join(df2, col("key1") === col("key2"), "left_semi").select(col("key1"), col("value1")) + val leftAntiJoin = + df1.join(df2, col("key1") === col("key2"), "left_anti").select(col("key1"), col("value1")) + + // Before Execution, there is one SortMergeJoin + val smjBeforeExecutionForLeftSemi = leftSemiJoin.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjBeforeExecutionForLeftSemi.length === 1) + + val smjBeforeExecutionForLeftAnti = leftSemiJoin.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjBeforeExecutionForLeftAnti.length === 1) + + // Check the answer. + val expectedAnswerForLeftSemi = + spark + .range(0, 1000) + .selectExpr("id % 1 as key", "id as value") + checkAnswer( + leftSemiJoin, + expectedAnswerForLeftSemi.collect()) + + val expectedAnswerForLeftAnti = Seq.empty + checkAnswer( + leftAntiJoin, + expectedAnswerForLeftAnti) + + // For the left outer join case: during execution, the SMJ is changed to Union of SMJ + 5 SMJ + // joins due to the skewed side is on the left and the join type is left semi + // (correspond with each other) + val smjAfterExecutionForLeftSemi = leftSemiJoin.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjAfterExecutionForLeftSemi.length === MAX_SPLIT + 1) + + // For the right outer join case: during execution, the SMJ is changed to Union of SMJ + 5 SMJ + // joins due to the skewed side is on the left and the join type is left anti + // (correspond with each other) + val smjAfterExecutionForLeftAnti = leftAntiJoin.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjAfterExecutionForLeftAnti.length === MAX_SPLIT + 1) + + val queryStageInputs = leftSemiJoin.queryExecution.executedPlan.collect { + case q: ShuffleQueryStageInput => q + } + assert(queryStageInputs.length === 2) + assert(queryStageInputs(0).skewedPartitions === queryStageInputs(1).skewedPartitions) + assert(queryStageInputs(0).skewedPartitions === Some(Set(0))) + + val skewedQueryStageInputs = leftSemiJoin.queryExecution.executedPlan.collect { + case q: SkewedShuffleQueryStageInput => q + } + assert(skewedQueryStageInputs.length === MAX_SPLIT * 2) + + } + } + + test("row count statistics, compressed") { + val spark = defaultSparkSession + withSparkSession(spark) { spark: SparkSession => + spark.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, "200") + spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1") + + val df1 = + spark + .range(0, 105, 1, 1) + .select(when(col("id") < 100, 1).otherwise(col("id")).as("id")) + val df2 = df1.repartition(col("id")) + assert(df2.collect().length == 105) + + val siAfterExecution = df2.queryExecution.executedPlan.collect { + case si: ShuffleQueryStageInput => si + } + assert(siAfterExecution.length === 1) + + // MapStatus uses log base 1.1 on records to compress, + // after decompressing, it becomes to 106 + val stats = siAfterExecution.head.childStage.mapOutputStatistics + assert(stats.recordsByPartitionId.count(_ == 106) == 1) + } + } + + test("row count statistics, highly compressed") { + val spark = defaultSparkSession + withSparkSession(spark) { spark: SparkSession => + spark.sparkContext.conf.set(config.SHUFFLE_ACCURATE_BLOCK_RECORD_THRESHOLD.key, "20") + spark.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, "2002") + spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1") + + val df1 = + spark + .range(0, 105, 1, 1) + .select(when(col("id") < 100, 1).otherwise(col("id")).as("id")) + val df2 = df1.repartition(col("id")) + assert(df2.collect().length == 105) + + val siAfterExecution = df2.queryExecution.executedPlan.collect { + case si: ShuffleQueryStageInput => si + } + assert(siAfterExecution.length === 1) + + // MapStatus uses log base 1.1 on records to compress, + // after decompressing, it becomes to 106 + val stats = siAfterExecution.head.childStage.mapOutputStatistics + assert(stats.recordsByPartitionId.count(_ == 106) == 1) + } + } + + test("row count statistics, verbose is false") { + val spark = defaultSparkSession + withSparkSession(spark) { spark: SparkSession => + spark.sparkContext.conf.set(config.SHUFFLE_STATISTICS_VERBOSE.key, "false") + + val df1 = + spark + .range(0, 105, 1, 1) + .select(when(col("id") < 100, 1).otherwise(col("id")).as("id")) + val df2 = df1.repartition(col("id")) + assert(df2.collect().length == 105) + + val siAfterExecution = df2.queryExecution.executedPlan.collect { + case si: ShuffleQueryStageInput => si + } + assert(siAfterExecution.length === 1) + + val stats = siAfterExecution.head.childStage.mapOutputStatistics + assert(stats.recordsByPartitionId.isEmpty) + } + } + + test("Calculate local shuffle read partition ranges") { + val testArrays = Array( + Array(0L, 0, 1, 2, 0, 1, 2, 0), + Array(1L, 1, 0), + Array(0L, 1, 0), + Array(0L, 0), + Array(1L, 2, 3), + Array[Long]() + ) + val anserStart = Array( + Array(2, 5), + Array(0), + Array(1), + Array(0), + Array(0), + Array(0) + ) + val anserEnd = Array( + Array(4, 7), + Array(2), + Array(2), + Array(0), + Array(3), + Array(0) + ) + val func = OptimizeJoin(new SQLConf).calculatePartitionStartEndIndices _ + testArrays.zip(anserStart).zip(anserEnd).foreach { + case ((parameter, expectStart), expectEnd) => + val (resultStart, resultEnd) = func(parameter) + assert(resultStart.deep == expectStart.deep) + assert(resultEnd.deep == expectEnd.deep) + case _ => + } + } + + test("equally divide mappers in skewed partition") { + val handleSkewedJoin = HandleSkewedJoin(defaultSparkSession().sqlContext.conf) + val cases = Seq((0, 5), (4, 5), (15, 5), (16, 5), (17, 5), (18, 5), (19, 5), (20, 5)) + val expects = Seq( + Seq(0, 0, 0, 0, 0), + Seq(0, 1, 2, 3, 4), + Seq(0, 3, 6, 9, 12), + Seq(0, 4, 7, 10, 13), + Seq(0, 4, 8, 11, 14), + Seq(0, 4, 8, 12, 15), + Seq(0, 4, 8, 12, 16), + Seq(0, 4, 8, 12, 16)) + cases.zip(expects).foreach { case ((numElements, numBuckets), expect) => + val answer = handleSkewedJoin.equallyDivide(numElements, numBuckets) + assert(answer === expect) + } + } + + test("different pre-shuffle partition number") { + val spark = defaultSparkSession + import spark.implicits._ + val tName = "test" + scala.util.Random.nextInt(1000) + spark.sql(s"""CREATE table $tName (age INT, name STRING) + | USING parquet""".stripMargin) + val data: Seq[(Int, String)] = (1 to 2).map { i => (i, s"this is test $i") } + data.toDF("key", "value").createOrReplaceTempView("t") + spark.sql(s"insert overwrite table $tName select * from t") + + checkAnswer(spark.sql(s"select count($tName.age) from $tName group by $tName.name" + + s" union all select count($tName.age) from $tName"), + Row(1) :: Row(1) :: Row(2) :: Nil) + } + + test("different pre-shuffle partition number of datasets to union with adaptive") { + val sparkSession = defaultSparkSession + val dataset1 = sparkSession.range(1000) + val dataset2 = sparkSession.range(1001) + + val compute = dataset1.repartition(505, dataset1.col("id")) + .union(dataset2.repartition(105, dataset2.col("id"))) + + assert(compute.orderBy("id").toDF("id").takeAsList(10).toArray + === Seq((0), (0), (1), (1), (2), (2), (3), (3), (4), (4)).map(i => Row(i)).toArray) + } +} diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index d9102cb846fda..0a51082def301 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../../pom.xml diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index a4aa7f9601a92..05cdc5a3c5056 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../../pom.xml diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index e0ddaba7bea26..30450d976a8b8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.statsEstimation.Statistics import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.client.HiveClientImpl import org.apache.spark.sql.internal.SQLConf @@ -223,4 +224,9 @@ case class HiveTableScanExec( } override def otherCopyArgs: Seq[AnyRef] = Seq(sparkSession) + + override def computeStats(): Statistics = { + val stats = relation.computeStats() + Statistics(stats.sizeInBytes) + } } diff --git a/streaming/pom.xml b/streaming/pom.xml index d99e91d0fe648..c624014c5c07d 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index 166b0ad3006e6..13dc40057c250 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.4.1-kylin-r39 + 2.4.1-kylin-r40 ../pom.xml