diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java index a888975436..3910775aa9 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java @@ -17,6 +17,7 @@ package org.apache.spark.shuffle.writer; +import java.util.ArrayList; import java.util.List; import org.apache.uniffle.common.ShuffleBlockInfo; @@ -25,10 +26,23 @@ public class AddBlockEvent { private String taskId; private List shuffleDataInfoList; + private List processedCallbackChain; public AddBlockEvent(String taskId, List shuffleDataInfoList) { this.taskId = taskId; this.shuffleDataInfoList = shuffleDataInfoList; + this.processedCallbackChain = new ArrayList<>(); + } + + public AddBlockEvent(String taskId, List shuffleBlockInfoList, Runnable callback) { + this.taskId = taskId; + this.shuffleDataInfoList = shuffleBlockInfoList; + this.processedCallbackChain = new ArrayList<>(); + addCallback(callback); + } + + public void addCallback(Runnable callback) { + processedCallbackChain.add(callback); } public String getTaskId() { @@ -39,6 +53,10 @@ public List getShuffleDataInfoList() { return shuffleDataInfoList; } + public List getProcessedCallbackChain() { + return processedCallbackChain; + } + @Override public String toString() { return "AddBlockEvent: TaskId[" + taskId + "], " + shuffleDataInfoList; diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java new file mode 100644 index 0000000000..44cd9d1362 --- /dev/null +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.writer; + +import java.io.Closeable; +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +import com.google.common.collect.Queues; +import com.google.common.collect.Sets; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.uniffle.client.api.ShuffleWriteClient; +import org.apache.uniffle.client.response.SendShuffleDataResult; +import org.apache.uniffle.common.ShuffleBlockInfo; +import org.apache.uniffle.common.util.ThreadUtils; + +public class DataPusher implements Closeable { + private static final Logger LOGGER = LoggerFactory.getLogger(DataPusher.class); + + private final ExecutorService executorService; + + private final ShuffleWriteClient shuffleWriteClient; + private final Map> taskToSuccessBlockIds; + private final Map> taskToFailedBlockIds; + private final String appId; + private final Set failedTaskIds; + + public DataPusher(ShuffleWriteClient shuffleWriteClient, + Map> taskToSuccessBlockIds, + Map> taskToFailedBlockIds, + Set failedTaskIds, + String appId, + int threadPoolSize, + int threadKeepAliveTime) { + this.shuffleWriteClient = shuffleWriteClient; + this.taskToSuccessBlockIds = taskToSuccessBlockIds; + this.taskToFailedBlockIds = taskToFailedBlockIds; + this.failedTaskIds = failedTaskIds; + this.appId = appId; + + this.executorService = new ThreadPoolExecutor( + threadPoolSize, + threadPoolSize * 2, + threadKeepAliveTime, + TimeUnit.SECONDS, + Queues.newLinkedBlockingQueue(Integer.MAX_VALUE), + ThreadUtils.getThreadFactory(this.getClass().getName()) + ); + } + + public CompletableFuture send(AddBlockEvent event) { + return CompletableFuture.supplyAsync(() -> { + String taskId = event.getTaskId(); + List shuffleBlockInfoList = event.getShuffleDataInfoList(); + try { + SendShuffleDataResult result = shuffleWriteClient.sendShuffleData( + appId, + shuffleBlockInfoList, + () -> !isValidTask(taskId) + ); + putBlockId(taskToSuccessBlockIds, taskId, result.getSuccessBlockIds()); + putBlockId(taskToFailedBlockIds, taskId, result.getFailedBlockIds()); + } finally { + List callbackChain = Optional.of(event.getProcessedCallbackChain()).orElse(Collections.EMPTY_LIST); + for (Runnable runnable : callbackChain) { + runnable.run(); + } + } + return shuffleBlockInfoList.stream() + .map(x -> x.getFreeMemory()) + .reduce((a, b) -> a + b) + .get(); + }, executorService); + } + + private synchronized void putBlockId( + Map> taskToBlockIds, + String taskAttemptId, + Set blockIds) { + if (blockIds == null || blockIds.isEmpty()) { + return; + } + taskToBlockIds.putIfAbsent(taskAttemptId, Sets.newConcurrentHashSet()); + taskToBlockIds.get(taskAttemptId).addAll(blockIds); + } + + public boolean isValidTask(String taskId) { + return !failedTaskIds.contains(taskId); + } + + @Override + public void close() throws IOException { + if (executorService != null) { + executorService.shutdown(); + } + } +} diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java index 9f33be388f..08f0218208 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java @@ -17,10 +17,13 @@ package org.apache.spark.shuffle.writer; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Function; import com.clearspring.analytics.util.Lists; import com.google.common.annotations.VisibleForTesting; @@ -42,6 +45,7 @@ import org.apache.uniffle.common.ShuffleBlockInfo; import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.compression.Codec; +import org.apache.uniffle.common.config.RssClientConf; import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.util.ChecksumUtils; @@ -61,6 +65,7 @@ public class WriteBufferManager extends MemoryConsumer { private Map partitionToSeqNo = Maps.newHashMap(); private long askExecutorMemory; private int shuffleId; + private String taskId; private long taskAttemptId; private SerializerInstance instance; private ShuffleWriteMetrics shuffleWriteMetrics; @@ -81,6 +86,8 @@ public class WriteBufferManager extends MemoryConsumer { private long requireMemoryInterval; private int requireMemoryRetryMax; private Codec codec; + private Function> spillAsyncFunc; + private long sendSizeLimit; public WriteBufferManager( int shuffleId, @@ -91,12 +98,38 @@ public WriteBufferManager( TaskMemoryManager taskMemoryManager, ShuffleWriteMetrics shuffleWriteMetrics, RssConf rssConf) { + this( + shuffleId, + null, + taskAttemptId, + bufferManagerOptions, + serializer, + partitionToServers, + taskMemoryManager, + shuffleWriteMetrics, + rssConf, + null + ); + } + + public WriteBufferManager( + int shuffleId, + String taskId, + long taskAttemptId, + BufferManagerOptions bufferManagerOptions, + Serializer serializer, + Map> partitionToServers, + TaskMemoryManager taskMemoryManager, + ShuffleWriteMetrics shuffleWriteMetrics, + RssConf rssConf, + Function> spillFunc) { super(taskMemoryManager, taskMemoryManager.pageSizeBytes(), MemoryMode.ON_HEAP); this.bufferSize = bufferManagerOptions.getBufferSize(); this.spillSize = bufferManagerOptions.getBufferSpillThreshold(); this.instance = serializer.newInstance(); this.buffers = Maps.newHashMap(); this.shuffleId = shuffleId; + this.taskId = taskId; this.taskAttemptId = taskAttemptId; this.partitionToServers = partitionToServers; this.shuffleWriteMetrics = shuffleWriteMetrics; @@ -111,6 +144,8 @@ public WriteBufferManager( .substring(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX.length()), RssSparkConfig.SPARK_SHUFFLE_COMPRESS_DEFAULT); this.codec = compress ? Codec.newInstance(rssConf) : null; + this.spillAsyncFunc = spillFunc; + this.sendSizeLimit = rssConf.get(RssClientConf.RSS_CLIENT_SEND_SIZE_LIMIT); } public List addRecord(int partitionId, Object key, Object value) { @@ -247,10 +282,44 @@ private void requestExecutorMemory(long leastMem) { } } + public List buildBlockEvents(List shuffleBlockInfoList) { + long totalSize = 0; + long memoryUsed = 0; + List events = new ArrayList<>(); + List shuffleBlockInfosPerEvent = com.google.common.collect.Lists.newArrayList(); + for (ShuffleBlockInfo sbi : shuffleBlockInfoList) { + totalSize += sbi.getSize(); + memoryUsed += sbi.getFreeMemory(); + shuffleBlockInfosPerEvent.add(sbi); + // split shuffle data according to the size + if (totalSize > sendSizeLimit) { + LOG.info("Post event to queue with " + shuffleBlockInfosPerEvent.size() + + " blocks and " + totalSize + " bytes"); + final long _memoryUsed = memoryUsed; + events.add( + new AddBlockEvent(taskId, shuffleBlockInfosPerEvent, () -> freeAllocatedMemory(_memoryUsed)) + ); + shuffleBlockInfosPerEvent = com.google.common.collect.Lists.newArrayList(); + totalSize = 0; + memoryUsed = 0; + } + } + if (!shuffleBlockInfosPerEvent.isEmpty()) { + LOG.info("Post event to queue with " + shuffleBlockInfosPerEvent.size() + + " blocks and " + totalSize + " bytes"); + final long _memoryUsed = memoryUsed; + events.add( + new AddBlockEvent(taskId, shuffleBlockInfosPerEvent, () -> freeAllocatedMemory(_memoryUsed)) + ); + } + return events; + } + @Override public long spill(long size, MemoryConsumer trigger) { - // there is no spill for such situation - return 0; + List events = buildBlockEvents(clear()); + events.stream().forEach(x -> spillAsyncFunc.apply(x)); + return 0L; } @VisibleForTesting diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index e70026ac1a..3a2680c8b2 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -17,16 +17,17 @@ package org.apache.spark.shuffle; +import java.io.IOException; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.stream.Collectors; @@ -48,6 +49,7 @@ import org.apache.spark.shuffle.reader.RssShuffleReader; import org.apache.spark.shuffle.writer.AddBlockEvent; import org.apache.spark.shuffle.writer.BufferManagerOptions; +import org.apache.spark.shuffle.writer.DataPusher; import org.apache.spark.shuffle.writer.RssShuffleWriter; import org.apache.spark.shuffle.writer.WriteBufferManager; import org.apache.spark.sql.internal.SQLConf; @@ -64,12 +66,10 @@ import org.apache.uniffle.client.api.ShuffleWriteClient; import org.apache.uniffle.client.factory.ShuffleClientFactory; -import org.apache.uniffle.client.response.SendShuffleDataResult; import org.apache.uniffle.client.util.ClientUtils; import org.apache.uniffle.common.PartitionRange; import org.apache.uniffle.common.RemoteStorageInfo; import org.apache.uniffle.common.ShuffleAssignmentsInfo; -import org.apache.uniffle.common.ShuffleBlockInfo; import org.apache.uniffle.common.ShuffleDataDistributionType; import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.config.RssClientConf; @@ -97,7 +97,6 @@ public class RssShuffleManager implements ShuffleManager { private ShuffleWriteClient shuffleWriteClient; private final Map> taskToSuccessBlockIds; private final Map> taskToFailedBlockIds; - private Map taskToBufferManager = Maps.newConcurrentMap(); private ScheduledExecutorService heartBeatScheduledExecutorService; private boolean heartbeatStarted = false; private boolean dynamicConfEnabled = false; @@ -105,55 +104,7 @@ public class RssShuffleManager implements ShuffleManager { private String user; private String uuid; private Set failedTaskIds = Sets.newConcurrentHashSet(); - private final EventLoop eventLoop; - private final EventLoop defaultEventLoop = new EventLoop("ShuffleDataQueue") { - - @Override - public void onReceive(AddBlockEvent event) { - threadPoolExecutor.execute(() -> sendShuffleData(event.getTaskId(), event.getShuffleDataInfoList())); - } - - @Override - public void onError(Throwable throwable) { - LOG.info("Shuffle event loop error...", throwable); - } - - @Override - public void onStart() { - LOG.info("Shuffle event loop start..."); - } - - private void sendShuffleData(String taskId, List shuffleDataInfoList) { - try { - SendShuffleDataResult result = shuffleWriteClient.sendShuffleData( - id.get(), - shuffleDataInfoList, - () -> !isValidTask(taskId) - ); - putBlockId(taskToSuccessBlockIds, taskId, result.getSuccessBlockIds()); - putBlockId(taskToFailedBlockIds, taskId, result.getFailedBlockIds()); - } finally { - final AtomicLong releaseSize = new AtomicLong(0); - shuffleDataInfoList.forEach((sbi) -> releaseSize.addAndGet(sbi.getFreeMemory())); - WriteBufferManager bufferManager = taskToBufferManager.get(taskId); - if (bufferManager != null) { - bufferManager.freeAllocatedMemory(releaseSize.get()); - } - LOG.debug("Spark 3.0 finish send data and release " + releaseSize + " bytes"); - } - } - - private synchronized void putBlockId( - Map> taskToBlockIds, - String taskAttemptId, - Set blockIds) { - if (blockIds == null || blockIds.isEmpty()) { - return; - } - taskToBlockIds.putIfAbsent(taskAttemptId, Sets.newConcurrentHashSet()); - taskToBlockIds.get(taskAttemptId).addAll(blockIds); - } - }; + private DataPusher dataPusher; public RssShuffleManager(SparkConf conf, boolean isDriver) { this.sparkConf = conf; @@ -213,9 +164,6 @@ public RssShuffleManager(SparkConf conf, boolean isDriver) { taskToSuccessBlockIds = Maps.newConcurrentMap(); taskToFailedBlockIds = Maps.newConcurrentMap(); // for non-driver executor, start a thread for sending shuffle data to shuffle server - LOG.info("RSS data send thread is starting"); - eventLoop = defaultEventLoop; - eventLoop.start(); int poolSize = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_SIZE); int keepAliveTime = sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_KEEPALIVE); threadPoolExecutor = new ThreadPoolExecutor(poolSize, poolSize * 2, keepAliveTime, TimeUnit.SECONDS, @@ -225,6 +173,22 @@ public RssShuffleManager(SparkConf conf, boolean isDriver) { heartBeatScheduledExecutorService = Executors.newSingleThreadScheduledExecutor( ThreadUtils.getThreadFactory("rss-heartbeat-%d")); } + this.dataPusher = new DataPusher( + shuffleWriteClient, + taskToSuccessBlockIds, + taskToFailedBlockIds, + failedTaskIds, + id.get(), + poolSize, + keepAliveTime + ); + } + + public CompletableFuture sendData(AddBlockEvent event) { + if (dataPusher != null && event != null) { + return dataPusher.send(event); + } + return new CompletableFuture<>(); } @VisibleForTesting @@ -284,12 +248,6 @@ protected static ShuffleDataDistributionType getDataDistributionType(SparkConf s ); this.taskToSuccessBlockIds = taskToSuccessBlockIds; this.taskToFailedBlockIds = taskToFailedBlockIds; - if (loop != null) { - eventLoop = loop; - } else { - eventLoop = defaultEventLoop; - } - eventLoop.start(); threadPoolExecutor = null; heartBeatScheduledExecutorService = null; } @@ -400,10 +358,9 @@ public ShuffleWriter getWriter( writeMetrics = context.taskMetrics().shuffleWriteMetrics(); } WriteBufferManager bufferManager = new WriteBufferManager( - shuffleId, context.taskAttemptId(), bufferOptions, rssHandle.getDependency().serializer(), + shuffleId, taskId, context.taskAttemptId(), bufferOptions, rssHandle.getDependency().serializer(), rssHandle.getPartitionToServers(), context.taskMemoryManager(), - writeMetrics, RssSparkConfig.toRssConf(sparkConf)); - taskToBufferManager.put(taskId, bufferManager); + writeMetrics, RssSparkConfig.toRssConf(sparkConf), this::sendData); LOG.info("RssHandle appId {} shuffleId {} ", rssHandle.getAppId(), rssHandle.getShuffleId()); return new RssShuffleWriter<>(rssHandle.getAppId(), shuffleId, taskId, context.taskAttemptId(), bufferManager, writeMetrics, this, sparkConf, shuffleWriteClient, rssHandle, @@ -665,15 +622,18 @@ public void stop() { if (shuffleWriteClient != null) { shuffleWriteClient.close(); } - if (eventLoop != null) { - eventLoop.stop(); + if (dataPusher != null) { + try { + dataPusher.close(); + } catch (IOException e) { + // ignore + } } } public void clearTaskMeta(String taskId) { taskToSuccessBlockIds.remove(taskId); taskToFailedBlockIds.remove(taskId); - taskToBufferManager.remove(taskId); } @VisibleForTesting @@ -735,12 +695,6 @@ private synchronized void startHeartbeat() { } } - public void postEvent(AddBlockEvent addBlockEvent) { - if (eventLoop != null) { - eventLoop.post(addBlockEvent); - } - } - public Set getFailedBlockIds(String taskId) { Set result = taskToFailedBlockIds.get(taskId); if (result == null) { diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java index d0a6ea718d..223f18d1bb 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java @@ -231,26 +231,8 @@ private void processShuffleBlockInfos(List shuffleBlockInfoLis } protected void postBlockEvent(List shuffleBlockInfoList) { - long totalSize = 0; - List shuffleBlockInfosPerEvent = Lists.newArrayList(); - for (ShuffleBlockInfo sbi : shuffleBlockInfoList) { - totalSize += sbi.getSize(); - shuffleBlockInfosPerEvent.add(sbi); - // split shuffle data according to the size - if (totalSize > sendSizeLimit) { - LOG.debug("Post event to queue with " + shuffleBlockInfosPerEvent.size() - + " blocks and " + totalSize + " bytes"); - shuffleManager.postEvent( - new AddBlockEvent(taskId, shuffleBlockInfosPerEvent)); - shuffleBlockInfosPerEvent = Lists.newArrayList(); - totalSize = 0; - } - } - if (!shuffleBlockInfosPerEvent.isEmpty()) { - LOG.debug("Post event to queue with " + shuffleBlockInfosPerEvent.size() - + " blocks and " + totalSize + " bytes"); - shuffleManager.postEvent( - new AddBlockEvent(taskId, shuffleBlockInfosPerEvent)); + for (AddBlockEvent event : bufferManager.buildBlockEvents(shuffleBlockInfoList)) { + shuffleManager.sendData(event); } } diff --git a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java index 119ab4e6b6..eb180f516d 100644 --- a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java +++ b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java @@ -43,4 +43,10 @@ public class RssClientConf { .defaultValue(ShuffleDataDistributionType.NORMAL) .withDescription("The type of partition shuffle data distribution, including normal and local_order. " + "The default value is normal. This config is only valid in Spark3.x"); + + public static final ConfigOption RSS_CLIENT_SEND_SIZE_LIMIT = ConfigOptions + .key("rss.client.send.size.limit") + .longType() + .defaultValue(1024 * 1024 * 16L) + .withDescription("The max data size sent to shuffle server"); }