diff --git a/stub/src/main/java/io/grpc/kotlin/ServerCalls.kt b/stub/src/main/java/io/grpc/kotlin/ServerCalls.kt index 404990da..d2b532b4 100644 --- a/stub/src/main/java/io/grpc/kotlin/ServerCalls.kt +++ b/stub/src/main/java/io/grpc/kotlin/ServerCalls.kt @@ -26,13 +26,13 @@ import io.grpc.ServerCallHandler import io.grpc.ServerMethodDefinition import io.grpc.Status import io.grpc.StatusException +import io.grpc.StatusRuntimeException import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.cancel import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.onFailure import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.collect import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.map import kotlinx.coroutines.launch @@ -40,7 +40,6 @@ import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock import java.util.concurrent.atomic.AtomicBoolean import kotlin.coroutines.CoroutineContext -import kotlinx.coroutines.channels.onFailure import io.grpc.Metadata as GrpcMetadata /** @@ -262,6 +261,7 @@ object ServerCalls { val closeStatus = when (failure) { null -> Status.OK is CancellationException -> Status.CANCELLED.withCause(failure) + is StatusException, is StatusRuntimeException -> Status.fromThrowable(failure) else -> Status.fromThrowable(failure).withCause(failure) } val trailers = failure?.let { Status.trailersFromThrowable(it) } ?: GrpcMetadata() diff --git a/stub/src/test/java/io/grpc/kotlin/ServerCallsTest.kt b/stub/src/test/java/io/grpc/kotlin/ServerCallsTest.kt index 4996c2aa..d47a4587 100644 --- a/stub/src/test/java/io/grpc/kotlin/ServerCallsTest.kt +++ b/stub/src/test/java/io/grpc/kotlin/ServerCallsTest.kt @@ -20,12 +20,10 @@ import com.google.common.truth.Truth.assertThat import com.google.common.truth.extensions.proto.ProtoTruth.assertThat import io.grpc.* import io.grpc.examples.helloworld.GreeterGrpc -import io.grpc.examples.helloworld.GreeterGrpcKt import io.grpc.examples.helloworld.HelloReply import io.grpc.examples.helloworld.HelloRequest import io.grpc.examples.helloworld.GreeterGrpcKt.GreeterCoroutineStub import io.grpc.examples.helloworld.GreeterGrpcKt.GreeterCoroutineImplBase -import io.grpc.stub.StreamObserver import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.ExperimentalCoroutinesApi @@ -1015,7 +1013,8 @@ class ServerCallsTest : AbstractCallsTest() { } @Test - fun testStatusExceptionPropagatesStack() = runBlocking { + fun testPropagateStackTraceForStatusException() = runBlocking { + val thrownStatusCause = CompletableDeferred() val serverImpl = object : GreeterCoroutineImplBase() { override suspend fun sayHello(request: HelloRequest): HelloReply { @@ -1023,7 +1022,9 @@ class ServerCallsTest : AbstractCallsTest() { } private fun internalServerCall(): Nothing { - throw StatusException(Status.INTERNAL) + val exception = Exception("causal exception") + thrownStatusCause.complete(exception) + throw Status.INTERNAL.withCause(exception).asException() } } @@ -1059,7 +1060,111 @@ class ServerCallsTest : AbstractCallsTest() { assertThat(clientException.status.code).isEqualTo(Status.Code.INTERNAL) val statusCause = receivedStatusCause.await() // but the exception should propagate to server interceptors, with stack trace intact - assertThat(statusCause).isNotNull() + assertThat(statusCause).isEqualTo(thrownStatusCause.await()) + assertThat(statusCause!!.stackTraceToString()).contains("internalServerCall") + } + + @Test + fun testPropagateStackTraceForStatusRuntimeException() = runBlocking { + val thrownStatusCause = CompletableDeferred() + + val serverImpl = object : GreeterCoroutineImplBase() { + override suspend fun sayHello(request: HelloRequest): HelloReply { + internalServerCall() + } + + private fun internalServerCall(): Nothing { + val exception = Exception("causal exception") + thrownStatusCause.complete(exception) + throw Status.INTERNAL.withCause(exception).asRuntimeException() + } + } + + val receivedStatusCause = CompletableDeferred() + + val interceptor = object : ServerInterceptor { + override fun interceptCall( + call: ServerCall, + requestHeaders: Metadata, + next: ServerCallHandler + ): ServerCall.Listener = + next.startCall( + object : ForwardingServerCall.SimpleForwardingServerCall(call) { + override fun close(status: Status, trailers: Metadata) { + receivedStatusCause.complete(status.cause) + super.close(status, trailers) + } + }, + requestHeaders + ) + } + + val channel = makeChannel(serverImpl, interceptor) + + val stub = GreeterGrpc.newBlockingStub(channel) + val clientException = assertThrows { + stub.sayHello(helloRequest("")) + } + + // the exception should not propagate to the client + assertThat(clientException.cause).isNull() + + assertThat(clientException.status.code).isEqualTo(Status.Code.INTERNAL) + val statusCause = receivedStatusCause.await() + // but the exception should propagate to server interceptors, with stack trace intact + assertThat(statusCause).isEqualTo(thrownStatusCause.await()) + assertThat(statusCause!!.stackTraceToString()).contains("internalServerCall") + } + + @Test + fun testPropagateStackTraceForNonStatusException() = runBlocking { + val thrownStatusCause = CompletableDeferred() + + val serverImpl = object : GreeterCoroutineImplBase() { + override suspend fun sayHello(request: HelloRequest): HelloReply { + internalServerCall() + } + + private fun internalServerCall(): Nothing { + val exception = Exception("causal exception") + thrownStatusCause.complete(exception) + throw exception + } + } + + val receivedStatusCause = CompletableDeferred() + + val interceptor = object : ServerInterceptor { + override fun interceptCall( + call: ServerCall, + requestHeaders: Metadata, + next: ServerCallHandler + ): ServerCall.Listener = + next.startCall( + object : ForwardingServerCall.SimpleForwardingServerCall(call) { + override fun close(status: Status, trailers: Metadata) { + receivedStatusCause.complete(status.cause) + super.close(status, trailers) + } + }, + requestHeaders + ) + } + + val channel = makeChannel(serverImpl, interceptor) + + val stub = GreeterGrpc.newBlockingStub(channel) + val clientException = assertThrows { + stub.sayHello(helloRequest("")) + } + + // the exception should not propagate to the client + assertThat(clientException.cause).isNull() + + assertThat(clientException.status.code).isEqualTo(Status.Code.UNKNOWN) + val statusCause = receivedStatusCause.await() + // but the exception should propagate to server interceptors, with stack trace intact + assertThat(statusCause).isEqualTo(thrownStatusCause.await()) assertThat(statusCause!!.stackTraceToString()).contains("internalServerCall") }