Skip to content

Commit

Permalink
Changes in this commit: (#9505) (#9559)
Browse files Browse the repository at this point in the history
- Special handling for identity compressor to ensure compress flag is not turned on. Identity compression is a special case, and for compatibility it should not be treated as a normal compressor.
- If a gRPC service is not found, include ":status" header in response to comply with gRPC spec
- A few null checks were missing in GrpcRouteHandler
- New tests for all the changes above

Co-authored-by: Santiago Pericas-Geertsen <[email protected]>
  • Loading branch information
barchetta and spericas authored Dec 4, 2024
1 parent d9f48b4 commit b30b099
Show file tree
Hide file tree
Showing 8 changed files with 278 additions and 42 deletions.
17 changes: 17 additions & 0 deletions webserver/grpc/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,23 @@
</dependency>
</dependencies>
</plugin>
<plugin>
<groupId>org.xolstice.maven.plugins</groupId>
<artifactId>protobuf-maven-plugin</artifactId>
<executions>
<execution>
<goals>
<goal>test-compile</goal>
</goals>
</execution>
</executions>
<configuration>
<protocArtifact>com.google.protobuf:protoc:${version.lib.google-protobuf}:exe:${os.detected.classifier}</protocArtifact>
<pluginId>grpc-java</pluginId>
<pluginArtifact>io.grpc:protoc-gen-grpc-java:${version.lib.grpc}:exe:${os.detected.classifier}
</pluginArtifact>
</configuration>
</plugin>
</plugins>
</build>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import io.helidon.http.http2.StreamFlowControl;
import io.helidon.webserver.http2.spi.Http2SubProtocolSelector;

import io.grpc.Codec;
import io.grpc.Compressor;
import io.grpc.CompressorRegistry;
import io.grpc.Decompressor;
Expand Down Expand Up @@ -90,6 +91,7 @@ class GrpcProtocolHandler<REQ, RES> implements Http2SubProtocolSelector.SubProto
private BufferData entityBytes;
private Compressor compressor;
private Decompressor decompressor;
private boolean isIdentityCompressor;

GrpcProtocolHandler(HttpPrologue prologue,
Http2Headers headers,
Expand All @@ -115,41 +117,10 @@ class GrpcProtocolHandler<REQ, RES> implements Http2SubProtocolSelector.SubProto
public void init() {
try {
ServerCall<REQ, RES> serverCall = createServerCall();

Headers httpHeaders = headers.httpHeaders();

// check for encoding and respond using same algorithm
if (httpHeaders.contains(GRPC_ENCODING)) {
Header grpcEncoding = httpHeaders.get(GRPC_ENCODING);
String encoding = grpcEncoding.asString().get();
decompressor = DECOMPRESSOR_REGISTRY.lookupDecompressor(encoding);
compressor = COMPRESSOR_REGISTRY.lookupCompressor(encoding);

// report encoding not supported
if (decompressor == null || compressor == null) {
Metadata metadata = new Metadata();
Set<String> encodings = DECOMPRESSOR_REGISTRY.getAdvertisedMessageEncodings();
metadata.put(Metadata.Key.of(GRPC_ACCEPT_ENCODING.defaultCase(), Metadata.ASCII_STRING_MARSHALLER),
String.join(",", encodings));
serverCall.close(Status.UNIMPLEMENTED, metadata);
currentStreamState = Http2StreamState.CLOSED; // stops processing
return;
}
} else if (httpHeaders.contains(GRPC_ACCEPT_ENCODING)) {
Header acceptEncoding = httpHeaders.get(GRPC_ACCEPT_ENCODING);

// check for matching encoding
for (String encoding : acceptEncoding.allValues()) {
compressor = COMPRESSOR_REGISTRY.lookupCompressor(encoding);
if (compressor != null) {
decompressor = DECOMPRESSOR_REGISTRY.lookupDecompressor(encoding);
if (decompressor != null) {
break; // found match
}
compressor = null;
}
}
}
// setup compression
initCompression(serverCall, httpHeaders);

// initiate server call
ServerCallHandler<REQ, RES> callHandler = route.callHandler();
Expand All @@ -161,10 +132,6 @@ public void init() {
}
}

private void addNumMessages(int n) {
numMessages.getAndAdd(n);
}

@Override
public Http2StreamState streamState() {
return currentStreamState;
Expand Down Expand Up @@ -224,6 +191,52 @@ public void data(Http2FrameHeader header, BufferData data) {
}
}

void initCompression(ServerCall<REQ, RES> serverCall, Headers httpHeaders) {
// check for encoding and respond using same algorithm
if (httpHeaders.contains(GRPC_ENCODING)) {
Header grpcEncoding = httpHeaders.get(GRPC_ENCODING);
String encoding = grpcEncoding.asString().get();
decompressor = DECOMPRESSOR_REGISTRY.lookupDecompressor(encoding);
compressor = COMPRESSOR_REGISTRY.lookupCompressor(encoding);

// report encoding not supported
if (decompressor == null || compressor == null) {
Metadata metadata = new Metadata();
Set<String> encodings = DECOMPRESSOR_REGISTRY.getAdvertisedMessageEncodings();
metadata.put(Metadata.Key.of(GRPC_ACCEPT_ENCODING.defaultCase(), Metadata.ASCII_STRING_MARSHALLER),
String.join(",", encodings));
serverCall.close(Status.UNIMPLEMENTED, metadata);
currentStreamState = Http2StreamState.CLOSED; // stops processing
return;
}
} else if (httpHeaders.contains(GRPC_ACCEPT_ENCODING)) {
Header acceptEncoding = httpHeaders.get(GRPC_ACCEPT_ENCODING);

// check for matching encoding
for (String encoding : acceptEncoding.allValues()) {
compressor = COMPRESSOR_REGISTRY.lookupCompressor(encoding);
if (compressor != null) {
decompressor = DECOMPRESSOR_REGISTRY.lookupDecompressor(encoding);
if (decompressor != null) {
break; // found match
}
compressor = null;
}
}
}

// special handling for identity compressor
isIdentityCompressor = (compressor instanceof Codec.Identity);
}

boolean isIdentityCompressor() {
return isIdentityCompressor;
}

private void addNumMessages(int n) {
numMessages.getAndAdd(n);
}

private void flushQueue() {
if (listener != null) {
while (!listenerQueue.isEmpty() && numMessages.getAndDecrement() > 0) {
Expand Down Expand Up @@ -268,10 +281,10 @@ public void sendMessage(RES message) {
try (InputStream inputStream = route.method().streamResponse(message)) {
// prepare buffer for writing
BufferData bufferData;
if (compressor == null) {
if (compressor == null || isIdentityCompressor) {
byte[] bytes = inputStream.readAllBytes();
bufferData = BufferData.create(5 + bytes.length);
bufferData.write(0);
bufferData.write(0); // off for identity compressor
bufferData.writeUnsignedInt32(bytes.length);
bufferData.write(bytes);
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, 2023 Oracle and/or its affiliates.
* Copyright (c) 2022, 2024 Oracle and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@
package io.helidon.webserver.grpc;

import io.helidon.common.buffers.BufferData;
import io.helidon.http.Status;
import io.helidon.http.WritableHeaders;
import io.helidon.http.http2.FlowControl;
import io.helidon.http.http2.Http2Flag;
Expand Down Expand Up @@ -45,6 +46,7 @@ class GrpcProtocolHandlerNotFound implements Http2SubProtocolSelector.SubProtoco
@Override
public void init() {
WritableHeaders<?> writable = WritableHeaders.create();
writable.set(Http2Headers.STATUS_NAME, Status.NOT_FOUND_404.code());
writable.set(GrpcStatus.NOT_FOUND);
Http2Headers http2Headers = Http2Headers.create(writable);
streamWriter.writeHeaders(http2Headers,
Expand All @@ -70,5 +72,4 @@ public void windowUpdate(Http2WindowUpdate update) {
@Override
public void data(Http2FrameHeader header, BufferData data) {
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,14 @@ private static <ResT, ReqT> GrpcRouteHandler<ReqT, ResT> grpc(Descriptors.FileDe
String methodName,
ServerCallHandler<ReqT, ResT> callHandler) {
Descriptors.ServiceDescriptor svc = proto.findServiceByName(serviceName);
if (svc == null) {
throw new IllegalArgumentException("Unable to find gRPC service '" + serviceName + "'");
}
Descriptors.MethodDescriptor mtd = svc.findMethodByName(methodName);

if (mtd == null) {
throw new IllegalArgumentException("Unable to find gRPC method '" + methodName
+ "' in service '" + serviceName + "'");
}
String path = svc.getFullName() + "/" + methodName;

/*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright (c) 2024 Oracle and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.helidon.webserver.grpc;

import io.helidon.http.Status;
import io.helidon.http.http2.FlowControl;
import io.helidon.http.http2.Http2Flag;
import io.helidon.http.http2.Http2FrameData;
import io.helidon.http.http2.Http2Headers;
import io.helidon.http.http2.Http2StreamState;
import io.helidon.http.http2.Http2StreamWriter;
import org.junit.jupiter.api.Test;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.CoreMatchers.is;

class GrpcProtocolHandlerNotFoundTest {

private boolean validateHeaders;

@Test
void testNotFoundHeaders() {
Http2StreamWriter writer = new Http2StreamWriter() {
@Override
public void write(Http2FrameData frame) {
throw new UnsupportedOperationException("Unsupported");
}

@Override
public void writeData(Http2FrameData frame, FlowControl.Outbound flowControl) {
throw new UnsupportedOperationException("Unsupported");

}

@Override
public int writeHeaders(Http2Headers headers, int streamId, Http2Flag.HeaderFlags flags, FlowControl.Outbound flowControl) {
validateHeaders = (headers.status() == Status.NOT_FOUND_404);
try {
headers.validateResponse();
} catch (Exception e) {
validateHeaders = false;
}
return 0;
}

@Override
public int writeHeaders(Http2Headers headers, int streamId, Http2Flag.HeaderFlags flags, Http2FrameData dataFrame, FlowControl.Outbound flowControl) {
throw new UnsupportedOperationException("Unsupported");
}
};
GrpcProtocolHandlerNotFound handler = new GrpcProtocolHandlerNotFound(writer, 1, Http2StreamState.OPEN);
assertThat(validateHeaders, is(false));
handler.init();
assertThat(validateHeaders, is(true));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright (c) 2024 Oracle and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.helidon.webserver.grpc;

import io.helidon.http.HeaderName;
import io.helidon.http.HeaderNames;
import io.helidon.http.WritableHeaders;
import io.helidon.http.http2.Http2Headers;
import io.helidon.http.http2.Http2Settings;
import io.helidon.http.http2.Http2StreamState;
import org.junit.jupiter.api.Test;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.CoreMatchers.is;

class GrpcProtocolHandlerTest {

private static final HeaderName GRPC_ACCEPT_ENCODING = HeaderNames.create("grpc-accept-encoding");

@Test
@SuppressWarnings("unchecked")
void testIdentityCompressorFlag() {
WritableHeaders<?> headers = WritableHeaders.create();
headers.add(GRPC_ACCEPT_ENCODING, "identity");
GrpcProtocolHandler handler = new GrpcProtocolHandler(null,
Http2Headers.create(headers),
null,
1,
Http2Settings.builder().build(),
Http2Settings.builder().build(),
null,
Http2StreamState.OPEN,
null);
handler.initCompression(null, headers);
assertThat(handler.isIdentityCompressor(), is(true));
}

@Test
@SuppressWarnings("unchecked")
void testGzipCompressor() {
WritableHeaders<?> headers = WritableHeaders.create();
headers.add(GRPC_ACCEPT_ENCODING, "gzip");
GrpcProtocolHandler handler = new GrpcProtocolHandler(null,
Http2Headers.create(headers),
null,
1,
Http2Settings.builder().build(),
Http2Settings.builder().build(),
null,
Http2StreamState.OPEN,
null);
handler.initCompression(null, headers);
assertThat(handler.isIdentityCompressor(), is(false));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright (c) 2024 Oracle and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.helidon.webserver.grpc;

import com.google.protobuf.Descriptors;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertThrows;

public class GrpcRouteHandlerTest {

@Test
void testBadServiceNames() throws Descriptors.DescriptorValidationException {
assertThrows(IllegalArgumentException.class,
() -> GrpcRouteHandler.unary(Strings.getDescriptor(), "foo", "Upper", null));
assertThrows(IllegalArgumentException.class,
() -> GrpcRouteHandler.unary(Strings.getDescriptor(), "StringService", "foo", null));
}
}
Loading

0 comments on commit b30b099

Please sign in to comment.