Skip to content

Commit

Permalink
[CELEBORN-771][FLINK] Convert PushDataHandShake, RegionFinish, Region…
Browse files Browse the repository at this point in the history
…Start to PB

### What changes were proposed in this pull request?

`PushDataHandShake`, `RegionFinish`, and `RegionStart` should merge to transport messages to enhance celeborn's compatibility.

### Why are the changes needed?

1. Improves celeborn's transport flexibility to change RPC.
2. Makes Compatible with 0.2 client.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

- `RemoteShuffleOutputGateSuiteJ`

Closes apache#1910 from SteNicholas/CELEBORN-771.

Authored-by: SteNicholas <[email protected]>
Signed-off-by: mingji <[email protected]>
  • Loading branch information
SteNicholas authored and FMX committed Sep 22, 2023
1 parent f1713da commit 55e8505
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,16 @@
import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.client.TransportClientFactory;
import org.apache.celeborn.common.network.protocol.PushData;
import org.apache.celeborn.common.network.protocol.PushDataHandShake;
import org.apache.celeborn.common.network.protocol.RegionFinish;
import org.apache.celeborn.common.network.protocol.RegionStart;
import org.apache.celeborn.common.network.protocol.TransportMessage;
import org.apache.celeborn.common.network.util.TransportConf;
import org.apache.celeborn.common.protocol.MessageType;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.common.protocol.PbChangeLocationPartitionInfo;
import org.apache.celeborn.common.protocol.PbChangeLocationResponse;
import org.apache.celeborn.common.protocol.PbPartitionLocation.Mode;
import org.apache.celeborn.common.protocol.PbPushDataHandShake;
import org.apache.celeborn.common.protocol.PbRegionFinish;
import org.apache.celeborn.common.protocol.PbRegionStart;
import org.apache.celeborn.common.protocol.ReviveRequest;
import org.apache.celeborn.common.protocol.TransportModuleConstants;
import org.apache.celeborn.common.protocol.message.ControlMessages;
Expand Down Expand Up @@ -332,18 +335,23 @@ public Optional<PartitionLocation> pushDataHandShake(
location.getUniqueId());
logger.debug("PushDataHandShake location {}", location);
TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState);
PushDataHandShake handShake =
new PushDataHandShake(
PRIMARY_MODE,
shuffleKey,
location.getUniqueId(),
attemptId,
numPartitions,
bufferSize);
ByteBuffer pushDataHandShakeResponse;
try {
pushDataHandShakeResponse =
client.sendRpcSync(handShake.toByteBuffer(), conf.pushDataTimeoutMs());
client.sendRpcSync(
new TransportMessage(
MessageType.PUSH_DATA_HAND_SHAKE,
PbPushDataHandShake.newBuilder()
.setMode(Mode.forNumber(PRIMARY_MODE))
.setShuffleKey(shuffleKey)
.setPartitionUniqueId(location.getUniqueId())
.setAttemptId(attemptId)
.setNumPartitions(numPartitions)
.setBufferSize(bufferSize)
.build()
.toByteArray())
.toByteBuffer(),
conf.pushDataTimeoutMs());
} catch (IOException e) {
// ioexeption revive
return revive(shuffleId, mapId, attemptId, location);
Expand Down Expand Up @@ -378,18 +386,23 @@ public Optional<PartitionLocation> regionStart(
location.getUniqueId());
logger.debug("RegionStart for location {}.", location.toString());
TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState);
RegionStart regionStart =
new RegionStart(
PRIMARY_MODE,
shuffleKey,
location.getUniqueId(),
attemptId,
currentRegionIdx,
isBroadcast);
ByteBuffer regionStartResponse;
try {
regionStartResponse =
client.sendRpcSync(regionStart.toByteBuffer(), conf.pushDataTimeoutMs());
client.sendRpcSync(
new TransportMessage(
MessageType.REGION_START,
PbRegionStart.newBuilder()
.setMode(Mode.forNumber(PRIMARY_MODE))
.setShuffleKey(shuffleKey)
.setPartitionUniqueId(location.getUniqueId())
.setAttemptId(attemptId)
.setCurrentRegionIndex(currentRegionIdx)
.setIsBroadcast(isBroadcast)
.build()
.toByteArray())
.toByteBuffer(),
conf.pushDataTimeoutMs());
} catch (IOException e) {
// ioexeption revive
return revive(shuffleId, mapId, attemptId, location);
Expand Down Expand Up @@ -459,9 +472,18 @@ public void regionFinish(int shuffleId, int mapId, int attemptId, PartitionLocat
location.getUniqueId());
logger.debug("RegionFinish for location {}.", location);
TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState);
RegionFinish regionFinish =
new RegionFinish(PRIMARY_MODE, shuffleKey, location.getUniqueId(), attemptId);
client.sendRpcSync(regionFinish.toByteBuffer(), conf.pushDataTimeoutMs());
client.sendRpcSync(
new TransportMessage(
MessageType.REGION_FINISH,
PbRegionFinish.newBuilder()
.setMode(Mode.forNumber(PRIMARY_MODE))
.setShuffleKey(shuffleKey)
.setPartitionUniqueId(location.getUniqueId())
.setAttemptId(attemptId)
.build()
.toByteArray())
.toByteBuffer(),
conf.pushDataTimeoutMs());
return null;
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;

