Skip to content

Commit

Permalink
[apache#706] Implement spill method to avoid memory preemption
Browse files Browse the repository at this point in the history
  • Loading branch information
zuston committed Mar 20, 2023
1 parent d60d675 commit 9b614d1
Show file tree
Hide file tree
Showing 6 changed files with 247 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.shuffle.writer;

import java.util.ArrayList;
import java.util.List;

import org.apache.uniffle.common.ShuffleBlockInfo;
Expand All @@ -25,10 +26,23 @@ public class AddBlockEvent {

private String taskId;
private List<ShuffleBlockInfo> shuffleDataInfoList;
private List<Runnable> processedCallbackChain;

public AddBlockEvent(String taskId, List<ShuffleBlockInfo> shuffleDataInfoList) {
this.taskId = taskId;
this.shuffleDataInfoList = shuffleDataInfoList;
this.processedCallbackChain = new ArrayList<>();
}

public AddBlockEvent(String taskId, List<ShuffleBlockInfo> shuffleBlockInfoList, Runnable callback) {
this.taskId = taskId;
this.shuffleDataInfoList = shuffleBlockInfoList;
this.processedCallbackChain = new ArrayList<>();
addCallback(callback);
}

public void addCallback(Runnable callback) {
processedCallbackChain.add(callback);
}

public String getTaskId() {
Expand All @@ -39,6 +53,10 @@ public List<ShuffleBlockInfo> getShuffleDataInfoList() {
return shuffleDataInfoList;
}

public List<Runnable> getProcessedCallbackChain() {
return processedCallbackChain;
}

@Override
public String toString() {
return "AddBlockEvent: TaskId[" + taskId + "], " + shuffleDataInfoList;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.writer;

import java.io.Closeable;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

import com.google.common.collect.Queues;
import com.google.common.collect.Sets;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.response.SendShuffleDataResult;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.util.ThreadUtils;

public class DataPusher implements Closeable {
private static final Logger LOGGER = LoggerFactory.getLogger(DataPusher.class);

private final ExecutorService executorService;

private final ShuffleWriteClient shuffleWriteClient;
private final Map<String, Set<Long>> taskToSuccessBlockIds;
private final Map<String, Set<Long>> taskToFailedBlockIds;
private final String appId;
private final Set<String> failedTaskIds;

public DataPusher(ShuffleWriteClient shuffleWriteClient,
Map<String, Set<Long>> taskToSuccessBlockIds,
Map<String, Set<Long>> taskToFailedBlockIds,
Set<String> failedTaskIds,
String appId,
int threadPoolSize,
int threadKeepAliveTime) {
this.shuffleWriteClient = shuffleWriteClient;
this.taskToSuccessBlockIds = taskToSuccessBlockIds;
this.taskToFailedBlockIds = taskToFailedBlockIds;
this.failedTaskIds = failedTaskIds;
this.appId = appId;

this.executorService = new ThreadPoolExecutor(
threadPoolSize,
threadPoolSize * 2,
threadKeepAliveTime,
TimeUnit.SECONDS,
Queues.newLinkedBlockingQueue(Integer.MAX_VALUE),
ThreadUtils.getThreadFactory(this.getClass().getName())
);
}

public CompletableFuture<Long> send(AddBlockEvent event) {
return CompletableFuture.supplyAsync(() -> {
String taskId = event.getTaskId();
List<ShuffleBlockInfo> shuffleBlockInfoList = event.getShuffleDataInfoList();
try {
SendShuffleDataResult result = shuffleWriteClient.sendShuffleData(
appId,
shuffleBlockInfoList,
() -> !isValidTask(taskId)
);
putBlockId(taskToSuccessBlockIds, taskId, result.getSuccessBlockIds());
putBlockId(taskToFailedBlockIds, taskId, result.getFailedBlockIds());
} finally {
List<Runnable> callbackChain = Optional.of(event.getProcessedCallbackChain()).orElse(Collections.EMPTY_LIST);
for (Runnable runnable : callbackChain) {
runnable.run();
}
}
return shuffleBlockInfoList.stream()
.map(x -> x.getFreeMemory())
.reduce((a, b) -> a + b)
.get();
}, executorService);
}

private synchronized void putBlockId(
Map<String, Set<Long>> taskToBlockIds,
String taskAttemptId,
Set<Long> blockIds) {
if (blockIds == null || blockIds.isEmpty()) {
return;
}
taskToBlockIds.putIfAbsent(taskAttemptId, Sets.newConcurrentHashSet());
taskToBlockIds.get(taskAttemptId).addAll(blockIds);
}

public boolean isValidTask(String taskId) {
return !failedTaskIds.contains(taskId);
}

@Override
public void close() throws IOException {
if (executorService != null) {
executorService.shutdown();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@

package org.apache.spark.shuffle.writer;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;

import com.clearspring.analytics.util.Lists;
import com.google.common.annotations.VisibleForTesting;
Expand All @@ -42,6 +45,7 @@
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.compression.Codec;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.ChecksumUtils;
Expand All @@ -61,6 +65,7 @@ public class WriteBufferManager extends MemoryConsumer {
private Map<Integer, Integer> partitionToSeqNo = Maps.newHashMap();
private long askExecutorMemory;
private int shuffleId;
private String taskId;
private long taskAttemptId;
private SerializerInstance instance;
private ShuffleWriteMetrics shuffleWriteMetrics;
Expand All @@ -81,6 +86,8 @@ public class WriteBufferManager extends MemoryConsumer {
private long requireMemoryInterval;
private int requireMemoryRetryMax;
private Codec codec;
private Function<AddBlockEvent, CompletableFuture<Long>> spillAsyncFunc;
private long sendSizeLimit;

public WriteBufferManager(
int shuffleId,
Expand All @@ -91,12 +98,38 @@ public WriteBufferManager(
TaskMemoryManager taskMemoryManager,
ShuffleWriteMetrics shuffleWriteMetrics,
RssConf rssConf) {
this(
shuffleId,
null,
taskAttemptId,
bufferManagerOptions,
serializer,
partitionToServers,
taskMemoryManager,
shuffleWriteMetrics,
rssConf,
null
);
}

public WriteBufferManager(
int shuffleId,
String taskId,
long taskAttemptId,
BufferManagerOptions bufferManagerOptions,
Serializer serializer,
Map<Integer, List<ShuffleServerInfo>> partitionToServers,
TaskMemoryManager taskMemoryManager,
ShuffleWriteMetrics shuffleWriteMetrics,
RssConf rssConf,
Function<AddBlockEvent, CompletableFuture<Long>> spillFunc) {
super(taskMemoryManager, taskMemoryManager.pageSizeBytes(), MemoryMode.ON_HEAP);
this.bufferSize = bufferManagerOptions.getBufferSize();
this.spillSize = bufferManagerOptions.getBufferSpillThreshold();
this.instance = serializer.newInstance();
this.buffers = Maps.newHashMap();
this.shuffleId = shuffleId;
this.taskId = taskId;
this.taskAttemptId = taskAttemptId;
this.partitionToServers = partitionToServers;
this.shuffleWriteMetrics = shuffleWriteMetrics;
Expand All @@ -111,6 +144,8 @@ public WriteBufferManager(
.substring(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX.length()),
RssSparkConfig.SPARK_SHUFFLE_COMPRESS_DEFAULT);
this.codec = compress ? Codec.newInstance(rssConf) : null;
this.spillAsyncFunc = spillFunc;
this.sendSizeLimit = rssConf.get(RssClientConf.RSS_CLIENT_SEND_SIZE_LIMIT);
}

public List<ShuffleBlockInfo> addRecord(int partitionId, Object key, Object value) {
Expand Down Expand Up @@ -247,10 +282,44 @@ private void requestExecutorMemory(long leastMem) {
}
}

public List<AddBlockEvent> buildBlockEvents(List<ShuffleBlockInfo> shuffleBlockInfoList) {
long totalSize = 0;
long memoryUsed = 0;
List<AddBlockEvent> events = new ArrayList<>();
List<ShuffleBlockInfo> shuffleBlockInfosPerEvent = com.google.common.collect.Lists.newArrayList();
for (ShuffleBlockInfo sbi : shuffleBlockInfoList) {
totalSize += sbi.getSize();
memoryUsed += sbi.getFreeMemory();
shuffleBlockInfosPerEvent.add(sbi);
// split shuffle data according to the size
if (totalSize > sendSizeLimit) {
LOG.info("Post event to queue with " + shuffleBlockInfosPerEvent.size()
+ " blocks and " + totalSize + " bytes");
final long _memoryUsed = memoryUsed;
events.add(
new AddBlockEvent(taskId, shuffleBlockInfosPerEvent, () -> freeAllocatedMemory(_memoryUsed))
);
shuffleBlockInfosPerEvent = com.google.common.collect.Lists.newArrayList();
totalSize = 0;
memoryUsed = 0;
}
}
if (!shuffleBlockInfosPerEvent.isEmpty()) {
LOG.info("Post event to queue with " + shuffleBlockInfosPerEvent.size()
+ " blocks and " + totalSize + " bytes");
final long _memoryUsed = memoryUsed;
events.add(
new AddBlockEvent(taskId, shuffleBlockInfosPerEvent, () -> freeAllocatedMemory(_memoryUsed))
);
}
return events;
}

@Override
public long spill(long size, MemoryConsumer trigger) {
// there is no spill for such situation
return 0;
List<AddBlockEvent> events = buildBlockEvents(clear());
events.stream().forEach(x -> spillAsyncFunc.apply(x));
return 0L;
}

@VisibleForTesting
Expand Down
Loading

0 comments on commit 9b614d1

Please sign in to comment.