diff --git a/Cargo.toml b/Cargo.toml index 9aa2b01eae0..895467ed86a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -111,6 +111,7 @@ full = [ "io_ipc", "io_flight", "io_ipc_write_async", + "io_ipc_read_async", "io_ipc_compression", "io_json_integration", "io_print", @@ -132,6 +133,7 @@ io_csv_write = ["csv", "streaming-iterator", "lexical-core"] io_json = ["serde", "serde_json", "streaming-iterator", "fallible-streaming-iterator", "indexmap", "lexical-core"] io_ipc = ["arrow-format"] io_ipc_write_async = ["io_ipc", "futures"] +io_ipc_read_async = ["io_ipc", "futures"] io_ipc_compression = ["lz4", "zstd"] io_flight = ["io_ipc", "arrow-format/flight-data"] # base64 + io_ipc because arrow schemas are stored as base64-encoded ipc format. diff --git a/src/io/ipc/read/mod.rs b/src/io/ipc/read/mod.rs index 22b3a2fe448..3a45d4ecac6 100644 --- a/src/io/ipc/read/mod.rs +++ b/src/io/ipc/read/mod.rs @@ -16,6 +16,9 @@ mod read_basic; mod reader; mod schema; mod stream; +#[cfg(feature = "io_ipc_read_async")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_read_async")))] +pub mod stream_async; pub use common::{read_dictionary, read_record_batch}; pub use reader::{read_file_metadata, FileMetadata, FileReader}; diff --git a/src/io/ipc/read/schema.rs b/src/io/ipc/read/schema.rs index 7dbeeedfd73..bde3deba1e2 100644 --- a/src/io/ipc/read/schema.rs +++ b/src/io/ipc/read/schema.rs @@ -8,7 +8,10 @@ use crate::{ error::{ArrowError, Result}, }; -use super::super::{IpcField, IpcSchema}; +use super::{ + super::{IpcField, IpcSchema}, + StreamMetadata, +}; fn try_unzip_vec>>(iter: I) -> Result<(Vec, Vec)> { let mut a = vec![]; @@ -370,3 +373,28 @@ pub(super) fn fb_to_schema(schema: arrow_format::ipc::SchemaRef) -> Result<(Sche }, )) } + +pub(super) fn deserialize_stream_metadata(meta: &[u8]) -> Result { + let message = arrow_format::ipc::MessageRef::read_as_root(meta).map_err(|err| { + ArrowError::OutOfSpec(format!("Unable to get root as message: {:?}", err)) + })?; + let version = message.version()?; + // message header is a Schema, so read it + let header = message + .header()? + .ok_or_else(|| ArrowError::oos("Unable to read the first IPC message"))?; + let schema = if let arrow_format::ipc::MessageHeaderRef::Schema(schema) = header { + schema + } else { + return Err(ArrowError::oos( + "The first IPC message of the stream must be a schema", + )); + }; + let (schema, ipc_schema) = fb_to_schema(schema)?; + + Ok(StreamMetadata { + schema, + version, + ipc_schema, + }) +} diff --git a/src/io/ipc/read/stream.rs b/src/io/ipc/read/stream.rs index 370ea9f429d..81ede969cca 100644 --- a/src/io/ipc/read/stream.rs +++ b/src/io/ipc/read/stream.rs @@ -12,7 +12,7 @@ use crate::io::ipc::IpcSchema; use super::super::CONTINUATION_MARKER; use super::common::*; -use super::schema::fb_to_schema; +use super::schema::deserialize_stream_metadata; use super::Dictionaries; /// Metadata of an Arrow IPC stream, written at the start of the stream @@ -45,29 +45,7 @@ pub fn read_stream_metadata(reader: &mut R) -> Result { let mut meta_buffer = vec![0; meta_len as usize]; reader.read_exact(&mut meta_buffer)?; - let message = - arrow_format::ipc::MessageRef::read_as_root(meta_buffer.as_slice()).map_err(|err| { - ArrowError::OutOfSpec(format!("Unable to get root as message: {:?}", err)) - })?; - let version = message.version()?; - // message header is a Schema, so read it - let header = message - .header()? - .ok_or_else(|| ArrowError::oos("Unable to read the first IPC message"))?; - let schema = if let arrow_format::ipc::MessageHeaderRef::Schema(schema) = header { - schema - } else { - return Err(ArrowError::oos( - "The first IPC message of the stream must be a schema", - )); - }; - let (schema, ipc_schema) = fb_to_schema(schema)?; - - Ok(StreamMetadata { - schema, - version, - ipc_schema, - }) + deserialize_stream_metadata(&meta_buffer) } /// Encodes the stream's status after each read. diff --git a/src/io/ipc/read/stream_async.rs b/src/io/ipc/read/stream_async.rs new file mode 100644 index 00000000000..9e054cd07ce --- /dev/null +++ b/src/io/ipc/read/stream_async.rs @@ -0,0 +1,212 @@ +//! APIs to read Arrow streams asynchronously +use std::sync::Arc; + +use arrow_format::ipc::planus::ReadAsRoot; +use futures::future::BoxFuture; +use futures::AsyncRead; +use futures::AsyncReadExt; +use futures::Stream; + +use crate::array::*; +use crate::chunk::Chunk; +use crate::error::{ArrowError, Result}; + +use super::super::CONTINUATION_MARKER; +use super::common::{read_dictionary, read_record_batch}; +use super::schema::deserialize_stream_metadata; +use super::Dictionaries; +use super::StreamMetadata; + +/// A (private) state of stream messages +struct ReadState { + pub reader: R, + pub metadata: StreamMetadata, + pub dictionaries: Dictionaries, + /// The internal buffer to read data inside the messages (records and dictionaries) to + pub data_buffer: Vec, + /// The internal buffer to read messages to + pub message_buffer: Vec, +} + +/// The state of an Arrow stream +enum StreamState { + /// The stream does not contain new chunks (and it has not been closed) + Waiting(ReadState), + /// The stream contain a new chunk + Some((ReadState, Chunk>)), +} + +/// Reads the [`StreamMetadata`] of the Arrow stream asynchronously +pub async fn read_stream_metadata_async( + reader: &mut R, +) -> Result { + // determine metadata length + let mut meta_size: [u8; 4] = [0; 4]; + reader.read_exact(&mut meta_size).await?; + let meta_len = { + // If a continuation marker is encountered, skip over it and read + // the size from the next four bytes. + if meta_size == CONTINUATION_MARKER { + reader.read_exact(&mut meta_size).await?; + } + i32::from_le_bytes(meta_size) + }; + + let mut meta_buffer = vec![0; meta_len as usize]; + reader.read_exact(&mut meta_buffer).await?; + + deserialize_stream_metadata(&meta_buffer) +} + +/// Reads the next item, yielding `None` if the stream has been closed, +/// or a [`StreamState`] otherwise. +async fn maybe_next( + mut state: ReadState, +) -> Result>> { + // determine metadata length + let mut meta_length: [u8; 4] = [0; 4]; + + match state.reader.read_exact(&mut meta_length).await { + Ok(()) => (), + Err(e) => { + return if e.kind() == std::io::ErrorKind::UnexpectedEof { + // Handle EOF without the "0xFFFFFFFF 0x00000000" + // valid according to: + // https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format + Ok(Some(StreamState::Waiting(state))) + } else { + Err(ArrowError::from(e)) + }; + } + } + + let meta_length = { + // If a continuation marker is encountered, skip over it and read + // the size from the next four bytes. + if meta_length == CONTINUATION_MARKER { + state.reader.read_exact(&mut meta_length).await?; + } + i32::from_le_bytes(meta_length) as usize + }; + + if meta_length == 0 { + // the stream has ended, mark the reader as finished + return Ok(None); + } + + state.message_buffer.clear(); + state.message_buffer.resize(meta_length, 0); + state.reader.read_exact(&mut state.message_buffer).await?; + + let message = + arrow_format::ipc::MessageRef::read_as_root(&state.message_buffer).map_err(|err| { + ArrowError::OutOfSpec(format!("Unable to get root as message: {:?}", err)) + })?; + let header = message.header()?.ok_or_else(|| { + ArrowError::oos("IPC: unable to fetch the message header. The file or stream is corrupted.") + })?; + + match header { + arrow_format::ipc::MessageHeaderRef::Schema(_) => Err(ArrowError::oos("A stream ")), + arrow_format::ipc::MessageHeaderRef::RecordBatch(batch) => { + // read the block that makes up the record batch into a buffer + state.data_buffer.clear(); + state.data_buffer.resize(message.body_length()? as usize, 0); + state.reader.read_exact(&mut state.data_buffer).await?; + + read_record_batch( + batch, + &state.metadata.schema.fields, + &state.metadata.ipc_schema, + None, + &state.dictionaries, + state.metadata.version, + &mut std::io::Cursor::new(&state.data_buffer), + 0, + ) + .map(|chunk| Some(StreamState::Some((state, chunk)))) + } + arrow_format::ipc::MessageHeaderRef::DictionaryBatch(batch) => { + // read the block that makes up the dictionary batch into a buffer + let mut buf = vec![0; message.body_length()? as usize]; + state.reader.read_exact(&mut buf).await?; + + let mut dict_reader = std::io::Cursor::new(buf); + + read_dictionary( + batch, + &state.metadata.schema.fields, + &state.metadata.ipc_schema, + &mut state.dictionaries, + &mut dict_reader, + 0, + )?; + + // read the next message until we encounter a Chunk> message + Ok(Some(StreamState::Waiting(state))) + } + t => Err(ArrowError::OutOfSpec(format!( + "Reading types other than record batches not yet supported, unable to read {:?} ", + t + ))), + } +} + +/// A [`Stream`] over an Arrow IPC stream that asynchronously yields [`Chunk`]s. +pub struct AsyncStreamReader { + metadata: StreamMetadata, + future: Option>>>>, +} + +impl AsyncStreamReader { + /// Creates a new [`AsyncStreamReader`] + pub fn new(reader: R, metadata: StreamMetadata) -> Self { + let state = ReadState { + reader, + metadata: metadata.clone(), + dictionaries: Default::default(), + data_buffer: Default::default(), + message_buffer: Default::default(), + }; + let future = Some(Box::pin(maybe_next(state)) as _); + Self { metadata, future } + } + + /// Return the schema of the stream + pub fn metadata(&self) -> &StreamMetadata { + &self.metadata + } +} + +impl Stream for AsyncStreamReader { + type Item = Result>>; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + use std::pin::Pin; + use std::task::Poll; + let me = Pin::into_inner(self); + + match &mut me.future { + Some(fut) => match fut.as_mut().poll(cx) { + Poll::Ready(Ok(None)) => { + me.future = None; + Poll::Ready(None) + } + Poll::Ready(Ok(Some(StreamState::Some((state, batch))))) => { + me.future = Some(Box::pin(maybe_next(state))); + Poll::Ready(Some(Ok(batch))) + } + Poll::Ready(Ok(Some(StreamState::Waiting(_)))) => Poll::Pending, + Poll::Ready(Err(err)) => { + me.future = None; + Poll::Ready(Some(Err(err))) + } + Poll::Pending => Poll::Pending, + }, + None => Poll::Ready(None), + } + } +} diff --git a/tests/it/io/ipc/mod.rs b/tests/it/io/ipc/mod.rs index 6dafdaed47c..e1a5f79b4f0 100644 --- a/tests/it/io/ipc/mod.rs +++ b/tests/it/io/ipc/mod.rs @@ -6,3 +6,6 @@ pub use common::read_gzip_json; #[cfg(feature = "io_ipc_write_async")] mod write_async; + +#[cfg(feature = "io_ipc_read_async")] +mod read_stream_async; diff --git a/tests/it/io/ipc/read_stream_async.rs b/tests/it/io/ipc/read_stream_async.rs new file mode 100644 index 00000000000..593e2b0aeba --- /dev/null +++ b/tests/it/io/ipc/read_stream_async.rs @@ -0,0 +1,45 @@ +use futures::StreamExt; +use tokio::fs::File; +use tokio_util::compat::*; + +use arrow2::error::Result; +use arrow2::io::ipc::read::stream_async::*; + +use crate::io::ipc::common::read_gzip_json; + +async fn test_file(version: &str, file_name: &str) -> Result<()> { + let testdata = crate::test_util::arrow_test_data(); + let mut file = File::open(format!( + "{}/arrow-ipc-stream/integration/{}/{}.stream", + testdata, version, file_name + )) + .await? + .compat(); + + let metadata = read_stream_metadata_async(&mut file).await?; + let mut reader = AsyncStreamReader::new(file, metadata); + + // read expected JSON output + let (schema, ipc_fields, batches) = read_gzip_json(version, file_name)?; + + assert_eq!(&schema, &reader.metadata().schema); + assert_eq!(&ipc_fields, &reader.metadata().ipc_schema.fields); + + let mut items = vec![]; + while let Some(item) = reader.next().await { + items.push(item?) + } + + batches + .iter() + .zip(items.into_iter()) + .for_each(|(lhs, rhs)| { + assert_eq!(lhs, &rhs); + }); + Ok(()) +} + +#[tokio::test] +async fn write_async() -> Result<()> { + test_file("1.0.0-littleendian", "generated_primitive").await +}