diff --git a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/congestcontrol/CongestionController.java b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/congestcontrol/CongestionController.java index 0a5cc3eaee9..d05ddd58291 100644 --- a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/congestcontrol/CongestionController.java +++ b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/congestcontrol/CongestionController.java @@ -113,29 +113,6 @@ public static CongestionController instance() { return _INSTANCE; } - private static class UserBufferInfo { - long timestamp; - final BufferStatusHub bufferStatusHub; - - public UserBufferInfo(long timestamp, BufferStatusHub bufferStatusHub) { - this.timestamp = timestamp; - this.bufferStatusHub = bufferStatusHub; - } - - synchronized void updateInfo(long timestamp, BufferStatusHub.BufferStatusNode node) { - this.timestamp = timestamp; - this.bufferStatusHub.add(timestamp, node); - } - - public long getTimestamp() { - return timestamp; - } - - public BufferStatusHub getBufferStatusHub() { - return bufferStatusHub; - } - } - /** * 1. If the total pending bytes is over high watermark, will congest users who produce speed is * higher than the potential average consume speed. @@ -166,24 +143,19 @@ public boolean isUserCongested(UserIdentifier userIdentifier) { return false; } - public void produceBytes(UserIdentifier userIdentifier, int numBytes) { - long currentTimeMillis = System.currentTimeMillis(); - UserBufferInfo userBufferInfo = - userBufferStatuses.computeIfAbsent( - userIdentifier, - user -> { - logger.info("New user {} comes, initializing its rate status", user); - BufferStatusHub bufferStatusHub = new BufferStatusHub(sampleTimeWindowSeconds); - UserBufferInfo userInfo = new UserBufferInfo(currentTimeMillis, bufferStatusHub); - workerSource.addGauge( - WorkerSource.USER_PRODUCE_SPEED(), - userIdentifier.toJMap(), - () -> getUserProduceSpeed(userInfo)); - return userInfo; - }); - - BufferStatusHub.BufferStatusNode node = new BufferStatusHub.BufferStatusNode(numBytes); - userBufferInfo.updateInfo(currentTimeMillis, node); + public UserBufferInfo getUserBuffer(UserIdentifier userIdentifier) { + return userBufferStatuses.computeIfAbsent( + userIdentifier, + user -> { + logger.info("New user {} comes, initializing its rate status", user); + BufferStatusHub bufferStatusHub = new BufferStatusHub(sampleTimeWindowSeconds); + UserBufferInfo userInfo = new UserBufferInfo(System.currentTimeMillis(), bufferStatusHub); + workerSource.addGauge( + WorkerSource.USER_PRODUCE_SPEED(), + userIdentifier.toJMap(), + () -> getUserProduceSpeed(userInfo)); + return userInfo; + }); } public void consumeBytes(int numBytes) { diff --git a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/congestcontrol/UserBufferInfo.java b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/congestcontrol/UserBufferInfo.java new file mode 100644 index 00000000000..4d12f59081b --- /dev/null +++ b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/congestcontrol/UserBufferInfo.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.service.deploy.worker.congestcontrol; + +public class UserBufferInfo { + long timestamp; + final BufferStatusHub bufferStatusHub; + + public UserBufferInfo(long timestamp, BufferStatusHub bufferStatusHub) { + this.timestamp = timestamp; + this.bufferStatusHub = bufferStatusHub; + } + + public synchronized void updateInfo(long timestamp, BufferStatusHub.BufferStatusNode node) { + this.timestamp = timestamp; + this.bufferStatusHub.add(timestamp, node); + } + + public long getTimestamp() { + return timestamp; + } + + public BufferStatusHub getBufferStatusHub() { + return bufferStatusHub; + } +} diff --git a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionDataWriter.java b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionDataWriter.java index a39a4f9d7cb..589832dbe69 100644 --- a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionDataWriter.java +++ b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionDataWriter.java @@ -21,7 +21,6 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.FileChannel; -import java.util.Optional; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; @@ -50,7 +49,9 @@ import org.apache.celeborn.common.unsafe.Platform; import org.apache.celeborn.common.util.FileChannelUtils; import org.apache.celeborn.service.deploy.worker.WorkerSource; +import org.apache.celeborn.service.deploy.worker.congestcontrol.BufferStatusHub; import org.apache.celeborn.service.deploy.worker.congestcontrol.CongestionController; +import org.apache.celeborn.service.deploy.worker.congestcontrol.UserBufferInfo; import org.apache.celeborn.service.deploy.worker.memory.MemoryManager; /* @@ -104,6 +105,8 @@ public abstract class PartitionDataWriter implements DeviceObserver { private boolean metricsCollectCriticalEnabled; private long chunkSize; + private UserBufferInfo userBufferInfo = null; + public PartitionDataWriter( StorageManager storageManager, AbstractSource workerSource, @@ -155,6 +158,10 @@ public PartitionDataWriter( this.mapIdBitMap = new RoaringBitmap(); } takeBuffer(); + CongestionController congestionController = CongestionController.instance(); + if (!isMemoryShuffleFile.get() && congestionController != null) { + userBufferInfo = congestionController.getUserBuffer(getDiskFileInfo().getUserIdentifier()); + } } public void initFileChannelsForDiskFile() throws IOException { @@ -294,10 +301,10 @@ public void write(ByteBuf data) throws IOException { MemoryManager.instance().increaseMemoryFileStorage(numBytes); } else { MemoryManager.instance().incrementDiskBuffer(numBytes); - Optional.ofNullable(CongestionController.instance()) - .ifPresent( - congestionController -> - congestionController.produceBytes(diskFileInfo.getUserIdentifier(), numBytes)); + if (userBufferInfo != null) { + userBufferInfo.updateInfo( + System.currentTimeMillis(), new BufferStatusHub.BufferStatusNode(numBytes)); + } } synchronized (flushLock) { diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/congestcontrol/TestCongestionController.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/congestcontrol/TestCongestionController.java index b8fb5cadc67..deb533fa43c 100644 --- a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/congestcontrol/TestCongestionController.java +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/congestcontrol/TestCongestionController.java @@ -69,7 +69,7 @@ public void testSingleUser() { Assert.assertFalse(controller.isUserCongested(userIdentifier)); - controller.produceBytes(userIdentifier, 1001); + produceBytes(userIdentifier, 1001); pendingBytes = 1001; controller.checkCongestion(); Assert.assertTrue(controller.isUserCongested(userIdentifier)); @@ -91,8 +91,8 @@ public void testMultipleUsers() { // If pendingBytes exceed the high watermark, user1 produce speed > avg consume speed // While user2 produce speed < avg consume speed - controller.produceBytes(user1, 800); - controller.produceBytes(user2, 201); + produceBytes(user1, 800); + produceBytes(user2, 201); controller.consumeBytes(500); pendingBytes = 1001; controller.checkCongestion(); @@ -100,8 +100,8 @@ public void testMultipleUsers() { Assert.assertFalse(controller.isUserCongested(user2)); // If both users higher than the avg consume speed, should congest them all. - controller.produceBytes(user1, 800); - controller.produceBytes(user2, 800); + produceBytes(user1, 800); + produceBytes(user2, 800); controller.consumeBytes(500); pendingBytes = 1600; controller.checkCongestion(); @@ -119,7 +119,7 @@ public void testMultipleUsers() { public void testUserMetrics() throws InterruptedException { UserIdentifier user = new UserIdentifier("test", "celeborn"); Assert.assertFalse(controller.isUserCongested(user)); - controller.produceBytes(user, 800); + produceBytes(user, 800); Assert.assertTrue( isGaugeExist( @@ -159,4 +159,10 @@ private boolean isGaugeExist(String name, Map labels) { .count() == 1; } + + private void produceBytes(UserIdentifier userIdentifier, long numBytes) { + controller + .getUserBuffer(userIdentifier) + .updateInfo(System.currentTimeMillis(), new BufferStatusHub.BufferStatusNode(numBytes)); + } }