From 3dade7782a88a04caf34d8ff556b0fa08ceeed3d Mon Sep 17 00:00:00 2001 From: Guo Chenzhao Date: Tue, 30 Jun 2020 09:13:04 +0800 Subject: [PATCH] [remote-shuffle]Remote shuffle manager for spark3.0 (#1356) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add RemoteShuffle codebase to OAP (#1156) * Initial commit * Add pom * Update ignore * Add basic components for remote shuffle writing * Add ExternalSorter for writing to HDFS * Update actual writing class RemoteBlockObjectWriter, update related interfaces to RemoteBOW * Update ShuffleResolver to write index file and commit * Spill to remote storage * Add RemoteExternalSorter test suite * Test RemoteExternalSorter writer to HDFS * Write as .index, .data * Fix minor bugs * Add tests for RemoteShuffleBlockResolver * General remote shuffle reader * Test getBlockData in Resolver * Test HadoopFileSegmentManagedBuffer * Refactor Resolver and test suite * Fix: check existence first * Test actual reading iterator * Fix appId early getting, add basic RDD shuffle operation test * Fix bug in the condition of empty mapoutput data file, add tests to ensure this * Introduce classes for optimized shuffle writing * Optimized shuffle writer path & tests * Optimized path configurable and refactor * Introduce BypassMergeSortShuffleWriter * Implement bypass mergesort path & tests * Refactor: move HDFS connection from Utils to Resolver, add RemoteShuffleConf * Introduce RemoteAggregator and related classes, refactor RemoteSorter * Aggregator spill to remote storage, add tests for RemoteAppendOnlyMap * Fix: No closing after coping streams * Hardcode using Hadoop 2.7, truncate half write content when exception occurs, add tests for BlockObjectWriter * Fix test suite, test shuffle reader should read by block * Avoid overriding Spark classes to make default shuffle manager still work, and other refactors * Fix wrong importing, make more classes not override Spark code * Make storage master and root directory configurable * Properly get appId while running on distributed env * Lazy evaluation for getting SparkEnv vars * Add a remote bypass-merge threshold conf * Assemble Hadoop Configuration from SparkConf ++ else, instead of loading local default * Fix * Use SortShuffle's block iterator framework including shuffle blocks pre-fetch * Not loading any default config from files, and more objects reuse * Make replica configurable * Rename to ShuffleRemoteSorter * Fix: use RemoteSorter instead of ExternalSorter * Introduce DAGScheduler * With executors lost, no need to rerun map tasks thanks to remote shuffle * Require remote shuffle and external shuffle service not be enabled at the same time * When index cache enabled, fetch index files from executors who wrote them * Read index from Guava cache * UT doesn't rely on external systems * Add travis support * add code for read/write metrics (#5) * update read/write metrics * write/read metrics 功能添加完毕 * Delete compile.sh * metrics pr * metrics pr * add code about read/write metrics * add codes about shuffle read/write * add codes about shuffle read/write * remove work file * Fix wrong offset and length (#6) * Fix NettyBlockRpcServer: only cast type when remote shuffle enabled * Add micro-benchmark for shuffle writers/reader (#3) * Add SortShuffleWriterBenchmark to compare SortShuffle and RemoteShuffle interfaces * Update travis * Fix * Add other 2 writers' benchmark * Add reader micro-benchmark * Multiple stages in Travis to avoid timeout * Post benchmark results as PR comments * Fix * Debug * Debug * Fix * Beautify * Minor fix * Some renames for better understanding * Style * spark reads hadoop conf remotely (#8) ### What changes were proposed in this pull request? Originally RemoteShuffle load an empty Hadoop configuration by `val hadoopConf = new Configuration(false)`. However, Hadoop configuration needs to be loaded remotely. Some work is done in this pull request. ### How was this patch tested? By a new unit test in `org.apache.spark.shuffle.remote.RemoteShuffleManagerSuite` where a fade server is mocked to provide Hadoop configuration remotely. * Docs (#19) Add configuration and tuning guides. * remove remain/release in RemoteShuffleBlockIterator (#23) The concrete buffer implementation of ManagedBuffer might be managed outside the JVM garbage collector. If the buffer is going to be passed around to a different thread, retain/release should be called. But in RemoteShuffle, HadoopManagedBuffer is used, and it's definitely inside a JVM's lifecycle, so we don't need these operations. * Read DAOS conf from local * check style when compiling (#24) Add scala style check * Remove extra input stream layer, not needed because no buffer releasing (#25) Extra layer brings overhead. * All skip -> seek * More tests on ShuffleManager, UTs on read Iterator covering index cache enabled path * Data file asynchronous pre-fetch from multiple sources (#30) This PR resolves #16 , improving shuffle read performance by asynchronously reading whole ShuffleBlocks requests to memory(and then perform later operations) & constraining the number of reading requests in flight. In reduce stage, we observed a long time thread blocking for remote I/O to be ready. An optimization resembles vanilla Spark's can be made: send multiple block reading requests asynchronously before we actually need the data for compute, put the shuffle blocks fetched in a queue, and use the subsequent compute takes whichever block that's ready first. Constrain the requests in flight by maxBytesInFlight, maxReqsInFlight, maxBlocksInFlightPerAddress (these 3 are identical to vanilla Spark) and maxConcurrentFetches(introduced, for the maximum data file reading threads) More tests with bigger datasets, different map side partition lengths, index cache enabled/disabled, and constraints set/unset. * Refactor & style * Put index information in cache in map stage to avoid loading from storage in reduce stage (#33) * Put index info in cache in map stage if index cache is enabled * Refactor * Fix * Fix: Failing to fetch remote HDFS configurations should not crash the app (#36) Minor fix to avoid exceptions originated by 2 reasons under HDFS: 1)port unset, 2)connection failed. * Add corruption detect (#34) * Add corruption detect * Throw Exception only in task threads * Only retry the failed map tasks * Fix unsafe shuffle writer (#39) Part of #37 When memory is insufficient and spill happens, the outputs produced by unsafe shuffle writer are wrong. It's due to the bugs in mergeSpillsWithTransferTo, missed the length parameter during Streams copying. Actually this merge path doesn't apply in remote shuffle over Hadoop storage, because the NIO-based transferTo optimization may not exist. Added unit tests to ensure the correctness. * Add UTs for RemoteSorter (#40) Ensure RemoteSorter correctness. * Shuffle read metrics update even after cleaning up (#42) * Shuffle read metrics update even after cleaning up * Style * Not overidding Spark source code for better compatibility (#44) * Not overidding Spark source code for better compatibility * Fix: RpcEnv is not set in Executor * Test fix * Implement close * Catch and log Exception during RemoteShuffleTransferService's closing * Remove benchmarker * Remove the logis that will never go through under the customized TransferService, throw Exception in those branches * Set numCores using reflection, get from Dispatcher * Move package * Adding back benchmark * Style and comments * Remove reflection, let a config determine threads number for new transfer service * Not reading hdfs-site.xml when storage is DAOS * Move repository * Move repository Co-authored-by: Shuaiqi Ge <35885772+BestOreo@users.noreply.github.com> * Integrate remote-shuffle in CI & more docs (#1167) * CI * Remove subdir travis * Docs * More docs * Separate travis tests to different stages * Fix * Introduce new performance evaluation tool and deprecate the old micro-benchmark (#1172) * [remote-shuffle]Refactor (#1206) * Refactor * Docs * [remote-shuffle]Add docs for performance evaluation tool (#1233) * Allow producing a test jar with dependencies, refactor * Support -h help * Add docs * Disable hash-based shuffle writer by default (#1239) * Reuse file handle in reduce stage (#1234) * Remove perf evaluation tool * Update: scheduler in Spark 3.0 * Basic update for Spark3.0, updated ShuffleManager and related codes * Upper level batch fetch support, full custom metrics support(by Reynold) * Update readme * Modify Travis, empty install * Modify docs Co-authored-by: Shuaiqi Ge <35885772+BestOreo@users.noreply.github.com> --- .travis.yml | 6 + oap-shuffle/remote-shuffle/.gitignore | 4 + oap-shuffle/remote-shuffle/LICENSE | 202 ++ oap-shuffle/remote-shuffle/README.md | 131 +- oap-shuffle/remote-shuffle/dev/checkstyle.xml | 189 ++ .../remote-shuffle/dev/post_results_to_PR.sh | 26 + oap-shuffle/remote-shuffle/pom.xml | 301 +++ .../remote-shuffle/scalastyle-config.xml | 387 +++ .../shuffle/MyOneForOneBlockFetcher.java | 113 + .../RemoteBypassMergeSortShuffleWriter.java | 271 ++ .../spark/shuffle/sort/RemoteSpillInfo.java | 37 + .../sort/RemoteUnsafeShuffleSorter.java | 442 ++++ .../sort/RemoteUnsafeShuffleWriter.java | 478 ++++ .../network/netty/MyNettyBlockRpcServer.scala | 92 + .../netty/RemoteShuffleTransferService.scala | 158 ++ .../apache/spark/scheduler/DAGScheduler.scala | 2239 +++++++++++++++++ .../HadoopFileSegmentManagedBuffer.scala | 205 ++ .../shuffle/remote/RemoteAggregator.scala | 66 + .../remote/RemoteBlockObjectWriter.scala | 295 +++ .../remote/RemoteShuffleBlockIterator.scala | 550 ++++ .../remote/RemoteShuffleBlockResolver.scala | 388 +++ .../shuffle/remote/RemoteShuffleConf.scala | 112 + .../shuffle/remote/RemoteShuffleManager.scala | 257 ++ .../shuffle/remote/RemoteShuffleReader.scala | 151 ++ .../shuffle/remote/RemoteShuffleUtils.scala | 88 + .../shuffle/remote/RemoteShuffleWriter.scala | 108 + .../RPartitionedAppendOnlyMap.scala | 45 + .../collection/RPartitionedPairBuffer.scala | 106 + .../RWritablePartitionedPairCollection.scala | 113 + .../util/collection/RemoteAppendOnlyMap.scala | 643 +++++ .../spark/util/collection/RemoteSorter.scala | 848 +++++++ .../sort/RemoteUnsafeShuffleWriterSuite.java | 572 +++++ .../RemoteShuffleBlockIteratorSuite.scala | 259 ++ .../RemoteShuffleBlockObjectWriterSuite.scala | 193 ++ .../RemoteShuffleBlockResolverSuite.scala | 246 ++ .../remote/RemoteShuffleManagerSuite.scala | 183 ++ .../apache/spark/shuffle/remote/package.scala | 35 + .../collection/RemoteAppendOnlyMapSuite.scala | 575 +++++ .../util/collection/RemoteSorterSuite.scala | 696 +++++ .../spark/util/collection/package.scala | 32 + .../test-jar-with-dependencies.xml | 19 + 41 files changed, 11859 insertions(+), 2 deletions(-) create mode 100644 oap-shuffle/remote-shuffle/.gitignore create mode 100644 oap-shuffle/remote-shuffle/LICENSE create mode 100644 oap-shuffle/remote-shuffle/dev/checkstyle.xml create mode 100644 oap-shuffle/remote-shuffle/dev/post_results_to_PR.sh create mode 100644 oap-shuffle/remote-shuffle/pom.xml create mode 100644 oap-shuffle/remote-shuffle/scalastyle-config.xml create mode 100644 oap-shuffle/remote-shuffle/src/main/java/org/apache/spark/network/shuffle/MyOneForOneBlockFetcher.java create mode 100644 oap-shuffle/remote-shuffle/src/main/java/org/apache/spark/shuffle/sort/RemoteBypassMergeSortShuffleWriter.java create mode 100644 oap-shuffle/remote-shuffle/src/main/java/org/apache/spark/shuffle/sort/RemoteSpillInfo.java create mode 100644 oap-shuffle/remote-shuffle/src/main/java/org/apache/spark/shuffle/sort/RemoteUnsafeShuffleSorter.java create mode 100644 oap-shuffle/remote-shuffle/src/main/java/org/apache/spark/shuffle/sort/RemoteUnsafeShuffleWriter.java create mode 100644 oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/network/netty/MyNettyBlockRpcServer.scala create mode 100644 oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/network/netty/RemoteShuffleTransferService.scala create mode 100644 oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala create mode 100644 oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/HadoopFileSegmentManagedBuffer.scala create mode 100644 oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteAggregator.scala create mode 100644 oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteBlockObjectWriter.scala create mode 100644 oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleBlockIterator.scala create mode 100644 oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleBlockResolver.scala create mode 100644 oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleConf.scala create mode 100644 oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleManager.scala create mode 100644 oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleReader.scala create mode 100644 oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleUtils.scala create mode 100644 oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleWriter.scala create mode 100644 oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/util/collection/RPartitionedAppendOnlyMap.scala create mode 100644 oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/util/collection/RPartitionedPairBuffer.scala create mode 100644 oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/util/collection/RWritablePartitionedPairCollection.scala create mode 100644 oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/util/collection/RemoteAppendOnlyMap.scala create mode 100644 oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/util/collection/RemoteSorter.scala create mode 100644 oap-shuffle/remote-shuffle/src/test/java/org/apache/spark/shuffle/sort/RemoteUnsafeShuffleWriterSuite.java create mode 100644 oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/shuffle/remote/RemoteShuffleBlockIteratorSuite.scala create mode 100644 oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/shuffle/remote/RemoteShuffleBlockObjectWriterSuite.scala create mode 100644 oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/shuffle/remote/RemoteShuffleBlockResolverSuite.scala create mode 100644 oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/shuffle/remote/RemoteShuffleManagerSuite.scala create mode 100644 oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/shuffle/remote/package.scala create mode 100644 oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/util/collection/RemoteAppendOnlyMapSuite.scala create mode 100644 oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/util/collection/RemoteSorterSuite.scala create mode 100644 oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/util/collection/package.scala create mode 100644 oap-shuffle/remote-shuffle/test-jar-with-dependencies.xml diff --git a/.travis.yml b/.travis.yml index e03bc0a78..fa64f55b8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -55,3 +55,9 @@ jobs: - mvn clean -q package -DskipTests #skip core tests - cd ${TRAVIS_BUILD_DIR}/oap-data-source/arrow - mvn clean -q test + - name: oap-shuffle-remote-shuffle + install: + - #empty install step + script: + - cd ${TRAVIS_BUILD_DIR}/oap-shuffle/remote-shuffle/ + - mvn -q test diff --git a/oap-shuffle/remote-shuffle/.gitignore b/oap-shuffle/remote-shuffle/.gitignore new file mode 100644 index 000000000..4005c21c8 --- /dev/null +++ b/oap-shuffle/remote-shuffle/.gitignore @@ -0,0 +1,4 @@ +target/ +benchmarks/ +.idea/ +*.iml diff --git a/oap-shuffle/remote-shuffle/LICENSE b/oap-shuffle/remote-shuffle/LICENSE new file mode 100644 index 000000000..57bc88a15 --- /dev/null +++ b/oap-shuffle/remote-shuffle/LICENSE @@ -0,0 +1,202 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + diff --git a/oap-shuffle/remote-shuffle/README.md b/oap-shuffle/remote-shuffle/README.md index 1c64e955b..2dda86d26 100644 --- a/oap-shuffle/remote-shuffle/README.md +++ b/oap-shuffle/remote-shuffle/README.md @@ -1,3 +1,130 @@ -TODO: Introduction to this module. - +# Spark Remote Shuffle Plugin +Remote Shuffle is a Spark ShuffleManager plugin, shuffling data through a remote Hadoop-compatible file system, as opposed to vanilla Spark's local-disks. + +This is an essential part of enabling Spark on disaggregated compute and storage architecture. + +## Build and Deploy + +Build the project using the following command or download the pre-built jar: remote-shuffle-\.jar. This file needs to +be deployed on every compute node that runs Spark. Manually place it on all nodes or let resource manager do the work. + +``` + mvn -DskipTests clean package +``` + +## Enable Remote Shuffle + +Add the jar files to the classpath of Spark driver and executors: Put the +following configurations in spark-defaults.conf or Spark submit command line arguments. + +Note: For DAOS users, DAOS Hadoop/Java API jars should also be included in the classpath as we leverage DAOS Hadoop filesystem. + +``` + spark.executor.extraClassPath /path/to/remote-shuffle-dir/remote-shuffle-.jar + spark.driver.extraClassPath /path/to/remote-shuffle-dir/remote-shuffle-.jar +``` + +Enable the remote shuffle manager and specify the Hadoop storage system URI holding shuffle data. + +``` + spark.shuffle.manager org.apache.spark.shuffle.remote.RemoteShuffleManager + spark.shuffle.remote.storageMasterUri daos://default:1 # Or hdfs://namenode:port, file:///my/shuffle/dir +``` + +## Configurations + +Configurations and tuning parameters that change the behavior of remote shuffle. Most of them should work well under default values. + +### Shuffle Root Directory + +This is to configure the root directory holding remote shuffle files. For each Spark application, a +directory named after application ID is created under this root directory. + +``` + spark.shuffle.remote.filesRootDirectory /shuffle +``` + +### Index Cache Size + +This is to configure the cache size for shuffle index files per executor. Shuffle data includes data files and +index files. An index file is small but will be read many (the number of reducers) times. On a large scale, constantly +reading these small index files from Hadoop Filesystem implementation(i.e. HDFS) is going to cause much overhead and latency. In addition, the shuffle files’ +transfer completely relies on the network between compute nodes and storage nodes. But the network inside compute nodes are +not fully utilized. The index cache can eliminate the overhead of reading index files from storage cluster multiple times. By +enabling index file cache, a reduce task fetches them from the remote executors who write them instead of reading from +storage. If the remote executor doesn’t have a desired index file in its cache, it will read the file from storage and cache +it locally. The feature can also be disabled by setting the value to zero. + +``` + spark.shuffle.remote.index.cache.size 30m +``` + +### Number of Threads Reading Data Files + +This is one of the parameters influencing shuffle read performance. It is to determine number of threads per executor reading shuffle data files from storage. + +``` + spark.shuffle.remote.numReadThreads 5 +``` + +### Number of Threads Transitioning Index Files (when index cache is enabled) + +This is one of the parameters influencing shuffle read performance. It is to determine the number of client and server threads that transmit index information from another executor’s cache. It is only valid when the index cache feature is enabled. + +``` + spark.shuffle.remote.numIndexReadThreads 3 +``` + +### Bypass-merge-sort Threshold + +This threshold is used to decide using bypass-merge(hash-based) shuffle or not. By default we disable(by setting it to -1) +hash-based shuffle writer in remote shuffle, because when memory is relatively sufficient, sort-based shuffle writer is often more efficient than the hash-based one. +Hash-based shuffle writer entails a merging process, performing 3x I/Os than total shuffle size: 1 time for read I/Os and 2 times for write I/Os, this can be an even larger overhead under remote shuffle: +the 3x shuffle size is gone through network, arriving at a remote storage system. + +``` + spark.shuffle.remote.bypassMergeThreshold -1 +``` + +### Configurations fetching port for HDFS + +When the backend storage is HDFS, we contact http://$host:$port/conf to fetch configurations. They were not locally loaded because we assume absence of local storage. + +``` + spark.shuffle.remote.hdfs.storageMasterUIPort 50070 +``` + +### Inherited Spark Shuffle Configurations + +These configurations are inherited from upstream Spark, they are still supported in remote shuffle. More explanations can be found in [Spark core docs](https://spark.apache.org/docs/3.0.0/configuration.html#shuffle-behavior) and [Spark SQL docs](https://spark.apache.org/docs/3.0.0/sql-performance-tuning.html). +``` + spark.reducer.maxSizeInFlight + spark.reducer.maxReqsInFlight + spark.reducer.maxBlocksInFlightPerAddress + spark.shuffle.compress + spark.shuffle.file.buffer + spark.shuffle.io.maxRetries + spark.shuffle.io.numConnectionsPerPeer + spark.shuffle.io.preferDirectBufs + spark.shuffle.io.retryWait + spark.shuffle.io.backLog + spark.shuffle.spill.compress + spark.shuffle.accurateBlockThreshold + spark.sql.shuffle.partitions +``` + +### Deprecated Spark Shuffle Configurations + +These configurations are deprecated and will not take effect. +``` + spark.shuffle.sort.bypassMergeThreshold # Replaced by spark.shuffle.remote.bypassMergeThreshold + spark.maxRemoteBlockSizeFetchToMem # As we assume no local disks on compute nodes, shuffle blocks are all fetched to memory + + spark.shuffle.service.enabled # All following configurations are related to External Shuffle Service. ESS & remote shuffle cannot be enabled at the same time, as this remote shuffle facility takes over almost all functionalities of ESS. + spark.shuffle.service.port + spark.shuffle.service.index.cache.size + spark.shuffle.maxChunksBeingTransferred + spark.shuffle.registration.timeout + spark.shuffle.registration.maxAttempts +``` diff --git a/oap-shuffle/remote-shuffle/dev/checkstyle.xml b/oap-shuffle/remote-shuffle/dev/checkstyle.xml new file mode 100644 index 000000000..127ab81ca --- /dev/null +++ b/oap-shuffle/remote-shuffle/dev/checkstyle.xml @@ -0,0 +1,189 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/oap-shuffle/remote-shuffle/dev/post_results_to_PR.sh b/oap-shuffle/remote-shuffle/dev/post_results_to_PR.sh new file mode 100644 index 000000000..7a2b43a37 --- /dev/null +++ b/oap-shuffle/remote-shuffle/dev/post_results_to_PR.sh @@ -0,0 +1,26 @@ +USERNAME=benchmarker-RemoteShuffle +PASSWORD=$BENCHMARKER_PASSWORD +PULL_REQUEST_NUM=$TRAVIS_PULL_REQUEST + +READ_OR_WRITE=$1 + +RESULTS="" +for benchmark_file in benchmarks/*${READ_OR_WRITE}*; do + echo $benchmark_file + RESULTS+=$(cat $benchmark_file) + RESULTS+=$'\n\n' +done + +echo "$RESULTS" + +message='{"body": "```' +message+='\n' +message+="$RESULTS" +message+='\n' +json_message=$(echo "$message" | awk '{printf "%s\\n", $0}') +json_message+='```", "event":"COMMENT"}' +echo "$json_message" > benchmark_results.json + +echo "Sending benchmark requests to PR $PULL_REQUEST_NUM" +curl -XPOST https://${USERNAME}:${PASSWORD}@api.github.com/repos/Intel-bigdata/RemoteShuffle/pulls/${PULL_REQUEST_NUM}/reviews -d @benchmark_results.json +rm benchmark_results.json diff --git a/oap-shuffle/remote-shuffle/pom.xml b/oap-shuffle/remote-shuffle/pom.xml new file mode 100644 index 000000000..0bdeff4f3 --- /dev/null +++ b/oap-shuffle/remote-shuffle/pom.xml @@ -0,0 +1,301 @@ + + + + + 4.0.0 + + org.apache.spark + remote-shuffle + 0.1-SNAPSHOT + Spark Remote Shuffle Plugin + jar + + + + Chenzhao Guo + chenzhao.guo@intel.com + + + + + 2.12.10 + 2.12 + 3.0.0 + 1.8 + UTF-8 + UTF-8 + + + + + org.scala-lang + scala-library + ${scala.version} + + + org.apache.spark + spark-core_2.12 + ${spark.version} + + + org.apache.spark + spark-core_2.12 + ${spark.version} + test-jar + test + + + org.apache.hadoop + hadoop-client + 2.7.4 + + + org.scalatest + scalatest_${scala.binary.version} + test + 3.0.3 + + + junit + junit + 4.12 + test + + + org.hamcrest + hamcrest-core + 1.3 + test + + + org.hamcrest + hamcrest-library + 1.3 + test + + + org.mockito + mockito-core + 2.23.4 + test + + + org.mock-server + mockserver-netty + 5.6.0 + + + org.mock-server + mockserver-client-java + 5.6.0 + + + org.eclipse.jetty + jetty-servlet + 9.4.12.v20180830 + test + + + commons-cli + commons-cli + 1.4 + test + + + + + + + net.alchim31.maven + scala-maven-plugin + 3.4.4 + + + scala-compile-first + process-resources + + add-source + compile + + + + scala-test-compile + process-test-resources + + testCompile + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 2.0.2 + + ${java.version} + ${java.version} + + + + compile + + compile + + + + + + org.apache.maven.plugins + maven-surefire-plugin + 2.7 + + + **/Test*.java + **/*Test.java + **/*TestCase.java + **/*Suite.java + + + 1 + ${project.basedir} + + + 1 + ${scala.binary.version} + + + + + org.scalatest + scalatest-maven-plugin + 1.0 + + . + WDF TestSuite.txt + + 1 + ${project.basedir} + + + 1 + ${scala.binary.version} + + + + + test + + test + + + + + + org.apache.maven.plugins + maven-jar-plugin + 3.0.2 + + + + test-jar + + + + + + org.apache.maven.plugins + maven-assembly-plugin + 3.2.0 + + + test-jar-with-dependencies.xml + + + + + package + + single + + + + + + org.apache.maven.plugins + maven-checkstyle-plugin + 3.0.0 + + false + true + + src/main/java + src/main/scala + + + src/test/java + src/test/scala + + dev/checkstyle.xml + ${basedir}/target/checkstyle-output.xml + ${project.build.sourceEncoding} + ${project.reporting.outputEncoding} + + + + compile + + check + + + + + + + + + diff --git a/oap-shuffle/remote-shuffle/scalastyle-config.xml b/oap-shuffle/remote-shuffle/scalastyle-config.xml new file mode 100644 index 000000000..d46d8acb1 --- /dev/null +++ b/oap-shuffle/remote-shuffle/scalastyle-config.xml @@ -0,0 +1,387 @@ + + + + + Scalastyle standard configuration + + + + + + + + + + + + + + + + + + + ^FunSuite[A-Za-z]*$ + Tests must extend org.apache.spark.SparkFunSuite instead. + + + + + ^println$ + + + + + spark(.sqlContext)?.sparkContext.hadoopConfiguration + + + + + @VisibleForTesting + + + + + Runtime\.getRuntime\.addShutdownHook + + + + + mutable\.SynchronizedBuffer + + + + + Class\.forName + + + + + Await\.result + + + + + Await\.ready + + + + + (\.toUpperCase|\.toLowerCase)(?!(\(|\(Locale.ROOT\))) + + + + + throw new \w+Error\( + + + + + + JavaConversions + Instead of importing implicits in scala.collection.JavaConversions._, import + scala.collection.JavaConverters._ and use .asScala / .asJava methods + + + + org\.apache\.commons\.lang\. + Use Commons Lang 3 classes (package org.apache.commons.lang3.*) instead + of Commons Lang 2 (package org.apache.commons.lang.*) + + + + extractOpt + Use jsonOption(x).map(.extract[T]) instead of .extractOpt[T], as the latter + is slower. + + + + + java,scala,3rdParty,spark + javax?\..* + scala\..* + (?!org\.apache\.spark\.).* + org\.apache\.spark\..* + + + + + + COMMA + + + + + + \)\{ + + + + + (?m)^(\s*)/[*][*].*$(\r|)\n^\1 [*] + Use Javadoc style indentation for multiline comments + + + + case[^\n>]*=>\s*\{ + Omit braces in case clauses. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 800> + + + + + 30 + + + + + 10 + + + + + 50 + + + + + + + + + + + -1,0,1,2,3 + + + diff --git a/oap-shuffle/remote-shuffle/src/main/java/org/apache/spark/network/shuffle/MyOneForOneBlockFetcher.java b/oap-shuffle/remote-shuffle/src/main/java/org/apache/spark/network/shuffle/MyOneForOneBlockFetcher.java new file mode 100644 index 000000000..8aaa12f0b --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/main/java/org/apache/spark/network/shuffle/MyOneForOneBlockFetcher.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.nio.ByteBuffer; + +import io.netty.buffer.Unpooled; +import org.apache.spark.shuffle.remote.HadoopFileSegmentManagedBuffer; +import org.apache.spark.shuffle.remote.MessageForHadoopManagedBuffers; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.shuffle.protocol.StreamHandle; +import scala.Tuple2; + +/** + * Simple wrapper on top of a TransportClient which interprets each chunk as a whole block, and + * invokes the BlockFetchingListener appropriately. This class is agnostic to the actual RPC + * handler, as long as there is a single "open blocks" message which returns a ShuffleStreamHandle, + * and Java serialization is used. + * + * Note that this typically corresponds to a + * {@link org.apache.spark.network.server.OneForOneStreamManager} on the server side. + */ +public class MyOneForOneBlockFetcher { + private static final Logger logger = LoggerFactory.getLogger(MyOneForOneBlockFetcher.class); + + private final TransportClient client; + private final OpenBlocks openMessage; + private final String[] blockIds; + private final BlockFetchingListener listener; + + public MyOneForOneBlockFetcher( + TransportClient client, + String appId, + String execId, + String[] blockIds, + BlockFetchingListener listener) { + this.client = client; + this.openMessage = new OpenBlocks(appId, execId, blockIds); + this.blockIds = blockIds; + this.listener = listener; + } + + /** + * Begins the fetching process, calling the listener with every block fetched. + * The given message will be serialized with the Java serializer, and the RPC must return a + * {@link StreamHandle}. We will send all fetch requests immediately, without throttling. + */ + public void start() { + if (blockIds.length == 0) { + throw new IllegalArgumentException("Zero-sized blockIds array"); + } + + client.sendRpc(openMessage.toByteBuffer(), new RpcResponseCallback() { + @Override + public void onSuccess(ByteBuffer response) { + try { + boolean isShuffleRequest = + (response.get(0) == MessageForHadoopManagedBuffers.MAGIC_CODE()); + if (isShuffleRequest) { + MessageForHadoopManagedBuffers message = + MessageForHadoopManagedBuffers.fromByteBuffer(Unpooled.wrappedBuffer(response)); + for (Tuple2 entry: message.buffers()) { + listener.onBlockFetchSuccess(entry._1, entry._2); + } + } else { + throw new UnsupportedOperationException("MyNettyOneForOneBlockFetcher got an " + + "unexpected response, which is not from RemoteShuffle dedicated TransferService"); + } + } catch (Exception e) { + logger.error("Failed while starting block fetches after success", e); + failRemainingBlocks(blockIds, e); + } + } + + @Override + public void onFailure(Throwable e) { + logger.error("Failed while starting block fetches", e); + failRemainingBlocks(blockIds, e); + } + }); + } + + /** Invokes the "onBlockFetchFailure" callback for every listed block id. */ + private void failRemainingBlocks(String[] failedBlockIds, Throwable e) { + for (String blockId : failedBlockIds) { + try { + listener.onBlockFetchFailure(blockId, e); + } catch (Exception e2) { + logger.error("Error in block fetch failure callback", e2); + } + } + } +} diff --git a/oap-shuffle/remote-shuffle/src/main/java/org/apache/spark/shuffle/sort/RemoteBypassMergeSortShuffleWriter.java b/oap-shuffle/remote-shuffle/src/main/java/org/apache/spark/shuffle/sort/RemoteBypassMergeSortShuffleWriter.java new file mode 100644 index 000000000..71934d106 --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/main/java/org/apache/spark/shuffle/sort/RemoteBypassMergeSortShuffleWriter.java @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +import java.io.IOException; + +import javax.annotation.Nullable; + +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; +import scala.None$; +import scala.Option; +import scala.Product2; +import scala.Tuple2; +import scala.collection.Iterator; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.io.Closeables; +import org.apache.hadoop.fs.FSDataInputStream; +import org.apache.hadoop.fs.FSDataOutputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.*; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.scheduler.MapStatus; +import org.apache.spark.scheduler.MapStatus$; +import org.apache.spark.serializer.Serializer; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.serializer.SerializerManager; +import org.apache.spark.shuffle.ShuffleWriter; +import org.apache.spark.shuffle.remote.*; +import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.TempShuffleBlockId; +import org.apache.spark.util.Utils; + +/** + * This class implements sort-based shuffle's hash-style shuffle fallback path. This write path + * writes incoming records to separate files, one file per reduce partition, then concatenates these + * per-partition files to form a single output file, regions of which are served to reducers. + * Records are not buffered in memory. It writes output in a format + * that can be served / consumed via + * {@link org.apache.spark.shuffle.remote.RemoteShuffleBlockResolver}. + *

+ * This write path is inefficient for shuffles with large numbers of reduce partitions because it + * simultaneously opens separate serializers and file streams for all partitions. As a result, + * {@link SortShuffleManager} only selects this write path when + *

    + *
  • no Ordering is specified,
  • + *
  • no Aggregator is specified, and
  • + *
  • the number of partitions is less than + * spark.shuffle.sort.bypassMergeThreshold.
  • + *
+ * + * This code used to be part of {@link org.apache.spark.util.collection.ExternalSorter} but was + * refactored into its own class in order to reduce code complexity; see SPARK-7855 for details. + *

+ * There have been proposals to completely remove this code path; see SPARK-6026 for details. + */ +public final class RemoteBypassMergeSortShuffleWriter extends ShuffleWriter { + + private final SerializerManager serializerManager = SparkEnv.get().serializerManager(); + + private static final Logger logger = LoggerFactory.getLogger( + RemoteBypassMergeSortShuffleWriter.class); + + static { + logger.warn("******** Bypass-Merge-Sort Remote Shuffle Writer is used ********"); + } + + private final int fileBufferSize; + private final boolean transferToEnabled; + private final int numPartitions; + private final BlockManager blockManager; + private final Partitioner partitioner; + private final ShuffleWriteMetricsReporter writeMetrics; + private final int shuffleId; + private final long mapId; + private final Serializer serializer; + private final RemoteShuffleBlockResolver shuffleBlockResolver; + + /** Array of file writers, one for each partition */ + private RemoteBlockObjectWriter[] partitionWriters; + private HadoopFileSegment[] partitionWriterSegments; + @Nullable private MapStatus mapStatus; + private long[] partitionLengths; + + /** + * Are we in the process of stopping? Because map tasks can call stop() with success = true + * and then call stop() with success = false if they get an exception, we want to make sure + * we don't try deleting files, etc twice. + */ + private boolean stopping = false; + + public RemoteBypassMergeSortShuffleWriter( + BlockManager blockManager, + RemoteShuffleBlockResolver shuffleBlockResolver, + BypassMergeSortShuffleHandle handle, + long mapId, + TaskContext taskContext, + SparkConf conf, + ShuffleWriteMetricsReporter metrics) { + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided + this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); + this.blockManager = blockManager; + final ShuffleDependency dep = handle.dependency(); + this.mapId = mapId; + this.shuffleId = dep.shuffleId(); + this.partitioner = dep.partitioner(); + this.numPartitions = partitioner.numPartitions(); + this.writeMetrics = metrics; + this.serializer = dep.serializer(); + this.shuffleBlockResolver = shuffleBlockResolver; + } + + @Override + public void write(Iterator> records) throws IOException { + assert (partitionWriters == null); + if (!records.hasNext()) { + partitionLengths = new long[numPartitions]; + shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null); + mapStatus = MapStatus$.MODULE$.apply( + RemoteShuffleManager$.MODULE$.getResolver().shuffleServerId(), partitionLengths, mapId); + return; + } + final SerializerInstance serInstance = serializer.newInstance(); + final long openStartTime = System.nanoTime(); + partitionWriters = new RemoteBlockObjectWriter[numPartitions]; + partitionWriterSegments = new HadoopFileSegment[numPartitions]; + for (int i = 0; i < numPartitions; i++) { + final Tuple2 tempShuffleBlockIdPlusFile = + shuffleBlockResolver.createTempShuffleBlock(); + final Path file = tempShuffleBlockIdPlusFile._2(); + final BlockId blockId = tempShuffleBlockIdPlusFile._1(); + partitionWriters[i] = + RemoteShuffleUtils.getRemoteWriter( + blockId, file, serializerManager, serInstance, fileBufferSize, writeMetrics); + } + // Creating the file to write to and creating a disk writer both involve interacting with + // the disk, and can take a long time in aggregate when we open many files, so should be + // included in the shuffle write time. + writeMetrics.incWriteTime(System.nanoTime() - openStartTime); + + while (records.hasNext()) { + final Product2 record = records.next(); + final K key = record._1(); + partitionWriters[partitioner.getPartition(key)].write(key, record._2()); + } + + for (int i = 0; i < numPartitions; i++) { + final RemoteBlockObjectWriter writer = partitionWriters[i]; + partitionWriterSegments[i] = writer.commitAndGet(); + writer.close(); + } + + Path output = shuffleBlockResolver.getDataFile(shuffleId, mapId); + Path tmp = RemoteShuffleUtils.tempPathWith(output); + try { + partitionLengths = writePartitionedFile(tmp); + shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); + } finally { + FileSystem fs = RemoteShuffleManager.getFileSystem(); + if (fs.exists(tmp) && !fs.delete(tmp, true)) { + logger.error("Error while deleting temp file {}", tmp.toString()); + } + } + mapStatus = MapStatus$.MODULE$.apply( + RemoteShuffleManager$.MODULE$.getResolver().shuffleServerId(), partitionLengths, mapId); + } + + @VisibleForTesting + long[] getPartitionLengths() { + return partitionLengths; + } + + /** + * Concatenate all of the per-partition files into a single combined file. + * + * @return array of lengths, in bytes, of each partition of the file (used by map output tracker). + */ + private long[] writePartitionedFile(Path outputFile) throws IOException { + final FileSystem fs = RemoteShuffleManager.getFileSystem(); + // Track location of the partition starts in the output file + final long[] lengths = new long[numPartitions]; + if (partitionWriters == null) { + // We were passed an empty iterator + return lengths; + } + + final FSDataOutputStream out = fs.create(outputFile); + final long writeStartTime = System.nanoTime(); + boolean threwException = true; + try { + for (int i = 0; i < numPartitions; i++) { + final Path file = partitionWriterSegments[i].file(); + if (fs.exists(file)) { + final FSDataInputStream in = fs.open(file); + boolean copyThrewException = true; + try { + lengths[i] = Utils.copyStream(in, out, false, transferToEnabled); + copyThrewException = false; + } finally { + Closeables.close(in, copyThrewException); + } + if (!fs.delete(file, true)) { + logger.error("Unable to delete file for partition {}", i); + } + } + } + threwException = false; + } finally { + Closeables.close(out, threwException); + writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); + } + partitionWriters = null; + return lengths; + } + + @Override + public Option stop(boolean success) { + if (stopping) { + return None$.empty(); + } else { + stopping = true; + if (success) { + if (mapStatus == null) { + throw new IllegalStateException("Cannot call stop(true) without having called write()"); + } + return Option.apply(mapStatus); + } else { + // The map task failed, so delete our output data. + if (partitionWriters != null) { + try { + FileSystem fs = null; + for (RemoteBlockObjectWriter writer : partitionWriters) { + // This method explicitly does _not_ throw exceptions: + Path file = writer.revertPartialWritesAndClose(); + fs = RemoteShuffleManager.getFileSystem(); + if (!fs.delete(file, true)) { + logger.error("Error while deleting file {}", file.toString()); + } + } + } catch (IOException e) { + e.printStackTrace(); + } finally { + partitionWriters = null; + } + } + return None$.empty(); + } + } + } +} diff --git a/oap-shuffle/remote-shuffle/src/main/java/org/apache/spark/shuffle/sort/RemoteSpillInfo.java b/oap-shuffle/remote-shuffle/src/main/java/org/apache/spark/shuffle/sort/RemoteSpillInfo.java new file mode 100644 index 000000000..61e5084f1 --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/main/java/org/apache/spark/shuffle/sort/RemoteSpillInfo.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +import org.apache.hadoop.fs.Path; + +import org.apache.spark.storage.TempShuffleBlockId; + +/** + * Metadata for a block of data written by {@link RemoteUnsafeShuffleSorter}. + */ +final class RemoteSpillInfo { + final long[] partitionLengths; + final Path file; + final TempShuffleBlockId blockId; + + RemoteSpillInfo(int numPartitions, Path file, TempShuffleBlockId blockId) { + this.partitionLengths = new long[numPartitions]; + this.file = file; + this.blockId = blockId; + } +} diff --git a/oap-shuffle/remote-shuffle/src/main/java/org/apache/spark/shuffle/sort/RemoteUnsafeShuffleSorter.java b/oap-shuffle/remote-shuffle/src/main/java/org/apache/spark/shuffle/sort/RemoteUnsafeShuffleSorter.java new file mode 100644 index 000000000..56ea7a80b --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/main/java/org/apache/spark/shuffle/sort/RemoteUnsafeShuffleSorter.java @@ -0,0 +1,442 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +import java.io.IOException; +import java.util.LinkedList; + +import javax.annotation.Nullable; + +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; +import scala.Tuple2; + +import com.google.common.annotations.VisibleForTesting; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkEnv; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.internal.config.package$; +import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.memory.SparkOutOfMemoryError; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.memory.TooLargePageException; +import org.apache.spark.serializer.DummySerializerInstance; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.serializer.SerializerManager; +import org.apache.spark.shuffle.remote.*; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.TempShuffleBlockId; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UnsafeAlignedOffset; +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.util.Utils; + +/** + * An external sorter that is specialized for sort-based shuffle. + *

+ * Incoming records are appended to data pages. When all records have been inserted (or when the + * current thread's shuffle memory limit is reached), the in-memory records are sorted according to + * their partition ids (using a {@link ShuffleInMemorySorter}). The sorted records are then + * written to a single output file (or multiple files, if we've spilled). The format of the output + * files is the same as the format of the final output file written by + * {@link org.apache.spark.shuffle.sort.SortShuffleWriter}: each output partition's records are + * written as a single serialized, compressed stream that can be read with a new decompression and + * deserialization stream. + *

+ * Unlike {@link org.apache.spark.util.collection.ExternalSorter}, this sorter does not merge its + * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a + * specialized merge procedure that avoids extra serialization/deserialization. + */ +final class RemoteUnsafeShuffleSorter extends MemoryConsumer { + + private SerializerManager serializerManager = SparkEnv.get().serializerManager(); + + private static final Logger logger = LoggerFactory.getLogger(RemoteUnsafeShuffleSorter.class); + + @VisibleForTesting + static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; + + private final int numPartitions; + private final TaskMemoryManager taskMemoryManager; + private final BlockManager blockManager; + private final TaskContext taskContext; + private final RemoteShuffleBlockResolver resolver; + private final ShuffleWriteMetricsReporter writeMetrics; + + /** + * Force this sorter to spill when there are this many elements in memory. + */ + private final int numElementsForSpillThreshold; + + /** The buffer size to use when writing spills using RemoteBlockObjectWriter */ + private final int fileBufferSizeBytes; + + /** The buffer size to use when writing the sorted records to an on-disk file */ + private final int diskWriteBufferSize; + + /** + * Memory pages that hold the records being sorted. The pages in this list are freed when + * spilling, although in principle we could recycle these pages across spills (on the other hand, + * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager + * itself). + */ + private final LinkedList allocatedPages = new LinkedList<>(); + + private final LinkedList spills = new LinkedList<>(); + + /** Peak memory used by this sorter so far, in bytes. **/ + private long peakMemoryUsedBytes; + + // These variables are reset after spilling: + @Nullable private ShuffleInMemorySorter inMemSorter; + @Nullable private MemoryBlock currentPage = null; + private long pageCursor = -1; + + RemoteUnsafeShuffleSorter( + TaskMemoryManager memoryManager, + BlockManager blockManager, + TaskContext taskContext, + RemoteShuffleBlockResolver resolver, + int initialSize, + int numPartitions, + SparkConf conf, + ShuffleWriteMetricsReporter writeMetrics) { + super(memoryManager, + (int) Math.min(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, memoryManager.pageSizeBytes()), + memoryManager.getTungstenMemoryMode()); + this.taskMemoryManager = memoryManager; + this.blockManager = blockManager; + this.taskContext = taskContext; + this.resolver = resolver; + this.numPartitions = numPartitions; + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided + this.fileBufferSizeBytes = + (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; + this.numElementsForSpillThreshold = + (int) conf.get(package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()); + this.writeMetrics = writeMetrics; + this.inMemSorter = new ShuffleInMemorySorter( + this, initialSize, conf.getBoolean("spark.shuffle.sort.useRadixSort", true)); + this.peakMemoryUsedBytes = getMemoryUsage(); + this.diskWriteBufferSize = + (int) (long) conf.get(package$.MODULE$.SHUFFLE_DISK_WRITE_BUFFER_SIZE()); + } + + /** + * Sorts the in-memory records and writes the sorted records to an on-disk file. + * This method does not free the sort data structures. + * + * @param isLastFile if true, this indicates that we're writing the final output file and that the + * bytes written should be counted towards shuffle spill metrics rather than + * shuffle write metrics. + */ + private void writeSortedFile(boolean isLastFile) { + + final ShuffleWriteMetricsReporter writeMetricsToUse; + + if (isLastFile) { + // We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes. + writeMetricsToUse = writeMetrics; + } else { + // We're spilling, so bytes written should be counted towards spill rather than write. + // Create a dummy WriteMetrics object to absorb these metrics, since we don't want to count + // them towards shuffle bytes written. + writeMetricsToUse = new ShuffleWriteMetrics(); + } + + // This call performs the actual sort. + final ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords = + inMemSorter.getSortedIterator(); + + // Small writes to RemoteBlockObjectWriter will be fairly inefficient. Since there doesn't + // seem to be an API to directly transfer bytes from managed memory to the disk writer, + // we buffer data through a byte array. This array does not need to be large enough to + // hold a single record; + final byte[] writeBuffer = new byte[diskWriteBufferSize]; + + // Because this output will be read during shuffle, its compression codec must be controlled by + // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use + // createTempShuffleBlock here; see SPARK-3426 for more details. + final Tuple2 spilledFileInfo = resolver.createTempShuffleBlock(); + final Path file = spilledFileInfo._2(); + final TempShuffleBlockId blockId = spilledFileInfo._1(); + final RemoteSpillInfo spillInfo = new RemoteSpillInfo(numPartitions, file, blockId); + + // Unfortunately, we need a serializer instance in order to construct a RemoteBlockObjectWriter. + // Our write path doesn't actually use this serializer (since we end up calling the `write()` + // OutputStream methods), but RemoteBlockObjectWriter still calls some methods on it. To work + // around this, we pass a dummy no-op serializer. + final SerializerInstance ser = DummySerializerInstance.INSTANCE; + + final RemoteBlockObjectWriter writer = + RemoteShuffleUtils.getRemoteWriter( + blockId, file, serializerManager, ser, fileBufferSizeBytes, writeMetricsToUse); + + int currentPartition = -1; + final int uaoSize = UnsafeAlignedOffset.getUaoSize(); + while (sortedRecords.hasNext()) { + sortedRecords.loadNext(); + final int partition = sortedRecords.packedRecordPointer.getPartitionId(); + assert (partition >= currentPartition); + if (partition != currentPartition) { + // Switch to the new partition + if (currentPartition != -1) { + final HadoopFileSegment fileSegment = writer.commitAndGet(); + spillInfo.partitionLengths[currentPartition] = fileSegment.length(); + } + currentPartition = partition; + } + + final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); + final Object recordPage = taskMemoryManager.getPage(recordPointer); + final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer); + int dataRemaining = UnsafeAlignedOffset.getSize(recordPage, recordOffsetInPage); + long recordReadPosition = recordOffsetInPage + uaoSize; // skip over record length + while (dataRemaining > 0) { + final int toTransfer = Math.min(diskWriteBufferSize, dataRemaining); + Platform.copyMemory( + recordPage, recordReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer); + writer.write(writeBuffer, 0, toTransfer); + recordReadPosition += toTransfer; + dataRemaining -= toTransfer; + } + writer.recordWritten(); + } + + final HadoopFileSegment committedSegment = writer.commitAndGet(); + writer.close(); + // If `writeSortedFile()` was called from `closeAndGetSpills()` and no records were inserted, + // then the file might be empty. Note that it might be better to avoid calling + // writeSortedFile() in that case. + if (currentPartition != -1) { + spillInfo.partitionLengths[currentPartition] = committedSegment.length(); + spills.add(spillInfo); + } + + if (!isLastFile) { // i.e. this is a spill file + // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records + // are written to disk, not when they enter the shuffle sorting code. RemoteBlockObjectWriter + // relies on its `recordWritten()` method being called in order to trigger periodic updates to + // `shuffleBytesWritten`. If we were to remove the `recordWritten()` call and increment that + // counter at a higher-level, then the in-progress metrics for records written and bytes + // written would get out of sync. + // + // When writing the last file, we pass `writeMetrics` directly to the RemoteBlockObjectWriter; + // in all other cases, we pass in a dummy write metrics to capture metrics, then copy those + // metrics to the true write metrics here. The reason for performing this copying is so that + // we can avoid reporting spilled bytes as shuffle write bytes. + // + // Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`. + // Consistent with ExternalSorter, we do not count this IO towards shuffle write time. + // This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this. + writeMetrics.incRecordsWritten( + ((ShuffleWriteMetrics)writeMetricsToUse).recordsWritten()); + taskContext.taskMetrics().incDiskBytesSpilled( + ((ShuffleWriteMetrics)writeMetricsToUse).bytesWritten()); + } + } + + /** + * Sort and spill the current records in response to memory pressure. + */ + @Override + public long spill(long size, MemoryConsumer trigger) throws IOException { + if (trigger != this || inMemSorter == null || inMemSorter.numRecords() == 0) { + return 0L; + } + + logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", + Thread.currentThread().getId(), + Utils.bytesToString(getMemoryUsage()), + spills.size(), + spills.size() > 1 ? " times" : " time"); + + writeSortedFile(false); + final long spillSize = freeMemory(); + inMemSorter.reset(); + // Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the + // records. Otherwise, if the task is over allocated memory, then without freeing the memory + // pages, we might not be able to get memory for the pointer array. + taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); + return spillSize; + } + + private long getMemoryUsage() { + long totalPageSize = 0; + for (MemoryBlock page : allocatedPages) { + totalPageSize += page.size(); + } + return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize; + } + + private void updatePeakMemoryUsed() { + long mem = getMemoryUsage(); + if (mem > peakMemoryUsedBytes) { + peakMemoryUsedBytes = mem; + } + } + + /** + * Return the peak memory used so far, in bytes. + */ + long getPeakMemoryUsedBytes() { + updatePeakMemoryUsed(); + return peakMemoryUsedBytes; + } + + private long freeMemory() { + updatePeakMemoryUsed(); + long memoryFreed = 0; + for (MemoryBlock block : allocatedPages) { + memoryFreed += block.size(); + freePage(block); + } + allocatedPages.clear(); + currentPage = null; + pageCursor = 0; + return memoryFreed; + } + + /** + * Force all memory and spill files to be deleted; called by shuffle error-handling code. + */ + public void cleanupResources() throws IOException { + freeMemory(); + if (inMemSorter != null) { + inMemSorter.free(); + inMemSorter = null; + } + FileSystem fs = null; + if (!spills.isEmpty()) { + fs = RemoteShuffleManager.getFileSystem(); + } + for (RemoteSpillInfo spill : spills) { + if (fs.exists(spill.file) && !fs.delete(spill.file, true)) { + logger.error("Unable to delete spill file {}", spill.file.toString()); + } + } + } + + /** + * Checks whether there is enough space to insert an additional record in to the sort pointer + * array and grows the array if additional space is required. If the required space cannot be + * obtained, then the in-memory data will be spilled to disk. + */ + private void growPointerArrayIfNecessary() throws IOException { + assert(inMemSorter != null); + if (!inMemSorter.hasSpaceForAnotherRecord()) { + long used = inMemSorter.getMemoryUsage(); + LongArray array; + try { + // could trigger spilling + array = allocateArray(used / 8 * 2); + } catch (TooLargePageException e) { + // The pointer array is too big to fix in a single page, spill. + spill(); + return; + } catch (SparkOutOfMemoryError e) { + // should have trigger spilling + if (!inMemSorter.hasSpaceForAnotherRecord()) { + logger.error("Unable to grow the pointer array"); + throw e; + } + return; + } + // check if spilling is triggered or not + if (inMemSorter.hasSpaceForAnotherRecord()) { + freeArray(array); + } else { + inMemSorter.expandPointerArray(array); + } + } + } + + /** + * Allocates more memory in order to insert an additional record. This will request additional + * memory from the memory manager and spill if the requested memory can not be obtained. + * + * @param required the required space in the data page, in bytes, including space for storing + * the record size. This must be less than or equal to the page size (records + * that exceed the page size are handled via a different code path which uses + * special overflow pages). + */ + private void acquireNewPageIfNecessary(int required) { + if (currentPage == null || + pageCursor + required > currentPage.getBaseOffset() + currentPage.size() ) { + // TODO: try to find space in previous pages + currentPage = allocatePage(required); + pageCursor = currentPage.getBaseOffset(); + allocatedPages.add(currentPage); + } + } + + /** + * Write a record to the shuffle sorter. + */ + public void insertRecord(Object recordBase, long recordOffset, int length, int partitionId) + throws IOException { + + // for tests + assert(inMemSorter != null); + if (inMemSorter.numRecords() >= numElementsForSpillThreshold) { + logger.info("Spilling data because number of spilledRecords crossed the threshold " + + numElementsForSpillThreshold); + spill(); + } + + growPointerArrayIfNecessary(); + final int uaoSize = UnsafeAlignedOffset.getUaoSize(); + // Need 4 or 8 bytes to store the record length. + final int required = length + uaoSize; + acquireNewPageIfNecessary(required); + + assert(currentPage != null); + final Object base = currentPage.getBaseObject(); + final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor); + UnsafeAlignedOffset.putSize(base, pageCursor, length); + pageCursor += uaoSize; + Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length); + pageCursor += length; + inMemSorter.insertRecord(recordAddress, partitionId); + } + + /** + * Close the sorter, causing any buffered data to be sorted and written out to disk. + * + * @return metadata for the spill files written by this sorter. If no records were ever inserted + * into this sorter, then this will return an empty array. + * @throws IOException + */ + public RemoteSpillInfo[] closeAndGetSpills() throws IOException { + if (inMemSorter != null) { + // Do not count the final file towards the spill count. + writeSortedFile(true); + freeMemory(); + inMemSorter.free(); + inMemSorter = null; + } + return spills.toArray(new RemoteSpillInfo[spills.size()]); + } + +} diff --git a/oap-shuffle/remote-shuffle/src/main/java/org/apache/spark/shuffle/sort/RemoteUnsafeShuffleWriter.java b/oap-shuffle/remote-shuffle/src/main/java/org/apache/spark/shuffle/sort/RemoteUnsafeShuffleWriter.java new file mode 100644 index 000000000..f87d35d1a --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/main/java/org/apache/spark/shuffle/sort/RemoteUnsafeShuffleWriter.java @@ -0,0 +1,478 @@ +package org.apache.spark.shuffle.sort; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import java.io.*; +import java.util.Iterator; + +import javax.annotation.Nullable; + +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; +import org.apache.spark.shuffle.remote.RemoteShuffleManager$; +import scala.Option; +import scala.Product2; +import scala.collection.JavaConverters; +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.io.ByteStreams; +import com.google.common.io.Closeables; +import org.apache.commons.io.output.CloseShieldOutputStream; +import org.apache.commons.io.output.CountingOutputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.Partitioner; +import org.apache.spark.ShuffleDependency; +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.annotation.Private; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.internal.config.package$; +import org.apache.spark.io.CompressionCodec; +import org.apache.spark.io.CompressionCodec$; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.network.util.LimitedInputStream; +import org.apache.spark.scheduler.MapStatus; +import org.apache.spark.scheduler.MapStatus$; +import org.apache.spark.serializer.SerializationStream; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.ShuffleWriter; +import org.apache.spark.shuffle.remote.RemoteShuffleBlockResolver; +import org.apache.spark.shuffle.remote.RemoteShuffleManager; +import org.apache.spark.shuffle.remote.RemoteShuffleUtils; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.TimeTrackingOutputStream; +import org.apache.spark.unsafe.Platform; + +@Private +public class RemoteUnsafeShuffleWriter extends ShuffleWriter { + + private static final Logger logger = + LoggerFactory.getLogger(RemoteUnsafeShuffleWriter.class); + + static { + logger.warn("******** Optimized Remote Shuffle Writer is used ********"); + } + + private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); + + @VisibleForTesting + static final int DEFAULT_INITIAL_SORT_BUFFER_SIZE = 4096; + static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024; + + private final BlockManager blockManager; + private final RemoteShuffleBlockResolver shuffleBlockResolver; + private final TaskMemoryManager memoryManager; + private final SerializerInstance serializer; + private final Partitioner partitioner; + private final int shuffleId; + private final long mapId; + private final TaskContext taskContext; + private final SparkConf sparkConf; + private final ShuffleWriteMetricsReporter writeMetrics; + + private final boolean transferToEnabled; + private final int initialSortBufferSize; + private final int inputBufferSizeInBytes; + private final int outputBufferSizeInBytes; + + @Nullable private MapStatus mapStatus; + @Nullable private RemoteUnsafeShuffleSorter sorter; + private long peakMemoryUsedBytes = 0; + + /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ + private static final class MyByteArrayOutputStream extends ByteArrayOutputStream { + MyByteArrayOutputStream(int size) { super(size); } + public byte[] getBuf() { return buf; } + } + + private MyByteArrayOutputStream serBuffer; + private SerializationStream serOutputStream; + + /** + * Are we in the process of stopping? Because map tasks can call stop() with success = true + * and then call stop() with success = false if they get an exception, we want to make sure + * we don't try deleting files, etc twice. + */ + private boolean stopping = false; + + private class CloseAndFlushShieldOutputStream extends CloseShieldOutputStream { + + CloseAndFlushShieldOutputStream(OutputStream outputStream) { + super(outputStream); + } + + @Override + public void flush() { + // do nothing + } + } + + public RemoteUnsafeShuffleWriter( + BlockManager blockManager, + RemoteShuffleBlockResolver shuffleBlockResolver, + TaskMemoryManager memoryManager, + SerializedShuffleHandle handle, + long mapId, + TaskContext taskContext, + SparkConf sparkConf, + ShuffleWriteMetricsReporter metrics) throws IOException { + final int numPartitions = handle.dependency().partitioner().numPartitions(); + if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) { + throw new IllegalArgumentException( + "RemoteUnsafeShuffleWriter can only be used for shuffles with at most " + + SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE() + + " reduce partitions"); + } + this.blockManager = blockManager; + this.shuffleBlockResolver = shuffleBlockResolver; + this.memoryManager = memoryManager; + this.mapId = mapId; + final ShuffleDependency dep = handle.dependency(); + this.shuffleId = dep.shuffleId(); + this.serializer = dep.serializer().newInstance(); + this.partitioner = dep.partitioner(); + this.taskContext = taskContext; + this.sparkConf = sparkConf; + this.writeMetrics = metrics; + this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); + this.initialSortBufferSize = sparkConf.getInt("spark.shuffle.sort.initialBufferSize", + DEFAULT_INITIAL_SORT_BUFFER_SIZE); + this.inputBufferSizeInBytes = + (int) (long) sparkConf + .get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; + this.outputBufferSizeInBytes = + (int) (long) sparkConf + .get(package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024; + open(); + } + + private void updatePeakMemoryUsed() { + // sorter can be null if this writer is closed + if (sorter != null) { + long mem = sorter.getPeakMemoryUsedBytes(); + if (mem > peakMemoryUsedBytes) { + peakMemoryUsedBytes = mem; + } + } + } + + /** + * Return the peak memory used so far, in bytes. + */ + public long getPeakMemoryUsedBytes() { + updatePeakMemoryUsed(); + return peakMemoryUsedBytes; + } + + /** + * This convenience method should only be called in test code. + */ + @VisibleForTesting + public void write(Iterator> records) throws IOException { + write(JavaConverters.asScalaIteratorConverter(records).asScala()); + } + + @Override + public void write(scala.collection.Iterator> records) throws IOException { + // Keep track of success so we know if we encountered an exception + // We do this rather than a standard try/catch/re-throw to handle + // generic throwables. + boolean success = false; + try { + while (records.hasNext()) { + insertRecordIntoSorter(records.next()); + } + closeAndWriteOutput(); + success = true; + } finally { + if (sorter != null) { + try { + sorter.cleanupResources(); + } catch (Exception e) { + // Only throw this error if we won't be masking another + // error. + if (success) { + throw e; + } else { + logger.error("In addition to a failure during writing, we failed during " + + "cleanup.", e); + } + } + } + } + } + + private void open() { + assert (sorter == null); + sorter = new RemoteUnsafeShuffleSorter( + memoryManager, + blockManager, + taskContext, + shuffleBlockResolver, + initialSortBufferSize, + partitioner.numPartitions(), + sparkConf, + writeMetrics); + serBuffer = new RemoteUnsafeShuffleWriter + .MyByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE); + serOutputStream = serializer.serializeStream(serBuffer); + } + + @VisibleForTesting + void closeAndWriteOutput() throws IOException { + assert(sorter != null); + updatePeakMemoryUsed(); + serBuffer = null; + serOutputStream = null; + final RemoteSpillInfo[] spills = sorter.closeAndGetSpills(); + sorter = null; + final long[] partitionLengths; + final Path output = shuffleBlockResolver.getDataFile(shuffleId, mapId); + final Path tmp = RemoteShuffleUtils.tempPathWith(output); + FileSystem fs = RemoteShuffleManager.getFileSystem(); + try { + try { + partitionLengths = mergeSpills(spills, tmp); + } finally { + + for (RemoteSpillInfo spill : spills) { + if (fs.exists(spill.file) && ! fs.delete(spill.file, true)) { + logger.error("Error while deleting spill file {}", spill.file.toString()); + } + } + } + shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); + } finally { + if (fs.exists(tmp) && !fs.delete(tmp, true)) { + logger.error("Error while deleting temp file {}", tmp.toString()); + } + } + mapStatus = MapStatus$.MODULE$.apply( + RemoteShuffleManager$.MODULE$.getResolver().shuffleServerId(), partitionLengths, mapId); + } + + @VisibleForTesting + void insertRecordIntoSorter(Product2 record) throws IOException { + assert(sorter != null); + final K key = record._1(); + final int partitionId = partitioner.getPartition(key); + serBuffer.reset(); + serOutputStream.writeKey(key, OBJECT_CLASS_TAG); + serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG); + serOutputStream.flush(); + + final int serializedRecordSize = serBuffer.size(); + assert (serializedRecordSize > 0); + + sorter.insertRecord( + serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); + } + + @VisibleForTesting + void forceSorterToSpill() throws IOException { + assert (sorter != null); + sorter.spill(); + } + + /** + * Merge zero or more spill files together, choosing the fastest merging strategy based on the + * number of spills and the IO compression codec. + * + * @return the partition lengths in the merged file. + */ + private long[] mergeSpills(RemoteSpillInfo[] spills, Path outputFile) throws IOException { + final FileSystem fs = RemoteShuffleManager.getFileSystem(); + final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true); + final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf); + final boolean fastMergeEnabled = + sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true); + final boolean fastMergeIsSupported = !compressionEnabled || + CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); + try { + if (spills.length == 0) { + fs.create(outputFile).close(); // Create an empty file + return new long[partitioner.numPartitions()]; + } else if (spills.length == 1) { + // Here, we don't need to perform any metrics updates because the bytes written to this + // output file would have already been counted as shuffle bytes written. + fs.rename(spills[0].file, outputFile); + return spills[0].partitionLengths; + } else { + final long[] partitionLengths; + // There are multiple spills to merge, so none of these spill files' lengths were counted + // towards our shuffle write count or shuffle write time. If we use the slow merge path, + // then the final output file's size won't necessarily be equal to the sum of the spill + // files' sizes. To guard against this case, we look at the output file's actual size when + // computing shuffle bytes written. + // + // We allow the individual merge methods to report their own IO times since different merge + // strategies use different IO techniques. We count IO during merge towards the shuffle + // shuffle write time, which appears to be consistent with the "not bypassing merge-sort" + // branch in ExternalSorter. + + // We do not perform a transferTo-optimized merge due to underground storage may not support + // this (NIO FileChannel.transferTo) + if (fastMergeEnabled && fastMergeIsSupported) { + // Compression is disabled or we are using an IO compression codec that supports + // decompression of concatenated compressed streams, so we can perform a fast spill merge + // that doesn't need to interpret the spilled bytes. + logger.debug("Using fileStream-based fast merge"); + partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null); + } else { + logger.debug("Using slow merge"); + partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec); + } + // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has + // in-memory records, we write out the in-memory records to a file but do not count that + // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs + // to be counted as shuffle write, but this will lead to double-counting of the final + // RemoteSpillInfo's bytes. + writeMetrics.decBytesWritten(fs.getFileStatus(spills[spills.length - 1].file).getLen()); + writeMetrics.incBytesWritten(fs.getFileStatus(outputFile).getLen()); + return partitionLengths; + } + } catch (IOException e) { + if (fs.exists(outputFile) && !fs.delete(outputFile, true)) { + logger.error("Unable to delete output file {}", outputFile.toString()); + } + throw e; + } + } + + /** + * Merges spill files using Java FileStreams. This code path is typically slower than + * the NIO-based merge, + * {@link RemoteUnsafeShuffleWriter#mergeSpillsWithTransferTo(RemoteSpillInfo[], + * Path)}, and it's mostly used in cases where the IO compression codec does not support + * concatenation of compressed data, when encryption is enabled, or when users have + * explicitly disabled use of {@code transferTo} in order to work around kernel bugs. + * This code path might also be faster in cases where individual partition size in a spill + * is small and RemoteUnsafeShuffleWriter#mergeSpillsWithTransferTo method performs many small + * disk ios which is inefficient. In those case, Using large buffers for input and output + * files helps reducing the number of disk ios, making the file merging faster. + * + * @param spills the spills to merge. + * @param outputFile the file to write the merged data to. + * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled. + * @return the partition lengths in the merged file. + */ + private long[] mergeSpillsWithFileStream( + RemoteSpillInfo[] spills, + Path outputFile, + @Nullable CompressionCodec compressionCodec) throws IOException { + assert (spills.length >= 2); + final FileSystem fs = RemoteShuffleManager.getFileSystem(); + final int numPartitions = partitioner.numPartitions(); + final long[] partitionLengths = new long[numPartitions]; + final InputStream[] spillInputStreams = new InputStream[spills.length]; + + final OutputStream bos = new BufferedOutputStream( + fs.create(outputFile), + outputBufferSizeInBytes); + // Use a counting output stream to avoid having to close the underlying file and ask + // the file system for its size after each partition is written. + final CountingOutputStream mergedFileOutputStream = new CountingOutputStream(bos); + + boolean threwException = true; + try { + for (int i = 0; i < spills.length; i++) { + // Note by Chenzhao: Originally NioBufferedFileInputStream is used + spillInputStreams[i] = new BufferedInputStream( + fs.open(spills[i].file), + inputBufferSizeInBytes); + } + for (int partition = 0; partition < numPartitions; partition++) { + final long initialFileLength = mergedFileOutputStream.getByteCount(); + // Shield the underlying output stream from close() and flush() calls, so that + // we can close the higher level streams to make sure all data is really flushed + // and internal state is cleaned. + OutputStream partitionOutput = + new RemoteUnsafeShuffleWriter.CloseAndFlushShieldOutputStream( + new TimeTrackingOutputStream(writeMetrics, mergedFileOutputStream)); + partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); + if (compressionCodec != null) { + partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); + } + for (int i = 0; i < spills.length; i++) { + final long partitionLengthInSpill = spills[i].partitionLengths[partition]; + if (partitionLengthInSpill > 0) { + InputStream partitionInputStream = new LimitedInputStream(spillInputStreams[i], + partitionLengthInSpill, false); + try { + partitionInputStream = blockManager.serializerManager().wrapForEncryption( + partitionInputStream); + if (compressionCodec != null) { + partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); + } + ByteStreams.copy(partitionInputStream, partitionOutput); + } finally { + partitionInputStream.close(); + } + } + } + partitionOutput.flush(); + partitionOutput.close(); + partitionLengths[partition] = (mergedFileOutputStream.getByteCount() - initialFileLength); + } + threwException = false; + } finally { + // To avoid masking exceptions that caused us to prematurely enter the finally block, only + // throw exceptions during cleanup if threwException == false. + for (InputStream stream : spillInputStreams) { + Closeables.close(stream, threwException); + } + Closeables.close(mergedFileOutputStream, threwException); + } + return partitionLengths; + } + + @Override + public Option stop(boolean success) { + try { + taskContext.taskMetrics().incPeakExecutionMemory(getPeakMemoryUsedBytes()); + + if (stopping) { + return Option.apply(null); + } else { + stopping = true; + if (success) { + if (mapStatus == null) { + throw new IllegalStateException("Cannot call stop(true) without having called write()"); + } + return Option.apply(mapStatus); + } else { + return Option.apply(null); + } + } + } finally { + if (sorter != null) { + // If sorter is non-null, then this implies that we called stop() in response to an error, + // so we need to clean up memory and spill files created by the sorter + try { + sorter.cleanupResources(); + } catch (IOException e) { + e.printStackTrace(); + } + } + } + } +} diff --git a/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/network/netty/MyNettyBlockRpcServer.scala b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/network/netty/MyNettyBlockRpcServer.scala new file mode 100644 index 000000000..1bf78bfef --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/network/netty/MyNettyBlockRpcServer.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.nio.ByteBuffer + +import scala.language.existentials + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.network.BlockDataManager +import org.apache.spark.network.client.{RpcResponseCallback, StreamCallbackWithID, TransportClient} +import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager} +import org.apache.spark.network.shuffle.protocol._ +import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.remote.{HadoopFileSegmentManagedBuffer, MessageForHadoopManagedBuffers, RemoteShuffleManager} +import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.storage.{BlockId, ShuffleBlockId} + +/** + * Serves requests to open blocks by simply registering one chunk per block requested. + * Handles opening and uploading arbitrary BlockManager blocks. + * + * Opened blocks are registered with the "one-for-one" strategy, meaning each Transport-layer Chunk + * is equivalent to one Spark-level shuffle block. + */ +class MyNettyBlockRpcServer( + appId: String, + serializer: Serializer, + blockManager: BlockDataManager) + extends RpcHandler with Logging { + + private val streamManager = new OneForOneStreamManager() + + override def receive( + client: TransportClient, + rpcMessage: ByteBuffer, + responseContext: RpcResponseCallback): Unit = { + val message = BlockTransferMessage.Decoder.fromByteBuffer(rpcMessage) + logTrace(s"Received request: $message") + + message match { + case openBlocks: OpenBlocks => + val blocksNum = openBlocks.blockIds.length + val isShuffleRequest = (blocksNum > 0) && + BlockId.apply(openBlocks.blockIds(0)).isInstanceOf[ShuffleBlockId] && + (SparkEnv.get.conf.get("spark.shuffle.manager", classOf[SortShuffleManager].getName) + == classOf[RemoteShuffleManager].getName) + if (isShuffleRequest) { + val blockIdAndManagedBufferPair = + openBlocks.blockIds.map(block => (block, blockManager.getHostLocalShuffleData( + BlockId.apply(block), Array.empty).asInstanceOf[HadoopFileSegmentManagedBuffer])) + responseContext.onSuccess(new MessageForHadoopManagedBuffers( + blockIdAndManagedBufferPair).toByteBuffer.nioBuffer()) + } else { + // This customized Netty RPC server is only served for RemoteShuffle requests, + // Other RPC messages or data chunks transferring should go through + // NettyBlockTransferService' NettyBlockRpcServer + throw new UnsupportedOperationException("MyNettyBlockRpcServer only serves remote" + + " shuffle requests for OpenBlocks") + } + + case uploadBlock: UploadBlock => + throw new UnsupportedOperationException("MyNettyBlockRpcServer doesn't serve UploadBlock") + } + } + + override def receiveStream( + client: TransportClient, + messageHeader: ByteBuffer, + responseContext: RpcResponseCallback): StreamCallbackWithID = { + throw new UnsupportedOperationException("MyNettyBlockRpcServer doesn't support receiving" + + " stream") + } + + override def getStreamManager(): StreamManager = streamManager +} diff --git a/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/network/netty/RemoteShuffleTransferService.scala b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/network/netty/RemoteShuffleTransferService.scala new file mode 100644 index 000000000..c067bbe9a --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/network/netty/RemoteShuffleTransferService.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.util.{HashMap => JHashMap, Map => JMap} + +import scala.collection.JavaConverters._ +import scala.concurrent.Future +import scala.reflect.ClassTag + +import com.codahale.metrics.{Metric, MetricSet} + +import org.apache.spark.{SecurityManager, SparkConf, SparkEnv} +import org.apache.spark.network.{BlockDataManager, BlockTransferService, TransportContext} +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.client.{TransportClientBootstrap, TransportClientFactory} +import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap} +import org.apache.spark.network.server.{TransportServer, TransportServerBootstrap} +import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, MyOneForOneBlockFetcher, RetryingBlockFetcher} +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.storage.{BlockId, BlockManagerId, StorageLevel} +import org.apache.spark.util.Utils + +// This is to support index cache feature without overriding Spark source code, which may cause +// compatibility issue when upgrading Spark. +// Instead, RemoteShuffle self-maintain a Customized netty transfer service solely for shuffle index +// files transferring +private[spark] class RemoteShuffleTransferService( + conf: SparkConf, + securityManager: SecurityManager, + bindAddress: String, + override val hostName: String, + _port: Int, + numCores: Int) extends BlockTransferService { + + // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. + private val serializer = new JavaSerializer(conf) + private val authEnabled = securityManager.isAuthenticationEnabled() + private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numCores) + + private[this] var transportContext: TransportContext = _ + private[this] var server: TransportServer = _ + private[this] var clientFactory: TransportClientFactory = _ + private[this] var appId: String = _ + + init(SparkEnv.get.blockManager) + + override def init(blockDataManager: BlockDataManager): Unit = { + val rpcHandler = new MyNettyBlockRpcServer(conf.getAppId, serializer, blockDataManager) + var serverBootstrap: Option[TransportServerBootstrap] = None + var clientBootstrap: Option[TransportClientBootstrap] = None + if (authEnabled) { + serverBootstrap = Some(new AuthServerBootstrap(transportConf, securityManager)) + clientBootstrap = Some(new AuthClientBootstrap(transportConf, conf.getAppId, securityManager)) + } + transportContext = new TransportContext(transportConf, rpcHandler) + clientFactory = transportContext.createClientFactory(clientBootstrap.toSeq.asJava) + server = createServer(serverBootstrap.toList) + appId = conf.getAppId + logInfo(s"Server created on ${hostName}:${server.getPort}") + } + + val getShuffleServerId: BlockManagerId = { + val id = SparkEnv.get.blockManager.blockManagerId + BlockManagerId(id.executorId, id.host, port) + } + + /** Creates and binds the TransportServer, possibly trying multiple ports. */ + private def createServer(bootstraps: List[TransportServerBootstrap]): TransportServer = { + def startService(port: Int): (TransportServer, Int) = { + val server = transportContext.createServer(bindAddress, port, bootstraps.asJava) + (server, server.getPort) + } + + Utils.startServiceOnPort(_port, startService, conf, getClass.getName)._1 + } + + override def shuffleMetrics(): MetricSet = { + require( + server != null && clientFactory != null, "RemoteShuffleTransferService is not initialized") + + new MetricSet { + val allMetrics = new JHashMap[String, Metric]() + override def getMetrics: JMap[String, Metric] = { + allMetrics.putAll(clientFactory.getAllMetrics.getMetrics) + allMetrics.putAll(server.getAllMetrics.getMetrics) + allMetrics + } + } + } + + override def fetchBlocks( + host: String, + port: Int, + execId: String, + blockIds: Array[String], + listener: BlockFetchingListener, + tempFileManager: DownloadFileManager): Unit = { + logTrace(s"Fetch blocks from $host:$port (executor id $execId)") + try { + val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { + override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { + val client = clientFactory.createClient(host, port) + new MyOneForOneBlockFetcher(client, appId, execId, blockIds, listener).start() + } + } + + val maxRetries = transportConf.maxIORetries() + if (maxRetries > 0) { + // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's + // a bug in this code. We should remove the if statement once we're sure of the stability. + new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start() + } else { + blockFetchStarter.createAndStart(blockIds, listener) + } + } catch { + case e: Exception => + logError("Exception while beginning fetchBlocks", e) + blockIds.foreach(listener.onBlockFetchFailure(_, e)) + } + } + + override def port: Int = server.getPort + + // Make this an empty implementation, because remote shuffle only uses the `fetchBlocks` function + override def uploadBlock( + hostname: String, + port: Int, + execId: String, + blockId: BlockId, + blockData: ManagedBuffer, + level: StorageLevel, + classTag: ClassTag[_]): Future[Unit] = ??? + + override def close(): Unit = { + if (server != null) { + server.close() + } + if (clientFactory != null) { + clientFactory.close() + } + } +} diff --git a/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala new file mode 100644 index 000000000..62d57655e --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -0,0 +1,2239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.io.NotSerializableException +import java.util.Properties +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} +import java.util.concurrent.atomic.AtomicInteger + +import scala.annotation.tailrec +import scala.collection.Map +import scala.collection.mutable +import scala.collection.mutable.{HashMap, HashSet, ListBuffer} +import scala.concurrent.duration._ +import scala.util.control.NonFatal + +import org.apache.spark._ +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config +import org.apache.spark.internal.config.Tests.TEST_NO_STAGE_RETRY +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} +import org.apache.spark.rdd.{DeterministicLevel, RDD, RDDCheckpointData} +import org.apache.spark.rpc.RpcTimeout +import org.apache.spark.shuffle.remote.RemoteShuffleManager +import org.apache.spark.storage._ +import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat +import org.apache.spark.util._ + +/** + * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of + * stages for each job, keeps track of which RDDs and stage outputs are materialized, and finds a + * minimal schedule to run the job. It then submits stages as TaskSets to an underlying + * TaskScheduler implementation that runs them on the cluster. A TaskSet contains fully independent + * tasks that can run right away based on the data that's already on the cluster (e.g. map output + * files from previous stages), though it may fail if this data becomes unavailable. + * + * Spark stages are created by breaking the RDD graph at shuffle boundaries. RDD operations with + * "narrow" dependencies, like map() and filter(), are pipelined together into one set of tasks + * in each stage, but operations with shuffle dependencies require multiple stages (one to write a + * set of map output files, and another to read those files after a barrier). In the end, every + * stage will have only shuffle dependencies on other stages, and may compute multiple operations + * inside it. The actual pipelining of these operations happens in the RDD.compute() functions of + * various RDDs + * + * In addition to coming up with a DAG of stages, the DAGScheduler also determines the preferred + * locations to run each task on, based on the current cache status, and passes these to the + * low-level TaskScheduler. Furthermore, it handles failures due to shuffle output files being + * lost, in which case old stages may need to be resubmitted. Failures *within* a stage that are + * not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task + * a small number of times before cancelling the whole stage. + * + * When looking through this code, there are several key concepts: + * + * - Jobs (represented by [[ActiveJob]]) are the top-level work items submitted to the scheduler. + * For example, when the user calls an action, like count(), a job will be submitted through + * submitJob. Each Job may require the execution of multiple stages to build intermediate data. + * + * - Stages ([[Stage]]) are sets of tasks that compute intermediate results in jobs, where each + * task computes the same function on partitions of the same RDD. Stages are separated at shuffle + * boundaries, which introduce a barrier (where we must wait for the previous stage to finish to + * fetch outputs). There are two types of stages: [[ResultStage]], for the final stage that + * executes an action, and [[ShuffleMapStage]], which writes map output files for a shuffle. + * Stages are often shared across multiple jobs, if these jobs reuse the same RDDs. + * + * - Tasks are individual units of work, each sent to one machine. + * + * - Cache tracking: the DAGScheduler figures out which RDDs are cached to avoid recomputing them + * and likewise remembers which shuffle map stages have already produced output files to avoid + * redoing the map side of a shuffle. + * + * - Preferred locations: the DAGScheduler also computes where to run each task in a stage based + * on the preferred locations of its underlying RDDs, or the location of cached or shuffle data. + * + * - Cleanup: all data structures are cleared when the running jobs that depend on them finish, + * to prevent memory leaks in a long-running application. + * + * To recover from failures, the same stage might need to run multiple times, which are called + * "attempts". If the TaskScheduler reports that a task failed because a map output file from a + * previous stage was lost, the DAGScheduler resubmits that lost stage. This is detected through a + * CompletionEvent with FetchFailed, or an ExecutorLost event. The DAGScheduler will wait a small + * amount of time to see whether other nodes or tasks fail, then resubmit TaskSets for any lost + * stage(s) that compute the missing tasks. As part of this process, we might also have to create + * Stage objects for old (finished) stages where we previously cleaned up the Stage object. Since + * tasks from the old attempt of a stage could still be running, care must be taken to map any + * events received in the correct Stage object. + * + * Here's a checklist to use when making or reviewing changes to this class: + * + * - All data structures should be cleared when the jobs involving them end to avoid indefinite + * accumulation of state in long-running programs. + * + * - When adding a new data structure, update `DAGSchedulerSuite.assertDataStructuresEmpty` to + * include the new structure. This will help to catch memory leaks. + */ +private[spark] class DAGScheduler( + private[scheduler] val sc: SparkContext, + private[scheduler] val taskScheduler: TaskScheduler, + listenerBus: LiveListenerBus, + mapOutputTracker: MapOutputTrackerMaster, + blockManagerMaster: BlockManagerMaster, + env: SparkEnv, + clock: Clock = new SystemClock()) + extends Logging { + + def this(sc: SparkContext, taskScheduler: TaskScheduler) = { + this( + sc, + taskScheduler, + sc.listenerBus, + sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], + sc.env.blockManager.master, + sc.env) + } + + def this(sc: SparkContext) = this(sc, sc.taskScheduler) + + private[spark] val metricsSource: DAGSchedulerSource = new DAGSchedulerSource(this) + + private[scheduler] val nextJobId = new AtomicInteger(0) + private[scheduler] def numTotalJobs: Int = nextJobId.get() + private val nextStageId = new AtomicInteger(0) + + private[scheduler] val jobIdToStageIds = new HashMap[Int, HashSet[Int]] + private[scheduler] val stageIdToStage = new HashMap[Int, Stage] + /** + * Mapping from shuffle dependency ID to the ShuffleMapStage that will generate the data for + * that dependency. Only includes stages that are part of currently running job (when the job(s) + * that require the shuffle stage complete, the mapping will be removed, and the only record of + * the shuffle data will be in the MapOutputTracker). + */ + private[scheduler] val shuffleIdToMapStage = new HashMap[Int, ShuffleMapStage] + private[scheduler] val jobIdToActiveJob = new HashMap[Int, ActiveJob] + + // Stages we need to run whose parents aren't done + private[scheduler] val waitingStages = new HashSet[Stage] + + // Stages we are running right now + private[scheduler] val runningStages = new HashSet[Stage] + + // Stages that must be resubmitted due to fetch failures + private[scheduler] val failedStages = new HashSet[Stage] + + private[scheduler] val activeJobs = new HashSet[ActiveJob] + + /** + * Contains the locations that each RDD's partitions are cached on. This map's keys are RDD ids + * and its values are arrays indexed by partition numbers. Each array value is the set of + * locations where that RDD partition is cached. + * + * All accesses to this map should be guarded by synchronizing on it (see SPARK-4454). + */ + private val cacheLocs = new HashMap[Int, IndexedSeq[Seq[TaskLocation]]] + + // For tracking failed nodes, we use the MapOutputTracker's epoch number, which is sent with + // every task. When we detect a node failing, we note the current epoch number and failed + // executor, increment it for new tasks, and use this to ignore stray ShuffleMapTask results. + // + // TODO: Garbage collect information about failure epochs when we know there are no more + // stray messages to detect. + private val failedEpoch = new HashMap[String, Long] + + private [scheduler] val outputCommitCoordinator = env.outputCommitCoordinator + + // A closure serializer that we reuse. + // This is only safe because DAGScheduler runs in a single thread. + private val closureSerializer = SparkEnv.get.closureSerializer.newInstance() + + /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */ + private val disallowStageRetryForTest = sc.getConf.get(TEST_NO_STAGE_RETRY) + + /** + * Whether to unregister all the outputs on the host in condition that we receive a FetchFailure, + * this is set default to false, which means, we only unregister the outputs related to the exact + * executor(instead of the host) on a FetchFailure. + */ + private[scheduler] val unRegisterOutputOnHostOnFetchFailure = + sc.getConf.get(config.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE) + + /** + * Number of consecutive stage attempts allowed before a stage is aborted. + */ + private[scheduler] val maxConsecutiveStageAttempts = + sc.getConf.getInt("spark.stage.maxConsecutiveAttempts", + DAGScheduler.DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS) + + /** + * Number of max concurrent tasks check failures for each barrier job. + */ + private[scheduler] val barrierJobIdToNumTasksCheckFailures = new ConcurrentHashMap[Int, Int] + + /** + * Time in seconds to wait between a max concurrent tasks check failure and the next check. + */ + private val timeIntervalNumTasksCheck = sc.getConf + .get(config.BARRIER_MAX_CONCURRENT_TASKS_CHECK_INTERVAL) + + /** + * Max number of max concurrent tasks check failures allowed for a job before fail the job + * submission. + */ + private val maxFailureNumTasksCheck = sc.getConf + .get(config.BARRIER_MAX_CONCURRENT_TASKS_CHECK_MAX_FAILURES) + + private val messageScheduler = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("dag-scheduler-message") + + private[spark] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) + taskScheduler.setDAGScheduler(this) + + /** + * Called by the TaskSetManager to report task's starting. + */ + def taskStarted(task: Task[_], taskInfo: TaskInfo): Unit = { + eventProcessLoop.post(BeginEvent(task, taskInfo)) + } + + /** + * Called by the TaskSetManager to report that a task has completed + * and results are being fetched remotely. + */ + def taskGettingResult(taskInfo: TaskInfo): Unit = { + eventProcessLoop.post(GettingResultEvent(taskInfo)) + } + + /** + * Called by the TaskSetManager to report task completions or failures. + */ + def taskEnded( + task: Task[_], + reason: TaskEndReason, + result: Any, + accumUpdates: Seq[AccumulatorV2[_, _]], + metricPeaks: Array[Long], + taskInfo: TaskInfo): Unit = { + eventProcessLoop.post( + CompletionEvent(task, reason, result, accumUpdates, metricPeaks, taskInfo)) + } + + /** + * Update metrics for in-progress tasks and let the master know that the BlockManager is still + * alive. Return true if the driver knows about the given block manager. Otherwise, return false, + * indicating that the block manager should re-register. + */ + def executorHeartbeatReceived( + execId: String, + // (taskId, stageId, stageAttemptId, accumUpdates) + accumUpdates: Array[(Long, Int, Int, Seq[AccumulableInfo])], + blockManagerId: BlockManagerId, + // (stageId, stageAttemptId) -> metrics + executorUpdates: mutable.Map[(Int, Int), ExecutorMetrics]): Boolean = { + listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, accumUpdates, + executorUpdates)) + blockManagerMaster.driverHeartbeatEndPoint.askSync[Boolean]( + BlockManagerHeartbeat(blockManagerId), new RpcTimeout(10.minutes, "BlockManagerHeartbeat")) + } + + /** + * Called by TaskScheduler implementation when an executor fails. + */ + def executorLost(execId: String, reason: ExecutorLossReason): Unit = { + eventProcessLoop.post(ExecutorLost(execId, reason)) + } + + /** + * Called by TaskScheduler implementation when a worker is removed. + */ + def workerRemoved(workerId: String, host: String, message: String): Unit = { + eventProcessLoop.post(WorkerRemoved(workerId, host, message)) + } + + /** + * Called by TaskScheduler implementation when a host is added. + */ + def executorAdded(execId: String, host: String): Unit = { + eventProcessLoop.post(ExecutorAdded(execId, host)) + } + + /** + * Called by the TaskSetManager to cancel an entire TaskSet due to either repeated failures or + * cancellation of the job itself. + */ + def taskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable]): Unit = { + eventProcessLoop.post(TaskSetFailed(taskSet, reason, exception)) + } + + /** + * Called by the TaskSetManager when it decides a speculative task is needed. + */ + def speculativeTaskSubmitted(task: Task[_]): Unit = { + eventProcessLoop.post(SpeculativeTaskSubmitted(task)) + } + + private[scheduler] + def getCacheLocs(rdd: RDD[_]): IndexedSeq[Seq[TaskLocation]] = cacheLocs.synchronized { + // Note: this doesn't use `getOrElse()` because this method is called O(num tasks) times + if (!cacheLocs.contains(rdd.id)) { + // Note: if the storage level is NONE, we don't need to get locations from block manager. + val locs: IndexedSeq[Seq[TaskLocation]] = if (rdd.getStorageLevel == StorageLevel.NONE) { + IndexedSeq.fill(rdd.partitions.length)(Nil) + } else { + val blockIds = + rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId] + blockManagerMaster.getLocations(blockIds).map { bms => + bms.map(bm => TaskLocation(bm.host, bm.executorId)) + } + } + cacheLocs(rdd.id) = locs + } + cacheLocs(rdd.id) + } + + private def clearCacheLocs(): Unit = cacheLocs.synchronized { + cacheLocs.clear() + } + + /** + * Gets a shuffle map stage if one exists in shuffleIdToMapStage. Otherwise, if the + * shuffle map stage doesn't already exist, this method will create the shuffle map stage in + * addition to any missing ancestor shuffle map stages. + */ + private def getOrCreateShuffleMapStage( + shuffleDep: ShuffleDependency[_, _, _], + firstJobId: Int): ShuffleMapStage = { + shuffleIdToMapStage.get(shuffleDep.shuffleId) match { + case Some(stage) => + stage + + case None => + // Create stages for all missing ancestor shuffle dependencies. + getMissingAncestorShuffleDependencies(shuffleDep.rdd).foreach { dep => + // Even though getMissingAncestorShuffleDependencies only returns shuffle dependencies + // that were not already in shuffleIdToMapStage, it's possible that by the time we + // get to a particular dependency in the foreach loop, it's been added to + // shuffleIdToMapStage by the stage creation process for an earlier dependency. See + // SPARK-13902 for more information. + if (!shuffleIdToMapStage.contains(dep.shuffleId)) { + createShuffleMapStage(dep, firstJobId) + } + } + // Finally, create a stage for the given shuffle dependency. + createShuffleMapStage(shuffleDep, firstJobId) + } + } + + /** + * Check to make sure we don't launch a barrier stage with unsupported RDD chain pattern. The + * following patterns are not supported: + * 1. Ancestor RDDs that have different number of partitions from the resulting RDD (eg. + * union()/coalesce()/first()/take()/PartitionPruningRDD); + * 2. An RDD that depends on multiple barrier RDDs (eg. barrierRdd1.zip(barrierRdd2)). + */ + private def checkBarrierStageWithRDDChainPattern(rdd: RDD[_], numTasksInStage: Int): Unit = { + if (rdd.isBarrier() && + !traverseParentRDDsWithinStage(rdd, (r: RDD[_]) => + r.getNumPartitions == numTasksInStage && + r.dependencies.count(_.rdd.isBarrier()) <= 1)) { + throw new BarrierJobUnsupportedRDDChainException + } + } + + /** + * Creates a ShuffleMapStage that generates the given shuffle dependency's partitions. If a + * previously run stage generated the same shuffle data, this function will copy the output + * locations that are still available from the previous shuffle to avoid unnecessarily + * regenerating data. + */ + def createShuffleMapStage[K, V, C]( + shuffleDep: ShuffleDependency[K, V, C], jobId: Int): ShuffleMapStage = { + val rdd = shuffleDep.rdd + checkBarrierStageWithDynamicAllocation(rdd) + checkBarrierStageWithNumSlots(rdd) + checkBarrierStageWithRDDChainPattern(rdd, rdd.getNumPartitions) + val numTasks = rdd.partitions.length + val parents = getOrCreateParentStages(rdd, jobId) + val id = nextStageId.getAndIncrement() + val stage = new ShuffleMapStage( + id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep, mapOutputTracker) + + stageIdToStage(id) = stage + shuffleIdToMapStage(shuffleDep.shuffleId) = stage + updateJobIdStageIdMaps(jobId, stage) + + if (!mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { + // Kind of ugly: need to register RDDs with the cache and map output tracker here + // since we can't do it in the RDD constructor because # of partitions is unknown + logInfo(s"Registering RDD ${rdd.id} (${rdd.getCreationSite}) as input to " + + s"shuffle ${shuffleDep.shuffleId}") + mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.length) + } + stage + } + + /** + * We don't support run a barrier stage with dynamic resource allocation enabled, it shall lead + * to some confusing behaviors (eg. with dynamic resource allocation enabled, it may happen that + * we acquire some executors (but not enough to launch all the tasks in a barrier stage) and + * later release them due to executor idle time expire, and then acquire again). + * + * We perform the check on job submit and fail fast if running a barrier stage with dynamic + * resource allocation enabled. + * + * TODO SPARK-24942 Improve cluster resource management with jobs containing barrier stage + */ + private def checkBarrierStageWithDynamicAllocation(rdd: RDD[_]): Unit = { + if (rdd.isBarrier() && Utils.isDynamicAllocationEnabled(sc.getConf)) { + throw new BarrierJobRunWithDynamicAllocationException + } + } + + /** + * Check whether the barrier stage requires more slots (to be able to launch all tasks in the + * barrier stage together) than the total number of active slots currently. Fail current check + * if trying to submit a barrier stage that requires more slots than current total number. If + * the check fails consecutively beyond a configured number for a job, then fail current job + * submission. + */ + private def checkBarrierStageWithNumSlots(rdd: RDD[_]): Unit = { + val numPartitions = rdd.getNumPartitions + val maxNumConcurrentTasks = sc.maxNumConcurrentTasks + if (rdd.isBarrier() && numPartitions > maxNumConcurrentTasks) { + throw new BarrierJobSlotsNumberCheckFailed(numPartitions, maxNumConcurrentTasks) + } + } + + /** + * Create a ResultStage associated with the provided jobId. + */ + private def createResultStage( + rdd: RDD[_], + func: (TaskContext, Iterator[_]) => _, + partitions: Array[Int], + jobId: Int, + callSite: CallSite): ResultStage = { + checkBarrierStageWithDynamicAllocation(rdd) + checkBarrierStageWithNumSlots(rdd) + checkBarrierStageWithRDDChainPattern(rdd, partitions.toSet.size) + val parents = getOrCreateParentStages(rdd, jobId) + val id = nextStageId.getAndIncrement() + val stage = new ResultStage(id, rdd, func, partitions, parents, jobId, callSite) + stageIdToStage(id) = stage + updateJobIdStageIdMaps(jobId, stage) + stage + } + + /** + * Get or create the list of parent stages for a given RDD. The new Stages will be created with + * the provided firstJobId. + */ + private def getOrCreateParentStages(rdd: RDD[_], firstJobId: Int): List[Stage] = { + getShuffleDependencies(rdd).map { shuffleDep => + getOrCreateShuffleMapStage(shuffleDep, firstJobId) + }.toList + } + + /** Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet */ + private def getMissingAncestorShuffleDependencies( + rdd: RDD[_]): ListBuffer[ShuffleDependency[_, _, _]] = { + val ancestors = new ListBuffer[ShuffleDependency[_, _, _]] + val visited = new HashSet[RDD[_]] + // We are manually maintaining a stack here to prevent StackOverflowError + // caused by recursively visiting + val waitingForVisit = new ListBuffer[RDD[_]] + waitingForVisit += rdd + while (waitingForVisit.nonEmpty) { + val toVisit = waitingForVisit.remove(0) + if (!visited(toVisit)) { + visited += toVisit + getShuffleDependencies(toVisit).foreach { shuffleDep => + if (!shuffleIdToMapStage.contains(shuffleDep.shuffleId)) { + ancestors.prepend(shuffleDep) + waitingForVisit.prepend(shuffleDep.rdd) + } // Otherwise, the dependency and its ancestors have already been registered. + } + } + } + ancestors + } + + /** + * Returns shuffle dependencies that are immediate parents of the given RDD. + * + * This function will not return more distant ancestors. For example, if C has a shuffle + * dependency on B which has a shuffle dependency on A: + * + * A <-- B <-- C + * + * calling this function with rdd C will only return the B <-- C dependency. + * + * This function is scheduler-visible for the purpose of unit testing. + */ + private[scheduler] def getShuffleDependencies( + rdd: RDD[_]): HashSet[ShuffleDependency[_, _, _]] = { + val parents = new HashSet[ShuffleDependency[_, _, _]] + val visited = new HashSet[RDD[_]] + val waitingForVisit = new ListBuffer[RDD[_]] + waitingForVisit += rdd + while (waitingForVisit.nonEmpty) { + val toVisit = waitingForVisit.remove(0) + if (!visited(toVisit)) { + visited += toVisit + toVisit.dependencies.foreach { + case shuffleDep: ShuffleDependency[_, _, _] => + parents += shuffleDep + case dependency => + waitingForVisit.prepend(dependency.rdd) + } + } + } + parents + } + + /** + * Traverses the given RDD and its ancestors within the same stage and checks whether all of the + * RDDs satisfy a given predicate. + */ + private def traverseParentRDDsWithinStage(rdd: RDD[_], predicate: RDD[_] => Boolean): Boolean = { + val visited = new HashSet[RDD[_]] + val waitingForVisit = new ListBuffer[RDD[_]] + waitingForVisit += rdd + while (waitingForVisit.nonEmpty) { + val toVisit = waitingForVisit.remove(0) + if (!visited(toVisit)) { + if (!predicate(toVisit)) { + return false + } + visited += toVisit + toVisit.dependencies.foreach { + case _: ShuffleDependency[_, _, _] => + // Not within the same stage with current rdd, do nothing. + case dependency => + waitingForVisit.prepend(dependency.rdd) + } + } + } + true + } + + private def getMissingParentStages(stage: Stage): List[Stage] = { + val missing = new HashSet[Stage] + val visited = new HashSet[RDD[_]] + // We are manually maintaining a stack here to prevent StackOverflowError + // caused by recursively visiting + val waitingForVisit = new ListBuffer[RDD[_]] + waitingForVisit += stage.rdd + def visit(rdd: RDD[_]): Unit = { + if (!visited(rdd)) { + visited += rdd + val rddHasUncachedPartitions = getCacheLocs(rdd).contains(Nil) + if (rddHasUncachedPartitions) { + for (dep <- rdd.dependencies) { + dep match { + case shufDep: ShuffleDependency[_, _, _] => + val mapStage = getOrCreateShuffleMapStage(shufDep, stage.firstJobId) + if (!mapStage.isAvailable) { + missing += mapStage + } + case narrowDep: NarrowDependency[_] => + waitingForVisit.prepend(narrowDep.rdd) + } + } + } + } + } + while (waitingForVisit.nonEmpty) { + visit(waitingForVisit.remove(0)) + } + missing.toList + } + + /** + * Registers the given jobId among the jobs that need the given stage and + * all of that stage's ancestors. + */ + private def updateJobIdStageIdMaps(jobId: Int, stage: Stage): Unit = { + @tailrec + def updateJobIdStageIdMapsList(stages: List[Stage]): Unit = { + if (stages.nonEmpty) { + val s = stages.head + s.jobIds += jobId + jobIdToStageIds.getOrElseUpdate(jobId, new HashSet[Int]()) += s.id + val parentsWithoutThisJobId = s.parents.filter { ! _.jobIds.contains(jobId) } + updateJobIdStageIdMapsList(parentsWithoutThisJobId ++ stages.tail) + } + } + updateJobIdStageIdMapsList(List(stage)) + } + + /** + * Removes state for job and any stages that are not needed by any other job. Does not + * handle cancelling tasks or notifying the SparkListener about finished jobs/stages/tasks. + * + * @param job The job whose state to cleanup. + */ + private def cleanupStateForJobAndIndependentStages(job: ActiveJob): Unit = { + val registeredStages = jobIdToStageIds.get(job.jobId) + if (registeredStages.isEmpty || registeredStages.get.isEmpty) { + logError("No stages registered for job " + job.jobId) + } else { + stageIdToStage.filterKeys(stageId => registeredStages.get.contains(stageId)).foreach { + case (stageId, stage) => + val jobSet = stage.jobIds + if (!jobSet.contains(job.jobId)) { + logError( + "Job %d not registered for stage %d even though that stage was registered for the job" + .format(job.jobId, stageId)) + } else { + def removeStage(stageId: Int): Unit = { + // data structures based on Stage + for (stage <- stageIdToStage.get(stageId)) { + if (runningStages.contains(stage)) { + logDebug("Removing running stage %d".format(stageId)) + runningStages -= stage + } + for ((k, v) <- shuffleIdToMapStage.find(_._2 == stage)) { + shuffleIdToMapStage.remove(k) + } + if (waitingStages.contains(stage)) { + logDebug("Removing stage %d from waiting set.".format(stageId)) + waitingStages -= stage + } + if (failedStages.contains(stage)) { + logDebug("Removing stage %d from failed set.".format(stageId)) + failedStages -= stage + } + } + // data structures based on StageId + stageIdToStage -= stageId + logDebug("After removal of stage %d, remaining stages = %d" + .format(stageId, stageIdToStage.size)) + } + + jobSet -= job.jobId + if (jobSet.isEmpty) { // no other job needs this stage + removeStage(stageId) + } + } + } + } + jobIdToStageIds -= job.jobId + jobIdToActiveJob -= job.jobId + activeJobs -= job + job.finalStage match { + case r: ResultStage => r.removeActiveJob() + case m: ShuffleMapStage => m.removeActiveJob(job) + } + } + + /** + * Submit an action job to the scheduler. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param partitions set of partitions to run on; some jobs may not want to compute on all + * partitions of the target RDD, e.g. for operations like first() + * @param callSite where in the user program this job was called + * @param resultHandler callback to pass each result to + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + * + * @return a JobWaiter object that can be used to block until the job finishes executing + * or can be used to cancel the job. + * + * @throws IllegalArgumentException when partitions ids are illegal + */ + def submitJob[T, U]( + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + partitions: Seq[Int], + callSite: CallSite, + resultHandler: (Int, U) => Unit, + properties: Properties): JobWaiter[U] = { + // Check to make sure we are not launching a task on a partition that does not exist. + val maxPartitions = rdd.partitions.length + partitions.find(p => p >= maxPartitions || p < 0).foreach { p => + throw new IllegalArgumentException( + "Attempting to access a non-existent partition: " + p + ". " + + "Total number of partitions: " + maxPartitions) + } + + val jobId = nextJobId.getAndIncrement() + if (partitions.isEmpty) { + val clonedProperties = Utils.cloneProperties(properties) + if (sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) == null) { + clonedProperties.setProperty(SparkContext.SPARK_JOB_DESCRIPTION, callSite.shortForm) + } + val time = clock.getTimeMillis() + listenerBus.post( + SparkListenerJobStart(jobId, time, Seq.empty, clonedProperties)) + listenerBus.post( + SparkListenerJobEnd(jobId, time, JobSucceeded)) + // Return immediately if the job is running 0 tasks + return new JobWaiter[U](this, jobId, 0, resultHandler) + } + + assert(partitions.nonEmpty) + val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] + val waiter = new JobWaiter[U](this, jobId, partitions.size, resultHandler) + eventProcessLoop.post(JobSubmitted( + jobId, rdd, func2, partitions.toArray, callSite, waiter, + Utils.cloneProperties(properties))) + waiter + } + + /** + * Run an action job on the given RDD and pass all the results to the resultHandler function as + * they arrive. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param partitions set of partitions to run on; some jobs may not want to compute on all + * partitions of the target RDD, e.g. for operations like first() + * @param callSite where in the user program this job was called + * @param resultHandler callback to pass each result to + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + * + * @note Throws `Exception` when the job fails + */ + def runJob[T, U]( + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + partitions: Seq[Int], + callSite: CallSite, + resultHandler: (Int, U) => Unit, + properties: Properties): Unit = { + val start = System.nanoTime + val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties) + ThreadUtils.awaitReady(waiter.completionFuture, Duration.Inf) + waiter.completionFuture.value.get match { + case scala.util.Success(_) => + logInfo("Job %d finished: %s, took %f s".format + (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) + case scala.util.Failure(exception) => + logInfo("Job %d failed: %s, took %f s".format + (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) + // SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler. + val callerStackTrace = Thread.currentThread().getStackTrace.tail + exception.setStackTrace(exception.getStackTrace ++ callerStackTrace) + throw exception + } + } + + /** + * Run an approximate job on the given RDD and pass all the results to an ApproximateEvaluator + * as they arrive. Returns a partial result object from the evaluator. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param evaluator `ApproximateEvaluator` to receive the partial results + * @param callSite where in the user program this job was called + * @param timeout maximum time to wait for the job, in milliseconds + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + */ + def runApproximateJob[T, U, R]( + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + evaluator: ApproximateEvaluator[U, R], + callSite: CallSite, + timeout: Long, + properties: Properties): PartialResult[R] = { + val jobId = nextJobId.getAndIncrement() + if (rdd.partitions.isEmpty) { + // Return immediately if the job is running 0 tasks + val time = clock.getTimeMillis() + listenerBus.post(SparkListenerJobStart(jobId, time, Seq[StageInfo](), properties)) + listenerBus.post(SparkListenerJobEnd(jobId, time, JobSucceeded)) + return new PartialResult(evaluator.currentResult(), true) + } + val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) + val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] + eventProcessLoop.post(JobSubmitted( + jobId, rdd, func2, rdd.partitions.indices.toArray, callSite, listener, + Utils.cloneProperties(properties))) + listener.awaitResult() // Will throw an exception if the job fails + } + + /** + * Submit a shuffle map stage to run independently and get a JobWaiter object back. The waiter + * can be used to block until the job finishes executing or can be used to cancel the job. + * This method is used for adaptive query planning, to run map stages and look at statistics + * about their outputs before submitting downstream stages. + * + * @param dependency the ShuffleDependency to run a map stage for + * @param callback function called with the result of the job, which in this case will be a + * single MapOutputStatistics object showing how much data was produced for each partition + * @param callSite where in the user program this job was submitted + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + */ + def submitMapStage[K, V, C]( + dependency: ShuffleDependency[K, V, C], + callback: MapOutputStatistics => Unit, + callSite: CallSite, + properties: Properties): JobWaiter[MapOutputStatistics] = { + + val rdd = dependency.rdd + val jobId = nextJobId.getAndIncrement() + if (rdd.partitions.length == 0) { + throw new SparkException("Can't run submitMapStage on RDD with 0 partitions") + } + + // We create a JobWaiter with only one "task", which will be marked as complete when the whole + // map stage has completed, and will be passed the MapOutputStatistics for that stage. + // This makes it easier to avoid race conditions between the user code and the map output + // tracker that might result if we told the user the stage had finished, but then they queries + // the map output tracker and some node failures had caused the output statistics to be lost. + val waiter = new JobWaiter[MapOutputStatistics]( + this, jobId, 1, + (_: Int, r: MapOutputStatistics) => callback(r)) + eventProcessLoop.post(MapStageSubmitted( + jobId, dependency, callSite, waiter, Utils.cloneProperties(properties))) + waiter + } + + /** + * Cancel a job that is running or waiting in the queue. + */ + def cancelJob(jobId: Int, reason: Option[String]): Unit = { + logInfo("Asked to cancel job " + jobId) + eventProcessLoop.post(JobCancelled(jobId, reason)) + } + + /** + * Cancel all jobs in the given job group ID. + */ + def cancelJobGroup(groupId: String): Unit = { + logInfo("Asked to cancel job group " + groupId) + eventProcessLoop.post(JobGroupCancelled(groupId)) + } + + /** + * Cancel all jobs that are running or waiting in the queue. + */ + def cancelAllJobs(): Unit = { + eventProcessLoop.post(AllJobsCancelled) + } + + private[scheduler] def doCancelAllJobs(): Unit = { + // Cancel all running jobs. + runningStages.map(_.firstJobId).foreach(handleJobCancellation(_, + Option("as part of cancellation of all jobs"))) + activeJobs.clear() // These should already be empty by this point, + jobIdToActiveJob.clear() // but just in case we lost track of some jobs... + } + + /** + * Cancel all jobs associated with a running or scheduled stage. + */ + def cancelStage(stageId: Int, reason: Option[String]): Unit = { + eventProcessLoop.post(StageCancelled(stageId, reason)) + } + + /** + * Kill a given task. It will be retried. + * + * @return Whether the task was successfully killed. + */ + def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean = { + taskScheduler.killTaskAttempt(taskId, interruptThread, reason) + } + + /** + * Resubmit any failed stages. Ordinarily called after a small amount of time has passed since + * the last fetch failure. + */ + private[scheduler] def resubmitFailedStages(): Unit = { + if (failedStages.nonEmpty) { + // Failed stages may be removed by job cancellation, so failed might be empty even if + // the ResubmitFailedStages event has been scheduled. + logInfo("Resubmitting failed stages") + clearCacheLocs() + val failedStagesCopy = failedStages.toArray + failedStages.clear() + for (stage <- failedStagesCopy.sortBy(_.firstJobId)) { + submitStage(stage) + } + } + } + + /** + * Check for waiting stages which are now eligible for resubmission. + * Submits stages that depend on the given parent stage. Called when the parent stage completes + * successfully. + */ + private def submitWaitingChildStages(parent: Stage): Unit = { + logTrace(s"Checking if any dependencies of $parent are now runnable") + logTrace("running: " + runningStages) + logTrace("waiting: " + waitingStages) + logTrace("failed: " + failedStages) + val childStages = waitingStages.filter(_.parents.contains(parent)).toArray + waitingStages --= childStages + for (stage <- childStages.sortBy(_.firstJobId)) { + submitStage(stage) + } + } + + /** Finds the earliest-created active job that needs the stage */ + // TODO: Probably should actually find among the active jobs that need this + // stage the one with the highest priority (highest-priority pool, earliest created). + // That should take care of at least part of the priority inversion problem with + // cross-job dependencies. + private def activeJobForStage(stage: Stage): Option[Int] = { + val jobsThatUseStage: Array[Int] = stage.jobIds.toArray.sorted + jobsThatUseStage.find(jobIdToActiveJob.contains) + } + + private[scheduler] def handleJobGroupCancelled(groupId: String): Unit = { + // Cancel all jobs belonging to this job group. + // First finds all active jobs with this group id, and then kill stages for them. + val activeInGroup = activeJobs.filter { activeJob => + Option(activeJob.properties).exists { + _.getProperty(SparkContext.SPARK_JOB_GROUP_ID) == groupId + } + } + val jobIds = activeInGroup.map(_.jobId) + jobIds.foreach(handleJobCancellation(_, + Option("part of cancelled job group %s".format(groupId)))) + } + + private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo): Unit = { + // Note that there is a chance that this task is launched after the stage is cancelled. + // In that case, we wouldn't have the stage anymore in stageIdToStage. + val stageAttemptId = + stageIdToStage.get(task.stageId).map(_.latestInfo.attemptNumber).getOrElse(-1) + listenerBus.post(SparkListenerTaskStart(task.stageId, stageAttemptId, taskInfo)) + } + + private[scheduler] def handleSpeculativeTaskSubmitted(task: Task[_]): Unit = { + listenerBus.post(SparkListenerSpeculativeTaskSubmitted(task.stageId, task.stageAttemptId)) + } + + private[scheduler] def handleTaskSetFailed( + taskSet: TaskSet, + reason: String, + exception: Option[Throwable]): Unit = { + stageIdToStage.get(taskSet.stageId).foreach { abortStage(_, reason, exception) } + } + + private[scheduler] def cleanUpAfterSchedulerStop(): Unit = { + for (job <- activeJobs) { + val error = + new SparkException(s"Job ${job.jobId} cancelled because SparkContext was shut down") + job.listener.jobFailed(error) + // Tell the listeners that all of the running stages have ended. Don't bother + // cancelling the stages because if the DAG scheduler is stopped, the entire application + // is in the process of getting stopped. + val stageFailedMessage = "Stage cancelled because SparkContext was shut down" + // The `toArray` here is necessary so that we don't iterate over `runningStages` while + // mutating it. + runningStages.toArray.foreach { stage => + markStageAsFinished(stage, Some(stageFailedMessage)) + } + listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobFailed(error))) + } + } + + private[scheduler] def handleGetTaskResult(taskInfo: TaskInfo): Unit = { + listenerBus.post(SparkListenerTaskGettingResult(taskInfo)) + } + + private[scheduler] def handleJobSubmitted(jobId: Int, + finalRDD: RDD[_], + func: (TaskContext, Iterator[_]) => _, + partitions: Array[Int], + callSite: CallSite, + listener: JobListener, + properties: Properties): Unit = { + var finalStage: ResultStage = null + try { + // New stage creation may throw an exception if, for example, jobs are run on a + // HadoopRDD whose underlying HDFS files have been deleted. + finalStage = createResultStage(finalRDD, func, partitions, jobId, callSite) + } catch { + case e: BarrierJobSlotsNumberCheckFailed => + // If jobId doesn't exist in the map, Scala coverts its value null to 0: Int automatically. + val numCheckFailures = barrierJobIdToNumTasksCheckFailures.compute(jobId, + (_: Int, value: Int) => value + 1) + + logWarning(s"Barrier stage in job $jobId requires ${e.requiredConcurrentTasks} slots, " + + s"but only ${e.maxConcurrentTasks} are available. " + + s"Will retry up to ${maxFailureNumTasksCheck - numCheckFailures + 1} more times") + + if (numCheckFailures <= maxFailureNumTasksCheck) { + messageScheduler.schedule( + new Runnable { + override def run(): Unit = eventProcessLoop.post(JobSubmitted(jobId, finalRDD, func, + partitions, callSite, listener, properties)) + }, + timeIntervalNumTasksCheck, + TimeUnit.SECONDS + ) + return + } else { + // Job failed, clear internal data. + barrierJobIdToNumTasksCheckFailures.remove(jobId) + listener.jobFailed(e) + return + } + + case e: Exception => + logWarning("Creating new stage failed due to exception - job: " + jobId, e) + listener.jobFailed(e) + return + } + // Job submitted, clear internal data. + barrierJobIdToNumTasksCheckFailures.remove(jobId) + + val job = new ActiveJob(jobId, finalStage, callSite, listener, properties) + clearCacheLocs() + logInfo("Got job %s (%s) with %d output partitions".format( + job.jobId, callSite.shortForm, partitions.length)) + logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")") + logInfo("Parents of final stage: " + finalStage.parents) + logInfo("Missing parents: " + getMissingParentStages(finalStage)) + + val jobSubmissionTime = clock.getTimeMillis() + jobIdToActiveJob(jobId) = job + activeJobs += job + finalStage.setActiveJob(job) + val stageIds = jobIdToStageIds(jobId).toArray + val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) + listenerBus.post( + SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) + submitStage(finalStage) + } + + private[scheduler] def handleMapStageSubmitted(jobId: Int, + dependency: ShuffleDependency[_, _, _], + callSite: CallSite, + listener: JobListener, + properties: Properties): Unit = { + // Submitting this map stage might still require the creation of some parent stages, so make + // sure that happens. + var finalStage: ShuffleMapStage = null + try { + // New stage creation may throw an exception if, for example, jobs are run on a + // HadoopRDD whose underlying HDFS files have been deleted. + finalStage = getOrCreateShuffleMapStage(dependency, jobId) + } catch { + case e: Exception => + logWarning("Creating new stage failed due to exception - job: " + jobId, e) + listener.jobFailed(e) + return + } + + val job = new ActiveJob(jobId, finalStage, callSite, listener, properties) + clearCacheLocs() + logInfo("Got map stage job %s (%s) with %d output partitions".format( + jobId, callSite.shortForm, dependency.rdd.partitions.length)) + logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")") + logInfo("Parents of final stage: " + finalStage.parents) + logInfo("Missing parents: " + getMissingParentStages(finalStage)) + + val jobSubmissionTime = clock.getTimeMillis() + jobIdToActiveJob(jobId) = job + activeJobs += job + finalStage.addActiveJob(job) + val stageIds = jobIdToStageIds(jobId).toArray + val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) + listenerBus.post( + SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) + submitStage(finalStage) + + // If the whole stage has already finished, tell the listener and remove it + if (finalStage.isAvailable) { + markMapStageJobAsFinished(job, mapOutputTracker.getStatistics(dependency)) + } + } + + /** Submits stage, but first recursively submits any missing parents. */ + private def submitStage(stage: Stage): Unit = { + val jobId = activeJobForStage(stage) + if (jobId.isDefined) { + logDebug(s"submitStage($stage (name=${stage.name};" + + s"jobs=${stage.jobIds.toSeq.sorted.mkString(",")}))") + if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) { + val missing = getMissingParentStages(stage).sortBy(_.id) + logDebug("missing: " + missing) + if (missing.isEmpty) { + logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents") + submitMissingTasks(stage, jobId.get) + } else { + for (parent <- missing) { + submitStage(parent) + } + waitingStages += stage + } + } + } else { + abortStage(stage, "No active job for stage " + stage.id, None) + } + } + + /** Called when stage's parents are available and we can now do its task. */ + private def submitMissingTasks(stage: Stage, jobId: Int): Unit = { + logDebug("submitMissingTasks(" + stage + ")") + + // Before find missing partition, do the intermediate state clean work first. + // The operation here can make sure for the partially completed intermediate stage, + // `findMissingPartitions()` returns all partitions every time. + stage match { + case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable => + mapOutputTracker.unregisterAllMapOutput(sms.shuffleDep.shuffleId) + case _ => + } + + // Figure out the indexes of partition ids to compute. + val partitionsToCompute: Seq[Int] = stage.findMissingPartitions() + + // Use the scheduling pool, job group, description, etc. from an ActiveJob associated + // with this Stage + val properties = jobIdToActiveJob(jobId).properties + + runningStages += stage + // SparkListenerStageSubmitted should be posted before testing whether tasks are + // serializable. If tasks are not serializable, a SparkListenerStageCompleted event + // will be posted, which should always come after a corresponding SparkListenerStageSubmitted + // event. + stage match { + case s: ShuffleMapStage => + outputCommitCoordinator.stageStart(stage = s.id, maxPartitionId = s.numPartitions - 1) + case s: ResultStage => + outputCommitCoordinator.stageStart( + stage = s.id, maxPartitionId = s.rdd.partitions.length - 1) + } + val taskIdToLocations: Map[Int, Seq[TaskLocation]] = try { + stage match { + case s: ShuffleMapStage => + partitionsToCompute.map { id => (id, getPreferredLocs(stage.rdd, id))}.toMap + case s: ResultStage => + partitionsToCompute.map { id => + val p = s.partitions(id) + (id, getPreferredLocs(stage.rdd, p)) + }.toMap + } + } catch { + case NonFatal(e) => + stage.makeNewStageAttempt(partitionsToCompute.size) + listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) + abortStage(stage, s"Task creation failed: $e\n${Utils.exceptionString(e)}", Some(e)) + runningStages -= stage + return + } + + stage.makeNewStageAttempt(partitionsToCompute.size, taskIdToLocations.values.toSeq) + + // If there are tasks to execute, record the submission time of the stage. Otherwise, + // post the even without the submission time, which indicates that this stage was + // skipped. + if (partitionsToCompute.nonEmpty) { + stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) + } + listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) + + // TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times. + // Broadcasted binary for the task, used to dispatch tasks to executors. Note that we broadcast + // the serialized copy of the RDD and for each task we will deserialize it, which means each + // task gets a different copy of the RDD. This provides stronger isolation between tasks that + // might modify state of objects referenced in their closures. This is necessary in Hadoop + // where the JobConf/Configuration object is not thread-safe. + var taskBinary: Broadcast[Array[Byte]] = null + var partitions: Array[Partition] = null + try { + // For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep). + // For ResultTask, serialize and broadcast (rdd, func). + var taskBinaryBytes: Array[Byte] = null + // taskBinaryBytes and partitions are both effected by the checkpoint status. We need + // this synchronization in case another concurrent job is checkpointing this RDD, so we get a + // consistent view of both variables. + RDDCheckpointData.synchronized { + taskBinaryBytes = stage match { + case stage: ShuffleMapStage => + JavaUtils.bufferToArray( + closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef)) + case stage: ResultStage => + JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef)) + } + + partitions = stage.rdd.partitions + } + + if (taskBinaryBytes.length > TaskSetManager.TASK_SIZE_TO_WARN_KIB * 1024) { + logWarning(s"Broadcasting large task binary with size " + + s"${Utils.bytesToString(taskBinaryBytes.length)}") + } + taskBinary = sc.broadcast(taskBinaryBytes) + } catch { + // In the case of a failure during serialization, abort the stage. + case e: NotSerializableException => + abortStage(stage, "Task not serializable: " + e.toString, Some(e)) + runningStages -= stage + + // Abort execution + return + case e: Throwable => + abortStage(stage, s"Task serialization failed: $e\n${Utils.exceptionString(e)}", Some(e)) + runningStages -= stage + + // Abort execution + return + } + + val tasks: Seq[Task[_]] = try { + val serializedTaskMetrics = closureSerializer.serialize(stage.latestInfo.taskMetrics).array() + stage match { + case stage: ShuffleMapStage => + stage.pendingPartitions.clear() + partitionsToCompute.map { id => + val locs = taskIdToLocations(id) + val part = partitions(id) + stage.pendingPartitions += id + new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber, + taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId), + Option(sc.applicationId), sc.applicationAttemptId, stage.rdd.isBarrier()) + } + + case stage: ResultStage => + partitionsToCompute.map { id => + val p: Int = stage.partitions(id) + val part = partitions(p) + val locs = taskIdToLocations(id) + new ResultTask(stage.id, stage.latestInfo.attemptNumber, + taskBinary, part, locs, id, properties, serializedTaskMetrics, + Option(jobId), Option(sc.applicationId), sc.applicationAttemptId, + stage.rdd.isBarrier()) + } + } + } catch { + case NonFatal(e) => + abortStage(stage, s"Task creation failed: $e\n${Utils.exceptionString(e)}", Some(e)) + runningStages -= stage + return + } + + if (tasks.nonEmpty) { + logInfo(s"Submitting ${tasks.size} missing tasks from $stage (${stage.rdd}) (first 15 " + + s"tasks are for partitions ${tasks.take(15).map(_.partitionId)})") + taskScheduler.submitTasks(new TaskSet( + tasks.toArray, stage.id, stage.latestInfo.attemptNumber, jobId, properties)) + } else { + // Because we posted SparkListenerStageSubmitted earlier, we should mark + // the stage as completed here in case there are no tasks to run + markStageAsFinished(stage, None) + + stage match { + case stage: ShuffleMapStage => + logDebug(s"Stage ${stage} is actually done; " + + s"(available: ${stage.isAvailable}," + + s"available outputs: ${stage.numAvailableOutputs}," + + s"partitions: ${stage.numPartitions})") + markMapStageJobsAsFinished(stage) + case stage : ResultStage => + logDebug(s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})") + } + submitWaitingChildStages(stage) + } + } + + /** + * Merge local values from a task into the corresponding accumulators previously registered + * here on the driver. + * + * Although accumulators themselves are not thread-safe, this method is called only from one + * thread, the one that runs the scheduling loop. This means we only handle one task + * completion event at a time so we don't need to worry about locking the accumulators. + * This still doesn't stop the caller from updating the accumulator outside the scheduler, + * but that's not our problem since there's nothing we can do about that. + */ + private def updateAccumulators(event: CompletionEvent): Unit = { + val task = event.task + val stage = stageIdToStage(task.stageId) + + event.accumUpdates.foreach { updates => + val id = updates.id + try { + // Find the corresponding accumulator on the driver and update it + val acc: AccumulatorV2[Any, Any] = AccumulatorContext.get(id) match { + case Some(accum) => accum.asInstanceOf[AccumulatorV2[Any, Any]] + case None => + throw new SparkException(s"attempted to access non-existent accumulator $id") + } + acc.merge(updates.asInstanceOf[AccumulatorV2[Any, Any]]) + // To avoid UI cruft, ignore cases where value wasn't updated + if (acc.name.isDefined && !updates.isZero) { + stage.latestInfo.accumulables(id) = acc.toInfo(None, Some(acc.value)) + event.taskInfo.setAccumulables( + acc.toInfo(Some(updates.value), Some(acc.value)) +: event.taskInfo.accumulables) + } + } catch { + case NonFatal(e) => + // Log the class name to make it easy to find the bad implementation + val accumClassName = AccumulatorContext.get(id) match { + case Some(accum) => accum.getClass.getName + case None => "Unknown class" + } + logError( + s"Failed to update accumulator $id ($accumClassName) for task ${task.partitionId}", + e) + } + } + } + + private def postTaskEnd(event: CompletionEvent): Unit = { + val taskMetrics: TaskMetrics = + if (event.accumUpdates.nonEmpty) { + try { + TaskMetrics.fromAccumulators(event.accumUpdates) + } catch { + case NonFatal(e) => + val taskId = event.taskInfo.taskId + logError(s"Error when attempting to reconstruct metrics for task $taskId", e) + null + } + } else { + null + } + + listenerBus.post(SparkListenerTaskEnd(event.task.stageId, event.task.stageAttemptId, + Utils.getFormattedClassName(event.task), event.reason, event.taskInfo, + new ExecutorMetrics(event.metricPeaks), taskMetrics)) + } + + /** + * Check [[SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL]] in job properties to see if we should + * interrupt running tasks. Returns `false` if the property value is not a boolean value + */ + private def shouldInterruptTaskThread(job: ActiveJob): Boolean = { + if (job.properties == null) { + false + } else { + val shouldInterruptThread = + job.properties.getProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false") + try { + shouldInterruptThread.toBoolean + } catch { + case e: IllegalArgumentException => + logWarning(s"${SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL} in Job ${job.jobId} " + + s"is invalid: $shouldInterruptThread. Using 'false' instead", e) + false + } + } + } + + /** + * Responds to a task finishing. This is called inside the event loop so it assumes that it can + * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside. + */ + private[scheduler] def handleTaskCompletion(event: CompletionEvent): Unit = { + val task = event.task + val stageId = task.stageId + + outputCommitCoordinator.taskCompleted( + stageId, + task.stageAttemptId, + task.partitionId, + event.taskInfo.attemptNumber, // this is a task attempt number + event.reason) + + if (!stageIdToStage.contains(task.stageId)) { + // The stage may have already finished when we get this event -- eg. maybe it was a + // speculative task. It is important that we send the TaskEnd event in any case, so listeners + // are properly notified and can chose to handle it. For instance, some listeners are + // doing their own accounting and if they don't get the task end event they think + // tasks are still running when they really aren't. + postTaskEnd(event) + + // Skip all the actions if the stage has been cancelled. + return + } + + val stage = stageIdToStage(task.stageId) + + // Make sure the task's accumulators are updated before any other processing happens, so that + // we can post a task end event before any jobs or stages are updated. The accumulators are + // only updated in certain cases. + event.reason match { + case Success => + task match { + case rt: ResultTask[_, _] => + val resultStage = stage.asInstanceOf[ResultStage] + resultStage.activeJob match { + case Some(job) => + // Only update the accumulator once for each result task. + if (!job.finished(rt.outputId)) { + updateAccumulators(event) + } + case None => // Ignore update if task's job has finished. + } + case _ => + updateAccumulators(event) + } + case _: ExceptionFailure | _: TaskKilled => updateAccumulators(event) + case _ => + } + postTaskEnd(event) + + event.reason match { + case Success => + // An earlier attempt of a stage (which is zombie) may still have running tasks. If these + // tasks complete, they still count and we can mark the corresponding partitions as + // finished. Here we notify the task scheduler to skip running tasks for the same partition, + // to save resource. + if (task.stageAttemptId < stage.latestInfo.attemptNumber()) { + taskScheduler.notifyPartitionCompletion(stageId, task.partitionId) + } + + task match { + case rt: ResultTask[_, _] => + // Cast to ResultStage here because it's part of the ResultTask + // TODO Refactor this out to a function that accepts a ResultStage + val resultStage = stage.asInstanceOf[ResultStage] + resultStage.activeJob match { + case Some(job) => + if (!job.finished(rt.outputId)) { + job.finished(rt.outputId) = true + job.numFinished += 1 + // If the whole job has finished, remove it + if (job.numFinished == job.numPartitions) { + markStageAsFinished(resultStage) + cancelRunningIndependentStages(job, s"Job ${job.jobId} is finished.") + cleanupStateForJobAndIndependentStages(job) + try { + // killAllTaskAttempts will fail if a SchedulerBackend does not implement + // killTask. + logInfo(s"Job ${job.jobId} is finished. Cancelling potential speculative " + + "or zombie tasks for this job") + // ResultStage is only used by this job. It's safe to kill speculative or + // zombie tasks in this stage. + taskScheduler.killAllTaskAttempts( + stageId, + shouldInterruptTaskThread(job), + reason = "Stage finished") + } catch { + case e: UnsupportedOperationException => + logWarning(s"Could not cancel tasks for stage $stageId", e) + } + listenerBus.post( + SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded)) + } + + // taskSucceeded runs some user code that might throw an exception. Make sure + // we are resilient against that. + try { + job.listener.taskSucceeded(rt.outputId, event.result) + } catch { + case e: Throwable if !Utils.isFatalError(e) => + // TODO: Perhaps we want to mark the resultStage as failed? + job.listener.jobFailed(new SparkDriverExecutionException(e)) + } + } + case None => + logInfo("Ignoring result from " + rt + " because its job has finished") + } + + case smt: ShuffleMapTask => + val shuffleStage = stage.asInstanceOf[ShuffleMapStage] + shuffleStage.pendingPartitions -= task.partitionId + val status = event.result.asInstanceOf[MapStatus] + val execId = status.location.executorId + logDebug("ShuffleMapTask finished on " + execId) + if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) { + logInfo(s"Ignoring possibly bogus $smt completion from executor $execId") + } else { + // The epoch of the task is acceptable (i.e., the task was launched after the most + // recent failure we're aware of for the executor), so mark the task's output as + // available. + mapOutputTracker.registerMapOutput( + shuffleStage.shuffleDep.shuffleId, smt.partitionId, status) + } + + if (runningStages.contains(shuffleStage) && shuffleStage.pendingPartitions.isEmpty) { + markStageAsFinished(shuffleStage) + logInfo("looking for newly runnable stages") + logInfo("running: " + runningStages) + logInfo("waiting: " + waitingStages) + logInfo("failed: " + failedStages) + + // This call to increment the epoch may not be strictly necessary, but it is retained + // for now in order to minimize the changes in behavior from an earlier version of the + // code. This existing behavior of always incrementing the epoch following any + // successful shuffle map stage completion may have benefits by causing unneeded + // cached map outputs to be cleaned up earlier on executors. In the future we can + // consider removing this call, but this will require some extra investigation. + // See https://github.com/apache/spark/pull/17955/files#r117385673 for more details. + mapOutputTracker.incrementEpoch() + + clearCacheLocs() + + if (!shuffleStage.isAvailable) { + // Some tasks had failed; let's resubmit this shuffleStage. + // TODO: Lower-level scheduler should also deal with this + logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name + + ") because some of its tasks had failed: " + + shuffleStage.findMissingPartitions().mkString(", ")) + submitStage(shuffleStage) + } else { + markMapStageJobsAsFinished(shuffleStage) + submitWaitingChildStages(shuffleStage) + } + } + } + + case FetchFailed(bmAddress, shuffleId, _, mapIndex, _, failureMessage) => + val failedStage = stageIdToStage(task.stageId) + val mapStage = shuffleIdToMapStage(shuffleId) + + if (failedStage.latestInfo.attemptNumber != task.stageAttemptId) { + logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" + + s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + + s"(attempt ${failedStage.latestInfo.attemptNumber}) running") + } else { + failedStage.failedAttemptIds.add(task.stageAttemptId) + val shouldAbortStage = + failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || + disallowStageRetryForTest + + // It is likely that we receive multiple FetchFailed for a single stage (because we have + // multiple tasks running concurrently on different executors). In that case, it is + // possible the fetch failure has already been handled by the scheduler. + if (runningStages.contains(failedStage)) { + logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + + s"due to a fetch failure from $mapStage (${mapStage.name})") + markStageAsFinished(failedStage, errorMessage = Some(failureMessage), + willRetry = !shouldAbortStage) + } else { + logDebug(s"Received fetch failure from $task, but it's from $failedStage which is no " + + "longer running") + } + + if (mapStage.rdd.isBarrier()) { + // Mark all the map as broken in the map stage, to ensure retry all the tasks on + // resubmitted stage attempt. + mapOutputTracker.unregisterAllMapOutput(shuffleId) + } else if (mapIndex != -1) { + // Mark the map whose fetch failed as broken in the map stage + mapOutputTracker.unregisterMapOutput(shuffleId, mapIndex, bmAddress) + } + + if (failedStage.rdd.isBarrier()) { + failedStage match { + case failedMapStage: ShuffleMapStage => + // Mark all the map as broken in the map stage, to ensure retry all the tasks on + // resubmitted stage attempt. + mapOutputTracker.unregisterAllMapOutput(failedMapStage.shuffleDep.shuffleId) + + case failedResultStage: ResultStage => + // Abort the failed result stage since we may have committed output for some + // partitions. + val reason = "Could not recover from a failed barrier ResultStage. Most recent " + + s"failure reason: $failureMessage" + abortStage(failedResultStage, reason, None) + } + } + + if (shouldAbortStage) { + val abortMessage = if (disallowStageRetryForTest) { + "Fetch failure will not retry stage due to testing config" + } else { + s"""$failedStage (${failedStage.name}) + |has failed the maximum allowable number of + |times: $maxConsecutiveStageAttempts. + |Most recent failure reason: $failureMessage""".stripMargin.replaceAll("\n", " ") + } + abortStage(failedStage, abortMessage, None) + } else { // update failedStages and make sure a ResubmitFailedStages event is enqueued + // TODO: Cancel running tasks in the failed stage -- cf. SPARK-17064 + val noResubmitEnqueued = !failedStages.contains(failedStage) + failedStages += failedStage + failedStages += mapStage + if (noResubmitEnqueued) { + // If the map stage is INDETERMINATE, which means the map tasks may return + // different result when re-try, we need to re-try all the tasks of the failed + // stage and its succeeding stages, because the input data will be changed after the + // map tasks are re-tried. + // Note that, if map stage is UNORDERED, we are fine. The shuffle partitioner is + // guaranteed to be determinate, so the input data of the reducers will not change + // even if the map tasks are re-tried. + if (mapStage.isIndeterminate) { + // It's a little tricky to find all the succeeding stages of `mapStage`, because + // each stage only know its parents not children. Here we traverse the stages from + // the leaf nodes (the result stages of active jobs), and rollback all the stages + // in the stage chains that connect to the `mapStage`. To speed up the stage + // traversing, we collect the stages to rollback first. If a stage needs to + // rollback, all its succeeding stages need to rollback to. + val stagesToRollback = HashSet[Stage](mapStage) + + def collectStagesToRollback(stageChain: List[Stage]): Unit = { + if (stagesToRollback.contains(stageChain.head)) { + stageChain.drop(1).foreach(s => stagesToRollback += s) + } else { + stageChain.head.parents.foreach { s => + collectStagesToRollback(s :: stageChain) + } + } + } + + def generateErrorMessage(stage: Stage): String = { + "A shuffle map stage with indeterminate output was failed and retried. " + + s"However, Spark cannot rollback the $stage to re-process the input data, " + + "and has to fail this job. Please eliminate the indeterminacy by " + + "checkpointing the RDD before repartition and try again." + } + + activeJobs.foreach(job => collectStagesToRollback(job.finalStage :: Nil)) + + // The stages will be rolled back after checking + val rollingBackStages = HashSet[Stage](mapStage) + stagesToRollback.foreach { + case mapStage: ShuffleMapStage => + val numMissingPartitions = mapStage.findMissingPartitions().length + if (numMissingPartitions < mapStage.numTasks) { + if (sc.getConf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) { + val reason = "A shuffle map stage with indeterminate output was failed " + + "and retried. However, Spark can only do this while using the new " + + "shuffle block fetching protocol. Please check the config " + + "'spark.shuffle.useOldFetchProtocol', see more detail in " + + "SPARK-27665 and SPARK-25341." + abortStage(mapStage, reason, None) + } else { + rollingBackStages += mapStage + } + } + + case resultStage: ResultStage if resultStage.activeJob.isDefined => + val numMissingPartitions = resultStage.findMissingPartitions().length + if (numMissingPartitions < resultStage.numTasks) { + // TODO: support to rollback result tasks. + abortStage(resultStage, generateErrorMessage(resultStage), None) + } + + case _ => + } + logInfo(s"The shuffle map stage $mapStage with indeterminate output was failed, " + + s"we will roll back and rerun below stages which include itself and all its " + + s"indeterminate child stages: $rollingBackStages") + } + + // We expect one executor failure to trigger many FetchFailures in rapid succession, + // but all of those task failures can typically be handled by a single resubmission of + // the failed stage. We avoid flooding the scheduler's event queue with resubmit + // messages by checking whether a resubmit is already in the event queue for the + // failed stage. If there is already a resubmit enqueued for a different failed + // stage, that event would also be sufficient to handle the current failed stage, but + // producing a resubmit for each failed stage makes debugging and logging a little + // simpler while not producing an overwhelming number of scheduler events. + logInfo( + s"Resubmitting $mapStage (${mapStage.name}) and " + + s"$failedStage (${failedStage.name}) due to fetch failure" + ) + messageScheduler.schedule( + new Runnable { + override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) + }, + DAGScheduler.RESUBMIT_TIMEOUT, + TimeUnit.MILLISECONDS + ) + } + } + + // TODO: mark the executor as failed only if there were lots of fetch failures on it + if (bmAddress != null) { + val hostToUnregisterOutputs = if (env.blockManager.externalShuffleServiceEnabled && + unRegisterOutputOnHostOnFetchFailure) { + // We had a fetch failure with the external shuffle service, so we + // assume all shuffle data on the node is bad. + Some(bmAddress.host) + } else { + // Unregister shuffle data just for one executor (we don't have any + // reason to believe shuffle data has been lost for the entire host). + None + } + removeExecutorAndUnregisterOutputs( + execId = bmAddress.executorId, + fileLost = true, + hostToUnregisterOutputs = hostToUnregisterOutputs, + maybeEpoch = Some(task.epoch)) + } + } + + case failure: TaskFailedReason if task.isBarrier => + // Also handle the task failed reasons here. + failure match { + case Resubmitted => + handleResubmittedFailure(task, stage) + + case _ => // Do nothing. + } + + // Always fail the current stage and retry all the tasks when a barrier task fail. + val failedStage = stageIdToStage(task.stageId) + if (failedStage.latestInfo.attemptNumber != task.stageAttemptId) { + logInfo(s"Ignoring task failure from $task as it's from $failedStage attempt" + + s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + + s"(attempt ${failedStage.latestInfo.attemptNumber}) running") + } else { + logInfo(s"Marking $failedStage (${failedStage.name}) as failed due to a barrier task " + + "failed.") + val message = s"Stage failed because barrier task $task finished unsuccessfully.\n" + + failure.toErrorString + try { + // killAllTaskAttempts will fail if a SchedulerBackend does not implement killTask. + val reason = s"Task $task from barrier stage $failedStage (${failedStage.name}) " + + "failed." + taskScheduler.killAllTaskAttempts(stageId, interruptThread = false, reason) + } catch { + case e: UnsupportedOperationException => + // Cannot continue with barrier stage if failed to cancel zombie barrier tasks. + // TODO SPARK-24877 leave the zombie tasks and ignore their completion events. + logWarning(s"Could not kill all tasks for stage $stageId", e) + abortStage(failedStage, "Could not kill zombie barrier tasks for stage " + + s"$failedStage (${failedStage.name})", Some(e)) + } + markStageAsFinished(failedStage, Some(message)) + + failedStage.failedAttemptIds.add(task.stageAttemptId) + // TODO Refactor the failure handling logic to combine similar code with that of + // FetchFailed. + val shouldAbortStage = + failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || + disallowStageRetryForTest + + if (shouldAbortStage) { + val abortMessage = if (disallowStageRetryForTest) { + "Barrier stage will not retry stage due to testing config. Most recent failure " + + s"reason: $message" + } else { + s"""$failedStage (${failedStage.name}) + |has failed the maximum allowable number of + |times: $maxConsecutiveStageAttempts. + |Most recent failure reason: $message + """.stripMargin.replaceAll("\n", " ") + } + abortStage(failedStage, abortMessage, None) + } else { + failedStage match { + case failedMapStage: ShuffleMapStage => + // Mark all the map as broken in the map stage, to ensure retry all the tasks on + // resubmitted stage attempt. + mapOutputTracker.unregisterAllMapOutput(failedMapStage.shuffleDep.shuffleId) + + case failedResultStage: ResultStage => + // Abort the failed result stage since we may have committed output for some + // partitions. + val reason = "Could not recover from a failed barrier ResultStage. Most recent " + + s"failure reason: $message" + abortStage(failedResultStage, reason, None) + } + // In case multiple task failures triggered for a single stage attempt, ensure we only + // resubmit the failed stage once. + val noResubmitEnqueued = !failedStages.contains(failedStage) + failedStages += failedStage + if (noResubmitEnqueued) { + logInfo(s"Resubmitting $failedStage (${failedStage.name}) due to barrier stage " + + "failure.") + messageScheduler.schedule(new Runnable { + override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) + }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) + } + } + } + + case Resubmitted => + handleResubmittedFailure(task, stage) + + case _: TaskCommitDenied => + // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits + + case _: ExceptionFailure | _: TaskKilled => + // Nothing left to do, already handled above for accumulator updates. + + case TaskResultLost => + // Do nothing here; the TaskScheduler handles these failures and resubmits the task. + + case _: ExecutorLostFailure | UnknownReason => + // Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler + // will abort the job. + } + } + + private def handleResubmittedFailure(task: Task[_], stage: Stage): Unit = { + logInfo(s"Resubmitted $task, so marking it as still running.") + stage match { + case sms: ShuffleMapStage => + sms.pendingPartitions += task.partitionId + + case _ => + throw new SparkException("TaskSetManagers should only send Resubmitted task " + + "statuses for tasks in ShuffleMapStages.") + } + } + + private[scheduler] def markMapStageJobsAsFinished(shuffleStage: ShuffleMapStage): Unit = { + // Mark any map-stage jobs waiting on this stage as finished + if (shuffleStage.isAvailable && shuffleStage.mapStageJobs.nonEmpty) { + val stats = mapOutputTracker.getStatistics(shuffleStage.shuffleDep) + for (job <- shuffleStage.mapStageJobs) { + markMapStageJobAsFinished(job, stats) + } + } + } + + /** + * Responds to an executor being lost. This is called inside the event loop, so it assumes it can + * modify the scheduler's internal state. Use executorLost() to post a loss event from outside. + * + * We will also assume that we've lost all shuffle blocks associated with the executor if the + * executor serves its own blocks (i.e., we're not using external shuffle), the entire slave + * is lost (likely including the shuffle service), or a FetchFailed occurred, in which case we + * presume all shuffle data related to this executor to be lost. + * + * Optionally the epoch during which the failure was caught can be passed to avoid allowing + * stray fetch failures from possibly retriggering the detection of a node as lost. + */ + private[scheduler] def handleExecutorLost( + execId: String, + workerLost: Boolean): Unit = { + // if the cluster manager explicitly tells us that the entire worker was lost, then + // we know to unregister shuffle output. (Note that "worker" specifically refers to the process + // from a Standalone cluster, where the shuffle service lives in the Worker.) + val remoteShuffleClass = classOf[RemoteShuffleManager].getName + val remoteShuffleEnabled = env.conf.get("spark.shuffle.manager") == remoteShuffleClass + // If remote shuffle is enabled, shuffle files will be taken care of by remote storage, the + // unregistering and rerun of certain tasks are not needed. + val fileLost = + !remoteShuffleEnabled && (workerLost || !env.blockManager.externalShuffleServiceEnabled) + removeExecutorAndUnregisterOutputs( + execId = execId, + fileLost = fileLost, + hostToUnregisterOutputs = None, + maybeEpoch = None) + } + + private def removeExecutorAndUnregisterOutputs( + execId: String, + fileLost: Boolean, + hostToUnregisterOutputs: Option[String], + maybeEpoch: Option[Long] = None): Unit = { + val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch) + if (!failedEpoch.contains(execId) || failedEpoch(execId) < currentEpoch) { + failedEpoch(execId) = currentEpoch + logInfo("Executor lost: %s (epoch %d)".format(execId, currentEpoch)) + blockManagerMaster.removeExecutor(execId) + if (fileLost) { + hostToUnregisterOutputs match { + case Some(host) => + logInfo("Shuffle files lost for host: %s (epoch %d)".format(host, currentEpoch)) + mapOutputTracker.removeOutputsOnHost(host) + case None => + logInfo("Shuffle files lost for executor: %s (epoch %d)".format(execId, currentEpoch)) + mapOutputTracker.removeOutputsOnExecutor(execId) + } + clearCacheLocs() + + } else { + logDebug("Additional executor lost message for %s (epoch %d)".format(execId, currentEpoch)) + } + } + } + + /** + * Responds to a worker being removed. This is called inside the event loop, so it assumes it can + * modify the scheduler's internal state. Use workerRemoved() to post a loss event from outside. + * + * We will assume that we've lost all shuffle blocks associated with the host if a worker is + * removed, so we will remove them all from MapStatus. + * + * @param workerId identifier of the worker that is removed. + * @param host host of the worker that is removed. + * @param message the reason why the worker is removed. + */ + private[scheduler] def handleWorkerRemoved( + workerId: String, + host: String, + message: String): Unit = { + logInfo("Shuffle files lost for worker %s on host %s".format(workerId, host)) + mapOutputTracker.removeOutputsOnHost(host) + clearCacheLocs() + } + + private[scheduler] def handleExecutorAdded(execId: String, host: String): Unit = { + // remove from failedEpoch(execId) ? + if (failedEpoch.contains(execId)) { + logInfo("Host added was in lost list earlier: " + host) + failedEpoch -= execId + } + } + + private[scheduler] def handleStageCancellation(stageId: Int, reason: Option[String]): Unit = { + stageIdToStage.get(stageId) match { + case Some(stage) => + val jobsThatUseStage: Array[Int] = stage.jobIds.toArray + jobsThatUseStage.foreach { jobId => + val reasonStr = reason match { + case Some(originalReason) => + s"because $originalReason" + case None => + s"because Stage $stageId was cancelled" + } + handleJobCancellation(jobId, Option(reasonStr)) + } + case None => + logInfo("No active jobs to kill for Stage " + stageId) + } + } + + private[scheduler] def handleJobCancellation(jobId: Int, reason: Option[String]): Unit = { + if (!jobIdToStageIds.contains(jobId)) { + logDebug("Trying to cancel unregistered job " + jobId) + } else { + failJobAndIndependentStages( + jobIdToActiveJob(jobId), "Job %d cancelled %s".format(jobId, reason.getOrElse(""))) + } + } + + /** + * Marks a stage as finished and removes it from the list of running stages. + */ + private def markStageAsFinished( + stage: Stage, + errorMessage: Option[String] = None, + willRetry: Boolean = false): Unit = { + val serviceTime = stage.latestInfo.submissionTime match { + case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0) + case _ => "Unknown" + } + if (errorMessage.isEmpty) { + logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) + stage.latestInfo.completionTime = Some(clock.getTimeMillis()) + + // Clear failure count for this stage, now that it's succeeded. + // We only limit consecutive failures of stage attempts,so that if a stage is + // re-used many times in a long-running job, unrelated failures don't eventually cause the + // stage to be aborted. + stage.clearFailures() + } else { + stage.latestInfo.stageFailed(errorMessage.get) + logInfo(s"$stage (${stage.name}) failed in $serviceTime s due to ${errorMessage.get}") + } + + if (!willRetry) { + outputCommitCoordinator.stageEnd(stage.id) + } + listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) + runningStages -= stage + } + + /** + * Aborts all jobs depending on a particular Stage. This is called in response to a task set + * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. + */ + private[scheduler] def abortStage( + failedStage: Stage, + reason: String, + exception: Option[Throwable]): Unit = { + if (!stageIdToStage.contains(failedStage.id)) { + // Skip all the actions if the stage has been removed. + return + } + val dependentJobs: Seq[ActiveJob] = + activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq + failedStage.latestInfo.completionTime = Some(clock.getTimeMillis()) + for (job <- dependentJobs) { + failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason", exception) + } + if (dependentJobs.isEmpty) { + logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") + } + } + + /** Cancel all independent, running stages that are only used by this job. */ + private def cancelRunningIndependentStages(job: ActiveJob, reason: String): Boolean = { + var ableToCancelStages = true + val stages = jobIdToStageIds(job.jobId) + if (stages.isEmpty) { + logError(s"No stages registered for job ${job.jobId}") + } + stages.foreach { stageId => + val jobsForStage: Option[HashSet[Int]] = stageIdToStage.get(stageId).map(_.jobIds) + if (jobsForStage.isEmpty || !jobsForStage.get.contains(job.jobId)) { + logError( + "Job %d not registered for stage %d even though that stage was registered for the job" + .format(job.jobId, stageId)) + } else if (jobsForStage.get.size == 1) { + if (!stageIdToStage.contains(stageId)) { + logError(s"Missing Stage for stage with id $stageId") + } else { + // This stage is only used by the job, so finish the stage if it is running. + val stage = stageIdToStage(stageId) + if (runningStages.contains(stage)) { + try { // cancelTasks will fail if a SchedulerBackend does not implement killTask + taskScheduler.cancelTasks(stageId, shouldInterruptTaskThread(job)) + markStageAsFinished(stage, Some(reason)) + } catch { + case e: UnsupportedOperationException => + logWarning(s"Could not cancel tasks for stage $stageId", e) + ableToCancelStages = false + } + } + } + } + } + ableToCancelStages + } + + /** Fails a job and all stages that are only used by that job, and cleans up relevant state. */ + private def failJobAndIndependentStages( + job: ActiveJob, + failureReason: String, + exception: Option[Throwable] = None): Unit = { + if (cancelRunningIndependentStages(job, failureReason)) { + // SPARK-15783 important to cleanup state first, just for tests where we have some asserts + // against the state. Otherwise we have a *little* bit of flakiness in the tests. + cleanupStateForJobAndIndependentStages(job) + val error = new SparkException(failureReason, exception.orNull) + job.listener.jobFailed(error) + listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobFailed(error))) + } + } + + /** Return true if one of stage's ancestors is target. */ + private def stageDependsOn(stage: Stage, target: Stage): Boolean = { + if (stage == target) { + return true + } + val visitedRdds = new HashSet[RDD[_]] + // We are manually maintaining a stack here to prevent StackOverflowError + // caused by recursively visiting + val waitingForVisit = new ListBuffer[RDD[_]] + waitingForVisit += stage.rdd + def visit(rdd: RDD[_]): Unit = { + if (!visitedRdds(rdd)) { + visitedRdds += rdd + for (dep <- rdd.dependencies) { + dep match { + case shufDep: ShuffleDependency[_, _, _] => + val mapStage = getOrCreateShuffleMapStage(shufDep, stage.firstJobId) + if (!mapStage.isAvailable) { + waitingForVisit.prepend(mapStage.rdd) + } // Otherwise there's no need to follow the dependency back + case narrowDep: NarrowDependency[_] => + waitingForVisit.prepend(narrowDep.rdd) + } + } + } + } + while (waitingForVisit.nonEmpty) { + visit(waitingForVisit.remove(0)) + } + visitedRdds.contains(target.rdd) + } + + /** + * Gets the locality information associated with a partition of a particular RDD. + * + * This method is thread-safe and is called from both DAGScheduler and SparkContext. + * + * @param rdd whose partitions are to be looked at + * @param partition to lookup locality information for + * @return list of machines that are preferred by the partition + */ + private[spark] + def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = { + getPreferredLocsInternal(rdd, partition, new HashSet) + } + + /** + * Recursive implementation for getPreferredLocs. + * + * This method is thread-safe because it only accesses DAGScheduler state through thread-safe + * methods (getCacheLocs()); please be careful when modifying this method, because any new + * DAGScheduler state accessed by it may require additional synchronization. + */ + private def getPreferredLocsInternal( + rdd: RDD[_], + partition: Int, + visited: HashSet[(RDD[_], Int)]): Seq[TaskLocation] = { + // If the partition has already been visited, no need to re-visit. + // This avoids exponential path exploration. SPARK-695 + if (!visited.add((rdd, partition))) { + // Nil has already been returned for previously visited partitions. + return Nil + } + // If the partition is cached, return the cache locations + val cached = getCacheLocs(rdd)(partition) + if (cached.nonEmpty) { + return cached + } + // If the RDD has some placement preferences (as is the case for input RDDs), get those + val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList + if (rddPrefs.nonEmpty) { + return rddPrefs.map(TaskLocation(_)) + } + + // If the RDD has narrow dependencies, pick the first partition of the first narrow dependency + // that has any placement preferences. Ideally we would choose based on transfer sizes, + // but this will do for now. + rdd.dependencies.foreach { + case n: NarrowDependency[_] => + for (inPart <- n.getParents(partition)) { + val locs = getPreferredLocsInternal(n.rdd, inPart, visited) + if (locs != Nil) { + return locs + } + } + + case _ => + } + + Nil + } + + /** Mark a map stage job as finished with the given output stats, and report to its listener. */ + def markMapStageJobAsFinished(job: ActiveJob, stats: MapOutputStatistics): Unit = { + // In map stage jobs, we only create a single "task", which is to finish all of the stage + // (including reusing any previous map outputs, etc); so we just mark task 0 as done + job.finished(0) = true + job.numFinished += 1 + job.listener.taskSucceeded(0, stats) + cleanupStateForJobAndIndependentStages(job) + listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded)) + } + + def stop(): Unit = { + messageScheduler.shutdownNow() + eventProcessLoop.stop() + taskScheduler.stop() + } + + eventProcessLoop.start() +} + +private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler) + extends EventLoop[DAGSchedulerEvent]("dag-scheduler-event-loop") with Logging { + + private[this] val timer = dagScheduler.metricsSource.messageProcessingTimer + + /** + * The main event loop of the DAG scheduler. + */ + override def onReceive(event: DAGSchedulerEvent): Unit = { + val timerContext = timer.time() + try { + doOnReceive(event) + } finally { + timerContext.stop() + } + } + + private def doOnReceive(event: DAGSchedulerEvent): Unit = event match { + case JobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) => + dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) + + case MapStageSubmitted(jobId, dependency, callSite, listener, properties) => + dagScheduler.handleMapStageSubmitted(jobId, dependency, callSite, listener, properties) + + case StageCancelled(stageId, reason) => + dagScheduler.handleStageCancellation(stageId, reason) + + case JobCancelled(jobId, reason) => + dagScheduler.handleJobCancellation(jobId, reason) + + case JobGroupCancelled(groupId) => + dagScheduler.handleJobGroupCancelled(groupId) + + case AllJobsCancelled => + dagScheduler.doCancelAllJobs() + + case ExecutorAdded(execId, host) => + dagScheduler.handleExecutorAdded(execId, host) + + case ExecutorLost(execId, reason) => + val workerLost = reason match { + case SlaveLost(_, true) => true + case _ => false + } + dagScheduler.handleExecutorLost(execId, workerLost) + + case WorkerRemoved(workerId, host, message) => + dagScheduler.handleWorkerRemoved(workerId, host, message) + + case BeginEvent(task, taskInfo) => + dagScheduler.handleBeginEvent(task, taskInfo) + + case SpeculativeTaskSubmitted(task) => + dagScheduler.handleSpeculativeTaskSubmitted(task) + + case GettingResultEvent(taskInfo) => + dagScheduler.handleGetTaskResult(taskInfo) + + case completion: CompletionEvent => + dagScheduler.handleTaskCompletion(completion) + + case TaskSetFailed(taskSet, reason, exception) => + dagScheduler.handleTaskSetFailed(taskSet, reason, exception) + + case ResubmitFailedStages => + dagScheduler.resubmitFailedStages() + } + + override def onError(e: Throwable): Unit = { + logError("DAGSchedulerEventProcessLoop failed; shutting down SparkContext", e) + try { + dagScheduler.doCancelAllJobs() + } catch { + case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t) + } + dagScheduler.sc.stopInNewThread() + } + + override def onStop(): Unit = { + // Cancel any active jobs in postStop hook + dagScheduler.cleanUpAfterSchedulerStop() + } +} + +private[spark] object DAGScheduler { + // The time, in millis, to wait for fetch failure events to stop coming in after one is detected; + // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one + // as more failure events come in + val RESUBMIT_TIMEOUT = 200 + + // Number of consecutive stage attempts allowed before a stage is aborted + val DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS = 4 +} diff --git a/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/HadoopFileSegmentManagedBuffer.scala b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/HadoopFileSegmentManagedBuffer.scala new file mode 100644 index 000000000..d1c1147cf --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/HadoopFileSegmentManagedBuffer.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.remote + +import java.io.{ByteArrayInputStream, InputStream, IOException} +import java.nio.ByteBuffer +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.mutable + +import io.netty.buffer.{ByteBuf, Unpooled} +import org.apache.hadoop.fs.{FSDataInputStream, Path} + +import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.protocol.{Encodable, Encoders} +import org.apache.spark.network.util.{JavaUtils, LimitedInputStream} + +/** + * Something like [[org.apache.spark.network.buffer.FileSegmentManagedBuffer]], instead we only + * need createInputStream function, so we don't need a TransportConf field, which is intended to + * be used in other functions + */ +private[spark] class HadoopFileSegmentManagedBuffer( + val file: Path, val offset: Long, val length: Long, var eagerRequirement: Boolean = false) + extends ManagedBuffer with Logging { + + import HadoopFileSegmentManagedBuffer._ + + private lazy val dataStream: InputStream = { + if (length == 0) { + new ByteArrayInputStream(new Array[Byte](0)) + } else { + var is: FSDataInputStream = null + is = fs.open(file) + is.seek(offset) + new LimitedInputStream(is, length) + } + } + + private lazy val dataInByteArray: Array[Byte] = { + if (length == 0) { + Array.empty[Byte] + } else { + var is: FSDataInputStream = null + try { + is = { + if (reuseFileHandle) { + val pathToHandleMap = handleCache.get(Thread.currentThread().getId) + if (pathToHandleMap == null) { + val res = fs.open(file) + handleCache.put(Thread.currentThread().getId, + new mutable.HashMap[Path, FSDataInputStream]() += (file -> res)) + res + } else { + pathToHandleMap.getOrElseUpdate(file, fs.open(file)) + } + } else { + fs.open(file) + } + } + is.seek(offset) + val array = new Array[Byte](length.toInt) + is.readFully(array) + array + } catch { + case e: IOException => + var errorMessage = "Error in reading " + this + if (is != null) { + val size = fs.getFileStatus(file).getLen + errorMessage += " (actual file length " + size + ")" + } + throw new IOException(errorMessage, e) + } finally { + if (!reuseFileHandle) { + // Immediately close it if disabled file handle reuse + JavaUtils.closeQuietly(is) + } + } + } + } + + private[spark] def prepareData(eagerRequirement: Boolean): Unit = { + this.eagerRequirement = eagerRequirement + if (! eagerRequirement) { + dataInByteArray + } + } + + override def size(): Long = length + + override def createInputStream(): InputStream = if (eagerRequirement) { + logInfo("Eagerly requiring this data input stream") + dataStream + } else { + new ByteArrayInputStream(dataInByteArray) + } + + override def equals(obj: Any): Boolean = { + if (! obj.isInstanceOf[HadoopFileSegmentManagedBuffer]) { + false + } else { + val buffer = obj.asInstanceOf[HadoopFileSegmentManagedBuffer] + this.file == buffer.file && this.offset == buffer.offset && this.length == buffer.length + } + } + + override def hashCode(): Int = super.hashCode() + + override def retain(): ManagedBuffer = this + + override def release(): ManagedBuffer = this + + override def nioByteBuffer(): ByteBuffer = throw new UnsupportedOperationException + + override def convertToNetty(): AnyRef = throw new UnsupportedOperationException +} + +private[remote] object HadoopFileSegmentManagedBuffer { + private val fs = RemoteShuffleManager.getFileSystem + + private[remote] lazy val handleCache = + new ConcurrentHashMap[Long, mutable.HashMap[Path, FSDataInputStream]]() + private val reuseFileHandle = + RemoteShuffleManager.getConf.get(RemoteShuffleConf.REUSE_FILE_HANDLE) +} + +/** + * This is an RPC message encapsulating HadoopFileSegmentManagedBuffers. Slightly different with + * the OpenBlocks message, this doesn't transfer block stream between executors through netty, but + * only returns file segment ranges(offsets and lengths). Due to in remote shuffle, there is a + * globally-accessible remote storage, like HDFS or DAOS. + */ +class MessageForHadoopManagedBuffers( + val buffers: Array[(String, HadoopFileSegmentManagedBuffer)]) extends Encodable { + + override def encodedLength(): Int = { + var sum = 0 + // the length of count: Int + sum += 4 + for ((blockId, hadoopFileSegment) <- buffers) { + sum += Encoders.Strings.encodedLength(blockId) + sum += Encoders.Strings.encodedLength(hadoopFileSegment.file.toUri.toString) + sum += 8 + sum += 8 + } + sum + } + + override def encode(buf: ByteBuf): Unit = { + val count = buffers.length + // To differentiate from other BlockTransferMessage + buf.writeByte(MessageForHadoopManagedBuffers.MAGIC_CODE) + buf.writeInt(count) + for ((blockId, hadoopFileSegment) <- buffers) { + Encoders.Strings.encode(buf, blockId) + Encoders.Strings.encode(buf, hadoopFileSegment.file.toUri.toString) + buf.writeLong(hadoopFileSegment.offset) + buf.writeLong(hadoopFileSegment.length) + } + } + + // As opposed to fromByteBuffer + def toByteBuffer: ByteBuf = { + val buf = Unpooled.buffer(encodedLength) + encode(buf) + buf + } +} + +object MessageForHadoopManagedBuffers { + + // To differentiate from other BlockTransferMessage + val MAGIC_CODE = -99 + + // Decode + def fromByteBuffer(buf: ByteBuf): MessageForHadoopManagedBuffers = { + val magic = buf.readByte() + assert(magic == MAGIC_CODE, "This is not a MessageForHadoopManagedBuffers! : (") + val count = buf.readInt() + val buffers = for (i <- 0 until count) yield { + val blockId = Encoders.Strings.decode(buf) + val path = new Path(Encoders.Strings.decode(buf)) + val offset = buf.readLong() + val length = buf.readLong() + (blockId, new HadoopFileSegmentManagedBuffer(path, offset, length)) + } + new MessageForHadoopManagedBuffers(buffers.toArray) + } +} diff --git a/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteAggregator.scala b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteAggregator.scala new file mode 100644 index 000000000..8a5607454 --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteAggregator.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.remote + +import org.apache.spark.{Aggregator, TaskContext} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.util.collection.RemoteAppendOnlyMap + +/** + * NOTE: + * + * :: DeveloperApi :: + * A set of functions used to aggregate data. + * + * @param createCombiner function to create the initial value of the aggregation. + * @param mergeValue function to merge a new value into the aggregation result. + * @param mergeCombiners function to merge outputs from multiple mergeValue function. + */ +@DeveloperApi +class RemoteAggregator[K, V, C](agg: Aggregator[K, V, C], resolver: RemoteShuffleBlockResolver) + extends Aggregator[K, V, C](agg.createCombiner, agg.mergeValue, agg.mergeCombiners) { + + override def combineValuesByKey( + iter: Iterator[_ <: Product2[K, V]], + context: TaskContext): Iterator[(K, C)] = { + val combiners = new RemoteAppendOnlyMap[K, V, C]( + createCombiner, mergeValue, mergeCombiners, resolver = resolver) + combiners.insertAll(iter) + updateMetrics(context, combiners) + combiners.iterator + } + + override def combineCombinersByKey( + iter: Iterator[_ <: Product2[K, C]], + context: TaskContext): Iterator[(K, C)] = { + val combiners = new RemoteAppendOnlyMap[K, C, C]( + identity, mergeCombiners, mergeCombiners, resolver = resolver) + combiners.insertAll(iter) + updateMetrics(context, combiners) + combiners.iterator + } + + /** Update task metrics after populating the external map. */ + private def updateMetrics(context: TaskContext, map: RemoteAppendOnlyMap[_, _, _]): Unit = { + Option(context).foreach { c => + c.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) + c.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) + c.taskMetrics().incPeakExecutionMemory(map.peakMemoryUsedBytes) + } + } +} diff --git a/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteBlockObjectWriter.scala b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteBlockObjectWriter.scala new file mode 100644 index 000000000..a764b9d53 --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteBlockObjectWriter.scala @@ -0,0 +1,295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.remote + +import java.io.{BufferedOutputStream, OutputStream} + +import org.apache.hadoop.fs.{FSDataOutputStream, Path} +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.internal.Logging +import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter +import org.apache.spark.storage.{BlockId, TimeTrackingOutputStream} +import org.apache.spark.util.Utils + +/** + * References a particular segment of a Hadoop file (potentially the entire file), + * based off an offset and a length. + */ +private[spark] class HadoopFileSegment(val file: Path, val offset: Long, val length: Long) { + require(offset >= 0, s"File segment offset cannot be negative (got $offset)") + require(length >= 0, s"File segment length cannot be negative (got $length)") + override def toString: String = { + "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length) + } +} + +/** + * NOTE: Most of the code is copied from DiskBlockObjectWriter, as the only difference is that this + * class performs a block object writing to remote storage + * + * A class for writing JVM objects directly to a file on remote storage. This class allows data to + * be appended to an existing block. For efficiency, it retains the underlying file channel across + * multiple commits. This channel is kept open until close() is called. In case of faults, + * callers should instead close with revertPartialWritesAndClose() to atomically revert the + * uncommitted partial writes. + * + * This class does not support concurrent writes. Also, once the writer has been opened it + * cannot be reopened again. + */ +private[spark] class RemoteBlockObjectWriter( + val file: Path, + serializerManager: SerializerManager, + serializerInstance: SerializerInstance, + bufferSize: Int, + syncWrites: Boolean, + // These write metrics concurrently shared with other active DiskBlockObjectWriters who + // are themselves performing writes. All updates must be relative. + writeMetrics: ShuffleWriteMetricsReporter, + val blockId: BlockId = null) + extends OutputStream + with Logging { + + private lazy val fs = RemoteShuffleManager.getFileSystem + + /** + * Guards against close calls, e.g. from a wrapping stream. + * Call manualClose to close the stream that was extended by this trait. + * Commit uses this trait to close object streams without paying the + * cost of closing and opening the underlying file. + */ + private trait ManualCloseOutputStream extends OutputStream { + abstract override def close(): Unit = { + flush() + } + + def manualClose(): Unit = { + super.close() + } + } + + // No need to use a channel, instead call FSDataOutputStream.getPos + private var mcs: ManualCloseOutputStream = null + private var bs: OutputStream = null + private var fsdos: FSDataOutputStream = null + private var ts: TimeTrackingOutputStream = null + private var objOut: SerializationStream = null + private var initialized = false + private var streamOpen = false + private var hasBeenClosed = false + + /** + * Cursors used to represent positions in the file. + * + * xxxxxxxxxx|----------|-----| + * ^ ^ ^ + * | | channel.position() + * | reportedPosition + * committedPosition + * + * reportedPosition: Position at the time of the last update to the write metrics. + * committedPosition: Offset after last committed write. + * -----: Current writes to the underlying file. + * xxxxx: Committed contents of the file. + */ + private var committedPosition = 0L + private var reportedPosition = committedPosition + + /** + * Keep track of number of records written and also use this to periodically + * output bytes written since the latter is expensive to do for each record. + * And we reset it after every commitAndGet called. + */ + private var numRecordsWritten = 0 + + private def initialize(): Unit = { + fsdos = fs.create(file) + ts = new TimeTrackingOutputStream(writeMetrics, fsdos) + class ManualCloseBufferedOutputStream + extends BufferedOutputStream(ts, bufferSize) with ManualCloseOutputStream + mcs = new ManualCloseBufferedOutputStream + } + + def open(): RemoteBlockObjectWriter = { + if (hasBeenClosed) { + throw new IllegalStateException("Writer already closed. Cannot be reopened.") + } + if (!initialized) { + initialize() + initialized = true + } + + bs = serializerManager.wrapStream(blockId, mcs) + objOut = serializerInstance.serializeStream(bs) + streamOpen = true + this + } + + /** + * Close and cleanup all resources. + * Should call after committing or reverting partial writes. + */ + private def closeResources(): Unit = { + if (initialized) { + Utils.tryWithSafeFinally { + mcs.manualClose() + } { + mcs = null + bs = null + fsdos = null + ts = null + objOut = null + initialized = false + streamOpen = false + hasBeenClosed = true + } + } + } + + /** + * Commits any remaining partial writes and closes resources. + */ + override def close() { + if (initialized) { + Utils.tryWithSafeFinally { + commitAndGet() + } { + closeResources() + } + } + } + + /** + * Flush the partial writes and commit them as a single atomic block. + * A commit may write additional bytes to frame the atomic block. + * + * @return file segment with previous offset and length committed on this call. + */ + def commitAndGet(): HadoopFileSegment = { + if (streamOpen) { + // NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the + // serializer stream and the lower level stream. + objOut.flush() + bs.flush() + objOut.close() + streamOpen = false + + /* NOTE by Chenzhao: It doesn't work for local file system */ + if (syncWrites) { + // Force outstanding writes to disk and track how long it takes + val start = System.nanoTime() + fsdos.hsync() + writeMetrics.incWriteTime(System.nanoTime() - start) + } + + val pos = fsdos.getPos + val fileSegment = new HadoopFileSegment(file, committedPosition, pos - committedPosition) + committedPosition = pos + // In certain compression codecs, more bytes are written after streams are closed + writeMetrics.incBytesWritten(committedPosition - reportedPosition) + reportedPosition = committedPosition + numRecordsWritten = 0 + fileSegment + } else { + new HadoopFileSegment(file, committedPosition, 0) + } + } + + + /** + * Reverts writes that haven't been committed yet. Callers should invoke this function + * when there are runtime exceptions. This method will not throw, though it may be + * unsuccessful in truncating written data. + * + * @return the file that this DiskBlockObjectWriter wrote to. + */ + def revertPartialWritesAndClose(): Path = { + // Discard current writes. We do this by flushing the outstanding writes and then + // truncating the file to its initial position. + Utils.tryWithSafeFinally { + if (initialized) { + writeMetrics.decBytesWritten(reportedPosition - committedPosition) + writeMetrics.decRecordsWritten(numRecordsWritten) + streamOpen = false + closeResources() + } + } { + try { + close() + fs.truncate(file, committedPosition) + } catch { + case _: UnsupportedOperationException => logInfo("This filesystem doesn't support" + + "truncate") + case e: Exception => + logError("Uncaught exception while reverting partial writes to file " + file, e) + } + } + file + } + + /** + * Writes a key-value pair. + */ + def write(key: Any, value: Any) { + if (!streamOpen) { + open() + } + + objOut.writeKey(key) + objOut.writeValue(value) + recordWritten() + } + + override def write(b: Int): Unit = throw new UnsupportedOperationException() + + override def write(kvBytes: Array[Byte], offs: Int, len: Int): Unit = { + if (!streamOpen) { + open() + } + + bs.write(kvBytes, offs, len) + } + + /** + * Notify the writer that a record worth of bytes has been written with OutputStream#write. + */ + def recordWritten(): Unit = { + numRecordsWritten += 1 + writeMetrics.incRecordsWritten(1) + + if (numRecordsWritten % 16384 == 0) { + updateBytesWritten() + } + } + + /** + * Report the number of bytes written in this writer's shuffle write metrics. + * Note that this is only valid before the underlying streams are closed. + */ + private def updateBytesWritten() { + val pos = fsdos.getPos + writeMetrics.incBytesWritten(pos - reportedPosition) + reportedPosition = pos + } + + // For testing + private[spark] override def flush() { + objOut.flush() + bs.flush() + } +} diff --git a/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleBlockIterator.scala b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleBlockIterator.scala new file mode 100644 index 000000000..f8d8f8b20 --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleBlockIterator.scala @@ -0,0 +1,550 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.remote + +import java.io.{IOException, InputStream} +import java.nio.ByteBuffer +import java.util.concurrent.{LinkedBlockingQueue, TimeUnit} +import java.{lang, util} + +import javax.annotation.concurrent.GuardedBy + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} +import scala.concurrent.{ExecutionContext, Future} +import scala.util.{Failure, Success} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.{REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS, REDUCER_MAX_REQS_IN_FLIGHT, REDUCER_MAX_SIZE_IN_FLIGHT} +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.shuffle._ +import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} +import org.apache.spark.storage.{BlockException, BlockId, BlockManagerId, ShuffleBlockId} +import org.apache.spark.util.io.ChunkedByteBufferOutputStream +import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.{SparkConf, SparkEnv, SparkException, TaskContext} + +/** + * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block + * manager. For remote blocks, it fetches them using the provided BlockTransferService. + * + * This creates an iterator of (BlockID, InputStream) tuples so the caller can handle blocks + * in a pipelined fashion as they are received. + * + * The implementation throttles the remote fetches so they don't exceed maxBytesInFlight to avoid + * using too much memory. + * + * @param context [[TaskContext]], used for metrics update + * @param shuffleClient [[ShuffleClient]] for fetching remote blocks + * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. + * For each block we also require the size (in bytes as a long field) in + * order to throttle the memory usage. Note that zero-sized blocks are + * already excluded, which happened in + * [[org.apache.spark.MapOutputTracker]]. + * @param streamWrapper A function to wrap the returned input stream. + * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. + * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. + * @param maxBlocksInFlightPerAddress max number of shuffle blocks being fetched at any given point + * for a given remote host:port. + * @param detectCorrupt whether to detect any corruption in fetched blocks. + */ +private[spark] +final class RemoteShuffleBlockIterator( + context: TaskContext, + shuffleClient: BlockStoreClient, + resolver: RemoteShuffleBlockResolver, + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], + streamWrapper: (BlockId, InputStream) => InputStream, + maxBytesInFlight: Long, + maxReqsInFlight: Int, + maxBlocksInFlightPerAddress: Int, + detectCorrupt: Boolean, + readMetrics: ShuffleReadMetricsReporter, + doBatchFetch: Boolean) + + extends Iterator[(BlockId, InputStream)] with Logging { + + import RemoteShuffleBlockIterator._ + + private val indexCacheEnabled = resolver.indexCacheEnabled + + /** + * Total number of blocks to fetch. This should be equal to the total number of blocks + * in [[blocksByAddress]] because we already filter out zero-sized blocks in [[blocksByAddress]]. + * + * This should equal localBlocks.size + remoteBlocks.size. + */ + private[this] var numBlocksToFetch = 0 + + /** + * The number of blocks processed by the caller. The iterator is exhausted when + * [[numBlocksProcessed]] == [[numBlocksToFetch]]. + */ + private[this] var numBlocksProcessed = 0 + + private[this] val startTime = System.currentTimeMillis + + /** + * A queue to hold our results. This turns the asynchronous model provided by + * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator). + */ + private[this] val results = new LinkedBlockingQueue[RemoteFetchResult] + + /** + * Current [[RemoteFetchResult]] being processed. We track this so we can release + * the current buffer in case of a runtime exception when processing the current buffer. + */ + @volatile private[this] var currentResult: SuccessRemoteFetchResult = null + + /** + * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that + * the number of bytes in flight is limited to maxBytesInFlight. + */ + private[this] val fetchRequests = new Queue[RemoteFetchRequest] + + /** + * Queue of fetch requests which could not be issued the first time they were dequeued. These + * requests are tried again when the fetch constraints are satisfied. + */ + private[this] val deferredFetchRequests = new HashMap[BlockManagerId, Queue[RemoteFetchRequest]]() + + /** Current bytes in flight from our requests */ + private[this] var bytesInFlight = 0L + + /** Current number of requests in flight */ + private[this] var reqsInFlight = 0 + + /** Current number of blocks in flight per host:port */ + private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, Int]() + + /** + * The blocks that can't be decompressed successfully, it is used to guarantee that we retry + * at most once for those corrupted blocks. + */ + private[this] val corruptedBlocks = mutable.HashSet[BlockId]() + + /** + * Whether the iterator is still active. If isZombie is true, the callback interface will no + * longer place fetched blocks into [[results]]. + */ + @GuardedBy("this") + private[this] var isZombie = false + + initialize() + + /** + * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. + */ + private[this] def cleanup() { + synchronized { + isZombie = true + } + val iter = results.iterator() + while (iter.hasNext) { + val result = iter.next() + result match { + case SuccessRemoteFetchResult(_, _, _, _, buf, _) => + readMetrics.incRemoteBytesRead(buf.size) + readMetrics.incRemoteBlocksFetched(1) + case _ => + } + } + } + + private[this] def sendRequest(req: RemoteFetchRequest) { + logDebug("Sending request for %d blocks (%s) from %s".format( + req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) + bytesInFlight += req.size + reqsInFlight += 1 + + // so we can look up the info of each blockID + val infoMap = req.blocks.map { case (blockId, size, mapIndex) => (blockId.toString, (size, mapIndex))}.toMap + val remainingBlocks = new HashSet[String]() ++= infoMap.keys + val blockIds = req.blocks.map(_._1) + val address = req.address + + val blockFetchingListener = new BlockFetchingListener { + override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { + val casted = buf.asInstanceOf[HadoopFileSegmentManagedBuffer] + val res = Future { + // Another possibility: casted.prepareData(results.size == 0) to only eagerly require a + // Shuffle Block when the results queue is empty + casted.prepareData(eagerRequirement = eagerRequirement) + } (RemoteShuffleBlockIterator.executionContext) + res.onComplete { + case Success(_) => + RemoteShuffleBlockIterator.this.synchronized { + // Only add the buffer to results queue if the iterator is not zombie, + // i.e. cleanup() has not been called yet. + if (!isZombie) { + remainingBlocks -= blockId + results.put(SuccessRemoteFetchResult( + BlockId(blockId), + infoMap(blockId)._2, + address, + infoMap(blockId)._1, + buf, + remainingBlocks.isEmpty)) + logDebug("remainingBlocks: " + remainingBlocks) + } + } + case Failure(e) => + results.put(FailureRemoteFetchResult(BlockId(blockId), infoMap(blockId)._2, address, e)) + + } (RemoteShuffleBlockIterator.executionContext) + } + + override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { + logError(s"Failed to get block(s) ", e) + results.put(FailureRemoteFetchResult(BlockId(blockId), infoMap(blockId)._2, address, e)) + } + } + if (indexCacheEnabled) { + shuffleClient.fetchBlocks( + address.host, address.port, address.executorId, blockIds.map(_.toString()).toArray, + blockFetchingListener, null) + } else { + fetchBlocks(blockIds.toArray, blockFetchingListener) + } + } + + private def fetchBlocks( + blockIds: Array[BlockId], + listener: BlockFetchingListener) = { + for (blockId <- blockIds) { + // Note by Chenzhao: Can be optimized by reading consecutive blocks + try { + val buf = resolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) + listener.onBlockFetchSuccess(blockId.toString(), buf) + } catch { + case e: Exception => listener.onBlockFetchFailure(blockId.toString(), e) + } + } + } + + // For remote shuffling, all blocks are remote, so this actually resembles RemoteFetchRequests + private[this] def splitLocalRemoteBlocks(): ArrayBuffer[RemoteFetchRequest] = { + + // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them + // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 + // nodes, rather than blocking on reading output from one node. + val targetRequestSize = math.max(maxBytesInFlight / 5, 1L) + logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize + + ", maxBlocksInFlightPerAddress: " + maxBlocksInFlightPerAddress) + + // Split local and remote blocks. Remote blocks are further split into FetchRequests of size + // at most maxBytes InFlight in order to limit the amount of data in flight. + val remoteRequests = new ArrayBuffer[RemoteFetchRequest] + + for ((address, blockInfos) <- blocksByAddress) { + val iterator = blockInfos.iterator + var curRequestSize = 0L + var curBlocks = new ArrayBuffer[(BlockId, Long, Int)] + while (iterator.hasNext) { + val (blockId, size, mapIndex) = iterator.next() + if (size < 0) { + throw new BlockException(blockId, "Negative block size " + size) + } else if (size == 0) { + throw new BlockException(blockId, "Zero-sized blocks should be excluded.") + } else { + curBlocks += ((blockId, size, mapIndex)) + numBlocksToFetch += 1 + curRequestSize += size + } + // We only care about the amount of requests, but not the total content size of blocks, + // due to during this fetch process we only get a range(offset and length) for each block. + // The block content will not be transferred through netty, while it's read from a + // globally-accessible Hadoop compatible file system + if (curRequestSize >= targetRequestSize || + curBlocks.size >= maxBlocksInFlightPerAddress) { + // Add this FetchRequest + remoteRequests += new RemoteFetchRequest(address, curBlocks) + logDebug(s"Creating fetch request of $curRequestSize at $address " + + s"with ${curBlocks.size} blocks") + curBlocks = new ArrayBuffer[(BlockId, Long, Int)] + curRequestSize = 0 + } + } + // Add in the final request + if (curBlocks.nonEmpty) { + remoteRequests += new RemoteFetchRequest(address, curBlocks) + } + } + logInfo(s"Getting $numBlocksToFetch non-empty blocks, " + + s"number of remote requests: ${remoteRequests.size}") + remoteRequests + } + + private[this] def initialize(): Unit = { + // Add a task completion callback (called in both success case and failure case) to cleanup. + context.addTaskCompletionListener[Unit](_ => cleanup()) + + // Split local and remote blocks. Actually it assembles remote fetch requests due to all blocks + // are remote under remote shuffle + val remoteRequests = splitLocalRemoteBlocks() + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(remoteRequests) + assert ((0 == reqsInFlight) == (0 == bytesInFlight), + "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight + + ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight) + + // Send out initial requests for blocks, up to our maxBytesInFlight + fetchUpToMaxBytes() + + val numFetches = remoteRequests.size - fetchRequests.size + + } + + override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch + + /** + * Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers + * underlying each InputStream will be freed by the cleanup() method registered with the + * TaskCompletionListener. However, callers should close() these InputStreams + * as soon as they are no longer needed, in order to release memory as early as possible. + * + * Throws a FetchFailedException if the next block could not be fetched. + */ + override def next(): (BlockId, InputStream) = { + if (!hasNext) { + throw new NoSuchElementException + } + + numBlocksProcessed += 1 + + var result: RemoteFetchResult = null + var input: InputStream = null + // Take the next fetched result and try to decompress it to detect data corruption, + // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch + // is also corrupt, so the previous stage could be retried. + // For local shuffle block, throw FailureFetchResult for the first IOException. + while (result == null) { + val startFetchWait = System.nanoTime() + result = results.take() + val fetchWaitTime = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait) + readMetrics.incFetchWaitTime(fetchWaitTime) + + + result match { + case r @SuccessRemoteFetchResult( + blockId, mapIndex, address, size, buf, isNetworkReqDone) => + numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 + readMetrics.incRemoteBytesRead(buf.size()) + readMetrics.incRemoteBlocksFetched(1) + bytesInFlight -= size + if (isNetworkReqDone) { + reqsInFlight -= 1 + logDebug("Number of requests in flight " + reqsInFlight) + } + if (buf.size == 0) { + // We will never legitimately receive a zero-size block. All blocks with zero records + // have zero size and all zero-size blocks have no records (and hence should never + // have been requested in the first place). This statement relies on behaviors of the + // shuffle writers, which are guaranteed by the following test cases: + // + // - BypassMergeSortShuffleWriterSuite: "write with some empty partitions" + // - UnsafeShuffleWriterSuite: "writeEmptyIterator" + // - DiskBlockObjectWriterSuite: "commit() and close() without ever opening or writing + // + // There is not an explicit test for SortShuffleWriter but the underlying APIs that + // uses are shared by the UnsafeShuffleWriter (both writers use DiskBlockObjectWriter + // which returns a zero-size from commitAndGet() in case no records were written + // since the last call. + val msg = s"Received a zero-size buffer for block $blockId from $address " + + s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)" + throwFetchFailedException(blockId, mapIndex, address, new IOException(msg)) + } + val in = try { + buf.createInputStream() + } catch { + case e: IOException => + // Actually here we know the buf is a HadoopFileSegmentManagedBuffer + logError("Failed to create input stream from block backed by Hadoop file segment", e) + throwFetchFailedException(blockId, mapIndex, address, e) + } + var isStreamCopied: Boolean = false + // Detect ShuffleBlock corruption + try { + input = streamWrapper(blockId, in) + // Only copy the stream if it's wrapped by compression or encryption, also the size of + // block is small (the decompressed block is smaller than maxBytesInFlight) + if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) { + isStreamCopied = true + val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) + // Decompress the whole block at once to detect any corruption, which could increase + // the memory usage tne potential increase the chance of OOM. + // TODO: manage the memory used here, and spill it into disk in case of OOM. + Utils.copyStream(input, out, closeStreams = true) + input = out.toChunkedByteBuffer.toInputStream(dispose = true) + } + } catch { + case e: IOException => + if (corruptedBlocks.contains(blockId)) { + throwFetchFailedException(blockId, mapIndex, address, e) + } else { + logWarning(s"got an corrupted block $blockId from $address, fetch again", e) + corruptedBlocks += blockId + fetchRequests += RemoteFetchRequest(address, Array((blockId, size, mapIndex))) + result = null + } + } finally { + if (isStreamCopied) { + in.close() + } + } + case FailureRemoteFetchResult(blockId, mapIndex, address, e) => + throwFetchFailedException(blockId, mapIndex, address, e) + } + + // Send fetch requests up to maxBytesInFlight + fetchUpToMaxBytes() + } + + currentResult = result.asInstanceOf[SuccessRemoteFetchResult] + (currentResult.blockId, input) + } + + private def fetchUpToMaxBytes(): Unit = { + // Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host + // immediately, defer the request until the next time it can be processed. + + // Process any outstanding deferred fetch requests if possible. + if (deferredFetchRequests.nonEmpty) { + for ((remoteAddress, defReqQueue) <- deferredFetchRequests) { + while (isRemoteBlockFetchable(defReqQueue) && + !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) { + val request = defReqQueue.dequeue() + logDebug(s"Processing deferred fetch request for $remoteAddress with " + + s"${request.blocks.length} blocks") + send(remoteAddress, request) + if (defReqQueue.isEmpty) { + deferredFetchRequests -= remoteAddress + } + } + } + } + + // Process any regular fetch requests if possible. + while (isRemoteBlockFetchable(fetchRequests)) { + val request = fetchRequests.dequeue() + val remoteAddress = request.address + if (isRemoteAddressMaxedOut(remoteAddress, request)) { + logDebug(s"Deferring fetch request for $remoteAddress with ${request.blocks.size} blocks") + val defReqQueue = + deferredFetchRequests.getOrElse(remoteAddress, new Queue[RemoteFetchRequest]()) + defReqQueue.enqueue(request) + deferredFetchRequests(remoteAddress) = defReqQueue + } else { + send(remoteAddress, request) + } + } + + } + + private def send(remoteAddress: BlockManagerId, request: RemoteFetchRequest): Unit = { + sendRequest(request) + numBlocksInFlightPerAddress(remoteAddress) = + numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size + } + + private def isRemoteBlockFetchable(fetchReqQueue: Queue[RemoteFetchRequest]): Boolean = { + fetchReqQueue.nonEmpty && + (bytesInFlight == 0 || + (reqsInFlight + 1 <= maxReqsInFlight && + bytesInFlight + fetchReqQueue.front.size <= maxBytesInFlight)) + } + + // Checks if sending a new fetch request will exceed the max no. of blocks being fetched from a + // given remote address. + private def isRemoteAddressMaxedOut( + remoteAddress: BlockManagerId, request: RemoteFetchRequest): Boolean = { + numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size > + maxBlocksInFlightPerAddress + } + + + private def throwFetchFailedException( + blockId: BlockId, mapIndex: Int, address: BlockManagerId, e: Throwable) = { + blockId match { + case ShuffleBlockId(shufId, mapId, reduceId) => + // Suppress the BlockManagerId to only retry the failed map tasks, instead of all map tasks + // that shared the same executor with the failed map tasks. This is more reasonable in + // remote shuffle + throw new FetchFailedException( + null, shufId.toInt, mapId, mapIndex, reduceId, e) + case _ => + throw new SparkException( + "Failed to get block " + blockId + ", which is not a shuffle block", e) + } + } +} + +private[remote] object RemoteShuffleBlockIterator { + + private val maxConcurrentFetches = + RemoteShuffleManager.getConf.get(RemoteShuffleConf.NUM_CONCURRENT_FETCH) + private val eagerRequirement = + RemoteShuffleManager.getConf.get(RemoteShuffleConf.DATA_FETCH_EAGER_REQUIREMENT) + + private val executionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("shuffle-data-fetching", maxConcurrentFetches)) + + /** + * A request to fetch blocks from a remote BlockManager. + * @param address remote BlockManager to fetch from. + * @param blocks Sequence of tuple, where the first element is the block id, + * and the second element is the estimated size, used to calculate bytesInFlight. + */ + case class RemoteFetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long, Int)]) { + val size = blocks.map(_._2).sum + } + + /** + * Result of a fetch from a remote block. + */ + private[remote] sealed trait RemoteFetchResult { + val blockId: BlockId + } + + /** + * Result of a fetch from a remote block successfully. + * @param blockId block id + * @param buf `ManagedBuffer` for the content. + */ + private[remote] case class SuccessRemoteFetchResult( + blockId: BlockId, + mapIndex: Int, + address: BlockManagerId, + size: Long, + buf: ManagedBuffer, + isNetworkReqDone: Boolean) extends RemoteFetchResult { + require(buf != null) + require(size >= 0) + } + + /** + * Result of a fetch from a remote block unsuccessfully. + * @param blockId block id + * @param e the failure exception + */ + private[remote] case class FailureRemoteFetchResult( + blockId: BlockId, + mapIndex: Int, + address: BlockManagerId, + e: Throwable) extends RemoteFetchResult +} diff --git a/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleBlockResolver.scala b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleBlockResolver.scala new file mode 100644 index 000000000..13d10c437 --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleBlockResolver.scala @@ -0,0 +1,388 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.remote + +import java.io._ +import java.nio.{ByteBuffer, LongBuffer} +import java.util.UUID +import java.util.function.Consumer + +import scala.collection.mutable +import com.google.common.cache.{CacheBuilder, CacheLoader, Weigher} +import org.apache.hadoop.fs.{FSDataInputStream, Path} +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.BLOCK_MANAGER_PORT +import org.apache.spark.network.BlockTransferService +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.netty.RemoteShuffleTransferService +import org.apache.spark.network.shuffle.ShuffleIndexRecord +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.shuffle.ShuffleBlockResolver +import org.apache.spark.storage.{BlockId, ShuffleBlockId, TempLocalBlockId, TempShuffleBlockId} +import org.apache.spark.util.Utils + +/** + * Create and maintain the shuffle blocks' mapping between logic block and physical file location. + * It also manages the resource cleaning and temporary files creation, + * like a [[org.apache.spark.shuffle.IndexShuffleBlockResolver]] ++ + * [[org.apache.spark.storage.DiskBlockManager]] + * + */ +class RemoteShuffleBlockResolver(conf: SparkConf) extends ShuffleBlockResolver with Logging { + + private val master = conf.get(RemoteShuffleConf.STORAGE_MASTER_URI) + private val rootDir = conf.get(RemoteShuffleConf.SHUFFLE_FILES_ROOT_DIRECTORY) + // 1. Use lazy evaluation due to at the time this class(and its fields) is initialized, + // SparkEnv._conf is not yet set + // 2. conf.getAppId may not always work, because during unit tests we may just new a Resolver + // instead of getting one from the ShuffleManager referenced by SparkContext + private lazy val applicationId = + if (Utils.isTesting) s"test${UUID.randomUUID()}" else conf.getAppId + private def dirPrefix = s"$master/$rootDir/$applicationId" + + // This referenced is shared for all the I/Os with shuffling storage system + lazy val fs = new Path(dirPrefix).getFileSystem(RemoteShuffleManager.active.getHadoopConf) + + private[remote] lazy val remoteShuffleTransferService: BlockTransferService = { + val env = SparkEnv.get + new RemoteShuffleTransferService( + conf, + env.securityManager, + env.blockManager.blockManagerId.host, + env.blockManager.blockManagerId.host, + env.conf.get(BLOCK_MANAGER_PORT), + conf.get(RemoteShuffleConf.NUM_TRANSFER_SERVICE_THREADS)) + } + private[remote] lazy val shuffleServerId = { + if (indexCacheEnabled) { + remoteShuffleTransferService.asInstanceOf[RemoteShuffleTransferService].getShuffleServerId + } else { + SparkEnv.get.blockManager.blockManagerId + } + } + + private[remote] val indexCacheEnabled: Boolean = { + val size = JavaUtils.byteStringAsBytes(conf.get(RemoteShuffleConf.REMOTE_INDEX_CACHE_SIZE)) + val dynamicAllocationEnabled = + conf.getBoolean("spark.dynamicAllocation.enabled", false) + (size > 0) && { + if (dynamicAllocationEnabled) { + logWarning("Index cache is not enabled due to dynamic allocation is enabled, the" + + " cache in executors may get removed. ") + } + !dynamicAllocationEnabled + } + } + + if (indexCacheEnabled) { + logWarning("Fetching index files from the cache of executors which wrote them") + } + + // These 3 fields will only be initialized when index cache enabled + lazy val indexCacheSize: String = + conf.get("spark.shuffle.remote.index.cache.size", "30m") + + lazy val indexCacheLoader: CacheLoader[Path, RemoteShuffleIndexInfo] = + new CacheLoader[Path, RemoteShuffleIndexInfo]() { + override def load(file: Path) = new RemoteShuffleIndexInfo(file) + } + + lazy val shuffleIndexCache = + CacheBuilder.newBuilder + .maximumWeight(JavaUtils.byteStringAsBytes(indexCacheSize)) + .weigher(new Weigher[Path, RemoteShuffleIndexInfo]() { + override def weigh(file: Path, indexInfo: RemoteShuffleIndexInfo): Int = + indexInfo.getSize + }) + .build(indexCacheLoader) + + def getDataFile(shuffleId: Int, mapId: Long): Path = { + new Path(s"${dirPrefix}/${shuffleId}_${mapId}.data") + } + + def getIndexFile(shuffleId: Int, mapId: Long): Path = { + new Path(s"${dirPrefix}/${shuffleId}_${mapId}.index") + } + + /** + * Write an index file with the offsets of each block, plus a final offset at the end for the + * end of the output file. This will be used by getBlockData to figure out where each block + * begins and ends. + * + * It will commit the data and index file as an atomic operation, use the existing ones, or + * replace them with new ones. + * + * Note: the `lengths` will be updated to match the existing index file if use the existing ones. + */ + def writeIndexFileAndCommit( + shuffleId: Int, + mapId: Long, + lengths: Array[Long], + dataTmp: Path): Unit = { + + val indexFile = getIndexFile(shuffleId, mapId) + val indexTmp = RemoteShuffleUtils.tempPathWith(indexFile) + try { + val dataFile = getDataFile(shuffleId, mapId) + // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure + // the following check and rename are atomic. + synchronized { + val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length) + if (existingLengths != null) { + // Another attempt for the same task has already written our map outputs successfully, + // so just use the existing partition lengths and delete our temporary map outputs. + System.arraycopy(existingLengths, 0, lengths, 0, lengths.length) + if (dataTmp != null && fs.exists(dataTmp)) { + fs.delete(dataTmp, true) + } + } else { + // This is the first successful attempt in writing the map outputs for this task, + // so override any existing index and data files with the ones we wrote. + val out = new DataOutputStream(new BufferedOutputStream(fs.create(indexTmp))) + val offsetsBuffer = new Array[Long](lengths.length + 1) + Utils.tryWithSafeFinally { + // We take in lengths of each block, need to convert it to offsets. + var offset = 0L + offsetsBuffer(0) = 0 + out.writeLong(0) + var i = 1 + for (length <- lengths) { + offset += length + offsetsBuffer(i) = offset + out.writeLong(offset) + i += 1 + } + } { + out.close() + } + // Put index info in cache if enabled + if (indexCacheEnabled) { + shuffleIndexCache.put(indexFile, new RemoteShuffleIndexInfo(offsetsBuffer)) + } + if (fs.exists(indexFile)) { + fs.delete(indexFile, true) + } + if (fs.exists(dataFile)) { + fs.delete(dataFile, true) + } + if (!fs.rename(indexTmp, indexFile)) { + throw new IOException("fail to rename file " + indexTmp + " to " + indexFile) + } + if (dataTmp != null && fs.exists(dataTmp) && !fs.rename(dataTmp, dataFile)) { + throw new IOException("fail to rename file " + dataTmp + " to " + dataFile) + } + } + } + } finally { + if (fs.exists(indexTmp) && !fs.delete(indexTmp, true)) { + logError(s"Failed to delete temporary index file at ${indexTmp.getName}") + } + } + } + + /** + * Check whether the given index and data files match each other. + * If so, return the partition lengths in the data file. Otherwise return null. + */ + private def checkIndexAndDataFile(index: Path, data: Path, blocks: Int): Array[Long] = { + + // the index file should exist(of course) and have `block + 1` longs as offset. + if (!fs.exists(index) || fs.getFileStatus(index).getLen != (blocks + 1) * 8L) { + return null + } + val lengths = new Array[Long](blocks) + // Read the lengths of blocks + val in = try { + // Note by Chenzhao: originally [[NioBufferedFileInputStream]] is used + new DataInputStream(new BufferedInputStream(fs.open(index))) + } catch { + case e: IOException => + return null + } + try { + // Convert the offsets into lengths of each block + var offset = in.readLong() + if (offset != 0L) { + return null + } + var i = 0 + while (i < blocks) { + val off = in.readLong() + lengths(i) = off - offset + offset = off + i += 1 + } + } catch { + case e: IOException => + return null + } finally { + in.close() + } + + // the size of data file should match with index file + if (fs.exists(data) && fs.getFileStatus(data).getLen == lengths.sum) { + lengths + } else { + null + } + } + + def getBlockData(bId: BlockId, dirs: Option[Array[String]] = None): ManagedBuffer = { + val blockId = bId.asInstanceOf[ShuffleBlockId] + // The block is actually going to be a range of a single map output file for this map, so + // find out the consolidated file, then the offset within that from our index + val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId) + + val (offset, length) = + if (indexCacheEnabled) { + val shuffleIndexInfo = shuffleIndexCache.get(indexFile) + val range = shuffleIndexInfo.getIndex(blockId.reduceId) + (range.getOffset, range.getLength) + } else { + // SPARK-22982: if this FileInputStream's position is seeked forward by another + // piece of code which is incorrectly using our file descriptor then this code + // will fetch the wrong offsets (which may cause a reducer to be sent a different + // reducer's data). The explicit position checks added here were a useful debugging + // aid during SPARK-22982 and may help prevent this class of issue from re-occurring + // in the future which is why they are left here even though SPARK-22982 is fixed. + val in = fs.open(indexFile) + in.seek(blockId.reduceId * 8L) + try { + val offset = in.readLong() + val nextOffset = in.readLong() + val actualPosition = in.getPos() + val expectedPosition = blockId.reduceId * 8L + 16 + if (actualPosition != expectedPosition) { + throw new Exception(s"SPARK-22982: Incorrect channel position " + + s"after index file reads: expected $expectedPosition but actual" + + s" position was $actualPosition.") + } + (offset, nextOffset - offset) + } finally { + in.close() + } + } + new HadoopFileSegmentManagedBuffer( + getDataFile(blockId.shuffleId, blockId.mapId), + offset, + length) + } + + /** + * Remove data file and index file that contain the output data from one map. + */ + def removeDataByMap(shuffleId: Int, mapId: Long): Unit = { + var file = getDataFile(shuffleId, mapId) + if (fs.exists(file)) { + if (!fs.delete(file, true)) { + logWarning(s"Error deleting data ${file.toString}") + } + } + + file = getIndexFile(shuffleId, mapId) + if (fs.exists(file)) { + if (!fs.delete(file, true)) { + logWarning(s"Error deleting index ${file.getName()}") + } + } + } + + def createTempShuffleBlock(): (TempShuffleBlockId, Path) = { + RemoteShuffleUtils.createTempShuffleBlock(dirPrefix) + } + + def createTempLocalBlock(): (TempLocalBlockId, Path) = { + RemoteShuffleUtils.createTempLocalBlock(dirPrefix) + } + + // Mainly for tests, similar to [[DiskBlockManager.getAllFiles]] + def getAllFiles(): Seq[Path] = { + val dir = new Path(dirPrefix) + val internalIter = fs.listFiles(dir, true) + new Iterator[Path] { + override def hasNext: Boolean = internalIter.hasNext + + override def next(): Path = internalIter.next().getPath + }.toSeq + } + + override def stop(): Unit = { + val dir = new Path(dirPrefix) + fs.delete(dir, true) + try { + HadoopFileSegmentManagedBuffer.handleCache.values().forEach { + new Consumer[mutable.HashMap[Path, FSDataInputStream]] { + override def accept(t: mutable.HashMap[Path, FSDataInputStream]): Unit = { + t.values.foreach(JavaUtils.closeQuietly) + } + } + } + JavaUtils.closeQuietly(remoteShuffleTransferService) + } catch { + case e: Exception => logInfo(s"Exception thrown when closing " + + s"RemoteShuffleTransferService\n" + + s"Caused by: ${e.toString}\n${e.getStackTrace.mkString("\n")}") + } + } +} + +// For index cache feature, this is the data structure stored in Guava cache +private[remote] class RemoteShuffleIndexInfo extends Logging { + + private var offsets: LongBuffer = _ + private var size: Int = _ + + // Construction by reading index files from storage to memory, which happens in reduce stage + def this(indexFile: Path) { + this() + val fs = RemoteShuffleManager.getFileSystem + + size = fs.getFileStatus(indexFile).getLen.toInt + val rawBuffer = ByteBuffer.allocate(size) + offsets = rawBuffer.asLongBuffer + var input: FSDataInputStream = null + try { + logInfo("Loading index file from storage to Guava cache") + input = fs.open(indexFile) + input.readFully(rawBuffer.array) + } finally { + if (input != null) { + input.close() + } + } + } + + // Construction by directly putting the index offsets info in cache, which happens in map stage + def this(offsetsArray: Array[Long]) { + this() + size = offsetsArray.length * 8 + offsets = LongBuffer.wrap(offsetsArray) + } + + def getSize: Int = size + + /** + * Get index offset for a particular reducer. + */ + def getIndex(reduceId: Int): ShuffleIndexRecord = { + val offset = offsets.get(reduceId) + val nextOffset = offsets.get(reduceId + 1) + new ShuffleIndexRecord(offset, nextOffset - offset) + } +} diff --git a/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleConf.scala b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleConf.scala new file mode 100644 index 000000000..3beadf00d --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleConf.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.remote + +import org.apache.spark.internal.config.{ConfigBuilder, ConfigEntry} + +object RemoteShuffleConf { + + val STORAGE_MASTER_URI: ConfigEntry[String] = + ConfigBuilder("spark.shuffle.remote.storageMasterUri") + .doc("Contact this storage master while persisting shuffle files") + .stringConf + .createWithDefault("hdfs://localhost:9001") + + val STORAGE_HDFS_MASTER_UI_PORT: ConfigEntry[String] = + ConfigBuilder("spark.shuffle.remote.hdfs.storageMasterUIPort") + .doc("Contact this UI port to retrieve HDFS configurations") + .stringConf + .createWithDefault("50070") + + val SHUFFLE_FILES_ROOT_DIRECTORY: ConfigEntry[String] = + ConfigBuilder("spark.shuffle.remote.filesRootDirectory") + .doc("Use this as the root directory for shuffle files") + .stringConf + .createWithDefault("/shuffle") + + val DFS_REPLICATION: ConfigEntry[Int] = + ConfigBuilder("spark.shuffle.remote.hdfs.replication") + .doc("The default replication of remote storage system, will override dfs.replication" + + " when HDFS is used as shuffling storage") + .intConf + .createWithDefault(3) + + val REMOTE_OPTIMIZED_SHUFFLE_ENABLED: ConfigEntry[Boolean] = + ConfigBuilder("spark.shuffle.remote.optimizedPathEnabled") + .doc("Enable using unsafe-optimized shuffle writer") + .internal() + .booleanConf + .createWithDefault(true) + + val REMOTE_BYPASS_MERGE_THRESHOLD: ConfigEntry[Int] = + ConfigBuilder("spark.shuffle.remote.bypassMergeThreshold") + .doc("Remote shuffle manager uses this threshold to decide using bypass-merge(hash-based)" + + "shuffle or not, a new configuration is introduced(and it's -1 by default) because we" + + " want to explicitly make disabling hash-based shuffle writer as the default behavior." + + " When memory is relatively sufficient, using sort-based shuffle writer in remote shuffle" + + " is often more efficient than the hash-based one. Because the bypass-merge shuffle " + + "writer proceeds I/O of 3x total shuffle size: 1 time for read I/O and 2 times for write" + + " I/Os, and this can be an even larger overhead under remote shuffle, the 3x shuffle size" + + " is gone through network, arriving at remote storage system.") + .intConf + .createWithDefault(-1) + + val REMOTE_INDEX_CACHE_SIZE: ConfigEntry[String] = + ConfigBuilder("spark.shuffle.remote.index.cache.size") + .doc("This index file cache resides in each executor. If it's a positive value, index " + + "cache will be turned on: instead of reading index files directly from remote storage" + + ", a reducer will fetch the index files from the executors that write them through" + + " network. And those executors will return the index files kept in cache. (read them" + + "from storage if needed)") + .stringConf + .createWithDefault("0") + + val NUM_TRANSFER_SERVICE_THREADS: ConfigEntry[Int] = + ConfigBuilder("spark.shuffle.remote.numIndexReadThreads") + .doc("The maximum number of server/client threads used in RemoteShuffleTransferService for" + + "index files transferring") + .intConf + .createWithDefault(3) + + val NUM_CONCURRENT_FETCH: ConfigEntry[Int] = + ConfigBuilder("spark.shuffle.remote.numReadThreads") + .doc("The maximum number of concurrent reading threads fetching shuffle data blocks") + .intConf + .createWithDefault(Runtime.getRuntime.availableProcessors()) + + val REUSE_FILE_HANDLE: ConfigEntry[Boolean] = + ConfigBuilder("spark.shuffle.remote.reuseFileHandle") + .doc("By switching on this feature, the file handles returned by Filesystem open operations" + + " will be cached/reused inside an executor(across different rounds of reduce tasks)," + + " eliminating open overhead. This should improve the reduce stage performance only when" + + " file open operations occupy majority of the time, e.g. There is a large number of" + + " shuffle blocks, each reading a fairly small block of data, and there is no other" + + " compute in the reduce stage.") + .booleanConf + .createWithDefault(false) + + val DATA_FETCH_EAGER_REQUIREMENT: ConfigEntry[Boolean] = + ConfigBuilder("spark.shuffle.remote.eagerRequirementDataFetch") + .doc("With eager requirement = false, a shuffle block will be counted ready and served for" + + " compute until all content of the block is put in Spark's local memory. With eager " + + "requirement = true, a shuffle block will be served to later compute after the bytes " + + "required is fetched and put in memory") + .booleanConf + .createWithDefault(false) + +} diff --git a/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleManager.scala b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleManager.scala new file mode 100644 index 000000000..3b84e6b18 --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleManager.scala @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.remote + +import java.io.IOException +import java.net.URL +import java.util.concurrent.ConcurrentHashMap + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileSystem +import org.apache.spark._ +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.{Logging, config} +import org.apache.spark.shuffle._ +import org.apache.spark.shuffle.remote.RemoteShuffleManager.{active, appendRemoteStorageHadoopConfigurations} +import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch +import org.apache.spark.shuffle.sort._ +import org.apache.spark.util.collection.OpenHashSet + +/** + * In remote shuffle, data is written to a remote Hadoop compatible file system instead of local + * disks. + */ +private[spark] class RemoteShuffleManager(private val conf: SparkConf) extends ShuffleManager with + Logging { + + require(conf.get( + config.SHUFFLE_SERVICE_ENABLED.key, config.SHUFFLE_SERVICE_ENABLED.defaultValueString) + == "false", "Remote shuffle and external shuffle service: they cannot be enabled at the" + + " same time") + + RemoteShuffleManager.setActive(this) + + logWarning("******** Remote Shuffle Manager is used ********") + + if (!conf.getBoolean("spark.shuffle.spill", true)) { + logWarning( + "spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." + + " Shuffle will continue to spill to disk when necessary.") + } + + /** + * A mapping from shuffle ids to the task ids of mappers producing output for those shuffles. + */ + private[this] val taskIdMapsForShuffle = new ConcurrentHashMap[Int, OpenHashSet[Long]]() + + override val shuffleBlockResolver = new RemoteShuffleBlockResolver(conf) + + /** + * Obtains a [[ShuffleHandle]] to pass to tasks. + */ + override def registerShuffle[K, V, C]( + shuffleId: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + if (RemoteShuffleManager.shouldBypassMergeSort(conf, dependency)) { + // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't + // need map-side aggregation, then write numPartitions files directly and just concatenate + // them at the end. This avoids doing serialization and deserialization twice to merge + // together the spilled files, which would happen with the normal code path. The downside is + // having multiple files open at a time and thus more memory allocated to buffers. + new BypassMergeSortShuffleHandle[K, V]( + shuffleId, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else if (RemoteShuffleManager.canUseSerializedShuffle(dependency, conf)) { + new SerializedShuffleHandle[K, V]( + shuffleId, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else { + // Otherwise, buffer map outputs in a deserialized form: + new BaseShuffleHandle(shuffleId, dependency) + } + } + + /** + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). + * Called on executors by reduce tasks. + */ + override def getReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + + val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( + handle.shuffleId, startPartition, endPartition) + + new RemoteShuffleReader( + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], + shuffleBlockResolver, + blocksByAddress, + context, + metrics, + shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context)) + } + + override def getReaderForRange[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + + val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByRange( + handle.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition) + + new RemoteShuffleReader( + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], + shuffleBlockResolver, + blocksByAddress, + context, + metrics, + shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context)) + } + + /** Get a writer for a given partition. Called on executors by map tasks. */ + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Long, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { + val mapTaskIds = taskIdMapsForShuffle.computeIfAbsent( + handle.shuffleId, _ => new OpenHashSet[Long](16)) + mapTaskIds.synchronized { mapTaskIds.add(context.taskAttemptId()) } + val env = SparkEnv.get + handle match { + case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => + new RemoteUnsafeShuffleWriter( + env.blockManager, + shuffleBlockResolver, + context.taskMemoryManager(), + unsafeShuffleHandle, + mapId, + context, + env.conf, + metrics) + case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => + new RemoteBypassMergeSortShuffleWriter( + env.blockManager, + shuffleBlockResolver, + bypassMergeSortHandle, + mapId, + context, + env.conf, + metrics) + case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => + new RemoteShuffleWriter(shuffleBlockResolver, other, mapId, context) + } + } + + /** Remove a shuffle's metadata from the ShuffleManager. */ + override def unregisterShuffle(shuffleId: Int): Boolean = { + Option(taskIdMapsForShuffle.remove(shuffleId)).foreach { mapTaskIds => + mapTaskIds.iterator.foreach { mapTaskId => + shuffleBlockResolver.removeDataByMap(shuffleId, mapTaskId) + } + } + true + } + + private[spark] val getHadoopConf = { + val storageMasterUri = active.conf.get("spark.shuffle.remote.storageMasterUri") + + // DAOS-Hadoop-compatible-filesystem configurations are loaded by DAOS Filesystem itself + val hadoopConf = new Configuration(false) + // Hadoop configuration will be loaded from a remote web URI if the shuffle storage + // system is HDFS + if (storageMasterUri.startsWith("hdfs")) { + val host = storageMasterUri.split("//")(1).split(":")(0) + val port = active.conf.get(RemoteShuffleConf.STORAGE_HDFS_MASTER_UI_PORT) + val address = s"http://$host:$port/conf" + try { + hadoopConf.addResource(new URL(address).openConnection.getInputStream) + } catch { + // Suppress this Exception and use the default one + case e: IOException => logWarning( + s"Exception occurs getting configurations from: $address, caused by ${e.getMessage}") + } + } + + (new SparkHadoopUtil).appendS3AndSparkHadoopHiveConfigurations(active.conf, hadoopConf) + appendRemoteStorageHadoopConfigurations(active.conf, hadoopConf) + hadoopConf + } + + /** Shut down this ShuffleManager. */ + override def stop(): Unit = { + shuffleBlockResolver.stop() + } +} + + +private[spark] object RemoteShuffleManager extends Logging { + + var active: RemoteShuffleManager = _ + private[remote] def setActive(update: RemoteShuffleManager): Unit = active = update + + private def appendRemoteStorageHadoopConfigurations( + sparkConf: SparkConf, hadoopConf: Configuration) = { + hadoopConf.set("dfs.replication", sparkConf.get(RemoteShuffleConf.DFS_REPLICATION).toString) + } + + def getFileSystem : FileSystem = { + require(active != null, + "Active RemoteShuffleManager unassigned! It's probably never constructed") + active.shuffleBlockResolver.fs + } + + def getResolver: RemoteShuffleBlockResolver = { + require(active != null, + "Active RemoteShuffleManager unassigned! It's probably never constructed") + active.shuffleBlockResolver + } + + def getConf: SparkConf = { + require(active != null, + "Active RemoteShuffleManager unassigned! It's probably never constructed") + active.conf + } + + /** + * Make the decision also referring to a configuration + */ + def canUseSerializedShuffle(dependency: ShuffleDependency[_, _, _], conf: SparkConf): Boolean = { + val optimizedShuffleEnabled = conf.get(RemoteShuffleConf.REMOTE_OPTIMIZED_SHUFFLE_ENABLED) + optimizedShuffleEnabled && SortShuffleManager.canUseSerializedShuffle(dependency) + } + + // This is identical to [[SortShuffleWriter.shouldBypassMergeSort]], except reading from + // a modified configuration name, due to we need to change the default threshold to -1 in remote + // shuffle + def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = { + + // We cannot bypass sorting if we need to do map-side aggregation. + if (dep.mapSideCombine) { + false + } else { + val bypassMergeThreshold = conf.get(RemoteShuffleConf.REMOTE_BYPASS_MERGE_THRESHOLD) + dep.partitioner.numPartitions <= bypassMergeThreshold + } + } +} diff --git a/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleReader.scala b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleReader.scala new file mode 100644 index 000000000..d25a9bd3e --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleReader.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.remote + +import org.apache.spark._ +import org.apache.spark.internal.{Logging, config} +import org.apache.spark.io.CompressionCodec +import org.apache.spark.serializer.SerializerManager +import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReadMetricsReporter, ShuffleReader} +import org.apache.spark.storage.{BlockId, BlockManagerId} +import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.collection.RemoteSorter + +/** + * Fetches and reads the partitions in range [startPartition, endPartition) from a shuffle by + * requesting them from other nodes' block stores. + */ +private[spark] class RemoteShuffleReader[K, C]( + handle: BaseShuffleHandle[K, _, C], + resolver: RemoteShuffleBlockResolver, + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], + context: TaskContext, + readMetrics: ShuffleReadMetricsReporter, + serializerManager: SerializerManager = SparkEnv.get.serializerManager, + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker, + shouldBatchFetch: Boolean = false) + +extends ShuffleReader[K, C] with Logging { + + private val dep = handle.dependency + + private def fetchContinuousBlocksInBatch: Boolean = { + val conf = SparkEnv.get.conf + val serializerRelocatable = dep.serializer.supportsRelocationOfSerializedObjects + val compressed = conf.get(config.SHUFFLE_COMPRESS) + val codecConcatenation = if (compressed) { + CompressionCodec.supportsConcatenationOfSerializedStreams(CompressionCodec.createCodec(conf)) + } else { + true + } + val useOldFetchProtocol = conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL) + + val doBatchFetch = shouldBatchFetch && serializerRelocatable && + (!compressed || codecConcatenation) && !useOldFetchProtocol + if (shouldBatchFetch && !doBatchFetch) { + logDebug("The feature tag of continuous shuffle block fetching is set to true, but " + + "we can not enable the feature because other conditions are not satisfied. " + + s"Shuffle compress: $compressed, serializer relocatable: $serializerRelocatable, " + + s"codec concatenation: $codecConcatenation, use old shuffle fetch protocol: " + + s"$useOldFetchProtocol.") + } + doBatchFetch + } + + /** Read the combined key-values for this reduce task */ + override def read(): Iterator[Product2[K, C]] = { + val wrappedStreams = new RemoteShuffleBlockIterator( + context, + resolver.remoteShuffleTransferService, + resolver, + blocksByAddress, + serializerManager.wrapStream, + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, + SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), + SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), + SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true), + readMetrics, + fetchContinuousBlocksInBatch) + + val serializerInstance = dep.serializer.newInstance() + + // Create a key/value iterator for each stream + val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) => + // Note: the asKeyValueIterator below wraps a key/value iterator inside of a + // NextIterator. The NextIterator makes sure that close() is called on the + // underlying InputStream when all records have been read. + serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator + } + + // Update the context task metrics for each record read. + val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( + recordIter.map { record => + readMetrics.incRecordsRead(1) + record + }, + context.taskMetrics().mergeShuffleReadMetrics()) + + // An interruptible iterator must be used here in order to support task cancellation + val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) + + val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { + val remoteAggregator = dep.aggregator.map(new RemoteAggregator(_, resolver)) + if (dep.mapSideCombine) { + // We are reading values that are already combined + val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] + remoteAggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) + } else { + // We don't know the value type, but also don't care -- the dependency *should* + // have made sure its compatible w/ this aggregator, which will convert the value + // type to the combined type C + val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] + remoteAggregator.get.combineValuesByKey(keyValuesIterator, context) + } + } else { + interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] + } + + // Sort the output if there is a sort ordering defined. + val resultIter = dep.keyOrdering match { + case Some(keyOrd: Ordering[K]) => + // Create an ExternalSorter to sort the data. + val sorter = new RemoteSorter[K, C, C]( + context, resolver, ordering = Some(keyOrd), serializer = dep.serializer) + sorter.insertAll(aggregatedIter) + context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) + context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) + // Use completion callback to stop sorter if task was finished/cancelled. + context.addTaskCompletionListener[Unit](_ => { + sorter.stop() + }) + CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) + case None => + aggregatedIter + } + + resultIter match { + case _: InterruptibleIterator[Product2[K, C]] => resultIter + case _ => + // Use another interruptible iterator here to support task cancellation as aggregator + // or(and) sorter may have consumed previous interruptible iterator. + new InterruptibleIterator[Product2[K, C]](context, resultIter) + } + } +} diff --git a/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleUtils.scala b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleUtils.scala new file mode 100644 index 000000000..4b129851a --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleUtils.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.remote + +import java.util.UUID + +import org.apache.hadoop.fs.Path +import org.apache.spark.SparkEnv +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.serializer.{SerializerInstance, SerializerManager} +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter +import org.apache.spark.storage.{BlockId, TempLocalBlockId, TempShuffleBlockId} + +object RemoteShuffleUtils { + + val env = SparkEnv.get + + /** + * Something like [[org.apache.spark.util.Utils.tempFileWith()]], instead returning Path + */ + def tempPathWith(path: Path): Path = { + new Path(path.toString + "." + UUID.randomUUID()) + } + + private def getPath(blockId: BlockId, dirUri: String): Path = { + new Path(s"${dirUri}/${blockId.name}") + } + + /** + * Something like [[org.apache.spark.storage.DiskBlockManager.createTempShuffleBlock()]], instead + * returning Path + */ + private[remote] def createTempShuffleBlock(dirUri: String): (TempShuffleBlockId, Path) = { + var blockId = new TempShuffleBlockId(UUID.randomUUID()) + val tmpPath = getPath(blockId, dirUri) + val fs = RemoteShuffleManager.getFileSystem + while (fs.exists(tmpPath)) { + blockId = new TempShuffleBlockId(UUID.randomUUID()) + } + (blockId, tmpPath) + } + + /** + * Something like [[org.apache.spark.storage.DiskBlockManager.createTempLocalBlock()]], instead + * returning Path + */ + private[remote] def createTempLocalBlock(dirUri: String): (TempLocalBlockId, Path) = { + var blockId = new TempLocalBlockId(UUID.randomUUID()) + val tmpPath = getPath(blockId, dirUri) + val fs = RemoteShuffleManager.getFileSystem + while (fs.exists(tmpPath)) { + blockId = new TempLocalBlockId(UUID.randomUUID()) + } + (blockId, tmpPath) + } + + /** + * Something like [[org.apache.spark.storage.BlockManager.getDiskWriter()]], instead returning + * a RemoteBlockObjectWriter + */ + def getRemoteWriter( + blockId: BlockId, + file: Path, + serializerManager: SerializerManager, + serializerInstance: SerializerInstance, + bufferSize: Int, + writeMetrics: ShuffleWriteMetricsReporter): RemoteBlockObjectWriter = { + val syncWrites = false // env.blockManager.conf.getBoolean("spark.shuffle.sync", false) + new RemoteBlockObjectWriter(file, serializerManager, serializerInstance, bufferSize, + syncWrites, writeMetrics, blockId) + } + +} diff --git a/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleWriter.scala b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleWriter.scala new file mode 100644 index 000000000..09d6c25ad --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/shuffle/remote/RemoteShuffleWriter.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.remote + +import org.apache.spark._ +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.MapStatus +import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter} +import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.util.collection.RemoteSorter + +private[spark] class RemoteShuffleWriter[K, V, C]( + resolver: RemoteShuffleBlockResolver, + handle: BaseShuffleHandle[K, V, C], + mapId: Long, + context: TaskContext) + extends ShuffleWriter[K, V] with Logging { + + logWarning("******** General Remote Shuffle Writer is used ********") + + private lazy val fs = RemoteShuffleManager.getFileSystem + + private val blockManager = SparkEnv.get.blockManager + + private val dep = handle.dependency + + private var sorter: RemoteSorter[K, V, _] = null + + // Are we in the process of stopping? Because map tasks can call stop() with success = true + // and then call stop() with success = false if they get an exception, we want to make sure + // we don't try deleting files, etc twice. + private var stopping = false + + private var mapStatus: MapStatus = null + + private val writeMetrics = context.taskMetrics().shuffleWriteMetrics + + /** Write a bunch of records to this task's output */ + override def write(records: Iterator[Product2[K, V]]): Unit = { + sorter = if (dep.mapSideCombine) { + new RemoteSorter[K, V, C]( + context, resolver, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) + } else { + // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't + // care whether the keys get sorted in each partition; that will be done on the reduce side + // if the operation being run is sortByKey. + new RemoteSorter[K, V, V]( + context, resolver, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer + ) + } + sorter.insertAll(records) + + // Don't bother including the time to open the merged output file in the shuffle write time, + // because it just opens a single file, so is typically too fast to measure accurately + // (see SPARK-3570). + val output = resolver.getDataFile(dep.shuffleId, mapId) + val tmp = RemoteShuffleUtils.tempPathWith(output) + try { + val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) + val partitionLengths = sorter.writePartitionedFile(blockId, tmp) + resolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) + mapStatus = + MapStatus(RemoteShuffleManager.getResolver.shuffleServerId, partitionLengths, mapId) + } finally { + if (fs.exists(tmp) && !fs.delete(tmp, true)) { + logError(s"Error while deleting temp file ${tmp.getName}") + } + } + } + + /** Close this writer, passing along whether the map completed */ + override def stop(success: Boolean): Option[MapStatus] = { + try { + if (stopping) { + return None + } + stopping = true + if (success) { + return Option(mapStatus) + } else { + return None + } + } finally { + // Clean up our sorter, which may have its own intermediate files + if (sorter != null) { + val startTime = System.nanoTime() + sorter.stop() + writeMetrics.incWriteTime(System.nanoTime - startTime) + sorter = null + } + } + } +} diff --git a/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/util/collection/RPartitionedAppendOnlyMap.scala b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/util/collection/RPartitionedAppendOnlyMap.scala new file mode 100644 index 000000000..87170db71 --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/util/collection/RPartitionedAppendOnlyMap.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.util.Comparator + +import org.apache.spark.util.collection.RWritablePartitionedPairCollection._ + +/** + * Note: Only the class name is modified. We didn't just override + * [[WritablePartitionedPairCollection]], [[PartitionedPairBuffer]] and + * [[PartitionedAppendOnlyMap]] in order to let the default local sort shuffle manager still work + * with the remote shuffle package existed + * + * Implementation of WritablePartitionedPairCollection that wraps a map in which the keys + * are tuples of (partition ID, K) + */ +private[spark] class RPartitionedAppendOnlyMap[K, V] + extends SizeTrackingAppendOnlyMap[(Int, K), V] with RWritablePartitionedPairCollection[K, V] { + + def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]]) + : Iterator[((Int, K), V)] = { + val comparator = keyComparator.map(partitionKeyComparator).getOrElse(partitionComparator) + destructiveSortedIterator(comparator) + } + + def insert(partition: Int, key: K, value: V): Unit = { + update((partition, key), value) + } +} diff --git a/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/util/collection/RPartitionedPairBuffer.scala b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/util/collection/RPartitionedPairBuffer.scala new file mode 100644 index 000000000..c1afce2d8 --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/util/collection/RPartitionedPairBuffer.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.util.Comparator + +import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.util.collection.RWritablePartitionedPairCollection._ + +/** + * Note: Only the class name is modified. We didn't just override + * [[WritablePartitionedPairCollection]], [[PartitionedPairBuffer]] and + * [[PartitionedAppendOnlyMap]] in order to let the default local sort shuffle manager still work + * with the remote shuffle package existed + * + * Append-only buffer of key-value pairs, each with a corresponding partition ID, that keeps track + * of its estimated size in bytes. + * + * The buffer can support up to 1073741819 elements. + */ +private[spark] class RPartitionedPairBuffer[K, V](initialCapacity: Int = 64) + extends RWritablePartitionedPairCollection[K, V] with SizeTracker +{ + import RPartitionedPairBuffer._ + + require(initialCapacity <= MAXIMUM_CAPACITY, + s"Can't make capacity bigger than ${MAXIMUM_CAPACITY} elements") + require(initialCapacity >= 1, "Invalid initial capacity") + + // Basic growable array data structure. We use a single array of AnyRef to hold both the keys + // and the values, so that we can sort them efficiently with KVArraySortDataFormat. + private var capacity = initialCapacity + private var curSize = 0 + private var data = new Array[AnyRef](2 * initialCapacity) + + /** Add an element into the buffer */ + def insert(partition: Int, key: K, value: V): Unit = { + if (curSize == capacity) { + growArray() + } + data(2 * curSize) = (partition, key.asInstanceOf[AnyRef]) + data(2 * curSize + 1) = value.asInstanceOf[AnyRef] + curSize += 1 + afterUpdate() + } + + /** Double the size of the array because we've reached capacity */ + private def growArray(): Unit = { + if (capacity >= MAXIMUM_CAPACITY) { + throw new IllegalStateException(s"Can't insert more than ${MAXIMUM_CAPACITY} elements") + } + val newCapacity = + if (capacity * 2 > MAXIMUM_CAPACITY) { // Overflow + MAXIMUM_CAPACITY + } else { + capacity * 2 + } + val newArray = new Array[AnyRef](2 * newCapacity) + System.arraycopy(data, 0, newArray, 0, 2 * capacity) + data = newArray + capacity = newCapacity + resetSamples() + } + + /** Iterate through the data in a given order. For this class this is not really destructive. */ + override def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]]) + : Iterator[((Int, K), V)] = { + val comparator = keyComparator.map(partitionKeyComparator).getOrElse(partitionComparator) + new Sorter(new KVArraySortDataFormat[(Int, K), AnyRef]).sort(data, 0, curSize, comparator) + iterator + } + + private def iterator(): Iterator[((Int, K), V)] = new Iterator[((Int, K), V)] { + var pos = 0 + + override def hasNext: Boolean = pos < curSize + + override def next(): ((Int, K), V) = { + if (!hasNext) { + throw new NoSuchElementException + } + val pair = (data(2 * pos).asInstanceOf[(Int, K)], data(2 * pos + 1).asInstanceOf[V]) + pos += 1 + pair + } + } +} + +private object RPartitionedPairBuffer { + val MAXIMUM_CAPACITY: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH / 2 +} diff --git a/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/util/collection/RWritablePartitionedPairCollection.scala b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/util/collection/RWritablePartitionedPairCollection.scala new file mode 100644 index 000000000..87e357067 --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/util/collection/RWritablePartitionedPairCollection.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.util.Comparator + +import org.apache.spark.shuffle.remote.RemoteBlockObjectWriter + +/** + * NOTE: This is to override Spark 3.0.0's WritablePartitionedIterator: Changing the writeNext + * interface's args type to RemoteBlockObjectWriter + * + * Note: We made several places returning a [[RemoteBlockObjectWriter]]. And we didn't just + * override [[WritablePartitionedPairCollection]], [[PartitionedPairBuffer]] and + * [[PartitionedAppendOnlyMap]] in order to let the default local sort shuffle manager still work + * with the remote shuffle package existed + * + * A common interface for size-tracking collections of key-value pairs that + * + * - Have an associated partition for each key-value pair. + * - Support a memory-efficient sorted iterator + * - Support a WritablePartitionedIterator for writing the contents directly as bytes. + */ +private[spark] trait RWritablePartitionedPairCollection[K, V] { + /** + * Insert a key-value pair with a partition into the collection + */ + def insert(partition: Int, key: K, value: V): Unit + + /** + * Iterate through the data in order of partition ID and then the given comparator. This may + * destroy the underlying collection. + */ + def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]]) + : Iterator[((Int, K), V)] + + /** + * Iterate through the data and write out the elements instead of returning them. Records are + * returned in order of their partition ID and then the given comparator. + * This may destroy the underlying collection. + */ + def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]]) + : RWritablePartitionedIterator = { + val it = partitionedDestructiveSortedIterator(keyComparator) + new RWritablePartitionedIterator { + private[this] var cur = if (it.hasNext) it.next() else null + + def writeNext(writer: RemoteBlockObjectWriter): Unit = { + writer.write(cur._1._2, cur._2) + cur = if (it.hasNext) it.next() else null + } + + def hasNext(): Boolean = cur != null + + def nextPartition(): Int = cur._1._1 + } + } +} + +private[spark] object RWritablePartitionedPairCollection { + /** + * A comparator for (Int, K) pairs that orders them by only their partition ID. + */ + def partitionComparator[K]: Comparator[(Int, K)] = new Comparator[(Int, K)] { + override def compare(a: (Int, K), b: (Int, K)): Int = { + a._1 - b._1 + } + } + + /** + * A comparator for (Int, K) pairs that orders them both by their partition ID + * and a key ordering. + */ + def partitionKeyComparator[K](keyComparator: Comparator[K]): Comparator[(Int, K)] = { + new Comparator[(Int, K)] { + override def compare(a: (Int, K), b: (Int, K)): Int = { + val partitionDiff = a._1 - b._1 + if (partitionDiff != 0) { + partitionDiff + } else { + keyComparator.compare(a._2, b._2) + } + } + } + } +} + +/** + * Iterator that writes elements to a DiskBlockObjectWriter instead of returning them. Each element + * has an associated partition. + */ +private[spark] trait RWritablePartitionedIterator { + def writeNext(writer: RemoteBlockObjectWriter): Unit + + def hasNext(): Boolean + + def nextPartition(): Int +} diff --git a/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/util/collection/RemoteAppendOnlyMap.scala b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/util/collection/RemoteAppendOnlyMap.scala new file mode 100644 index 000000000..ce6ade041 --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/util/collection/RemoteAppendOnlyMap.scala @@ -0,0 +1,643 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.io._ +import java.util.Comparator + +import scala.collection.{mutable, BufferedIterator} +import scala.collection.mutable.ArrayBuffer + +import com.google.common.io.ByteStreams +import org.apache.hadoop.fs.{FSDataInputStream, Path} + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.internal.Logging +import org.apache.spark.serializer.{DeserializationStream, Serializer, SerializerManager} +import org.apache.spark.shuffle.remote.{RemoteBlockObjectWriter, RemoteShuffleBlockResolver, RemoteShuffleManager, RemoteShuffleUtils} +import org.apache.spark.storage.BlockId +import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.collection.RemoteAppendOnlyMap.HashComparator + +/** + * :: DeveloperApi :: + * An append-only map that spills sorted content to disk when there is insufficient space for it + * to grow. + * + * This map takes two passes over the data: + * + * (1) Values are merged into combiners, which are sorted and spilled to disk as necessary + * (2) Combiners are read from disk and merged together + * + * The setting of the spill threshold faces the following trade-off: If the spill threshold is + * too high, the in-memory map may occupy more memory than is available, resulting in OOM. + * However, if the spill threshold is too low, we spill frequently and incur unnecessary disk + * writes. This may lead to a performance regression compared to the normal case of using the + * non-spilling AppendOnlyMap. + */ +@DeveloperApi +class RemoteAppendOnlyMap[K, V, C]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C, + serializer: Serializer = SparkEnv.get.serializer, + resolver: RemoteShuffleBlockResolver, + context: TaskContext = TaskContext.get(), + serializerManager: SerializerManager = SparkEnv.get.serializerManager) + extends Spillable[SizeTracker](context.taskMemoryManager()) + with Serializable + with Logging + with Iterable[(K, C)] { + + if (context == null) { + throw new IllegalStateException( + "Spillable collections should not be instantiated outside of tasks") + } + + private lazy val fs = RemoteShuffleManager.getFileSystem + + /** + * Exposed for testing + */ + @volatile private[collection] var currentMap = new SizeTrackingAppendOnlyMap[K, C] + private val spilledMaps = new ArrayBuffer[DiskMapIterator] + private val sparkConf = SparkEnv.get.conf + + /** + * Size of object batches when reading/writing from serializers. + * + * Objects are written in batches, with each batch using its own serialization stream. This + * cuts down on the size of reference-tracking maps constructed when deserializing a stream. + * + * NOTE: Setting this too low can cause excessive copying when serializing, since some + * serializers grow internal data structures by growing + copying every time the number + * of objects doubles. + */ + private val serializerBatchSize = sparkConf.getLong("spark.shuffle.spill.batchSize", 10000) + + // Number of bytes spilled in total + private var _diskBytesSpilled = 0L + def diskBytesSpilled: Long = _diskBytesSpilled + + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided + private val fileBufferSize = + sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 + + // Write metrics + private val writeMetrics: ShuffleWriteMetrics = new ShuffleWriteMetrics() + + // Peak size of the in-memory map observed so far, in bytes + private var _peakMemoryUsedBytes: Long = 0L + def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes + + private val keyComparator = new HashComparator[K] + private val ser = serializer.newInstance() + + @volatile private var readingIterator: SpillableIterator = null + + /** + * Number of files this map has spilled so far. + * Exposed for testing. + */ + private[collection] def numSpills: Int = spilledMaps.size + + /** + * Insert the given key and value into the map. + */ + def insert(key: K, value: V): Unit = { + insertAll(Iterator((key, value))) + } + + /** + * Insert the given iterator of keys and values into the map. + * + * When the underlying map needs to grow, check if the global pool of shuffle memory has + * enough room for this to happen. If so, allocate the memory required to grow the map; + * otherwise, spill the in-memory map to disk. + * + * The shuffle memory usage of the first trackMemoryThreshold entries is not tracked. + */ + def insertAll(entries: Iterator[Product2[K, V]]): Unit = { + if (currentMap == null) { + throw new IllegalStateException( + "Cannot insert new elements into a map after calling iterator") + } + // An update function for the map that we reuse across entries to avoid allocating + // a new closure each time + var curEntry: Product2[K, V] = null + val update: (Boolean, C) => C = (hadVal, oldVal) => { + if (hadVal) mergeValue(oldVal, curEntry._2) else createCombiner(curEntry._2) + } + + while (entries.hasNext) { + curEntry = entries.next() + val estimatedSize = currentMap.estimateSize() + if (estimatedSize > _peakMemoryUsedBytes) { + _peakMemoryUsedBytes = estimatedSize + } + if (maybeSpill(currentMap, estimatedSize)) { + currentMap = new SizeTrackingAppendOnlyMap[K, C] + } + currentMap.changeValue(curEntry._1, update) + addElementsRead() + } + } + + /** + * Insert the given iterable of keys and values into the map. + * + * When the underlying map needs to grow, check if the global pool of shuffle memory has + * enough room for this to happen. If so, allocate the memory required to grow the map; + * otherwise, spill the in-memory map to disk. + * + * The shuffle memory usage of the first trackMemoryThreshold entries is not tracked. + */ + def insertAll(entries: Iterable[Product2[K, V]]): Unit = { + insertAll(entries.iterator) + } + + /** + * Sort the existing contents of the in-memory map and spill them to a temporary file on disk. + */ + override protected[this] def spill(collection: SizeTracker): Unit = { + val inMemoryIterator = currentMap.destructiveSortedIterator(keyComparator) + val diskMapIterator = spillMemoryIteratorToDisk(inMemoryIterator) + spilledMaps += diskMapIterator + } + + /** + * Force to spilling the current in-memory collection to disk to release memory, + * It will be called by TaskMemoryManager when there is not enough memory for the task. + */ + override protected[this] def forceSpill(): Boolean = { + if (readingIterator != null) { + val isSpilled = readingIterator.spill() + if (isSpilled) { + currentMap = null + } + isSpilled + } else if (currentMap.size > 0) { + spill(currentMap) + currentMap = new SizeTrackingAppendOnlyMap[K, C] + true + } else { + false + } + } + + /** + * Spill the in-memory Iterator to a temporary file on disk. + */ + private[this] def spillMemoryIteratorToDisk(inMemoryIterator: Iterator[(K, C)]) + : DiskMapIterator = { + val (blockId, file) = resolver.createTempLocalBlock() + val writer: RemoteBlockObjectWriter = RemoteShuffleUtils.getRemoteWriter( + blockId, file, serializerManager, ser, fileBufferSize, writeMetrics) + var objectsWritten = 0 + + // List of batch sizes (bytes) in the order they are written to disk + val batchSizes = new ArrayBuffer[Long] + + // Flush the disk writer's contents to disk, and update relevant variables + def flush(): Unit = { + val segment = writer.commitAndGet() + batchSizes += segment.length + _diskBytesSpilled += segment.length + objectsWritten = 0 + } + + var success = false + try { + while (inMemoryIterator.hasNext) { + val kv = inMemoryIterator.next() + writer.write(kv._1, kv._2) + objectsWritten += 1 + + if (objectsWritten == serializerBatchSize) { + flush() + } + } + if (objectsWritten > 0) { + flush() + writer.close() + } else { + writer.revertPartialWritesAndClose() + } + success = true + } finally { + if (!success) { + // This code path only happens if an exception was thrown above before we set success; + // close our stuff and let the exception be thrown further + writer.revertPartialWritesAndClose() + if (fs.exists(file)) { + if (!fs.delete(file, true)) { + logWarning(s"Error deleting ${file}") + } + } + } + } + + new DiskMapIterator(file, blockId, batchSizes) + } + + /** + * Returns a destructive iterator for iterating over the entries of this map. + * If this iterator is forced spill to disk to release memory when there is not enough memory, + * it returns pairs from an on-disk map. + */ + def destructiveIterator(inMemoryIterator: Iterator[(K, C)]): Iterator[(K, C)] = { + readingIterator = new SpillableIterator(inMemoryIterator) + readingIterator.toCompletionIterator + } + + /** + * Return a destructive iterator that merges the in-memory map with the spilled maps. + * If no spill has occurred, simply return the in-memory map's iterator. + */ + override def iterator: Iterator[(K, C)] = { + if (currentMap == null) { + throw new IllegalStateException( + "RemoteAppendOnlyMap.iterator is destructive and should only be called once.") + } + if (spilledMaps.isEmpty) { + destructiveIterator(currentMap.iterator) + } else { + new ExternalIterator() + } + } + + private def freeCurrentMap(): Unit = { + if (currentMap != null) { + currentMap = null // So that the memory can be garbage-collected + releaseMemory() + } + } + + /** + * An iterator that sort-merges (K, C) pairs from the in-memory map and the spilled maps + */ + private class ExternalIterator extends Iterator[(K, C)] { + + // A queue that maintains a buffer for each stream we are currently merging + // This queue maintains the invariant that it only contains non-empty buffers + private val mergeHeap = new mutable.PriorityQueue[StreamBuffer] + + // Input streams are derived both from the in-memory map and spilled maps on disk + // The in-memory map is sorted in place, while the spilled maps are already in sorted order + private val sortedMap = destructiveIterator( + currentMap.destructiveSortedIterator(keyComparator)) + private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered) + + inputStreams.foreach { it => + val kcPairs = new ArrayBuffer[(K, C)] + readNextHashCode(it, kcPairs) + if (kcPairs.length > 0) { + mergeHeap.enqueue(new StreamBuffer(it, kcPairs)) + } + } + + /** + * Fill a buffer with the next set of keys with the same hash code from a given iterator. We + * read streams one hash code at a time to ensure we don't miss elements when they are merged. + * + * Assumes the given iterator is in sorted order of hash code. + * + * @param it iterator to read from + * @param buf buffer to write the results into + */ + private def readNextHashCode(it: BufferedIterator[(K, C)], buf: ArrayBuffer[(K, C)]): Unit = { + if (it.hasNext) { + var kc = it.next() + buf += kc + val minHash = hashKey(kc) + while (it.hasNext && it.head._1.hashCode() == minHash) { + kc = it.next() + buf += kc + } + } + } + + /** + * If the given buffer contains a value for the given key, merge that value into + * baseCombiner and remove the corresponding (K, C) pair from the buffer. + */ + private def mergeIfKeyExists(key: K, baseCombiner: C, buffer: StreamBuffer): C = { + var i = 0 + while (i < buffer.pairs.length) { + val pair = buffer.pairs(i) + if (pair._1 == key) { + // Note that there's at most one pair in the buffer with a given key, since we always + // merge stuff in a map before spilling, so it's safe to return after the first we find + removeFromBuffer(buffer.pairs, i) + return mergeCombiners(baseCombiner, pair._2) + } + i += 1 + } + baseCombiner + } + + /** + * Remove the index'th element from an ArrayBuffer in constant time, swapping another element + * into its place. This is more efficient than the ArrayBuffer.remove method because it does + * not have to shift all the elements in the array over. It works for our array buffers because + * we don't care about the order of elements inside, we just want to search them for a key. + */ + private def removeFromBuffer[T](buffer: ArrayBuffer[T], index: Int): T = { + val elem = buffer(index) + buffer(index) = buffer(buffer.size - 1) // This also works if index == buffer.size - 1 + buffer.reduceToSize(buffer.size - 1) + elem + } + + /** + * Return true if there exists an input stream that still has unvisited pairs. + */ + override def hasNext: Boolean = mergeHeap.nonEmpty + + /** + * Select a key with the minimum hash, then combine all values with the same key from all + * input streams. + */ + override def next(): (K, C) = { + if (mergeHeap.isEmpty) { + throw new NoSuchElementException + } + // Select a key from the StreamBuffer that holds the lowest key hash + val minBuffer = mergeHeap.dequeue() + val minPairs = minBuffer.pairs + val minHash = minBuffer.minKeyHash + val minPair = removeFromBuffer(minPairs, 0) + val minKey = minPair._1 + var minCombiner = minPair._2 + assert(hashKey(minPair) == minHash) + + // For all other streams that may have this key (i.e. have the same minimum key hash), + // merge in the corresponding value (if any) from that stream + val mergedBuffers = ArrayBuffer[StreamBuffer](minBuffer) + while (mergeHeap.nonEmpty && mergeHeap.head.minKeyHash == minHash) { + val newBuffer = mergeHeap.dequeue() + minCombiner = mergeIfKeyExists(minKey, minCombiner, newBuffer) + mergedBuffers += newBuffer + } + + // Repopulate each visited stream buffer and add it back to the queue if it is non-empty + mergedBuffers.foreach { buffer => + if (buffer.isEmpty) { + readNextHashCode(buffer.iterator, buffer.pairs) + } + if (!buffer.isEmpty) { + mergeHeap.enqueue(buffer) + } + } + + (minKey, minCombiner) + } + + /** + * A buffer for streaming from a map iterator (in-memory or on-disk) sorted by key hash. + * Each buffer maintains all of the key-value pairs with what is currently the lowest hash + * code among keys in the stream. There may be multiple keys if there are hash collisions. + * Note that because when we spill data out, we only spill one value for each key, there is + * at most one element for each key. + * + * StreamBuffers are ordered by the minimum key hash currently available in their stream so + * that we can put them into a heap and sort that. + */ + private class StreamBuffer( + val iterator: BufferedIterator[(K, C)], + val pairs: ArrayBuffer[(K, C)]) + extends Comparable[StreamBuffer] { + + def isEmpty: Boolean = pairs.length == 0 + + // Invalid if there are no more pairs in this stream + def minKeyHash: Int = { + assert(pairs.length > 0) + hashKey(pairs.head) + } + + override def compareTo(other: StreamBuffer): Int = { + // descending order because mutable.PriorityQueue dequeues the max, not the min + if (other.minKeyHash < minKeyHash) -1 else if (other.minKeyHash == minKeyHash) 0 else 1 + } + } + } + + /** + * An iterator that returns (K, C) pairs in sorted order from an on-disk map + */ + private class DiskMapIterator(file: Path, blockId: BlockId, batchSizes: ArrayBuffer[Long]) + extends Iterator[(K, C)] + { + private val batchOffsets = batchSizes.scanLeft(0L)(_ + _) // Size will be batchSize.length + 1 + assert(fs.getFileStatus(file).getLen == batchOffsets.last, + "File length is not equal to the last batch offset:\n" + + s" file length = ${fs.getFileStatus(file).getLen}\n" + + s" last batch offset = ${batchOffsets.last}\n" + + s" all batch offsets = ${batchOffsets.mkString(",")}" + ) + + private var batchIndex = 0 // Which batch we're in + private var fileStream: FSDataInputStream = null + + // An intermediate stream that reads from exactly one batch + // This guards against pre-fetching and other arbitrary behavior of higher level streams + private var deserializeStream: DeserializationStream = null + private var nextItem: (K, C) = null + private var objectsRead = 0 + + /** + * Construct a stream that reads only from the next batch. + */ + private def nextBatchStream(): DeserializationStream = { + // Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether + // we're still in a valid batch. + if (batchIndex < batchOffsets.length - 1) { + if (deserializeStream != null) { + deserializeStream.close() + fileStream.close() + deserializeStream = null + fileStream = null + } + + val start = batchOffsets(batchIndex) + fileStream = fs.open(file) + fileStream.seek(start) + batchIndex += 1 + + val end = batchOffsets(batchIndex) + + assert(end >= start, "start = " + start + ", end = " + end + + ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]")) + + val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) + val wrappedStream = serializerManager.wrapStream(blockId, bufferedStream) + ser.deserializeStream(wrappedStream) + } else { + // No more batches left + cleanup() + null + } + } + + /** + * Return the next (K, C) pair from the deserialization stream. + * + * If the current batch is drained, construct a stream for the next batch and read from it. + * If no more pairs are left, return null. + */ + private def readNextItem(): (K, C) = { + try { + val k = deserializeStream.readKey().asInstanceOf[K] + val c = deserializeStream.readValue().asInstanceOf[C] + val item = (k, c) + objectsRead += 1 + if (objectsRead == serializerBatchSize) { + objectsRead = 0 + deserializeStream = nextBatchStream() + } + item + } catch { + case e: EOFException => + cleanup() + null + } + } + + override def hasNext: Boolean = { + if (nextItem == null) { + if (deserializeStream == null) { + // In case of deserializeStream has not been initialized + deserializeStream = nextBatchStream() + if (deserializeStream == null) { + return false + } + } + nextItem = readNextItem() + } + nextItem != null + } + + override def next(): (K, C) = { + if (!hasNext) { + throw new NoSuchElementException + } + val item = nextItem + nextItem = null + item + } + + private def cleanup() { + batchIndex = batchOffsets.length // Prevent reading any other batch + if (deserializeStream != null) { + deserializeStream.close() + deserializeStream = null + } + if (fileStream != null) { + fileStream.close() + fileStream = null + } + if (fs.exists(file)) { + if (!fs.delete(file, true)) { + logWarning(s"Error deleting ${file}") + } + } + } + + context.addTaskCompletionListener[Unit](context => cleanup()) + } + + private class SpillableIterator(var upstream: Iterator[(K, C)]) + extends Iterator[(K, C)] { + + private val SPILL_LOCK = new Object() + + private var cur: (K, C) = readNext() + + private var hasSpilled: Boolean = false + + def spill(): Boolean = SPILL_LOCK.synchronized { + if (hasSpilled) { + false + } else { + logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " + + s"it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory") + val nextUpstream = spillMemoryIteratorToDisk(upstream) + assert(!upstream.hasNext) + hasSpilled = true + upstream = nextUpstream + true + } + } + + private def destroy(): Unit = { + freeCurrentMap() + upstream = Iterator.empty + } + + def toCompletionIterator: CompletionIterator[(K, C), SpillableIterator] = { + CompletionIterator[(K, C), SpillableIterator](this, this.destroy) + } + + def readNext(): (K, C) = SPILL_LOCK.synchronized { + if (upstream.hasNext) { + upstream.next() + } else { + null + } + } + + override def hasNext(): Boolean = cur != null + + override def next(): (K, C) = { + val r = cur + cur = readNext() + r + } + } + + /** Convenience function to hash the given (K, C) pair by the key. */ + private def hashKey(kc: (K, C)): Int = RemoteAppendOnlyMap.hash(kc._1) + + override def toString(): String = { + this.getClass.getName + "@" + java.lang.Integer.toHexString(this.hashCode()) + } +} + +private[spark] object RemoteAppendOnlyMap { + + /** + * Return the hash code of the given object. If the object is null, return a special hash code. + */ + private def hash[T](obj: T): Int = { + if (obj == null) 0 else obj.hashCode() + } + + /** + * A comparator which sorts arbitrary keys based on their hash codes. + */ + private class HashComparator[K] extends Comparator[K] { + def compare(key1: K, key2: K): Int = { + val hash1 = hash(key1) + val hash2 = hash(key2) + if (hash1 < hash2) -1 else if (hash1 == hash2) 0 else 1 + } + } +} diff --git a/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/util/collection/RemoteSorter.scala b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/util/collection/RemoteSorter.scala new file mode 100644 index 000000000..40f4fb292 --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/main/scala/org/apache/spark/util/collection/RemoteSorter.scala @@ -0,0 +1,848 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.io._ +import java.util.Comparator + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import com.google.common.io.ByteStreams +import org.apache.hadoop.fs.{FSDataInputStream, Path} + +import org.apache.spark._ +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.internal.Logging +import org.apache.spark.serializer._ +import org.apache.spark.shuffle.remote.{RemoteBlockObjectWriter, RemoteShuffleBlockResolver, RemoteShuffleManager, RemoteShuffleUtils} +import org.apache.spark.storage.BlockId + +/** + * NOTE: This version of Spark 2.4.0 ExternalSorter is currently imported for the interface + * modification of function: writePartitionedFile, to support writing to Hadoop Filesystem + * + * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner + * pairs of type (K, C). Uses a Partitioner to first group the keys into partitions, and then + * optionally sorts keys within each partition using a custom Comparator. Can output a single + * partitioned file with a different byte range for each partition, suitable for shuffle fetches. + * + * If combining is disabled, the type C must equal V -- we'll cast the objects at the end. + * + * Note: Although ExternalSorter is a fairly generic sorter, some of its configuration is tied + * to its use in sort-based shuffle (for example, its block compression is controlled by + * `spark.shuffle.compress`). We may need to revisit this if ExternalSorter is used in other + * non-shuffle contexts where we might want to use different configuration settings. + * + * @param aggregator optional Aggregator with combine functions to use for merging data + * @param partitioner optional Partitioner; if given, sort by partition ID and then key + * @param ordering optional Ordering to sort keys within each partition; should be a total ordering + * @param serializer serializer to use when spilling to disk + * + * Note that if an Ordering is given, we'll always sort using it, so only provide it if you really + * want the output keys to be sorted. In a map task without map-side combine for example, you + * probably want to pass None as the ordering to avoid extra sorting. On the other hand, if you do + * want to do combining, having an Ordering is more efficient than not having it. + * + * Users interact with this class in the following way: + * + * 1. Instantiate an ExternalSorter. + * + * 2. Call insertAll() with a set of records. + * + * 3. Request an iterator() back to traverse sorted/aggregated records. + * - or - + * Invoke writePartitionedFile() to create a file containing sorted/aggregated outputs + * that can be used in Spark's sort shuffle. + * + * At a high level, this class works internally as follows: + * + * - We repeatedly fill up buffers of in-memory data, using either a PartitionedAppendOnlyMap if + * we want to combine by key, or a PartitionedPairBuffer if we don't. + * Inside these buffers, we sort elements by partition ID and then possibly also by key. + * To avoid calling the partitioner multiple times with each key, we store the partition ID + * alongside each record. + * + * - When each buffer reaches our memory limit, we spill it to a file. This file is sorted first + * by partition ID and possibly second by key or by hash code of the key, if we want to do + * aggregation. For each file, we track how many objects were in each partition in memory, so we + * don't have to write out the partition ID for every element. + * + * - When the user requests an iterator or file output, the spilled files are merged, along with + * any remaining in-memory data, using the same sort order defined above (unless both sorting + * and aggregation are disabled). If we need to aggregate by key, we either use a total ordering + * from the ordering parameter, or read the keys with the same hash code and compare them with + * each other for equality to merge values. + * + * - Users are expected to call stop() at the end to delete all the intermediate files. + */ +private[spark] class RemoteSorter[K, V, C]( + context: TaskContext, + resolver: RemoteShuffleBlockResolver, + aggregator: Option[Aggregator[K, V, C]] = None, + partitioner: Option[Partitioner] = None, + ordering: Option[Ordering[K]] = None, + serializer: Serializer = SparkEnv.get.serializer) + extends Spillable[RWritablePartitionedPairCollection[K, C]](context.taskMemoryManager()) + with Logging { + + private val conf = SparkEnv.get.conf + private lazy val fs = RemoteShuffleManager.getFileSystem + + private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1) + private val shouldPartition = numPartitions > 1 + private def getPartition(key: K): Int = { + if (shouldPartition) partitioner.get.getPartition(key) else 0 + } + + private val blockManager = SparkEnv.get.blockManager + private val diskBlockManager = blockManager.diskBlockManager + private val serializerManager = SparkEnv.get.serializerManager + private val serInstance = serializer.newInstance() + + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided + private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 + + // Size of object batches when reading/writing from serializers. + // + // Objects are written in batches, with each batch using its own serialization stream. This + // cuts down on the size of reference-tracking maps constructed when deserializing a stream. + // + // NOTE: Setting this too low can cause excessive copying when serializing, since some serializers + // grow internal data structures by growing + copying every time the number of objects doubles. + private val serializerBatchSize = conf.getLong("spark.shuffle.spill.batchSize", 10000) + + // Data structures to store in-memory objects before we spill. Depending on whether we have an + // Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we + // store them in an array buffer. + @volatile private var map = new RPartitionedAppendOnlyMap[K, C] + @volatile private var buffer = new RPartitionedPairBuffer[K, C] + + // Total spilling statistics + private var _diskBytesSpilled = 0L + def diskBytesSpilled: Long = _diskBytesSpilled + + // Peak size of the in-memory data structure observed so far, in bytes + private var _peakMemoryUsedBytes: Long = 0L + def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes + + @volatile private var isShuffleSort: Boolean = true + private val forceSpillFiles = new ArrayBuffer[SpilledFile] + @volatile private var readingIterator: SpillableIterator = null + + // A comparator for keys K that orders them within a partition to allow aggregation or sorting. + // Can be a partial ordering by hash code if a total ordering is not provided through by the + // user. (A partial ordering means that equal keys have comparator.compare(k, k) = 0, but some + // non-equal keys also have this, so we need to do a later pass to find truly equal keys). + // Note that we ignore this if no aggregator and no ordering are given. + private val keyComparator: Comparator[K] = ordering.getOrElse(new Comparator[K] { + override def compare(a: K, b: K): Int = { + val h1 = if (a == null) 0 else a.hashCode() + val h2 = if (b == null) 0 else b.hashCode() + if (h1 < h2) -1 else if (h1 == h2) 0 else 1 + } + }) + + private def comparator: Option[Comparator[K]] = { + if (ordering.isDefined || aggregator.isDefined) { + Some(keyComparator) + } else { + None + } + } + + // NOTE: This is an Hadoop-Filesystem-version SpilledFile, comparing with + // [[ExternalSorter.SpilledFile]] Information about a spilled file. + // Includes sizes in bytes of "batches" written by the serializer as we + // periodically reset its stream, as well as number of elements in each + // partition, used to efficiently keep track of partitions when merging. + private[this] case class SpilledFile( + file: Path, + blockId: BlockId, + serializerBatchSizes: Array[Long], + elementsPerPartition: Array[Long]) + + private val spills = new ArrayBuffer[SpilledFile] + + /** + * Number of files this sorter has spilled so far. + * Exposed for testing. + */ + private[spark] def numSpills: Int = spills.size + + def insertAll(records: Iterator[Product2[K, V]]): Unit = { + // TODO: stop combining if we find that the reduction factor isn't high + val shouldCombine = aggregator.isDefined + + if (shouldCombine) { + // Combine values in-memory first using our AppendOnlyMap + val mergeValue = aggregator.get.mergeValue + val createCombiner = aggregator.get.createCombiner + var kv: Product2[K, V] = null + val update = (hadValue: Boolean, oldValue: C) => { + if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2) + } + while (records.hasNext) { + addElementsRead() + kv = records.next() + map.changeValue((getPartition(kv._1), kv._1), update) + maybeSpillCollection(usingMap = true) + } + } else { + // Stick values into our buffer + while (records.hasNext) { + addElementsRead() + val kv = records.next() + buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C]) + maybeSpillCollection(usingMap = false) + } + } + } + + /** + * Spill the current in-memory collection to disk if needed. + * + * @param usingMap whether we're using a map or buffer as our current in-memory collection + */ + private def maybeSpillCollection(usingMap: Boolean): Unit = { + var estimatedSize = 0L + if (usingMap) { + estimatedSize = map.estimateSize() + if (maybeSpill(map, estimatedSize)) { + map = new RPartitionedAppendOnlyMap[K, C] + } + } else { + estimatedSize = buffer.estimateSize() + if (maybeSpill(buffer, estimatedSize)) { + buffer = new RPartitionedPairBuffer[K, C] + } + } + + if (estimatedSize > _peakMemoryUsedBytes) { + _peakMemoryUsedBytes = estimatedSize + } + } + + /** + * Spill our in-memory collection to a sorted file that we can merge later. + * We add this file into `spilledFiles` to find it later. + * + * @param collection whichever collection we're using (map or buffer) + */ + override protected[this] def spill(collection: RWritablePartitionedPairCollection[K, C]): Unit = { + val inMemoryIterator = collection.destructiveSortedWritablePartitionedIterator(comparator) + val spillFile = spillMemoryIteratorToDisk(inMemoryIterator) + spills += spillFile + } + + /** + * Force to spilling the current in-memory collection to disk to release memory, + * It will be called by TaskMemoryManager when there is not enough memory for the task. + */ + override protected[this] def forceSpill(): Boolean = { + if (isShuffleSort) { + false + } else { + assert(readingIterator != null) + val isSpilled = readingIterator.spill() + if (isSpilled) { + map = null + buffer = null + } + isSpilled + } + } + + /** + * Spill contents of in-memory iterator to a temporary file on disk. + */ + private[this] def spillMemoryIteratorToDisk(inMemoryIterator: RWritablePartitionedIterator) + : SpilledFile = { + // Because these files may be read during shuffle, their compression must be controlled by + // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use + // createTempShuffleBlock here; see SPARK-3426 for more context. + val (blockId, file) = resolver.createTempShuffleBlock() + + // These variables are reset after each flush + var objectsWritten: Long = 0 + val spillMetrics: ShuffleWriteMetrics = new ShuffleWriteMetrics + val writer: RemoteBlockObjectWriter = + RemoteShuffleUtils.getRemoteWriter( + blockId, file, serializerManager, serInstance, fileBufferSize, spillMetrics) + + // List of batch sizes (bytes) in the order they are written to disk + val batchSizes = new ArrayBuffer[Long] + + // How many elements we have in each partition + val elementsPerPartition = new Array[Long](numPartitions) + + // Flush the disk writer's contents to disk, and update relevant variables. + // The writer is committed at the end of this process. + def flush(): Unit = { + val segment = writer.commitAndGet() + batchSizes += segment.length + _diskBytesSpilled += segment.length + objectsWritten = 0 + } + + var success = false + try { + while (inMemoryIterator.hasNext) { + val partitionId = inMemoryIterator.nextPartition() + require(partitionId >= 0 && partitionId < numPartitions, + s"partition Id: ${partitionId} should be in the range [0, ${numPartitions})") + inMemoryIterator.writeNext(writer) + elementsPerPartition(partitionId) += 1 + objectsWritten += 1 + + if (objectsWritten == serializerBatchSize) { + flush() + } + } + if (objectsWritten > 0) { + flush() + } else { + writer.revertPartialWritesAndClose() + } + success = true + } finally { + if (success) { + writer.close() + } else { + // This code path only happens if an exception was thrown above before we set success; + // close our stuff and let the exception be thrown further + writer.revertPartialWritesAndClose() + if (fs.exists(file)) { + if (!fs.delete(file, true)) { + logWarning(s"Error deleting ${file}") + } + } + } + } + + SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition) + } + + /** + * Merge a sequence of sorted files, giving an iterator over partitions and then over elements + * inside each partition. This can be used to either write out a new file or return data to + * the user. + * + * Returns an iterator over all the data written to this object, grouped by partition. For each + * partition we then have an iterator over its contents, and these are expected to be accessed + * in order (you can't "skip ahead" to one partition without reading the previous one). + * Guaranteed to return a key-value pair for each partition, in order of partition ID. + */ + private def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)]) + : Iterator[(Int, Iterator[Product2[K, C]])] = { + val readers = spills.map(new SpillReader(_)) + val inMemBuffered = inMemory.buffered + (0 until numPartitions).iterator.map { p => + val inMemIterator = new IteratorForPartition(p, inMemBuffered) + val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator) + if (aggregator.isDefined) { + // Perform partial aggregation across partitions + (p, mergeWithAggregation( + iterators, aggregator.get.mergeCombiners, keyComparator, ordering.isDefined)) + } else if (ordering.isDefined) { + // No aggregator given, but we have an ordering (e.g. used by reduce tasks in sortByKey); + // sort the elements without trying to merge them + (p, mergeSort(iterators, ordering.get)) + } else { + (p, iterators.iterator.flatten) + } + } + } + + /** + * Merge-sort a sequence of (K, C) iterators using a given a comparator for the keys. + */ + private def mergeSort(iterators: Seq[Iterator[Product2[K, C]]], comparator: Comparator[K]) + : Iterator[Product2[K, C]] = + { + val bufferedIters = iterators.filter(_.hasNext).map(_.buffered) + type Iter = BufferedIterator[Product2[K, C]] + val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] { + // Use the reverse order because PriorityQueue dequeues the max + override def compare(x: Iter, y: Iter): Int = comparator.compare(y.head._1, x.head._1) + }) + heap.enqueue(bufferedIters: _*) // Will contain only the iterators with hasNext = true + new Iterator[Product2[K, C]] { + override def hasNext: Boolean = !heap.isEmpty + + override def next(): Product2[K, C] = { + if (!hasNext) { + throw new NoSuchElementException + } + val firstBuf = heap.dequeue() + val firstPair = firstBuf.next() + if (firstBuf.hasNext) { + heap.enqueue(firstBuf) + } + firstPair + } + } + } + + /** + * Merge a sequence of (K, C) iterators by aggregating values for each key, assuming that each + * iterator is sorted by key with a given comparator. If the comparator is not a total ordering + * (e.g. when we sort objects by hash code and different keys may compare as equal although + * they're not), we still merge them by doing equality tests for all keys that compare as equal. + */ + private def mergeWithAggregation( + iterators: Seq[Iterator[Product2[K, C]]], + mergeCombiners: (C, C) => C, + comparator: Comparator[K], + totalOrder: Boolean) + : Iterator[Product2[K, C]] = + { + if (!totalOrder) { + // We only have a partial ordering, e.g. comparing the keys by hash code, which means that + // multiple distinct keys might be treated as equal by the ordering. To deal with this, we + // need to read all keys considered equal by the ordering at once and compare them. + new Iterator[Iterator[Product2[K, C]]] { + val sorted = mergeSort(iterators, comparator).buffered + + // Buffers reused across elements to decrease memory allocation + val keys = new ArrayBuffer[K] + val combiners = new ArrayBuffer[C] + + override def hasNext: Boolean = sorted.hasNext + + override def next(): Iterator[Product2[K, C]] = { + if (!hasNext) { + throw new NoSuchElementException + } + keys.clear() + combiners.clear() + val firstPair = sorted.next() + keys += firstPair._1 + combiners += firstPair._2 + val key = firstPair._1 + while (sorted.hasNext && comparator.compare(sorted.head._1, key) == 0) { + val pair = sorted.next() + var i = 0 + var foundKey = false + while (i < keys.size && !foundKey) { + if (keys(i) == pair._1) { + combiners(i) = mergeCombiners(combiners(i), pair._2) + foundKey = true + } + i += 1 + } + if (!foundKey) { + keys += pair._1 + combiners += pair._2 + } + } + + // Note that we return an iterator of elements since we could've had many keys marked + // equal by the partial order; we flatten this below to get a flat iterator of (K, C). + keys.iterator.zip(combiners.iterator) + } + }.flatMap(i => i) + } else { + // We have a total ordering, so the objects with the same key are sequential. + new Iterator[Product2[K, C]] { + val sorted = mergeSort(iterators, comparator).buffered + + override def hasNext: Boolean = sorted.hasNext + + override def next(): Product2[K, C] = { + if (!hasNext) { + throw new NoSuchElementException + } + val elem = sorted.next() + val k = elem._1 + var c = elem._2 + while (sorted.hasNext && sorted.head._1 == k) { + val pair = sorted.next() + c = mergeCombiners(c, pair._2) + } + (k, c) + } + } + } + } + + /** + * An internal class for reading a spilled file partition by partition. Expects all the + * partitions to be requested in order. + */ + private[this] class SpillReader(spill: SpilledFile) { + // Serializer batch offsets; size will be batchSize.length + 1 + val batchOffsets = spill.serializerBatchSizes.scanLeft(0L)(_ + _) + + // Track which partition and which batch stream we're in. These will be the indices of + // the next element we will read. We'll also store the last partition read so that + // readNextPartition() can figure out what partition that was from. + var partitionId = 0 + var indexInPartition = 0L + var batchId = 0 + var indexInBatch = 0 + var lastPartitionId = 0 + + skipToNextPartition() + + // Intermediate file and deserializer streams that read from exactly one batch + // This guards against pre-fetching and other arbitrary behavior of higher level streams + var fileStream: FSDataInputStream = null + var deserializeStream = nextBatchStream() // Also sets fileStream + + var nextItem: (K, C) = null + var finished = false + + /** Construct a stream that only reads from the next batch */ + def nextBatchStream(): DeserializationStream = { + // Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether + // we're still in a valid batch. + if (batchId < batchOffsets.length - 1) { + if (deserializeStream != null) { + deserializeStream.close() + fileStream.close() + deserializeStream = null + fileStream = null + } + + val start = batchOffsets(batchId) + + fileStream = fs.open(spill.file) + fileStream.seek(start) + batchId += 1 + + val end = batchOffsets(batchId) + + assert(end >= start, "start = " + start + ", end = " + end + + ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]")) + + val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) + + val wrappedStream = serializerManager.wrapStream(spill.blockId, bufferedStream) + serInstance.deserializeStream(wrappedStream) + } else { + // No more batches left + cleanup() + null + } + } + + /** + * Update partitionId if we have reached the end of our current partition, possibly skipping + * empty partitions on the way. + */ + private def skipToNextPartition() { + while (partitionId < numPartitions && + indexInPartition == spill.elementsPerPartition(partitionId)) { + partitionId += 1 + indexInPartition = 0L + } + } + + /** + * Return the next (K, C) pair from the deserialization stream and update partitionId, + * indexInPartition, indexInBatch and such to match its location. + * + * If the current batch is drained, construct a stream for the next batch and read from it. + * If no more pairs are left, return null. + */ + private def readNextItem(): (K, C) = { + if (finished || deserializeStream == null) { + return null + } + val k = deserializeStream.readKey().asInstanceOf[K] + val c = deserializeStream.readValue().asInstanceOf[C] + lastPartitionId = partitionId + // Start reading the next batch if we're done with this one + indexInBatch += 1 + if (indexInBatch == serializerBatchSize) { + indexInBatch = 0 + deserializeStream = nextBatchStream() + } + // Update the partition location of the element we're reading + indexInPartition += 1 + skipToNextPartition() + // If we've finished reading the last partition, remember that we're done + if (partitionId == numPartitions) { + finished = true + if (deserializeStream != null) { + deserializeStream.close() + } + } + (k, c) + } + + var nextPartitionToRead = 0 + + def readNextPartition(): Iterator[Product2[K, C]] = new Iterator[Product2[K, C]] { + val myPartition = nextPartitionToRead + nextPartitionToRead += 1 + + override def hasNext: Boolean = { + if (nextItem == null) { + nextItem = readNextItem() + if (nextItem == null) { + return false + } + } + assert(lastPartitionId >= myPartition) + // Check that we're still in the right partition; note that readNextItem will have returned + // null at EOF above so we would've returned false there + lastPartitionId == myPartition + } + + override def next(): Product2[K, C] = { + if (!hasNext) { + throw new NoSuchElementException + } + val item = nextItem + nextItem = null + item + } + } + + // Clean up our open streams and put us in a state where we can't read any more data + def cleanup() { + batchId = batchOffsets.length // Prevent reading any other batch + val ds = deserializeStream + deserializeStream = null + fileStream = null + if (ds != null) { + ds.close() + } + // NOTE: We don't do file.delete() here because that is done in ExternalSorter.stop(). + // This should also be fixed in ExternalAppendOnlyMap. + } + } + + /** + * NOTE This basically turns a memoryIterator to a spillable(to remote storage) iterator + * + * Returns a destructive iterator for iterating over the entries of this map. + * If this iterator is forced spill to remote storage to release memory when there is not enough + * memory, it returns pairs from an on-disk map. + */ + def destructiveIterator(memoryIterator: Iterator[((Int, K), C)]): Iterator[((Int, K), C)] = { + if (isShuffleSort) { + memoryIterator + } else { + readingIterator = new SpillableIterator(memoryIterator) + readingIterator + } + } + + /** + * Return an iterator over all the data written to this object, grouped by partition and + * aggregated by the requested aggregator. For each partition we then have an iterator over its + * contents, and these are expected to be accessed in order (you can't "skip ahead" to one + * partition without reading the previous one). Guaranteed to return a key-value pair for each + * partition, in order of partition ID. + * + * For now, we just merge all the spilled files in once pass, but this can be modified to + * support hierarchical merging. + * Exposed for testing. + */ + def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { + val usingMap = aggregator.isDefined + val collection: RWritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer + if (spills.isEmpty) { + // Special case: if we have only in-memory data, we don't need to merge streams, and perhaps + // we don't even need to sort by anything other than partition ID + if (!ordering.isDefined) { + // The user hasn't requested sorted keys, so only sort by partition ID, not key + groupByPartition(destructiveIterator(collection.partitionedDestructiveSortedIterator(None))) + } else { + // We do need to sort by both partition ID and key + groupByPartition(destructiveIterator( + collection.partitionedDestructiveSortedIterator(Some(keyComparator)))) + } + } else { + // Merge spilled and in-memory data + merge(spills, destructiveIterator( + collection.partitionedDestructiveSortedIterator(comparator))) + } + } + + /** + * Return an iterator over all the data written to this object, aggregated by our aggregator. + */ + def iterator: Iterator[Product2[K, C]] = { + isShuffleSort = false + partitionedIterator.flatMap(pair => pair._2) + } + + /** + * Write all the data added into this ExternalSorter into a file in the disk store. This is + * called by the SortShuffleWriter. + * + * @param blockId block ID to write to. The index file will be blockId.name + ".index". + * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) + */ + def writePartitionedFile( + blockId: BlockId, + outputFile: Path): Array[Long] = { + + // Track location of each range in the output file + val lengths = new Array[Long](numPartitions) + val writer = RemoteShuffleUtils.getRemoteWriter( + blockId, outputFile, serializerManager, serInstance, fileBufferSize, + context.taskMetrics().shuffleWriteMetrics) + + if (spills.isEmpty) { + // Case where we only have in-memory data + val collection = if (aggregator.isDefined) map else buffer + val it = collection.destructiveSortedWritablePartitionedIterator(comparator) + while (it.hasNext) { + val partitionId = it.nextPartition() + while (it.hasNext && it.nextPartition() == partitionId) { + it.writeNext(writer) + } + val segment = writer.commitAndGet() + lengths(partitionId) = segment.length + } + } else { + // We must perform merge-sort; get an iterator by partition and write everything directly. + for ((id, elements) <- this.partitionedIterator) { + if (elements.hasNext) { + for (elem <- elements) { + writer.write(elem._1, elem._2) + } + val segment = writer.commitAndGet() + lengths(id) = segment.length + } + } + } + + writer.close() + context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled) + context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes) + + lengths + } + + def stop(): Unit = { + spills.foreach(s => fs.delete(s.file, true)) + spills.clear() + forceSpillFiles.foreach(s => fs.delete(s.file, true)) + forceSpillFiles.clear() + if (map != null || buffer != null) { + map = null // So that the memory can be garbage-collected + buffer = null // So that the memory can be garbage-collected + releaseMemory() + } + } + + /** + * Given a stream of ((partition, key), combiner) pairs *assumed to be sorted by partition ID*, + * group together the pairs for each partition into a sub-iterator. + * + * @param data an iterator of elements, assumed to already be sorted by partition ID + */ + private def groupByPartition(data: Iterator[((Int, K), C)]) + : Iterator[(Int, Iterator[Product2[K, C]])] = + { + val buffered = data.buffered + (0 until numPartitions).iterator.map(p => (p, new IteratorForPartition(p, buffered))) + } + + /** + * An iterator that reads only the elements for a given partition ID from an underlying buffered + * stream, assuming this partition is the next one to be read. Used to make it easier to return + * partitioned iterators from our in-memory collection. + */ + private[this] class IteratorForPartition(partitionId: Int, data: BufferedIterator[((Int, K), C)]) + extends Iterator[Product2[K, C]] + { + override def hasNext: Boolean = data.hasNext && data.head._1._1 == partitionId + + override def next(): Product2[K, C] = { + if (!hasNext) { + throw new NoSuchElementException + } + val elem = data.next() + (elem._1._2, elem._2) + } + } + + /* + * An iterator wrapping the collection of this ExternalSorter, it supports to spill to + * remote storage + */ + private[this] class SpillableIterator(var upstream: Iterator[((Int, K), C)]) + extends Iterator[((Int, K), C)] { + + private val SPILL_LOCK = new Object() + + private var nextUpstream: Iterator[((Int, K), C)] = null + + private var cur: ((Int, K), C) = readNext() + + private var hasSpilled: Boolean = false + + def spill(): Boolean = SPILL_LOCK.synchronized { + if (hasSpilled) { + false + } else { + val inMemoryIterator = new RWritablePartitionedIterator { + private[this] var cur = if (upstream.hasNext) upstream.next() else null + + def writeNext(writer: RemoteBlockObjectWriter): Unit = { + writer.write(cur._1._2, cur._2) + cur = if (upstream.hasNext) upstream.next() else null + } + + def hasNext(): Boolean = cur != null + + def nextPartition(): Int = cur._1._1 + } + logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " + + s" it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory") + val spillFile = spillMemoryIteratorToDisk(inMemoryIterator) + forceSpillFiles += spillFile + val spillReader = new SpillReader(spillFile) + nextUpstream = (0 until numPartitions).iterator.flatMap { p => + val iterator = spillReader.readNextPartition() + iterator.map(cur => ((p, cur._1), cur._2)) + } + hasSpilled = true + true + } + } + + def readNext(): ((Int, K), C) = SPILL_LOCK.synchronized { + if (nextUpstream != null) { + upstream = nextUpstream + nextUpstream = null + } + if (upstream.hasNext) { + upstream.next() + } else { + null + } + } + + override def hasNext(): Boolean = cur != null + + override def next(): ((Int, K), C) = { + val r = cur + cur = readNext() + r + } + } +} diff --git a/oap-shuffle/remote-shuffle/src/test/java/org/apache/spark/shuffle/sort/RemoteUnsafeShuffleWriterSuite.java b/oap-shuffle/remote-shuffle/src/test/java/org/apache/spark/shuffle/sort/RemoteUnsafeShuffleWriterSuite.java new file mode 100644 index 000000000..5e4ce79a5 --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/test/java/org/apache/spark/shuffle/sort/RemoteUnsafeShuffleWriterSuite.java @@ -0,0 +1,572 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +import java.io.*; +import java.nio.ByteBuffer; +import java.util.*; + +import org.apache.hadoop.fs.FSDataInputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.spark.*; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; +import org.apache.spark.shuffle.remote.RemoteShuffleBlockResolver; +import scala.Option; +import scala.Product2; +import scala.Tuple2; +import scala.collection.Iterator; + +import com.google.common.collect.HashMultiset; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.io.CompressionCodec$; +import org.apache.spark.io.LZ4CompressionCodec; +import org.apache.spark.io.LZFCompressionCodec; +import org.apache.spark.io.SnappyCompressionCodec; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.memory.TestMemoryManager; +import org.apache.spark.network.util.LimitedInputStream; +import org.apache.spark.scheduler.MapStatus; +import org.apache.spark.serializer.*; +import org.apache.spark.storage.*; +import org.apache.spark.shuffle.remote.*; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.lessThan; +import static org.junit.Assert.*; +import static org.mockito.Answers.RETURNS_SMART_NULLS; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyInt; +import static org.mockito.Mockito.*; + +public class RemoteUnsafeShuffleWriterSuite { + + static final int NUM_PARTITITONS = 4; + TestMemoryManager memoryManager; + TaskMemoryManager taskMemoryManager; + final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS); + Path mergedOutputFile; + long[] partitionSizesInMergedFile; + final LinkedList spillFilesCreated = new LinkedList<>(); + SparkConf conf; + final Serializer serializer = new KryoSerializer(new SparkConf()); + TaskMetrics taskMetrics; + + BlockManager blockManager; + + @Mock(answer = RETURNS_SMART_NULLS) + RemoteShuffleBlockResolver shuffleBlockResolver; + + RemoteShuffleManager shuffleManager; + SparkContext sc; + + @Mock(answer = RETURNS_SMART_NULLS) + TaskContext taskContext; + + @Mock(answer = RETURNS_SMART_NULLS) + ShuffleDependency shuffleDep; + + @Mock(answer = RETURNS_SMART_NULLS) + ShuffleWriteMetricsReporter metrics; + + FileSystem fs; + + @After + public void tearDown() throws IOException { + sc.stop(); + shuffleBlockResolver.stop(); + final long leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory(); + if (leakedMemory != 0) { + fail("Test leaked " + leakedMemory + " bytes of managed memory"); + } + } + + // We cannot incorporate this into the setUp function due to for most tests, SparkContext should + // be constructed 'not before' but inside the UT cases, to get necessary Spark configuration + private void setUpSparkContextPrivate() { + sc = new SparkContext(conf); + + blockManager = SparkEnv.get().blockManager(); + + shuffleManager = (RemoteShuffleManager) SparkEnv.get().shuffleManager(); + fs = shuffleManager.shuffleBlockResolver().fs(); + + when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); + doAnswer( + invocationOnMock -> { + partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2]; + Path tmp = (Path) invocationOnMock.getArguments()[3]; + fs.delete(mergedOutputFile, true); + fs.rename(tmp, mergedOutputFile); + return null; + }) + .when(shuffleBlockResolver) + .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), any(Path.class)); + + when(shuffleBlockResolver.createTempShuffleBlock()) + .thenAnswer( + invocationOnMock -> { + Tuple2 result = + shuffleManager.shuffleBlockResolver().createTempShuffleBlock(); + spillFilesCreated.add(result._2); + return result; + }); + } + + @Before + public void setUp() throws IOException { + MockitoAnnotations.initMocks(this); + spillFilesCreated.clear(); + String rootDir = "/tmp"; + conf = + new SparkConf(true) + .setMaster("local[1]") + .setAppName("Friday") + .set("spark.shuffle.manager", RemoteShuffleManager.class.getCanonicalName()) + .set("spark.shuffle.remote.storageMasterUri", "file://") + .set("spark.shuffle.remote.filesRootDirectory", rootDir) + .set("spark.buffer.pageSize", "1m") + .set("spark.memory.offHeap.enabled", "false"); + + mergedOutputFile = new Path(rootDir + "/shuffle/someFileAndItsFridayLOL"); + taskMetrics = new TaskMetrics(); + memoryManager = new TestMemoryManager(conf); + taskMemoryManager = new TaskMemoryManager(memoryManager, 0); + + when(taskContext.taskMetrics()).thenReturn(taskMetrics); + when(shuffleDep.serializer()).thenReturn(serializer); + when(shuffleDep.partitioner()).thenReturn(hashPartitioner); + } + + @Test(expected = IllegalStateException.class) + public void mustCallWriteBeforeSuccessfulStop() throws IOException { + setUpSparkContextPrivate(); + createWriter(false).stop(true); + } + + @Test + public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException { + setUpSparkContextPrivate(); + createWriter(false).stop(false); + } + + static class PandaException extends RuntimeException {} + + @Test(expected = PandaException.class) + public void writeFailurePropagates() throws Exception { + setUpSparkContextPrivate(); + class BadRecords extends scala.collection.AbstractIterator> { + @Override + public boolean hasNext() { + throw new PandaException(); + } + + @Override + public Product2 next() { + return null; + } + } + final RemoteUnsafeShuffleWriter writer = createWriter(true); + writer.write(new BadRecords()); + } + + @Test + public void writeEmptyIterator() throws Exception { + setUpSparkContextPrivate(); + final RemoteUnsafeShuffleWriter writer = createWriter(true); + writer.write(Collections.emptyIterator()); + final Option mapStatus = writer.stop(true); + assertTrue(mapStatus.isDefined()); + assertTrue(fs.exists(mergedOutputFile)); + assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile); + assertEquals(0, taskMetrics.shuffleWriteMetrics().recordsWritten()); + assertEquals(0, taskMetrics.shuffleWriteMetrics().bytesWritten()); + assertEquals(0, taskMetrics.diskBytesSpilled()); + assertEquals(0, taskMetrics.memoryBytesSpilled()); + } + + @Test + public void writeWithoutSpilling() throws Exception { + setUpSparkContextPrivate(); + // In this example, each partition should have exactly one record: + final ArrayList> dataToWrite = new ArrayList<>(); + for (int i = 0; i < NUM_PARTITITONS; i++) { + dataToWrite.add(new Tuple2<>(i, i)); + } + final RemoteUnsafeShuffleWriter writer = createWriter(true); + writer.write(dataToWrite.iterator()); + final Option mapStatus = writer.stop(true); + assertTrue(mapStatus.isDefined()); + assertTrue(fs.exists(mergedOutputFile)); + + long sumOfPartitionSizes = 0; + for (long size : partitionSizesInMergedFile) { + // All partitions should be the same size: + assertEquals(partitionSizesInMergedFile[0], size); + sumOfPartitionSizes += size; + } + assertEquals(fs.getFileStatus(mergedOutputFile).getLen(), sumOfPartitionSizes); + assertEquals(HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile())); + assertSpillFilesWereCleanedUp(); + ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics(); + assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten()); + assertEquals(0, taskMetrics.diskBytesSpilled()); + assertEquals(0, taskMetrics.memoryBytesSpilled()); + assertEquals(fs.getFileStatus(mergedOutputFile).getLen(), shuffleWriteMetrics.bytesWritten()); + } + + @Test + public void mergeSpillsWithTransferToAndLZF() throws Exception { + testMergingSpills(true, LZFCompressionCodec.class.getName(), false); + } + + @Test + public void mergeSpillsWithFileStreamAndLZF() throws Exception { + testMergingSpills(false, LZFCompressionCodec.class.getName(), false); + } + + @Test + public void mergeSpillsWithTransferToAndLZ4() throws Exception { + testMergingSpills(true, LZ4CompressionCodec.class.getName(), false); + } + + @Test + public void mergeSpillsWithFileStreamAndLZ4() throws Exception { + testMergingSpills(false, LZ4CompressionCodec.class.getName(), false); + } + + @Test + public void mergeSpillsWithTransferToAndSnappy() throws Exception { + testMergingSpills(true, SnappyCompressionCodec.class.getName(), false); + } + + @Test + public void mergeSpillsWithFileStreamAndSnappy() throws Exception { + testMergingSpills(false, SnappyCompressionCodec.class.getName(), false); + } + + @Test + public void mergeSpillsWithTransferToAndNoCompression() throws Exception { + testMergingSpills(true, null, false); + } + + @Test + public void mergeSpillsWithFileStreamAndNoCompression() throws Exception { + testMergingSpills(false, null, false); + } + + @Test + public void mergeSpillsWithCompressionAndEncryption() throws Exception { + // This should actually be translated to a "file stream merge" internally, just have the + // test to make sure that it's the case. + testMergingSpills(true, LZ4CompressionCodec.class.getName(), true); + } + + @Test + public void mergeSpillsWithFileStreamAndCompressionAndEncryption() throws Exception { + testMergingSpills(false, LZ4CompressionCodec.class.getName(), true); + } + + @Test + public void mergeSpillsWithCompressionAndEncryptionSlowPath() throws Exception { + conf.set("spark.shuffle.unsafe.fastMergeEnabled", "false"); + testMergingSpills(false, LZ4CompressionCodec.class.getName(), true); + } + + @Test + public void mergeSpillsWithEncryptionAndNoCompression() throws Exception { + // This should actually be translated to a "file stream merge" internally, just have the + // test to make sure that it's the case. + testMergingSpills(true, null, true); + } + + @Test + public void mergeSpillsWithFileStreamAndEncryptionAndNoCompression() throws Exception { + testMergingSpills(false, null, true); + } + + @Test + public void writeEnoughDataToTriggerSpill() throws Exception { + setUpSparkContextPrivate(); + memoryManager.limit(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES); + final RemoteUnsafeShuffleWriter writer = createWriter(false); + final ArrayList> dataToWrite = new ArrayList<>(); + final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 10]; + for (int i = 0; i < 10 + 1; i++) { + dataToWrite.add(new Tuple2<>(i, bigByteArray)); + } + writer.write(dataToWrite.iterator()); + assertEquals(2, spillFilesCreated.size()); + writer.stop(true); + readRecordsFromFile(); + assertSpillFilesWereCleanedUp(); + ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics(); + assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten()); + assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L)); + assertThat( + taskMetrics.diskBytesSpilled(), lessThan(fs.getFileStatus(mergedOutputFile).getLen())); + assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L)); + assertEquals(fs.getFileStatus(mergedOutputFile).getLen(), shuffleWriteMetrics.bytesWritten()); + } + + @Test + public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpillRadixOff() throws Exception { + conf.set("spark.shuffle.sort.useRadixSort", "false"); + setUpSparkContextPrivate(); + writeEnoughRecordsToTriggerSortBufferExpansionAndSpill(); + assertEquals(2, spillFilesCreated.size()); + } + + @Test + public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpillRadixOn() throws Exception { + conf.set("spark.shuffle.sort.useRadixSort", "true"); + setUpSparkContextPrivate(); + writeEnoughRecordsToTriggerSortBufferExpansionAndSpill(); + assertEquals(3, spillFilesCreated.size()); + } + + private void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception { + memoryManager.limit(UnsafeShuffleWriter.DEFAULT_INITIAL_SER_BUFFER_SIZE * 16); + final RemoteUnsafeShuffleWriter writer = createWriter(false); + final ArrayList> dataToWrite = new ArrayList<>(); + for (int i = 0; i < UnsafeShuffleWriter.DEFAULT_INITIAL_SER_BUFFER_SIZE + 1; i++) { + dataToWrite.add(new Tuple2<>(i, i)); + } + writer.write(dataToWrite.iterator()); + writer.stop(true); + readRecordsFromFile(); + assertSpillFilesWereCleanedUp(); + ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics(); + assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten()); + assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L)); + assertThat( + taskMetrics.diskBytesSpilled(), lessThan(fs.getFileStatus(mergedOutputFile).getLen())); + assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L)); + assertEquals(fs.getFileStatus(mergedOutputFile).getLen(), shuffleWriteMetrics.bytesWritten()); + } + + @Test + public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception { + setUpSparkContextPrivate(); + final RemoteUnsafeShuffleWriter writer = createWriter(false); + final ArrayList> dataToWrite = new ArrayList<>(); + final byte[] bytes = new byte[(int) (ShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)]; + new Random(42).nextBytes(bytes); + dataToWrite.add(new Tuple2<>(1, ByteBuffer.wrap(bytes))); + writer.write(dataToWrite.iterator()); + writer.stop(true); + assertEquals(HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile())); + assertSpillFilesWereCleanedUp(); + } + + @Test + public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception { + setUpSparkContextPrivate(); + final RemoteUnsafeShuffleWriter writer = createWriter(false); + final ArrayList> dataToWrite = new ArrayList<>(); + dataToWrite.add(new Tuple2<>(1, ByteBuffer.wrap(new byte[1]))); + // We should be able to write a record that's right _at_ the max record size + final byte[] atMaxRecordSize = new byte[(int) taskMemoryManager.pageSizeBytes() - 4]; + new Random(42).nextBytes(atMaxRecordSize); + dataToWrite.add(new Tuple2<>(2, ByteBuffer.wrap(atMaxRecordSize))); + // Inserting a record that's larger than the max record size + final byte[] exceedsMaxRecordSize = new byte[(int) taskMemoryManager.pageSizeBytes()]; + new Random(42).nextBytes(exceedsMaxRecordSize); + dataToWrite.add(new Tuple2<>(3, ByteBuffer.wrap(exceedsMaxRecordSize))); + writer.write(dataToWrite.iterator()); + writer.stop(true); + assertEquals(HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile())); + assertSpillFilesWereCleanedUp(); + } + + @Test + public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException { + setUpSparkContextPrivate(); + final RemoteUnsafeShuffleWriter writer = createWriter(false); + writer.insertRecordIntoSorter(new Tuple2<>(1, 1)); + writer.insertRecordIntoSorter(new Tuple2<>(2, 2)); + writer.forceSorterToSpill(); + writer.insertRecordIntoSorter(new Tuple2<>(2, 2)); + writer.stop(false); + assertSpillFilesWereCleanedUp(); + } + + @Test + public void testPeakMemoryUsed() throws Exception { + setUpSparkContextPrivate(); + final long recordLengthBytes = 8; + final long pageSizeBytes = 256; + final long numRecordsPerPage = pageSizeBytes / recordLengthBytes; + taskMemoryManager = spy(taskMemoryManager); + when(taskMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes); + final RemoteUnsafeShuffleWriter writer = + new RemoteUnsafeShuffleWriter<>( + blockManager, + shuffleBlockResolver, + taskMemoryManager, + new SerializedShuffleHandle<>(0, shuffleDep), + 0, // map id + taskContext, + conf, + metrics); + + // Peak memory should be monotonically increasing. More specifically, every time + // we allocate a new page it should increase by exactly the size of the page. + long previousPeakMemory = writer.getPeakMemoryUsedBytes(); + long newPeakMemory; + try { + for (int i = 0; i < numRecordsPerPage * 10; i++) { + writer.insertRecordIntoSorter(new Tuple2(1, 1)); + newPeakMemory = writer.getPeakMemoryUsedBytes(); + if (i % numRecordsPerPage == 0) { + // The first page is allocated in constructor, another page will be allocated after + // every numRecordsPerPage records (peak memory should change). + assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory); + } else { + assertEquals(previousPeakMemory, newPeakMemory); + } + previousPeakMemory = newPeakMemory; + } + + // Spilling should not change peak memory + writer.forceSorterToSpill(); + newPeakMemory = writer.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + for (int i = 0; i < numRecordsPerPage; i++) { + writer.insertRecordIntoSorter(new Tuple2(1, 1)); + } + newPeakMemory = writer.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + + // Closing the writer should not change peak memory + writer.closeAndWriteOutput(); + newPeakMemory = writer.getPeakMemoryUsedBytes(); + assertEquals(previousPeakMemory, newPeakMemory); + } finally { + writer.stop(false); + } + } + + private RemoteUnsafeShuffleWriter createWriter(boolean transferToEnabled) + throws IOException { + conf.set("spark.file.transferTo", String.valueOf(transferToEnabled)); + return new RemoteUnsafeShuffleWriter<>( + blockManager, + shuffleBlockResolver, + taskMemoryManager, + new SerializedShuffleHandle<>(0, shuffleDep), + 1, // map id + taskContext, + conf, + metrics); + } + + private List> readRecordsFromFile() throws IOException { + final ArrayList> recordsList = new ArrayList<>(); + long startOffset = 0; + for (int i = 0; i < NUM_PARTITITONS; i++) { + final long partitionSize = partitionSizesInMergedFile[i]; + if (partitionSize > 0) { + InputStream fin = fs.open(mergedOutputFile); + ((FSDataInputStream) fin).seek(startOffset); + InputStream in = new LimitedInputStream(fin, partitionSize); + in = SparkEnv.get().serializerManager().wrapForEncryption(in); + if (conf.getBoolean("spark.shuffle.compress", true)) { + in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in); + } + DeserializationStream recordsStream = serializer.newInstance().deserializeStream(in); + Iterator> records = recordsStream.asKeyValueIterator(); + while (records.hasNext()) { + Tuple2 record = records.next(); + assertEquals(i, hashPartitioner.getPartition(record._1())); + recordsList.add(record); + } + recordsStream.close(); + startOffset += partitionSize; + } + } + return recordsList; + } + + private void testMergingSpills( + final boolean transferToEnabled, String compressionCodecName, boolean encrypt) + throws Exception { + if (compressionCodecName != null) { + conf.set("spark.shuffle.compress", "true"); + conf.set("spark.io.compression.codec", compressionCodecName); + } else { + conf.set("spark.shuffle.compress", "false"); + } + conf.set(org.apache.spark.internal.config.package$.MODULE$.IO_ENCRYPTION_ENABLED(), encrypt); + + setUpSparkContextPrivate(); + testMergingSpills(transferToEnabled, encrypt); + } + + private void testMergingSpills(boolean transferToEnabled, boolean encrypted) throws IOException { + final RemoteUnsafeShuffleWriter writer = createWriter(transferToEnabled); + final ArrayList> dataToWrite = new ArrayList<>(); + for (int i : new int[] {1, 2, 3, 4, 4, 2}) { + dataToWrite.add(new Tuple2<>(i, i)); + } + writer.insertRecordIntoSorter(dataToWrite.get(0)); + writer.insertRecordIntoSorter(dataToWrite.get(1)); + writer.insertRecordIntoSorter(dataToWrite.get(2)); + writer.insertRecordIntoSorter(dataToWrite.get(3)); + writer.forceSorterToSpill(); + writer.insertRecordIntoSorter(dataToWrite.get(4)); + writer.insertRecordIntoSorter(dataToWrite.get(5)); + writer.closeAndWriteOutput(); + final Option mapStatus = writer.stop(true); + assertTrue(mapStatus.isDefined()); + assertTrue(fs.exists(mergedOutputFile)); + assertEquals(2, spillFilesCreated.size()); + + long sumOfPartitionSizes = 0; + for (long size : partitionSizesInMergedFile) { + sumOfPartitionSizes += size; + } + + assertEquals(sumOfPartitionSizes, fs.getFileStatus(mergedOutputFile).getLen()); + + assertEquals(HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile())); + assertSpillFilesWereCleanedUp(); + ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics(); + assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten()); + assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L)); + assertThat( + taskMetrics.diskBytesSpilled(), lessThan(fs.getFileStatus(mergedOutputFile).getLen())); + assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L)); + assertEquals(fs.getFileStatus(mergedOutputFile).getLen(), shuffleWriteMetrics.bytesWritten()); + } + + private void assertSpillFilesWereCleanedUp() throws IOException { + for (Path spillFile : spillFilesCreated) { + assertFalse( + "Spill file " + spillFile.toString() + " was not cleaned up", fs.exists(spillFile)); + } + } +} diff --git a/oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/shuffle/remote/RemoteShuffleBlockIteratorSuite.scala b/oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/shuffle/remote/RemoteShuffleBlockIteratorSuite.scala new file mode 100644 index 000000000..cdd66cab5 --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/shuffle/remote/RemoteShuffleBlockIteratorSuite.scala @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.remote + +import java.io.{IOException, InputStream} + +import scala.collection.mutable.ArrayBuffer + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito.{doNothing, mock, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer + +import org.apache.spark._ +import org.apache.spark.internal.config +import org.apache.spark.network.BlockTransferService +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.netty.RemoteShuffleTransferService +import org.apache.spark.network.shuffle.BlockFetchingListener +import org.apache.spark.network.util.LimitedInputStream +import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} +import org.apache.spark.storage._ +import org.apache.spark.util.Utils + +class RemoteShuffleBlockIteratorSuite extends SparkFunSuite with LocalSparkContext { + + val metrics = mock(classOf[ShuffleReadMetricsReporter]) + + // With/without index cache, configurations set/unset + testWithMultiplePath("basic read")(basicRead) + + test("retry corrupt blocks") { + // To set an active ShuffleManager + new RemoteShuffleManager(createDefaultConfWithIndexCacheEnabled(true)) + val blockResolver = mock(classOf[RemoteShuffleBlockResolver]) + when(blockResolver.indexCacheEnabled).thenReturn(true) + + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val blocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer() + ) + + val corruptLocalBuffer = mock(classOf[HadoopFileSegmentManagedBuffer]) + doNothing().when(corruptLocalBuffer).prepareData(any()) + when(corruptLocalBuffer.createInputStream()).thenThrow(new RuntimeException("oops")) + + val transfer = mock(classOf[BlockTransferService]) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { + override def answer(invocation: InvocationOnMock): Unit = { + val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + // Return the first block, and then fail. + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 2, 0).toString, corruptLocalBuffer) + } + }) + + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( + (remoteBmId, blocks.keys.zipWithIndex.map { + case (blockId, mapIndex) => (blockId, 1.asInstanceOf[Long], mapIndex) + }.toSeq)).toIterator + + val taskContext = TaskContext.empty() + val iterator = new RemoteShuffleBlockIterator( + taskContext, + transfer, + blockResolver, + blocksByAddress, + (_, in) => new LimitedInputStream(in, 100), + 48 * 1024 * 1024, + Int.MaxValue, + Int.MaxValue, + true, + metrics, + false) + + // The first block should be returned without an exception + val (id1, _) = iterator.next() + assert(id1 === ShuffleBlockId(0, 0, 0)) + + // The next block is corrupt local block (the second one is corrupt and retried) + intercept[FetchFailedException] { iterator.next() } + + intercept[FetchFailedException] { iterator.next() } + } + + // Create a mock managed buffer for testing + private def createMockManagedBuffer(size: Int = 1): ManagedBuffer = { + val mockManagedBuffer = mock(classOf[HadoopFileSegmentManagedBuffer]) + val in = mock(classOf[InputStream]) + when(in.read(any[Array[Byte]])).thenReturn(1) + when(in.read(any(), any(), any())).thenReturn(1) + doNothing().when(mockManagedBuffer).prepareData(any()) + when(mockManagedBuffer.createInputStream()).thenReturn(in) + when(mockManagedBuffer.size()).thenReturn(size) + mockManagedBuffer + } + + private def mockCorruptBuffer(size: Long = 1L): ManagedBuffer = { + val corruptStream = mock(classOf[InputStream]) + when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) + val corruptBuffer = mock(classOf[HadoopFileSegmentManagedBuffer]) + when(corruptBuffer.size()).thenReturn(size) + when(corruptBuffer.createInputStream()).thenReturn(corruptStream) + corruptBuffer + } + + private def testWithMultiplePath(name: String, loadDefaults: Boolean = true) + (body: (SparkConf => Unit)): Unit = { + val indexCacheDisabledConf = createDefaultConf(loadDefaults) + val indexCacheEnabledConf = createDefaultConfWithIndexCacheEnabled(loadDefaults) + + test(name + " w/o index cache") { + body(indexCacheDisabledConf) + } + test(name + " w/ index cache") { + body(indexCacheEnabledConf) + } + test(name + " w/o index cache, constraining maxBlocksInFlightPerAddress") { + body(indexCacheDisabledConf.set(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS.key, "1")) + } + test(name + " w index cache, constraining maxBlocksInFlightPerAddress") { + body(indexCacheEnabledConf.set(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS.key, "1")) + } + val default = RemoteShuffleConf.DATA_FETCH_EAGER_REQUIREMENT.defaultValue.get + val testWith = (true ^ default) + test(name + s" with eager requirement = ${testWith}") { + body(indexCacheEnabledConf.set( + RemoteShuffleConf.DATA_FETCH_EAGER_REQUIREMENT.key, testWith.toString)) + } + } + + private def prepareMapOutput( + resolver: RemoteShuffleBlockResolver, shuffleId: Int, mapId: Int, blocks: Array[Byte]*) { + val dataTmp = RemoteShuffleUtils.tempPathWith(resolver.getDataFile(shuffleId, mapId)) + val fs = resolver.fs + val out = fs.create(dataTmp) + val lengths = new ArrayBuffer[Long] + Utils.tryWithSafeFinally { + for (block <- blocks) { + lengths += block.length + out.write(block) + } + } { + out.close() + } + // Actually this UT relies on this outside function's fine working + resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths.toArray, dataTmp) + } + + private def basicRead(conf: SparkConf): Unit = { + + sc = new SparkContext("local[1]", "Shuffle Iterator read", conf) + val shuffleId = 1 + + val env = SparkEnv.get + val resolver = env.shuffleManager.shuffleBlockResolver.asInstanceOf[RemoteShuffleBlockResolver] + // There are two transferServices, use the one exclusively for RemoteShuffle + val transferService = env.shuffleManager.shuffleBlockResolver + .asInstanceOf[RemoteShuffleBlockResolver].remoteShuffleTransferService + val shuffleServerId = + transferService.asInstanceOf[RemoteShuffleTransferService].getShuffleServerId + + val numMaps = 3 + + val expectPart0 = Array[Byte](1) + val expectPart1 = Array[Byte](6, 4) + val expectPart2 = Array[Byte](0, 2) + val expectPart3 = Array[Byte](28) + val expectPart4 = Array[Byte](96, 97) + val expectPart5 = Array[Byte](95) + + prepareMapOutput( + resolver, shuffleId, 0, Array[Byte](3, 6, 9), expectPart0, expectPart1) + prepareMapOutput( + resolver, shuffleId, 1, Array[Byte](19, 94), expectPart2, expectPart3) + prepareMapOutput( + resolver, shuffleId, 2, Array[Byte](99, 98), expectPart4, expectPart5) + + val startPartition = 1 + val endPartition = 3 + + val blockInfos = for (i <- 0 until numMaps; j <- startPartition until endPartition) yield { + (ShuffleBlockId(shuffleId, i, j), 1L, 1) + } + + val blocksByAddress = Seq((shuffleServerId, blockInfos)) + + val iter = new RemoteShuffleBlockIterator( + TaskContext.empty(), + transferService, + resolver, + blocksByAddress.toIterator, + (_: BlockId, input: InputStream) => input, + 48 * 1024 * 1024, + Int.MaxValue, + Int.MaxValue, + true, + metrics, + false) + + val expected = + expectPart0 ++ expectPart1 ++ expectPart2 ++ expectPart3 ++ expectPart4 ++ expectPart5 + + val answer = new ArrayBuffer[Byte]() + iter.map(_._2).foreach { case input => + var current: Int = -1 + while ({current = input.read(); current != -1}) { + answer += current.toByte + } + } + // Shuffle doesn't guarantee that the blocks are returned as ordered in blockInfos, + // so the answer and expected should be sorted before compared + assert(answer.map(_.toInt).sorted.zip(expected.map(_.toInt).sorted) + .forall{case (byteAns, byteExp) => byteAns === byteExp}) + } + + private def cleanAll(files: Path*): Unit = { + for (file <- files) { + deleteFileAndTempWithPrefix(file) + } + } + + private def deleteFileAndTempWithPrefix(prefixPath: Path): Unit = { + val fs = prefixPath.getFileSystem(new Configuration(false)) + val parentDir = prefixPath.getParent + val iter = fs.listFiles(parentDir, false) + while (iter.hasNext) { + val file = iter.next() + if (file.getPath.toString.contains(prefixPath.getName)) { + fs.delete(file.getPath, true) + } + } + } +} diff --git a/oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/shuffle/remote/RemoteShuffleBlockObjectWriterSuite.scala b/oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/shuffle/remote/RemoteShuffleBlockObjectWriterSuite.scala new file mode 100644 index 000000000..b4ea0fd45 --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/shuffle/remote/RemoteShuffleBlockObjectWriterSuite.scala @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.remote + +import org.apache.hadoop.fs.Path +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.SparkFunSuite +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerManager} + +class RemoteBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { + + var shuffleManager: RemoteShuffleManager = _ + + private lazy val fs = shuffleManager.shuffleBlockResolver.fs + + override def beforeEach(): Unit = { + super.beforeEach() + } + + override def afterEach(): Unit = { + try { + if (shuffleManager != null) { + shuffleManager.stop() + } + } finally { + super.afterEach() + } + } + + private def createWriter(): (RemoteBlockObjectWriter, Path, ShuffleWriteMetrics) = { + val conf = createDefaultConf() + shuffleManager = new RemoteShuffleManager(conf) + val resolver = shuffleManager.shuffleBlockResolver + val file = resolver.createTempLocalBlock()._2 + val serializerManager = new SerializerManager(new JavaSerializer(conf), conf) + val writeMetrics = new ShuffleWriteMetrics() + val writer = new RemoteBlockObjectWriter( + file, serializerManager, new KryoSerializer(createDefaultConf()).newInstance(), 1024, + true, writeMetrics) + (writer, file, writeMetrics) + } + + test("verify write metrics") { + val (writer, file, writeMetrics) = createWriter() + + writer.write(Long.box(20), Long.box(30)) + // Record metrics update on every write + assert(writeMetrics.recordsWritten === 1) + // Metrics don't update on every write + assert(writeMetrics.bytesWritten == 0) + // After 16384 writes, metrics should update + for (i <- 0 until 16384) { + writer.flush() + writer.write(Long.box(i), Long.box(i)) + } + assert(writeMetrics.bytesWritten > 0) + assert(writeMetrics.recordsWritten === 16385) + writer.commitAndGet() + writer.close() + assert(fs.getFileStatus(file).getLen() == writeMetrics.bytesWritten) + } + + test("verify write metrics on revert") { + val (writer, _, writeMetrics) = createWriter() + + writer.write(Long.box(20), Long.box(30)) + // Record metrics update on every write + assert(writeMetrics.recordsWritten === 1) + // Metrics don't update on every write + assert(writeMetrics.bytesWritten == 0) + // After 16384 writes, metrics should update + for (i <- 0 until 16384) { + writer.flush() + writer.write(Long.box(i), Long.box(i)) + } + assert(writeMetrics.bytesWritten > 0) + assert(writeMetrics.recordsWritten === 16385) + writer.revertPartialWritesAndClose() + assert(writeMetrics.bytesWritten == 0) + assert(writeMetrics.recordsWritten == 0) + } + + test("Reopening a closed block writer") { + val (writer, _, _) = createWriter() + + writer.open() + writer.close() + intercept[IllegalStateException] { + writer.open() + } + } + + // 1. When the underlying filesystem is local file system, the closeAndGet doesn't immediately + // sync with the device unless BLockObjectWriter.close is called 2. Local file system doesn't + // support truncate + ignore("calling revertPartialWritesAndClose() on a partial write should truncate up to commit") { + val (writer, file, writeMetrics) = createWriter() + + writer.write(Long.box(20), Long.box(30)) + val firstSegment = writer.commitAndGet() + assert(firstSegment.length === fs.getFileStatus(file).getLen()) + assert(writeMetrics.bytesWritten === fs.getFileStatus(file).getLen()) + + writer.write(Long.box(40), Long.box(50)) + + writer.revertPartialWritesAndClose() + assert(firstSegment.length === fs.getFileStatus(file).getLen()) + assert(writeMetrics.bytesWritten === fs.getFileStatus(file).getLen()) + assert(writeMetrics.recordsWritten == 1) + } + + ignore("calling revertPartialWritesAndClose() after commit() should have no effect") { + val (writer, file, writeMetrics) = createWriter() + + writer.write(Long.box(20), Long.box(30)) + val firstSegment = writer.commitAndGet() + assert(firstSegment.length === fs.getFileStatus(file).getLen()) + assert(writeMetrics.bytesWritten === fs.getFileStatus(file).getLen()) + + writer.revertPartialWritesAndClose() + assert(firstSegment.length === fs.getFileStatus(file).getLen()) + assert(writeMetrics.bytesWritten === fs.getFileStatus(file).getLen()) + } + + test("calling revertPartialWritesAndClose() on a closed block writer should have no effect") { + val (writer, file, writeMetrics) = createWriter() + for (i <- 1 to 1000) { + writer.write(i, i) + } + writer.commitAndGet() + writer.close() + val bytesWritten = writeMetrics.bytesWritten + assert(writeMetrics.recordsWritten === 1000) + writer.revertPartialWritesAndClose() + assert(writeMetrics.recordsWritten === 1000) + assert(writeMetrics.bytesWritten === bytesWritten) + } + + test("commit() and close() should be idempotent") { + val (writer, file, writeMetrics) = createWriter() + for (i <- 1 to 1000) { + writer.write(i, i) + } + writer.commitAndGet() + writer.close() + val bytesWritten = writeMetrics.bytesWritten + val writeTime = writeMetrics.writeTime + assert(writeMetrics.recordsWritten === 1000) + writer.commitAndGet() + writer.close() + assert(writeMetrics.recordsWritten === 1000) + assert(writeMetrics.bytesWritten === bytesWritten) + assert(writeMetrics.writeTime === writeTime) + } + + test("revertPartialWritesAndClose() should be idempotent") { + val (writer, file, writeMetrics) = createWriter() + for (i <- 1 to 1000) { + writer.write(i, i) + } + writer.revertPartialWritesAndClose() + val bytesWritten = writeMetrics.bytesWritten + val writeTime = writeMetrics.writeTime + assert(writeMetrics.recordsWritten === 0) + writer.revertPartialWritesAndClose() + assert(writeMetrics.recordsWritten === 0) + assert(writeMetrics.bytesWritten === bytesWritten) + assert(writeMetrics.writeTime === writeTime) + } + + test("commit() and close() without ever opening or writing") { + val (writer, _, _) = createWriter() + val segment = writer.commitAndGet() + writer.close() + assert(segment.length === 0) + } +} diff --git a/oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/shuffle/remote/RemoteShuffleBlockResolverSuite.scala b/oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/shuffle/remote/RemoteShuffleBlockResolverSuite.scala new file mode 100644 index 000000000..3638fa7fc --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/shuffle/remote/RemoteShuffleBlockResolverSuite.scala @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.remote + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.SparkFunSuite +import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.util.Utils + +class RemoteShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEach { + + val conf = createDefaultConf() + + var dataFile: Path = _ + var indexFile: Path = _ + var dataTmp: Path = _ + var shuffleManager: RemoteShuffleManager = _ + val shuffleId = 1 + val mapId = 2 + + test("Commit shuffle files multiple times") { + + shuffleManager = new RemoteShuffleManager(conf) + val resolver = shuffleManager.shuffleBlockResolver + + indexFile = resolver.getIndexFile(shuffleId, mapId) + dataFile = resolver.getDataFile(shuffleId, mapId) + val fs = resolver.fs + + dataTmp = RemoteShuffleUtils.tempPathWith(dataFile) + + val lengths = Array[Long](10, 0, 20) + val out = fs.create(dataTmp) + Utils.tryWithSafeFinally { + out.write(new Array[Byte](30)) + } { + out.close() + } + resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp) + + assert(fs.exists(indexFile)) + assert(fs.getFileStatus(indexFile).getLen() === (lengths.length + 1) * 8) + assert(fs.exists(dataFile)) + assert(fs.getFileStatus(dataFile).getLen() === 30) + assert(!fs.exists(dataTmp)) + + val lengths2 = new Array[Long](3) + val dataTmp2 = RemoteShuffleUtils.tempPathWith(dataFile) + val out2 = fs.create(dataTmp2) + Utils.tryWithSafeFinally { + out2.write(Array[Byte](1)) + out2.write(new Array[Byte](29)) + } { + out2.close() + } + resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths2, dataTmp2) + + assert(fs.getFileStatus(indexFile).getLen() === (lengths.length + 1) * 8) + assert(lengths2.toSeq === lengths.toSeq) + assert(fs.exists(dataFile)) + assert(fs.getFileStatus(dataFile).getLen() === 30) + assert(!fs.exists(dataTmp2)) + + // The dataFile should be the previous one + val firstByte = new Array[Byte](1) + val dataIn = fs.open(dataFile) + Utils.tryWithSafeFinally { + dataIn.read(firstByte) + } { + dataIn.close() + } + assert(firstByte(0) === 0) + + // The index file should not change + val indexIn = fs.open(indexFile) + Utils.tryWithSafeFinally { + indexIn.readLong() // the first offset is always 0 + assert(indexIn.readLong() === 10, "The index file should not change") + } { + indexIn.close() + } + + // remove data file + fs.delete(dataFile, true) + + val lengths3 = Array[Long](7, 10, 15, 3) + val dataTmp3 = RemoteShuffleUtils.tempPathWith(dataFile) + val out3 = fs.create(dataTmp3) + Utils.tryWithSafeFinally { + out3.write(Array[Byte](2)) + out3.write(new Array[Byte](34)) + } { + out3.close() + } + resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths3, dataTmp3) + assert(fs.getFileStatus(indexFile).getLen() === (lengths3.length + 1) * 8) + assert(lengths3.toSeq != lengths.toSeq) + assert(fs.exists(dataFile)) + assert(fs.getFileStatus(dataFile).getLen() === 35) + assert(!fs.exists(dataTmp3)) + + // The dataFile should be the new one, since we deleted the dataFile from the first attempt + val dataIn2 = fs.open(dataFile) + Utils.tryWithSafeFinally { + dataIn2.read(firstByte) + } { + dataIn2.close() + } + assert(firstByte(0) === 2) + + // The index file should be updated, since we deleted the dataFile from the first attempt + val indexIn2 = fs.open(indexFile) + Utils.tryWithSafeFinally { + indexIn2.readLong() // the first offset is always 0 + assert(indexIn2.readLong() === 7, "The index file should be updated") + } { + indexIn2.close() + } + } + + test("get block data") { + + shuffleManager = new RemoteShuffleManager(conf) + val resolver = shuffleManager.shuffleBlockResolver + + indexFile = resolver.getIndexFile(shuffleId, mapId) + dataFile = resolver.getDataFile(shuffleId, mapId) + val fs = resolver.fs + + val partitionId = 3 + val expected = Array[Byte](8, 7, 6, 5) + val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, partitionId) + + val lengths = Array[Long](3, 1, 2, 4) + dataTmp = RemoteShuffleUtils.tempPathWith(dataFile) + val out = fs.create(dataTmp) + Utils.tryWithSafeFinally { + out.write(Array[Byte](3, 6, 9)) + out.write(Array[Byte](1)) + out.write(Array[Byte](2, 4)) + out.write(expected) + } { + out.close() + } + // Actually this UT relies on this outside function's fine working + resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp) + + val answerBuffer = + resolver.getBlockData(shuffleBlockId).asInstanceOf[HadoopFileSegmentManagedBuffer] + val expectedBuffer = new HadoopFileSegmentManagedBuffer(dataFile, 6, 4) + assert(expectedBuffer.equals(answerBuffer)) + } + + test("createInputStream of HadoopFileSegmentManagedBuffer") { + + shuffleManager = new RemoteShuffleManager(conf) + val resolver = shuffleManager.shuffleBlockResolver + + dataFile = resolver.getDataFile(shuffleId, mapId) + val fs = resolver.fs + + val out = fs.create(dataFile) + val expected = Array[Byte](2, 4) + Utils.tryWithSafeFinally { + out.write(Array[Byte](3, 6, 9)) + out.write(Array[Byte](1)) + out.write(expected) + out.write(Array[Byte](8, 7, 6, 5)) + } { + out.close() + } + + val answer = new Array[Byte](2) + val buf = new HadoopFileSegmentManagedBuffer(dataFile, 4, 2) + val inputStream = buf.createInputStream() + inputStream.read(answer) + assert(expected === answer) + assert(inputStream.available() == 0) + } + + test("createInputStream of HadoopFileSegmentManagedBuffer, with no data") { + + shuffleManager = new RemoteShuffleManager(conf) + val resolver = shuffleManager.shuffleBlockResolver + + dataFile = resolver.getDataFile(shuffleId, mapId) + val fs = resolver.fs + + val answer = new Array[Byte](0) + val expected = new Array[Byte](0) + val buf = new HadoopFileSegmentManagedBuffer(dataFile, 4, 0) + val inputStream = buf.createInputStream() + inputStream.read(answer) + assert(expected === answer) + assert(inputStream.available() == 0) + } + + private def deleteFilesWithPrefix(prefixPath: Path): Unit = { + val fs = prefixPath.getFileSystem(new Configuration(false)) + val parentDir = prefixPath.getParent + if (fs.exists(parentDir)) { + val iter = fs.listFiles(parentDir, false) + while (iter.hasNext) { + val file = iter.next() + if (file.getPath.toString.contains(prefixPath.getName)) { + fs.delete(file.getPath, true) + } + } + } + } + + override def afterEach() { + super.afterEach() + if (dataFile != null) { + // Also delete tmp files if needed + deleteFilesWithPrefix(dataFile) + } + + if (indexFile != null) { + // Also delete tmp files if needed + deleteFilesWithPrefix(indexFile) + } + if (shuffleManager != null) { + shuffleManager.stop() + } + } + +} diff --git a/oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/shuffle/remote/RemoteShuffleManagerSuite.scala b/oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/shuffle/remote/RemoteShuffleManagerSuite.scala new file mode 100644 index 000000000..839ca4d06 --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/shuffle/remote/RemoteShuffleManagerSuite.scala @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.remote + +import org.mockserver.integration.ClientAndServer.startClientAndServer +import org.mockserver.model.{HttpRequest, HttpResponse} + +import org.apache.spark._ +import org.apache.spark.internal.config +import org.apache.spark.util.Utils + +class RemoteShuffleManagerSuite extends SparkFunSuite with LocalSparkContext { + + testWithMultiplePath("repartition")(repartition(100, 10, 20)) + testWithMultiplePath("re-large-partition")(repartition(1000000, 3, 2)) + + testWithMultiplePath( + "repartition with some map output empty")(repartitionWithEmptyMapOutput) + + testWithMultiplePath("sort")(sort(500, 13, true)) + testWithMultiplePath("sort large partition")(sort(500000, 2)) + + test("disable bypass-merge-sort shuffle writer by default") { + sc = new SparkContext("local", "test", new SparkConf(true)) + val partitioner = new HashPartitioner(100) + val rdd = sc.parallelize((1 to 10).map(x => (x, x + 1)), 10) + val dependency = new ShuffleDependency[Int, Int, Int](rdd, partitioner) + assert(RemoteShuffleManager.shouldBypassMergeSort(new SparkConf(true), dependency) + == false) + } + + test("Remote shuffle and external shuffle service cannot be enabled at the same time") { + intercept[Exception] { + sc = new SparkContext( + "local", + "test", + new SparkConf(true) + .set("spark.shuffle.manager", "org.apache.spark.shuffle.remote.RemoteShuffleManager") + .set("spark.shuffle.service.enabled", "true")) + } + } + + test("request HDFS configuration from remote storage master") { + val expectKey = "whatever" + val expectVal = "55555" + val mockHadoopConf: String = s"$expectKey" + + s"$expectVal" + val port = 56789 + + val mockServer = startClientAndServer(port) + mockServer.when(HttpRequest.request.withPath("/conf")) + .respond(HttpResponse.response().withBody(mockHadoopConf)) + + try { + val conf = new SparkConf(false) + .set("spark.shuffle.manager", "org.apache.spark.shuffle.remote.RemoteShuffleManager") + .set(RemoteShuffleConf.STORAGE_HDFS_MASTER_UI_PORT, port.toString) + .set("spark.shuffle.remote.storageMasterUri", "hdfs://localhost:9001") + val manager = new RemoteShuffleManager(conf) + assert(manager.getHadoopConf.get(expectKey) == expectVal) + } + finally { + mockServer.stop() + } + } + + test("request HDFS configuration from remote storage master:" + + " unset port or no connection cause no exception") { + val conf = new SparkConf(false) + .set("spark.shuffle.manager", "org.apache.spark.shuffle.remote.RemoteShuffleManager") + .set("spark.shuffle.remote.storageMasterUri", "hdfs://localhost:9001") + val manager = new RemoteShuffleManager(conf) + manager.getHadoopConf + } + + // Optimized shuffle writer & non-optimized shuffle writer + private def testWithMultiplePath(name: String, loadDefaults: Boolean = true) + (body: (SparkConf => Unit)): Unit = { + test(name + " with general shuffle path") { + body(createSparkConf(loadDefaults, bypassMergeSort = false, unsafeOptimized = false)) + } + test(name + " with optimized shuffle path") { + body(createSparkConf(loadDefaults, bypassMergeSort = false, unsafeOptimized = true)) + } + test(name + " with bypass-merge-sort shuffle path") { + body(createSparkConf(loadDefaults, bypassMergeSort = true, unsafeOptimized = false)) + } + test(name + " with bypass-merge-sort shuffle path + index cache") { + body(createSparkConf(loadDefaults, + bypassMergeSort = true, unsafeOptimized = false, indexCache = true)) + } + test(name + " with optimized shuffle path + index cache") { + body(createSparkConf(loadDefaults, + bypassMergeSort = false, unsafeOptimized = true, indexCache = true)) + } + test(name + " with whatever shuffle write path + constraining maxBlocksPerAdress") { + body(createSparkConf(loadDefaults, indexCache = false, setMaxBlocksPerAdress = true)) + } + test(name + " with whatever shuffle write path + index cache + constraining maxBlocksPerAdress") + { + body(createSparkConf(loadDefaults, indexCache = true, setMaxBlocksPerAdress = true)) + } + val default = RemoteShuffleConf.DATA_FETCH_EAGER_REQUIREMENT.defaultValue.get + val testWith = (true ^ default) + test(name + s" with eager requirement = ${testWith}") + { + body(createSparkConf(loadDefaults, indexCache = true) + .set(RemoteShuffleConf.DATA_FETCH_EAGER_REQUIREMENT.key, testWith.toString)) + } + } + + private def repartition( + dataSize: Int, preShuffleNumPartitions: Int, postShuffleNumPartitions: Int) + (conf: SparkConf): Unit = { + sc = new SparkContext("local", "test_repartition", conf) + val data = 0 until dataSize + val rdd = sc.parallelize(data, preShuffleNumPartitions) + val newRdd = rdd.repartition(postShuffleNumPartitions) + assert(newRdd.collect().sorted === data) + } + + private def repartitionWithEmptyMapOutput(conf: SparkConf): Unit = { + sc = new SparkContext("local", "test_repartition_empty", conf) + val data = 0 until 20 + val rdd = sc.parallelize(data, 30) + val newRdd = rdd.repartition(40) + assert(newRdd.collect().sorted === data) + } + + private def sort( + dataSize: Int, numPartitions: Int, differentMapSidePartitionLength: Boolean = false) + (conf: SparkConf): Unit = { + sc = new SparkContext("local", "sort", conf) + val data = if (differentMapSidePartitionLength) { + List.fill(dataSize/2)(0) ++ (dataSize / 2 until dataSize) + } else { + 0 until dataSize + } + val rdd = sc.parallelize(Utils.randomize(data), numPartitions) + + val newRdd = rdd.sortBy((x: Int) => x.toLong) + assert(newRdd.collect() === data) + } + + private def createSparkConf( + loadDefaults: Boolean, bypassMergeSort: Boolean = false, unsafeOptimized: Boolean = true, + indexCache: Boolean = false, setMaxBlocksPerAdress: Boolean = false): SparkConf = { + val smallThreshold = 1 + val largeThreshold = 50 + val conf = createDefaultConf(loadDefaults) + .set("spark.shuffle.optimizedPathEnabled", unsafeOptimized.toString) + .set("spark.shuffle.manager", "org.apache.spark.shuffle.remote.RemoteShuffleManager") + // Use a strict threshold as default so that Bypass-Merge-Sort shuffle writer won't be used + .set("spark.shuffle.sort.bypassMergeThreshold", smallThreshold.toString) + if (bypassMergeSort) { + // Use a loose threshold + conf.set("spark.shuffle.sort.bypassMergeThreshold", largeThreshold.toString) + } + if (indexCache) { + conf.set("spark.shuffle.remote.index.cache.size", "3m") + } + if (setMaxBlocksPerAdress) { + conf.set(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS.key, "1") + } + conf + } + +} diff --git a/oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/shuffle/remote/package.scala b/oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/shuffle/remote/package.scala new file mode 100644 index 000000000..efc4a1016 --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/shuffle/remote/package.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import org.apache.spark.SparkConf + +package object remote { + def createDefaultConf(loadDefaults: Boolean = true): SparkConf = { + new SparkConf(loadDefaults) + .set("spark.shuffle.manager", classOf[RemoteShuffleManager].getCanonicalName) + // Unit tests should not rely on external systems, using local file system as storage + .set("spark.shuffle.remote.storageMasterUri", "file://") + .set("spark.shuffle.remote.filesRootDirectory", "/tmp") + .set("spark.shuffle.sync", "true") + } + def createDefaultConfWithIndexCacheEnabled(loadDefaults: Boolean = true): SparkConf = { + createDefaultConf(loadDefaults) + .set("spark.shuffle.remote.index.cache.size", "5m") + } +} diff --git a/oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/util/collection/RemoteAppendOnlyMapSuite.scala b/oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/util/collection/RemoteAppendOnlyMapSuite.scala new file mode 100644 index 000000000..6879d4352 --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/util/collection/RemoteAppendOnlyMapSuite.scala @@ -0,0 +1,575 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration._ +import scala.ref.WeakReference + +import org.scalatest.Matchers +import org.scalatest.concurrent.Eventually + +import org.apache.spark._ +import org.apache.spark.internal.config._ +import org.apache.spark.io.CompressionCodec +import org.apache.spark.memory.MemoryTestingUtils +import org.apache.spark.shuffle.remote.RemoteShuffleBlockResolver +import org.apache.spark.util.CompletionIterator + + +/** + * TODO: Why will this UT leave uncleaned files after testing? + * [[org.apache.spark.shuffle.remote.RemoteAggregator]] && [[RemoteAppendOnlyMap]] are both tested + */ +class RemoteAppendOnlyMapSuite extends SparkFunSuite + with LocalSparkContext + with Eventually + with Matchers{ + import TestUtils.{assertNotSpilled, assertSpilled} + + private val allCompressionCodecs = CompressionCodec.ALL_COMPRESSION_CODECS + private def createCombiner[T](i: T) = ArrayBuffer[T](i) + private def mergeValue[T](buffer: ArrayBuffer[T], i: T): ArrayBuffer[T] = buffer += i + private def mergeCombiners[T](buf1: ArrayBuffer[T], buf2: ArrayBuffer[T]): ArrayBuffer[T] = + buf1 ++= buf2 + + private val resolver = new RemoteShuffleBlockResolver(createDefaultConf()) + + private def createExternalMap[T] = { + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + new RemoteAppendOnlyMap[T, T, ArrayBuffer[T]]( + createCombiner[T], mergeValue[T], mergeCombiners[T], resolver = resolver, context = context) + } + + private def createSparkConf(loadDefaults: Boolean, codec: Option[String] = None): SparkConf = { + val conf = createDefaultConf(loadDefaults) + // Make the Java serializer write a reset instruction (TC_RESET) after each object to test + // for a bug we had with bytes written past the last object in a batch (SPARK-2792) + conf.set("spark.serializer.objectStreamReset", "1") + conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") + conf.set("spark.shuffle.spill.compress", codec.isDefined.toString) + conf.set("spark.shuffle.compress", codec.isDefined.toString) + codec.foreach { c => conf.set("spark.io.compression.codec", c) } + // Ensure that we actually have multiple batches per spill file + conf.set("spark.shuffle.spill.batchSize", "10") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.remote.RemoteShuffleManager") + conf + } + + test("single insert") { + val conf = createSparkConf(loadDefaults = false) + sc = new SparkContext("local", "test", conf) + val map = createExternalMap[Int] + map.insert(1, 10) + val it = map.iterator + assert(it.hasNext) + val kv = it.next() + assert(kv._1 === 1 && kv._2 === ArrayBuffer[Int](10)) + assert(!it.hasNext) + sc.stop() + } + + test("multiple insert") { + val conf = createSparkConf(loadDefaults = false) + sc = new SparkContext("local", "test", conf) + val map = createExternalMap[Int] + map.insert(1, 10) + map.insert(2, 20) + map.insert(3, 30) + val it = map.iterator + assert(it.hasNext) + assert(it.toSet === Set[(Int, ArrayBuffer[Int])]( + (1, ArrayBuffer[Int](10)), + (2, ArrayBuffer[Int](20)), + (3, ArrayBuffer[Int](30)))) + sc.stop() + } + + test("insert with collision") { + val conf = createSparkConf(loadDefaults = false) + sc = new SparkContext("local", "test", conf) + val map = createExternalMap[Int] + + map.insertAll(Seq( + (1, 10), + (2, 20), + (3, 30), + (1, 100), + (2, 200), + (1, 1000))) + val it = map.iterator + assert(it.hasNext) + val result = it.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.toSet)) + assert(result === Set[(Int, Set[Int])]( + (1, Set[Int](10, 100, 1000)), + (2, Set[Int](20, 200)), + (3, Set[Int](30)))) + sc.stop() + } + + test("ordering") { + val conf = createSparkConf(loadDefaults = false) + sc = new SparkContext("local", "test", conf) + + val map1 = createExternalMap[Int] + map1.insert(1, 10) + map1.insert(2, 20) + map1.insert(3, 30) + + val map2 = createExternalMap[Int] + map2.insert(2, 20) + map2.insert(3, 30) + map2.insert(1, 10) + + val map3 = createExternalMap[Int] + map3.insert(3, 30) + map3.insert(1, 10) + map3.insert(2, 20) + + val it1 = map1.iterator + val it2 = map2.iterator + val it3 = map3.iterator + + var kv1 = it1.next() + var kv2 = it2.next() + var kv3 = it3.next() + assert(kv1._1 === kv2._1 && kv2._1 === kv3._1) + assert(kv1._2 === kv2._2 && kv2._2 === kv3._2) + + kv1 = it1.next() + kv2 = it2.next() + kv3 = it3.next() + assert(kv1._1 === kv2._1 && kv2._1 === kv3._1) + assert(kv1._2 === kv2._2 && kv2._2 === kv3._2) + + kv1 = it1.next() + kv2 = it2.next() + kv3 = it3.next() + assert(kv1._1 === kv2._1 && kv2._1 === kv3._1) + assert(kv1._2 === kv2._2 && kv2._2 === kv3._2) + sc.stop() + } + + test("null keys and values") { + val conf = createSparkConf(loadDefaults = false) + sc = new SparkContext("local", "test", conf) + + val map = createExternalMap[Int] + val nullInt = null.asInstanceOf[Int] + map.insert(1, 5) + map.insert(2, 6) + map.insert(3, 7) + map.insert(4, nullInt) + map.insert(nullInt, 8) + map.insert(nullInt, nullInt) + val result = map.iterator.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.sorted)) + assert(result === Set[(Int, Seq[Int])]( + (1, Seq[Int](5)), + (2, Seq[Int](6)), + (3, Seq[Int](7)), + (4, Seq[Int](nullInt)), + (nullInt, Seq[Int](nullInt, 8)) + )) + + sc.stop() + } + + test("simple aggregator") { + val conf = createSparkConf(loadDefaults = false) + sc = new SparkContext("local", "test", conf) + + // reduceByKey + val rdd = sc.parallelize(1 to 10).map(i => (i%2, 1)) + val result1 = rdd.reduceByKey(_ + _).collect() + assert(result1.toSet === Set[(Int, Int)]((0, 5), (1, 5))) + + // groupByKey + val result2 = rdd.groupByKey().collect().map(x => (x._1, x._2.toList)).toSet + assert(result2.toSet === Set[(Int, Seq[Int])] + ((0, List[Int](1, 1, 1, 1, 1)), (1, List[Int](1, 1, 1, 1, 1)))) + sc.stop() + } + + test("simple cogroup") { + val conf = createSparkConf(loadDefaults = false) + sc = new SparkContext("local", "test", conf) + val rdd1 = sc.parallelize(1 to 4).map(i => (i, i)) + val rdd2 = sc.parallelize(1 to 4).map(i => (i%2, i)) + val result = rdd1.cogroup(rdd2).collect() + + result.foreach { case (i, (seq1, seq2)) => + i match { + case 0 => assert(seq1.toSet === Set[Int]() && seq2.toSet === Set[Int](2, 4)) + case 1 => assert(seq1.toSet === Set[Int](1) && seq2.toSet === Set[Int](1, 3)) + case 2 => assert(seq1.toSet === Set[Int](2) && seq2.toSet === Set[Int]()) + case 3 => assert(seq1.toSet === Set[Int](3) && seq2.toSet === Set[Int]()) + case 4 => assert(seq1.toSet === Set[Int](4) && seq2.toSet === Set[Int]()) + } + } + sc.stop() + } + + test("spilling") { + testSimpleSpilling() + } + + test("spilling with compression") { + // Keep track of which compression codec we're using to report in test failure messages + var lastCompressionCodec: Option[String] = None + try { + allCompressionCodecs.foreach { c => + lastCompressionCodec = Some(c) + testSimpleSpilling(Some(c)) + } + } catch { + // Include compression codec used in test failure message + // We need to catch Throwable here because assertion failures are not covered by Exceptions + case t: Throwable => + val compressionMessage = lastCompressionCodec + .map { c => "with compression using codec " + c } + .getOrElse("without compression") + val newException = new Exception(s"Test failed $compressionMessage:\n\n${t.getMessage}") + newException.setStackTrace(t.getStackTrace) + throw newException + } + } + + test("spilling with compression and encryption") { + testSimpleSpilling(Some(CompressionCodec.DEFAULT_COMPRESSION_CODEC), encrypt = true) + } + + /** + * Test spilling through simple aggregations and cogroups. + * If a compression codec is provided, use it. Otherwise, do not compress spills. + */ + private def testSimpleSpilling(codec: Option[String] = None, encrypt: Boolean = false): Unit = { + val size = 1000 + val conf = createSparkConf(loadDefaults = true, codec) // Load defaults for Spark home + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 4).toString) + conf.set(IO_ENCRYPTION_ENABLED, encrypt) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + + assertSpilled(sc, "reduceByKey") { + val result = sc.parallelize(0 until size) + .map { i => (i / 2, i) }.reduceByKey(math.max).collect() + assert(result.length === size / 2) + result.foreach { case (k, v) => + val expected = k * 2 + 1 + assert(v === expected, s"Value for $k was wrong: expected $expected, got $v") + } + } + + assertSpilled(sc, "groupByKey") { + val result = sc.parallelize(0 until size).map { i => (i / 2, i) }.groupByKey().collect() + assert(result.length == size / 2) + result.foreach { case (i, seq) => + val actual = seq.toSet + val expected = Set(i * 2, i * 2 + 1) + assert(actual === expected, s"Value for $i was wrong: expected $expected, got $actual") + } + } + + assertSpilled(sc, "cogroup") { + val rdd1 = sc.parallelize(0 until size).map { i => (i / 2, i) } + val rdd2 = sc.parallelize(0 until size).map { i => (i / 2, i) } + val result = rdd1.cogroup(rdd2).collect() + assert(result.length === size / 2) + result.foreach { case (i, (seq1, seq2)) => + val actual1 = seq1.toSet + val actual2 = seq2.toSet + val expected = Set(i * 2, i * 2 + 1) + assert(actual1 === expected, s"Value 1 for $i was wrong: expected $expected, got $actual1") + assert(actual2 === expected, s"Value 2 for $i was wrong: expected $expected, got $actual2") + } + } + + sc.stop() + } + + test("RemoteAppendOnlyMap shouldn't fail when forced to spill before calling its iterator") { + val size = 1000 + val conf = createSparkConf(loadDefaults = true) + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + val map = createExternalMap[String] + val consumer = createExternalMap[String] + map.insertAll((1 to size).iterator.map(_.toString).map(i => (i, i))) + assert(map.spill(10000, consumer) == 0L) + } + + test("spilling with hash collisions") { + val size = 1000 + val conf = createSparkConf(loadDefaults = true) + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + val map = createExternalMap[String] + + val collisionPairs = Seq( + ("Aa", "BB"), // 2112 + ("to", "v1"), // 3707 + ("variants", "gelato"), // -1249574770 + ("Teheran", "Siblings"), // 231609873 + ("misused", "horsemints"), // 1069518484 + ("isohel", "epistolaries"), // -1179291542 + ("righto", "buzzards"), // -931102253 + ("hierarch", "crinolines"), // -1732884796 + ("inwork", "hypercatalexes"), // -1183663690 + ("wainages", "presentencing"), // 240183619 + ("trichothecenes", "locular"), // 339006536 + ("pomatoes", "eructation") // 568647356 + ) + + collisionPairs.foreach { case (w1, w2) => + // String.hashCode is documented to use a specific algorithm, but check just in case + assert(w1.hashCode === w2.hashCode) + } + + map.insertAll((1 to size).iterator.map(_.toString).map(i => (i, i))) + collisionPairs.foreach { case (w1, w2) => + map.insert(w1, w2) + map.insert(w2, w1) + } + assert(map.numSpills > 0, "map did not spill") + + // A map of collision pairs in both directions + val collisionPairsMap = (collisionPairs ++ collisionPairs.map(_.swap)).toMap + + // Avoid map.size or map.iterator.length because this destructively sorts the underlying map + var count = 0 + + val it = map.iterator + while (it.hasNext) { + val kv = it.next() + val expectedValue = ArrayBuffer[String](collisionPairsMap.getOrElse(kv._1, kv._1)) + assert(kv._2.equals(expectedValue)) + count += 1 + } + assert(count === size + collisionPairs.size * 2) + sc.stop() + } + + test("spilling with many hash collisions") { + val size = 1000 + val conf = createSparkConf(loadDefaults = true) + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + val map = + new RemoteAppendOnlyMap[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _, + resolver = resolver, context = context) + + // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes + // problems if the map fails to group together the objects with the same code (SPARK-2043). + for (i <- 1 to 10) { + for (j <- 1 to size) { + map.insert(FixedHashObject(j, j % 2), 1) + } + } + assert(map.numSpills > 0, "map did not spill") + + val it = map.iterator + var count = 0 + while (it.hasNext) { + val kv = it.next() + assert(kv._2 === 10) + count += 1 + } + assert(count === size) + sc.stop() + } + + test("spilling with hash collisions using the Int.MaxValue key") { + val size = 1000 + val conf = createSparkConf(loadDefaults = true) + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + val map = createExternalMap[Int] + + (1 to size).foreach { i => map.insert(i, i) } + map.insert(Int.MaxValue, Int.MaxValue) + assert(map.numSpills > 0, "map did not spill") + + val it = map.iterator + while (it.hasNext) { + // Should not throw NoSuchElementException + it.next() + } + sc.stop() + } + + test("spilling with null keys and values") { + val size = 1000 + val conf = createSparkConf(loadDefaults = true) + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + val map = createExternalMap[Int] + + map.insertAll((1 to size).iterator.map(i => (i, i))) + map.insert(null.asInstanceOf[Int], 1) + map.insert(1, null.asInstanceOf[Int]) + map.insert(null.asInstanceOf[Int], null.asInstanceOf[Int]) + assert(map.numSpills > 0, "map did not spill") + + val it = map.iterator + while (it.hasNext) { + // Should not throw NullPointerException + it.next() + } + sc.stop() + } + + test("SPARK-22713 spill during iteration leaks internal map") { + val size = 1000 + val conf = createSparkConf(loadDefaults = true) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + val map = createExternalMap[Int] + + map.insertAll((0 until size).iterator.map(i => (i / 10, i))) + assert(map.numSpills == 0, "map was not supposed to spill") + + val it = map.iterator + assert(it.isInstanceOf[CompletionIterator[_, _]]) + // org.apache.spark.util.collection.AppendOnlyMap.destructiveSortedIterator returns + // an instance of an annonymous Iterator class. + + val underlyingMapRef = WeakReference(map.currentMap) + + { + // direct asserts introduced some macro generated code that held a reference to the map + val tmpIsNull = null == underlyingMapRef.get.orNull + assert(!tmpIsNull) + } + + val first50Keys = for ( _ <- 0 until 50) yield { + val (k, vs) = it.next + val sortedVs = vs.sorted + assert(sortedVs.seq == (0 until 10).map(10 * k + _)) + k + } + assert(map.numSpills == 0) + map.spill(Long.MaxValue, null) + // these asserts try to show that we're no longer holding references to the underlying map. + // it'd be nice to use something like + // https://github.com/scala/scala/blob/2.13.x/test/junit/scala/tools/testing/AssertUtil.scala + // (lines 69-89) + // assert(map.currentMap == null) + eventually(timeout(5 seconds), interval(200 milliseconds)) { + System.gc() + // direct asserts introduced some macro generated code that held a reference to the map + val tmpIsNull = null == underlyingMapRef.get.orNull + assert(tmpIsNull) + } + + + val next50Keys = for ( _ <- 0 until 50) yield { + val (k, vs) = it.next + val sortedVs = vs.sorted + assert(sortedVs.seq == (0 until 10).map(10 * k + _)) + k + } + assert(!it.hasNext) + val keys = (first50Keys ++ next50Keys).sorted + assert(keys == (0 until 100)) + } + + test("drop all references to the underlying map once the iterator is exhausted") { + val size = 1000 + val conf = createSparkConf(loadDefaults = true) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + val map = createExternalMap[Int] + + map.insertAll((0 until size).iterator.map(i => (i / 10, i))) + assert(map.numSpills == 0, "map was not supposed to spill") + + val underlyingMapRef = WeakReference(map.currentMap) + + { + // direct asserts introduced some macro generated code that held a reference to the map + val tmpIsNull = null == underlyingMapRef.get.orNull + assert(!tmpIsNull) + } + + val it = map.iterator + assert( it.isInstanceOf[CompletionIterator[_, _]]) + + + val keys = it.map{ + case (k, vs) => + val sortedVs = vs.sorted + assert(sortedVs.seq == (0 until 10).map(10 * k + _)) + k + } + .toList + .sorted + + assert(it.isEmpty) + assert(keys == (0 until 100).toList) + + assert(map.numSpills == 0) + // these asserts try to show that we're no longer holding references to the underlying map. + // it'd be nice to use something like + // https://github.com/scala/scala/blob/2.13.x/test/junit/scala/tools/testing/AssertUtil.scala + // (lines 69-89) + assert(map.currentMap == null) + + eventually { + Thread.sleep(500) + System.gc() + // direct asserts introduced some macro generated code that held a reference to the map + val tmpIsNull = null == underlyingMapRef.get.orNull + assert(tmpIsNull) + } + + assert(it.toList.isEmpty) + } + + test("SPARK-22713 external aggregation updates peak execution memory") { + val spillThreshold = 1000 + val conf = createSparkConf(loadDefaults = false) + .set("spark.shuffle.spill.numElementsForceSpillThreshold", spillThreshold.toString) + sc = new SparkContext("local", "test", conf) + // No spilling + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external map without spilling") { + assertNotSpilled(sc, "verify peak memory") { + sc.parallelize(1 to spillThreshold / 2, 2).map { i => (i, i) }.reduceByKey(_ + _).count() + } + } + // With spilling + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external map with spilling") { + assertSpilled(sc, "verify peak memory") { + sc.parallelize(1 to spillThreshold * 3, 2).map { i => (i, i) }.reduceByKey(_ + _).count() + } + } + } + + test("force to spill for external aggregation") { + val conf = createSparkConf(loadDefaults = false) + .set("spark.shuffle.memoryFraction", "0.01") + .set("spark.memory.useLegacyMode", "true") + .set("spark.testing.memory", "500000000") + .set("spark.shuffle.sort.bypassMergeThreshold", "0") + sc = new SparkContext("local", "test", conf) + val N = 2e5.toInt + sc.parallelize(1 to N, 2) + .map { i => (i, i) } + .groupByKey() + .reduceByKey(_ ++ _) + .count() + } + +} diff --git a/oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/util/collection/RemoteSorterSuite.scala b/oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/util/collection/RemoteSorterSuite.scala new file mode 100644 index 000000000..f7c07927c --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/util/collection/RemoteSorterSuite.scala @@ -0,0 +1,696 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import scala.collection.mutable.ArrayBuffer +import scala.util.Random + +import org.apache.hadoop.fs.Path + +import org.apache.spark._ +import org.apache.spark.TestUtils.assertSpilled +import org.apache.spark.memory.MemoryTestingUtils +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerInstance, SerializerManager} +import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.shuffle.remote.{RemoteAggregator, RemoteShuffleBlockResolver, RemoteShuffleUtils} +import org.apache.spark.storage.ShuffleBlockId + +class RemoteSorterSuite extends SparkFunSuite with LocalSparkContext { + + var sorter: RemoteSorter[Int, Int, Int] = _ + var resolver: RemoteShuffleBlockResolver = _ + + testWithMultipleSer("empty data stream")(emptyDataStream) + + testWithMultipleSer("few elements per partition")(fewElementsPerPartition) + + testWithMultipleSer("empty partitions with spilling")(emptyPartitionsWithSpilling) + + // Load defaults, otherwise SPARK_HOME is not found + testWithMultipleSer("spilling in local cluster", loadDefaults = true) { + (conf: SparkConf) => testSpillingInLocalCluster(conf, 2) + } + + testWithMultipleSer("spilling in local cluster with many reduce tasks", loadDefaults = true) { + (conf: SparkConf) => testSpillingInLocalCluster(conf, 100) + } + + test("cleanup of intermediate files in sorter") { + cleanupIntermediateFilesInSorter(withFailures = false) + } + + test("cleanup of intermediate files in sorter with failures") { + cleanupIntermediateFilesInSorter(withFailures = true) + } + + test("cleanup of intermediate files in shuffle") { + cleanupIntermediateFilesInShuffle(withFailures = false) + } + + test("cleanup of intermediate files in shuffle with failures") { + cleanupIntermediateFilesInShuffle(withFailures = true) + } + + testWithMultipleSer("no sorting or partial aggregation") { (conf: SparkConf) => + basicSorterTest(conf, withPartialAgg = false, withOrdering = false, withSpilling = false) + } + + testWithMultipleSer("no sorting or partial aggregation with spilling") { (conf: SparkConf) => + basicSorterTest(conf, withPartialAgg = false, withOrdering = false, withSpilling = true) + } + + testWithMultipleSer("sorting, no partial aggregation") { (conf: SparkConf) => + basicSorterTest(conf, withPartialAgg = false, withOrdering = true, withSpilling = false) + } + + testWithMultipleSer("sorting, no partial aggregation with spilling") { (conf: SparkConf) => + basicSorterTest(conf, withPartialAgg = false, withOrdering = true, withSpilling = true) + } + + testWithMultipleSer("partial aggregation, no sorting") { (conf: SparkConf) => + basicSorterTest(conf, withPartialAgg = true, withOrdering = false, withSpilling = false) + } + + testWithMultipleSer("partial aggregation, no sorting with spilling") { (conf: SparkConf) => + basicSorterTest(conf, withPartialAgg = true, withOrdering = false, withSpilling = true) + } + + testWithMultipleSer("partial aggregation and sorting") { (conf: SparkConf) => + basicSorterTest(conf, withPartialAgg = true, withOrdering = true, withSpilling = false) + } + + testWithMultipleSer("partial aggregation and sorting with spilling") { (conf: SparkConf) => + basicSorterTest(conf, withPartialAgg = true, withOrdering = true, withSpilling = true) + } + + testWithMultipleSer("sort without breaking sorting contracts", loadDefaults = true)( + sortWithoutBreakingSortingContracts) + + test("spilling with hash collisions") { + val size = 1000 + val conf = createSparkConf(loadDefaults = true, kryo = false) + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + resolver = + SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[RemoteShuffleBlockResolver] + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + + def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i) + def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i + def mergeCombiners(buffer1: ArrayBuffer[String], + buffer2: ArrayBuffer[String]): ArrayBuffer[String] = buffer1 ++= buffer2 + + val agg = new Aggregator[String, String, ArrayBuffer[String]]( + createCombiner _, mergeValue _, mergeCombiners _) + + val sorter = new RemoteSorter[String, String, ArrayBuffer[String]]( + context, resolver, Some(agg), None, None) + + val collisionPairs = Seq( + ("Aa", "BB"), // 2112 + ("to", "v1"), // 3707 + ("variants", "gelato"), // -1249574770 + ("Teheran", "Siblings"), // 231609873 + ("misused", "horsemints"), // 1069518484 + ("isohel", "epistolaries"), // -1179291542 + ("righto", "buzzards"), // -931102253 + ("hierarch", "crinolines"), // -1732884796 + ("inwork", "hypercatalexes"), // -1183663690 + ("wainages", "presentencing"), // 240183619 + ("trichothecenes", "locular"), // 339006536 + ("pomatoes", "eructation") // 568647356 + ) + + collisionPairs.foreach { case (w1, w2) => + // String.hashCode is documented to use a specific algorithm, but check just in case + assert(w1.hashCode === w2.hashCode) + } + + val toInsert = (1 to size).iterator.map(_.toString).map(s => (s, s)) ++ + collisionPairs.iterator ++ collisionPairs.iterator.map(_.swap) + + sorter.insertAll(toInsert) + assert(sorter.numSpills > 0, "sorter did not spill") + + // A map of collision pairs in both directions + val collisionPairsMap = (collisionPairs ++ collisionPairs.map(_.swap)).toMap + + // Avoid map.size or map.iterator.length because this destructively sorts the underlying map + var count = 0 + + val it = sorter.iterator + while (it.hasNext) { + val kv = it.next() + val expectedValue = ArrayBuffer[String](collisionPairsMap.getOrElse(kv._1, kv._1)) + assert(kv._2.equals(expectedValue)) + count += 1 + } + assert(count === size + collisionPairs.size * 2) + } + + test("spilling with many hash collisions") { + val size = 1000 + val conf = createSparkConf(loadDefaults = true, kryo = false) + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + resolver = + SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[RemoteShuffleBlockResolver] + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + val agg = new Aggregator[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _) + val sorter = + new RemoteSorter[FixedHashObject, Int, Int](context, resolver, Some(agg), None, None) + // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes + // problems if the map fails to group together the objects with the same code (SPARK-2043). + val toInsert = for (i <- 1 to 10; j <- 1 to size) yield (FixedHashObject(j, j % 2), 1) + sorter.insertAll(toInsert.iterator) + assert(sorter.numSpills > 0, "sorter did not spill") + val it = sorter.iterator + var count = 0 + while (it.hasNext) { + val kv = it.next() + assert(kv._2 === 10) + count += 1 + } + assert(count === size) + } + + test("spilling with hash collisions using the Int.MaxValue key") { + val size = 1000 + val conf = createSparkConf(loadDefaults = true, kryo = false) + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + resolver = + SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[RemoteShuffleBlockResolver] + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + + def createCombiner(i: Int): ArrayBuffer[Int] = ArrayBuffer[Int](i) + def mergeValue(buffer: ArrayBuffer[Int], i: Int): ArrayBuffer[Int] = buffer += i + def mergeCombiners(buf1: ArrayBuffer[Int], buf2: ArrayBuffer[Int]): ArrayBuffer[Int] = { + buf1 ++= buf2 + } + + val agg = new Aggregator[Int, Int, ArrayBuffer[Int]](createCombiner, mergeValue, mergeCombiners) + val sorter = + new RemoteSorter[Int, Int, ArrayBuffer[Int]](context, resolver, Some(agg), None, None) + sorter.insertAll( + (1 to size).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue))) + assert(sorter.numSpills > 0, "sorter did not spill") + val it = sorter.iterator + while (it.hasNext) { + // Should not throw NoSuchElementException + it.next() + } + } + + test("spilling with null keys and values") { + val size = 1000 + val conf = createSparkConf(loadDefaults = true, kryo = false) + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + resolver = + SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[RemoteShuffleBlockResolver] + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + + def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i) + def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i + def mergeCombiners(buf1: ArrayBuffer[String], buf2: ArrayBuffer[String]): ArrayBuffer[String] = + buf1 ++= buf2 + + val agg = new Aggregator[String, String, ArrayBuffer[String]]( + createCombiner, mergeValue, mergeCombiners) + + val sorter = new RemoteSorter[String, String, ArrayBuffer[String]]( + context, resolver, Some(agg), None, None) + + sorter.insertAll((1 to size).iterator.map(i => (i.toString, i.toString)) ++ Iterator( + (null.asInstanceOf[String], "1"), + ("1", null.asInstanceOf[String]), + (null.asInstanceOf[String], null.asInstanceOf[String]) + )) + assert(sorter.numSpills > 0, "sorter did not spill") + val it = sorter.iterator + while (it.hasNext) { + // Should not throw NullPointerException + it.next() + } + } + + /* ============================= * + | Helper test utility methods | + * ============================= */ + + private def createSparkConf(loadDefaults: Boolean, kryo: Boolean): SparkConf = { + val conf = createDefaultConf(loadDefaults) + if (kryo) { + conf.set("spark.serializer", classOf[KryoSerializer].getName) + } else { + // Make the Java serializer write a reset instruction (TC_RESET) after each object to test + // for a bug we had with bytes written past the last object in a batch (SPARK-2792) + conf.set("spark.serializer.objectStreamReset", "1") + .set("spark.serializer", classOf[JavaSerializer].getName) + } + conf.set("spark.shuffle.sort.bypassMergeThreshold", "0") + // Ensure that we actually have multiple batches per spill file + .set("spark.shuffle.spill.batchSize", "10") + .set("spark.shuffle.spill.initialMemoryThreshold", "512") + .set("spark.shuffle.manager", "org.apache.spark.shuffle.remote.RemoteShuffleManager") + } + + /** + * Run a test multiple times, each time with a different serializer. + */ + private def testWithMultipleSer( + name: String, + loadDefaults: Boolean = false)(body: (SparkConf => Unit)): Unit = { + test(name + " with kryo ser") { + body(createSparkConf(loadDefaults, kryo = true)) + } + test(name + " with java ser") { + body(createSparkConf(loadDefaults, kryo = false)) + } + } + + /* =========================================== * + | Helper methods that contain the test body | + * =========================================== */ + + private def emptyDataStream(conf: SparkConf) { + sc = new SparkContext("local", "test", conf) + resolver = + SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[RemoteShuffleBlockResolver] + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + + val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) + val ord = implicitly[Ordering[Int]] + + // Both aggregator and ordering + val sorter = new RemoteSorter[Int, Int, Int]( + context, resolver, Some(agg), Some(new HashPartitioner(3)), Some(ord)) + assert(sorter.iterator.toSeq === Seq()) + sorter.stop() + + // Only aggregator + val sorter2 = new RemoteSorter[Int, Int, Int]( + context, resolver, Some(agg), Some(new HashPartitioner(3)), None) + assert(sorter2.iterator.toSeq === Seq()) + sorter2.stop() + + // Only ordering + val sorter3 = new RemoteSorter[Int, Int, Int]( + context, resolver, None, Some(new HashPartitioner(3)), Some(ord)) + assert(sorter3.iterator.toSeq === Seq()) + sorter3.stop() + + // Neither aggregator nor ordering + val sorter4 = new RemoteSorter[Int, Int, Int]( + context, resolver, None, Some(new HashPartitioner(3)), None) + assert(sorter4.iterator.toSeq === Seq()) + sorter4.stop() + } + + private def fewElementsPerPartition(conf: SparkConf) { + sc = new SparkContext("local", "test", conf) + resolver = + SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[RemoteShuffleBlockResolver] + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + + val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) + val ord = implicitly[Ordering[Int]] + val elements = Set((1, 1), (2, 2), (5, 5)) + val expected = Set( + (0, Set()), (1, Set((1, 1))), (2, Set((2, 2))), (3, Set()), (4, Set()), + (5, Set((5, 5))), (6, Set())) + + // Both aggregator and ordering + val sorter = new RemoteSorter[Int, Int, Int]( + context, resolver, Some(agg), Some(new HashPartitioner(7)), Some(ord)) + sorter.insertAll(elements.iterator) + assert(sorter.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) + sorter.stop() + + // Only aggregator + val sorter2 = new RemoteSorter[Int, Int, Int]( + context, resolver, Some(agg), Some(new HashPartitioner(7)), None) + sorter2.insertAll(elements.iterator) + assert(sorter2.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) + sorter2.stop() + + // Only ordering + val sorter3 = new RemoteSorter[Int, Int, Int]( + context, resolver, None, Some(new HashPartitioner(7)), Some(ord)) + sorter3.insertAll(elements.iterator) + assert(sorter3.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) + sorter3.stop() + + // Neither aggregator nor ordering + val sorter4 = new RemoteSorter[Int, Int, Int]( + context, resolver, None, Some(new HashPartitioner(7)), None) + sorter4.insertAll(elements.iterator) + assert(sorter4.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) + sorter4.stop() + } + + private def emptyPartitionsWithSpilling(conf: SparkConf) { + val size = 1000 + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) + sc = new SparkContext("local", "test", conf) + resolver = + SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[RemoteShuffleBlockResolver] + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + + val ord = implicitly[Ordering[Int]] + val elements = Iterator((1, 1), (5, 5)) ++ (0 until size).iterator.map(x => (2, 2)) + + val sorter = new RemoteSorter[Int, Int, Int]( + context, resolver, None, Some(new HashPartitioner(7)), Some(ord)) + sorter.insertAll(elements) + assert(sorter.numSpills > 0, "sorter did not spill") + val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList)) + assert(iter.next() === ((0, Nil))) + assert(iter.next() === ((1, List((1, 1))))) + assert(iter.next() === ((2, (0 until 1000).map(x => (2, 2)).toList))) + assert(iter.next() === ((3, Nil))) + assert(iter.next() === ((4, Nil))) + assert(iter.next() === ((5, List((5, 5))))) + assert(iter.next() === ((6, Nil))) + sorter.stop() + } + + private def testSpillingInLocalCluster(conf: SparkConf, numReduceTasks: Int) { + val size = 5000 + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 4).toString) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + + assertSpilled(sc, "reduceByKey") { + val result = sc.parallelize(0 until size) + .map { i => (i / 2, i) } + .reduceByKey(math.max _, numReduceTasks) + .collect() + assert(result.length === size / 2) + result.foreach { case (k, v) => + val expected = k * 2 + 1 + assert(v === expected, s"Value for $k was wrong: expected $expected, got $v") + } + } + + assertSpilled(sc, "groupByKey") { + val result = sc.parallelize(0 until size) + .map { i => (i / 2, i) } + .groupByKey(numReduceTasks) + .collect() + assert(result.length == size / 2) + result.foreach { case (i, seq) => + val actual = seq.toSet + val expected = Set(i * 2, i * 2 + 1) + assert(actual === expected, s"Value for $i was wrong: expected $expected, got $actual") + } + } + + assertSpilled(sc, "cogroup") { + val rdd1 = sc.parallelize(0 until size).map { i => (i / 2, i) } + val rdd2 = sc.parallelize(0 until size).map { i => (i / 2, i) } + val result = rdd1.cogroup(rdd2, numReduceTasks).collect() + assert(result.length === size / 2) + result.foreach { case (i, (seq1, seq2)) => + val actual1 = seq1.toSet + val actual2 = seq2.toSet + val expected = Set(i * 2, i * 2 + 1) + assert(actual1 === expected, s"Value 1 for $i was wrong: expected $expected, got $actual1") + assert(actual2 === expected, s"Value 2 for $i was wrong: expected $expected, got $actual2") + } + } + + assertSpilled(sc, "sortByKey") { + val result = sc.parallelize(0 until size) + .map { i => (i / 2, i) } + .sortByKey(numPartitions = numReduceTasks) + .collect() + val expected = (0 until size).map { i => (i / 2, i) }.toArray + assert(result.length === size) + result.zipWithIndex.foreach { case ((k, _), i) => + val (expectedKey, _) = expected(i) + assert(k === expectedKey, s"Value for $i was wrong: expected $expectedKey, got $k") + } + } + } + + + private def cleanupIntermediateFilesInSorter(withFailures: Boolean): Unit = { + val size = 1200 + val conf = createSparkConf(loadDefaults = false, kryo = false) + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 4).toString) + sc = new SparkContext("local", "test", conf) + resolver = + SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[RemoteShuffleBlockResolver] + val diskBlockManager = sc.env.blockManager.diskBlockManager + val ord = implicitly[Ordering[Int]] + val expectedSize = if (withFailures) size - 1 else size + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + val sorter = new RemoteSorter[Int, Int, Int]( + context, resolver, None, Some(new HashPartitioner(3)), Some(ord)) + if (withFailures) { + intercept[SparkException] { + sorter.insertAll((0 until size).iterator.map { i => + if (i == size - 1) { throw new SparkException("intentional failure") } + (i, i) + }) + } + } else { + sorter.insertAll((0 until size).iterator.map(i => (i, i))) + } + assert(sorter.iterator.toSet === (0 until expectedSize).map(i => (i, i)).toSet) + assert(sorter.numSpills > 0, "sorter did not spill") + assert(resolver.getAllFiles().nonEmpty, "sorter did not spill") + sorter.stop() + assert(resolver.getAllFiles().isEmpty, "spilled files were not cleaned up") + } + + private def cleanupIntermediateFilesInShuffle(withFailures: Boolean): Unit = { + val size = 1200 + val conf = createSparkConf(loadDefaults = false, kryo = false) + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 4).toString) + sc = new SparkContext("local", "test", conf) + resolver = + SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[RemoteShuffleBlockResolver] + val data = sc.parallelize(0 until size, 2).map { i => + if (withFailures && i == size - 1) { + throw new SparkException("intentional failure") + } + (i, i) + } + + assertSpilled(sc, "test shuffle cleanup") { + if (withFailures) { + intercept[SparkException] { + data.reduceByKey(_ + _).count() + } + // After the shuffle, there should be only 2 files on disk: the output of task 1 and + // its index. All other files (map 2's output and intermediate merge files) should + // have been deleted. + assert(resolver.getAllFiles().length === 2) + } else { + assert(data.reduceByKey(_ + _).count() === size) + // After the shuffle, there should be only 4 files on disk: the output of both tasks + // and their indices. All intermediate merge files should have been deleted. + assert(resolver.getAllFiles().length === 4) + } + } + } + + /* =========================================== * + | Helper methods that contain the test body | + * =========================================== */ + private def basicSorterTest( + conf: SparkConf, + withPartialAgg: Boolean, + withOrdering: Boolean, + withSpilling: Boolean) { + val size = 1000 + if (withSpilling) { + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) + } + sc = new SparkContext("local", "test", conf) + resolver = + SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[RemoteShuffleBlockResolver] + val agg = + if (withPartialAgg) { + Some(new RemoteAggregator( + new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j), resolver)) + } else { + None + } + val ord = if (withOrdering) Some(implicitly[Ordering[Int]]) else None + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + sorter = new RemoteSorter[Int, Int, Int]( + context, resolver, agg, Some(new HashPartitioner(3)), ord) + sorter.insertAll((0 until size).iterator.map { i => (i / 4, i) }) + if (withSpilling) { + assert(sorter.numSpills > 0, "sorter did not spill") + } else { + assert(sorter.numSpills === 0, "sorter spilled") + } + val results = sorter.partitionedIterator.map { case (p, vs) => (p, vs.toSet) }.toSet + val expected = (0 until 3).map { p => + var v = (0 until size).map { i => (i / 4, i) }.filter { case (k, _) => k % 3 == p }.toSet + if (withPartialAgg) { + v = v.groupBy(_._1).mapValues { s => s.map(_._2).sum }.toSet + } + (p, v.toSet) + }.toSet + assert(results === expected) + sorter.stop() + } + + private def sortWithoutBreakingSortingContracts(conf: SparkConf) { + val size = 100000 + val conf = createSparkConf(loadDefaults = true, kryo = false) + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + resolver = + SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[RemoteShuffleBlockResolver] + + // Using wrongOrdering to show integer overflow introduced exception. + val rand = new Random(100L) + val wrongOrdering = new Ordering[String] { + override def compare(a: String, b: String): Int = { + val h1 = if (a == null) 0 else a.hashCode() + val h2 = if (b == null) 0 else b.hashCode() + h1 - h2 + } + } + + val testData = Array.tabulate(size) { _ => rand.nextInt().toString } + + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + val sorter1 = new RemoteSorter[String, String, String]( + context, resolver, None, None, Some(wrongOrdering)) + val thrown = intercept[IllegalArgumentException] { + sorter1.insertAll(testData.iterator.map(i => (i, i))) + assert(sorter1.numSpills > 0, "sorter did not spill") + sorter1.iterator + } + + assert(thrown.getClass === classOf[IllegalArgumentException]) + assert(thrown.getMessage.contains("Comparison method violates its general contract")) + sorter1.stop() + + // Using aggregation and external spill to make sure RemoteSorter using + // partitionKeyComparator. + def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer(i) + def mergeValue(c: ArrayBuffer[String], i: String): ArrayBuffer[String] = c += i + def mergeCombiners(c1: ArrayBuffer[String], c2: ArrayBuffer[String]): ArrayBuffer[String] = + c1 ++= c2 + + val agg = new Aggregator[String, String, ArrayBuffer[String]]( + createCombiner, mergeValue, mergeCombiners) + + val sorter2 = new RemoteSorter[String, String, ArrayBuffer[String]]( + context, resolver, Some(agg), None, None) + sorter2.insertAll(testData.iterator.map(i => (i, i))) + assert(sorter2.numSpills > 0, "sorter did not spill") + + // To validate the hash ordering of key + var minKey = Int.MinValue + sorter2.iterator.foreach { case (k, v) => + val h = k.hashCode() + assert(h >= minKey) + minKey = h + } + + sorter2.stop() + } + + private class SimpleRemoteBlockObjectReader[K, V]( + serializerManager: SerializerManager, serializerInstance: SerializerInstance) { + + def read(mapperInfo: ShuffleBlockId, + startPartition: Int, + endPartition: Int, + resolver: RemoteShuffleBlockResolver, + file: Path) + : Iterator[Product2[K, V]] = { + val fs = resolver.fs + val inputStream = fs.open(file) + (startPartition until endPartition).flatMap { i => + val blockId = ShuffleBlockId(mapperInfo.shuffleId, mapperInfo.mapId, i) + val buf = resolver.getBlockData(blockId) + + val rawStream = buf.createInputStream() + serializerInstance.deserializeStream( + serializerManager.wrapStream(blockId, rawStream)) + .asKeyValueIterator.asInstanceOf[Iterator[Product2[K, V]]] + }.toIterator + } + } + + private def basicSorterWrite( + conf: SparkConf, + withPartialAgg: Boolean, + withOrdering: Boolean, + withSpilling: Boolean) { + val size = 1000 + if (withSpilling) { + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) + } + sc = new SparkContext("local", "test", conf) + val shuffleManager = SparkEnv.get.shuffleManager + val resolver = shuffleManager.shuffleBlockResolver.asInstanceOf[RemoteShuffleBlockResolver] + val agg = + if (withPartialAgg) { + Some(new RemoteAggregator( + new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j), resolver)) + } else { + None + } + val ord = if (withOrdering) Some(implicitly[Ordering[Int]]) else None + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + sorter = new RemoteSorter[Int, Int, Int]( + context, resolver, agg, Some(new HashPartitioner(3)), ord) + sorter.insertAll((0 until size).iterator.map { i => (i / 4, i) }) + if (withSpilling) { + assert(sorter.numSpills > 0, "sorter did not spill") + } else { + assert(sorter.numSpills === 0, "sorter spilled") + } + + val (shuffleId, mapId) = (66, 666) + val testShuffleBlockId = ShuffleBlockId( + shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) + val path = resolver.getDataFile(shuffleId, mapId) + val tmp = RemoteShuffleUtils.tempPathWith(path) + val lengths = sorter.writePartitionedFile(testShuffleBlockId, tmp) + resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths, tmp) + + val results = + new SimpleRemoteBlockObjectReader[Int, Int]( + sc.env.serializerManager, sc.env.serializer.newInstance()).read( + testShuffleBlockId, 0, lengths.length, resolver, path).toSet + val expected = (0 until size).map { i => (i / 4, i)}.toSet + + assert(results === expected) + + } + + override def afterEach(): Unit = { + super.afterEach() + if (sorter != null) { + sorter.stop() + } + } + +} diff --git a/oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/util/collection/package.scala b/oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/util/collection/package.scala new file mode 100644 index 000000000..445198cbb --- /dev/null +++ b/oap-shuffle/remote-shuffle/src/test/scala/org/apache/spark/util/collection/package.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.apache.spark.SparkConf +import org.apache.spark.shuffle.remote.RemoteShuffleManager + +package object collection { + def createDefaultConf(loadDefaults: Boolean = true): SparkConf = { + new SparkConf(loadDefaults) + .set("spark.shuffle.manager", classOf[RemoteShuffleManager].getCanonicalName) + // Unit tests should not rely on external systems, using local file system as storage + .set("spark.shuffle.remote.storageMasterUri", "file://") + .set("spark.shuffle.remote.filesRootDirectory", "/tmp") + .set("spark.shuffle.sync", "true") + } +} diff --git a/oap-shuffle/remote-shuffle/test-jar-with-dependencies.xml b/oap-shuffle/remote-shuffle/test-jar-with-dependencies.xml new file mode 100644 index 000000000..277b87efe --- /dev/null +++ b/oap-shuffle/remote-shuffle/test-jar-with-dependencies.xml @@ -0,0 +1,19 @@ + + test-jar-with-dependencies + + jar + + false + + + / + true + + true + true + test + + + \ No newline at end of file