Skip to content

Commit

Permalink
okhttp: Avoid test-specific transport.start()
Browse files Browse the repository at this point in the history
With the completely different constructor it was hard to track which
fields were different during the test and reduced confidence. Now the
test code flows are much closer to the real-life code flows.
  • Loading branch information
ejona86 committed Apr 4, 2022
1 parent 004ee10 commit a978c9e
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,14 @@ final class ExceptionHandlingFrameWriter implements FrameWriter {

private final FrameWriter frameWriter;

private final OkHttpFrameLogger frameLogger;
private final OkHttpFrameLogger frameLogger =
new OkHttpFrameLogger(Level.FINE, OkHttpClientTransport.class);

ExceptionHandlingFrameWriter(
TransportExceptionHandler transportExceptionHandler, FrameWriter frameWriter) {
this(transportExceptionHandler, frameWriter,
new OkHttpFrameLogger(Level.FINE, OkHttpClientTransport.class));
}

@VisibleForTesting
ExceptionHandlingFrameWriter(
TransportExceptionHandler transportExceptionHandler,
FrameWriter frameWriter,
OkHttpFrameLogger frameLogger) {
this.transportExceptionHandler =
checkNotNull(transportExceptionHandler, "transportExceptionHandler");
this.frameWriter = Preconditions.checkNotNull(frameWriter, "frameWriter");
this.frameLogger = Preconditions.checkNotNull(frameLogger, "frameLogger");
}

@Override
Expand Down
152 changes: 81 additions & 71 deletions okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,8 @@ private static Map<ErrorCode, Status> buildErrorCodeToStatusMap() {
// Returns new unstarted stopwatches
private final Supplier<Stopwatch> stopwatchFactory;
private final int initialWindowSize;
private final Variant variant;
private Listener listener;
private FrameReader testFrameReader;
private OkHttpFrameLogger testFrameLogger;
@GuardedBy("lock")
private ExceptionHandlingFrameWriter frameWriter;
private OutboundFlowController outboundFlow;
Expand Down Expand Up @@ -192,7 +191,6 @@ private static Map<ErrorCode, Status> buildErrorCodeToStatusMap() {
@GuardedBy("lock")
private final Deque<OkHttpClientStream> pendingStreams = new LinkedList<>();
private final ConnectionSpec connectionSpec;
private FrameWriter testFrameWriter;
private ScheduledExecutorService scheduler;
private KeepAliveManager keepAliveManager;
private boolean enableKeepAlive;
Expand Down Expand Up @@ -228,7 +226,7 @@ protected void handleNotInUse() {
Runnable connectingCallback;
SettableFuture<Void> connectedFuture;

OkHttpClientTransport(
public OkHttpClientTransport(
InetSocketAddress address,
String authority,
@Nullable String userAgent,
Expand All @@ -245,6 +243,46 @@ protected void handleNotInUse() {
int maxInboundMetadataSize,
TransportTracer transportTracer,
boolean useGetForSafeMethods) {
this(
address,
authority,
userAgent,
eagAttrs,
executor,
socketFactory,
sslSocketFactory,
hostnameVerifier,
connectionSpec,
GrpcUtil.STOPWATCH_SUPPLIER,
new Http2(),
maxMessageSize,
initialWindowSize,
proxiedAddr,
tooManyPingsRunnable,
maxInboundMetadataSize,
transportTracer,
useGetForSafeMethods);
}

private OkHttpClientTransport(
InetSocketAddress address,
String authority,
@Nullable String userAgent,
Attributes eagAttrs,
Executor executor,
@Nullable SocketFactory socketFactory,
@Nullable SSLSocketFactory sslSocketFactory,
@Nullable HostnameVerifier hostnameVerifier,
ConnectionSpec connectionSpec,
Supplier<Stopwatch> stopwatchFactory,
Variant variant,
int maxMessageSize,
int initialWindowSize,
@Nullable HttpConnectProxiedSocketAddress proxiedAddr,
Runnable tooManyPingsRunnable,
int maxInboundMetadataSize,
TransportTracer transportTracer,
boolean useGetForSafeMethods) {
this.address = Preconditions.checkNotNull(address, "address");
this.defaultAuthority = authority;
this.maxMessageSize = maxMessageSize;
Expand All @@ -258,7 +296,8 @@ protected void handleNotInUse() {
this.sslSocketFactory = sslSocketFactory;
this.hostnameVerifier = hostnameVerifier;
this.connectionSpec = Preconditions.checkNotNull(connectionSpec, "connectionSpec");
this.stopwatchFactory = GrpcUtil.STOPWATCH_SUPPLIER;
this.stopwatchFactory = Preconditions.checkNotNull(stopwatchFactory, "stopwatchFactory");
this.variant = Preconditions.checkNotNull(variant, "variant");
this.userAgent = GrpcUtil.getGrpcUserAgent("okhttp", userAgent);
this.proxiedAddr = proxiedAddr;
this.tooManyPingsRunnable =
Expand All @@ -279,43 +318,36 @@ protected void handleNotInUse() {
OkHttpClientTransport(
String userAgent,
Executor executor,
FrameReader frameReader,
FrameWriter testFrameWriter,
OkHttpFrameLogger testFrameLogger,
int nextStreamId,
Socket socket,
@Nullable SocketFactory socketFactory,
Supplier<Stopwatch> stopwatchFactory,
Variant variant,
@Nullable Runnable connectingCallback,
SettableFuture<Void> connectedFuture,
int maxMessageSize,
int initialWindowSize,
Runnable tooManyPingsRunnable,
TransportTracer transportTracer) {
useGetForSafeMethods = false;
address = null;
this.maxMessageSize = maxMessageSize;
this.initialWindowSize = initialWindowSize;
defaultAuthority = "notarealauthority:80";
this.userAgent = GrpcUtil.getGrpcUserAgent("okhttp", userAgent);
this.executor = Preconditions.checkNotNull(executor, "executor");
serializingExecutor = new SerializingExecutor(executor);
this.socketFactory = SocketFactory.getDefault();
this.testFrameReader = Preconditions.checkNotNull(frameReader, "frameReader");
this.testFrameWriter = Preconditions.checkNotNull(testFrameWriter, "testFrameWriter");
this.testFrameLogger = Preconditions.checkNotNull(testFrameLogger, "testFrameLogger");
this.socket = Preconditions.checkNotNull(socket, "socket");
this.nextStreamId = nextStreamId;
this.stopwatchFactory = stopwatchFactory;
this.connectionSpec = null;
this(
new InetSocketAddress("127.0.0.1", 80),
"notarealauthority:80",
userAgent,
Attributes.EMPTY,
executor,
socketFactory,
null,
null,
OkHttpChannelBuilder.INTERNAL_DEFAULT_CONNECTION_SPEC,
stopwatchFactory,
variant,
maxMessageSize,
initialWindowSize,
null,
tooManyPingsRunnable,
Integer.MAX_VALUE,
transportTracer,
false);
this.connectingCallback = connectingCallback;
this.connectedFuture = Preconditions.checkNotNull(connectedFuture, "connectedFuture");
this.proxiedAddr = null;
this.tooManyPingsRunnable =
Preconditions.checkNotNull(tooManyPingsRunnable, "tooManyPingsRunnable");
this.maxInboundMetadataSize = Integer.MAX_VALUE;
this.transportTracer = Preconditions.checkNotNull(transportTracer, "transportTracer");
this.logId = InternalLogId.allocate(getClass(), String.valueOf(socket.getInetAddress()));
initTransportTracer();
}

// sslSocketFactory is set to null when use plaintext.
Expand Down Expand Up @@ -349,10 +381,6 @@ void enableKeepAlive(boolean enable, long keepAliveTimeNanos,
this.keepAliveWithoutCalls = keepAliveWithoutCalls;
}

private boolean isForTest() {
return address == null;
}

@Override
public void ping(final PingCallback callback, Executor executor) {
long data = 0;
Expand Down Expand Up @@ -488,32 +516,8 @@ public Runnable start(Listener listener) {
keepAliveWithoutCalls);
keepAliveManager.onTransportStarted();
}
if (isForTest()) {
synchronized (lock) {
frameWriter = new ExceptionHandlingFrameWriter(OkHttpClientTransport.this, testFrameWriter,
testFrameLogger);
outboundFlow = new OutboundFlowController(OkHttpClientTransport.this, frameWriter);
}
serializingExecutor.execute(new Runnable() {
@Override
public void run() {
if (connectingCallback != null) {
connectingCallback.run();
}
clientFrameHandler = new ClientFrameHandler(testFrameReader, testFrameLogger);
executor.execute(clientFrameHandler);
synchronized (lock) {
maxConcurrentStreams = Integer.MAX_VALUE;
startPendingStreams();
}
connectedFuture.set(null);
}
});
return null;
}

final AsyncSink asyncSink = AsyncSink.sink(serializingExecutor, this);
final Variant variant = new Http2();
FrameWriter rawFrameWriter = variant.newWriter(Okio.buffer(asyncSink), true);

synchronized (lock) {
Expand Down Expand Up @@ -616,13 +620,19 @@ sslSocketFactory, hostnameVerifier, sock, getOverridenHost(), getOverridenPort()
serializingExecutor.execute(new Runnable() {
@Override
public void run() {
if (connectingCallback != null) {
connectingCallback.run();
}
// ClientFrameHandler need to be started after connectionPreface / settings, otherwise it
// may send goAway immediately.
executor.execute(clientFrameHandler);
synchronized (lock) {
maxConcurrentStreams = Integer.MAX_VALUE;
startPendingStreams();
}
if (connectedFuture != null) {
connectedFuture.set(null);
}
}
});
return null;
Expand All @@ -631,8 +641,7 @@ public void run() {
/**
* Should only be called once when the transport is first established.
*/
@VisibleForTesting
void sendConnectionPrefaceAndSettings() {
private void sendConnectionPrefaceAndSettings() {
synchronized (lock) {
frameWriter.connectionPreface();
Settings settings = new Settings();
Expand Down Expand Up @@ -855,6 +864,13 @@ int getPendingStreamSize() {
}
}

@VisibleForTesting
void setNextStreamId(int nextStreamId) {
synchronized (lock) {
this.nextStreamId = nextStreamId;
}
}

/**
* Finish all active streams due to an IOException, then close the transport.
*/
Expand Down Expand Up @@ -1081,21 +1097,15 @@ public ListenableFuture<SocketStats> getStats() {
/**
* Runnable which reads frames and dispatches them to in flight calls.
*/
@VisibleForTesting
class ClientFrameHandler implements FrameReader.Handler, Runnable {

private final OkHttpFrameLogger logger;
private final OkHttpFrameLogger logger =
new OkHttpFrameLogger(Level.FINE, OkHttpClientTransport.class);
FrameReader frameReader;
boolean firstSettings = true;

ClientFrameHandler(FrameReader frameReader) {
this(frameReader, new OkHttpFrameLogger(Level.FINE, OkHttpClientTransport.class));
}

@VisibleForTesting
ClientFrameHandler(FrameReader frameReader, OkHttpFrameLogger frameLogger) {
this.frameReader = frameReader;
logger = frameLogger;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ public class ExceptionHandlingFrameWriterTest {
private final TransportExceptionHandler transportExceptionHandler =
mock(TransportExceptionHandler.class);
private final ExceptionHandlingFrameWriter exceptionHandlingFrameWriter =
new ExceptionHandlingFrameWriter(transportExceptionHandler, mockedFrameWriter,
new OkHttpFrameLogger(Level.FINE, logger));
new ExceptionHandlingFrameWriter(transportExceptionHandler, mockedFrameWriter);

@Test
public void exception() throws IOException {
Expand Down Expand Up @@ -194,4 +193,4 @@ public void close() throws SecurityException {

logger.removeHandler(handler);
}
}
}
Loading

0 comments on commit a978c9e

Please sign in to comment.