diff --git a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java index ae91bc9cfdd08..5a0f575b0dac1 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java @@ -146,7 +146,8 @@ public TransportChannelHandler initializePipeline( TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler); channel.pipeline() .addLast("encoder", ENCODER) - .addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder()) + .addLast(TransportFrameDecoder.HANDLER_NAME, + NettyUtils.createFrameDecoder(conf.maxRemoteBlockSizeFetchToMem(), false)) .addLast("decoder", DECODER) .addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000)) // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java index d322aec28793e..bf3b60648b3ca 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java @@ -17,6 +17,8 @@ package org.apache.spark.network.client; +import io.netty.buffer.ByteBuf; + import java.io.IOException; import java.nio.ByteBuffer; @@ -28,13 +30,13 @@ * The network library guarantees that a single thread will call these methods at a time, but * different call may be made by different threads. */ -public interface StreamCallback { +public interface StreamCallback { /** Called upon receipt of stream data. */ - void onData(String streamId, ByteBuffer buf) throws IOException; + void onData(T streamId, ByteBuffer buf) throws IOException; /** Called when all data from the stream has been received. */ - void onComplete(String streamId) throws IOException; + void onComplete(T streamId) throws IOException; /** Called if there's an error reading data from the stream. */ - void onFailure(String streamId, Throwable cause) throws IOException; + void onFailure(T streamId, Throwable cause) throws IOException; } diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java index b0e85bae7c309..19faacf67d00e 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java @@ -28,17 +28,17 @@ * An interceptor that is registered with the frame decoder to feed stream data to a * callback. */ -class StreamInterceptor implements TransportFrameDecoder.Interceptor { +class StreamInterceptor implements TransportFrameDecoder.Interceptor { private final TransportResponseHandler handler; - private final String streamId; + private final T streamId; private final long byteCount; - private final StreamCallback callback; + private final StreamCallback callback; private long bytesRead; StreamInterceptor( TransportResponseHandler handler, - String streamId, + T streamId, long byteCount, StreamCallback callback) { this.handler = handler; diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index 8f354ad78bbaa..c8a8c83f93b9c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -24,6 +24,7 @@ import java.util.UUID; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; import javax.annotation.Nullable; import com.google.common.annotations.VisibleForTesting; @@ -132,14 +133,15 @@ public void setClientId(String id) { public void fetchChunk( long streamId, int chunkIndex, - ChunkReceivedCallback callback) { + ChunkReceivedCallback callback, + Supplier> streamCallbackFactory) { long startTime = System.currentTimeMillis(); if (logger.isDebugEnabled()) { logger.debug("Sending fetch chunk request {} to {}", chunkIndex, getRemoteAddress(channel)); } StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex); - handler.addFetchRequest(streamChunkId, callback); + handler.addFetchRequest(streamChunkId, callback, streamCallbackFactory); channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(future -> { if (future.isSuccess()) { @@ -169,7 +171,7 @@ public void fetchChunk( * @param streamId The stream to fetch. * @param callback Object to call with the stream data. */ - public void stream(String streamId, StreamCallback callback) { + public void stream(String streamId, StreamCallback callback) { long startTime = System.currentTimeMillis(); if (logger.isDebugEnabled()) { logger.debug("Sending stream request for {} to {}", streamId, getRemoteAddress(channel)); diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 7a3d96ceaef0c..41c2c7cc6b271 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -23,6 +23,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; import com.google.common.annotations.VisibleForTesting; import io.netty.channel.Channel; @@ -55,10 +56,12 @@ public class TransportResponseHandler extends MessageHandler { private final Channel channel; private final Map outstandingFetches; + private final Map>> + outstandingStreamFetches; private final Map outstandingRpcs; - private final Queue> streamCallbacks; + private final Queue>> streamCallbacks; private volatile boolean streamActive; /** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */ @@ -67,18 +70,24 @@ public class TransportResponseHandler extends MessageHandler { public TransportResponseHandler(Channel channel) { this.channel = channel; this.outstandingFetches = new ConcurrentHashMap<>(); + this.outstandingStreamFetches = new ConcurrentHashMap<>(); this.outstandingRpcs = new ConcurrentHashMap<>(); this.streamCallbacks = new ConcurrentLinkedQueue<>(); this.timeOfLastRequestNs = new AtomicLong(0); } - public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) { + public void addFetchRequest( + StreamChunkId streamChunkId, + ChunkReceivedCallback callback, + Supplier> streamCallbackFactory) { updateTimeOfLastRequest(); outstandingFetches.put(streamChunkId, callback); + outstandingStreamFetches.put(streamChunkId, streamCallbackFactory); } public void removeFetchRequest(StreamChunkId streamChunkId) { outstandingFetches.remove(streamChunkId); + outstandingStreamFetches.remove(streamChunkId); } public void addRpcRequest(long requestId, RpcResponseCallback callback) { @@ -90,7 +99,7 @@ public void removeRpcRequest(long requestId) { outstandingRpcs.remove(requestId); } - public void addStreamCallback(String streamId, StreamCallback callback) { + public void addStreamCallback(String streamId, StreamCallback callback) { timeOfLastRequestNs.set(System.nanoTime()); streamCallbacks.offer(ImmutablePair.of(streamId, callback)); } @@ -119,7 +128,7 @@ private void failOutstandingRequests(Throwable cause) { logger.warn("RpcResponseCallback.onFailure throws exception", e); } } - for (Pair entry : streamCallbacks) { + for (Pair> entry : streamCallbacks) { try { entry.getValue().onFailure(entry.getKey(), cause); } catch (Exception e) { @@ -131,6 +140,7 @@ private void failOutstandingRequests(Throwable cause) { outstandingFetches.clear(); outstandingRpcs.clear(); streamCallbacks.clear(); + outstandingStreamFetches.clear(); } @Override @@ -165,10 +175,37 @@ public void handle(ResponseMessage message) throws Exception { if (listener == null) { logger.warn("Ignoring response for block {} from {} since it is not outstanding", resp.streamChunkId, getRemoteAddress(channel)); - resp.body().release(); } else { - outstandingFetches.remove(resp.streamChunkId); - listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body()); + if (resp.isBodyInFrame()) { + outstandingFetches.remove(resp.streamChunkId); + listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body()); + } else { + StreamCallback streamCallback = + outstandingStreamFetches.get(resp.streamChunkId).get(); + outstandingFetches.remove(resp.streamChunkId); + outstandingStreamFetches.remove(resp.streamChunkId); + if (resp.remainingFrameSize > 0) { + StreamInterceptor interceptor = new StreamInterceptor(this, + resp.streamChunkId, resp.remainingFrameSize, streamCallback); + try { + TransportFrameDecoder frameDecoder = (TransportFrameDecoder) + channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); + frameDecoder.setInterceptor(interceptor); + streamActive = true; + } catch (Exception e) { + logger.error("Error installing stream handler.", e); + deactivateStream(); + } + } else { + try { + streamCallback.onComplete(resp.streamChunkId); + } catch (Exception e) { + logger.warn("Error in stream handler onComplete().", e); + } + } + } + } + if (resp.isBodyInFrame()) { resp.body().release(); } } else if (message instanceof ChunkFetchFailure) { @@ -208,12 +245,12 @@ public void handle(ResponseMessage message) throws Exception { } } else if (message instanceof StreamResponse) { StreamResponse resp = (StreamResponse) message; - Pair entry = streamCallbacks.poll(); + Pair> entry = streamCallbacks.poll(); if (entry != null) { StreamCallback callback = entry.getValue(); if (resp.byteCount > 0) { - StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, - callback); + StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, + resp.byteCount, callback); try { TransportFrameDecoder frameDecoder = (TransportFrameDecoder) channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); @@ -235,7 +272,7 @@ public void handle(ResponseMessage message) throws Exception { } } else if (message instanceof StreamFailure) { StreamFailure resp = (StreamFailure) message; - Pair entry = streamCallbacks.poll(); + Pair> entry = streamCallbacks.poll(); if (entry != null) { StreamCallback callback = entry.getValue(); try { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java index 94c2ac9b20e43..7a315f9db2fd2 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java @@ -31,11 +31,23 @@ * Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer. */ public final class ChunkFetchSuccess extends AbstractResponseMessage { + public static final int ENCODED_LENGTH = StreamChunkId.ENCODED_LENGTH; public final StreamChunkId streamChunkId; + public final long remainingFrameSize; public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) { super(buffer, true); this.streamChunkId = streamChunkId; + this.remainingFrameSize = 0; + } + + public ChunkFetchSuccess(StreamChunkId streamChunkId, + ManagedBuffer buffer, + boolean isBodyInFrame, + long remainingFrameSize) { + super(buffer, isBodyInFrame); + this.streamChunkId = streamChunkId; + this.remainingFrameSize = remainingFrameSize; } @Override @@ -58,11 +70,16 @@ public ResponseMessage createFailureResponse(String error) { } /** Decoding uses the given ByteBuf as our data, and will retain() it. */ - public static ChunkFetchSuccess decode(ByteBuf buf) { + public static ChunkFetchSuccess decode(ByteBuf buf, long remainingFrameSize) { StreamChunkId streamChunkId = StreamChunkId.decode(buf); - buf.retain(); - NettyManagedBuffer managedBuf = new NettyManagedBuffer(buf.duplicate()); - return new ChunkFetchSuccess(streamChunkId, managedBuf); + NettyManagedBuffer managedBuf = null; + final boolean isFullFrameProcessed = + remainingFrameSize == 0; + if (isFullFrameProcessed) { + buf.retain(); + managedBuf = new NettyManagedBuffer(buf.duplicate()); + } + return new ChunkFetchSuccess(streamChunkId, managedBuf, isFullFrameProcessed, remainingFrameSize); } @Override diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java index 434935a8ef2ad..8360d7d0ac21b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -39,6 +39,9 @@ enum Type implements Encodable { StreamRequest(6), StreamResponse(7), StreamFailure(8), OneWayMessage(9), User(-1); + /** Encoded length in bytes. */ + public static final int LENGTH = 1; + private final byte id; Type(int id) { @@ -48,7 +51,7 @@ enum Type implements Encodable { public byte id() { return id; } - @Override public int encodedLength() { return 1; } + @Override public int encodedLength() { return LENGTH; } @Override public void encode(ByteBuf buf) { buf.writeByte(id); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index 39a7495828a8a..fc7151d55a3e2 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -31,7 +31,7 @@ * This encoder is stateless so it is safe to be shared by multiple threads. */ @ChannelHandler.Sharable -public final class MessageDecoder extends MessageToMessageDecoder { +public final class MessageDecoder extends MessageToMessageDecoder { private static final Logger logger = LoggerFactory.getLogger(MessageDecoder.class); @@ -40,21 +40,20 @@ public final class MessageDecoder extends MessageToMessageDecoder { private MessageDecoder() {} @Override - public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { - Message.Type msgType = Message.Type.decode(in); - Message decoded = decode(msgType, in); - assert decoded.type() == msgType; - logger.trace("Received message {}: {}", msgType, decoded); + public void decode(ChannelHandlerContext ctx, ParsedFrame in, List out) { + Message decoded = decode(in.messageType, in.byteBuf, in.remainingFrameSize); + assert decoded.type() == in.messageType; + logger.trace("Received message {}: {}", in.messageType, decoded); out.add(decoded); } - private Message decode(Message.Type msgType, ByteBuf in) { + private Message decode(Message.Type msgType, ByteBuf in, long remainingFrameSize) { switch (msgType) { case ChunkFetchRequest: return ChunkFetchRequest.decode(in); case ChunkFetchSuccess: - return ChunkFetchSuccess.decode(in); + return ChunkFetchSuccess.decode(in, remainingFrameSize); case ChunkFetchFailure: return ChunkFetchFailure.decode(in); diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ParsedFrame.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ParsedFrame.java new file mode 100644 index 0000000000000..b29fe7626eefa --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ParsedFrame.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import io.netty.buffer.ByteBuf; + +public class ParsedFrame { + + public final Message.Type messageType; + + public final ByteBuf byteBuf; + + public final long remainingFrameSize; + + + public ParsedFrame(Message.Type messageType, ByteBuf byteBuf, long remainingFrameSize) { + this.messageType = messageType; + this.byteBuf = byteBuf; + this.remainingFrameSize = remainingFrameSize; + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java index d46a263884807..0807cb127c9d1 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java @@ -24,6 +24,9 @@ * Encapsulates a request for a particular chunk of a stream. */ public final class StreamChunkId implements Encodable { + + public static final int ENCODED_LENGTH = 8 + 4; + public final long streamId; public final int chunkIndex; @@ -34,7 +37,7 @@ public StreamChunkId(long streamId, int chunkIndex) { @Override public int encodedLength() { - return 8 + 4; + return ENCODED_LENGTH; } public void encode(ByteBuf buffer) { @@ -43,7 +46,7 @@ public void encode(ByteBuf buffer) { } public static StreamChunkId decode(ByteBuf buffer) { - assert buffer.readableBytes() >= 8 + 4; + assert buffer.readableBytes() >= ENCODED_LENGTH; long streamId = buffer.readLong(); int chunkIndex = buffer.readInt(); return new StreamChunkId(streamId, chunkIndex); diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java index 3ac9081d78a75..5128f90160a93 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java @@ -61,7 +61,7 @@ static void addToChannel( channel.pipeline() .addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(backend, maxOutboundBlockSize)) .addFirst("saslDecryption", new DecryptionHandler(backend)) - .addFirst("saslFrameDecoder", NettyUtils.createFrameDecoder()); + .addFirst("saslFrameDecoder", NettyUtils.createFrameDecoder(Integer.MAX_VALUE, true)); } private static class EncryptionHandler extends ChannelOutboundHandlerAdapter { diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java index 5e85180bd6f9f..eaa02e78610e2 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -84,8 +84,10 @@ public static Class getServerChannelClass(IOMode mode) * Creates a LengthFieldBasedFrameDecoder where the first 8 bytes are the length of the frame. * This is used before all decoders. */ - public static TransportFrameDecoder createFrameDecoder() { - return new TransportFrameDecoder(); + public static TransportFrameDecoder createFrameDecoder( + long maxRemoteBlockSizeFetchToMem, + boolean isSasl) { + return new TransportFrameDecoder(maxRemoteBlockSizeFetchToMem, isSasl); } /** Returns the remote address on the channel or "<unknown remote>" if none exists. */ diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 91497b9492219..a2a0939a15586 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -281,4 +281,8 @@ public Properties cryptoConf() { public long maxChunksBeingTransferred() { return conf.getLong("spark.shuffle.maxChunksBeingTransferred", Long.MAX_VALUE); } + + public long maxRemoteBlockSizeFetchToMem() { + return conf.getLong("spark.maxRemoteBlockSizeFetchToMem", Long.MAX_VALUE); + } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java index 8e73ab077a5c1..4e6b617e53996 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java @@ -25,6 +25,9 @@ import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; +import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.Message; +import org.apache.spark.network.protocol.ParsedFrame; /** * A customized frame decoder that allows intercepting raw data. @@ -55,6 +58,19 @@ public class TransportFrameDecoder extends ChannelInboundHandlerAdapter { private long totalSize = 0; private long nextFrameSize = UNKNOWN_FRAME_SIZE; private volatile Interceptor interceptor; + private Message.Type msgType = null; + private final boolean isSasl; + + private final long maxRemoteBlockSizeFetchToMem; + + public TransportFrameDecoder(long maxRemoteBlockSizeFetchToMem) { + this(maxRemoteBlockSizeFetchToMem, false); + } + + public TransportFrameDecoder(long maxRemoteBlockSizeFetchToMem, boolean isSasl) { + this.maxRemoteBlockSizeFetchToMem = maxRemoteBlockSizeFetchToMem; + this.isSasl = isSasl; + } @Override public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception { @@ -77,12 +93,37 @@ public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception } totalSize -= read; } else { - // Interceptor is not active, so try to decode one frame. - ByteBuf frame = decodeNext(); - if (frame == null) { - break; + if (isSasl) { + if (!isFrameSizeAvailable()) { + break; + } + ByteBuf frame = decodeNext(); + if (frame == null) { + break; + } + ctx.fireChannelRead(frame); + } else { + // Interceptor is not active, so try to decode one frame. + decodeNextMsgType(); + if (msgType == null) { + break; + } + long remainingFrameSize = 0; + if (msgType == Message.Type.ChunkFetchSuccess && + nextFrameSize - ChunkFetchSuccess.ENCODED_LENGTH > maxRemoteBlockSizeFetchToMem) { + remainingFrameSize = nextFrameSize - ChunkFetchSuccess.ENCODED_LENGTH; + nextFrameSize = ChunkFetchSuccess.ENCODED_LENGTH; + } + + ByteBuf frame = decodeNext(); + if (frame == null) { + break; + } + ParsedFrame parsedFrame = + new ParsedFrame(msgType, frame, remainingFrameSize); + msgType = null; + ctx.fireChannelRead(parsedFrame); } - ctx.fireChannelRead(frame); } } } @@ -121,18 +162,40 @@ private long decodeFrameSize() { return nextFrameSize; } - private ByteBuf decodeNext() { + private void decodeNextMsgType() { + if (msgType != null || !isFrameSizeAvailable() || totalSize < Message.Type.LENGTH) { + return; + } + + ByteBuf first = buffers.getFirst(); + msgType = Message.Type.decode(first); + totalSize -= Message.Type.LENGTH; + nextFrameSize -= Message.Type.LENGTH; + if (!first.isReadable()) { + buffers.removeFirst().release(); + } + } + + private boolean isFrameSizeAvailable() { long frameSize = decodeFrameSize(); - if (frameSize == UNKNOWN_FRAME_SIZE || totalSize < frameSize) { + if (frameSize == UNKNOWN_FRAME_SIZE) { + return false; + } + + Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE, "Too large frame: %s", frameSize); + Preconditions.checkArgument(frameSize > 0, "Frame length should be positive: %s", frameSize); + return true; + } + + private ByteBuf decodeNext() { + long frameSize = nextFrameSize; + if (totalSize < frameSize) { return null; } // Reset size for next frame. nextFrameSize = UNKNOWN_FRAME_SIZE; - Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE, "Too large frame: %s", frameSize); - Preconditions.checkArgument(frameSize > 0, "Frame length should be positive: %s", frameSize); - // If the first buffer holds the entire frame, return it. int remaining = (int) frameSize; if (buffers.getFirst().readableBytes() >= remaining) { diff --git a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index 824482af08dd4..1717f74a97a03 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -18,33 +18,30 @@ package org.apache.spark.network; import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; import java.io.RandomAccessFile; import java.nio.ByteBuffer; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashSet; -import java.util.LinkedList; -import java.util.List; -import java.util.Random; -import java.util.Set; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; +import java.util.*; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import com.google.common.collect.Sets; import com.google.common.io.Closeables; +import org.apache.spark.network.client.*; +import org.apache.spark.network.protocol.StreamChunkId; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; import static org.junit.Assert.*; +import static org.mockito.Mockito.mock; import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.StreamManager; @@ -55,6 +52,14 @@ public class ChunkFetchIntegrationSuite { static final long STREAM_ID = 1; static final int BUFFER_CHUNK_INDEX = 0; static final int FILE_CHUNK_INDEX = 1; + static final int BUFFER_FETCH_TO_DISK_CHUNK_INDEX = 2; + static final int MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM = 100000; + + static final TransportConf transportConf = + new TransportConf("shuffle", + new MapConfigProvider( + Collections.singletonMap( + "spark.maxRemoteBlockSizeFetchToMem", Integer.toString(MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM)))); static TransportServer server; static TransportClientFactory clientFactory; @@ -62,17 +67,68 @@ public class ChunkFetchIntegrationSuite { static File testFile; static ManagedBuffer bufferChunk; + static ManagedBuffer bufferToDiskChunk; + static ManagedBuffer fileChunk; + private class FetchChunkDownloadTestCallback implements StreamCallback { + private WritableByteChannel channel; + private File targetFile; + private FetchResult fetchResult; + private Semaphore semaphore; + + FetchChunkDownloadTestCallback(FetchResult fetchResult, Semaphore semaphore) { + this.fetchResult = fetchResult; + this.semaphore = semaphore; + try { + this.targetFile = File.createTempFile("shuffle-test-file-download-", "txt"); + this.targetFile.deleteOnExit(); + this.channel = Channels.newChannel(new FileOutputStream(targetFile)); + } catch (IOException e) { + throw new IllegalStateException(e); + } + } + + @Override + public void onData(StreamChunkId streamId, ByteBuffer buf) throws IOException { + while (buf.hasRemaining()) { + channel.write(buf); + } + } + + @Override + public void onComplete(StreamChunkId streamId) throws IOException { + channel.close(); + ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0, + targetFile.length()); + fetchResult.successChunks.add(streamId.chunkIndex); + fetchResult.buffers.add(buffer); + semaphore.release(); + } + + @Override + public void onFailure(StreamChunkId streamId, Throwable cause) throws IOException { + channel.close(); + this.fetchResult.failedChunks.add(streamId.chunkIndex); + semaphore.release(); + } + } + @BeforeClass public static void setUp() throws Exception { - int bufSize = 100000; - final ByteBuffer buf = ByteBuffer.allocate(bufSize); - for (int i = 0; i < bufSize; i ++) { - buf.put((byte) i); + int bufSize = MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM + 100; + final ByteBuffer hugeBuf = ByteBuffer.allocate(bufSize); + for (int i = 0; i < MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM; i ++) { + hugeBuf.put((byte) i); + } + ByteBuffer smallBuff = hugeBuf.duplicate(); + smallBuff.flip(); + bufferChunk = new NioManagedBuffer(smallBuff); + for (int i = MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM; i < bufSize; i ++) { + hugeBuf.put((byte) i); } - buf.flip(); - bufferChunk = new NioManagedBuffer(buf); + hugeBuf.flip(); + bufferToDiskChunk = new NioManagedBuffer(hugeBuf); testFile = File.createTempFile("shuffle-test-file", "txt"); testFile.deleteOnExit(); @@ -87,19 +143,21 @@ public static void setUp() throws Exception { Closeables.close(fp, shouldSuppressIOException); } - final TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); - fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); + fileChunk = new FileSegmentManagedBuffer(transportConf, testFile, 10, testFile.length() - 25); streamManager = new StreamManager() { @Override public ManagedBuffer getChunk(long streamId, int chunkIndex) { assertEquals(STREAM_ID, streamId); - if (chunkIndex == BUFFER_CHUNK_INDEX) { - return new NioManagedBuffer(buf); - } else if (chunkIndex == FILE_CHUNK_INDEX) { - return new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); - } else { - throw new IllegalArgumentException("Invalid chunk index: " + chunkIndex); + switch (chunkIndex) { + case BUFFER_CHUNK_INDEX: + return new NioManagedBuffer(smallBuff); + case FILE_CHUNK_INDEX: + return new FileSegmentManagedBuffer(transportConf, testFile, 10, testFile.length() - 25); + case BUFFER_FETCH_TO_DISK_CHUNK_INDEX: + return new NioManagedBuffer(hugeBuf); + default: + throw new IllegalArgumentException("Invalid chunk index: " + chunkIndex); } } }; @@ -117,7 +175,7 @@ public StreamManager getStreamManager() { return streamManager; } }; - TransportContext context = new TransportContext(conf, handler); + TransportContext context = new TransportContext(transportConf, handler); server = context.createServer(); clientFactory = context.createClientFactory(); } @@ -125,6 +183,7 @@ public StreamManager getStreamManager() { @AfterClass public static void tearDown() { bufferChunk.release(); + bufferToDiskChunk.release(); server.close(); clientFactory.close(); testFile.delete(); @@ -168,7 +227,8 @@ public void onFailure(int chunkIndex, Throwable e) { }; for (int chunkIndex : chunkIndices) { - client.fetchChunk(STREAM_ID, chunkIndex, callback); + client.fetchChunk(STREAM_ID, chunkIndex, callback, + () -> new FetchChunkDownloadTestCallback(res, sem)); } if (!sem.tryAcquire(chunkIndices.size(), 5, TimeUnit.SECONDS)) { fail("Timeout getting response from the server"); @@ -182,6 +242,7 @@ public void fetchBufferChunk() throws Exception { FetchResult res = fetchChunks(Arrays.asList(BUFFER_CHUNK_INDEX)); assertEquals(Sets.newHashSet(BUFFER_CHUNK_INDEX), res.successChunks); assertTrue(res.failedChunks.isEmpty()); + assertNumFileSegments(0, res.buffers); assertBufferListsEqual(Arrays.asList(bufferChunk), res.buffers); res.releaseBuffers(); } @@ -191,6 +252,7 @@ public void fetchFileChunk() throws Exception { FetchResult res = fetchChunks(Arrays.asList(FILE_CHUNK_INDEX)); assertEquals(Sets.newHashSet(FILE_CHUNK_INDEX), res.successChunks); assertTrue(res.failedChunks.isEmpty()); + assertNumFileSegments(0, res.buffers); assertBufferListsEqual(Arrays.asList(fileChunk), res.buffers); res.releaseBuffers(); } @@ -208,10 +270,25 @@ public void fetchBothChunks() throws Exception { FetchResult res = fetchChunks(Arrays.asList(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX)); assertEquals(Sets.newHashSet(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX), res.successChunks); assertTrue(res.failedChunks.isEmpty()); + assertNumFileSegments(0, res.buffers); assertBufferListsEqual(Arrays.asList(bufferChunk, fileChunk), res.buffers); res.releaseBuffers(); } + + @Test + public void fetchSomeChunksToDisk() throws Exception { + FetchResult res = fetchChunks( + Arrays.asList(BUFFER_CHUNK_INDEX, BUFFER_FETCH_TO_DISK_CHUNK_INDEX, FILE_CHUNK_INDEX)); + assertEquals( + Sets.newHashSet(BUFFER_CHUNK_INDEX, BUFFER_FETCH_TO_DISK_CHUNK_INDEX, FILE_CHUNK_INDEX), + res.successChunks); + assertTrue(res.failedChunks.isEmpty()); + assertNumFileSegments(1, res.buffers); + assertBufferListsEqual(Arrays.asList(bufferChunk, bufferToDiskChunk, fileChunk), res.buffers); + res.releaseBuffers(); + } + @Test public void fetchChunkAndNonExistent() throws Exception { FetchResult res = fetchChunks(Arrays.asList(BUFFER_CHUNK_INDEX, 12345)); @@ -221,6 +298,11 @@ public void fetchChunkAndNonExistent() throws Exception { res.releaseBuffers(); } + private void assertNumFileSegments(int expected, List buffers) { + assertEquals(expected, + buffers.stream().filter(b -> b instanceof FileSegmentManagedBuffer).count()); + } + private static void assertBufferListsEqual(List list0, List list1) throws Exception { assertEquals(list0.size(), list1.size()); diff --git a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java index bc94f7ca63a96..fd318b632919e 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -28,6 +28,8 @@ import org.junit.Test; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; import org.apache.spark.network.protocol.ChunkFetchFailure; import org.apache.spark.network.protocol.ChunkFetchRequest; @@ -47,20 +49,26 @@ import org.apache.spark.network.util.NettyUtils; public class ProtocolSuite { - private void testServerToClient(Message msg) { + private Message decodedClientMessageFromChannel(Message msg, long maxRemoteBlockSizeFetchToMem) { EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(), - MessageEncoder.INSTANCE); + MessageEncoder.INSTANCE); serverChannel.writeOutbound(msg); EmbeddedChannel clientChannel = new EmbeddedChannel( - NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE); + NettyUtils.createFrameDecoder(maxRemoteBlockSizeFetchToMem, false), + MessageDecoder.INSTANCE); while (!serverChannel.outboundMessages().isEmpty()) { clientChannel.writeOneInbound(serverChannel.readOutbound()); } assertEquals(1, clientChannel.inboundMessages().size()); - assertEquals(msg, clientChannel.readInbound()); + return clientChannel.readInbound(); + } + + private void testServerToClient(Message msg) { + Message clientMessage = decodedClientMessageFromChannel(msg, Integer.MAX_VALUE); + assertEquals(msg, clientMessage); } private void testClientToServer(Message msg) { @@ -69,7 +77,8 @@ private void testClientToServer(Message msg) { clientChannel.writeOutbound(msg); EmbeddedChannel serverChannel = new EmbeddedChannel( - NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE); + NettyUtils.createFrameDecoder(Integer.MAX_VALUE, false), + MessageDecoder.INSTANCE); while (!clientChannel.outboundMessages().isEmpty()) { serverChannel.writeOneInbound(clientChannel.readOutbound()); @@ -79,6 +88,31 @@ private void testClientToServer(Message msg) { assertEquals(msg, serverChannel.readInbound()); } + private ChunkFetchSuccess chunkFetchSuccessWith100Bytes() { + return new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(100)); + } + + private void testChunkFetchSuccess() { + // test without fetch to disk, maxRemoteBlockSizeFetchToMem is Integer.MAX_VALUE + testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(10))); + testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0))); + + // test with fetch to disk + // under (and at) the fetch to mem limit + Message chunkFetchSuccessUnderFetchToMemBlockSize = + decodedClientMessageFromChannel(chunkFetchSuccessWith100Bytes(), 101); + assertEquals(chunkFetchSuccessWith100Bytes(), chunkFetchSuccessUnderFetchToMemBlockSize); + chunkFetchSuccessUnderFetchToMemBlockSize = + decodedClientMessageFromChannel(chunkFetchSuccessWith100Bytes(), 100); + assertEquals(chunkFetchSuccessWith100Bytes(), chunkFetchSuccessUnderFetchToMemBlockSize); + + // above the fetch to mem limit + Message chunkFetchSuccessAboveFetchToMemBlockSize = + decodedClientMessageFromChannel(chunkFetchSuccessWith100Bytes(), 99); + assertNull("message body must be not included", chunkFetchSuccessAboveFetchToMemBlockSize.body()); + assertFalse("message body must be not included", chunkFetchSuccessAboveFetchToMemBlockSize.isBodyInFrame()); + } + @Test public void requests() { testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2))); @@ -88,10 +122,10 @@ public void requests() { testClientToServer(new OneWayMessage(new TestManagedBuffer(10))); } + @Test public void responses() { - testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(10))); - testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0))); + testChunkFetchSuccess(); testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "this is an error")); testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "")); testServerToClient(new RpcResponse(12345, new TestManagedBuffer(0))); diff --git a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java index c0724e018263f..1d6dcf4a11280 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java @@ -20,10 +20,8 @@ import com.google.common.util.concurrent.Uninterruptibles; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.client.*; +import org.apache.spark.network.protocol.StreamChunkId; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; @@ -31,6 +29,7 @@ import org.apache.spark.network.util.TransportConf; import org.junit.*; import static org.junit.Assert.*; +import static org.mockito.Mockito.mock; import java.io.IOException; import java.nio.ByteBuffer; @@ -38,6 +37,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; /** * Suite which ensures that requests that go without a response for the network timeout period are @@ -211,12 +211,13 @@ public StreamManager getStreamManager() { // Send one request, which will eventually fail. TestCallback callback0 = new TestCallback(); - client.fetchChunk(0, 0, callback0); + Supplier> streamCallbackFactory = mock(Supplier.class); + client.fetchChunk(0, 0, callback0, streamCallbackFactory); Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS); // Send a second request before the first has failed. TestCallback callback1 = new TestCallback(); - client.fetchChunk(0, 1, callback1); + client.fetchChunk(0, 1, callback1, streamCallbackFactory); Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS); // not complete yet, but should complete soon diff --git a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java index f253a07e64be1..04ae42efc0abb 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -308,7 +308,7 @@ private void waitForCompletion(TestCallback callback) throws Exception { } - private static class TestCallback implements StreamCallback { + private static class TestCallback implements StreamCallback { private final OutputStream out; public volatile boolean completed; diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java index b4032c4c3f031..f5935e4c404c1 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -19,19 +19,17 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.function.Supplier; import io.netty.channel.Channel; import io.netty.channel.local.LocalChannel; +import org.apache.spark.network.client.*; import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.*; import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.StreamCallback; -import org.apache.spark.network.client.TransportResponseHandler; import org.apache.spark.network.protocol.ChunkFetchFailure; import org.apache.spark.network.protocol.ChunkFetchSuccess; import org.apache.spark.network.protocol.RpcFailure; @@ -48,7 +46,8 @@ public void handleSuccessfulFetch() throws Exception { TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); - handler.addFetchRequest(streamChunkId, callback); + Supplier> streamCallbackFactory = mock(Supplier.class); + handler.addFetchRequest(streamChunkId, callback, streamCallbackFactory); assertEquals(1, handler.numOutstandingRequests()); handler.handle(new ChunkFetchSuccess(streamChunkId, new TestManagedBuffer(123))); @@ -61,7 +60,8 @@ public void handleFailedFetch() throws Exception { StreamChunkId streamChunkId = new StreamChunkId(1, 0); TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); - handler.addFetchRequest(streamChunkId, callback); + Supplier> streamCallbackFactory = mock(Supplier.class); + handler.addFetchRequest(streamChunkId, callback, streamCallbackFactory); assertEquals(1, handler.numOutstandingRequests()); handler.handle(new ChunkFetchFailure(streamChunkId, "some error msg")); @@ -73,9 +73,11 @@ public void handleFailedFetch() throws Exception { public void clearAllOutstandingRequests() throws Exception { TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); - handler.addFetchRequest(new StreamChunkId(1, 0), callback); - handler.addFetchRequest(new StreamChunkId(1, 1), callback); - handler.addFetchRequest(new StreamChunkId(1, 2), callback); + Supplier> streamCallback = mock(Supplier.class); + + handler.addFetchRequest(new StreamChunkId(1, 0), callback, streamCallback); + handler.addFetchRequest(new StreamChunkId(1, 1), callback, streamCallback); + handler.addFetchRequest(new StreamChunkId(1, 2), callback, streamCallback); assertEquals(3, handler.numOutstandingRequests()); handler.handle(new ChunkFetchSuccess(new StreamChunkId(1, 0), new TestManagedBuffer(12))); @@ -123,7 +125,8 @@ public void handleFailedRPC() throws Exception { @Test public void testActiveStreams() throws Exception { Channel c = new LocalChannel(); - c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); + c.pipeline() + .addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder(Integer.MAX_VALUE)); TransportResponseHandler handler = new TransportResponseHandler(c); StreamResponse response = new StreamResponse("stream", 1234L, null); @@ -145,7 +148,8 @@ public void testActiveStreams() throws Exception { @Test public void failOutstandingStreamCallbackOnClose() throws Exception { Channel c = new LocalChannel(); - c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); + c.pipeline() + .addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder(Integer.MAX_VALUE)); TransportResponseHandler handler = new TransportResponseHandler(c); StreamCallback cb = mock(StreamCallback.class); @@ -158,7 +162,8 @@ public void failOutstandingStreamCallbackOnClose() throws Exception { @Test public void failOutstandingStreamCallbackOnException() throws Exception { Channel c = new LocalChannel(); - c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); + c.pipeline() + .addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder(Integer.MAX_VALUE)); TransportResponseHandler handler = new TransportResponseHandler(c); StreamCallback cb = mock(StreamCallback.class); diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 6f15718bd8705..397e4f63f608f 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -33,6 +33,7 @@ import java.util.concurrent.TimeoutException; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; import javax.security.sasl.SaslException; import com.google.common.collect.ImmutableMap; @@ -44,16 +45,14 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelOutboundHandlerAdapter; import io.netty.channel.ChannelPromise; +import org.apache.spark.network.client.*; +import org.apache.spark.network.protocol.StreamChunkId; import org.junit.Test; import org.apache.spark.network.TestUtils; import org.apache.spark.network.TransportContext; import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientBootstrap; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; @@ -281,7 +280,8 @@ public void testFileRegionEncryption() throws Exception { return null; }).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class)); - ctx.client.fetchChunk(0, 0, callback); + Supplier> streamCallbackFactory = mock(Supplier.class); + ctx.client.fetchChunk(0, 0, callback, streamCallbackFactory); lock.await(10, TimeUnit.SECONDS); verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class)); diff --git a/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java index b53e41303751c..388bb41a57a5e 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java @@ -17,8 +17,8 @@ package org.apache.spark.network.util; -import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Random; import java.util.concurrent.atomic.AtomicInteger; @@ -26,6 +26,8 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; +import org.apache.spark.network.protocol.Message; +import org.apache.spark.network.protocol.ParsedFrame; import org.junit.AfterClass; import org.junit.Test; import static org.junit.Assert.*; @@ -42,7 +44,7 @@ public static void cleanup() { @Test public void testFrameDecoding() throws Exception { - TransportFrameDecoder decoder = new TransportFrameDecoder(); + TransportFrameDecoder decoder = new TransportFrameDecoder(Integer.MAX_VALUE); ChannelHandlerContext ctx = mockChannelHandlerContext(); ByteBuf data = createAndFeedFrames(100, decoder, ctx); verifyAndCloseDecoder(decoder, ctx, data); @@ -51,7 +53,7 @@ public void testFrameDecoding() throws Exception { @Test public void testInterception() throws Exception { int interceptedReads = 3; - TransportFrameDecoder decoder = new TransportFrameDecoder(); + TransportFrameDecoder decoder = new TransportFrameDecoder(Integer.MAX_VALUE); TransportFrameDecoder.Interceptor interceptor = spy(new MockInterceptor(interceptedReads)); ChannelHandlerContext ctx = mockChannelHandlerContext(); @@ -69,7 +71,7 @@ public void testInterception() throws Exception { decoder.channelRead(ctx, len); decoder.channelRead(ctx, dataBuf); verify(interceptor, times(interceptedReads)).handle(any(ByteBuf.class)); - verify(ctx).fireChannelRead(any(ByteBuffer.class)); + verify(ctx).fireChannelRead(any(ParsedFrame.class)); assertEquals(0, len.refCnt()); assertEquals(0, dataBuf.refCnt()); } finally { @@ -80,19 +82,19 @@ public void testInterception() throws Exception { @Test public void testRetainedFrames() throws Exception { - TransportFrameDecoder decoder = new TransportFrameDecoder(); + TransportFrameDecoder decoder = new TransportFrameDecoder(Integer.MAX_VALUE); AtomicInteger count = new AtomicInteger(); - List retained = new ArrayList<>(); + List retained = new ArrayList<>(); ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); when(ctx.fireChannelRead(any())).thenAnswer(in -> { // Retain a few frames but not others. - ByteBuf buf = (ByteBuf) in.getArguments()[0]; + ParsedFrame parsedFrame = (ParsedFrame) in.getArguments()[0]; if (count.incrementAndGet() % 2 == 0) { - retained.add(buf); + retained.add(parsedFrame); } else { - buf.release(); + parsedFrame.byteBuf.release(); } return null; }); @@ -100,15 +102,15 @@ public void testRetainedFrames() throws Exception { ByteBuf data = createAndFeedFrames(100, decoder, ctx); try { // Verify all retained buffers are readable. - for (ByteBuf b : retained) { - byte[] tmp = new byte[b.readableBytes()]; - b.readBytes(tmp); - b.release(); + for (ParsedFrame b : retained) { + byte[] tmp = new byte[b.byteBuf.readableBytes()]; + b.byteBuf.readBytes(tmp); + b.byteBuf.release(); } verifyAndCloseDecoder(decoder, ctx, data); } finally { - for (ByteBuf b : retained) { - release(b); + for (ParsedFrame b : retained) { + release(b.byteBuf); } } } @@ -120,13 +122,13 @@ public void testSplitLengthField() throws Exception { buf.writeLong(frame.length + 8); buf.writeBytes(frame); - TransportFrameDecoder decoder = new TransportFrameDecoder(); + TransportFrameDecoder decoder = new TransportFrameDecoder(Integer.MAX_VALUE); ChannelHandlerContext ctx = mockChannelHandlerContext(); try { decoder.channelRead(ctx, buf.readSlice(RND.nextInt(7)).retain()); - verify(ctx, never()).fireChannelRead(any(ByteBuf.class)); + verify(ctx, never()).fireChannelRead(any(ParsedFrame.class)); decoder.channelRead(ctx, buf); - verify(ctx).fireChannelRead(any(ByteBuf.class)); + verify(ctx).fireChannelRead(any(ParsedFrame.class)); assertEquals(0, buf.refCnt()); } finally { decoder.channelInactive(ctx); @@ -154,8 +156,14 @@ private ByteBuf createAndFeedFrames( TransportFrameDecoder decoder, ChannelHandlerContext ctx) throws Exception { ByteBuf data = Unpooled.buffer(); + Message.Type msgTypes[] = Arrays.stream(Message.Type.values()) + .filter(t -> t != Message.Type.User) + .toArray(Message.Type[]::new); + for (int i = 0; i < frameCount; i++) { byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)]; + Message.Type randomMsgType = msgTypes[RND.nextInt(msgTypes.length)]; + frame[0] = randomMsgType.id(); data.writeLong(frame.length + 8); data.writeBytes(frame); } @@ -166,7 +174,7 @@ private ByteBuf createAndFeedFrames( decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size)).retain()); } - verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class)); + verify(ctx, times(frameCount)).fireChannelRead(any(ParsedFrame.class)); } catch (Exception e) { release(data); throw e; @@ -187,7 +195,7 @@ private void verifyAndCloseDecoder( } private void testInvalidFrame(long size) throws Exception { - TransportFrameDecoder decoder = new TransportFrameDecoder(); + TransportFrameDecoder decoder = new TransportFrameDecoder(Integer.MAX_VALUE); ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); ByteBuf frame = Unpooled.copyLong(size); try { @@ -200,8 +208,8 @@ private void testInvalidFrame(long size) throws Exception { private ChannelHandlerContext mockChannelHandlerContext() { ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); when(ctx.fireChannelRead(any())).thenAnswer(in -> { - ByteBuf buf = (ByteBuf) in.getArguments()[0]; - buf.release(); + ParsedFrame parsedFrame = (ParsedFrame) in.getArguments()[0]; + parsedFrame.byteBuf.release(); return null; }); return ctx; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 7ed0b6e93a7a8..f2f445da748fb 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -86,12 +86,13 @@ public void init(String appId) { @Override public void fetchBlocks( - String host, - int port, - String execId, - String[] blockIds, - BlockFetchingListener listener, - TempFileManager tempFileManager) { + String host, + int port, + String execId, + String[] blockIds, + BlockFetchingListener listener, + TempFileManager tempFileManager, + boolean useStreamRequestMessage) { checkInit(); logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { @@ -99,7 +100,7 @@ public void fetchBlocks( (blockIds1, listener1) -> { TransportClient client = clientFactory.createClient(host, port); new OneForOneBlockFetcher(client, appId, execId, - blockIds1, listener1, conf, tempFileManager).start(); + blockIds1, listener1, conf, tempFileManager, useStreamRequestMessage).start(); }; int maxRetries = conf.maxIORetries(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 0bc571874f07c..5d1ec48ddaf0c 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -24,16 +24,18 @@ import java.nio.channels.Channels; import java.nio.channels.WritableByteChannel; import java.util.Arrays; +import java.util.function.Supplier; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import org.apache.spark.network.buffer.NettyManagedBuffer; +import org.apache.spark.network.client.*; +import org.apache.spark.network.protocol.StreamChunkId; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.StreamCallback; -import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; import org.apache.spark.network.shuffle.protocol.OpenBlocks; @@ -57,8 +59,10 @@ public class OneForOneBlockFetcher { private final String[] blockIds; private final BlockFetchingListener listener; private final ChunkReceivedCallback chunkCallback; + private final Supplier> fetchChunkDownloadCallbackFactory; private final TransportConf transportConf; private final TempFileManager tempFileManager; + private final boolean useStreamRequestMessage; private StreamHandle streamHandle = null; @@ -69,7 +73,7 @@ public OneForOneBlockFetcher( String[] blockIds, BlockFetchingListener listener, TransportConf transportConf) { - this(client, appId, execId, blockIds, listener, transportConf, null); + this(client, appId, execId, blockIds, listener, transportConf, null, false); } public OneForOneBlockFetcher( @@ -79,18 +83,24 @@ public OneForOneBlockFetcher( String[] blockIds, BlockFetchingListener listener, TransportConf transportConf, - TempFileManager tempFileManager) { + TempFileManager tempFileManager, + boolean useStreamRequestMessage) { this.client = client; this.openMessage = new OpenBlocks(appId, execId, blockIds); this.blockIds = blockIds; this.listener = listener; this.chunkCallback = new ChunkCallback(); this.transportConf = transportConf; + // TODO extend tests to pass a valid tempFileManager and use: + // this.tempFileManager = Preconditions.checkNotNull(tempFileManager); + fetchChunkDownloadCallbackFactory = () -> new FetchChunkDownloadCallback(); this.tempFileManager = tempFileManager; + this.useStreamRequestMessage = useStreamRequestMessage; } /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */ private class ChunkCallback implements ChunkReceivedCallback { + @Override public void onSuccess(int chunkIndex, ManagedBuffer buffer) { // On receipt of a chunk, pass it upwards as a block. @@ -125,11 +135,12 @@ public void onSuccess(ByteBuffer response) { // Immediately request all chunks -- we expect that the total size of the request is // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. for (int i = 0; i < streamHandle.numChunks; i++) { - if (tempFileManager != null) { + if (useStreamRequestMessage) { client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i), new DownloadCallback(i)); } else { - client.fetchChunk(streamHandle.streamId, i, chunkCallback); + client.fetchChunk(streamHandle.streamId, i, + chunkCallback, fetchChunkDownloadCallbackFactory); } } } catch (Exception e) { @@ -157,7 +168,7 @@ private void failRemainingBlocks(String[] failedBlockIds, Throwable e) { } } - private class DownloadCallback implements StreamCallback { + private class DownloadCallback implements StreamCallback { private WritableByteChannel channel = null; private File targetFile = null; @@ -196,4 +207,45 @@ public void onFailure(String streamId, Throwable cause) throws IOException { targetFile.delete(); } } + + private class FetchChunkDownloadCallback implements StreamCallback { + private WritableByteChannel channel = null; + private File targetFile = null; + + FetchChunkDownloadCallback() { + this.targetFile = tempFileManager.createTempFile(); + try { + this.channel = Channels.newChannel(new FileOutputStream(targetFile)); + } catch (IOException e) { + throw new IllegalStateException(e); + } + } + + @Override + public void onData(StreamChunkId streamId, ByteBuffer buf) throws IOException { + while (buf.hasRemaining()) { + channel.write(buf); + } + } + + @Override + public void onComplete(StreamChunkId streamId) throws IOException { + channel.close(); + ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0, + targetFile.length()); + listener.onBlockFetchSuccess(blockIds[streamId.chunkIndex], buffer); + if (!tempFileManager.registerTempFileToClean(targetFile)) { + targetFile.delete(); + } + } + + @Override + public void onFailure(StreamChunkId streamId, Throwable cause) throws IOException { + channel.close(); + // On receipt of a failure, fail every block from chunkIndex onwards. + String[] remainingBlockIds = Arrays.copyOfRange(blockIds, streamId.chunkIndex, blockIds.length); + failRemainingBlocks(remainingBlockIds, cause); + targetFile.delete(); + } + } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java index 18b04fedcac5b..eeb5c85b967ba 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -37,24 +37,24 @@ public void init(String appId) { } * Note that this API takes a sequence so the implementation can batch requests, and does not * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as * the data of a block is fetched, rather than waiting for all blocks to be fetched. - * * @param host the host of the remote node. * @param port the port of the remote node. * @param execId the executor id. * @param blockIds block ids to fetch. * @param listener the listener to receive block fetching status. * @param tempFileManager TempFileManager to create and clean temp files. - * If it's not null, the remote blocks will be streamed - * into temp shuffle files to reduce the memory usage, otherwise, - * they will be kept in memory. +* If it's not null, the remote blocks will be streamed +* into temp shuffle files to reduce the memory usage, otherwise, + * @param useStreamRequestMessage flags whether to fetch to disk as the request is too large */ public abstract void fetchBlocks( - String host, - int port, - String execId, - String[] blockIds, - BlockFetchingListener listener, - TempFileManager tempFileManager); + String host, + int port, + String execId, + String[] blockIds, + BlockFetchingListener listener, + TempFileManager tempFileManager, + boolean useStreamRequestMessage); /** * Get the shuffle MetricsSet from ShuffleClient, this will be used in MetricsSystem to diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 02e6eb3a4467e..416af7150b67f 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -23,7 +23,10 @@ import java.util.Arrays; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; +import org.apache.spark.network.client.*; +import org.apache.spark.network.protocol.StreamChunkId; import org.junit.After; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -35,10 +38,6 @@ import org.apache.spark.network.TestUtils; import org.apache.spark.network.TransportContext; import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; @@ -232,6 +231,7 @@ public void onBlockFetchFailure(String blockId, Throwable t) { CountDownLatch chunkReceivedLatch = new CountDownLatch(1); ChunkReceivedCallback callback = new ChunkReceivedCallback() { + @Override public void onSuccess(int chunkIndex, ManagedBuffer buffer) { chunkReceivedLatch.countDown(); @@ -244,7 +244,8 @@ public void onFailure(int chunkIndex, Throwable t) { }; exception.set(null); - client2.fetchChunk(streamId, 0, callback); + Supplier> streamCallbackFactory = mock(Supplier.class); + client2.fetchChunk(streamId, 0, callback, streamCallbackFactory); chunkReceivedLatch.await(); checkSecurityException(exception.get()); } finally { diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index a6a1b8d0ac3f1..e3a85ea67741a 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -158,7 +158,7 @@ public void onBlockFetchFailure(String blockId, Throwable exception) { } } } - }, null); + }, null, false); if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { fail("Timeout getting response from the server"); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index dc947a619bf02..ab8a37cdb3154 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -165,7 +165,7 @@ private static BlockFetchingListener fetchBlocks(LinkedHashMap 0) { diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 4103dfb10175e..44884395984a7 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -52,6 +52,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), + SparkEnv.get.conf.get(config.STREAM_REQUEST_MESSAGE_ENABLED), SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true)) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e0276a4dc4224..197154814e169 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -212,6 +212,7 @@ private[spark] class BlockManager( private[storage] val remoteBlockTempFileManager = new BlockManager.RemoteBlockTempFileManager(this) private val maxRemoteBlockToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) + private val streamRequestMessageEnabled = conf.get(config.STREAM_REQUEST_MESSAGE_ENABLED) /** * Initializes the BlockManager with the given appId. This is not performed in the constructor as @@ -671,16 +672,7 @@ private[spark] class BlockManager( b.status.diskSize.max(b.status.memSize) }.getOrElse(0L) val blockLocations = locationsAndStatus.map(_.locations).getOrElse(Seq.empty) - - // If the block size is above the threshold, we should pass our FileManger to - // BlockTransferService, which will leverage it to spill the block; if not, then passed-in - // null value means the block will be persisted in memory. - val tempFileManager = if (blockSize > maxRemoteBlockToMem) { - remoteBlockTempFileManager - } else { - null - } - + val useStreamRequestMessage = streamRequestMessageEnabled && blockSize > maxRemoteBlockToMem val locations = sortLocations(blockLocations) val maxFetchFailures = locations.size var locationIterator = locations.iterator @@ -689,7 +681,12 @@ private[spark] class BlockManager( logDebug(s"Getting remote block $blockId from $loc") val data = try { blockTransferService.fetchBlockSync( - loc.host, loc.port, loc.executorId, blockId.toString, tempFileManager).nioByteBuffer() + loc.host, + loc.port, + loc.executorId, + blockId.toString, + remoteBlockTempFileManager, + useStreamRequestMessage).nioByteBuffer() } catch { case NonFatal(e) => runningFailureCount += 1 diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index b31862323a895..b7203e01b013d 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -69,6 +69,7 @@ final class ShuffleBlockFetcherIterator( maxBytesInFlight: Long, maxReqsInFlight: Int, maxBlocksInFlightPerAddress: Int, + streamRequestMessageEnabled: Boolean, maxReqSizeShuffleToMem: Long, detectCorrupt: Boolean) extends Iterator[(BlockId, InputStream)] with TempFileManager with Logging { @@ -248,13 +249,14 @@ final class ShuffleBlockFetcherIterator( // Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch // the data and write it to file directly. - if (req.size > maxReqSizeShuffleToMem) { - shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - blockFetchingListener, this) - } else { - shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - blockFetchingListener, null) - } + shuffleClient.fetchBlocks( + address.host, + address.port, + address.executorId, + blockIds.toArray, + blockFetchingListener, + this, + req.size > maxReqSizeShuffleToMem && streamRequestMessageEnabled) } private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 28ea0c6f0bdba..0af967f39612e 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -171,7 +171,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex val serializerManager = SparkEnv.get.serializerManager blockManager.master.getLocations(blockId).foreach { cmId => val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId, - blockId.toString, null) + blockId.toString, null, false) val deserialized = serializerManager.dataDeserializeStream(blockId, new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream())(data.elementClassTag).toList assert(deserialized === (1 to 100).toList) diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 21138bd4a16ba..b574e07cc7b50 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -165,7 +165,9 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { promise.success(data.retain()) } - }, null) + }, + null, + false) ThreadUtils.awaitReady(promise.future, FiniteDuration(10, TimeUnit.SECONDS)) promise.future.value.get diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index b19d8ebf72c61..ec6a5cbcb280d 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -1434,7 +1434,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE execId: String, blockIds: Array[String], listener: BlockFetchingListener, - tempFileManager: TempFileManager): Unit = { + tempFileManager: TempFileManager, + useStreamRequestMessage: Boolean): Unit = { listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1))) } @@ -1461,13 +1462,14 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE port: Int, execId: String, blockId: String, - tempFileManager: TempFileManager): ManagedBuffer = { + tempFileManager: TempFileManager, + useStreamRequestMessage: Boolean): ManagedBuffer = { numCalls += 1 this.tempFileManager = tempFileManager if (numCalls <= maxFailures) { throw new RuntimeException("Failing block fetch in the mock block transfer service") } - super.fetchBlockSync(host, port, execId, blockId, tempFileManager) + super.fetchBlockSync(host, port, execId, blockId, tempFileManager, useStreamRequestMessage) } } } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index a2997dbd1b1ac..1e41a35aa4bdd 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -46,7 +46,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT /** Creates a mock [[BlockTransferService]] that returns data from the given map. */ private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = { val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]] @@ -111,6 +111,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, + false, Int.MaxValue, true) @@ -140,7 +141,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // 3 local blocks, and 2 remote blocks // (but from the same block manager so one call to fetchBlocks) verify(blockManager, times(3)).getBlockData(any()) - verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any()) + verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any(), any()) } test("release current unexhausted buffer in case the task completes early") { @@ -159,7 +160,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -189,6 +190,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, + false, Int.MaxValue, true) @@ -227,7 +229,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -257,6 +259,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, + false, Int.MaxValue, true) @@ -297,7 +300,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -327,6 +330,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, + false, Int.MaxValue, true) @@ -337,7 +341,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val (id1, _) = iterator.next() assert(id1 === ShuffleBlockId(0, 0, 0)) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -391,6 +395,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 2048, Int.MaxValue, Int.MaxValue, + false, Int.MaxValue, true) // Blocks should be returned without exceptions. @@ -415,7 +420,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -445,6 +450,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, + false, Int.MaxValue, false) @@ -478,12 +484,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val remoteBlocks = Map[BlockId, ManagedBuffer]( ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) val transfer = mock(classOf[BlockTransferService]) - var tempFileManager: TempFileManager = null - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + var useStreamRequestMessage: Boolean = false + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - tempFileManager = invocation.getArguments()(5).asInstanceOf[TempFileManager] + useStreamRequestMessage = invocation.getArguments()(6).asInstanceOf[Boolean] Future { listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0))) @@ -505,6 +511,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT maxBytesInFlight = Int.MaxValue, maxReqsInFlight = Int.MaxValue, maxBlocksInFlightPerAddress = Int.MaxValue, + streamRequestMessageEnabled = true, maxReqSizeShuffleToMem = 200, detectCorrupt = true) } @@ -514,14 +521,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT fetchShuffleBlock(blocksByAddress1) // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch // shuffle block to disk. - assert(tempFileManager == null) + assert(!useStreamRequestMessage) val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)).toIterator fetchShuffleBlock(blocksByAddress2) // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch // shuffle block to disk. - assert(tempFileManager != null) + assert(useStreamRequestMessage) } test("fail zero-size blocks") { @@ -551,6 +558,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, + false, Int.MaxValue, true)