diff --git a/core/src/main/java/com/linecorp/armeria/common/MediaType.java b/core/src/main/java/com/linecorp/armeria/common/MediaType.java
index adc7c3c7816d..3ee9e94cd4d0 100644
--- a/core/src/main/java/com/linecorp/armeria/common/MediaType.java
+++ b/core/src/main/java/com/linecorp/armeria/common/MediaType.java
@@ -662,6 +662,8 @@ private static MediaType addKnownType(MediaType mediaType) {
* Protocol buffers.
*/
public static final MediaType PROTOBUF = createConstant(APPLICATION_TYPE, "protobuf");
+ public static final MediaType X_PROTOBUF = createConstant(APPLICATION_TYPE, "x-protobuf");
+ public static final MediaType X_GOOGLE_PROTOBUF = createConstant(APPLICATION_TYPE, "x-google-protobuf");
/**
* RDF/XML documents, which are XML
@@ -1052,6 +1054,20 @@ public boolean isJson() {
return is(JSON) || subtype().endsWith("+json");
}
+ /**
+ * Returns {@code true} when the subtype is one of {@link MediaType#PROTOBUF}, {@link MediaType#X_PROTOBUF}
+ * and {@link MediaType#X_GOOGLE_PROTOBUF}. Otherwise {@code false}.
+ *
+ *
{@code
+ * PROTOBUF.isProtobuf() // true
+ * X_PROTOBUF.isProtobuf() // true
+ * X_GOOGLE_PROTOBUF.isProtobuf() // true
+ * }
+ */
+ public boolean isProtobuf() {
+ return is(PROTOBUF) || is(X_PROTOBUF)|| is(X_GOOGLE_PROTOBUF);
+ }
+
/**
* Returns {@code true} if this {@link MediaType} belongs to the given {@link MediaType}.
* Similar to what {@link MediaType#is(MediaType)} does except that this one compares the parameters
diff --git a/core/src/main/java/com/linecorp/armeria/common/MediaTypeNames.java b/core/src/main/java/com/linecorp/armeria/common/MediaTypeNames.java
index ee8b0983943b..3efd086c7249 100644
--- a/core/src/main/java/com/linecorp/armeria/common/MediaTypeNames.java
+++ b/core/src/main/java/com/linecorp/armeria/common/MediaTypeNames.java
@@ -505,6 +505,14 @@ public final class MediaTypeNames {
* {@value #PROTOBUF}.
*/
public static final String PROTOBUF = "application/protobuf";
+ /**
+ * {@value #X_PROTOBUF}.
+ */
+ public static final String X_PROTOBUF = "application/x-protobuf";
+ /**
+ * {@value #X_GOOGLE_PROTOBUF}.
+ */
+ public static final String X_GOOGLE_PROTOBUF = "application/x-google-protobuf";
/**
* {@value #RDF_XML_UTF_8}.
*/
diff --git a/grpc/src/main/java/com/linecorp/armeria/server/grpc/AbstractUnframedGrpcService.java b/grpc/src/main/java/com/linecorp/armeria/server/grpc/AbstractUnframedGrpcService.java
index e3c79095bdc0..f4655117e337 100644
--- a/grpc/src/main/java/com/linecorp/armeria/server/grpc/AbstractUnframedGrpcService.java
+++ b/grpc/src/main/java/com/linecorp/armeria/server/grpc/AbstractUnframedGrpcService.java
@@ -138,7 +138,8 @@ protected void frameAndServe(
RequestHeaders grpcHeaders,
HttpData content,
CompletableFuture res,
- @Nullable Function responseBodyConverter) {
+ @Nullable Function responseBodyConverter,
+ MediaType responseContentType) {
final HttpRequest grpcRequest;
try (ArmeriaMessageFramer framer = new ArmeriaMessageFramer(
ctx.alloc(), ArmeriaMessageFramer.NO_MAX_OUTBOUND_MESSAGE_SIZE, false)) {
@@ -170,7 +171,7 @@ protected void frameAndServe(
res.completeExceptionally(t);
} else {
deframeAndRespond(ctx, framedResponse, res, unframedGrpcErrorHandler,
- responseBodyConverter);
+ responseBodyConverter, responseContentType);
}
}
return null;
@@ -182,7 +183,8 @@ static void deframeAndRespond(ServiceRequestContext ctx,
AggregatedHttpResponse grpcResponse,
CompletableFuture res,
UnframedGrpcErrorHandler unframedGrpcErrorHandler,
- @Nullable Function responseBodyConverter) {
+ @Nullable Function responseBodyConverter,
+ MediaType responseContentType) {
final HttpHeaders trailers = !grpcResponse.trailers().isEmpty() ?
grpcResponse.trailers() : grpcResponse.headers();
final String grpcStatusCode = trailers.get(GrpcHeaderNames.GRPC_STATUS);
@@ -210,15 +212,15 @@ static void deframeAndRespond(ServiceRequestContext ctx,
}
final MediaType grpcMediaType = grpcResponse.contentType();
+ if (grpcMediaType == null) {
+ PooledObjects.close(grpcResponse.content());
+ res.completeExceptionally(new NullPointerException("MediaType is undefined"));
+ return;
+ }
+
final ResponseHeadersBuilder unframedHeaders = grpcResponse.headers().toBuilder();
unframedHeaders.set(GrpcHeaderNames.GRPC_STATUS, grpcStatusCode); // grpcStatusCode is 0 which is OK.
- if (grpcMediaType != null) {
- if (grpcMediaType.is(GrpcSerializationFormats.PROTO.mediaType())) {
- unframedHeaders.contentType(MediaType.PROTOBUF);
- } else if (grpcMediaType.is(GrpcSerializationFormats.JSON.mediaType())) {
- unframedHeaders.contentType(MediaType.JSON_UTF_8);
- }
- }
+ unframedHeaders.contentType(responseContentType);
final ArmeriaMessageDeframer deframer = new ArmeriaMessageDeframer(
// Max outbound message size is handled by the GrpcService, so we don't need to set it here.
diff --git a/grpc/src/main/java/com/linecorp/armeria/server/grpc/HttpJsonTranscodingService.java b/grpc/src/main/java/com/linecorp/armeria/server/grpc/HttpJsonTranscodingService.java
index 2eb9496ad2c5..a87d23138684 100644
--- a/grpc/src/main/java/com/linecorp/armeria/server/grpc/HttpJsonTranscodingService.java
+++ b/grpc/src/main/java/com/linecorp/armeria/server/grpc/HttpJsonTranscodingService.java
@@ -557,8 +557,9 @@ private HttpResponse serve0(ServiceRequestContext ctx, HttpRequest req,
"gRPC encoding is not supported for non-framed requests.");
}
+ final MediaType jsonContentType = GrpcSerializationFormats.JSON.mediaType();
grpcHeaders.method(HttpMethod.POST)
- .contentType(GrpcSerializationFormats.JSON.mediaType());
+ .contentType(jsonContentType);
// All clients support no encoding, and we don't support gRPC encoding for non-framed requests, so just
// clear the header if it's present.
grpcHeaders.remove(GrpcHeaderNames.GRPC_ACCEPT_ENCODING);
@@ -576,7 +577,7 @@ private HttpResponse serve0(ServiceRequestContext ctx, HttpRequest req,
ctx.setAttr(FramedGrpcService.RESOLVED_GRPC_METHOD, spec.method);
frameAndServe(unwrap(), ctx, grpcHeaders.build(),
convertToJson(ctx, clientRequest, spec),
- responseFuture, generateResponseBodyConverter(spec));
+ responseFuture, generateResponseBodyConverter(spec), jsonContentType);
} catch (IllegalArgumentException iae) {
responseFuture.completeExceptionally(
HttpStatusException.of(HttpStatus.BAD_REQUEST, iae));
diff --git a/grpc/src/main/java/com/linecorp/armeria/server/grpc/UnframedGrpcService.java b/grpc/src/main/java/com/linecorp/armeria/server/grpc/UnframedGrpcService.java
index ae9da44a64b2..9318ff68876b 100644
--- a/grpc/src/main/java/com/linecorp/armeria/server/grpc/UnframedGrpcService.java
+++ b/grpc/src/main/java/com/linecorp/armeria/server/grpc/UnframedGrpcService.java
@@ -119,14 +119,16 @@ public HttpResponse serve(ServiceRequestContext ctx, HttpRequest req) throws Exc
final RequestHeadersBuilder grpcHeaders = clientHeaders.toBuilder();
final MediaType framedContentType;
- if (contentType.is(MediaType.PROTOBUF)) {
+ if (contentType.isProtobuf()) {
framedContentType = GrpcSerializationFormats.PROTO.mediaType();
} else if (contentType.is(MediaType.JSON)) {
framedContentType = GrpcSerializationFormats.JSON.mediaType();
} else {
return HttpResponse.of(HttpStatus.UNSUPPORTED_MEDIA_TYPE,
MediaType.PLAIN_TEXT_UTF_8,
- "Unsupported media type. Only application/protobuf is supported.");
+ "Unsupported media type. Only application/protobuf, " +
+ "application/x-protobuf, application/x-google-protobuf" +
+ "and application/json are supported.");
}
grpcHeaders.contentType(framedContentType);
@@ -149,8 +151,8 @@ public HttpResponse serve(ServiceRequestContext ctx, HttpRequest req) throws Exc
if (t != null) {
responseFuture.completeExceptionally(t);
} else {
- frameAndServe(unwrap(), ctx, grpcHeaders.build(),
- clientRequest.content(), responseFuture, null);
+ frameAndServe(unwrap(), ctx, grpcHeaders.build(), clientRequest.content(),
+ responseFuture, null, contentType);
}
}
return null;
diff --git a/grpc/src/test/java/com/linecorp/armeria/server/grpc/UnframedGrpcServiceResponseMediaTypeTest.java b/grpc/src/test/java/com/linecorp/armeria/server/grpc/UnframedGrpcServiceResponseMediaTypeTest.java
new file mode 100644
index 000000000000..d7fea89ace53
--- /dev/null
+++ b/grpc/src/test/java/com/linecorp/armeria/server/grpc/UnframedGrpcServiceResponseMediaTypeTest.java
@@ -0,0 +1,109 @@
+/*
+ * Copyright 2022 LINE Corporation
+ *
+ * LINE Corporation licenses this file to you under the Apache License,
+ * version 2.0 (the "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at:
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ */
+
+package com.linecorp.armeria.server.grpc;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+import java.util.stream.Stream;
+
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtensionContext;
+import org.junit.jupiter.api.extension.RegisterExtension;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.ArgumentsProvider;
+import org.junit.jupiter.params.provider.ArgumentsSource;
+
+import com.linecorp.armeria.common.AggregatedHttpResponse;
+import com.linecorp.armeria.common.HttpMethod;
+import com.linecorp.armeria.common.HttpRequest;
+import com.linecorp.armeria.common.HttpStatus;
+import com.linecorp.armeria.common.MediaType;
+import com.linecorp.armeria.grpc.testing.TestServiceGrpc;
+import com.linecorp.armeria.protobuf.EmptyProtos;
+import com.linecorp.armeria.server.ServiceRequestContext;
+import com.linecorp.armeria.testing.junit5.common.EventLoopExtension;
+
+import io.grpc.BindableService;
+import io.grpc.stub.StreamObserver;
+
+public class UnframedGrpcServiceResponseMediaTypeTest {
+
+ @RegisterExtension
+ static EventLoopExtension eventLoop = new EventLoopExtension();
+
+ private static class TestService extends TestServiceGrpc.TestServiceImplBase {
+
+ @Override
+ public void emptyCall(EmptyProtos.Empty request, StreamObserver responseObserver) {
+ responseObserver.onNext(EmptyProtos.Empty.newBuilder().build());
+ responseObserver.onCompleted();
+ }
+ }
+
+ private static final TestService testService = new TestService();
+ private static final int MAX_MESSAGE_BYTES = 1024;
+
+ @Test
+ void respondWithCorrespondingJsonMediaType() throws Exception {
+ final UnframedGrpcService unframedGrpcService = buildUnframedGrpcService(testService);
+
+ final HttpRequest request = HttpRequest.of(HttpMethod.POST,
+ "/armeria.grpc.testing.TestService/EmptyCall",
+ MediaType.JSON_UTF_8, "{}");
+ final ServiceRequestContext ctx = ServiceRequestContext.builder(request)
+ .build();
+
+ final AggregatedHttpResponse res = unframedGrpcService.serve(ctx, request).aggregate().join();
+ assertThat(res.status()).isEqualTo(HttpStatus.OK);
+ assertThat(res.contentType()).isEqualTo(MediaType.JSON_UTF_8);
+ }
+
+ @ParameterizedTest
+ @ArgumentsSource(ProtobufMediaTypeProvider.class)
+ void respondWithCorrespondingProtobufMediaType(MediaType protobufType) throws Exception {
+ final UnframedGrpcService unframedGrpcService = buildUnframedGrpcService(testService);
+
+ final HttpRequest request = HttpRequest.of(HttpMethod.POST,
+ "/armeria.grpc.testing.TestService/EmptyCall",
+ protobufType,
+ EmptyProtos.Empty.getDefaultInstance().toByteArray());
+ final ServiceRequestContext ctx = ServiceRequestContext.builder(request)
+ .build();
+
+ final AggregatedHttpResponse res = unframedGrpcService.serve(ctx, request).aggregate().join();
+ assertThat(res.status()).isEqualTo(HttpStatus.OK);
+ assertThat(res.contentType()).isEqualTo(protobufType);
+ }
+
+ private static class ProtobufMediaTypeProvider implements ArgumentsProvider {
+ @Override
+ public Stream extends Arguments> provideArguments(ExtensionContext context) {
+ return Stream.of(MediaType.PROTOBUF, MediaType.X_PROTOBUF, MediaType.X_GOOGLE_PROTOBUF)
+ .map(Arguments::of);
+ }
+ }
+
+ private static UnframedGrpcService buildUnframedGrpcService(BindableService bindableService) {
+ return (UnframedGrpcService) GrpcService.builder()
+ .addService(bindableService)
+ .maxRequestMessageLength(MAX_MESSAGE_BYTES)
+ .maxResponseMessageLength(MAX_MESSAGE_BYTES)
+ .enableUnframedRequests(true)
+ .build();
+ }
+}
diff --git a/grpc/src/test/java/com/linecorp/armeria/server/grpc/UnframedGrpcServiceTest.java b/grpc/src/test/java/com/linecorp/armeria/server/grpc/UnframedGrpcServiceTest.java
index 26fb91af80be..906e6a8304a4 100644
--- a/grpc/src/test/java/com/linecorp/armeria/server/grpc/UnframedGrpcServiceTest.java
+++ b/grpc/src/test/java/com/linecorp/armeria/server/grpc/UnframedGrpcServiceTest.java
@@ -122,10 +122,26 @@ void shouldClosePooledObjectsForNonOK() {
final ByteBuf byteBuf = Unpooled.buffer();
final ResponseHeaders responseHeaders = ResponseHeaders.builder(HttpStatus.OK)
.add(GrpcHeaderNames.GRPC_STATUS, "1")
+ .contentType(MediaType.PROTOBUF)
.build();
final AggregatedHttpResponse framedResponse = AggregatedHttpResponse.of(responseHeaders,
HttpData.wrap(byteBuf));
- UnframedGrpcService.deframeAndRespond(ctx, framedResponse, res, UnframedGrpcErrorHandler.of(), null);
+ UnframedGrpcService.deframeAndRespond(ctx, framedResponse, res, UnframedGrpcErrorHandler.of(),
+ null, MediaType.PROTOBUF);
+ assertThat(byteBuf.refCnt()).isZero();
+ }
+
+ @Test
+ void shouldClosePooledObjectsForMissingMediaType() {
+ final CompletableFuture res = new CompletableFuture<>();
+ final ByteBuf byteBuf = Unpooled.buffer();
+ final ResponseHeaders responseHeaders = ResponseHeaders.builder(HttpStatus.OK)
+ .add(GrpcHeaderNames.GRPC_STATUS, "0")
+ .build();
+ final AggregatedHttpResponse framedResponse = AggregatedHttpResponse
+ .of(responseHeaders, HttpData.wrap(byteBuf));
+ AbstractUnframedGrpcService.deframeAndRespond(ctx, framedResponse, res, UnframedGrpcErrorHandler.of(),
+ null, MediaType.PROTOBUF);
assertThat(byteBuf.refCnt()).isZero();
}
@@ -133,13 +149,31 @@ void shouldClosePooledObjectsForNonOK() {
void shouldClosePooledObjectsForMissingGrpcStatus() {
final CompletableFuture res = new CompletableFuture<>();
final ByteBuf byteBuf = Unpooled.buffer();
- final ResponseHeaders responseHeaders = ResponseHeaders.of(HttpStatus.OK);
+ final ResponseHeaders responseHeaders = ResponseHeaders.builder(HttpStatus.OK)
+ .contentType(MediaType.PROTOBUF)
+ .build();
final AggregatedHttpResponse framedResponse = AggregatedHttpResponse.of(responseHeaders,
- HttpData.wrap(byteBuf));
- UnframedGrpcService.deframeAndRespond(ctx, framedResponse, res, UnframedGrpcErrorHandler.of(), null);
+ HttpData.wrap(byteBuf));
+ AbstractUnframedGrpcService.deframeAndRespond(ctx, framedResponse, res, UnframedGrpcErrorHandler.of(),
+ null, MediaType.PROTOBUF);
assertThat(byteBuf.refCnt()).isZero();
}
+ @Test
+ void succeedWithAllRequiredHeaders() throws Exception {
+ final CompletableFuture res = new CompletableFuture<>();
+ final ByteBuf byteBuf = Unpooled.buffer();
+ final ResponseHeaders responseHeaders = ResponseHeaders.builder(HttpStatus.OK)
+ .add(GrpcHeaderNames.GRPC_STATUS, "0")
+ .contentType(MediaType.PROTOBUF)
+ .build();
+ final AggregatedHttpResponse framedResponse = AggregatedHttpResponse
+ .of(responseHeaders, HttpData.wrap(byteBuf));
+ AbstractUnframedGrpcService.deframeAndRespond(ctx, framedResponse, res, UnframedGrpcErrorHandler.of(),
+ null, MediaType.PROTOBUF);
+ assertThat(HttpResponse.from(res).aggregate().get().status()).isEqualTo(HttpStatus.OK);
+ }
+
@Test
void unframedGrpcStatusFunction() throws Exception {
final TestService spyTestService = spy(testService);