From 1706fdeaae9213bf63ad8f88c4d1c1d4ba926974 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Tue, 29 May 2018 17:57:12 +0200 Subject: [PATCH 1/8] Initial version --- .../spark/network/TransportContext.java | 3 +- .../ChunkReceivedWithStreamCallback.java | 24 +++++++ .../spark/network/client/StreamCallback.java | 10 +-- .../network/client/StreamInterceptor.java | 8 +-- .../spark/network/client/TransportClient.java | 4 +- .../client/TransportResponseHandler.java | 45 ++++++++++--- .../network/protocol/ChunkFetchSuccess.java | 25 +++++-- .../spark/network/protocol/Message.java | 5 +- .../network/protocol/MessageDecoder.java | 15 ++--- .../spark/network/protocol/ParsedFrame.java | 36 ++++++++++ .../spark/network/protocol/StreamChunkId.java | 7 +- .../spark/network/sasl/SaslEncryption.java | 2 +- .../apache/spark/network/util/NettyUtils.java | 4 +- .../spark/network/util/TransportConf.java | 4 ++ .../network/util/TransportFrameDecoder.java | 58 +++++++++++++++-- .../network/ChunkFetchIntegrationSuite.java | 24 +++++-- .../apache/spark/network/ProtocolSuite.java | 4 +- .../RequestTimeoutIntegrationSuite.java | 23 +++++-- .../org/apache/spark/network/StreamSuite.java | 2 +- .../TransportResponseHandlerSuite.java | 20 +++--- .../spark/network/sasl/SparkSaslSuite.java | 7 +- .../util/TransportFrameDecoderSuite.java | 51 +++++++++------ .../shuffle/ExternalShuffleClient.java | 15 +++-- .../shuffle/OneForOneBlockFetcher.java | 65 ++++++++++++++++--- .../spark/network/shuffle/ShuffleClient.java | 20 +++--- .../network/sasl/SaslIntegrationSuite.java | 23 +++++-- .../ExternalShuffleIntegrationSuite.java | 2 +- .../spark/internal/config/package.scala | 12 +++- .../spark/network/BlockTransferService.scala | 10 ++- .../netty/NettyBlockTransferService.scala | 8 +-- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 2 +- .../shuffle/BlockStoreShuffleReader.scala | 1 + .../apache/spark/storage/BlockManager.scala | 19 +++--- .../storage/ShuffleBlockFetcherIterator.scala | 16 +++-- .../org/apache/spark/DistributedSuite.scala | 2 +- .../NettyBlockTransferSecuritySuite.scala | 4 +- .../spark/storage/BlockManagerSuite.scala | 8 ++- .../ShuffleBlockFetcherIteratorSuite.scala | 24 ++++--- 38 files changed, 443 insertions(+), 169 deletions(-) create mode 100644 common/network-common/src/main/java/org/apache/spark/network/client/ChunkReceivedWithStreamCallback.java create mode 100644 common/network-common/src/main/java/org/apache/spark/network/protocol/ParsedFrame.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java index ae91bc9cfdd08..af68728ce5204 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java @@ -146,7 +146,8 @@ public TransportChannelHandler initializePipeline( TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler); channel.pipeline() .addLast("encoder", ENCODER) - .addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder()) + .addLast(TransportFrameDecoder.HANDLER_NAME, + NettyUtils.createFrameDecoder(conf.maxRemoteBlockSizeFetchToMem())) .addLast("decoder", DECODER) .addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000)) // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/ChunkReceivedWithStreamCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/ChunkReceivedWithStreamCallback.java new file mode 100644 index 0000000000000..ad36fa7b5353c --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/client/ChunkReceivedWithStreamCallback.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import org.apache.spark.network.protocol.StreamChunkId; + +public interface ChunkReceivedWithStreamCallback extends + ChunkReceivedCallback, StreamCallback { +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java index d322aec28793e..bf3b60648b3ca 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java @@ -17,6 +17,8 @@ package org.apache.spark.network.client; +import io.netty.buffer.ByteBuf; + import java.io.IOException; import java.nio.ByteBuffer; @@ -28,13 +30,13 @@ * The network library guarantees that a single thread will call these methods at a time, but * different call may be made by different threads. */ -public interface StreamCallback { +public interface StreamCallback { /** Called upon receipt of stream data. */ - void onData(String streamId, ByteBuffer buf) throws IOException; + void onData(T streamId, ByteBuffer buf) throws IOException; /** Called when all data from the stream has been received. */ - void onComplete(String streamId) throws IOException; + void onComplete(T streamId) throws IOException; /** Called if there's an error reading data from the stream. */ - void onFailure(String streamId, Throwable cause) throws IOException; + void onFailure(T streamId, Throwable cause) throws IOException; } diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java index b0e85bae7c309..19faacf67d00e 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java @@ -28,17 +28,17 @@ * An interceptor that is registered with the frame decoder to feed stream data to a * callback. */ -class StreamInterceptor implements TransportFrameDecoder.Interceptor { +class StreamInterceptor implements TransportFrameDecoder.Interceptor { private final TransportResponseHandler handler; - private final String streamId; + private final T streamId; private final long byteCount; - private final StreamCallback callback; + private final StreamCallback callback; private long bytesRead; StreamInterceptor( TransportResponseHandler handler, - String streamId, + T streamId, long byteCount, StreamCallback callback) { this.handler = handler; diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index 8f354ad78bbaa..1fb037bc12f74 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -132,7 +132,7 @@ public void setClientId(String id) { public void fetchChunk( long streamId, int chunkIndex, - ChunkReceivedCallback callback) { + ChunkReceivedWithStreamCallback callback) { long startTime = System.currentTimeMillis(); if (logger.isDebugEnabled()) { logger.debug("Sending fetch chunk request {} to {}", chunkIndex, getRemoteAddress(channel)); @@ -169,7 +169,7 @@ public void fetchChunk( * @param streamId The stream to fetch. * @param callback Object to call with the stream data. */ - public void stream(String streamId, StreamCallback callback) { + public void stream(String streamId, StreamCallback callback) { long startTime = System.currentTimeMillis(); if (logger.isDebugEnabled()) { logger.debug("Sending stream request for {} to {}", streamId, getRemoteAddress(channel)); diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 7a3d96ceaef0c..c00a82e76547d 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -54,7 +54,7 @@ public class TransportResponseHandler extends MessageHandler { private final Channel channel; - private final Map outstandingFetches; + private final Map outstandingFetches; private final Map outstandingRpcs; @@ -72,7 +72,9 @@ public TransportResponseHandler(Channel channel) { this.timeOfLastRequestNs = new AtomicLong(0); } - public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) { + public void addFetchRequest( + StreamChunkId streamChunkId, + ChunkReceivedWithStreamCallback callback) { updateTimeOfLastRequest(); outstandingFetches.put(streamChunkId, callback); } @@ -90,7 +92,7 @@ public void removeRpcRequest(long requestId) { outstandingRpcs.remove(requestId); } - public void addStreamCallback(String streamId, StreamCallback callback) { + public void addStreamCallback(String streamId, StreamCallback callback) { timeOfLastRequestNs.set(System.nanoTime()); streamCallbacks.offer(ImmutablePair.of(streamId, callback)); } @@ -105,7 +107,7 @@ public void deactivateStream() { * uncaught exception or pre-mature connection termination. */ private void failOutstandingRequests(Throwable cause) { - for (Map.Entry entry : outstandingFetches.entrySet()) { + for (Map.Entry entry : outstandingFetches.entrySet()) { try { entry.getValue().onFailure(entry.getKey().chunkIndex, cause); } catch (Exception e) { @@ -161,14 +163,37 @@ public void exceptionCaught(Throwable cause) { public void handle(ResponseMessage message) throws Exception { if (message instanceof ChunkFetchSuccess) { ChunkFetchSuccess resp = (ChunkFetchSuccess) message; - ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); + ChunkReceivedWithStreamCallback listener = outstandingFetches.get(resp.streamChunkId); if (listener == null) { logger.warn("Ignoring response for block {} from {} since it is not outstanding", resp.streamChunkId, getRemoteAddress(channel)); - resp.body().release(); } else { - outstandingFetches.remove(resp.streamChunkId); - listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body()); + if (resp.isBodyInFrame()) { + outstandingFetches.remove(resp.streamChunkId); + listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body()); + } else { + if (resp.remainingFrameSize > 0) { + StreamInterceptor interceptor = new StreamInterceptor(this, + resp.streamChunkId, resp.remainingFrameSize, listener); + try { + TransportFrameDecoder frameDecoder = (TransportFrameDecoder) + channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); + frameDecoder.setInterceptor(interceptor); + streamActive = true; + } catch (Exception e) { + logger.error("Error installing stream handler.", e); + deactivateStream(); + } + } else { + try { + listener.onComplete(resp.streamChunkId); + } catch (Exception e) { + logger.warn("Error in stream handler onComplete().", e); + } + } + } + } + if (resp.isBodyInFrame()) { resp.body().release(); } } else if (message instanceof ChunkFetchFailure) { @@ -212,8 +237,8 @@ public void handle(ResponseMessage message) throws Exception { if (entry != null) { StreamCallback callback = entry.getValue(); if (resp.byteCount > 0) { - StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, - callback); + StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, + resp.byteCount, callback); try { TransportFrameDecoder frameDecoder = (TransportFrameDecoder) channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java index 94c2ac9b20e43..7a315f9db2fd2 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java @@ -31,11 +31,23 @@ * Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer. */ public final class ChunkFetchSuccess extends AbstractResponseMessage { + public static final int ENCODED_LENGTH = StreamChunkId.ENCODED_LENGTH; public final StreamChunkId streamChunkId; + public final long remainingFrameSize; public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) { super(buffer, true); this.streamChunkId = streamChunkId; + this.remainingFrameSize = 0; + } + + public ChunkFetchSuccess(StreamChunkId streamChunkId, + ManagedBuffer buffer, + boolean isBodyInFrame, + long remainingFrameSize) { + super(buffer, isBodyInFrame); + this.streamChunkId = streamChunkId; + this.remainingFrameSize = remainingFrameSize; } @Override @@ -58,11 +70,16 @@ public ResponseMessage createFailureResponse(String error) { } /** Decoding uses the given ByteBuf as our data, and will retain() it. */ - public static ChunkFetchSuccess decode(ByteBuf buf) { + public static ChunkFetchSuccess decode(ByteBuf buf, long remainingFrameSize) { StreamChunkId streamChunkId = StreamChunkId.decode(buf); - buf.retain(); - NettyManagedBuffer managedBuf = new NettyManagedBuffer(buf.duplicate()); - return new ChunkFetchSuccess(streamChunkId, managedBuf); + NettyManagedBuffer managedBuf = null; + final boolean isFullFrameProcessed = + remainingFrameSize == 0; + if (isFullFrameProcessed) { + buf.retain(); + managedBuf = new NettyManagedBuffer(buf.duplicate()); + } + return new ChunkFetchSuccess(streamChunkId, managedBuf, isFullFrameProcessed, remainingFrameSize); } @Override diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java index 434935a8ef2ad..8360d7d0ac21b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -39,6 +39,9 @@ enum Type implements Encodable { StreamRequest(6), StreamResponse(7), StreamFailure(8), OneWayMessage(9), User(-1); + /** Encoded length in bytes. */ + public static final int LENGTH = 1; + private final byte id; Type(int id) { @@ -48,7 +51,7 @@ enum Type implements Encodable { public byte id() { return id; } - @Override public int encodedLength() { return 1; } + @Override public int encodedLength() { return LENGTH; } @Override public void encode(ByteBuf buf) { buf.writeByte(id); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index 39a7495828a8a..fc7151d55a3e2 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -31,7 +31,7 @@ * This encoder is stateless so it is safe to be shared by multiple threads. */ @ChannelHandler.Sharable -public final class MessageDecoder extends MessageToMessageDecoder { +public final class MessageDecoder extends MessageToMessageDecoder { private static final Logger logger = LoggerFactory.getLogger(MessageDecoder.class); @@ -40,21 +40,20 @@ public final class MessageDecoder extends MessageToMessageDecoder { private MessageDecoder() {} @Override - public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { - Message.Type msgType = Message.Type.decode(in); - Message decoded = decode(msgType, in); - assert decoded.type() == msgType; - logger.trace("Received message {}: {}", msgType, decoded); + public void decode(ChannelHandlerContext ctx, ParsedFrame in, List out) { + Message decoded = decode(in.messageType, in.byteBuf, in.remainingFrameSize); + assert decoded.type() == in.messageType; + logger.trace("Received message {}: {}", in.messageType, decoded); out.add(decoded); } - private Message decode(Message.Type msgType, ByteBuf in) { + private Message decode(Message.Type msgType, ByteBuf in, long remainingFrameSize) { switch (msgType) { case ChunkFetchRequest: return ChunkFetchRequest.decode(in); case ChunkFetchSuccess: - return ChunkFetchSuccess.decode(in); + return ChunkFetchSuccess.decode(in, remainingFrameSize); case ChunkFetchFailure: return ChunkFetchFailure.decode(in); diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ParsedFrame.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ParsedFrame.java new file mode 100644 index 0000000000000..b29fe7626eefa --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ParsedFrame.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import io.netty.buffer.ByteBuf; + +public class ParsedFrame { + + public final Message.Type messageType; + + public final ByteBuf byteBuf; + + public final long remainingFrameSize; + + + public ParsedFrame(Message.Type messageType, ByteBuf byteBuf, long remainingFrameSize) { + this.messageType = messageType; + this.byteBuf = byteBuf; + this.remainingFrameSize = remainingFrameSize; + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java index d46a263884807..0807cb127c9d1 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java @@ -24,6 +24,9 @@ * Encapsulates a request for a particular chunk of a stream. */ public final class StreamChunkId implements Encodable { + + public static final int ENCODED_LENGTH = 8 + 4; + public final long streamId; public final int chunkIndex; @@ -34,7 +37,7 @@ public StreamChunkId(long streamId, int chunkIndex) { @Override public int encodedLength() { - return 8 + 4; + return ENCODED_LENGTH; } public void encode(ByteBuf buffer) { @@ -43,7 +46,7 @@ public void encode(ByteBuf buffer) { } public static StreamChunkId decode(ByteBuf buffer) { - assert buffer.readableBytes() >= 8 + 4; + assert buffer.readableBytes() >= ENCODED_LENGTH; long streamId = buffer.readLong(); int chunkIndex = buffer.readInt(); return new StreamChunkId(streamId, chunkIndex); diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java index 3ac9081d78a75..e29a1e0f2c800 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java @@ -61,7 +61,7 @@ static void addToChannel( channel.pipeline() .addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(backend, maxOutboundBlockSize)) .addFirst("saslDecryption", new DecryptionHandler(backend)) - .addFirst("saslFrameDecoder", NettyUtils.createFrameDecoder()); + .addFirst("saslFrameDecoder", NettyUtils.createFrameDecoder(Integer.MAX_VALUE)); } private static class EncryptionHandler extends ChannelOutboundHandlerAdapter { diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java index 5e85180bd6f9f..2e34b9d69a522 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -84,8 +84,8 @@ public static Class getServerChannelClass(IOMode mode) * Creates a LengthFieldBasedFrameDecoder where the first 8 bytes are the length of the frame. * This is used before all decoders. */ - public static TransportFrameDecoder createFrameDecoder() { - return new TransportFrameDecoder(); + public static TransportFrameDecoder createFrameDecoder(long maxRemoteBlockSizeFetchToMem) { + return new TransportFrameDecoder(maxRemoteBlockSizeFetchToMem); } /** Returns the remote address on the channel or "<unknown remote>" if none exists. */ diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 91497b9492219..a2a0939a15586 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -281,4 +281,8 @@ public Properties cryptoConf() { public long maxChunksBeingTransferred() { return conf.getLong("spark.shuffle.maxChunksBeingTransferred", Long.MAX_VALUE); } + + public long maxRemoteBlockSizeFetchToMem() { + return conf.getLong("spark.maxRemoteBlockSizeFetchToMem", Long.MAX_VALUE); + } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java index 8e73ab077a5c1..197b14e9a5888 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java @@ -25,6 +25,9 @@ import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; +import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.Message; +import org.apache.spark.network.protocol.ParsedFrame; /** * A customized frame decoder that allows intercepting raw data. @@ -55,6 +58,13 @@ public class TransportFrameDecoder extends ChannelInboundHandlerAdapter { private long totalSize = 0; private long nextFrameSize = UNKNOWN_FRAME_SIZE; private volatile Interceptor interceptor; + private Message.Type msgType = null; + + private final long maxRemoteBlockSizeFetchToMem; + + public TransportFrameDecoder(long maxRemoteBlockSizeFetchToMem) { + this.maxRemoteBlockSizeFetchToMem = maxRemoteBlockSizeFetchToMem; + } @Override public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception { @@ -78,11 +88,25 @@ public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception totalSize -= read; } else { // Interceptor is not active, so try to decode one frame. + decodeNextMsgType(); + if (msgType == null) { + break; + } + long remainingFrameSize = 0; + if (msgType == Message.Type.ChunkFetchSuccess && + nextFrameSize - ChunkFetchSuccess.ENCODED_LENGTH > maxRemoteBlockSizeFetchToMem) { + remainingFrameSize = nextFrameSize - ChunkFetchSuccess.ENCODED_LENGTH; + nextFrameSize = ChunkFetchSuccess.ENCODED_LENGTH; + } + ByteBuf frame = decodeNext(); if (frame == null) { break; } - ctx.fireChannelRead(frame); + ParsedFrame parsedFrame = + new ParsedFrame(msgType, frame, remainingFrameSize); + msgType = null; + ctx.fireChannelRead(parsedFrame); } } } @@ -121,18 +145,40 @@ private long decodeFrameSize() { return nextFrameSize; } - private ByteBuf decodeNext() { + private void decodeNextMsgType() { + if (msgType != null) { + return; + } long frameSize = decodeFrameSize(); - if (frameSize == UNKNOWN_FRAME_SIZE || totalSize < frameSize) { + if (frameSize == UNKNOWN_FRAME_SIZE) { + return; + } + + Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE, "Too large frame: %s", frameSize); + Preconditions.checkArgument(frameSize > 0, "Frame length should be positive: %s", frameSize); + + if(totalSize < Message.Type.LENGTH) { + return; + } + + ByteBuf first = buffers.getFirst(); + msgType = Message.Type.decode(first); + totalSize -= Message.Type.LENGTH; + nextFrameSize -= Message.Type.LENGTH; + if (!first.isReadable()) { + buffers.removeFirst().release(); + } + } + + private ByteBuf decodeNext() { + long frameSize = nextFrameSize; + if (totalSize < frameSize) { return null; } // Reset size for next frame. nextFrameSize = UNKNOWN_FRAME_SIZE; - Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE, "Too large frame: %s", frameSize); - Preconditions.checkArgument(frameSize > 0, "Frame length should be positive: %s", frameSize); - // If the first buffer holds the entire frame, return it. int remaining = (int) frameSize; if (buffers.getFirst().readableBytes() >= remaining) { diff --git a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index 824482af08dd4..bb281142e7cf4 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.network; import java.io.File; +import java.io.IOException; import java.io.RandomAccessFile; import java.nio.ByteBuffer; import java.util.Arrays; @@ -32,6 +33,8 @@ import com.google.common.collect.Sets; import com.google.common.io.Closeables; +import org.apache.spark.network.client.*; +import org.apache.spark.network.protocol.StreamChunkId; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; @@ -41,10 +44,6 @@ import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.StreamManager; @@ -151,7 +150,22 @@ private FetchResult fetchChunks(List chunkIndices) throws Exception { res.failedChunks = Collections.synchronizedSet(new HashSet()); res.buffers = Collections.synchronizedList(new LinkedList()); - ChunkReceivedCallback callback = new ChunkReceivedCallback() { + ChunkReceivedWithStreamCallback callback = new ChunkReceivedWithStreamCallback() { + @Override + public void onData(StreamChunkId streamId, ByteBuffer buf) throws IOException { + + } + + @Override + public void onComplete(StreamChunkId streamId) throws IOException { + + } + + @Override + public void onFailure(StreamChunkId streamId, Throwable cause) throws IOException { + + } + @Override public void onSuccess(int chunkIndex, ManagedBuffer buffer) { buffer.retain(); diff --git a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java index bc94f7ca63a96..e677a26242d4e 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -53,7 +53,7 @@ private void testServerToClient(Message msg) { serverChannel.writeOutbound(msg); EmbeddedChannel clientChannel = new EmbeddedChannel( - NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE); + NettyUtils.createFrameDecoder(Integer.MAX_VALUE), MessageDecoder.INSTANCE); while (!serverChannel.outboundMessages().isEmpty()) { clientChannel.writeOneInbound(serverChannel.readOutbound()); @@ -69,7 +69,7 @@ private void testClientToServer(Message msg) { clientChannel.writeOutbound(msg); EmbeddedChannel serverChannel = new EmbeddedChannel( - NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE); + NettyUtils.createFrameDecoder(Integer.MAX_VALUE), MessageDecoder.INSTANCE); while (!clientChannel.outboundMessages().isEmpty()) { serverChannel.writeOneInbound(clientChannel.readOutbound()); diff --git a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java index c0724e018263f..25d5edec53543 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java @@ -20,10 +20,8 @@ import com.google.common.util.concurrent.Uninterruptibles; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.client.*; +import org.apache.spark.network.protocol.StreamChunkId; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; @@ -235,7 +233,7 @@ public StreamManager getStreamManager() { * Callback which sets 'success' or 'failure' on completion. * Additionally notifies all waiters on this callback when invoked. */ - static class TestCallback implements RpcResponseCallback, ChunkReceivedCallback { + static class TestCallback implements RpcResponseCallback, ChunkReceivedWithStreamCallback { int successLength = -1; Throwable failure; @@ -269,5 +267,20 @@ public void onFailure(int chunkIndex, Throwable e) { failure = e; latch.countDown(); } + + @Override + public void onData(StreamChunkId streamId, ByteBuffer buf) throws IOException { + + } + + @Override + public void onComplete(StreamChunkId streamId) throws IOException { + + } + + @Override + public void onFailure(StreamChunkId streamId, Throwable cause) throws IOException { + + } } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java index f253a07e64be1..04ae42efc0abb 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -308,7 +308,7 @@ private void waitForCompletion(TestCallback callback) throws Exception { } - private static class TestCallback implements StreamCallback { + private static class TestCallback implements StreamCallback { private final OutputStream out; public volatile boolean completed; diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java index b4032c4c3f031..9856ae42b6798 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -22,16 +22,13 @@ import io.netty.channel.Channel; import io.netty.channel.local.LocalChannel; +import org.apache.spark.network.client.*; import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.*; import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.StreamCallback; -import org.apache.spark.network.client.TransportResponseHandler; import org.apache.spark.network.protocol.ChunkFetchFailure; import org.apache.spark.network.protocol.ChunkFetchSuccess; import org.apache.spark.network.protocol.RpcFailure; @@ -47,7 +44,7 @@ public void handleSuccessfulFetch() throws Exception { StreamChunkId streamChunkId = new StreamChunkId(1, 0); TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); - ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + ChunkReceivedWithStreamCallback callback = mock(ChunkReceivedWithStreamCallback.class); handler.addFetchRequest(streamChunkId, callback); assertEquals(1, handler.numOutstandingRequests()); @@ -60,7 +57,7 @@ public void handleSuccessfulFetch() throws Exception { public void handleFailedFetch() throws Exception { StreamChunkId streamChunkId = new StreamChunkId(1, 0); TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); - ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + ChunkReceivedWithStreamCallback callback = mock(ChunkReceivedWithStreamCallback.class); handler.addFetchRequest(streamChunkId, callback); assertEquals(1, handler.numOutstandingRequests()); @@ -72,7 +69,7 @@ public void handleFailedFetch() throws Exception { @Test public void clearAllOutstandingRequests() throws Exception { TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); - ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + ChunkReceivedWithStreamCallback callback = mock(ChunkReceivedWithStreamCallback.class); handler.addFetchRequest(new StreamChunkId(1, 0), callback); handler.addFetchRequest(new StreamChunkId(1, 1), callback); handler.addFetchRequest(new StreamChunkId(1, 2), callback); @@ -123,7 +120,8 @@ public void handleFailedRPC() throws Exception { @Test public void testActiveStreams() throws Exception { Channel c = new LocalChannel(); - c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); + c.pipeline() + .addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder(Integer.MAX_VALUE)); TransportResponseHandler handler = new TransportResponseHandler(c); StreamResponse response = new StreamResponse("stream", 1234L, null); @@ -145,7 +143,8 @@ public void testActiveStreams() throws Exception { @Test public void failOutstandingStreamCallbackOnClose() throws Exception { Channel c = new LocalChannel(); - c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); + c.pipeline() + .addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder(Integer.MAX_VALUE)); TransportResponseHandler handler = new TransportResponseHandler(c); StreamCallback cb = mock(StreamCallback.class); @@ -158,7 +157,8 @@ public void failOutstandingStreamCallbackOnClose() throws Exception { @Test public void failOutstandingStreamCallbackOnException() throws Exception { Channel c = new LocalChannel(); - c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); + c.pipeline() + .addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder(Integer.MAX_VALUE)); TransportResponseHandler handler = new TransportResponseHandler(c); StreamCallback cb = mock(StreamCallback.class); diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 6f15718bd8705..7e0beb9d7c87a 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -44,16 +44,13 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelOutboundHandlerAdapter; import io.netty.channel.ChannelPromise; +import org.apache.spark.network.client.*; import org.junit.Test; import org.apache.spark.network.TestUtils; import org.apache.spark.network.TransportContext; import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientBootstrap; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; @@ -273,7 +270,7 @@ public void testFileRegionEncryption() throws Exception { CountDownLatch lock = new CountDownLatch(1); - ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + ChunkReceivedWithStreamCallback callback = mock(ChunkReceivedWithStreamCallback.class); doAnswer(invocation -> { response.set((ManagedBuffer) invocation.getArguments()[1]); response.get().retain(); diff --git a/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java index b53e41303751c..f590d67000c04 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java @@ -19,6 +19,7 @@ import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Random; import java.util.concurrent.atomic.AtomicInteger; @@ -26,6 +27,8 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; +import org.apache.spark.network.protocol.Message; +import org.apache.spark.network.protocol.ParsedFrame; import org.junit.AfterClass; import org.junit.Test; import static org.junit.Assert.*; @@ -42,7 +45,7 @@ public static void cleanup() { @Test public void testFrameDecoding() throws Exception { - TransportFrameDecoder decoder = new TransportFrameDecoder(); + TransportFrameDecoder decoder = new TransportFrameDecoder(Integer.MAX_VALUE); ChannelHandlerContext ctx = mockChannelHandlerContext(); ByteBuf data = createAndFeedFrames(100, decoder, ctx); verifyAndCloseDecoder(decoder, ctx, data); @@ -51,7 +54,7 @@ public void testFrameDecoding() throws Exception { @Test public void testInterception() throws Exception { int interceptedReads = 3; - TransportFrameDecoder decoder = new TransportFrameDecoder(); + TransportFrameDecoder decoder = new TransportFrameDecoder(Integer.MAX_VALUE); TransportFrameDecoder.Interceptor interceptor = spy(new MockInterceptor(interceptedReads)); ChannelHandlerContext ctx = mockChannelHandlerContext(); @@ -69,7 +72,7 @@ public void testInterception() throws Exception { decoder.channelRead(ctx, len); decoder.channelRead(ctx, dataBuf); verify(interceptor, times(interceptedReads)).handle(any(ByteBuf.class)); - verify(ctx).fireChannelRead(any(ByteBuffer.class)); + verify(ctx).fireChannelRead(any(ParsedFrame.class)); assertEquals(0, len.refCnt()); assertEquals(0, dataBuf.refCnt()); } finally { @@ -80,19 +83,19 @@ public void testInterception() throws Exception { @Test public void testRetainedFrames() throws Exception { - TransportFrameDecoder decoder = new TransportFrameDecoder(); + TransportFrameDecoder decoder = new TransportFrameDecoder(Integer.MAX_VALUE); AtomicInteger count = new AtomicInteger(); - List retained = new ArrayList<>(); + List retained = new ArrayList<>(); ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); when(ctx.fireChannelRead(any())).thenAnswer(in -> { // Retain a few frames but not others. - ByteBuf buf = (ByteBuf) in.getArguments()[0]; + ParsedFrame parsedFrame = (ParsedFrame) in.getArguments()[0]; if (count.incrementAndGet() % 2 == 0) { - retained.add(buf); + retained.add(parsedFrame); } else { - buf.release(); + parsedFrame.byteBuf.release(); } return null; }); @@ -100,15 +103,15 @@ public void testRetainedFrames() throws Exception { ByteBuf data = createAndFeedFrames(100, decoder, ctx); try { // Verify all retained buffers are readable. - for (ByteBuf b : retained) { - byte[] tmp = new byte[b.readableBytes()]; - b.readBytes(tmp); - b.release(); + for (ParsedFrame b : retained) { + byte[] tmp = new byte[b.byteBuf.readableBytes()]; + b.byteBuf.readBytes(tmp); + b.byteBuf.release(); } verifyAndCloseDecoder(decoder, ctx, data); } finally { - for (ByteBuf b : retained) { - release(b); + for (ParsedFrame b : retained) { + release(b.byteBuf); } } } @@ -120,13 +123,13 @@ public void testSplitLengthField() throws Exception { buf.writeLong(frame.length + 8); buf.writeBytes(frame); - TransportFrameDecoder decoder = new TransportFrameDecoder(); + TransportFrameDecoder decoder = new TransportFrameDecoder(Integer.MAX_VALUE); ChannelHandlerContext ctx = mockChannelHandlerContext(); try { decoder.channelRead(ctx, buf.readSlice(RND.nextInt(7)).retain()); - verify(ctx, never()).fireChannelRead(any(ByteBuf.class)); + verify(ctx, never()).fireChannelRead(any(ParsedFrame.class)); decoder.channelRead(ctx, buf); - verify(ctx).fireChannelRead(any(ByteBuf.class)); + verify(ctx).fireChannelRead(any(ParsedFrame.class)); assertEquals(0, buf.refCnt()); } finally { decoder.channelInactive(ctx); @@ -154,8 +157,14 @@ private ByteBuf createAndFeedFrames( TransportFrameDecoder decoder, ChannelHandlerContext ctx) throws Exception { ByteBuf data = Unpooled.buffer(); + Message.Type msgTypes[] = Arrays.stream(Message.Type.values()) + .filter(t -> t != Message.Type.User) + .toArray(Message.Type[]::new); + for (int i = 0; i < frameCount; i++) { byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)]; + Message.Type randomMsgType = msgTypes[RND.nextInt(msgTypes.length)]; + frame[0] = randomMsgType.id(); data.writeLong(frame.length + 8); data.writeBytes(frame); } @@ -166,7 +175,7 @@ private ByteBuf createAndFeedFrames( decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size)).retain()); } - verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class)); + verify(ctx, times(frameCount)).fireChannelRead(any(ParsedFrame.class)); } catch (Exception e) { release(data); throw e; @@ -187,7 +196,7 @@ private void verifyAndCloseDecoder( } private void testInvalidFrame(long size) throws Exception { - TransportFrameDecoder decoder = new TransportFrameDecoder(); + TransportFrameDecoder decoder = new TransportFrameDecoder(Integer.MAX_VALUE); ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); ByteBuf frame = Unpooled.copyLong(size); try { @@ -200,8 +209,8 @@ private void testInvalidFrame(long size) throws Exception { private ChannelHandlerContext mockChannelHandlerContext() { ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); when(ctx.fireChannelRead(any())).thenAnswer(in -> { - ByteBuf buf = (ByteBuf) in.getArguments()[0]; - buf.release(); + ParsedFrame parsedFrame = (ParsedFrame) in.getArguments()[0]; + parsedFrame.byteBuf.release(); return null; }); return ctx; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 7ed0b6e93a7a8..f2f445da748fb 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -86,12 +86,13 @@ public void init(String appId) { @Override public void fetchBlocks( - String host, - int port, - String execId, - String[] blockIds, - BlockFetchingListener listener, - TempFileManager tempFileManager) { + String host, + int port, + String execId, + String[] blockIds, + BlockFetchingListener listener, + TempFileManager tempFileManager, + boolean useStreamRequestMessage) { checkInit(); logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { @@ -99,7 +100,7 @@ public void fetchBlocks( (blockIds1, listener1) -> { TransportClient client = clientFactory.createClient(host, port); new OneForOneBlockFetcher(client, appId, execId, - blockIds1, listener1, conf, tempFileManager).start(); + blockIds1, listener1, conf, tempFileManager, useStreamRequestMessage).start(); }; int maxRetries = conf.maxIORetries(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 0bc571874f07c..80a8d8f83d3cb 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -25,15 +25,16 @@ import java.nio.channels.WritableByteChannel; import java.util.Arrays; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import org.apache.spark.network.buffer.NettyManagedBuffer; +import org.apache.spark.network.client.*; +import org.apache.spark.network.protocol.StreamChunkId; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.StreamCallback; -import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; import org.apache.spark.network.shuffle.protocol.OpenBlocks; @@ -56,9 +57,10 @@ public class OneForOneBlockFetcher { private final OpenBlocks openMessage; private final String[] blockIds; private final BlockFetchingListener listener; - private final ChunkReceivedCallback chunkCallback; + private final ChunkReceivedWithStreamCallback chunkCallback; private final TransportConf transportConf; private final TempFileManager tempFileManager; + private final boolean useStreamRequestMessage; private StreamHandle streamHandle = null; @@ -69,7 +71,7 @@ public OneForOneBlockFetcher( String[] blockIds, BlockFetchingListener listener, TransportConf transportConf) { - this(client, appId, execId, blockIds, listener, transportConf, null); + this(client, appId, execId, blockIds, listener, transportConf, null, false); } public OneForOneBlockFetcher( @@ -79,18 +81,61 @@ public OneForOneBlockFetcher( String[] blockIds, BlockFetchingListener listener, TransportConf transportConf, - TempFileManager tempFileManager) { + TempFileManager tempFileManager, + boolean useStreamRequestMessage) { this.client = client; this.openMessage = new OpenBlocks(appId, execId, blockIds); this.blockIds = blockIds; this.listener = listener; this.chunkCallback = new ChunkCallback(); this.transportConf = transportConf; + // TODO extend tests to pass a valid tempFileManager and use: + // this.tempFileManager = Preconditions.checkNotNull(tempFileManager); this.tempFileManager = tempFileManager; + this.useStreamRequestMessage = useStreamRequestMessage; } /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */ - private class ChunkCallback implements ChunkReceivedCallback { + private class ChunkCallback implements ChunkReceivedWithStreamCallback { + private WritableByteChannel channel = null; + private File targetFile = null; + + ChunkCallback() { + this.targetFile = tempFileManager.createTempFile(); + try { + this.channel = Channels.newChannel(new FileOutputStream(targetFile)); + } catch (IOException e) { + throw new IllegalStateException(e); + } + } + + @Override + public void onData(StreamChunkId streamId, ByteBuffer buf) throws IOException { + while (buf.hasRemaining()) { + channel.write(buf); + } + } + + @Override + public void onComplete(StreamChunkId streamId) throws IOException { + channel.close(); + ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0, + targetFile.length()); + listener.onBlockFetchSuccess(blockIds[streamId.chunkIndex], buffer); + if (!tempFileManager.registerTempFileToClean(targetFile)) { + targetFile.delete(); + } + } + + @Override + public void onFailure(StreamChunkId streamId, Throwable cause) throws IOException { + channel.close(); + // On receipt of a failure, fail every block from chunkIndex onwards. + String[] remainingBlockIds = Arrays.copyOfRange(blockIds, streamId.chunkIndex, blockIds.length); + failRemainingBlocks(remainingBlockIds, cause); + targetFile.delete(); + } + @Override public void onSuccess(int chunkIndex, ManagedBuffer buffer) { // On receipt of a chunk, pass it upwards as a block. @@ -125,7 +170,7 @@ public void onSuccess(ByteBuffer response) { // Immediately request all chunks -- we expect that the total size of the request is // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. for (int i = 0; i < streamHandle.numChunks; i++) { - if (tempFileManager != null) { + if (useStreamRequestMessage) { client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i), new DownloadCallback(i)); } else { @@ -157,7 +202,7 @@ private void failRemainingBlocks(String[] failedBlockIds, Throwable e) { } } - private class DownloadCallback implements StreamCallback { + private class DownloadCallback implements StreamCallback { private WritableByteChannel channel = null; private File targetFile = null; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java index 18b04fedcac5b..eeb5c85b967ba 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -37,24 +37,24 @@ public void init(String appId) { } * Note that this API takes a sequence so the implementation can batch requests, and does not * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as * the data of a block is fetched, rather than waiting for all blocks to be fetched. - * * @param host the host of the remote node. * @param port the port of the remote node. * @param execId the executor id. * @param blockIds block ids to fetch. * @param listener the listener to receive block fetching status. * @param tempFileManager TempFileManager to create and clean temp files. - * If it's not null, the remote blocks will be streamed - * into temp shuffle files to reduce the memory usage, otherwise, - * they will be kept in memory. +* If it's not null, the remote blocks will be streamed +* into temp shuffle files to reduce the memory usage, otherwise, + * @param useStreamRequestMessage flags whether to fetch to disk as the request is too large */ public abstract void fetchBlocks( - String host, - int port, - String execId, - String[] blockIds, - BlockFetchingListener listener, - TempFileManager tempFileManager); + String host, + int port, + String execId, + String[] blockIds, + BlockFetchingListener listener, + TempFileManager tempFileManager, + boolean useStreamRequestMessage); /** * Get the shuffle MetricsSet from ShuffleClient, this will be used in MetricsSystem to diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 02e6eb3a4467e..69861acc59edd 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -24,6 +24,8 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicReference; +import org.apache.spark.network.client.*; +import org.apache.spark.network.protocol.StreamChunkId; import org.junit.After; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -35,10 +37,6 @@ import org.apache.spark.network.TestUtils; import org.apache.spark.network.TransportContext; import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; @@ -231,7 +229,22 @@ public void onBlockFetchFailure(String blockId, Throwable t) { blockServer.getPort()); CountDownLatch chunkReceivedLatch = new CountDownLatch(1); - ChunkReceivedCallback callback = new ChunkReceivedCallback() { + ChunkReceivedWithStreamCallback callback = new ChunkReceivedWithStreamCallback() { + @Override + public void onData(StreamChunkId streamId, ByteBuffer buf) throws IOException { + + } + + @Override + public void onComplete(StreamChunkId streamId) throws IOException { + + } + + @Override + public void onFailure(StreamChunkId streamId, Throwable cause) throws IOException { + + } + @Override public void onSuccess(int chunkIndex, ManagedBuffer buffer) { chunkReceivedLatch.countDown(); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index a6a1b8d0ac3f1..e3a85ea67741a 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -158,7 +158,7 @@ public void onBlockFetchFailure(String blockId, Throwable exception) { } } } - }, null); + }, null, false); if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { fail("Timeout getting response from the server"); diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index a54b091a64d50..ad7a7d8af1a74 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -425,12 +425,18 @@ package object config { .doc("Remote block will be fetched to disk when size of the block is above this threshold " + "in bytes. This is to avoid a giant request takes too much memory. We can enable this " + "config by setting a specific value(e.g. 200m). Note this configuration will affect " + - "both shuffle fetch and block manager remote block fetch. For users who enabled " + - "external shuffle service, this feature can only be worked when external shuffle" + - "service is newer than Spark 2.2.") + "both shuffle fetch and block manager remote block fetch.") .bytesConf(ByteUnit.BYTE) .createWithDefault(Long.MaxValue) + private[spark] val STREAM_REQUEST_MESSAGE_ENABLED = + ConfigBuilder("spark.streamRequestMessageEnabled") + .doc("Remote block will be requested to be fetched to disk using stream request message. " + + "For users who enabled external shuffle service, this feature can only be worked when " + + "external shuffle service is newer than Spark 2.2.") + .booleanConf + .createWithDefault(false) + private[spark] val TASK_METRICS_TRACK_UPDATED_BLOCK_STATUSES = ConfigBuilder("spark.taskMetrics.trackUpdatedBlockStatuses") .doc("Enable tracking of updatedBlockStatuses in the TaskMetrics. Off by default since " + diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 1d8a266d0079c..3800c80c56a81 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -68,7 +68,8 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo execId: String, blockIds: Array[String], listener: BlockFetchingListener, - tempFileManager: TempFileManager): Unit + tempFileManager: TempFileManager, + useStreamRequestMessage: Boolean): Unit /** * Upload a single block to a remote node, available only after [[init]] is invoked. @@ -92,7 +93,8 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo port: Int, execId: String, blockId: String, - tempFileManager: TempFileManager): ManagedBuffer = { + tempFileManager: TempFileManager, + useStreamRequestMessage: Boolean): ManagedBuffer = { // A monitor for the thread to wait on. val result = Promise[ManagedBuffer]() fetchBlocks(host, port, execId, Array(blockId), @@ -111,7 +113,9 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo result.success(new NioManagedBuffer(ret)) } } - }, tempFileManager) + }, + tempFileManager, + useStreamRequestMessage) ThreadUtils.awaitResult(result.future, Duration.Inf) } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index b7d8c35032763..5f0e475a12f90 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -100,19 +100,19 @@ private[spark] class NettyBlockTransferService( } override def fetchBlocks( - host: String, - port: Int, + host: String, port: Int, execId: String, blockIds: Array[String], listener: BlockFetchingListener, - tempFileManager: TempFileManager): Unit = { + tempFileManager: TempFileManager, + useStreamRequestMessage: Boolean): Unit = { logTrace(s"Fetch blocks from $host:$port (executor id $execId)") try { val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { val client = clientFactory.createClient(host, port) new OneForOneBlockFetcher(client, appId, execId, blockIds, listener, - transportConf, tempFileManager).start() + transportConf, tempFileManager, useStreamRequestMessage).start() } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index a2936d6ad539c..0d8cf2f135d78 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -407,7 +407,7 @@ private[netty] class NettyRpcEnv( private class FileDownloadCallback( sink: WritableByteChannel, source: FileDownloadChannel, - client: TransportClient) extends StreamCallback { + client: TransportClient) extends StreamCallback[String] { override def onData(streamId: String, buf: ByteBuffer): Unit = { while (buf.remaining() > 0) { diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 4103dfb10175e..44884395984a7 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -52,6 +52,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), + SparkEnv.get.conf.get(config.STREAM_REQUEST_MESSAGE_ENABLED), SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true)) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e0276a4dc4224..197154814e169 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -212,6 +212,7 @@ private[spark] class BlockManager( private[storage] val remoteBlockTempFileManager = new BlockManager.RemoteBlockTempFileManager(this) private val maxRemoteBlockToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) + private val streamRequestMessageEnabled = conf.get(config.STREAM_REQUEST_MESSAGE_ENABLED) /** * Initializes the BlockManager with the given appId. This is not performed in the constructor as @@ -671,16 +672,7 @@ private[spark] class BlockManager( b.status.diskSize.max(b.status.memSize) }.getOrElse(0L) val blockLocations = locationsAndStatus.map(_.locations).getOrElse(Seq.empty) - - // If the block size is above the threshold, we should pass our FileManger to - // BlockTransferService, which will leverage it to spill the block; if not, then passed-in - // null value means the block will be persisted in memory. - val tempFileManager = if (blockSize > maxRemoteBlockToMem) { - remoteBlockTempFileManager - } else { - null - } - + val useStreamRequestMessage = streamRequestMessageEnabled && blockSize > maxRemoteBlockToMem val locations = sortLocations(blockLocations) val maxFetchFailures = locations.size var locationIterator = locations.iterator @@ -689,7 +681,12 @@ private[spark] class BlockManager( logDebug(s"Getting remote block $blockId from $loc") val data = try { blockTransferService.fetchBlockSync( - loc.host, loc.port, loc.executorId, blockId.toString, tempFileManager).nioByteBuffer() + loc.host, + loc.port, + loc.executorId, + blockId.toString, + remoteBlockTempFileManager, + useStreamRequestMessage).nioByteBuffer() } catch { case NonFatal(e) => runningFailureCount += 1 diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index b31862323a895..b7203e01b013d 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -69,6 +69,7 @@ final class ShuffleBlockFetcherIterator( maxBytesInFlight: Long, maxReqsInFlight: Int, maxBlocksInFlightPerAddress: Int, + streamRequestMessageEnabled: Boolean, maxReqSizeShuffleToMem: Long, detectCorrupt: Boolean) extends Iterator[(BlockId, InputStream)] with TempFileManager with Logging { @@ -248,13 +249,14 @@ final class ShuffleBlockFetcherIterator( // Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch // the data and write it to file directly. - if (req.size > maxReqSizeShuffleToMem) { - shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - blockFetchingListener, this) - } else { - shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - blockFetchingListener, null) - } + shuffleClient.fetchBlocks( + address.host, + address.port, + address.executorId, + blockIds.toArray, + blockFetchingListener, + this, + req.size > maxReqSizeShuffleToMem && streamRequestMessageEnabled) } private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 28ea0c6f0bdba..0af967f39612e 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -171,7 +171,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex val serializerManager = SparkEnv.get.serializerManager blockManager.master.getLocations(blockId).foreach { cmId => val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId, - blockId.toString, null) + blockId.toString, null, false) val deserialized = serializerManager.dataDeserializeStream(blockId, new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream())(data.elementClassTag).toList assert(deserialized === (1 to 100).toList) diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 21138bd4a16ba..b574e07cc7b50 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -165,7 +165,9 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { promise.success(data.retain()) } - }, null) + }, + null, + false) ThreadUtils.awaitReady(promise.future, FiniteDuration(10, TimeUnit.SECONDS)) promise.future.value.get diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index b19d8ebf72c61..ec6a5cbcb280d 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -1434,7 +1434,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE execId: String, blockIds: Array[String], listener: BlockFetchingListener, - tempFileManager: TempFileManager): Unit = { + tempFileManager: TempFileManager, + useStreamRequestMessage: Boolean): Unit = { listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1))) } @@ -1461,13 +1462,14 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE port: Int, execId: String, blockId: String, - tempFileManager: TempFileManager): ManagedBuffer = { + tempFileManager: TempFileManager, + useStreamRequestMessage: Boolean): ManagedBuffer = { numCalls += 1 this.tempFileManager = tempFileManager if (numCalls <= maxFailures) { throw new RuntimeException("Failing block fetch in the mock block transfer service") } - super.fetchBlockSync(host, port, execId, blockId, tempFileManager) + super.fetchBlockSync(host, port, execId, blockId, tempFileManager, useStreamRequestMessage) } } } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index a2997dbd1b1ac..81b245a3f5bcf 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -46,7 +46,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT /** Creates a mock [[BlockTransferService]] that returns data from the given map. */ private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = { val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), false)) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]] @@ -111,6 +111,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, + false, Int.MaxValue, true) @@ -140,7 +141,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // 3 local blocks, and 2 remote blocks // (but from the same block manager so one call to fetchBlocks) verify(blockManager, times(3)).getBlockData(any()) - verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any()) + verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any(), false) } test("release current unexhausted buffer in case the task completes early") { @@ -159,7 +160,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), false)) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -189,6 +190,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, + false, Int.MaxValue, true) @@ -227,7 +229,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), false)) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -257,6 +259,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, + false, Int.MaxValue, true) @@ -297,7 +300,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), false)) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -327,6 +330,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, + false, Int.MaxValue, true) @@ -337,7 +341,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val (id1, _) = iterator.next() assert(id1 === ShuffleBlockId(0, 0, 0)) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), false)) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -391,6 +395,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 2048, Int.MaxValue, Int.MaxValue, + false, Int.MaxValue, true) // Blocks should be returned without exceptions. @@ -415,7 +420,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), false)) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -445,6 +450,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, + false, Int.MaxValue, false) @@ -479,7 +485,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) val transfer = mock(classOf[BlockTransferService]) var tempFileManager: TempFileManager = null - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), false)) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -505,6 +511,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT maxBytesInFlight = Int.MaxValue, maxReqsInFlight = Int.MaxValue, maxBlocksInFlightPerAddress = Int.MaxValue, + streamRequestMessageEnabled = false, maxReqSizeShuffleToMem = 200, detectCorrupt = true) } @@ -551,6 +558,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, + false, Int.MaxValue, true) From d2753a66da0bd19b8aeb0c1bc232f75afb363974 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Tue, 29 May 2018 18:43:42 +0200 Subject: [PATCH 2/8] introduce factory --- .../ChunkReceivedWithStreamCallback.java | 24 ----- .../spark/network/client/TransportClient.java | 6 +- .../client/TransportResponseHandler.java | 28 +++--- .../network/ChunkFetchIntegrationSuite.java | 22 ++--- .../RequestTimeoutIntegrationSuite.java | 24 ++--- .../TransportResponseHandlerSuite.java | 21 +++-- .../spark/network/sasl/SparkSaslSuite.java | 7 +- .../shuffle/OneForOneBlockFetcher.java | 89 ++++++++++--------- .../network/sasl/SaslIntegrationSuite.java | 20 +---- .../shuffle/OneForOneBlockFetcherSuite.java | 2 +- 10 files changed, 104 insertions(+), 139 deletions(-) delete mode 100644 common/network-common/src/main/java/org/apache/spark/network/client/ChunkReceivedWithStreamCallback.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/ChunkReceivedWithStreamCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/ChunkReceivedWithStreamCallback.java deleted file mode 100644 index ad36fa7b5353c..0000000000000 --- a/common/network-common/src/main/java/org/apache/spark/network/client/ChunkReceivedWithStreamCallback.java +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.client; - -import org.apache.spark.network.protocol.StreamChunkId; - -public interface ChunkReceivedWithStreamCallback extends - ChunkReceivedCallback, StreamCallback { -} diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index 1fb037bc12f74..c8a8c83f93b9c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -24,6 +24,7 @@ import java.util.UUID; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; import javax.annotation.Nullable; import com.google.common.annotations.VisibleForTesting; @@ -132,14 +133,15 @@ public void setClientId(String id) { public void fetchChunk( long streamId, int chunkIndex, - ChunkReceivedWithStreamCallback callback) { + ChunkReceivedCallback callback, + Supplier> streamCallbackFactory) { long startTime = System.currentTimeMillis(); if (logger.isDebugEnabled()) { logger.debug("Sending fetch chunk request {} to {}", chunkIndex, getRemoteAddress(channel)); } StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex); - handler.addFetchRequest(streamChunkId, callback); + handler.addFetchRequest(streamChunkId, callback, streamCallbackFactory); channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(future -> { if (future.isSuccess()) { diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index c00a82e76547d..3630a0ad6c395 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -23,6 +23,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; import com.google.common.annotations.VisibleForTesting; import io.netty.channel.Channel; @@ -54,11 +55,13 @@ public class TransportResponseHandler extends MessageHandler { private final Channel channel; - private final Map outstandingFetches; + private final Map outstandingFetches; + private final Map>> + outstandingStreamFetches; private final Map outstandingRpcs; - private final Queue> streamCallbacks; + private final Queue>> streamCallbacks; private volatile boolean streamActive; /** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */ @@ -67,6 +70,7 @@ public class TransportResponseHandler extends MessageHandler { public TransportResponseHandler(Channel channel) { this.channel = channel; this.outstandingFetches = new ConcurrentHashMap<>(); + this.outstandingStreamFetches = new ConcurrentHashMap<>(); this.outstandingRpcs = new ConcurrentHashMap<>(); this.streamCallbacks = new ConcurrentLinkedQueue<>(); this.timeOfLastRequestNs = new AtomicLong(0); @@ -74,9 +78,11 @@ public TransportResponseHandler(Channel channel) { public void addFetchRequest( StreamChunkId streamChunkId, - ChunkReceivedWithStreamCallback callback) { + ChunkReceivedCallback callback, + Supplier> streamCallbackFactory) { updateTimeOfLastRequest(); outstandingFetches.put(streamChunkId, callback); + outstandingStreamFetches.put(streamChunkId, streamCallbackFactory); } public void removeFetchRequest(StreamChunkId streamChunkId) { @@ -107,7 +113,7 @@ public void deactivateStream() { * uncaught exception or pre-mature connection termination. */ private void failOutstandingRequests(Throwable cause) { - for (Map.Entry entry : outstandingFetches.entrySet()) { + for (Map.Entry entry : outstandingFetches.entrySet()) { try { entry.getValue().onFailure(entry.getKey().chunkIndex, cause); } catch (Exception e) { @@ -121,7 +127,7 @@ private void failOutstandingRequests(Throwable cause) { logger.warn("RpcResponseCallback.onFailure throws exception", e); } } - for (Pair entry : streamCallbacks) { + for (Pair> entry : streamCallbacks) { try { entry.getValue().onFailure(entry.getKey(), cause); } catch (Exception e) { @@ -163,7 +169,7 @@ public void exceptionCaught(Throwable cause) { public void handle(ResponseMessage message) throws Exception { if (message instanceof ChunkFetchSuccess) { ChunkFetchSuccess resp = (ChunkFetchSuccess) message; - ChunkReceivedWithStreamCallback listener = outstandingFetches.get(resp.streamChunkId); + ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); if (listener == null) { logger.warn("Ignoring response for block {} from {} since it is not outstanding", resp.streamChunkId, getRemoteAddress(channel)); @@ -172,9 +178,11 @@ public void handle(ResponseMessage message) throws Exception { outstandingFetches.remove(resp.streamChunkId); listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body()); } else { + StreamCallback streamCallback = + outstandingStreamFetches.get(resp.streamChunkId).get(); if (resp.remainingFrameSize > 0) { StreamInterceptor interceptor = new StreamInterceptor(this, - resp.streamChunkId, resp.remainingFrameSize, listener); + resp.streamChunkId, resp.remainingFrameSize, streamCallback); try { TransportFrameDecoder frameDecoder = (TransportFrameDecoder) channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); @@ -186,7 +194,7 @@ public void handle(ResponseMessage message) throws Exception { } } else { try { - listener.onComplete(resp.streamChunkId); + streamCallback.onComplete(resp.streamChunkId); } catch (Exception e) { logger.warn("Error in stream handler onComplete().", e); } @@ -233,7 +241,7 @@ public void handle(ResponseMessage message) throws Exception { } } else if (message instanceof StreamResponse) { StreamResponse resp = (StreamResponse) message; - Pair entry = streamCallbacks.poll(); + Pair> entry = streamCallbacks.poll(); if (entry != null) { StreamCallback callback = entry.getValue(); if (resp.byteCount > 0) { @@ -260,7 +268,7 @@ public void handle(ResponseMessage message) throws Exception { } } else if (message instanceof StreamFailure) { StreamFailure resp = (StreamFailure) message; - Pair entry = streamCallbacks.poll(); + Pair> entry = streamCallbacks.poll(); if (entry != null) { StreamCallback callback = entry.getValue(); try { diff --git a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index bb281142e7cf4..62c3b7e4c61c7 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -30,6 +30,7 @@ import java.util.Set; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; import com.google.common.collect.Sets; import com.google.common.io.Closeables; @@ -40,6 +41,7 @@ import org.junit.Test; import static org.junit.Assert.*; +import static org.mockito.Mockito.mock; import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; @@ -150,22 +152,7 @@ private FetchResult fetchChunks(List chunkIndices) throws Exception { res.failedChunks = Collections.synchronizedSet(new HashSet()); res.buffers = Collections.synchronizedList(new LinkedList()); - ChunkReceivedWithStreamCallback callback = new ChunkReceivedWithStreamCallback() { - @Override - public void onData(StreamChunkId streamId, ByteBuffer buf) throws IOException { - - } - - @Override - public void onComplete(StreamChunkId streamId) throws IOException { - - } - - @Override - public void onFailure(StreamChunkId streamId, Throwable cause) throws IOException { - - } - + ChunkReceivedCallback callback = new ChunkReceivedCallback() { @Override public void onSuccess(int chunkIndex, ManagedBuffer buffer) { buffer.retain(); @@ -181,8 +168,9 @@ public void onFailure(int chunkIndex, Throwable e) { } }; + Supplier> streamCallbackFactory = mock(Supplier.class); for (int chunkIndex : chunkIndices) { - client.fetchChunk(STREAM_ID, chunkIndex, callback); + client.fetchChunk(STREAM_ID, chunkIndex, callback, streamCallbackFactory); } if (!sem.tryAcquire(chunkIndices.size(), 5, TimeUnit.SECONDS)) { fail("Timeout getting response from the server"); diff --git a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java index 25d5edec53543..1d6dcf4a11280 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java @@ -29,6 +29,7 @@ import org.apache.spark.network.util.TransportConf; import org.junit.*; import static org.junit.Assert.*; +import static org.mockito.Mockito.mock; import java.io.IOException; import java.nio.ByteBuffer; @@ -36,6 +37,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; /** * Suite which ensures that requests that go without a response for the network timeout period are @@ -209,12 +211,13 @@ public StreamManager getStreamManager() { // Send one request, which will eventually fail. TestCallback callback0 = new TestCallback(); - client.fetchChunk(0, 0, callback0); + Supplier> streamCallbackFactory = mock(Supplier.class); + client.fetchChunk(0, 0, callback0, streamCallbackFactory); Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS); // Send a second request before the first has failed. TestCallback callback1 = new TestCallback(); - client.fetchChunk(0, 1, callback1); + client.fetchChunk(0, 1, callback1, streamCallbackFactory); Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS); // not complete yet, but should complete soon @@ -233,7 +236,7 @@ public StreamManager getStreamManager() { * Callback which sets 'success' or 'failure' on completion. * Additionally notifies all waiters on this callback when invoked. */ - static class TestCallback implements RpcResponseCallback, ChunkReceivedWithStreamCallback { + static class TestCallback implements RpcResponseCallback, ChunkReceivedCallback { int successLength = -1; Throwable failure; @@ -267,20 +270,5 @@ public void onFailure(int chunkIndex, Throwable e) { failure = e; latch.countDown(); } - - @Override - public void onData(StreamChunkId streamId, ByteBuffer buf) throws IOException { - - } - - @Override - public void onComplete(StreamChunkId streamId) throws IOException { - - } - - @Override - public void onFailure(StreamChunkId streamId, Throwable cause) throws IOException { - - } } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java index 9856ae42b6798..f5935e4c404c1 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.function.Supplier; import io.netty.channel.Channel; import io.netty.channel.local.LocalChannel; @@ -44,8 +45,9 @@ public void handleSuccessfulFetch() throws Exception { StreamChunkId streamChunkId = new StreamChunkId(1, 0); TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); - ChunkReceivedWithStreamCallback callback = mock(ChunkReceivedWithStreamCallback.class); - handler.addFetchRequest(streamChunkId, callback); + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + Supplier> streamCallbackFactory = mock(Supplier.class); + handler.addFetchRequest(streamChunkId, callback, streamCallbackFactory); assertEquals(1, handler.numOutstandingRequests()); handler.handle(new ChunkFetchSuccess(streamChunkId, new TestManagedBuffer(123))); @@ -57,8 +59,9 @@ public void handleSuccessfulFetch() throws Exception { public void handleFailedFetch() throws Exception { StreamChunkId streamChunkId = new StreamChunkId(1, 0); TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); - ChunkReceivedWithStreamCallback callback = mock(ChunkReceivedWithStreamCallback.class); - handler.addFetchRequest(streamChunkId, callback); + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + Supplier> streamCallbackFactory = mock(Supplier.class); + handler.addFetchRequest(streamChunkId, callback, streamCallbackFactory); assertEquals(1, handler.numOutstandingRequests()); handler.handle(new ChunkFetchFailure(streamChunkId, "some error msg")); @@ -69,10 +72,12 @@ public void handleFailedFetch() throws Exception { @Test public void clearAllOutstandingRequests() throws Exception { TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); - ChunkReceivedWithStreamCallback callback = mock(ChunkReceivedWithStreamCallback.class); - handler.addFetchRequest(new StreamChunkId(1, 0), callback); - handler.addFetchRequest(new StreamChunkId(1, 1), callback); - handler.addFetchRequest(new StreamChunkId(1, 2), callback); + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + Supplier> streamCallback = mock(Supplier.class); + + handler.addFetchRequest(new StreamChunkId(1, 0), callback, streamCallback); + handler.addFetchRequest(new StreamChunkId(1, 1), callback, streamCallback); + handler.addFetchRequest(new StreamChunkId(1, 2), callback, streamCallback); assertEquals(3, handler.numOutstandingRequests()); handler.handle(new ChunkFetchSuccess(new StreamChunkId(1, 0), new TestManagedBuffer(12))); diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 7e0beb9d7c87a..397e4f63f608f 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -33,6 +33,7 @@ import java.util.concurrent.TimeoutException; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; import javax.security.sasl.SaslException; import com.google.common.collect.ImmutableMap; @@ -45,6 +46,7 @@ import io.netty.channel.ChannelOutboundHandlerAdapter; import io.netty.channel.ChannelPromise; import org.apache.spark.network.client.*; +import org.apache.spark.network.protocol.StreamChunkId; import org.junit.Test; import org.apache.spark.network.TestUtils; @@ -270,7 +272,7 @@ public void testFileRegionEncryption() throws Exception { CountDownLatch lock = new CountDownLatch(1); - ChunkReceivedWithStreamCallback callback = mock(ChunkReceivedWithStreamCallback.class); + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); doAnswer(invocation -> { response.set((ManagedBuffer) invocation.getArguments()[1]); response.get().retain(); @@ -278,7 +280,8 @@ public void testFileRegionEncryption() throws Exception { return null; }).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class)); - ctx.client.fetchChunk(0, 0, callback); + Supplier> streamCallbackFactory = mock(Supplier.class); + ctx.client.fetchChunk(0, 0, callback, streamCallbackFactory); lock.await(10, TimeUnit.SECONDS); verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class)); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 80a8d8f83d3cb..5d1ec48ddaf0c 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -24,6 +24,7 @@ import java.nio.channels.Channels; import java.nio.channels.WritableByteChannel; import java.util.Arrays; +import java.util.function.Supplier; import io.netty.buffer.ByteBuf; import io.netty.buffer.CompositeByteBuf; @@ -57,7 +58,8 @@ public class OneForOneBlockFetcher { private final OpenBlocks openMessage; private final String[] blockIds; private final BlockFetchingListener listener; - private final ChunkReceivedWithStreamCallback chunkCallback; + private final ChunkReceivedCallback chunkCallback; + private final Supplier> fetchChunkDownloadCallbackFactory; private final TransportConf transportConf; private final TempFileManager tempFileManager; private final boolean useStreamRequestMessage; @@ -91,50 +93,13 @@ public OneForOneBlockFetcher( this.transportConf = transportConf; // TODO extend tests to pass a valid tempFileManager and use: // this.tempFileManager = Preconditions.checkNotNull(tempFileManager); + fetchChunkDownloadCallbackFactory = () -> new FetchChunkDownloadCallback(); this.tempFileManager = tempFileManager; this.useStreamRequestMessage = useStreamRequestMessage; } /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */ - private class ChunkCallback implements ChunkReceivedWithStreamCallback { - private WritableByteChannel channel = null; - private File targetFile = null; - - ChunkCallback() { - this.targetFile = tempFileManager.createTempFile(); - try { - this.channel = Channels.newChannel(new FileOutputStream(targetFile)); - } catch (IOException e) { - throw new IllegalStateException(e); - } - } - - @Override - public void onData(StreamChunkId streamId, ByteBuffer buf) throws IOException { - while (buf.hasRemaining()) { - channel.write(buf); - } - } - - @Override - public void onComplete(StreamChunkId streamId) throws IOException { - channel.close(); - ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0, - targetFile.length()); - listener.onBlockFetchSuccess(blockIds[streamId.chunkIndex], buffer); - if (!tempFileManager.registerTempFileToClean(targetFile)) { - targetFile.delete(); - } - } - - @Override - public void onFailure(StreamChunkId streamId, Throwable cause) throws IOException { - channel.close(); - // On receipt of a failure, fail every block from chunkIndex onwards. - String[] remainingBlockIds = Arrays.copyOfRange(blockIds, streamId.chunkIndex, blockIds.length); - failRemainingBlocks(remainingBlockIds, cause); - targetFile.delete(); - } + private class ChunkCallback implements ChunkReceivedCallback { @Override public void onSuccess(int chunkIndex, ManagedBuffer buffer) { @@ -174,7 +139,8 @@ public void onSuccess(ByteBuffer response) { client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i), new DownloadCallback(i)); } else { - client.fetchChunk(streamHandle.streamId, i, chunkCallback); + client.fetchChunk(streamHandle.streamId, i, + chunkCallback, fetchChunkDownloadCallbackFactory); } } } catch (Exception e) { @@ -241,4 +207,45 @@ public void onFailure(String streamId, Throwable cause) throws IOException { targetFile.delete(); } } + + private class FetchChunkDownloadCallback implements StreamCallback { + private WritableByteChannel channel = null; + private File targetFile = null; + + FetchChunkDownloadCallback() { + this.targetFile = tempFileManager.createTempFile(); + try { + this.channel = Channels.newChannel(new FileOutputStream(targetFile)); + } catch (IOException e) { + throw new IllegalStateException(e); + } + } + + @Override + public void onData(StreamChunkId streamId, ByteBuffer buf) throws IOException { + while (buf.hasRemaining()) { + channel.write(buf); + } + } + + @Override + public void onComplete(StreamChunkId streamId) throws IOException { + channel.close(); + ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0, + targetFile.length()); + listener.onBlockFetchSuccess(blockIds[streamId.chunkIndex], buffer); + if (!tempFileManager.registerTempFileToClean(targetFile)) { + targetFile.delete(); + } + } + + @Override + public void onFailure(StreamChunkId streamId, Throwable cause) throws IOException { + channel.close(); + // On receipt of a failure, fail every block from chunkIndex onwards. + String[] remainingBlockIds = Arrays.copyOfRange(blockIds, streamId.chunkIndex, blockIds.length); + failRemainingBlocks(remainingBlockIds, cause); + targetFile.delete(); + } + } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 69861acc59edd..416af7150b67f 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -23,6 +23,7 @@ import java.util.Arrays; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; import org.apache.spark.network.client.*; import org.apache.spark.network.protocol.StreamChunkId; @@ -229,21 +230,7 @@ public void onBlockFetchFailure(String blockId, Throwable t) { blockServer.getPort()); CountDownLatch chunkReceivedLatch = new CountDownLatch(1); - ChunkReceivedWithStreamCallback callback = new ChunkReceivedWithStreamCallback() { - @Override - public void onData(StreamChunkId streamId, ByteBuffer buf) throws IOException { - - } - - @Override - public void onComplete(StreamChunkId streamId) throws IOException { - - } - - @Override - public void onFailure(StreamChunkId streamId, Throwable cause) throws IOException { - - } + ChunkReceivedCallback callback = new ChunkReceivedCallback() { @Override public void onSuccess(int chunkIndex, ManagedBuffer buffer) { @@ -257,7 +244,8 @@ public void onFailure(int chunkIndex, Throwable t) { }; exception.set(null); - client2.fetchChunk(streamId, 0, callback); + Supplier> streamCallbackFactory = mock(Supplier.class); + client2.fetchChunk(streamId, 0, callback, streamCallbackFactory); chunkReceivedLatch.await(); checkSecurityException(exception.get()); } finally { diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index dc947a619bf02..ab8a37cdb3154 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -165,7 +165,7 @@ private static BlockFetchingListener fetchBlocks(LinkedHashMap Date: Wed, 30 May 2018 15:07:54 +0200 Subject: [PATCH 3/8] Extend ProtocolSuite --- .../apache/spark/network/ProtocolSuite.java | 44 ++++++++++++++++--- 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java index e677a26242d4e..8c8c5ef2c80ae 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -28,6 +28,8 @@ import org.junit.Test; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; import org.apache.spark.network.protocol.ChunkFetchFailure; import org.apache.spark.network.protocol.ChunkFetchRequest; @@ -47,20 +49,25 @@ import org.apache.spark.network.util.NettyUtils; public class ProtocolSuite { - private void testServerToClient(Message msg) { + private Message decodedClientMessageFromChannel(Message msg, long maxRemoteBlockSizeFetchToMem) { EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(), - MessageEncoder.INSTANCE); + MessageEncoder.INSTANCE); serverChannel.writeOutbound(msg); EmbeddedChannel clientChannel = new EmbeddedChannel( - NettyUtils.createFrameDecoder(Integer.MAX_VALUE), MessageDecoder.INSTANCE); + NettyUtils.createFrameDecoder(maxRemoteBlockSizeFetchToMem), MessageDecoder.INSTANCE); while (!serverChannel.outboundMessages().isEmpty()) { clientChannel.writeOneInbound(serverChannel.readOutbound()); } assertEquals(1, clientChannel.inboundMessages().size()); - assertEquals(msg, clientChannel.readInbound()); + return clientChannel.readInbound(); + } + + private void testServerToClient(Message msg) { + Message clientMessage = decodedClientMessageFromChannel(msg, Integer.MAX_VALUE); + assertEquals(msg, clientMessage); } private void testClientToServer(Message msg) { @@ -79,6 +86,31 @@ private void testClientToServer(Message msg) { assertEquals(msg, serverChannel.readInbound()); } + private ChunkFetchSuccess chunkFetchSuccessWith100Bytes() { + return new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(100)); + } + + private void testChunkFetchSuccess() { + // test without fetch to disk, maxRemoteBlockSizeFetchToMem is Integer.MAX_VALUE + testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(10))); + testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0))); + + // test with fetch to disk + // under (and at) the fetch to mem limit + Message chunkFetchSuccessUnderFetchToMemBlockSize = + decodedClientMessageFromChannel(chunkFetchSuccessWith100Bytes(), 101); + assertEquals(chunkFetchSuccessWith100Bytes(), chunkFetchSuccessUnderFetchToMemBlockSize); + chunkFetchSuccessUnderFetchToMemBlockSize = + decodedClientMessageFromChannel(chunkFetchSuccessWith100Bytes(), 100); + assertEquals(chunkFetchSuccessWith100Bytes(), chunkFetchSuccessUnderFetchToMemBlockSize); + + // above the fetch to mem limit + Message chunkFetchSuccessAboveFetchToMemBlockSize = + decodedClientMessageFromChannel(chunkFetchSuccessWith100Bytes(), 99); + assertNull("message body must be not included", chunkFetchSuccessAboveFetchToMemBlockSize.body()); + assertFalse("message body must be not included", chunkFetchSuccessAboveFetchToMemBlockSize.isBodyInFrame()); + } + @Test public void requests() { testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2))); @@ -88,10 +120,10 @@ public void requests() { testClientToServer(new OneWayMessage(new TestManagedBuffer(10))); } + @Test public void responses() { - testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(10))); - testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0))); + testChunkFetchSuccess(); testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "this is an error")); testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "")); testServerToClient(new RpcResponse(12345, new TestManagedBuffer(0))); From acc1e20fe025a238a4a3b4a86b373ca1219d4a9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Thu, 31 May 2018 18:14:32 +0200 Subject: [PATCH 4/8] add test for fetch to disk --- .../network/ChunkFetchIntegrationSuite.java | 132 ++++++++++++++---- 1 file changed, 106 insertions(+), 26 deletions(-) diff --git a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index 62c3b7e4c61c7..2d770db4e34c0 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -18,19 +18,15 @@ package org.apache.spark.network; import java.io.File; +import java.io.FileOutputStream; import java.io.IOException; import java.io.RandomAccessFile; import java.nio.ByteBuffer; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashSet; -import java.util.LinkedList; -import java.util.List; -import java.util.Random; -import java.util.Set; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; +import java.util.*; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; -import java.util.function.Supplier; import com.google.common.collect.Sets; import com.google.common.io.Closeables; @@ -56,6 +52,14 @@ public class ChunkFetchIntegrationSuite { static final long STREAM_ID = 1; static final int BUFFER_CHUNK_INDEX = 0; static final int FILE_CHUNK_INDEX = 1; + static final int BUFFER_FETCH_TO_DISK_CHUNK_INDEX = 2; + static final int MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM = 100000; + + static final TransportConf transportConf = + new TransportConf("shuffle", + new MapConfigProvider( + Collections.singletonMap( + "spark.maxRemoteBlockSizeFetchToMem", Integer.toString(MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM)))); static TransportServer server; static TransportClientFactory clientFactory; @@ -63,17 +67,68 @@ public class ChunkFetchIntegrationSuite { static File testFile; static ManagedBuffer bufferChunk; + static ManagedBuffer bufferToDiskChunk; + static ManagedBuffer fileChunk; + private class FetchChunkDownloadTestCallback implements StreamCallback { + private WritableByteChannel channel = null; + private File targetFile = null; + private FetchResult fetchResult; + private Semaphore semaphore; + + FetchChunkDownloadTestCallback(FetchResult fetchResult, Semaphore semaphore) { + this.fetchResult = fetchResult; + this.semaphore = semaphore; + try { + this.targetFile = File.createTempFile("shuffle-test-file-download-", "txt"); + this.targetFile.deleteOnExit(); + this.channel = Channels.newChannel(new FileOutputStream(targetFile)); + } catch (IOException e) { + throw new IllegalStateException(e); + } + } + + @Override + public void onData(StreamChunkId streamId, ByteBuffer buf) throws IOException { + while (buf.hasRemaining()) { + channel.write(buf); + } + } + + @Override + public void onComplete(StreamChunkId streamId) throws IOException { + channel.close(); + ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0, + targetFile.length()); + fetchResult.successChunks.add(streamId.chunkIndex); + fetchResult.buffers.add(buffer); + semaphore.release(); + } + + @Override + public void onFailure(StreamChunkId streamId, Throwable cause) throws IOException { + channel.close(); + this.fetchResult.failedChunks.add(streamId.chunkIndex); + semaphore.release(); + } + } + @BeforeClass public static void setUp() throws Exception { - int bufSize = 100000; - final ByteBuffer buf = ByteBuffer.allocate(bufSize); - for (int i = 0; i < bufSize; i ++) { - buf.put((byte) i); + int bufSize = MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM + 100; + final ByteBuffer hugeBuf = ByteBuffer.allocate(bufSize); + for (int i = 0; i < MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM; i ++) { + hugeBuf.put((byte) i); + } + ByteBuffer smallBuff = hugeBuf.duplicate(); + smallBuff.flip(); + bufferChunk = new NioManagedBuffer(smallBuff); + for (int i = MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM; i < bufSize; i ++) { + hugeBuf.put((byte) i); } - buf.flip(); - bufferChunk = new NioManagedBuffer(buf); + hugeBuf.flip(); + bufferToDiskChunk = new NioManagedBuffer(hugeBuf); testFile = File.createTempFile("shuffle-test-file", "txt"); testFile.deleteOnExit(); @@ -88,19 +143,21 @@ public static void setUp() throws Exception { Closeables.close(fp, shouldSuppressIOException); } - final TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); - fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); + fileChunk = new FileSegmentManagedBuffer(transportConf, testFile, 10, testFile.length() - 25); streamManager = new StreamManager() { @Override public ManagedBuffer getChunk(long streamId, int chunkIndex) { assertEquals(STREAM_ID, streamId); - if (chunkIndex == BUFFER_CHUNK_INDEX) { - return new NioManagedBuffer(buf); - } else if (chunkIndex == FILE_CHUNK_INDEX) { - return new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); - } else { - throw new IllegalArgumentException("Invalid chunk index: " + chunkIndex); + switch (chunkIndex) { + case BUFFER_CHUNK_INDEX: + return new NioManagedBuffer(smallBuff); + case FILE_CHUNK_INDEX: + return new FileSegmentManagedBuffer(transportConf, testFile, 10, testFile.length() - 25); + case BUFFER_FETCH_TO_DISK_CHUNK_INDEX: + return new NioManagedBuffer(hugeBuf); + default: + throw new IllegalArgumentException("Invalid chunk index: " + chunkIndex); } } }; @@ -118,7 +175,7 @@ public StreamManager getStreamManager() { return streamManager; } }; - TransportContext context = new TransportContext(conf, handler); + TransportContext context = new TransportContext(transportConf, handler); server = context.createServer(); clientFactory = context.createClientFactory(); } @@ -126,6 +183,7 @@ public StreamManager getStreamManager() { @AfterClass public static void tearDown() { bufferChunk.release(); + bufferToDiskChunk.release(); server.close(); clientFactory.close(); testFile.delete(); @@ -168,11 +226,11 @@ public void onFailure(int chunkIndex, Throwable e) { } }; - Supplier> streamCallbackFactory = mock(Supplier.class); for (int chunkIndex : chunkIndices) { - client.fetchChunk(STREAM_ID, chunkIndex, callback, streamCallbackFactory); + client.fetchChunk(STREAM_ID, chunkIndex, callback, + () -> new FetchChunkDownloadTestCallback(res, sem)); } - if (!sem.tryAcquire(chunkIndices.size(), 5, TimeUnit.SECONDS)) { + if (!sem.tryAcquire(chunkIndices.size(), 5, TimeUnit.MINUTES)) { fail("Timeout getting response from the server"); } client.close(); @@ -184,6 +242,7 @@ public void fetchBufferChunk() throws Exception { FetchResult res = fetchChunks(Arrays.asList(BUFFER_CHUNK_INDEX)); assertEquals(Sets.newHashSet(BUFFER_CHUNK_INDEX), res.successChunks); assertTrue(res.failedChunks.isEmpty()); + assertNumFileSegments(0, res.buffers); assertBufferListsEqual(Arrays.asList(bufferChunk), res.buffers); res.releaseBuffers(); } @@ -193,6 +252,7 @@ public void fetchFileChunk() throws Exception { FetchResult res = fetchChunks(Arrays.asList(FILE_CHUNK_INDEX)); assertEquals(Sets.newHashSet(FILE_CHUNK_INDEX), res.successChunks); assertTrue(res.failedChunks.isEmpty()); + assertNumFileSegments(0, res.buffers); assertBufferListsEqual(Arrays.asList(fileChunk), res.buffers); res.releaseBuffers(); } @@ -210,10 +270,25 @@ public void fetchBothChunks() throws Exception { FetchResult res = fetchChunks(Arrays.asList(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX)); assertEquals(Sets.newHashSet(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX), res.successChunks); assertTrue(res.failedChunks.isEmpty()); + assertNumFileSegments(0, res.buffers); assertBufferListsEqual(Arrays.asList(bufferChunk, fileChunk), res.buffers); res.releaseBuffers(); } + + @Test + public void fetchSomeChunksToDisk() throws Exception { + FetchResult res = fetchChunks( + Arrays.asList(BUFFER_CHUNK_INDEX, BUFFER_FETCH_TO_DISK_CHUNK_INDEX, FILE_CHUNK_INDEX)); + assertEquals( + Sets.newHashSet(BUFFER_CHUNK_INDEX, BUFFER_FETCH_TO_DISK_CHUNK_INDEX, FILE_CHUNK_INDEX), + res.successChunks); + assertTrue(res.failedChunks.isEmpty()); + assertNumFileSegments(1, res.buffers); + assertBufferListsEqual(Arrays.asList(bufferChunk, bufferToDiskChunk, fileChunk), res.buffers); + res.releaseBuffers(); + } + @Test public void fetchChunkAndNonExistent() throws Exception { FetchResult res = fetchChunks(Arrays.asList(BUFFER_CHUNK_INDEX, 12345)); @@ -223,6 +298,11 @@ public void fetchChunkAndNonExistent() throws Exception { res.releaseBuffers(); } + private void assertNumFileSegments(int expected, List buffers) { + assertEquals(expected, + buffers.stream().filter(b -> b instanceof FileSegmentManagedBuffer).count()); + } + private static void assertBufferListsEqual(List list0, List list1) throws Exception { assertEquals(list0.size(), list1.size()); From 797f558bc9d52010a6e350cd3fde9702a5873f40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Fri, 1 Jun 2018 16:15:27 +0200 Subject: [PATCH 5/8] tiny fix --- .../apache/spark/network/ChunkFetchIntegrationSuite.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index 2d770db4e34c0..1717f74a97a03 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -72,8 +72,8 @@ public class ChunkFetchIntegrationSuite { static ManagedBuffer fileChunk; private class FetchChunkDownloadTestCallback implements StreamCallback { - private WritableByteChannel channel = null; - private File targetFile = null; + private WritableByteChannel channel; + private File targetFile; private FetchResult fetchResult; private Semaphore semaphore; @@ -230,7 +230,7 @@ public void onFailure(int chunkIndex, Throwable e) { client.fetchChunk(STREAM_ID, chunkIndex, callback, () -> new FetchChunkDownloadTestCallback(res, sem)); } - if (!sem.tryAcquire(chunkIndices.size(), 5, TimeUnit.MINUTES)) { + if (!sem.tryAcquire(chunkIndices.size(), 5, TimeUnit.SECONDS)) { fail("Timeout getting response from the server"); } client.close(); From 76f23cbec96ca746f1cf5fd324c671f918300f28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Tue, 5 Jun 2018 16:25:06 +0200 Subject: [PATCH 6/8] Add SASL support for FrameDecoder --- .../spark/network/TransportContext.java | 2 +- .../spark/network/sasl/SaslEncryption.java | 2 +- .../apache/spark/network/util/NettyUtils.java | 6 +- .../network/util/TransportFrameDecoder.java | 75 ++++++++++++------- .../apache/spark/network/ProtocolSuite.java | 6 +- .../util/TransportFrameDecoderSuite.java | 1 - 6 files changed, 56 insertions(+), 36 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java index af68728ce5204..5a0f575b0dac1 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java @@ -147,7 +147,7 @@ public TransportChannelHandler initializePipeline( channel.pipeline() .addLast("encoder", ENCODER) .addLast(TransportFrameDecoder.HANDLER_NAME, - NettyUtils.createFrameDecoder(conf.maxRemoteBlockSizeFetchToMem())) + NettyUtils.createFrameDecoder(conf.maxRemoteBlockSizeFetchToMem(), false)) .addLast("decoder", DECODER) .addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000)) // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java index e29a1e0f2c800..5128f90160a93 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java @@ -61,7 +61,7 @@ static void addToChannel( channel.pipeline() .addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(backend, maxOutboundBlockSize)) .addFirst("saslDecryption", new DecryptionHandler(backend)) - .addFirst("saslFrameDecoder", NettyUtils.createFrameDecoder(Integer.MAX_VALUE)); + .addFirst("saslFrameDecoder", NettyUtils.createFrameDecoder(Integer.MAX_VALUE, true)); } private static class EncryptionHandler extends ChannelOutboundHandlerAdapter { diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java index 2e34b9d69a522..eaa02e78610e2 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -84,8 +84,10 @@ public static Class getServerChannelClass(IOMode mode) * Creates a LengthFieldBasedFrameDecoder where the first 8 bytes are the length of the frame. * This is used before all decoders. */ - public static TransportFrameDecoder createFrameDecoder(long maxRemoteBlockSizeFetchToMem) { - return new TransportFrameDecoder(maxRemoteBlockSizeFetchToMem); + public static TransportFrameDecoder createFrameDecoder( + long maxRemoteBlockSizeFetchToMem, + boolean isSasl) { + return new TransportFrameDecoder(maxRemoteBlockSizeFetchToMem, isSasl); } /** Returns the remote address on the channel or "<unknown remote>" if none exists. */ diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java index 197b14e9a5888..4e6b617e53996 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java @@ -59,11 +59,17 @@ public class TransportFrameDecoder extends ChannelInboundHandlerAdapter { private long nextFrameSize = UNKNOWN_FRAME_SIZE; private volatile Interceptor interceptor; private Message.Type msgType = null; + private final boolean isSasl; private final long maxRemoteBlockSizeFetchToMem; public TransportFrameDecoder(long maxRemoteBlockSizeFetchToMem) { + this(maxRemoteBlockSizeFetchToMem, false); + } + + public TransportFrameDecoder(long maxRemoteBlockSizeFetchToMem, boolean isSasl) { this.maxRemoteBlockSizeFetchToMem = maxRemoteBlockSizeFetchToMem; + this.isSasl = isSasl; } @Override @@ -87,26 +93,37 @@ public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception } totalSize -= read; } else { - // Interceptor is not active, so try to decode one frame. - decodeNextMsgType(); - if (msgType == null) { - break; - } - long remainingFrameSize = 0; - if (msgType == Message.Type.ChunkFetchSuccess && - nextFrameSize - ChunkFetchSuccess.ENCODED_LENGTH > maxRemoteBlockSizeFetchToMem) { + if (isSasl) { + if (!isFrameSizeAvailable()) { + break; + } + ByteBuf frame = decodeNext(); + if (frame == null) { + break; + } + ctx.fireChannelRead(frame); + } else { + // Interceptor is not active, so try to decode one frame. + decodeNextMsgType(); + if (msgType == null) { + break; + } + long remainingFrameSize = 0; + if (msgType == Message.Type.ChunkFetchSuccess && + nextFrameSize - ChunkFetchSuccess.ENCODED_LENGTH > maxRemoteBlockSizeFetchToMem) { remainingFrameSize = nextFrameSize - ChunkFetchSuccess.ENCODED_LENGTH; nextFrameSize = ChunkFetchSuccess.ENCODED_LENGTH; + } + + ByteBuf frame = decodeNext(); + if (frame == null) { + break; + } + ParsedFrame parsedFrame = + new ParsedFrame(msgType, frame, remainingFrameSize); + msgType = null; + ctx.fireChannelRead(parsedFrame); } - - ByteBuf frame = decodeNext(); - if (frame == null) { - break; - } - ParsedFrame parsedFrame = - new ParsedFrame(msgType, frame, remainingFrameSize); - msgType = null; - ctx.fireChannelRead(parsedFrame); } } } @@ -146,18 +163,7 @@ private long decodeFrameSize() { } private void decodeNextMsgType() { - if (msgType != null) { - return; - } - long frameSize = decodeFrameSize(); - if (frameSize == UNKNOWN_FRAME_SIZE) { - return; - } - - Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE, "Too large frame: %s", frameSize); - Preconditions.checkArgument(frameSize > 0, "Frame length should be positive: %s", frameSize); - - if(totalSize < Message.Type.LENGTH) { + if (msgType != null || !isFrameSizeAvailable() || totalSize < Message.Type.LENGTH) { return; } @@ -170,6 +176,17 @@ private void decodeNextMsgType() { } } + private boolean isFrameSizeAvailable() { + long frameSize = decodeFrameSize(); + if (frameSize == UNKNOWN_FRAME_SIZE) { + return false; + } + + Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE, "Too large frame: %s", frameSize); + Preconditions.checkArgument(frameSize > 0, "Frame length should be positive: %s", frameSize); + return true; + } + private ByteBuf decodeNext() { long frameSize = nextFrameSize; if (totalSize < frameSize) { diff --git a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java index 8c8c5ef2c80ae..fd318b632919e 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -55,7 +55,8 @@ private Message decodedClientMessageFromChannel(Message msg, long maxRemoteBlock serverChannel.writeOutbound(msg); EmbeddedChannel clientChannel = new EmbeddedChannel( - NettyUtils.createFrameDecoder(maxRemoteBlockSizeFetchToMem), MessageDecoder.INSTANCE); + NettyUtils.createFrameDecoder(maxRemoteBlockSizeFetchToMem, false), + MessageDecoder.INSTANCE); while (!serverChannel.outboundMessages().isEmpty()) { clientChannel.writeOneInbound(serverChannel.readOutbound()); @@ -76,7 +77,8 @@ private void testClientToServer(Message msg) { clientChannel.writeOutbound(msg); EmbeddedChannel serverChannel = new EmbeddedChannel( - NettyUtils.createFrameDecoder(Integer.MAX_VALUE), MessageDecoder.INSTANCE); + NettyUtils.createFrameDecoder(Integer.MAX_VALUE, false), + MessageDecoder.INSTANCE); while (!clientChannel.outboundMessages().isEmpty()) { serverChannel.writeOneInbound(clientChannel.readOutbound()); diff --git a/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java index f590d67000c04..388bb41a57a5e 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.network.util; -import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.List; From 365e673868760cd57a789faba255d3dfd4f30f79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Tue, 5 Jun 2018 17:22:17 +0200 Subject: [PATCH 7/8] fix --- .../apache/spark/network/client/TransportResponseHandler.java | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 3630a0ad6c395..41c2c7cc6b271 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -87,6 +87,7 @@ public void addFetchRequest( public void removeFetchRequest(StreamChunkId streamChunkId) { outstandingFetches.remove(streamChunkId); + outstandingStreamFetches.remove(streamChunkId); } public void addRpcRequest(long requestId, RpcResponseCallback callback) { @@ -139,6 +140,7 @@ private void failOutstandingRequests(Throwable cause) { outstandingFetches.clear(); outstandingRpcs.clear(); streamCallbacks.clear(); + outstandingStreamFetches.clear(); } @Override @@ -180,6 +182,8 @@ public void handle(ResponseMessage message) throws Exception { } else { StreamCallback streamCallback = outstandingStreamFetches.get(resp.streamChunkId).get(); + outstandingFetches.remove(resp.streamChunkId); + outstandingStreamFetches.remove(resp.streamChunkId); if (resp.remainingFrameSize > 0) { StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamChunkId, resp.remainingFrameSize, streamCallback); From 5899663d75235d00c8526dafa406f0d6be5a0a58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Wed, 6 Jun 2018 13:18:17 +0200 Subject: [PATCH 8/8] fix --- .../ShuffleBlockFetcherIteratorSuite.scala | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 81b245a3f5bcf..1e41a35aa4bdd 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -46,7 +46,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT /** Creates a mock [[BlockTransferService]] that returns data from the given map. */ private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = { val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), false)) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]] @@ -141,7 +141,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // 3 local blocks, and 2 remote blocks // (but from the same block manager so one call to fetchBlocks) verify(blockManager, times(3)).getBlockData(any()) - verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any(), false) + verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any(), any()) } test("release current unexhausted buffer in case the task completes early") { @@ -160,7 +160,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), false)) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -229,7 +229,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), false)) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -300,7 +300,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), false)) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -341,7 +341,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val (id1, _) = iterator.next() assert(id1 === ShuffleBlockId(0, 0, 0)) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), false)) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -420,7 +420,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), false)) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -484,12 +484,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val remoteBlocks = Map[BlockId, ManagedBuffer]( ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) val transfer = mock(classOf[BlockTransferService]) - var tempFileManager: TempFileManager = null - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), false)) + var useStreamRequestMessage: Boolean = false + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - tempFileManager = invocation.getArguments()(5).asInstanceOf[TempFileManager] + useStreamRequestMessage = invocation.getArguments()(6).asInstanceOf[Boolean] Future { listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0))) @@ -511,7 +511,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT maxBytesInFlight = Int.MaxValue, maxReqsInFlight = Int.MaxValue, maxBlocksInFlightPerAddress = Int.MaxValue, - streamRequestMessageEnabled = false, + streamRequestMessageEnabled = true, maxReqSizeShuffleToMem = 200, detectCorrupt = true) } @@ -521,14 +521,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT fetchShuffleBlock(blocksByAddress1) // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch // shuffle block to disk. - assert(tempFileManager == null) + assert(!useStreamRequestMessage) val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)).toIterator fetchShuffleBlock(blocksByAddress2) // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch // shuffle block to disk. - assert(tempFileManager != null) + assert(useStreamRequestMessage) } test("fail zero-size blocks") {