Skip to content

Commit

Permalink
feat: record_batch! macro
Browse files Browse the repository at this point in the history
closes: apache#6553
  • Loading branch information
sk314e committed Oct 18, 2024
1 parent 1666a4d commit 7813d8a
Showing 1 changed file with 128 additions and 2 deletions.
130 changes: 128 additions & 2 deletions arrow-array/src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,76 @@ pub struct RecordBatch {
row_count: usize,
}

/// Creates an array from a literal slice of values,
/// suitable for rapid testing and development.
///
/// Example:
///
/// ```rust
///
/// use arrow_array::create_array;
///
/// let array = create_array!(Int32, [1, 2, 3, 4, 5]);
/// ```
#[macro_export]
macro_rules! create_array {
(@array_of Boolean) => { $crate::BooleanArray };
(@array_of Int8) => { $crate::Int8Array };
(@array_of Int16) => { $crate::Int16Array };
(@array_of Int32) => { $crate::Int32Array };
(@array_of Int64) => { $crate::Int64Array };
(@array_of UInt8) => { $crate::UInt8Array };
(@array_of UInt16) => { $crate::UInt16Array };
(@array_of UInt32) => { $crate::UInt32Array };
(@array_of UInt64) => { $crate::UInt64Array };
(@array_of Float16) => { $crate::Float16Array };
(@array_of Float32) => { $crate::Float32Array };
(@array_of Float64) => { $crate::Float64Array };
(@array_of Utf8) => { $crate::StringArray };


($ty: tt, [$($values: expr),*]) => {
std::sync::Arc::new(<$crate::create_array!(@array_of $ty)>::from(vec![$($values),*]))
};
}

/// Creates a record batch from literal slice of values, suitable for rapid
/// testing and development.
///
/// Example:
///
/// ```rust
/// use arrow_array::record_batch;
/// use arrow_schema;
///
/// let batch = record_batch!(
/// ("a", Int32, [1, 2, 3]),
/// ("b", Float64, [Some(4.0), None, Some(5.0)]),
/// ("c", Utf8, ["alpha", "beta", "gamma"])
/// );
/// ```
#[macro_export]
macro_rules! record_batch {
($(($name: expr, $type: ident, [$($values: expr),*])),*) => {
{
let schema = std::sync::Arc::new(arrow_schema::Schema::new(vec![
$(
arrow_schema::Field::new($name, arrow_schema::DataType::$type, true),
)*
]));

let batch = $crate::RecordBatch::try_new(
schema,
vec![$(
$crate::create_array!($type, [$($values),*]),
)*]
);

batch
}
}
}

impl RecordBatch {
/// Creates a `RecordBatch` from a schema and columns.
///
Expand Down Expand Up @@ -623,11 +693,12 @@ where

#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::{any::Any, collections::HashMap};

use super::*;
use crate::{
BooleanArray, Int32Array, Int64Array, Int8Array, ListArray, StringArray, StringViewArray,
BooleanArray, Float64Array, Int32Array, Int64Array, Int8Array, ListArray, StringArray,
StringViewArray,
};
use arrow_buffer::{Buffer, ToByteSlice};
use arrow_data::{ArrayData, ArrayDataBuilder};
Expand All @@ -648,6 +719,61 @@ mod tests {
check_batch(record_batch, 5)
}

fn downcast_as<T: Any>(array: &ArrayRef) -> Result<&T, ArrowError> {
array.as_any().downcast_ref::<T>().ok_or_else(|| {
ArrowError::CastError(format!(
"Cannot downcast array to {}",
std::any::type_name::<T>()
))
})
}

#[test]
fn test_create_record_batch() -> Result<(), ArrowError> {
let result = record_batch!(
("a", Int32, [1, 2, 3, 4]),
("b", Float64, [Some(4.0), None, Some(5.0), None]),
("c", Utf8, ["alpha", "beta", "gamma", "delta"])
);

assert!(result.is_ok());
let batch = result.unwrap();

assert_eq!(3, batch.num_columns());
assert_eq!(4, batch.num_rows());

let values = downcast_as::<Int32Array>(batch.column(0))?
.values()
.iter()
.copied()
.collect::<Vec<_>>();

assert_eq!(values, vec![1, 2, 3, 4]);

let values = downcast_as::<Float64Array>(batch.column(1))?
.values()
.iter()
.copied()
.collect::<Vec<_>>();

assert_eq!(values, vec![4.0, 0.0, 5.0, 0.0]);

let nulls = downcast_as::<Float64Array>(batch.column(1))?
.nulls()
.unwrap()
.iter()
.collect::<Vec<_>>();
assert_eq!(nulls, vec![true, false, true, false]);

let values = downcast_as::<StringArray>(batch.column(2))?
.iter()
.flatten()
.collect::<Vec<_>>();

assert_eq!(values, vec!["alpha", "beta", "gamma", "delta"]);
Ok(())
}

#[test]
fn create_string_view_record_batch() {
let schema = Schema::new(vec![
Expand Down

0 comments on commit 7813d8a

Please sign in to comment.