Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streamed fetch chunk #1

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ public TransportChannelHandler initializePipeline(
TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
channel.pipeline()
.addLast("encoder", ENCODER)
.addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder())
.addLast(TransportFrameDecoder.HANDLER_NAME,
NettyUtils.createFrameDecoder(conf.maxRemoteBlockSizeFetchToMem(), false))
.addLast("decoder", DECODER)
.addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000))
// NOTE: Chunks are currently guaranteed to be returned in the order of request, but this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.network.client;

import io.netty.buffer.ByteBuf;

import java.io.IOException;
import java.nio.ByteBuffer;

Expand All @@ -28,13 +30,13 @@
* The network library guarantees that a single thread will call these methods at a time, but
* different call may be made by different threads.
*/
public interface StreamCallback {
public interface StreamCallback<T> {
/** Called upon receipt of stream data. */
void onData(String streamId, ByteBuffer buf) throws IOException;
void onData(T streamId, ByteBuffer buf) throws IOException;

/** Called when all data from the stream has been received. */
void onComplete(String streamId) throws IOException;
void onComplete(T streamId) throws IOException;

/** Called if there's an error reading data from the stream. */
void onFailure(String streamId, Throwable cause) throws IOException;
void onFailure(T streamId, Throwable cause) throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@
* An interceptor that is registered with the frame decoder to feed stream data to a
* callback.
*/
class StreamInterceptor implements TransportFrameDecoder.Interceptor {
class StreamInterceptor<T> implements TransportFrameDecoder.Interceptor {

private final TransportResponseHandler handler;
private final String streamId;
private final T streamId;
private final long byteCount;
private final StreamCallback callback;
private final StreamCallback<T> callback;
private long bytesRead;

StreamInterceptor(
TransportResponseHandler handler,
String streamId,
T streamId,
long byteCount,
StreamCallback callback) {
this.handler = handler;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import javax.annotation.Nullable;

import com.google.common.annotations.VisibleForTesting;
Expand Down Expand Up @@ -132,14 +133,15 @@ public void setClientId(String id) {
public void fetchChunk(
long streamId,
int chunkIndex,
ChunkReceivedCallback callback) {
ChunkReceivedCallback callback,
Supplier<StreamCallback<StreamChunkId>> streamCallbackFactory) {
long startTime = System.currentTimeMillis();
if (logger.isDebugEnabled()) {
logger.debug("Sending fetch chunk request {} to {}", chunkIndex, getRemoteAddress(channel));
}

StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex);
handler.addFetchRequest(streamChunkId, callback);
handler.addFetchRequest(streamChunkId, callback, streamCallbackFactory);

channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(future -> {
if (future.isSuccess()) {
Expand Down Expand Up @@ -169,7 +171,7 @@ public void fetchChunk(
* @param streamId The stream to fetch.
* @param callback Object to call with the stream data.
*/
public void stream(String streamId, StreamCallback callback) {
public void stream(String streamId, StreamCallback<String> callback) {
long startTime = System.currentTimeMillis();
if (logger.isDebugEnabled()) {
logger.debug("Sending stream request for {} to {}", streamId, getRemoteAddress(channel));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Supplier;

import com.google.common.annotations.VisibleForTesting;
import io.netty.channel.Channel;
Expand Down Expand Up @@ -55,10 +56,12 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
private final Channel channel;

private final Map<StreamChunkId, ChunkReceivedCallback> outstandingFetches;
private final Map<StreamChunkId, Supplier<StreamCallback<StreamChunkId>>>
outstandingStreamFetches;

private final Map<Long, RpcResponseCallback> outstandingRpcs;

private final Queue<Pair<String, StreamCallback>> streamCallbacks;
private final Queue<Pair<String, StreamCallback<String>>> streamCallbacks;
private volatile boolean streamActive;

/** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */
Expand All @@ -67,18 +70,24 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
public TransportResponseHandler(Channel channel) {
this.channel = channel;
this.outstandingFetches = new ConcurrentHashMap<>();
this.outstandingStreamFetches = new ConcurrentHashMap<>();
this.outstandingRpcs = new ConcurrentHashMap<>();
this.streamCallbacks = new ConcurrentLinkedQueue<>();
this.timeOfLastRequestNs = new AtomicLong(0);
}

public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) {
public void addFetchRequest(
StreamChunkId streamChunkId,
ChunkReceivedCallback callback,
Supplier<StreamCallback<StreamChunkId>> streamCallbackFactory) {
updateTimeOfLastRequest();
outstandingFetches.put(streamChunkId, callback);
outstandingStreamFetches.put(streamChunkId, streamCallbackFactory);
}

public void removeFetchRequest(StreamChunkId streamChunkId) {
outstandingFetches.remove(streamChunkId);
outstandingStreamFetches.remove(streamChunkId);
}

public void addRpcRequest(long requestId, RpcResponseCallback callback) {
Expand All @@ -90,7 +99,7 @@ public void removeRpcRequest(long requestId) {
outstandingRpcs.remove(requestId);
}

public void addStreamCallback(String streamId, StreamCallback callback) {
public void addStreamCallback(String streamId, StreamCallback<String> callback) {
timeOfLastRequestNs.set(System.nanoTime());
streamCallbacks.offer(ImmutablePair.of(streamId, callback));
}
Expand Down Expand Up @@ -119,7 +128,7 @@ private void failOutstandingRequests(Throwable cause) {
logger.warn("RpcResponseCallback.onFailure throws exception", e);
}
}
for (Pair<String, StreamCallback> entry : streamCallbacks) {
for (Pair<String, StreamCallback<String>> entry : streamCallbacks) {
try {
entry.getValue().onFailure(entry.getKey(), cause);
} catch (Exception e) {
Expand All @@ -131,6 +140,7 @@ private void failOutstandingRequests(Throwable cause) {
outstandingFetches.clear();
outstandingRpcs.clear();
streamCallbacks.clear();
outstandingStreamFetches.clear();
}

@Override
Expand Down Expand Up @@ -165,10 +175,37 @@ public void handle(ResponseMessage message) throws Exception {
if (listener == null) {
logger.warn("Ignoring response for block {} from {} since it is not outstanding",
resp.streamChunkId, getRemoteAddress(channel));
resp.body().release();
} else {
outstandingFetches.remove(resp.streamChunkId);
listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body());
if (resp.isBodyInFrame()) {
outstandingFetches.remove(resp.streamChunkId);
listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body());
} else {
StreamCallback<StreamChunkId> streamCallback =
outstandingStreamFetches.get(resp.streamChunkId).get();
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the whole point of the supplier is to delay this call to get(), right? you want to avoid creating the output file until here?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes and to avoid creating unnecessary temp files for the small chunks (under maxRemoteBlockSizeFetchToMem) which will be loaded into the memory.

outstandingFetches.remove(resp.streamChunkId);
outstandingStreamFetches.remove(resp.streamChunkId);
if (resp.remainingFrameSize > 0) {
StreamInterceptor interceptor = new StreamInterceptor<StreamChunkId>(this,
resp.streamChunkId, resp.remainingFrameSize, streamCallback);
try {
TransportFrameDecoder frameDecoder = (TransportFrameDecoder)
channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME);
frameDecoder.setInterceptor(interceptor);
streamActive = true;
} catch (Exception e) {
logger.error("Error installing stream handler.", e);
deactivateStream();
}
} else {
try {
streamCallback.onComplete(resp.streamChunkId);
} catch (Exception e) {
logger.warn("Error in stream handler onComplete().", e);
}
}
}
}
if (resp.isBodyInFrame()) {
resp.body().release();
}
} else if (message instanceof ChunkFetchFailure) {
Expand Down Expand Up @@ -208,12 +245,12 @@ public void handle(ResponseMessage message) throws Exception {
}
} else if (message instanceof StreamResponse) {
StreamResponse resp = (StreamResponse) message;
Pair<String, StreamCallback> entry = streamCallbacks.poll();
Pair<String, StreamCallback<String>> entry = streamCallbacks.poll();
if (entry != null) {
StreamCallback callback = entry.getValue();
if (resp.byteCount > 0) {
StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount,
callback);
StreamInterceptor interceptor = new StreamInterceptor<String>(this, resp.streamId,
resp.byteCount, callback);
try {
TransportFrameDecoder frameDecoder = (TransportFrameDecoder)
channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME);
Expand All @@ -235,7 +272,7 @@ public void handle(ResponseMessage message) throws Exception {
}
} else if (message instanceof StreamFailure) {
StreamFailure resp = (StreamFailure) message;
Pair<String, StreamCallback> entry = streamCallbacks.poll();
Pair<String, StreamCallback<String>> entry = streamCallbacks.poll();
if (entry != null) {
StreamCallback callback = entry.getValue();
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,23 @@
* Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer.
*/
public final class ChunkFetchSuccess extends AbstractResponseMessage {
public static final int ENCODED_LENGTH = StreamChunkId.ENCODED_LENGTH;
public final StreamChunkId streamChunkId;
public final long remainingFrameSize;

public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) {
super(buffer, true);
this.streamChunkId = streamChunkId;
this.remainingFrameSize = 0;
}

public ChunkFetchSuccess(StreamChunkId streamChunkId,
ManagedBuffer buffer,
boolean isBodyInFrame,
long remainingFrameSize) {
super(buffer, isBodyInFrame);
this.streamChunkId = streamChunkId;
this.remainingFrameSize = remainingFrameSize;
}

@Override
Expand All @@ -58,11 +70,16 @@ public ResponseMessage createFailureResponse(String error) {
}

/** Decoding uses the given ByteBuf as our data, and will retain() it. */
public static ChunkFetchSuccess decode(ByteBuf buf) {
public static ChunkFetchSuccess decode(ByteBuf buf, long remainingFrameSize) {
StreamChunkId streamChunkId = StreamChunkId.decode(buf);
buf.retain();
NettyManagedBuffer managedBuf = new NettyManagedBuffer(buf.duplicate());
return new ChunkFetchSuccess(streamChunkId, managedBuf);
NettyManagedBuffer managedBuf = null;
final boolean isFullFrameProcessed =
remainingFrameSize == 0;
if (isFullFrameProcessed) {
buf.retain();
managedBuf = new NettyManagedBuffer(buf.duplicate());
}
return new ChunkFetchSuccess(streamChunkId, managedBuf, isFullFrameProcessed, remainingFrameSize);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ enum Type implements Encodable {
StreamRequest(6), StreamResponse(7), StreamFailure(8),
OneWayMessage(9), User(-1);

/** Encoded length in bytes. */
public static final int LENGTH = 1;

private final byte id;

Type(int id) {
Expand All @@ -48,7 +51,7 @@ enum Type implements Encodable {

public byte id() { return id; }

@Override public int encodedLength() { return 1; }
@Override public int encodedLength() { return LENGTH; }

@Override public void encode(ByteBuf buf) { buf.writeByte(id); }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
* This encoder is stateless so it is safe to be shared by multiple threads.
*/
@ChannelHandler.Sharable
public final class MessageDecoder extends MessageToMessageDecoder<ByteBuf> {
public final class MessageDecoder extends MessageToMessageDecoder<ParsedFrame> {

private static final Logger logger = LoggerFactory.getLogger(MessageDecoder.class);

Expand All @@ -40,21 +40,20 @@ public final class MessageDecoder extends MessageToMessageDecoder<ByteBuf> {
private MessageDecoder() {}

@Override
public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
Message.Type msgType = Message.Type.decode(in);
Message decoded = decode(msgType, in);
assert decoded.type() == msgType;
logger.trace("Received message {}: {}", msgType, decoded);
public void decode(ChannelHandlerContext ctx, ParsedFrame in, List<Object> out) {
Message decoded = decode(in.messageType, in.byteBuf, in.remainingFrameSize);
assert decoded.type() == in.messageType;
logger.trace("Received message {}: {}", in.messageType, decoded);
out.add(decoded);
}

private Message decode(Message.Type msgType, ByteBuf in) {
private Message decode(Message.Type msgType, ByteBuf in, long remainingFrameSize) {
switch (msgType) {
case ChunkFetchRequest:
return ChunkFetchRequest.decode(in);

case ChunkFetchSuccess:
return ChunkFetchSuccess.decode(in);
return ChunkFetchSuccess.decode(in, remainingFrameSize);

case ChunkFetchFailure:
return ChunkFetchFailure.decode(in);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network.protocol;

import io.netty.buffer.ByteBuf;

public class ParsedFrame {

public final Message.Type messageType;

public final ByteBuf byteBuf;

public final long remainingFrameSize;


public ParsedFrame(Message.Type messageType, ByteBuf byteBuf, long remainingFrameSize) {
this.messageType = messageType;
this.byteBuf = byteBuf;
this.remainingFrameSize = remainingFrameSize;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
* Encapsulates a request for a particular chunk of a stream.
*/
public final class StreamChunkId implements Encodable {

public static final int ENCODED_LENGTH = 8 + 4;

public final long streamId;
public final int chunkIndex;

Expand All @@ -34,7 +37,7 @@ public StreamChunkId(long streamId, int chunkIndex) {

@Override
public int encodedLength() {
return 8 + 4;
return ENCODED_LENGTH;
}

public void encode(ByteBuf buffer) {
Expand All @@ -43,7 +46,7 @@ public void encode(ByteBuf buffer) {
}

public static StreamChunkId decode(ByteBuf buffer) {
assert buffer.readableBytes() >= 8 + 4;
assert buffer.readableBytes() >= ENCODED_LENGTH;
long streamId = buffer.readLong();
int chunkIndex = buffer.readInt();
return new StreamChunkId(streamId, chunkIndex);
Expand Down
Loading