Skip to content

Commit

Permalink
Close OkHttp ResponseBody after writeBodyTo and writeBodyToAsync (#36650
Browse files Browse the repository at this point in the history
)

Close OkHttp ResponseBody after writeBodyTo and writeBodyToAsync
  • Loading branch information
alzimmermsft authored Sep 12, 2023
1 parent d26e243 commit f4c38bb
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2676,4 +2676,11 @@
<Field name="instance"/>
<Bug pattern="LI_LAZY_INIT_STATIC"/>
</Match>

<!-- False positive, complaining about dereference of null 'source' on some paths, which can't happen based on the null validation checks. -->
<Match>
<Class name="com.azure.core.util.io.IOUtils"/>
<Method name="transfer"/>
<Bug pattern="NP_NULL_ON_SOME_PATH"/>
</Match>
</FindBugsFilter>
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import reactor.netty.NettyPipeline;
import reactor.netty.http.client.HttpClientRequest;
import reactor.netty.http.client.HttpClientResponse;
import reactor.util.retry.Retry;

import javax.net.ssl.SSLException;
import java.io.IOException;
Expand Down Expand Up @@ -114,6 +113,13 @@ public Mono<HttpResponse> send(HttpRequest request, Context context) {
.orElse(null);
ProgressReporter progressReporter = Contexts.with(context).getHttpRequestProgressReporter();

return attemptAsync(request, eagerlyReadResponse, ignoreResponseBody, headersEagerlyConverted, responseTimeout,
progressReporter, false);
}

