diff --git a/core/src/main/java/com/google/net/stubby/Status.java b/core/src/main/java/com/google/net/stubby/Status.java index d6828975851..d491cf68ab6 100644 --- a/core/src/main/java/com/google/net/stubby/Status.java +++ b/core/src/main/java/com/google/net/stubby/Status.java @@ -14,6 +14,7 @@ public class Status { public static final Status OK = new Status(Transport.Code.OK); + public static final Status CANCELLED = new Status(Transport.Code.CANCELLED); public static Status fromThrowable(Throwable t) { for (Throwable cause : Throwables.getCausalChain(t)) { diff --git a/core/src/main/java/com/google/net/stubby/newtransport/ClientStream.java b/core/src/main/java/com/google/net/stubby/newtransport/ClientStream.java index b7b760ce033..5f9df7570d7 100644 --- a/core/src/main/java/com/google/net/stubby/newtransport/ClientStream.java +++ b/core/src/main/java/com/google/net/stubby/newtransport/ClientStream.java @@ -1,16 +1,14 @@ package com.google.net.stubby.newtransport; - /** * Extension of {@link Stream} to support client-side termination semantics. */ public interface ClientStream extends Stream { /** - * Used to abnormally terminate the stream. Any internally buffered messages are dropped. After - * this is called, no further messages may be sent and no further {@link StreamListener} callbacks - * (with the exception of onClosed) will be invoked for this stream. Any frames received for this - * stream after returning from this method will be discarded. + * Used to abnormally terminate the stream. After calling this method, no further messages will be + * sent or received, however it may still be possible to receive buffered messages for a brief + * period until {@link StreamListener#closed} is called. */ void cancel(); } diff --git a/core/src/main/java/com/google/net/stubby/newtransport/ClientTransportFactory.java b/core/src/main/java/com/google/net/stubby/newtransport/ClientTransportFactory.java index 059bab47cca..0a574e06c6b 100644 --- a/core/src/main/java/com/google/net/stubby/newtransport/ClientTransportFactory.java +++ b/core/src/main/java/com/google/net/stubby/newtransport/ClientTransportFactory.java @@ -1,5 +1,6 @@ package com.google.net.stubby.newtransport; + /** Pre-configured factory for created {@link ClientTransport} instances. */ public interface ClientTransportFactory { /** Create an unstarted transport for exclusive use. */ diff --git a/core/src/main/java/com/google/net/stubby/newtransport/Deframer.java b/core/src/main/java/com/google/net/stubby/newtransport/Deframer.java index 06b26e50a65..6c7387dc3ef 100644 --- a/core/src/main/java/com/google/net/stubby/newtransport/Deframer.java +++ b/core/src/main/java/com/google/net/stubby/newtransport/Deframer.java @@ -5,6 +5,7 @@ import com.google.net.stubby.Operation; import com.google.net.stubby.Status; import com.google.net.stubby.transport.Transport; +import com.google.protobuf.ByteString; import java.io.ByteArrayInputStream; import java.io.DataInputStream; @@ -26,6 +27,7 @@ public abstract class Deframer implements Framer.Sink { private boolean inFrame; private byte currentFlags; private int currentLength = LENGTH_NOT_SET; + private boolean statusDelivered; public Deframer(Framer target) { this.target = target; @@ -34,8 +36,12 @@ public Deframer(Framer target) { @Override public void deliverFrame(F frame, boolean endOfStream) { int remaining = internalDeliverFrame(frame); - if (endOfStream && remaining > 0) { - target.writeStatus(new Status(Transport.Code.UNKNOWN, "EOF on incomplete frame")); + if (endOfStream) { + if (remaining > 0) { + writeStatus(new Status(Transport.Code.UNKNOWN, "EOF on incomplete frame")); + } else if (!statusDelivered) { + writeStatus(Status.OK); + } } } @@ -91,8 +97,8 @@ private int internalDeliverFrame(F frame) { // deal with out-of-order tags etc. Transport.ContextValue contextValue = Transport.ContextValue.parseFrom(framedChunk); try { - target.writeContext(contextValue.getKey(), - contextValue.getValue().newInput(), currentLength); + ByteString value = contextValue.getValue(); + target.writeContext(contextValue.getKey(), value.newInput(), value.size()); } finally { currentLength = LENGTH_NOT_SET; inFrame = false; @@ -104,10 +110,9 @@ private int internalDeliverFrame(F frame) { try { if (code == null) { // Log for unknown code - target.writeStatus( - new Status(Transport.Code.UNKNOWN, "Unknown status code " + status)); + writeStatus(new Status(Transport.Code.UNKNOWN, "Unknown status code " + status)); } else { - target.writeStatus(new Status(code)); + writeStatus(new Status(code)); } } finally { currentLength = LENGTH_NOT_SET; @@ -121,7 +126,7 @@ private int internalDeliverFrame(F frame) { } } catch (IOException ioe) { Status status = new Status(Transport.Code.UNKNOWN, ioe); - target.writeStatus(status); + writeStatus(status); throw status.asRuntimeException(); } } @@ -148,6 +153,11 @@ private boolean ensure(InputStream input, int len) throws IOException { return (input.available() >= len); } + private void writeStatus(Status status) { + target.writeStatus(status); + statusDelivered = true; + } + /** * Return a message of {@code len} bytes than can be read from the buffer. If sufficient * bytes are unavailable then buffer the available bytes and return null. diff --git a/core/src/main/java/com/google/net/stubby/newtransport/Framer.java b/core/src/main/java/com/google/net/stubby/newtransport/Framer.java index 125ae597fd9..c5b273dadec 100644 --- a/core/src/main/java/com/google/net/stubby/newtransport/Framer.java +++ b/core/src/main/java/com/google/net/stubby/newtransport/Framer.java @@ -11,12 +11,12 @@ public interface Framer { /** - * Sink implemented by the transport layer to receive frames and forward them to their - * destination + * Sink implemented by the transport layer to receive frames and forward them to their destination */ public interface Sink { /** * Deliver a frame via the transport. + * * @param frame the contents of the frame to deliver * @param endOfStream whether the frame is the last one for the GRPC stream */ @@ -47,12 +47,20 @@ public interface Sink { public void flush(); /** - * Flushes and closes the framer and releases any buffers. + * Indicates whether or not this {@link Framer} has been closed via a call to either + * {@link #close()} or {@link #dispose()}. + */ + public boolean isClosed(); + + /** + * Flushes and closes the framer and releases any buffers. After the {@link Framer} is closed or + * disposed, additional calls to this method will have no affect. */ public void close(); /** - * Closes the framer and releases any buffers, but does not flush. + * Closes the framer and releases any buffers, but does not flush. After the {@link Framer} is + * closed or disposed, additional calls to this method will have no affect. */ public void dispose(); } diff --git a/core/src/main/java/com/google/net/stubby/newtransport/HttpUtil.java b/core/src/main/java/com/google/net/stubby/newtransport/HttpUtil.java new file mode 100644 index 00000000000..25b992c7ba5 --- /dev/null +++ b/core/src/main/java/com/google/net/stubby/newtransport/HttpUtil.java @@ -0,0 +1,24 @@ +package com.google.net.stubby.newtransport; + +/** + * Constants for GRPC-over-HTTP (or HTTP/2) + */ +public final class HttpUtil { + /** + * The Content-Type header name. Defined here since it is not explicitly defined by the HTTP/2 + * spec. + */ + public static final String CONTENT_TYPE_HEADER = "content-type"; + + /** + * Content-Type used for GRPC-over-HTTP/2. + */ + public static final String CONTENT_TYPE_PROTORPC = "application/protorpc"; + + /** + * The HTTP method used for GRPC requests. + */ + public static final String HTTP_METHOD = "POST"; + + private HttpUtil() {} +} diff --git a/core/src/main/java/com/google/net/stubby/newtransport/MessageFramer.java b/core/src/main/java/com/google/net/stubby/newtransport/MessageFramer.java index 7a0502d5b43..87cb153d820 100644 --- a/core/src/main/java/com/google/net/stubby/newtransport/MessageFramer.java +++ b/core/src/main/java/com/google/net/stubby/newtransport/MessageFramer.java @@ -67,6 +67,7 @@ public MessageFramer(Sink sink, int maxFrameSize) { * Sets whether compression is encouraged. */ public void setAllowCompression(boolean enable) { + verifyNotClosed(); framer.setAllowCompression(enable); } @@ -76,11 +77,13 @@ public void setAllowCompression(boolean enable) { * @see java.util.zip.Deflater#setLevel */ public void setCompressionLevel(int level) { + verifyNotClosed(); framer.setCompressionLevel(level); } @Override public void writePayload(InputStream message, int messageLength) { + verifyNotClosed(); try { scratch.clear(); scratch.put(GrpcFramingUtil.PAYLOAD_FRAME); @@ -98,6 +101,7 @@ public void writePayload(InputStream message, int messageLength) { @Override public void writeContext(String key, InputStream message, int messageLen) { + verifyNotClosed(); try { scratch.clear(); scratch.put(GrpcFramingUtil.CONTEXT_VALUE_FRAME); @@ -132,6 +136,7 @@ public void writeContext(String key, InputStream message, int messageLen) { @Override public void writeStatus(Status status) { + verifyNotClosed(); short code = (short) status.getCode().getNumber(); scratch.clear(); scratch.put(GrpcFramingUtil.STATUS_FRAME); @@ -144,14 +149,22 @@ public void writeStatus(Status status) { @Override public void flush() { + verifyNotClosed(); framer.flush(); } + @Override + public boolean isClosed() { + return framer == null; + } + @Override public void close() { - // TODO(user): Returning buffer to a pool would go here - framer.close(); - framer = null; + if (!isClosed()) { + // TODO(user): Returning buffer to a pool would go here + framer.close(); + framer = null; + } } @Override @@ -160,6 +173,12 @@ public void dispose() { framer = null; } + private void verifyNotClosed() { + if (isClosed()) { + throw new IllegalStateException("Framer already closed"); + } + } + /** * Write a raw VarInt32 to the buffer */ diff --git a/core/src/main/java/com/google/net/stubby/newtransport/TransportFrameUtil.java b/core/src/main/java/com/google/net/stubby/newtransport/TransportFrameUtil.java index e543961d9b0..9c402b81951 100644 --- a/core/src/main/java/com/google/net/stubby/newtransport/TransportFrameUtil.java +++ b/core/src/main/java/com/google/net/stubby/newtransport/TransportFrameUtil.java @@ -1,16 +1,17 @@ package com.google.net.stubby.newtransport; + /** * Utility functions for transport layer framing. * - * Within a given transport frame we reserve the first byte to indicate the - * type of compression used for the contents of the transport frame. + *

Within a given transport frame we reserve the first byte to indicate the type of compression + * used for the contents of the transport frame. */ -public class TransportFrameUtil { +public final class TransportFrameUtil { // Compression modes (lowest order 3 bits of frame flags) public static final byte NO_COMPRESS_FLAG = 0x0; - public static final byte FLATE_FLAG = 0x1; + public static final byte FLATE_FLAG = 0x1; public static final byte COMPRESSION_FLAG_MASK = 0x7; public static boolean isNotCompressed(int b) { @@ -20,4 +21,6 @@ public static boolean isNotCompressed(int b) { public static boolean isFlateCompressed(int b) { return ((b & COMPRESSION_FLAG_MASK) == FLATE_FLAG); } + + private TransportFrameUtil() {} } diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/ByteBufDeframer.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/ByteBufDeframer.java new file mode 100644 index 00000000000..84944ab3f19 --- /dev/null +++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/ByteBufDeframer.java @@ -0,0 +1,71 @@ +package com.google.net.stubby.newtransport.netty; + +import com.google.net.stubby.newtransport.Deframer; +import com.google.net.stubby.newtransport.Framer; +import com.google.net.stubby.newtransport.TransportFrameUtil; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufInputStream; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.UnpooledByteBufAllocator; + +import java.io.DataInputStream; +import java.io.IOException; +import java.nio.ByteOrder; + +/** + * Parse a sequence of {@link ByteBuf} instances that represent the frames of a GRPC call + */ +public class ByteBufDeframer extends Deframer { + + private final CompositeByteBuf buffer; + + public ByteBufDeframer(Framer target) { + this(UnpooledByteBufAllocator.DEFAULT, target); + } + + public ByteBufDeframer(ByteBufAllocator alloc, Framer target) { + super(target); + buffer = alloc.compositeBuffer(); + } + + public void dispose() { + // Remove the components from the composite buffer. This should set the reference + // count on all buffers to zero. + buffer.removeComponents(0, buffer.numComponents()); + + // Release the composite buffer + buffer.release(); + } + + @Override + protected DataInputStream prefix(ByteBuf frame) throws IOException { + buffer.addComponent(frame); + buffer.writerIndex(buffer.writerIndex() + frame.writerIndex() - frame.readerIndex()); + return new DataInputStream(new ByteBufInputStream(buffer)); + } + + @Override + protected int consolidate() { + buffer.consolidate(); + return buffer.readableBytes(); + } + + @Override + protected ByteBuf decompress(ByteBuf frame) throws IOException { + frame = frame.order(ByteOrder.BIG_ENDIAN); + int compressionType = frame.readUnsignedByte(); + int frameLength = frame.readUnsignedMedium(); + if (frameLength != frame.readableBytes()) { + throw new IllegalArgumentException("GRPC and buffer lengths misaligned. Frame length=" + + frameLength + ", readableBytes=" + frame.readableBytes()); + } + if (TransportFrameUtil.isNotCompressed(compressionType)) { + // Need to retain the frame as we may be holding it over channel events + frame.retain(); + return frame; + } + throw new IOException("Unknown compression type " + compressionType); + } +} diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/CancelStreamCommand.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/CancelStreamCommand.java new file mode 100644 index 00000000000..ae42c75c9f5 --- /dev/null +++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/CancelStreamCommand.java @@ -0,0 +1,18 @@ +package com.google.net.stubby.newtransport.netty; + +import com.google.common.base.Preconditions; + +/** + * Command sent from a Netty client stream to the handler to cancel the stream. + */ +class CancelStreamCommand { + private final NettyClientStream stream; + + CancelStreamCommand(NettyClientStream stream) { + this.stream = Preconditions.checkNotNull(stream, "stream"); + } + + NettyClientStream stream() { + return stream; + } +} diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/CreateStreamCommand.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/CreateStreamCommand.java new file mode 100644 index 00000000000..2107796df62 --- /dev/null +++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/CreateStreamCommand.java @@ -0,0 +1,26 @@ +package com.google.net.stubby.newtransport.netty; + +import com.google.common.base.Preconditions; +import com.google.net.stubby.MethodDescriptor; + +/** + * A command to create a new stream. This is created by {@link NettyClientStream} and passed to the + * {@link NettyClientHandler} for processing in the Channel thread. + */ +class CreateStreamCommand { + final MethodDescriptor method; + final NettyClientStream stream; + + CreateStreamCommand(MethodDescriptor method, NettyClientStream stream) { + this.method = Preconditions.checkNotNull(method, "method"); + this.stream = Preconditions.checkNotNull(stream, "stream"); + } + + MethodDescriptor method() { + return method; + } + + NettyClientStream stream() { + return stream; + } +} diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClient.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClient.java deleted file mode 100644 index 017a6c4bcec..00000000000 --- a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClient.java +++ /dev/null @@ -1,81 +0,0 @@ -package com.google.net.stubby.newtransport.netty; - -import static io.netty.channel.ChannelOption.SO_KEEPALIVE; - -import com.google.common.util.concurrent.AbstractService; - -import io.netty.bootstrap.Bootstrap; -import io.netty.channel.Channel; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelInitializer; -import io.netty.channel.EventLoopGroup; -import io.netty.channel.nio.NioEventLoopGroup; -import io.netty.channel.socket.SocketChannel; -import io.netty.channel.socket.nio.NioSocketChannel; - -/** - * Implementation of the {@link com.google.common.util.concurrent.Service} interface for a - * Netty-based client. - */ -public class NettyClient extends AbstractService { - private final String host; - private final int port; - private final ChannelInitializer channelInitializer; - private Channel channel; - private EventLoopGroup eventGroup; - - public NettyClient(String host, int port, ChannelInitializer channelInitializer) { - this.host = host; - this.port = port; - this.channelInitializer = channelInitializer; - } - - public Channel channel() { - return channel; - } - - @Override - protected void doStart() { - eventGroup = new NioEventLoopGroup(); - - Bootstrap b = new Bootstrap(); - b.group(eventGroup); - b.channel(NioSocketChannel.class); - b.option(SO_KEEPALIVE, true); - b.handler(channelInitializer); - - // Start the connection operation to the server. - b.connect(host, port).addListener(new ChannelFutureListener() { - @Override - public void operationComplete(ChannelFuture future) throws Exception { - if (future.isSuccess()) { - channel = future.channel(); - notifyStarted(); - } else { - notifyFailed(future.cause()); - } - } - }); - } - - @Override - protected void doStop() { - if (channel != null && channel.isOpen()) { - channel.close().addListener(new ChannelFutureListener() { - @Override - public void operationComplete(ChannelFuture future) throws Exception { - if (future.isSuccess()) { - notifyStopped(); - } else { - notifyFailed(future.cause()); - } - } - }); - } - - if (eventGroup != null) { - eventGroup.shutdownGracefully(); - } - } -} diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientHandler.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientHandler.java new file mode 100644 index 00000000000..2c565e2de02 --- /dev/null +++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientHandler.java @@ -0,0 +1,407 @@ +package com.google.net.stubby.newtransport.netty; + +import static com.google.net.stubby.newtransport.HttpUtil.CONTENT_TYPE_HEADER; +import static com.google.net.stubby.newtransport.HttpUtil.CONTENT_TYPE_PROTORPC; +import static com.google.net.stubby.newtransport.HttpUtil.HTTP_METHOD; +import static com.google.net.stubby.newtransport.netty.NettyClientStream.PENDING_STREAM_ID; + +import com.google.common.base.Preconditions; +import com.google.net.stubby.MethodDescriptor; +import com.google.net.stubby.Status; +import com.google.net.stubby.transport.Transport; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http2.AbstractHttp2ConnectionHandler; +import io.netty.handler.codec.http2.DefaultHttp2Connection; +import io.netty.handler.codec.http2.DefaultHttp2Headers; +import io.netty.handler.codec.http2.Http2Connection; +import io.netty.handler.codec.http2.Http2ConnectionAdapter; +import io.netty.handler.codec.http2.Http2Error; +import io.netty.handler.codec.http2.Http2Exception; +import io.netty.handler.codec.http2.Http2Headers; +import io.netty.handler.codec.http2.Http2Stream; +import io.netty.handler.codec.http2.Http2StreamException; +import io.netty.handler.codec.http2.Http2StreamRemovalPolicy; + +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.Iterator; + +/** + * Client-side Netty handler for GRPC processing. All event handlers are executed entirely within + * the context of the Netty Channel thread. + */ +class NettyClientHandler extends AbstractHttp2ConnectionHandler { + private static final Status GOAWAY_STATUS = new Status(Transport.Code.UNAVAILABLE); + + /** + * A pending stream creation. + */ + private final class PendingStream { + private final MethodDescriptor method; + private final NettyClientStream stream; + private final ChannelPromise promise; + + public PendingStream(CreateStreamCommand command, ChannelPromise promise) { + method = command.method(); + stream = command.stream(); + this.promise = promise; + } + } + + private final String host; + private final String scheme; + private final Deque pendingStreams = new ArrayDeque(); + private Status goAwayStatus = GOAWAY_STATUS; + + public NettyClientHandler(String host, boolean ssl, + Http2StreamRemovalPolicy streamRemovalPolicy) { + this(host, ssl, new DefaultHttp2Connection(false, false, streamRemovalPolicy)); + } + + private NettyClientHandler(String host, boolean ssl, Http2Connection connection) { + super(connection); + this.host = Preconditions.checkNotNull(host, "host"); + this.scheme = ssl ? "https" : "http"; + + // Disallow stream creation by the server. + connection.remote().maxStreams(0); + + // Observe the HTTP/2 connection for events. + connection.addListener(new Http2ConnectionAdapter() { + @Override + public void streamHalfClosed(Http2Stream stream) { + // Check for disallowed state: HALF_CLOSED_REMOTE. + terminateIfInvalidState(stream); + } + + @Override + public void streamInactive(Http2Stream stream) { + // Whenever a stream has been closed, try to create a pending stream to fill its place. + createPendingStreams(); + } + + @Override + public void goingAway() { + NettyClientHandler.this.goingAway(); + } + }); + } + + /** + * Handler for commands sent from the stream. + */ + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + try { + if (msg instanceof CreateStreamCommand) { + createStream((CreateStreamCommand) msg, promise); + } else if (msg instanceof SendGrpcFrameCommand) { + sendGrpcFrame(ctx, (SendGrpcFrameCommand) msg, promise); + } else if (msg instanceof CancelStreamCommand) { + cancelStream(ctx, (CancelStreamCommand) msg, promise); + } else { + throw new AssertionError("Write called for unexpected type: " + msg.getClass().getName()); + } + } catch (Throwable t) { + promise.setFailure(t); + } + } + + /** + * Handler for an inbound HTTP/2 DATA frame. + */ + @Override + public void onDataRead(ChannelHandlerContext ctx, + int streamId, + ByteBuf data, + int padding, + boolean endOfStream, + boolean endOfSegment, + boolean compressed) throws Http2Exception { + NettyClientStream stream = clientStream(connection().requireStream(streamId)); + + // TODO(user): update flow controller to use a promise. + stream.inboundDataReceived(data, endOfStream, ctx.newPromise()); + } + + /** + * Handler for an inbound HTTP/2 RST_STREAM frame, terminating a stream. + */ + @Override + public void onRstStreamRead(ChannelHandlerContext ctx, int streamId, long errorCode) + throws Http2Exception { + // TODO(user): do something with errorCode? + Http2Stream http2Stream = connection().requireStream(streamId); + NettyClientStream stream = clientStream(http2Stream); + stream.setStatus(new Status(Transport.Code.UNKNOWN)); + } + + /** + * Handler for the Channel shutting down. + */ + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + super.channelInactive(ctx); + + // Fail any streams that are awaiting creation. + failPendingStreams(goAwayStatus); + + // Any streams that are still active must be closed. + for (Http2Stream stream : http2Streams()) { + clientStream(stream).setStatus(goAwayStatus); + } + } + + /** + * Handler for connection errors that have occurred during HTTP/2 frame processing. + */ + @Override + protected void onConnectionError(ChannelHandlerContext ctx, Http2Exception cause) { + // Save the exception that is causing us to send a GO_AWAY. + goAwayStatus = Status.fromThrowable(cause); + + // Call the base class to send the GOAWAY. This will call the goingAway handler. + super.onConnectionError(ctx, cause); + } + + /** + * Handler for stream errors that have occurred during HTTP/2 frame processing. + */ + @Override + protected void onStreamError(ChannelHandlerContext ctx, Http2StreamException cause) { + // Close the stream with a status that contains the cause. + Http2Stream stream = connection().stream(cause.streamId()); + if (stream != null) { + clientStream(stream).setStatus(Status.fromThrowable(cause)); + } + super.onStreamError(ctx, cause); + } + + /** + * Attempts to create a new stream from the given command. If there are too many active streams, + * the creation request is queued. + */ + private void createStream(CreateStreamCommand command, ChannelPromise promise) { + // Add the creation request to the queue. + pendingStreams.addLast(new PendingStream(command, promise)); + + // Process the pending streams queue. + createPendingStreams(); + } + + /** + * Cancels this stream. + */ + private void cancelStream(ChannelHandlerContext ctx, CancelStreamCommand cmd, + ChannelPromise promise) throws Http2Exception { + NettyClientStream stream = cmd.stream(); + stream.setStatus(Status.CANCELLED); + + // No need to set the stream status for a cancellation. It should already have been + // set prior to sending the command. + + // If the stream hasn't been created yet, remove it from the pending queue. + if (stream.id() == PENDING_STREAM_ID) { + removePendingStream(stream); + promise.setSuccess(); + return; + } + + // Send a RST_STREAM frame to terminate this stream. + Http2Stream http2Stream = connection().requireStream(stream.id()); + if (http2Stream.state() != Http2Stream.State.CLOSED) { + writeRstStream(ctx, promise, stream.id(), Http2Error.CANCEL.code()); + } + } + + /** + * Sends the given GRPC frame for the stream. + */ + private void sendGrpcFrame(ChannelHandlerContext ctx, SendGrpcFrameCommand cmd, + ChannelPromise promise) throws Http2Exception { + NettyClientStream stream = cmd.stream(); + Http2Stream http2Stream = connection().requireStream(stream.id()); + switch (http2Stream.state()) { + case CLOSED: + case HALF_CLOSED_LOCAL: + case IDLE: + case RESERVED_LOCAL: + case RESERVED_REMOTE: + cmd.release(); + promise.setFailure(new Exception("Closed before write could occur")); + return; + default: + break; + } + + // Call the base class to write the HTTP/2 DATA frame. + writeData(ctx, + promise, + stream.id(), + cmd.content(), + 0, + cmd.endStream(), + cmd.endSegment(), + false); + } + + /** + * Handler for a GOAWAY being either sent or received. + */ + private void goingAway() { + // Fail any streams that are awaiting creation. + failPendingStreams(goAwayStatus); + + if (connection().local().isGoAwayReceived()) { + // Received a GOAWAY from the remote endpoint. Fail any streams that were created after the + // last known stream. + int lastKnownStream = connection().local().lastKnownStream(); + for (Http2Stream stream : http2Streams()) { + if (lastKnownStream < stream.id()) { + clientStream(stream).setStatus(goAwayStatus); + stream.close(); + } + } + } + } + + /** + * Processes the pending stream creation requests. This considers several conditions: + * + *

+ * 1) The HTTP/2 connection has exhausted its stream IDs. In this case all pending streams are + * immediately failed. + *

+ * 2) The HTTP/2 connection is going away. In this case all pending streams are immediately + * failed. + *

+ * 3) The HTTP/2 connection's MAX_CONCURRENT_STREAMS limit has been reached. In this case, + * processing of pending streams stops until an active stream has been closed. + */ + private void createPendingStreams() { + Http2Connection connection = connection(); + Http2Connection.Endpoint local = connection.local(); + while (!pendingStreams.isEmpty()) { + final int streamId = local.nextStreamId(); + if (streamId <= 0) { + // The HTTP/2 connection has exhausted its stream IDs. Permanently fail all stream creation + // attempts for this transport. + // TODO(user): send GO_AWAY? + failPendingStreams(goAwayStatus); + return; + } + + if (connection.isGoAway()) { + failPendingStreams(goAwayStatus); + return; + } + + if (!local.acceptingNewStreams()) { + // We're bumping up against the MAX_CONCURRENT_STEAMS threshold for this endpoint. Need to + // wait until the endpoint is accepting new streams. + return; + } + + // Finish creation of the stream by writing a headers frame. + final PendingStream pendingStream = pendingStreams.remove(); + // TODO(user): Change Netty to not send priority, just use default. + Http2Headers headers = DefaultHttp2Headers + .newBuilder() + .method(HTTP_METHOD) + .authority(host) + .scheme(scheme) + .add(CONTENT_TYPE_HEADER, CONTENT_TYPE_PROTORPC) + .path("/" + pendingStream.method.getName()) + .build(); + writeHeaders(ctx(), ctx().newPromise(), streamId, headers, 0, false, false).addListener( + new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + streamCreated(pendingStream.stream, streamId, pendingStream.promise); + } else { + // Fail the creation request. + pendingStream.promise.setFailure(future.cause()); + } + } + }); + } + } + + /** + * Handles the successful creation of a new stream. + */ + private void streamCreated(NettyClientStream stream, int streamId, ChannelPromise promise) + throws Http2Exception { + // Attach the client stream to the HTTP/2 stream object as user data. + Http2Stream http2Stream = connection().requireStream(streamId); + http2Stream.data(stream); + + // Notify the stream that it has been created. + stream.id(streamId); + promise.setSuccess(); + } + + /** + * Gets the client stream associated to the given HTTP/2 stream object. + */ + private NettyClientStream clientStream(Http2Stream stream) { + return stream.data(); + } + + /** + * Fails all pending streams with the given status and clears the queue. + */ + private void failPendingStreams(Status status) { + while (!pendingStreams.isEmpty()) { + PendingStream pending = pendingStreams.remove(); + pending.promise.setFailure(status.asException()); + } + } + + /** + * Removes the given stream from the pending queue + * + * @param stream the stream to be removed. + */ + private void removePendingStream(NettyClientStream stream) { + for (Iterator iter = pendingStreams.iterator(); iter.hasNext();) { + PendingStream pending = iter.next(); + if (pending.stream == stream) { + iter.remove(); + return; + } + } + } + + /** + * Gets a copy of the streams currently in the connection. + */ + private Http2Stream[] http2Streams() { + return connection().activeStreams().toArray(new Http2Stream[0]); + } + + /** + * Terminates the stream if it's in an unsupported state. + */ + private void terminateIfInvalidState(Http2Stream stream) { + switch (stream.state()) { + case HALF_CLOSED_REMOTE: + case IDLE: + case RESERVED_LOCAL: + case RESERVED_REMOTE: + // Disallowed state, terminate the stream. + clientStream(stream).setStatus( + new Status(Transport.Code.INTERNAL, "Stream in invalid state: " + stream.state())); + writeRstStream(ctx(), ctx().newPromise(), stream.id(), Http2Error.INTERNAL_ERROR.code()); + break; + default: + break; + } + } +} diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientStream.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientStream.java new file mode 100644 index 00000000000..8881ad08518 --- /dev/null +++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientStream.java @@ -0,0 +1,327 @@ +package com.google.net.stubby.newtransport.netty; + +import static com.google.net.stubby.newtransport.StreamState.CLOSED; +import static com.google.net.stubby.newtransport.StreamState.OPEN; +import static com.google.net.stubby.newtransport.StreamState.READ_ONLY; + +import com.google.common.base.Preconditions; +import com.google.common.io.Closeables; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.MoreExecutors; +import com.google.net.stubby.Status; +import com.google.net.stubby.newtransport.ClientStream; +import com.google.net.stubby.newtransport.Deframer; +import com.google.net.stubby.newtransport.Framer; +import com.google.net.stubby.newtransport.MessageFramer; +import com.google.net.stubby.newtransport.StreamListener; +import com.google.net.stubby.newtransport.StreamState; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; +import io.netty.channel.ChannelPromise; + +import java.io.InputStream; +import java.nio.ByteBuffer; + +import javax.annotation.Nullable; + +/** + * Client stream for a Netty transport. + */ +class NettyClientStream implements ClientStream { + public static final int PENDING_STREAM_ID = -1; + + /** + * Indicates the phase of the GRPC stream in one direction. + */ + private enum Phase { + CONTEXT, MESSAGE, STATUS + } + + /** + * Guards transition of stream state. + */ + private final Object stateLock = new Object(); + + /** + * Guards access to the frame writer. + */ + private final Object writeLock = new Object(); + + private volatile StreamState state = OPEN; + private volatile int id = PENDING_STREAM_ID; + private Status status; + private Phase inboundPhase = Phase.CONTEXT; + private Phase outboundPhase = Phase.CONTEXT; + private final StreamListener listener; + private final Channel channel; + private final Framer framer; + private final Deframer deframer; + + private final Framer.Sink outboundFrameHandler = new Framer.Sink() { + @Override + public void deliverFrame(ByteBuffer buffer, boolean endStream) { + ByteBuf buf = toByteBuf(buffer); + send(buf, endStream, endStream); + } + }; + + private final Framer inboundMessageHandler = new Framer() { + @Override + public void writeContext(String name, InputStream value, int length) { + ListenableFuture future = null; + try { + inboundPhase(Phase.CONTEXT); + future = listener.contextRead(name, value, length); + } finally { + closeWhenDone(future, value); + } + } + + @Override + public void writePayload(InputStream input, int length) { + ListenableFuture future = null; + try { + inboundPhase(Phase.MESSAGE); + future = listener.messageRead(input, length); + } finally { + closeWhenDone(future, input); + } + } + + @Override + public void writeStatus(Status status) { + inboundPhase(Phase.STATUS); + setStatus(status); + } + + @Override + public void flush() {} + + @Override + public boolean isClosed() { + return false; + } + + @Override + public void close() {} + + @Override + public void dispose() {} + }; + + NettyClientStream(StreamListener listener, Channel channel) { + this.listener = Preconditions.checkNotNull(listener, "listener"); + this.channel = Preconditions.checkNotNull(channel, "channel"); + this.deframer = new ByteBufDeframer(channel.alloc(), inboundMessageHandler); + this.framer = new MessageFramer(outboundFrameHandler, 4096); + } + + /** + * Returns the HTTP/2 ID for this stream. + */ + public int id() { + return id; + } + + void id(int id) { + this.id = id; + } + + @Override + public StreamState state() { + return state; + } + + @Override + public void close() { + outboundPhase(Phase.STATUS); + // Transition the state to mark the close the local side of the stream. + synchronized (stateLock) { + state = state == OPEN ? READ_ONLY : CLOSED; + } + + // Close the frame writer and send any buffered frames. + synchronized (writeLock) { + framer.close(); + } + } + + @Override + public void cancel() { + outboundPhase = Phase.STATUS; + + // Send the cancel command to the handler. + channel.writeAndFlush(new CancelStreamCommand(this)); + } + + /** + * Free any resources associated with this stream. + */ + public void dispose() { + synchronized (writeLock) { + framer.dispose(); + } + } + + @Override + public void writeContext(String name, InputStream value, int length, + @Nullable final Runnable accepted) { + Preconditions.checkNotNull(name, "name"); + Preconditions.checkNotNull(value, "value"); + Preconditions.checkArgument(length >= 0, "length must be >= 0"); + outboundPhase(Phase.CONTEXT); + synchronized (writeLock) { + if (!framer.isClosed()) { + framer.writeContext(name, value, length); + } + } + + // TODO(user): add flow control. + if (accepted != null) { + accepted.run(); + } + } + + @Override + public void writeMessage(InputStream message, int length, @Nullable final Runnable accepted) { + Preconditions.checkNotNull(message, "message"); + Preconditions.checkArgument(length >= 0, "length must be >= 0"); + outboundPhase(Phase.MESSAGE); + synchronized (writeLock) { + if (!framer.isClosed()) { + framer.writePayload(message, length); + } + } + + // TODO(user): add flow control. + if (accepted != null) { + accepted.run(); + } + } + + @Override + public void flush() { + synchronized (writeLock) { + if (!framer.isClosed()) { + framer.flush(); + } + } + } + + /** + * Called in the channel thread to process the content of an inbound DATA frame. + * + * @param frame the inbound HTTP/2 DATA frame. If this buffer is not used immediately, it must be + * retained. + * @param promise the promise to be set after the application has finished processing the frame. + */ + public void inboundDataReceived(ByteBuf frame, boolean endOfStream, ChannelPromise promise) { + Preconditions.checkNotNull(frame, "frame"); + Preconditions.checkNotNull(promise, "promise"); + if (state == CLOSED) { + promise.setSuccess(); + return; + } + + // Retain the ByteBuf until it is released by the deframer. + deframer.deliverFrame(frame.retain(), endOfStream); + + // TODO(user): add flow control. + promise.setSuccess(); + } + + /** + * Sets the status if not already set and notifies the stream listener that the stream was closed. + * This method must be called from the Netty channel thread. + * + * @param newStatus the new status to set + * @return {@code} true if the status was not already set. + */ + public boolean setStatus(final Status newStatus) { + Preconditions.checkNotNull(newStatus, "newStatus"); + synchronized (stateLock) { + if (status != null) { + // Disallow override of current status. + return false; + } + + status = newStatus; + state = CLOSED; + } + + // Invoke the observer callback. + listener.closed(newStatus); + + // Free any resources. + dispose(); + + return true; + } + + /** + * Writes the given frame to the channel. + * + * @param data the grpc frame to be written. + * @param endStream indicates whether this is the last frame to be sent for this stream. + * @param endMessage indicates whether the data ends at a message boundary. + */ + private void send(ByteBuf data, boolean endStream, boolean endMessage) { + SendGrpcFrameCommand frame = new SendGrpcFrameCommand(this, data, endStream, endMessage); + channel.writeAndFlush(frame); + } + + /** + * Copies the content of the given {@link ByteBuffer} to a new {@link ByteBuf} instance. + */ + private ByteBuf toByteBuf(ByteBuffer source) { + ByteBuf buf = channel.alloc().buffer(source.remaining()); + buf.writeBytes(source); + return buf; + } + + /** + * Transitions the inbound phase. If the transition is disallowed, throws a + * {@link IllegalStateException}. + */ + private void inboundPhase(Phase nextPhase) { + inboundPhase = verifyNextPhase(inboundPhase, nextPhase); + } + + /** + * Transitions the outbound phase. If the transition is disallowed, throws a + * {@link IllegalStateException}. + */ + private void outboundPhase(Phase nextPhase) { + outboundPhase = verifyNextPhase(outboundPhase, nextPhase); + } + + private Phase verifyNextPhase(Phase currentPhase, Phase nextPhase) { + // Only allow forward progression. + if (nextPhase.ordinal() < currentPhase.ordinal() || currentPhase == Phase.STATUS) { + throw new IllegalStateException( + String.format("Cannot transition phase from %s to %s", currentPhase, nextPhase)); + } + return nextPhase; + } + + /** + * If the given future is provided, closes the {@link InputStream} when it completes. Otherwise + * the {@link InputStream} is closed immediately. + */ + private static void closeWhenDone(@Nullable ListenableFuture future, + final InputStream input) { + if (future == null) { + Closeables.closeQuietly(input); + return; + } + + // Close the buffer when the future completes. + future.addListener(new Runnable() { + @Override + public void run() { + Closeables.closeQuietly(input); + } + }, MoreExecutors.sameThreadExecutor()); + } +} diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientTransport.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientTransport.java new file mode 100644 index 00000000000..28a652d9e0f --- /dev/null +++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientTransport.java @@ -0,0 +1,142 @@ +package com.google.net.stubby.newtransport.netty; + +import static io.netty.channel.ChannelOption.SO_KEEPALIVE; + +import com.google.common.base.Preconditions; +import com.google.common.util.concurrent.AbstractService; +import com.google.net.stubby.MethodDescriptor; +import com.google.net.stubby.newtransport.ClientStream; +import com.google.net.stubby.newtransport.ClientTransport; +import com.google.net.stubby.newtransport.StreamListener; + +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.http2.DefaultHttp2StreamRemovalPolicy; + +import java.util.concurrent.ExecutionException; + +/** + * A Netty-based {@link ClientTransport} implementation. + */ +class NettyClientTransport extends AbstractService implements ClientTransport { + + private final String host; + private final int port; + private final EventLoopGroup eventGroup; + private final ChannelInitializer channelInitializer; + private Channel channel; + + NettyClientTransport(String host, int port, boolean ssl) { + this(host, port, ssl, new NioEventLoopGroup()); + } + + NettyClientTransport(String host, int port, boolean ssl, EventLoopGroup eventGroup) { + Preconditions.checkNotNull(host, "host"); + Preconditions.checkArgument(port >= 0, "port must be positive"); + Preconditions.checkNotNull(eventGroup, "eventGroup"); + this.host = host; + this.port = port; + this.eventGroup = eventGroup; + final DefaultHttp2StreamRemovalPolicy streamRemovalPolicy = + new DefaultHttp2StreamRemovalPolicy(); + final NettyClientHandler handler = new NettyClientHandler(host, ssl, streamRemovalPolicy); + // TODO(user): handle SSL. + channelInitializer = new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) throws Exception { + ch.pipeline().addLast(streamRemovalPolicy); + ch.pipeline().addLast(handler); + } + }; + } + + @Override + public ClientStream newStream(MethodDescriptor method, StreamListener listener) { + Preconditions.checkNotNull(method, "method"); + Preconditions.checkNotNull(listener, "listener"); + switch (state()) { + case STARTING: + // Wait until the transport is running before creating the new stream. + awaitRunning(); + break; + case NEW: + case TERMINATED: + case FAILED: + throw new IllegalStateException("Unable to create new stream in state: " + state()); + default: + break; + } + + // Create the stream. + NettyClientStream stream = new NettyClientStream(listener, channel); + + try { + // Write the request and await creation of the stream. + channel.writeAndFlush(new CreateStreamCommand(method, stream)).get(); + } catch (InterruptedException e) { + // Restore the interrupt. + Thread.currentThread().interrupt(); + stream.dispose(); + throw new RuntimeException(e); + } catch (ExecutionException e) { + stream.dispose(); + throw new RuntimeException(e); + } + + return stream; + } + + @Override + protected void doStart() { + Bootstrap b = new Bootstrap(); + b.group(eventGroup); + b.channel(NioSocketChannel.class); + b.option(SO_KEEPALIVE, true); + b.handler(channelInitializer); + + // Start the connection operation to the server. + b.connect(host, port).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + channel = future.channel(); + notifyStarted(); + + // Listen for the channel close event. + channel.closeFuture().addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + notifyStopped(); + } else { + notifyFailed(future.cause()); + } + } + }); + } else { + notifyFailed(future.cause()); + } + } + }); + } + + @Override + protected void doStop() { + // No explicit call to notifyStopped() here, since this is automatically done when the + // channel closes. + if (channel != null && channel.isOpen()) { + channel.close(); + } + + if (eventGroup != null) { + eventGroup.shutdownGracefully(); + } + } +} diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientTransportFactory.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientTransportFactory.java new file mode 100644 index 00000000000..3a3a30a7219 --- /dev/null +++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyClientTransportFactory.java @@ -0,0 +1,30 @@ +package com.google.net.stubby.newtransport.netty; + +import com.google.common.base.Preconditions; +import com.google.net.stubby.newtransport.ClientTransportFactory; + +import io.netty.channel.EventLoopGroup; + +/** + * Factory that manufactures instances of {@link NettyClientTransport}. + */ +public class NettyClientTransportFactory implements ClientTransportFactory { + + private final String host; + private final int port; + private final boolean ssl; + private final EventLoopGroup group; + + public NettyClientTransportFactory(String host, int port, boolean ssl, EventLoopGroup group) { + this.group = Preconditions.checkNotNull(group, "group"); + Preconditions.checkArgument(port > 0, "Port must be positive"); + this.host = Preconditions.checkNotNull(host, "host"); + this.port = port; + this.ssl = ssl; + } + + @Override + public NettyClientTransport newClientTransport() { + return new NettyClientTransport(host, port, ssl, group); + } +} diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyServer.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyServer.java index 048235a2139..630cdd787fa 100644 --- a/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyServer.java +++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/NettyServer.java @@ -3,6 +3,7 @@ import static io.netty.channel.ChannelOption.SO_BACKLOG; import static io.netty.channel.ChannelOption.SO_KEEPALIVE; +import com.google.common.base.Preconditions; import com.google.common.util.concurrent.AbstractService; import io.netty.bootstrap.ServerBootstrap; @@ -22,20 +23,28 @@ public class NettyServer extends AbstractService { private final int port; private final ChannelInitializer channelInitializer; + private final EventLoopGroup bossGroup; + private final EventLoopGroup workerGroup; private Channel channel; - private EventLoopGroup bossGroup; - private EventLoopGroup workerGroup; public NettyServer(int port, ChannelInitializer channelInitializer) { + this(port, channelInitializer, new NioEventLoopGroup(), new NioEventLoopGroup()); + } + + public NettyServer(int port, ChannelInitializer channelInitializer, + EventLoopGroup bossGroup, EventLoopGroup workerGroup) { + Preconditions.checkNotNull(channelInitializer, "channelInitializer"); + Preconditions.checkNotNull(bossGroup, "bossGroup"); + Preconditions.checkNotNull(workerGroup, "workerGroup"); + Preconditions.checkArgument(port >= 0, "port must be positive"); this.port = port; this.channelInitializer = channelInitializer; + this.bossGroup = bossGroup; + this.workerGroup = workerGroup; } @Override protected void doStart() { - bossGroup = new NioEventLoopGroup(); - workerGroup = new NioEventLoopGroup(); - ServerBootstrap b = new ServerBootstrap(); b.group(bossGroup, workerGroup); b.channel(NioServerSocketChannel.class); diff --git a/core/src/main/java/com/google/net/stubby/newtransport/netty/SendGrpcFrameCommand.java b/core/src/main/java/com/google/net/stubby/newtransport/netty/SendGrpcFrameCommand.java new file mode 100644 index 00000000000..d61f8624250 --- /dev/null +++ b/core/src/main/java/com/google/net/stubby/newtransport/netty/SendGrpcFrameCommand.java @@ -0,0 +1,70 @@ +package com.google.net.stubby.newtransport.netty; + +import com.google.common.base.Preconditions; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufHolder; +import io.netty.buffer.DefaultByteBufHolder; + +/** + * Command sent from the transport to the Netty channel to send a GRPC frame to the remote endpoint. + */ +class SendGrpcFrameCommand extends DefaultByteBufHolder { + private final NettyClientStream stream; + private final boolean endStream; + private final boolean endSegment; + + SendGrpcFrameCommand(NettyClientStream stream, ByteBuf content, boolean endStream, + boolean endSegment) { + super(content); + this.stream = Preconditions.checkNotNull(stream, "stream"); + this.endStream = endStream; + this.endSegment = endSegment; + } + + NettyClientStream stream() { + return stream; + } + + boolean endStream() { + return endStream; + } + + boolean endSegment() { + return endSegment; + } + + @Override + public ByteBufHolder copy() { + return new SendGrpcFrameCommand(stream, content().copy(), endStream, endSegment); + } + + @Override + public ByteBufHolder duplicate() { + return new SendGrpcFrameCommand(stream, content().duplicate(), endStream, endSegment); + } + + @Override + public SendGrpcFrameCommand retain() { + super.retain(); + return this; + } + + @Override + public SendGrpcFrameCommand retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public SendGrpcFrameCommand touch() { + super.touch(); + return this; + } + + @Override + public SendGrpcFrameCommand touch(Object hint) { + super.touch(hint); + return this; + } +} diff --git a/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyClientHandlerTest.java b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyClientHandlerTest.java new file mode 100644 index 00000000000..98279da2479 --- /dev/null +++ b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyClientHandlerTest.java @@ -0,0 +1,314 @@ +package com.google.net.stubby.newtransport.netty; + +import static com.google.net.stubby.newtransport.HttpUtil.CONTENT_TYPE_HEADER; +import static com.google.net.stubby.newtransport.HttpUtil.CONTENT_TYPE_PROTORPC; +import static com.google.net.stubby.newtransport.HttpUtil.HTTP_METHOD; +import static io.netty.handler.codec.http2.Http2CodecUtil.immediateRemovalPolicy; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.calls; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.net.stubby.MethodDescriptor; +import com.google.net.stubby.Status; +import com.google.net.stubby.newtransport.StreamState; +import com.google.net.stubby.transport.Transport; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.InOrder; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http2.DefaultHttp2FrameReader; +import io.netty.handler.codec.http2.DefaultHttp2FrameWriter; +import io.netty.handler.codec.http2.Http2CodecUtil; +import io.netty.handler.codec.http2.Http2Error; +import io.netty.handler.codec.http2.Http2FrameObserver; +import io.netty.handler.codec.http2.Http2FrameReader; +import io.netty.handler.codec.http2.Http2FrameWriter; +import io.netty.handler.codec.http2.Http2Headers; +import io.netty.handler.codec.http2.Http2Settings; + +/** + * Tests for {@link NettyClientHandler}. + */ +@RunWith(JUnit4.class) +public class NettyClientHandlerTest { + + private NettyClientHandler handler; + + @Mock + private Channel channel; + + @Mock + private ChannelHandlerContext ctx; + + @Mock + private ChannelFuture future; + + @Mock + private ChannelPromise promise; + + @Mock + private NettyClientStream stream; + + @Mock + private MethodDescriptor method; + + @Mock + private Http2FrameObserver frameObserver; + + private Http2FrameWriter frameWriter; + private Http2FrameReader frameReader; + private ByteBuf content; + + @Before + public void setup() throws Exception { + MockitoAnnotations.initMocks(this); + + frameWriter = new DefaultHttp2FrameWriter(); + frameReader = new DefaultHttp2FrameReader(); + handler = new NettyClientHandler("www.fake.com", true, immediateRemovalPolicy()); + content = Unpooled.copiedBuffer("hello world", UTF_8); + + when(channel.isActive()).thenReturn(true); + mockContext(); + mockFuture(true); + + when(method.getName()).thenReturn("fakemethod"); + when(stream.state()).thenReturn(StreamState.OPEN); + + // Simulate activation of the handler to force writing of the initial settings + handler.handlerAdded(ctx); + + // Simulate receipt of initial remote settings. + ByteBuf serializedSettings = serializeSettings(new Http2Settings()); + handler.channelRead(ctx, serializedSettings); + + // Reset the context to clear any interactions resulting from the HTTP/2 + // connection preface handshake. + mockContext(); + } + + @Test + public void createStreamShouldSucceed() throws Exception { + handler.write(ctx, new CreateStreamCommand(method, stream), promise); + verify(promise).setSuccess(); + verify(stream).id(eq(3)); + + // Capture and verify the written headers frame. + ByteBuf serializedHeaders = captureWrite(ctx); + ChannelHandlerContext ctx = newContext(); + frameReader.readFrame(ctx, serializedHeaders, frameObserver); + ArgumentCaptor captor = ArgumentCaptor.forClass(Http2Headers.class); + verify(frameObserver).onHeadersRead(eq(ctx), + eq(3), + captor.capture(), + eq(0), + eq(Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT), + eq(false), + eq(0), + eq(false), + eq(false)); + Http2Headers headers = captor.getValue(); + assertEquals("https", headers.scheme()); + assertEquals(HTTP_METHOD, headers.method()); + assertEquals("www.fake.com", headers.authority()); + assertEquals(CONTENT_TYPE_PROTORPC, headers.get(CONTENT_TYPE_HEADER)); + assertEquals("/fakemethod", headers.path()); + } + + @Test + public void cancelShouldSucceed() throws Exception { + createStream(); + + handler.write(ctx, new CancelStreamCommand(stream), promise); + + ByteBuf expected = rstStreamFrame(3, Http2Error.CANCEL.code()); + verify(ctx).writeAndFlush(eq(expected), eq(promise)); + } + + @Test + public void cancelForUnknownStreamShouldFail() throws Exception { + when(stream.id()).thenReturn(3); + handler.write(ctx, new CancelStreamCommand(stream), promise); + verify(promise).setFailure(any(Throwable.class)); + } + + @Test + public void sendFrameShouldSucceed() throws Exception { + createStream(); + + // Send a frame and verify that it was written. + handler.write(ctx, new SendGrpcFrameCommand(stream, content, true, true), promise); + verify(promise, never()).setFailure(any(Throwable.class)); + verify(ctx).writeAndFlush(any(ByteBuf.class), eq(promise)); + } + + @Test + public void sendForUnknownStreamShouldFail() throws Exception { + when(stream.id()).thenReturn(3); + handler.write(ctx, new SendGrpcFrameCommand(stream, content, true, true), promise); + verify(promise).setFailure(any(Throwable.class)); + } + + @Test + public void inboundDataShouldForwardToStream() throws Exception { + createStream(); + + // Create a data frame and then trigger the handler to read it. + // Need to retain to simulate what is done by the stream. + ByteBuf frame = dataFrame(3, false).retain(); + handler.channelRead(this.ctx, frame); + verify(stream).inboundDataReceived(eq(content), eq(false), eq(promise)); + } + + @Test + public void createShouldQueueStream() throws Exception { + // Disallow stream creation to force the stream to get added to the pending queue. + setMaxConcurrentStreams(0); + handler.write(ctx, new CreateStreamCommand(method, stream), promise); + + // Make sure the write never occurred. + verify(frameObserver, never()).onHeadersRead(eq(ctx), + eq(3), + any(Http2Headers.class), + eq(0), + eq(Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT), + eq(false), + eq(0), + eq(false), + eq(false)); + } + + @Test + public void receivedGoAwayShouldFailQueuedStreams() throws Exception { + // Force a stream to get added to the pending queue. + setMaxConcurrentStreams(0); + handler.write(ctx, new CreateStreamCommand(method, stream), promise); + + handler.channelRead(ctx, goAwayFrame(0)); + verify(promise).setFailure(any(Throwable.class)); + } + + @Test + public void receivedGoAwayShouldFailUnknownStreams() throws Exception { + // Force a stream to get added to the pending queue. + handler.write(ctx, new CreateStreamCommand(method, stream), promise); + + // Read a GOAWAY that indicates our stream was never processed by the server. + handler.channelRead(ctx, goAwayFrame(0)); + ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); + InOrder inOrder = inOrder(stream); + inOrder.verify(stream, calls(1)).setStatus(captor.capture()); + assertEquals(Transport.Code.UNAVAILABLE, captor.getValue().getCode()); + } + + private void setMaxConcurrentStreams(int max) throws Exception { + ByteBuf serializedSettings = serializeSettings(new Http2Settings().maxConcurrentStreams(max)); + handler.channelRead(ctx, serializedSettings); + // Reset the context to clear this write. + mockContext(); + } + + private ByteBuf dataFrame(int streamId, boolean endStream) { + // Need to retain the content since the frameWriter releases it. + content.retain(); + ChannelHandlerContext ctx = newContext(); + frameWriter.writeData(ctx, newPromise(), streamId, content, 0, endStream, false, false); + return captureWrite(ctx); + } + + private ByteBuf goAwayFrame(int lastStreamId) { + ChannelHandlerContext ctx = newContext(); + frameWriter.writeGoAway(ctx, newPromise(), lastStreamId, 0, Unpooled.EMPTY_BUFFER); + return captureWrite(ctx); + } + + private ByteBuf rstStreamFrame(int streamId, int errorCode) { + ChannelHandlerContext ctx = newContext(); + frameWriter.writeRstStream(ctx, newPromise(), streamId, errorCode); + return captureWrite(ctx); + } + + private ByteBuf serializeSettings(Http2Settings settings) { + ChannelHandlerContext ctx = newContext(); + frameWriter.writeSettings(ctx, newPromise(), settings); + return captureWrite(ctx); + } + + private ChannelHandlerContext newContext() { + ChannelHandlerContext ctx = Mockito.mock(ChannelHandlerContext.class); + when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT); + return ctx; + } + + private ChannelPromise newPromise() { + return Mockito.mock(ChannelPromise.class); + } + + private ByteBuf captureWrite(ChannelHandlerContext ctx) { + ArgumentCaptor captor = ArgumentCaptor.forClass(ByteBuf.class); + verify(ctx).writeAndFlush(captor.capture(), any(ChannelPromise.class)); + return captor.getValue(); + } + + private void createStream() throws Exception { + // Create the stream. + handler.write(ctx, new CreateStreamCommand(method, stream), promise); + when(stream.id()).thenReturn(3); + // Reset the context mock to clear recording of sent headers frame. + mockContext(); + } + + private void mockContext() { + Mockito.reset(ctx); + Mockito.reset(promise); + when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT); + when(ctx.channel()).thenReturn(channel); + when(ctx.write(any())).thenReturn(future); + when(ctx.write(any(), eq(promise))).thenReturn(future); + when(ctx.writeAndFlush(any())).thenReturn(future); + when(ctx.writeAndFlush(any(), eq(promise))).thenReturn(future); + when(ctx.newPromise()).thenReturn(promise); + } + + private void mockFuture(boolean succeeded) { + when(future.isDone()).thenReturn(true); + when(future.isCancelled()).thenReturn(false); + when(future.isSuccess()).thenReturn(succeeded); + if (!succeeded) { + when(future.cause()).thenReturn(new Exception("fake")); + } + + doAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocation) throws Throwable { + ChannelFutureListener listener = (ChannelFutureListener) invocation.getArguments()[0]; + listener.operationComplete(future); + return future; + } + }).when(future).addListener(any(ChannelFutureListener.class)); + } +} diff --git a/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyClientStreamTest.java b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyClientStreamTest.java new file mode 100644 index 00000000000..97c63060624 --- /dev/null +++ b/core/src/test/java/com/google/net/stubby/newtransport/netty/NettyClientStreamTest.java @@ -0,0 +1,267 @@ +package com.google.net.stubby.newtransport.netty; + +import static com.google.net.stubby.GrpcFramingUtil.CONTEXT_VALUE_FRAME; +import static com.google.net.stubby.GrpcFramingUtil.PAYLOAD_FRAME; +import static com.google.net.stubby.GrpcFramingUtil.STATUS_FRAME; +import static io.netty.util.CharsetUtil.UTF_8; +import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyLong; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.io.ByteStreams; +import com.google.net.stubby.Status; +import com.google.net.stubby.newtransport.StreamListener; +import com.google.net.stubby.newtransport.StreamState; +import com.google.net.stubby.transport.Transport; +import com.google.net.stubby.transport.Transport.ContextValue; +import com.google.protobuf.ByteString; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelPromise; +import io.netty.channel.EventLoop; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; +import java.io.InputStream; +import java.util.concurrent.TimeUnit; + +/** + * Tests for {@link NettyClientStream}. + */ +@RunWith(JUnit4.class) +public class NettyClientStreamTest { + private static final String CONTEXT_KEY = "key"; + private static final String MESSAGE = "hello world"; + + private NettyClientStream stream; + + @Mock + private StreamListener listener; + + @Mock + private Channel channel; + + @Mock + private ChannelFuture future; + + @Mock + private ChannelPromise promise; + + @Mock + private EventLoop eventLoop; + + private InputStream input; + + @Mock + private Runnable accepted; + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + + mockChannelFuture(true); + when(channel.write(any())).thenReturn(future); + when(channel.writeAndFlush(any())).thenReturn(future); + when(channel.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT); + when(channel.eventLoop()).thenReturn(eventLoop); + when(eventLoop.inEventLoop()).thenReturn(true); + + stream = new NettyClientStream(listener, channel); + assertEquals(StreamState.OPEN, stream.state()); + input = new ByteArrayInputStream(MESSAGE.getBytes(UTF_8)); + } + + @Test + public void closeShouldSucceed() { + // Force stream creation. + stream.id(1); + stream.close(); + assertEquals(StreamState.READ_ONLY, stream.state()); + } + + @Test + public void cancelShouldSendCommand() { + stream.cancel(); + verify(channel).writeAndFlush(any(CancelStreamCommand.class)); + } + + @Test + public void writeContextShouldSendRequest() throws Exception { + // Force stream creation. + stream.id(1); + stream.writeContext(CONTEXT_KEY, input, input.available(), accepted); + stream.flush(); + ArgumentCaptor captor = + ArgumentCaptor.forClass(SendGrpcFrameCommand.class); + verify(channel).writeAndFlush(captor.capture()); + assertEquals(contextFrame(), captor.getValue().content()); + verify(accepted).run(); + } + + @Test + public void writeMessageShouldSendRequest() throws Exception { + // Force stream creation. + stream.id(1); + stream.writeMessage(input, input.available(), accepted); + stream.flush(); + ArgumentCaptor captor = + ArgumentCaptor.forClass(SendGrpcFrameCommand.class); + verify(channel).writeAndFlush(captor.capture()); + assertEquals(messageFrame(), captor.getValue().content()); + verify(accepted).run(); + } + + @Test + public void setStatusWithOkShouldCloseStream() { + stream.id(1); + stream.setStatus(Status.OK); + verify(listener).closed(Status.OK); + assertEquals(StreamState.CLOSED, stream.state()); + } + + @Test + public void setStatusWithErrorShouldCloseStream() { + Status errorStatus = new Status(Transport.Code.INTERNAL); + stream.setStatus(errorStatus); + verify(listener).closed(eq(errorStatus)); + assertEquals(StreamState.CLOSED, stream.state()); + } + + @Test + public void setStatusWithOkShouldNotOverrideError() { + Status errorStatus = new Status(Transport.Code.INTERNAL); + stream.setStatus(errorStatus); + stream.setStatus(Status.OK); + verify(listener).closed(any(Status.class)); + assertEquals(StreamState.CLOSED, stream.state()); + } + + @Test + public void setStatusWithErrorShouldNotOverridePreviousError() { + Status errorStatus = new Status(Transport.Code.INTERNAL); + stream.setStatus(errorStatus); + stream.setStatus(Status.fromThrowable(new RuntimeException("fake"))); + verify(listener).closed(any(Status.class)); + assertEquals(StreamState.CLOSED, stream.state()); + } + + @Test + public void inboundContextShouldCallListener() throws Exception { + stream.inboundDataReceived(contextFrame(), false, promise); + ArgumentCaptor captor = ArgumentCaptor.forClass(InputStream.class); + verify(listener).contextRead(eq(CONTEXT_KEY), captor.capture(), eq(MESSAGE.length())); + verify(promise).setSuccess(); + assertEquals(MESSAGE, toString(captor.getValue())); + } + + @Test + public void inboundMessageShouldCallListener() throws Exception { + stream.inboundDataReceived(messageFrame(), false, promise); + ArgumentCaptor captor = ArgumentCaptor.forClass(InputStream.class); + verify(listener).messageRead(captor.capture(), eq(MESSAGE.length())); + verify(promise).setSuccess(); + assertEquals(MESSAGE, toString(captor.getValue())); + } + + @Test + public void inboundStatusShouldSetStatus() throws Exception { + stream.id(1); + stream.inboundDataReceived(statusFrame(), false, promise); + ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); + verify(listener).closed(captor.capture()); + assertEquals(Transport.Code.INTERNAL, captor.getValue().getCode()); + verify(promise).setSuccess(); + assertEquals(StreamState.CLOSED, stream.state()); + } + + private String toString(InputStream in) throws Exception { + byte[] bytes = new byte[in.available()]; + ByteStreams.readFully(in, bytes); + return new String(bytes, UTF_8); + } + + private ByteBuf contextFrame() throws Exception { + byte[] body = ContextValue.newBuilder().setKey(CONTEXT_KEY) + .setValue(ByteString.copyFromUtf8(MESSAGE)).build().toByteArray(); + ByteArrayOutputStream os = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(os); + dos.write(CONTEXT_VALUE_FRAME); + dos.writeInt(body.length); + dos.write(body); + dos.close(); + + // Write the compression header followed by the context frame. + return compressionFrame(os.toByteArray()); + } + + private ByteBuf messageFrame() throws Exception { + ByteArrayOutputStream os = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(os); + dos.write(PAYLOAD_FRAME); + dos.writeInt(MESSAGE.length()); + dos.write(MESSAGE.getBytes(UTF_8)); + dos.close(); + + // Write the compression header followed by the context frame. + return compressionFrame(os.toByteArray()); + } + + private ByteBuf statusFrame() throws Exception { + ByteArrayOutputStream os = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(os); + short code = (short) Transport.Code.INTERNAL.getNumber(); + dos.write(STATUS_FRAME); + int length = 2; + dos.writeInt(length); + dos.writeShort(code); + + // Write the compression header followed by the context frame. + return compressionFrame(os.toByteArray()); + } + + private ByteBuf compressionFrame(byte[] data) { + ByteBuf buf = Unpooled.buffer(); + buf.writeInt(data.length); + buf.writeBytes(data); + return buf; + } + + private void mockChannelFuture(boolean succeeded) { + when(future.isDone()).thenReturn(true); + when(future.isCancelled()).thenReturn(false); + when(future.isSuccess()).thenReturn(succeeded); + when(future.awaitUninterruptibly(anyLong(), any(TimeUnit.class))).thenReturn(true); + if (!succeeded) { + when(future.cause()).thenReturn(new Exception("fake")); + } + + doAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocation) throws Throwable { + ChannelFutureListener listener = (ChannelFutureListener) invocation.getArguments()[0]; + listener.operationComplete(future); + return future; + } + }).when(future).addListener(any(ChannelFutureListener.class)); + } +}