From c20e058cd2c7a8b87574b3373526bad420e9930c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Klemen=20Ko=C5=A1ir?= Date: Wed, 24 Jan 2024 17:50:56 +0900 Subject: [PATCH] Support RxJava-wrapped `HttpResult` response in annotated services (#5386) Motivation: When returning `Single>` from an annotated service, the response is always serialized as `{}`, since we currently just try to serialize `HttpResult` with Jackson. Instead, we should convert it to `HttpResponse` and only do serialization on `HttpResult.content()`. Modifications: - Create `HttpResultUtil` with the logic needed to build `HttpResponse` headers from `HttpResult` - Apply the same `HttpResult` conversion logic from `AnnotatedService` to `CompositeResponseConverterFunction` Result: - Closes #5380 - Annotated services support `HttpResult` wrapped in RxJava objects --- .../server/annotation/AnnotatedService.java | 38 ++--- .../CompositeResponseConverterFunction.java | 7 + .../server/annotation/HttpResultUtil.java | 71 ++++++++++ .../annotation/AnnotatedServiceTest.java | 19 +++ .../server/annotation/HttpResultUtilTest.java | 133 ++++++++++++++++++ ...servableResponseConverterFunctionTest.java | 19 +++ ...servableResponseConverterFunctionTest.java | 19 +++ 7 files changed, 278 insertions(+), 28 deletions(-) create mode 100644 core/src/main/java/com/linecorp/armeria/internal/server/annotation/HttpResultUtil.java create mode 100644 core/src/test/java/com/linecorp/armeria/internal/server/annotation/HttpResultUtilTest.java diff --git a/core/src/main/java/com/linecorp/armeria/internal/server/annotation/AnnotatedService.java b/core/src/main/java/com/linecorp/armeria/internal/server/annotation/AnnotatedService.java index f9147e87f06..183b3a0653c 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/server/annotation/AnnotatedService.java +++ b/core/src/main/java/com/linecorp/armeria/internal/server/annotation/AnnotatedService.java @@ -41,7 +41,6 @@ import com.linecorp.armeria.common.AggregatedHttpResponse; import com.linecorp.armeria.common.ExchangeType; import com.linecorp.armeria.common.Flags; -import com.linecorp.armeria.common.HttpHeaderNames; import com.linecorp.armeria.common.HttpHeaders; import com.linecorp.armeria.common.HttpRequest; import com.linecorp.armeria.common.HttpResponse; @@ -249,6 +248,11 @@ Route route() { return route; } + // TODO: Expose through `AnnotatedServiceConfig`, see #5382. + HttpStatus defaultStatus() { + return defaultStatus; + } + HttpService withExceptionHandler(HttpService service) { if (exceptionHandler == null) { return service; @@ -398,7 +402,7 @@ private HttpResponse convertResponse(ServiceRequestContext ctx, @Nullable Object if (result instanceof HttpResult) { final HttpResult httpResult = (HttpResult) result; - headers = buildResponseHeaders(ctx, httpResult.headers()); + headers = HttpResultUtil.buildResponseHeaders(ctx, httpResult); result = httpResult.content(); trailers = httpResult.trailers(); } else { @@ -426,39 +430,17 @@ private HttpResponse convertResponseInternal(ServiceRequestContext ctx, } } - private ResponseHeaders buildResponseHeaders(ServiceRequestContext ctx, HttpHeaders customHeaders) { - final ResponseHeadersBuilder builder; - - // Prefer ResponseHeaders#toBuilder because builder#add(Iterable) is an expensive operation. - if (customHeaders instanceof ResponseHeaders) { - builder = ((ResponseHeaders) customHeaders).toBuilder(); - } else { - builder = ResponseHeaders.builder(); - builder.add(customHeaders); - if (!builder.contains(HttpHeaderNames.STATUS)) { - builder.status(defaultStatus); - } - } - return maybeAddContentType(ctx, builder).build(); - } - private ResponseHeaders buildResponseHeaders(ServiceRequestContext ctx) { - return maybeAddContentType(ctx, ResponseHeaders.builder(defaultStatus)).build(); - } - - private static ResponseHeadersBuilder maybeAddContentType(ServiceRequestContext ctx, - ResponseHeadersBuilder builder) { + final ResponseHeadersBuilder builder = ResponseHeaders.builder(defaultStatus); if (builder.status().isContentAlwaysEmpty()) { - return builder; - } - if (builder.contentType() != null) { - return builder; + return builder.build(); } + final MediaType negotiatedResponseMediaType = ctx.negotiatedResponseMediaType(); if (negotiatedResponseMediaType != null) { builder.contentType(negotiatedResponseMediaType); } - return builder; + return builder.build(); } /** diff --git a/core/src/main/java/com/linecorp/armeria/internal/server/annotation/CompositeResponseConverterFunction.java b/core/src/main/java/com/linecorp/armeria/internal/server/annotation/CompositeResponseConverterFunction.java index 67f0f72c3e7..d73d8fc0eb6 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/server/annotation/CompositeResponseConverterFunction.java +++ b/core/src/main/java/com/linecorp/armeria/internal/server/annotation/CompositeResponseConverterFunction.java @@ -29,6 +29,7 @@ import com.linecorp.armeria.common.util.SafeCloseable; import com.linecorp.armeria.server.ServiceRequestContext; import com.linecorp.armeria.server.annotation.FallthroughException; +import com.linecorp.armeria.server.annotation.HttpResult; import com.linecorp.armeria.server.annotation.ResponseConverterFunction; /** @@ -62,6 +63,12 @@ public HttpResponse convertResponse(ServiceRequestContext ctx, if (result instanceof HttpResponse) { return (HttpResponse) result; } + if (result instanceof HttpResult) { + final HttpResult httpResult = (HttpResult) result; + headers = HttpResultUtil.buildResponseHeaders(ctx, httpResult); + result = httpResult.content(); + trailers = httpResult.trailers(); + } try (SafeCloseable ignored = ctx.push()) { for (final ResponseConverterFunction func : functions) { try { diff --git a/core/src/main/java/com/linecorp/armeria/internal/server/annotation/HttpResultUtil.java b/core/src/main/java/com/linecorp/armeria/internal/server/annotation/HttpResultUtil.java new file mode 100644 index 00000000000..fb3d804885b --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/internal/server/annotation/HttpResultUtil.java @@ -0,0 +1,71 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.internal.server.annotation; + +import com.linecorp.armeria.common.HttpHeaderNames; +import com.linecorp.armeria.common.HttpHeaders; +import com.linecorp.armeria.common.HttpStatus; +import com.linecorp.armeria.common.MediaType; +import com.linecorp.armeria.common.ResponseHeaders; +import com.linecorp.armeria.common.ResponseHeadersBuilder; +import com.linecorp.armeria.server.ServiceRequestContext; +import com.linecorp.armeria.server.annotation.HttpResult; + +final class HttpResultUtil { + static ResponseHeaders buildResponseHeaders(ServiceRequestContext ctx, HttpResult result) { + final ResponseHeadersBuilder builder; + final HttpHeaders customHeaders = result.headers(); + + // Prefer ResponseHeaders#toBuilder because builder#add(Iterable) is an expensive operation. + if (customHeaders instanceof ResponseHeaders) { + builder = ((ResponseHeaders) customHeaders).toBuilder(); + } else { + builder = ResponseHeaders.builder(); + builder.add(customHeaders); + + if (!builder.contains(HttpHeaderNames.STATUS)) { + final AnnotatedService service = ctx.config().service().as(AnnotatedService.class); + if (service != null) { + builder.status(service.defaultStatus()); + } else { + builder.status(HttpStatus.OK); + } + } + } + + return maybeAddContentType(ctx, builder).build(); + } + + private static ResponseHeadersBuilder maybeAddContentType(ServiceRequestContext ctx, + ResponseHeadersBuilder builder) { + if (builder.status().isContentAlwaysEmpty()) { + return builder; + } + if (builder.contentType() != null) { + return builder; + } + + final MediaType negotiatedResponseMediaType = ctx.negotiatedResponseMediaType(); + if (negotiatedResponseMediaType != null) { + builder.contentType(negotiatedResponseMediaType); + } + + return builder; + } + + private HttpResultUtil() {} +} diff --git a/core/src/test/java/com/linecorp/armeria/internal/server/annotation/AnnotatedServiceTest.java b/core/src/test/java/com/linecorp/armeria/internal/server/annotation/AnnotatedServiceTest.java index 6789e378c10..807631b1ac4 100644 --- a/core/src/test/java/com/linecorp/armeria/internal/server/annotation/AnnotatedServiceTest.java +++ b/core/src/test/java/com/linecorp/armeria/internal/server/annotation/AnnotatedServiceTest.java @@ -86,6 +86,7 @@ import com.linecorp.armeria.server.annotation.Delimiter; import com.linecorp.armeria.server.annotation.Get; import com.linecorp.armeria.server.annotation.Header; +import com.linecorp.armeria.server.annotation.HttpResult; import com.linecorp.armeria.server.annotation.Order; import com.linecorp.armeria.server.annotation.Param; import com.linecorp.armeria.server.annotation.Path; @@ -200,6 +201,20 @@ public CompletableFuture returnIntAsync(@Param int var) { return UnmodifiableFuture.completedFuture(var).thenApply(n -> n + 1); } + @Get + @Path("/string-response-async/:var") + public CompletableFuture returnStringResponseAsync(@Param String var) { + return CompletableFuture.supplyAsync(() -> HttpResponse.of(var)); + } + + // Wrapped content is handled by a custom String -> HttpResponse converter. + @Get + @Path("/string-result-async/:var") + @ResponseConverter(NaiveStringConverterFunction.class) + public CompletableFuture> returnStringResultAsync(@Param String var) { + return CompletableFuture.supplyAsync(() -> HttpResult.of(var)); + } + @Get @Path("/path/ctx/async/:var") public static CompletableFuture returnPathCtxAsync(@Param int var, @@ -845,6 +860,10 @@ void testAnnotatedService() throws Exception { testBody(hc, get("/1/string/%F0%90%8D%88"), "String: \uD800\uDF48", // 𐍈 StandardCharsets.UTF_8); + // Deferred HttpResponse and HttpResult. + testBody(hc, get("/1/string-response-async/blah"), "blah"); + testBody(hc, get("/1/string-result-async/blah"), "String: blah"); + // Get a requested path as typed string from ServiceRequestContext or HttpRequest testBody(hc, get("/1/path/ctx/async/1"), "String[/1/path/ctx/async/1]"); testBody(hc, get("/1/path/req/async/1"), "String[/1/path/req/async/1]"); diff --git a/core/src/test/java/com/linecorp/armeria/internal/server/annotation/HttpResultUtilTest.java b/core/src/test/java/com/linecorp/armeria/internal/server/annotation/HttpResultUtilTest.java new file mode 100644 index 00000000000..ffcfe65c2a9 --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/internal/server/annotation/HttpResultUtilTest.java @@ -0,0 +1,133 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.internal.server.annotation; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +import org.junit.jupiter.api.Test; + +import com.linecorp.armeria.common.HttpHeaders; +import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.HttpStatus; +import com.linecorp.armeria.common.MediaType; +import com.linecorp.armeria.common.ResponseHeaders; +import com.linecorp.armeria.server.Server; +import com.linecorp.armeria.server.ServiceRequestContext; +import com.linecorp.armeria.server.annotation.Get; +import com.linecorp.armeria.server.annotation.HttpResult; +import com.linecorp.armeria.server.annotation.StatusCode; + +class HttpResultUtilTest { + + @Test + void shouldReuseResponseHeaders() { + final ServiceRequestContext ctx = mock(ServiceRequestContext.class); + + final ResponseHeaders headers = ResponseHeaders + .builder(HttpStatus.OK) + .contentType(MediaType.PLAIN_TEXT_UTF_8) + .add("foo", "bar") + .build(); + final HttpResult result = HttpResult.of(headers, 123); + + final ResponseHeaders actual = HttpResultUtil.buildResponseHeaders(ctx, result); + assertThat(actual).isEqualTo(headers); + assertThat(actual.contentType()).isEqualTo(MediaType.PLAIN_TEXT_UTF_8); + assertThat(actual.get("foo")).isEqualTo("bar"); + + verifyNoInteractions(ctx); + } + + @Test + void shouldNotAddContentTypeWhenNoContent() { + final ServiceRequestContext ctx = mock(ServiceRequestContext.class); + + final ResponseHeaders headers = ResponseHeaders.of(HttpStatus.NO_CONTENT); + final HttpResult result = HttpResult.of(headers, 123); + + final ResponseHeaders actual = HttpResultUtil.buildResponseHeaders(ctx, result); + assertThat(actual).isEqualTo(headers); + assertThat(actual.contentType()).isNull(); + + verifyNoInteractions(ctx); + } + + @Test + void shouldNegotiateContentType() { + final ServiceRequestContext ctx = mock(ServiceRequestContext.class); + when(ctx.negotiatedResponseMediaType()).thenReturn(MediaType.JSON_UTF_8); + + final ResponseHeaders headers = ResponseHeaders.of(HttpStatus.OK, "foo", "bar"); + final HttpResult result = HttpResult.of(headers, 123); + + final ResponseHeaders actual = HttpResultUtil.buildResponseHeaders(ctx, result); + assertThat(actual.status()).isEqualTo(HttpStatus.OK); + assertThat(actual.contentType()).isEqualTo(MediaType.JSON_UTF_8); + assertThat(actual.get("foo")).isEqualTo("bar"); + } + + @Test + void shouldAddStatusFromAnnotatedService() { + final Server server = Server + .builder() + .annotatedService("/", new MyAnnotatedService()) + .build(); + + final ServiceRequestContext ctx = mock(ServiceRequestContext.class); + when(ctx.config()).thenReturn(server.serviceConfigs().get(0)); + when(ctx.negotiatedResponseMediaType()).thenReturn(MediaType.PLAIN_TEXT_UTF_8); + + final HttpHeaders headers = HttpHeaders.of("foo", "bar"); + final HttpResult result = HttpResult.of(headers, 123); + + final ResponseHeaders actual = HttpResultUtil.buildResponseHeaders(ctx, result); + assertThat(actual.status()).isEqualTo(HttpStatus.ACCEPTED); + assertThat(actual.contentType()).isEqualTo(MediaType.PLAIN_TEXT_UTF_8); + assertThat(actual.get("foo")).isEqualTo("bar"); + } + + @Test + void shouldUseOkStatusWhenNotAnnotatedService() { + final Server server = Server + .builder() + .service("/", (ctx, req) -> HttpResponse.of(HttpStatus.ACCEPTED)) + .build(); + + final ServiceRequestContext ctx = mock(ServiceRequestContext.class); + when(ctx.config()).thenReturn(server.serviceConfigs().get(0)); + when(ctx.negotiatedResponseMediaType()).thenReturn(MediaType.JSON_UTF_8); + + final HttpHeaders headers = HttpHeaders.of("foo", "bar"); + final HttpResult result = HttpResult.of(headers, 123); + + final ResponseHeaders actual = HttpResultUtil.buildResponseHeaders(ctx, result); + assertThat(actual.status()).isEqualTo(HttpStatus.OK); + assertThat(actual.contentType()).isEqualTo(MediaType.JSON_UTF_8); + assertThat(actual.get("foo")).isEqualTo("bar"); + } + + public class MyAnnotatedService { + @Get + @StatusCode(202) + public int myMethod() { + return 123; + } + } +} diff --git a/rxjava2/src/test/java/com/linecorp/armeria/server/rxjava2/ObservableResponseConverterFunctionTest.java b/rxjava2/src/test/java/com/linecorp/armeria/server/rxjava2/ObservableResponseConverterFunctionTest.java index 2f17429ab2f..948a77e3b34 100644 --- a/rxjava2/src/test/java/com/linecorp/armeria/server/rxjava2/ObservableResponseConverterFunctionTest.java +++ b/rxjava2/src/test/java/com/linecorp/armeria/server/rxjava2/ObservableResponseConverterFunctionTest.java @@ -52,6 +52,7 @@ import com.linecorp.armeria.server.ServerBuilder; import com.linecorp.armeria.server.ServiceRequestContext; import com.linecorp.armeria.server.annotation.Get; +import com.linecorp.armeria.server.annotation.HttpResult; import com.linecorp.armeria.server.annotation.Post; import com.linecorp.armeria.server.annotation.ProducesJson; import com.linecorp.armeria.server.annotation.ProducesJsonSequences; @@ -110,6 +111,11 @@ public Maybe httpResponse() { return Maybe.just(HttpResponse.of("a")); } + @Get("/http-result") + public Maybe> httpResult() { + return Maybe.just(HttpResult.of("a")); + } + @Post("/defer-empty-post") public Maybe deferEmptyPost() { final RequestContext ctx = RequestContext.current(); @@ -153,6 +159,11 @@ public Single httpResponse() { return Single.just(HttpResponse.of("a")); } + @Get("/http-result") + public Single> httpResult() { + return Single.just(HttpResult.of("a")); + } + @Post("/defer-empty-post") public Single deferEmptyPost() { final RequestContext ctx = RequestContext.current(); @@ -361,6 +372,10 @@ void maybe() { assertThat(res.contentType()).isEqualTo(MediaType.PLAIN_TEXT_UTF_8); assertThat(res.contentUtf8()).isEqualTo("a"); + res = client.get("/http-result").aggregate().join(); + assertThat(res.contentType()).isEqualTo(MediaType.PLAIN_TEXT_UTF_8); + assertThat(res.contentUtf8()).isEqualTo("a"); + res = client.post("/defer-empty-post", "").aggregate().join(); assertThat(res.contentType()).isEqualTo(MediaType.PLAIN_TEXT_UTF_8); assertThat(res.contentUtf8()).isEqualTo("a"); @@ -391,6 +406,10 @@ void single() { assertThat(res.contentType()).isEqualTo(MediaType.PLAIN_TEXT_UTF_8); assertThat(res.contentUtf8()).isEqualTo("a"); + res = client.get("/http-result").aggregate().join(); + assertThat(res.contentType()).isEqualTo(MediaType.PLAIN_TEXT_UTF_8); + assertThat(res.contentUtf8()).isEqualTo("a"); + res = client.post("/defer-empty-post", "").aggregate().join(); assertThat(res.contentType()).isEqualTo(MediaType.PLAIN_TEXT_UTF_8); assertThat(res.contentUtf8()).isEqualTo("a"); diff --git a/rxjava3/src/test/java/com/linecorp/armeria/server/rxjava3/ObservableResponseConverterFunctionTest.java b/rxjava3/src/test/java/com/linecorp/armeria/server/rxjava3/ObservableResponseConverterFunctionTest.java index 319d26c7688..12f1a282695 100644 --- a/rxjava3/src/test/java/com/linecorp/armeria/server/rxjava3/ObservableResponseConverterFunctionTest.java +++ b/rxjava3/src/test/java/com/linecorp/armeria/server/rxjava3/ObservableResponseConverterFunctionTest.java @@ -40,6 +40,7 @@ import com.linecorp.armeria.internal.testing.GenerateNativeImageTrace; import com.linecorp.armeria.server.ServerBuilder; import com.linecorp.armeria.server.annotation.Get; +import com.linecorp.armeria.server.annotation.HttpResult; import com.linecorp.armeria.server.annotation.ProducesJson; import com.linecorp.armeria.server.annotation.ProducesJsonSequences; import com.linecorp.armeria.server.annotation.ProducesText; @@ -84,6 +85,11 @@ public Maybe error() { public Maybe httpResponse() { return Maybe.just(HttpResponse.of("a")); } + + @Get("/http-result") + public Maybe> httpResult() { + return Maybe.just(HttpResult.of("a")); + } }); sb.annotatedService("/single", new Object() { @@ -107,6 +113,11 @@ public Single error() { public Single httpResponse() { return Single.just(HttpResponse.of("a")); } + + @Get("/http-result") + public Single> httpResult() { + return Single.just(HttpResult.of("a")); + } }); sb.annotatedService("/completable", new Object() { @@ -205,6 +216,10 @@ void maybe() { res = client.get("/http-response").aggregate().join(); assertThat(res.contentType()).isEqualTo(MediaType.PLAIN_TEXT_UTF_8); assertThat(res.contentUtf8()).isEqualTo("a"); + + res = client.get("/http-result").aggregate().join(); + assertThat(res.contentType()).isEqualTo(MediaType.PLAIN_TEXT_UTF_8); + assertThat(res.contentUtf8()).isEqualTo("a"); } @Test @@ -227,6 +242,10 @@ void single() { res = client.get("/http-response").aggregate().join(); assertThat(res.contentType()).isEqualTo(MediaType.PLAIN_TEXT_UTF_8); assertThat(res.contentUtf8()).isEqualTo("a"); + + res = client.get("/http-result").aggregate().join(); + assertThat(res.contentType()).isEqualTo(MediaType.PLAIN_TEXT_UTF_8); + assertThat(res.contentUtf8()).isEqualTo("a"); } @Test