Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve API for bidi and server streaming calls #130

Merged
merged 8 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,13 @@ package com.connectrpc.conformance

import com.connectrpc.Code
import com.connectrpc.ConnectException
import com.connectrpc.Headers
import com.connectrpc.ProtocolClientConfig
import com.connectrpc.RequestCompression
import com.connectrpc.StreamResult
import com.connectrpc.Trailers
import com.connectrpc.compression.GzipCompressionPool
import com.connectrpc.conformance.ssl.sslContext
import com.connectrpc.conformance.v1.ErrorDetail
import com.connectrpc.conformance.v1.PayloadType
import com.connectrpc.conformance.v1.StreamingOutputCallResponse
import com.connectrpc.conformance.v1.TestServiceClient
import com.connectrpc.conformance.v1.UnimplementedServiceClient
import com.connectrpc.conformance.v1.echoStatus
Expand All @@ -43,7 +41,6 @@ import com.google.protobuf.ByteString
import com.google.protobuf.empty
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withContext
Expand All @@ -63,7 +60,6 @@ import java.time.Duration
import java.util.Base64
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean

@RunWith(Parameterized::class)
class Conformance(
Expand Down Expand Up @@ -177,17 +173,18 @@ class Conformance(
responseParameters += params
},
).getOrThrow()
val results = streamResults(stream.resultChannel())
assertThat(results.cause).isNull()
assertThat(results.code).isEqualTo(Code.OK)
assertThat(results.messages.map { it.payload.type }.toSet()).isEqualTo(setOf(PayloadType.COMPRESSABLE))
assertThat(results.messages.map { it.payload.body.size() }).isEqualTo(sizes)
val responses = mutableListOf<StreamingOutputCallResponse>()
for (response in stream.responseChannel()) {
responses.add(response)
}
assertThat(responses.map { it.payload.type }.toSet()).isEqualTo(setOf(PayloadType.COMPRESSABLE))
assertThat(responses.map { it.payload.body.size() }).isEqualTo(sizes)
}

@Test
fun pingPong(): Unit = runBlocking {
val stream = testServiceConnectClient.fullDuplexCall()
var readHeaders = false
val responseChannel = stream.responseChannel()
listOf(512_000, 16, 2_028, 65_536).forEach {
val param = responseParameters { size = it }
stream.send(
Expand All @@ -196,25 +193,14 @@ class Conformance(
responseParameters += param
},
).getOrThrow()
if (!readHeaders) {
val headersResult = stream.resultChannel().receive()
assertThat(headersResult).isInstanceOf(StreamResult.Headers::class.java)
readHeaders = true
}
val result = stream.resultChannel().receive()
assertThat(result).isInstanceOf(StreamResult.Message::class.java)
val messageResult = result as StreamResult.Message
val payload = messageResult.message.payload
val response = responseChannel.receive()
val payload = response.payload
assertThat(payload.type).isEqualTo(PayloadType.COMPRESSABLE)
assertThat(payload.body).hasSize(it)
}
stream.sendClose()
val results = streamResults(stream.resultChannel())
// We've already read all the messages
assertThat(results.messages).isEmpty()
assertThat(results.cause).isNull()
assertThat(results.code).isEqualTo(Code.OK)
stream.receiveClose()
assertThat(responseChannel.receiveCatching().isClosed).isTrue()
}

