diff --git a/core/src/main/java/com/linecorp/armeria/common/HttpRequest.java b/core/src/main/java/com/linecorp/armeria/common/HttpRequest.java index 3501deb8d34..26df1050e4d 100644 --- a/core/src/main/java/com/linecorp/armeria/common/HttpRequest.java +++ b/core/src/main/java/com/linecorp/armeria/common/HttpRequest.java @@ -32,6 +32,7 @@ import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.FormatMethod; @@ -43,6 +44,7 @@ import com.linecorp.armeria.common.FixedHttpRequest.TwoElementFixedHttpRequest; import com.linecorp.armeria.common.annotation.UnstableApi; import com.linecorp.armeria.common.stream.HttpDecoder; +import com.linecorp.armeria.common.stream.PublisherBasedStreamMessage; import com.linecorp.armeria.common.stream.StreamMessage; import com.linecorp.armeria.common.stream.SubscriptionOption; import com.linecorp.armeria.common.util.EventLoopCheckingFuture; @@ -263,6 +265,10 @@ static HttpRequest of(RequestHeaders headers, HttpData... contents) { /** * Creates a new instance from an existing {@link RequestHeaders} and {@link Publisher}. + * + *

Note that the {@link HttpObject}s in the {@link Publisher} are not released when + * {@link Subscription#cancel()} or {@link #abort()} is called. You should add a hook in order to + * release the elements. See {@link PublisherBasedStreamMessage} for more information. */ static HttpRequest of(RequestHeaders headers, Publisher publisher) { requireNonNull(headers, "headers"); diff --git a/core/src/main/java/com/linecorp/armeria/common/HttpResponse.java b/core/src/main/java/com/linecorp/armeria/common/HttpResponse.java index 26d229db8d3..c17a4a84a82 100644 --- a/core/src/main/java/com/linecorp/armeria/common/HttpResponse.java +++ b/core/src/main/java/com/linecorp/armeria/common/HttpResponse.java @@ -32,6 +32,7 @@ import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; import com.google.errorprone.annotations.CheckReturnValue; import com.google.errorprone.annotations.FormatMethod; @@ -43,6 +44,7 @@ import com.linecorp.armeria.common.FixedHttpResponse.TwoElementFixedHttpResponse; import com.linecorp.armeria.common.annotation.UnstableApi; import com.linecorp.armeria.common.stream.HttpDecoder; +import com.linecorp.armeria.common.stream.PublisherBasedStreamMessage; import com.linecorp.armeria.common.stream.StreamMessage; import com.linecorp.armeria.common.stream.SubscriptionOption; import com.linecorp.armeria.common.util.EventLoopCheckingFuture; @@ -387,6 +389,10 @@ static HttpResponse of(HttpObject... objs) { /** * Creates a new HTTP response whose stream is produced from an existing {@link Publisher}. + * + *

Note that the {@link HttpObject}s in the {@link Publisher} are not released when + * {@link Subscription#cancel()} or {@link #abort()} is called. You should add a hook in order to + * release the elements. See {@link PublisherBasedStreamMessage} for more information. */ static HttpResponse of(Publisher publisher) { requireNonNull(publisher, "publisher"); @@ -400,6 +406,10 @@ static HttpResponse of(Publisher publisher) { /** * Creates a new HTTP response with the specified headers whose stream is produced from an existing * {@link Publisher}. + * + *

Note that the {@link HttpObject}s in the {@link Publisher} are not released when + * {@link Subscription#cancel()} or {@link #abort()} is called. You should add a hook in order to + * release the elements. See {@link PublisherBasedStreamMessage} for more information. */ static HttpResponse of(ResponseHeaders headers, Publisher publisher) { requireNonNull(headers, "headers"); diff --git a/core/src/main/java/com/linecorp/armeria/common/stream/PublisherBasedStreamMessage.java b/core/src/main/java/com/linecorp/armeria/common/stream/PublisherBasedStreamMessage.java index 840acce9c91..5cc4e80032f 100644 --- a/core/src/main/java/com/linecorp/armeria/common/stream/PublisherBasedStreamMessage.java +++ b/core/src/main/java/com/linecorp/armeria/common/stream/PublisherBasedStreamMessage.java @@ -17,6 +17,7 @@ package com.linecorp.armeria.common.stream; import static com.linecorp.armeria.common.stream.StreamMessageUtil.containsNotifyCancellation; +import static com.linecorp.armeria.common.stream.StreamMessageUtil.containsWithPooledObjects; import static com.linecorp.armeria.common.stream.SubscriberUtil.abortedOrLate; import static com.linecorp.armeria.common.util.Exceptions.throwIfFatal; import static java.util.Objects.requireNonNull; @@ -39,6 +40,7 @@ import com.linecorp.armeria.common.util.CompositeException; import com.linecorp.armeria.common.util.EventLoopCheckingFuture; import com.linecorp.armeria.internal.common.stream.NoopSubscription; +import com.linecorp.armeria.unsafe.PooledObjects; import io.netty.util.concurrent.EventExecutor; import io.netty.util.concurrent.ImmediateEventExecutor; @@ -46,6 +48,13 @@ /** * Adapts a {@link Publisher} into a {@link StreamMessage}. * + *

Note that the elements in the {@link Publisher} are not released when {@link Subscription#cancel()} or + * {@link #abort()} is called. So you should add a hook in order to release the elements. You can use + * doOnDiscard + * if you are using Reactor, or you can use + * doOnDispose + * if you are using RxJava. + * * @param the type of element signaled */ @UnstableApi @@ -97,24 +106,24 @@ public final long demand() { @Override public final void subscribe(Subscriber subscriber, EventExecutor executor) { - subscribe0(subscriber, executor, false); + subscribe0(subscriber, executor, false, false); } @Override public final void subscribe(Subscriber subscriber, EventExecutor executor, SubscriptionOption... options) { requireNonNull(options, "options"); - + final boolean withPooledObjects = containsWithPooledObjects(options); final boolean notifyCancellation = containsNotifyCancellation(options); - subscribe0(subscriber, executor, notifyCancellation); + subscribe0(subscriber, executor, withPooledObjects, notifyCancellation); } private void subscribe0(Subscriber subscriber, EventExecutor executor, - boolean notifyCancellation) { + boolean withPooledObjects, boolean notifyCancellation) { requireNonNull(subscriber, "subscriber"); requireNonNull(executor, "executor"); - if (!subscribe1(subscriber, executor, notifyCancellation)) { + if (!subscribe1(subscriber, executor, withPooledObjects, notifyCancellation)) { final AbortableSubscriber oldSubscriber = this.subscriber; assert oldSubscriber != null; failLateSubscriber(executor, subscriber, oldSubscriber.subscriber); @@ -122,8 +131,9 @@ private void subscribe0(Subscriber subscriber, EventExecutor executor } private boolean subscribe1(Subscriber subscriber, EventExecutor executor, - boolean notifyCancellation) { - final AbortableSubscriber s = new AbortableSubscriber(this, subscriber, executor, notifyCancellation); + boolean withPooledObjects, boolean notifyCancellation) { + final AbortableSubscriber s = + new AbortableSubscriber(this, subscriber, executor, withPooledObjects, notifyCancellation); if (!subscriberUpdater.compareAndSet(this, null, s)) { return false; } @@ -173,7 +183,7 @@ private void abort0(Throwable cause) { final AbortableSubscriber abortable = new AbortableSubscriber(this, AbortingSubscriber.get(cause), ImmediateEventExecutor.INSTANCE, - false); + false, false); if (!subscriberUpdater.compareAndSet(this, null, abortable)) { this.subscriber.abort(cause); return; @@ -192,6 +202,7 @@ public final CompletableFuture whenComplete() { static final class AbortableSubscriber implements Subscriber, Subscription { private final PublisherBasedStreamMessage parent; private final EventExecutor executor; + private boolean withPooledObjects; private final boolean notifyCancellation; private Subscriber subscriber; @Nullable @@ -201,10 +212,11 @@ static final class AbortableSubscriber implements Subscriber, Subscripti @SuppressWarnings("unchecked") AbortableSubscriber(PublisherBasedStreamMessage parent, Subscriber subscriber, - EventExecutor executor, boolean notifyCancellation) { + EventExecutor executor, boolean withPooledObjects, boolean notifyCancellation) { this.parent = parent; this.subscriber = (Subscriber) subscriber; this.executor = executor; + this.withPooledObjects = withPooledObjects; this.notifyCancellation = notifyCancellation; } @@ -324,6 +336,9 @@ private void onNext0(Object obj) { parent.demand--; } try { + if (!withPooledObjects) { + obj = PooledObjects.copyAndClose(obj); + } subscriber.onNext(obj); } catch (Throwable t) { abort(t); diff --git a/core/src/test/java/com/linecorp/armeria/common/stream/AbortableSubscriberBlackboxTckTest.java b/core/src/test/java/com/linecorp/armeria/common/stream/AbortableSubscriberBlackboxTckTest.java index 00c97e74cbf..bb4811e7927 100644 --- a/core/src/test/java/com/linecorp/armeria/common/stream/AbortableSubscriberBlackboxTckTest.java +++ b/core/src/test/java/com/linecorp/armeria/common/stream/AbortableSubscriberBlackboxTckTest.java @@ -53,7 +53,8 @@ public Long createElement(int element) { @Override public Subscriber createSubscriber() { - return new AbortableSubscriber(publisher, NoopSubscriber.get(), ImmediateEventExecutor.INSTANCE, false); + return new AbortableSubscriber(publisher, NoopSubscriber.get(), ImmediateEventExecutor.INSTANCE, + false, false); } @Test(enabled = false) diff --git a/core/src/test/java/com/linecorp/armeria/common/stream/FilteredStreamMessageTest.java b/core/src/test/java/com/linecorp/armeria/common/stream/FilteredStreamMessageTest.java index c95aced6b8f..9e015fbfcca 100644 --- a/core/src/test/java/com/linecorp/armeria/common/stream/FilteredStreamMessageTest.java +++ b/core/src/test/java/com/linecorp/armeria/common/stream/FilteredStreamMessageTest.java @@ -107,7 +107,8 @@ public void onComplete() { @Test void notifyCancellation() { - final HttpData data = HttpData.wrap(newPooledBuffer()).withEndOfStream(); + final ByteBuf buf = newPooledBuffer(); + final HttpData data = HttpData.wrap(buf).withEndOfStream(); final DefaultStreamMessage stream = new DefaultStreamMessage<>(); stream.write(data); stream.close(); @@ -119,7 +120,7 @@ protected HttpData filter(HttpData obj) { return obj; } }; - SubscriptionOptionTest.notifyCancellation(filtered); + SubscriptionOptionTest.notifyCancellation(buf, filtered); } @Test diff --git a/core/src/test/java/com/linecorp/armeria/common/stream/PublisherBasedStreamMessageTest.java b/core/src/test/java/com/linecorp/armeria/common/stream/PublisherBasedStreamMessageTest.java index 9ee7da40e5c..5e5cbb06669 100644 --- a/core/src/test/java/com/linecorp/armeria/common/stream/PublisherBasedStreamMessageTest.java +++ b/core/src/test/java/com/linecorp/armeria/common/stream/PublisherBasedStreamMessageTest.java @@ -16,6 +16,7 @@ package com.linecorp.armeria.common.stream; +import static com.linecorp.armeria.common.stream.StreamMessageTest.newPooledBuffer; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.fail; @@ -40,6 +41,8 @@ import com.linecorp.armeria.common.HttpData; import com.linecorp.armeria.common.stream.PublisherBasedStreamMessage.AbortableSubscriber; +import io.netty.buffer.ByteBuf; + class PublisherBasedStreamMessageTest { /** @@ -97,9 +100,11 @@ void testAbortWithoutSubscriber(@Nullable Throwable cause) { @Test void notifyCancellation() { + final ByteBuf buf = newPooledBuffer(); final DefaultStreamMessage delegate = new DefaultStreamMessage<>(); + delegate.write(HttpData.wrap(buf)); final PublisherBasedStreamMessage p = new PublisherBasedStreamMessage<>(delegate); - SubscriptionOptionTest.notifyCancellation(p); + SubscriptionOptionTest.notifyCancellation(buf, p); } @Test diff --git a/core/src/test/java/com/linecorp/armeria/common/stream/SubscriptionOptionTest.java b/core/src/test/java/com/linecorp/armeria/common/stream/SubscriptionOptionTest.java index 15d348129f8..ac19c74898f 100644 --- a/core/src/test/java/com/linecorp/armeria/common/stream/SubscriptionOptionTest.java +++ b/core/src/test/java/com/linecorp/armeria/common/stream/SubscriptionOptionTest.java @@ -39,6 +39,7 @@ import com.linecorp.armeria.common.HttpData; import io.netty.buffer.ByteBuf; +import reactor.core.publisher.Mono; class SubscriptionOptionTest { @@ -100,11 +101,11 @@ public void onComplete() { @ParameterizedTest @ArgumentsSource(PooledHttpDataStreamProvider.class) - void notifyCancellation(HttpData unused1, ByteBuf unused2, StreamMessage stream) { - notifyCancellation(stream); + void notifyCancellation(HttpData data, ByteBuf buf, StreamMessage stream) { + notifyCancellation(buf, stream); } - static void notifyCancellation(StreamMessage stream) { + static void notifyCancellation(ByteBuf buf, StreamMessage stream) { final AtomicBoolean completed = new AtomicBoolean(); stream.subscribe(new Subscriber() { @Override @@ -131,6 +132,7 @@ public void onComplete() { await().untilAsserted(() -> assertThat(completed).isTrue()); await().untilAsserted(() -> assertThat(stream.whenComplete()).isCompletedExceptionally()); + assertThat(buf.refCnt()).isZero(); } static SubscriptionOption[] subscriptionOptions(boolean subscribedWithPooledObjects) { @@ -145,7 +147,7 @@ private static class PooledHttpDataStreamProvider implements ArgumentsProvider { @Override public Stream provideArguments(ExtensionContext context) { - return Stream.of(defaultStream(), fixedStream(), deferredStream()); + return Stream.of(defaultStream(), fixedStream(), deferredStream(), publisherBasedStream()); } private static Arguments defaultStream() { @@ -174,5 +176,14 @@ private static Arguments deferredStream() { d.close(); return of(data, buf, deferredStream); } + + private static Arguments publisherBasedStream() { + final ByteBuf buf = newPooledBuffer(); + final HttpData data = HttpData.wrap(buf).withEndOfStream(); + final PublisherBasedStreamMessage publisherBasedStream = + new PublisherBasedStreamMessage<>(Mono.just(data) + .doOnDiscard(HttpData.class, HttpData::close)); + return of(data, buf, publisherBasedStream); + } } } diff --git a/spring/boot2-webflux-autoconfigure/src/test/java/com/linecorp/armeria/spring/web/reactive/ReactiveWebServerCompressionLeakTest.java b/spring/boot2-webflux-autoconfigure/src/test/java/com/linecorp/armeria/spring/web/reactive/ReactiveWebServerCompressionLeakTest.java new file mode 100644 index 00000000000..5ad8943a8d6 --- /dev/null +++ b/spring/boot2-webflux-autoconfigure/src/test/java/com/linecorp/armeria/spring/web/reactive/ReactiveWebServerCompressionLeakTest.java @@ -0,0 +1,121 @@ +/* + * Copyright 2021 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package com.linecorp.armeria.spring.web.reactive; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.ArrayList; +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.reactivestreams.Publisher; +import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.context.SpringBootTest.WebEnvironment; +import org.springframework.boot.web.server.LocalServerPort; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.NettyDataBuffer; +import org.springframework.http.MediaType; +import org.springframework.http.server.reactive.ServerHttpResponseDecorator; +import org.springframework.stereotype.Component; +import org.springframework.web.reactive.function.BodyInserters; +import org.springframework.web.reactive.function.server.RequestPredicates; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerRequest; +import org.springframework.web.reactive.function.server.ServerResponse; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebFilter; +import org.springframework.web.server.WebFilterChain; + +import com.linecorp.armeria.client.WebClient; +import com.linecorp.armeria.client.encoding.DecodingClient; +import com.linecorp.armeria.common.AggregatedHttpResponse; +import com.linecorp.armeria.server.encoding.EncodingService; +import com.linecorp.armeria.spring.ArmeriaServerConfigurator; + +import reactor.core.publisher.Mono; + +@SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT) +class ReactiveWebServerCompressionLeakTest { + + private static final List nettyData = new ArrayList<>(); + + @SpringBootApplication + @Configuration + static class TestConfiguration { + + @Bean + public RouterFunction route(TestHandler testHandler) { + return RouterFunctions.route(RequestPredicates.GET("/hello"), testHandler::hello); + } + + @Component + static class TestHandler { + Mono hello(ServerRequest request) { + return ServerResponse.ok().contentType(MediaType.TEXT_PLAIN) + .body(BodyInserters.fromValue("Hello Armeria")); + } + } + + @Bean + public ArmeriaServerConfigurator configurator() { + return builder -> builder.decorator(EncodingService.builder() + .minBytesToForceChunkedEncoding(5) + .newDecorator()); + } + + @Component + private static final class HttpDataCaptor implements WebFilter { + + @Override + public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { + final ServerHttpResponseDecorator httpResponseDecorator = + new ServerHttpResponseDecorator(exchange.getResponse()) { + @Override + public Mono writeWith(Publisher body) { + final Mono buffer = Mono.from(body); + return super.writeWith(buffer.doOnNext(b -> { + assert b instanceof NettyDataBuffer; + nettyData.add((NettyDataBuffer) b); + })); + } + }; + return chain.filter(exchange.mutate().response(httpResponseDecorator).build()); + } + } + } + + @LocalServerPort + int port; + + @Test + void nettyDataBufferShouldBeReleaseWhenCompressionEnabled() throws Exception { + final WebClient client = webClient(); + final AggregatedHttpResponse response = client.get("/hello").aggregate().join(); + assertThat(response.contentUtf8()).isEqualTo("Hello Armeria"); + assertThat(nettyData.size()).isOne(); + assertThat(nettyData.get(0).getNativeBuffer().refCnt()).isZero(); + } + + private WebClient webClient() { + return WebClient.builder("http://127.0.0.1:" + port) + .decorator(DecodingClient.newDecorator()) + .build(); + } +}