diff --git a/arrow-array/src/record_batch.rs b/arrow-array/src/record_batch.rs index c56b1fd308cf..f4cfc27f09b1 100644 --- a/arrow-array/src/record_batch.rs +++ b/arrow-array/src/record_batch.rs @@ -79,6 +79,76 @@ pub struct RecordBatch { row_count: usize, } +/// Creates an array from a literal slice of values, +/// suitable for rapid testing and development. +/// +/// Example: +/// +/// ```rust +/// +/// use arrow_array::create_array; +/// +/// let array = create_array!(Int32, [1, 2, 3, 4, 5]); +/// ``` +#[macro_export] +macro_rules! create_array { + (@array_of Boolean) => { $crate::BooleanArray }; + (@array_of Int8) => { $crate::Int8Array }; + (@array_of Int16) => { $crate::Int16Array }; + (@array_of Int32) => { $crate::Int32Array }; + (@array_of Int64) => { $crate::Int64Array }; + (@array_of UInt8) => { $crate::UInt8Array }; + (@array_of UInt16) => { $crate::UInt16Array }; + (@array_of UInt32) => { $crate::UInt32Array }; + (@array_of UInt64) => { $crate::UInt64Array }; + (@array_of Float16) => { $crate::Float16Array }; + (@array_of Float32) => { $crate::Float32Array }; + (@array_of Float64) => { $crate::Float64Array }; + (@array_of Utf8) => { $crate::StringArray }; + + + ($ty: tt, [$($values: expr),*]) => { + std::sync::Arc::new(<$crate::create_array!(@array_of $ty)>::from(vec![$($values),*])) + }; +} + +/// Creates a record batch from literal slice of values, suitable for rapid +/// testing and development. +/// +/// Example: +/// +/// ```rust +/// use arrow_array::record_batch; +/// use arrow_schema; +/// +/// let batch = record_batch!( +/// ("a", Int32, [1, 2, 3]), +/// ("b", Float64, [Some(4.0), None, Some(5.0)]), +/// ("c", Utf8, ["alpha", "beta", "gamma"]) +/// ); +/// ``` +#[macro_export] +macro_rules! record_batch { + ($(($name: expr, $type: ident, [$($values: expr),*])),*) => { + { + let schema = std::sync::Arc::new(arrow_schema::Schema::new(vec![ + $( + arrow_schema::Field::new($name, arrow_schema::DataType::$type, true), + )* + ])); + + let batch = $crate::RecordBatch::try_new( + schema, + vec![$( + $crate::create_array!($type, [$($values),*]), + )*] + ); + + batch + } + } +} + impl RecordBatch { /// Creates a `RecordBatch` from a schema and columns. /// @@ -623,11 +693,12 @@ where #[cfg(test)] mod tests { - use std::collections::HashMap; + use std::{any::Any, collections::HashMap}; use super::*; use crate::{ - BooleanArray, Int32Array, Int64Array, Int8Array, ListArray, StringArray, StringViewArray, + BooleanArray, Float64Array, Int32Array, Int64Array, Int8Array, ListArray, StringArray, + StringViewArray, }; use arrow_buffer::{Buffer, ToByteSlice}; use arrow_data::{ArrayData, ArrayDataBuilder}; @@ -648,6 +719,61 @@ mod tests { check_batch(record_batch, 5) } + fn downcast_as(array: &ArrayRef) -> Result<&T, ArrowError> { + array.as_any().downcast_ref::().ok_or_else(|| { + ArrowError::CastError(format!( + "Cannot downcast array to {}", + std::any::type_name::() + )) + }) + } + + #[test] + fn test_create_record_batch() -> Result<(), ArrowError> { + let result = record_batch!( + ("a", Int32, [1, 2, 3, 4]), + ("b", Float64, [Some(4.0), None, Some(5.0), None]), + ("c", Utf8, ["alpha", "beta", "gamma", "delta"]) + ); + + assert!(result.is_ok()); + let batch = result.unwrap(); + + assert_eq!(3, batch.num_columns()); + assert_eq!(4, batch.num_rows()); + + let values = downcast_as::(batch.column(0))? + .values() + .iter() + .copied() + .collect::>(); + + assert_eq!(values, vec![1, 2, 3, 4]); + + let values = downcast_as::(batch.column(1))? + .values() + .iter() + .copied() + .collect::>(); + + assert_eq!(values, vec![4.0, 0.0, 5.0, 0.0]); + + let nulls = downcast_as::(batch.column(1))? + .nulls() + .unwrap() + .iter() + .collect::>(); + assert_eq!(nulls, vec![true, false, true, false]); + + let values = downcast_as::(batch.column(2))? + .iter() + .flatten() + .collect::>(); + + assert_eq!(values, vec!["alpha", "beta", "gamma", "delta"]); + Ok(()) + } + #[test] fn create_string_view_record_batch() { let schema = Schema::new(vec![