From 55e85055b80111b6f3662474c8fd500bf6589f2a Mon Sep 17 00:00:00 2001 From: SteNicholas Date: Fri, 22 Sep 2023 11:36:45 +0800 Subject: [PATCH] [CELEBORN-771][FLINK] Convert PushDataHandShake, RegionFinish, RegionStart to PB ### What changes were proposed in this pull request? `PushDataHandShake`, `RegionFinish`, and `RegionStart` should merge to transport messages to enhance celeborn's compatibility. ### Why are the changes needed? 1. Improves celeborn's transport flexibility to change RPC. 2. Makes Compatible with 0.2 client. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - `RemoteShuffleOutputGateSuiteJ` Closes #1910 from SteNicholas/CELEBORN-771. Authored-by: SteNicholas Signed-off-by: mingji --- .../readclient/FlinkShuffleClientImpl.java | 70 ++++++--- .../network/protocol/PushDataHandShake.java | 1 + .../common/network/protocol/RegionFinish.java | 1 + .../common/network/protocol/RegionStart.java | 1 + .../network/protocol/TransportMessage.java | 12 ++ common/src/main/proto/TransportMessages.proto | 30 +++- .../deploy/worker/PushDataHandler.scala | 145 ++++++++++++++---- 7 files changed, 207 insertions(+), 53 deletions(-) diff --git a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java index 99271689422..c7ed33b3302 100644 --- a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java +++ b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java @@ -44,13 +44,16 @@ import org.apache.celeborn.common.network.client.TransportClient; import org.apache.celeborn.common.network.client.TransportClientFactory; import org.apache.celeborn.common.network.protocol.PushData; -import org.apache.celeborn.common.network.protocol.PushDataHandShake; -import org.apache.celeborn.common.network.protocol.RegionFinish; -import org.apache.celeborn.common.network.protocol.RegionStart; +import org.apache.celeborn.common.network.protocol.TransportMessage; import org.apache.celeborn.common.network.util.TransportConf; +import org.apache.celeborn.common.protocol.MessageType; import org.apache.celeborn.common.protocol.PartitionLocation; import org.apache.celeborn.common.protocol.PbChangeLocationPartitionInfo; import org.apache.celeborn.common.protocol.PbChangeLocationResponse; +import org.apache.celeborn.common.protocol.PbPartitionLocation.Mode; +import org.apache.celeborn.common.protocol.PbPushDataHandShake; +import org.apache.celeborn.common.protocol.PbRegionFinish; +import org.apache.celeborn.common.protocol.PbRegionStart; import org.apache.celeborn.common.protocol.ReviveRequest; import org.apache.celeborn.common.protocol.TransportModuleConstants; import org.apache.celeborn.common.protocol.message.ControlMessages; @@ -332,18 +335,23 @@ public Optional pushDataHandShake( location.getUniqueId()); logger.debug("PushDataHandShake location {}", location); TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState); - PushDataHandShake handShake = - new PushDataHandShake( - PRIMARY_MODE, - shuffleKey, - location.getUniqueId(), - attemptId, - numPartitions, - bufferSize); ByteBuffer pushDataHandShakeResponse; try { pushDataHandShakeResponse = - client.sendRpcSync(handShake.toByteBuffer(), conf.pushDataTimeoutMs()); + client.sendRpcSync( + new TransportMessage( + MessageType.PUSH_DATA_HAND_SHAKE, + PbPushDataHandShake.newBuilder() + .setMode(Mode.forNumber(PRIMARY_MODE)) + .setShuffleKey(shuffleKey) + .setPartitionUniqueId(location.getUniqueId()) + .setAttemptId(attemptId) + .setNumPartitions(numPartitions) + .setBufferSize(bufferSize) + .build() + .toByteArray()) + .toByteBuffer(), + conf.pushDataTimeoutMs()); } catch (IOException e) { // ioexeption revive return revive(shuffleId, mapId, attemptId, location); @@ -378,18 +386,23 @@ public Optional regionStart( location.getUniqueId()); logger.debug("RegionStart for location {}.", location.toString()); TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState); - RegionStart regionStart = - new RegionStart( - PRIMARY_MODE, - shuffleKey, - location.getUniqueId(), - attemptId, - currentRegionIdx, - isBroadcast); ByteBuffer regionStartResponse; try { regionStartResponse = - client.sendRpcSync(regionStart.toByteBuffer(), conf.pushDataTimeoutMs()); + client.sendRpcSync( + new TransportMessage( + MessageType.REGION_START, + PbRegionStart.newBuilder() + .setMode(Mode.forNumber(PRIMARY_MODE)) + .setShuffleKey(shuffleKey) + .setPartitionUniqueId(location.getUniqueId()) + .setAttemptId(attemptId) + .setCurrentRegionIndex(currentRegionIdx) + .setIsBroadcast(isBroadcast) + .build() + .toByteArray()) + .toByteBuffer(), + conf.pushDataTimeoutMs()); } catch (IOException e) { // ioexeption revive return revive(shuffleId, mapId, attemptId, location); @@ -459,9 +472,18 @@ public void regionFinish(int shuffleId, int mapId, int attemptId, PartitionLocat location.getUniqueId()); logger.debug("RegionFinish for location {}.", location); TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState); - RegionFinish regionFinish = - new RegionFinish(PRIMARY_MODE, shuffleKey, location.getUniqueId(), attemptId); - client.sendRpcSync(regionFinish.toByteBuffer(), conf.pushDataTimeoutMs()); + client.sendRpcSync( + new TransportMessage( + MessageType.REGION_FINISH, + PbRegionFinish.newBuilder() + .setMode(Mode.forNumber(PRIMARY_MODE)) + .setShuffleKey(shuffleKey) + .setPartitionUniqueId(location.getUniqueId()) + .setAttemptId(attemptId) + .build() + .toByteArray()) + .toByteBuffer(), + conf.pushDataTimeoutMs()); return null; }); } diff --git a/common/src/main/java/org/apache/celeborn/common/network/protocol/PushDataHandShake.java b/common/src/main/java/org/apache/celeborn/common/network/protocol/PushDataHandShake.java index 163fcaeb9ee..dc8c0481680 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/protocol/PushDataHandShake.java +++ b/common/src/main/java/org/apache/celeborn/common/network/protocol/PushDataHandShake.java @@ -20,6 +20,7 @@ import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; +@Deprecated public final class PushDataHandShake extends RequestMessage { // 0 for primary, 1 for replica, see PartitionLocation.Mode public final byte mode; diff --git a/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionFinish.java b/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionFinish.java index 62fe18e531b..c7f804d84c7 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionFinish.java +++ b/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionFinish.java @@ -20,6 +20,7 @@ import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; +@Deprecated public final class RegionFinish extends RequestMessage { // 0 for primary, 1 for replica, see PartitionLocation.Mode diff --git a/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionStart.java b/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionStart.java index 322029d28ba..4081c5cf46f 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionStart.java +++ b/common/src/main/java/org/apache/celeborn/common/network/protocol/RegionStart.java @@ -20,6 +20,7 @@ import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; +@Deprecated public final class RegionStart extends RequestMessage { // 0 for primary, 1 for replica, see PartitionLocation.Mode diff --git a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java index d72bfec1b14..d59bcaab579 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java +++ b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java @@ -18,6 +18,9 @@ package org.apache.celeborn.common.network.protocol; import static org.apache.celeborn.common.protocol.MessageType.OPEN_STREAM_VALUE; +import static org.apache.celeborn.common.protocol.MessageType.PUSH_DATA_HAND_SHAKE_VALUE; +import static org.apache.celeborn.common.protocol.MessageType.REGION_FINISH_VALUE; +import static org.apache.celeborn.common.protocol.MessageType.REGION_START_VALUE; import static org.apache.celeborn.common.protocol.MessageType.STREAM_HANDLER_VALUE; import java.io.Serializable; @@ -31,6 +34,9 @@ import org.apache.celeborn.common.exception.CelebornIOException; import org.apache.celeborn.common.protocol.MessageType; import org.apache.celeborn.common.protocol.PbOpenStream; +import org.apache.celeborn.common.protocol.PbPushDataHandShake; +import org.apache.celeborn.common.protocol.PbRegionFinish; +import org.apache.celeborn.common.protocol.PbRegionStart; import org.apache.celeborn.common.protocol.PbStreamHandler; public class TransportMessage implements Serializable { @@ -64,6 +70,12 @@ public T getParsedPayload() throws InvalidProtoco return (T) PbOpenStream.parseFrom(payload); case STREAM_HANDLER_VALUE: return (T) PbStreamHandler.parseFrom(payload); + case PUSH_DATA_HAND_SHAKE_VALUE: + return (T) PbPushDataHandShake.parseFrom(payload); + case REGION_START_VALUE: + return (T) PbRegionStart.parseFrom(payload); + case REGION_FINISH_VALUE: + return (T) PbRegionFinish.parseFrom(payload); default: logger.error("Unexpected type {}", type); } diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto index 2e55c73b1ee..67db3e4c18b 100644 --- a/common/src/main/proto/TransportMessages.proto +++ b/common/src/main/proto/TransportMessages.proto @@ -76,6 +76,9 @@ enum MessageType { CHECK_WORKERS_AVAILABLE = 53; CHECK_WORKERS_AVAILABLE_RESPONSE = 54; REMOVE_WORKERS_UNAVAILABLE_INFO = 55; + PUSH_DATA_HAND_SHAKE = 56; + REGION_START = 57; + REGION_FINISH = 58; } message PbStorageInfo { @@ -499,4 +502,29 @@ message PbStreamHandler { int32 numChunks = 2; repeated int64 chunkOffsets = 3 ; string fullPath = 4; -} \ No newline at end of file +} + +message PbPushDataHandShake { + PbPartitionLocation.Mode mode = 1; + string shuffleKey = 2; + string partitionUniqueId = 3; + int32 attemptId = 4; + int32 numPartitions = 5; + int32 bufferSize = 6; +} + +message PbRegionStart { + PbPartitionLocation.Mode mode = 1; + string shuffleKey = 2; + string partitionUniqueId = 3; + int32 attemptId = 4; + int32 currentRegionIndex = 5; + bool isBroadcast = 6; +} + +message PbRegionFinish { + PbPartitionLocation.Mode mode = 1; + string shuffleKey = 2; + string partitionUniqueId = 3; + int32 attemptId = 4; +} diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala index 68facb60184..0a201439a0e 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala @@ -22,6 +22,7 @@ import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor} import java.util.concurrent.atomic.{AtomicBoolean, AtomicIntegerArray} import com.google.common.base.Throwables +import com.google.protobuf.GeneratedMessageV3 import io.netty.buffer.ByteBuf import org.apache.celeborn.common.exception.{AlreadyClosedException, CelebornIOException} @@ -30,10 +31,11 @@ import org.apache.celeborn.common.meta.{DiskStatus, WorkerInfo, WorkerPartitionL import org.apache.celeborn.common.metrics.source.Source import org.apache.celeborn.common.network.buffer.{NettyManagedBuffer, NioManagedBuffer} import org.apache.celeborn.common.network.client.{RpcResponseCallback, TransportClient, TransportClientFactory} -import org.apache.celeborn.common.network.protocol.{Message, PushData, PushDataHandShake, PushMergedData, RegionFinish, RegionStart, RequestMessage, RpcFailure, RpcRequest, RpcResponse} +import org.apache.celeborn.common.network.protocol.{Message, PushData, PushDataHandShake, PushMergedData, RegionFinish, RegionStart, RequestMessage, RpcFailure, RpcRequest, RpcResponse, TransportMessage} import org.apache.celeborn.common.network.protocol.Message.Type import org.apache.celeborn.common.network.server.BaseMessageHandler -import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionSplitMode, PartitionType} +import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionSplitMode, PartitionType, PbPushDataHandShake, PbRegionFinish, PbRegionStart} +import org.apache.celeborn.common.protocol.PbPartitionLocation.Mode import org.apache.celeborn.common.protocol.message.StatusCode import org.apache.celeborn.common.unsafe.Platform import org.apache.celeborn.common.util.Utils @@ -798,10 +800,8 @@ class PushDataHandler extends BaseMessageHandler with Logging { pushData.`type`(), shuffleKey, pushData.partitionUniqueId, - null, location, - callback, - wrappedCallback)) return + callback)) return val fileWriter = getFileWriterAndCheck(pushData.`type`(), location, isPrimary, callback) match { @@ -849,45 +849,120 @@ class PushDataHandler extends BaseMessageHandler with Logging { } private def handleRpcRequest(client: TransportClient, rpcRequest: RpcRequest): Unit = { - val msg = Message.decode(rpcRequest.body().nioByteBuffer()) val requestId = rpcRequest.requestId - val (mode, shuffleKey, partitionUniqueId, checkSplit) = msg match { - case p: PushDataHandShake => (p.mode, p.shuffleKey, p.partitionUniqueId, true) - case rs: RegionStart => (rs.mode, rs.shuffleKey, rs.partitionUniqueId, true) - case rf: RegionFinish => (rf.mode, rf.shuffleKey, rf.partitionUniqueId, false) - } + val (pbMsg, msg, isLegacy, messageType, mode, shuffleKey, partitionUniqueId, checkSplit) = + mapPartitionRpcRequest(rpcRequest) handleCore( client, rpcRequest, requestId, () => handleMapPartitionRpcRequestCore( - mode, + requestId, + pbMsg, msg, + isLegacy, + messageType, + mode, shuffleKey, partitionUniqueId, - requestId, checkSplit, new SimpleRpcResponseCallback( client, requestId, shuffleKey))) + } + private def mapPartitionRpcRequest(rpcRequest: RpcRequest) + : Tuple8[GeneratedMessageV3, Message, Boolean, Type, Mode, String, String, Boolean] = { + try { + val msg = TransportMessage.fromByteBuffer( + rpcRequest.body().nioByteBuffer()).getParsedPayload.asInstanceOf[GeneratedMessageV3] + msg match { + case p: PbPushDataHandShake => + ( + msg, + null, + false, + Type.PUSH_DATA_HAND_SHAKE, + p.getMode, + p.getShuffleKey, + p.getPartitionUniqueId, + true) + case rs: PbRegionStart => + ( + msg, + null, + false, + Type.REGION_START, + rs.getMode, + rs.getShuffleKey, + rs.getPartitionUniqueId, + true) + case rf: PbRegionFinish => + ( + msg, + null, + false, + Type.REGION_FINISH, + rf.getMode, + rf.getShuffleKey, + rf.getPartitionUniqueId, + false) + } + } catch { + case _: Exception => + val msg = Message.decode(rpcRequest.body().nioByteBuffer()) + msg match { + case p: PushDataHandShake => + ( + null, + msg, + true, + Type.PUSH_DATA_HAND_SHAKE, + Mode.forNumber(p.mode), + p.shuffleKey, + p.partitionUniqueId, + true) + case rs: RegionStart => + ( + null, + msg, + true, + Type.REGION_START, + Mode.forNumber(rs.mode), + rs.shuffleKey, + rs.partitionUniqueId, + true) + case rf: RegionFinish => + ( + null, + msg, + true, + Type.REGION_FINISH, + Mode.forNumber(rf.mode), + rf.shuffleKey, + rf.partitionUniqueId, + false) + } + } } private def handleMapPartitionRpcRequestCore( - mode: Byte, - message: Message, + requestId: Long, + pbMsg: GeneratedMessageV3, + msg: Message, + isLegacy: Boolean, + messageType: Message.Type, + mode: Mode, shuffleKey: String, partitionUniqueId: String, - requestId: Long, checkSplit: Boolean, callback: RpcResponseCallback): Unit = { - val isPrimary = PartitionLocation.getMode(mode) == PartitionLocation.Mode.PRIMARY - val messageType = message.`type`() log.debug( s"requestId:$requestId, pushdata rpc:$messageType, mode:$mode, shuffleKey:$shuffleKey, " + s"partitionUniqueId:$partitionUniqueId") + val isPrimary = mode == Mode.Primary val (workerSourcePrimary, workerSourceReplica) = messageType match { case Type.PUSH_DATA_HAND_SHAKE => @@ -924,10 +999,8 @@ class PushDataHandler extends BaseMessageHandler with Logging { messageType, shuffleKey, partitionUniqueId, - null, location, - callback, - wrappedCallback)) return + callback)) return val fileWriter = getFileWriterAndCheck(messageType, location, isPrimary, callback) match { @@ -957,13 +1030,31 @@ class PushDataHandler extends BaseMessageHandler with Logging { try { messageType match { case Type.PUSH_DATA_HAND_SHAKE => + val (numPartitions, bufferSize) = + if (isLegacy) + ( + msg.asInstanceOf[PushDataHandShake].numPartitions, + msg.asInstanceOf[PushDataHandShake].bufferSize) + else + ( + pbMsg.asInstanceOf[PbPushDataHandShake].getNumPartitions, + pbMsg.asInstanceOf[PbPushDataHandShake].getBufferSize) fileWriter.asInstanceOf[MapPartitionFileWriter].pushDataHandShake( - message.asInstanceOf[PushDataHandShake].numPartitions, - message.asInstanceOf[PushDataHandShake].bufferSize) + numPartitions, + bufferSize) case Type.REGION_START => + val (currentRegionIndex, isBroadcast) = + if (isLegacy) + ( + msg.asInstanceOf[RegionStart].currentRegionIndex, + msg.asInstanceOf[RegionStart].isBroadcast) + else + ( + pbMsg.asInstanceOf[PbRegionStart].getCurrentRegionIndex, + Boolean.box(pbMsg.asInstanceOf[PbRegionStart].getIsBroadcast)) fileWriter.asInstanceOf[MapPartitionFileWriter].regionStart( - message.asInstanceOf[RegionStart].currentRegionIndex, - message.asInstanceOf[RegionStart].isBroadcast) + currentRegionIndex, + isBroadcast) case Type.REGION_FINISH => fileWriter.asInstanceOf[MapPartitionFileWriter].regionFinish() case _ => throw new IllegalArgumentException(s"Not support $messageType yet") @@ -1039,10 +1130,8 @@ class PushDataHandler extends BaseMessageHandler with Logging { messageType: Message.Type, shuffleKey: String, partitionUniqueId: String, - body: ByteBuf, location: PartitionLocation, - callback: RpcResponseCallback, - wrappedCallback: RpcResponseCallback): Boolean = { + callback: RpcResponseCallback): Boolean = { if (location == null) { val msg = s"Partition location wasn't found for task(shuffle $shuffleKey, uniqueId $partitionUniqueId)."