This repository has been archived by the owner on Feb 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 224
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added support to read Arrow streams asynchronously (#832)
- Loading branch information
1 parent
394a699
commit 9e6924b
Showing
7 changed files
with
296 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
//! APIs to read Arrow streams asynchronously | ||
use std::sync::Arc; | ||
|
||
use arrow_format::ipc::planus::ReadAsRoot; | ||
use futures::future::BoxFuture; | ||
use futures::AsyncRead; | ||
use futures::AsyncReadExt; | ||
use futures::Stream; | ||
|
||
use crate::array::*; | ||
use crate::chunk::Chunk; | ||
use crate::error::{ArrowError, Result}; | ||
|
||
use super::super::CONTINUATION_MARKER; | ||
use super::common::{read_dictionary, read_record_batch}; | ||
use super::schema::deserialize_stream_metadata; | ||
use super::Dictionaries; | ||
use super::StreamMetadata; | ||
|
||
/// A (private) state of stream messages | ||
struct ReadState<R> { | ||
pub reader: R, | ||
pub metadata: StreamMetadata, | ||
pub dictionaries: Dictionaries, | ||
/// The internal buffer to read data inside the messages (records and dictionaries) to | ||
pub data_buffer: Vec<u8>, | ||
/// The internal buffer to read messages to | ||
pub message_buffer: Vec<u8>, | ||
} | ||
|
||
/// The state of an Arrow stream | ||
enum StreamState<R> { | ||
/// The stream does not contain new chunks (and it has not been closed) | ||
Waiting(ReadState<R>), | ||
/// The stream contain a new chunk | ||
Some((ReadState<R>, Chunk<Arc<dyn Array>>)), | ||
} | ||
|
||
/// Reads the [`StreamMetadata`] of the Arrow stream asynchronously | ||
pub async fn read_stream_metadata_async<R: AsyncRead + Unpin + Send>( | ||
reader: &mut R, | ||
) -> Result<StreamMetadata> { | ||
// determine metadata length | ||
let mut meta_size: [u8; 4] = [0; 4]; | ||
reader.read_exact(&mut meta_size).await?; | ||
let meta_len = { | ||
// If a continuation marker is encountered, skip over it and read | ||
// the size from the next four bytes. | ||
if meta_size == CONTINUATION_MARKER { | ||
reader.read_exact(&mut meta_size).await?; | ||
} | ||
i32::from_le_bytes(meta_size) | ||
}; | ||
|
||
let mut meta_buffer = vec![0; meta_len as usize]; | ||
reader.read_exact(&mut meta_buffer).await?; | ||
|
||
deserialize_stream_metadata(&meta_buffer) | ||
} | ||
|
||
/// Reads the next item, yielding `None` if the stream has been closed, | ||
/// or a [`StreamState`] otherwise. | ||
async fn maybe_next<R: AsyncRead + Unpin + Send>( | ||
mut state: ReadState<R>, | ||
) -> Result<Option<StreamState<R>>> { | ||
// determine metadata length | ||
let mut meta_length: [u8; 4] = [0; 4]; | ||
|
||
match state.reader.read_exact(&mut meta_length).await { | ||
Ok(()) => (), | ||
Err(e) => { | ||
return if e.kind() == std::io::ErrorKind::UnexpectedEof { | ||
// Handle EOF without the "0xFFFFFFFF 0x00000000" | ||
// valid according to: | ||
// https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format | ||
Ok(Some(StreamState::Waiting(state))) | ||
} else { | ||
Err(ArrowError::from(e)) | ||
}; | ||
} | ||
} | ||
|
||
let meta_length = { | ||
// If a continuation marker is encountered, skip over it and read | ||
// the size from the next four bytes. | ||
if meta_length == CONTINUATION_MARKER { | ||
state.reader.read_exact(&mut meta_length).await?; | ||
} | ||
i32::from_le_bytes(meta_length) as usize | ||
}; | ||
|
||
if meta_length == 0 { | ||
// the stream has ended, mark the reader as finished | ||
return Ok(None); | ||
} | ||
|
||
state.message_buffer.clear(); | ||
state.message_buffer.resize(meta_length, 0); | ||
state.reader.read_exact(&mut state.message_buffer).await?; | ||
|
||
let message = | ||
arrow_format::ipc::MessageRef::read_as_root(&state.message_buffer).map_err(|err| { | ||
ArrowError::OutOfSpec(format!("Unable to get root as message: {:?}", err)) | ||
})?; | ||
let header = message.header()?.ok_or_else(|| { | ||
ArrowError::oos("IPC: unable to fetch the message header. The file or stream is corrupted.") | ||
})?; | ||
|
||
match header { | ||
arrow_format::ipc::MessageHeaderRef::Schema(_) => Err(ArrowError::oos("A stream ")), | ||
arrow_format::ipc::MessageHeaderRef::RecordBatch(batch) => { | ||
// read the block that makes up the record batch into a buffer | ||
state.data_buffer.clear(); | ||
state.data_buffer.resize(message.body_length()? as usize, 0); | ||
state.reader.read_exact(&mut state.data_buffer).await?; | ||
|
||
read_record_batch( | ||
batch, | ||
&state.metadata.schema.fields, | ||
&state.metadata.ipc_schema, | ||
None, | ||
&state.dictionaries, | ||
state.metadata.version, | ||
&mut std::io::Cursor::new(&state.data_buffer), | ||
0, | ||
) | ||
.map(|chunk| Some(StreamState::Some((state, chunk)))) | ||
} | ||
arrow_format::ipc::MessageHeaderRef::DictionaryBatch(batch) => { | ||
// read the block that makes up the dictionary batch into a buffer | ||
let mut buf = vec![0; message.body_length()? as usize]; | ||
state.reader.read_exact(&mut buf).await?; | ||
|
||
let mut dict_reader = std::io::Cursor::new(buf); | ||
|
||
read_dictionary( | ||
batch, | ||
&state.metadata.schema.fields, | ||
&state.metadata.ipc_schema, | ||
&mut state.dictionaries, | ||
&mut dict_reader, | ||
0, | ||
)?; | ||
|
||
// read the next message until we encounter a Chunk<Arc<dyn Array>> message | ||
Ok(Some(StreamState::Waiting(state))) | ||
} | ||
t => Err(ArrowError::OutOfSpec(format!( | ||
"Reading types other than record batches not yet supported, unable to read {:?} ", | ||
t | ||
))), | ||
} | ||
} | ||
|
||
/// A [`Stream`] over an Arrow IPC stream that asynchronously yields [`Chunk`]s. | ||
pub struct AsyncStreamReader<R: AsyncRead + Unpin + Send + 'static> { | ||
metadata: StreamMetadata, | ||
future: Option<BoxFuture<'static, Result<Option<StreamState<R>>>>>, | ||
} | ||
|
||
impl<R: AsyncRead + Unpin + Send + 'static> AsyncStreamReader<R> { | ||
/// Creates a new [`AsyncStreamReader`] | ||
pub fn new(reader: R, metadata: StreamMetadata) -> Self { | ||
let state = ReadState { | ||
reader, | ||
metadata: metadata.clone(), | ||
dictionaries: Default::default(), | ||
data_buffer: Default::default(), | ||
message_buffer: Default::default(), | ||
}; | ||
let future = Some(Box::pin(maybe_next(state)) as _); | ||
Self { metadata, future } | ||
} | ||
|
||
/// Return the schema of the stream | ||
pub fn metadata(&self) -> &StreamMetadata { | ||
&self.metadata | ||
} | ||
} | ||
|
||
impl<R: AsyncRead + Unpin + Send> Stream for AsyncStreamReader<R> { | ||
type Item = Result<Chunk<Arc<dyn Array>>>; | ||
|
||
fn poll_next( | ||
self: std::pin::Pin<&mut Self>, | ||
cx: &mut std::task::Context<'_>, | ||
) -> std::task::Poll<Option<Self::Item>> { | ||
use std::pin::Pin; | ||
use std::task::Poll; | ||
let me = Pin::into_inner(self); | ||
|
||
match &mut me.future { | ||
Some(fut) => match fut.as_mut().poll(cx) { | ||
Poll::Ready(Ok(None)) => { | ||
me.future = None; | ||
Poll::Ready(None) | ||
} | ||
Poll::Ready(Ok(Some(StreamState::Some((state, batch))))) => { | ||
me.future = Some(Box::pin(maybe_next(state))); | ||
Poll::Ready(Some(Ok(batch))) | ||
} | ||
Poll::Ready(Ok(Some(StreamState::Waiting(_)))) => Poll::Pending, | ||
Poll::Ready(Err(err)) => { | ||
me.future = None; | ||
Poll::Ready(Some(Err(err))) | ||
} | ||
Poll::Pending => Poll::Pending, | ||
}, | ||
None => Poll::Ready(None), | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
use futures::StreamExt; | ||
use tokio::fs::File; | ||
use tokio_util::compat::*; | ||
|
||
use arrow2::error::Result; | ||
use arrow2::io::ipc::read::stream_async::*; | ||
|
||
use crate::io::ipc::common::read_gzip_json; | ||
|
||
async fn test_file(version: &str, file_name: &str) -> Result<()> { | ||
let testdata = crate::test_util::arrow_test_data(); | ||
let mut file = File::open(format!( | ||
"{}/arrow-ipc-stream/integration/{}/{}.stream", | ||
testdata, version, file_name | ||
)) | ||
.await? | ||
.compat(); | ||
|
||
let metadata = read_stream_metadata_async(&mut file).await?; | ||
let mut reader = AsyncStreamReader::new(file, metadata); | ||
|
||
// read expected JSON output | ||
let (schema, ipc_fields, batches) = read_gzip_json(version, file_name)?; | ||
|
||
assert_eq!(&schema, &reader.metadata().schema); | ||
assert_eq!(&ipc_fields, &reader.metadata().ipc_schema.fields); | ||
|
||
let mut items = vec![]; | ||
while let Some(item) = reader.next().await { | ||
items.push(item?) | ||
} | ||
|
||
batches | ||
.iter() | ||
.zip(items.into_iter()) | ||
.for_each(|(lhs, rhs)| { | ||
assert_eq!(lhs, &rhs); | ||
}); | ||
Ok(()) | ||
} | ||
|
||
#[tokio::test] | ||
async fn write_async() -> Result<()> { | ||
test_file("1.0.0-littleendian", "generated_primitive").await | ||
} |