diff --git a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/PooledHdfsShuffleWriteHandler.java b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/PooledHdfsShuffleWriteHandler.java index bb891a88fe..1cf85d8e8b 100644 --- a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/PooledHdfsShuffleWriteHandler.java +++ b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/PooledHdfsShuffleWriteHandler.java @@ -19,6 +19,7 @@ import java.util.List; import java.util.concurrent.LinkedBlockingDeque; +import java.util.function.Function; import com.google.common.annotations.VisibleForTesting; import org.apache.commons.lang.StringUtils; @@ -45,6 +46,8 @@ public class PooledHdfsShuffleWriteHandler implements ShuffleWriteHandler { private final LinkedBlockingDeque queue; private final int maxConcurrency; private final String basePath; + private Function createWriterFunc; + private volatile int initializedHandlerCnt = 0; // Only for tests @VisibleForTesting @@ -54,6 +57,17 @@ public PooledHdfsShuffleWriteHandler(LinkedBlockingDeque qu this.basePath = StringUtils.EMPTY; } + @VisibleForTesting + public PooledHdfsShuffleWriteHandler( + LinkedBlockingDeque queue, + int maxConcurrency, + Function createWriterFunc) { + this.queue = queue; + this.maxConcurrency = maxConcurrency; + this.basePath = StringUtils.EMPTY; + this.createWriterFunc = createWriterFunc; + } + public PooledHdfsShuffleWriteHandler( String appId, int shuffleId, @@ -70,31 +84,34 @@ public PooledHdfsShuffleWriteHandler( this.basePath = ShuffleStorageUtils.getFullShuffleDataFolder(storageBasePath, ShuffleStorageUtils.getShuffleDataPath(appId, shuffleId, startPartition, endPartition)); - // todo: support init lazily - try { - for (int i = 0; i < maxConcurrency; i++) { - // Use add() here because we are sure the capacity will not be exceeded. - // Note: add() throws IllegalStateException when queue is full. - queue.add( - new HdfsShuffleWriteHandler( - appId, - shuffleId, - startPartition, - endPartition, - storageBasePath, - fileNamePrefix + "_" + i, - hadoopConf, - user - ) + this.createWriterFunc = index -> { + try { + return new HdfsShuffleWriteHandler( + appId, + shuffleId, + startPartition, + endPartition, + storageBasePath, + fileNamePrefix + "_" + index, + hadoopConf, + user ); + } catch (Exception e) { + throw new RssException("Errors on initializing Hdfs writer handler.", e); } - } catch (Exception e) { - throw new RssException("Errors on initializing Hdfs writer handler.", e); - } + }; } @Override public void write(List shuffleBlocks) throws Exception { + if (queue.isEmpty() && initializedHandlerCnt < maxConcurrency) { + synchronized (this) { + if (initializedHandlerCnt < maxConcurrency) { + queue.add(createWriterFunc.apply(initializedHandlerCnt++)); + } + } + } + if (queue.isEmpty()) { LOGGER.warn("No free hdfs writer handler, it will wait. storage path: {}", basePath); } @@ -107,4 +124,9 @@ public void write(List shuffleBlocks) throws Exception queue.addFirst(writeHandler); } } + + @VisibleForTesting + protected int getInitializedHandlerCnt() { + return initializedHandlerCnt; + } } diff --git a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/PooledHdfsShuffleWriteHandlerTest.java b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/PooledHdfsShuffleWriteHandlerTest.java index ca62b8352f..e3d655b245 100644 --- a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/PooledHdfsShuffleWriteHandlerTest.java +++ b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/PooledHdfsShuffleWriteHandlerTest.java @@ -20,6 +20,7 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.LinkedBlockingDeque; @@ -46,6 +47,13 @@ static class FakedShuffleWriteHandler implements ShuffleWriteHandler { this.execution = runnable; } + FakedShuffleWriteHandler(List initializedList, List invokedList, int index, Runnable runnable) { + initializedList.add(index); + this.invokedList = invokedList; + this.index = index; + this.execution = runnable; + } + @Override public void write(List shuffleBlocks) throws Exception { execution.run(); @@ -53,6 +61,52 @@ public void write(List shuffleBlocks) throws Exception } } + @Test + public void lazyInitializeWriterHandlerTest() throws Exception { + int maxConcurrency = 5; + LinkedBlockingDeque deque = new LinkedBlockingDeque(maxConcurrency); + + CopyOnWriteArrayList invokedList = new CopyOnWriteArrayList<>(); + CopyOnWriteArrayList initializedList = new CopyOnWriteArrayList<>(); + + PooledHdfsShuffleWriteHandler handler = new PooledHdfsShuffleWriteHandler( + deque, + maxConcurrency, + index -> new FakedShuffleWriteHandler(initializedList, invokedList, index, () -> { + try { + Thread.sleep(10); + } catch (Exception e) { + // ignore + } + }) + ); + + // case1: no race condition + for (int i = 0; i < 10; i++) { + handler.write(Collections.emptyList()); + assertEquals(1, initializedList.size()); + } + + // case2: initialized by multi threads + invokedList.clear(); + CountDownLatch latch = new CountDownLatch(100); + for (int i = 0; i < 100; i++) { + new Thread(() -> { + try { + handler.write(Collections.emptyList()); + } catch (Exception e) { + // ignore + } finally { + latch.countDown(); + } + }).start(); + } + latch.await(); + assertEquals(100, invokedList.size()); + assertEquals(5, initializedList.size()); + assertEquals(5, handler.getInitializedHandlerCnt()); + } + @Test public void writeSameFileWhenNoRaceCondition() throws Exception { int concurrency = 5;