diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 393e4afdf865..5a913e91b0f8 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -1293,7 +1293,6 @@ mod tests { use crate::execution::context::QueryPlanner; use crate::from_slice::FromSlice; use crate::logical_plan::{binary_expr, lit, Operator}; - use crate::physical_plan::collect; use crate::physical_plan::functions::{make_scalar_function, Volatility}; use crate::test; use crate::variable::VarType; @@ -1302,16 +1301,14 @@ mod tests { logical_plan::{col, create_udf, sum, Expr}, }; use crate::{ - datasource::{empty::EmptyTable, MemTable}, - logical_plan::create_udaf, + datasource::MemTable, logical_plan::create_udaf, physical_plan::expressions::AvgAccumulator, }; use arrow::array::{ Array, ArrayRef, DictionaryArray, Float32Array, Float64Array, Int16Array, - Int32Array, Int64Array, Int8Array, LargeStringArray, StringArray, UInt16Array, - UInt32Array, UInt64Array, UInt8Array, + Int32Array, Int64Array, Int8Array, LargeStringArray, UInt16Array, UInt32Array, + UInt64Array, UInt8Array, }; - use arrow::compute::add; use arrow::datatypes::*; use arrow::record_batch::RecordBatch; use async_trait::async_trait; @@ -1399,79 +1396,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn sort() -> Result<()> { - let results = - execute("SELECT c1, c2 FROM test ORDER BY c1 DESC, c2 ASC", 4).await?; - assert_eq!(results.len(), 1); - - let expected: Vec<&str> = vec![ - "+----+----+", - "| c1 | c2 |", - "+----+----+", - "| 3 | 1 |", - "| 3 | 2 |", - "| 3 | 3 |", - "| 3 | 4 |", - "| 3 | 5 |", - "| 3 | 6 |", - "| 3 | 7 |", - "| 3 | 8 |", - "| 3 | 9 |", - "| 3 | 10 |", - "| 2 | 1 |", - "| 2 | 2 |", - "| 2 | 3 |", - "| 2 | 4 |", - "| 2 | 5 |", - "| 2 | 6 |", - "| 2 | 7 |", - "| 2 | 8 |", - "| 2 | 9 |", - "| 2 | 10 |", - "| 1 | 1 |", - "| 1 | 2 |", - "| 1 | 3 |", - "| 1 | 4 |", - "| 1 | 5 |", - "| 1 | 6 |", - "| 1 | 7 |", - "| 1 | 8 |", - "| 1 | 9 |", - "| 1 | 10 |", - "| 0 | 1 |", - "| 0 | 2 |", - "| 0 | 3 |", - "| 0 | 4 |", - "| 0 | 5 |", - "| 0 | 6 |", - "| 0 | 7 |", - "| 0 | 8 |", - "| 0 | 9 |", - "| 0 | 10 |", - "+----+----+", - ]; - - // Note it is important to NOT use assert_batches_sorted_eq - // here as we are testing the sortedness of the output - assert_batches_eq!(expected, &results); - - Ok(()) - } - - #[tokio::test] - async fn sort_empty() -> Result<()> { - // The predicate on this query purposely generates no results - let results = execute( - "SELECT c1, c2 FROM test WHERE c1 > 100000 ORDER BY c1 DESC, c2 ASC", - 4, - ) - .await - .unwrap(); - assert_eq!(results.len(), 0); - Ok(()) - } - #[tokio::test] async fn left_join_using() -> Result<()> { let results = execute( @@ -2828,7 +2752,7 @@ mod tests { let tmp_dir = TempDir::new()?; let mut ctx = create_ctx(&tmp_dir, 4).await?; - // execute a simple query and write the results to CSV + // execute a simple query and write the results to parquet let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out"; write_parquet(&mut ctx, "SELECT c1, c2 FROM test", &out_dir, None).await?; @@ -2939,111 +2863,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn scalar_udf() -> Result<()> { - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - ]); - - let batch = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![ - Arc::new(Int32Array::from_slice(&[1, 10, 10, 100])), - Arc::new(Int32Array::from_slice(&[2, 12, 12, 120])), - ], - )?; - - let mut ctx = ExecutionContext::new(); - - let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]])?; - ctx.register_table("t", Arc::new(provider))?; - - let myfunc = |args: &[ArrayRef]| { - let l = &args[0] - .as_any() - .downcast_ref::() - .expect("cast failed"); - let r = &args[1] - .as_any() - .downcast_ref::() - .expect("cast failed"); - Ok(Arc::new(add(l, r)?) as ArrayRef) - }; - let myfunc = make_scalar_function(myfunc); - - ctx.register_udf(create_udf( - "my_add", - vec![DataType::Int32, DataType::Int32], - Arc::new(DataType::Int32), - Volatility::Immutable, - myfunc, - )); - - // from here on, we may be in a different scope. We would still like to be able - // to call UDFs. - - let t = ctx.table("t")?; - - let plan = LogicalPlanBuilder::from(t.to_logical_plan()) - .project(vec![ - col("a"), - col("b"), - ctx.udf("my_add")?.call(vec![col("a"), col("b")]), - ])? - .build()?; - - assert_eq!( - format!("{:?}", plan), - "Projection: #t.a, #t.b, my_add(#t.a, #t.b)\n TableScan: t projection=None" - ); - - let plan = ctx.optimize(&plan)?; - let plan = ctx.create_physical_plan(&plan).await?; - let runtime = ctx.state.lock().runtime_env.clone(); - let result = collect(plan, runtime).await?; - - let expected = vec![ - "+-----+-----+-----------------+", - "| a | b | my_add(t.a,t.b) |", - "+-----+-----+-----------------+", - "| 1 | 2 | 3 |", - "| 10 | 12 | 22 |", - "| 10 | 12 | 22 |", - "| 100 | 120 | 220 |", - "+-----+-----+-----------------+", - ]; - assert_batches_eq!(expected, &result); - - let batch = &result[0]; - let a = batch - .column(0) - .as_any() - .downcast_ref::() - .expect("failed to cast a"); - let b = batch - .column(1) - .as_any() - .downcast_ref::() - .expect("failed to cast b"); - let sum = batch - .column(2) - .as_any() - .downcast_ref::() - .expect("failed to cast sum"); - - assert_eq!(4, a.len()); - assert_eq!(4, b.len()); - assert_eq!(4, sum.len()); - for i in 0..sum.len() { - assert_eq!(a.value(i) + b.value(i), sum.value(i)); - } - - ctx.deregister_table("t")?; - - Ok(()) - } - #[tokio::test] async fn simple_avg() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); @@ -3080,52 +2899,6 @@ mod tests { Ok(()) } - /// tests the creation, registration and usage of a UDAF - #[tokio::test] - async fn simple_udaf() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - - let batch1 = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(Int32Array::from_slice(&[1, 2, 3]))], - )?; - let batch2 = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(Int32Array::from_slice(&[4, 5]))], - )?; - - let mut ctx = ExecutionContext::new(); - - let provider = - MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; - ctx.register_table("t", Arc::new(provider))?; - - // define a udaf, using a DataFusion's accumulator - let my_avg = create_udaf( - "my_avg", - DataType::Float64, - Arc::new(DataType::Float64), - Volatility::Immutable, - Arc::new(|| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))), - Arc::new(vec![DataType::UInt64, DataType::Float64]), - ); - - ctx.register_udaf(my_avg); - - let result = plan_and_collect(&mut ctx, "SELECT MY_AVG(a) FROM t").await?; - - let expected = vec![ - "+-------------+", - "| my_avg(t.a) |", - "+-------------+", - "| 3 |", - "+-------------+", - ]; - assert_batches_eq!(expected, &result); - - Ok(()) - } - #[tokio::test] async fn custom_query_planner() -> Result<()> { let mut ctx = ExecutionContext::with_config( @@ -3234,65 +3007,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn create_external_table_with_timestamps() { - let mut ctx = ExecutionContext::new(); - - let data = "Jorge,2018-12-13T12:12:10.011Z\n\ - Andrew,2018-11-13T17:11:10.011Z"; - - let tmp_dir = TempDir::new().unwrap(); - let file_path = tmp_dir.path().join("timestamps.csv"); - - // scope to ensure the file is closed and written - { - File::create(&file_path) - .expect("creating temp file") - .write_all(data.as_bytes()) - .expect("writing data"); - } - - let sql = format!( - "CREATE EXTERNAL TABLE csv_with_timestamps ( - name VARCHAR, - ts TIMESTAMP - ) - STORED AS CSV - LOCATION '{}' - ", - file_path.to_str().expect("path is utf8") - ); - - plan_and_collect(&mut ctx, &sql) - .await - .expect("Executing CREATE EXTERNAL TABLE"); - - let sql = "SELECT * from csv_with_timestamps"; - let result = plan_and_collect(&mut ctx, sql).await.unwrap(); - let expected = vec![ - "+--------+-------------------------+", - "| name | ts |", - "+--------+-------------------------+", - "| Andrew | 2018-11-13 17:11:10.011 |", - "| Jorge | 2018-12-13 12:12:10.011 |", - "+--------+-------------------------+", - ]; - assert_batches_sorted_eq!(expected, &result); - } - - #[tokio::test] - async fn query_empty_table() { - let mut ctx = ExecutionContext::new(); - let empty_table = Arc::new(EmptyTable::new(Arc::new(Schema::empty()))); - ctx.register_table("test_tbl", empty_table).unwrap(); - let sql = "SELECT * FROM test_tbl"; - let result = plan_and_collect(&mut ctx, sql) - .await - .expect("Query empty table"); - let expected = vec!["++", "++"]; - assert_batches_sorted_eq!(expected, &result); - } - #[tokio::test] async fn catalogs_not_leaked() { // the information schema used to introduce cyclic Arcs @@ -3316,60 +3030,6 @@ mod tests { assert_eq!(Weak::strong_count(&catalog_weak), 0); } - #[tokio::test] - async fn schema_merge_ignores_metadata() { - // Create two parquet files in same table with same schema but different metadata - let tmp_dir = TempDir::new().unwrap(); - let table_dir = tmp_dir.path().join("parquet_test"); - let table_path = Path::new(&table_dir); - - let mut non_empty_metadata: HashMap = HashMap::new(); - non_empty_metadata.insert("testing".to_string(), "metadata".to_string()); - - let fields = vec![ - Field::new("id", DataType::Int32, true), - Field::new("name", DataType::Utf8, true), - ]; - let schemas = vec![ - Arc::new(Schema::new_with_metadata( - fields.clone(), - non_empty_metadata.clone(), - )), - Arc::new(Schema::new(fields.clone())), - ]; - - if let Ok(()) = fs::create_dir(table_path) { - for (i, schema) in schemas.iter().enumerate().take(2) { - let filename = format!("part-{}.parquet", i); - let path = table_path.join(&filename); - let file = fs::File::create(path).unwrap(); - let mut writer = - ArrowWriter::try_new(file.try_clone().unwrap(), schema.clone(), None) - .unwrap(); - - // create mock record batch - let ids = Arc::new(Int32Array::from_slice(&[i as i32])); - let names = Arc::new(StringArray::from_slice(&["test"])); - let rec_batch = - RecordBatch::try_new(schema.clone(), vec![ids, names]).unwrap(); - - writer.write(&rec_batch).unwrap(); - writer.close().unwrap(); - } - } - - // Read the parquet files into a dataframe to confirm results - // (no errors) - let mut ctx = ExecutionContext::new(); - let df = ctx - .read_parquet(table_dir.to_str().unwrap().to_string()) - .await - .unwrap(); - let result = df.collect().await.unwrap(); - - assert_eq!(result[0].schema().metadata(), result[1].schema().metadata()); - } - #[tokio::test] async fn normalized_column_identifiers() { // create local execution context diff --git a/datafusion/tests/sql/create_drop.rs b/datafusion/tests/sql/create_drop.rs index 7dcca46710b7..45f2a36047c5 100644 --- a/datafusion/tests/sql/create_drop.rs +++ b/datafusion/tests/sql/create_drop.rs @@ -15,6 +15,10 @@ // specific language governing permissions and limitations // under the License. +use std::io::Write; + +use tempfile::TempDir; + use super::*; #[tokio::test] @@ -76,3 +80,49 @@ async fn csv_query_create_external_table() { ]; assert_batches_eq!(expected, &actual); } + +#[tokio::test] +async fn create_external_table_with_timestamps() { + let mut ctx = ExecutionContext::new(); + + let data = "Jorge,2018-12-13T12:12:10.011Z\n\ + Andrew,2018-11-13T17:11:10.011Z"; + + let tmp_dir = TempDir::new().unwrap(); + let file_path = tmp_dir.path().join("timestamps.csv"); + + // scope to ensure the file is closed and written + { + std::fs::File::create(&file_path) + .expect("creating temp file") + .write_all(data.as_bytes()) + .expect("writing data"); + } + + let sql = format!( + "CREATE EXTERNAL TABLE csv_with_timestamps ( + name VARCHAR, + ts TIMESTAMP + ) + STORED AS CSV + LOCATION '{}' + ", + file_path.to_str().expect("path is utf8") + ); + + plan_and_collect(&mut ctx, &sql) + .await + .expect("Executing CREATE EXTERNAL TABLE"); + + let sql = "SELECT * from csv_with_timestamps"; + let result = plan_and_collect(&mut ctx, sql).await.unwrap(); + let expected = vec![ + "+--------+-------------------------+", + "| name | ts |", + "+--------+-------------------------+", + "| Andrew | 2018-11-13 17:11:10.011 |", + "| Jorge | 2018-12-13 12:12:10.011 |", + "+--------+-------------------------+", + ]; + assert_batches_sorted_eq!(expected, &result); +} diff --git a/datafusion/tests/sql/mod.rs b/datafusion/tests/sql/mod.rs index 468762ea05bb..d9088f51820f 100644 --- a/datafusion/tests/sql/mod.rs +++ b/datafusion/tests/sql/mod.rs @@ -883,7 +883,7 @@ async fn nyc() -> Result<()> { }, _ => unreachable!(), }, - _ => unreachable!(false), + _ => unreachable!(), } Ok(()) diff --git a/datafusion/tests/sql/order.rs b/datafusion/tests/sql/order.rs index fa59d9d19661..d23c81778951 100644 --- a/datafusion/tests/sql/order.rs +++ b/datafusion/tests/sql/order.rs @@ -124,3 +124,77 @@ async fn test_specific_nulls_first_asc() -> Result<()> { assert_batches_eq!(expected, &actual); Ok(()) } + +#[tokio::test] +async fn sort() -> Result<()> { + let results = + partitioned_csv::execute("SELECT c1, c2 FROM test ORDER BY c1 DESC, c2 ASC", 4) + .await?; + assert_eq!(results.len(), 1); + + let expected: Vec<&str> = vec![ + "+----+----+", + "| c1 | c2 |", + "+----+----+", + "| 3 | 1 |", + "| 3 | 2 |", + "| 3 | 3 |", + "| 3 | 4 |", + "| 3 | 5 |", + "| 3 | 6 |", + "| 3 | 7 |", + "| 3 | 8 |", + "| 3 | 9 |", + "| 3 | 10 |", + "| 2 | 1 |", + "| 2 | 2 |", + "| 2 | 3 |", + "| 2 | 4 |", + "| 2 | 5 |", + "| 2 | 6 |", + "| 2 | 7 |", + "| 2 | 8 |", + "| 2 | 9 |", + "| 2 | 10 |", + "| 1 | 1 |", + "| 1 | 2 |", + "| 1 | 3 |", + "| 1 | 4 |", + "| 1 | 5 |", + "| 1 | 6 |", + "| 1 | 7 |", + "| 1 | 8 |", + "| 1 | 9 |", + "| 1 | 10 |", + "| 0 | 1 |", + "| 0 | 2 |", + "| 0 | 3 |", + "| 0 | 4 |", + "| 0 | 5 |", + "| 0 | 6 |", + "| 0 | 7 |", + "| 0 | 8 |", + "| 0 | 9 |", + "| 0 | 10 |", + "+----+----+", + ]; + + // Note it is important to NOT use assert_batches_sorted_eq + // here as we are testing the sortedness of the output + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn sort_empty() -> Result<()> { + // The predicate on this query purposely generates no results + let results = partitioned_csv::execute( + "SELECT c1, c2 FROM test WHERE c1 > 100000 ORDER BY c1 DESC, c2 ASC", + 4, + ) + .await + .unwrap(); + assert_eq!(results.len(), 0); + Ok(()) +} diff --git a/datafusion/tests/sql/parquet.rs b/datafusion/tests/sql/parquet.rs index 3d4f49ea2619..37912c8751c8 100644 --- a/datafusion/tests/sql/parquet.rs +++ b/datafusion/tests/sql/parquet.rs @@ -15,6 +15,11 @@ // specific language governing permissions and limitations // under the License. +use std::{collections::HashMap, fs, path::Path}; + +use ::parquet::arrow::ArrowWriter; +use tempfile::TempDir; + use super::*; #[tokio::test] @@ -162,3 +167,57 @@ async fn parquet_list_columns() { assert_eq!(result.value(2), "hij"); assert_eq!(result.value(3), "xyz"); } + +#[tokio::test] +async fn schema_merge_ignores_metadata() { + // Create two parquet files in same table with same schema but different metadata + let tmp_dir = TempDir::new().unwrap(); + let table_dir = tmp_dir.path().join("parquet_test"); + let table_path = Path::new(&table_dir); + + let mut non_empty_metadata: HashMap = HashMap::new(); + non_empty_metadata.insert("testing".to_string(), "metadata".to_string()); + + let fields = vec![ + Field::new("id", DataType::Int32, true), + Field::new("name", DataType::Utf8, true), + ]; + let schemas = vec![ + Arc::new(Schema::new_with_metadata( + fields.clone(), + non_empty_metadata.clone(), + )), + Arc::new(Schema::new(fields.clone())), + ]; + + if let Ok(()) = fs::create_dir(table_path) { + for (i, schema) in schemas.iter().enumerate().take(2) { + let filename = format!("part-{}.parquet", i); + let path = table_path.join(&filename); + let file = fs::File::create(path).unwrap(); + let mut writer = + ArrowWriter::try_new(file.try_clone().unwrap(), schema.clone(), None) + .unwrap(); + + // create mock record batch + let ids = Arc::new(Int32Array::from_slice(&[i as i32])); + let names = Arc::new(StringArray::from_slice(&["test"])); + let rec_batch = + RecordBatch::try_new(schema.clone(), vec![ids, names]).unwrap(); + + writer.write(&rec_batch).unwrap(); + writer.close().unwrap(); + } + } + + // Read the parquet files into a dataframe to confirm results + // (no errors) + let mut ctx = ExecutionContext::new(); + let df = ctx + .read_parquet(table_dir.to_str().unwrap().to_string()) + .await + .unwrap(); + let result = df.collect().await.unwrap(); + + assert_eq!(result[0].schema().metadata(), result[1].schema().metadata()); +} diff --git a/datafusion/tests/sql/partitioned_csv.rs b/datafusion/tests/sql/partitioned_csv.rs index 5efc837d5c95..3394887ad0b8 100644 --- a/datafusion/tests/sql/partitioned_csv.rs +++ b/datafusion/tests/sql/partitioned_csv.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Utility functions for running with a partitioned csv dataset: +//! Utility functions for creating and running with a partitioned csv dataset. use std::{io::Write, sync::Arc}; diff --git a/datafusion/tests/sql/select.rs b/datafusion/tests/sql/select.rs index 02869dd99d2b..6ba190856a46 100644 --- a/datafusion/tests/sql/select.rs +++ b/datafusion/tests/sql/select.rs @@ -16,7 +16,10 @@ // under the License. use super::*; -use datafusion::{from_slice::FromSlice, physical_plan::collect_partitioned}; +use datafusion::{ + datasource::empty::EmptyTable, from_slice::FromSlice, + physical_plan::collect_partitioned, +}; use tempfile::TempDir; #[tokio::test] @@ -985,3 +988,16 @@ async fn parallel_query_with_filter() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn query_empty_table() { + let mut ctx = ExecutionContext::new(); + let empty_table = Arc::new(EmptyTable::new(Arc::new(Schema::empty()))); + ctx.register_table("test_tbl", empty_table).unwrap(); + let sql = "SELECT * FROM test_tbl"; + let result = plan_and_collect(&mut ctx, sql) + .await + .expect("Query empty table"); + let expected = vec!["++", "++"]; + assert_batches_sorted_eq!(expected, &result); +} diff --git a/datafusion/tests/sql/udf.rs b/datafusion/tests/sql/udf.rs index db42574c1bd0..6b714cb368b8 100644 --- a/datafusion/tests/sql/udf.rs +++ b/datafusion/tests/sql/udf.rs @@ -16,6 +16,11 @@ // under the License. use super::*; +use arrow::compute::add; +use datafusion::{ + logical_plan::{create_udaf, FunctionRegistry, LogicalPlanBuilder}, + physical_plan::{expressions::AvgAccumulator, functions::make_scalar_function}, +}; /// test that casting happens on udfs. /// c11 is f32, but `custom_sqrt` requires f64. Casting happens but the logical plan and @@ -30,3 +35,153 @@ async fn csv_query_custom_udf_with_cast() -> Result<()> { assert_float_eq(&expected, &actual); Ok(()) } + +#[tokio::test] +async fn scalar_udf() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(Int32Array::from_slice(&[1, 10, 10, 100])), + Arc::new(Int32Array::from_slice(&[2, 12, 12, 120])), + ], + )?; + + let mut ctx = ExecutionContext::new(); + + let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]])?; + ctx.register_table("t", Arc::new(provider))?; + + let myfunc = |args: &[ArrayRef]| { + let l = &args[0] + .as_any() + .downcast_ref::() + .expect("cast failed"); + let r = &args[1] + .as_any() + .downcast_ref::() + .expect("cast failed"); + Ok(Arc::new(add(l, r)?) as ArrayRef) + }; + let myfunc = make_scalar_function(myfunc); + + ctx.register_udf(create_udf( + "my_add", + vec![DataType::Int32, DataType::Int32], + Arc::new(DataType::Int32), + Volatility::Immutable, + myfunc, + )); + + // from here on, we may be in a different scope. We would still like to be able + // to call UDFs. + + let t = ctx.table("t")?; + + let plan = LogicalPlanBuilder::from(t.to_logical_plan()) + .project(vec![ + col("a"), + col("b"), + ctx.udf("my_add")?.call(vec![col("a"), col("b")]), + ])? + .build()?; + + assert_eq!( + format!("{:?}", plan), + "Projection: #t.a, #t.b, my_add(#t.a, #t.b)\n TableScan: t projection=None" + ); + + let plan = ctx.optimize(&plan)?; + let plan = ctx.create_physical_plan(&plan).await?; + let runtime = ctx.state.lock().runtime_env.clone(); + let result = collect(plan, runtime).await?; + + let expected = vec![ + "+-----+-----+-----------------+", + "| a | b | my_add(t.a,t.b) |", + "+-----+-----+-----------------+", + "| 1 | 2 | 3 |", + "| 10 | 12 | 22 |", + "| 10 | 12 | 22 |", + "| 100 | 120 | 220 |", + "+-----+-----+-----------------+", + ]; + assert_batches_eq!(expected, &result); + + let batch = &result[0]; + let a = batch + .column(0) + .as_any() + .downcast_ref::() + .expect("failed to cast a"); + let b = batch + .column(1) + .as_any() + .downcast_ref::() + .expect("failed to cast b"); + let sum = batch + .column(2) + .as_any() + .downcast_ref::() + .expect("failed to cast sum"); + + assert_eq!(4, a.len()); + assert_eq!(4, b.len()); + assert_eq!(4, sum.len()); + for i in 0..sum.len() { + assert_eq!(a.value(i) + b.value(i), sum.value(i)); + } + + ctx.deregister_table("t")?; + + Ok(()) +} + +/// tests the creation, registration and usage of a UDAF +#[tokio::test] +async fn simple_udaf() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let batch1 = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(Int32Array::from_slice(&[1, 2, 3]))], + )?; + let batch2 = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(Int32Array::from_slice(&[4, 5]))], + )?; + + let mut ctx = ExecutionContext::new(); + + let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; + ctx.register_table("t", Arc::new(provider))?; + + // define a udaf, using a DataFusion's accumulator + let my_avg = create_udaf( + "my_avg", + DataType::Float64, + Arc::new(DataType::Float64), + Volatility::Immutable, + Arc::new(|| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))), + Arc::new(vec![DataType::UInt64, DataType::Float64]), + ); + + ctx.register_udaf(my_avg); + + let result = plan_and_collect(&mut ctx, "SELECT MY_AVG(a) FROM t").await?; + + let expected = vec![ + "+-------------+", + "| my_avg(t.a) |", + "+-------------+", + "| 3 |", + "+-------------+", + ]; + assert_batches_eq!(expected, &result); + + Ok(()) +}