Skip to content

Commit

Permalink
Add support for alternative protobuf content types (line#4364)
Browse files Browse the repository at this point in the history
Motivation:
Armeria `UnframedGrpcService` doesn't support alternative protobuf content types

Modifications:
- add `application/x-protobuf` and `application/x-google-protobuf` to MediaType
- add isProtobuf() function
- identify Protobuf contentType in UnframedGrpcService using isProtobuf()
- respond with request-given protobuf content type for deframed response

Result:

- Closes line#4355
- Armeria `UnframedGrpcService` now supports `application/x-protobuf` and `application/x-google-protobuf` media types.
  • Loading branch information
mscheong01 authored and heowc committed Sep 24, 2022
1 parent 5e95716 commit aab6852
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 20 deletions.
16 changes: 16 additions & 0 deletions core/src/main/java/com/linecorp/armeria/common/MediaType.java
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,8 @@ private static MediaType addKnownType(MediaType mediaType) {
* <a href="https://developers.google.com/protocol-buffers">Protocol buffers</a>.
*/
public static final MediaType PROTOBUF = createConstant(APPLICATION_TYPE, "protobuf");
public static final MediaType X_PROTOBUF = createConstant(APPLICATION_TYPE, "x-protobuf");
public static final MediaType X_GOOGLE_PROTOBUF = createConstant(APPLICATION_TYPE, "x-google-protobuf");

/**
* <a href="https://en.wikipedia.org/wiki/RDF/XML">RDF/XML</a> documents, which are XML
Expand Down Expand Up @@ -1052,6 +1054,20 @@ public boolean isJson() {
return is(JSON) || subtype().endsWith("+json");
}

/**
* Returns {@code true} when the subtype is one of {@link MediaType#PROTOBUF}, {@link MediaType#X_PROTOBUF}
* and {@link MediaType#X_GOOGLE_PROTOBUF}. Otherwise {@code false}.
*
* <pre>{@code
* PROTOBUF.isProtobuf() // true
* X_PROTOBUF.isProtobuf() // true
* X_GOOGLE_PROTOBUF.isProtobuf() // true
* }</pre>
*/
public boolean isProtobuf() {
return is(PROTOBUF) || is(X_PROTOBUF)|| is(X_GOOGLE_PROTOBUF);
}

/**
* Returns {@code true} if this {@link MediaType} belongs to the given {@link MediaType}.
* Similar to what {@link MediaType#is(MediaType)} does except that this one compares the parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,14 @@ public final class MediaTypeNames {
* {@value #PROTOBUF}.
*/
public static final String PROTOBUF = "application/protobuf";
/**
* {@value #X_PROTOBUF}.
*/
public static final String X_PROTOBUF = "application/x-protobuf";
/**
* {@value #X_GOOGLE_PROTOBUF}.
*/
public static final String X_GOOGLE_PROTOBUF = "application/x-google-protobuf";
/**
* {@value #RDF_XML_UTF_8}.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ protected void frameAndServe(
RequestHeaders grpcHeaders,
HttpData content,
CompletableFuture<HttpResponse> res,
@Nullable Function<HttpData, HttpData> responseBodyConverter) {
@Nullable Function<HttpData, HttpData> responseBodyConverter,
MediaType responseContentType) {
final HttpRequest grpcRequest;
try (ArmeriaMessageFramer framer = new ArmeriaMessageFramer(
ctx.alloc(), ArmeriaMessageFramer.NO_MAX_OUTBOUND_MESSAGE_SIZE, false)) {
Expand Down Expand Up @@ -170,7 +171,7 @@ protected void frameAndServe(
res.completeExceptionally(t);
} else {
deframeAndRespond(ctx, framedResponse, res, unframedGrpcErrorHandler,
responseBodyConverter);
responseBodyConverter, responseContentType);
}
}
return null;
Expand All @@ -182,7 +183,8 @@ static void deframeAndRespond(ServiceRequestContext ctx,
AggregatedHttpResponse grpcResponse,
CompletableFuture<HttpResponse> res,
UnframedGrpcErrorHandler unframedGrpcErrorHandler,
@Nullable Function<HttpData, HttpData> responseBodyConverter) {
@Nullable Function<HttpData, HttpData> responseBodyConverter,
MediaType responseContentType) {
final HttpHeaders trailers = !grpcResponse.trailers().isEmpty() ?
grpcResponse.trailers() : grpcResponse.headers();
final String grpcStatusCode = trailers.get(GrpcHeaderNames.GRPC_STATUS);
Expand Down Expand Up @@ -210,15 +212,15 @@ static void deframeAndRespond(ServiceRequestContext ctx,
}

final MediaType grpcMediaType = grpcResponse.contentType();
if (grpcMediaType == null) {
PooledObjects.close(grpcResponse.content());
res.completeExceptionally(new NullPointerException("MediaType is undefined"));
return;
}

final ResponseHeadersBuilder unframedHeaders = grpcResponse.headers().toBuilder();
unframedHeaders.set(GrpcHeaderNames.GRPC_STATUS, grpcStatusCode); // grpcStatusCode is 0 which is OK.
if (grpcMediaType != null) {
if (grpcMediaType.is(GrpcSerializationFormats.PROTO.mediaType())) {
unframedHeaders.contentType(MediaType.PROTOBUF);
} else if (grpcMediaType.is(GrpcSerializationFormats.JSON.mediaType())) {
unframedHeaders.contentType(MediaType.JSON_UTF_8);
}
}
unframedHeaders.contentType(responseContentType);

final ArmeriaMessageDeframer deframer = new ArmeriaMessageDeframer(
// Max outbound message size is handled by the GrpcService, so we don't need to set it here.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -557,8 +557,9 @@ private HttpResponse serve0(ServiceRequestContext ctx, HttpRequest req,
"gRPC encoding is not supported for non-framed requests.");
}

final MediaType jsonContentType = GrpcSerializationFormats.JSON.mediaType();
grpcHeaders.method(HttpMethod.POST)
.contentType(GrpcSerializationFormats.JSON.mediaType());
.contentType(jsonContentType);
// All clients support no encoding, and we don't support gRPC encoding for non-framed requests, so just
// clear the header if it's present.
grpcHeaders.remove(GrpcHeaderNames.GRPC_ACCEPT_ENCODING);
Expand All @@ -576,7 +577,7 @@ private HttpResponse serve0(ServiceRequestContext ctx, HttpRequest req,
ctx.setAttr(FramedGrpcService.RESOLVED_GRPC_METHOD, spec.method);
frameAndServe(unwrap(), ctx, grpcHeaders.build(),
convertToJson(ctx, clientRequest, spec),
responseFuture, generateResponseBodyConverter(spec));
responseFuture, generateResponseBodyConverter(spec), jsonContentType);
} catch (IllegalArgumentException iae) {
responseFuture.completeExceptionally(
HttpStatusException.of(HttpStatus.BAD_REQUEST, iae));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,16 @@ public HttpResponse serve(ServiceRequestContext ctx, HttpRequest req) throws Exc
final RequestHeadersBuilder grpcHeaders = clientHeaders.toBuilder();

final MediaType framedContentType;
if (contentType.is(MediaType.PROTOBUF)) {
if (contentType.isProtobuf()) {
framedContentType = GrpcSerializationFormats.PROTO.mediaType();
} else if (contentType.is(MediaType.JSON)) {
framedContentType = GrpcSerializationFormats.JSON.mediaType();
} else {
return HttpResponse.of(HttpStatus.UNSUPPORTED_MEDIA_TYPE,
MediaType.PLAIN_TEXT_UTF_8,
"Unsupported media type. Only application/protobuf is supported.");
"Unsupported media type. Only application/protobuf, " +
"application/x-protobuf, application/x-google-protobuf" +
"and application/json are supported.");
}
grpcHeaders.contentType(framedContentType);

Expand All @@ -149,8 +151,8 @@ public HttpResponse serve(ServiceRequestContext ctx, HttpRequest req) throws Exc
if (t != null) {
responseFuture.completeExceptionally(t);
} else {
frameAndServe(unwrap(), ctx, grpcHeaders.build(),
clientRequest.content(), responseFuture, null);
frameAndServe(unwrap(), ctx, grpcHeaders.build(), clientRequest.content(),
responseFuture, null, contentType);
}
}
return null;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright 2022 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/

package com.linecorp.armeria.server.grpc;

import static org.assertj.core.api.Assertions.assertThat;

import java.util.stream.Stream;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.ArgumentsProvider;
import org.junit.jupiter.params.provider.ArgumentsSource;

import com.linecorp.armeria.common.AggregatedHttpResponse;
import com.linecorp.armeria.common.HttpMethod;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.HttpStatus;
import com.linecorp.armeria.common.MediaType;
import com.linecorp.armeria.grpc.testing.TestServiceGrpc;
import com.linecorp.armeria.protobuf.EmptyProtos;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.testing.junit5.common.EventLoopExtension;

import io.grpc.BindableService;
import io.grpc.stub.StreamObserver;

public class UnframedGrpcServiceResponseMediaTypeTest {

@RegisterExtension
static EventLoopExtension eventLoop = new EventLoopExtension();

private static class TestService extends TestServiceGrpc.TestServiceImplBase {

@Override
public void emptyCall(EmptyProtos.Empty request, StreamObserver<EmptyProtos.Empty> responseObserver) {
responseObserver.onNext(EmptyProtos.Empty.newBuilder().build());
responseObserver.onCompleted();
}
}

private static final TestService testService = new TestService();
private static final int MAX_MESSAGE_BYTES = 1024;

@Test
void respondWithCorrespondingJsonMediaType() throws Exception {
final UnframedGrpcService unframedGrpcService = buildUnframedGrpcService(testService);

final HttpRequest request = HttpRequest.of(HttpMethod.POST,
"/armeria.grpc.testing.TestService/EmptyCall",
MediaType.JSON_UTF_8, "{}");
final ServiceRequestContext ctx = ServiceRequestContext.builder(request)
.build();

final AggregatedHttpResponse res = unframedGrpcService.serve(ctx, request).aggregate().join();
assertThat(res.status()).isEqualTo(HttpStatus.OK);
assertThat(res.contentType()).isEqualTo(MediaType.JSON_UTF_8);
}

@ParameterizedTest
@ArgumentsSource(ProtobufMediaTypeProvider.class)
void respondWithCorrespondingProtobufMediaType(MediaType protobufType) throws Exception {
final UnframedGrpcService unframedGrpcService = buildUnframedGrpcService(testService);

final HttpRequest request = HttpRequest.of(HttpMethod.POST,
"/armeria.grpc.testing.TestService/EmptyCall",
protobufType,
EmptyProtos.Empty.getDefaultInstance().toByteArray());
final ServiceRequestContext ctx = ServiceRequestContext.builder(request)
.build();

final AggregatedHttpResponse res = unframedGrpcService.serve(ctx, request).aggregate().join();
assertThat(res.status()).isEqualTo(HttpStatus.OK);
assertThat(res.contentType()).isEqualTo(protobufType);
}

private static class ProtobufMediaTypeProvider implements ArgumentsProvider {
@Override
public Stream<? extends Arguments> provideArguments(ExtensionContext context) {
return Stream.of(MediaType.PROTOBUF, MediaType.X_PROTOBUF, MediaType.X_GOOGLE_PROTOBUF)
.map(Arguments::of);
}
}

private static UnframedGrpcService buildUnframedGrpcService(BindableService bindableService) {
return (UnframedGrpcService) GrpcService.builder()
.addService(bindableService)
.maxRequestMessageLength(MAX_MESSAGE_BYTES)
.maxResponseMessageLength(MAX_MESSAGE_BYTES)
.enableUnframedRequests(true)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -122,24 +122,58 @@ void shouldClosePooledObjectsForNonOK() {
final ByteBuf byteBuf = Unpooled.buffer();
final ResponseHeaders responseHeaders = ResponseHeaders.builder(HttpStatus.OK)
.add(GrpcHeaderNames.GRPC_STATUS, "1")
.contentType(MediaType.PROTOBUF)
.build();
final AggregatedHttpResponse framedResponse = AggregatedHttpResponse.of(responseHeaders,
HttpData.wrap(byteBuf));
UnframedGrpcService.deframeAndRespond(ctx, framedResponse, res, UnframedGrpcErrorHandler.of(), null);
UnframedGrpcService.deframeAndRespond(ctx, framedResponse, res, UnframedGrpcErrorHandler.of(),
null, MediaType.PROTOBUF);
assertThat(byteBuf.refCnt()).isZero();
}

@Test
void shouldClosePooledObjectsForMissingMediaType() {
final CompletableFuture<HttpResponse> res = new CompletableFuture<>();
final ByteBuf byteBuf = Unpooled.buffer();
final ResponseHeaders responseHeaders = ResponseHeaders.builder(HttpStatus.OK)
.add(GrpcHeaderNames.GRPC_STATUS, "0")
.build();
final AggregatedHttpResponse framedResponse = AggregatedHttpResponse
.of(responseHeaders, HttpData.wrap(byteBuf));
AbstractUnframedGrpcService.deframeAndRespond(ctx, framedResponse, res, UnframedGrpcErrorHandler.of(),
null, MediaType.PROTOBUF);
assertThat(byteBuf.refCnt()).isZero();
}

@Test
void shouldClosePooledObjectsForMissingGrpcStatus() {
final CompletableFuture<HttpResponse> res = new CompletableFuture<>();
final ByteBuf byteBuf = Unpooled.buffer();
final ResponseHeaders responseHeaders = ResponseHeaders.of(HttpStatus.OK);
final ResponseHeaders responseHeaders = ResponseHeaders.builder(HttpStatus.OK)
.contentType(MediaType.PROTOBUF)
.build();
final AggregatedHttpResponse framedResponse = AggregatedHttpResponse.of(responseHeaders,
HttpData.wrap(byteBuf));
UnframedGrpcService.deframeAndRespond(ctx, framedResponse, res, UnframedGrpcErrorHandler.of(), null);
HttpData.wrap(byteBuf));
AbstractUnframedGrpcService.deframeAndRespond(ctx, framedResponse, res, UnframedGrpcErrorHandler.of(),
null, MediaType.PROTOBUF);
assertThat(byteBuf.refCnt()).isZero();
}

@Test
void succeedWithAllRequiredHeaders() throws Exception {
final CompletableFuture<HttpResponse> res = new CompletableFuture<>();
final ByteBuf byteBuf = Unpooled.buffer();
final ResponseHeaders responseHeaders = ResponseHeaders.builder(HttpStatus.OK)
.add(GrpcHeaderNames.GRPC_STATUS, "0")
.contentType(MediaType.PROTOBUF)
.build();
final AggregatedHttpResponse framedResponse = AggregatedHttpResponse
.of(responseHeaders, HttpData.wrap(byteBuf));
AbstractUnframedGrpcService.deframeAndRespond(ctx, framedResponse, res, UnframedGrpcErrorHandler.of(),
null, MediaType.PROTOBUF);
assertThat(HttpResponse.from(res).aggregate().get().status()).isEqualTo(HttpStatus.OK);
}

@Test
void unframedGrpcStatusFunction() throws Exception {
final TestService spyTestService = spy(testService);
Expand Down

0 comments on commit aab6852

Please sign in to comment.