private Mono<HttpResponse> attemptAsync(HttpRequest request, boolean eagerlyReadResponse,
boolean ignoreResponseBody, boolean headersEagerlyConverted, Long responseTimeout,
ProgressReporter progressReporter, boolean proxyRetry) {
Flux<HttpResponse> nettyRequest = nettyClient.request(toReactorNettyHttpMethod(request.getHttpMethod()))
.uri(request.getUrl().toString())
.send(bodySendDelegate(request))
Expand All @@ -128,22 +134,27 @@ public Mono<HttpResponse> send(HttpRequest request, Context context) {
return nettyRequest.single()
.flatMap(response -> {
if (addProxyHandler && response.getStatusCode() == 407) {
return Mono.error(new ProxyConnectException("First attempt to connect to proxy failed."));
return proxyRetry
? Mono.error(new ProxyConnectException("Connection to proxy failed."))
: Mono.error(new ProxyConnectException("First attempt to connect to proxy failed."));
} else {
return Mono.just(response);
}
})
.onErrorMap(throwable -> {
// The exception was an SSLException that was caused by a failure to connect to a proxy.
// Extract the inner ProxyConnectException and propagate that instead.
if (throwable instanceof SSLException && throwable.getCause() instanceof ProxyConnectException) {
return throwable.getCause();
}
.onErrorResume(throwable -> shouldRetryProxyError(proxyRetry, throwable)
? attemptAsync(request, eagerlyReadResponse, ignoreResponseBody, headersEagerlyConverted,
responseTimeout, progressReporter, true)
: Mono.error(throwable));
}

return throwable;
})
.retryWhen(Retry.max(1).filter(throwable -> throwable instanceof ProxyConnectException)
.onRetryExhaustedThrow((ignoredSpec, signal) -> signal.failure()));
private static boolean shouldRetryProxyError(boolean proxyRetry, Throwable throwable) {
// Only retry if this is the first attempt to connect to a proxy and the exception was caused by a failure to
// connect to the proxy.
// Sometimes connecting to the proxy may return an SSLException that wraps the ProxyConnectException, this
// generally happens if the proxy is using SSL.
return !proxyRetry
&& (throwable instanceof ProxyConnectException
|| (throwable instanceof SSLException && throwable.getCause() instanceof ProxyConnectException));
}

@Override
Expand Down Expand Up @@ -323,8 +334,8 @@ private static HttpMethod toReactorNettyHttpMethod(com.azure.core.http.HttpMetho
case TRACE: return HttpMethod.TRACE;
case CONNECT: return HttpMethod.CONNECT;
case OPTIONS: return HttpMethod.OPTIONS;
default: throw LOGGER.logExceptionAsError(new IllegalStateException("Unknown HttpMethod '"
+ azureHttpMethod + "'.")); // Should never happen
default: throw LOGGER.logExceptionAsError(
new IllegalStateException("Unknown HttpMethod '" + azureHttpMethod + "'.")); // Should never happen
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,10 @@ public Mono<InputStream> getBodyAsInputStream() {

@Override
public Mono<Void> writeBodyToAsync(AsynchronousByteChannel channel) {
return Mono.<Void>create(sink -> bodyIntern().subscribe(
new ByteBufWriteSubscriber(byteBuffer -> channel.write(byteBuffer).get(), sink, getContentLength())))
.doFinally(ignored -> close());
Long length = getContentLength();
return Mono.using(() -> this, response -> Mono.create(sink -> response.bodyIntern()
.subscribe(new ByteBufWriteSubscriber(byteBuffer -> channel.write(byteBuffer).get(), sink, length))),
NettyAsyncHttpResponse::close);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
import com.azure.core.util.Contexts;
import com.azure.core.util.FluxUtil;
import com.azure.core.util.ProgressReporter;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.handler.proxy.ProxyConnectException;
import io.netty.resolver.DefaultAddressResolverGroup;
import io.netty.resolver.NoopAddressResolverGroup;
Expand All @@ -46,6 +49,7 @@
import reactor.test.StepVerifier;
import reactor.test.StepVerifierOptions;

import javax.net.ssl.SSLException;
import javax.servlet.ServletException;
import java.io.ByteArrayOutputStream;
import java.io.FileOutputStream;
Expand Down Expand Up @@ -619,6 +623,31 @@ public void failedProxyAuthenticationReturnsCorrectError() {
}
}

@Test
public void sslExceptionWrappedProxyConnectExceptionDoesNotRetryInfinitely() {
try (MockProxyServer mockProxyServer = new MockProxyServer("1", "1")) {
HttpPipeline httpPipeline = new HttpPipelineBuilder()
.httpClient(new NettyAsyncHttpClientBuilder(reactor.netty.http.client.HttpClient.create()
.doOnRequest((req, conn) -> {
conn.addHandlerLast("sslException", new ChannelOutboundHandlerAdapter() {
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
promise.setFailure(new SSLException(
new ProxyConnectException("Simulated SSLException")));
}
});
}))
.proxy(new ProxyOptions(ProxyOptions.Type.HTTP, mockProxyServer.socketAddress())
.setCredentials("1", "1"))
.build())
.build();

StepVerifier.create(httpPipeline.send(new HttpRequest(HttpMethod.GET, url(server, PROXY_TO_ADDRESS))))
.verifyErrorMatches(exception -> exception instanceof SSLException
&& exception.getCause() instanceof ProxyConnectException);
}
}

@Test
public void httpClientWithDefaultResolverUsesNoopResolverWithProxy() {
try (MockProxyServer mockProxyServer = new MockProxyServer()) {
Expand Down
3 changes: 3 additions & 0 deletions sdk/core/azure-core-http-okhttp/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

### Other Changes

- Changed buffer read size from `4096` to `8192` when returning `Flux<ByteBuffer>` in `HttpResponse` to reduce number
of reads and elements emitted by the `Flux`.

## 1.11.13 (2023-09-07)

### Other Changes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@

import com.azure.core.http.HttpRequest;
import com.azure.core.util.BinaryData;
import com.azure.core.util.FluxUtil;
import com.azure.core.util.io.IOUtils;
import okhttp3.Response;
import okhttp3.ResponseBody;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.util.function.Tuple2;
import reactor.util.function.Tuples;

import java.io.IOException;
import java.io.InputStream;
Expand All @@ -23,10 +22,9 @@
* Default HTTP response for OkHttp.
*/
public final class OkHttpAsyncResponse extends OkHttpAsyncResponseBase {
// using 4K as default buffer size: https://stackoverflow.com/a/237495/1473510
private static final int BYTE_BUFFER_CHUNK_SIZE = 4096;

private static final ByteBuffer EMPTY_BYTE_BUFFER = ByteBuffer.allocate(0);
// Previously, this was 4096, but it is being changed to 8192 as that more closely aligns to what Netty uses as a
// default and will reduce the number of small allocations we'll need to make.
private static final int BYTE_BUFFER_CHUNK_SIZE = 8192;

private final ResponseBody responseBody;

Expand All @@ -52,31 +50,11 @@ public Flux<ByteBuffer> getBody() {
}

// Use Flux.using to close the stream after complete emission
return Flux.using(this.responseBody::byteStream, OkHttpAsyncResponse::toFluxByteBuffer,
return Flux.using(this.responseBody::byteStream,
bodyStream -> FluxUtil.toFluxByteBuffer(bodyStream, BYTE_BUFFER_CHUNK_SIZE),
bodyStream -> this.close(), false);
}

private static Flux<ByteBuffer> toFluxByteBuffer(InputStream responseBody) {
return Flux.just(true)
.repeat()
.flatMap(ignored -> {
byte[] buffer = new byte[BYTE_BUFFER_CHUNK_SIZE];
try {
int read = responseBody.read(buffer);
if (read > 0) {
return Mono.just(Tuples.of(read, ByteBuffer.wrap(buffer, 0, read)));
} else {
return Mono.just(Tuples.of(read, EMPTY_BYTE_BUFFER));
}
} catch (IOException ex) {
return Mono.error(ex);
}
})
.takeUntil(tuple -> tuple.getT1() == -1)
.filter(tuple -> tuple.getT1() > 0)
.map(Tuple2::getT2);
}

@Override
public Mono<byte[]> getBodyAsByteArray() {
return Mono.fromCallable(() -> {
Expand Down Expand Up @@ -108,14 +86,20 @@ public Mono<InputStream> getBodyAsInputStream() {
@Override
public void writeBodyTo(WritableByteChannel channel) throws IOException {
if (responseBody != null) {
IOUtils.transfer(responseBody.source(), channel);
try {
IOUtils.transfer(responseBody.source(), channel, responseBody.contentLength());
} finally {
close();
}
}
}

@Override
public Mono<Void> writeBodyToAsync(AsynchronousByteChannel channel) {
if (responseBody != null) {
return IOUtils.transferAsync(responseBody.source(), channel);
return Mono.using(() -> this,
ignored -> IOUtils.transferAsync(responseBody.source(), channel, responseBody.contentLength()),
OkHttpAsyncResponse::close);
} else {
return Mono.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ public final class IOUtils {
private static final ClientLogger LOGGER = new ClientLogger(IOUtils.class);

private static final int DEFAULT_BUFFER_SIZE = 8192;
private static final int SIXTY_FOUR_KB = 64 * 1024;
private static final int THIRTY_TWO_KB = 32 * 1024;
private static final int MB = 1024 * 1024;
private static final int GB = 1024 * MB;

/**
* Adapts {@link AsynchronousFileChannel} to {@link AsynchronousByteChannel}.
Expand All @@ -40,8 +44,8 @@ public final class IOUtils {
* @throws NullPointerException When {@code fileChannel} is null.
* @throws IllegalArgumentException When {@code position} is negative.
*/
public static AsynchronousByteChannel toAsynchronousByteChannel(
AsynchronousFileChannel fileChannel, long position) {
public static AsynchronousByteChannel toAsynchronousByteChannel(AsynchronousFileChannel fileChannel,
long position) {
Objects.requireNonNull(fileChannel, "'fileChannel' must not be null");
if (position < 0) {
throw LOGGER.logExceptionAsError(new IllegalArgumentException("'position' cannot be less than 0."));
Expand All @@ -51,16 +55,38 @@ public static AsynchronousByteChannel toAsynchronousByteChannel(

/**
* Transfers bytes from {@link ReadableByteChannel} to {@link WritableByteChannel}.
*
* @param source A source {@link ReadableByteChannel}.
* @param destination A destination {@link WritableByteChannel}.
* @throws IOException When I/O operation fails.
* @throws NullPointerException When {@code source} is null.
* @throws NullPointerException When {@code destination} is null.
* @throws NullPointerException When {@code source} or {@code destination} is null.
*/
public static void transfer(ReadableByteChannel source, WritableByteChannel destination) throws IOException {
Objects.requireNonNull(source, "'source' must not be null");
Objects.requireNonNull(source, "'destination' must not be null");
ByteBuffer buffer = ByteBuffer.allocate(getBufferSize(source));
transfer(source, destination, null);
}

/**
* Transfers bytes from {@link ReadableByteChannel} to {@link WritableByteChannel}.
*
* @param source A source {@link ReadableByteChannel}.
* @param destination A destination {@link WritableByteChannel}.
* @param estimatedSourceSize An estimated size of the source channel, may be null. Used to better determine the
* size of the buffer used to transfer data in an attempt to reduce read and write calls.
* @throws IOException When I/O operation fails.
* @throws NullPointerException When {@code source} or {@code destination} is null.
*/
public static void transfer(ReadableByteChannel source, WritableByteChannel destination, Long estimatedSourceSize)
throws IOException {
if (source == null && destination == null) {
throw new NullPointerException("'source' and 'destination' cannot be null.");
} else if (source == null) {
throw new NullPointerException("'source' cannot be null.");
} else if (destination == null) {
throw new NullPointerException("'destination' cannot be null.");
}

int bufferSize = (estimatedSourceSize == null) ? getBufferSize(source) : getBufferSize(estimatedSourceSize);
ByteBuffer buffer = ByteBuffer.allocate(bufferSize);
int read;
do {
buffer.clear();
Expand All @@ -74,17 +100,39 @@ public static void transfer(ReadableByteChannel source, WritableByteChannel dest

/**
* Transfers bytes from {@link ReadableByteChannel} to {@link AsynchronousByteChannel}.
*
* @param source A source {@link ReadableByteChannel}.
* @param destination A destination {@link AsynchronousByteChannel}.
* @return A {@link Mono} that completes when transfer is finished.
* @throws NullPointerException When {@code source} is null.
* @throws NullPointerException When {@code destination} is null.
* @throws NullPointerException When {@code source} or {@code destination} is null.
*/
public static Mono<Void> transferAsync(ReadableByteChannel source, AsynchronousByteChannel destination) {
Objects.requireNonNull(source, "'source' must not be null");
Objects.requireNonNull(source, "'destination' must not be null");
return transferAsync(source, destination, null);
}

/**
* Transfers bytes from {@link ReadableByteChannel} to {@link AsynchronousByteChannel}.
*
* @param source A source {@link ReadableByteChannel}.
* @param destination A destination {@link AsynchronousByteChannel}.
* @param estimatedSourceSize An estimated size of the source channel, may be null. Used to better determine the
* size of the buffer used to transfer data in an attempt to reduce read and write calls.
* @return A {@link Mono} that completes when transfer is finished.
* @throws NullPointerException When {@code source} or {@code destination} is null.
*/
public static Mono<Void> transferAsync(ReadableByteChannel source, AsynchronousByteChannel destination,
Long estimatedSourceSize) {
if (source == null && destination == null) {
return Mono.error(new NullPointerException("'source' and 'destination' cannot be null."));
} else if (source == null) {
return Mono.error(new NullPointerException("'source' cannot be null."));
} else if (destination == null) {
return Mono.error(new NullPointerException("'destination' cannot be null."));
}

int bufferSize = (estimatedSourceSize == null) ? getBufferSize(source) : getBufferSize(estimatedSourceSize);
return Mono.create(sink -> sink.onRequest(value -> {
ByteBuffer buffer = ByteBuffer.allocate(getBufferSize(source));
ByteBuffer buffer = ByteBuffer.allocate(bufferSize);
try {
transferAsynchronously(source, destination, buffer, sink);
} catch (IOException e) {
Expand All @@ -93,8 +141,7 @@ public static Mono<Void> transferAsync(ReadableByteChannel source, AsynchronousB
}));
}

private static void transferAsynchronously(
ReadableByteChannel source, AsynchronousByteChannel destination,
private static void transferAsynchronously(ReadableByteChannel source, AsynchronousByteChannel destination,
ByteBuffer buffer, MonoSink<Void> sink) throws IOException {
buffer.clear();
int read = source.read(buffer);
Expand Down Expand Up @@ -194,21 +241,23 @@ private static int getBufferSize(ReadableByteChannel source) {
long size = seekableSource.size();
long position = seekableSource.position();

long count = size - position;

if (count > 1024 * 1024 * 1024) {
return 65536;
} else if (count > 1024 * 1024) {
return 32768;
} else {
return DEFAULT_BUFFER_SIZE;
}
return getBufferSize(size - position);
} catch (IOException ex) {
// Don't let an IOException prevent transfer when we are only trying to gain information.
return DEFAULT_BUFFER_SIZE;
}
}

private static int getBufferSize(long dataSize) {
if (dataSize > GB) {
return SIXTY_FOUR_KB;
} else if (dataSize > MB) {
return THIRTY_TWO_KB;
} else {
return DEFAULT_BUFFER_SIZE;
}
}

private IOUtils() {
}
}

0 comments on commit f4c38bb

Please sign in to comment.