Skip to content

Commit

Permalink
Add nio transport to security plugin (#31942)
Browse files Browse the repository at this point in the history
This is related to #27260. It adds the SecurityNioTransport to the
security plugin. Additionally, it adds support for ip filtering. And it
randomly uses the nio transport in security integration tests.
  • Loading branch information
Tim-Brooks authored Jul 12, 2018
1 parent 334c255 commit c375d5a
Show file tree
Hide file tree
Showing 13 changed files with 199 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,19 @@

import java.io.IOException;
import java.util.function.Consumer;
import java.util.function.Predicate;

public class BytesChannelContext extends SocketChannelContext {

public BytesChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler,
ReadWriteHandler handler, InboundChannelBuffer channelBuffer) {
super(channel, selector, exceptionHandler, handler, channelBuffer);
this(channel, selector, exceptionHandler, handler, channelBuffer, ALWAYS_ALLOW_CHANNEL);
}

public BytesChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler,
ReadWriteHandler handler, InboundChannelBuffer channelBuffer,
Predicate<NioSocketChannel> allowChannelPredicate) {
super(channel, selector, exceptionHandler, handler, channelBuffer, allowChannelPredicate);
}

@Override
Expand Down Expand Up @@ -77,7 +84,7 @@ public void closeChannel() {

@Override
public boolean selectorShouldClose() {
return isPeerClosed() || hasIOException() || isClosing.get();
return closeNow() || isClosing.get();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ public abstract class ChannelContext<S extends SelectableChannel & NetworkChanne
}

protected void register() throws IOException {
doSelectorRegister();
}

// Package private for testing
void doSelectorRegister() throws IOException {
setSelectionKey(rawChannel.register(getSelector().rawSelector(), 0));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Predicate;

/**
* This context should implement the specific logic for a channel. When a channel receives a notification
Expand All @@ -43,24 +44,28 @@
*/
public abstract class SocketChannelContext extends ChannelContext<SocketChannel> {

public static final Predicate<NioSocketChannel> ALWAYS_ALLOW_CHANNEL = (c) -> true;

protected final NioSocketChannel channel;
protected final InboundChannelBuffer channelBuffer;
protected final AtomicBoolean isClosing = new AtomicBoolean(false);
private final ReadWriteHandler readWriteHandler;
private final Predicate<NioSocketChannel> allowChannelPredicate;
private final NioSelector selector;
private final CompletableContext<Void> connectContext = new CompletableContext<>();
private final LinkedList<FlushOperation> pendingFlushes = new LinkedList<>();
private boolean ioException;
private boolean peerClosed;
private boolean closeNow;
private Exception connectException;

protected SocketChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler,
ReadWriteHandler readWriteHandler, InboundChannelBuffer channelBuffer) {
ReadWriteHandler readWriteHandler, InboundChannelBuffer channelBuffer,
Predicate<NioSocketChannel> allowChannelPredicate) {
super(channel.getRawChannel(), exceptionHandler);
this.selector = selector;
this.channel = channel;
this.readWriteHandler = readWriteHandler;
this.channelBuffer = channelBuffer;
this.allowChannelPredicate = allowChannelPredicate;
}

@Override
Expand Down Expand Up @@ -161,6 +166,14 @@ protected FlushOperation getPendingFlush() {
return pendingFlushes.peekFirst();
}

@Override
protected void register() throws IOException {
super.register();
if (allowChannelPredicate.test(channel) == false) {
closeNow = true;
}
}

@Override
public void closeFromSelector() throws IOException {
getSelector().assertOnSelectorThread();
Expand Down Expand Up @@ -217,24 +230,20 @@ public boolean readyForFlush() {
*/
public abstract boolean selectorShouldClose();

protected boolean hasIOException() {
return ioException;
}

protected boolean isPeerClosed() {
return peerClosed;
protected boolean closeNow() {
return closeNow;
}

protected int readFromChannel(ByteBuffer buffer) throws IOException {
try {
int bytesRead = rawChannel.read(buffer);
if (bytesRead < 0) {
peerClosed = true;
closeNow = true;
bytesRead = 0;
}
return bytesRead;
} catch (IOException e) {
ioException = true;
closeNow = true;
throw e;
}
}
Expand All @@ -243,12 +252,12 @@ protected int readFromChannel(ByteBuffer[] buffers) throws IOException {
try {
int bytesRead = (int) rawChannel.read(buffers);
if (bytesRead < 0) {
peerClosed = true;
closeNow = true;
bytesRead = 0;
}
return bytesRead;
} catch (IOException e) {
ioException = true;
closeNow = true;
throw e;
}
}
Expand All @@ -257,7 +266,7 @@ protected int flushToChannel(ByteBuffer buffer) throws IOException {
try {
return rawChannel.write(buffer);
} catch (IOException e) {
ioException = true;
closeNow = true;
throw e;
}
}
Expand All @@ -266,7 +275,7 @@ protected int flushToChannel(ByteBuffer[] buffers) throws IOException {
try {
return (int) rawChannel.write(buffers);
} catch (IOException e) {
ioException = true;
closeNow = true;
throw e;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Predicate;
import java.util.function.Supplier;

import static org.mockito.Matchers.any;
Expand Down Expand Up @@ -77,23 +78,39 @@ public void testIOExceptionSetIfEncountered() throws IOException {
when(rawChannel.write(any(ByteBuffer.class))).thenThrow(new IOException());
when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenThrow(new IOException());
when(rawChannel.read(any(ByteBuffer.class))).thenThrow(new IOException());
assertFalse(context.hasIOException());
assertFalse(context.closeNow());
expectThrows(IOException.class, () -> {
if (randomBoolean()) {
context.read();
} else {
context.flushChannel();
}
});
assertTrue(context.hasIOException());
assertTrue(context.closeNow());
}

public void testSignalWhenPeerClosed() throws IOException {
when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenReturn(-1L);
when(rawChannel.read(any(ByteBuffer.class))).thenReturn(-1);
assertFalse(context.isPeerClosed());
assertFalse(context.closeNow());
context.read();
assertTrue(context.isPeerClosed());
assertTrue(context.closeNow());
}

public void testValidateInRegisterCanSucceed() throws IOException {
InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance();
context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, (c) -> true);
assertFalse(context.closeNow());
context.register();
assertFalse(context.closeNow());
}

public void testValidateInRegisterCanFail() throws IOException {
InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance();
context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, (c) -> false);
assertFalse(context.closeNow());
context.register();
assertTrue(context.closeNow());
}

public void testConnectSucceeds() throws IOException {
Expand Down Expand Up @@ -277,7 +294,13 @@ private static class TestSocketChannelContext extends SocketChannelContext {

private TestSocketChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler,
ReadWriteHandler readWriteHandler, InboundChannelBuffer channelBuffer) {
super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer);
this(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, ALWAYS_ALLOW_CHANNEL);
}

private TestSocketChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler,
ReadWriteHandler readWriteHandler, InboundChannelBuffer channelBuffer,
Predicate<NioSocketChannel> allowChannelPredicate) {
super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, allowChannelPredicate);
}

@Override
Expand Down Expand Up @@ -309,6 +332,11 @@ public boolean selectorShouldClose() {
public void closeChannel() {
isClosing.set(true);
}

@Override
void doSelectorRegister() {
// We do not want to call the actual register with selector method as it will throw a NPE
}
}

private static byte[] createMessage(int length) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
public final class SecurityField {

public static final String NAME4 = XPackField.SECURITY + "4";
public static final String NIO = XPackField.SECURITY + "-nio";
public static final Setting<Optional<String>> USER_SETTING =
new Setting<>(setting("user"), (String) null, Optional::ofNullable, Setting.Property.NodeScope);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ public static Settings addTransportSettings(final Settings settings) {
final Settings.Builder builder = Settings.builder();
if (NetworkModule.TRANSPORT_TYPE_SETTING.exists(settings)) {
final String transportType = NetworkModule.TRANSPORT_TYPE_SETTING.get(settings);
if (SecurityField.NAME4.equals(transportType) == false) {
if (SecurityField.NAME4.equals(transportType) == false && SecurityField.NIO.equals(transportType) == false) {
throw new IllegalArgumentException("transport type setting [" + NetworkModule.TRANSPORT_TYPE_KEY
+ "] must be [" + SecurityField.NAME4 + "] but is [" + transportType + "]");
+ "] must be [" + SecurityField.NAME4 + "] or [" + SecurityField.NIO + "]" + " but is ["
+ transportType + "]");
}
} else {
// default to security4
Expand All @@ -39,7 +40,7 @@ public static Settings addUserSettings(final Settings settings) {
final int i = userSetting.indexOf(":");
if (i < 0 || i == userSetting.length() - 1) {
throw new IllegalArgumentException("invalid [" + SecurityField.USER_SETTING.getKey()
+ "] setting. must be in the form of \"<username>:<password>\"");
+ "] setting. must be in the form of \"<username>:<password>\"");
}
String username = userSetting.substring(0, i);
String password = userSetting.substring(i + 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@
import org.elasticsearch.xpack.security.transport.netty4.SecurityNetty4HttpServerTransport;
import org.elasticsearch.xpack.security.transport.netty4.SecurityNetty4ServerTransport;
import org.elasticsearch.xpack.core.template.TemplateUtils;
import org.elasticsearch.xpack.security.transport.nio.SecurityNioTransport;
import org.joda.time.DateTime;
import org.joda.time.DateTimeZone;

Expand Down Expand Up @@ -846,8 +847,14 @@ public Map<String, Supplier<Transport>> getTransports(Settings settings, ThreadP
if (transportClientMode || enabled == false) { // don't register anything if we are not enabled, or in transport client mode
return Collections.emptyMap();
}
return Collections.singletonMap(SecurityField.NAME4, () -> new SecurityNetty4ServerTransport(settings, threadPool,
networkService, bigArrays, namedWriteableRegistry, circuitBreakerService, ipFilter.get(), getSslService()));

Map<String, Supplier<Transport>> transports = new HashMap<>();
transports.put(SecurityField.NAME4, () -> new SecurityNetty4ServerTransport(settings, threadPool,
networkService, bigArrays, namedWriteableRegistry, circuitBreakerService, ipFilter.get(), getSslService()));
transports.put(SecurityField.NIO, () -> new SecurityNioTransport(settings, threadPool,
networkService, bigArrays, pageCacheRecycler, namedWriteableRegistry, circuitBreakerService, ipFilter.get(), getSslService()));

return Collections.unmodifiableMap(transports);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import java.io.IOException;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Predicate;

/**
* Provides a TLS/SSL read/write layer over a channel. This context will use a {@link SSLDriver} to handshake
Expand All @@ -30,7 +31,13 @@ public final class SSLChannelContext extends SocketChannelContext {

SSLChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler, SSLDriver sslDriver,
ReadWriteHandler readWriteHandler, InboundChannelBuffer channelBuffer) {
super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer);
this(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer, ALWAYS_ALLOW_CHANNEL);
}

SSLChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler, SSLDriver sslDriver,
ReadWriteHandler readWriteHandler, InboundChannelBuffer channelBuffer,
Predicate<NioSocketChannel> allowChannelPredicate) {
super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, allowChannelPredicate);
this.sslDriver = sslDriver;
}

Expand All @@ -52,7 +59,7 @@ public void queueWriteOperation(WriteOperation writeOperation) {

@Override
public void flushChannel() throws IOException {
if (hasIOException()) {
if (closeNow()) {
return;
}
// If there is currently data in the outbound write buffer, flush the buffer.
Expand Down Expand Up @@ -116,7 +123,7 @@ public boolean readyForFlush() {
@Override
public int read() throws IOException {
int bytesRead = 0;
if (hasIOException()) {
if (closeNow()) {
return bytesRead;
}
bytesRead = readFromChannel(sslDriver.getNetworkReadBuffer());
Expand All @@ -133,7 +140,7 @@ public int read() throws IOException {

@Override
public boolean selectorShouldClose() {
return isPeerClosed() || hasIOException() || sslDriver.isClosed();
return closeNow() || sslDriver.isClosed();
}

@Override
Expand Down
Loading

0 comments on commit c375d5a

Please sign in to comment.