Skip to content

Commit

Permalink
MockTcpTransport to connect asynchronously (#28203)
Browse files Browse the repository at this point in the history
The method `initiateChannel` on `TcpTransport` is explicit in that
channels can be connect asynchronously. All production implementations
do connect asynchronously. Only the blocking `MockTcpTransport`
connects in a synchronous manner. This avoids testing some of the
blocking code in `TcpTransport` that waits on connections to complete.
Additionally, it requires a more extensive method signature than
required for other transports.

This commit modifies the `MockTcpTransport` to make these connections
asynchronously on a different thread. Additionally, it simplifies that
`initiateChannel` method signature.
  • Loading branch information
Tim-Brooks authored Jan 15, 2018
1 parent 190f1e1 commit ee7eac8
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.SuppressForbidden;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
Expand All @@ -51,12 +50,10 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TcpChannel;
import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.TransportRequestOptions;

Expand Down Expand Up @@ -239,9 +236,8 @@ protected final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
}

@Override
protected NettyTcpChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> listener)
throws IOException {
ChannelFuture channelFuture = bootstrap.connect(node.getAddress().address());
protected NettyTcpChannel initiateChannel(InetSocketAddress address, ActionListener<Void> listener) throws IOException {
ChannelFuture channelFuture = bootstrap.connect(address);
Channel channel = channelFuture.channel();
if (channel == null) {
Netty4Utils.maybeDie(channelFuture.cause());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,12 @@

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.recycler.Recycler;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.common.util.concurrent.EsExecutors;
Expand Down Expand Up @@ -93,9 +91,8 @@ protected TcpNioServerSocketChannel bind(String name, InetSocketAddress address)
}

@Override
protected TcpNioSocketChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> connectListener)
throws IOException {
TcpNioSocketChannel channel = nioGroup.openChannel(node.getAddress().address(), clientChannelFactory);
protected TcpNioSocketChannel initiateChannel(InetSocketAddress address, ActionListener<Void> connectListener) throws IOException {
TcpNioSocketChannel channel = nioGroup.openChannel(address, clientChannelFactory);
channel.addConnectListener(ActionListener.toBiConsumer(connectListener));
return channel;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ public final NodeChannels openConnection(DiscoveryNode node, ConnectionProfile c
try {
PlainActionFuture<Void> connectFuture = PlainActionFuture.newFuture();
connectionFutures.add(connectFuture);
TcpChannel channel = initiateChannel(node, connectionProfile.getConnectTimeout(), connectFuture);
TcpChannel channel = initiateChannel(node.getAddress().address(), connectFuture);
logger.trace(() -> new ParameterizedMessage("Tcp transport client channel opened: {}", channel));
channels.add(channel);
} catch (Exception e) {
Expand Down Expand Up @@ -1057,17 +1057,14 @@ protected void serverAcceptedChannel(TcpChannel channel) {
protected abstract TcpChannel bind(String name, InetSocketAddress address) throws IOException;

/**
* Initiate a single tcp socket channel to a node. Implementations do not have to observe the connectTimeout.
* It is provided for synchronous connection implementations.
* Initiate a single tcp socket channel.
*
* @param node the node
* @param connectTimeout the connection timeout
* @param connectListener listener to be called when connection complete
* @param address address for the initiated connection
* @param connectListener listener to be called when connection complete
* @return the pending connection
* @throws IOException if an I/O exception occurs while opening the channel
*/
protected abstract TcpChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> connectListener)
throws IOException;
protected abstract TcpChannel initiateChannel(InetSocketAddress address, ActionListener<Void> connectListener) throws IOException;

/**
* Called to tear down internal resources
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.compress.CompressorFactory;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
Expand All @@ -41,15 +40,13 @@
import java.io.IOException;
import java.io.StreamCorruptedException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.core.IsInstanceOf.instanceOf;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;

/** Unit tests for {@link TcpTransport} */
public class TcpTransportTests extends ESTestCase {
Expand Down Expand Up @@ -193,8 +190,7 @@ protected FakeChannel bind(String name, InetSocketAddress address) throws IOExce
}

@Override
protected FakeChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> connectListener)
throws IOException {
protected FakeChannel initiateChannel(InetSocketAddress address, ActionListener<Void> connectListener) throws IOException {
return new FakeChannel(messageCaptor);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.apache.lucene.util.IOUtils;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.InputStreamStreamInput;
Expand All @@ -30,7 +29,6 @@
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.CancellableThreads;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
Expand All @@ -49,7 +47,6 @@
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
Expand All @@ -61,7 +58,6 @@
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;

/**
* This is a socket based blocking TcpTransport implementation that is used for tests
Expand Down Expand Up @@ -164,28 +160,32 @@ private void readMessage(MockChannel mockChannel, StreamInput input) throws IOEx
}

@Override
protected MockChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> connectListener)
throws IOException {
InetSocketAddress address = node.getAddress().address();
protected MockChannel initiateChannel(InetSocketAddress address, ActionListener<Void> connectListener) throws IOException {
final MockSocket socket = new MockSocket();
final MockChannel channel = new MockChannel(socket, address, "none");

boolean success = false;
try {
configureSocket(socket);
try {
socket.connect(address, Math.toIntExact(connectTimeout.millis()));
} catch (SocketTimeoutException ex) {
throw new ConnectTransportException(node, "connect_timeout[" + connectTimeout + "]", ex);
}
MockChannel channel = new MockChannel(socket, address, "none", (c) -> {});
channel.loopRead(executor);
success = true;
connectListener.onResponse(null);
return channel;
} finally {
if (success == false) {
IOUtils.close(socket);
}

}

executor.submit(() -> {
try {
socket.connect(address);
channel.loopRead(executor);
connectListener.onResponse(null);
} catch (Exception ex) {
connectListener.onFailure(ex);
}
});

return channel;
}

@Override
Expand Down Expand Up @@ -218,7 +218,6 @@ public final class MockChannel implements Closeable, TcpChannel {
private final Socket activeChannel;
private final String profile;
private final CancellableThreads cancellableThreads = new CancellableThreads();
private final Closeable onClose;
private final CompletableFuture<Void> closeFuture = new CompletableFuture<>();

/**
Expand All @@ -227,14 +226,12 @@ public final class MockChannel implements Closeable, TcpChannel {
* @param socket The client socket. Mut not be null.
* @param localAddress Address associated with the corresponding local server socket. Must not be null.
* @param profile The associated profile name.
* @param onClose Callback to execute when this channel is closed.
*/
public MockChannel(Socket socket, InetSocketAddress localAddress, String profile, Consumer<MockChannel> onClose) {
public MockChannel(Socket socket, InetSocketAddress localAddress, String profile) {
this.localAddress = localAddress;
this.activeChannel = socket;
this.serverSocket = null;
this.profile = profile;
this.onClose = () -> onClose.accept(this);
synchronized (openChannels) {
openChannels.add(this);
}
Expand All @@ -246,12 +243,11 @@ public MockChannel(Socket socket, InetSocketAddress localAddress, String profile
* @param serverSocket The associated server socket. Must not be null.
* @param profile The associated profile name.
*/
public MockChannel(ServerSocket serverSocket, String profile) {
MockChannel(ServerSocket serverSocket, String profile) {
this.localAddress = (InetSocketAddress) serverSocket.getLocalSocketAddress();
this.serverSocket = serverSocket;
this.profile = profile;
this.activeChannel = null;
this.onClose = null;
synchronized (openChannels) {
openChannels.add(this);
}
Expand All @@ -266,8 +262,19 @@ public void accept(Executor executor) throws IOException {
synchronized (this) {
if (isOpen.get()) {
incomingChannel = new MockChannel(incomingSocket,
new InetSocketAddress(incomingSocket.getLocalAddress(), incomingSocket.getPort()), profile,
workerChannels::remove);
new InetSocketAddress(incomingSocket.getLocalAddress(), incomingSocket.getPort()), profile);
MockChannel finalIncomingChannel = incomingChannel;
incomingChannel.addCloseListener(new ActionListener<Void>() {
@Override
public void onResponse(Void aVoid) {
workerChannels.remove(finalIncomingChannel);
}

@Override
public void onFailure(Exception e) {
workerChannels.remove(finalIncomingChannel);
}
});
serverAcceptedChannel(incomingChannel);
//establish a happens-before edge between closing and accepting a new connection
workerChannels.add(incomingChannel);
Expand All @@ -287,7 +294,7 @@ public void accept(Executor executor) throws IOException {
}
}

public void loopRead(Executor executor) {
void loopRead(Executor executor) {
executor.execute(new AbstractRunnable() {
@Override
public void onFailure(Exception e) {
Expand All @@ -312,7 +319,7 @@ protected void doRun() throws Exception {
});
}

public synchronized void close0() throws IOException {
synchronized void close0() throws IOException {
// establish a happens-before edge between closing and accepting a new connection
// we have to sync this entire block to ensure that our openChannels checks work correctly.
// The close block below will close all worker channels but if one of the worker channels runs into an exception
Expand All @@ -325,7 +332,7 @@ public synchronized void close0() throws IOException {
removedChannel = openChannels.remove(this);
}
IOUtils.close(serverSocket, activeChannel, () -> IOUtils.close(workerChannels),
() -> cancellableThreads.cancel("channel closed"), onClose);
() -> cancellableThreads.cancel("channel closed"));
assert removedChannel: "Channel was not removed or removed twice?";
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,11 @@

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.recycler.Recycler;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
Expand Down Expand Up @@ -83,9 +81,8 @@ protected MockServerChannel bind(String name, InetSocketAddress address) throws
}

@Override
protected MockSocketChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> connectListener)
throws IOException {
MockSocketChannel channel = nioGroup.openChannel(node.getAddress().address(), clientChannelFactory);
protected MockSocketChannel initiateChannel(InetSocketAddress address, ActionListener<Void> connectListener) throws IOException {
MockSocketChannel channel = nioGroup.openChannel(address, clientChannelFactory);
channel.addConnectListener(ActionListener.toBiConsumer(connectListener));
return channel;
}
Expand Down

0 comments on commit ee7eac8

Please sign in to comment.