@Deprecated
public final class PushDataHandShake extends RequestMessage {
// 0 for primary, 1 for replica, see PartitionLocation.Mode
public final byte mode;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;

@Deprecated
public final class RegionFinish extends RequestMessage {

// 0 for primary, 1 for replica, see PartitionLocation.Mode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;

@Deprecated
public final class RegionStart extends RequestMessage {

// 0 for primary, 1 for replica, see PartitionLocation.Mode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
package org.apache.celeborn.common.network.protocol;

import static org.apache.celeborn.common.protocol.MessageType.OPEN_STREAM_VALUE;
import static org.apache.celeborn.common.protocol.MessageType.PUSH_DATA_HAND_SHAKE_VALUE;
import static org.apache.celeborn.common.protocol.MessageType.REGION_FINISH_VALUE;
import static org.apache.celeborn.common.protocol.MessageType.REGION_START_VALUE;
import static org.apache.celeborn.common.protocol.MessageType.STREAM_HANDLER_VALUE;

import java.io.Serializable;
Expand All @@ -31,6 +34,9 @@
import org.apache.celeborn.common.exception.CelebornIOException;
import org.apache.celeborn.common.protocol.MessageType;
import org.apache.celeborn.common.protocol.PbOpenStream;
import org.apache.celeborn.common.protocol.PbPushDataHandShake;
import org.apache.celeborn.common.protocol.PbRegionFinish;
import org.apache.celeborn.common.protocol.PbRegionStart;
import org.apache.celeborn.common.protocol.PbStreamHandler;

public class TransportMessage implements Serializable {
Expand Down Expand Up @@ -64,6 +70,12 @@ public <T extends GeneratedMessageV3> T getParsedPayload() throws InvalidProtoco
return (T) PbOpenStream.parseFrom(payload);
case STREAM_HANDLER_VALUE:
return (T) PbStreamHandler.parseFrom(payload);
case PUSH_DATA_HAND_SHAKE_VALUE:
return (T) PbPushDataHandShake.parseFrom(payload);
case REGION_START_VALUE:
return (T) PbRegionStart.parseFrom(payload);
case REGION_FINISH_VALUE:
return (T) PbRegionFinish.parseFrom(payload);
default:
logger.error("Unexpected type {}", type);
}
Expand Down
30 changes: 29 additions & 1 deletion common/src/main/proto/TransportMessages.proto
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ enum MessageType {
CHECK_WORKERS_AVAILABLE = 53;
CHECK_WORKERS_AVAILABLE_RESPONSE = 54;
REMOVE_WORKERS_UNAVAILABLE_INFO = 55;
PUSH_DATA_HAND_SHAKE = 56;
REGION_START = 57;
REGION_FINISH = 58;
}

message PbStorageInfo {
Expand Down Expand Up @@ -499,4 +502,29 @@ message PbStreamHandler {
int32 numChunks = 2;
repeated int64 chunkOffsets = 3 ;
string fullPath = 4;
}
}

message PbPushDataHandShake {
PbPartitionLocation.Mode mode = 1;
string shuffleKey = 2;
string partitionUniqueId = 3;
int32 attemptId = 4;
int32 numPartitions = 5;
int32 bufferSize = 6;
}

message PbRegionStart {
PbPartitionLocation.Mode mode = 1;
string shuffleKey = 2;
string partitionUniqueId = 3;
int32 attemptId = 4;
int32 currentRegionIndex = 5;
bool isBroadcast = 6;
}

message PbRegionFinish {
PbPartitionLocation.Mode mode = 1;
string shuffleKey = 2;
string partitionUniqueId = 3;
int32 attemptId = 4;
}
Loading

0 comments on commit 55e8505

Please sign in to comment.