Skip to content

Commit

Permalink
Add example of a streaming response (#1862)
Browse files Browse the repository at this point in the history
* Add example of a streaming response

* Add model example

* Resolve build / lint issues
  • Loading branch information
heaths authored Oct 22, 2024
1 parent d688a1a commit 8a12d75
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 11 deletions.
4 changes: 4 additions & 0 deletions sdk/typespec/typespec_client_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,9 @@ reqwest_rustls = ["reqwest/rustls-tls"]
tokio_sleep = ["tokio/time"]
xml = ["dep:quick-xml"]

[[example]]
name = "stream_response"
required-features = ["derive"]

[package.metadata.docs.rs]
all-features = true
124 changes: 124 additions & 0 deletions sdk/typespec/typespec_client_core/examples/stream_response.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

use futures::StreamExt;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Get a response from a service client.
let response = client::get_binary_data_response()?;

// Normally you'd deserialize into a type or `collect()` the body,
// but this better simulates fetching multiple chunks from a slow response.
let mut body = response.into_body();
while let Some(data) = body.next().await {
// Assume bytes are a string in this example.
let page = String::from_utf8(data?.into())?;
println!("{page}");
}

// You can also deserialize into a model from a slow response.
let team = client::get_model_response()?.deserialize_body().await?;
println!("{team:#?}");

Ok(())
}

#[allow(dead_code)]
mod client {
use futures::Stream;
use serde::Deserialize;
use std::{cmp::min, task::Poll, time::Duration};
use typespec_client_core::{
http::{headers::Headers, Model, Response, StatusCode},
Bytes,
};

#[derive(Debug, Model, Deserialize)]
pub struct Team {
pub name: Option<String>,
#[serde(default)]
pub members: Vec<Person>,
}

#[derive(Debug, Model, Deserialize)]
pub struct Person {
pub id: u32,
pub name: Option<String>,
}

pub fn get_binary_data_response() -> typespec_client_core::Result<Response<()>> {
let bytes = Bytes::from_static(b"Hello, world!");
let response = SlowResponse {
bytes: bytes.repeat(5).into(),
bytes_per_read: bytes.len(),
bytes_read: 0,
};

Ok(Response::new(
StatusCode::Ok,
Headers::new(),
Box::pin(response),
))
}

pub fn get_model_response() -> typespec_client_core::Result<Response<Team>> {
let bytes = br#"{
"name": "Contoso Dev Team",
"members": [
{
"id": 1234,
"name": "Jan"
},
{
"id": 5678,
"name": "Bill"
}
]
}"#;
let response = SlowResponse {
bytes: Bytes::from_static(bytes),
bytes_per_read: 64,
bytes_read: 0,
};

Ok(Response::new(
StatusCode::Ok,
Headers::new(),
Box::pin(response),
))
}

struct SlowResponse {
bytes: Bytes,
bytes_per_read: usize,
bytes_read: usize,
}

impl Stream for SlowResponse {
type Item = typespec_client_core::Result<Bytes>;

fn poll_next(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let self_mut = self.get_mut();
if self_mut.bytes_read < self_mut.bytes.len() {
eprintln!("getting partial response...");
std::thread::sleep(Duration::from_millis(200));

let end = self_mut.bytes_read
+ min(
self_mut.bytes_per_read,
self_mut.bytes.len() - self_mut.bytes_read,
);
let bytes = self_mut.bytes.slice(self_mut.bytes_read..end);
self_mut.bytes_read += bytes.len();
Poll::Ready(Some(Ok(bytes)))
} else {
eprintln!("done");
Poll::Ready(None)
}
}
}
}
1 change: 1 addition & 0 deletions sdk/typespec/typespec_client_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub mod stream;
pub mod xml;

pub use crate::error::{Error, Result};
pub use bytes::Bytes;
pub use uuid::Uuid;

#[cfg(feature = "derive")]
Expand Down
21 changes: 10 additions & 11 deletions sdk/typespec/typespec_client_core/src/stream/bytes_stream.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

use super::SeekableStream;
use bytes::Bytes;
use super::{Bytes, SeekableStream};
use futures::io::AsyncRead;
use futures::stream::Stream;
use std::pin::Pin;
Expand Down Expand Up @@ -111,14 +110,14 @@ mod tests {
use futures::stream::StreamExt;

// Test BytesStream Stream
#[test]
fn test_bytes_stream() {
#[tokio::test]
async fn bytes_stream() {
let bytes = Bytes::from("hello world");
let mut stream = BytesStream::new(bytes.clone());

let mut buf = Vec::new();
let mut bytes_read = 0;
while let Some(Ok(bytes)) = futures::executor::block_on(stream.next()) {
while let Some(Ok(bytes)) = stream.next().await {
buf.extend_from_slice(&bytes);
bytes_read += bytes.len();
}
Expand All @@ -128,26 +127,26 @@ mod tests {
}

// Test BytesStream AsyncRead, all bytes at once
#[test]
fn test_async_read_all_bytes_at_once() {
#[tokio::test]
async fn async_read_all_bytes_at_once() {
let bytes = Bytes::from("hello world");
let mut stream = BytesStream::new(bytes.clone());

let mut buf = [0; 11];
let bytes_read = futures::executor::block_on(stream.read(&mut buf)).unwrap();
let bytes_read = stream.read(&mut buf).await.unwrap();
assert_eq!(bytes_read, 11);
assert_eq!(&buf[..], &bytes);
}

// Test BytesStream AsyncRead, one byte at a time
#[test]
fn test_async_read_one_byte_at_a_time() {
#[tokio::test]
async fn async_read_one_byte_at_a_time() {
let bytes = Bytes::from("hello world");
let mut stream = BytesStream::new(bytes.clone());

for i in 0..bytes.len() {
let mut buf = [0; 1];
let bytes_read = futures::executor::block_on(stream.read(&mut buf)).unwrap();
let bytes_read = stream.read(&mut buf).await.unwrap();
assert_eq!(bytes_read, 1);
assert_eq!(buf[0], bytes[i]);
}
Expand Down

0 comments on commit 8a12d75

Please sign in to comment.