Skip to content

Commit

Permalink
Support RxJava-wrapped HttpResult response in annotated services (#…
Browse files Browse the repository at this point in the history
…5386)

Motivation:

When returning `Single<HttpResult<T>>` from an annotated service, the
response is always serialized as `{}`, since we currently just try to
serialize `HttpResult` with Jackson. Instead, we should convert it to
`HttpResponse` and only do serialization on `HttpResult.content()`.

Modifications:

- Create `HttpResultUtil` with the logic needed to build `HttpResponse`
headers from `HttpResult`
- Apply the same `HttpResult` conversion logic from `AnnotatedService`
to `CompositeResponseConverterFunction`

Result:

- Closes #5380 
- Annotated services support `HttpResult` wrapped in RxJava objects
  • Loading branch information
KarboniteKream authored Jan 24, 2024
1 parent 7d5ed55 commit c20e058
Show file tree
Hide file tree
Showing 7 changed files with 278 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import com.linecorp.armeria.common.AggregatedHttpResponse;
import com.linecorp.armeria.common.ExchangeType;
import com.linecorp.armeria.common.Flags;
import com.linecorp.armeria.common.HttpHeaderNames;
import com.linecorp.armeria.common.HttpHeaders;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.HttpResponse;
Expand Down Expand Up @@ -249,6 +248,11 @@ Route route() {
return route;
}

// TODO: Expose through `AnnotatedServiceConfig`, see #5382.
HttpStatus defaultStatus() {
return defaultStatus;
}