@Test
Expand Down Expand Up @@ -244,17 +230,17 @@ class Conformance(
val countDownLatch = CountDownLatch(1)
withContext(Dispatchers.IO) {
val job = async {
val responses = mutableListOf<StreamingOutputCallResponse>()
try {
val result = streamResults(stream.resultChannel())
assertThat(result.messages.map { it.payload.body.size() }).isEqualTo(sizes)
assertThat(result.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
assertThat(result.cause).isInstanceOf(ConnectException::class.java)
val connectException = result.cause as ConnectException
assertThat(connectException.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
assertThat(connectException.message).isEqualTo("soirée 🎉")
assertThat(connectException.unpackedDetails(ErrorDetail::class)).containsExactly(
expectedErrorDetail,
)
for (response in stream.responseChannel()) {
responses.add(response)
}
fail("expected call to fail with ConnectException")
} catch (e: ConnectException) {
assertThat(responses.map { it.payload.body.size() }).isEqualTo(sizes)
assertThat(e.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
assertThat(e.message).isEqualTo("soirée 🎉")
assertThat(e.unpackedDetails(ErrorDetail::class)).containsExactly(expectedErrorDetail)
} finally {
countDownLatch.countDown()
}
Expand Down Expand Up @@ -363,10 +349,11 @@ class Conformance(
withContext(Dispatchers.IO) {
val job = launch {
try {
val result = streamResults(stream.resultChannel())
assertThat(result.cause).isInstanceOf(ConnectException::class.java)
assertThat(result.code)
.withFailMessage { "Expected Code.DEADLINE_EXCEEDED but got ${result.code}" }
stream.responseChannel().receive()
fail("unexpected ConnectException to be thrown")
} catch (e: ConnectException) {
assertThat(e.code)
.withFailMessage { "Expected Code.DEADLINE_EXCEEDED but got ${e.code}" }
.isEqualTo(Code.DEADLINE_EXCEEDED)
} finally {
countDownLatch.countDown()
Expand Down Expand Up @@ -437,11 +424,10 @@ class Conformance(
withContext(Dispatchers.IO) {
val job = async {
try {
val result = streamResults(stream.resultChannel())
assertThat(result.code).isEqualTo(Code.UNIMPLEMENTED)
assertThat(result.cause).isInstanceOf(ConnectException::class.java)
val exception = result.cause as ConnectException
assertThat(exception.code).isEqualTo(Code.UNIMPLEMENTED)
stream.responseChannel().receive()
fail("expected call to fail with a ConnectException")
} catch (e: ConnectException) {
assertThat(e.code).isEqualTo(Code.UNIMPLEMENTED)
} finally {
countDownLatch.countDown()
}
Expand Down Expand Up @@ -801,8 +787,8 @@ class Conformance(
withContext(Dispatchers.IO) {
val job = async {
try {
val result = stream.receiveAndClose().getOrThrow()
assertThat(result.aggregatedPayloadSize).isEqualTo(sum)
val response = stream.receiveAndClose()
assertThat(response.aggregatedPayloadSize).isEqualTo(sum)
} finally {
countDownLatch.countDown()
}
Expand All @@ -813,56 +799,6 @@ class Conformance(
}
}

private data class ServerStreamingResult<Output>(
val headers: Headers,
val messages: List<Output>,
val code: Code,
val trailers: Trailers,
val cause: Throwable?,
)

/*
* Convenience method to return all results (with sanity checking) for calls which stream results from the server
* (bidi and server streaming).
*
* This allows us to easily verify headers, messages, trailers, and errors without having to use fold/maybeFold
* manually in each location.
*/
private suspend fun <Output> streamResults(channel: ReceiveChannel<StreamResult<Output>>): ServerStreamingResult<Output> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now users don't need to handle this complexity to use bidi or server streaming calls.

val seenHeaders = AtomicBoolean(false)
var headers: Headers = emptyMap()
val messages: MutableList<Output> = mutableListOf()
val seenCompletion = AtomicBoolean(false)
var code: Code = Code.UNKNOWN
var trailers: Headers = emptyMap()
var error: Throwable? = null
for (response in channel) {
response.maybeFold(
onHeaders = {
if (!seenHeaders.compareAndSet(false, true)) {
throw IllegalStateException("multiple onHeaders callbacks")
}
headers = it.headers
},
onMessage = {
messages.add(it.message)
},
onCompletion = {
if (!seenCompletion.compareAndSet(false, true)) {
throw IllegalStateException("multiple onCompletion callbacks")
}
code = it.code
trailers = it.trailers
error = it.cause
},
)
}
if (!seenCompletion.get()) {
throw IllegalStateException("didn't get completion message")
}
return ServerStreamingResult(headers, messages, code, trailers, error)
}

private fun b64Encode(trailingValue: ByteArray): String {
return String(Base64.getEncoder().encode(trailingValue))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import android.widget.TextView
import androidx.appcompat.app.AppCompatActivity
import androidx.lifecycle.lifecycleScope
import androidx.recyclerview.widget.RecyclerView
import com.connectrpc.ConnectException
import com.connectrpc.ProtocolClientConfig
import com.connectrpc.eliza.v1.ConverseRequest
import com.connectrpc.eliza.v1.ElizaServiceClient
Expand Down Expand Up @@ -135,29 +136,25 @@ class ElizaChatActivity : AppCompatActivity() {
lifecycleScope.launch(Dispatchers.IO) {
// Initialize a bidi stream with Eliza.
val stream = elizaServiceClient.converse()

for (streamResult in stream.resultChannel()) {
streamResult.maybeFold(
onMessage = { result ->
// A stream message is received: Eliza has said something to us.
val elizaResponse = result.message.sentence
if (elizaResponse?.isNotBlank() == true) {
adapter.add(MessageData(elizaResponse, true))
} else {
// Something odd occurred.
adapter.add(MessageData("...No response from Eliza...", true))
}
},
onCompletion = {
// This should only be called once.
adapter.add(
MessageData(
"Session has ended.",
true,
),
)
},
try {
for (message in stream.responseChannel()) {
// A stream message is received: Eliza has said something to us.
val elizaResponse = message.sentence
if (elizaResponse?.isNotBlank() == true) {
adapter.add(MessageData(elizaResponse, true))
} else {
// Something odd occurred.
adapter.add(MessageData("...No response from Eliza...", true))
}
}
adapter.add(
MessageData(
"Session has ended.",
true,
),
)
} catch (e: ConnectException) {
adapter.add(MessageData("Session failed with code ${e.code}", true))
}
lifecycleScope.launch(Dispatchers.Main) {
buttonView.setOnClickListener {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

package com.connectrpc.examples.kotlin

import com.connectrpc.Code
import com.connectrpc.ConnectException
import com.connectrpc.ProtocolClientConfig
import com.connectrpc.eliza.v1.ElizaServiceClient
Expand Down Expand Up @@ -63,23 +62,8 @@ class Main {
// Add the message the user is sending to the views.
stream.send(converseRequest { sentence = "hello" })
stream.sendClose()
for (streamResult in stream.resultChannel()) {
streamResult.maybeFold(
onMessage = { result ->
// Update the view with the response.
val elizaResponse = result.message
println(elizaResponse.sentence)
},
onCompletion = { result ->
if (result.code != Code.OK) {
val exception = result.connectException()
if (exception != null) {
throw exception
}
throw ConnectException(code = result.code, metadata = result.trailers)
}
},
)
for (response in stream.responseChannel()) {
println(response.sentence)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's an example of the improved API experience for callers.

}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

package com.connectrpc.examples.kotlin

import com.connectrpc.Code
import com.connectrpc.ConnectException
import com.connectrpc.ProtocolClientConfig
import com.connectrpc.eliza.v1.ElizaServiceClient
import com.connectrpc.eliza.v1.converseRequest
Expand Down Expand Up @@ -63,23 +61,8 @@ class Main {
// Add the message the user is sending to the views.
stream.send(converseRequest { sentence = "hello" })
stream.sendClose()
for (streamResult in stream.resultChannel()) {
streamResult.maybeFold(
onMessage = { result ->
// Update the view with the response.
val elizaResponse = result.message
println(elizaResponse.sentence)
},
onCompletion = { result ->
if (result.code != Code.OK) {
val exception = result.connectException()
if (exception != null) {
throw exception
}
throw ConnectException(code = result.code, metadata = result.trailers)
}
},
)
for (response in stream.responseChannel()) {
println(response.sentence)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ import kotlinx.coroutines.channels.ReceiveChannel
*/
interface BidirectionalStreamInterface<Input, Output> {
/**
* The Channel for received StreamResults.
* The Channel for responses.
*
* @return ReceiveChannel for iterating over the received results.
* @return ReceiveChannel for iterating over the responses.
*/
fun resultChannel(): ReceiveChannel<StreamResult<Output>>
fun responseChannel(): ReceiveChannel<Output>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean consumers have no way of reading stream response headers? Where is StreamResult used now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consumers can continue to access stream response headers from interceptors - I've opened #131 with some ideas on making it easier for people to access these based on patterns in connect-go.


/**
* Send a request to the server over the stream.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ interface ClientOnlyStreamInterface<Input, Output> {
/**
* Receive a single response and close the stream.
*
* @return the single response [ResponseMessage].
* @return the single response [Output].
*/
suspend fun receiveAndClose(): ResponseMessage<Output>
suspend fun receiveAndClose(): Output
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now all of the streaming interfaces work the same.


/**
* Close the stream. No calls to [send] are valid after calling [sendClose].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package com.connectrpc

import kotlinx.coroutines.channels.ReceiveChannel

/**
* Represents a server-only stream (a stream where the server streams data to the client after
* receiving an initial request) that can send request messages.
Expand All @@ -25,7 +26,7 @@ interface ServerOnlyStreamInterface<Input, Output> {
*
* @return ReceiveChannel for iterating over the received results.
*/
fun resultChannel(): ReceiveChannel<StreamResult<Output>>
fun responseChannel(): ReceiveChannel<Output>

/**
* Send a request to the server over the stream and closes the request.
Expand Down
Loading
Loading