From 2c4999ea60b28ed84de60fefebc9aee4a74da8c4 Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Tue, 28 Mar 2023 19:02:47 +0800 Subject: [PATCH] [#706] feat(spark): support spill to avoid memory deadlock (#714) ### What changes were proposed in this pull request? 1. Introduce the `DataPusher` to replace the `eventLoop`, this could be as general part for spark2 and spark3. 2. Implement the `spill` method in `WriterBufferManager` to avoid memory deadlock. ### Why are the changes needed? In current codebase, if having several `WriterBufferManagers`, when each other is acquiring memory, the deadlock will happen. To solve this, we should implement spill function to break this deadlock condition. Fix: #706 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? 1. Existing UTs 2. Newly added UTs --- .../apache/spark/shuffle/RssSparkConfig.java | 20 ++- .../spark/shuffle/writer/AddBlockEvent.java | 21 +++ .../spark/shuffle/writer/DataPusher.java | 137 ++++++++++++++++++ .../shuffle/writer/WriteBufferManager.java | 115 ++++++++++++++- .../spark/shuffle/writer/DataPusherTest.java | 119 +++++++++++++++ .../writer/WriteBufferManagerTest.java | 79 ++++++++++ .../spark/shuffle/RssShuffleManager.java | 128 ++++++---------- .../shuffle/writer/RssShuffleWriter.java | 25 +--- .../shuffle/writer/RssShuffleWriterTest.java | 100 ++++++++----- .../spark/shuffle/RssShuffleManager.java | 125 +++++----------- .../shuffle/writer/RssShuffleWriter.java | 25 +--- .../org/apache/spark/shuffle/TestUtils.java | 7 +- .../shuffle/writer/RssShuffleWriterTest.java | 125 ++++++++++------ .../uniffle/common/util/ThreadUtils.java | 23 ++- .../uniffle/common/util/ThreadUtilsTest.java | 47 ++++++ 15 files changed, 782 insertions(+), 314 deletions(-) create mode 100644 client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java create mode 100644 client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java create mode 100644 common/src/test/java/org/apache/uniffle/common/util/ThreadUtilsTest.java diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java index c8e3478ccf..6aa90e1af1 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java @@ -28,11 +28,26 @@ import scala.runtime.AbstractFunction1; import org.apache.uniffle.client.util.RssClientConfig; +import org.apache.uniffle.common.config.ConfigOption; +import org.apache.uniffle.common.config.ConfigOptions; import org.apache.uniffle.common.config.ConfigUtils; import org.apache.uniffle.common.config.RssConf; public class RssSparkConfig { + public static final ConfigOption RSS_CLIENT_SEND_SIZE_LIMITATION = ConfigOptions + .key("rss.client.send.size.limit") + .longType() + .defaultValue(1024 * 1024 * 16L) + .withDescription("The max data size sent to shuffle server"); + + public static final ConfigOption RSS_MEMORY_SPILL_TIMEOUT = ConfigOptions + .key("rss.client.memory.spill.timeout.sec") + .intType() + .defaultValue(1) + .withDescription("The timeout of spilling data to remote shuffle server, " + + "which will be triggered by Spark TaskMemoryManager. Unit is sec, default value is 1"); + public static final String SPARK_RSS_CONFIG_PREFIX = "spark."; public static final ConfigEntry RSS_PARTITION_NUM_PER_RANGE = createIntegerBuilder( @@ -115,11 +130,6 @@ public class RssSparkConfig { new ConfigBuilder("spark.rss.client.heartBeat.threadNum")) .createWithDefault(4); - public static final ConfigEntry RSS_CLIENT_SEND_SIZE_LIMIT = createStringBuilder( - new ConfigBuilder("spark.rss.client.send.size.limit") - .doc("The max data size sent to shuffle server")) - .createWithDefault("16m"); - public static final ConfigEntry RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE = createIntegerBuilder( new ConfigBuilder("spark.rss.client.unregister.thread.pool.size")) .createWithDefault(10); diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java index a888975436..7dab0725cc 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java @@ -17,6 +17,7 @@ package org.apache.spark.shuffle.writer; +import java.util.ArrayList; import java.util.List; import org.apache.uniffle.common.ShuffleBlockInfo; @@ -25,10 +26,26 @@ public class AddBlockEvent { private String taskId; private List shuffleDataInfoList; + private List processedCallbackChain; public AddBlockEvent(String taskId, List shuffleDataInfoList) { this.taskId = taskId; this.shuffleDataInfoList = shuffleDataInfoList; + this.processedCallbackChain = new ArrayList<>(); + } + + public AddBlockEvent(String taskId, List shuffleBlockInfoList, Runnable callback) { + this.taskId = taskId; + this.shuffleDataInfoList = shuffleBlockInfoList; + this.processedCallbackChain = new ArrayList<>(); + addCallback(callback); + } + + /** + * @param callback, should not throw any exception and execute fast. + */ + public void addCallback(Runnable callback) { + processedCallbackChain.add(callback); } public String getTaskId() { @@ -39,6 +56,10 @@ public List getShuffleDataInfoList() { return shuffleDataInfoList; } + public List getProcessedCallbackChain() { + return processedCallbackChain; + } + @Override public String toString() { return "AddBlockEvent: TaskId[" + taskId + "], " + shuffleDataInfoList; diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java new file mode 100644 index 0000000000..ca03a784bb --- /dev/null +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.writer; + +import java.io.Closeable; +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +import com.google.common.collect.Queues; +import com.google.common.collect.Sets; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.uniffle.client.api.ShuffleWriteClient; +import org.apache.uniffle.client.response.SendShuffleDataResult; +import org.apache.uniffle.common.ShuffleBlockInfo; +import org.apache.uniffle.common.exception.RssException; +import org.apache.uniffle.common.util.ThreadUtils; + +/** + * A {@link DataPusher} that is responsible for sending data to remote + * shuffle servers asynchronously. + */ +public class DataPusher implements Closeable { + private static final Logger LOGGER = LoggerFactory.getLogger(DataPusher.class); + + private final ExecutorService executorService; + + private final ShuffleWriteClient shuffleWriteClient; + // Must be thread safe + private final Map> taskToSuccessBlockIds; + // Must be thread safe + private final Map> taskToFailedBlockIds; + private String rssAppId; + // Must be thread safe + private final Set failedTaskIds; + + public DataPusher(ShuffleWriteClient shuffleWriteClient, + Map> taskToSuccessBlockIds, + Map> taskToFailedBlockIds, + Set failedTaskIds, + int threadPoolSize, + int threadKeepAliveTime) { + this.shuffleWriteClient = shuffleWriteClient; + this.taskToSuccessBlockIds = taskToSuccessBlockIds; + this.taskToFailedBlockIds = taskToFailedBlockIds; + this.failedTaskIds = failedTaskIds; + this.executorService = new ThreadPoolExecutor( + threadPoolSize, + threadPoolSize * 2, + threadKeepAliveTime, + TimeUnit.SECONDS, + Queues.newLinkedBlockingQueue(Integer.MAX_VALUE), + ThreadUtils.getThreadFactory(this.getClass().getName()) + ); + } + + public CompletableFuture send(AddBlockEvent event) { + if (rssAppId == null) { + throw new RssException("RssAppId should be set."); + } + return CompletableFuture.supplyAsync(() -> { + String taskId = event.getTaskId(); + List shuffleBlockInfoList = event.getShuffleDataInfoList(); + try { + SendShuffleDataResult result = shuffleWriteClient.sendShuffleData( + rssAppId, + shuffleBlockInfoList, + () -> !isValidTask(taskId) + ); + putBlockId(taskToSuccessBlockIds, taskId, result.getSuccessBlockIds()); + putBlockId(taskToFailedBlockIds, taskId, result.getFailedBlockIds()); + } finally { + List callbackChain = Optional.of(event.getProcessedCallbackChain()).orElse(Collections.EMPTY_LIST); + for (Runnable runnable : callbackChain) { + runnable.run(); + } + } + return shuffleBlockInfoList.stream() + .map(x -> x.getFreeMemory()) + .reduce((a, b) -> a + b) + .get(); + }, executorService); + } + + private synchronized void putBlockId( + Map> taskToBlockIds, + String taskAttemptId, + Set blockIds) { + if (blockIds == null || blockIds.isEmpty()) { + return; + } + taskToBlockIds.computeIfAbsent(taskAttemptId, x -> Sets.newConcurrentHashSet()).addAll(blockIds); + } + + public boolean isValidTask(String taskId) { + return !failedTaskIds.contains(taskId); + } + + public void setRssAppId(String rssAppId) { + this.rssAppId = rssAppId; + } + + @Override + public void close() throws IOException { + if (executorService != null) { + try { + ThreadUtils.shutdownThreadPool(executorService, 5); + } catch (InterruptedException interruptedException) { + LOGGER.error("Errors on shutdown thread pool of [{}].", this.getClass().getSimpleName()); + } + } + } +} diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java index 9f33be388f..580fce699d 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java @@ -17,10 +17,16 @@ package org.apache.spark.shuffle.writer; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Function; +import java.util.stream.Collectors; import com.clearspring.analytics.util.Lists; import com.google.common.annotations.VisibleForTesting; @@ -61,6 +67,7 @@ public class WriteBufferManager extends MemoryConsumer { private Map partitionToSeqNo = Maps.newHashMap(); private long askExecutorMemory; private int shuffleId; + private String taskId; private long taskAttemptId; private SerializerInstance instance; private ShuffleWriteMetrics shuffleWriteMetrics; @@ -81,6 +88,9 @@ public class WriteBufferManager extends MemoryConsumer { private long requireMemoryInterval; private int requireMemoryRetryMax; private Codec codec; + private Function> spillFunc; + private long sendSizeLimit; + private int memorySpillTimeoutSec; public WriteBufferManager( int shuffleId, @@ -91,12 +101,38 @@ public WriteBufferManager( TaskMemoryManager taskMemoryManager, ShuffleWriteMetrics shuffleWriteMetrics, RssConf rssConf) { + this( + shuffleId, + null, + taskAttemptId, + bufferManagerOptions, + serializer, + partitionToServers, + taskMemoryManager, + shuffleWriteMetrics, + rssConf, + null + ); + } + + public WriteBufferManager( + int shuffleId, + String taskId, + long taskAttemptId, + BufferManagerOptions bufferManagerOptions, + Serializer serializer, + Map> partitionToServers, + TaskMemoryManager taskMemoryManager, + ShuffleWriteMetrics shuffleWriteMetrics, + RssConf rssConf, + Function> spillFunc) { super(taskMemoryManager, taskMemoryManager.pageSizeBytes(), MemoryMode.ON_HEAP); this.bufferSize = bufferManagerOptions.getBufferSize(); this.spillSize = bufferManagerOptions.getBufferSpillThreshold(); this.instance = serializer.newInstance(); this.buffers = Maps.newHashMap(); this.shuffleId = shuffleId; + this.taskId = taskId; this.taskAttemptId = taskAttemptId; this.partitionToServers = partitionToServers; this.shuffleWriteMetrics = shuffleWriteMetrics; @@ -111,6 +147,9 @@ public WriteBufferManager( .substring(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX.length()), RssSparkConfig.SPARK_SHUFFLE_COMPRESS_DEFAULT); this.codec = compress ? Codec.newInstance(rssConf) : null; + this.spillFunc = spillFunc; + this.sendSizeLimit = rssConf.get(RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMITATION); + this.memorySpillTimeoutSec = rssConf.get(RssSparkConfig.RSS_MEMORY_SPILL_TIMEOUT); } public List addRecord(int partitionId, Object key, Object value) { @@ -165,7 +204,7 @@ public List addRecord(int partitionId, Object key, Object valu } // transform all [partition, records] to [partition, ShuffleBlockInfo] and clear cache - public List clear() { + public synchronized List clear() { List result = Lists.newArrayList(); long dataSize = 0; long memoryUsed = 0; @@ -247,10 +286,64 @@ private void requestExecutorMemory(long leastMem) { } } + public List buildBlockEvents(List shuffleBlockInfoList) { + long totalSize = 0; + long memoryUsed = 0; + List events = new ArrayList<>(); + List shuffleBlockInfosPerEvent = Lists.newArrayList(); + for (ShuffleBlockInfo sbi : shuffleBlockInfoList) { + totalSize += sbi.getSize(); + memoryUsed += sbi.getFreeMemory(); + shuffleBlockInfosPerEvent.add(sbi); + // split shuffle data according to the size + if (totalSize > sendSizeLimit) { + LOG.info("Build event with " + shuffleBlockInfosPerEvent.size() + + " blocks and " + totalSize + " bytes"); + // Use final temporary variables for closures + final long _memoryUsed = memoryUsed; + events.add( + new AddBlockEvent(taskId, shuffleBlockInfosPerEvent, () -> freeAllocatedMemory(_memoryUsed)) + ); + shuffleBlockInfosPerEvent = Lists.newArrayList(); + totalSize = 0; + memoryUsed = 0; + } + } + if (!shuffleBlockInfosPerEvent.isEmpty()) { + LOG.info("Build event with " + shuffleBlockInfosPerEvent.size() + + " blocks and " + totalSize + " bytes"); + // Use final temporary variables for closures + final long _memoryUsed = memoryUsed; + events.add( + new AddBlockEvent(taskId, shuffleBlockInfosPerEvent, () -> freeAllocatedMemory(_memoryUsed)) + ); + } + return events; + } + @Override public long spill(long size, MemoryConsumer trigger) { - // there is no spill for such situation - return 0; + List events = buildBlockEvents(clear()); + List> futures = events.stream().map(x -> spillFunc.apply(x)).collect(Collectors.toList()); + CompletableFuture allOfFutures = + CompletableFuture.allOf(futures.toArray(new CompletableFuture[futures.size()])); + try { + allOfFutures.get(memorySpillTimeoutSec, TimeUnit.SECONDS); + } catch (TimeoutException timeoutException) { + // A best effort strategy to wait. + // If timeout exception occurs, the underlying tasks won't be cancelled. + } finally { + long releasedSize = futures.stream().filter(x -> x.isDone()).mapToLong(x -> { + try { + return x.get(); + } catch (Exception e) { + return 0; + } + }).sum(); + LOG.info("[taskId: {}] Spill triggered by memory consumer of {}, released memory size: {}", + taskId, trigger.getClass().getSimpleName(), releasedSize); + return releasedSize; + } } @VisibleForTesting @@ -307,4 +400,20 @@ public String getManagerCostInfo() { + estimateTime + "], requireMemoryTime[" + requireMemoryTime + "], uncompressedDataLen[" + uncompressedDataLen + "]"; } + + @VisibleForTesting + public void setTaskId(String taskId) { + this.taskId = taskId; + } + + @VisibleForTesting + public void setSpillFunc( + Function> spillFunc) { + this.spillFunc = spillFunc; + } + + @VisibleForTesting + public void setSendSizeLimit(long sendSizeLimit) { + this.sendSizeLimit = sendSizeLimit; + } } diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java new file mode 100644 index 0000000000..20711dc053 --- /dev/null +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.writer; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.function.Supplier; + +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import org.junit.jupiter.api.Test; + +import org.apache.uniffle.client.impl.ShuffleWriteClientImpl; +import org.apache.uniffle.client.response.SendShuffleDataResult; +import org.apache.uniffle.common.ShuffleBlockInfo; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class DataPusherTest { + + static class FakedShuffleWriteClient extends ShuffleWriteClientImpl { + private SendShuffleDataResult fakedShuffleDataResult; + + FakedShuffleWriteClient() { + this( + "GRPC", + 1, + 1, + 10, + 1, + 1, + 1, + false, + 1, + 1, + 1, + 1 + ); + } + + private FakedShuffleWriteClient(String clientType, int retryMax, long retryIntervalMax, int heartBeatThreadNum, + int replica, int replicaWrite, int replicaRead, boolean replicaSkipEnabled, int dataTransferPoolSize, + int dataCommitPoolSize, int unregisterThreadPoolSize, int unregisterRequestTimeSec) { + super(clientType, retryMax, retryIntervalMax, heartBeatThreadNum, replica, replicaWrite, replicaRead, + replicaSkipEnabled, dataTransferPoolSize, dataCommitPoolSize, unregisterThreadPoolSize, + unregisterRequestTimeSec); + } + + @Override + public SendShuffleDataResult sendShuffleData(String appId, List shuffleBlockInfoList, + Supplier needCancelRequest) { + return fakedShuffleDataResult; + } + + public void setFakedShuffleDataResult(SendShuffleDataResult fakedShuffleDataResult) { + this.fakedShuffleDataResult = fakedShuffleDataResult; + } + } + + @Test + public void testSendData() throws ExecutionException, InterruptedException { + FakedShuffleWriteClient shuffleWriteClient = new FakedShuffleWriteClient(); + + Map> taskToSuccessBlockIds = Maps.newConcurrentMap(); + Map> taskToFailedBlockIds = Maps.newConcurrentMap(); + Set failedTaskIds = new HashSet<>(); + + DataPusher dataPusher = new DataPusher( + shuffleWriteClient, + taskToSuccessBlockIds, + taskToFailedBlockIds, + failedTaskIds, + 1, + 2 + ); + dataPusher.setRssAppId("testSendData_appId"); + + // sync send + AddBlockEvent event = new AddBlockEvent("taskId", Arrays.asList( + new ShuffleBlockInfo( + 1, 1, 1, 1, 1, new byte[1], null, 1, 100, 1 + )) + ); + shuffleWriteClient.setFakedShuffleDataResult( + new SendShuffleDataResult( + Sets.newHashSet(1L, 2L), + Sets.newHashSet(3L, 4L) + ) + ); + CompletableFuture future = dataPusher.send(event); + long memoryFree = future.get(); + assertEquals(100, memoryFree); + assertTrue(taskToSuccessBlockIds.get("taskId").contains(1L)); + assertTrue(taskToSuccessBlockIds.get("taskId").contains(2L)); + assertTrue(taskToFailedBlockIds.get("taskId").contains(3L)); + assertTrue(taskToFailedBlockIds.get("taskId").contains(4L)); + } +} diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java index 4f57e265f3..8758054142 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java @@ -17,7 +17,11 @@ package org.apache.spark.shuffle.writer; +import java.util.Arrays; import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; import com.google.common.collect.Maps; import org.apache.commons.lang.reflect.FieldUtils; @@ -27,6 +31,7 @@ import org.apache.spark.serializer.KryoSerializer; import org.apache.spark.serializer.Serializer; import org.apache.spark.shuffle.RssSparkConfig; +import org.awaitility.Awaitility; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -195,4 +200,78 @@ public void createBlockIdTest() { sbi = wbm.createShuffleBlock(1, mockWriterBuffer); assertEquals(35184374185984L, sbi.getBlockId()); } + + @Test + public void buildBlockEventsTest() { + SparkConf conf = getConf(); + conf.set("spark.rss.client.send.size.limit", "30"); + + TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class); + + BufferManagerOptions bufferOptions = new BufferManagerOptions(conf); + WriteBufferManager wbm = new WriteBufferManager( + 0, 0, bufferOptions, new KryoSerializer(conf), + Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics(), RssSparkConfig.toRssConf(conf)); + + // every block: length=4, memoryUsed=12 + ShuffleBlockInfo info1 = new ShuffleBlockInfo(1, 1, 1, 4, 1, new byte[1], null, 1, 12, 1); + ShuffleBlockInfo info2 = new ShuffleBlockInfo(1, 1, 1, 4, 1, new byte[1], null, 1, 12, 1); + ShuffleBlockInfo info3 = new ShuffleBlockInfo(1, 1, 1, 4, 1, new byte[1], null, 1, 12, 1); + List events = wbm.buildBlockEvents(Arrays.asList(info1, info2, info3)); + assertEquals(3, events.size()); + } + + @Test + public void spillTest() { + SparkConf conf = getConf(); + conf.set("spark.rss.client.send.size.limit", "1000"); + TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class); + BufferManagerOptions bufferOptions = new BufferManagerOptions(conf); + + Function> spillFunc = event -> { + event.getProcessedCallbackChain().stream().forEach(x -> x.run()); + return CompletableFuture.completedFuture( + event.getShuffleDataInfoList().stream().mapToLong(x -> x.getFreeMemory()).sum() + ); + }; + + WriteBufferManager wbm = new WriteBufferManager( + 0, "taskId_spillTest", 0, bufferOptions, new KryoSerializer(conf), + Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics(), + RssSparkConfig.toRssConf(conf), spillFunc); + WriteBufferManager spyManager = spy(wbm); + doReturn(512L).when(spyManager).acquireMemory(anyLong()); + + String testKey = "Key"; + String testValue = "Value"; + spyManager.addRecord(0, testKey, testValue); + spyManager.addRecord(1, testKey, testValue); + + // case1. all events are flushed within normal time. + long releasedSize = spyManager.spill(1000, mock(WriteBufferManager.class)); + assertEquals(64, releasedSize); + + // case2. partial events are not flushed within normal time. + // when calling spill func, 2 events should be spilled. But + // only event will be finished in the expected time. + spyManager.setSendSizeLimit(30); + spyManager.addRecord(0, testKey, testValue); + spyManager.addRecord(1, testKey, testValue); + spyManager.setSpillFunc(event -> CompletableFuture.supplyAsync(() -> { + int partitionId = event.getShuffleDataInfoList().get(0).getPartitionId(); + if (partitionId == 1) { + try { + Thread.sleep(2000); + } catch (InterruptedException interruptedException) { + // ignore. + } + } + event.getProcessedCallbackChain().stream().forEach(x -> x.run()); + return event.getShuffleDataInfoList().stream().mapToLong(x -> x.getFreeMemory()).sum(); + })); + releasedSize = spyManager.spill(1000, mock(WriteBufferManager.class)); + assertEquals(32, releasedSize); + assertEquals(32, spyManager.getUsedBytes()); + Awaitility.await().timeout(3, TimeUnit.SECONDS).until(() -> spyManager.getUsedBytes() == 0); + } } diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index 0b301154e3..77b9ef0dd2 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -17,17 +17,17 @@ package org.apache.spark.shuffle; +import java.io.IOException; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.function.Function; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.Queues; import com.google.common.collect.Sets; import org.apache.hadoop.conf.Configuration; import org.apache.spark.ShuffleDependency; @@ -39,11 +39,11 @@ import org.apache.spark.shuffle.reader.RssShuffleReader; import org.apache.spark.shuffle.writer.AddBlockEvent; import org.apache.spark.shuffle.writer.BufferManagerOptions; +import org.apache.spark.shuffle.writer.DataPusher; import org.apache.spark.shuffle.writer.RssShuffleWriter; import org.apache.spark.shuffle.writer.WriteBufferManager; import org.apache.spark.storage.BlockId; import org.apache.spark.storage.BlockManagerId; -import org.apache.spark.util.EventLoop; import org.roaringbitmap.longlong.Roaring64NavigableMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -54,12 +54,10 @@ import org.apache.uniffle.client.api.ShuffleWriteClient; import org.apache.uniffle.client.factory.ShuffleClientFactory; -import org.apache.uniffle.client.response.SendShuffleDataResult; import org.apache.uniffle.client.util.ClientUtils; import org.apache.uniffle.common.PartitionRange; import org.apache.uniffle.common.RemoteStorageInfo; import org.apache.uniffle.common.ShuffleAssignmentsInfo; -import org.apache.uniffle.common.ShuffleBlockInfo; import org.apache.uniffle.common.ShuffleDataDistributionType; import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.exception.RssException; @@ -80,7 +78,6 @@ public class RssShuffleManager implements ShuffleManager { private ShuffleWriteClient shuffleWriteClient; private Map> taskToSuccessBlockIds = JavaUtils.newConcurrentMap(); private Map> taskToFailedBlockIds = JavaUtils.newConcurrentMap(); - private Map taskToBufferManager = JavaUtils.newConcurrentMap(); private final int dataReplica; private final int dataReplicaWrite; private final int dataReplicaRead; @@ -92,58 +89,7 @@ public class RssShuffleManager implements ShuffleManager { private boolean dynamicConfEnabled = false; private final String user; private final String uuid; - private ThreadPoolExecutor threadPoolExecutor; - private EventLoop eventLoop = new EventLoop("ShuffleDataQueue") { - - @Override - public void onReceive(AddBlockEvent event) { - threadPoolExecutor.execute(() -> sendShuffleData(event.getTaskId(), event.getShuffleDataInfoList())); - } - - private void sendShuffleData(String taskId, List shuffleDataInfoList) { - try { - SendShuffleDataResult result = shuffleWriteClient.sendShuffleData( - appId, - shuffleDataInfoList, - () -> !isValidTask(taskId) - ); - putBlockId(taskToSuccessBlockIds, taskId, result.getSuccessBlockIds()); - putBlockId(taskToFailedBlockIds, taskId, result.getFailedBlockIds()); - } finally { - // data is already send, release the memory to executor - long releaseSize = 0; - for (ShuffleBlockInfo sbi : shuffleDataInfoList) { - releaseSize += sbi.getFreeMemory(); - } - WriteBufferManager bufferManager = taskToBufferManager.get(taskId); - if (bufferManager != null) { - bufferManager.freeAllocatedMemory(releaseSize); - } - LOG.debug("Finish send data and release " + releaseSize + " bytes"); - } - } - - private synchronized void putBlockId( - Map> taskToBlockIds, - String taskAttemptId, - Set blockIds) { - if (blockIds == null) { - return; - } - if (taskToBlockIds.get(taskAttemptId) == null) { - taskToBlockIds.put(taskAttemptId, Sets.newConcurrentHashSet()); - } - taskToBlockIds.get(taskAttemptId).addAll(blockIds); - } - - @Override - public void onError(Throwable throwable) { - } - - @Override - public void onStart() { - } - }; + private DataPusher dataPusher; public RssShuffleManager(SparkConf sparkConf, boolean isDriver) { if (sparkConf.getBoolean("spark.sql.adaptive.enabled", false)) { @@ -193,19 +139,22 @@ public RssShuffleManager(SparkConf sparkConf, boolean isDriver) { sparkConf.set("spark.shuffle.reduceLocality.enabled", "false"); LOG.info("Disable shuffle data locality in RssShuffleManager."); if (!sparkConf.getBoolean(RssSparkConfig.RSS_TEST_FLAG.key(), false)) { - // for non-driver executor, start a thread for sending shuffle data to shuffle server - LOG.info("RSS data send thread is starting"); - eventLoop.start(); - int poolSize = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_SIZE); - int keepAliveTime = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_KEEPALIVE); - threadPoolExecutor = new ThreadPoolExecutor(poolSize, poolSize * 2, keepAliveTime, TimeUnit.SECONDS, - Queues.newLinkedBlockingQueue(Integer.MAX_VALUE), - ThreadUtils.getThreadFactory("SendData")); - if (isDriver) { heartBeatScheduledExecutorService = ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat"); } + // for non-driver executor, start a thread for sending shuffle data to shuffle server + LOG.info("RSS data pusher is starting..."); + int poolSize = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_SIZE); + int keepAliveTime = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_KEEPALIVE); + this.dataPusher = new DataPusher( + shuffleWriteClient, + taskToSuccessBlockIds, + taskToFailedBlockIds, + failedTaskIds, + poolSize, + keepAliveTime + ); } } @@ -236,6 +185,7 @@ public ShuffleHandle registerShuffle(int shuffleId, int numMaps, Shuff // will be called many times depend on how many shuffle stage if ("".equals(appId)) { appId = SparkEnv.get().conf().getAppId() + "_" + uuid; + dataPusher.setRssAppId(appId); LOG.info("Generate application id used in rss: " + appId); } @@ -344,6 +294,13 @@ protected void registerCoordinator() { shuffleWriteClient.registerCoordinators(coordinators); } + public CompletableFuture sendData(AddBlockEvent event) { + if (dataPusher != null && event != null) { + return dataPusher.send(event); + } + return new CompletableFuture<>(); + } + // This method is called in Spark executor, // getting information from Spark driver via the ShuffleHandle. @Override @@ -352,6 +309,7 @@ public ShuffleWriter getWriter(ShuffleHandle handle, int mapId, if (handle instanceof RssShuffleHandle) { RssShuffleHandle rssHandle = (RssShuffleHandle) handle; appId = rssHandle.getAppId(); + dataPusher.setRssAppId(appId); int shuffleId = rssHandle.getShuffleId(); String taskId = "" + context.taskAttemptId() + "_" + context.attemptNumber(); @@ -359,15 +317,16 @@ public ShuffleWriter getWriter(ShuffleHandle handle, int mapId, ShuffleWriteMetrics writeMetrics = context.taskMetrics().shuffleWriteMetrics(); WriteBufferManager bufferManager = new WriteBufferManager( shuffleId, + taskId, context.taskAttemptId(), bufferOptions, rssHandle.getDependency().serializer(), rssHandle.getPartitionToServers(), context.taskMemoryManager(), writeMetrics, - RssSparkConfig.toRssConf(sparkConf) + RssSparkConfig.toRssConf(sparkConf), + this::sendData ); - taskToBufferManager.put(taskId, bufferManager); return new RssShuffleWriter<>(rssHandle.getAppId(), shuffleId, taskId, context.taskAttemptId(), bufferManager, writeMetrics, this, sparkConf, shuffleWriteClient, rssHandle, @@ -448,7 +407,13 @@ public void stop() { if (heartBeatScheduledExecutorService != null) { heartBeatScheduledExecutorService.shutdownNow(); } - threadPoolExecutor.shutdownNow(); + if (dataPusher != null) { + try { + dataPusher.close(); + } catch (IOException e) { + LOG.warn("Errors on closing data pusher", e); + } + } shuffleWriteClient.close(); } @@ -457,15 +422,6 @@ public ShuffleBlockResolver shuffleBlockResolver() { throw new RssException("RssShuffleManager.shuffleBlockResolver is not implemented"); } - public EventLoop getEventLoop() { - return eventLoop; - } - - @VisibleForTesting - public void setEventLoop(EventLoop eventLoop) { - this.eventLoop = eventLoop; - } - // when speculation enable, duplicate data will be sent and reported to shuffle server, // get the actual tasks and filter the duplicate data caused by speculation task private Roaring64NavigableMap getExpectedTasks(int shuffleId, int startPartition, int endPartition) { @@ -520,15 +476,9 @@ public void addSuccessBlockIds(String taskId, Set blockIds) { taskToSuccessBlockIds.get(taskId).addAll(blockIds); } - @VisibleForTesting - public Map getTaskToBufferManager() { - return taskToBufferManager; - } - public void clearTaskMeta(String taskId) { taskToSuccessBlockIds.remove(taskId); taskToFailedBlockIds.remove(taskId); - taskToBufferManager.remove(taskId); } @VisibleForTesting @@ -550,4 +500,12 @@ public boolean markFailedTask(String taskId) { public boolean isValidTask(String taskId) { return !failedTaskIds.contains(taskId); } + + public DataPusher getDataPusher() { + return dataPusher; + } + + public void setDataPusher(DataPusher dataPusher) { + this.dataPusher = dataPusher; + } } diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java index 81445f4414..b0b7653f8d 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java @@ -81,7 +81,6 @@ public class RssShuffleWriter extends ShuffleWriter { private RssShuffleManager shuffleManager; private long sendCheckTimeout; private long sendCheckInterval; - private long sendSizeLimit; private boolean isMemoryShuffleEnabled; private final Function taskFailureCallback; @@ -136,8 +135,6 @@ public RssShuffleWriter( this.shouldPartition = partitioner.numPartitions() > 1; this.sendCheckTimeout = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS); this.sendCheckInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS); - this.sendSizeLimit = sparkConf.getSizeAsBytes(RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMIT.key(), - RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMIT.defaultValue().get()); this.bitmapSplitNum = sparkConf.get(RssSparkConfig.RSS_CLIENT_BITMAP_SPLIT_NUM); this.partitionToBlockIds = Maps.newHashMap(); this.shuffleWriteClient = shuffleWriteClient; @@ -233,26 +230,8 @@ private void processShuffleBlockInfos(List shuffleBlockInfoLis // don't send huge block to shuffle server, or there will be OOM if shuffle sever receives data more than expected protected void postBlockEvent(List shuffleBlockInfoList) { - long totalSize = 0; - List shuffleBlockInfosPerEvent = Lists.newArrayList(); - for (ShuffleBlockInfo sbi : shuffleBlockInfoList) { - totalSize += sbi.getSize(); - shuffleBlockInfosPerEvent.add(sbi); - // split shuffle data according to the size - if (totalSize > sendSizeLimit) { - LOG.debug("Post event to queue with " + shuffleBlockInfosPerEvent.size() - + " blocks and " + totalSize + " bytes"); - shuffleManager.getEventLoop().post( - new AddBlockEvent(taskId, shuffleBlockInfosPerEvent)); - shuffleBlockInfosPerEvent = Lists.newArrayList(); - totalSize = 0; - } - } - if (!shuffleBlockInfosPerEvent.isEmpty()) { - LOG.debug("Post event to queue with " + shuffleBlockInfosPerEvent.size() - + " blocks and " + totalSize + " bytes"); - shuffleManager.getEventLoop().post( - new AddBlockEvent(taskId, shuffleBlockInfosPerEvent)); + for (AddBlockEvent event : bufferManager.buildBlockEvents(shuffleBlockInfoList)) { + shuffleManager.sendData(event); } } diff --git a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java index 81d1bd2303..61ff6325e4 100644 --- a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java +++ b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java @@ -21,6 +21,8 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; import java.util.stream.Collectors; import com.google.common.collect.Lists; @@ -38,7 +40,6 @@ import org.apache.spark.shuffle.RssShuffleHandle; import org.apache.spark.shuffle.RssShuffleManager; import org.apache.spark.shuffle.RssSparkConfig; -import org.apache.spark.util.EventLoop; import org.junit.jupiter.api.Test; import scala.Product2; import scala.Tuple2; @@ -51,7 +52,6 @@ import org.apache.uniffle.storage.util.StorageType; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.anyLong; @@ -126,11 +126,36 @@ public void checkBlockSendResultTest() { manager.clearTaskMeta(taskId); assertTrue(manager.getSuccessBlockIds(taskId).isEmpty()); assertTrue(manager.getFailedBlockIds(taskId).isEmpty()); - assertNull(manager.getTaskToBufferManager().get(taskId)); sc.stop(); } + static class FakedDataPusher extends DataPusher { + private final Function> sendFunc; + + FakedDataPusher(Function> sendFunc) { + this(null, null, null, null, 1, 1, sendFunc); + } + + private FakedDataPusher( + ShuffleWriteClient shuffleWriteClient, + Map> taskToSuccessBlockIds, + Map> taskToFailedBlockIds, + Set failedTaskIds, + int threadPoolSize, + int threadKeepAliveTime, + Function> sendFunc) { + super(shuffleWriteClient, taskToSuccessBlockIds, taskToFailedBlockIds, failedTaskIds, threadPoolSize, + threadKeepAliveTime); + this.sendFunc = sendFunc; + } + + @Override + public CompletableFuture send(AddBlockEvent event) { + return sendFunc.apply(event); + } + } + @Test public void writeTest() throws Exception { SparkConf conf = new SparkConf(); @@ -149,21 +174,15 @@ public void writeTest() throws Exception { RssShuffleManager manager = new RssShuffleManager(conf, false); List shuffleBlockInfos = Lists.newArrayList(); - manager.setEventLoop(new EventLoop("test") { - @Override - public void onReceive(AddBlockEvent event) { - assertEquals("taskId", event.getTaskId()); - shuffleBlockInfos.addAll(event.getShuffleDataInfoList()); - Set blockIds = event.getShuffleDataInfoList().parallelStream() - .map(sdi -> sdi.getBlockId()).collect(Collectors.toSet()); - manager.addSuccessBlockIds(event.getTaskId(), blockIds); - } - - @Override - public void onError(Throwable e) { - } + DataPusher dataPusher = new FakedDataPusher(event -> { + assertEquals("taskId", event.getTaskId()); + shuffleBlockInfos.addAll(event.getShuffleDataInfoList()); + Set blockIds = event.getShuffleDataInfoList().parallelStream() + .map(sdi -> sdi.getBlockId()).collect(Collectors.toSet()); + manager.addSuccessBlockIds(event.getTaskId(), blockIds); + return CompletableFuture.completedFuture(0L); }); - manager.getEventLoop().start(); + manager.setDataPusher(dataPusher); Partitioner mockPartitioner = mock(Partitioner.class); ShuffleDependency mockDependency = mock(ShuffleDependency.class); @@ -200,8 +219,8 @@ public void onError(Throwable e) { ShuffleWriteMetrics shuffleWriteMetrics = new ShuffleWriteMetrics(); BufferManagerOptions bufferOptions = new BufferManagerOptions(conf); WriteBufferManager bufferManager = new WriteBufferManager( - 0, 0, bufferOptions, kryoSerializer, - partitionToServers, mockTaskMemoryManager, shuffleWriteMetrics, new RssConf()); + 0, "taskId", 0, bufferOptions, kryoSerializer, + partitionToServers, mockTaskMemoryManager, shuffleWriteMetrics, new RssConf(), null); WriteBufferManager bufferManagerSpy = spy(bufferManager); doReturn(1000000L).when(bufferManagerSpy).acquireMemory(anyLong()); @@ -253,37 +272,46 @@ public void onError(Throwable e) { @Test public void postBlockEventTest() throws Exception { - final WriteBufferManager mockBufferManager = mock(WriteBufferManager.class); final ShuffleWriteMetrics mockMetrics = mock(ShuffleWriteMetrics.class); ShuffleDependency mockDependency = mock(ShuffleDependency.class); Partitioner mockPartitioner = mock(Partitioner.class); - final RssShuffleManager mockShuffleManager = mock(RssShuffleManager.class); when(mockDependency.partitioner()).thenReturn(mockPartitioner); when(mockPartitioner.numPartitions()).thenReturn(2); List events = Lists.newArrayList(); - EventLoop eventLoop = new EventLoop("test") { - @Override - public void onReceive(AddBlockEvent event) { - events.add(event); - } + SparkConf conf = new SparkConf(); + conf.setAppName("postBlockEventTest").setMaster("local[2]") + .set(RssSparkConfig.RSS_TEST_FLAG.key(), "true") + .set(RssSparkConfig.RSS_TEST_MODE_ENABLE.key(), "true") + .set(RssSparkConfig.RSS_WRITER_BUFFER_SIZE.key(), "32") + .set(RssSparkConfig.RSS_WRITER_SERIALIZER_BUFFER_SIZE.key(), "32") + .set(RssSparkConfig.RSS_WRITER_BUFFER_SEGMENT_SIZE.key(), "64") + .set(RssSparkConfig.RSS_WRITER_BUFFER_SPILL_SIZE.key(), "128") + .set(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS.key(), "1000") + .set(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.LOCALFILE.name()) + .set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), "127.0.0.1:12345,127.0.0.1:12346") + .set(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMITATION.key(), "64") + .set(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.LOCALFILE.name()); - @Override - public void onError(Throwable e) { - } - }; - eventLoop.start(); + TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class); + BufferManagerOptions bufferOptions = new BufferManagerOptions(conf); + WriteBufferManager bufferManager = new WriteBufferManager( + 0, 0, bufferOptions, new KryoSerializer(conf), + Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics(), RssSparkConfig.toRssConf(conf)); + + RssShuffleManager manager = new RssShuffleManager(conf, false); + DataPusher dataPusher = new FakedDataPusher(event -> { + events.add(event); + return CompletableFuture.completedFuture(0L); + }); + manager.setDataPusher(dataPusher); - when(mockShuffleManager.getEventLoop()).thenReturn(eventLoop); RssShuffleHandle mockHandle = mock(RssShuffleHandle.class); when(mockHandle.getDependency()).thenReturn(mockDependency); ShuffleWriteClient mockWriteClient = mock(ShuffleWriteClient.class); - SparkConf conf = new SparkConf(); - conf.set(RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMIT.key(), "64") - .set(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.LOCALFILE.name()); RssShuffleWriter writer = new RssShuffleWriter<>("appId", 0, "taskId", 1L, - mockBufferManager, mockMetrics, mockShuffleManager, conf, mockWriteClient, mockHandle); + bufferManager, mockMetrics, manager, conf, mockWriteClient, mockHandle); List shuffleBlockInfoList = createShuffleBlockList(1, 31); writer.postBlockEvent(shuffleBlockInfoList); Thread.sleep(500); diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index 04b57136d4..e875132d96 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -17,21 +17,20 @@ package org.apache.spark.shuffle; +import java.io.IOException; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.stream.Collectors; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.Queues; import com.google.common.collect.Sets; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import org.apache.hadoop.conf.Configuration; @@ -46,12 +45,12 @@ import org.apache.spark.shuffle.reader.RssShuffleReader; import org.apache.spark.shuffle.writer.AddBlockEvent; import org.apache.spark.shuffle.writer.BufferManagerOptions; +import org.apache.spark.shuffle.writer.DataPusher; import org.apache.spark.shuffle.writer.RssShuffleWriter; import org.apache.spark.shuffle.writer.WriteBufferManager; import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.storage.BlockId; import org.apache.spark.storage.BlockManagerId; -import org.apache.spark.util.EventLoop; import org.roaringbitmap.longlong.Roaring64NavigableMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -62,12 +61,10 @@ import org.apache.uniffle.client.api.ShuffleWriteClient; import org.apache.uniffle.client.factory.ShuffleClientFactory; -import org.apache.uniffle.client.response.SendShuffleDataResult; import org.apache.uniffle.client.util.ClientUtils; import org.apache.uniffle.common.PartitionRange; import org.apache.uniffle.common.RemoteStorageInfo; import org.apache.uniffle.common.ShuffleAssignmentsInfo; -import org.apache.uniffle.common.ShuffleBlockInfo; import org.apache.uniffle.common.ShuffleDataDistributionType; import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.config.RssClientConf; @@ -84,7 +81,6 @@ public class RssShuffleManager implements ShuffleManager { private final String clientType; private final long heartbeatInterval; private final long heartbeatTimeout; - private final ThreadPoolExecutor threadPoolExecutor; private AtomicReference id = new AtomicReference<>(); private SparkConf sparkConf; private final int dataReplica; @@ -96,7 +92,6 @@ public class RssShuffleManager implements ShuffleManager { private ShuffleWriteClient shuffleWriteClient; private final Map> taskToSuccessBlockIds; private final Map> taskToFailedBlockIds; - private Map taskToBufferManager = JavaUtils.newConcurrentMap(); private ScheduledExecutorService heartBeatScheduledExecutorService; private boolean heartbeatStarted = false; private boolean dynamicConfEnabled = false; @@ -104,55 +99,7 @@ public class RssShuffleManager implements ShuffleManager { private String user; private String uuid; private Set failedTaskIds = Sets.newConcurrentHashSet(); - private final EventLoop eventLoop; - private final EventLoop defaultEventLoop = new EventLoop("ShuffleDataQueue") { - - @Override - public void onReceive(AddBlockEvent event) { - threadPoolExecutor.execute(() -> sendShuffleData(event.getTaskId(), event.getShuffleDataInfoList())); - } - - @Override - public void onError(Throwable throwable) { - LOG.info("Shuffle event loop error...", throwable); - } - - @Override - public void onStart() { - LOG.info("Shuffle event loop start..."); - } - - private void sendShuffleData(String taskId, List shuffleDataInfoList) { - try { - SendShuffleDataResult result = shuffleWriteClient.sendShuffleData( - id.get(), - shuffleDataInfoList, - () -> !isValidTask(taskId) - ); - putBlockId(taskToSuccessBlockIds, taskId, result.getSuccessBlockIds()); - putBlockId(taskToFailedBlockIds, taskId, result.getFailedBlockIds()); - } finally { - final AtomicLong releaseSize = new AtomicLong(0); - shuffleDataInfoList.forEach((sbi) -> releaseSize.addAndGet(sbi.getFreeMemory())); - WriteBufferManager bufferManager = taskToBufferManager.get(taskId); - if (bufferManager != null) { - bufferManager.freeAllocatedMemory(releaseSize.get()); - } - LOG.debug("Spark 3.0 finish send data and release " + releaseSize + " bytes"); - } - } - - private synchronized void putBlockId( - Map> taskToBlockIds, - String taskAttemptId, - Set blockIds) { - if (blockIds == null || blockIds.isEmpty()) { - return; - } - taskToBlockIds.putIfAbsent(taskAttemptId, Sets.newConcurrentHashSet()); - taskToBlockIds.get(taskAttemptId).addAll(blockIds); - } - }; + private DataPusher dataPusher; public RssShuffleManager(SparkConf conf, boolean isDriver) { this.sparkConf = conf; @@ -211,19 +158,28 @@ public RssShuffleManager(SparkConf conf, boolean isDriver) { LOG.info("Disable shuffle data locality in RssShuffleManager."); taskToSuccessBlockIds = JavaUtils.newConcurrentMap(); taskToFailedBlockIds = JavaUtils.newConcurrentMap(); - // for non-driver executor, start a thread for sending shuffle data to shuffle server - LOG.info("RSS data send thread is starting"); - eventLoop = defaultEventLoop; - eventLoop.start(); - int poolSize = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_SIZE); - int keepAliveTime = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_KEEPALIVE); - threadPoolExecutor = new ThreadPoolExecutor(poolSize, poolSize * 2, keepAliveTime, TimeUnit.SECONDS, - Queues.newLinkedBlockingQueue(Integer.MAX_VALUE), - ThreadUtils.getThreadFactory("SendData")); if (isDriver) { heartBeatScheduledExecutorService = ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat"); } + LOG.info("Rss data pusher is starting..."); + int poolSize = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_SIZE); + int keepAliveTime = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_KEEPALIVE); + this.dataPusher = new DataPusher( + shuffleWriteClient, + taskToSuccessBlockIds, + taskToFailedBlockIds, + failedTaskIds, + poolSize, + keepAliveTime + ); + } + + public CompletableFuture sendData(AddBlockEvent event) { + if (dataPusher != null && event != null) { + return dataPusher.send(event); + } + return new CompletableFuture<>(); } @VisibleForTesting @@ -242,7 +198,7 @@ protected static ShuffleDataDistributionType getDataDistributionType(SparkConf s RssShuffleManager( SparkConf conf, boolean isDriver, - EventLoop loop, + DataPusher dataPusher, Map> taskToSuccessBlockIds, Map> taskToFailedBlockIds) { this.sparkConf = conf; @@ -283,14 +239,8 @@ protected static ShuffleDataDistributionType getDataDistributionType(SparkConf s ); this.taskToSuccessBlockIds = taskToSuccessBlockIds; this.taskToFailedBlockIds = taskToFailedBlockIds; - if (loop != null) { - eventLoop = loop; - } else { - eventLoop = defaultEventLoop; - } - eventLoop.start(); - threadPoolExecutor = null; - heartBeatScheduledExecutorService = null; + this.heartBeatScheduledExecutorService = null; + this.dataPusher = dataPusher; } // This method is called in Spark driver side, @@ -315,6 +265,7 @@ public ShuffleHandle registerShuffle(int shuffleId, ShuffleDependency< if (id.get() == null) { id.compareAndSet(null, SparkEnv.get().conf().getAppId() + "_" + uuid); + dataPusher.setRssAppId(id.get()); } LOG.info("Generate application id used in rss: " + id.get()); @@ -390,6 +341,7 @@ public ShuffleWriter getWriter( // todo: this implement is tricky, we should refactor it if (id.get() == null) { id.compareAndSet(null, rssHandle.getAppId()); + dataPusher.setRssAppId(id.get()); } int shuffleId = rssHandle.getShuffleId(); String taskId = "" + context.taskAttemptId() + "_" + context.attemptNumber(); @@ -401,10 +353,9 @@ public ShuffleWriter getWriter( writeMetrics = context.taskMetrics().shuffleWriteMetrics(); } WriteBufferManager bufferManager = new WriteBufferManager( - shuffleId, context.taskAttemptId(), bufferOptions, rssHandle.getDependency().serializer(), + shuffleId, taskId, context.taskAttemptId(), bufferOptions, rssHandle.getDependency().serializer(), rssHandle.getPartitionToServers(), context.taskMemoryManager(), - writeMetrics, RssSparkConfig.toRssConf(sparkConf)); - taskToBufferManager.put(taskId, bufferManager); + writeMetrics, RssSparkConfig.toRssConf(sparkConf), this::sendData); LOG.info("RssHandle appId {} shuffleId {} ", rssHandle.getAppId(), rssHandle.getShuffleId()); return new RssShuffleWriter<>(rssHandle.getAppId(), shuffleId, taskId, context.taskAttemptId(), bufferManager, writeMetrics, this, sparkConf, shuffleWriteClient, rssHandle, @@ -660,21 +611,21 @@ public void stop() { if (heartBeatScheduledExecutorService != null) { heartBeatScheduledExecutorService.shutdownNow(); } - if (threadPoolExecutor != null) { - threadPoolExecutor.shutdownNow(); - } if (shuffleWriteClient != null) { shuffleWriteClient.close(); } - if (eventLoop != null) { - eventLoop.stop(); + if (dataPusher != null) { + try { + dataPusher.close(); + } catch (IOException e) { + LOG.warn("Errors on closing data pusher", e); + } } } public void clearTaskMeta(String taskId) { taskToSuccessBlockIds.remove(taskId); taskToFailedBlockIds.remove(taskId); - taskToBufferManager.remove(taskId); } @VisibleForTesting @@ -736,12 +687,6 @@ private synchronized void startHeartbeat() { } } - public void postEvent(AddBlockEvent addBlockEvent) { - if (eventLoop != null) { - eventLoop.post(addBlockEvent); - } - } - public Set getFailedBlockIds(String taskId) { Set result = taskToFailedBlockIds.get(taskId); if (result == null) { diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java index 4784c3908e..fc1e32f5d1 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java @@ -75,7 +75,6 @@ public class RssShuffleWriter extends ShuffleWriter { private final boolean shouldPartition; private final long sendCheckTimeout; private final long sendCheckInterval; - private final long sendSizeLimit; private final int bitmapSplitNum; private final Map> partitionToBlockIds; private final ShuffleWriteClient shuffleWriteClient; @@ -137,8 +136,6 @@ public RssShuffleWriter( this.shouldPartition = partitioner.numPartitions() > 1; this.sendCheckTimeout = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS); this.sendCheckInterval = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS); - this.sendSizeLimit = sparkConf.getSizeAsBytes(RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMIT.key(), - RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMIT.defaultValue().get()); this.bitmapSplitNum = sparkConf.get(RssSparkConfig.RSS_CLIENT_BITMAP_SPLIT_NUM); this.partitionToBlockIds = Maps.newHashMap(); this.shuffleWriteClient = shuffleWriteClient; @@ -231,26 +228,8 @@ private void processShuffleBlockInfos(List shuffleBlockInfoLis } protected void postBlockEvent(List shuffleBlockInfoList) { - long totalSize = 0; - List shuffleBlockInfosPerEvent = Lists.newArrayList(); - for (ShuffleBlockInfo sbi : shuffleBlockInfoList) { - totalSize += sbi.getSize(); - shuffleBlockInfosPerEvent.add(sbi); - // split shuffle data according to the size - if (totalSize > sendSizeLimit) { - LOG.debug("Post event to queue with " + shuffleBlockInfosPerEvent.size() - + " blocks and " + totalSize + " bytes"); - shuffleManager.postEvent( - new AddBlockEvent(taskId, shuffleBlockInfosPerEvent)); - shuffleBlockInfosPerEvent = Lists.newArrayList(); - totalSize = 0; - } - } - if (!shuffleBlockInfosPerEvent.isEmpty()) { - LOG.debug("Post event to queue with " + shuffleBlockInfosPerEvent.size() - + " blocks and " + totalSize + " bytes"); - shuffleManager.postEvent( - new AddBlockEvent(taskId, shuffleBlockInfosPerEvent)); + for (AddBlockEvent event : bufferManager.buildBlockEvents(shuffleBlockInfoList)) { + shuffleManager.sendData(event); } } diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/TestUtils.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/TestUtils.java index 033966a7a3..1c7f988aa0 100644 --- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/TestUtils.java +++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/TestUtils.java @@ -22,8 +22,7 @@ import org.apache.commons.lang3.SystemUtils; import org.apache.spark.SparkConf; -import org.apache.spark.shuffle.writer.AddBlockEvent; -import org.apache.spark.util.EventLoop; +import org.apache.spark.shuffle.writer.DataPusher; public class TestUtils { @@ -33,10 +32,10 @@ private TestUtils() { public static RssShuffleManager createShuffleManager( SparkConf conf, Boolean isDriver, - EventLoop loop, + DataPusher dataPusher, Map> successBlockIds, Map> failBlockIds) { - return new RssShuffleManager(conf, isDriver, loop, successBlockIds, failBlockIds); + return new RssShuffleManager(conf, isDriver, dataPusher, successBlockIds, failBlockIds); } public static boolean isMacOnAppleSilicon() { diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java index f446cf3aa7..25c76497ea 100644 --- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java +++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java @@ -17,11 +17,13 @@ package org.apache.spark.shuffle.writer; - +import java.time.Duration; import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; import java.util.stream.Collectors; import com.google.common.collect.Lists; @@ -40,7 +42,7 @@ import org.apache.spark.shuffle.RssShuffleManager; import org.apache.spark.shuffle.RssSparkConfig; import org.apache.spark.shuffle.TestUtils; -import org.apache.spark.util.EventLoop; +import org.awaitility.Awaitility; import org.junit.jupiter.api.Test; import scala.Product2; import scala.Tuple2; @@ -49,7 +51,6 @@ import org.apache.uniffle.client.api.ShuffleWriteClient; import org.apache.uniffle.common.ShuffleBlockInfo; import org.apache.uniffle.common.ShuffleServerInfo; -import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.util.JavaUtils; import org.apache.uniffle.storage.util.StorageType; @@ -102,7 +103,7 @@ public void checkBlockSendResultTest() { BufferManagerOptions bufferOptions = new BufferManagerOptions(conf); WriteBufferManager bufferManager = new WriteBufferManager( 0, 0, bufferOptions, kryoSerializer, - Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics(), new RssConf()); + Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics(), RssSparkConfig.toRssConf(conf)); WriteBufferManager bufferManagerSpy = spy(bufferManager); RssShuffleWriter rssShuffleWriter = new RssShuffleWriter<>("appId", 0, "taskId", 1L, @@ -135,6 +136,32 @@ public void checkBlockSendResultTest() { sc.stop(); } + static class FakedDataPusher extends DataPusher { + private final Function> sendFunc; + + FakedDataPusher(Function> sendFunc) { + this(null, null, null, null, 1, 1, sendFunc); + } + + private FakedDataPusher( + ShuffleWriteClient shuffleWriteClient, + Map> taskToSuccessBlockIds, + Map> taskToFailedBlockIds, + Set failedTaskIds, + int threadPoolSize, + int threadKeepAliveTime, + Function> sendFunc) { + super(shuffleWriteClient, taskToSuccessBlockIds, taskToFailedBlockIds, failedTaskIds, threadPoolSize, + threadKeepAliveTime); + this.sendFunc = sendFunc; + } + + @Override + public CompletableFuture send(AddBlockEvent event) { + return sendFunc.apply(event); + } + } + @Test public void writeTest() throws Exception { SparkConf conf = new SparkConf(); @@ -151,27 +178,24 @@ public void writeTest() throws Exception { // init SparkContext List shuffleBlockInfos = Lists.newArrayList(); final SparkContext sc = SparkContext.getOrCreate(conf); - Map> successBlockIds = JavaUtils.newConcurrentMap(); - EventLoop testLoop = new EventLoop("test") { - @Override - public void onReceive(AddBlockEvent event) { - assertEquals("taskId", event.getTaskId()); - shuffleBlockInfos.addAll(event.getShuffleDataInfoList()); - Set blockIds = event.getShuffleDataInfoList().parallelStream() - .map(sdi -> sdi.getBlockId()).collect(Collectors.toSet()); - successBlockIds.putIfAbsent(event.getTaskId(), Sets.newConcurrentHashSet()); - successBlockIds.get(event.getTaskId()).addAll(blockIds); - } - - @Override - public void onError(Throwable e) { - } - }; + Map> successBlockIds = Maps.newConcurrentMap(); + + FakedDataPusher dataPusher = new FakedDataPusher( + event -> { + assertEquals("taskId", event.getTaskId()); + shuffleBlockInfos.addAll(event.getShuffleDataInfoList()); + Set blockIds = event.getShuffleDataInfoList().parallelStream() + .map(sdi -> sdi.getBlockId()).collect(Collectors.toSet()); + successBlockIds.putIfAbsent(event.getTaskId(), Sets.newConcurrentHashSet()); + successBlockIds.get(event.getTaskId()).addAll(blockIds); + return new CompletableFuture<>(); + } + ); final RssShuffleManager manager = TestUtils.createShuffleManager( conf, false, - testLoop, + dataPusher, successBlockIds, JavaUtils.newConcurrentMap()); Serializer kryoSerializer = new KryoSerializer(conf); @@ -210,7 +234,11 @@ public void onError(Throwable e) { ShuffleWriteMetrics shuffleWriteMetrics = new ShuffleWriteMetrics(); WriteBufferManager bufferManager = new WriteBufferManager( 0, 0, bufferOptions, kryoSerializer, - partitionToServers, mockTaskMemoryManager, shuffleWriteMetrics, new RssConf()); + partitionToServers, mockTaskMemoryManager, shuffleWriteMetrics, + RssSparkConfig.toRssConf(conf) + ); + bufferManager.setTaskId("taskId"); + WriteBufferManager bufferManagerSpy = spy(bufferManager); RssShuffleWriter rssShuffleWriter = new RssShuffleWriter<>("appId", 0, "taskId", 1L, bufferManagerSpy, shuffleWriteMetrics, manager, conf, mockShuffleWriteClient, mockHandle); @@ -265,7 +293,16 @@ public void onError(Throwable e) { @Test public void postBlockEventTest() throws Exception { - WriteBufferManager mockBufferManager = mock(WriteBufferManager.class); + SparkConf conf = new SparkConf(); + conf.set(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMITATION.key(), "64") + .set(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.MEMORY_LOCALFILE.name()); + + BufferManagerOptions bufferOptions = new BufferManagerOptions(conf); + WriteBufferManager bufferManager = new WriteBufferManager( + 0, 0, bufferOptions, new KryoSerializer(conf), + Maps.newHashMap(), mock(TaskMemoryManager.class), new ShuffleWriteMetrics(), RssSparkConfig.toRssConf(conf)); + WriteBufferManager bufferManagerSpy = spy(bufferManager); + ShuffleDependency mockDependency = mock(ShuffleDependency.class); ShuffleWriteMetrics mockMetrics = mock(ShuffleWriteMetrics.class); Partitioner mockPartitioner = mock(Partitioner.class); @@ -274,35 +311,39 @@ public void postBlockEventTest() throws Exception { when(mockPartitioner.numPartitions()).thenReturn(2); List events = Lists.newArrayList(); - EventLoop eventLoop = new EventLoop("test") { - @Override - public void onReceive(AddBlockEvent event) { - events.add(event); - } + FakedDataPusher dataPusher = new FakedDataPusher( + event -> { + events.add(event); + return new CompletableFuture<>(); + } + ); - @Override - public void onError(Throwable e) { - } - }; RssShuffleManager mockShuffleManager = spy(TestUtils.createShuffleManager( sparkConf, false, - eventLoop, - JavaUtils.newConcurrentMap(), - JavaUtils.newConcurrentMap())); + dataPusher, + Maps.newConcurrentMap(), + Maps.newConcurrentMap())); RssShuffleHandle mockHandle = mock(RssShuffleHandle.class); when(mockHandle.getDependency()).thenReturn(mockDependency); ShuffleWriteClient mockWriteClient = mock(ShuffleWriteClient.class); - SparkConf conf = new SparkConf(); - conf.set(RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMIT.key(), "64") - .set(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.MEMORY_LOCALFILE.name()); + List shuffleBlockInfoList = createShuffleBlockList(1, 31); - RssShuffleWriter writer = new RssShuffleWriter<>("appId", 0, "taskId", 1L, - mockBufferManager, mockMetrics, mockShuffleManager, conf, mockWriteClient, mockHandle); + RssShuffleWriter writer = new RssShuffleWriter<>( + "appId", + 0, + "taskId", + 1L, + bufferManagerSpy, + mockMetrics, + mockShuffleManager, + conf, + mockWriteClient, + mockHandle + ); writer.postBlockEvent(shuffleBlockInfoList); - Thread.sleep(500); - assertEquals(1, events.size()); + Awaitility.await().timeout(Duration.ofSeconds(1)).until(() -> events.size() == 1); assertEquals(1, events.get(0).getShuffleDataInfoList().size()); events.clear(); diff --git a/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java b/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java index b0ed870e5a..eb9173babf 100644 --- a/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java +++ b/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java @@ -22,15 +22,19 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; import com.google.common.util.concurrent.ThreadFactoryBuilder; import io.netty.util.concurrent.DefaultThreadFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; -/** - * Provide a general method to create a thread factory to make the code more standardized - */ public class ThreadUtils { + private static final Logger LOGGER = LoggerFactory.getLogger(ThreadUtils.class); + /** + * Provide a general method to create a thread factory to make the code more standardized + */ public static ThreadFactory getThreadFactory(String factoryName) { return new ThreadFactoryBuilder().setDaemon(true).setNameFormat(factoryName + "-%d").build(); } @@ -74,4 +78,17 @@ public static ExecutorService getDaemonSingleThreadExecutor(String factoryName) public static ExecutorService getDaemonCachedThreadPool(String factoryName) { return Executors.newCachedThreadPool(getThreadFactory(factoryName)); } + + public static void shutdownThreadPool(ExecutorService threadPool, int waitSec) throws InterruptedException { + if (threadPool == null) { + return; + } + threadPool.shutdown(); + if (!threadPool.awaitTermination(waitSec, TimeUnit.SECONDS)) { + threadPool.shutdownNow(); + if (!threadPool.awaitTermination(waitSec, TimeUnit.SECONDS)) { + LOGGER.warn("Thread pool don't stop gracefully."); + } + } + } } diff --git a/common/src/test/java/org/apache/uniffle/common/util/ThreadUtilsTest.java b/common/src/test/java/org/apache/uniffle/common/util/ThreadUtilsTest.java new file mode 100644 index 0000000000..7fafa44219 --- /dev/null +++ b/common/src/test/java/org/apache/uniffle/common/util/ThreadUtilsTest.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.util; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ThreadUtilsTest { + + @Test + public void shutdownThreadPoolTest() throws InterruptedException { + ExecutorService executorService = Executors.newFixedThreadPool(2); + AtomicBoolean finished = new AtomicBoolean(false); + executorService.submit(() -> { + try { + Thread.sleep(100000); + } catch (InterruptedException interruptedException) { + // ignore + } finally { + finished.set(true); + } + }); + ThreadUtils.shutdownThreadPool(executorService, 1); + assertTrue(finished.get()); + assertTrue(executorService.isShutdown()); + } +}