HttpService withExceptionHandler(HttpService service) {
if (exceptionHandler == null) {
return service;
Expand Down Expand Up @@ -398,7 +402,7 @@ private HttpResponse convertResponse(ServiceRequestContext ctx, @Nullable Object

if (result instanceof HttpResult) {
final HttpResult<?> httpResult = (HttpResult<?>) result;
headers = buildResponseHeaders(ctx, httpResult.headers());
headers = HttpResultUtil.buildResponseHeaders(ctx, httpResult);
result = httpResult.content();
trailers = httpResult.trailers();
} else {
Expand Down Expand Up @@ -426,39 +430,17 @@ private HttpResponse convertResponseInternal(ServiceRequestContext ctx,
}
}

private ResponseHeaders buildResponseHeaders(ServiceRequestContext ctx, HttpHeaders customHeaders) {
final ResponseHeadersBuilder builder;

// Prefer ResponseHeaders#toBuilder because builder#add(Iterable) is an expensive operation.
if (customHeaders instanceof ResponseHeaders) {
builder = ((ResponseHeaders) customHeaders).toBuilder();
} else {
builder = ResponseHeaders.builder();
builder.add(customHeaders);
if (!builder.contains(HttpHeaderNames.STATUS)) {
builder.status(defaultStatus);
}
}
return maybeAddContentType(ctx, builder).build();
}

private ResponseHeaders buildResponseHeaders(ServiceRequestContext ctx) {
return maybeAddContentType(ctx, ResponseHeaders.builder(defaultStatus)).build();
}

private static ResponseHeadersBuilder maybeAddContentType(ServiceRequestContext ctx,
ResponseHeadersBuilder builder) {
final ResponseHeadersBuilder builder = ResponseHeaders.builder(defaultStatus);
if (builder.status().isContentAlwaysEmpty()) {
return builder;
}
if (builder.contentType() != null) {
return builder;
return builder.build();
}

final MediaType negotiatedResponseMediaType = ctx.negotiatedResponseMediaType();
if (negotiatedResponseMediaType != null) {
builder.contentType(negotiatedResponseMediaType);
}
return builder;
return builder.build();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import com.linecorp.armeria.common.util.SafeCloseable;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.server.annotation.FallthroughException;
import com.linecorp.armeria.server.annotation.HttpResult;
import com.linecorp.armeria.server.annotation.ResponseConverterFunction;

/**
Expand Down Expand Up @@ -62,6 +63,12 @@ public HttpResponse convertResponse(ServiceRequestContext ctx,
if (result instanceof HttpResponse) {
return (HttpResponse) result;
}
if (result instanceof HttpResult) {
final HttpResult<?> httpResult = (HttpResult<?>) result;
headers = HttpResultUtil.buildResponseHeaders(ctx, httpResult);
result = httpResult.content();
trailers = httpResult.trailers();
}
try (SafeCloseable ignored = ctx.push()) {
for (final ResponseConverterFunction func : functions) {
try {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright 2024 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/

package com.linecorp.armeria.internal.server.annotation;

import com.linecorp.armeria.common.HttpHeaderNames;
import com.linecorp.armeria.common.HttpHeaders;
import com.linecorp.armeria.common.HttpStatus;
import com.linecorp.armeria.common.MediaType;
import com.linecorp.armeria.common.ResponseHeaders;
import com.linecorp.armeria.common.ResponseHeadersBuilder;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.server.annotation.HttpResult;

final class HttpResultUtil {
static ResponseHeaders buildResponseHeaders(ServiceRequestContext ctx, HttpResult<?> result) {
final ResponseHeadersBuilder builder;
final HttpHeaders customHeaders = result.headers();

// Prefer ResponseHeaders#toBuilder because builder#add(Iterable) is an expensive operation.
if (customHeaders instanceof ResponseHeaders) {
builder = ((ResponseHeaders) customHeaders).toBuilder();
} else {
builder = ResponseHeaders.builder();
builder.add(customHeaders);

if (!builder.contains(HttpHeaderNames.STATUS)) {
final AnnotatedService service = ctx.config().service().as(AnnotatedService.class);
if (service != null) {
builder.status(service.defaultStatus());
} else {
builder.status(HttpStatus.OK);
}
}
}

return maybeAddContentType(ctx, builder).build();
}

private static ResponseHeadersBuilder maybeAddContentType(ServiceRequestContext ctx,
ResponseHeadersBuilder builder) {
if (builder.status().isContentAlwaysEmpty()) {
return builder;
}
if (builder.contentType() != null) {
return builder;
}

final MediaType negotiatedResponseMediaType = ctx.negotiatedResponseMediaType();
if (negotiatedResponseMediaType != null) {
builder.contentType(negotiatedResponseMediaType);
}

return builder;
}

private HttpResultUtil() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
import com.linecorp.armeria.server.annotation.Delimiter;
import com.linecorp.armeria.server.annotation.Get;
import com.linecorp.armeria.server.annotation.Header;
import com.linecorp.armeria.server.annotation.HttpResult;
import com.linecorp.armeria.server.annotation.Order;
import com.linecorp.armeria.server.annotation.Param;
import com.linecorp.armeria.server.annotation.Path;
Expand Down Expand Up @@ -200,6 +201,20 @@ public CompletableFuture<Integer> returnIntAsync(@Param int var) {
return UnmodifiableFuture.completedFuture(var).thenApply(n -> n + 1);
}

@Get
@Path("/string-response-async/:var")
public CompletableFuture<HttpResponse> returnStringResponseAsync(@Param String var) {
return CompletableFuture.supplyAsync(() -> HttpResponse.of(var));
}

// Wrapped content is handled by a custom String -> HttpResponse converter.
@Get
@Path("/string-result-async/:var")
@ResponseConverter(NaiveStringConverterFunction.class)
public CompletableFuture<HttpResult<String>> returnStringResultAsync(@Param String var) {
return CompletableFuture.supplyAsync(() -> HttpResult.of(var));
}

@Get
@Path("/path/ctx/async/:var")
public static CompletableFuture<String> returnPathCtxAsync(@Param int var,
Expand Down Expand Up @@ -845,6 +860,10 @@ void testAnnotatedService() throws Exception {
testBody(hc, get("/1/string/%F0%90%8D%88"), "String: \uD800\uDF48", // 𐍈
StandardCharsets.UTF_8);

// Deferred HttpResponse and HttpResult.
testBody(hc, get("/1/string-response-async/blah"), "blah");
testBody(hc, get("/1/string-result-async/blah"), "String: blah");

// Get a requested path as typed string from ServiceRequestContext or HttpRequest
testBody(hc, get("/1/path/ctx/async/1"), "String[/1/path/ctx/async/1]");
testBody(hc, get("/1/path/req/async/1"), "String[/1/path/req/async/1]");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Copyright 2024 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/

package com.linecorp.armeria.internal.server.annotation;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;

import org.junit.jupiter.api.Test;

import com.linecorp.armeria.common.HttpHeaders;
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.HttpStatus;
import com.linecorp.armeria.common.MediaType;
import com.linecorp.armeria.common.ResponseHeaders;
import com.linecorp.armeria.server.Server;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.server.annotation.Get;
import com.linecorp.armeria.server.annotation.HttpResult;
import com.linecorp.armeria.server.annotation.StatusCode;

class HttpResultUtilTest {

@Test
void shouldReuseResponseHeaders() {
final ServiceRequestContext ctx = mock(ServiceRequestContext.class);

final ResponseHeaders headers = ResponseHeaders
.builder(HttpStatus.OK)
.contentType(MediaType.PLAIN_TEXT_UTF_8)
.add("foo", "bar")
.build();
final HttpResult<Integer> result = HttpResult.of(headers, 123);

final ResponseHeaders actual = HttpResultUtil.buildResponseHeaders(ctx, result);
assertThat(actual).isEqualTo(headers);
assertThat(actual.contentType()).isEqualTo(MediaType.PLAIN_TEXT_UTF_8);
assertThat(actual.get("foo")).isEqualTo("bar");

verifyNoInteractions(ctx);
}

@Test
void shouldNotAddContentTypeWhenNoContent() {
final ServiceRequestContext ctx = mock(ServiceRequestContext.class);

final ResponseHeaders headers = ResponseHeaders.of(HttpStatus.NO_CONTENT);
final HttpResult<Integer> result = HttpResult.of(headers, 123);

final ResponseHeaders actual = HttpResultUtil.buildResponseHeaders(ctx, result);
assertThat(actual).isEqualTo(headers);
assertThat(actual.contentType()).isNull();

verifyNoInteractions(ctx);
}

@Test
void shouldNegotiateContentType() {
final ServiceRequestContext ctx = mock(ServiceRequestContext.class);
when(ctx.negotiatedResponseMediaType()).thenReturn(MediaType.JSON_UTF_8);

final ResponseHeaders headers = ResponseHeaders.of(HttpStatus.OK, "foo", "bar");
final HttpResult<Integer> result = HttpResult.of(headers, 123);

final ResponseHeaders actual = HttpResultUtil.buildResponseHeaders(ctx, result);
assertThat(actual.status()).isEqualTo(HttpStatus.OK);
assertThat(actual.contentType()).isEqualTo(MediaType.JSON_UTF_8);
assertThat(actual.get("foo")).isEqualTo("bar");
}

@Test
void shouldAddStatusFromAnnotatedService() {
final Server server = Server
.builder()
.annotatedService("/", new MyAnnotatedService())
.build();

final ServiceRequestContext ctx = mock(ServiceRequestContext.class);
when(ctx.config()).thenReturn(server.serviceConfigs().get(0));
when(ctx.negotiatedResponseMediaType()).thenReturn(MediaType.PLAIN_TEXT_UTF_8);

final HttpHeaders headers = HttpHeaders.of("foo", "bar");
final HttpResult<Integer> result = HttpResult.of(headers, 123);

final ResponseHeaders actual = HttpResultUtil.buildResponseHeaders(ctx, result);
assertThat(actual.status()).isEqualTo(HttpStatus.ACCEPTED);
assertThat(actual.contentType()).isEqualTo(MediaType.PLAIN_TEXT_UTF_8);
assertThat(actual.get("foo")).isEqualTo("bar");
}

@Test
void shouldUseOkStatusWhenNotAnnotatedService() {
final Server server = Server
.builder()
.service("/", (ctx, req) -> HttpResponse.of(HttpStatus.ACCEPTED))
.build();

final ServiceRequestContext ctx = mock(ServiceRequestContext.class);
when(ctx.config()).thenReturn(server.serviceConfigs().get(0));
when(ctx.negotiatedResponseMediaType()).thenReturn(MediaType.JSON_UTF_8);

final HttpHeaders headers = HttpHeaders.of("foo", "bar");
final HttpResult<Integer> result = HttpResult.of(headers, 123);

final ResponseHeaders actual = HttpResultUtil.buildResponseHeaders(ctx, result);
assertThat(actual.status()).isEqualTo(HttpStatus.OK);
assertThat(actual.contentType()).isEqualTo(MediaType.JSON_UTF_8);
assertThat(actual.get("foo")).isEqualTo("bar");
}

public class MyAnnotatedService {
@Get
@StatusCode(202)
public int myMethod() {
return 123;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import com.linecorp.armeria.server.ServerBuilder;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.server.annotation.Get;
import com.linecorp.armeria.server.annotation.HttpResult;
import com.linecorp.armeria.server.annotation.Post;
import com.linecorp.armeria.server.annotation.ProducesJson;
import com.linecorp.armeria.server.annotation.ProducesJsonSequences;
Expand Down Expand Up @@ -110,6 +111,11 @@ public Maybe<HttpResponse> httpResponse() {
return Maybe.just(HttpResponse.of("a"));
}

@Get("/http-result")
public Maybe<HttpResult<String>> httpResult() {
return Maybe.just(HttpResult.of("a"));
}

@Post("/defer-empty-post")
public Maybe<String> deferEmptyPost() {
final RequestContext ctx = RequestContext.current();
Expand Down Expand Up @@ -153,6 +159,11 @@ public Single<HttpResponse> httpResponse() {
return Single.just(HttpResponse.of("a"));
}

@Get("/http-result")
public Single<HttpResult<String>> httpResult() {
return Single.just(HttpResult.of("a"));
}

@Post("/defer-empty-post")
public Single<String> deferEmptyPost() {
final RequestContext ctx = RequestContext.current();
Expand Down Expand Up @@ -361,6 +372,10 @@ void maybe() {
assertThat(res.contentType()).isEqualTo(MediaType.PLAIN_TEXT_UTF_8);
assertThat(res.contentUtf8()).isEqualTo("a");

res = client.get("/http-result").aggregate().join();
assertThat(res.contentType()).isEqualTo(MediaType.PLAIN_TEXT_UTF_8);
assertThat(res.contentUtf8()).isEqualTo("a");

res = client.post("/defer-empty-post", "").aggregate().join();
assertThat(res.contentType()).isEqualTo(MediaType.PLAIN_TEXT_UTF_8);
assertThat(res.contentUtf8()).isEqualTo("a");
Expand Down Expand Up @@ -391,6 +406,10 @@ void single() {
assertThat(res.contentType()).isEqualTo(MediaType.PLAIN_TEXT_UTF_8);
assertThat(res.contentUtf8()).isEqualTo("a");

res = client.get("/http-result").aggregate().join();
assertThat(res.contentType()).isEqualTo(MediaType.PLAIN_TEXT_UTF_8);
assertThat(res.contentUtf8()).isEqualTo("a");

res = client.post("/defer-empty-post", "").aggregate().join();
assertThat(res.contentType()).isEqualTo(MediaType.PLAIN_TEXT_UTF_8);
assertThat(res.contentUtf8()).isEqualTo("a");
Expand Down
Loading

0 comments on commit c20e058

Please sign in to comment.