From 7b8d72c5342610a40827f23df7d5604cf24133fd Mon Sep 17 00:00:00 2001 From: "xudong.w" Date: Fri, 28 Jan 2022 02:17:28 +0800 Subject: [PATCH 01/50] feat: add join type for logical plan display (#1674) --- datafusion/src/logical_plan/builder.rs | 2 +- datafusion/src/logical_plan/plan.rs | 25 ++++++++++++++-- datafusion/src/optimizer/filter_push_down.rs | 30 +++++++++---------- .../src/optimizer/projection_push_down.rs | 6 ++-- datafusion/src/sql/planner.rs | 22 +++++++------- datafusion/tests/sql/explain_analyze.rs | 6 ++-- 6 files changed, 56 insertions(+), 35 deletions(-) diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index fc609390bcc0..613c8e950c93 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -1150,7 +1150,7 @@ mod tests { // id column should only show up once in projection let expected = "Projection: #t1.id, #t1.first_name, #t1.last_name, #t1.state, #t1.salary, #t2.first_name, #t2.last_name, #t2.state, #t2.salary\ - \n Join: Using #t1.id = #t2.id\ + \n Inner Join: Using #t1.id = #t2.id\ \n TableScan: t1 projection=None\ \n TableScan: t2 projection=None"; diff --git a/datafusion/src/logical_plan/plan.rs b/datafusion/src/logical_plan/plan.rs index b40dfc0103fc..3d49e5484eab 100644 --- a/datafusion/src/logical_plan/plan.rs +++ b/datafusion/src/logical_plan/plan.rs @@ -25,6 +25,7 @@ use crate::error::DataFusionError; use crate::logical_plan::dfschema::DFSchemaRef; use crate::sql::parser::FileType; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use std::fmt::Formatter; use std::{ collections::HashSet, fmt::{self, Display}, @@ -48,6 +49,20 @@ pub enum JoinType { Anti, } +impl Display for JoinType { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + let join_type = match self { + JoinType::Inner => "Inner", + JoinType::Left => "Left", + JoinType::Right => "Right", + JoinType::Full => "Full", + JoinType::Semi => "Semi", + JoinType::Anti => "Anti", + }; + write!(f, "{}", join_type) + } +} + /// Join constraint #[derive(Debug, Clone, Copy)] pub enum JoinConstraint { @@ -934,16 +949,22 @@ impl LogicalPlan { LogicalPlan::Join(Join { on: ref keys, join_constraint, + join_type, .. }) => { let join_expr: Vec = keys.iter().map(|(l, r)| format!("{} = {}", l, r)).collect(); match join_constraint { JoinConstraint::On => { - write!(f, "Join: {}", join_expr.join(", ")) + write!(f, "{} Join: {}", join_type, join_expr.join(", ")) } JoinConstraint::Using => { - write!(f, "Join: Using {}", join_expr.join(", ")) + write!( + f, + "{} Join: Using {}", + join_type, + join_expr.join(", ") + ) } } } diff --git a/datafusion/src/optimizer/filter_push_down.rs b/datafusion/src/optimizer/filter_push_down.rs index 6141af18a780..ababb52020d7 100644 --- a/datafusion/src/optimizer/filter_push_down.rs +++ b/datafusion/src/optimizer/filter_push_down.rs @@ -1014,7 +1014,7 @@ mod tests { format!("{:?}", plan), "\ Filter: #test.a <= Int64(1)\ - \n Join: #test.a = #test2.a\ + \n Inner Join: #test.a = #test2.a\ \n TableScan: test projection=None\ \n Projection: #test2.a\ \n TableScan: test2 projection=None" @@ -1022,7 +1022,7 @@ mod tests { // filter sent to side before the join let expected = "\ - Join: #test.a = #test2.a\ + Inner Join: #test.a = #test2.a\ \n Filter: #test.a <= Int64(1)\ \n TableScan: test projection=None\ \n Projection: #test2.a\ @@ -1055,7 +1055,7 @@ mod tests { format!("{:?}", plan), "\ Filter: #test.a <= Int64(1)\ - \n Join: Using #test.a = #test2.a\ + \n Inner Join: Using #test.a = #test2.a\ \n TableScan: test projection=None\ \n Projection: #test2.a\ \n TableScan: test2 projection=None" @@ -1063,7 +1063,7 @@ mod tests { // filter sent to side before the join let expected = "\ - Join: Using #test.a = #test2.a\ + Inner Join: Using #test.a = #test2.a\ \n Filter: #test.a <= Int64(1)\ \n TableScan: test projection=None\ \n Projection: #test2.a\ @@ -1099,7 +1099,7 @@ mod tests { format!("{:?}", plan), "\ Filter: #test.c <= #test2.b\ - \n Join: #test.a = #test2.a\ + \n Inner Join: #test.a = #test2.a\ \n Projection: #test.a, #test.c\ \n TableScan: test projection=None\ \n Projection: #test2.a, #test2.b\ @@ -1138,7 +1138,7 @@ mod tests { format!("{:?}", plan), "\ Filter: #test.b <= Int64(1)\ - \n Join: #test.a = #test2.a\ + \n Inner Join: #test.a = #test2.a\ \n Projection: #test.a, #test.b\ \n TableScan: test projection=None\ \n Projection: #test2.a, #test2.c\ @@ -1146,7 +1146,7 @@ mod tests { ); let expected = "\ - Join: #test.a = #test2.a\ + Inner Join: #test.a = #test2.a\ \n Projection: #test.a, #test.b\ \n Filter: #test.b <= Int64(1)\ \n TableScan: test projection=None\ @@ -1180,7 +1180,7 @@ mod tests { format!("{:?}", plan), "\ Filter: #test2.a <= Int64(1)\ - \n Join: Using #test.a = #test2.a\ + \n Left Join: Using #test.a = #test2.a\ \n TableScan: test projection=None\ \n Projection: #test2.a\ \n TableScan: test2 projection=None" @@ -1189,7 +1189,7 @@ mod tests { // filter not duplicated nor pushed down - i.e. noop let expected = "\ Filter: #test2.a <= Int64(1)\ - \n Join: Using #test.a = #test2.a\ + \n Left Join: Using #test.a = #test2.a\ \n TableScan: test projection=None\ \n Projection: #test2.a\ \n TableScan: test2 projection=None"; @@ -1221,7 +1221,7 @@ mod tests { format!("{:?}", plan), "\ Filter: #test.a <= Int64(1)\ - \n Join: Using #test.a = #test2.a\ + \n Right Join: Using #test.a = #test2.a\ \n TableScan: test projection=None\ \n Projection: #test2.a\ \n TableScan: test2 projection=None" @@ -1230,7 +1230,7 @@ mod tests { // filter not duplicated nor pushed down - i.e. noop let expected = "\ Filter: #test.a <= Int64(1)\ - \n Join: Using #test.a = #test2.a\ + \n Right Join: Using #test.a = #test2.a\ \n TableScan: test projection=None\ \n Projection: #test2.a\ \n TableScan: test2 projection=None"; @@ -1262,7 +1262,7 @@ mod tests { format!("{:?}", plan), "\ Filter: #test.a <= Int64(1)\ - \n Join: Using #test.a = #test2.a\ + \n Left Join: Using #test.a = #test2.a\ \n TableScan: test projection=None\ \n Projection: #test2.a\ \n TableScan: test2 projection=None" @@ -1270,7 +1270,7 @@ mod tests { // filter sent to left side of the join, not the right let expected = "\ - Join: Using #test.a = #test2.a\ + Left Join: Using #test.a = #test2.a\ \n Filter: #test.a <= Int64(1)\ \n TableScan: test projection=None\ \n Projection: #test2.a\ @@ -1303,7 +1303,7 @@ mod tests { format!("{:?}", plan), "\ Filter: #test2.a <= Int64(1)\ - \n Join: Using #test.a = #test2.a\ + \n Right Join: Using #test.a = #test2.a\ \n TableScan: test projection=None\ \n Projection: #test2.a\ \n TableScan: test2 projection=None" @@ -1311,7 +1311,7 @@ mod tests { // filter sent to right side of join, not duplicated to the left let expected = "\ - Join: Using #test.a = #test2.a\ + Right Join: Using #test.a = #test2.a\ \n TableScan: test projection=None\ \n Projection: #test2.a\ \n Filter: #test2.a <= Int64(1)\ diff --git a/datafusion/src/optimizer/projection_push_down.rs b/datafusion/src/optimizer/projection_push_down.rs index fb45e981612e..1a64a44c52d6 100644 --- a/datafusion/src/optimizer/projection_push_down.rs +++ b/datafusion/src/optimizer/projection_push_down.rs @@ -593,7 +593,7 @@ mod tests { // make sure projections are pushed down to both table scans let expected = "Projection: #test.a, #test.b, #test2.c1\ - \n Join: #test.a = #test2.c1\ + \n Left Join: #test.a = #test2.c1\ \n TableScan: test projection=Some([0, 1])\ \n TableScan: test2 projection=Some([0])"; @@ -634,7 +634,7 @@ mod tests { // make sure projections are pushed down to both table scans let expected = "Projection: #test.a, #test.b\ - \n Join: #test.a = #test2.c1\ + \n Left Join: #test.a = #test2.c1\ \n TableScan: test projection=Some([0, 1])\ \n TableScan: test2 projection=Some([0])"; @@ -673,7 +673,7 @@ mod tests { // make sure projections are pushed down to table scan let expected = "Projection: #test.a, #test.b\ - \n Join: Using #test.a = #test2.a\ + \n Left Join: Using #test.a = #test2.a\ \n TableScan: test projection=Some([0, 1])\ \n TableScan: test2 projection=Some([0])"; diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index e951a3a702a1..9da54cad4daa 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -3289,7 +3289,7 @@ mod tests { JOIN orders \ ON id = customer_id"; let expected = "Projection: #person.id, #orders.order_id\ - \n Join: #person.id = #orders.customer_id\ + \n Inner Join: #person.id = #orders.customer_id\ \n TableScan: person projection=None\ \n TableScan: orders projection=None"; quick_test(sql, expected); @@ -3303,7 +3303,7 @@ mod tests { ON id = customer_id AND order_id > 1 "; let expected = "Projection: #person.id, #orders.order_id\ \n Filter: #orders.order_id > Int64(1)\ - \n Join: #person.id = #orders.customer_id\ + \n Inner Join: #person.id = #orders.customer_id\ \n TableScan: person projection=None\ \n TableScan: orders projection=None"; quick_test(sql, expected); @@ -3316,7 +3316,7 @@ mod tests { LEFT JOIN orders \ ON id = customer_id AND order_id > 1"; let expected = "Projection: #person.id, #orders.order_id\ - \n Join: #person.id = #orders.customer_id\ + \n Left Join: #person.id = #orders.customer_id\ \n TableScan: person projection=None\ \n Filter: #orders.order_id > Int64(1)\ \n TableScan: orders projection=None"; @@ -3330,7 +3330,7 @@ mod tests { RIGHT JOIN orders \ ON id = customer_id AND id > 1"; let expected = "Projection: #person.id, #orders.order_id\ - \n Join: #person.id = #orders.customer_id\ + \n Right Join: #person.id = #orders.customer_id\ \n Filter: #person.id > Int64(1)\ \n TableScan: person projection=None\ \n TableScan: orders projection=None"; @@ -3344,7 +3344,7 @@ mod tests { JOIN orders \ ON person.id = orders.customer_id"; let expected = "Projection: #person.id, #orders.order_id\ - \n Join: #person.id = #orders.customer_id\ + \n Inner Join: #person.id = #orders.customer_id\ \n TableScan: person projection=None\ \n TableScan: orders projection=None"; quick_test(sql, expected); @@ -3357,7 +3357,7 @@ mod tests { JOIN person as person2 \ USING (id)"; let expected = "Projection: #person.first_name, #person.id\ - \n Join: Using #person.id = #person2.id\ + \n Inner Join: Using #person.id = #person2.id\ \n TableScan: person projection=None\ \n TableScan: person2 projection=None"; quick_test(sql, expected); @@ -3370,7 +3370,7 @@ mod tests { JOIN lineitem as lineitem2 \ USING (l_item_id)"; let expected = "Projection: #lineitem.l_item_id, #lineitem.l_description, #lineitem.price, #lineitem2.l_description, #lineitem2.price\ - \n Join: Using #lineitem.l_item_id = #lineitem2.l_item_id\ + \n Inner Join: Using #lineitem.l_item_id = #lineitem2.l_item_id\ \n TableScan: lineitem projection=None\ \n TableScan: lineitem2 projection=None"; quick_test(sql, expected); @@ -3384,8 +3384,8 @@ mod tests { JOIN lineitem ON o_item_id = l_item_id"; let expected = "Projection: #person.id, #orders.order_id, #lineitem.l_description\ - \n Join: #orders.o_item_id = #lineitem.l_item_id\ - \n Join: #person.id = #orders.customer_id\ + \n Inner Join: #orders.o_item_id = #lineitem.l_item_id\ + \n Inner Join: #person.id = #orders.customer_id\ \n TableScan: person projection=None\ \n TableScan: orders projection=None\ \n TableScan: lineitem projection=None"; @@ -3918,8 +3918,8 @@ mod tests { fn cross_join_to_inner_join() { let sql = "select person.id from person, orders, lineitem where person.id = lineitem.l_item_id and orders.o_item_id = lineitem.l_description;"; let expected = "Projection: #person.id\ - \n Join: #lineitem.l_description = #orders.o_item_id\ - \n Join: #person.id = #lineitem.l_item_id\ + \n Inner Join: #lineitem.l_description = #orders.o_item_id\ + \n Inner Join: #person.id = #lineitem.l_item_id\ \n TableScan: person projection=None\ \n TableScan: lineitem projection=None\ \n TableScan: orders projection=None"; diff --git a/datafusion/tests/sql/explain_analyze.rs b/datafusion/tests/sql/explain_analyze.rs index 7c1fa69ab73f..2bd78ec728f5 100644 --- a/datafusion/tests/sql/explain_analyze.rs +++ b/datafusion/tests/sql/explain_analyze.rs @@ -616,9 +616,9 @@ order by Sort: #revenue DESC NULLS FIRST\ \n Projection: #customer.c_custkey, #customer.c_name, #SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue, #customer.c_acctbal, #nation.n_name, #customer.c_address, #customer.c_phone, #customer.c_comment\ \n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, #customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, #customer.c_comment]], aggr=[[SUM(#lineitem.l_extendedprice * Int64(1) - #lineitem.l_discount)]]\ - \n Join: #customer.c_nationkey = #nation.n_nationkey\ - \n Join: #orders.o_orderkey = #lineitem.l_orderkey\ - \n Join: #customer.c_custkey = #orders.o_custkey\ + \n Inner Join: #customer.c_nationkey = #nation.n_nationkey\ + \n Inner Join: #orders.o_orderkey = #lineitem.l_orderkey\ + \n Inner Join: #customer.c_custkey = #orders.o_custkey\ \n TableScan: customer projection=Some([0, 1, 2, 3, 4, 5, 7])\ \n Filter: #orders.o_orderdate >= Date32(\"8674\") AND #orders.o_orderdate < Date32(\"8766\")\ \n TableScan: orders projection=Some([0, 1, 4]), filters=[#orders.o_orderdate >= Date32(\"8674\"), #orders.o_orderdate < Date32(\"8766\")]\ From 18ced8dce282a9979ecb3ee45c108f5e22d2557d Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 27 Jan 2022 20:58:06 -0500 Subject: [PATCH 02/50] (minor) Reduce memory manager and disk manager logs from `info!` to `debug!` (#1689) --- datafusion/src/execution/disk_manager.rs | 4 ++-- datafusion/src/execution/memory_manager.rs | 8 ++++---- datafusion/src/physical_plan/sorts/sort.rs | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/datafusion/src/execution/disk_manager.rs b/datafusion/src/execution/disk_manager.rs index 79b70f1f8b9a..31565fec130d 100644 --- a/datafusion/src/execution/disk_manager.rs +++ b/datafusion/src/execution/disk_manager.rs @@ -19,7 +19,7 @@ //! hashed among the directories listed in RuntimeConfig::local_dirs. use crate::error::{DataFusionError, Result}; -use log::{debug, info}; +use log::debug; use rand::{thread_rng, Rng}; use std::path::PathBuf; use std::sync::Arc; @@ -88,7 +88,7 @@ impl DiskManager { } DiskManagerConfig::NewSpecified(conf_dirs) => { let local_dirs = create_local_dirs(conf_dirs)?; - info!( + debug!( "Created local dirs {:?} as DataFusion working directory", local_dirs ); diff --git a/datafusion/src/execution/memory_manager.rs b/datafusion/src/execution/memory_manager.rs index 32f79750a70d..53eb720c4729 100644 --- a/datafusion/src/execution/memory_manager.rs +++ b/datafusion/src/execution/memory_manager.rs @@ -20,7 +20,7 @@ use crate::error::{DataFusionError, Result}; use async_trait::async_trait; use hashbrown::HashMap; -use log::info; +use log::debug; use std::fmt; use std::fmt::{Debug, Display, Formatter}; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -169,7 +169,7 @@ pub trait MemoryConsumer: Send + Sync { /// reached for this consumer. async fn try_grow(&self, required: usize) -> Result<()> { let current = self.mem_used(); - info!( + debug!( "trying to acquire {} whiling holding {} from consumer {}", human_readable_size(required), human_readable_size(current), @@ -181,7 +181,7 @@ pub trait MemoryConsumer: Send + Sync { .can_grow_directly(required, current) .await; if !can_grow_directly { - info!( + debug!( "Failed to grow memory of {} directly from consumer {}, spilling first ...", human_readable_size(required), self.id() @@ -261,7 +261,7 @@ impl MemoryManager { match config { MemoryManagerConfig::Existing(manager) => manager, MemoryManagerConfig::New { .. } => { - info!( + debug!( "Creating memory manager with initial size {}", human_readable_size(pool_size) ); diff --git a/datafusion/src/physical_plan/sorts/sort.rs b/datafusion/src/physical_plan/sorts/sort.rs index a2df6453ee82..d40d6cf170e4 100644 --- a/datafusion/src/physical_plan/sorts/sort.rs +++ b/datafusion/src/physical_plan/sorts/sort.rs @@ -44,7 +44,7 @@ use arrow::record_batch::RecordBatch; use async_trait::async_trait; use futures::lock::Mutex; use futures::StreamExt; -use log::{error, info}; +use log::{debug, error}; use std::any::Any; use std::fmt; use std::fmt::{Debug, Formatter}; @@ -207,7 +207,7 @@ impl MemoryConsumer for ExternalSorter { } async fn spill(&self) -> Result { - info!( + debug!( "{}[{}] spilling sort data of {} to disk while inserting ({} time(s) so far)", self.name(), self.id(), @@ -331,7 +331,7 @@ fn write_sorted( writer.write(&batch?)?; } writer.finish()?; - info!( + debug!( "Spilled {} batches of total {} rows to disk, memory released {}", writer.num_batches, writer.num_rows, writer.num_bytes ); From ed1de63aea89ba3b025ef431a60bff68adce1bcd Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 28 Jan 2022 09:47:57 -0500 Subject: [PATCH 03/50] Move `information_schema` tests out of execution/context.rs to `sql_integration` tests (#1684) * Move tests from context.rs to information_schema.rs * Fix up tests to compile --- datafusion/src/execution/context.rs | 477 +------------------- datafusion/tests/sql/information_schema.rs | 502 +++++++++++++++++++++ datafusion/tests/sql/mod.rs | 17 + 3 files changed, 522 insertions(+), 474 deletions(-) create mode 100644 datafusion/tests/sql/information_schema.rs diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 61cbf3abc8a1..9cc54dfe1f37 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -1277,14 +1277,13 @@ mod tests { logical_plan::{col, create_udf, sum, Expr}, }; use crate::{ - datasource::{empty::EmptyTable, MemTable, TableType}, + datasource::{empty::EmptyTable, MemTable}, logical_plan::create_udaf, physical_plan::expressions::AvgAccumulator, }; use arrow::array::{ - Array, ArrayRef, BinaryArray, DictionaryArray, Float32Array, Float64Array, - Int16Array, Int32Array, Int64Array, Int8Array, LargeBinaryArray, - LargeStringArray, StringArray, TimestampNanosecondArray, UInt16Array, + Array, ArrayRef, DictionaryArray, Float32Array, Float64Array, Int16Array, + Int32Array, Int64Array, Int8Array, LargeStringArray, StringArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }; use arrow::compute::add; @@ -3551,476 +3550,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn information_schema_tables_not_exist_by_default() { - let mut ctx = ExecutionContext::new(); - - let err = plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") - .await - .unwrap_err(); - assert_eq!( - err.to_string(), - "Error during planning: Table or CTE with name 'information_schema.tables' not found" - ); - } - - #[tokio::test] - async fn information_schema_tables_no_tables() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); - - let result = - plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") - .await - .unwrap(); - - let expected = vec![ - "+---------------+--------------------+------------+------------+", - "| table_catalog | table_schema | table_name | table_type |", - "+---------------+--------------------+------------+------------+", - "| datafusion | information_schema | columns | VIEW |", - "| datafusion | information_schema | tables | VIEW |", - "+---------------+--------------------+------------+------------+", - ]; - assert_batches_sorted_eq!(expected, &result); - } - - #[tokio::test] - async fn information_schema_tables_tables_default_catalog() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); - - // Now, register an empty table - ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - - let result = - plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") - .await - .unwrap(); - - let expected = vec![ - "+---------------+--------------------+------------+------------+", - "| table_catalog | table_schema | table_name | table_type |", - "+---------------+--------------------+------------+------------+", - "| datafusion | information_schema | tables | VIEW |", - "| datafusion | information_schema | columns | VIEW |", - "| datafusion | public | t | BASE TABLE |", - "+---------------+--------------------+------------+------------+", - ]; - assert_batches_sorted_eq!(expected, &result); - - // Newly added tables should appear - ctx.register_table("t2", test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - - let result = - plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") - .await - .unwrap(); - - let expected = vec![ - "+---------------+--------------------+------------+------------+", - "| table_catalog | table_schema | table_name | table_type |", - "+---------------+--------------------+------------+------------+", - "| datafusion | information_schema | columns | VIEW |", - "| datafusion | information_schema | tables | VIEW |", - "| datafusion | public | t | BASE TABLE |", - "| datafusion | public | t2 | BASE TABLE |", - "+---------------+--------------------+------------+------------+", - ]; - assert_batches_sorted_eq!(expected, &result); - } - - #[tokio::test] - async fn information_schema_tables_tables_with_multiple_catalogs() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); - let catalog = MemoryCatalogProvider::new(); - let schema = MemorySchemaProvider::new(); - schema - .register_table("t1".to_owned(), test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - schema - .register_table("t2".to_owned(), test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - catalog.register_schema("my_schema", Arc::new(schema)); - ctx.register_catalog("my_catalog", Arc::new(catalog)); - - let catalog = MemoryCatalogProvider::new(); - let schema = MemorySchemaProvider::new(); - schema - .register_table("t3".to_owned(), test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - catalog.register_schema("my_other_schema", Arc::new(schema)); - ctx.register_catalog("my_other_catalog", Arc::new(catalog)); - - let result = - plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") - .await - .unwrap(); - - let expected = vec![ - "+------------------+--------------------+------------+------------+", - "| table_catalog | table_schema | table_name | table_type |", - "+------------------+--------------------+------------+------------+", - "| datafusion | information_schema | columns | VIEW |", - "| datafusion | information_schema | tables | VIEW |", - "| my_catalog | information_schema | columns | VIEW |", - "| my_catalog | information_schema | tables | VIEW |", - "| my_catalog | my_schema | t1 | BASE TABLE |", - "| my_catalog | my_schema | t2 | BASE TABLE |", - "| my_other_catalog | information_schema | columns | VIEW |", - "| my_other_catalog | information_schema | tables | VIEW |", - "| my_other_catalog | my_other_schema | t3 | BASE TABLE |", - "+------------------+--------------------+------------+------------+", - ]; - assert_batches_sorted_eq!(expected, &result); - } - - #[tokio::test] - async fn information_schema_tables_table_types() { - struct TestTable(TableType); - - #[async_trait] - impl TableProvider for TestTable { - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn table_type(&self) -> TableType { - self.0 - } - - fn schema(&self) -> SchemaRef { - unimplemented!() - } - - async fn scan( - &self, - _: &Option>, - _: &[Expr], - _: Option, - ) -> Result> { - unimplemented!() - } - } - - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); - - ctx.register_table("physical", Arc::new(TestTable(TableType::Base))) - .unwrap(); - ctx.register_table("query", Arc::new(TestTable(TableType::View))) - .unwrap(); - ctx.register_table("temp", Arc::new(TestTable(TableType::Temporary))) - .unwrap(); - - let result = - plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") - .await - .unwrap(); - - let expected = vec![ - "+---------------+--------------------+------------+-----------------+", - "| table_catalog | table_schema | table_name | table_type |", - "+---------------+--------------------+------------+-----------------+", - "| datafusion | information_schema | tables | VIEW |", - "| datafusion | information_schema | columns | VIEW |", - "| datafusion | public | physical | BASE TABLE |", - "| datafusion | public | query | VIEW |", - "| datafusion | public | temp | LOCAL TEMPORARY |", - "+---------------+--------------------+------------+-----------------+", - ]; - assert_batches_sorted_eq!(expected, &result); - } - - #[tokio::test] - async fn information_schema_show_tables_no_information_schema() { - let mut ctx = ExecutionContext::with_config(ExecutionConfig::new()); - - ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - - // use show tables alias - let err = plan_and_collect(&mut ctx, "SHOW TABLES").await.unwrap_err(); - - assert_eq!(err.to_string(), "Error during planning: SHOW TABLES is not supported unless information_schema is enabled"); - } - - #[tokio::test] - async fn information_schema_show_tables() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); - - ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - - // use show tables alias - let result = plan_and_collect(&mut ctx, "SHOW TABLES").await.unwrap(); - - let expected = vec![ - "+---------------+--------------------+------------+------------+", - "| table_catalog | table_schema | table_name | table_type |", - "+---------------+--------------------+------------+------------+", - "| datafusion | information_schema | columns | VIEW |", - "| datafusion | information_schema | tables | VIEW |", - "| datafusion | public | t | BASE TABLE |", - "+---------------+--------------------+------------+------------+", - ]; - assert_batches_sorted_eq!(expected, &result); - - let result = plan_and_collect(&mut ctx, "SHOW tables").await.unwrap(); - - assert_batches_sorted_eq!(expected, &result); - } - - #[tokio::test] - async fn information_schema_show_columns_no_information_schema() { - let mut ctx = ExecutionContext::with_config(ExecutionConfig::new()); - - ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - - let err = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM t") - .await - .unwrap_err(); - - assert_eq!(err.to_string(), "Error during planning: SHOW COLUMNS is not supported unless information_schema is enabled"); - } - - #[tokio::test] - async fn information_schema_show_columns_like_where() { - let mut ctx = ExecutionContext::with_config(ExecutionConfig::new()); - - ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - - let expected = - "Error during planning: SHOW COLUMNS with WHERE or LIKE is not supported"; - - let err = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM t LIKE 'f'") - .await - .unwrap_err(); - assert_eq!(err.to_string(), expected); - - let err = - plan_and_collect(&mut ctx, "SHOW COLUMNS FROM t WHERE column_name = 'bar'") - .await - .unwrap_err(); - assert_eq!(err.to_string(), expected); - } - - #[tokio::test] - async fn information_schema_show_columns() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); - - ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - - let result = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM t") - .await - .unwrap(); - - let expected = vec![ - "+---------------+--------------+------------+-------------+-----------+-------------+", - "| table_catalog | table_schema | table_name | column_name | data_type | is_nullable |", - "+---------------+--------------+------------+-------------+-----------+-------------+", - "| datafusion | public | t | i | Int32 | YES |", - "+---------------+--------------+------------+-------------+-----------+-------------+", - ]; - assert_batches_sorted_eq!(expected, &result); - - let result = plan_and_collect(&mut ctx, "SHOW columns from t") - .await - .unwrap(); - assert_batches_sorted_eq!(expected, &result); - - // This isn't ideal but it is consistent behavior for `SELECT * from T` - let err = plan_and_collect(&mut ctx, "SHOW columns from T") - .await - .unwrap_err(); - assert_eq!( - err.to_string(), - "Error during planning: Unknown relation for SHOW COLUMNS: T" - ); - } - - // test errors with WHERE and LIKE - #[tokio::test] - async fn information_schema_show_columns_full_extended() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); - - ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - - let result = plan_and_collect(&mut ctx, "SHOW FULL COLUMNS FROM t") - .await - .unwrap(); - let expected = vec![ - "+---------------+--------------+------------+-------------+------------------+----------------+-------------+-----------+--------------------------+------------------------+-------------------+-------------------------+---------------+--------------------+---------------+", - "| table_catalog | table_schema | table_name | column_name | ordinal_position | column_default | is_nullable | data_type | character_maximum_length | character_octet_length | numeric_precision | numeric_precision_radix | numeric_scale | datetime_precision | interval_type |", - "+---------------+--------------+------------+-------------+------------------+----------------+-------------+-----------+--------------------------+------------------------+-------------------+-------------------------+---------------+--------------------+---------------+", - "| datafusion | public | t | i | 0 | | YES | Int32 | | | 32 | 2 | | | |", - "+---------------+--------------+------------+-------------+------------------+----------------+-------------+-----------+--------------------------+------------------------+-------------------+-------------------------+---------------+--------------------+---------------+", - ]; - assert_batches_sorted_eq!(expected, &result); - - let result = plan_and_collect(&mut ctx, "SHOW EXTENDED COLUMNS FROM t") - .await - .unwrap(); - assert_batches_sorted_eq!(expected, &result); - } - - #[tokio::test] - async fn information_schema_show_table_table_names() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); - - ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - - let result = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM public.t") - .await - .unwrap(); - - let expected = vec![ - "+---------------+--------------+------------+-------------+-----------+-------------+", - "| table_catalog | table_schema | table_name | column_name | data_type | is_nullable |", - "+---------------+--------------+------------+-------------+-----------+-------------+", - "| datafusion | public | t | i | Int32 | YES |", - "+---------------+--------------+------------+-------------+-----------+-------------+", - ]; - assert_batches_sorted_eq!(expected, &result); - - let result = plan_and_collect(&mut ctx, "SHOW columns from datafusion.public.t") - .await - .unwrap(); - assert_batches_sorted_eq!(expected, &result); - - let err = plan_and_collect(&mut ctx, "SHOW columns from t2") - .await - .unwrap_err(); - assert_eq!( - err.to_string(), - "Error during planning: Unknown relation for SHOW COLUMNS: t2" - ); - - let err = plan_and_collect(&mut ctx, "SHOW columns from datafusion.public.t2") - .await - .unwrap_err(); - assert_eq!(err.to_string(), "Error during planning: Unknown relation for SHOW COLUMNS: datafusion.public.t2"); - } - - #[tokio::test] - async fn show_unsupported() { - let mut ctx = ExecutionContext::with_config(ExecutionConfig::new()); - - let err = plan_and_collect(&mut ctx, "SHOW SOMETHING_UNKNOWN") - .await - .unwrap_err(); - - assert_eq!(err.to_string(), "This feature is not implemented: SHOW SOMETHING_UNKNOWN not implemented. Supported syntax: SHOW "); - } - - #[tokio::test] - async fn information_schema_columns_not_exist_by_default() { - let mut ctx = ExecutionContext::new(); - - let err = plan_and_collect(&mut ctx, "SELECT * from information_schema.columns") - .await - .unwrap_err(); - assert_eq!( - err.to_string(), - "Error during planning: Table or CTE with name 'information_schema.columns' not found" - ); - } - - fn table_with_many_types() -> Arc { - let schema = Schema::new(vec![ - Field::new("int32_col", DataType::Int32, false), - Field::new("float64_col", DataType::Float64, true), - Field::new("utf8_col", DataType::Utf8, true), - Field::new("large_utf8_col", DataType::LargeUtf8, false), - Field::new("binary_col", DataType::Binary, false), - Field::new("large_binary_col", DataType::LargeBinary, false), - Field::new( - "timestamp_nanos", - DataType::Timestamp(TimeUnit::Nanosecond, None), - false, - ), - ]); - - let batch = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![ - Arc::new(Int32Array::from_slice(&[1])), - Arc::new(Float64Array::from_slice(&[1.0])), - Arc::new(StringArray::from(vec![Some("foo")])), - Arc::new(LargeStringArray::from(vec![Some("bar")])), - Arc::new(BinaryArray::from_slice(&[b"foo" as &[u8]])), - Arc::new(LargeBinaryArray::from_slice(&[b"foo" as &[u8]])), - Arc::new(TimestampNanosecondArray::from_opt_vec( - vec![Some(123)], - None, - )), - ], - ) - .unwrap(); - let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]]).unwrap(); - Arc::new(provider) - } - - #[tokio::test] - async fn information_schema_columns() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); - let catalog = MemoryCatalogProvider::new(); - let schema = MemorySchemaProvider::new(); - - schema - .register_table("t1".to_owned(), test::table_with_sequence(1, 1).unwrap()) - .unwrap(); - - schema - .register_table("t2".to_owned(), table_with_many_types()) - .unwrap(); - catalog.register_schema("my_schema", Arc::new(schema)); - ctx.register_catalog("my_catalog", Arc::new(catalog)); - - let result = - plan_and_collect(&mut ctx, "SELECT * from information_schema.columns") - .await - .unwrap(); - - let expected = vec![ - "+---------------+--------------+------------+------------------+------------------+----------------+-------------+-----------------------------+--------------------------+------------------------+-------------------+-------------------------+---------------+--------------------+---------------+", - "| table_catalog | table_schema | table_name | column_name | ordinal_position | column_default | is_nullable | data_type | character_maximum_length | character_octet_length | numeric_precision | numeric_precision_radix | numeric_scale | datetime_precision | interval_type |", - "+---------------+--------------+------------+------------------+------------------+----------------+-------------+-----------------------------+--------------------------+------------------------+-------------------+-------------------------+---------------+--------------------+---------------+", - "| my_catalog | my_schema | t1 | i | 0 | | YES | Int32 | | | 32 | 2 | | | |", - "| my_catalog | my_schema | t2 | binary_col | 4 | | NO | Binary | | 2147483647 | | | | | |", - "| my_catalog | my_schema | t2 | float64_col | 1 | | YES | Float64 | | | 24 | 2 | | | |", - "| my_catalog | my_schema | t2 | int32_col | 0 | | NO | Int32 | | | 32 | 2 | | | |", - "| my_catalog | my_schema | t2 | large_binary_col | 5 | | NO | LargeBinary | | 9223372036854775807 | | | | | |", - "| my_catalog | my_schema | t2 | large_utf8_col | 3 | | NO | LargeUtf8 | | 9223372036854775807 | | | | | |", - "| my_catalog | my_schema | t2 | timestamp_nanos | 6 | | NO | Timestamp(Nanosecond, None) | | | | | | | |", - "| my_catalog | my_schema | t2 | utf8_col | 2 | | YES | Utf8 | | 2147483647 | | | | | |", - "+---------------+--------------+------------+------------------+------------------+----------------+-------------+-----------------------------+--------------------------+------------------------+-------------------+-------------------------+---------------+--------------------+---------------+", - ]; - assert_batches_sorted_eq!(expected, &result); - } - #[tokio::test] async fn disabled_default_catalog_and_schema() -> Result<()> { let mut ctx = ExecutionContext::with_config( diff --git a/datafusion/tests/sql/information_schema.rs b/datafusion/tests/sql/information_schema.rs new file mode 100644 index 000000000000..d93f0d7328d3 --- /dev/null +++ b/datafusion/tests/sql/information_schema.rs @@ -0,0 +1,502 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use async_trait::async_trait; +use datafusion::{ + catalog::{ + catalog::MemoryCatalogProvider, + schema::{MemorySchemaProvider, SchemaProvider}, + }, + datasource::{TableProvider, TableType}, + logical_plan::Expr, +}; + +use super::*; + +#[tokio::test] +async fn information_schema_tables_not_exist_by_default() { + let mut ctx = ExecutionContext::new(); + + let err = plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Table or CTE with name 'information_schema.tables' not found" + ); +} + +#[tokio::test] +async fn information_schema_tables_no_tables() { + let mut ctx = ExecutionContext::with_config( + ExecutionConfig::new().with_information_schema(true), + ); + + let result = plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") + .await + .unwrap(); + + let expected = vec![ + "+---------------+--------------------+------------+------------+", + "| table_catalog | table_schema | table_name | table_type |", + "+---------------+--------------------+------------+------------+", + "| datafusion | information_schema | columns | VIEW |", + "| datafusion | information_schema | tables | VIEW |", + "+---------------+--------------------+------------+------------+", + ]; + assert_batches_sorted_eq!(expected, &result); +} + +#[tokio::test] +async fn information_schema_tables_tables_default_catalog() { + let mut ctx = ExecutionContext::with_config( + ExecutionConfig::new().with_information_schema(true), + ); + + // Now, register an empty table + ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + .unwrap(); + + let result = plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") + .await + .unwrap(); + + let expected = vec![ + "+---------------+--------------------+------------+------------+", + "| table_catalog | table_schema | table_name | table_type |", + "+---------------+--------------------+------------+------------+", + "| datafusion | information_schema | tables | VIEW |", + "| datafusion | information_schema | columns | VIEW |", + "| datafusion | public | t | BASE TABLE |", + "+---------------+--------------------+------------+------------+", + ]; + assert_batches_sorted_eq!(expected, &result); + + // Newly added tables should appear + ctx.register_table("t2", table_with_sequence(1, 1).unwrap()) + .unwrap(); + + let result = plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") + .await + .unwrap(); + + let expected = vec![ + "+---------------+--------------------+------------+------------+", + "| table_catalog | table_schema | table_name | table_type |", + "+---------------+--------------------+------------+------------+", + "| datafusion | information_schema | columns | VIEW |", + "| datafusion | information_schema | tables | VIEW |", + "| datafusion | public | t | BASE TABLE |", + "| datafusion | public | t2 | BASE TABLE |", + "+---------------+--------------------+------------+------------+", + ]; + assert_batches_sorted_eq!(expected, &result); +} + +#[tokio::test] +async fn information_schema_tables_tables_with_multiple_catalogs() { + let mut ctx = ExecutionContext::with_config( + ExecutionConfig::new().with_information_schema(true), + ); + let catalog = MemoryCatalogProvider::new(); + let schema = MemorySchemaProvider::new(); + schema + .register_table("t1".to_owned(), table_with_sequence(1, 1).unwrap()) + .unwrap(); + schema + .register_table("t2".to_owned(), table_with_sequence(1, 1).unwrap()) + .unwrap(); + catalog.register_schema("my_schema", Arc::new(schema)); + ctx.register_catalog("my_catalog", Arc::new(catalog)); + + let catalog = MemoryCatalogProvider::new(); + let schema = MemorySchemaProvider::new(); + schema + .register_table("t3".to_owned(), table_with_sequence(1, 1).unwrap()) + .unwrap(); + catalog.register_schema("my_other_schema", Arc::new(schema)); + ctx.register_catalog("my_other_catalog", Arc::new(catalog)); + + let result = plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") + .await + .unwrap(); + + let expected = vec![ + "+------------------+--------------------+------------+------------+", + "| table_catalog | table_schema | table_name | table_type |", + "+------------------+--------------------+------------+------------+", + "| datafusion | information_schema | columns | VIEW |", + "| datafusion | information_schema | tables | VIEW |", + "| my_catalog | information_schema | columns | VIEW |", + "| my_catalog | information_schema | tables | VIEW |", + "| my_catalog | my_schema | t1 | BASE TABLE |", + "| my_catalog | my_schema | t2 | BASE TABLE |", + "| my_other_catalog | information_schema | columns | VIEW |", + "| my_other_catalog | information_schema | tables | VIEW |", + "| my_other_catalog | my_other_schema | t3 | BASE TABLE |", + "+------------------+--------------------+------------+------------+", + ]; + assert_batches_sorted_eq!(expected, &result); +} + +#[tokio::test] +async fn information_schema_tables_table_types() { + struct TestTable(TableType); + + #[async_trait] + impl TableProvider for TestTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn table_type(&self) -> TableType { + self.0 + } + + fn schema(&self) -> SchemaRef { + unimplemented!() + } + + async fn scan( + &self, + _: &Option>, + _: &[Expr], + _: Option, + ) -> Result> { + unimplemented!() + } + } + + let mut ctx = ExecutionContext::with_config( + ExecutionConfig::new().with_information_schema(true), + ); + + ctx.register_table("physical", Arc::new(TestTable(TableType::Base))) + .unwrap(); + ctx.register_table("query", Arc::new(TestTable(TableType::View))) + .unwrap(); + ctx.register_table("temp", Arc::new(TestTable(TableType::Temporary))) + .unwrap(); + + let result = plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") + .await + .unwrap(); + + let expected = vec![ + "+---------------+--------------------+------------+-----------------+", + "| table_catalog | table_schema | table_name | table_type |", + "+---------------+--------------------+------------+-----------------+", + "| datafusion | information_schema | tables | VIEW |", + "| datafusion | information_schema | columns | VIEW |", + "| datafusion | public | physical | BASE TABLE |", + "| datafusion | public | query | VIEW |", + "| datafusion | public | temp | LOCAL TEMPORARY |", + "+---------------+--------------------+------------+-----------------+", + ]; + assert_batches_sorted_eq!(expected, &result); +} + +#[tokio::test] +async fn information_schema_show_tables_no_information_schema() { + let mut ctx = ExecutionContext::with_config(ExecutionConfig::new()); + + ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + .unwrap(); + + // use show tables alias + let err = plan_and_collect(&mut ctx, "SHOW TABLES").await.unwrap_err(); + + assert_eq!(err.to_string(), "Error during planning: SHOW TABLES is not supported unless information_schema is enabled"); +} + +#[tokio::test] +async fn information_schema_show_tables() { + let mut ctx = ExecutionContext::with_config( + ExecutionConfig::new().with_information_schema(true), + ); + + ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + .unwrap(); + + // use show tables alias + let result = plan_and_collect(&mut ctx, "SHOW TABLES").await.unwrap(); + + let expected = vec![ + "+---------------+--------------------+------------+------------+", + "| table_catalog | table_schema | table_name | table_type |", + "+---------------+--------------------+------------+------------+", + "| datafusion | information_schema | columns | VIEW |", + "| datafusion | information_schema | tables | VIEW |", + "| datafusion | public | t | BASE TABLE |", + "+---------------+--------------------+------------+------------+", + ]; + assert_batches_sorted_eq!(expected, &result); + + let result = plan_and_collect(&mut ctx, "SHOW tables").await.unwrap(); + + assert_batches_sorted_eq!(expected, &result); +} + +#[tokio::test] +async fn information_schema_show_columns_no_information_schema() { + let mut ctx = ExecutionContext::with_config(ExecutionConfig::new()); + + ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + .unwrap(); + + let err = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM t") + .await + .unwrap_err(); + + assert_eq!(err.to_string(), "Error during planning: SHOW COLUMNS is not supported unless information_schema is enabled"); +} + +#[tokio::test] +async fn information_schema_show_columns_like_where() { + let mut ctx = ExecutionContext::with_config(ExecutionConfig::new()); + + ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + .unwrap(); + + let expected = + "Error during planning: SHOW COLUMNS with WHERE or LIKE is not supported"; + + let err = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM t LIKE 'f'") + .await + .unwrap_err(); + assert_eq!(err.to_string(), expected); + + let err = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM t WHERE column_name = 'bar'") + .await + .unwrap_err(); + assert_eq!(err.to_string(), expected); +} + +#[tokio::test] +async fn information_schema_show_columns() { + let mut ctx = ExecutionContext::with_config( + ExecutionConfig::new().with_information_schema(true), + ); + + ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + .unwrap(); + + let result = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM t") + .await + .unwrap(); + + let expected = vec![ + "+---------------+--------------+------------+-------------+-----------+-------------+", + "| table_catalog | table_schema | table_name | column_name | data_type | is_nullable |", + "+---------------+--------------+------------+-------------+-----------+-------------+", + "| datafusion | public | t | i | Int32 | YES |", + "+---------------+--------------+------------+-------------+-----------+-------------+", + ]; + assert_batches_sorted_eq!(expected, &result); + + let result = plan_and_collect(&mut ctx, "SHOW columns from t") + .await + .unwrap(); + assert_batches_sorted_eq!(expected, &result); + + // This isn't ideal but it is consistent behavior for `SELECT * from T` + let err = plan_and_collect(&mut ctx, "SHOW columns from T") + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Unknown relation for SHOW COLUMNS: T" + ); +} + +// test errors with WHERE and LIKE +#[tokio::test] +async fn information_schema_show_columns_full_extended() { + let mut ctx = ExecutionContext::with_config( + ExecutionConfig::new().with_information_schema(true), + ); + + ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + .unwrap(); + + let result = plan_and_collect(&mut ctx, "SHOW FULL COLUMNS FROM t") + .await + .unwrap(); + let expected = vec![ + "+---------------+--------------+------------+-------------+------------------+----------------+-------------+-----------+--------------------------+------------------------+-------------------+-------------------------+---------------+--------------------+---------------+", + "| table_catalog | table_schema | table_name | column_name | ordinal_position | column_default | is_nullable | data_type | character_maximum_length | character_octet_length | numeric_precision | numeric_precision_radix | numeric_scale | datetime_precision | interval_type |", + "+---------------+--------------+------------+-------------+------------------+----------------+-------------+-----------+--------------------------+------------------------+-------------------+-------------------------+---------------+--------------------+---------------+", + "| datafusion | public | t | i | 0 | | YES | Int32 | | | 32 | 2 | | | |", + "+---------------+--------------+------------+-------------+------------------+----------------+-------------+-----------+--------------------------+------------------------+-------------------+-------------------------+---------------+--------------------+---------------+", + ]; + assert_batches_sorted_eq!(expected, &result); + + let result = plan_and_collect(&mut ctx, "SHOW EXTENDED COLUMNS FROM t") + .await + .unwrap(); + assert_batches_sorted_eq!(expected, &result); +} + +#[tokio::test] +async fn information_schema_show_table_table_names() { + let mut ctx = ExecutionContext::with_config( + ExecutionConfig::new().with_information_schema(true), + ); + + ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + .unwrap(); + + let result = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM public.t") + .await + .unwrap(); + + let expected = vec![ + "+---------------+--------------+------------+-------------+-----------+-------------+", + "| table_catalog | table_schema | table_name | column_name | data_type | is_nullable |", + "+---------------+--------------+------------+-------------+-----------+-------------+", + "| datafusion | public | t | i | Int32 | YES |", + "+---------------+--------------+------------+-------------+-----------+-------------+", + ]; + assert_batches_sorted_eq!(expected, &result); + + let result = plan_and_collect(&mut ctx, "SHOW columns from datafusion.public.t") + .await + .unwrap(); + assert_batches_sorted_eq!(expected, &result); + + let err = plan_and_collect(&mut ctx, "SHOW columns from t2") + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Unknown relation for SHOW COLUMNS: t2" + ); + + let err = plan_and_collect(&mut ctx, "SHOW columns from datafusion.public.t2") + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Unknown relation for SHOW COLUMNS: datafusion.public.t2" + ); +} + +#[tokio::test] +async fn show_unsupported() { + let mut ctx = ExecutionContext::with_config(ExecutionConfig::new()); + + let err = plan_and_collect(&mut ctx, "SHOW SOMETHING_UNKNOWN") + .await + .unwrap_err(); + + assert_eq!(err.to_string(), "This feature is not implemented: SHOW SOMETHING_UNKNOWN not implemented. Supported syntax: SHOW "); +} + +#[tokio::test] +async fn information_schema_columns_not_exist_by_default() { + let mut ctx = ExecutionContext::new(); + + let err = plan_and_collect(&mut ctx, "SELECT * from information_schema.columns") + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Table or CTE with name 'information_schema.columns' not found" + ); +} + +fn table_with_many_types() -> Arc { + let schema = Schema::new(vec![ + Field::new("int32_col", DataType::Int32, false), + Field::new("float64_col", DataType::Float64, true), + Field::new("utf8_col", DataType::Utf8, true), + Field::new("large_utf8_col", DataType::LargeUtf8, false), + Field::new("binary_col", DataType::Binary, false), + Field::new("large_binary_col", DataType::LargeBinary, false), + Field::new( + "timestamp_nanos", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(Int32Array::from_slice(&[1])), + Arc::new(Float64Array::from_slice(&[1.0])), + Arc::new(StringArray::from(vec![Some("foo")])), + Arc::new(LargeStringArray::from(vec![Some("bar")])), + Arc::new(BinaryArray::from_slice(&[b"foo" as &[u8]])), + Arc::new(LargeBinaryArray::from_slice(&[b"foo" as &[u8]])), + Arc::new(TimestampNanosecondArray::from_opt_vec( + vec![Some(123)], + None, + )), + ], + ) + .unwrap(); + let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]]).unwrap(); + Arc::new(provider) +} + +#[tokio::test] +async fn information_schema_columns() { + let mut ctx = ExecutionContext::with_config( + ExecutionConfig::new().with_information_schema(true), + ); + let catalog = MemoryCatalogProvider::new(); + let schema = MemorySchemaProvider::new(); + + schema + .register_table("t1".to_owned(), table_with_sequence(1, 1).unwrap()) + .unwrap(); + + schema + .register_table("t2".to_owned(), table_with_many_types()) + .unwrap(); + catalog.register_schema("my_schema", Arc::new(schema)); + ctx.register_catalog("my_catalog", Arc::new(catalog)); + + let result = plan_and_collect(&mut ctx, "SELECT * from information_schema.columns") + .await + .unwrap(); + + let expected = vec![ + "+---------------+--------------+------------+------------------+------------------+----------------+-------------+-----------------------------+--------------------------+------------------------+-------------------+-------------------------+---------------+--------------------+---------------+", + "| table_catalog | table_schema | table_name | column_name | ordinal_position | column_default | is_nullable | data_type | character_maximum_length | character_octet_length | numeric_precision | numeric_precision_radix | numeric_scale | datetime_precision | interval_type |", + "+---------------+--------------+------------+------------------+------------------+----------------+-------------+-----------------------------+--------------------------+------------------------+-------------------+-------------------------+---------------+--------------------+---------------+", + "| my_catalog | my_schema | t1 | i | 0 | | YES | Int32 | | | 32 | 2 | | | |", + "| my_catalog | my_schema | t2 | binary_col | 4 | | NO | Binary | | 2147483647 | | | | | |", + "| my_catalog | my_schema | t2 | float64_col | 1 | | YES | Float64 | | | 24 | 2 | | | |", + "| my_catalog | my_schema | t2 | int32_col | 0 | | NO | Int32 | | | 32 | 2 | | | |", + "| my_catalog | my_schema | t2 | large_binary_col | 5 | | NO | LargeBinary | | 9223372036854775807 | | | | | |", + "| my_catalog | my_schema | t2 | large_utf8_col | 3 | | NO | LargeUtf8 | | 9223372036854775807 | | | | | |", + "| my_catalog | my_schema | t2 | timestamp_nanos | 6 | | NO | Timestamp(Nanosecond, None) | | | | | | | |", + "| my_catalog | my_schema | t2 | utf8_col | 2 | | YES | Utf8 | | 2147483647 | | | | | |", + "+---------------+--------------+------------+------------------+------------------+----------------+-------------+-----------------------------+--------------------------+------------------------+-------------------+-------------------------+---------------+--------------------+---------------+", + ]; + assert_batches_sorted_eq!(expected, &result); +} + +/// Execute SQL and return results +async fn plan_and_collect( + ctx: &mut ExecutionContext, + sql: &str, +) -> Result> { + ctx.sql(sql).await?.collect().await +} diff --git a/datafusion/tests/sql/mod.rs b/datafusion/tests/sql/mod.rs index 55715af4f164..f2496c36814b 100644 --- a/datafusion/tests/sql/mod.rs +++ b/datafusion/tests/sql/mod.rs @@ -29,6 +29,7 @@ use datafusion::assert_batches_eq; use datafusion::assert_batches_sorted_eq; use datafusion::assert_contains; use datafusion::assert_not_contains; +use datafusion::datasource::TableProvider; use datafusion::from_slice::FromSlice; use datafusion::logical_plan::plan::{Aggregate, Projection}; use datafusion::logical_plan::LogicalPlan; @@ -95,6 +96,7 @@ pub mod udf; pub mod union; pub mod window; +pub mod information_schema; #[cfg_attr(not(feature = "unicode_expressions"), ignore)] pub mod unicode; @@ -693,6 +695,21 @@ fn make_timestamp_nano_table() -> Result> { make_timestamp_table::() } +/// Return a new table provider that has a single Int32 column with +/// values between `seq_start` and `seq_end` +pub fn table_with_sequence( + seq_start: i32, + seq_end: i32, +) -> Result> { + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); + let arr = Arc::new(Int32Array::from((seq_start..=seq_end).collect::>())); + let partitions = vec![vec![RecordBatch::try_new( + schema.clone(), + vec![arr as ArrayRef], + )?]]; + Ok(Arc::new(MemTable::try_new(schema, partitions)?)) +} + // Normalizes parts of an explain plan that vary from run to run (such as path) fn normalize_for_explain(s: &str) -> String { // Convert things like /Users/alamb/Software/arrow/testing/data/csv/aggregate_test_100.csv From ab145c801e17cd9d7f9be820f92cfa61ed086df1 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 28 Jan 2022 13:22:04 -0500 Subject: [PATCH 04/50] Move timestamp related tests out of context.rs and into sql integration test (#1696) * Move some tests out of context.rs and into sql * Move support test out of context.rs and into sql tests * Fixup tests and make them compile --- datafusion/src/execution/context.rs | 165 ---------------------------- datafusion/src/test/mod.rs | 92 +--------------- datafusion/tests/sql/aggregates.rs | 102 +++++++++++++++++ datafusion/tests/sql/joins.rs | 48 ++++++++ datafusion/tests/sql/mod.rs | 98 ++++++++++++++++- 5 files changed, 247 insertions(+), 258 deletions(-) diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 9cc54dfe1f37..6ed8223f0c52 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -2265,121 +2265,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn aggregate_timestamps_sum() -> Result<()> { - let tmp_dir = TempDir::new()?; - let mut ctx = create_ctx(&tmp_dir, 1).await?; - ctx.register_table("t", test::table_with_timestamps()) - .unwrap(); - - let results = plan_and_collect( - &mut ctx, - "SELECT sum(nanos), sum(micros), sum(millis), sum(secs) FROM t", - ) - .await - .unwrap_err(); - - assert_eq!(results.to_string(), "Error during planning: The function Sum does not support inputs of type Timestamp(Nanosecond, None)."); - - Ok(()) - } - - #[tokio::test] - async fn aggregate_timestamps_count() -> Result<()> { - let tmp_dir = TempDir::new()?; - let mut ctx = create_ctx(&tmp_dir, 1).await?; - ctx.register_table("t", test::table_with_timestamps()) - .unwrap(); - - let results = plan_and_collect( - &mut ctx, - "SELECT count(nanos), count(micros), count(millis), count(secs) FROM t", - ) - .await - .unwrap(); - - let expected = vec![ - "+----------------+-----------------+-----------------+---------------+", - "| COUNT(t.nanos) | COUNT(t.micros) | COUNT(t.millis) | COUNT(t.secs) |", - "+----------------+-----------------+-----------------+---------------+", - "| 3 | 3 | 3 | 3 |", - "+----------------+-----------------+-----------------+---------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) - } - - #[tokio::test] - async fn aggregate_timestamps_min() -> Result<()> { - let tmp_dir = TempDir::new()?; - let mut ctx = create_ctx(&tmp_dir, 1).await?; - ctx.register_table("t", test::table_with_timestamps()) - .unwrap(); - - let results = plan_and_collect( - &mut ctx, - "SELECT min(nanos), min(micros), min(millis), min(secs) FROM t", - ) - .await - .unwrap(); - - let expected = vec![ - "+----------------------------+----------------------------+-------------------------+---------------------+", - "| MIN(t.nanos) | MIN(t.micros) | MIN(t.millis) | MIN(t.secs) |", - "+----------------------------+----------------------------+-------------------------+---------------------+", - "| 2011-12-13 11:13:10.123450 | 2011-12-13 11:13:10.123450 | 2011-12-13 11:13:10.123 | 2011-12-13 11:13:10 |", - "+----------------------------+----------------------------+-------------------------+---------------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) - } - - #[tokio::test] - async fn aggregate_timestamps_max() -> Result<()> { - let tmp_dir = TempDir::new()?; - let mut ctx = create_ctx(&tmp_dir, 1).await?; - ctx.register_table("t", test::table_with_timestamps()) - .unwrap(); - - let results = plan_and_collect( - &mut ctx, - "SELECT max(nanos), max(micros), max(millis), max(secs) FROM t", - ) - .await - .unwrap(); - - let expected = vec![ - "+-------------------------+-------------------------+-------------------------+---------------------+", - "| MAX(t.nanos) | MAX(t.micros) | MAX(t.millis) | MAX(t.secs) |", - "+-------------------------+-------------------------+-------------------------+---------------------+", - "| 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10 |", - "+-------------------------+-------------------------+-------------------------+---------------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) - } - - #[tokio::test] - async fn aggregate_timestamps_avg() -> Result<()> { - let tmp_dir = TempDir::new()?; - let mut ctx = create_ctx(&tmp_dir, 1).await?; - ctx.register_table("t", test::table_with_timestamps()) - .unwrap(); - - let results = plan_and_collect( - &mut ctx, - "SELECT avg(nanos), avg(micros), avg(millis), avg(secs) FROM t", - ) - .await - .unwrap_err(); - - assert_eq!(results.to_string(), "Error during planning: The function Avg does not support inputs of type Timestamp(Nanosecond, None)."); - Ok(()) - } - #[tokio::test] async fn aggregate_avg_add() -> Result<()> { let results = execute( @@ -2418,56 +2303,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn join_timestamp() -> Result<()> { - let tmp_dir = TempDir::new()?; - let mut ctx = create_ctx(&tmp_dir, 1).await?; - ctx.register_table("t", test::table_with_timestamps()) - .unwrap(); - - let expected = vec![ - "+-------------------------------+----------------------------+-------------------------+---------------------+-------+-------------------------------+----------------------------+-------------------------+---------------------+-------+", - "| nanos | micros | millis | secs | name | nanos | micros | millis | secs | name |", - "+-------------------------------+----------------------------+-------------------------+---------------------+-------+-------------------------------+----------------------------+-------------------------+---------------------+-------+", - "| 2011-12-13 11:13:10.123450 | 2011-12-13 11:13:10.123450 | 2011-12-13 11:13:10.123 | 2011-12-13 11:13:10 | Row 1 | 2011-12-13 11:13:10.123450 | 2011-12-13 11:13:10.123450 | 2011-12-13 11:13:10.123 | 2011-12-13 11:13:10 | Row 1 |", - "| 2018-11-13 17:11:10.011375885 | 2018-11-13 17:11:10.011375 | 2018-11-13 17:11:10.011 | 2018-11-13 17:11:10 | Row 0 | 2018-11-13 17:11:10.011375885 | 2018-11-13 17:11:10.011375 | 2018-11-13 17:11:10.011 | 2018-11-13 17:11:10 | Row 0 |", - "| 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10 | Row 3 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10 | Row 3 |", - "+-------------------------------+----------------------------+-------------------------+---------------------+-------+-------------------------------+----------------------------+-------------------------+---------------------+-------+", - ]; - - let results = plan_and_collect( - &mut ctx, - "SELECT * FROM t as t1 \ - JOIN (SELECT * FROM t) as t2 \ - ON t1.nanos = t2.nanos", - ) - .await - .unwrap(); - assert_batches_sorted_eq!(expected, &results); - - let results = plan_and_collect( - &mut ctx, - "SELECT * FROM t as t1 \ - JOIN (SELECT * FROM t) as t2 \ - ON t1.micros = t2.micros", - ) - .await - .unwrap(); - assert_batches_sorted_eq!(expected, &results); - - let results = plan_and_collect( - &mut ctx, - "SELECT * FROM t as t1 \ - JOIN (SELECT * FROM t) as t2 \ - ON t1.millis = t2.millis", - ) - .await - .unwrap(); - assert_batches_sorted_eq!(expected, &results); - - Ok(()) - } - #[tokio::test] async fn count_basic() -> Result<()> { let results = execute("SELECT COUNT(c1), COUNT(c2) FROM test", 1).await?; diff --git a/datafusion/src/test/mod.rs b/datafusion/src/test/mod.rs index 844d03188eae..497bfe59e1a1 100644 --- a/datafusion/src/test/mod.rs +++ b/datafusion/src/test/mod.rs @@ -22,10 +22,7 @@ use crate::datasource::{MemTable, PartitionedFile, TableProvider}; use crate::error::Result; use crate::from_slice::FromSlice; use crate::logical_plan::{LogicalPlan, LogicalPlanBuilder}; -use array::{ - Array, ArrayRef, StringArray, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampNanosecondArray, TimestampSecondArray, -}; +use array::{Array, ArrayRef}; use arrow::array::{self, DecimalBuilder, Int32Array}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; @@ -185,14 +182,6 @@ pub fn make_partition(sz: i32) -> RecordBatch { RecordBatch::try_new(schema, vec![arr]).unwrap() } -/// Return a new table provider containing all of the supported timestamp types -pub fn table_with_timestamps() -> Arc { - let batch = make_timestamps(); - let schema = batch.schema(); - let partitions = vec![vec![batch]]; - Arc::new(MemTable::try_new(schema, partitions).unwrap()) -} - /// Return a new table which provide this decimal column pub fn table_with_decimal() -> Arc { let batch_decimal = make_decimal(); @@ -214,85 +203,6 @@ fn make_decimal() -> RecordBatch { RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).unwrap() } -/// Return record batch with all of the supported timestamp types -/// values -/// -/// Columns are named: -/// "nanos" --> TimestampNanosecondArray -/// "micros" --> TimestampMicrosecondArray -/// "millis" --> TimestampMillisecondArray -/// "secs" --> TimestampSecondArray -/// "names" --> StringArray -pub fn make_timestamps() -> RecordBatch { - let ts_strings = vec![ - Some("2018-11-13T17:11:10.011375885995"), - Some("2011-12-13T11:13:10.12345"), - None, - Some("2021-1-1T05:11:10.432"), - ]; - - let ts_nanos = ts_strings - .into_iter() - .map(|t| { - t.map(|t| { - t.parse::() - .unwrap() - .timestamp_nanos() - }) - }) - .collect::>(); - - let ts_micros = ts_nanos - .iter() - .map(|t| t.as_ref().map(|ts_nanos| ts_nanos / 1000)) - .collect::>(); - - let ts_millis = ts_nanos - .iter() - .map(|t| t.as_ref().map(|ts_nanos| ts_nanos / 1000000)) - .collect::>(); - - let ts_secs = ts_nanos - .iter() - .map(|t| t.as_ref().map(|ts_nanos| ts_nanos / 1000000000)) - .collect::>(); - - let names = ts_nanos - .iter() - .enumerate() - .map(|(i, _)| format!("Row {}", i)) - .collect::>(); - - let arr_nanos = TimestampNanosecondArray::from_opt_vec(ts_nanos, None); - let arr_micros = TimestampMicrosecondArray::from_opt_vec(ts_micros, None); - let arr_millis = TimestampMillisecondArray::from_opt_vec(ts_millis, None); - let arr_secs = TimestampSecondArray::from_opt_vec(ts_secs, None); - - let names = names.iter().map(|s| s.as_str()).collect::>(); - let arr_names = StringArray::from(names); - - let schema = Schema::new(vec![ - Field::new("nanos", arr_nanos.data_type().clone(), true), - Field::new("micros", arr_micros.data_type().clone(), true), - Field::new("millis", arr_millis.data_type().clone(), true), - Field::new("secs", arr_secs.data_type().clone(), true), - Field::new("name", arr_names.data_type().clone(), true), - ]); - let schema = Arc::new(schema); - - RecordBatch::try_new( - schema, - vec![ - Arc::new(arr_nanos), - Arc::new(arr_micros), - Arc::new(arr_millis), - Arc::new(arr_secs), - Arc::new(arr_names), - ], - ) - .unwrap() -} - /// Asserts that given future is pending. pub fn assert_is_pending<'a, T>(fut: &mut Pin + Send + 'a>>) { let waker = futures::task::noop_waker(); diff --git a/datafusion/tests/sql/aggregates.rs b/datafusion/tests/sql/aggregates.rs index 9d72752b091d..2d4287054388 100644 --- a/datafusion/tests/sql/aggregates.rs +++ b/datafusion/tests/sql/aggregates.rs @@ -473,3 +473,105 @@ async fn csv_query_array_agg_distinct() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn aggregate_timestamps_sum() -> Result<()> { + let mut ctx = ExecutionContext::new(); + ctx.register_table("t", table_with_timestamps()).unwrap(); + + let results = plan_and_collect( + &mut ctx, + "SELECT sum(nanos), sum(micros), sum(millis), sum(secs) FROM t", + ) + .await + .unwrap_err(); + + assert_eq!(results.to_string(), "Error during planning: The function Sum does not support inputs of type Timestamp(Nanosecond, None)."); + + Ok(()) +} + +#[tokio::test] +async fn aggregate_timestamps_count() -> Result<()> { + let mut ctx = ExecutionContext::new(); + ctx.register_table("t", table_with_timestamps()).unwrap(); + + let results = execute_to_batches( + &mut ctx, + "SELECT count(nanos), count(micros), count(millis), count(secs) FROM t", + ) + .await; + + let expected = vec![ + "+----------------+-----------------+-----------------+---------------+", + "| COUNT(t.nanos) | COUNT(t.micros) | COUNT(t.millis) | COUNT(t.secs) |", + "+----------------+-----------------+-----------------+---------------+", + "| 3 | 3 | 3 | 3 |", + "+----------------+-----------------+-----------------+---------------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn aggregate_timestamps_min() -> Result<()> { + let mut ctx = ExecutionContext::new(); + ctx.register_table("t", table_with_timestamps()).unwrap(); + + let results = execute_to_batches( + &mut ctx, + "SELECT min(nanos), min(micros), min(millis), min(secs) FROM t", + ) + .await; + + let expected = vec![ + "+----------------------------+----------------------------+-------------------------+---------------------+", + "| MIN(t.nanos) | MIN(t.micros) | MIN(t.millis) | MIN(t.secs) |", + "+----------------------------+----------------------------+-------------------------+---------------------+", + "| 2011-12-13 11:13:10.123450 | 2011-12-13 11:13:10.123450 | 2011-12-13 11:13:10.123 | 2011-12-13 11:13:10 |", + "+----------------------------+----------------------------+-------------------------+---------------------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn aggregate_timestamps_max() -> Result<()> { + let mut ctx = ExecutionContext::new(); + ctx.register_table("t", table_with_timestamps()).unwrap(); + + let results = execute_to_batches( + &mut ctx, + "SELECT max(nanos), max(micros), max(millis), max(secs) FROM t", + ) + .await; + + let expected = vec![ + "+-------------------------+-------------------------+-------------------------+---------------------+", + "| MAX(t.nanos) | MAX(t.micros) | MAX(t.millis) | MAX(t.secs) |", + "+-------------------------+-------------------------+-------------------------+---------------------+", + "| 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10 |", + "+-------------------------+-------------------------+-------------------------+---------------------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn aggregate_timestamps_avg() -> Result<()> { + let mut ctx = ExecutionContext::new(); + ctx.register_table("t", table_with_timestamps()).unwrap(); + + let results = plan_and_collect( + &mut ctx, + "SELECT avg(nanos), avg(micros), avg(millis), avg(secs) FROM t", + ) + .await + .unwrap_err(); + + assert_eq!(results.to_string(), "Error during planning: The function Avg does not support inputs of type Timestamp(Nanosecond, None)."); + Ok(()) +} diff --git a/datafusion/tests/sql/joins.rs b/datafusion/tests/sql/joins.rs index 70d824b12e1a..04436ed460b1 100644 --- a/datafusion/tests/sql/joins.rs +++ b/datafusion/tests/sql/joins.rs @@ -882,3 +882,51 @@ async fn join_tables_with_duplicated_column_name_not_in_on_constraint() -> Resul assert_batches_eq!(expected, &actual); Ok(()) } + +#[tokio::test] +async fn join_timestamp() -> Result<()> { + let mut ctx = ExecutionContext::new(); + ctx.register_table("t", table_with_timestamps()).unwrap(); + + let expected = vec![ + "+-------------------------------+----------------------------+-------------------------+---------------------+-------+-------------------------------+----------------------------+-------------------------+---------------------+-------+", + "| nanos | micros | millis | secs | name | nanos | micros | millis | secs | name |", + "+-------------------------------+----------------------------+-------------------------+---------------------+-------+-------------------------------+----------------------------+-------------------------+---------------------+-------+", + "| 2011-12-13 11:13:10.123450 | 2011-12-13 11:13:10.123450 | 2011-12-13 11:13:10.123 | 2011-12-13 11:13:10 | Row 1 | 2011-12-13 11:13:10.123450 | 2011-12-13 11:13:10.123450 | 2011-12-13 11:13:10.123 | 2011-12-13 11:13:10 | Row 1 |", + "| 2018-11-13 17:11:10.011375885 | 2018-11-13 17:11:10.011375 | 2018-11-13 17:11:10.011 | 2018-11-13 17:11:10 | Row 0 | 2018-11-13 17:11:10.011375885 | 2018-11-13 17:11:10.011375 | 2018-11-13 17:11:10.011 | 2018-11-13 17:11:10 | Row 0 |", + "| 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10 | Row 3 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10.432 | 2021-01-01 05:11:10 | Row 3 |", + "+-------------------------------+----------------------------+-------------------------+---------------------+-------+-------------------------------+----------------------------+-------------------------+---------------------+-------+", + ]; + + let results = execute_to_batches( + &mut ctx, + "SELECT * FROM t as t1 \ + JOIN (SELECT * FROM t) as t2 \ + ON t1.nanos = t2.nanos", + ) + .await; + + assert_batches_sorted_eq!(expected, &results); + + let results = execute_to_batches( + &mut ctx, + "SELECT * FROM t as t1 \ + JOIN (SELECT * FROM t) as t2 \ + ON t1.micros = t2.micros", + ) + .await; + + assert_batches_sorted_eq!(expected, &results); + + let results = execute_to_batches( + &mut ctx, + "SELECT * FROM t as t1 \ + JOIN (SELECT * FROM t) as t2 \ + ON t1.millis = t2.millis", + ) + .await; + + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} diff --git a/datafusion/tests/sql/mod.rs b/datafusion/tests/sql/mod.rs index f2496c36814b..90fe5138ac44 100644 --- a/datafusion/tests/sql/mod.rs +++ b/datafusion/tests/sql/mod.rs @@ -521,8 +521,15 @@ async fn register_aggregate_csv(ctx: &mut ExecutionContext) -> Result<()> { Ok(()) } -/// Execute query and return result set as 2-d table of Vecs -/// `result[row][column]` +/// Execute SQL and return results as a RecordBatch +async fn plan_and_collect( + ctx: &mut ExecutionContext, + sql: &str, +) -> Result> { + ctx.sql(sql).await?.collect().await +} + +/// Execute query and return results as a Vec of RecordBatches async fn execute_to_batches(ctx: &mut ExecutionContext, sql: &str) -> Vec { let msg = format!("Creating logical plan for '{}'", sql); let plan = ctx.create_logical_plan(sql).expect(&msg); @@ -734,6 +741,93 @@ fn normalize_vec_for_explain(v: Vec>) -> Vec> { .collect::>() } +/// Return a new table provider containing all of the supported timestamp types +pub fn table_with_timestamps() -> Arc { + let batch = make_timestamps(); + let schema = batch.schema(); + let partitions = vec![vec![batch]]; + Arc::new(MemTable::try_new(schema, partitions).unwrap()) +} + +/// Return record batch with all of the supported timestamp types +/// values +/// +/// Columns are named: +/// "nanos" --> TimestampNanosecondArray +/// "micros" --> TimestampMicrosecondArray +/// "millis" --> TimestampMillisecondArray +/// "secs" --> TimestampSecondArray +/// "names" --> StringArray +pub fn make_timestamps() -> RecordBatch { + let ts_strings = vec![ + Some("2018-11-13T17:11:10.011375885995"), + Some("2011-12-13T11:13:10.12345"), + None, + Some("2021-1-1T05:11:10.432"), + ]; + + let ts_nanos = ts_strings + .into_iter() + .map(|t| { + t.map(|t| { + t.parse::() + .unwrap() + .timestamp_nanos() + }) + }) + .collect::>(); + + let ts_micros = ts_nanos + .iter() + .map(|t| t.as_ref().map(|ts_nanos| ts_nanos / 1000)) + .collect::>(); + + let ts_millis = ts_nanos + .iter() + .map(|t| t.as_ref().map(|ts_nanos| ts_nanos / 1000000)) + .collect::>(); + + let ts_secs = ts_nanos + .iter() + .map(|t| t.as_ref().map(|ts_nanos| ts_nanos / 1000000000)) + .collect::>(); + + let names = ts_nanos + .iter() + .enumerate() + .map(|(i, _)| format!("Row {}", i)) + .collect::>(); + + let arr_nanos = TimestampNanosecondArray::from_opt_vec(ts_nanos, None); + let arr_micros = TimestampMicrosecondArray::from_opt_vec(ts_micros, None); + let arr_millis = TimestampMillisecondArray::from_opt_vec(ts_millis, None); + let arr_secs = TimestampSecondArray::from_opt_vec(ts_secs, None); + + let names = names.iter().map(|s| s.as_str()).collect::>(); + let arr_names = StringArray::from(names); + + let schema = Schema::new(vec![ + Field::new("nanos", arr_nanos.data_type().clone(), true), + Field::new("micros", arr_micros.data_type().clone(), true), + Field::new("millis", arr_millis.data_type().clone(), true), + Field::new("secs", arr_secs.data_type().clone(), true), + Field::new("name", arr_names.data_type().clone(), true), + ]); + let schema = Arc::new(schema); + + RecordBatch::try_new( + schema, + vec![ + Arc::new(arr_nanos), + Arc::new(arr_micros), + Arc::new(arr_millis), + Arc::new(arr_secs), + Arc::new(arr_names), + ], + ) + .unwrap() +} + #[tokio::test] async fn nyc() -> Result<()> { // schema for nyxtaxi csv files From 641338f726549c10c5bafee34537dc1e56cdec04 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Sun, 30 Jan 2022 00:02:49 +0800 Subject: [PATCH 05/50] Add `MemTrackingMetrics` to ease memory tracking for non-limited memory consumers (#1691) * Memory manager no longer track consumers, update aggregatedMetricsSet * Easy memory tracking with metrics * use tracking metrics in SPMS * tests * fix * doc * Update datafusion/src/physical_plan/sorts/sort.rs Co-authored-by: Andrew Lamb * make tracker AtomicUsize Co-authored-by: Andrew Lamb --- datafusion/src/execution/memory_manager.rs | 134 ++++++------ datafusion/src/execution/runtime_env.rs | 24 +- datafusion/src/physical_plan/common.rs | 12 +- datafusion/src/physical_plan/explain.rs | 6 +- .../src/physical_plan/metrics/aggregated.rs | 155 ------------- .../src/physical_plan/metrics/baseline.rs | 14 +- .../src/physical_plan/metrics/composite.rs | 205 ++++++++++++++++++ datafusion/src/physical_plan/metrics/mod.rs | 6 +- .../src/physical_plan/metrics/tracker.rs | 131 +++++++++++ datafusion/src/physical_plan/sorts/mod.rs | 9 - datafusion/src/physical_plan/sorts/sort.rs | 95 ++++---- .../sorts/sort_preserving_merge.rs | 114 ++-------- datafusion/tests/provider_filter_pushdown.rs | 6 +- 13 files changed, 525 insertions(+), 386 deletions(-) delete mode 100644 datafusion/src/physical_plan/metrics/aggregated.rs create mode 100644 datafusion/src/physical_plan/metrics/composite.rs create mode 100644 datafusion/src/physical_plan/metrics/tracker.rs diff --git a/datafusion/src/execution/memory_manager.rs b/datafusion/src/execution/memory_manager.rs index 53eb720c4729..0fb3cfbb4ecf 100644 --- a/datafusion/src/execution/memory_manager.rs +++ b/datafusion/src/execution/memory_manager.rs @@ -19,12 +19,12 @@ use crate::error::{DataFusionError, Result}; use async_trait::async_trait; -use hashbrown::HashMap; +use hashbrown::HashSet; use log::debug; use std::fmt; use std::fmt::{Debug, Display, Formatter}; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::{Arc, Condvar, Mutex, Weak}; +use std::sync::{Arc, Condvar, Mutex}; static CONSUMER_ID: AtomicUsize = AtomicUsize::new(0); @@ -245,10 +245,10 @@ The memory management architecture is the following: /// Manage memory usage during physical plan execution #[derive(Debug)] pub struct MemoryManager { - requesters: Arc>>>, - trackers: Arc>>>, + requesters: Arc>>, pool_size: usize, requesters_total: Arc>, + trackers_total: AtomicUsize, cv: Condvar, } @@ -267,10 +267,10 @@ impl MemoryManager { ); Arc::new(Self { - requesters: Arc::new(Mutex::new(HashMap::new())), - trackers: Arc::new(Mutex::new(HashMap::new())), + requesters: Arc::new(Mutex::new(HashSet::new())), pool_size, requesters_total: Arc::new(Mutex::new(0)), + trackers_total: AtomicUsize::new(0), cv: Condvar::new(), }) } @@ -278,30 +278,36 @@ impl MemoryManager { } fn get_tracker_total(&self) -> usize { - let trackers = self.trackers.lock().unwrap(); - if trackers.len() > 0 { - trackers.values().fold(0usize, |acc, y| match y.upgrade() { - None => acc, - Some(t) => acc + t.mem_used(), - }) - } else { - 0 - } + self.trackers_total.load(Ordering::SeqCst) } - /// Register a new memory consumer for memory usage tracking - pub(crate) fn register_consumer(&self, consumer: &Arc) { - let id = consumer.id().clone(); - match consumer.type_() { - ConsumerType::Requesting => { - let mut requesters = self.requesters.lock().unwrap(); - requesters.insert(id, Arc::downgrade(consumer)); - } - ConsumerType::Tracking => { - let mut trackers = self.trackers.lock().unwrap(); - trackers.insert(id, Arc::downgrade(consumer)); - } - } + pub(crate) fn grow_tracker_usage(&self, delta: usize) { + self.trackers_total.fetch_add(delta, Ordering::SeqCst); + } + + pub(crate) fn shrink_tracker_usage(&self, delta: usize) { + let update = + self.trackers_total + .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |x| { + if x >= delta { + Some(x - delta) + } else { + None + } + }); + update.expect(&*format!( + "Tracker total memory shrink by {} underflow, current value is ", + delta + )); + } + + fn get_requester_total(&self) -> usize { + *self.requesters_total.lock().unwrap() + } + + /// Register a new memory requester + pub(crate) fn register_requester(&self, requester_id: &MemoryConsumerId) { + self.requesters.lock().unwrap().insert(requester_id.clone()); } fn max_mem_for_requesters(&self) -> usize { @@ -317,7 +323,6 @@ impl MemoryManager { let granted; loop { - let remaining = rqt_max - *rqt_current_used; let max_per_rqt = rqt_max / num_rqt; let min_per_rqt = max_per_rqt / 2; @@ -326,6 +331,7 @@ impl MemoryManager { break; } + let remaining = rqt_max.checked_sub(*rqt_current_used).unwrap_or_default(); if remaining >= required { granted = true; *rqt_current_used += required; @@ -347,46 +353,37 @@ impl MemoryManager { fn record_free_then_acquire(&self, freed: usize, acquired: usize) { let mut requesters_total = self.requesters_total.lock().unwrap(); + assert!(*requesters_total >= freed); *requesters_total -= freed; *requesters_total += acquired; self.cv.notify_all() } - /// Drop a memory consumer from memory usage tracking - pub(crate) fn drop_consumer(&self, id: &MemoryConsumerId) { + /// Drop a memory consumer and reclaim the memory + pub(crate) fn drop_consumer(&self, id: &MemoryConsumerId, mem_used: usize) { // find in requesters first { let mut requesters = self.requesters.lock().unwrap(); - if requesters.remove(id).is_some() { - return; + if requesters.remove(id) { + let mut total = self.requesters_total.lock().unwrap(); + assert!(*total >= mem_used); + *total -= mem_used; } } - let mut trackers = self.trackers.lock().unwrap(); - trackers.remove(id); + self.shrink_tracker_usage(mem_used); + self.cv.notify_all(); } } impl Display for MemoryManager { fn fmt(&self, f: &mut Formatter) -> fmt::Result { - let requesters = - self.requesters - .lock() - .unwrap() - .values() - .fold(vec![], |mut acc, consumer| match consumer.upgrade() { - None => acc, - Some(c) => { - acc.push(format!("{}", c)); - acc - } - }); - let tracker_mem = self.get_tracker_total(); write!(f, - "MemoryManager usage statistics: total {}, tracker used {}, total {} requesters detail: \n {},", - human_readable_size(self.pool_size), - human_readable_size(tracker_mem), - &requesters.len(), - requesters.join("\n")) + "MemoryManager usage statistics: total {}, trackers used {}, total {} requesters used: {}", + human_readable_size(self.pool_size), + human_readable_size(self.get_tracker_total()), + self.requesters.lock().unwrap().len(), + human_readable_size(self.get_requester_total()), + ) } } @@ -418,6 +415,8 @@ mod tests { use super::*; use crate::error::Result; use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use crate::execution::MemoryConsumer; + use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MemTrackingMetrics}; use async_trait::async_trait; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; @@ -487,6 +486,7 @@ mod tests { impl DummyTracker { fn new(partition: usize, runtime: Arc, mem_used: usize) -> Self { + runtime.grow_tracker_usage(mem_used); Self { id: MemoryConsumerId::new(partition), runtime, @@ -528,23 +528,29 @@ mod tests { .with_memory_manager(MemoryManagerConfig::try_new_limit(100, 1.0).unwrap()); let runtime = Arc::new(RuntimeEnv::new(config).unwrap()); - let tracker1 = Arc::new(DummyTracker::new(0, runtime.clone(), 5)); - runtime.register_consumer(&(tracker1.clone() as Arc)); + DummyTracker::new(0, runtime.clone(), 5); assert_eq!(runtime.memory_manager.get_tracker_total(), 5); - let tracker2 = Arc::new(DummyTracker::new(0, runtime.clone(), 10)); - runtime.register_consumer(&(tracker2.clone() as Arc)); + let tracker1 = DummyTracker::new(0, runtime.clone(), 10); assert_eq!(runtime.memory_manager.get_tracker_total(), 15); - let tracker3 = Arc::new(DummyTracker::new(0, runtime.clone(), 15)); - runtime.register_consumer(&(tracker3.clone() as Arc)); + DummyTracker::new(0, runtime.clone(), 15); assert_eq!(runtime.memory_manager.get_tracker_total(), 30); - runtime.drop_consumer(tracker2.id()); + runtime.drop_consumer(tracker1.id(), tracker1.mem_used); + assert_eq!(runtime.memory_manager.get_tracker_total(), 20); + + // MemTrackingMetrics as an easy way to track memory + let ms = ExecutionPlanMetricsSet::new(); + let tracking_metric = MemTrackingMetrics::new_with_rt(&ms, 0, runtime.clone()); + tracking_metric.init_mem_used(15); + assert_eq!(runtime.memory_manager.get_tracker_total(), 35); + + drop(tracking_metric); assert_eq!(runtime.memory_manager.get_tracker_total(), 20); - let requester1 = Arc::new(DummyRequester::new(0, runtime.clone())); - runtime.register_consumer(&(requester1.clone() as Arc)); + let requester1 = DummyRequester::new(0, runtime.clone()); + runtime.register_requester(requester1.id()); // first requester entered, should be able to use any of the remaining 80 requester1.do_with_mem(40).await.unwrap(); @@ -553,8 +559,8 @@ mod tests { assert_eq!(requester1.mem_used(), 50); assert_eq!(*runtime.memory_manager.requesters_total.lock().unwrap(), 50); - let requester2 = Arc::new(DummyRequester::new(0, runtime.clone())); - runtime.register_consumer(&(requester2.clone() as Arc)); + let requester2 = DummyRequester::new(0, runtime.clone()); + runtime.register_requester(requester2.id()); requester2.do_with_mem(20).await.unwrap(); requester2.do_with_mem(30).await.unwrap(); diff --git a/datafusion/src/execution/runtime_env.rs b/datafusion/src/execution/runtime_env.rs index cdcd1f71b4f5..e993b385ecd4 100644 --- a/datafusion/src/execution/runtime_env.rs +++ b/datafusion/src/execution/runtime_env.rs @@ -22,9 +22,7 @@ use crate::{ error::Result, execution::{ disk_manager::{DiskManager, DiskManagerConfig}, - memory_manager::{ - MemoryConsumer, MemoryConsumerId, MemoryManager, MemoryManagerConfig, - }, + memory_manager::{MemoryConsumerId, MemoryManager, MemoryManagerConfig}, }, }; @@ -71,13 +69,23 @@ impl RuntimeEnv { } /// Register the consumer to get it tracked - pub fn register_consumer(&self, memory_consumer: &Arc) { - self.memory_manager.register_consumer(memory_consumer); + pub fn register_requester(&self, id: &MemoryConsumerId) { + self.memory_manager.register_requester(id); } - /// Drop the consumer from get tracked - pub fn drop_consumer(&self, id: &MemoryConsumerId) { - self.memory_manager.drop_consumer(id) + /// Drop the consumer from get tracked, reclaim memory + pub fn drop_consumer(&self, id: &MemoryConsumerId, mem_used: usize) { + self.memory_manager.drop_consumer(id, mem_used) + } + + /// Grow tracker memory of `delta` + pub fn grow_tracker_usage(&self, delta: usize) { + self.memory_manager.grow_tracker_usage(delta) + } + + /// Shrink tracker memory of `delta` + pub fn shrink_tracker_usage(&self, delta: usize) { + self.memory_manager.shrink_tracker_usage(delta) } } diff --git a/datafusion/src/physical_plan/common.rs b/datafusion/src/physical_plan/common.rs index 390f004fb469..bc4400d98186 100644 --- a/datafusion/src/physical_plan/common.rs +++ b/datafusion/src/physical_plan/common.rs @@ -20,7 +20,7 @@ use super::{RecordBatchStream, SendableRecordBatchStream}; use crate::error::{DataFusionError, Result}; use crate::execution::runtime_env::RuntimeEnv; -use crate::physical_plan::metrics::BaselineMetrics; +use crate::physical_plan::metrics::MemTrackingMetrics; use crate::physical_plan::{ColumnStatistics, ExecutionPlan, Statistics}; use arrow::compute::concat; use arrow::datatypes::{Schema, SchemaRef}; @@ -43,7 +43,7 @@ pub struct SizedRecordBatchStream { schema: SchemaRef, batches: Vec>, index: usize, - baseline_metrics: BaselineMetrics, + metrics: MemTrackingMetrics, } impl SizedRecordBatchStream { @@ -51,13 +51,15 @@ impl SizedRecordBatchStream { pub fn new( schema: SchemaRef, batches: Vec>, - baseline_metrics: BaselineMetrics, + metrics: MemTrackingMetrics, ) -> Self { + let size = batches.iter().map(|b| batch_byte_size(b)).sum::(); + metrics.init_mem_used(size); SizedRecordBatchStream { schema, index: 0, batches, - baseline_metrics, + metrics, } } } @@ -75,7 +77,7 @@ impl Stream for SizedRecordBatchStream { } else { None }); - self.baseline_metrics.record_poll(poll) + self.metrics.record_poll(poll) } } diff --git a/datafusion/src/physical_plan/explain.rs b/datafusion/src/physical_plan/explain.rs index f827dc32eca4..eb18926f9466 100644 --- a/datafusion/src/physical_plan/explain.rs +++ b/datafusion/src/physical_plan/explain.rs @@ -32,7 +32,7 @@ use arrow::{array::StringBuilder, datatypes::SchemaRef, record_batch::RecordBatc use super::SendableRecordBatchStream; use crate::execution::runtime_env::RuntimeEnv; -use crate::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet}; +use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MemTrackingMetrics}; use async_trait::async_trait; /// Explain execution plan operator. This operator contains the string @@ -148,12 +148,12 @@ impl ExecutionPlan for ExplainExec { )?; let metrics = ExecutionPlanMetricsSet::new(); - let baseline_metrics = BaselineMetrics::new(&metrics, partition); + let tracking_metrics = MemTrackingMetrics::new(&metrics, partition); Ok(Box::pin(SizedRecordBatchStream::new( self.schema.clone(), vec![Arc::new(record_batch)], - baseline_metrics, + tracking_metrics, ))) } diff --git a/datafusion/src/physical_plan/metrics/aggregated.rs b/datafusion/src/physical_plan/metrics/aggregated.rs deleted file mode 100644 index c55cc1601768..000000000000 --- a/datafusion/src/physical_plan/metrics/aggregated.rs +++ /dev/null @@ -1,155 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Metrics common for complex operators with multiple steps. - -use crate::physical_plan::metrics::{ - BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricsSet, Time, -}; -use std::sync::Arc; -use std::time::Duration; - -#[derive(Debug, Clone)] -/// Aggregates all metrics during a complex operation, which is composed of multiple steps and -/// each stage reports its statistics separately. -/// Give sort as an example, when the dataset is more significant than available memory, it will report -/// multiple in-mem sort metrics and final merge-sort metrics from `SortPreservingMergeStream`. -/// Therefore, We need a separation of metrics for which are final metrics (for output_rows accumulation), -/// and which are intermediate metrics that we only account for elapsed_compute time. -pub struct AggregatedMetricsSet { - intermediate: Arc>>, - final_: Arc>>, -} - -impl Default for AggregatedMetricsSet { - fn default() -> Self { - Self::new() - } -} - -impl AggregatedMetricsSet { - /// Create a new aggregated set - pub fn new() -> Self { - Self { - intermediate: Arc::new(std::sync::Mutex::new(vec![])), - final_: Arc::new(std::sync::Mutex::new(vec![])), - } - } - - /// create a new intermediate baseline - pub fn new_intermediate_baseline(&self, partition: usize) -> BaselineMetrics { - let ms = ExecutionPlanMetricsSet::new(); - let result = BaselineMetrics::new(&ms, partition); - self.intermediate.lock().unwrap().push(ms); - result - } - - /// create a new final baseline - pub fn new_final_baseline(&self, partition: usize) -> BaselineMetrics { - let ms = ExecutionPlanMetricsSet::new(); - let result = BaselineMetrics::new(&ms, partition); - self.final_.lock().unwrap().push(ms); - result - } - - fn merge_compute_time(&self, dest: &Time) { - let time1 = self - .intermediate - .lock() - .unwrap() - .iter() - .map(|es| { - es.clone_inner() - .elapsed_compute() - .map_or(0u64, |v| v as u64) - }) - .sum(); - let time2 = self - .final_ - .lock() - .unwrap() - .iter() - .map(|es| { - es.clone_inner() - .elapsed_compute() - .map_or(0u64, |v| v as u64) - }) - .sum(); - dest.add_duration(Duration::from_nanos(time1)); - dest.add_duration(Duration::from_nanos(time2)); - } - - fn merge_spill_count(&self, dest: &Count) { - let count1 = self - .intermediate - .lock() - .unwrap() - .iter() - .map(|es| es.clone_inner().spill_count().map_or(0, |v| v)) - .sum(); - let count2 = self - .final_ - .lock() - .unwrap() - .iter() - .map(|es| es.clone_inner().spill_count().map_or(0, |v| v)) - .sum(); - dest.add(count1); - dest.add(count2); - } - - fn merge_spilled_bytes(&self, dest: &Count) { - let count1 = self - .intermediate - .lock() - .unwrap() - .iter() - .map(|es| es.clone_inner().spilled_bytes().map_or(0, |v| v)) - .sum(); - let count2 = self - .final_ - .lock() - .unwrap() - .iter() - .map(|es| es.clone_inner().spilled_bytes().map_or(0, |v| v)) - .sum(); - dest.add(count1); - dest.add(count2); - } - - fn merge_output_count(&self, dest: &Count) { - let count = self - .final_ - .lock() - .unwrap() - .iter() - .map(|es| es.clone_inner().output_rows().map_or(0, |v| v)) - .sum(); - dest.add(count); - } - - /// Aggregate all metrics into a one - pub fn aggregate_all(&self) -> MetricsSet { - let metrics = ExecutionPlanMetricsSet::new(); - let baseline = BaselineMetrics::new(&metrics, 0); - self.merge_compute_time(baseline.elapsed_compute()); - self.merge_spill_count(baseline.spill_count()); - self.merge_spilled_bytes(baseline.spilled_bytes()); - self.merge_output_count(baseline.output_rows()); - metrics.clone_inner() - } -} diff --git a/datafusion/src/physical_plan/metrics/baseline.rs b/datafusion/src/physical_plan/metrics/baseline.rs index 50c49ece141b..8dff5ee3fd77 100644 --- a/datafusion/src/physical_plan/metrics/baseline.rs +++ b/datafusion/src/physical_plan/metrics/baseline.rs @@ -113,7 +113,7 @@ impl BaselineMetrics { /// Records the fact that this operator's execution is complete /// (recording the `end_time` metric). /// - /// Note care should be taken to call `done()` maually if + /// Note care should be taken to call `done()` manually if /// `BaselineMetrics` is not `drop`ped immediately upon operator /// completion, as async streams may not be dropped immediately /// depending on the consumer. @@ -129,6 +129,13 @@ impl BaselineMetrics { self.output_rows.add(num_rows); } + /// If not previously recorded `done()`, record + pub fn try_done(&self) { + if self.end_time.value().is_none() { + self.end_time.record() + } + } + /// Process a poll result of a stream producing output for an /// operator, recording the output rows and stream done time and /// returning the same poll result @@ -151,10 +158,7 @@ impl BaselineMetrics { impl Drop for BaselineMetrics { fn drop(&mut self) { - // if not previously recorded, record - if self.end_time.value().is_none() { - self.end_time.record() - } + self.try_done() } } diff --git a/datafusion/src/physical_plan/metrics/composite.rs b/datafusion/src/physical_plan/metrics/composite.rs new file mode 100644 index 000000000000..cd4d5c38a9ec --- /dev/null +++ b/datafusion/src/physical_plan/metrics/composite.rs @@ -0,0 +1,205 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Metrics common for complex operators with multiple steps. + +use crate::execution::runtime_env::RuntimeEnv; +use crate::physical_plan::metrics::tracker::MemTrackingMetrics; +use crate::physical_plan::metrics::{ + BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricValue, MetricsSet, Time, + Timestamp, +}; +use crate::physical_plan::Metric; +use chrono::{TimeZone, Utc}; +use std::sync::Arc; +use std::time::Duration; + +#[derive(Debug, Clone)] +/// Collects all metrics during a complex operation, which is composed of multiple steps and +/// each stage reports its statistics separately. +/// Give sort as an example, when the dataset is more significant than available memory, it will report +/// multiple in-mem sort metrics and final merge-sort metrics from `SortPreservingMergeStream`. +/// Therefore, We need a separation of metrics for which are final metrics (for output_rows accumulation), +/// and which are intermediate metrics that we only account for elapsed_compute time. +pub struct CompositeMetricsSet { + mid: ExecutionPlanMetricsSet, + final_: ExecutionPlanMetricsSet, +} + +impl Default for CompositeMetricsSet { + fn default() -> Self { + Self::new() + } +} + +impl CompositeMetricsSet { + /// Create a new aggregated set + pub fn new() -> Self { + Self { + mid: ExecutionPlanMetricsSet::new(), + final_: ExecutionPlanMetricsSet::new(), + } + } + + /// create a new intermediate baseline + pub fn new_intermediate_baseline(&self, partition: usize) -> BaselineMetrics { + BaselineMetrics::new(&self.mid, partition) + } + + /// create a new final baseline + pub fn new_final_baseline(&self, partition: usize) -> BaselineMetrics { + BaselineMetrics::new(&self.final_, partition) + } + + /// create a new intermediate memory tracking metrics + pub fn new_intermediate_tracking( + &self, + partition: usize, + runtime: Arc, + ) -> MemTrackingMetrics { + MemTrackingMetrics::new_with_rt(&self.mid, partition, runtime) + } + + /// create a new final memory tracking metrics + pub fn new_final_tracking( + &self, + partition: usize, + runtime: Arc, + ) -> MemTrackingMetrics { + MemTrackingMetrics::new_with_rt(&self.final_, partition, runtime) + } + + fn merge_compute_time(&self, dest: &Time) { + let time1 = self + .mid + .clone_inner() + .elapsed_compute() + .map_or(0u64, |v| v as u64); + let time2 = self + .final_ + .clone_inner() + .elapsed_compute() + .map_or(0u64, |v| v as u64); + dest.add_duration(Duration::from_nanos(time1)); + dest.add_duration(Duration::from_nanos(time2)); + } + + fn merge_spill_count(&self, dest: &Count) { + let count1 = self.mid.clone_inner().spill_count().map_or(0, |v| v); + let count2 = self.final_.clone_inner().spill_count().map_or(0, |v| v); + dest.add(count1); + dest.add(count2); + } + + fn merge_spilled_bytes(&self, dest: &Count) { + let count1 = self.mid.clone_inner().spilled_bytes().map_or(0, |v| v); + let count2 = self.final_.clone_inner().spill_count().map_or(0, |v| v); + dest.add(count1); + dest.add(count2); + } + + fn merge_output_count(&self, dest: &Count) { + let count = self.final_.clone_inner().output_rows().map_or(0, |v| v); + dest.add(count); + } + + fn merge_start_time(&self, dest: &Timestamp) { + let start1 = self + .mid + .clone_inner() + .sum(|metric| matches!(metric.value(), MetricValue::StartTimestamp(_))) + .map(|v| v.as_usize()); + let start2 = self + .final_ + .clone_inner() + .sum(|metric| matches!(metric.value(), MetricValue::StartTimestamp(_))) + .map(|v| v.as_usize()); + match (start1, start2) { + (Some(start1), Some(start2)) => { + dest.set(Utc.timestamp_nanos(start1.min(start2) as i64)) + } + (Some(start1), None) => dest.set(Utc.timestamp_nanos(start1 as i64)), + (None, Some(start2)) => dest.set(Utc.timestamp_nanos(start2 as i64)), + (None, None) => {} + } + } + + fn merge_end_time(&self, dest: &Timestamp) { + let start1 = self + .mid + .clone_inner() + .sum(|metric| matches!(metric.value(), MetricValue::EndTimestamp(_))) + .map(|v| v.as_usize()); + let start2 = self + .final_ + .clone_inner() + .sum(|metric| matches!(metric.value(), MetricValue::EndTimestamp(_))) + .map(|v| v.as_usize()); + match (start1, start2) { + (Some(start1), Some(start2)) => { + dest.set(Utc.timestamp_nanos(start1.max(start2) as i64)) + } + (Some(start1), None) => dest.set(Utc.timestamp_nanos(start1 as i64)), + (None, Some(start2)) => dest.set(Utc.timestamp_nanos(start2 as i64)), + (None, None) => {} + } + } + + /// Aggregate all metrics into a one + pub fn aggregate_all(&self) -> MetricsSet { + let mut metrics = MetricsSet::new(); + let elapsed_time = Time::new(); + let spill_count = Count::new(); + let spilled_bytes = Count::new(); + let output_count = Count::new(); + let start_time = Timestamp::new(); + let end_time = Timestamp::new(); + + metrics.push(Arc::new(Metric::new( + MetricValue::ElapsedCompute(elapsed_time.clone()), + None, + ))); + metrics.push(Arc::new(Metric::new( + MetricValue::SpillCount(spill_count.clone()), + None, + ))); + metrics.push(Arc::new(Metric::new( + MetricValue::SpilledBytes(spilled_bytes.clone()), + None, + ))); + metrics.push(Arc::new(Metric::new( + MetricValue::OutputRows(output_count.clone()), + None, + ))); + metrics.push(Arc::new(Metric::new( + MetricValue::StartTimestamp(start_time.clone()), + None, + ))); + metrics.push(Arc::new(Metric::new( + MetricValue::EndTimestamp(end_time.clone()), + None, + ))); + + self.merge_compute_time(&elapsed_time); + self.merge_spill_count(&spill_count); + self.merge_spilled_bytes(&spilled_bytes); + self.merge_output_count(&output_count); + self.merge_start_time(&start_time); + self.merge_end_time(&end_time); + metrics + } +} diff --git a/datafusion/src/physical_plan/metrics/mod.rs b/datafusion/src/physical_plan/metrics/mod.rs index d48959974e8d..e609beb08c37 100644 --- a/datafusion/src/physical_plan/metrics/mod.rs +++ b/datafusion/src/physical_plan/metrics/mod.rs @@ -17,9 +17,10 @@ //! Metrics for recording information about execution -mod aggregated; mod baseline; mod builder; +mod composite; +mod tracker; mod value; use std::{ @@ -31,9 +32,10 @@ use std::{ use hashbrown::HashMap; // public exports -pub use aggregated::AggregatedMetricsSet; pub use baseline::{BaselineMetrics, RecordOutput}; pub use builder::MetricBuilder; +pub use composite::CompositeMetricsSet; +pub use tracker::MemTrackingMetrics; pub use value::{Count, Gauge, MetricValue, ScopedTimerGuard, Time, Timestamp}; /// Something that tracks a value of interest (metric) of a DataFusion diff --git a/datafusion/src/physical_plan/metrics/tracker.rs b/datafusion/src/physical_plan/metrics/tracker.rs new file mode 100644 index 000000000000..bdceadb8a190 --- /dev/null +++ b/datafusion/src/physical_plan/metrics/tracker.rs @@ -0,0 +1,131 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Metrics with memory usage tracking capability + +use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::MemoryConsumerId; +use crate::physical_plan::metrics::{ + BaselineMetrics, Count, ExecutionPlanMetricsSet, Time, +}; +use std::sync::Arc; +use std::task::Poll; + +use arrow::{error::ArrowError, record_batch::RecordBatch}; + +/// Simplified version of tracking memory consumer, +/// see also: [`Tracking`](crate::execution::memory_manager::ConsumerType::Tracking) +/// +/// You could use this to replace [BaselineMetrics], report the memory, +/// and get the memory usage bookkeeping in the memory manager easily. +#[derive(Debug)] +pub struct MemTrackingMetrics { + id: MemoryConsumerId, + runtime: Option>, + metrics: BaselineMetrics, +} + +/// Delegates most of the metrics functionalities to the inner BaselineMetrics, +/// intercept memory metrics functionalities and do memory manager bookkeeping. +impl MemTrackingMetrics { + /// Create metrics similar to [BaselineMetrics] + pub fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self { + let id = MemoryConsumerId::new(partition); + Self { + id, + runtime: None, + metrics: BaselineMetrics::new(metrics, partition), + } + } + + /// Create memory tracking metrics with reference to runtime + pub fn new_with_rt( + metrics: &ExecutionPlanMetricsSet, + partition: usize, + runtime: Arc, + ) -> Self { + let id = MemoryConsumerId::new(partition); + Self { + id, + runtime: Some(runtime), + metrics: BaselineMetrics::new(metrics, partition), + } + } + + /// return the metric for cpu time spend in this operator + pub fn elapsed_compute(&self) -> &Time { + self.metrics.elapsed_compute() + } + + /// return the size for current memory usage + pub fn mem_used(&self) -> usize { + self.metrics.mem_used().value() + } + + /// setup initial memory usage and register it with memory manager + pub fn init_mem_used(&self, size: usize) { + self.metrics.mem_used().set(size); + if let Some(rt) = self.runtime.as_ref() { + rt.memory_manager.grow_tracker_usage(size); + } + } + + /// return the metric for the total number of output rows produced + pub fn output_rows(&self) -> &Count { + self.metrics.output_rows() + } + + /// Records the fact that this operator's execution is complete + /// (recording the `end_time` metric). + /// + /// Note care should be taken to call `done()` manually if + /// `MemTrackingMetrics` is not `drop`ped immediately upon operator + /// completion, as async streams may not be dropped immediately + /// depending on the consumer. + pub fn done(&self) { + self.metrics.done() + } + + /// Record that some number of rows have been produced as output + /// + /// See the [`RecordOutput`] for conveniently recording record + /// batch output for other thing + pub fn record_output(&self, num_rows: usize) { + self.metrics.record_output(num_rows) + } + + /// Process a poll result of a stream producing output for an + /// operator, recording the output rows and stream done time and + /// returning the same poll result + pub fn record_poll( + &self, + poll: Poll>>, + ) -> Poll>> { + self.metrics.record_poll(poll) + } +} + +impl Drop for MemTrackingMetrics { + fn drop(&mut self) { + self.metrics.try_done(); + if self.mem_used() != 0 { + if let Some(rt) = self.runtime.as_ref() { + rt.drop_consumer(&self.id, self.mem_used()); + } + } + } +} diff --git a/datafusion/src/physical_plan/sorts/mod.rs b/datafusion/src/physical_plan/sorts/mod.rs index 785556864ce8..64ec29179b19 100644 --- a/datafusion/src/physical_plan/sorts/mod.rs +++ b/datafusion/src/physical_plan/sorts/mod.rs @@ -248,15 +248,6 @@ enum StreamWrapper { Stream(Option), } -impl StreamWrapper { - fn mem_used(&self) -> usize { - match &self { - StreamWrapper::Stream(Some(s)) => s.mem_used, - _ => 0, - } - } -} - impl Stream for StreamWrapper { type Item = ArrowResult; diff --git a/datafusion/src/physical_plan/sorts/sort.rs b/datafusion/src/physical_plan/sorts/sort.rs index d40d6cf170e4..7266b6cace47 100644 --- a/datafusion/src/physical_plan/sorts/sort.rs +++ b/datafusion/src/physical_plan/sorts/sort.rs @@ -26,7 +26,9 @@ use crate::execution::memory_manager::{ use crate::execution::runtime_env::RuntimeEnv; use crate::physical_plan::common::{batch_byte_size, IPCWriter, SizedRecordBatchStream}; use crate::physical_plan::expressions::PhysicalSortExpr; -use crate::physical_plan::metrics::{AggregatedMetricsSet, BaselineMetrics, MetricsSet}; +use crate::physical_plan::metrics::{ + BaselineMetrics, CompositeMetricsSet, MemTrackingMetrics, MetricsSet, +}; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeStream; use crate::physical_plan::sorts::SortedStream; use crate::physical_plan::stream::RecordBatchReceiverStream; @@ -73,8 +75,8 @@ struct ExternalSorter { /// Sort expressions expr: Vec, runtime: Arc, - metrics: AggregatedMetricsSet, - inner_metrics: BaselineMetrics, + metrics_set: CompositeMetricsSet, + metrics: BaselineMetrics, } impl ExternalSorter { @@ -82,10 +84,10 @@ impl ExternalSorter { partition_id: usize, schema: SchemaRef, expr: Vec, - metrics: AggregatedMetricsSet, + metrics_set: CompositeMetricsSet, runtime: Arc, ) -> Self { - let inner_metrics = metrics.new_intermediate_baseline(partition_id); + let metrics = metrics_set.new_intermediate_baseline(partition_id); Self { id: MemoryConsumerId::new(partition_id), schema, @@ -93,8 +95,8 @@ impl ExternalSorter { spills: Mutex::new(vec![]), expr, runtime, + metrics_set, metrics, - inner_metrics, } } @@ -102,7 +104,7 @@ impl ExternalSorter { if input.num_rows() > 0 { let size = batch_byte_size(&input); self.try_grow(size).await?; - self.inner_metrics.mem_used().add(size); + self.metrics.mem_used().add(size); let mut in_mem_batches = self.in_mem_batches.lock().await; in_mem_batches.push(input); } @@ -120,16 +122,18 @@ impl ExternalSorter { let mut in_mem_batches = self.in_mem_batches.lock().await; if self.spilled_before().await { - let baseline_metrics = self.metrics.new_intermediate_baseline(partition); + let tracking_metrics = self + .metrics_set + .new_intermediate_tracking(partition, self.runtime.clone()); let mut streams: Vec = vec![]; if in_mem_batches.len() > 0 { let in_mem_stream = in_mem_partial_sort( &mut *in_mem_batches, self.schema.clone(), &self.expr, - baseline_metrics, + tracking_metrics, )?; - let prev_used = self.inner_metrics.mem_used().set(0); + let prev_used = self.metrics.mem_used().set(0); streams.push(SortedStream::new(in_mem_stream, prev_used)); } @@ -139,25 +143,28 @@ impl ExternalSorter { let stream = read_spill_as_stream(spill, self.schema.clone())?; streams.push(SortedStream::new(stream, 0)); } - let baseline_metrics = self.metrics.new_final_baseline(partition); + let tracking_metrics = self + .metrics_set + .new_final_tracking(partition, self.runtime.clone()); Ok(Box::pin(SortPreservingMergeStream::new_from_streams( streams, self.schema.clone(), &self.expr, - baseline_metrics, - partition, + tracking_metrics, self.runtime.clone(), ))) } else if in_mem_batches.len() > 0 { - let baseline_metrics = self.metrics.new_final_baseline(partition); + let tracking_metrics = self + .metrics_set + .new_final_tracking(partition, self.runtime.clone()); let result = in_mem_partial_sort( &mut *in_mem_batches, self.schema.clone(), &self.expr, - baseline_metrics, + tracking_metrics, ); - self.inner_metrics.mem_used().set(0); - // TODO: the result size is not tracked + // Report to the memory manager we are no longer using memory + self.metrics.mem_used().set(0); result } else { Ok(Box::pin(EmptyRecordBatchStream::new(self.schema.clone()))) @@ -165,15 +172,15 @@ impl ExternalSorter { } fn used(&self) -> usize { - self.inner_metrics.mem_used().value() + self.metrics.mem_used().value() } fn spilled_bytes(&self) -> usize { - self.inner_metrics.spilled_bytes().value() + self.metrics.spilled_bytes().value() } fn spill_count(&self) -> usize { - self.inner_metrics.spill_count().value() + self.metrics.spill_count().value() } } @@ -188,6 +195,12 @@ impl Debug for ExternalSorter { } } +impl Drop for ExternalSorter { + fn drop(&mut self) { + self.runtime.drop_consumer(self.id(), self.used()); + } +} + #[async_trait] impl MemoryConsumer for ExternalSorter { fn name(&self) -> String { @@ -222,27 +235,29 @@ impl MemoryConsumer for ExternalSorter { return Ok(0); } - let baseline_metrics = self.metrics.new_intermediate_baseline(partition); + let tracking_metrics = self + .metrics_set + .new_intermediate_tracking(partition, self.runtime.clone()); let spillfile = self.runtime.disk_manager.create_tmp_file()?; let stream = in_mem_partial_sort( &mut *in_mem_batches, self.schema.clone(), &*self.expr, - baseline_metrics, + tracking_metrics, ); spill_partial_sorted_stream(&mut stream?, spillfile.path(), self.schema.clone()) .await?; let mut spills = self.spills.lock().await; - let used = self.inner_metrics.mem_used().set(0); - self.inner_metrics.record_spill(used); + let used = self.metrics.mem_used().set(0); + self.metrics.record_spill(used); spills.push(spillfile); Ok(used) } fn mem_used(&self) -> usize { - self.inner_metrics.mem_used().value() + self.metrics.mem_used().value() } } @@ -251,14 +266,14 @@ fn in_mem_partial_sort( buffered_batches: &mut Vec, schema: SchemaRef, expressions: &[PhysicalSortExpr], - baseline_metrics: BaselineMetrics, + tracking_metrics: MemTrackingMetrics, ) -> Result { assert_ne!(buffered_batches.len(), 0); let result = { // NB timer records time taken on drop, so there are no // calls to `timer.done()` below. - let _timer = baseline_metrics.elapsed_compute().timer(); + let _timer = tracking_metrics.elapsed_compute().timer(); let pre_sort = if buffered_batches.len() == 1 { buffered_batches.pop() @@ -276,7 +291,7 @@ fn in_mem_partial_sort( Ok(Box::pin(SizedRecordBatchStream::new( schema, vec![Arc::new(result.unwrap())], - baseline_metrics, + tracking_metrics, ))) } @@ -357,7 +372,7 @@ pub struct SortExec { /// Sort expressions expr: Vec, /// Containing all metrics set created during sort - all_metrics: AggregatedMetricsSet, + metrics_set: CompositeMetricsSet, /// Preserve partitions of input plan preserve_partitioning: bool, } @@ -381,7 +396,7 @@ impl SortExec { Self { expr, input, - all_metrics: AggregatedMetricsSet::new(), + metrics_set: CompositeMetricsSet::new(), preserve_partitioning, } } @@ -470,14 +485,14 @@ impl ExecutionPlan for SortExec { input, partition, self.expr.clone(), - self.all_metrics.clone(), + self.metrics_set.clone(), runtime, ) .await } fn metrics(&self) -> Option { - Some(self.all_metrics.aggregate_all()) + Some(self.metrics_set.aggregate_all()) } fn fmt_as( @@ -537,27 +552,23 @@ async fn do_sort( mut input: SendableRecordBatchStream, partition_id: usize, expr: Vec, - metrics: AggregatedMetricsSet, + metrics_set: CompositeMetricsSet, runtime: Arc, ) -> Result { let schema = input.schema(); - let sorter = Arc::new(ExternalSorter::new( + let sorter = ExternalSorter::new( partition_id, schema.clone(), expr, - metrics, + metrics_set, runtime.clone(), - )); - runtime.register_consumer(&(sorter.clone() as Arc)); - + ); + runtime.register_requester(sorter.id()); while let Some(batch) = input.next().await { let batch = batch?; sorter.insert_batch(batch).await?; } - - let result = sorter.sort().await; - runtime.drop_consumer(sorter.id()); - result + sorter.sort().await } #[cfg(test)] diff --git a/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs b/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs index 2ac468b35508..7b9d5d5de328 100644 --- a/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs +++ b/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs @@ -19,11 +19,11 @@ use crate::physical_plan::common::AbortOnDropMany; use crate::physical_plan::metrics::{ - BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, + ExecutionPlanMetricsSet, MemTrackingMetrics, MetricsSet, }; use std::any::Any; use std::collections::{BinaryHeap, VecDeque}; -use std::fmt::{Debug, Formatter}; +use std::fmt::Debug; use std::pin::Pin; use std::sync::{Arc, Mutex}; use std::task::{Context, Poll}; @@ -41,9 +41,6 @@ use futures::stream::FusedStream; use futures::{Stream, StreamExt}; use crate::error::{DataFusionError, Result}; -use crate::execution::memory_manager::{ - ConsumerType, MemoryConsumer, MemoryConsumerId, MemoryManager, -}; use crate::execution::runtime_env::RuntimeEnv; use crate::physical_plan::sorts::{RowIndex, SortKeyCursor, SortedStream, StreamWrapper}; use crate::physical_plan::{ @@ -161,7 +158,7 @@ impl ExecutionPlan for SortPreservingMergeExec { ))); } - let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + let tracking_metrics = MemTrackingMetrics::new(&self.metrics, partition); let input_partitions = self.input.output_partitioning().partition_count(); match input_partitions { @@ -193,8 +190,7 @@ impl ExecutionPlan for SortPreservingMergeExec { AbortOnDropMany(join_handles), self.schema(), &self.expr, - baseline_metrics, - partition, + tracking_metrics, runtime, ))) } @@ -223,36 +219,19 @@ impl ExecutionPlan for SortPreservingMergeExec { } } +#[derive(Debug)] struct MergingStreams { - /// ConsumerId - id: MemoryConsumerId, /// The sorted input streams to merge together streams: Mutex>, /// number of streams num_streams: usize, - /// Runtime - runtime: Arc, -} - -impl Debug for MergingStreams { - fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { - f.debug_struct("MergingStreams") - .field("id", &self.id()) - .finish() - } } impl MergingStreams { - fn new( - partition: usize, - input_streams: Vec, - runtime: Arc, - ) -> Self { + fn new(input_streams: Vec) -> Self { Self { - id: MemoryConsumerId::new(partition), num_streams: input_streams.len(), streams: Mutex::new(input_streams), - runtime, } } @@ -261,45 +240,13 @@ impl MergingStreams { } } -#[async_trait] -impl MemoryConsumer for MergingStreams { - fn name(&self) -> String { - "MergingStreams".to_owned() - } - - fn id(&self) -> &MemoryConsumerId { - &self.id - } - - fn memory_manager(&self) -> Arc { - self.runtime.memory_manager.clone() - } - - fn type_(&self) -> &ConsumerType { - &ConsumerType::Tracking - } - - async fn spill(&self) -> Result { - return Err(DataFusionError::Internal(format!( - "Calling spill on a tracking only consumer {}, {}", - self.name(), - self.id, - ))); - } - - fn mem_used(&self) -> usize { - let streams = self.streams.lock().unwrap(); - streams.iter().map(StreamWrapper::mem_used).sum::() - } -} - #[derive(Debug)] pub(crate) struct SortPreservingMergeStream { /// The schema of the RecordBatches yielded by this stream schema: SchemaRef, /// The sorted input streams to merge together - streams: Arc, + streams: MergingStreams, /// Drop helper for tasks feeding the [`receivers`](Self::receivers) _drop_helper: AbortOnDropMany<()>, @@ -324,7 +271,7 @@ pub(crate) struct SortPreservingMergeStream { sort_options: Arc>, /// used to record execution metrics - baseline_metrics: BaselineMetrics, + tracking_metrics: MemTrackingMetrics, /// If the stream has encountered an error aborted: bool, @@ -335,25 +282,17 @@ pub(crate) struct SortPreservingMergeStream { /// min heap for record comparison min_heap: BinaryHeap, - /// runtime - runtime: Arc, -} - -impl Drop for SortPreservingMergeStream { - fn drop(&mut self) { - self.runtime.drop_consumer(self.streams.id()) - } + /// target batch size + batch_size: usize, } impl SortPreservingMergeStream { - #[allow(clippy::too_many_arguments)] pub(crate) fn new_from_receivers( receivers: Vec>>, _drop_helper: AbortOnDropMany<()>, schema: SchemaRef, expressions: &[PhysicalSortExpr], - baseline_metrics: BaselineMetrics, - partition: usize, + tracking_metrics: MemTrackingMetrics, runtime: Arc, ) -> Self { let stream_count = receivers.len(); @@ -362,23 +301,21 @@ impl SortPreservingMergeStream { .map(|_| VecDeque::new()) .collect(); let wrappers = receivers.into_iter().map(StreamWrapper::Receiver).collect(); - let streams = Arc::new(MergingStreams::new(partition, wrappers, runtime.clone())); - runtime.register_consumer(&(streams.clone() as Arc)); SortPreservingMergeStream { schema, batches, cursor_finished: vec![true; stream_count], - streams, + streams: MergingStreams::new(wrappers), _drop_helper, column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(), sort_options: Arc::new(expressions.iter().map(|x| x.options).collect()), - baseline_metrics, + tracking_metrics, aborted: false, in_progress: vec![], next_batch_id: 0, min_heap: BinaryHeap::with_capacity(stream_count), - runtime, + batch_size: runtime.batch_size(), } } @@ -386,8 +323,7 @@ impl SortPreservingMergeStream { streams: Vec, schema: SchemaRef, expressions: &[PhysicalSortExpr], - baseline_metrics: BaselineMetrics, - partition: usize, + tracking_metrics: MemTrackingMetrics, runtime: Arc, ) -> Self { let stream_count = streams.len(); @@ -395,27 +331,26 @@ impl SortPreservingMergeStream { .into_iter() .map(|_| VecDeque::new()) .collect(); + tracking_metrics.init_mem_used(streams.iter().map(|s| s.mem_used).sum()); let wrappers = streams .into_iter() .map(|s| StreamWrapper::Stream(Some(s))) .collect(); - let streams = Arc::new(MergingStreams::new(partition, wrappers, runtime.clone())); - runtime.register_consumer(&(streams.clone() as Arc)); Self { schema, batches, cursor_finished: vec![true; stream_count], - streams, + streams: MergingStreams::new(wrappers), _drop_helper: AbortOnDropMany(vec![]), column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(), sort_options: Arc::new(expressions.iter().map(|x| x.options).collect()), - baseline_metrics, + tracking_metrics, aborted: false, in_progress: vec![], next_batch_id: 0, min_heap: BinaryHeap::with_capacity(stream_count), - runtime, + batch_size: runtime.batch_size(), } } @@ -577,7 +512,7 @@ impl Stream for SortPreservingMergeStream { cx: &mut Context<'_>, ) -> Poll> { let poll = self.poll_next_inner(cx); - self.baseline_metrics.record_poll(poll) + self.tracking_metrics.record_poll(poll) } } @@ -606,7 +541,7 @@ impl SortPreservingMergeStream { loop { // NB timer records time taken on drop, so there are no // calls to `timer.done()` below. - let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); + let elapsed_compute = self.tracking_metrics.elapsed_compute().clone(); let _timer = elapsed_compute.timer(); match self.min_heap.pop() { @@ -630,7 +565,7 @@ impl SortPreservingMergeStream { row_idx, }); - if self.in_progress.len() == self.runtime.batch_size() { + if self.in_progress.len() == self.batch_size { return Poll::Ready(Some(self.build_record_batch())); } @@ -1263,7 +1198,7 @@ mod tests { } let metrics = ExecutionPlanMetricsSet::new(); - let baseline_metrics = BaselineMetrics::new(&metrics, 0); + let tracking_metrics = MemTrackingMetrics::new(&metrics, 0); let merge_stream = SortPreservingMergeStream::new_from_receivers( receivers, @@ -1271,8 +1206,7 @@ mod tests { AbortOnDropMany(vec![]), batches.schema(), sort.as_slice(), - baseline_metrics, - 0, + tracking_metrics, runtime.clone(), ); diff --git a/datafusion/tests/provider_filter_pushdown.rs b/datafusion/tests/provider_filter_pushdown.rs index 5a4f90702ecb..3aac5a8f3662 100644 --- a/datafusion/tests/provider_filter_pushdown.rs +++ b/datafusion/tests/provider_filter_pushdown.rs @@ -25,7 +25,7 @@ use datafusion::execution::context::ExecutionContext; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::logical_plan::Expr; use datafusion::physical_plan::common::SizedRecordBatchStream; -use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet}; +use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MemTrackingMetrics}; use datafusion::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; @@ -86,11 +86,11 @@ impl ExecutionPlan for CustomPlan { _runtime: Arc, ) -> Result { let metrics = ExecutionPlanMetricsSet::new(); - let baseline_metrics = BaselineMetrics::new(&metrics, partition); + let tracking_metrics = MemTrackingMetrics::new(&metrics, partition); Ok(Box::pin(SizedRecordBatchStream::new( self.schema(), self.batches.clone(), - baseline_metrics, + tracking_metrics, ))) } From 0d6d1ce5133c7bac9cf313e72f5a529aaefc766c Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 30 Jan 2022 05:42:18 -0500 Subject: [PATCH 06/50] Implement TableProvider for DataFrameImpl (#1699) * Add TableProvider impl for DataFrameImpl * Add physical plan in * Clean up plan construction and names construction * Remove duplicate comments * Remove unused parameter * Add test * Remove duplicate limit comment * Use cloned instead of individual clone * Reduce the amount of code to get a schema Co-authored-by: Andrew Lamb * Add comments to test * Fix plan comparison * Compare only the results of execution * Remove println * Refer to df_impl instead of table in test Co-authored-by: Andrew Lamb * Fix the register_table test to use the correct result set for comparison * Consolidate group/agg exprs * Format * Remove outdated comment Co-authored-by: Andrew Lamb --- datafusion/src/execution/dataframe_impl.rs | 114 +++++++++++++++++++++ 1 file changed, 114 insertions(+) diff --git a/datafusion/src/execution/dataframe_impl.rs b/datafusion/src/execution/dataframe_impl.rs index f2d0385a3fe0..d3f62bbb46db 100644 --- a/datafusion/src/execution/dataframe_impl.rs +++ b/datafusion/src/execution/dataframe_impl.rs @@ -17,8 +17,11 @@ //! Implementation of DataFrame API. +use std::any::Any; use std::sync::{Arc, Mutex}; +use crate::arrow::datatypes::Schema; +use crate::arrow::datatypes::SchemaRef; use crate::arrow::record_batch::RecordBatch; use crate::error::Result; use crate::execution::context::{ExecutionContext, ExecutionContextState}; @@ -26,12 +29,15 @@ use crate::logical_plan::{ col, DFSchema, Expr, FunctionRegistry, JoinType, LogicalPlan, LogicalPlanBuilder, Partitioning, }; +use crate::scalar::ScalarValue; use crate::{ dataframe::*, physical_plan::{collect, collect_partitioned}, }; use crate::arrow::util::pretty; +use crate::datasource::TableProvider; +use crate::datasource::TableType; use crate::physical_plan::{ execute_stream, execute_stream_partitioned, ExecutionPlan, SendableRecordBatchStream, }; @@ -62,6 +68,59 @@ impl DataFrameImpl { } } +#[async_trait] +impl TableProvider for DataFrameImpl { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + let schema: Schema = self.plan.schema().as_ref().into(); + Arc::new(schema) + } + + fn table_type(&self) -> TableType { + TableType::View + } + + async fn scan( + &self, + projection: &Option>, + filters: &[Expr], + limit: Option, + ) -> Result> { + let expr = projection + .as_ref() + // construct projections + .map_or_else( + || Ok(Arc::new(Self::new(self.ctx_state.clone(), &self.plan)) as Arc<_>), + |projection| { + let schema = TableProvider::schema(self).project(projection)?; + let names = schema + .fields() + .iter() + .map(|field| field.name().as_str()) + .collect::>(); + self.select_columns(names.as_slice()) + }, + )? + // add predicates, otherwise use `true` as the predicate + .filter(filters.iter().cloned().fold( + Expr::Literal(ScalarValue::Boolean(Some(true))), + |acc, new| acc.and(new), + ))?; + // add a limit if given + Self::new( + self.ctx_state.clone(), + &limit + .map_or_else(|| Ok(expr.clone()), |n| expr.limit(n))? + .to_logical_plan(), + ) + .create_physical_plan() + .await + } +} + #[async_trait] impl DataFrame for DataFrameImpl { /// Apply a projection based on a list of column names @@ -488,6 +547,61 @@ mod tests { Ok(()) } + #[tokio::test] + async fn register_table() -> Result<()> { + let df = test_table().await?.select_columns(&["c1", "c12"])?; + let mut ctx = ExecutionContext::new(); + let df_impl = + Arc::new(DataFrameImpl::new(ctx.state.clone(), &df.to_logical_plan())); + + // register a dataframe as a table + ctx.register_table("test_table", df_impl.clone())?; + + // pull the table out + let table = ctx.table("test_table")?; + + let group_expr = vec![col("c1")]; + let aggr_expr = vec![sum(col("c12"))]; + + // check that we correctly read from the table + let df_results = &df_impl + .aggregate(group_expr.clone(), aggr_expr.clone())? + .collect() + .await?; + let table_results = &table.aggregate(group_expr, aggr_expr)?.collect().await?; + + assert_batches_sorted_eq!( + vec![ + "+----+-----------------------------+", + "| c1 | SUM(aggregate_test_100.c12) |", + "+----+-----------------------------+", + "| a | 10.238448667882977 |", + "| b | 7.797734760124923 |", + "| c | 13.860958726523545 |", + "| d | 8.793968289758968 |", + "| e | 10.206140546981722 |", + "+----+-----------------------------+", + ], + df_results + ); + + // the results are the same as the results from the view, modulo the leaf table name + assert_batches_sorted_eq!( + vec![ + "+----+---------------------+", + "| c1 | SUM(test_table.c12) |", + "+----+---------------------+", + "| a | 10.238448667882977 |", + "| b | 7.797734760124923 |", + "| c | 13.860958726523545 |", + "| d | 8.793968289758968 |", + "| e | 10.206140546981722 |", + "+----+---------------------+", + ], + table_results + ); + Ok(()) + } /// Compare the formatted string representation of two plans for equality fn assert_same_plan(plan1: &LogicalPlan, plan2: &LogicalPlan) { assert_eq!(format!("{:?}", plan1), format!("{:?}", plan2)); From 75c7578cd2d510c0814742fc78f7745ca6873c3f Mon Sep 17 00:00:00 2001 From: "xudong.w" Date: Sun, 30 Jan 2022 21:11:17 +0800 Subject: [PATCH 07/50] refine test in repartition.rs & coalesce_batches.rs (#1707) --- .../src/physical_plan/coalesce_batches.rs | 20 +---------------- datafusion/src/physical_plan/mod.rs | 6 ++--- datafusion/src/physical_plan/planner.rs | 4 ++-- datafusion/src/physical_plan/repartition.rs | 22 +++---------------- datafusion/src/test/mod.rs | 20 +++++++++++++++++ 5 files changed, 28 insertions(+), 44 deletions(-) diff --git a/datafusion/src/physical_plan/coalesce_batches.rs b/datafusion/src/physical_plan/coalesce_batches.rs index 586b05219bdf..ec238ad68cf8 100644 --- a/datafusion/src/physical_plan/coalesce_batches.rs +++ b/datafusion/src/physical_plan/coalesce_batches.rs @@ -295,9 +295,8 @@ pub fn concat_batches( #[cfg(test)] mod tests { use super::*; - use crate::from_slice::FromSlice; use crate::physical_plan::{memory::MemoryExec, repartition::RepartitionExec}; - use arrow::array::UInt32Array; + use crate::test::create_vec_batches; use arrow::datatypes::{DataType, Field, Schema}; #[tokio::test(flavor = "multi_thread")] @@ -325,23 +324,6 @@ mod tests { Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)])) } - fn create_vec_batches(schema: &Arc, num_batches: usize) -> Vec { - let batch = create_batch(schema); - let mut vec = Vec::with_capacity(num_batches); - for _ in 0..num_batches { - vec.push(batch.clone()); - } - vec - } - - fn create_batch(schema: &Arc) -> RecordBatch { - RecordBatch::try_new( - schema.clone(), - vec![Arc::new(UInt32Array::from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]))], - ) - .unwrap() - } - async fn coalesce_batches( schema: &SchemaRef, input_partitions: Vec>, diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index 216d4a65e639..24aa6ad38339 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -59,7 +59,7 @@ pub type SendableRecordBatchStream = Pin usize { use Partitioning::*; match self { - RoundRobinBatch(n) => *n, - Hash(_, n) => *n, - UnknownPartitioning(n) => *n, + RoundRobinBatch(n) | Hash(_, n) | UnknownPartitioning(n) => *n, } } } diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 2dcde9d11333..226e3f392497 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -226,9 +226,9 @@ pub trait PhysicalPlanner { /// /// `expr`: the expression to convert /// - /// `input_dfschema`: the logical plan schema for evaluating `e` + /// `input_dfschema`: the logical plan schema for evaluating `expr` /// - /// `input_schema`: the physical schema for evaluating `e` + /// `input_schema`: the physical schema for evaluating `expr` fn create_physical_expr( &self, expr: &Expr, diff --git a/datafusion/src/physical_plan/repartition.rs b/datafusion/src/physical_plan/repartition.rs index 746075429a45..86866728cdda 100644 --- a/datafusion/src/physical_plan/repartition.rs +++ b/datafusion/src/physical_plan/repartition.rs @@ -447,7 +447,7 @@ struct RepartitionStream { /// Number of input partitions that have finished sending batches to this output channel num_input_partitions_processed: usize, - /// Schema + /// Schema wrapped by Arc schema: SchemaRef, /// channel containing the repartitioned batches @@ -494,6 +494,7 @@ impl RecordBatchStream for RepartitionStream { mod tests { use super::*; use crate::from_slice::FromSlice; + use crate::test::create_vec_batches; use crate::{ assert_batches_sorted_eq, physical_plan::{collect, expressions::col, memory::MemoryExec}, @@ -508,7 +509,7 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use arrow::{ - array::{ArrayRef, StringArray, UInt32Array}, + array::{ArrayRef, StringArray}, error::ArrowError, }; use futures::FutureExt; @@ -601,23 +602,6 @@ mod tests { Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)])) } - fn create_vec_batches(schema: &Arc, n: usize) -> Vec { - let batch = create_batch(schema); - let mut vec = Vec::with_capacity(n); - for _ in 0..n { - vec.push(batch.clone()); - } - vec - } - - fn create_batch(schema: &Arc) -> RecordBatch { - RecordBatch::try_new( - schema.clone(), - vec![Arc::new(UInt32Array::from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]))], - ) - .unwrap() - } - async fn repartition( schema: &SchemaRef, input_partitions: Vec>, diff --git a/datafusion/src/test/mod.rs b/datafusion/src/test/mod.rs index 497bfe59e1a1..cebd9ee02d1c 100644 --- a/datafusion/src/test/mod.rs +++ b/datafusion/src/test/mod.rs @@ -17,6 +17,7 @@ //! Common unit test utility methods +use crate::arrow::array::UInt32Array; use crate::datasource::object_store::local::local_unpartitioned_file; use crate::datasource::{MemTable, PartitionedFile, TableProvider}; use crate::error::Result; @@ -212,6 +213,25 @@ pub fn assert_is_pending<'a, T>(fut: &mut Pin + Send assert!(poll.is_pending()); } +/// Create vector batches +pub fn create_vec_batches(schema: &Arc, n: usize) -> Vec { + let batch = create_batch(schema); + let mut vec = Vec::with_capacity(n); + for _ in 0..n { + vec.push(batch.clone()); + } + vec +} + +/// Create batch +fn create_batch(schema: &Arc) -> RecordBatch { + RecordBatch::try_new( + schema.clone(), + vec![Arc::new(UInt32Array::from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]))], + ) + .unwrap() +} + pub mod exec; pub mod object_store; pub mod user_defined; From a7f0156b33be22c6c3fa66db3754a56844b3c99f Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Sun, 30 Jan 2022 21:17:47 +0800 Subject: [PATCH 08/50] Fuzz test for spillable sort (#1706) --- datafusion/Cargo.toml | 1 + datafusion/fuzz-utils/Cargo.toml | 28 +++++ datafusion/fuzz-utils/src/lib.rs | 73 +++++++++++++ datafusion/src/execution/memory_manager.rs | 3 +- datafusion/src/execution/mod.rs | 4 +- datafusion/src/physical_plan/sorts/sort.rs | 6 +- datafusion/tests/merge_fuzz.rs | 50 +-------- datafusion/tests/order_spill_fuzz.rs | 121 +++++++++++++++++++++ 8 files changed, 234 insertions(+), 52 deletions(-) create mode 100644 datafusion/fuzz-utils/Cargo.toml create mode 100644 datafusion/fuzz-utils/src/lib.rs create mode 100644 datafusion/tests/order_spill_fuzz.rs diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index e0e880dba3dc..422a776448d9 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -82,6 +82,7 @@ tempfile = "3" [dev-dependencies] criterion = "0.3" doc-comment = "0.3" +fuzz-utils = { path = "fuzz-utils" } [[bench]] name = "aggregate_query_sql" diff --git a/datafusion/fuzz-utils/Cargo.toml b/datafusion/fuzz-utils/Cargo.toml new file mode 100644 index 000000000000..304cbfea8434 --- /dev/null +++ b/datafusion/fuzz-utils/Cargo.toml @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "fuzz-utils" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +arrow = { version = "8.0.0", features = ["prettyprint"] } +rand = "0.8" +env_logger = "0.9.0" diff --git a/datafusion/fuzz-utils/src/lib.rs b/datafusion/fuzz-utils/src/lib.rs new file mode 100644 index 000000000000..e021f55f8724 --- /dev/null +++ b/datafusion/fuzz-utils/src/lib.rs @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Common utils for fuzz tests +use arrow::{array::Int32Array, record_batch::RecordBatch}; +use rand::prelude::StdRng; +use rand::Rng; + +pub use env_logger; + +/// Extracts the i32 values from the set of batches and returns them as a single Vec +pub fn batches_to_vec(batches: &[RecordBatch]) -> Vec> { + batches + .iter() + .map(|batch| { + assert_eq!(batch.num_columns(), 1); + batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .iter() + }) + .flatten() + .collect() +} + +/// extract values from batches and sort them +pub fn partitions_to_sorted_vec(partitions: &[Vec]) -> Vec> { + let mut values: Vec<_> = partitions + .iter() + .map(|batches| batches_to_vec(batches).into_iter()) + .flatten() + .collect(); + + values.sort_unstable(); + values +} + +/// Adds a random number of empty record batches into the stream +pub fn add_empty_batches( + batches: Vec, + rng: &mut StdRng, +) -> Vec { + let schema = batches[0].schema(); + + batches + .into_iter() + .map(|batch| { + // insert 0, or 1 empty batches before and after the current batch + let empty_batch = RecordBatch::new_empty(schema.clone()); + std::iter::repeat(empty_batch.clone()) + .take(rng.gen_range(0..2)) + .chain(std::iter::once(batch)) + .chain(std::iter::repeat(empty_batch).take(rng.gen_range(0..2))) + }) + .flatten() + .collect() +} diff --git a/datafusion/src/execution/memory_manager.rs b/datafusion/src/execution/memory_manager.rs index 0fb3cfbb4ecf..5015f466c674 100644 --- a/datafusion/src/execution/memory_manager.rs +++ b/datafusion/src/execution/memory_manager.rs @@ -392,7 +392,8 @@ const GB: u64 = 1 << 30; const MB: u64 = 1 << 20; const KB: u64 = 1 << 10; -fn human_readable_size(size: usize) -> String { +/// Present size in human readable form +pub fn human_readable_size(size: usize) -> String { let size = size as u64; let (value, unit) = { if size >= 2 * TB { diff --git a/datafusion/src/execution/mod.rs b/datafusion/src/execution/mod.rs index e3b42ae254a9..427c539cc75b 100644 --- a/datafusion/src/execution/mod.rs +++ b/datafusion/src/execution/mod.rs @@ -25,4 +25,6 @@ pub mod options; pub mod runtime_env; pub use disk_manager::DiskManager; -pub use memory_manager::{MemoryConsumer, MemoryConsumerId, MemoryManager}; +pub use memory_manager::{ + human_readable_size, MemoryConsumer, MemoryConsumerId, MemoryManager, +}; diff --git a/datafusion/src/physical_plan/sorts/sort.rs b/datafusion/src/physical_plan/sorts/sort.rs index 7266b6cace47..7f7f58104fc8 100644 --- a/datafusion/src/physical_plan/sorts/sort.rs +++ b/datafusion/src/physical_plan/sorts/sort.rs @@ -21,7 +21,7 @@ use crate::error::{DataFusionError, Result}; use crate::execution::memory_manager::{ - ConsumerType, MemoryConsumer, MemoryConsumerId, MemoryManager, + human_readable_size, ConsumerType, MemoryConsumer, MemoryConsumerId, MemoryManager, }; use crate::execution::runtime_env::RuntimeEnv; use crate::physical_plan::common::{batch_byte_size, IPCWriter, SizedRecordBatchStream}; @@ -348,7 +348,9 @@ fn write_sorted( writer.finish()?; debug!( "Spilled {} batches of total {} rows to disk, memory released {}", - writer.num_batches, writer.num_rows, writer.num_bytes + writer.num_batches, + writer.num_rows, + human_readable_size(writer.num_bytes as usize), ); Ok(()) } diff --git a/datafusion/tests/merge_fuzz.rs b/datafusion/tests/merge_fuzz.rs index 81920549c18a..6821c6ba52d0 100644 --- a/datafusion/tests/merge_fuzz.rs +++ b/datafusion/tests/merge_fuzz.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Fuzz Test for various corner cases merging streams of RecordBatchs +//! Fuzz Test for various corner cases merging streams of RecordBatches use std::sync::Arc; use arrow::{ @@ -32,6 +32,7 @@ use datafusion::{ sorts::sort_preserving_merge::SortPreservingMergeExec, }, }; +use fuzz_utils::{add_empty_batches, batches_to_vec, partitions_to_sorted_vec}; use rand::{prelude::StdRng, Rng, SeedableRng}; #[tokio::test] @@ -147,35 +148,6 @@ async fn run_merge_test(input: Vec>) { } } -/// Extracts the i32 values from the set of batches and returns them as a single Vec -fn batches_to_vec(batches: &[RecordBatch]) -> Vec> { - batches - .iter() - .map(|batch| { - assert_eq!(batch.num_columns(), 1); - batch - .column(0) - .as_any() - .downcast_ref::() - .unwrap() - .iter() - }) - .flatten() - .collect() -} - -// extract values from batches and sort them -fn partitions_to_sorted_vec(partitions: &[Vec]) -> Vec> { - let mut values: Vec<_> = partitions - .iter() - .map(|batches| batches_to_vec(batches).into_iter()) - .flatten() - .collect(); - - values.sort_unstable(); - values -} - /// Return the values `low..high` in order, in randomly sized /// record batches in a field named 'x' of type `Int32` fn make_staggered_batches(low: i32, high: i32, seed: u64) -> Vec { @@ -199,24 +171,6 @@ fn make_staggered_batches(low: i32, high: i32, seed: u64) -> Vec { add_empty_batches(batches, &mut rng) } -/// Adds a random number of empty record batches into the stream -fn add_empty_batches(batches: Vec, rng: &mut StdRng) -> Vec { - let schema = batches[0].schema(); - - batches - .into_iter() - .map(|batch| { - // insert 0, or 1 empty batches before and after the current batch - let empty_batch = RecordBatch::new_empty(schema.clone()); - std::iter::repeat(empty_batch.clone()) - .take(rng.gen_range(0..2)) - .chain(std::iter::once(batch)) - .chain(std::iter::repeat(empty_batch).take(rng.gen_range(0..2))) - }) - .flatten() - .collect() -} - fn concat(mut v1: Vec, v2: Vec) -> Vec { v1.extend(v2); v1 diff --git a/datafusion/tests/order_spill_fuzz.rs b/datafusion/tests/order_spill_fuzz.rs new file mode 100644 index 000000000000..049fe6a4f4fd --- /dev/null +++ b/datafusion/tests/order_spill_fuzz.rs @@ -0,0 +1,121 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Fuzz Test for various corner cases sorting RecordBatches exceeds available memory and should spill + +use arrow::{ + array::{ArrayRef, Int32Array}, + compute::SortOptions, + record_batch::RecordBatch, +}; +use datafusion::execution::memory_manager::MemoryManagerConfig; +use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::physical_plan::expressions::{col, PhysicalSortExpr}; +use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::sorts::sort::SortExec; +use datafusion::physical_plan::{collect, ExecutionPlan}; +use fuzz_utils::{add_empty_batches, batches_to_vec, partitions_to_sorted_vec}; +use rand::prelude::StdRng; +use rand::{Rng, SeedableRng}; +use std::sync::Arc; + +#[tokio::test] +async fn test_sort_1k_mem() { + run_sort(1024, vec![(5, false), (2000, true), (1000000, true)]).await +} + +#[tokio::test] +async fn test_sort_100k_mem() { + run_sort(102400, vec![(5, false), (2000, false), (1000000, true)]).await +} + +#[tokio::test] +async fn test_sort_unlimited_mem() { + run_sort( + usize::MAX, + vec![(5, false), (2000, false), (1000000, false)], + ) + .await +} + +/// Sort the input using SortExec and ensure the results are correct according to `Vec::sort` +async fn run_sort(pool_size: usize, size_spill: Vec<(usize, bool)>) { + for (size, spill) in size_spill { + let input = vec![make_staggered_batches(size)]; + let first_batch = input + .iter() + .map(|p| p.iter()) + .flatten() + .next() + .expect("at least one batch"); + let schema = first_batch.schema(); + + let sort = vec![PhysicalSortExpr { + expr: col("x", &schema).unwrap(), + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]; + + let exec = MemoryExec::try_new(&input, schema, None).unwrap(); + let sort = Arc::new(SortExec::try_new(sort, Arc::new(exec)).unwrap()); + + let runtime_config = RuntimeConfig::new().with_memory_manager( + MemoryManagerConfig::try_new_limit(pool_size, 1.0).unwrap(), + ); + let runtime = Arc::new(RuntimeEnv::new(runtime_config).unwrap()); + let collected = collect(sort.clone(), runtime).await.unwrap(); + + let expected = partitions_to_sorted_vec(&input); + let actual = batches_to_vec(&collected); + + if spill { + assert_ne!(sort.metrics().unwrap().spill_count().unwrap(), 0); + } else { + assert_eq!(sort.metrics().unwrap().spill_count().unwrap(), 0); + } + + assert_eq!(expected, actual, "failure in @ pool_size {}", pool_size); + } +} + +/// Return randomly sized record batches in a field named 'x' of type `Int32` +/// with randomized i32 content +fn make_staggered_batches(len: usize) -> Vec { + let mut rng = rand::thread_rng(); + let mut input: Vec = vec![0; len]; + rng.fill(&mut input[..]); + let input = Int32Array::from_iter_values(input.into_iter()); + + // split into several record batches + let mut remainder = + RecordBatch::try_from_iter(vec![("x", Arc::new(input) as ArrayRef)]).unwrap(); + + let mut batches = vec![]; + + // use a random number generator to pick a random sized output + let mut rng = StdRng::seed_from_u64(42); + while remainder.num_rows() > 0 { + let batch_size = rng.gen_range(0..remainder.num_rows() + 1); + + batches.push(remainder.slice(0, batch_size)); + remainder = remainder.slice(batch_size, remainder.num_rows() - batch_size); + } + + add_empty_batches(batches, &mut rng) +} From fecce97b519cbdaa16c9974af58bb12d9d73d327 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 30 Jan 2022 08:18:09 -0500 Subject: [PATCH 09/50] Lazy TempDir creation in DiskManager (#1695) --- datafusion/src/execution/disk_manager.rs | 99 +++++++++++++++++------- 1 file changed, 73 insertions(+), 26 deletions(-) diff --git a/datafusion/src/execution/disk_manager.rs b/datafusion/src/execution/disk_manager.rs index 31565fec130d..4486f53a21b8 100644 --- a/datafusion/src/execution/disk_manager.rs +++ b/datafusion/src/execution/disk_manager.rs @@ -21,8 +21,8 @@ use crate::error::{DataFusionError, Result}; use log::debug; use rand::{thread_rng, Rng}; -use std::path::PathBuf; use std::sync::Arc; +use std::{path::PathBuf, sync::Mutex}; use tempfile::{Builder, NamedTempFile, TempDir}; /// Configuration for temporary disk access @@ -67,7 +67,9 @@ impl DiskManagerConfig { /// while processing dataset larger than available memory. #[derive(Debug)] pub struct DiskManager { - local_dirs: Vec, + /// TempDirs to put temporary files in. A new OS specified + /// temporary directory will be created if this list is empty. + local_dirs: Mutex>, } impl DiskManager { @@ -75,31 +77,39 @@ impl DiskManager { pub fn try_new(config: DiskManagerConfig) -> Result> { match config { DiskManagerConfig::Existing(manager) => Ok(manager), - DiskManagerConfig::NewOs => { - let tempdir = tempfile::tempdir().map_err(DataFusionError::IoError)?; - - debug!( - "Created directory {:?} as DataFusion working directory", - tempdir - ); - Ok(Arc::new(Self { - local_dirs: vec![tempdir], - })) - } + DiskManagerConfig::NewOs => Ok(Arc::new(Self { + local_dirs: Mutex::new(vec![]), + })), DiskManagerConfig::NewSpecified(conf_dirs) => { let local_dirs = create_local_dirs(conf_dirs)?; debug!( "Created local dirs {:?} as DataFusion working directory", local_dirs ); - Ok(Arc::new(Self { local_dirs })) + Ok(Arc::new(Self { + local_dirs: Mutex::new(local_dirs), + })) } } } /// Return a temporary file from a randomized choice in the configured locations pub fn create_tmp_file(&self) -> Result { - create_tmp_file(&self.local_dirs) + let mut local_dirs = self.local_dirs.lock().unwrap(); + + // Create a temporary directory if needed + if local_dirs.is_empty() { + let tempdir = tempfile::tempdir().map_err(DataFusionError::IoError)?; + + debug!( + "Created directory '{:?}' as DataFusion tempfile directory", + tempdir.path().to_string_lossy() + ); + + local_dirs.push(tempdir); + } + + create_tmp_file(&local_dirs) } } @@ -129,10 +139,42 @@ fn create_tmp_file(local_dirs: &[TempDir]) -> Result { #[cfg(test)] mod tests { + use std::path::Path; + use super::*; use crate::error::Result; use tempfile::TempDir; + #[test] + fn lazy_temp_dir_creation() -> Result<()> { + // A default configuration should not create temp files until requested + let config = DiskManagerConfig::new(); + let dm = DiskManager::try_new(config)?; + + assert_eq!(0, local_dir_snapshot(&dm).len()); + + // can still create a tempfile however: + let actual = dm.create_tmp_file()?; + + // Now the tempdir has been created on demand + assert_eq!(1, local_dir_snapshot(&dm).len()); + + // the returned tempfile file should be in the temp directory + let local_dirs = local_dir_snapshot(&dm); + assert_path_in_dirs(actual.path(), local_dirs.iter().map(|p| p.as_path())); + + Ok(()) + } + + fn local_dir_snapshot(dm: &DiskManager) -> Vec { + dm.local_dirs + .lock() + .unwrap() + .iter() + .map(|p| p.path().into()) + .collect() + } + #[test] fn file_in_right_dir() -> Result<()> { let local_dir1 = TempDir::new()?; @@ -147,19 +189,24 @@ mod tests { let actual = dm.create_tmp_file()?; // the file should be in one of the specified local directories - let found = local_dirs.iter().any(|p| { - actual - .path() + assert_path_in_dirs(actual.path(), local_dirs.into_iter()); + + Ok(()) + } + + /// Asserts that `file_path` is found anywhere in any of `dir` directories + fn assert_path_in_dirs<'a>( + file_path: &'a Path, + dirs: impl Iterator, + ) { + let dirs: Vec<&Path> = dirs.collect(); + + let found = dirs.iter().any(|file_path| { + file_path .ancestors() - .any(|candidate_path| *p == candidate_path) + .any(|candidate_path| *file_path == candidate_path) }); - assert!( - found, - "Can't find {:?} in specified local dirs: {:?}", - actual, local_dirs - ); - - Ok(()) + assert!(found, "Can't find {:?} in dirs: {:?}", file_path, dirs); } } From 3494e9ce2cd41f246240caebada4e3faad1f5fa9 Mon Sep 17 00:00:00 2001 From: Matthew Turner Date: Sun, 30 Jan 2022 08:25:32 -0500 Subject: [PATCH 10/50] Incorporate dyn scalar kernels (#1685) * Rebase * impl ToNumeric for ScalarValue * Update macro to be based on * Add floats * Cleanup * Newline --- .../src/physical_plan/expressions/binary.rs | 122 ++++++++++++++++-- 1 file changed, 114 insertions(+), 8 deletions(-) diff --git a/datafusion/src/physical_plan/expressions/binary.rs b/datafusion/src/physical_plan/expressions/binary.rs index 604cb3c49931..4680dd0a49d9 100644 --- a/datafusion/src/physical_plan/expressions/binary.rs +++ b/datafusion/src/physical_plan/expressions/binary.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::convert::TryInto; use std::{any::Any, sync::Arc}; use arrow::array::TimestampMillisecondArray; @@ -29,6 +30,18 @@ use arrow::compute::kernels::comparison::{ eq_bool, eq_bool_scalar, gt_bool, gt_bool_scalar, gt_eq_bool, gt_eq_bool_scalar, lt_bool, lt_bool_scalar, lt_eq_bool, lt_eq_bool_scalar, neq_bool, neq_bool_scalar, }; +use arrow::compute::kernels::comparison::{ + eq_dyn_bool_scalar, gt_dyn_bool_scalar, gt_eq_dyn_bool_scalar, lt_dyn_bool_scalar, + lt_eq_dyn_bool_scalar, neq_dyn_bool_scalar, +}; +use arrow::compute::kernels::comparison::{ + eq_dyn_scalar, gt_dyn_scalar, gt_eq_dyn_scalar, lt_dyn_scalar, lt_eq_dyn_scalar, + neq_dyn_scalar, +}; +use arrow::compute::kernels::comparison::{ + eq_dyn_utf8_scalar, gt_dyn_utf8_scalar, gt_eq_dyn_utf8_scalar, lt_dyn_utf8_scalar, + lt_eq_dyn_utf8_scalar, neq_dyn_utf8_scalar, +}; use arrow::compute::kernels::comparison::{ eq_scalar, gt_eq_scalar, gt_scalar, lt_eq_scalar, lt_scalar, neq_scalar, }; @@ -430,6 +443,23 @@ macro_rules! compute_utf8_op_scalar { }}; } +/// Invoke a compute kernel on a data array and a scalar value +macro_rules! compute_utf8_op_dyn_scalar { + ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ + if let Some(string_value) = $RIGHT { + Ok(Arc::new(paste::expr! {[<$OP _dyn_utf8_scalar>]}( + $LEFT, + &string_value, + )?)) + } else { + Err(DataFusionError::Internal(format!( + "compute_utf8_op_scalar for '{}' failed with literal 'none' value", + stringify!($OP), + ))) + } + }}; +} + /// Invoke a compute kernel on a boolean data array and a scalar value macro_rules! compute_bool_op_scalar { ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ @@ -447,6 +477,25 @@ macro_rules! compute_bool_op_scalar { }}; } +/// Invoke a compute kernel on a boolean data array and a scalar value +macro_rules! compute_bool_op_dyn_scalar { + ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ + // generate the scalar function name, such as lt_dyn_bool_scalar, from the $OP parameter + // (which could have a value of lt) and the suffix _scalar + if let Some(b) = $RIGHT { + Ok(Arc::new(paste::expr! {[<$OP _dyn_bool_scalar>]}( + $LEFT, + b, + )?)) + } else { + Err(DataFusionError::Internal(format!( + "compute_utf8_op_scalar for '{}' failed with literal 'none' value", + stringify!($OP), + ))) + } + }}; +} + /// Invoke a bool compute kernel on array(s) macro_rules! compute_bool_op { // invoke binary operator @@ -475,7 +524,6 @@ macro_rules! compute_bool_op { /// LEFT is array, RIGHT is scalar value macro_rules! compute_op_scalar { ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ - use std::convert::TryInto; let ll = $LEFT .as_any() .downcast_ref::<$DT>() @@ -489,6 +537,26 @@ macro_rules! compute_op_scalar { }}; } +/// Invoke a dyn compute kernel on a data array and a scalar value +/// LEFT is Primitive or Dictionart array of numeric values, RIGHT is scalar value +macro_rules! compute_op_dyn_scalar { + ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ + // generate the scalar function name, such as lt_dyn_scalar, from the $OP parameter + // (which could have a value of lt_dyn) and the suffix _scalar + if let Some(value) = $RIGHT { + Ok(Arc::new(paste::expr! {[<$OP _dyn_scalar>]}( + $LEFT, + value, + )?)) + } else { + Err(DataFusionError::Internal(format!( + "compute_utf8_op_scalar for '{}' failed with literal 'none' value", + stringify!($OP), + ))) + } + }}; +} + /// Invoke a compute kernel on array(s) macro_rules! compute_op { // invoke binary operator @@ -879,26 +947,64 @@ impl PhysicalExpr for BinaryExpr { } } +/// The binary_array_op_dyn_scalar macro includes types that extend beyond the primitive, +/// such as Utf8 strings. +#[macro_export] +macro_rules! binary_array_op_dyn_scalar { + ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ + let result: Result> = match $RIGHT { + ScalarValue::Boolean(b) => compute_bool_op_dyn_scalar!($LEFT, b, $OP), + ScalarValue::Decimal128(..) => compute_decimal_op_scalar!($LEFT, $RIGHT, $OP, DecimalArray), + ScalarValue::Utf8(v) => compute_utf8_op_dyn_scalar!($LEFT, v, $OP), + ScalarValue::LargeUtf8(v) => compute_utf8_op_dyn_scalar!($LEFT, v, $OP), + ScalarValue::Int8(v) => compute_op_dyn_scalar!($LEFT, v, $OP), + ScalarValue::Int16(v) => compute_op_dyn_scalar!($LEFT, v, $OP), + ScalarValue::Int32(v) => compute_op_dyn_scalar!($LEFT, v, $OP), + ScalarValue::Int64(v) => compute_op_dyn_scalar!($LEFT, v, $OP), + ScalarValue::UInt8(v) => compute_op_dyn_scalar!($LEFT, v, $OP), + ScalarValue::UInt16(v) => compute_op_dyn_scalar!($LEFT, v, $OP), + ScalarValue::UInt32(v) => compute_op_dyn_scalar!($LEFT, v, $OP), + ScalarValue::UInt64(v) => compute_op_dyn_scalar!($LEFT, v, $OP), + ScalarValue::Float32(_) => compute_op_scalar!($LEFT, $RIGHT, $OP, Float32Array), + ScalarValue::Float64(_) => compute_op_scalar!($LEFT, $RIGHT, $OP, Float64Array), + ScalarValue::Date32(_) => compute_op_scalar!($LEFT, $RIGHT, $OP, Date32Array), + ScalarValue::Date64(_) => compute_op_scalar!($LEFT, $RIGHT, $OP, Date64Array), + ScalarValue::TimestampSecond(..) => compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampSecondArray), + ScalarValue::TimestampMillisecond(..) => compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampMillisecondArray), + ScalarValue::TimestampMicrosecond(..) => compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampMicrosecondArray), + ScalarValue::TimestampNanosecond(..) => compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampNanosecondArray), + other => Err(DataFusionError::Internal(format!("Data type {:?} not supported for scalar operation '{}' on dyn array", other, stringify!($OP)))) + }; + Some(result) + }} +} + impl BinaryExpr { /// Evaluate the expression of the left input is an array and /// right is literal - use scalar operations fn evaluate_array_scalar( &self, - array: &ArrayRef, + array: &dyn Array, scalar: &ScalarValue, ) -> Result>> { let scalar_result = match &self.op { - Operator::Lt => binary_array_op_scalar!(array, scalar.clone(), lt), + Operator::Lt => { + binary_array_op_dyn_scalar!(array, scalar.clone(), lt) + } Operator::LtEq => { - binary_array_op_scalar!(array, scalar.clone(), lt_eq) + binary_array_op_dyn_scalar!(array, scalar.clone(), lt_eq) + } + Operator::Gt => { + binary_array_op_dyn_scalar!(array, scalar.clone(), gt) } - Operator::Gt => binary_array_op_scalar!(array, scalar.clone(), gt), Operator::GtEq => { - binary_array_op_scalar!(array, scalar.clone(), gt_eq) + binary_array_op_dyn_scalar!(array, scalar.clone(), gt_eq) + } + Operator::Eq => { + binary_array_op_dyn_scalar!(array, scalar.clone(), eq) } - Operator::Eq => binary_array_op_scalar!(array, scalar.clone(), eq), Operator::NotEq => { - binary_array_op_scalar!(array, scalar.clone(), neq) + binary_array_op_dyn_scalar!(array, scalar.clone(), neq) } Operator::Like => { binary_string_array_op_scalar!(array, scalar.clone(), like) From 251260849ebee7d5d87f8091a70b4d11ca5ae91a Mon Sep 17 00:00:00 2001 From: "xudong.w" Date: Mon, 31 Jan 2022 22:45:14 +0800 Subject: [PATCH 11/50] add annotation for select_to_plan (#1714) --- datafusion/src/sql/planner.rs | 34 ++++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 9da54cad4daa..a74c44665de1 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -697,14 +697,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - /// Generate a logic plan from an SQL select - fn select_to_plan( + /// Generate a logic plan from selection clause, the function contain optimization for cross join to inner join + /// Related PR: https://github.com/apache/arrow-datafusion/pull/1566 + fn plan_selection( &self, select: &Select, - ctes: &mut HashMap, - alias: Option, + plans: Vec, ) -> Result { - let plans = self.plan_from_tables(&select.from, ctes)?; let plan = match &select.selection { Some(predicate_expr) => { // build join schema @@ -822,9 +821,23 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } }; - let plan = plan?; + plan + } - // The SELECT expressions, with wildcards expanded. + /// Generate a logic plan from an SQL select + fn select_to_plan( + &self, + select: &Select, + ctes: &mut HashMap, + alias: Option, + ) -> Result { + // process `from` clause + let plans = self.plan_from_tables(&select.from, ctes)?; + + // process `where` clause + let plan = self.plan_selection(select, plans)?; + + // process the SELECT expressions, with wildcards expanded. let select_exprs = self.prepare_select_exprs(&plan, select)?; // having and group by clause may reference aliases defined in select projection @@ -873,6 +886,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // All of the aggregate expressions (deduplicated). let aggr_exprs = find_aggregate_exprs(&aggr_expr_haystack); + // All of the group by expressions let group_by_exprs = select .group_by .iter() @@ -891,6 +905,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }) .collect::>>()?; + // process group by, aggregation or having let (plan, select_exprs_post_aggr, having_expr_post_aggr_opt) = if !group_by_exprs .is_empty() || !aggr_exprs.is_empty() @@ -931,7 +946,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan }; - // window function + // process window function let window_func_exprs = find_window_exprs(&select_exprs_post_aggr); let plan = if window_func_exprs.is_empty() { @@ -940,6 +955,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { LogicalPlanBuilder::window_plan(plan, window_func_exprs)? }; + // process distinct clause let plan = if select.distinct { return LogicalPlanBuilder::from(plan) .aggregate(select_exprs_post_aggr, iter::empty::())? @@ -947,6 +963,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } else { plan }; + + // generate the final projection plan project_with_alias(plan, select_exprs_post_aggr, alias) } From 1caf52ae311ac425704916fd92f7e275cb3be5cb Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 31 Jan 2022 09:46:59 -0500 Subject: [PATCH 12/50] Support `create_physical_expr` and `ExecutionContextState` or `DefaultPhysicalPlanner` for faster speed (#1700) * Change physical_expr creation API * Refactor API usage to avoid creating ExecutionContextState * Fixup ballista * clippy! --- .../src/serde/physical_plan/from_proto.rs | 3 +- datafusion/src/execution/context.rs | 40 +- .../src/optimizer/simplify_expressions.rs | 30 +- datafusion/src/physical_optimizer/pruning.rs | 13 +- datafusion/src/physical_plan/functions.rs | 34 +- datafusion/src/physical_plan/planner.rs | 935 +++++++++--------- 6 files changed, 525 insertions(+), 530 deletions(-) diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 0b3d50306063..5dd57d1b7079 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -622,7 +622,6 @@ impl TryFrom<&protobuf::PhysicalExprNode> for Arc { let ctx_state = ExecutionContextState { catalog_list, scalar_functions: Default::default(), - var_provider: Default::default(), aggregate_functions: Default::default(), config: ExecutionConfig::new(), execution_props: ExecutionProps::new(), @@ -632,7 +631,7 @@ impl TryFrom<&protobuf::PhysicalExprNode> for Arc { let fun_expr = functions::create_physical_fun( &(&scalar_function).into(), - &ctx_state, + &ctx_state.execution_props, )?; Arc::new(ScalarFunctionExpr::new( diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 6ed8223f0c52..ca86d0f0a019 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -190,7 +190,6 @@ impl ExecutionContext { state: Arc::new(Mutex::new(ExecutionContextState { catalog_list, scalar_functions: HashMap::new(), - var_provider: HashMap::new(), aggregate_functions: HashMap::new(), config, execution_props: ExecutionProps::new(), @@ -324,8 +323,8 @@ impl ExecutionContext { self.state .lock() .unwrap() - .var_provider - .insert(variable_type, provider); + .execution_props + .add_var_provider(variable_type, provider); } /// Registers a scalar UDF within this context. @@ -1115,9 +1114,14 @@ impl ExecutionConfig { /// An instance of this struct is created each time a [`LogicalPlan`] is prepared for /// execution (optimized). If the same plan is optimized multiple times, a new /// `ExecutionProps` is created each time. +/// +/// It is important that this structure be cheap to create as it is +/// done so during predicate pruning and expression simplification #[derive(Clone)] pub struct ExecutionProps { pub(crate) query_execution_start_time: DateTime, + /// providers for scalar variables + pub var_providers: Option>>, } impl Default for ExecutionProps { @@ -1131,6 +1135,7 @@ impl ExecutionProps { pub fn new() -> Self { ExecutionProps { query_execution_start_time: chrono::Utc::now(), + var_providers: None, } } @@ -1139,6 +1144,32 @@ impl ExecutionProps { self.query_execution_start_time = chrono::Utc::now(); &*self } + + /// Registers a variable provider, returning the existing + /// provider, if any + pub fn add_var_provider( + &mut self, + var_type: VarType, + provider: Arc, + ) -> Option> { + let mut var_providers = self.var_providers.take().unwrap_or_else(HashMap::new); + + let old_provider = var_providers.insert(var_type, provider); + + self.var_providers = Some(var_providers); + + old_provider + } + + /// Returns the provider for the var_type, if any + pub fn get_var_provider( + &self, + var_type: VarType, + ) -> Option> { + self.var_providers + .as_ref() + .and_then(|var_providers| var_providers.get(&var_type).map(Arc::clone)) + } } /// Execution context for registering data sources and executing queries @@ -1148,8 +1179,6 @@ pub struct ExecutionContextState { pub catalog_list: Arc, /// Scalar functions that are registered with the context pub scalar_functions: HashMap>, - /// Variable provider that are registered with the context - pub var_provider: HashMap>, /// Aggregate functions registered in the context pub aggregate_functions: HashMap>, /// Context configuration @@ -1174,7 +1203,6 @@ impl ExecutionContextState { ExecutionContextState { catalog_list: Arc::new(MemoryCatalogList::new()), scalar_functions: HashMap::new(), - var_provider: HashMap::new(), aggregate_functions: HashMap::new(), config: ExecutionConfig::new(), execution_props: ExecutionProps::new(), diff --git a/datafusion/src/optimizer/simplify_expressions.rs b/datafusion/src/optimizer/simplify_expressions.rs index 7127a8fa94d6..6f5235e852b7 100644 --- a/datafusion/src/optimizer/simplify_expressions.rs +++ b/datafusion/src/optimizer/simplify_expressions.rs @@ -22,13 +22,13 @@ use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use crate::error::DataFusionError; -use crate::execution::context::{ExecutionContextState, ExecutionProps}; +use crate::execution::context::ExecutionProps; use crate::logical_plan::{lit, DFSchemaRef, Expr}; use crate::logical_plan::{DFSchema, ExprRewriter, LogicalPlan, RewriteRecursion}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; use crate::physical_plan::functions::Volatility; -use crate::physical_plan::planner::DefaultPhysicalPlanner; +use crate::physical_plan::planner::create_physical_expr; use crate::scalar::ScalarValue; use crate::{error::Result, logical_plan::Operator}; @@ -223,7 +223,7 @@ impl SimplifyExpressions { /// let rewritten = expr.rewrite(&mut const_evaluator).unwrap(); /// assert_eq!(rewritten, lit(3) + col("a")); /// ``` -pub struct ConstEvaluator { +pub struct ConstEvaluator<'a> { /// can_evaluate is used during the depth-first-search of the /// Expr tree to track if any siblings (or their descendants) were /// non evaluatable (e.g. had a column reference or volatile @@ -238,13 +238,12 @@ pub struct ConstEvaluator { /// descendants) so this Expr can be evaluated can_evaluate: Vec, - ctx_state: ExecutionContextState, - planner: DefaultPhysicalPlanner, + execution_props: &'a ExecutionProps, input_schema: DFSchema, input_batch: RecordBatch, } -impl ExprRewriter for ConstEvaluator { +impl<'a> ExprRewriter for ConstEvaluator<'a> { fn pre_visit(&mut self, expr: &Expr) -> Result { // Default to being able to evaluate this node self.can_evaluate.push(true); @@ -282,16 +281,11 @@ impl ExprRewriter for ConstEvaluator { } } -impl ConstEvaluator { +impl<'a> ConstEvaluator<'a> { /// Create a new `ConstantEvaluator`. Session constants (such as /// the time for `now()` are taken from the passed /// `execution_props`. - pub fn new(execution_props: &ExecutionProps) -> Self { - let planner = DefaultPhysicalPlanner::default(); - let ctx_state = ExecutionContextState { - execution_props: execution_props.clone(), - ..ExecutionContextState::new() - }; + pub fn new(execution_props: &'a ExecutionProps) -> Self { let input_schema = DFSchema::empty(); // The dummy column name is unused and doesn't matter as only @@ -306,8 +300,7 @@ impl ConstEvaluator { Self { can_evaluate: vec![], - ctx_state, - planner, + execution_props, input_schema, input_batch, } @@ -364,11 +357,11 @@ impl ConstEvaluator { return Ok(s); } - let phys_expr = self.planner.create_physical_expr( + let phys_expr = create_physical_expr( &expr, &self.input_schema, &self.input_batch.schema(), - &self.ctx_state, + self.execution_props, )?; let col_val = phys_expr.evaluate(&self.input_batch)?; match col_val { @@ -1141,6 +1134,7 @@ mod tests { ) { let execution_props = ExecutionProps { query_execution_start_time: *date_time, + var_providers: None, }; let mut const_evaluator = ConstEvaluator::new(&execution_props); @@ -1622,6 +1616,7 @@ mod tests { let rule = SimplifyExpressions::new(); let execution_props = ExecutionProps { query_execution_start_time: *date_time, + var_providers: None, }; let err = rule @@ -1638,6 +1633,7 @@ mod tests { let rule = SimplifyExpressions::new(); let execution_props = ExecutionProps { query_execution_start_time: *date_time, + var_providers: None, }; let optimized_plan = rule diff --git a/datafusion/src/physical_optimizer/pruning.rs b/datafusion/src/physical_optimizer/pruning.rs index 22b854b93a59..7bbffd1546fd 100644 --- a/datafusion/src/physical_optimizer/pruning.rs +++ b/datafusion/src/physical_optimizer/pruning.rs @@ -37,13 +37,14 @@ use arrow::{ record_batch::RecordBatch, }; +use crate::execution::context::ExecutionProps; +use crate::physical_plan::planner::create_physical_expr; use crate::prelude::lit; use crate::{ error::{DataFusionError, Result}, - execution::context::ExecutionContextState, logical_plan::{Column, DFSchema, Expr, Operator}, optimizer::utils, - physical_plan::{planner::DefaultPhysicalPlanner, ColumnarValue, PhysicalExpr}, + physical_plan::{ColumnarValue, PhysicalExpr}, }; /// Interface to pass statistics information to [`PruningPredicates`] @@ -129,12 +130,14 @@ impl PruningPredicate { .collect::>(); let stat_schema = Schema::new(stat_fields); let stat_dfschema = DFSchema::try_from(stat_schema.clone())?; - let execution_context_state = ExecutionContextState::new(); - let predicate_expr = DefaultPhysicalPlanner::default().create_physical_expr( + + // TODO allow these properties to be passed in + let execution_props = ExecutionProps::new(); + let predicate_expr = create_physical_expr( &logical_predicate_expr, &stat_dfschema, &stat_schema, - &execution_context_state, + &execution_props, )?; Ok(Self { schema, diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs index 2c1946e9da37..644defce1545 100644 --- a/datafusion/src/physical_plan/functions.rs +++ b/datafusion/src/physical_plan/functions.rs @@ -33,7 +33,7 @@ use super::{ type_coercion::{coerce, data_types}, ColumnarValue, PhysicalExpr, }; -use crate::execution::context::ExecutionContextState; +use crate::execution::context::ExecutionProps; use crate::physical_plan::array_expressions; use crate::physical_plan::datetime_expressions; use crate::physical_plan::expressions::{ @@ -723,7 +723,7 @@ macro_rules! invoke_if_unicode_expressions_feature_flag { /// Create a physical scalar function. pub fn create_physical_fun( fun: &BuiltinScalarFunction, - ctx_state: &ExecutionContextState, + execution_props: &ExecutionProps, ) -> Result { Ok(match fun { // math functions @@ -820,7 +820,7 @@ pub fn create_physical_fun( BuiltinScalarFunction::Now => { // bind value for now at plan time Arc::new(datetime_expressions::make_now( - ctx_state.execution_props.query_execution_start_time, + execution_props.query_execution_start_time, )) } BuiltinScalarFunction::InitCap => Arc::new(|args| match args[0].data_type() { @@ -1157,7 +1157,7 @@ pub fn create_physical_expr( fun: &BuiltinScalarFunction, input_phy_exprs: &[Arc], input_schema: &Schema, - ctx_state: &ExecutionContextState, + execution_props: &ExecutionProps, ) -> Result> { let coerced_phy_exprs = coerce(input_phy_exprs, input_schema, &signature(fun))?; @@ -1254,7 +1254,7 @@ pub fn create_physical_expr( } }), // These don't need args and input schema - _ => create_physical_fun(fun, ctx_state)?, + _ => create_physical_fun(fun, execution_props)?, }; Ok(Arc::new(ScalarFunctionExpr::new( @@ -1720,14 +1720,14 @@ mod tests { ($FUNC:ident, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $DATA_TYPE: ident, $ARRAY_TYPE:ident) => { // used to provide type annotation let expected: Result> = $EXPECTED; - let ctx_state = ExecutionContextState::new(); + let execution_props = ExecutionProps::new(); // any type works here: we evaluate against a literal of `value` let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let columns: Vec = vec![Arc::new(Int32Array::from_slice(&[1]))]; let expr = - create_physical_expr(&BuiltinScalarFunction::$FUNC, $ARGS, &schema, &ctx_state)?; + create_physical_expr(&BuiltinScalarFunction::$FUNC, $ARGS, &schema, &execution_props)?; // type is correct assert_eq!(expr.data_type(&schema)?, DataType::$DATA_TYPE); @@ -3888,7 +3888,7 @@ mod tests { #[test] fn test_empty_arguments_error() -> Result<()> { - let ctx_state = ExecutionContextState::new(); + let execution_props = ExecutionProps::new(); let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); // pick some arbitrary functions to test @@ -3900,7 +3900,7 @@ mod tests { ]; for fun in funs.iter() { - let expr = create_physical_expr(fun, &[], &schema, &ctx_state); + let expr = create_physical_expr(fun, &[], &schema, &execution_props); match expr { Ok(..) => { @@ -3931,13 +3931,13 @@ mod tests { #[test] fn test_empty_arguments() -> Result<()> { - let ctx_state = ExecutionContextState::new(); + let execution_props = ExecutionProps::new(); let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let funs = [BuiltinScalarFunction::Now, BuiltinScalarFunction::Random]; for fun in funs.iter() { - create_physical_expr(fun, &[], &schema, &ctx_state)?; + create_physical_expr(fun, &[], &schema, &execution_props)?; } Ok(()) } @@ -3954,13 +3954,13 @@ mod tests { Field::new("b", value2.data_type().clone(), false), ]); let columns: Vec = vec![value1, value2]; - let ctx_state = ExecutionContextState::new(); + let execution_props = ExecutionProps::new(); let expr = create_physical_expr( &BuiltinScalarFunction::Array, &[col("a", &schema)?, col("b", &schema)?], &schema, - &ctx_state, + &execution_props, )?; // type is correct @@ -4017,7 +4017,7 @@ mod tests { fn test_regexp_match() -> Result<()> { use arrow::array::ListArray; let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); - let ctx_state = ExecutionContextState::new(); + let execution_props = ExecutionProps::new(); let col_value: ArrayRef = Arc::new(StringArray::from_slice(&["aaa-555"])); let pattern = lit(ScalarValue::Utf8(Some(r".*-(\d*)".to_string()))); @@ -4026,7 +4026,7 @@ mod tests { &BuiltinScalarFunction::RegexpMatch, &[col("a", &schema)?, pattern], &schema, - &ctx_state, + &execution_props, )?; // type is correct @@ -4056,7 +4056,7 @@ mod tests { fn test_regexp_match_all_literals() -> Result<()> { use arrow::array::ListArray; let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let ctx_state = ExecutionContextState::new(); + let execution_props = ExecutionProps::new(); let col_value = lit(ScalarValue::Utf8(Some("aaa-555".to_string()))); let pattern = lit(ScalarValue::Utf8(Some(r".*-(\d*)".to_string()))); @@ -4065,7 +4065,7 @@ mod tests { &BuiltinScalarFunction::RegexpMatch, &[col_value, pattern], &schema, - &ctx_state, + &execution_props, )?; // type is correct diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 226e3f392497..bf8be3df720b 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -22,7 +22,7 @@ use super::{ aggregates, empty::EmptyExec, expressions::binary, functions, hash_join::PartitionMode, udaf, union::UnionExec, values::ValuesExec, windows, }; -use crate::execution::context::ExecutionContextState; +use crate::execution::context::{ExecutionContextState, ExecutionProps}; use crate::logical_plan::plan::{ Aggregate, EmptyRelation, Filter, Join, Projection, Sort, TableScan, Window, }; @@ -299,12 +299,11 @@ impl PhysicalPlanner for DefaultPhysicalPlanner { input_schema: &Schema, ctx_state: &ExecutionContextState, ) -> Result> { - DefaultPhysicalPlanner::create_physical_expr( - self, + create_physical_expr( expr, input_dfschema, input_schema, - ctx_state, + &ctx_state.execution_props, ) } } @@ -440,7 +439,7 @@ impl DefaultPhysicalPlanner { expr, asc, nulls_first, - } => self.create_physical_sort_expr( + } => create_physical_sort_expr( expr, logical_input_schema, &physical_input_schema, @@ -448,7 +447,7 @@ impl DefaultPhysicalPlanner { descending: !*asc, nulls_first: *nulls_first, }, - ctx_state, + &ctx_state.execution_props, ), _ => unreachable!(), }) @@ -464,11 +463,11 @@ impl DefaultPhysicalPlanner { let window_expr = window_expr .iter() .map(|e| { - self.create_window_expr( + create_window_expr( e, logical_input_schema, &physical_input_schema, - ctx_state, + &ctx_state.execution_props, ) }) .collect::>>()?; @@ -507,11 +506,11 @@ impl DefaultPhysicalPlanner { let aggregates = aggr_expr .iter() .map(|e| { - self.create_aggregate_expr( + create_aggregate_expr( e, logical_input_schema, &physical_input_schema, - ctx_state, + &ctx_state.execution_props, ) }) .collect::>>()?; @@ -688,7 +687,7 @@ impl DefaultPhysicalPlanner { expr, asc, nulls_first, - } => self.create_physical_sort_expr( + } => create_physical_sort_expr( expr, input_dfschema, &input_schema, @@ -696,7 +695,7 @@ impl DefaultPhysicalPlanner { descending: !*asc, nulls_first: *nulls_first, }, - ctx_state, + &ctx_state.execution_props, ), _ => Err(DataFusionError::Plan( "Sort only accepts sort expressions".to_string(), @@ -866,517 +865,487 @@ impl DefaultPhysicalPlanner { exec_plan }.boxed() } +} - /// Create a physical expression from a logical expression - pub fn create_physical_expr( - &self, - e: &Expr, - input_dfschema: &DFSchema, - input_schema: &Schema, - ctx_state: &ExecutionContextState, - ) -> Result> { - match e { - Expr::Alias(expr, ..) => Ok(self.create_physical_expr( - expr, - input_dfschema, - input_schema, - ctx_state, - )?), - Expr::Column(c) => { - let idx = input_dfschema.index_of_column(c)?; - Ok(Arc::new(Column::new(&c.name, idx))) - } - Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))), - Expr::ScalarVariable(variable_names) => { - if &variable_names[0][0..2] == "@@" { - match ctx_state.var_provider.get(&VarType::System) { - Some(provider) => { - let scalar_value = - provider.get_value(variable_names.clone())?; - Ok(Arc::new(Literal::new(scalar_value))) - } - _ => Err(DataFusionError::Plan( - "No system variable provider found".to_string(), - )), +/// Create a physical expression from a logical expression ([Expr]) +pub fn create_physical_expr( + e: &Expr, + input_dfschema: &DFSchema, + input_schema: &Schema, + execution_props: &ExecutionProps, +) -> Result> { + match e { + Expr::Alias(expr, ..) => Ok(create_physical_expr( + expr, + input_dfschema, + input_schema, + execution_props, + )?), + Expr::Column(c) => { + let idx = input_dfschema.index_of_column(c)?; + Ok(Arc::new(Column::new(&c.name, idx))) + } + Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))), + Expr::ScalarVariable(variable_names) => { + if &variable_names[0][0..2] == "@@" { + match execution_props.get_var_provider(VarType::System) { + Some(provider) => { + let scalar_value = provider.get_value(variable_names.clone())?; + Ok(Arc::new(Literal::new(scalar_value))) } - } else { - match ctx_state.var_provider.get(&VarType::UserDefined) { - Some(provider) => { - let scalar_value = - provider.get_value(variable_names.clone())?; - Ok(Arc::new(Literal::new(scalar_value))) - } - _ => Err(DataFusionError::Plan( - "No user defined variable provider found".to_string(), - )), + _ => Err(DataFusionError::Plan( + "No system variable provider found".to_string(), + )), + } + } else { + match execution_props.get_var_provider(VarType::UserDefined) { + Some(provider) => { + let scalar_value = provider.get_value(variable_names.clone())?; + Ok(Arc::new(Literal::new(scalar_value))) } + _ => Err(DataFusionError::Plan( + "No user defined variable provider found".to_string(), + )), } } - Expr::BinaryExpr { left, op, right } => { - let lhs = self.create_physical_expr( - left, - input_dfschema, - input_schema, - ctx_state, - )?; - let rhs = self.create_physical_expr( - right, + } + Expr::BinaryExpr { left, op, right } => { + let lhs = create_physical_expr( + left, + input_dfschema, + input_schema, + execution_props, + )?; + let rhs = create_physical_expr( + right, + input_dfschema, + input_schema, + execution_props, + )?; + binary(lhs, *op, rhs, input_schema) + } + Expr::Case { + expr, + when_then_expr, + else_expr, + .. + } => { + let expr: Option> = if let Some(e) = expr { + Some(create_physical_expr( + e.as_ref(), input_dfschema, input_schema, - ctx_state, - )?; - binary(lhs, *op, rhs, input_schema) - } - Expr::Case { - expr, - when_then_expr, - else_expr, - .. - } => { - let expr: Option> = if let Some(e) = expr { - Some(self.create_physical_expr( - e.as_ref(), + execution_props, + )?) + } else { + None + }; + let when_expr = when_then_expr + .iter() + .map(|(w, _)| { + create_physical_expr( + w.as_ref(), input_dfschema, input_schema, - ctx_state, - )?) - } else { - None - }; - let when_expr = when_then_expr - .iter() - .map(|(w, _)| { - self.create_physical_expr( - w.as_ref(), - input_dfschema, - input_schema, - ctx_state, - ) - }) - .collect::>>()?; - let then_expr = when_then_expr - .iter() - .map(|(_, t)| { - self.create_physical_expr( - t.as_ref(), - input_dfschema, - input_schema, - ctx_state, - ) - }) - .collect::>>()?; - let when_then_expr: Vec<(Arc, Arc)> = - when_expr - .iter() - .zip(then_expr.iter()) - .map(|(w, t)| (w.clone(), t.clone())) - .collect(); - let else_expr: Option> = if let Some(e) = else_expr - { - Some(self.create_physical_expr( - e.as_ref(), + execution_props, + ) + }) + .collect::>>()?; + let then_expr = when_then_expr + .iter() + .map(|(_, t)| { + create_physical_expr( + t.as_ref(), input_dfschema, input_schema, - ctx_state, - )?) - } else { - None - }; - Ok(Arc::new(CaseExpr::try_new( - expr, - &when_then_expr, - else_expr, - )?)) - } - Expr::Cast { expr, data_type } => expressions::cast( - self.create_physical_expr(expr, input_dfschema, input_schema, ctx_state)?, - input_schema, - data_type.clone(), - ), - Expr::TryCast { expr, data_type } => expressions::try_cast( - self.create_physical_expr(expr, input_dfschema, input_schema, ctx_state)?, - input_schema, - data_type.clone(), - ), - Expr::Not(expr) => expressions::not( - self.create_physical_expr(expr, input_dfschema, input_schema, ctx_state)?, - input_schema, - ), - Expr::Negative(expr) => expressions::negative( - self.create_physical_expr(expr, input_dfschema, input_schema, ctx_state)?, - input_schema, - ), - Expr::IsNull(expr) => expressions::is_null(self.create_physical_expr( + execution_props, + ) + }) + .collect::>>()?; + let when_then_expr: Vec<(Arc, Arc)> = + when_expr + .iter() + .zip(then_expr.iter()) + .map(|(w, t)| (w.clone(), t.clone())) + .collect(); + let else_expr: Option> = if let Some(e) = else_expr { + Some(create_physical_expr( + e.as_ref(), + input_dfschema, + input_schema, + execution_props, + )?) + } else { + None + }; + Ok(Arc::new(CaseExpr::try_new( expr, - input_dfschema, + &when_then_expr, + else_expr, + )?)) + } + Expr::Cast { expr, data_type } => expressions::cast( + create_physical_expr(expr, input_dfschema, input_schema, execution_props)?, + input_schema, + data_type.clone(), + ), + Expr::TryCast { expr, data_type } => expressions::try_cast( + create_physical_expr(expr, input_dfschema, input_schema, execution_props)?, + input_schema, + data_type.clone(), + ), + Expr::Not(expr) => expressions::not( + create_physical_expr(expr, input_dfschema, input_schema, execution_props)?, + input_schema, + ), + Expr::Negative(expr) => expressions::negative( + create_physical_expr(expr, input_dfschema, input_schema, execution_props)?, + input_schema, + ), + Expr::IsNull(expr) => expressions::is_null(create_physical_expr( + expr, + input_dfschema, + input_schema, + execution_props, + )?), + Expr::IsNotNull(expr) => expressions::is_not_null(create_physical_expr( + expr, + input_dfschema, + input_schema, + execution_props, + )?), + Expr::GetIndexedField { expr, key } => Ok(Arc::new(GetIndexedFieldExpr::new( + create_physical_expr(expr, input_dfschema, input_schema, execution_props)?, + key.clone(), + ))), + + Expr::ScalarFunction { fun, args } => { + let physical_args = args + .iter() + .map(|e| { + create_physical_expr(e, input_dfschema, input_schema, execution_props) + }) + .collect::>>()?; + functions::create_physical_expr( + fun, + &physical_args, input_schema, - ctx_state, - )?), - Expr::IsNotNull(expr) => expressions::is_not_null( - self.create_physical_expr(expr, input_dfschema, input_schema, ctx_state)?, - ), - Expr::GetIndexedField { expr, key } => { - Ok(Arc::new(GetIndexedFieldExpr::new( - self.create_physical_expr( - expr, - input_dfschema, - input_schema, - ctx_state, - )?, - key.clone(), - ))) - } - - Expr::ScalarFunction { fun, args } => { - let physical_args = args - .iter() - .map(|e| { - self.create_physical_expr( - e, - input_dfschema, - input_schema, - ctx_state, - ) - }) - .collect::>>()?; - functions::create_physical_expr( - fun, - &physical_args, + execution_props, + ) + } + Expr::ScalarUDF { fun, args } => { + let mut physical_args = vec![]; + for e in args { + physical_args.push(create_physical_expr( + e, + input_dfschema, input_schema, - ctx_state, - ) + execution_props, + )?); } - Expr::ScalarUDF { fun, args } => { - let mut physical_args = vec![]; - for e in args { - physical_args.push(self.create_physical_expr( - e, - input_dfschema, - input_schema, - ctx_state, - )?); - } - udf::create_physical_expr( - fun.clone().as_ref(), - &physical_args, - input_schema, - ) - } - Expr::Between { + udf::create_physical_expr(fun.clone().as_ref(), &physical_args, input_schema) + } + Expr::Between { + expr, + negated, + low, + high, + } => { + let value_expr = create_physical_expr( expr, - negated, - low, + input_dfschema, + input_schema, + execution_props, + )?; + let low_expr = + create_physical_expr(low, input_dfschema, input_schema, execution_props)?; + let high_expr = create_physical_expr( high, - } => { - let value_expr = self.create_physical_expr( + input_dfschema, + input_schema, + execution_props, + )?; + + // rewrite the between into the two binary operators + let binary_expr = binary( + binary(value_expr.clone(), Operator::GtEq, low_expr, input_schema)?, + Operator::And, + binary(value_expr.clone(), Operator::LtEq, high_expr, input_schema)?, + input_schema, + ); + + if *negated { + expressions::not(binary_expr?, input_schema) + } else { + binary_expr + } + } + Expr::InList { + expr, + list, + negated, + } => match expr.as_ref() { + Expr::Literal(ScalarValue::Utf8(None)) => { + Ok(expressions::lit(ScalarValue::Boolean(None))) + } + _ => { + let value_expr = create_physical_expr( expr, input_dfschema, input_schema, - ctx_state, + execution_props, )?; - let low_expr = self.create_physical_expr( - low, - input_dfschema, - input_schema, - ctx_state, - )?; - let high_expr = self.create_physical_expr( - high, - input_dfschema, - input_schema, - ctx_state, - )?; - - // rewrite the between into the two binary operators - let binary_expr = binary( - binary(value_expr.clone(), Operator::GtEq, low_expr, input_schema)?, - Operator::And, - binary(value_expr.clone(), Operator::LtEq, high_expr, input_schema)?, - input_schema, - ); - - if *negated { - expressions::not(binary_expr?, input_schema) - } else { - binary_expr - } - } - Expr::InList { - expr, - list, - negated, - } => match expr.as_ref() { - Expr::Literal(ScalarValue::Utf8(None)) => { - Ok(expressions::lit(ScalarValue::Boolean(None))) - } - _ => { - let value_expr = self.create_physical_expr( - expr, - input_dfschema, - input_schema, - ctx_state, - )?; - let value_expr_data_type = value_expr.data_type(input_schema)?; + let value_expr_data_type = value_expr.data_type(input_schema)?; - let list_exprs = list - .iter() - .map(|expr| match expr { - Expr::Literal(ScalarValue::Utf8(None)) => self - .create_physical_expr( - expr, - input_dfschema, - input_schema, - ctx_state, - ), - _ => { - let list_expr = self.create_physical_expr( - expr, - input_dfschema, - input_schema, - ctx_state, - )?; - let list_expr_data_type = - list_expr.data_type(input_schema)?; - - if list_expr_data_type == value_expr_data_type { - Ok(list_expr) - } else if can_cast_types( - &list_expr_data_type, - &value_expr_data_type, - ) { - expressions::cast( - list_expr, - input_schema, - value_expr.data_type(input_schema)?, - ) - } else { - Err(DataFusionError::Plan(format!( - "Unsupported CAST from {:?} to {:?}", - list_expr_data_type, value_expr_data_type - ))) - } - } - }) - .collect::>>()?; - - expressions::in_list(value_expr, list_exprs, negated) - } - }, - other => Err(DataFusionError::NotImplemented(format!( - "Physical plan does not support logical expression {:?}", - other - ))), - } - } - - /// Create a window expression with a name from a logical expression - pub fn create_window_expr_with_name( - &self, - e: &Expr, - name: impl Into, - logical_input_schema: &DFSchema, - physical_input_schema: &Schema, - ctx_state: &ExecutionContextState, - ) -> Result> { - let name = name.into(); - match e { - Expr::WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - } => { - let args = args + let list_exprs = list .iter() - .map(|e| { - self.create_physical_expr( - e, - logical_input_schema, - physical_input_schema, - ctx_state, - ) - }) - .collect::>>()?; - let partition_by = partition_by - .iter() - .map(|e| { - self.create_physical_expr( - e, - logical_input_schema, - physical_input_schema, - ctx_state, - ) - }) - .collect::>>()?; - let order_by = order_by - .iter() - .map(|e| match e { - Expr::Sort { + .map(|expr| match expr { + Expr::Literal(ScalarValue::Utf8(None)) => create_physical_expr( expr, - asc, - nulls_first, - } => self.create_physical_sort_expr( - expr, - logical_input_schema, - physical_input_schema, - SortOptions { - descending: !*asc, - nulls_first: *nulls_first, - }, - ctx_state, + input_dfschema, + input_schema, + execution_props, ), - _ => Err(DataFusionError::Plan( - "Sort only accepts sort expressions".to_string(), - )), + _ => { + let list_expr = create_physical_expr( + expr, + input_dfschema, + input_schema, + execution_props, + )?; + let list_expr_data_type = + list_expr.data_type(input_schema)?; + + if list_expr_data_type == value_expr_data_type { + Ok(list_expr) + } else if can_cast_types( + &list_expr_data_type, + &value_expr_data_type, + ) { + expressions::cast( + list_expr, + input_schema, + value_expr.data_type(input_schema)?, + ) + } else { + Err(DataFusionError::Plan(format!( + "Unsupported CAST from {:?} to {:?}", + list_expr_data_type, value_expr_data_type + ))) + } + } }) .collect::>>()?; - if window_frame.is_some() { - return Err(DataFusionError::NotImplemented( - "window expression with window frame definition is not yet supported" - .to_owned(), - )); - } - windows::create_window_expr( - fun, - name, - &args, - &partition_by, - &order_by, - *window_frame, - physical_input_schema, - ) + + expressions::in_list(value_expr, list_exprs, negated) } - other => Err(DataFusionError::Internal(format!( - "Invalid window expression '{:?}'", - other - ))), - } + }, + other => Err(DataFusionError::NotImplemented(format!( + "Physical plan does not support logical expression {:?}", + other + ))), } +} - /// Create a window expression from a logical expression or an alias - pub fn create_window_expr( - &self, - e: &Expr, - logical_input_schema: &DFSchema, - physical_input_schema: &Schema, - ctx_state: &ExecutionContextState, - ) -> Result> { - // unpack aliased logical expressions, e.g. "sum(col) over () as total" - let (name, e) = match e { - Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()), - _ => (physical_name(e)?, e), - }; - self.create_window_expr_with_name( - e, - name, - logical_input_schema, - physical_input_schema, - ctx_state, - ) +/// Create a window expression with a name from a logical expression +pub fn create_window_expr_with_name( + e: &Expr, + name: impl Into, + logical_input_schema: &DFSchema, + physical_input_schema: &Schema, + execution_props: &ExecutionProps, +) -> Result> { + let name = name.into(); + match e { + Expr::WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + } => { + let args = args + .iter() + .map(|e| { + create_physical_expr( + e, + logical_input_schema, + physical_input_schema, + execution_props, + ) + }) + .collect::>>()?; + let partition_by = partition_by + .iter() + .map(|e| { + create_physical_expr( + e, + logical_input_schema, + physical_input_schema, + execution_props, + ) + }) + .collect::>>()?; + let order_by = order_by + .iter() + .map(|e| match e { + Expr::Sort { + expr, + asc, + nulls_first, + } => create_physical_sort_expr( + expr, + logical_input_schema, + physical_input_schema, + SortOptions { + descending: !*asc, + nulls_first: *nulls_first, + }, + execution_props, + ), + _ => Err(DataFusionError::Plan( + "Sort only accepts sort expressions".to_string(), + )), + }) + .collect::>>()?; + if window_frame.is_some() { + return Err(DataFusionError::NotImplemented( + "window expression with window frame definition is not yet supported" + .to_owned(), + )); + } + windows::create_window_expr( + fun, + name, + &args, + &partition_by, + &order_by, + *window_frame, + physical_input_schema, + ) + } + other => Err(DataFusionError::Internal(format!( + "Invalid window expression '{:?}'", + other + ))), } +} - /// Create an aggregate expression with a name from a logical expression - pub fn create_aggregate_expr_with_name( - &self, - e: &Expr, - name: impl Into, - logical_input_schema: &DFSchema, - physical_input_schema: &Schema, - ctx_state: &ExecutionContextState, - ) -> Result> { - match e { - Expr::AggregateFunction { +/// Create a window expression from a logical expression or an alias +pub fn create_window_expr( + e: &Expr, + logical_input_schema: &DFSchema, + physical_input_schema: &Schema, + execution_props: &ExecutionProps, +) -> Result> { + // unpack aliased logical expressions, e.g. "sum(col) over () as total" + let (name, e) = match e { + Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()), + _ => (physical_name(e)?, e), + }; + create_window_expr_with_name( + e, + name, + logical_input_schema, + physical_input_schema, + execution_props, + ) +} + +/// Create an aggregate expression with a name from a logical expression +pub fn create_aggregate_expr_with_name( + e: &Expr, + name: impl Into, + logical_input_schema: &DFSchema, + physical_input_schema: &Schema, + execution_props: &ExecutionProps, +) -> Result> { + match e { + Expr::AggregateFunction { + fun, + distinct, + args, + .. + } => { + let args = args + .iter() + .map(|e| { + create_physical_expr( + e, + logical_input_schema, + physical_input_schema, + execution_props, + ) + }) + .collect::>>()?; + aggregates::create_aggregate_expr( fun, - distinct, - args, - .. - } => { - let args = args - .iter() - .map(|e| { - self.create_physical_expr( - e, - logical_input_schema, - physical_input_schema, - ctx_state, - ) - }) - .collect::>>()?; - aggregates::create_aggregate_expr( - fun, - *distinct, - &args, - physical_input_schema, - name, - ) - } - Expr::AggregateUDF { fun, args, .. } => { - let args = args - .iter() - .map(|e| { - self.create_physical_expr( - e, - logical_input_schema, - physical_input_schema, - ctx_state, - ) - }) - .collect::>>()?; + *distinct, + &args, + physical_input_schema, + name, + ) + } + Expr::AggregateUDF { fun, args, .. } => { + let args = args + .iter() + .map(|e| { + create_physical_expr( + e, + logical_input_schema, + physical_input_schema, + execution_props, + ) + }) + .collect::>>()?; - udaf::create_aggregate_expr(fun, &args, physical_input_schema, name) - } - other => Err(DataFusionError::Internal(format!( - "Invalid aggregate expression '{:?}'", - other - ))), + udaf::create_aggregate_expr(fun, &args, physical_input_schema, name) } + other => Err(DataFusionError::Internal(format!( + "Invalid aggregate expression '{:?}'", + other + ))), } +} - /// Create an aggregate expression from a logical expression or an alias - pub fn create_aggregate_expr( - &self, - e: &Expr, - logical_input_schema: &DFSchema, - physical_input_schema: &Schema, - ctx_state: &ExecutionContextState, - ) -> Result> { - // unpack (nested) aliased logical expressions, e.g. "sum(col) as total" - let (name, e) = match e { - Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()), - _ => (physical_name(e)?, e), - }; - - self.create_aggregate_expr_with_name( - e, - name, - logical_input_schema, - physical_input_schema, - ctx_state, - ) - } +/// Create an aggregate expression from a logical expression or an alias +pub fn create_aggregate_expr( + e: &Expr, + logical_input_schema: &DFSchema, + physical_input_schema: &Schema, + execution_props: &ExecutionProps, +) -> Result> { + // unpack (nested) aliased logical expressions, e.g. "sum(col) as total" + let (name, e) = match e { + Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()), + _ => (physical_name(e)?, e), + }; - /// Create a physical sort expression from a logical expression - pub fn create_physical_sort_expr( - &self, - e: &Expr, - input_dfschema: &DFSchema, - input_schema: &Schema, - options: SortOptions, - ctx_state: &ExecutionContextState, - ) -> Result { - Ok(PhysicalSortExpr { - expr: self.create_physical_expr( - e, - input_dfschema, - input_schema, - ctx_state, - )?, - options, - }) - } + create_aggregate_expr_with_name( + e, + name, + logical_input_schema, + physical_input_schema, + execution_props, + ) +} + +/// Create a physical sort expression from a logical expression +pub fn create_physical_sort_expr( + e: &Expr, + input_dfschema: &DFSchema, + input_schema: &Schema, + options: SortOptions, + execution_props: &ExecutionProps, +) -> Result { + Ok(PhysicalSortExpr { + expr: create_physical_expr(e, input_dfschema, input_schema, execution_props)?, + options, + }) +} +impl DefaultPhysicalPlanner { /// Handles capturing the various plans for EXPLAIN queries /// /// Returns From f849968057ddddccc9aa19915ef3ea56bf14d80d Mon Sep 17 00:00:00 2001 From: Yang <37145547+Ted-Jiang@users.noreply.github.com> Date: Mon, 31 Jan 2022 22:48:37 +0800 Subject: [PATCH 13/50] Fix can not load parquet table form spark in datafusion-cli. (#1665) * fix can not load parquet table form spark * add Invalid file in log. * fix fmt --- benchmarks/src/bin/tpch.rs | 6 ++-- .../examples/parquet_sql_multiple_files.rs | 6 ++-- datafusion/src/datasource/file_format/avro.rs | 2 ++ datafusion/src/datasource/file_format/csv.rs | 2 ++ datafusion/src/datasource/file_format/json.rs | 2 ++ datafusion/src/datasource/listing/table.rs | 6 ++-- datafusion/src/execution/context.rs | 31 ++++++++++--------- datafusion/src/execution/options.rs | 3 +- .../src/physical_plan/file_format/parquet.rs | 9 ++++-- 9 files changed, 43 insertions(+), 24 deletions(-) diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index 0b9fba52140b..59bb55162a8e 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -54,6 +54,8 @@ use datafusion::{ }, }; +use datafusion::datasource::file_format::csv::DEFAULT_CSV_EXTENSION; +use datafusion::datasource::file_format::parquet::DEFAULT_PARQUET_EXTENSION; use structopt::StructOpt; #[cfg(feature = "snmalloc")] @@ -652,13 +654,13 @@ fn get_table( .with_delimiter(b',') .with_has_header(true); - (Arc::new(format), path, ".csv") + (Arc::new(format), path, DEFAULT_CSV_EXTENSION) } "parquet" => { let path = format!("{}/{}", path, table); let format = ParquetFormat::default().with_enable_pruning(true); - (Arc::new(format), path, ".parquet") + (Arc::new(format), path, DEFAULT_PARQUET_EXTENSION) } other => { unimplemented!("Invalid file format '{}'", other); diff --git a/datafusion-examples/examples/parquet_sql_multiple_files.rs b/datafusion-examples/examples/parquet_sql_multiple_files.rs index 2e954276083e..7485bc72f193 100644 --- a/datafusion-examples/examples/parquet_sql_multiple_files.rs +++ b/datafusion-examples/examples/parquet_sql_multiple_files.rs @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. -use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::file_format::parquet::{ + ParquetFormat, DEFAULT_PARQUET_EXTENSION, +}; use datafusion::datasource::listing::ListingOptions; use datafusion::error::Result; use datafusion::prelude::*; @@ -33,7 +35,7 @@ async fn main() -> Result<()> { // Configure listing options let file_format = ParquetFormat::default().with_enable_pruning(true); let listing_options = ListingOptions { - file_extension: ".parquet".to_owned(), + file_extension: DEFAULT_PARQUET_EXTENSION.to_owned(), format: Arc::new(file_format), table_partition_cols: vec![], collect_stat: true, diff --git a/datafusion/src/datasource/file_format/avro.rs b/datafusion/src/datasource/file_format/avro.rs index 08eb34386fb2..fa02d1ae2833 100644 --- a/datafusion/src/datasource/file_format/avro.rs +++ b/datafusion/src/datasource/file_format/avro.rs @@ -34,6 +34,8 @@ use crate::physical_plan::file_format::{AvroExec, FileScanConfig}; use crate::physical_plan::ExecutionPlan; use crate::physical_plan::Statistics; +/// The default file extension of avro files +pub const DEFAULT_AVRO_EXTENSION: &str = ".avro"; /// Avro `FileFormat` implementation. #[derive(Default, Debug)] pub struct AvroFormat; diff --git a/datafusion/src/datasource/file_format/csv.rs b/datafusion/src/datasource/file_format/csv.rs index f0a70d9176db..6aa0d21235a4 100644 --- a/datafusion/src/datasource/file_format/csv.rs +++ b/datafusion/src/datasource/file_format/csv.rs @@ -33,6 +33,8 @@ use crate::physical_plan::file_format::{CsvExec, FileScanConfig}; use crate::physical_plan::ExecutionPlan; use crate::physical_plan::Statistics; +/// The default file extension of csv files +pub const DEFAULT_CSV_EXTENSION: &str = ".csv"; /// Character Separated Value `FileFormat` implementation. #[derive(Debug)] pub struct CsvFormat { diff --git a/datafusion/src/datasource/file_format/json.rs b/datafusion/src/datasource/file_format/json.rs index d7a278d72a6e..bdd5ef81d559 100644 --- a/datafusion/src/datasource/file_format/json.rs +++ b/datafusion/src/datasource/file_format/json.rs @@ -37,6 +37,8 @@ use crate::physical_plan::file_format::NdJsonExec; use crate::physical_plan::ExecutionPlan; use crate::physical_plan::Statistics; +/// The default file extension of json files +pub const DEFAULT_JSON_EXTENSION: &str = ".json"; /// New line delimited JSON `FileFormat` implementation. #[derive(Debug, Default)] pub struct JsonFormat { diff --git a/datafusion/src/datasource/listing/table.rs b/datafusion/src/datasource/listing/table.rs index 2f8f70f5ede5..1501b8bd7a18 100644 --- a/datafusion/src/datasource/listing/table.rs +++ b/datafusion/src/datasource/listing/table.rs @@ -266,6 +266,8 @@ impl ListingTable { mod tests { use arrow::datatypes::DataType; + use crate::datasource::file_format::avro::DEFAULT_AVRO_EXTENSION; + use crate::datasource::file_format::parquet::DEFAULT_PARQUET_EXTENSION; use crate::{ datasource::{ file_format::{avro::AvroFormat, parquet::ParquetFormat}, @@ -318,7 +320,7 @@ mod tests { let store = TestObjectStore::new_arc(&[("table/p1=v1/file.avro", 100)]); let opt = ListingOptions { - file_extension: ".avro".to_owned(), + file_extension: DEFAULT_AVRO_EXTENSION.to_owned(), format: Arc::new(AvroFormat {}), table_partition_cols: vec![String::from("p1")], target_partitions: 4, @@ -419,7 +421,7 @@ mod tests { let testdata = crate::test_util::parquet_test_data(); let filename = format!("{}/{}", testdata, name); let opt = ListingOptions { - file_extension: "parquet".to_owned(), + file_extension: DEFAULT_PARQUET_EXTENSION.to_owned(), format: Arc::new(ParquetFormat::default()), table_partition_cols: vec![], target_partitions: 2, diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index ca86d0f0a019..023d3a0023be 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -24,8 +24,8 @@ use crate::{ datasource::listing::{ListingOptions, ListingTable}, datasource::{ file_format::{ - avro::AvroFormat, - csv::CsvFormat, + avro::{AvroFormat, DEFAULT_AVRO_EXTENSION}, + csv::{CsvFormat, DEFAULT_CSV_EXTENSION}, parquet::{ParquetFormat, DEFAULT_PARQUET_EXTENSION}, FileFormat, }, @@ -218,17 +218,20 @@ impl ExecutionContext { ref file_type, ref has_header, }) => { - let file_format = match file_type { - FileType::CSV => { - Ok(Arc::new(CsvFormat::default().with_has_header(*has_header)) - as Arc) - } - FileType::Parquet => { - Ok(Arc::new(ParquetFormat::default()) as Arc) - } - FileType::Avro => { - Ok(Arc::new(AvroFormat::default()) as Arc) - } + let (file_format, file_extension) = match file_type { + FileType::CSV => Ok(( + Arc::new(CsvFormat::default().with_has_header(*has_header)) + as Arc, + DEFAULT_CSV_EXTENSION, + )), + FileType::Parquet => Ok(( + Arc::new(ParquetFormat::default()) as Arc, + DEFAULT_PARQUET_EXTENSION, + )), + FileType::Avro => Ok(( + Arc::new(AvroFormat::default()) as Arc, + DEFAULT_AVRO_EXTENSION, + )), _ => Err(DataFusionError::NotImplemented(format!( "Unsupported file type {:?}.", file_type @@ -238,7 +241,7 @@ impl ExecutionContext { let options = ListingOptions { format: file_format, collect_stat: false, - file_extension: String::new(), + file_extension: file_extension.to_owned(), target_partitions: self .state .lock() diff --git a/datafusion/src/execution/options.rs b/datafusion/src/execution/options.rs index 219e2fd89700..79b07536acb3 100644 --- a/datafusion/src/execution/options.rs +++ b/datafusion/src/execution/options.rs @@ -21,6 +21,7 @@ use std::sync::Arc; use arrow::datatypes::{Schema, SchemaRef}; +use crate::datasource::file_format::json::DEFAULT_JSON_EXTENSION; use crate::datasource::{ file_format::{avro::AvroFormat, csv::CsvFormat}, listing::ListingOptions, @@ -173,7 +174,7 @@ impl<'a> Default for NdJsonReadOptions<'a> { Self { schema: None, schema_infer_max_records: 1000, - file_extension: ".json", + file_extension: DEFAULT_JSON_EXTENSION, } } } diff --git a/datafusion/src/physical_plan/file_format/parquet.rs b/datafusion/src/physical_plan/file_format/parquet.rs index d240fe27c58a..905bb1e28f9a 100644 --- a/datafusion/src/physical_plan/file_format/parquet.rs +++ b/datafusion/src/physical_plan/file_format/parquet.rs @@ -221,7 +221,7 @@ impl ExecutionPlan for ParquetExec { object_store.as_ref(), file_schema_ref, partition_index, - partition, + &partition, metrics, &projection, &pruning_predicate, @@ -230,7 +230,10 @@ impl ExecutionPlan for ParquetExec { limit, partition_col_proj, ) { - println!("Parquet reader thread terminated due to error: {:?}", e); + println!( + "Parquet reader thread terminated due to error: {:?} for files: {:?}", + e, partition + ); } }); @@ -445,7 +448,7 @@ fn read_partition( object_store: &dyn ObjectStore, file_schema: SchemaRef, partition_index: usize, - partition: Vec, + partition: &[PartitionedFile], metrics: ExecutionPlanMetricsSet, projection: &[usize], pruning_predicate: &Option, From d01d8d5e56a47d2309d0946fdae89ab3c2c550e0 Mon Sep 17 00:00:00 2001 From: Remzi Yang <59198230+HaoYang670@users.noreply.github.com> Date: Tue, 1 Feb 2022 04:36:49 +0800 Subject: [PATCH 14/50] add upper bound for pub fn (#1713) Signed-off-by: remzi <13716567376yh@gmail.com> --- datafusion/src/physical_plan/aggregates.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index f7beb76df3bc..c40fd7104201 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -359,7 +359,7 @@ static TIMESTAMPS: &[DataType] = &[ static DATES: &[DataType] = &[DataType::Date32, DataType::Date64]; /// the signatures supported by the function `fun`. -pub fn signature(fun: &AggregateFunction) -> Signature { +pub(super) fn signature(fun: &AggregateFunction) -> Signature { // note: the physical expression must accept the type returned by this function or the execution panics. match fun { AggregateFunction::Count From 7bec762d1f1ebef4801af2eefd7a5033c474fe77 Mon Sep 17 00:00:00 2001 From: Dan Harris <1327726+thinkharderdev@users.noreply.github.com> Date: Mon, 31 Jan 2022 15:41:21 -0500 Subject: [PATCH 15/50] Create SchemaAdapter trait to map table schema to file schemas (#1709) * Create SchemaAdapter trait to map table schema to file schemas * Linting fix * Remove commented code --- .../src/physical_plan/file_format/avro.rs | 64 +++++++- .../src/physical_plan/file_format/csv.rs | 47 ++++++ .../src/physical_plan/file_format/json.rs | 43 ++++++ .../src/physical_plan/file_format/mod.rs | 140 ++++++++++++++++++ .../src/physical_plan/file_format/parquet.rs | 68 ++------- datafusion/src/test_util.rs | 26 ++++ 6 files changed, 330 insertions(+), 58 deletions(-) diff --git a/datafusion/src/physical_plan/file_format/avro.rs b/datafusion/src/physical_plan/file_format/avro.rs index 9de8ffded5f7..b38968665015 100644 --- a/datafusion/src/physical_plan/file_format/avro.rs +++ b/datafusion/src/physical_plan/file_format/avro.rs @@ -165,13 +165,14 @@ impl ExecutionPlan for AvroExec { #[cfg(test)] #[cfg(feature = "avro")] mod tests { - use crate::datasource::file_format::{avro::AvroFormat, FileFormat}; use crate::datasource::object_store::local::{ local_object_reader_stream, local_unpartitioned_file, LocalFileSystem, }; use crate::scalar::ScalarValue; + use arrow::datatypes::{DataType, Field, Schema}; use futures::StreamExt; + use sqlparser::ast::ObjectType::Schema; use super::*; @@ -228,6 +229,67 @@ mod tests { Ok(()) } + #[tokio::test] + async fn avro_exec_missing_column() -> Result<()> { + let testdata = crate::test_util::arrow_test_data(); + let filename = format!("{}/avro/alltypes_plain.avro", testdata); + let actual_schema = AvroFormat {} + .infer_schema(local_object_reader_stream(vec![filename])) + .await?; + + let mut fields = actual_schema.fields().clone(); + fields.push(Field::new("missing_col", DataType::Int32, true)); + + let file_schema = Arc::new(Schema::new(fields)); + + let avro_exec = AvroExec::new(FileScanConfig { + object_store: Arc::new(LocalFileSystem {}), + file_groups: vec![vec![local_unpartitioned_file(filename.clone())]], + file_schema, + statistics: Statistics::default(), + // Include the missing column in the projection + projection: Some(vec![0, 1, 2, file_schema.fields().len()]), + limit: None, + table_partition_cols: vec![], + }); + assert_eq!(avro_exec.output_partitioning().partition_count(), 1); + + let mut results = avro_exec.execute(0).await.expect("plan execution failed"); + let batch = results + .next() + .await + .expect("plan iterator empty") + .expect("plan iterator returned an error"); + + let expected = vec![ + "+----+----------+-------------+-------------+", + "| id | bool_col | tinyint_col | missing_col |", + "+----+----------+-------------+-------------+", + "| 4 | true | 0 | |", + "| 5 | false | 1 | |", + "| 6 | true | 0 | |", + "| 7 | false | 1 | |", + "| 2 | true | 0 | |", + "| 3 | false | 1 | |", + "| 0 | true | 0 | |", + "| 1 | false | 1 | |", + "+----+----------+-------------+-------------+", + ]; + + crate::assert_batches_eq!(expected, &[batch]); + + let batch = results.next().await; + assert!(batch.is_none()); + + let batch = results.next().await; + assert!(batch.is_none()); + + let batch = results.next().await; + assert!(batch.is_none()); + + Ok(()) + } + #[tokio::test] async fn avro_exec_with_partition() -> Result<()> { let testdata = crate::test_util::arrow_test_data(); diff --git a/datafusion/src/physical_plan/file_format/csv.rs b/datafusion/src/physical_plan/file_format/csv.rs index 5cff3b6c7296..4cf70f6e5cfd 100644 --- a/datafusion/src/physical_plan/file_format/csv.rs +++ b/datafusion/src/physical_plan/file_format/csv.rs @@ -170,6 +170,7 @@ impl ExecutionPlan for CsvExec { #[cfg(test)] mod tests { use super::*; + use crate::test_util::aggr_test_schema_with_missing_col; use crate::{ datasource::object_store::local::{local_unpartitioned_file, LocalFileSystem}, scalar::ScalarValue, @@ -269,6 +270,52 @@ mod tests { Ok(()) } + #[tokio::test] + async fn csv_exec_with_missing_column() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); + let file_schema = aggr_test_schema_with_missing_col(); + let testdata = crate::test_util::arrow_test_data(); + let filename = "aggregate_test_100.csv"; + let path = format!("{}/csv/{}", testdata, filename); + let csv = CsvExec::new( + FileScanConfig { + object_store: Arc::new(LocalFileSystem {}), + file_schema, + file_groups: vec![vec![local_unpartitioned_file(path)]], + statistics: Statistics::default(), + projection: None, + limit: Some(5), + table_partition_cols: vec![], + }, + true, + b',', + ); + assert_eq!(14, csv.base_config.file_schema.fields().len()); + assert_eq!(14, csv.projected_schema.fields().len()); + assert_eq!(14, csv.schema().fields().len()); + + let mut it = csv.execute(0, runtime).await?; + let batch = it.next().await.unwrap()?; + assert_eq!(14, batch.num_columns()); + assert_eq!(5, batch.num_rows()); + + let expected = vec![ + "+----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+-------------+", + "| c1 | c2 | c3 | c4 | c5 | c6 | c7 | c8 | c9 | c10 | c11 | c12 | c13 | missing_col |", + "+----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+-------------+", + "| c | 2 | 1 | 18109 | 2033001162 | -6513304855495910254 | 25 | 43062 | 1491205016 | 5863949479783605708 | 0.110830784 | 0.9294097332465232 | 6WfVFBVGJSQb7FhA7E0lBwdvjfZnSW | |", + "| d | 5 | -40 | 22614 | 706441268 | -7542719935673075327 | 155 | 14337 | 3373581039 | 11720144131976083864 | 0.69632107 | 0.3114712539863804 | C2GT5KVyOPZpgKVl110TyZO0NcJ434 | |", + "| b | 1 | 29 | -18218 | 994303988 | 5983957848665088916 | 204 | 9489 | 3275293996 | 14857091259186476033 | 0.53840446 | 0.17909035118828576 | AyYVExXK6AR2qUTxNZ7qRHQOVGMLcz | |", + "| a | 1 | -85 | -15154 | 1171968280 | 1919439543497968449 | 77 | 52286 | 774637006 | 12101411955859039553 | 0.12285209 | 0.6864391962767343 | 0keZ5G8BffGwgF2RwQD59TFzMStxCB | |", + "| b | 5 | -82 | 22080 | 1824882165 | 7373730676428214987 | 208 | 34331 | 3342719438 | 3330177516592499461 | 0.82634634 | 0.40975383525297016 | Ig1QcuKsjHXkproePdERo2w0mYzIqd | |", + "+----+----+-----+--------+------------+----------------------+-----+-------+------------+----------------------+-------------+---------------------+--------------------------------+-------------+", + ]; + + crate::assert_batches_eq!(expected, &[batch]); + + Ok(()) + } + #[tokio::test] async fn csv_exec_with_partition() -> Result<()> { let runtime = Arc::new(RuntimeEnv::default()); diff --git a/datafusion/src/physical_plan/file_format/json.rs b/datafusion/src/physical_plan/file_format/json.rs index 0fc95d1e4933..ac413062caf8 100644 --- a/datafusion/src/physical_plan/file_format/json.rs +++ b/datafusion/src/physical_plan/file_format/json.rs @@ -137,6 +137,8 @@ impl ExecutionPlan for NdJsonExec { #[cfg(test)] mod tests { + use arrow::array::Array; + use arrow::datatypes::{Field, Schema}; use futures::StreamExt; use crate::datasource::{ @@ -211,6 +213,47 @@ mod tests { Ok(()) } + #[tokio::test] + async fn nd_json_exec_file_with_missing_column() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); + use arrow::datatypes::DataType; + let path = format!("{}/1.json", TEST_DATA_BASE); + + let actual_schema = infer_schema(path.clone()).await?; + + let mut fields = actual_schema.fields().clone(); + fields.push(Field::new("missing_col", DataType::Int32, true)); + let missing_field_idx = fields.len() - 1; + + let file_schema = Arc::new(Schema::new(fields)); + + let exec = NdJsonExec::new(FileScanConfig { + object_store: Arc::new(LocalFileSystem {}), + file_groups: vec![vec![local_unpartitioned_file(path.clone())]], + file_schema, + statistics: Statistics::default(), + projection: None, + limit: Some(3), + table_partition_cols: vec![], + }); + + let mut it = exec.execute(0, runtime).await?; + let batch = it.next().await.unwrap()?; + + assert_eq!(batch.num_rows(), 3); + let values = batch + .column(missing_field_idx) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(values.len(), 3); + assert!(values.is_null(0)); + assert!(values.is_null(1)); + assert!(values.is_null(2)); + + Ok(()) + } + #[tokio::test] async fn nd_json_exec_file_projection() -> Result<()> { let runtime = Arc::new(RuntimeEnv::default()); diff --git a/datafusion/src/physical_plan/file_format/mod.rs b/datafusion/src/physical_plan/file_format/mod.rs index b655cdb09dbb..7658addd3561 100644 --- a/datafusion/src/physical_plan/file_format/mod.rs +++ b/datafusion/src/physical_plan/file_format/mod.rs @@ -35,11 +35,15 @@ pub use avro::AvroExec; pub use csv::CsvExec; pub use json::NdJsonExec; +use crate::error::DataFusionError; use crate::{ datasource::{object_store::ObjectStore, PartitionedFile}, + error::Result, scalar::ScalarValue, }; +use arrow::array::new_null_array; use lazy_static::lazy_static; +use log::info; use std::{ collections::HashMap, fmt::{Display, Formatter, Result as FmtResult}, @@ -165,6 +169,87 @@ impl<'a> Display for FileGroupsDisplay<'a> { } } +/// A utility which can adapt file-level record batches to a table schema which may have a schema +/// obtained from merging multiple file-level schemas. +/// +/// This is useful for enabling schema evolution in partitioned datasets. +/// +/// This has to be done in two stages. +/// +/// 1. Before reading the file, we have to map projected column indexes from the table schema to +/// the file schema. +/// +/// 2. After reading a record batch we need to map the read columns back to the expected columns +/// indexes and insert null-valued columns wherever the file schema was missing a colum present +/// in the table schema. +#[derive(Clone, Debug)] +pub(crate) struct SchemaAdapter { + /// Schema for the table + table_schema: SchemaRef, +} + +impl SchemaAdapter { + pub(crate) fn new(table_schema: SchemaRef) -> SchemaAdapter { + Self { table_schema } + } + + /// Map projected column indexes to the file schema. This will fail if the table schema + /// and the file schema contain a field with the same name and different types. + pub fn map_projections( + &self, + file_schema: &Schema, + projections: &[usize], + ) -> Result> { + let mut mapped: Vec = vec![]; + for idx in projections { + let field = self.table_schema.field(*idx); + if let Ok(mapped_idx) = file_schema.index_of(field.name().as_str()) { + if file_schema.field(mapped_idx).data_type() == field.data_type() { + mapped.push(mapped_idx) + } else { + let msg = format!("Failed to map column projection for field {}. Incompatible data types {:?} and {:?}", field.name(), file_schema.field(mapped_idx).data_type(), field.data_type()); + info!("{}", msg); + return Err(DataFusionError::Execution(msg)); + } + } + } + Ok(mapped) + } + + /// Re-order projected columns by index in record batch to match table schema column ordering. If the record + /// batch does not contain a column for an expected field, insert a null-valued column at the + /// required column index. + pub fn adapt_batch( + &self, + batch: RecordBatch, + projections: &[usize], + ) -> Result { + let batch_rows = batch.num_rows(); + + let batch_schema = batch.schema(); + + let mut cols: Vec = Vec::with_capacity(batch.columns().len()); + let batch_cols = batch.columns().to_vec(); + + for field_idx in projections { + let table_field = &self.table_schema.fields()[*field_idx]; + if let Some((batch_idx, _name)) = + batch_schema.column_with_name(table_field.name().as_str()) + { + cols.push(batch_cols[batch_idx].clone()); + } else { + cols.push(new_null_array(table_field.data_type(), batch_rows)) + } + } + + let projected_schema = Arc::new(self.table_schema.clone().project(projections)?); + + let merged_batch = RecordBatch::try_new(projected_schema, cols)?; + + Ok(merged_batch) + } +} + /// A helper that projects partition columns into the file record batches. /// /// One interesting trick is the usage of a cache for the key buffers of the partition column @@ -467,6 +552,61 @@ mod tests { crate::assert_batches_eq!(expected, &[projected_batch]); } + #[test] + fn schema_adapter_adapt_projections() { + let table_schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, true), + Field::new("c2", DataType::Int64, true), + Field::new("c3", DataType::Int8, true), + ])); + + let file_schema = Schema::new(vec![ + Field::new("c1", DataType::Utf8, true), + Field::new("c2", DataType::Int64, true), + ]); + + let file_schema_2 = Arc::new(Schema::new(vec![ + Field::new("c3", DataType::Int8, true), + Field::new("c2", DataType::Int64, true), + ])); + + let file_schema_3 = + Arc::new(Schema::new(vec![Field::new("c3", DataType::Float32, true)])); + + let adapter = SchemaAdapter::new(table_schema); + + let projections1: Vec = vec![0, 1, 2]; + let projections2: Vec = vec![2]; + + let mapped = adapter + .map_projections(&file_schema, projections1.as_slice()) + .expect("mapping projections"); + + assert_eq!(mapped, vec![0, 1]); + + let mapped = adapter + .map_projections(&file_schema, projections2.as_slice()) + .expect("mapping projections"); + + assert!(mapped.is_empty()); + + let mapped = adapter + .map_projections(&file_schema_2, projections1.as_slice()) + .expect("mapping projections"); + + assert_eq!(mapped, vec![1, 0]); + + let mapped = adapter + .map_projections(&file_schema_2, projections2.as_slice()) + .expect("mapping projections"); + + assert_eq!(mapped, vec![0]); + + let mapped = adapter.map_projections(&file_schema_3, projections1.as_slice()); + + assert!(mapped.is_err()); + } + // sets default for configs that play no role in projections fn config_for_projection( file_schema: SchemaRef, diff --git a/datafusion/src/physical_plan/file_format/parquet.rs b/datafusion/src/physical_plan/file_format/parquet.rs index 905bb1e28f9a..40acf5a51c17 100644 --- a/datafusion/src/physical_plan/file_format/parquet.rs +++ b/datafusion/src/physical_plan/file_format/parquet.rs @@ -44,14 +44,13 @@ use arrow::{ error::{ArrowError, Result as ArrowResult}, record_batch::RecordBatch, }; -use log::{debug, info}; +use log::debug; use parquet::file::{ metadata::RowGroupMetaData, reader::{FileReader, SerializedFileReader}, statistics::Statistics as ParquetStatistics, }; -use arrow::array::new_null_array; use fmt::Debug; use parquet::arrow::{ArrowReader, ParquetFileArrowReader}; @@ -61,6 +60,7 @@ use tokio::{ }; use crate::execution::runtime_env::RuntimeEnv; +use crate::physical_plan::file_format::SchemaAdapter; use async_trait::async_trait; use super::PartitionColumnProjector; @@ -215,11 +215,12 @@ impl ExecutionPlan for ParquetExec { &self.base_config.table_partition_cols, ); - let file_schema_ref = self.base_config().file_schema.clone(); + let adapter = SchemaAdapter::new(self.base_config.file_schema.clone()); + let join_handle = task::spawn_blocking(move || { if let Err(e) = read_partition( object_store.as_ref(), - file_schema_ref, + adapter, partition_index, &partition, metrics, @@ -420,33 +421,10 @@ fn build_row_group_predicate( } } -// Map projections from the schema which merges all file schemas to projections on a particular -// file -fn map_projections( - merged_schema: &Schema, - file_schema: &Schema, - projections: &[usize], -) -> Result> { - let mut mapped: Vec = vec![]; - for idx in projections { - let field = merged_schema.field(*idx); - if let Ok(mapped_idx) = file_schema.index_of(field.name().as_str()) { - if file_schema.field(mapped_idx).data_type() == field.data_type() { - mapped.push(mapped_idx) - } else { - let msg = format!("Failed to map column projection for field {}. Incompatible data types {:?} and {:?}", field.name(), file_schema.field(mapped_idx).data_type(), field.data_type()); - info!("{}", msg); - return Err(DataFusionError::Execution(msg)); - } - } - } - Ok(mapped) -} - #[allow(clippy::too_many_arguments)] fn read_partition( object_store: &dyn ObjectStore, - file_schema: SchemaRef, + schema_adapter: SchemaAdapter, partition_index: usize, partition: &[PartitionedFile], metrics: ExecutionPlanMetricsSet, @@ -480,44 +458,20 @@ fn read_partition( } let mut arrow_reader = ParquetFileArrowReader::new(Arc::new(file_reader)); - let mapped_projections = - map_projections(&file_schema, &arrow_reader.get_schema()?, projection)?; + let adapted_projections = + schema_adapter.map_projections(&arrow_reader.get_schema()?, projection)?; let mut batch_reader = - arrow_reader.get_record_reader_by_columns(mapped_projections, batch_size)?; + arrow_reader.get_record_reader_by_columns(adapted_projections, batch_size)?; loop { match batch_reader.next() { Some(Ok(batch)) => { - let total_cols = &file_schema.fields().len(); - let batch_rows = batch.num_rows(); total_rows += batch.num_rows(); - let batch_schema = batch.schema(); - - let mut cols: Vec = Vec::with_capacity(*total_cols); - let batch_cols = batch.columns().to_vec(); - - for field_idx in projection { - let merged_field = &file_schema.fields()[*field_idx]; - if let Some((batch_idx, _name)) = - batch_schema.column_with_name(merged_field.name().as_str()) - { - cols.push(batch_cols[batch_idx].clone()); - } else { - cols.push(new_null_array( - merged_field.data_type(), - batch_rows, - )) - } - } - - let projected_schema = file_schema.clone().project(projection)?; - - let merged_batch = - RecordBatch::try_new(Arc::new(projected_schema), cols)?; + let adapted_batch = schema_adapter.adapt_batch(batch, projection)?; let proj_batch = partition_column_projector - .project(merged_batch, &partitioned_file.partition_values); + .project(adapted_batch, &partitioned_file.partition_values); send_result(&response_tx, proj_batch)?; if limit.map(|l| total_rows >= l).unwrap_or(false) { diff --git a/datafusion/src/test_util.rs b/datafusion/src/test_util.rs index af6650361123..8ee0298f72ce 100644 --- a/datafusion/src/test_util.rs +++ b/datafusion/src/test_util.rs @@ -257,6 +257,32 @@ pub fn aggr_test_schema() -> SchemaRef { Arc::new(schema) } +/// Get the schema for the aggregate_test_* csv files with an additional filed not present in the files. +pub fn aggr_test_schema_with_missing_col() -> SchemaRef { + let mut f1 = Field::new("c1", DataType::Utf8, false); + f1.set_metadata(Some(BTreeMap::from_iter( + vec![("testing".into(), "test".into())].into_iter(), + ))); + let schema = Schema::new(vec![ + f1, + Field::new("c2", DataType::UInt32, false), + Field::new("c3", DataType::Int8, false), + Field::new("c4", DataType::Int16, false), + Field::new("c5", DataType::Int32, false), + Field::new("c6", DataType::Int64, false), + Field::new("c7", DataType::UInt8, false), + Field::new("c8", DataType::UInt16, false), + Field::new("c9", DataType::UInt32, false), + Field::new("c10", DataType::UInt64, false), + Field::new("c11", DataType::Float32, false), + Field::new("c12", DataType::Float64, false), + Field::new("c13", DataType::Utf8, false), + Field::new("missing_col", DataType::Int64, true), + ]); + + Arc::new(schema) +} + #[cfg(test)] mod tests { use super::*; From cfb655dc09013d161ef15d9502718998a6c4f86e Mon Sep 17 00:00:00 2001 From: Dom Date: Mon, 31 Jan 2022 20:41:56 +0000 Subject: [PATCH 16/50] approx_quantile() aggregation function (#1539) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: implement TDigest for approx quantile Adds a [TDigest] implementation providing approximate quantile estimations of large inputs using a small amount of (bounded) memory. A TDigest is most accurate near either "end" of the quantile range (that is, 0.1, 0.9, 0.95, etc) due to the use of a scalaing function that increases resolution at the tails. The paper claims single digit part per million errors for q ≤ 0.001 or q ≥ 0.999 using 100 centroids, and in practice I have found accuracy to be more than acceptable for an apprixmate function across the entire quantile range. The implementation is a modified copy of https://github.com/MnO2/t-digest, itself a Rust port of [Facebook's C++ implementation]. Both Facebook's implementation, and Mn02's Rust port are Apache 2.0 licensed. [TDigest]: https://arxiv.org/abs/1902.04023 [Facebook's C++ implementation]: https://github.com/facebook/folly/blob/main/folly/stats/TDigest.h * feat: approx_quantile aggregation Adds the ApproxQuantile physical expression, plumbing & test cases. The function signature is: approx_quantile(column, quantile) Where column can be any numeric type (that can be cast to a float64) and quantile is a float64 literal between 0 and 1. * feat: approx_quantile dataframe function Adds the approx_quantile() dataframe function, and exports it in the prelude. * refactor: bastilla approx_quantile support Adds bastilla wire encoding for approx_quantile. Adding support for this required modifying the AggregateExprNode proto message to support propigating multiple LogicalExprNode aggregate arguments - all the existing aggregations take a single argument, so this wasn't needed before. This commit adds "repeated" to the expr field, which I believe is backwards compatible as described here: https://developers.google.com/protocol-buffers/docs/proto3#updating Specifically, adding "repeated" to an existing message field: "For ... message fields, optional is compatible with repeated" No existing tests needed fixing, and a new roundtrip test is included that covers the change to allow multiple expr. * refactor: use input type as return type Casts the calculated quantile value to the same type as the input data. * fixup! refactor: bastilla approx_quantile support * refactor: rebase onto main * refactor: validate quantile value Ensures the quantile values is between 0 and 1, emitting a plan error if not. * refactor: rename to approx_percentile_cont * refactor: clippy lints --- ballista/rust/core/proto/ballista.proto | 3 +- .../core/src/serde/logical_plan/from_proto.rs | 6 +- .../rust/core/src/serde/logical_plan/mod.rs | 21 +- .../core/src/serde/logical_plan/to_proto.rs | 14 +- ballista/rust/core/src/serde/mod.rs | 3 + datafusion/src/logical_plan/expr.rs | 9 + datafusion/src/logical_plan/mod.rs | 14 +- datafusion/src/physical_plan/aggregates.rs | 87 +- .../coercion_rule/aggregate_rule.rs | 118 ++- .../expressions/approx_percentile_cont.rs | 313 +++++++ .../src/physical_plan/expressions/mod.rs | 4 + datafusion/src/physical_plan/mod.rs | 1 + datafusion/src/physical_plan/tdigest/mod.rs | 818 ++++++++++++++++++ datafusion/src/prelude.rs | 12 +- datafusion/tests/dataframe_functions.rs | 20 + datafusion/tests/sql/aggregates.rs | 89 ++ 16 files changed, 1485 insertions(+), 47 deletions(-) create mode 100644 datafusion/src/physical_plan/expressions/approx_percentile_cont.rs create mode 100644 datafusion/src/physical_plan/tdigest/mod.rs diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 15a7342d7b14..fb006e532ff3 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -176,11 +176,12 @@ enum AggregateFunction { STDDEV=11; STDDEV_POP=12; CORRELATION=13; + APPROX_PERCENTILE_CONT = 14; } message AggregateExprNode { AggregateFunction aggr_function = 1; - LogicalExprNode expr = 2; + repeated LogicalExprNode expr = 2; } enum BuiltInWindowFunction { diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 568485591425..044f823251a8 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -1065,7 +1065,11 @@ impl TryInto for &protobuf::LogicalExprNode { Ok(Expr::AggregateFunction { fun, - args: vec![parse_required_expr(&expr.expr)?], + args: expr + .expr + .iter() + .map(|e| e.try_into()) + .collect::, _>>()?, distinct: false, //TODO }) } diff --git a/ballista/rust/core/src/serde/logical_plan/mod.rs b/ballista/rust/core/src/serde/logical_plan/mod.rs index c09b8a57d4aa..c00e3e42912a 100644 --- a/ballista/rust/core/src/serde/logical_plan/mod.rs +++ b/ballista/rust/core/src/serde/logical_plan/mod.rs @@ -24,16 +24,14 @@ mod roundtrip_tests { use super::super::{super::error::Result, protobuf}; use crate::error::BallistaError; use core::panic; - use datafusion::arrow::datatypes::UnionMode; - use datafusion::logical_plan::Repartition; use datafusion::{ - arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}, + arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit, UnionMode}, datasource::object_store::local::LocalFileSystem, logical_plan::{ col, CreateExternalTable, Expr, LogicalPlan, LogicalPlanBuilder, - Partitioning, ToDFSchema, + Partitioning, Repartition, ToDFSchema, }, - physical_plan::functions::BuiltinScalarFunction::Sqrt, + physical_plan::{aggregates, functions::BuiltinScalarFunction::Sqrt}, prelude::*, scalar::ScalarValue, sql::parser::FileType, @@ -1001,4 +999,17 @@ mod roundtrip_tests { Ok(()) } + + #[test] + fn roundtrip_approx_percentile_cont() -> Result<()> { + let test_expr = Expr::AggregateFunction { + fun: aggregates::AggregateFunction::ApproxPercentileCont, + args: vec![col("bananas"), lit(0.42)], + distinct: false, + }; + + roundtrip_test!(test_expr, protobuf::LogicalExprNode, Expr); + + Ok(()) + } } diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index eb5d8102de42..4b13ce577cfb 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -1074,6 +1074,9 @@ impl TryInto for &Expr { AggregateFunction::ApproxDistinct => { protobuf::AggregateFunction::ApproxDistinct } + AggregateFunction::ApproxPercentileCont => { + protobuf::AggregateFunction::ApproxPercentileCont + } AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, AggregateFunction::Min => protobuf::AggregateFunction::Min, AggregateFunction::Max => protobuf::AggregateFunction::Max, @@ -1099,11 +1102,13 @@ impl TryInto for &Expr { } }; - let arg = &args[0]; - let aggregate_expr = Box::new(protobuf::AggregateExprNode { + let aggregate_expr = protobuf::AggregateExprNode { aggr_function: aggr_function.into(), - expr: Some(Box::new(arg.try_into()?)), - }); + expr: args + .iter() + .map(|v| v.try_into()) + .collect::, _>>()?, + }; Ok(protobuf::LogicalExprNode { expr_type: Some(ExprType::AggregateExpr(aggregate_expr)), }) @@ -1334,6 +1339,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Stddev => Self::Stddev, AggregateFunction::StddevPop => Self::StddevPop, AggregateFunction::Correlation => Self::Correlation, + AggregateFunction::ApproxPercentileCont => Self::ApproxPercentileCont, } } } diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index 4026273a9eb7..64a60dc4da5d 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -129,6 +129,9 @@ impl From for AggregateFunction { protobuf::AggregateFunction::Stddev => AggregateFunction::Stddev, protobuf::AggregateFunction::StddevPop => AggregateFunction::StddevPop, protobuf::AggregateFunction::Correlation => AggregateFunction::Correlation, + protobuf::AggregateFunction::ApproxPercentileCont => { + AggregateFunction::ApproxPercentileCont + } } } } diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 98c296939bc5..a1e51e07422e 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -1647,6 +1647,15 @@ pub fn approx_distinct(expr: Expr) -> Expr { } } +/// Calculate an approximation of the specified `percentile` for `expr`. +pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr { + Expr::AggregateFunction { + fun: aggregates::AggregateFunction::ApproxPercentileCont, + distinct: false, + args: vec![expr, percentile], + } +} + // TODO(kszucs): this seems buggy, unary_scalar_expr! is used for many // varying arity functions /// Create an convenience function representing a unary scalar function diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index 56fec3cf1a0c..06c6bf90c790 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -36,13 +36,13 @@ pub use builder::{ pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; pub use expr::{ - abs, acos, and, approx_distinct, array, ascii, asin, atan, avg, binary_expr, - bit_length, btrim, case, ceil, character_length, chr, col, columnize_expr, - combine_filters, concat, concat_ws, cos, count, count_distinct, create_udaf, - create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, floor, in_list, - initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, lower, lpad, ltrim, - max, md5, min, normalize_col, normalize_cols, now, octet_length, or, random, - regexp_match, regexp_replace, repeat, replace, replace_col, reverse, + abs, acos, and, approx_distinct, approx_percentile_cont, array, ascii, asin, atan, + avg, binary_expr, bit_length, btrim, case, ceil, character_length, chr, col, + columnize_expr, combine_filters, concat, concat_ws, cos, count, count_distinct, + create_udaf, create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, + floor, in_list, initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, + lower, lpad, ltrim, max, md5, min, normalize_col, normalize_cols, now, octet_length, + or, random, regexp_match, regexp_replace, repeat, replace, replace_col, reverse, rewrite_sort_cols_by_aggs, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc, unalias, unnormalize_col, unnormalize_cols, upper, when, diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index c40fd7104201..8fc94d386014 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -27,7 +27,7 @@ //! * Return type: a function `(arg_types) -> return_type`. E.g. for min, ([f32]) -> f32, ([f64]) -> f64. use super::{ - functions::{Signature, Volatility}, + functions::{Signature, TypeSignature, Volatility}, Accumulator, AggregateExpr, PhysicalExpr, }; use crate::error::{DataFusionError, Result}; @@ -80,6 +80,8 @@ pub enum AggregateFunction { CovariancePop, /// Correlation Correlation, + /// Approximate continuous percentile function + ApproxPercentileCont, } impl fmt::Display for AggregateFunction { @@ -110,6 +112,7 @@ impl FromStr for AggregateFunction { "covar_samp" => AggregateFunction::Covariance, "covar_pop" => AggregateFunction::CovariancePop, "corr" => AggregateFunction::Correlation, + "approx_percentile_cont" => AggregateFunction::ApproxPercentileCont, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", @@ -157,6 +160,7 @@ pub fn return_type( coerced_data_types[0].clone(), true, )))), + AggregateFunction::ApproxPercentileCont => Ok(coerced_data_types[0].clone()), } } @@ -331,6 +335,20 @@ pub fn create_aggregate_expr( "CORR(DISTINCT) aggregations are not available".to_string(), )); } + (AggregateFunction::ApproxPercentileCont, false) => { + Arc::new(expressions::ApproxPercentileCont::new( + // Pass in the desired percentile expr + coerced_phy_exprs, + name, + return_type, + )?) + } + (AggregateFunction::ApproxPercentileCont, true) => { + return Err(DataFusionError::NotImplemented( + "approx_percentile_cont(DISTINCT) aggregations are not available" + .to_string(), + )); + } }) } @@ -389,17 +407,25 @@ pub(super) fn signature(fun: &AggregateFunction) -> Signature { AggregateFunction::Correlation => { Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable) } + AggregateFunction::ApproxPercentileCont => Signature::one_of( + // Accept any numeric value paired with a float64 percentile + NUMERICS + .iter() + .map(|t| TypeSignature::Exact(vec![t.clone(), DataType::Float64])) + .collect(), + Volatility::Immutable, + ), } } #[cfg(test)] mod tests { use super::*; - use crate::error::Result; use crate::physical_plan::expressions::{ - ApproxDistinct, ArrayAgg, Avg, Correlation, Count, Covariance, DistinctArrayAgg, - DistinctCount, Max, Min, Stddev, Sum, Variance, + ApproxDistinct, ApproxPercentileCont, ArrayAgg, Avg, Correlation, Count, + Covariance, DistinctArrayAgg, DistinctCount, Max, Min, Stddev, Sum, Variance, }; + use crate::{error::Result, scalar::ScalarValue}; #[test] fn test_count_arragg_approx_expr() -> Result<()> { @@ -513,6 +539,59 @@ mod tests { Ok(()) } + #[test] + fn test_agg_approx_percentile_phy_expr() { + for data_type in NUMERICS { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![ + Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + ), + Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.2)))), + ]; + let result_agg_phy_exprs = create_aggregate_expr( + &AggregateFunction::ApproxPercentileCont, + false, + &input_phy_exprs[..], + &input_schema, + "c1", + ) + .expect("failed to create aggregate expr"); + + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", data_type.clone(), false), + result_agg_phy_exprs.field().unwrap() + ); + } + } + + #[test] + fn test_agg_approx_percentile_invalid_phy_expr() { + for data_type in NUMERICS { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![ + Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + ), + Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(4.2)))), + ]; + let err = create_aggregate_expr( + &AggregateFunction::ApproxPercentileCont, + false, + &input_phy_exprs[..], + &input_schema, + "c1", + ) + .expect_err("should fail due to invalid percentile"); + + assert!(matches!(err, DataFusionError::Plan(_))); + } + } + #[test] fn test_min_max_expr() -> Result<()> { let funcs = vec![AggregateFunction::Min, AggregateFunction::Max]; diff --git a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs index c151fb70a084..bae2de74c7b7 100644 --- a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs +++ b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs @@ -17,7 +17,6 @@ //! Support the coercion rule for aggregate function. -use crate::arrow::datatypes::Schema; use crate::error::{DataFusionError, Result}; use crate::physical_plan::aggregates::AggregateFunction; use crate::physical_plan::expressions::{ @@ -27,6 +26,10 @@ use crate::physical_plan::expressions::{ }; use crate::physical_plan::functions::{Signature, TypeSignature}; use crate::physical_plan::PhysicalExpr; +use crate::{ + arrow::datatypes::Schema, + physical_plan::expressions::is_approx_percentile_cont_supported_arg_type, +}; use arrow::datatypes::DataType; use std::ops::Deref; use std::sync::Arc; @@ -38,24 +41,9 @@ pub(crate) fn coerce_types( input_types: &[DataType], signature: &Signature, ) -> Result> { - match signature.type_signature { - TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => { - if input_types.len() != agg_count { - return Err(DataFusionError::Plan(format!( - "The function {:?} expects {:?} arguments, but {:?} were provided", - agg_fun, - agg_count, - input_types.len() - ))); - } - } - _ => { - return Err(DataFusionError::Internal(format!( - "Aggregate functions do not support this {:?}", - signature - ))); - } - }; + // Validate input_types matches (at least one of) the func signature. + check_arg_count(agg_fun, input_types, &signature.type_signature)?; + match agg_fun { AggregateFunction::Count | AggregateFunction::ApproxDistinct => { Ok(input_types.to_vec()) @@ -151,7 +139,75 @@ pub(crate) fn coerce_types( } Ok(input_types.to_vec()) } + AggregateFunction::ApproxPercentileCont => { + if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + if !matches!(input_types[1], DataType::Float64) { + return Err(DataFusionError::Plan(format!( + "The percentile argument for {:?} must be Float64, not {:?}.", + agg_fun, input_types[1] + ))); + } + Ok(input_types.to_vec()) + } + } +} + +/// Validate the length of `input_types` matches the `signature` for `agg_fun`. +/// +/// This method DOES NOT validate the argument types - only that (at least one, +/// in the case of [`TypeSignature::OneOf`]) signature matches the desired +/// number of input types. +fn check_arg_count( + agg_fun: &AggregateFunction, + input_types: &[DataType], + signature: &TypeSignature, +) -> Result<()> { + match signature { + TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => { + if input_types.len() != *agg_count { + return Err(DataFusionError::Plan(format!( + "The function {:?} expects {:?} arguments, but {:?} were provided", + agg_fun, + agg_count, + input_types.len() + ))); + } + } + TypeSignature::Exact(types) => { + if types.len() != input_types.len() { + return Err(DataFusionError::Plan(format!( + "The function {:?} expects {:?} arguments, but {:?} were provided", + agg_fun, + types.len(), + input_types.len() + ))); + } + } + TypeSignature::OneOf(variants) => { + let ok = variants + .iter() + .any(|v| check_arg_count(agg_fun, input_types, v).is_ok()); + if !ok { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not accept {:?} function arguments.", + agg_fun, + input_types.len() + ))); + } + } + _ => { + return Err(DataFusionError::Internal(format!( + "Aggregate functions do not support this {:?}", + signature + ))); + } } + Ok(()) } fn get_min_max_result_type(input_types: &[DataType]) -> Result> { @@ -267,5 +323,29 @@ mod tests { assert_eq!(*input_type, result.unwrap()); } } + + // ApproxPercentileCont input types + let input_types = vec![ + vec![DataType::Int8, DataType::Float64], + vec![DataType::Int16, DataType::Float64], + vec![DataType::Int32, DataType::Float64], + vec![DataType::Int64, DataType::Float64], + vec![DataType::UInt8, DataType::Float64], + vec![DataType::UInt16, DataType::Float64], + vec![DataType::UInt32, DataType::Float64], + vec![DataType::UInt64, DataType::Float64], + vec![DataType::Float32, DataType::Float64], + vec![DataType::Float64, DataType::Float64], + ]; + for input_type in &input_types { + let signature = + aggregates::signature(&AggregateFunction::ApproxPercentileCont); + let result = coerce_types( + &AggregateFunction::ApproxPercentileCont, + input_type, + &signature, + ); + assert_eq!(*input_type, result.unwrap()); + } } } diff --git a/datafusion/src/physical_plan/expressions/approx_percentile_cont.rs b/datafusion/src/physical_plan/expressions/approx_percentile_cont.rs new file mode 100644 index 000000000000..cba30ee481ab --- /dev/null +++ b/datafusion/src/physical_plan/expressions/approx_percentile_cont.rs @@ -0,0 +1,313 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, iter, sync::Arc}; + +use arrow::{ + array::{ + ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, + Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + }, + datatypes::{DataType, Field}, +}; + +use crate::{ + error::DataFusionError, + physical_plan::{tdigest::TDigest, Accumulator, AggregateExpr, PhysicalExpr}, + scalar::ScalarValue, +}; + +use crate::error::Result; + +use super::{format_state_name, Literal}; + +/// Return `true` if `arg_type` is of a [`DataType`] that the +/// [`ApproxPercentileCont`] aggregation can operate on. +pub fn is_approx_percentile_cont_supported_arg_type(arg_type: &DataType) -> bool { + matches!( + arg_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + ) +} + +/// APPROX_PERCENTILE_CONT aggregate expression +#[derive(Debug)] +pub struct ApproxPercentileCont { + name: String, + input_data_type: DataType, + expr: Arc, + percentile: f64, +} + +impl ApproxPercentileCont { + /// Create a new [`ApproxPercentileCont`] aggregate function. + pub fn new( + expr: Vec>, + name: impl Into, + input_data_type: DataType, + ) -> Result { + // Arguments should be [ColumnExpr, DesiredPercentileLiteral] + debug_assert_eq!(expr.len(), 2); + + // Extract the desired percentile literal + let lit = expr[1] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "desired percentile argument must be float literal".to_string(), + ) + })? + .value(); + let percentile = match lit { + ScalarValue::Float32(Some(q)) => *q as f64, + ScalarValue::Float64(Some(q)) => *q as f64, + got => return Err(DataFusionError::NotImplemented(format!( + "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})", + got + ))) + }; + + // Ensure the percentile is between 0 and 1. + if !(0.0..=1.0).contains(&percentile) { + return Err(DataFusionError::Plan(format!( + "Percentile value must be between 0.0 and 1.0 inclusive, {} is invalid", + percentile + ))); + } + + Ok(Self { + name: name.into(), + input_data_type, + // The physical expr to evaluate during accumulation + expr: expr[0].clone(), + percentile, + }) + } +} + +impl AggregateExpr for ApproxPercentileCont { + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, self.input_data_type.clone(), false)) + } + + /// See [`TDigest::to_scalar_state()`] for a description of the serialised + /// state. + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + &format_state_name(&self.name, "max_size"), + DataType::UInt64, + false, + ), + Field::new( + &format_state_name(&self.name, "sum"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "count"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "max"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "min"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "centroids"), + DataType::List(Box::new(Field::new("item", DataType::Float64, true))), + false, + ), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn create_accumulator(&self) -> Result> { + let accumulator: Box = match &self.input_data_type { + t @ (DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64) => { + Box::new(ApproxPercentileAccumulator::new(self.percentile, t.clone())) + } + other => { + return Err(DataFusionError::NotImplemented(format!( + "Support for 'APPROX_PERCENTILE_CONT' for data type {} is not implemented", + other + ))) + } + }; + Ok(accumulator) + } + + fn name(&self) -> &str { + &self.name + } +} + +#[derive(Debug)] +pub struct ApproxPercentileAccumulator { + digest: TDigest, + percentile: f64, + return_type: DataType, +} + +impl ApproxPercentileAccumulator { + pub fn new(percentile: f64, return_type: DataType) -> Self { + Self { + digest: TDigest::new(100), + percentile, + return_type, + } + } +} + +impl Accumulator for ApproxPercentileAccumulator { + fn state(&self) -> Result> { + Ok(self.digest.to_scalar_state()) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + debug_assert_eq!( + values.len(), + 1, + "invalid number of values in batch percentile update" + ); + let values = &values[0]; + + self.digest = match values.data_type() { + DataType::Float64 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::Float32 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::Int64 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::Int32 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::Int16 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::Int8 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::UInt64 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::UInt32 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::UInt16 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + DataType::UInt8 => { + let array = values.as_any().downcast_ref::().unwrap(); + self.digest.merge_unsorted(array.values().iter().cloned())? + } + e => { + return Err(DataFusionError::Internal(format!( + "APPROX_PERCENTILE_CONT is not expected to receive the type {:?}", + e + ))); + } + }; + + Ok(()) + } + + fn evaluate(&self) -> Result { + let q = self.digest.estimate_quantile(self.percentile); + + // These acceptable return types MUST match the validation in + // ApproxPercentile::create_accumulator. + Ok(match &self.return_type { + DataType::Int8 => ScalarValue::Int8(Some(q as i8)), + DataType::Int16 => ScalarValue::Int16(Some(q as i16)), + DataType::Int32 => ScalarValue::Int32(Some(q as i32)), + DataType::Int64 => ScalarValue::Int64(Some(q as i64)), + DataType::UInt8 => ScalarValue::UInt8(Some(q as u8)), + DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)), + DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)), + DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)), + DataType::Float32 => ScalarValue::Float32(Some(q as f32)), + DataType::Float64 => ScalarValue::Float64(Some(q as f64)), + v => unreachable!("unexpected return type {:?}", v), + }) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + }; + + let states = (0..states[0].len()) + .map(|index| { + states + .iter() + .map(|array| ScalarValue::try_from_array(array, index)) + .collect::>>() + .map(|state| TDigest::from_scalar_state(&state)) + }) + .chain(iter::once(Ok(self.digest.clone()))) + .collect::>>()?; + + self.digest = TDigest::merge_digests(&states); + + Ok(()) + } +} diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index ca14d7fa1a8d..9344fbd6b1bc 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -26,6 +26,7 @@ use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::record_batch::RecordBatch; mod approx_distinct; +mod approx_percentile_cont; mod array_agg; mod average; #[macro_use] @@ -64,6 +65,9 @@ pub mod helpers { } pub use approx_distinct::ApproxDistinct; +pub use approx_percentile_cont::{ + is_approx_percentile_cont_supported_arg_type, ApproxPercentileCont, +}; pub use array_agg::ArrayAgg; pub(crate) use average::is_avg_support_arg_type; pub use average::{avg_return_type, Avg, AvgAccumulator}; diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index 24aa6ad38339..725e475335ca 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -659,6 +659,7 @@ pub mod repartition; pub mod sorts; pub mod stream; pub mod string_expressions; +pub(crate) mod tdigest; pub mod type_coercion; pub mod udaf; pub mod udf; diff --git a/datafusion/src/physical_plan/tdigest/mod.rs b/datafusion/src/physical_plan/tdigest/mod.rs new file mode 100644 index 000000000000..6780adc84cd1 --- /dev/null +++ b/datafusion/src/physical_plan/tdigest/mod.rs @@ -0,0 +1,818 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with this +// work for additional information regarding copyright ownership. The ASF +// licenses this file to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +//! An implementation of the [TDigest sketch algorithm] providing approximate +//! quantile calculations. +//! +//! The TDigest code in this module is modified from +//! https://github.com/MnO2/t-digest, itself a rust reimplementation of +//! [Facebook's Folly TDigest] implementation. +//! +//! Alterations include reduction of runtime heap allocations, broader type +//! support, (de-)serialisation support, reduced type conversions and null value +//! tolerance. +//! +//! [TDigest sketch algorithm]: https://arxiv.org/abs/1902.04023 +//! [Facebook's Folly TDigest]: https://github.com/facebook/folly/blob/main/folly/stats/TDigest.h + +use arrow::datatypes::DataType; +use ordered_float::OrderedFloat; +use std::cmp::Ordering; + +use crate::{ + error::{DataFusionError, Result}, + scalar::ScalarValue, +}; + +// Cast a non-null [`ScalarValue::Float64`] to an [`OrderedFloat`], or +// panic. +macro_rules! cast_scalar_f64 { + ($value:expr ) => { + match &$value { + ScalarValue::Float64(Some(v)) => OrderedFloat::from(*v), + v => panic!("invalid type {:?}", v), + } + }; +} + +/// This trait is implemented for each type a [`TDigest`] can operate on, +/// allowing it to support both numerical rust types (obtained from +/// `PrimitiveArray` instances), and [`ScalarValue`] instances. +pub(crate) trait TryIntoOrderedF64 { + /// A fallible conversion of a possibly null `self` into a [`OrderedFloat`]. + /// + /// If `self` is null, this method must return `Ok(None)`. + /// + /// If `self` cannot be coerced to the desired type, this method must return + /// an `Err` variant. + fn try_as_f64(&self) -> Result>>; +} + +/// Generate an infallible conversion from `type` to an [`OrderedFloat`]. +macro_rules! impl_try_ordered_f64 { + ($type:ty) => { + impl TryIntoOrderedF64 for $type { + fn try_as_f64(&self) -> Result>> { + Ok(Some(OrderedFloat::from(*self as f64))) + } + } + }; +} + +impl_try_ordered_f64!(f64); +impl_try_ordered_f64!(f32); +impl_try_ordered_f64!(i64); +impl_try_ordered_f64!(i32); +impl_try_ordered_f64!(i16); +impl_try_ordered_f64!(i8); +impl_try_ordered_f64!(u64); +impl_try_ordered_f64!(u32); +impl_try_ordered_f64!(u16); +impl_try_ordered_f64!(u8); + +impl TryIntoOrderedF64 for ScalarValue { + fn try_as_f64(&self) -> Result>> { + match self { + ScalarValue::Float32(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::Float64(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::Int8(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::Int16(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::Int32(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::Int64(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::UInt8(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::UInt16(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::UInt32(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::UInt64(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + + got => { + return Err(DataFusionError::NotImplemented(format!( + "Support for 'APPROX_PERCENTILE_CONT' for data type {} is not implemented", + got + ))) + } + } + } +} + +/// Centroid implementation to the cluster mentioned in the paper. +#[derive(Debug, PartialEq, Eq, Clone)] +pub(crate) struct Centroid { + mean: OrderedFloat, + weight: OrderedFloat, +} + +impl PartialOrd for Centroid { + fn partial_cmp(&self, other: &Centroid) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Centroid { + fn cmp(&self, other: &Centroid) -> Ordering { + self.mean.cmp(&other.mean) + } +} + +impl Centroid { + pub(crate) fn new( + mean: impl Into>, + weight: impl Into>, + ) -> Self { + Centroid { + mean: mean.into(), + weight: weight.into(), + } + } + + #[inline] + pub(crate) fn mean(&self) -> OrderedFloat { + self.mean + } + + #[inline] + pub(crate) fn weight(&self) -> OrderedFloat { + self.weight + } + + pub(crate) fn add( + &mut self, + sum: impl Into>, + weight: impl Into>, + ) -> f64 { + let new_sum = sum.into() + self.weight * self.mean; + let new_weight = self.weight + weight.into(); + self.weight = new_weight; + self.mean = new_sum / new_weight; + new_sum.into_inner() + } +} + +impl Default for Centroid { + fn default() -> Self { + Centroid { + mean: OrderedFloat::from(0.0), + weight: OrderedFloat::from(1.0), + } + } +} + +/// T-Digest to be operated on. +#[derive(Debug, PartialEq, Eq, Clone)] +pub(crate) struct TDigest { + centroids: Vec, + max_size: usize, + sum: OrderedFloat, + count: OrderedFloat, + max: OrderedFloat, + min: OrderedFloat, +} + +impl TDigest { + pub(crate) fn new(max_size: usize) -> Self { + TDigest { + centroids: Vec::new(), + max_size, + sum: OrderedFloat::from(0.0), + count: OrderedFloat::from(0.0), + max: OrderedFloat::from(std::f64::NAN), + min: OrderedFloat::from(std::f64::NAN), + } + } + + #[inline] + pub(crate) fn count(&self) -> f64 { + self.count.into_inner() + } + + #[inline] + pub(crate) fn max(&self) -> f64 { + self.max.into_inner() + } + + #[inline] + pub(crate) fn min(&self) -> f64 { + self.min.into_inner() + } + + #[inline] + pub(crate) fn max_size(&self) -> usize { + self.max_size + } +} + +impl Default for TDigest { + fn default() -> Self { + TDigest { + centroids: Vec::new(), + max_size: 100, + sum: OrderedFloat::from(0.0), + count: OrderedFloat::from(0.0), + max: OrderedFloat::from(std::f64::NAN), + min: OrderedFloat::from(std::f64::NAN), + } + } +} + +impl TDigest { + fn k_to_q(k: f64, d: f64) -> OrderedFloat { + let k_div_d = k / d; + if k_div_d >= 0.5 { + let base = 1.0 - k_div_d; + 1.0 - 2.0 * base * base + } else { + 2.0 * k_div_d * k_div_d + } + .into() + } + + fn clamp( + v: OrderedFloat, + lo: OrderedFloat, + hi: OrderedFloat, + ) -> OrderedFloat { + if v > hi { + hi + } else if v < lo { + lo + } else { + v + } + } + + pub(crate) fn merge_unsorted( + &self, + unsorted_values: impl IntoIterator, + ) -> Result { + let mut values = unsorted_values + .into_iter() + .filter_map(|v| v.try_as_f64().transpose()) + .collect::>>()?; + + values.sort(); + + Ok(self.merge_sorted_f64(&values)) + } + + fn merge_sorted_f64(&self, sorted_values: &[OrderedFloat]) -> TDigest { + debug_assert!(is_sorted(sorted_values), "unsorted input to TDigest"); + + if sorted_values.is_empty() { + return self.clone(); + } + + let mut result = TDigest::new(self.max_size()); + result.count = OrderedFloat::from(self.count() + (sorted_values.len() as f64)); + + let maybe_min = *sorted_values.first().unwrap(); + let maybe_max = *sorted_values.last().unwrap(); + + if self.count() > 0.0 { + result.min = std::cmp::min(self.min, maybe_min); + result.max = std::cmp::max(self.max, maybe_max); + } else { + result.min = maybe_min; + result.max = maybe_max; + } + + let mut compressed: Vec = Vec::with_capacity(self.max_size); + + let mut k_limit: f64 = 1.0; + let mut q_limit_times_count = + Self::k_to_q(k_limit, self.max_size as f64) * result.count(); + k_limit += 1.0; + + let mut iter_centroids = self.centroids.iter().peekable(); + let mut iter_sorted_values = sorted_values.iter().peekable(); + + let mut curr: Centroid = if let Some(c) = iter_centroids.peek() { + let curr = **iter_sorted_values.peek().unwrap(); + if c.mean() < curr { + iter_centroids.next().unwrap().clone() + } else { + Centroid::new(*iter_sorted_values.next().unwrap(), 1.0) + } + } else { + Centroid::new(*iter_sorted_values.next().unwrap(), 1.0) + }; + + let mut weight_so_far = curr.weight(); + + let mut sums_to_merge = OrderedFloat::from(0.0); + let mut weights_to_merge = OrderedFloat::from(0.0); + + while iter_centroids.peek().is_some() || iter_sorted_values.peek().is_some() { + let next: Centroid = if let Some(c) = iter_centroids.peek() { + if iter_sorted_values.peek().is_none() + || c.mean() < **iter_sorted_values.peek().unwrap() + { + iter_centroids.next().unwrap().clone() + } else { + Centroid::new(*iter_sorted_values.next().unwrap(), 1.0) + } + } else { + Centroid::new(*iter_sorted_values.next().unwrap(), 1.0) + }; + + let next_sum = next.mean() * next.weight(); + weight_so_far += next.weight(); + + if weight_so_far <= q_limit_times_count { + sums_to_merge += next_sum; + weights_to_merge += next.weight(); + } else { + result.sum = OrderedFloat::from( + result.sum.into_inner() + curr.add(sums_to_merge, weights_to_merge), + ); + sums_to_merge = 0.0.into(); + weights_to_merge = 0.0.into(); + + compressed.push(curr.clone()); + q_limit_times_count = + Self::k_to_q(k_limit, self.max_size as f64) * result.count(); + k_limit += 1.0; + curr = next; + } + } + + result.sum = OrderedFloat::from( + result.sum.into_inner() + curr.add(sums_to_merge, weights_to_merge), + ); + compressed.push(curr); + compressed.shrink_to_fit(); + compressed.sort(); + + result.centroids = compressed; + result + } + + fn external_merge( + centroids: &mut Vec, + first: usize, + middle: usize, + last: usize, + ) { + let mut result: Vec = Vec::with_capacity(centroids.len()); + + let mut i = first; + let mut j = middle; + + while i < middle && j < last { + match centroids[i].cmp(¢roids[j]) { + Ordering::Less => { + result.push(centroids[i].clone()); + i += 1; + } + Ordering::Greater => { + result.push(centroids[j].clone()); + j += 1; + } + Ordering::Equal => { + result.push(centroids[i].clone()); + i += 1; + } + } + } + + while i < middle { + result.push(centroids[i].clone()); + i += 1; + } + + while j < last { + result.push(centroids[j].clone()); + j += 1; + } + + i = first; + for centroid in result.into_iter() { + centroids[i] = centroid; + i += 1; + } + } + + // Merge multiple T-Digests + pub(crate) fn merge_digests(digests: &[TDigest]) -> TDigest { + let n_centroids: usize = digests.iter().map(|d| d.centroids.len()).sum(); + if n_centroids == 0 { + return TDigest::default(); + } + + let max_size = digests.first().unwrap().max_size; + let mut centroids: Vec = Vec::with_capacity(n_centroids); + let mut starts: Vec = Vec::with_capacity(digests.len()); + + let mut count: f64 = 0.0; + let mut min = OrderedFloat::from(std::f64::INFINITY); + let mut max = OrderedFloat::from(std::f64::NEG_INFINITY); + + let mut start: usize = 0; + for digest in digests.iter() { + starts.push(start); + + let curr_count: f64 = digest.count(); + if curr_count > 0.0 { + min = std::cmp::min(min, digest.min); + max = std::cmp::max(max, digest.max); + count += curr_count; + for centroid in &digest.centroids { + centroids.push(centroid.clone()); + start += 1; + } + } + } + + let mut digests_per_block: usize = 1; + while digests_per_block < starts.len() { + for i in (0..starts.len()).step_by(digests_per_block * 2) { + if i + digests_per_block < starts.len() { + let first = starts[i]; + let middle = starts[i + digests_per_block]; + let last = if i + 2 * digests_per_block < starts.len() { + starts[i + 2 * digests_per_block] + } else { + centroids.len() + }; + + debug_assert!(first <= middle && middle <= last); + Self::external_merge(&mut centroids, first, middle, last); + } + } + + digests_per_block *= 2; + } + + let mut result = TDigest::new(max_size); + let mut compressed: Vec = Vec::with_capacity(max_size); + + let mut k_limit: f64 = 1.0; + let mut q_limit_times_count = + Self::k_to_q(k_limit, max_size as f64) * (count as f64); + + let mut iter_centroids = centroids.iter_mut(); + let mut curr = iter_centroids.next().unwrap(); + let mut weight_so_far = curr.weight(); + let mut sums_to_merge = OrderedFloat::from(0.0); + let mut weights_to_merge = OrderedFloat::from(0.0); + + for centroid in iter_centroids { + weight_so_far += centroid.weight(); + + if weight_so_far <= q_limit_times_count { + sums_to_merge += centroid.mean() * centroid.weight(); + weights_to_merge += centroid.weight(); + } else { + result.sum = OrderedFloat::from( + result.sum.into_inner() + curr.add(sums_to_merge, weights_to_merge), + ); + sums_to_merge = OrderedFloat::from(0.0); + weights_to_merge = OrderedFloat::from(0.0); + compressed.push(curr.clone()); + q_limit_times_count = + Self::k_to_q(k_limit, max_size as f64) * (count as f64); + k_limit += 1.0; + curr = centroid; + } + } + + result.sum = OrderedFloat::from( + result.sum.into_inner() + curr.add(sums_to_merge, weights_to_merge), + ); + compressed.push(curr.clone()); + compressed.shrink_to_fit(); + compressed.sort(); + + result.count = OrderedFloat::from(count as f64); + result.min = min; + result.max = max; + result.centroids = compressed; + result + } + + /// To estimate the value located at `q` quantile + pub(crate) fn estimate_quantile(&self, q: f64) -> f64 { + if self.centroids.is_empty() { + return 0.0; + } + + let count_ = self.count; + let rank = OrderedFloat::from(q) * count_; + + let mut pos: usize; + let mut t; + if q > 0.5 { + if q >= 1.0 { + return self.max(); + } + + pos = 0; + t = count_; + + for (k, centroid) in self.centroids.iter().enumerate().rev() { + t -= centroid.weight(); + + if rank >= t { + pos = k; + break; + } + } + } else { + if q <= 0.0 { + return self.min(); + } + + pos = self.centroids.len() - 1; + t = OrderedFloat::from(0.0); + + for (k, centroid) in self.centroids.iter().enumerate() { + if rank < t + centroid.weight() { + pos = k; + break; + } + + t += centroid.weight(); + } + } + + let mut delta = OrderedFloat::from(0.0); + let mut min = self.min; + let mut max = self.max; + + if self.centroids.len() > 1 { + if pos == 0 { + delta = self.centroids[pos + 1].mean() - self.centroids[pos].mean(); + max = self.centroids[pos + 1].mean(); + } else if pos == (self.centroids.len() - 1) { + delta = self.centroids[pos].mean() - self.centroids[pos - 1].mean(); + min = self.centroids[pos - 1].mean(); + } else { + delta = (self.centroids[pos + 1].mean() - self.centroids[pos - 1].mean()) + / 2.0; + min = self.centroids[pos - 1].mean(); + max = self.centroids[pos + 1].mean(); + } + } + + let value = self.centroids[pos].mean() + + ((rank - t) / self.centroids[pos].weight() - 0.5) * delta; + Self::clamp(value, min, max).into_inner() + } + + /// This method decomposes the [`TDigest`] and its [`Centroid`] instances + /// into a series of primitive scalar values. + /// + /// First the values of the TDigest are packed, followed by the variable + /// number of centroids packed into a [`ScalarValue::List`] of + /// [`ScalarValue::Float64`]: + /// + /// ```text + /// + /// ┌────────┬────────┬────────┬───────┬────────┬────────┐ + /// │max_size│ sum │ count │ max │ min │centroid│ + /// └────────┴────────┴────────┴───────┴────────┴────────┘ + /// │ + /// ┌─────────────────────┘ + /// ▼ + /// ┌ List ───┐ + /// │┌ ─ ─ ─ ┐│ + /// │ mean │ + /// │├ ─ ─ ─ ┼│─ ─ Centroid 1 + /// │ weight │ + /// │└ ─ ─ ─ ┘│ + /// │ │ + /// │┌ ─ ─ ─ ┐│ + /// │ mean │ + /// │├ ─ ─ ─ ┼│─ ─ Centroid 2 + /// │ weight │ + /// │└ ─ ─ ─ ┘│ + /// │ │ + /// ... + /// + /// ``` + /// + /// The [`TDigest::from_scalar_state()`] method reverses this processes, + /// consuming the output of this method and returning an unpacked + /// [`TDigest`]. + pub(crate) fn to_scalar_state(&self) -> Vec { + // Gather up all the centroids + let centroids: Vec<_> = self + .centroids + .iter() + .flat_map(|c| [c.mean().into_inner(), c.weight().into_inner()]) + .map(|v| ScalarValue::Float64(Some(v))) + .collect(); + + vec![ + ScalarValue::UInt64(Some(self.max_size as u64)), + ScalarValue::Float64(Some(self.sum.into_inner())), + ScalarValue::Float64(Some(self.count.into_inner())), + ScalarValue::Float64(Some(self.max.into_inner())), + ScalarValue::Float64(Some(self.min.into_inner())), + ScalarValue::List(Some(Box::new(centroids)), Box::new(DataType::Float64)), + ] + } + + /// Unpack the serialised state of a [`TDigest`] produced by + /// [`Self::to_scalar_state()`]. + /// + /// # Correctness + /// + /// Providing input to this method that was not obtained from + /// [`Self::to_scalar_state()`] results in undefined behaviour and may + /// panic. + pub(crate) fn from_scalar_state(state: &[ScalarValue]) -> Self { + assert_eq!(state.len(), 6, "invalid TDigest state"); + + let max_size = match &state[0] { + ScalarValue::UInt64(Some(v)) => *v as usize, + v => panic!("invalid max_size type {:?}", v), + }; + + let centroids: Vec<_> = match &state[5] { + ScalarValue::List(Some(c), d) if **d == DataType::Float64 => c + .chunks(2) + .map(|v| Centroid::new(cast_scalar_f64!(v[0]), cast_scalar_f64!(v[1]))) + .collect(), + v => panic!("invalid centroids type {:?}", v), + }; + + let max = cast_scalar_f64!(&state[3]); + let min = cast_scalar_f64!(&state[4]); + assert!(max >= min); + + Self { + max_size, + sum: cast_scalar_f64!(state[1]), + count: cast_scalar_f64!(&state[2]), + max, + min, + centroids, + } + } +} + +#[cfg(debug_assertions)] +fn is_sorted(values: &[OrderedFloat]) -> bool { + values.windows(2).all(|w| w[0] <= w[1]) +} + +#[cfg(test)] +mod tests { + use std::iter; + + use super::*; + + // A macro to assert the specified `quantile` estimated by `t` is within the + // allowable relative error bound. + macro_rules! assert_error_bounds { + ($t:ident, quantile = $quantile:literal, want = $want:literal) => { + assert_error_bounds!( + $t, + quantile = $quantile, + want = $want, + allowable_error = 0.01 + ) + }; + ($t:ident, quantile = $quantile:literal, want = $want:literal, allowable_error = $re:literal) => { + let ans = $t.estimate_quantile($quantile); + let expected: f64 = $want; + let percentage: f64 = (expected - ans).abs() / expected; + assert!( + percentage < $re, + "relative error {} is more than {}% (got quantile {}, want {})", + percentage, + $re, + ans, + expected + ); + }; + } + + macro_rules! assert_state_roundtrip { + ($t:ident) => { + let state = $t.to_scalar_state(); + let other = TDigest::from_scalar_state(&state); + assert_eq!($t, other); + }; + } + + #[test] + fn test_int64_uniform() { + let values = (1i64..=1000).map(|v| ScalarValue::Int64(Some(v))); + + let t = TDigest::new(100); + let t = t.merge_unsorted(values).unwrap(); + + assert_error_bounds!(t, quantile = 0.1, want = 100.0); + assert_error_bounds!(t, quantile = 0.5, want = 500.0); + assert_error_bounds!(t, quantile = 0.9, want = 900.0); + assert_state_roundtrip!(t); + } + + #[test] + fn test_int64_uniform_with_nulls() { + let values = (1i64..=1000).map(|v| ScalarValue::Int64(Some(v))); + // Prepend some NULLs + let values = iter::repeat(ScalarValue::Int64(None)) + .take(10) + .chain(values); + // Append some more NULLs + let values = values.chain(iter::repeat(ScalarValue::Int64(None)).take(10)); + + let t = TDigest::new(100); + let t = t.merge_unsorted(values).unwrap(); + + assert_error_bounds!(t, quantile = 0.1, want = 100.0); + assert_error_bounds!(t, quantile = 0.5, want = 500.0); + assert_error_bounds!(t, quantile = 0.9, want = 900.0); + assert_state_roundtrip!(t); + } + + #[test] + fn test_centroid_addition_regression() { + //https://github.com/MnO2/t-digest/pull/1 + + let vals = vec![1.0, 1.0, 1.0, 2.0, 1.0, 1.0]; + let mut t = TDigest::new(10); + + for v in vals { + t = t.merge_unsorted([ScalarValue::Float64(Some(v))]).unwrap(); + } + + assert_error_bounds!(t, quantile = 0.5, want = 1.0); + assert_error_bounds!(t, quantile = 0.95, want = 2.0); + assert_state_roundtrip!(t); + } + + #[test] + fn test_merge_unsorted_against_uniform_distro() { + let t = TDigest::new(100); + let values: Vec<_> = (1..=1_000_000) + .map(f64::from) + .map(|v| ScalarValue::Float64(Some(v))) + .collect(); + + let t = t.merge_unsorted(values).unwrap(); + + assert_error_bounds!(t, quantile = 1.0, want = 1_000_000.0); + assert_error_bounds!(t, quantile = 0.99, want = 990_000.0); + assert_error_bounds!(t, quantile = 0.01, want = 10_000.0); + assert_error_bounds!(t, quantile = 0.0, want = 1.0); + assert_error_bounds!(t, quantile = 0.5, want = 500_000.0); + assert_state_roundtrip!(t); + } + + #[test] + fn test_merge_unsorted_against_skewed_distro() { + let t = TDigest::new(100); + let mut values: Vec<_> = (1..=600_000) + .map(f64::from) + .map(|v| ScalarValue::Float64(Some(v))) + .collect(); + for _ in 0..400_000 { + values.push(ScalarValue::Float64(Some(1_000_000.0))); + } + + let t = t.merge_unsorted(values).unwrap(); + + assert_error_bounds!(t, quantile = 0.99, want = 1_000_000.0); + assert_error_bounds!(t, quantile = 0.01, want = 10_000.0); + assert_error_bounds!(t, quantile = 0.5, want = 500_000.0); + assert_state_roundtrip!(t); + } + + #[test] + fn test_merge_digests() { + let mut digests: Vec = Vec::new(); + + for _ in 1..=100 { + let t = TDigest::new(100); + let values: Vec<_> = (1..=1_000) + .map(f64::from) + .map(|v| ScalarValue::Float64(Some(v))) + .collect(); + let t = t.merge_unsorted(values).unwrap(); + digests.push(t) + } + + let t = TDigest::merge_digests(&digests); + + assert_error_bounds!(t, quantile = 1.0, want = 1000.0); + assert_error_bounds!(t, quantile = 0.99, want = 990.0); + assert_error_bounds!(t, quantile = 0.01, want = 10.0, allowable_error = 0.2); + assert_error_bounds!(t, quantile = 0.0, want = 1.0); + assert_error_bounds!(t, quantile = 0.5, want = 500.0); + assert_state_roundtrip!(t); + } +} diff --git a/datafusion/src/prelude.rs b/datafusion/src/prelude.rs index abc75829ea17..0aff006c7896 100644 --- a/datafusion/src/prelude.rs +++ b/datafusion/src/prelude.rs @@ -30,10 +30,10 @@ pub use crate::execution::context::{ExecutionConfig, ExecutionContext}; pub use crate::execution::options::AvroReadOptions; pub use crate::execution::options::{CsvReadOptions, NdJsonReadOptions}; pub use crate::logical_plan::{ - array, ascii, avg, bit_length, btrim, character_length, chr, col, concat, concat_ws, - count, create_udf, date_part, date_trunc, digest, in_list, initcap, left, length, - lit, lower, lpad, ltrim, max, md5, min, now, octet_length, random, regexp_match, - regexp_replace, repeat, replace, reverse, right, rpad, rtrim, sha224, sha256, sha384, - sha512, split_part, starts_with, strpos, substr, sum, to_hex, translate, trim, upper, - Column, JoinType, Partitioning, + approx_percentile_cont, array, ascii, avg, bit_length, btrim, character_length, chr, + col, concat, concat_ws, count, create_udf, date_part, date_trunc, digest, in_list, + initcap, left, length, lit, lower, lpad, ltrim, max, md5, min, now, octet_length, + random, regexp_match, regexp_replace, repeat, replace, reverse, right, rpad, rtrim, + sha224, sha256, sha384, sha512, split_part, starts_with, strpos, substr, sum, to_hex, + translate, trim, upper, Column, JoinType, Partitioning, }; diff --git a/datafusion/tests/dataframe_functions.rs b/datafusion/tests/dataframe_functions.rs index b8efc9815636..d5118b30d2af 100644 --- a/datafusion/tests/dataframe_functions.rs +++ b/datafusion/tests/dataframe_functions.rs @@ -153,6 +153,26 @@ async fn test_fn_btrim_with_chars() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_fn_approx_percentile_cont() -> Result<()> { + let expr = approx_percentile_cont(col("b"), lit(0.5)); + + let expected = vec![ + "+-------------------------------------------+", + "| APPROXPERCENTILECONT(test.b,Float64(0.5)) |", + "+-------------------------------------------+", + "| 10 |", + "+-------------------------------------------+", + ]; + + let df = create_test_table()?; + let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; + + assert_batches_eq!(expected, &batches); + + Ok(()) +} + #[tokio::test] async fn test_fn_character_length() -> Result<()> { let expr = character_length(col("a")); diff --git a/datafusion/tests/sql/aggregates.rs b/datafusion/tests/sql/aggregates.rs index 2d4287054388..a025d4eeec86 100644 --- a/datafusion/tests/sql/aggregates.rs +++ b/datafusion/tests/sql/aggregates.rs @@ -354,6 +354,95 @@ async fn csv_query_approx_count() -> Result<()> { Ok(()) } +// This test executes the APPROX_PERCENTILE_CONT aggregation against the test +// data, asserting the estimated quantiles are ±5% their actual values. +// +// Actual quantiles calculated with: +// +// ```r +// read_csv("./testing/data/csv/aggregate_test_100.csv") |> +// select_if(is.numeric) |> +// summarise_all(~ quantile(., c(0.1, 0.5, 0.9))) +// ``` +// +// Giving: +// +// ```text +// c2 c3 c4 c5 c6 c7 c8 c9 c10 c11 c12 +// +// 1 1 -95.3 -22925. -1882606710 -7.25e18 18.9 2671. 472608672. 1.83e18 0.109 0.0714 +// 2 3 15.5 4599 377164262 1.13e18 134. 30634 2365817608. 9.30e18 0.491 0.551 +// 3 5 102. 25334. 1991374996. 7.37e18 231 57518. 3776538487. 1.61e19 0.834 0.946 +// ``` +// +// Column `c12` is omitted due to a large relative error (~10%) due to the small +// float values. +#[tokio::test] +async fn csv_query_approx_percentile_cont() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + + // Generate an assertion that the estimated $percentile value for $column is + // within 5% of the $actual percentile value. + macro_rules! percentile_test { + ($ctx:ident, column=$column:literal, percentile=$percentile:literal, actual=$actual:literal) => { + let sql = format!("SELECT (ABS(1 - CAST(approx_percentile_cont({}, {}) AS DOUBLE) / {}) < 0.05) AS q FROM aggregate_test_100", $column, $percentile, $actual); + let actual = execute_to_batches(&mut ctx, &sql).await; + // + // "+------+", + // "| q |", + // "+------+", + // "| true |", + // "+------+", + // + let want = ["+------+", "| q |", "+------+", "| true |", "+------+"]; + assert_batches_eq!(want, &actual); + }; + } + + percentile_test!(ctx, column = "c2", percentile = 0.1, actual = 1.0); + percentile_test!(ctx, column = "c2", percentile = 0.5, actual = 3.0); + percentile_test!(ctx, column = "c2", percentile = 0.9, actual = 5.0); + //////////////////////////////////// + percentile_test!(ctx, column = "c3", percentile = 0.1, actual = -95.3); + percentile_test!(ctx, column = "c3", percentile = 0.5, actual = 15.5); + percentile_test!(ctx, column = "c3", percentile = 0.9, actual = 102.0); + //////////////////////////////////// + percentile_test!(ctx, column = "c4", percentile = 0.1, actual = -22925.0); + percentile_test!(ctx, column = "c4", percentile = 0.5, actual = 4599.0); + percentile_test!(ctx, column = "c4", percentile = 0.9, actual = 25334.0); + //////////////////////////////////// + percentile_test!(ctx, column = "c5", percentile = 0.1, actual = -1882606710.0); + percentile_test!(ctx, column = "c5", percentile = 0.5, actual = 377164262.0); + percentile_test!(ctx, column = "c5", percentile = 0.9, actual = 1991374996.0); + //////////////////////////////////// + percentile_test!(ctx, column = "c6", percentile = 0.1, actual = -7.25e18); + percentile_test!(ctx, column = "c6", percentile = 0.5, actual = 1.13e18); + percentile_test!(ctx, column = "c6", percentile = 0.9, actual = 7.37e18); + //////////////////////////////////// + percentile_test!(ctx, column = "c7", percentile = 0.1, actual = 18.9); + percentile_test!(ctx, column = "c7", percentile = 0.5, actual = 134.0); + percentile_test!(ctx, column = "c7", percentile = 0.9, actual = 231.0); + //////////////////////////////////// + percentile_test!(ctx, column = "c8", percentile = 0.1, actual = 2671.0); + percentile_test!(ctx, column = "c8", percentile = 0.5, actual = 30634.0); + percentile_test!(ctx, column = "c8", percentile = 0.9, actual = 57518.0); + //////////////////////////////////// + percentile_test!(ctx, column = "c9", percentile = 0.1, actual = 472608672.0); + percentile_test!(ctx, column = "c9", percentile = 0.5, actual = 2365817608.0); + percentile_test!(ctx, column = "c9", percentile = 0.9, actual = 3776538487.0); + //////////////////////////////////// + percentile_test!(ctx, column = "c10", percentile = 0.1, actual = 1.83e18); + percentile_test!(ctx, column = "c10", percentile = 0.5, actual = 9.30e18); + percentile_test!(ctx, column = "c10", percentile = 0.9, actual = 1.61e19); + //////////////////////////////////// + percentile_test!(ctx, column = "c11", percentile = 0.1, actual = 0.109); + percentile_test!(ctx, column = "c11", percentile = 0.5, actual = 0.491); + percentile_test!(ctx, column = "c11", percentile = 0.9, actual = 0.834); + + Ok(()) +} + #[tokio::test] async fn query_count_without_from() -> Result<()> { let mut ctx = ExecutionContext::new(); From 940d4eb60e76a3d4062489e872bf241dbfe0031a Mon Sep 17 00:00:00 2001 From: Kun Liu Date: Tue, 1 Feb 2022 05:13:19 +0800 Subject: [PATCH 17/50] suppport bitwise and as an example (#1653) * suppport bitwise and as an example * Use $OP in macro rather than `&` * fix: change signature to &dyn Array * fmt Co-authored-by: Andrew Lamb --- datafusion/src/logical_plan/operators.rs | 3 + .../coercion_rule/binary_rule.rs | 20 +++ .../src/physical_plan/expressions/binary.rs | 134 ++++++++++++++++++ datafusion/src/sql/planner.rs | 1 + 4 files changed, 158 insertions(+) diff --git a/datafusion/src/logical_plan/operators.rs b/datafusion/src/logical_plan/operators.rs index fdfd3f3ca267..14ccab0537bd 100644 --- a/datafusion/src/logical_plan/operators.rs +++ b/datafusion/src/logical_plan/operators.rs @@ -64,6 +64,8 @@ pub enum Operator { RegexNotMatch, /// Case insensitive regex not match RegexNotIMatch, + /// Bitwise and, like `&` + BitwiseAnd, } impl fmt::Display for Operator { @@ -90,6 +92,7 @@ impl fmt::Display for Operator { Operator::RegexNotIMatch => "!~*", Operator::IsDistinctFrom => "IS DISTINCT FROM", Operator::IsNotDistinctFrom => "IS NOT DISTINCT FROM", + Operator::BitwiseAnd => "&", }; write!(f, "{}", display) } diff --git a/datafusion/src/physical_plan/coercion_rule/binary_rule.rs b/datafusion/src/physical_plan/coercion_rule/binary_rule.rs index 982a4cb1bbc4..426d59f033e9 100644 --- a/datafusion/src/physical_plan/coercion_rule/binary_rule.rs +++ b/datafusion/src/physical_plan/coercion_rule/binary_rule.rs @@ -31,6 +31,7 @@ pub(crate) fn coerce_types( ) -> Result { // This result MUST be compatible with `binary_coerce` let result = match op { + Operator::BitwiseAnd => bitwise_coercion(lhs_type, rhs_type), Operator::And | Operator::Or => match (lhs_type, rhs_type) { // logical binary boolean operators can only be evaluated in bools (DataType::Boolean, DataType::Boolean) => Some(DataType::Boolean), @@ -72,6 +73,25 @@ pub(crate) fn coerce_types( } } +fn bitwise_coercion(left_type: &DataType, right_type: &DataType) -> Option { + use arrow::datatypes::DataType::*; + + if !is_numeric(left_type) || !is_numeric(right_type) { + return None; + } + if left_type == right_type && !is_dictionary(left_type) { + return Some(left_type.clone()); + } + // TODO support other data type + match (left_type, right_type) { + (Int64, _) | (_, Int64) => Some(Int64), + (Int32, _) | (_, Int32) => Some(Int32), + (Int16, _) | (_, Int16) => Some(Int16), + (Int8, _) | (_, Int8) => Some(Int8), + _ => None, + } +} + fn comparison_eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { // can't compare dictionaries directly due to // https://github.com/apache/arrow-rs/issues/1201 diff --git a/datafusion/src/physical_plan/expressions/binary.rs b/datafusion/src/physical_plan/expressions/binary.rs index 4680dd0a49d9..d1fc3bcdc029 100644 --- a/datafusion/src/physical_plan/expressions/binary.rs +++ b/datafusion/src/physical_plan/expressions/binary.rs @@ -348,6 +348,103 @@ fn modulus_decimal(left: &DecimalArray, right: &DecimalArray) -> Result {{ + let len = $LEFT.len(); + let left = $LEFT.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); + let right = $RIGHT.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); + let result = (0..len) + .into_iter() + .map(|i| { + if left.is_null(i) || right.is_null(i) { + None + } else { + Some(left.value(i) $OP right.value(i)) + } + }) + .collect::<$ARRAY_TYPE>(); + Ok(Arc::new(result)) + }}; +} + +/// The binary_bitwise_array_op macro only evaluates for integer types +/// like int64, int32. +/// It is used to do bitwise operation on an array with a scalar. +macro_rules! binary_bitwise_array_scalar { + ($LEFT:expr, $RIGHT:expr, $OP:tt, $ARRAY_TYPE:ident, $TYPE:ty) => {{ + let len = $LEFT.len(); + let array = $LEFT.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); + let scalar = $RIGHT; + if scalar.is_null() { + Ok(new_null_array(array.data_type(), len)) + } else { + let right: $TYPE = scalar.try_into().unwrap(); + let result = (0..len) + .into_iter() + .map(|i| { + if array.is_null(i) { + None + } else { + Some(array.value(i) $OP right) + } + }) + .collect::<$ARRAY_TYPE>(); + Ok(Arc::new(result) as ArrayRef) + } + }}; +} + +fn bitwise_and(left: ArrayRef, right: ArrayRef) -> Result { + match &left.data_type() { + DataType::Int8 => { + binary_bitwise_array_op!(left, right, &, Int8Array, i8) + } + DataType::Int16 => { + binary_bitwise_array_op!(left, right, &, Int16Array, i16) + } + DataType::Int32 => { + binary_bitwise_array_op!(left, right, &, Int32Array, i32) + } + DataType::Int64 => { + binary_bitwise_array_op!(left, right, &, Int64Array, i64) + } + other => Err(DataFusionError::Internal(format!( + "Data type {:?} not supported for binary operation '{}' on dyn arrays", + other, + Operator::BitwiseAnd + ))), + } +} + +fn bitwise_and_scalar( + array: &dyn Array, + scalar: ScalarValue, +) -> Option> { + let result = match array.data_type() { + DataType::Int8 => { + binary_bitwise_array_scalar!(array, scalar, &, Int8Array, i8) + } + DataType::Int16 => { + binary_bitwise_array_scalar!(array, scalar, &, Int16Array, i16) + } + DataType::Int32 => { + binary_bitwise_array_scalar!(array, scalar, &, Int32Array, i32) + } + DataType::Int64 => { + binary_bitwise_array_scalar!(array, scalar, &, Int64Array, i64) + } + other => Err(DataFusionError::Internal(format!( + "Data type {:?} not supported for binary operation '{}' on dyn arrays", + other, + Operator::BitwiseAnd + ))), + }; + Some(result) +} + /// Binary expression #[derive(Debug)] pub struct BinaryExpr { @@ -880,6 +977,8 @@ pub fn binary_operator_data_type( | Operator::RegexNotIMatch | Operator::IsDistinctFrom | Operator::IsNotDistinctFrom => Ok(DataType::Boolean), + // bitwise operations return the common coerced type + Operator::BitwiseAnd => Ok(result_type), // math operations return the same value as the common coerced type Operator::Plus | Operator::Minus @@ -1055,6 +1154,7 @@ impl BinaryExpr { true, true ), + Operator::BitwiseAnd => bitwise_and_scalar(array, scalar.clone()), // if scalar operation is not supported - fallback to array implementation _ => None, }; @@ -1143,6 +1243,7 @@ impl BinaryExpr { Operator::RegexNotIMatch => { binary_string_array_flag_op!(left, right, regexp_is_match, true, true) } + Operator::BitwiseAnd => bitwise_and(left, right), } } } @@ -1580,6 +1681,18 @@ mod tests { DataType::Boolean, vec![false, false, false, false, true] ); + test_coercion!( + Int16Array, + DataType::Int16, + vec![1i16, 2i16, 3i16], + Int64Array, + DataType::Int64, + vec![10i64, 4i64, 5i64], + Operator::BitwiseAnd, + Int64Array, + DataType::Int64, + vec![0i64, 0i64, 1i64] + ); Ok(()) } @@ -2954,4 +3067,25 @@ mod tests { Ok(()) } + + #[test] + fn bitwise_array_test() -> Result<()> { + let left = Arc::new(Int32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef; + let right = + Arc::new(Int32Array::from(vec![Some(1), Some(3), Some(7)])) as ArrayRef; + let result = bitwise_and(left, right)?; + let expected = Int32Array::from(vec![Some(0), None, Some(3)]); + assert_eq!(result.as_ref(), &expected); + Ok(()) + } + + #[test] + fn bitwise_scalar_test() -> Result<()> { + let left = Arc::new(Int32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef; + let right = ScalarValue::from(3i32); + let result = bitwise_and_scalar(&left, right).unwrap()?; + let expected = Int32Array::from(vec![Some(0), None, Some(3)]); + assert_eq!(result.as_ref(), &expected); + Ok(()) + } } diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index a74c44665de1..462977274ecb 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -1276,6 +1276,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { BinaryOperator::PGRegexIMatch => Ok(Operator::RegexIMatch), BinaryOperator::PGRegexNotMatch => Ok(Operator::RegexNotMatch), BinaryOperator::PGRegexNotIMatch => Ok(Operator::RegexNotIMatch), + BinaryOperator::BitwiseAnd => Ok(Operator::BitwiseAnd), _ => Err(DataFusionError::NotImplemented(format!( "Unsupported SQL binary operator {:?}", op From b6ace166e7fa2657791f20a4990ff07c9f95c4c2 Mon Sep 17 00:00:00 2001 From: Dmitry Patsura Date: Tue, 1 Feb 2022 00:47:18 +0300 Subject: [PATCH 18/50] fix: substr - correct behaivour with negative start pos (#1660) --- datafusion/src/physical_plan/functions.rs | 67 +++++++++++++++++++ .../src/physical_plan/unicode_expressions.rs | 28 +++++--- 2 files changed, 85 insertions(+), 10 deletions(-) diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs index 644defce1545..7d7cda75e867 100644 --- a/datafusion/src/physical_plan/functions.rs +++ b/datafusion/src/physical_plan/functions.rs @@ -3582,6 +3582,18 @@ mod tests { StringArray ); #[cfg(feature = "unicode_expressions")] + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("joséésoj".to_string()))), + lit(ScalarValue::Int64(Some(-5))), + ], + Ok(Some("joséésoj")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] test_function!( Substr, &[ @@ -3680,6 +3692,61 @@ mod tests { StringArray ); #[cfg(feature = "unicode_expressions")] + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(0))), + lit(ScalarValue::Int64(Some(5))), + ], + Ok(Some("alph")), + &str, + Utf8, + StringArray + ); + // starting from 5 (10 + -5) + #[cfg(feature = "unicode_expressions")] + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(-5))), + lit(ScalarValue::Int64(Some(10))), + ], + Ok(Some("alph")), + &str, + Utf8, + StringArray + ); + // starting from -1 (4 + -5) + #[cfg(feature = "unicode_expressions")] + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(-5))), + lit(ScalarValue::Int64(Some(4))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + // starting from 0 (5 + -5) + #[cfg(feature = "unicode_expressions")] + test_function!( + Substr, + &[ + lit(ScalarValue::Utf8(Some("alphabet".to_string()))), + lit(ScalarValue::Int64(Some(-5))), + lit(ScalarValue::Int64(Some(5))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] test_function!( Substr, &[ diff --git a/datafusion/src/physical_plan/unicode_expressions.rs b/datafusion/src/physical_plan/unicode_expressions.rs index 5a20d0599e3e..1d2c4b765a2d 100644 --- a/datafusion/src/physical_plan/unicode_expressions.rs +++ b/datafusion/src/physical_plan/unicode_expressions.rs @@ -457,21 +457,29 @@ pub fn substr(args: &[ArrayRef]) -> Result { start, count ))) - } else if start <= 0 { - Ok(Some(string.to_string())) } else { let graphemes = string.graphemes(true).collect::>(); - let start_pos = start as usize - 1; - let count_usize = count as usize; - if graphemes.len() < start_pos { + let (start_pos, end_pos) = if start <= 0 { + let end_pos = start + count - 1; + ( + 0_usize, + if end_pos < 0 { + // we use 0 as workaround for usize to return empty string + 0 + } else { + end_pos as usize + }, + ) + } else { + ((start - 1) as usize, (start + count - 1) as usize) + }; + + if end_pos == 0 || graphemes.len() < start_pos { Ok(Some("".to_string())) - } else if graphemes.len() < start_pos + count_usize { + } else if graphemes.len() < end_pos { Ok(Some(graphemes[start_pos..].concat())) } else { - Ok(Some( - graphemes[start_pos..start_pos + count_usize] - .concat(), - )) + Ok(Some(graphemes[start_pos..end_pos].concat())) } } } From bacf10df49372cfff4764c59ef034008d8ef8445 Mon Sep 17 00:00:00 2001 From: "xudong.w" Date: Wed, 2 Feb 2022 01:56:34 +0800 Subject: [PATCH 19/50] minor: fix cargo run --release error (#1723) --- datafusion/src/physical_plan/tdigest/mod.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/src/physical_plan/tdigest/mod.rs b/datafusion/src/physical_plan/tdigest/mod.rs index 6780adc84cd1..603cfd867c48 100644 --- a/datafusion/src/physical_plan/tdigest/mod.rs +++ b/datafusion/src/physical_plan/tdigest/mod.rs @@ -266,6 +266,7 @@ impl TDigest { } fn merge_sorted_f64(&self, sorted_values: &[OrderedFloat]) -> TDigest { + #[cfg(debug_assertions)] debug_assert!(is_sorted(sorted_values), "unsorted input to TDigest"); if sorted_values.is_empty() { From b9a8f151893b2487ee4d5e9018af71713565e26d Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Tue, 1 Feb 2022 20:28:56 +0000 Subject: [PATCH 20/50] Convert boolean case expressions to boolean logic (#1719) * Convert boolean case expressions to boolean logic * Review feedback --- .../src/optimizer/simplify_expressions.rs | 123 +++++++++++++++++- 1 file changed, 120 insertions(+), 3 deletions(-) diff --git a/datafusion/src/optimizer/simplify_expressions.rs b/datafusion/src/optimizer/simplify_expressions.rs index 6f5235e852b7..00739ccff5ac 100644 --- a/datafusion/src/optimizer/simplify_expressions.rs +++ b/datafusion/src/optimizer/simplify_expressions.rs @@ -655,6 +655,54 @@ impl<'a> ExprRewriter for Simplifier<'a> { _ => unreachable!(), }, + // + // Rules for Case + // + + // CASE + // WHEN X THEN A + // WHEN Y THEN B + // ... + // ELSE Q + // END + // + // ---> (X AND A) OR (Y AND B AND NOT X) OR ... (NOT (X OR Y) AND Q) + // + // Note: the rationale for this rewrite is that the expr can then be further + // simplified using the existing rules for AND/OR + Case { + expr: None, + when_then_expr, + else_expr, + } if !when_then_expr.is_empty() + && when_then_expr.len() < 3 // The rewrite is O(n!) so limit to small number + && self.is_boolean_type(&when_then_expr[0].1) => + { + // The disjunction of all the when predicates encountered so far + let mut filter_expr = lit(false); + // The disjunction of all the cases + let mut out_expr = lit(false); + + for (when, then) in when_then_expr { + let case_expr = when + .as_ref() + .clone() + .and(filter_expr.clone().not()) + .and(*then); + + out_expr = out_expr.or(case_expr); + filter_expr = filter_expr.or(*when); + } + + if let Some(else_expr) = else_expr { + let case_expr = filter_expr.not().and(*else_expr); + out_expr = out_expr.or(case_expr); + } + + // Do a first pass at simplification + out_expr.rewrite(self)? + } + expr => { // no additional rewrites possible expr @@ -1169,6 +1217,8 @@ mod tests { .expect("expected to simplify") .rewrite(&mut const_evaluator) .expect("expected to const evaluate") + .rewrite(&mut rewriter) + .expect("expected to simplify") } fn expr_test_schema() -> DFSchemaRef { @@ -1285,6 +1335,11 @@ mod tests { #[test] fn simplify_expr_case_when_then_else() { + // CASE WHERE c2 != false THEN "ok" == "not_ok" ELSE c2 == true + // --> + // CASE WHERE c2 THEN false ELSE c2 + // --> + // false assert_eq!( simplify(Expr::Case { expr: None, @@ -1294,11 +1349,73 @@ mod tests { )], else_expr: Some(Box::new(col("c2").eq(lit(true)))), }), - Expr::Case { + col("c2").not().and(col("c2")) // #1716 + ); + + // CASE WHERE c2 != false THEN "ok" == "ok" ELSE c2 + // --> + // CASE WHERE c2 THEN true ELSE c2 + // --> + // c2 + assert_eq!( + simplify(Expr::Case { expr: None, - when_then_expr: vec![(Box::new(col("c2")), Box::new(lit(false)))], + when_then_expr: vec![( + Box::new(col("c2").not_eq(lit(false))), + Box::new(lit("ok").eq(lit("ok"))), + )], + else_expr: Some(Box::new(col("c2").eq(lit(true)))), + }), + col("c2").or(col("c2").not().and(col("c2"))) // #1716 + ); + + // CASE WHERE ISNULL(c2) THEN true ELSE c2 + // --> + // ISNULL(c2) OR c2 + assert_eq!( + simplify(Expr::Case { + expr: None, + when_then_expr: vec![( + Box::new(col("c2").is_null()), + Box::new(lit(true)), + )], else_expr: Some(Box::new(col("c2"))), - } + }), + col("c2") + .is_null() + .or(col("c2").is_null().not().and(col("c2"))) + ); + + // CASE WHERE c1 then true WHERE c2 then false ELSE true + // --> c1 OR (NOT(c1) AND c2 AND FALSE) OR (NOT(c1 OR c2) AND TRUE) + // --> c1 OR (NOT(c1 OR c2)) + // --> NOT(c1) AND c2 + assert_eq!( + simplify(Expr::Case { + expr: None, + when_then_expr: vec![ + (Box::new(col("c1")), Box::new(lit(true)),), + (Box::new(col("c2")), Box::new(lit(false)),) + ], + else_expr: Some(Box::new(lit(true))), + }), + col("c1").or(col("c1").or(col("c2")).not()) + ); + + // CASE WHERE c1 then true WHERE c2 then true ELSE false + // --> c1 OR (NOT(c1) AND c2 AND TRUE) OR (NOT(c1 OR c2) AND FALSE) + // --> c1 OR (NOT(c1) AND c2) + // --> c1 OR c2 + assert_eq!( + simplify(Expr::Case { + expr: None, + when_then_expr: vec![ + (Box::new(col("c1")), Box::new(lit(true)),), + (Box::new(col("c2")), Box::new(lit(false)),) + ], + else_expr: Some(Box::new(lit(true))), + }), + col("c1").or(col("c1").or(col("c2")).not()) ); } From 46879f16186c52fa3b3dab9c61edbc5a5bd920dd Mon Sep 17 00:00:00 2001 From: "xudong.w" Date: Wed, 2 Feb 2022 19:20:44 +0800 Subject: [PATCH 21/50] substitute `parking_lot::Mutex` for `std::sync::Mutex` (#1720) * Substitute parking_lot::Mutex for std::sync::Mutex * enable parking_lot feature in tokio --- ballista-examples/Cargo.toml | 2 +- ballista/rust/client/Cargo.toml | 1 + ballista/rust/client/src/context.rs | 19 +++--- ballista/rust/core/Cargo.toml | 2 + ballista/rust/core/src/client.rs | 5 +- .../src/execution_plans/shuffle_writer.rs | 3 +- ballista/rust/executor/Cargo.toml | 3 +- ballista/rust/scheduler/Cargo.toml | 1 + benchmarks/Cargo.toml | 2 +- benchmarks/src/bin/nyctaxi.rs | 2 +- benchmarks/src/bin/tpch.rs | 4 +- datafusion-cli/Cargo.toml | 2 +- datafusion-examples/Cargo.toml | 2 +- datafusion/Cargo.toml | 3 +- datafusion/benches/aggregate_query_sql.rs | 5 +- datafusion/benches/math_query_sql.rs | 5 +- datafusion/benches/sort_limit_query_sql.rs | 13 ++-- datafusion/benches/window_query_sql.rs | 5 +- datafusion/src/catalog/catalog.rs | 15 +++-- datafusion/src/catalog/schema.rs | 13 ++-- datafusion/src/datasource/object_store/mod.rs | 16 ++--- datafusion/src/execution/context.rs | 67 +++++++------------ datafusion/src/execution/dataframe_impl.rs | 15 +++-- datafusion/src/execution/disk_manager.rs | 6 +- datafusion/src/execution/memory_manager.rs | 27 ++++---- datafusion/src/physical_plan/cross_join.rs | 8 +-- datafusion/src/physical_plan/metrics/mod.rs | 7 +- datafusion/src/physical_plan/metrics/value.rs | 12 ++-- datafusion/src/physical_plan/sorts/mod.rs | 9 +-- .../sorts/sort_preserving_merge.rs | 5 +- datafusion/tests/custom_sources.rs | 4 +- datafusion/tests/parquet_pruning.rs | 2 +- datafusion/tests/sql/aggregates.rs | 2 +- datafusion/tests/sql/avro.rs | 2 +- datafusion/tests/sql/errors.rs | 2 +- datafusion/tests/sql/explain_analyze.rs | 6 +- datafusion/tests/sql/mod.rs | 2 +- datafusion/tests/sql/parquet.rs | 4 +- datafusion/tests/sql/timestamp.rs | 2 +- 39 files changed, 154 insertions(+), 151 deletions(-) diff --git a/ballista-examples/Cargo.toml b/ballista-examples/Cargo.toml index d5f7d65d83ef..063ef8ae4831 100644 --- a/ballista-examples/Cargo.toml +++ b/ballista-examples/Cargo.toml @@ -33,6 +33,6 @@ datafusion = { path = "../datafusion" } ballista = { path = "../ballista/rust/client", version = "0.6.0"} prost = "0.9" tonic = "0.6" -tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] } +tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] } futures = "0.3" num_cpus = "1.13.0" diff --git a/ballista/rust/client/Cargo.toml b/ballista/rust/client/Cargo.toml index aa8297f8d06d..4ec1abe77654 100644 --- a/ballista/rust/client/Cargo.toml +++ b/ballista/rust/client/Cargo.toml @@ -35,6 +35,7 @@ log = "0.4" tokio = "1.0" tempfile = "3" sqlparser = "0.13" +parking_lot = "0.11" datafusion = { path = "../../../datafusion", version = "6.0.0" } diff --git a/ballista/rust/client/src/context.rs b/ballista/rust/client/src/context.rs index 3fb347bddbce..4cd5a219461e 100644 --- a/ballista/rust/client/src/context.rs +++ b/ballista/rust/client/src/context.rs @@ -17,11 +17,12 @@ //! Distributed execution context. +use parking_lot::Mutex; use sqlparser::ast::Statement; use std::collections::HashMap; use std::fs; use std::path::PathBuf; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use ballista_core::config::BallistaConfig; use ballista_core::utils::create_df_ctx_with_ballista_query_planner; @@ -142,7 +143,7 @@ impl BallistaContext { // use local DataFusion context for now but later this might call the scheduler let mut ctx = { - let guard = self.state.lock().unwrap(); + let guard = self.state.lock(); create_df_ctx_with_ballista_query_planner( &guard.scheduler_host, guard.scheduler_port, @@ -162,7 +163,7 @@ impl BallistaContext { // use local DataFusion context for now but later this might call the scheduler let mut ctx = { - let guard = self.state.lock().unwrap(); + let guard = self.state.lock(); create_df_ctx_with_ballista_query_planner( &guard.scheduler_host, guard.scheduler_port, @@ -186,7 +187,7 @@ impl BallistaContext { // use local DataFusion context for now but later this might call the scheduler let mut ctx = { - let guard = self.state.lock().unwrap(); + let guard = self.state.lock(); create_df_ctx_with_ballista_query_planner( &guard.scheduler_host, guard.scheduler_port, @@ -203,7 +204,7 @@ impl BallistaContext { name: &str, table: Arc, ) -> Result<()> { - let mut state = self.state.lock().unwrap(); + let mut state = self.state.lock(); state.tables.insert(name.to_owned(), table); Ok(()) } @@ -280,7 +281,7 @@ impl BallistaContext { /// might require the schema to be inferred. pub async fn sql(&self, sql: &str) -> Result> { let mut ctx = { - let state = self.state.lock().unwrap(); + let state = self.state.lock(); create_df_ctx_with_ballista_query_planner( &state.scheduler_host, state.scheduler_port, @@ -291,7 +292,7 @@ impl BallistaContext { let is_show = self.is_show_statement(sql).await?; // the show tables、 show columns sql can not run at scheduler because the tables is store at client if is_show { - let state = self.state.lock().unwrap(); + let state = self.state.lock(); ctx = ExecutionContext::with_config( ExecutionConfig::new().with_information_schema( state.config.default_with_information_schema(), @@ -301,7 +302,7 @@ impl BallistaContext { // register tables with DataFusion context { - let state = self.state.lock().unwrap(); + let state = self.state.lock(); for (name, prov) in &state.tables { ctx.register_table( TableReference::Bare { table: name }, @@ -483,7 +484,7 @@ mod tests { .unwrap(); { - let mut guard = context.state.lock().unwrap(); + let mut guard = context.state.lock(); let csv_table = guard.tables.get("single_nan"); if let Some(table_provide) = csv_table { diff --git a/ballista/rust/core/Cargo.toml b/ballista/rust/core/Cargo.toml index fa68be6b0ead..043f79a962b2 100644 --- a/ballista/rust/core/Cargo.toml +++ b/ballista/rust/core/Cargo.toml @@ -48,6 +48,8 @@ parse_arg = "0.1.3" arrow-flight = { version = "8.0.0" } datafusion = { path = "../../../datafusion", version = "6.0.0" } +parking_lot = "0.11" + [dev-dependencies] tempfile = "3" diff --git a/ballista/rust/core/src/client.rs b/ballista/rust/core/src/client.rs index b40c8788320a..aae4b2bb1bb2 100644 --- a/ballista/rust/core/src/client.rs +++ b/ballista/rust/core/src/client.rs @@ -17,7 +17,8 @@ //! Client API for sending requests to executors. -use std::sync::{Arc, Mutex}; +use parking_lot::Mutex; +use std::sync::Arc; use std::{collections::HashMap, pin::Pin}; use std::{ convert::{TryFrom, TryInto}, @@ -154,7 +155,7 @@ impl Stream for FlightDataStream { self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - let mut stream = self.stream.lock().expect("mutex is bad"); + let mut stream = self.stream.lock(); stream.poll_next_unpin(cx).map(|x| match x { Some(flight_data_chunk_result) => { let converted_chunk = flight_data_chunk_result diff --git a/ballista/rust/core/src/execution_plans/shuffle_writer.rs b/ballista/rust/core/src/execution_plans/shuffle_writer.rs index 4f027c1f28bd..724bb3518d74 100644 --- a/ballista/rust/core/src/execution_plans/shuffle_writer.rs +++ b/ballista/rust/core/src/execution_plans/shuffle_writer.rs @@ -20,10 +20,11 @@ //! partition is re-partitioned and streamed to disk in Arrow IPC format. Future stages of the query //! will use the ShuffleReaderExec to read these results. +use parking_lot::Mutex; use std::fs::File; use std::iter::Iterator; use std::path::PathBuf; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use std::time::Instant; use std::{any::Any, pin::Pin}; diff --git a/ballista/rust/executor/Cargo.toml b/ballista/rust/executor/Cargo.toml index 4eba2a152328..fb456a93ddc6 100644 --- a/ballista/rust/executor/Cargo.toml +++ b/ballista/rust/executor/Cargo.toml @@ -41,11 +41,12 @@ futures = "0.3" log = "0.4" snmalloc-rs = {version = "0.2", features= ["cache-friendly"], optional = true} tempfile = "3" -tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread"] } +tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "parking_lot"] } tokio-stream = { version = "0.1", features = ["net"] } tonic = "0.6" uuid = { version = "0.8", features = ["v4"] } hyper = "0.14.4" +parking_lot = "0.11" [dev-dependencies] diff --git a/ballista/rust/scheduler/Cargo.toml b/ballista/rust/scheduler/Cargo.toml index 10b3723712da..fdeb7e726d57 100644 --- a/ballista/rust/scheduler/Cargo.toml +++ b/ballista/rust/scheduler/Cargo.toml @@ -53,6 +53,7 @@ tokio-stream = { version = "0.1", features = ["net"], optional = true } tonic = "0.6" tower = { version = "0.4" } warp = "0.3" +parking_lot = "0.11" [dev-dependencies] ballista-core = { path = "../core", version = "0.6.0" } diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index a8def45b53e0..cf55284bfa9b 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -35,7 +35,7 @@ snmalloc = ["snmalloc-rs"] datafusion = { path = "../datafusion" } ballista = { path = "../ballista/rust/client" } structopt = { version = "0.3", default-features = false } -tokio = { version = "^1.0", features = ["macros", "rt", "rt-multi-thread"] } +tokio = { version = "^1.0", features = ["macros", "rt", "rt-multi-thread", "parking_lot"] } futures = "0.3" env_logger = "0.9" mimalloc = { version = "0.1", optional = true, default-features = false } diff --git a/benchmarks/src/bin/nyctaxi.rs b/benchmarks/src/bin/nyctaxi.rs index ad2494c6aff2..49679f46d7eb 100644 --- a/benchmarks/src/bin/nyctaxi.rs +++ b/benchmarks/src/bin/nyctaxi.rs @@ -116,7 +116,7 @@ async fn datafusion_sql_benchmarks( } async fn execute_sql(ctx: &mut ExecutionContext, sql: &str, debug: bool) -> Result<()> { - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = ctx.state.lock().runtime_env.clone(); let plan = ctx.create_logical_plan(sql)?; let plan = ctx.optimize(&plan)?; if debug { diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index 59bb55162a8e..d8cf3fb9ef4b 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -265,7 +265,7 @@ async fn benchmark_datafusion(opt: DataFusionBenchmarkOpt) -> Result>, sql: &str) { let rt = Runtime::new().unwrap(); - let df = rt.block_on(ctx.lock().unwrap().sql(sql)).unwrap(); + let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); criterion::black_box(rt.block_on(df.collect()).unwrap()); } diff --git a/datafusion/benches/math_query_sql.rs b/datafusion/benches/math_query_sql.rs index 85d11fa67ab5..6195937dc4e5 100644 --- a/datafusion/benches/math_query_sql.rs +++ b/datafusion/benches/math_query_sql.rs @@ -19,7 +19,8 @@ extern crate criterion; use criterion::Criterion; -use std::sync::{Arc, Mutex}; +use parking_lot::Mutex; +use std::sync::Arc; use tokio::runtime::Runtime; @@ -40,7 +41,7 @@ fn query(ctx: Arc>, sql: &str) { let rt = Runtime::new().unwrap(); // execute the query - let df = rt.block_on(ctx.lock().unwrap().sql(sql)).unwrap(); + let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); rt.block_on(df.collect()).unwrap(); } diff --git a/datafusion/benches/sort_limit_query_sql.rs b/datafusion/benches/sort_limit_query_sql.rs index 41f8c171e236..3828014e9892 100644 --- a/datafusion/benches/sort_limit_query_sql.rs +++ b/datafusion/benches/sort_limit_query_sql.rs @@ -22,7 +22,8 @@ use datafusion::datasource::file_format::csv::CsvFormat; use datafusion::datasource::listing::{ListingOptions, ListingTable}; use datafusion::datasource::object_store::local::LocalFileSystem; -use std::sync::{Arc, Mutex}; +use parking_lot::Mutex; +use std::sync::Arc; extern crate arrow; extern crate datafusion; @@ -38,7 +39,7 @@ fn query(ctx: Arc>, sql: &str) { let rt = Runtime::new().unwrap(); // execute the query - let df = rt.block_on(ctx.lock().unwrap().sql(sql)).unwrap(); + let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); rt.block_on(df.collect()).unwrap(); } @@ -81,18 +82,18 @@ fn create_context() -> Arc> { rt.block_on(async { // create local execution context let mut ctx = ExecutionContext::new(); - ctx.state.lock().unwrap().config.target_partitions = 1; - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + ctx.state.lock().config.target_partitions = 1; + let runtime = ctx.state.lock().runtime_env.clone(); let mem_table = MemTable::load(Arc::new(csv), Some(partitions), runtime) .await .unwrap(); ctx.register_table("aggregate_test_100", Arc::new(mem_table)) .unwrap(); - ctx_holder.lock().unwrap().push(Arc::new(Mutex::new(ctx))) + ctx_holder.lock().push(Arc::new(Mutex::new(ctx))) }); - let ctx = ctx_holder.lock().unwrap().get(0).unwrap().clone(); + let ctx = ctx_holder.lock().get(0).unwrap().clone(); ctx } diff --git a/datafusion/benches/window_query_sql.rs b/datafusion/benches/window_query_sql.rs index bca4a38360fe..dad838eb7f62 100644 --- a/datafusion/benches/window_query_sql.rs +++ b/datafusion/benches/window_query_sql.rs @@ -25,12 +25,13 @@ use crate::criterion::Criterion; use data_utils::create_table_provider; use datafusion::error::Result; use datafusion::execution::context::ExecutionContext; -use std::sync::{Arc, Mutex}; +use parking_lot::Mutex; +use std::sync::Arc; use tokio::runtime::Runtime; fn query(ctx: Arc>, sql: &str) { let rt = Runtime::new().unwrap(); - let df = rt.block_on(ctx.lock().unwrap().sql(sql)).unwrap(); + let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); criterion::black_box(rt.block_on(df.collect()).unwrap()); } diff --git a/datafusion/src/catalog/catalog.rs b/datafusion/src/catalog/catalog.rs index 7dbfa5a80c3e..d5f509f62bcc 100644 --- a/datafusion/src/catalog/catalog.rs +++ b/datafusion/src/catalog/catalog.rs @@ -19,9 +19,10 @@ //! representing collections of named schemas. use crate::catalog::schema::SchemaProvider; +use parking_lot::RwLock; use std::any::Any; use std::collections::HashMap; -use std::sync::{Arc, RwLock}; +use std::sync::Arc; /// Represent a list of named catalogs pub trait CatalogList: Sync + Send { @@ -75,17 +76,17 @@ impl CatalogList for MemoryCatalogList { name: String, catalog: Arc, ) -> Option> { - let mut catalogs = self.catalogs.write().unwrap(); + let mut catalogs = self.catalogs.write(); catalogs.insert(name, catalog) } fn catalog_names(&self) -> Vec { - let catalogs = self.catalogs.read().unwrap(); + let catalogs = self.catalogs.read(); catalogs.keys().map(|s| s.to_string()).collect() } fn catalog(&self, name: &str) -> Option> { - let catalogs = self.catalogs.read().unwrap(); + let catalogs = self.catalogs.read(); catalogs.get(name).cloned() } } @@ -129,7 +130,7 @@ impl MemoryCatalogProvider { name: impl Into, schema: Arc, ) -> Option> { - let mut schemas = self.schemas.write().unwrap(); + let mut schemas = self.schemas.write(); schemas.insert(name.into(), schema) } } @@ -140,12 +141,12 @@ impl CatalogProvider for MemoryCatalogProvider { } fn schema_names(&self) -> Vec { - let schemas = self.schemas.read().unwrap(); + let schemas = self.schemas.read(); schemas.keys().cloned().collect() } fn schema(&self, name: &str) -> Option> { - let schemas = self.schemas.read().unwrap(); + let schemas = self.schemas.read(); schemas.get(name).cloned() } } diff --git a/datafusion/src/catalog/schema.rs b/datafusion/src/catalog/schema.rs index 08707ea3347a..60894d1098d0 100644 --- a/datafusion/src/catalog/schema.rs +++ b/datafusion/src/catalog/schema.rs @@ -18,9 +18,10 @@ //! Describes the interface and built-in implementations of schemas, //! representing collections of named tables. +use parking_lot::RwLock; use std::any::Any; use std::collections::HashMap; -use std::sync::{Arc, RwLock}; +use std::sync::Arc; use crate::datasource::TableProvider; use crate::error::{DataFusionError, Result}; @@ -91,12 +92,12 @@ impl SchemaProvider for MemorySchemaProvider { } fn table_names(&self) -> Vec { - let tables = self.tables.read().unwrap(); + let tables = self.tables.read(); tables.keys().cloned().collect() } fn table(&self, name: &str) -> Option> { - let tables = self.tables.read().unwrap(); + let tables = self.tables.read(); tables.get(name).cloned() } @@ -111,17 +112,17 @@ impl SchemaProvider for MemorySchemaProvider { name ))); } - let mut tables = self.tables.write().unwrap(); + let mut tables = self.tables.write(); Ok(tables.insert(name, table)) } fn deregister_table(&self, name: &str) -> Result>> { - let mut tables = self.tables.write().unwrap(); + let mut tables = self.tables.write(); Ok(tables.remove(name)) } fn table_exist(&self, name: &str) -> bool { - let tables = self.tables.read().unwrap(); + let tables = self.tables.read(); tables.contains_key(name) } } diff --git a/datafusion/src/datasource/object_store/mod.rs b/datafusion/src/datasource/object_store/mod.rs index c77489689a86..4ca0d54c4092 100644 --- a/datafusion/src/datasource/object_store/mod.rs +++ b/datafusion/src/datasource/object_store/mod.rs @@ -19,11 +19,12 @@ pub mod local; +use parking_lot::RwLock; use std::collections::HashMap; use std::fmt::{self, Debug}; use std::io::Read; use std::pin::Pin; -use std::sync::{Arc, RwLock}; +use std::sync::Arc; use async_trait::async_trait; use chrono::{DateTime, Utc}; @@ -175,12 +176,7 @@ impl fmt::Debug for ObjectStoreRegistry { f.debug_struct("ObjectStoreRegistry") .field( "schemes", - &self - .object_stores - .read() - .unwrap() - .keys() - .collect::>(), + &self.object_stores.read().keys().collect::>(), ) .finish() } @@ -211,13 +207,13 @@ impl ObjectStoreRegistry { scheme: String, store: Arc, ) -> Option> { - let mut stores = self.object_stores.write().unwrap(); + let mut stores = self.object_stores.write(); stores.insert(scheme, store) } /// Get the store registered for scheme pub fn get(&self, scheme: &str) -> Option> { - let stores = self.object_stores.read().unwrap(); + let stores = self.object_stores.read(); stores.get(scheme).cloned() } @@ -231,7 +227,7 @@ impl ObjectStoreRegistry { uri: &'a str, ) -> Result<(Arc, &'a str)> { if let Some((scheme, path)) = uri.split_once("://") { - let stores = self.object_stores.read().unwrap(); + let stores = self.object_stores.read(); let store = stores .get(&*scheme.to_lowercase()) .map(Clone::clone) diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 023d3a0023be..deec84d5a0ff 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -39,13 +39,11 @@ use crate::{ }, }; use log::debug; +use parking_lot::Mutex; +use std::collections::{HashMap, HashSet}; use std::path::Path; use std::string::String; use std::sync::Arc; -use std::{ - collections::{HashMap, HashSet}, - sync::Mutex, -}; use std::{fs, path::PathBuf}; use futures::{StreamExt, TryStreamExt}; @@ -201,7 +199,7 @@ impl ExecutionContext { /// Return the [RuntimeEnv] used to run queries with this [ExecutionContext] pub fn runtime_env(&self) -> Arc { - self.state.lock().unwrap().runtime_env.clone() + self.state.lock().runtime_env.clone() } /// Creates a dataframe that will execute a SQL query. @@ -242,12 +240,7 @@ impl ExecutionContext { format: file_format, collect_stat: false, file_extension: file_extension.to_owned(), - target_partitions: self - .state - .lock() - .unwrap() - .config - .target_partitions, + target_partitions: self.state.lock().config.target_partitions, table_partition_cols: vec![], }; @@ -312,7 +305,7 @@ impl ExecutionContext { } // create a query planner - let state = self.state.lock().unwrap().clone(); + let state = self.state.lock().clone(); let query_planner = SqlToRel::new(&state); query_planner.statement_to_plan(&statements[0]) } @@ -325,7 +318,6 @@ impl ExecutionContext { ) { self.state .lock() - .unwrap() .execution_props .add_var_provider(variable_type, provider); } @@ -340,7 +332,6 @@ impl ExecutionContext { pub fn register_udf(&mut self, f: ScalarUDF) { self.state .lock() - .unwrap() .scalar_functions .insert(f.name.clone(), Arc::new(f)); } @@ -355,7 +346,6 @@ impl ExecutionContext { pub fn register_udaf(&mut self, f: AggregateUDF) { self.state .lock() - .unwrap() .aggregate_functions .insert(f.name.clone(), Arc::new(f)); } @@ -369,7 +359,7 @@ impl ExecutionContext { ) -> Result> { let uri: String = uri.into(); let (object_store, path) = self.object_store(&uri)?; - let target_partitions = self.state.lock().unwrap().config.target_partitions; + let target_partitions = self.state.lock().config.target_partitions; Ok(Arc::new(DataFrameImpl::new( self.state.clone(), &LogicalPlanBuilder::scan_avro( @@ -400,7 +390,7 @@ impl ExecutionContext { ) -> Result> { let uri: String = uri.into(); let (object_store, path) = self.object_store(&uri)?; - let target_partitions = self.state.lock().unwrap().config.target_partitions; + let target_partitions = self.state.lock().config.target_partitions; Ok(Arc::new(DataFrameImpl::new( self.state.clone(), &LogicalPlanBuilder::scan_csv( @@ -422,7 +412,7 @@ impl ExecutionContext { ) -> Result> { let uri: String = uri.into(); let (object_store, path) = self.object_store(&uri)?; - let target_partitions = self.state.lock().unwrap().config.target_partitions; + let target_partitions = self.state.lock().config.target_partitions; let logical_plan = LogicalPlanBuilder::scan_parquet(object_store, path, None, target_partitions) .await? @@ -477,8 +467,8 @@ impl ExecutionContext { uri: &str, options: CsvReadOptions<'_>, ) -> Result<()> { - let listing_options = options - .to_listing_options(self.state.lock().unwrap().config.target_partitions); + let listing_options = + options.to_listing_options(self.state.lock().config.target_partitions); self.register_listing_table( name, @@ -495,7 +485,7 @@ impl ExecutionContext { /// executed against this context. pub async fn register_parquet(&mut self, name: &str, uri: &str) -> Result<()> { let (target_partitions, enable_pruning) = { - let m = self.state.lock().unwrap(); + let m = self.state.lock(); (m.config.target_partitions, m.config.parquet_pruning) }; let file_format = ParquetFormat::default().with_enable_pruning(enable_pruning); @@ -521,8 +511,8 @@ impl ExecutionContext { uri: &str, options: AvroReadOptions<'_>, ) -> Result<()> { - let listing_options = options - .to_listing_options(self.state.lock().unwrap().config.target_partitions); + let listing_options = + options.to_listing_options(self.state.lock().config.target_partitions); self.register_listing_table(name, uri, listing_options, options.schema) .await?; @@ -542,7 +532,7 @@ impl ExecutionContext { ) -> Option> { let name = name.into(); - let state = self.state.lock().unwrap(); + let state = self.state.lock(); let catalog = if state.config.information_schema { Arc::new(CatalogWithInformationSchema::new( Arc::downgrade(&state.catalog_list), @@ -557,7 +547,7 @@ impl ExecutionContext { /// Retrieves a `CatalogProvider` instance by name pub fn catalog(&self, name: &str) -> Option> { - self.state.lock().unwrap().catalog_list.catalog(name) + self.state.lock().catalog_list.catalog(name) } /// Registers a object store with scheme using a custom `ObjectStore` so that @@ -573,7 +563,6 @@ impl ExecutionContext { self.state .lock() - .unwrap() .object_store_registry .register_store(scheme, object_store) } @@ -585,7 +574,6 @@ impl ExecutionContext { ) -> Result<(Arc, &'a str)> { self.state .lock() - .unwrap() .object_store_registry .get_by_uri(uri) .map_err(DataFusionError::from) @@ -605,7 +593,6 @@ impl ExecutionContext { let table_ref = table_ref.into(); self.state .lock() - .unwrap() .schema_for_ref(table_ref)? .register_table(table_ref.table().to_owned(), provider) } @@ -620,7 +607,6 @@ impl ExecutionContext { let table_ref = table_ref.into(); self.state .lock() - .unwrap() .schema_for_ref(table_ref)? .deregister_table(table_ref.table()) } @@ -634,7 +620,7 @@ impl ExecutionContext { table_ref: impl Into>, ) -> Result> { let table_ref = table_ref.into(); - let schema = self.state.lock().unwrap().schema_for_ref(table_ref)?; + let schema = self.state.lock().schema_for_ref(table_ref)?; match schema.table(table_ref.table()) { Some(ref provider) => { let plan = LogicalPlanBuilder::scan( @@ -664,7 +650,6 @@ impl ExecutionContext { Ok(self .state .lock() - .unwrap() // a bare reference will always resolve to the default catalog and schema .schema_for_ref(TableReference::Bare { table: "" })? .table_names() @@ -703,7 +688,7 @@ impl ExecutionContext { logical_plan: &LogicalPlan, ) -> Result> { let (state, planner) = { - let mut state = self.state.lock().unwrap(); + let mut state = self.state.lock(); state.execution_props.start_execution(); // We need to clone `state` to release the lock that is not `Send`. We could @@ -815,7 +800,7 @@ impl ExecutionContext { where F: FnMut(&LogicalPlan, &dyn OptimizerRule), { - let state = &mut self.state.lock().unwrap(); + let state = &mut self.state.lock(); let execution_props = &mut state.execution_props.clone(); let optimizers = &state.config.optimizers; @@ -840,15 +825,15 @@ impl From>> for ExecutionContext { impl FunctionRegistry for ExecutionContext { fn udfs(&self) -> HashSet { - self.state.lock().unwrap().udfs() + self.state.lock().udfs() } fn udf(&self, name: &str) -> Result> { - self.state.lock().unwrap().udf(name) + self.state.lock().udf(name) } fn udaf(&self, name: &str) -> Result> { - self.state.lock().unwrap().udaf(name) + self.state.lock().udaf(name) } } @@ -1512,7 +1497,7 @@ mod tests { let physical_plan = ctx.create_physical_plan(&logical_plan).await?; - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = ctx.state.lock().runtime_env.clone(); let results = collect_partitioned(physical_plan, runtime).await?; // note that the order of partitions is not deterministic @@ -1561,7 +1546,7 @@ mod tests { let tmp_dir = TempDir::new()?; let partition_count = 4; let ctx = create_ctx(&tmp_dir, partition_count).await?; - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = ctx.state.lock().runtime_env.clone(); let table = ctx.table("test")?; let logical_plan = LogicalPlanBuilder::from(table.to_logical_plan()) @@ -1669,7 +1654,7 @@ mod tests { assert_eq!(1, physical_plan.schema().fields().len()); assert_eq!("b", physical_plan.schema().field(0).name().as_str()); - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = ctx.state.lock().runtime_env.clone(); let batches = collect(physical_plan, runtime).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); @@ -3279,7 +3264,7 @@ mod tests { let plan = ctx.optimize(&plan)?; let plan = ctx.create_physical_plan(&plan).await?; - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = ctx.state.lock().runtime_env.clone(); let result = collect(plan, runtime).await?; let expected = vec![ @@ -3585,7 +3570,7 @@ mod tests { ctx.register_catalog("my_catalog", catalog); let catalog_list_weak = { - let state = ctx.state.lock().unwrap(); + let state = ctx.state.lock(); Arc::downgrade(&state.catalog_list) }; diff --git a/datafusion/src/execution/dataframe_impl.rs b/datafusion/src/execution/dataframe_impl.rs index d3f62bbb46db..3fcaa28af973 100644 --- a/datafusion/src/execution/dataframe_impl.rs +++ b/datafusion/src/execution/dataframe_impl.rs @@ -17,8 +17,9 @@ //! Implementation of DataFrame API. +use parking_lot::Mutex; use std::any::Any; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use crate::arrow::datatypes::Schema; use crate::arrow::datatypes::SchemaRef; @@ -61,7 +62,7 @@ impl DataFrameImpl { /// Create a physical plan async fn create_physical_plan(&self) -> Result> { - let state = self.ctx_state.lock().unwrap().clone(); + let state = self.ctx_state.lock().clone(); let ctx = ExecutionContext::from(Arc::new(Mutex::new(state))); let plan = ctx.optimize(&self.plan)?; ctx.create_physical_plan(&plan).await @@ -221,7 +222,7 @@ impl DataFrame for DataFrameImpl { /// execute it, collecting all resulting batches into memory async fn collect(&self) -> Result> { let plan = self.create_physical_plan().await?; - let runtime = self.ctx_state.lock().unwrap().runtime_env.clone(); + let runtime = self.ctx_state.lock().runtime_env.clone(); Ok(collect(plan, runtime).await?) } @@ -241,7 +242,7 @@ impl DataFrame for DataFrameImpl { /// execute it, returning a stream over a single partition async fn execute_stream(&self) -> Result { let plan = self.create_physical_plan().await?; - let runtime = self.ctx_state.lock().unwrap().runtime_env.clone(); + let runtime = self.ctx_state.lock().runtime_env.clone(); execute_stream(plan, runtime).await } @@ -250,7 +251,7 @@ impl DataFrame for DataFrameImpl { /// partitioning async fn collect_partitioned(&self) -> Result>> { let plan = self.create_physical_plan().await?; - let runtime = self.ctx_state.lock().unwrap().runtime_env.clone(); + let runtime = self.ctx_state.lock().runtime_env.clone(); Ok(collect_partitioned(plan, runtime).await?) } @@ -258,7 +259,7 @@ impl DataFrame for DataFrameImpl { /// execute it, returning a stream for each partition async fn execute_stream_partitioned(&self) -> Result> { let plan = self.create_physical_plan().await?; - let runtime = self.ctx_state.lock().unwrap().runtime_env.clone(); + let runtime = self.ctx_state.lock().runtime_env.clone(); Ok(execute_stream_partitioned(plan, runtime).await?) } @@ -275,7 +276,7 @@ impl DataFrame for DataFrameImpl { } fn registry(&self) -> Arc { - let registry = self.ctx_state.lock().unwrap().clone(); + let registry = self.ctx_state.lock().clone(); Arc::new(registry) } diff --git a/datafusion/src/execution/disk_manager.rs b/datafusion/src/execution/disk_manager.rs index 4486f53a21b8..c4fe6b4160fa 100644 --- a/datafusion/src/execution/disk_manager.rs +++ b/datafusion/src/execution/disk_manager.rs @@ -20,9 +20,10 @@ use crate::error::{DataFusionError, Result}; use log::debug; +use parking_lot::Mutex; use rand::{thread_rng, Rng}; +use std::path::PathBuf; use std::sync::Arc; -use std::{path::PathBuf, sync::Mutex}; use tempfile::{Builder, NamedTempFile, TempDir}; /// Configuration for temporary disk access @@ -95,7 +96,7 @@ impl DiskManager { /// Return a temporary file from a randomized choice in the configured locations pub fn create_tmp_file(&self) -> Result { - let mut local_dirs = self.local_dirs.lock().unwrap(); + let mut local_dirs = self.local_dirs.lock(); // Create a temporary directory if needed if local_dirs.is_empty() { @@ -169,7 +170,6 @@ mod tests { fn local_dir_snapshot(dm: &DiskManager) -> Vec { dm.local_dirs .lock() - .unwrap() .iter() .map(|p| p.path().into()) .collect() diff --git a/datafusion/src/execution/memory_manager.rs b/datafusion/src/execution/memory_manager.rs index 5015f466c674..d39eaab3c215 100644 --- a/datafusion/src/execution/memory_manager.rs +++ b/datafusion/src/execution/memory_manager.rs @@ -21,10 +21,11 @@ use crate::error::{DataFusionError, Result}; use async_trait::async_trait; use hashbrown::HashSet; use log::debug; +use parking_lot::{Condvar, Mutex}; use std::fmt; use std::fmt::{Debug, Display, Formatter}; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::{Arc, Condvar, Mutex}; +use std::sync::Arc; static CONSUMER_ID: AtomicUsize = AtomicUsize::new(0); @@ -302,12 +303,12 @@ impl MemoryManager { } fn get_requester_total(&self) -> usize { - *self.requesters_total.lock().unwrap() + *self.requesters_total.lock() } /// Register a new memory requester pub(crate) fn register_requester(&self, requester_id: &MemoryConsumerId) { - self.requesters.lock().unwrap().insert(requester_id.clone()); + self.requesters.lock().insert(requester_id.clone()); } fn max_mem_for_requesters(&self) -> usize { @@ -317,8 +318,8 @@ impl MemoryManager { /// Grow memory attempt from a consumer, return if we could grant that much to it async fn can_grow_directly(&self, required: usize, current: usize) -> bool { - let num_rqt = self.requesters.lock().unwrap().len(); - let mut rqt_current_used = self.requesters_total.lock().unwrap(); + let num_rqt = self.requesters.lock().len(); + let mut rqt_current_used = self.requesters_total.lock(); let mut rqt_max = self.max_mem_for_requesters(); let granted; @@ -339,7 +340,7 @@ impl MemoryManager { } else if current < min_per_rqt { // if we cannot acquire at lease 1/2n memory, just wait for others // to spill instead spill self frequently with limited total mem - rqt_current_used = self.cv.wait(rqt_current_used).unwrap(); + self.cv.wait(&mut rqt_current_used); } else { granted = false; break; @@ -351,8 +352,8 @@ impl MemoryManager { granted } - fn record_free_then_acquire(&self, freed: usize, acquired: usize) { - let mut requesters_total = self.requesters_total.lock().unwrap(); + fn record_free_then_acquire(&self, freed: usize, acquired: usize) -> usize { + let mut requesters_total = self.requesters_total.lock(); assert!(*requesters_total >= freed); *requesters_total -= freed; *requesters_total += acquired; @@ -363,9 +364,9 @@ impl MemoryManager { pub(crate) fn drop_consumer(&self, id: &MemoryConsumerId, mem_used: usize) { // find in requesters first { - let mut requesters = self.requesters.lock().unwrap(); + let mut requesters = self.requesters.lock(); if requesters.remove(id) { - let mut total = self.requesters_total.lock().unwrap(); + let mut total = self.requesters_total.lock(); assert!(*total >= mem_used); *total -= mem_used; } @@ -381,7 +382,7 @@ impl Display for MemoryManager { "MemoryManager usage statistics: total {}, trackers used {}, total {} requesters used: {}", human_readable_size(self.pool_size), human_readable_size(self.get_tracker_total()), - self.requesters.lock().unwrap().len(), + self.requesters.lock().len(), human_readable_size(self.get_requester_total()), ) } @@ -558,7 +559,7 @@ mod tests { requester1.do_with_mem(10).await.unwrap(); assert_eq!(requester1.get_spills(), 0); assert_eq!(requester1.mem_used(), 50); - assert_eq!(*runtime.memory_manager.requesters_total.lock().unwrap(), 50); + assert_eq!(*runtime.memory_manager.requesters_total.lock(), 50); let requester2 = DummyRequester::new(0, runtime.clone()); runtime.register_requester(requester2.id()); @@ -572,7 +573,7 @@ mod tests { assert_eq!(requester1.get_spills(), 1); assert_eq!(requester1.mem_used(), 10); - assert_eq!(*runtime.memory_manager.requesters_total.lock().unwrap(), 40); + assert_eq!(*runtime.memory_manager.requesters_total.lock(), 40); } #[tokio::test] diff --git a/datafusion/src/physical_plan/cross_join.rs b/datafusion/src/physical_plan/cross_join.rs index 48301f0916fe..e4369c180c85 100644 --- a/datafusion/src/physical_plan/cross_join.rs +++ b/datafusion/src/physical_plan/cross_join.rs @@ -192,7 +192,7 @@ impl ExecutionPlan for CrossJoinExec { schema: self.schema.clone(), left_data, right: stream, - right_batch: Arc::new(std::sync::Mutex::new(None)), + right_batch: Arc::new(parking_lot::Mutex::new(None)), left_index: 0, num_input_batches: 0, num_input_rows: 0, @@ -299,7 +299,7 @@ struct CrossJoinStream { /// Current value on the left left_index: usize, /// Current batch being processed from the right side - right_batch: Arc>>, + right_batch: Arc>>, /// number of input batches num_input_batches: usize, /// number of input rows @@ -354,7 +354,7 @@ impl Stream for CrossJoinStream { if self.left_index > 0 && self.left_index < self.left_data.num_rows() { let start = Instant::now(); let right_batch = { - let right_batch = self.right_batch.lock().unwrap(); + let right_batch = self.right_batch.lock(); right_batch.clone().unwrap() }; let result = @@ -389,7 +389,7 @@ impl Stream for CrossJoinStream { } self.left_index = 1; - let mut right_batch = self.right_batch.lock().unwrap(); + let mut right_batch = self.right_batch.lock(); *right_batch = Some(batch); Some(result) diff --git a/datafusion/src/physical_plan/metrics/mod.rs b/datafusion/src/physical_plan/metrics/mod.rs index e609beb08c37..021f2df823ae 100644 --- a/datafusion/src/physical_plan/metrics/mod.rs +++ b/datafusion/src/physical_plan/metrics/mod.rs @@ -23,10 +23,11 @@ mod composite; mod tracker; mod value; +use parking_lot::Mutex; use std::{ borrow::Cow, fmt::{Debug, Display}, - sync::{Arc, Mutex}, + sync::Arc, }; use hashbrown::HashMap; @@ -339,12 +340,12 @@ impl ExecutionPlanMetricsSet { /// Add the specified metric to the underlying metric set pub fn register(&self, metric: Arc) { - self.inner.lock().expect("not poisoned").push(metric) + self.inner.lock().push(metric) } /// Return a clone of the inner MetricsSet pub fn clone_inner(&self) -> MetricsSet { - let guard = self.inner.lock().expect("not poisoned"); + let guard = self.inner.lock(); (*guard).clone() } } diff --git a/datafusion/src/physical_plan/metrics/value.rs b/datafusion/src/physical_plan/metrics/value.rs index 6ac282a496ee..43a0ad236500 100644 --- a/datafusion/src/physical_plan/metrics/value.rs +++ b/datafusion/src/physical_plan/metrics/value.rs @@ -22,11 +22,13 @@ use std::{ fmt::Display, sync::{ atomic::{AtomicUsize, Ordering}, - Arc, Mutex, + Arc, }, time::{Duration, Instant}, }; +use parking_lot::Mutex; + use chrono::{DateTime, Utc}; /// A counter to record things such as number of input or output rows @@ -229,7 +231,7 @@ impl Timestamp { /// Sets the timestamps value to a specified time pub fn set(&self, now: DateTime) { - *self.timestamp.lock().unwrap() = Some(now); + *self.timestamp.lock() = Some(now); } /// return the timestamps value at the last time `record()` was @@ -237,7 +239,7 @@ impl Timestamp { /// /// Returns `None` if `record()` has not been called pub fn value(&self) -> Option> { - *self.timestamp.lock().unwrap() + *self.timestamp.lock() } /// sets the value of this timestamp to the minimum of this and other @@ -249,7 +251,7 @@ impl Timestamp { (Some(v1), Some(v2)) => Some(if v1 < v2 { v1 } else { v2 }), }; - *self.timestamp.lock().unwrap() = min; + *self.timestamp.lock() = min; } /// sets the value of this timestamp to the maximum of this and other @@ -261,7 +263,7 @@ impl Timestamp { (Some(v1), Some(v2)) => Some(if v1 < v2 { v2 } else { v1 }), }; - *self.timestamp.lock().unwrap() = max; + *self.timestamp.lock() = max; } } diff --git a/datafusion/src/physical_plan/sorts/mod.rs b/datafusion/src/physical_plan/sorts/mod.rs index 64ec29179b19..818546f316fc 100644 --- a/datafusion/src/physical_plan/sorts/mod.rs +++ b/datafusion/src/physical_plan/sorts/mod.rs @@ -28,11 +28,12 @@ use futures::channel::mpsc; use futures::stream::FusedStream; use futures::Stream; use hashbrown::HashMap; +use parking_lot::RwLock; use std::borrow::BorrowMut; use std::cmp::Ordering; use std::fmt::{Debug, Formatter}; use std::pin::Pin; -use std::sync::{Arc, RwLock}; +use std::sync::Arc; use std::task::{Context, Poll}; pub mod sort; @@ -135,7 +136,7 @@ impl SortKeyCursor { .collect::>(); self.init_cmp_if_needed(other, &zipped)?; - let map = self.batch_comparators.read().unwrap(); + let map = self.batch_comparators.read(); let cmp = map.get(&other.batch_id).ok_or_else(|| { DataFusionError::Execution(format!( "Failed to find comparator for {} cmp {}", @@ -172,10 +173,10 @@ impl SortKeyCursor { other: &SortKeyCursor, zipped: &[((&ArrayRef, &ArrayRef), &SortOptions)], ) -> Result<()> { - let hm = self.batch_comparators.read().unwrap(); + let hm = self.batch_comparators.read(); if !hm.contains_key(&other.batch_id) { drop(hm); - let mut map = self.batch_comparators.write().unwrap(); + let mut map = self.batch_comparators.write(); let cmp = map .borrow_mut() .entry(other.batch_id) diff --git a/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs b/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs index 7b9d5d5de328..ddc9ff1f9e47 100644 --- a/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs +++ b/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs @@ -21,11 +21,12 @@ use crate::physical_plan::common::AbortOnDropMany; use crate::physical_plan::metrics::{ ExecutionPlanMetricsSet, MemTrackingMetrics, MetricsSet, }; +use parking_lot::Mutex; use std::any::Any; use std::collections::{BinaryHeap, VecDeque}; use std::fmt::Debug; use std::pin::Pin; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use std::task::{Context, Poll}; use arrow::{ @@ -368,7 +369,7 @@ impl SortPreservingMergeStream { } let mut empty_batch = false; { - let mut streams = self.streams.streams.lock().unwrap(); + let mut streams = self.streams.streams.lock(); let stream = &mut streams[idx]; if stream.is_terminated() { diff --git a/datafusion/tests/custom_sources.rs b/datafusion/tests/custom_sources.rs index e069dd750c18..0e7f733232fa 100644 --- a/datafusion/tests/custom_sources.rs +++ b/datafusion/tests/custom_sources.rs @@ -242,7 +242,7 @@ async fn custom_source_dataframe() -> Result<()> { assert_eq!(1, physical_plan.schema().fields().len()); assert_eq!("c2", physical_plan.schema().field(0).name().as_str()); - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = ctx.state.lock().runtime_env.clone(); let batches = collect(physical_plan, runtime).await?; let origin_rec_batch = TEST_CUSTOM_RECORD_BATCH!()?; assert_eq!(1, batches.len()); @@ -289,7 +289,7 @@ async fn optimizers_catch_all_statistics() { ) .unwrap(); - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = ctx.state.lock().runtime_env.clone(); let actual = collect(physical_plan, runtime).await.unwrap(); assert_eq!(actual.len(), 1); diff --git a/datafusion/tests/parquet_pruning.rs b/datafusion/tests/parquet_pruning.rs index 9abf3fd55a64..9869a1f6b16a 100644 --- a/datafusion/tests/parquet_pruning.rs +++ b/datafusion/tests/parquet_pruning.rs @@ -537,7 +537,7 @@ impl ContextWithParquet { .await .expect("creating physical plan"); - let runtime = self.ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = self.ctx.state.lock().runtime_env.clone(); let results = datafusion::physical_plan::collect(physical_plan.clone(), runtime) .await .expect("Running"); diff --git a/datafusion/tests/sql/aggregates.rs b/datafusion/tests/sql/aggregates.rs index a025d4eeec86..fd1d15cc0ca7 100644 --- a/datafusion/tests/sql/aggregates.rs +++ b/datafusion/tests/sql/aggregates.rs @@ -26,7 +26,7 @@ async fn csv_query_avg_multi_batch() -> Result<()> { let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let plan = ctx.create_physical_plan(&plan).await.unwrap(); - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = ctx.state.lock().runtime_env.clone(); let results = collect(plan, runtime).await.unwrap(); let batch = &results[0]; let column = batch.column(0); diff --git a/datafusion/tests/sql/avro.rs b/datafusion/tests/sql/avro.rs index d0cdf71b0868..82d91a0bd481 100644 --- a/datafusion/tests/sql/avro.rs +++ b/datafusion/tests/sql/avro.rs @@ -124,7 +124,7 @@ async fn avro_single_nan_schema() { let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let plan = ctx.create_physical_plan(&plan).await.unwrap(); - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = ctx.state.lock().runtime_env.clone(); let results = collect(plan, runtime).await.unwrap(); for batch in results { assert_eq!(1, batch.num_rows()); diff --git a/datafusion/tests/sql/errors.rs b/datafusion/tests/sql/errors.rs index 05ca0642bae0..92b634dd5e96 100644 --- a/datafusion/tests/sql/errors.rs +++ b/datafusion/tests/sql/errors.rs @@ -37,7 +37,7 @@ async fn test_cast_expressions_error() -> Result<()> { let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let plan = ctx.create_physical_plan(&plan).await.unwrap(); - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = ctx.state.lock().runtime_env.clone(); let result = collect(plan, runtime).await; match result { diff --git a/datafusion/tests/sql/explain_analyze.rs b/datafusion/tests/sql/explain_analyze.rs index 2bd78ec728f5..2051bdd1b80b 100644 --- a/datafusion/tests/sql/explain_analyze.rs +++ b/datafusion/tests/sql/explain_analyze.rs @@ -41,7 +41,7 @@ async fn explain_analyze_baseline_metrics() { let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let physical_plan = ctx.create_physical_plan(&plan).await.unwrap(); - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = ctx.state.lock().runtime_env.clone(); let results = collect(physical_plan.clone(), runtime).await.unwrap(); let formatted = arrow::util::pretty::pretty_format_batches(&results) .unwrap() @@ -329,7 +329,7 @@ async fn csv_explain_plans() { // // Execute plan let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = ctx.state.lock().runtime_env.clone(); let results = collect(plan, runtime).await.expect(&msg); let actual = result_vec(&results); // flatten to a single string @@ -527,7 +527,7 @@ async fn csv_explain_verbose_plans() { // // Execute plan let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = ctx.state.lock().runtime_env.clone(); let results = collect(plan, runtime).await.expect(&msg); let actual = result_vec(&results); // flatten to a single string diff --git a/datafusion/tests/sql/mod.rs b/datafusion/tests/sql/mod.rs index 90fe5138ac44..ea6829969462 100644 --- a/datafusion/tests/sql/mod.rs +++ b/datafusion/tests/sql/mod.rs @@ -543,7 +543,7 @@ async fn execute_to_batches(ctx: &mut ExecutionContext, sql: &str) -> Vec Result<()> { let plan = ctx.create_physical_plan(&plan).await.expect(&msg); let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); - let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let runtime = ctx.state.lock().runtime_env.clone(); let res = collect(plan, runtime).await.expect(&msg); let actual = result_vec(&res); From e4a056f00c4c5a09bddc1d16b83f771926f7b4a9 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 2 Feb 2022 09:28:20 -0500 Subject: [PATCH 22/50] Add Expression Simplification API (#1717) * Add Expression Simplification API * fmt --- datafusion/src/logical_plan/expr.rs | 68 +++++++ datafusion/src/logical_plan/mod.rs | 1 + .../src/optimizer/simplify_expressions.rs | 179 ++++++++++-------- datafusion/tests/simplification.rs | 106 +++++++++++ 4 files changed, 273 insertions(+), 81 deletions(-) create mode 100644 datafusion/tests/simplification.rs diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index a1e51e07422e..63bf72cf226c 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -20,10 +20,12 @@ pub use super::Operator; use crate::error::{DataFusionError, Result}; +use crate::execution::context::ExecutionProps; use crate::field_util::get_indexed_field; use crate::logical_plan::{ plan::Aggregate, window_frames, DFField, DFSchema, LogicalPlan, }; +use crate::optimizer::simplify_expressions::{ConstEvaluator, Simplifier}; use crate::physical_plan::functions::Volatility; use crate::physical_plan::{ aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF, @@ -971,6 +973,58 @@ impl Expr { Ok(expr) } } + + /// Simplifies this [`Expr`]`s as much as possible, evaluating + /// constants and applying algebraic simplifications + /// + /// # Example: + /// `b > 2 AND b > 2` + /// can be written to + /// `b > 2` + /// + /// ``` + /// use datafusion::logical_plan::*; + /// use datafusion::error::Result; + /// use datafusion::execution::context::ExecutionProps; + /// + /// /// Simple implementation that provides `Simplifier` the information it needs + /// #[derive(Default)] + /// struct Info { + /// execution_props: ExecutionProps, + /// }; + /// + /// impl SimplifyInfo for Info { + /// fn is_boolean_type(&self, expr: &Expr) -> Result { + /// Ok(false) + /// } + /// fn nullable(&self, expr: &Expr) -> Result { + /// Ok(true) + /// } + /// fn execution_props(&self) -> &ExecutionProps { + /// &self.execution_props + /// } + /// } + /// + /// // b < 2 + /// let b_lt_2 = col("b").gt(lit(2)); + /// + /// // (b < 2) OR (b < 2) + /// let expr = b_lt_2.clone().or(b_lt_2.clone()); + /// + /// // (b < 2) OR (b < 2) --> (b < 2) + /// let expr = expr.simplify(&Info::default()).unwrap(); + /// assert_eq!(expr, b_lt_2); + /// ``` + pub fn simplify(self, info: &S) -> Result { + let mut rewriter = Simplifier::new(info); + let mut const_evaluator = ConstEvaluator::new(info.execution_props()); + + // TODO iterate until no changes are made during rewrite + // (evaluating constants can enable new simplifications and + // simplifications can enable new constant evaluation) + // https://github.com/apache/arrow-datafusion/issues/1160 + self.rewrite(&mut const_evaluator)?.rewrite(&mut rewriter) + } } impl Not for Expr { @@ -1092,6 +1146,20 @@ pub trait ExprRewriter: Sized { fn mutate(&mut self, expr: Expr) -> Result; } +/// The information necessary to apply algebraic simplification to an +/// [Expr]. See [SimplifyContext] for one implementation +pub trait SimplifyInfo { + /// returns true if this Expr has boolean type + fn is_boolean_type(&self, expr: &Expr) -> Result; + + /// returns true of this expr is nullable (could possibly be NULL) + fn nullable(&self, expr: &Expr) -> Result; + + /// Returns details needed for partial expression evaluation + fn execution_props(&self) -> &ExecutionProps; +} + +/// Helper struct for building [Expr::Case] pub struct CaseBuilder { expr: Option>, when_expr: Vec, diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index 06c6bf90c790..25714514d78a 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -47,6 +47,7 @@ pub use expr::{ signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc, unalias, unnormalize_col, unnormalize_cols, upper, when, Column, Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, RewriteRecursion, + SimplifyInfo, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/datafusion/src/optimizer/simplify_expressions.rs b/datafusion/src/optimizer/simplify_expressions.rs index 00739ccff5ac..c000bdbc2bea 100644 --- a/datafusion/src/optimizer/simplify_expressions.rs +++ b/datafusion/src/optimizer/simplify_expressions.rs @@ -23,8 +23,10 @@ use arrow::record_batch::RecordBatch; use crate::error::DataFusionError; use crate::execution::context::ExecutionProps; -use crate::logical_plan::{lit, DFSchemaRef, Expr}; -use crate::logical_plan::{DFSchema, ExprRewriter, LogicalPlan, RewriteRecursion}; +use crate::logical_plan::{ + lit, DFSchema, DFSchemaRef, Expr, ExprRewriter, LogicalPlan, RewriteRecursion, + SimplifyInfo, +}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; use crate::physical_plan::functions::Volatility; @@ -32,8 +34,57 @@ use crate::physical_plan::planner::create_physical_expr; use crate::scalar::ScalarValue; use crate::{error::Result, logical_plan::Operator}; -/// Simplifies plans by rewriting [`Expr`]`s evaluating constants -/// and applying algebraic simplifications +/// Provides simplification information based on schema and properties +struct SimplifyContext<'a, 'b> { + schemas: Vec<&'a DFSchemaRef>, + props: &'b ExecutionProps, +} + +impl<'a, 'b> SimplifyContext<'a, 'b> { + /// Create a new SimplifyContext + pub fn new(schemas: Vec<&'a DFSchemaRef>, props: &'b ExecutionProps) -> Self { + Self { schemas, props } + } +} + +impl<'a, 'b> SimplifyInfo for SimplifyContext<'a, 'b> { + /// returns true if this Expr has boolean type + fn is_boolean_type(&self, expr: &Expr) -> Result { + for schema in &self.schemas { + if let Ok(DataType::Boolean) = expr.get_type(schema) { + return Ok(true); + } + } + + Ok(false) + } + /// Returns true if expr is nullable + fn nullable(&self, expr: &Expr) -> Result { + self.schemas + .iter() + .find_map(|schema| { + // expr may be from another input, so ignore errors + // by converting to None to keep trying + expr.nullable(schema.as_ref()).ok() + }) + .ok_or_else(|| { + // This means we weren't able to compute `Expr::nullable` with + // *any* input schemas, signalling a problem + DataFusionError::Internal(format!( + "Could not find find columns in '{}' during simplify", + expr + )) + }) + } + + fn execution_props(&self) -> &ExecutionProps { + self.props + } +} + +/// Optimizer Pass that simplifies [`LogicalPlan`]s by rewriting +/// [`Expr`]`s evaluating constants and applying algebraic +/// simplifications /// /// # Introduction /// It uses boolean algebra laws to simplify or reduce the number of terms in expressions. @@ -44,7 +95,7 @@ use crate::{error::Result, logical_plan::Operator}; /// `Filter: b > 2` /// #[derive(Default)] -pub struct SimplifyExpressions {} +pub(crate) struct SimplifyExpressions {} /// returns true if `needle` is found in a chain of search_op /// expressions. Such as: (A AND B) AND C @@ -150,9 +201,7 @@ impl OptimizerRule for SimplifyExpressions { // projected columns. With just the projected schema, it's not possible to infer types for // expressions that references non-projected columns within the same project plan or its // children plans. - let mut simplifier = Simplifier::new(plan.all_schemas()); - - let mut const_evaluator = ConstEvaluator::new(execution_props); + let info = SimplifyContext::new(plan.all_schemas(), execution_props); let new_inputs = plan .inputs() @@ -168,15 +217,8 @@ impl OptimizerRule for SimplifyExpressions { // Constant folding should not change expression name. let name = &e.name(plan.schema()); - // TODO iterate until no changes are made - // during rewrite (evaluating constants can - // enable new simplifications and - // simplifications can enable new constant - // evaluation) - let new_e = e - // fold constants and then simplify - .rewrite(&mut const_evaluator)? - .rewrite(&mut simplifier)?; + // Apply the actual simplification logic + let new_e = e.simplify(&info)?; let new_name = &new_e.name(plan.schema()); @@ -389,52 +431,23 @@ impl<'a> ConstEvaluator<'a> { /// * `false = true` and `true = false` to `false` /// * `!!expr` to `expr` /// * `expr = null` and `expr != null` to `null` -pub(crate) struct Simplifier<'a> { - /// input schemas - schemas: Vec<&'a DFSchemaRef>, +pub(crate) struct Simplifier<'a, S> { + info: &'a S, } -impl<'a> Simplifier<'a> { - pub fn new(schemas: Vec<&'a DFSchemaRef>) -> Self { - Self { schemas } - } - - fn is_boolean_type(&self, expr: &Expr) -> bool { - for schema in &self.schemas { - if let Ok(DataType::Boolean) = expr.get_type(schema) { - return true; - } - } - - false - } - - /// Returns true if expr is nullable - fn nullable(&self, expr: &Expr) -> Result { - self.schemas - .iter() - .find_map(|schema| { - // expr may be from another input, so ignore errors - // by converting to None to keep trying - expr.nullable(schema.as_ref()).ok() - }) - .ok_or_else(|| { - // This means we weren't able to compute `Expr::nullable` with - // *any* input schemas, signalling a problem - DataFusionError::Internal(format!( - "Could not find find columns in '{}' during simplify", - expr - )) - }) +impl<'a, S> Simplifier<'a, S> { + pub fn new(info: &'a S) -> Self { + Self { info } } } -impl<'a> ExprRewriter for Simplifier<'a> { +impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { /// rewrite the expression simplifying any constant expressions fn mutate(&mut self, expr: Expr) -> Result { use Expr::*; use Operator::{And, Divide, Eq, Multiply, NotEq, Or}; + let info = self.info; let new_expr = match expr { // // Rules for Eq @@ -447,7 +460,7 @@ impl<'a> ExprRewriter for Simplifier<'a> { left, op: Eq, right, - } if is_bool_lit(&left) && self.is_boolean_type(&right) => { + } if is_bool_lit(&left) && info.is_boolean_type(&right)? => { match as_bool_lit(*left) { Some(true) => *right, Some(false) => Not(right), @@ -461,7 +474,7 @@ impl<'a> ExprRewriter for Simplifier<'a> { left, op: Eq, right, - } if is_bool_lit(&right) && self.is_boolean_type(&left) => { + } if is_bool_lit(&right) && info.is_boolean_type(&left)? => { match as_bool_lit(*right) { Some(true) => *left, Some(false) => Not(left), @@ -480,7 +493,7 @@ impl<'a> ExprRewriter for Simplifier<'a> { left, op: NotEq, right, - } if is_bool_lit(&left) && self.is_boolean_type(&right) => { + } if is_bool_lit(&left) && info.is_boolean_type(&right)? => { match as_bool_lit(*left) { Some(true) => Not(right), Some(false) => *right, @@ -494,7 +507,7 @@ impl<'a> ExprRewriter for Simplifier<'a> { left, op: NotEq, right, - } if is_bool_lit(&right) && self.is_boolean_type(&left) => { + } if is_bool_lit(&right) && info.is_boolean_type(&left)? => { match as_bool_lit(*right) { Some(true) => Not(left), Some(false) => *left, @@ -547,13 +560,13 @@ impl<'a> ExprRewriter for Simplifier<'a> { left, op: Or, right, - } if !self.nullable(&right)? && is_op_with(And, &right, &left) => *left, + } if !info.nullable(&right)? && is_op_with(And, &right, &left) => *left, // (A AND B) OR A --> A (if B not null) BinaryExpr { left, op: Or, right, - } if !self.nullable(&left)? && is_op_with(And, &left, &right) => *right, + } if !info.nullable(&left)? && is_op_with(And, &left, &right) => *right, // // Rules for AND @@ -600,13 +613,13 @@ impl<'a> ExprRewriter for Simplifier<'a> { left, op: And, right, - } if !self.nullable(&right)? && is_op_with(Or, &right, &left) => *left, + } if !info.nullable(&right)? && is_op_with(Or, &right, &left) => *left, // (A OR B) AND A --> A (if B not null) BinaryExpr { left, op: And, right, - } if !self.nullable(&left)? && is_op_with(Or, &left, &right) => *right, + } if !info.nullable(&left)? && is_op_with(Or, &left, &right) => *right, // // Rules for Multiply @@ -643,7 +656,7 @@ impl<'a> ExprRewriter for Simplifier<'a> { left, op: Divide, right, - } if !self.nullable(&left)? && left == right => lit(1), + } if !info.nullable(&left)? && left == right => lit(1), // // Rules for Not @@ -676,7 +689,7 @@ impl<'a> ExprRewriter for Simplifier<'a> { else_expr, } if !when_then_expr.is_empty() && when_then_expr.len() < 3 // The rewrite is O(n!) so limit to small number - && self.is_boolean_type(&when_then_expr[0].1) => + && info.is_boolean_type(&when_then_expr[0].1)? => { // The disjunction of all the when predicates encountered so far let mut filter_expr = lit(false); @@ -1208,17 +1221,9 @@ mod tests { fn simplify(expr: Expr) -> Expr { let schema = expr_test_schema(); - let mut rewriter = Simplifier::new(vec![&schema]); - let execution_props = ExecutionProps::new(); - let mut const_evaluator = ConstEvaluator::new(&execution_props); - - expr.rewrite(&mut rewriter) - .expect("expected to simplify") - .rewrite(&mut const_evaluator) - .expect("expected to const evaluate") - .rewrite(&mut rewriter) - .expect("expected to simplify") + let info = SimplifyContext::new(vec![&schema], &execution_props); + expr.simplify(&info).unwrap() } fn expr_test_schema() -> DFSchemaRef { @@ -1357,30 +1362,36 @@ mod tests { // CASE WHERE c2 THEN true ELSE c2 // --> // c2 + // + // Need to call simplify 2x due to + // https://github.com/apache/arrow-datafusion/issues/1160 assert_eq!( - simplify(Expr::Case { + simplify(simplify(Expr::Case { expr: None, when_then_expr: vec![( Box::new(col("c2").not_eq(lit(false))), Box::new(lit("ok").eq(lit("ok"))), )], else_expr: Some(Box::new(col("c2").eq(lit(true)))), - }), + })), col("c2").or(col("c2").not().and(col("c2"))) // #1716 ); // CASE WHERE ISNULL(c2) THEN true ELSE c2 // --> // ISNULL(c2) OR c2 + // + // Need to call simplify 2x due to + // https://github.com/apache/arrow-datafusion/issues/1160 assert_eq!( - simplify(Expr::Case { + simplify(simplify(Expr::Case { expr: None, when_then_expr: vec![( Box::new(col("c2").is_null()), Box::new(lit(true)), )], else_expr: Some(Box::new(col("c2"))), - }), + })), col("c2") .is_null() .or(col("c2").is_null().not().and(col("c2"))) @@ -1390,15 +1401,18 @@ mod tests { // --> c1 OR (NOT(c1) AND c2 AND FALSE) OR (NOT(c1 OR c2) AND TRUE) // --> c1 OR (NOT(c1 OR c2)) // --> NOT(c1) AND c2 + // + // Need to call simplify 2x due to + // https://github.com/apache/arrow-datafusion/issues/1160 assert_eq!( - simplify(Expr::Case { + simplify(simplify(Expr::Case { expr: None, when_then_expr: vec![ (Box::new(col("c1")), Box::new(lit(true)),), (Box::new(col("c2")), Box::new(lit(false)),) ], else_expr: Some(Box::new(lit(true))), - }), + })), col("c1").or(col("c1").or(col("c2")).not()) ); @@ -1406,15 +1420,18 @@ mod tests { // --> c1 OR (NOT(c1) AND c2 AND TRUE) OR (NOT(c1 OR c2) AND FALSE) // --> c1 OR (NOT(c1) AND c2) // --> c1 OR c2 + // + // Need to call simplify 2x due to + // https://github.com/apache/arrow-datafusion/issues/1160 assert_eq!( - simplify(Expr::Case { + simplify(simplify(Expr::Case { expr: None, when_then_expr: vec![ (Box::new(col("c1")), Box::new(lit(true)),), (Box::new(col("c2")), Box::new(lit(false)),) ], else_expr: Some(Box::new(lit(true))), - }), + })), col("c1").or(col("c1").or(col("c2")).not()) ); } diff --git a/datafusion/tests/simplification.rs b/datafusion/tests/simplification.rs new file mode 100644 index 000000000000..5edf43f5ccb2 --- /dev/null +++ b/datafusion/tests/simplification.rs @@ -0,0 +1,106 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This program demonstrates the DataFusion expression simplification API. + +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::{ + error::Result, + execution::context::ExecutionProps, + logical_plan::{DFSchema, Expr, SimplifyInfo}, + prelude::*, +}; + +/// In order to simplify expressions, DataFusion must have information +/// about the expressions. +/// +/// You can provide that information using DataFusion [DFSchema] +/// objects or from some other implemention +struct MyInfo { + /// The input schema + schema: DFSchema, + + /// Execution specific details needed for constant evaluation such + /// as the current time for `now()` and [VariableProviders] + execution_props: ExecutionProps, +} + +impl SimplifyInfo for MyInfo { + fn is_boolean_type(&self, expr: &Expr) -> Result { + Ok(matches!(expr.get_type(&self.schema)?, DataType::Boolean)) + } + + fn nullable(&self, expr: &Expr) -> Result { + expr.nullable(&self.schema) + } + + fn execution_props(&self) -> &ExecutionProps { + &self.execution_props + } +} + +impl From for MyInfo { + fn from(schema: DFSchema) -> Self { + Self { + schema, + execution_props: ExecutionProps::new(), + } + } +} + +/// A schema like: +/// +/// a: Int32 (possibly with nulls) +/// b: Int32 +/// s: Utf8 +fn schema() -> DFSchema { + Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, false), + Field::new("s", DataType::Utf8, false), + ]) + .try_into() + .unwrap() +} + +#[test] +fn basic() { + let info: MyInfo = schema().into(); + + // The `Expr` is a core concept in DataFusion, and DataFusion can + // help simplify it. + + // For example 'a < (2 + 3)' can be rewritten into the easier to + // optimize form `a < 5` automatically + let expr = col("a").lt(lit(2i32) + lit(3i32)); + + let simplified = expr.simplify(&info).unwrap(); + assert_eq!(simplified, col("a").lt(lit(5i32))); +} + +#[test] +fn fold_and_simplify() { + let info: MyInfo = schema().into(); + + // What will it do with the expression `concat('foo', 'bar') == 'foobar')`? + let expr = concat(&[lit("foo"), lit("bar")]).eq(lit("foobar")); + + // Since datafusion applies both simplification *and* rewriting + // some expressions can be entirely simplified + let simplified = expr.simplify(&info).unwrap(); + assert_eq!(simplified, lit(true)) +} From d1ebdbf89c955d54aa31f4b76407987580f8994c Mon Sep 17 00:00:00 2001 From: Will Jones Date: Thu, 3 Feb 2022 07:58:21 -0800 Subject: [PATCH 23/50] Add tests and CI for optional pyarrow module (#1711) * Implement other side of conversion * Add test workflow * Add (failing) tests * Get unit tests passing * Use python -m pip * Debug LD_LIBRARY_PATH * Set LIBRARY_PATH * Update help with better info --- .github/workflows/rust.yml | 49 ++++++++++++++++++++ datafusion/src/pyarrow.rs | 91 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 136 insertions(+), 4 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 4b633d4bc9e5..d466d67efa6f 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -230,6 +230,55 @@ jobs: # do not produce debug symbols to keep memory usage down RUSTFLAGS: "-C debuginfo=0" + test-datafusion-pyarrow: + needs: [linux-build-lib] + runs-on: ubuntu-latest + strategy: + matrix: + arch: [amd64] + rust: [stable] + container: + image: ${{ matrix.arch }}/rust + env: + # Disable full debug symbol generation to speed up CI build and keep memory down + # "1" means line tables only, which is useful for panic tracebacks. + RUSTFLAGS: "-C debuginfo=1" + steps: + - uses: actions/checkout@v2 + with: + submodules: true + - name: Cache Cargo + uses: actions/cache@v2 + with: + path: /github/home/.cargo + # this key equals the ones on `linux-build-lib` for re-use + key: cargo-cache- + - name: Cache Rust dependencies + uses: actions/cache@v2 + with: + path: /github/home/target + # this key equals the ones on `linux-build-lib` for re-use + key: ${{ runner.os }}-${{ matrix.arch }}-target-cache-${{ matrix.rust }} + - uses: actions/setup-python@v2 + with: + python-version: "3.8" + - name: Install PyArrow + run: | + echo "LIBRARY_PATH=$LD_LIBRARY_PATH" >> $GITHUB_ENV + python -m pip install pyarrow + - name: Setup Rust toolchain + run: | + rustup toolchain install ${{ matrix.rust }} + rustup default ${{ matrix.rust }} + rustup component add rustfmt + - name: Run tests + run: | + cd datafusion + cargo test --features=pyarrow + env: + CARGO_HOME: "/github/home/.cargo" + CARGO_TARGET_DIR: "/github/home/target" + lint: name: Lint runs-on: ubuntu-latest diff --git a/datafusion/src/pyarrow.rs b/datafusion/src/pyarrow.rs index da05d63d8c2c..d819b2b41154 100644 --- a/datafusion/src/pyarrow.rs +++ b/datafusion/src/pyarrow.rs @@ -15,10 +15,9 @@ // specific language governing permissions and limitations // under the License. -use pyo3::exceptions::{PyException, PyNotImplementedError}; +use pyo3::exceptions::PyException; use pyo3::prelude::*; use pyo3::types::PyList; -use pyo3::PyNativeType; use crate::arrow::array::ArrayData; use crate::arrow::pyarrow::PyArrowConvert; @@ -49,8 +48,13 @@ impl PyArrowConvert for ScalarValue { Ok(scalar) } - fn to_pyarrow(&self, _py: Python) -> PyResult { - Err(PyNotImplementedError::new_err("Not implemented")) + fn to_pyarrow(&self, py: Python) -> PyResult { + let array = self.to_array(); + // convert to pyarrow array using C data interface + let pyarray = array.data_ref().clone().into_py(py); + let pyscalar = pyarray.call_method1(py, "__getitem__", (0,))?; + + Ok(pyscalar) } } @@ -65,3 +69,82 @@ impl<'a> IntoPy for ScalarValue { self.to_pyarrow(py).unwrap() } } + +#[cfg(test)] +mod tests { + use super::*; + use pyo3::prepare_freethreaded_python; + use pyo3::py_run; + use pyo3::types::PyDict; + use pyo3::Python; + + fn init_python() { + prepare_freethreaded_python(); + Python::with_gil(|py| { + if let Err(err) = py.run("import pyarrow", None, None) { + let locals = PyDict::new(py); + py.run( + "import sys; executable = sys.executable; python_path = sys.path", + None, + Some(locals), + ) + .expect("Couldn't get python info"); + let executable: String = + locals.get_item("executable").unwrap().extract().unwrap(); + let python_path: Vec<&str> = + locals.get_item("python_path").unwrap().extract().unwrap(); + + Err(err).expect( + format!( + "pyarrow not found\nExecutable: {}\nPython path: {:?}\n\ + HINT: try `pip install pyarrow`\n\ + NOTE: On Mac OS, you must compile against a Framework Python \ + (default in python.org installers and brew, but not pyenv)\n\ + NOTE: On Mac OS, PYO3 might point to incorrect Python library \ + path when using virtual environments. Try \ + `export PYTHONPATH=$(python -c \"import sys; print(sys.path[-1])\")`\n", + executable, python_path + ) + .as_ref(), + ) + } + }) + } + + #[test] + fn test_roundtrip() { + init_python(); + + let example_scalars = vec![ + ScalarValue::Boolean(Some(true)), + ScalarValue::Int32(Some(23)), + ScalarValue::Float64(Some(12.34)), + ScalarValue::Utf8(Some("Hello!".to_string())), + ScalarValue::Date32(Some(1234)), + ]; + + Python::with_gil(|py| { + for scalar in example_scalars.iter() { + let result = + ScalarValue::from_pyarrow(scalar.to_pyarrow(py).unwrap().as_ref(py)) + .unwrap(); + assert_eq!(scalar, &result); + } + }); + } + + #[test] + fn test_py_scalar() { + init_python(); + + Python::with_gil(|py| { + let scalar_float = ScalarValue::Float64(Some(12.34)); + let py_float = scalar_float.into_py(py).call_method0(py, "as_py").unwrap(); + py_run!(py, py_float, "assert py_float == 12.34"); + + let scalar_string = ScalarValue::Utf8(Some("Hello!".to_string())); + let py_string = scalar_string.into_py(py).call_method0(py, "as_py").unwrap(); + py_run!(py, py_string, "assert py_string == 'Hello!'"); + }); + } +} From aca855d27e033f37447a5edf8ac74102e9cd4ce2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 3 Feb 2022 11:11:10 -0500 Subject: [PATCH 24/50] Update parking_lot requirement from 0.11 to 0.12 (#1735) Updates the requirements on [parking_lot](https://github.com/Amanieu/parking_lot) to permit the latest version. - [Release notes](https://github.com/Amanieu/parking_lot/releases) - [Changelog](https://github.com/Amanieu/parking_lot/blob/master/CHANGELOG.md) - [Commits](https://github.com/Amanieu/parking_lot/compare/0.11.0...0.12.0) --- updated-dependencies: - dependency-name: parking_lot dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- ballista/rust/client/Cargo.toml | 2 +- ballista/rust/core/Cargo.toml | 2 +- ballista/rust/executor/Cargo.toml | 2 +- ballista/rust/scheduler/Cargo.toml | 2 +- datafusion/Cargo.toml | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ballista/rust/client/Cargo.toml b/ballista/rust/client/Cargo.toml index 4ec1abe77654..dff5d1a5c584 100644 --- a/ballista/rust/client/Cargo.toml +++ b/ballista/rust/client/Cargo.toml @@ -35,7 +35,7 @@ log = "0.4" tokio = "1.0" tempfile = "3" sqlparser = "0.13" -parking_lot = "0.11" +parking_lot = "0.12" datafusion = { path = "../../../datafusion", version = "6.0.0" } diff --git a/ballista/rust/core/Cargo.toml b/ballista/rust/core/Cargo.toml index 043f79a962b2..59a1142c2199 100644 --- a/ballista/rust/core/Cargo.toml +++ b/ballista/rust/core/Cargo.toml @@ -48,7 +48,7 @@ parse_arg = "0.1.3" arrow-flight = { version = "8.0.0" } datafusion = { path = "../../../datafusion", version = "6.0.0" } -parking_lot = "0.11" +parking_lot = "0.12" [dev-dependencies] tempfile = "3" diff --git a/ballista/rust/executor/Cargo.toml b/ballista/rust/executor/Cargo.toml index fb456a93ddc6..1f7ac61a2071 100644 --- a/ballista/rust/executor/Cargo.toml +++ b/ballista/rust/executor/Cargo.toml @@ -46,7 +46,7 @@ tokio-stream = { version = "0.1", features = ["net"] } tonic = "0.6" uuid = { version = "0.8", features = ["v4"] } hyper = "0.14.4" -parking_lot = "0.11" +parking_lot = "0.12" [dev-dependencies] diff --git a/ballista/rust/scheduler/Cargo.toml b/ballista/rust/scheduler/Cargo.toml index fdeb7e726d57..8acb13ba8963 100644 --- a/ballista/rust/scheduler/Cargo.toml +++ b/ballista/rust/scheduler/Cargo.toml @@ -53,7 +53,7 @@ tokio-stream = { version = "0.1", features = ["net"], optional = true } tonic = "0.6" tower = { version = "0.4" } warp = "0.3" -parking_lot = "0.11" +parking_lot = "0.12" [dev-dependencies] ballista-core = { path = "../core", version = "0.6.0" } diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index 578542cbbec6..54247cbcf07c 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -78,7 +78,7 @@ avro-rs = { version = "0.13", features = ["snappy"], optional = true } num-traits = { version = "0.2", optional = true } pyo3 = { version = "0.15", optional = true } tempfile = "3" -parking_lot = "0.11" +parking_lot = "0.12" [dev-dependencies] criterion = "0.3" From 78c30b66bbf80733c65dfdef412ccb4e566b2194 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Thu, 3 Feb 2022 16:35:08 +0000 Subject: [PATCH 25/50] Prevent repartitioning of certain operator's direct children (#1731) (#1732) * Prevent repartitioning of certain operator's direct children (#1731) * Update ballista tests * Don't repartition children of RepartitionExec * Revert partition restriction on Repartition and Projection * Review feedback * Lint --- .../src/physical_optimizer/repartition.rs | 212 ++++++++++++------ datafusion/src/physical_plan/limit.rs | 5 + datafusion/src/physical_plan/mod.rs | 18 ++ datafusion/src/physical_plan/union.rs | 4 + 4 files changed, 173 insertions(+), 66 deletions(-) diff --git a/datafusion/src/physical_optimizer/repartition.rs b/datafusion/src/physical_optimizer/repartition.rs index 0926ed73ca9f..1f4505324aa3 100644 --- a/datafusion/src/physical_optimizer/repartition.rs +++ b/datafusion/src/physical_optimizer/repartition.rs @@ -19,10 +19,10 @@ use std::sync::Arc; use super::optimizer::PhysicalOptimizerRule; +use crate::physical_plan::Partitioning::*; use crate::physical_plan::{ empty::EmptyExec, repartition::RepartitionExec, ExecutionPlan, }; -use crate::physical_plan::{Distribution, Partitioning::*}; use crate::{error::Result, execution::context::ExecutionConfig}; /// Optimizer that introduces repartition to introduce more parallelism in the plan @@ -38,8 +38,8 @@ impl Repartition { fn optimize_partitions( target_partitions: usize, - requires_single_partition: bool, plan: Arc, + should_repartition: bool, ) -> Result> { // Recurse into children bottom-up (added nodes should be as deep as possible) @@ -47,17 +47,15 @@ fn optimize_partitions( // leaf node - don't replace children plan.clone() } else { + let should_repartition_children = plan.should_repartition_children(); let children = plan .children() .iter() .map(|child| { optimize_partitions( target_partitions, - matches!( - plan.required_child_distribution(), - Distribution::SinglePartition - ), child.clone(), + should_repartition_children, ) }) .collect::>()?; @@ -77,7 +75,7 @@ fn optimize_partitions( // But also not very useful to inlude let is_empty_exec = plan.as_any().downcast_ref::().is_some(); - if perform_repartition && !requires_single_partition && !is_empty_exec { + if perform_repartition && should_repartition && !is_empty_exec { Ok(Arc::new(RepartitionExec::try_new( new_plan, RoundRobinBatch(target_partitions), @@ -97,7 +95,7 @@ impl PhysicalOptimizerRule for Repartition { if config.target_partitions == 1 { Ok(plan) } else { - optimize_partitions(config.target_partitions, true, plan) + optimize_partitions(config.target_partitions, plan, false) } } @@ -107,93 +105,175 @@ impl PhysicalOptimizerRule for Repartition { } #[cfg(test)] mod tests { - use arrow::datatypes::Schema; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use super::*; use crate::datasource::PartitionedFile; + use crate::physical_plan::expressions::col; use crate::physical_plan::file_format::{FileScanConfig, ParquetExec}; - use crate::physical_plan::projection::ProjectionExec; - use crate::physical_plan::Statistics; + use crate::physical_plan::filter::FilterExec; + use crate::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; + use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; + use crate::physical_plan::union::UnionExec; + use crate::physical_plan::{displayable, Statistics}; use crate::test::object_store::TestObjectStore; + fn schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new("c1", DataType::Boolean, true)])) + } + + fn parquet_exec() -> Arc { + Arc::new(ParquetExec::new( + FileScanConfig { + object_store: TestObjectStore::new_arc(&[("x", 100)]), + file_schema: schema(), + file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], + statistics: Statistics::default(), + projection: None, + limit: None, + table_partition_cols: vec![], + }, + None, + )) + } + + fn filter_exec(input: Arc) -> Arc { + Arc::new(FilterExec::try_new(col("c1", &schema()).unwrap(), input).unwrap()) + } + + fn hash_aggregate(input: Arc) -> Arc { + let schema = schema(); + Arc::new( + HashAggregateExec::try_new( + AggregateMode::Final, + vec![], + vec![], + Arc::new( + HashAggregateExec::try_new( + AggregateMode::Partial, + vec![], + vec![], + input, + schema.clone(), + ) + .unwrap(), + ), + schema, + ) + .unwrap(), + ) + } + + fn limit_exec(input: Arc) -> Arc { + Arc::new(GlobalLimitExec::new( + Arc::new(LocalLimitExec::new(input, 100)), + 100, + )) + } + + fn trim_plan_display(plan: &str) -> Vec<&str> { + plan.split('\n') + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .collect() + } + #[test] fn added_repartition_to_single_partition() -> Result<()> { - let file_schema = Arc::new(Schema::empty()); - let parquet_project = ProjectionExec::try_new( - vec![], - Arc::new(ParquetExec::new( - FileScanConfig { - object_store: TestObjectStore::new_arc(&[("x", 100)]), - file_schema, - file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], - statistics: Statistics::default(), - projection: None, - limit: None, - table_partition_cols: vec![], - }, - None, - )), - )?; - let optimizer = Repartition {}; let optimized = optimizer.optimize( - Arc::new(parquet_project), + hash_aggregate(parquet_exec()), &ExecutionConfig::new().with_target_partitions(10), )?; - assert_eq!( - optimized.children()[0] - .output_partitioning() - .partition_count(), - 10 - ); + let plan = displayable(optimized.as_ref()).indent().to_string(); + + let expected = &[ + "HashAggregateExec: mode=Final, gby=[], aggr=[]", + "HashAggregateExec: mode=Partial, gby=[], aggr=[]", + "RepartitionExec: partitioning=RoundRobinBatch(10)", + "ParquetExec: limit=None, partitions=[x]", + ]; + assert_eq!(&trim_plan_display(&plan), &expected); Ok(()) } #[test] fn repartition_deepest_node() -> Result<()> { - let file_schema = Arc::new(Schema::empty()); - let parquet_project = ProjectionExec::try_new( - vec![], - Arc::new(ProjectionExec::try_new( - vec![], - Arc::new(ParquetExec::new( - FileScanConfig { - object_store: TestObjectStore::new_arc(&[("x", 100)]), - file_schema, - file_groups: vec![vec![PartitionedFile::new( - "x".to_string(), - 100, - )]], - statistics: Statistics::default(), - projection: None, - limit: None, - table_partition_cols: vec![], - }, - None, - )), - )?), + let optimizer = Repartition {}; + + let optimized = optimizer.optimize( + hash_aggregate(filter_exec(parquet_exec())), + &ExecutionConfig::new().with_target_partitions(10), )?; + let plan = displayable(optimized.as_ref()).indent().to_string(); + + let expected = &[ + "HashAggregateExec: mode=Final, gby=[], aggr=[]", + "HashAggregateExec: mode=Partial, gby=[], aggr=[]", + "FilterExec: c1@0", + "RepartitionExec: partitioning=RoundRobinBatch(10)", + "ParquetExec: limit=None, partitions=[x]", + ]; + + assert_eq!(&trim_plan_display(&plan), &expected); + Ok(()) + } + + #[test] + fn repartition_ignores_limit() -> Result<()> { let optimizer = Repartition {}; let optimized = optimizer.optimize( - Arc::new(parquet_project), + hash_aggregate(limit_exec(filter_exec(limit_exec(parquet_exec())))), &ExecutionConfig::new().with_target_partitions(10), )?; - // RepartitionExec is added to deepest node - assert!(optimized.children()[0] - .as_any() - .downcast_ref::() - .is_none()); - assert!(optimized.children()[0].children()[0] - .as_any() - .downcast_ref::() - .is_some()); + let plan = displayable(optimized.as_ref()).indent().to_string(); + + let expected = &[ + "HashAggregateExec: mode=Final, gby=[], aggr=[]", + "HashAggregateExec: mode=Partial, gby=[], aggr=[]", + "RepartitionExec: partitioning=RoundRobinBatch(10)", + "GlobalLimitExec: limit=100", + "LocalLimitExec: limit=100", + "FilterExec: c1@0", + "RepartitionExec: partitioning=RoundRobinBatch(10)", + "GlobalLimitExec: limit=100", + "LocalLimitExec: limit=100", + // Expect no repartition to happen for local limit + "ParquetExec: limit=None, partitions=[x]", + ]; + + assert_eq!(&trim_plan_display(&plan), &expected); + Ok(()) + } + + #[test] + fn repartition_ignores_union() -> Result<()> { + let optimizer = Repartition {}; + + let optimized = optimizer.optimize( + Arc::new(UnionExec::new(vec![parquet_exec(); 5])), + &ExecutionConfig::new().with_target_partitions(5), + )?; + + let plan = displayable(optimized.as_ref()).indent().to_string(); + + let expected = &[ + "UnionExec", + // Expect no repartition of ParquetExec + "ParquetExec: limit=None, partitions=[x]", + "ParquetExec: limit=None, partitions=[x]", + "ParquetExec: limit=None, partitions=[x]", + "ParquetExec: limit=None, partitions=[x]", + "ParquetExec: limit=None, partitions=[x]", + ]; + assert_eq!(&trim_plan_display(&plan), &expected); Ok(()) } } diff --git a/datafusion/src/physical_plan/limit.rs b/datafusion/src/physical_plan/limit.rs index f0225579d5a6..587780d9de4a 100644 --- a/datafusion/src/physical_plan/limit.rs +++ b/datafusion/src/physical_plan/limit.rs @@ -300,6 +300,11 @@ impl ExecutionPlan for LocalLimitExec { _ => Statistics::default(), } } + + fn should_repartition_children(&self) -> bool { + // No reason to repartition children as this node is just limiting each input partition. + false + } } /// Truncate a RecordBatch to maximum of n rows diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index 725e475335ca..ac70f2f90ae2 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -135,14 +135,32 @@ pub trait ExecutionPlan: Debug + Send + Sync { /// Returns the execution plan as [`Any`](std::any::Any) so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; + /// Get the schema for this execution plan fn schema(&self) -> SchemaRef; + /// Specifies the output partitioning scheme of this plan fn output_partitioning(&self) -> Partitioning; + /// Specifies the data distribution requirements of all the children for this operator fn required_child_distribution(&self) -> Distribution { Distribution::UnspecifiedDistribution } + + /// Returns `true` if the direct children of this `ExecutionPlan` should be repartitioned + /// to introduce greater concurrency to the plan + /// + /// The default implementation returns `true` unless `Self::required_child_distribution` + /// returns `Distribution::SinglePartition` + /// + /// Operators that do not benefit from additional partitioning may want to return `false` + fn should_repartition_children(&self) -> bool { + !matches!( + self.required_child_distribution(), + Distribution::SinglePartition + ) + } + /// Get a list of child execution plans that provide the input for this plan. The returned list /// will be empty for leaf nodes, will contain a single value for unary nodes, or two /// values for binary nodes (such as joins). diff --git a/datafusion/src/physical_plan/union.rs b/datafusion/src/physical_plan/union.rs index 93ecf224b7b3..d2c170bc27f8 100644 --- a/datafusion/src/physical_plan/union.rs +++ b/datafusion/src/physical_plan/union.rs @@ -143,6 +143,10 @@ impl ExecutionPlan for UnionExec { .reduce(stats_union) .unwrap_or_default() } + + fn should_repartition_children(&self) -> bool { + false + } } /// Stream wrapper that records `BaselineMetrics` for a particular From b2eaee385035d78ec89ea3ef8e5772eb9f9018fe Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 3 Feb 2022 12:04:40 -0500 Subject: [PATCH 26/50] API to get Expr's type and nullability without a `DFSchema` (#1726) * API to get Expr type and nullability without a `DFSchema` * Add test * publically export * Improve docs --- datafusion/src/logical_plan/expr.rs | 123 +++++++++++++++++++++++++--- datafusion/src/logical_plan/mod.rs | 4 +- 2 files changed, 113 insertions(+), 14 deletions(-) diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 63bf72cf226c..300c75137d8a 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -392,20 +392,59 @@ impl PartialOrd for Expr { } } +/// Provides schema information needed by [Expr] methods such as +/// [Expr::nullable] and [Expr::data_type]. +/// +/// Note that this trait is implemented for &[DFSchema] which is +/// widely used in the DataFusion codebase. +pub trait ExprSchema { + /// Is this column reference nullable? + fn nullable(&self, col: &Column) -> Result; + + /// What is the datatype of this column? + fn data_type(&self, col: &Column) -> Result<&DataType>; +} + +// Implement `ExprSchema` for `Arc` +impl> ExprSchema for P { + fn nullable(&self, col: &Column) -> Result { + self.as_ref().nullable(col) + } + + fn data_type(&self, col: &Column) -> Result<&DataType> { + self.as_ref().data_type(col) + } +} + +impl ExprSchema for DFSchema { + fn nullable(&self, col: &Column) -> Result { + Ok(self.field_from_column(col)?.is_nullable()) + } + + fn data_type(&self, col: &Column) -> Result<&DataType> { + Ok(self.field_from_column(col)?.data_type()) + } +} + impl Expr { - /// Returns the [arrow::datatypes::DataType] of the expression based on [arrow::datatypes::Schema]. + /// Returns the [arrow::datatypes::DataType] of the expression + /// based on [ExprSchema] + /// + /// Note: [DFSchema] implements [ExprSchema]. /// /// # Errors /// - /// This function errors when it is not possible to compute its [arrow::datatypes::DataType]. - /// This happens when e.g. the expression refers to a column that does not exist in the schema, or when - /// the expression is incorrectly typed (e.g. `[utf8] + [bool]`). - pub fn get_type(&self, schema: &DFSchema) -> Result { + /// This function errors when it is not possible to compute its + /// [arrow::datatypes::DataType]. This happens when e.g. the + /// expression refers to a column that does not exist in the + /// schema, or when the expression is incorrectly typed + /// (e.g. `[utf8] + [bool]`). + pub fn get_type(&self, schema: &S) -> Result { match self { Expr::Alias(expr, _) | Expr::Sort { expr, .. } | Expr::Negative(expr) => { expr.get_type(schema) } - Expr::Column(c) => Ok(schema.field_from_column(c)?.data_type().clone()), + Expr::Column(c) => Ok(schema.data_type(c)?.clone()), Expr::ScalarVariable(_) => Ok(DataType::Utf8), Expr::Literal(l) => Ok(l.get_datatype()), Expr::Case { when_then_expr, .. } => when_then_expr[0].1.get_type(schema), @@ -472,13 +511,16 @@ impl Expr { } } - /// Returns the nullability of the expression based on [arrow::datatypes::Schema]. + /// Returns the nullability of the expression based on [ExprSchema]. + /// + /// Note: [DFSchema] implements [ExprSchema]. /// /// # Errors /// - /// This function errors when it is not possible to compute its nullability. - /// This happens when the expression refers to a column that does not exist in the schema. - pub fn nullable(&self, input_schema: &DFSchema) -> Result { + /// This function errors when it is not possible to compute its + /// nullability. This happens when the expression refers to a + /// column that does not exist in the schema. + pub fn nullable(&self, input_schema: &S) -> Result { match self { Expr::Alias(expr, _) | Expr::Not(expr) @@ -486,7 +528,7 @@ impl Expr { | Expr::Sort { expr, .. } | Expr::Between { expr, .. } | Expr::InList { expr, .. } => expr.nullable(input_schema), - Expr::Column(c) => Ok(input_schema.field_from_column(c)?.is_nullable()), + Expr::Column(c) => input_schema.nullable(c), Expr::Literal(value) => Ok(value.is_null()), Expr::Case { when_then_expr, @@ -561,7 +603,11 @@ impl Expr { /// /// This function errors when it is impossible to cast the /// expression to the target [arrow::datatypes::DataType]. - pub fn cast_to(self, cast_to_type: &DataType, schema: &DFSchema) -> Result { + pub fn cast_to( + self, + cast_to_type: &DataType, + schema: &S, + ) -> Result { // TODO(kszucs): most of the operations do not validate the type correctness // like all of the binary expressions below. Perhaps Expr should track the // type of the expression? @@ -2557,4 +2603,57 @@ mod tests { combine_filters(&[filter1.clone(), filter2.clone(), filter3.clone()]); assert_eq!(result, Some(and(and(filter1, filter2), filter3))); } + + #[test] + fn expr_schema_nullability() { + let expr = col("foo").eq(lit(1)); + assert!(!expr.nullable(&MockExprSchema::new()).unwrap()); + assert!(expr + .nullable(&MockExprSchema::new().with_nullable(true)) + .unwrap()); + } + + #[test] + fn expr_schema_data_type() { + let expr = col("foo"); + assert_eq!( + DataType::Utf8, + expr.get_type(&MockExprSchema::new().with_data_type(DataType::Utf8)) + .unwrap() + ); + } + + struct MockExprSchema { + nullable: bool, + data_type: DataType, + } + + impl MockExprSchema { + fn new() -> Self { + Self { + nullable: false, + data_type: DataType::Null, + } + } + + fn with_nullable(mut self, nullable: bool) -> Self { + self.nullable = nullable; + self + } + + fn with_data_type(mut self, data_type: DataType) -> Self { + self.data_type = data_type; + self + } + } + + impl ExprSchema for MockExprSchema { + fn nullable(&self, _col: &Column) -> Result { + Ok(self.nullable) + } + + fn data_type(&self, _col: &Column) -> Result<&DataType> { + Ok(&self.data_type) + } + } } diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index 25714514d78a..22521a1bd1fb 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -46,8 +46,8 @@ pub use expr::{ rewrite_sort_cols_by_aggs, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc, unalias, unnormalize_col, unnormalize_cols, upper, when, - Column, Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, RewriteRecursion, - SimplifyInfo, + Column, Expr, ExprRewriter, ExprSchema, ExpressionVisitor, Literal, Recursion, + RewriteRecursion, SimplifyInfo, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; From 512475996eb90f7c07b8bdc55a8953b6fd8f0881 Mon Sep 17 00:00:00 2001 From: "r.4ntix" Date: Fri, 4 Feb 2022 03:14:57 +0800 Subject: [PATCH 27/50] Fix typos in crate documentation (#1739) --- datafusion/src/lib.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/src/lib.rs b/datafusion/src/lib.rs index 2eb7d9af2ffd..9442f7e5fe9f 100644 --- a/datafusion/src/lib.rs +++ b/datafusion/src/lib.rs @@ -162,8 +162,8 @@ //! * Sort: [`SortExec`](physical_plan::sort::SortExec) //! * Coalesce partitions: [`CoalescePartitionsExec`](physical_plan::coalesce_partitions::CoalescePartitionsExec) //! * Limit: [`LocalLimitExec`](physical_plan::limit::LocalLimitExec) and [`GlobalLimitExec`](physical_plan::limit::GlobalLimitExec) -//! * Scan a CSV: [`CsvExec`](physical_plan::csv::CsvExec) -//! * Scan a Parquet: [`ParquetExec`](physical_plan::parquet::ParquetExec) +//! * Scan a CSV: [`CsvExec`](physical_plan::file_format::CsvExec) +//! * Scan a Parquet: [`ParquetExec`](physical_plan::file_format::ParquetExec) //! * Scan from memory: [`MemoryExec`](physical_plan::memory::MemoryExec) //! * Explain the plan: [`ExplainExec`](physical_plan::explain::ExplainExec) //! From 97a1b21d01b896cea803b48c7841746d899eb3a1 Mon Sep 17 00:00:00 2001 From: "xudong.w" Date: Fri, 4 Feb 2022 22:41:16 +0800 Subject: [PATCH 28/50] add `cargo check --release` to ci (#1737) * remote test * Update .github/workflows/rust.yml Co-authored-by: Andrew Lamb Co-authored-by: Andrew Lamb --- .github/workflows/rust.yml | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index d466d67efa6f..046309c71482 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -58,12 +58,18 @@ jobs: rustup toolchain install ${{ matrix.rust }} rustup default ${{ matrix.rust }} rustup component add rustfmt - - name: Build Workspace + - name: Build workspace in debug mode run: | cargo build env: CARGO_HOME: "/github/home/.cargo" - CARGO_TARGET_DIR: "/github/home/target" + CARGO_TARGET_DIR: "/github/home/target/debug" + - name: Build workspace in release mode + run: | + cargo check --release + env: + CARGO_HOME: "/github/home/.cargo" + CARGO_TARGET_DIR: "/github/home/target/release" - name: Check DataFusion Build without default features run: | cargo check --no-default-features -p datafusion From 15cfcbc28305e82891a5a52d252fb23c72fd8458 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 4 Feb 2022 11:44:38 -0500 Subject: [PATCH 29/50] Move optimize test out of context.rs (#1742) * Move optimize test out of context.rs * Update --- datafusion/src/execution/context.rs | 38 ------------------ datafusion/tests/sql/explain.rs | 60 +++++++++++++++++++++++++++++ datafusion/tests/sql/mod.rs | 1 + 3 files changed, 61 insertions(+), 38 deletions(-) create mode 100644 datafusion/tests/sql/explain.rs diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index deec84d5a0ff..fb271a1a7e56 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -1347,44 +1347,6 @@ mod tests { )); } - #[test] - fn optimize_explain() { - let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); - - let plan = LogicalPlanBuilder::scan_empty(Some("employee"), &schema, None) - .unwrap() - .explain(true, false) - .unwrap() - .build() - .unwrap(); - - if let LogicalPlan::Explain(e) = &plan { - assert_eq!(e.stringified_plans.len(), 1); - } else { - panic!("plan was not an explain: {:?}", plan); - } - - // now optimize the plan and expect to see more plans - let optimized_plan = ExecutionContext::new().optimize(&plan).unwrap(); - if let LogicalPlan::Explain(e) = &optimized_plan { - // should have more than one plan - assert!( - e.stringified_plans.len() > 1, - "plans: {:#?}", - e.stringified_plans - ); - // should have at least one optimized plan - let opt = e - .stringified_plans - .iter() - .any(|p| matches!(p.plan_type, PlanType::OptimizedLogicalPlan { .. })); - - assert!(opt, "plans: {:#?}", e.stringified_plans); - } else { - panic!("plan was not an explain: {:?}", plan); - } - } - #[tokio::test] async fn parallel_projection() -> Result<()> { let partition_count = 4; diff --git a/datafusion/tests/sql/explain.rs b/datafusion/tests/sql/explain.rs new file mode 100644 index 000000000000..00842b5eb8ab --- /dev/null +++ b/datafusion/tests/sql/explain.rs @@ -0,0 +1,60 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::{ + logical_plan::{LogicalPlan, LogicalPlanBuilder, PlanType}, + prelude::ExecutionContext, +}; + +#[test] +fn optimize_explain() { + let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); + + let plan = LogicalPlanBuilder::scan_empty(Some("employee"), &schema, None) + .unwrap() + .explain(true, false) + .unwrap() + .build() + .unwrap(); + + if let LogicalPlan::Explain(e) = &plan { + assert_eq!(e.stringified_plans.len(), 1); + } else { + panic!("plan was not an explain: {:?}", plan); + } + + // now optimize the plan and expect to see more plans + let optimized_plan = ExecutionContext::new().optimize(&plan).unwrap(); + if let LogicalPlan::Explain(e) = &optimized_plan { + // should have more than one plan + assert!( + e.stringified_plans.len() > 1, + "plans: {:#?}", + e.stringified_plans + ); + // should have at least one optimized plan + let opt = e + .stringified_plans + .iter() + .any(|p| matches!(p.plan_type, PlanType::OptimizedLogicalPlan { .. })); + + assert!(opt, "plans: {:#?}", e.stringified_plans); + } else { + panic!("plan was not an explain: {:?}", plan); + } +} diff --git a/datafusion/tests/sql/mod.rs b/datafusion/tests/sql/mod.rs index ea6829969462..95623d45e467 100644 --- a/datafusion/tests/sql/mod.rs +++ b/datafusion/tests/sql/mod.rs @@ -96,6 +96,7 @@ pub mod udf; pub mod union; pub mod window; +mod explain; pub mod information_schema; #[cfg_attr(not(feature = "unicode_expressions"), ignore)] pub mod unicode; From 40df55f7a6d9b816eae4ec073736327009eb8d4f Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Sat, 5 Feb 2022 13:51:13 +0800 Subject: [PATCH 30/50] use clap 3 style args parsing for datafusion cli (#1749) * use clap 3 style args parsing for datafusion cli * upgrade cli version --- datafusion-cli/Cargo.toml | 3 +- datafusion-cli/src/command.rs | 11 +- datafusion-cli/src/exec.rs | 10 +- datafusion-cli/src/functions.rs | 2 +- datafusion-cli/src/lib.rs | 1 - datafusion-cli/src/main.rs | 162 +++++++++++------------------ datafusion-cli/src/print_format.rs | 70 +------------ 7 files changed, 80 insertions(+), 179 deletions(-) diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index e8f1e3083014..06a7e873e383 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -17,7 +17,8 @@ [package] name = "datafusion-cli" -version = "5.1.0" +description = "DataFusion is an in-memory query engine that uses Apache Arrow as the memory model. It supports executing SQL queries against CSV and Parquet files as well as querying directly against in-memory data." +version = "6.0.0" authors = ["Apache Arrow "] edition = "2021" keywords = [ "arrow", "datafusion", "ballista", "query", "sql" ] diff --git a/datafusion-cli/src/command.rs b/datafusion-cli/src/command.rs index ef6f67d69b66..0fd43a3071e5 100644 --- a/datafusion-cli/src/command.rs +++ b/datafusion-cli/src/command.rs @@ -20,7 +20,8 @@ use crate::context::Context; use crate::functions::{display_all_functions, Function}; use crate::print_format::PrintFormat; -use crate::print_options::{self, PrintOptions}; +use crate::print_options::PrintOptions; +use clap::ArgEnum; use datafusion::arrow::array::{ArrayRef, StringArray}; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::arrow::record_batch::RecordBatch; @@ -206,10 +207,14 @@ impl OutputFormat { Self::ChangeFormat(format) => { if let Ok(format) = format.parse::() { print_options.format = format; - println!("Output format is {}.", print_options.format); + println!("Output format is {:?}.", print_options.format); Ok(()) } else { - Err(DataFusionError::Execution(format!("{} is not a valid format type [possible values: csv, tsv, table, json, ndjson]", format))) + Err(DataFusionError::Execution(format!( + "{:?} is not a valid format type [possible values: {:?}]", + format, + PrintFormat::value_variants() + ))) } } } diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index dad6d6eb559e..17b329b86d9b 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -21,20 +21,14 @@ use crate::{ command::{Command, OutputFormat}, context::Context, helper::CliHelper, - print_format::{all_print_formats, PrintFormat}, print_options::PrintOptions, }; -use datafusion::arrow::record_batch::RecordBatch; -use datafusion::arrow::util::pretty; -use datafusion::error::{DataFusionError, Result}; -use rustyline::config::Config; +use datafusion::error::Result; use rustyline::error::ReadlineError; use rustyline::Editor; use std::fs::File; use std::io::prelude::*; use std::io::BufReader; -use std::str::FromStr; -use std::sync::Arc; use std::time::Instant; /// run and execute SQL statements and commands from a file, against a context with the given print options @@ -109,7 +103,7 @@ pub async fn exec_from_repl(ctx: &mut Context, print_options: &mut PrintOptions) ); } } else { - println!("Output format is {}.", print_options.format); + println!("Output format is {:?}.", print_options.format); } } _ => { diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index 2372e648d0f0..98b698ab5fb6 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -20,7 +20,7 @@ use arrow::array::StringArray; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; -use datafusion::error::{DataFusionError, Result}; +use datafusion::error::Result; use std::fmt; use std::str::FromStr; use std::sync::Arc; diff --git a/datafusion-cli/src/lib.rs b/datafusion-cli/src/lib.rs index b2bcdd3e48a6..b75be331259b 100644 --- a/datafusion-cli/src/lib.rs +++ b/datafusion-cli/src/lib.rs @@ -16,7 +16,6 @@ // under the License. #![doc = include_str!("../README.md")] -#![allow(unused_imports)] pub const DATAFUSION_CLI_VERSION: &str = env!("CARGO_PKG_VERSION"); pub mod command; diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index 4cb9e9ddef14..788bb27f899a 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -15,14 +15,11 @@ // specific language governing permissions and limitations // under the License. -use clap::{crate_version, App, Arg}; +use clap::Parser; use datafusion::error::Result; use datafusion::execution::context::ExecutionConfig; use datafusion_cli::{ - context::Context, - exec, - print_format::{all_print_formats, PrintFormat}, - print_options::PrintOptions, + context::Context, exec, print_format::PrintFormat, print_options::PrintOptions, DATAFUSION_CLI_VERSION, }; use std::env; @@ -30,117 +27,84 @@ use std::fs::File; use std::io::BufReader; use std::path::Path; +#[derive(Debug, Parser, PartialEq)] +#[clap(author, version, about, long_about= None)] +struct Args { + #[clap( + short = 'p', + long, + help = "Path to your data, default to current directory", + validator(is_valid_data_dir) + )] + data_path: Option, + + #[clap( + short = 'c', + long, + help = "The batch size of each query, or use DataFusion default", + validator(is_valid_batch_size) + )] + batch_size: Option, + + #[clap( + short, + long, + multiple_values = true, + help = "Execute commands from file(s), then exit", + validator(is_valid_file) + )] + file: Vec, + + #[clap(long, arg_enum, default_value_t = PrintFormat::Table)] + format: PrintFormat, + + #[clap(long, help = "Ballista scheduler host")] + host: Option, + + #[clap(long, help = "Ballista scheduler port")] + port: Option, + + #[clap( + short, + long, + help = "Reduce printing other than the results and work quietly" + )] + quiet: bool, +} + #[tokio::main] pub async fn main() -> Result<()> { - let matches = App::new("DataFusion") - .version(crate_version!()) - .about( - "DataFusion is an in-memory query engine that uses Apache Arrow \ - as the memory model. It supports executing SQL queries against CSV and \ - Parquet files as well as querying directly against in-memory data.", - ) - .arg( - Arg::new("data-path") - .help("Path to your data, default to current directory") - .short('p') - .long("data-path") - .validator(is_valid_data_dir) - .takes_value(true), - ) - .arg( - Arg::new("batch-size") - .help("The batch size of each query, or use DataFusion default") - .short('c') - .long("batch-size") - .validator(is_valid_batch_size) - .takes_value(true), - ) - .arg( - Arg::new("file") - .help("Execute commands from file(s), then exit") - .short('f') - .long("file") - .multiple_occurrences(true) - .validator(is_valid_file) - .takes_value(true), - ) - .arg( - Arg::new("format") - .help("Output format") - .long("format") - .default_value("table") - .possible_values( - &all_print_formats() - .iter() - .map(|format| format.to_string()) - .collect::>() - .iter() - .map(|i| i.as_str()) - .collect::>(), - ) - .takes_value(true), - ) - .arg( - Arg::new("host") - .help("Ballista scheduler host") - .long("host") - .takes_value(true), - ) - .arg( - Arg::new("port") - .help("Ballista scheduler port") - .long("port") - .takes_value(true), - ) - .arg( - Arg::new("quiet") - .help("Reduce printing other than the results and work quietly") - .short('q') - .long("quiet") - .takes_value(false), - ) - .get_matches(); - - let quiet = matches.is_present("quiet"); - - if !quiet { - println!("DataFusion CLI v{}\n", DATAFUSION_CLI_VERSION); - } + let args = Args::parse(); - let host = matches.value_of("host"); - let port = matches - .value_of("port") - .and_then(|port| port.parse::().ok()); + if !args.quiet { + println!("DataFusion CLI v{}", DATAFUSION_CLI_VERSION); + } - if let Some(path) = matches.value_of("data-path") { + if let Some(ref path) = args.data_path { let p = Path::new(path); env::set_current_dir(&p).unwrap(); }; let mut execution_config = ExecutionConfig::new().with_information_schema(true); - if let Some(batch_size) = matches - .value_of("batch-size") - .and_then(|size| size.parse::().ok()) - { + if let Some(batch_size) = args.batch_size { execution_config = execution_config.with_batch_size(batch_size); }; - let mut ctx: Context = match (host, port) { - (Some(h), Some(p)) => Context::new_remote(h, p)?, + let mut ctx: Context = match (args.host, args.port) { + (Some(ref h), Some(p)) => Context::new_remote(h, p)?, _ => Context::new_local(&execution_config), }; - let format = matches - .value_of("format") - .expect("No format is specified") - .parse::() - .expect("Invalid format"); - - let mut print_options = PrintOptions { format, quiet }; + let mut print_options = PrintOptions { + format: args.format, + quiet: args.quiet, + }; - if let Some(file_paths) = matches.values_of("file") { - let files = file_paths + let files = args.file; + if !files.is_empty() { + let files = files + .into_iter() .map(|file_path| File::open(file_path).unwrap()) .collect::>(); for file in files { diff --git a/datafusion-cli/src/print_format.rs b/datafusion-cli/src/print_format.rs index 0320166bcbe9..05a1ef7b10d8 100644 --- a/datafusion-cli/src/print_format.rs +++ b/datafusion-cli/src/print_format.rs @@ -21,11 +21,10 @@ use arrow::json::{ArrayWriter, LineDelimitedWriter}; use datafusion::arrow::record_batch::RecordBatch; use datafusion::arrow::util::pretty; use datafusion::error::{DataFusionError, Result}; -use std::fmt; use std::str::FromStr; /// Allow records to be printed in different formats -#[derive(Debug, PartialEq, Eq, Clone)] +#[derive(Debug, PartialEq, Eq, clap::ArgEnum, Clone)] pub enum PrintFormat { Csv, Tsv, @@ -34,40 +33,11 @@ pub enum PrintFormat { NdJson, } -/// returns all print formats -pub fn all_print_formats() -> Vec { - vec![ - PrintFormat::Csv, - PrintFormat::Tsv, - PrintFormat::Table, - PrintFormat::Json, - PrintFormat::NdJson, - ] -} - impl FromStr for PrintFormat { - type Err = (); - fn from_str(s: &str) -> std::result::Result { - match s.to_lowercase().as_str() { - "csv" => Ok(Self::Csv), - "tsv" => Ok(Self::Tsv), - "table" => Ok(Self::Table), - "json" => Ok(Self::Json), - "ndjson" => Ok(Self::NdJson), - _ => Err(()), - } - } -} + type Err = String; -impl fmt::Display for PrintFormat { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - Self::Csv => write!(f, "csv"), - Self::Tsv => write!(f, "tsv"), - Self::Table => write!(f, "table"), - Self::Json => write!(f, "json"), - Self::NdJson => write!(f, "ndjson"), - } + fn from_str(s: &str) -> std::result::Result { + clap::ArgEnum::from_str(s, true) } } @@ -123,38 +93,6 @@ mod tests { use datafusion::from_slice::FromSlice; use std::sync::Arc; - #[test] - fn test_from_str() { - let format = "csv".parse::().unwrap(); - assert_eq!(PrintFormat::Csv, format); - - let format = "tsv".parse::().unwrap(); - assert_eq!(PrintFormat::Tsv, format); - - let format = "json".parse::().unwrap(); - assert_eq!(PrintFormat::Json, format); - - let format = "ndjson".parse::().unwrap(); - assert_eq!(PrintFormat::NdJson, format); - - let format = "table".parse::().unwrap(); - assert_eq!(PrintFormat::Table, format); - } - - #[test] - fn test_to_str() { - assert_eq!("csv", PrintFormat::Csv.to_string()); - assert_eq!("table", PrintFormat::Table.to_string()); - assert_eq!("tsv", PrintFormat::Tsv.to_string()); - assert_eq!("json", PrintFormat::Json.to_string()); - assert_eq!("ndjson", PrintFormat::NdJson.to_string()); - } - - #[test] - fn test_from_str_failure() { - assert!("pretty".parse::().is_err()); - } - #[test] fn test_print_batches_with_sep() { let batches = vec![]; From e52f844427cc80fc6551851117266f5e2372ec6c Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 5 Feb 2022 06:10:35 -0500 Subject: [PATCH 31/50] Add partitioned_csv setup code to sql_integration test (#1743) --- datafusion/src/execution/context.rs | 239 +----------------------- datafusion/tests/sql/mod.rs | 1 + datafusion/tests/sql/partitioned_csv.rs | 95 ++++++++++ datafusion/tests/sql/projection.rs | 192 +++++++++++++++++++ datafusion/tests/sql/select.rs | 59 +++++- 5 files changed, 347 insertions(+), 239 deletions(-) create mode 100644 datafusion/tests/sql/partitioned_csv.rs diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index fb271a1a7e56..96e49c800f48 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -1281,11 +1281,9 @@ mod tests { use super::*; use crate::execution::context::QueryPlanner; use crate::from_slice::FromSlice; - use crate::logical_plan::plan::Projection; - use crate::logical_plan::TableScan; use crate::logical_plan::{binary_expr, lit, Operator}; + use crate::physical_plan::collect; use crate::physical_plan::functions::{make_scalar_function, Volatility}; - use crate::physical_plan::{collect, collect_partitioned}; use crate::test; use crate::variable::VarType; use crate::{ @@ -1311,7 +1309,6 @@ mod tests { use std::thread::{self, JoinHandle}; use std::{io::prelude::*, sync::Mutex}; use tempfile::TempDir; - use test::*; #[tokio::test] async fn shared_memory_and_disk_manager() { @@ -1347,62 +1344,6 @@ mod tests { )); } - #[tokio::test] - async fn parallel_projection() -> Result<()> { - let partition_count = 4; - let results = execute("SELECT c1, c2 FROM test", partition_count).await?; - - let expected = vec![ - "+----+----+", - "| c1 | c2 |", - "+----+----+", - "| 3 | 1 |", - "| 3 | 2 |", - "| 3 | 3 |", - "| 3 | 4 |", - "| 3 | 5 |", - "| 3 | 6 |", - "| 3 | 7 |", - "| 3 | 8 |", - "| 3 | 9 |", - "| 3 | 10 |", - "| 2 | 1 |", - "| 2 | 2 |", - "| 2 | 3 |", - "| 2 | 4 |", - "| 2 | 5 |", - "| 2 | 6 |", - "| 2 | 7 |", - "| 2 | 8 |", - "| 2 | 9 |", - "| 2 | 10 |", - "| 1 | 1 |", - "| 1 | 2 |", - "| 1 | 3 |", - "| 1 | 4 |", - "| 1 | 5 |", - "| 1 | 6 |", - "| 1 | 7 |", - "| 1 | 8 |", - "| 1 | 9 |", - "| 1 | 10 |", - "| 0 | 1 |", - "| 0 | 2 |", - "| 0 | 3 |", - "| 0 | 4 |", - "| 0 | 5 |", - "| 0 | 6 |", - "| 0 | 7 |", - "| 0 | 8 |", - "| 0 | 9 |", - "| 0 | 10 |", - "+----+----+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) - } - #[tokio::test] async fn create_variable_expr() -> Result<()> { let tmp_dir = TempDir::new()?; @@ -1447,184 +1388,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn parallel_query_with_filter() -> Result<()> { - let tmp_dir = TempDir::new()?; - let partition_count = 4; - let ctx = create_ctx(&tmp_dir, partition_count).await?; - - let logical_plan = - ctx.create_logical_plan("SELECT c1, c2 FROM test WHERE c1 > 0 AND c1 < 3")?; - let logical_plan = ctx.optimize(&logical_plan)?; - - let physical_plan = ctx.create_physical_plan(&logical_plan).await?; - - let runtime = ctx.state.lock().runtime_env.clone(); - let results = collect_partitioned(physical_plan, runtime).await?; - - // note that the order of partitions is not deterministic - let mut num_rows = 0; - for partition in &results { - for batch in partition { - num_rows += batch.num_rows(); - } - } - assert_eq!(20, num_rows); - - let results: Vec = results.into_iter().flatten().collect(); - let expected = vec![ - "+----+----+", - "| c1 | c2 |", - "+----+----+", - "| 1 | 1 |", - "| 1 | 10 |", - "| 1 | 2 |", - "| 1 | 3 |", - "| 1 | 4 |", - "| 1 | 5 |", - "| 1 | 6 |", - "| 1 | 7 |", - "| 1 | 8 |", - "| 1 | 9 |", - "| 2 | 1 |", - "| 2 | 10 |", - "| 2 | 2 |", - "| 2 | 3 |", - "| 2 | 4 |", - "| 2 | 5 |", - "| 2 | 6 |", - "| 2 | 7 |", - "| 2 | 8 |", - "| 2 | 9 |", - "+----+----+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) - } - - #[tokio::test] - async fn projection_on_table_scan() -> Result<()> { - let tmp_dir = TempDir::new()?; - let partition_count = 4; - let ctx = create_ctx(&tmp_dir, partition_count).await?; - let runtime = ctx.state.lock().runtime_env.clone(); - - let table = ctx.table("test")?; - let logical_plan = LogicalPlanBuilder::from(table.to_logical_plan()) - .project(vec![col("c2")])? - .build()?; - - let optimized_plan = ctx.optimize(&logical_plan)?; - match &optimized_plan { - LogicalPlan::Projection(Projection { input, .. }) => match &**input { - LogicalPlan::TableScan(TableScan { - source, - projected_schema, - .. - }) => { - assert_eq!(source.schema().fields().len(), 3); - assert_eq!(projected_schema.fields().len(), 1); - } - _ => panic!("input to projection should be TableScan"), - }, - _ => panic!("expect optimized_plan to be projection"), - } - - let expected = "Projection: #test.c2\ - \n TableScan: test projection=Some([1])"; - assert_eq!(format!("{:?}", optimized_plan), expected); - - let physical_plan = ctx.create_physical_plan(&optimized_plan).await?; - - assert_eq!(1, physical_plan.schema().fields().len()); - assert_eq!("c2", physical_plan.schema().field(0).name().as_str()); - - let batches = collect(physical_plan, runtime).await?; - assert_eq!(40, batches.iter().map(|x| x.num_rows()).sum::()); - - Ok(()) - } - - #[tokio::test] - async fn preserve_nullability_on_projection() -> Result<()> { - let tmp_dir = TempDir::new()?; - let ctx = create_ctx(&tmp_dir, 1).await?; - - let schema: Schema = ctx.table("test").unwrap().schema().clone().into(); - assert!(!schema.field_with_name("c1")?.is_nullable()); - - let plan = LogicalPlanBuilder::scan_empty(None, &schema, None)? - .project(vec![col("c1")])? - .build()?; - - let plan = ctx.optimize(&plan)?; - let physical_plan = ctx.create_physical_plan(&Arc::new(plan)).await?; - assert!(!physical_plan.schema().field_with_name("c1")?.is_nullable()); - Ok(()) - } - - #[tokio::test] - async fn projection_on_memory_scan() -> Result<()> { - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - Field::new("c", DataType::Int32, false), - ]); - let schema = SchemaRef::new(schema); - - let partitions = vec![vec![RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from_slice(&[1, 10, 10, 100])), - Arc::new(Int32Array::from_slice(&[2, 12, 12, 120])), - Arc::new(Int32Array::from_slice(&[3, 12, 12, 120])), - ], - )?]]; - - let plan = LogicalPlanBuilder::scan_memory(partitions, schema, None)? - .project(vec![col("b")])? - .build()?; - assert_fields_eq(&plan, vec!["b"]); - - let ctx = ExecutionContext::new(); - let optimized_plan = ctx.optimize(&plan)?; - match &optimized_plan { - LogicalPlan::Projection(Projection { input, .. }) => match &**input { - LogicalPlan::TableScan(TableScan { - source, - projected_schema, - .. - }) => { - assert_eq!(source.schema().fields().len(), 3); - assert_eq!(projected_schema.fields().len(), 1); - } - _ => panic!("input to projection should be InMemoryScan"), - }, - _ => panic!("expect optimized_plan to be projection"), - } - - let expected = format!( - "Projection: #{}.b\ - \n TableScan: {} projection=Some([1])", - UNNAMED_TABLE, UNNAMED_TABLE - ); - assert_eq!(format!("{:?}", optimized_plan), expected); - - let physical_plan = ctx.create_physical_plan(&optimized_plan).await?; - - assert_eq!(1, physical_plan.schema().fields().len()); - assert_eq!("b", physical_plan.schema().field(0).name().as_str()); - - let runtime = ctx.state.lock().runtime_env.clone(); - let batches = collect(physical_plan, runtime).await?; - assert_eq!(1, batches.len()); - assert_eq!(1, batches[0].num_columns()); - assert_eq!(4, batches[0].num_rows()); - - Ok(()) - } - #[tokio::test] async fn sort() -> Result<()> { let results = diff --git a/datafusion/tests/sql/mod.rs b/datafusion/tests/sql/mod.rs index 95623d45e467..468762ea05bb 100644 --- a/datafusion/tests/sql/mod.rs +++ b/datafusion/tests/sql/mod.rs @@ -98,6 +98,7 @@ pub mod window; mod explain; pub mod information_schema; +mod partitioned_csv; #[cfg_attr(not(feature = "unicode_expressions"), ignore)] pub mod unicode; diff --git a/datafusion/tests/sql/partitioned_csv.rs b/datafusion/tests/sql/partitioned_csv.rs new file mode 100644 index 000000000000..5efc837d5c95 --- /dev/null +++ b/datafusion/tests/sql/partitioned_csv.rs @@ -0,0 +1,95 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utility functions for running with a partitioned csv dataset: + +use std::{io::Write, sync::Arc}; + +use arrow::{ + datatypes::{DataType, Field, Schema, SchemaRef}, + record_batch::RecordBatch, +}; +use datafusion::{ + error::Result, + prelude::{CsvReadOptions, ExecutionConfig, ExecutionContext}, +}; +use tempfile::TempDir; + +/// Execute SQL and return results +async fn plan_and_collect( + ctx: &mut ExecutionContext, + sql: &str, +) -> Result> { + ctx.sql(sql).await?.collect().await +} + +/// Execute SQL and return results +pub async fn execute(sql: &str, partition_count: usize) -> Result> { + let tmp_dir = TempDir::new()?; + let mut ctx = create_ctx(&tmp_dir, partition_count).await?; + plan_and_collect(&mut ctx, sql).await +} + +/// Generate CSV partitions within the supplied directory +fn populate_csv_partitions( + tmp_dir: &TempDir, + partition_count: usize, + file_extension: &str, +) -> Result { + // define schema for data source (csv file) + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::UInt32, false), + Field::new("c2", DataType::UInt64, false), + Field::new("c3", DataType::Boolean, false), + ])); + + // generate a partitioned file + for partition in 0..partition_count { + let filename = format!("partition-{}.{}", partition, file_extension); + let file_path = tmp_dir.path().join(&filename); + let mut file = std::fs::File::create(file_path)?; + + // generate some data + for i in 0..=10 { + let data = format!("{},{},{}\n", partition, i, i % 2 == 0); + file.write_all(data.as_bytes())?; + } + } + + Ok(schema) +} + +/// Generate a partitioned CSV file and register it with an execution context +pub async fn create_ctx( + tmp_dir: &TempDir, + partition_count: usize, +) -> Result { + let mut ctx = + ExecutionContext::with_config(ExecutionConfig::new().with_target_partitions(8)); + + let schema = populate_csv_partitions(tmp_dir, partition_count, ".csv")?; + + // register csv file with the execution context + ctx.register_csv( + "test", + tmp_dir.path().to_str().unwrap(), + CsvReadOptions::new().schema(&schema), + ) + .await?; + + Ok(ctx) +} diff --git a/datafusion/tests/sql/projection.rs b/datafusion/tests/sql/projection.rs index 57fa598bb754..0a956a9411eb 100644 --- a/datafusion/tests/sql/projection.rs +++ b/datafusion/tests/sql/projection.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +use datafusion::logical_plan::{LogicalPlanBuilder, UNNAMED_TABLE}; +use tempfile::TempDir; + use super::*; #[tokio::test] @@ -73,3 +76,192 @@ async fn csv_query_group_by_avg_with_projection() -> Result<()> { assert_batches_sorted_eq!(expected, &actual); Ok(()) } + +#[tokio::test] +async fn parallel_projection() -> Result<()> { + let partition_count = 4; + let results = + partitioned_csv::execute("SELECT c1, c2 FROM test", partition_count).await?; + + let expected = vec![ + "+----+----+", + "| c1 | c2 |", + "+----+----+", + "| 3 | 1 |", + "| 3 | 2 |", + "| 3 | 3 |", + "| 3 | 4 |", + "| 3 | 5 |", + "| 3 | 6 |", + "| 3 | 7 |", + "| 3 | 8 |", + "| 3 | 9 |", + "| 3 | 10 |", + "| 2 | 1 |", + "| 2 | 2 |", + "| 2 | 3 |", + "| 2 | 4 |", + "| 2 | 5 |", + "| 2 | 6 |", + "| 2 | 7 |", + "| 2 | 8 |", + "| 2 | 9 |", + "| 2 | 10 |", + "| 1 | 1 |", + "| 1 | 2 |", + "| 1 | 3 |", + "| 1 | 4 |", + "| 1 | 5 |", + "| 1 | 6 |", + "| 1 | 7 |", + "| 1 | 8 |", + "| 1 | 9 |", + "| 1 | 10 |", + "| 0 | 1 |", + "| 0 | 2 |", + "| 0 | 3 |", + "| 0 | 4 |", + "| 0 | 5 |", + "| 0 | 6 |", + "| 0 | 7 |", + "| 0 | 8 |", + "| 0 | 9 |", + "| 0 | 10 |", + "+----+----+", + ]; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn projection_on_table_scan() -> Result<()> { + let tmp_dir = TempDir::new()?; + let partition_count = 4; + let ctx = partitioned_csv::create_ctx(&tmp_dir, partition_count).await?; + let runtime = ctx.state.lock().runtime_env.clone(); + + let table = ctx.table("test")?; + let logical_plan = LogicalPlanBuilder::from(table.to_logical_plan()) + .project(vec![col("c2")])? + .build()?; + + let optimized_plan = ctx.optimize(&logical_plan)?; + match &optimized_plan { + LogicalPlan::Projection(Projection { input, .. }) => match &**input { + LogicalPlan::TableScan(TableScan { + source, + projected_schema, + .. + }) => { + assert_eq!(source.schema().fields().len(), 3); + assert_eq!(projected_schema.fields().len(), 1); + } + _ => panic!("input to projection should be TableScan"), + }, + _ => panic!("expect optimized_plan to be projection"), + } + + let expected = "Projection: #test.c2\ + \n TableScan: test projection=Some([1])"; + assert_eq!(format!("{:?}", optimized_plan), expected); + + let physical_plan = ctx.create_physical_plan(&optimized_plan).await?; + + assert_eq!(1, physical_plan.schema().fields().len()); + assert_eq!("c2", physical_plan.schema().field(0).name().as_str()); + + let batches = collect(physical_plan, runtime).await?; + assert_eq!(40, batches.iter().map(|x| x.num_rows()).sum::()); + + Ok(()) +} + +#[tokio::test] +async fn preserve_nullability_on_projection() -> Result<()> { + let tmp_dir = TempDir::new()?; + let ctx = partitioned_csv::create_ctx(&tmp_dir, 1).await?; + + let schema: Schema = ctx.table("test").unwrap().schema().clone().into(); + assert!(!schema.field_with_name("c1")?.is_nullable()); + + let plan = LogicalPlanBuilder::scan_empty(None, &schema, None)? + .project(vec![col("c1")])? + .build()?; + + let plan = ctx.optimize(&plan)?; + let physical_plan = ctx.create_physical_plan(&Arc::new(plan)).await?; + assert!(!physical_plan.schema().field_with_name("c1")?.is_nullable()); + Ok(()) +} + +#[tokio::test] +async fn projection_on_memory_scan() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ]); + let schema = SchemaRef::new(schema); + + let partitions = vec![vec![RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from_slice(&[1, 10, 10, 100])), + Arc::new(Int32Array::from_slice(&[2, 12, 12, 120])), + Arc::new(Int32Array::from_slice(&[3, 12, 12, 120])), + ], + )?]]; + + let plan = LogicalPlanBuilder::scan_memory(partitions, schema, None)? + .project(vec![col("b")])? + .build()?; + assert_fields_eq(&plan, vec!["b"]); + + let ctx = ExecutionContext::new(); + let optimized_plan = ctx.optimize(&plan)?; + match &optimized_plan { + LogicalPlan::Projection(Projection { input, .. }) => match &**input { + LogicalPlan::TableScan(TableScan { + source, + projected_schema, + .. + }) => { + assert_eq!(source.schema().fields().len(), 3); + assert_eq!(projected_schema.fields().len(), 1); + } + _ => panic!("input to projection should be InMemoryScan"), + }, + _ => panic!("expect optimized_plan to be projection"), + } + + let expected = format!( + "Projection: #{}.b\ + \n TableScan: {} projection=Some([1])", + UNNAMED_TABLE, UNNAMED_TABLE + ); + assert_eq!(format!("{:?}", optimized_plan), expected); + + let physical_plan = ctx.create_physical_plan(&optimized_plan).await?; + + assert_eq!(1, physical_plan.schema().fields().len()); + assert_eq!("b", physical_plan.schema().field(0).name().as_str()); + + let runtime = ctx.state.lock().runtime_env.clone(); + let batches = collect(physical_plan, runtime).await?; + assert_eq!(1, batches.len()); + assert_eq!(1, batches[0].num_columns()); + assert_eq!(4, batches[0].num_rows()); + + Ok(()) +} + +fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) { + let actual: Vec = plan + .schema() + .fields() + .iter() + .map(|f| f.name().clone()) + .collect(); + assert_eq!(actual, expected); +} diff --git a/datafusion/tests/sql/select.rs b/datafusion/tests/sql/select.rs index 759a45c9fca9..02869dd99d2b 100644 --- a/datafusion/tests/sql/select.rs +++ b/datafusion/tests/sql/select.rs @@ -16,7 +16,8 @@ // under the License. use super::*; -use datafusion::from_slice::FromSlice; +use datafusion::{from_slice::FromSlice, physical_plan::collect_partitioned}; +use tempfile::TempDir; #[tokio::test] async fn all_where_empty() -> Result<()> { @@ -928,3 +929,59 @@ async fn csv_select_nested() -> Result<()> { assert_batches_eq!(expected, &actual); Ok(()) } + +#[tokio::test] +async fn parallel_query_with_filter() -> Result<()> { + let tmp_dir = TempDir::new()?; + let partition_count = 4; + let ctx = partitioned_csv::create_ctx(&tmp_dir, partition_count).await?; + + let logical_plan = + ctx.create_logical_plan("SELECT c1, c2 FROM test WHERE c1 > 0 AND c1 < 3")?; + let logical_plan = ctx.optimize(&logical_plan)?; + + let physical_plan = ctx.create_physical_plan(&logical_plan).await?; + + let runtime = ctx.state.lock().runtime_env.clone(); + let results = collect_partitioned(physical_plan, runtime).await?; + + // note that the order of partitions is not deterministic + let mut num_rows = 0; + for partition in &results { + for batch in partition { + num_rows += batch.num_rows(); + } + } + assert_eq!(20, num_rows); + + let results: Vec = results.into_iter().flatten().collect(); + let expected = vec![ + "+----+----+", + "| c1 | c2 |", + "+----+----+", + "| 1 | 1 |", + "| 1 | 10 |", + "| 1 | 2 |", + "| 1 | 3 |", + "| 1 | 4 |", + "| 1 | 5 |", + "| 1 | 6 |", + "| 1 | 7 |", + "| 1 | 8 |", + "| 1 | 9 |", + "| 2 | 1 |", + "| 2 | 10 |", + "| 2 | 2 |", + "| 2 | 3 |", + "| 2 | 4 |", + "| 2 | 5 |", + "| 2 | 6 |", + "| 2 | 7 |", + "| 2 | 8 |", + "| 2 | 9 |", + "+----+----+", + ]; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} From 4f4153ba26a1b0141b2ee36fd521607389d9f611 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 5 Feb 2022 10:26:42 -0700 Subject: [PATCH 32/50] use ordered-float 2.10 (#1756) Signed-off-by: Andy Grove --- datafusion/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index 54247cbcf07c..81e2bb14877b 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -68,7 +68,7 @@ md-5 = { version = "^0.10.0", optional = true } sha2 = { version = "^0.10.1", optional = true } blake2 = { version = "^0.10.2", optional = true } blake3 = { version = "1.0", optional = true } -ordered-float = "2.0" +ordered-float = "2.10" unicode-segmentation = { version = "^1.7.1", optional = true } regex = { version = "^1.4.3", optional = true } lazy_static = { version = "^1.4.0" } From f139ef812590dbd65292640499f6e09736167765 Mon Sep 17 00:00:00 2001 From: Rich Date: Mon, 7 Feb 2022 03:00:33 -0500 Subject: [PATCH 33/50] #1768 Support TimeUnit::Second in hasher (#1769) * Support TimeUnit::Second in hasher * fix linter --- datafusion/src/physical_plan/hash_utils.rs | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/datafusion/src/physical_plan/hash_utils.rs b/datafusion/src/physical_plan/hash_utils.rs index 5f7a610db075..27a5376cf749 100644 --- a/datafusion/src/physical_plan/hash_utils.rs +++ b/datafusion/src/physical_plan/hash_utils.rs @@ -23,8 +23,8 @@ use arrow::array::{ Array, ArrayRef, BooleanArray, Date32Array, Date64Array, DecimalArray, DictionaryArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, LargeStringArray, StringArray, TimestampMicrosecondArray, - TimestampMillisecondArray, TimestampNanosecondArray, UInt16Array, UInt32Array, - UInt64Array, UInt8Array, + TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, + UInt16Array, UInt32Array, UInt64Array, UInt8Array, }; use arrow::datatypes::{ ArrowDictionaryKeyType, ArrowNativeType, DataType, Int16Type, Int32Type, Int64Type, @@ -387,6 +387,16 @@ pub fn create_hashes<'a>( multi_col ); } + DataType::Timestamp(TimeUnit::Second, None) => { + hash_array_primitive!( + TimestampSecondArray, + col, + i64, + hashes_buffer, + random_state, + multi_col + ); + } DataType::Timestamp(TimeUnit::Millisecond, None) => { hash_array_primitive!( TimestampMillisecondArray, From 31d0adf897a59739b587f634a93bbf9a1be9092d Mon Sep 17 00:00:00 2001 From: "xudong.w" Date: Mon, 7 Feb 2022 19:53:58 +0800 Subject: [PATCH 34/50] format (#1745) --- README.md | 356 +----------------- docs/source/index.rst | 1 + .../source/specification/quarterly_roadmap.md | 72 ++++ docs/source/user-guide/sql/index.rst | 1 + docs/source/user-guide/sql/sql_status.md | 241 ++++++++++++ 5 files changed, 323 insertions(+), 348 deletions(-) create mode 100644 docs/source/specification/quarterly_roadmap.md create mode 100644 docs/source/user-guide/sql/sql_status.md diff --git a/README.md b/README.md index 1ea972cc4df1..dc350f69bb9c 100644 --- a/README.md +++ b/README.md @@ -73,363 +73,23 @@ Here are some of the projects known to use DataFusion: ## Example Usage -Run a SQL query against data stored in a CSV: +Please see [example usage](https://arrow.apache.org/datafusion/user-guide/example-usage.html) to find how to use DataFusion. -```rust -use datafusion::prelude::*; -use datafusion::arrow::util::pretty::print_batches; -use datafusion::arrow::record_batch::RecordBatch; - -#[tokio::main] -async fn main() -> datafusion::error::Result<()> { - // register the table - let mut ctx = ExecutionContext::new(); - ctx.register_csv("example", "tests/example.csv", CsvReadOptions::new()).await?; - - // create a plan to run a SQL query - let df = ctx.sql("SELECT a, MIN(b) FROM example GROUP BY a LIMIT 100").await?; - - // execute and print results - df.show().await?; - Ok(()) -} -``` - -Use the DataFrame API to process data stored in a CSV: - -```rust -use datafusion::prelude::*; -use datafusion::arrow::util::pretty::print_batches; -use datafusion::arrow::record_batch::RecordBatch; - -#[tokio::main] -async fn main() -> datafusion::error::Result<()> { - // create the dataframe - let mut ctx = ExecutionContext::new(); - let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; - - let df = df.filter(col("a").lt_eq(col("b")))? - .aggregate(vec![col("a")], vec![min(col("b"))])?; - - // execute and print results - df.show_limit(100).await?; - Ok(()) -} -``` - -Both of these examples will produce - -```text -+---+--------+ -| a | MIN(b) | -+---+--------+ -| 1 | 2 | -+---+--------+ -``` - -## Using DataFusion as a library - -DataFusion is [published on crates.io](https://crates.io/crates/datafusion), and is [well documented on docs.rs](https://docs.rs/datafusion/). - -To get started, add the following to your `Cargo.toml` file: - -```toml -[dependencies] -datafusion = "6.0.0" -``` - -## Using DataFusion as a binary - -DataFusion also includes a simple command-line interactive SQL utility. See the [CLI reference](https://arrow.apache.org/datafusion/cli/index.html) for more information. - -# Roadmap - -A quarterly roadmap will be published to give the DataFusion community visibility into the priorities of the projects contributors. This roadmap is not binding. - -## 2022 Q1 - -### DataFusion Core - -- Publish official Arrow2 branch -- Implementation of memory manager (i.e. to enable spilling to disk as needed) - -### Benchmarking - -- Inclusion in Db-Benchmark with all quries covered -- All TPCH queries covered - -### Performance Improvements - -- Predicate evaluation -- Improve multi-column comparisons (that can't be vectorized at the moment) -- Null constant support - -### New Features - -- Read JSON as table -- Simplify DDL with Datafusion-Cli -- Add Decimal128 data type and the attendant features such as Arrow Kernel and UDF support -- Add new experimental e-graph based optimizer - -### Ballista - -- Begin work on design documents and plan / priorities for development - -### Extensions ([datafusion-contrib](https://github.com/datafusion-contrib])) - -- Stable S3 support -- Begin design discussions and prototyping of a stream provider - -## Beyond 2022 Q1 - -There is no clear timeline for the below, but community members have expressed interest in working on these topics. - -### DataFusion Core - -- Custom SQL support -- Split DataFusion into multiple crates -- Push based query execution and code generation - -### Ballista - -- Evolve architecture so that it can be deployed in a multi-tenant cloud native environment -- Ensure Ballista is scalable, elastic, and stable for production usage -- Develop distributed ML capabilities - -# Status - -## General - -- [x] SQL Parser -- [x] SQL Query Planner -- [x] Query Optimizer -- [x] Constant folding -- [x] Join Reordering -- [x] Limit Pushdown -- [x] Projection push down -- [x] Predicate push down -- [x] Type coercion -- [x] Parallel query execution - -## SQL Support - -- [x] Projection -- [x] Filter (WHERE) -- [x] Filter post-aggregate (HAVING) -- [x] Limit -- [x] Aggregate -- [x] Common math functions -- [x] cast -- [x] try_cast -- [x] [`VALUES` lists](https://www.postgresql.org/docs/current/queries-values.html) -- Postgres compatible String functions - - [x] ascii - - [x] bit_length - - [x] btrim - - [x] char_length - - [x] character_length - - [x] chr - - [x] concat - - [x] concat_ws - - [x] initcap - - [x] left - - [x] length - - [x] lpad - - [x] ltrim - - [x] octet_length - - [x] regexp_replace - - [x] repeat - - [x] replace - - [x] reverse - - [x] right - - [x] rpad - - [x] rtrim - - [x] split_part - - [x] starts_with - - [x] strpos - - [x] substr - - [x] to_hex - - [x] translate - - [x] trim -- Miscellaneous/Boolean functions - - [x] nullif -- Approximation functions - - [x] approx_distinct -- Common date/time functions - - [ ] Basic date functions - - [ ] Basic time functions - - [x] Basic timestamp functions - - [x] [to_timestamp](docs/user-guide/book/sql/datafusion-functions.html#to_timestamp) - - [x] [to_timestamp_millis](docs/user-guide/book/sql/datafusion-functions.html#to_timestamp_millis) - - [x] [to_timestamp_micros](docs/user-guide/book/sql/datafusion-functions.html#to_timestamp_micros) - - [x] [to_timestamp_seconds](docs/user-guide/book/sql/datafusion-functions.html#to_timestamp_seconds) -- nested functions - - [x] Array of columns -- [x] Schema Queries - - [x] SHOW TABLES - - [x] SHOW COLUMNS - - [x] information_schema.{tables, columns} - - [ ] information_schema other views -- [x] Sorting -- [ ] Nested types -- [ ] Lists -- [x] Subqueries -- [x] Common table expressions -- [x] Set Operations - - [x] UNION ALL - - [x] UNION - - [x] INTERSECT - - [x] INTERSECT ALL - - [x] EXCEPT - - [x] EXCEPT ALL -- [x] Joins - - [x] INNER JOIN - - [x] LEFT JOIN - - [x] RIGHT JOIN - - [x] FULL JOIN - - [x] CROSS JOIN -- [ ] Window - - [x] Empty window - - [x] Common window functions - - [x] Window with PARTITION BY clause - - [x] Window with ORDER BY clause - - [ ] Window with FILTER clause - - [ ] [Window with custom WINDOW FRAME](https://github.com/apache/arrow-datafusion/issues/361) - - [ ] UDF and UDAF for window functions - -## Data Sources - -- [x] CSV -- [x] Parquet primitive types -- [ ] Parquet nested types - -## Extensibility - -DataFusion is designed to be extensible at all points. To that end, you can provide your own custom: - -- [x] User Defined Functions (UDFs) -- [x] User Defined Aggregate Functions (UDAFs) -- [x] User Defined Table Source (`TableProvider`) for tables -- [x] User Defined `Optimizer` passes (plan rewrites) -- [x] User Defined `LogicalPlan` nodes -- [x] User Defined `ExecutionPlan` nodes - -## Rust Version Compatbility - -This crate is tested with the latest stable version of Rust. We do not currently test against other, older versions of the Rust compiler. - -# Supported SQL - -This library currently supports many SQL constructs, including - -- `CREATE EXTERNAL TABLE X STORED AS PARQUET LOCATION '...';` to register a table's locations -- `SELECT ... FROM ...` together with any expression -- `ALIAS` to name an expression -- `CAST` to change types, including e.g. `Timestamp(Nanosecond, None)` -- Many mathematical unary and binary expressions such as `+`, `/`, `sqrt`, `tan`, `>=`. -- `WHERE` to filter -- `GROUP BY` together with one of the following aggregations: `MIN`, `MAX`, `COUNT`, `SUM`, `AVG`, `CORR`, `VAR`, `COVAR`, `STDDEV` (sample and population) -- `ORDER BY` together with an expression and optional `ASC` or `DESC` and also optional `NULLS FIRST` or `NULLS LAST` - -## Supported Functions - -DataFusion strives to implement a subset of the [PostgreSQL SQL dialect](https://www.postgresql.org/docs/current/functions.html) where possible. We explicitly choose a single dialect to maximize interoperability with other tools and allow reuse of the PostgreSQL documents and tutorials as much as possible. - -Currently, only a subset of the PostgreSQL dialect is implemented, and we will document any deviations. - -## Schema Metadata / Information Schema Support - -DataFusion supports the showing metadata about the tables available. This information can be accessed using the views of the ISO SQL `information_schema` schema or the DataFusion specific `SHOW TABLES` and `SHOW COLUMNS` commands. - -More information can be found in the [Postgres docs](https://www.postgresql.org/docs/13/infoschema-schema.html)). - -To show tables available for use in DataFusion, use the `SHOW TABLES` command or the `information_schema.tables` view: - -```sql -> show tables; -+---------------+--------------------+------------+------------+ -| table_catalog | table_schema | table_name | table_type | -+---------------+--------------------+------------+------------+ -| datafusion | public | t | BASE TABLE | -| datafusion | information_schema | tables | VIEW | -+---------------+--------------------+------------+------------+ - -> select * from information_schema.tables; - -+---------------+--------------------+------------+--------------+ -| table_catalog | table_schema | table_name | table_type | -+---------------+--------------------+------------+--------------+ -| datafusion | public | t | BASE TABLE | -| datafusion | information_schema | TABLES | SYSTEM TABLE | -+---------------+--------------------+------------+--------------+ -``` - -To show the schema of a table in DataFusion, use the `SHOW COLUMNS` command or the or `information_schema.columns` view: - -```sql -> show columns from t; -+---------------+--------------+------------+-------------+-----------+-------------+ -| table_catalog | table_schema | table_name | column_name | data_type | is_nullable | -+---------------+--------------+------------+-------------+-----------+-------------+ -| datafusion | public | t | a | Int32 | NO | -| datafusion | public | t | b | Utf8 | NO | -| datafusion | public | t | c | Float32 | NO | -+---------------+--------------+------------+-------------+-----------+-------------+ - -> select table_name, column_name, ordinal_position, is_nullable, data_type from information_schema.columns; -+------------+-------------+------------------+-------------+-----------+ -| table_name | column_name | ordinal_position | is_nullable | data_type | -+------------+-------------+------------------+-------------+-----------+ -| t | a | 0 | NO | Int32 | -| t | b | 1 | NO | Utf8 | -| t | c | 2 | NO | Float32 | -+------------+-------------+------------------+-------------+-----------+ -``` - -## Supported Data Types - -DataFusion uses Arrow, and thus the Arrow type system, for query -execution. The SQL types from -[sqlparser-rs](https://github.com/ballista-compute/sqlparser-rs/blob/main/src/ast/data_type.rs#L57) -are mapped to Arrow types according to the following table - -| SQL Data Type | Arrow DataType | -| ------------- | --------------------------------- | -| `CHAR` | `Utf8` | -| `VARCHAR` | `Utf8` | -| `UUID` | _Not yet supported_ | -| `CLOB` | _Not yet supported_ | -| `BINARY` | _Not yet supported_ | -| `VARBINARY` | _Not yet supported_ | -| `DECIMAL` | `Float64` | -| `FLOAT` | `Float32` | -| `SMALLINT` | `Int16` | -| `INT` | `Int32` | -| `BIGINT` | `Int64` | -| `REAL` | `Float32` | -| `DOUBLE` | `Float64` | -| `BOOLEAN` | `Boolean` | -| `DATE` | `Date32` | -| `TIME` | `Time64(TimeUnit::Millisecond)` | -| `TIMESTAMP` | `Timestamp(TimeUnit::Nanosecond)` | -| `INTERVAL` | _Not yet supported_ | -| `REGCLASS` | _Not yet supported_ | -| `TEXT` | _Not yet supported_ | -| `BYTEA` | _Not yet supported_ | -| `CUSTOM` | _Not yet supported_ | -| `ARRAY` | _Not yet supported_ | - -# Roadmap +## Roadmap Please see [Roadmap](docs/source/specification/roadmap.md) for information of where the project is headed. -# Architecture Overview +## Architecture Overview There is no formal document describing DataFusion's architecture yet, but the following presentations offer a good overview of its different components and how they interact together. - (March 2021): The DataFusion architecture is described in _Query Engine Design and the Rust-Based DataFusion in Apache Arrow_: [recording](https://www.youtube.com/watch?v=K6eCAVEk4kU) (DataFusion content starts [~ 15 minutes in](https://www.youtube.com/watch?v=K6eCAVEk4kU&t=875s)) and [slides](https://www.slideshare.net/influxdata/influxdb-iox-tech-talks-query-engine-design-and-the-rustbased-datafusion-in-apache-arrow-244161934) - (February 2021): How DataFusion is used within the Ballista Project is described in \*Ballista: Distributed Compute with Rust and Apache Arrow: [recording](https://www.youtube.com/watch?v=ZZHQaOap9pQ) -# Developer's guide +## User's guide + +Please see [User Guide](https://arrow.apache.org/datafusion/) for more information about DataFusion. + +## Developer's guide Please see [Developers Guide](DEVELOPERS.md) for information about developing DataFusion. diff --git a/docs/source/index.rst b/docs/source/index.rst index bf6b25096b4b..5109e60338fa 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -55,6 +55,7 @@ Table of content specification/roadmap specification/invariants specification/output-field-name-semantic + specification/quarterly_roadmap .. _toc.readme: diff --git a/docs/source/specification/quarterly_roadmap.md b/docs/source/specification/quarterly_roadmap.md new file mode 100644 index 000000000000..5bb805d7e7f0 --- /dev/null +++ b/docs/source/specification/quarterly_roadmap.md @@ -0,0 +1,72 @@ + + +# Roadmap + +A quarterly roadmap will be published to give the DataFusion community visibility into the priorities of the projects contributors. This roadmap is not binding. + +## 2022 Q1 + +### DataFusion Core + +- Publish official Arrow2 branch +- Implementation of memory manager (i.e. to enable spilling to disk as needed) + +### Benchmarking + +- Inclusion in Db-Benchmark with all quries covered +- All TPCH queries covered + +### Performance Improvements + +- Predicate evaluation +- Improve multi-column comparisons (that can't be vectorized at the moment) +- Null constant support + +### New Features + +- Read JSON as table +- Simplify DDL with Datafusion-Cli +- Add Decimal128 data type and the attendant features such as Arrow Kernel and UDF support +- Add new experimental e-graph based optimizer + +### Ballista + +- Begin work on design documents and plan / priorities for development + +### Extensions ([datafusion-contrib](https://github.com/datafusion-contrib])) + +- Stable S3 support +- Begin design discussions and prototyping of a stream provider + +## Beyond 2022 Q1 + +There is no clear timeline for the below, but community members have expressed interest in working on these topics. + +### DataFusion Core + +- Custom SQL support +- Split DataFusion into multiple crates +- Push based query execution and code generation + +### Ballista + +- Evolve architecture so that it can be deployed in a multi-tenant cloud native environment +- Ensure Ballista is scalable, elastic, and stable for production usage +- Develop distributed ML capabilities diff --git a/docs/source/user-guide/sql/index.rst b/docs/source/user-guide/sql/index.rst index 2489f6ba1f10..fc96acc8733c 100644 --- a/docs/source/user-guide/sql/index.rst +++ b/docs/source/user-guide/sql/index.rst @@ -21,6 +21,7 @@ SQL Reference .. toctree:: :maxdepth: 2 + sql_status select ddl DataFusion Functions diff --git a/docs/source/user-guide/sql/sql_status.md b/docs/source/user-guide/sql/sql_status.md new file mode 100644 index 000000000000..0df14e58a8be --- /dev/null +++ b/docs/source/user-guide/sql/sql_status.md @@ -0,0 +1,241 @@ + + +# Status + +## General + +- [x] SQL Parser +- [x] SQL Query Planner +- [x] Query Optimizer +- [x] Constant folding +- [x] Join Reordering +- [x] Limit Pushdown +- [x] Projection push down +- [x] Predicate push down +- [x] Type coercion +- [x] Parallel query execution + +## SQL Support + +- [x] Projection +- [x] Filter (WHERE) +- [x] Filter post-aggregate (HAVING) +- [x] Limit +- [x] Aggregate +- [x] Common math functions +- [x] cast +- [x] try_cast +- [x] [`VALUES` lists](https://www.postgresql.org/docs/current/queries-values.html) +- Postgres compatible String functions + - [x] ascii + - [x] bit_length + - [x] btrim + - [x] char_length + - [x] character_length + - [x] chr + - [x] concat + - [x] concat_ws + - [x] initcap + - [x] left + - [x] length + - [x] lpad + - [x] ltrim + - [x] octet_length + - [x] regexp_replace + - [x] repeat + - [x] replace + - [x] reverse + - [x] right + - [x] rpad + - [x] rtrim + - [x] split_part + - [x] starts_with + - [x] strpos + - [x] substr + - [x] to_hex + - [x] translate + - [x] trim +- Miscellaneous/Boolean functions + - [x] nullif +- Approximation functions + - [x] approx_distinct +- Common date/time functions + - [ ] Basic date functions + - [ ] Basic time functions + - [x] Basic timestamp functions + - [x] [to_timestamp](docs/user-guide/book/sql/datafusion-functions.html#to_timestamp) + - [x] [to_timestamp_millis](docs/user-guide/book/sql/datafusion-functions.html#to_timestamp_millis) + - [x] [to_timestamp_micros](docs/user-guide/book/sql/datafusion-functions.html#to_timestamp_micros) + - [x] [to_timestamp_seconds](docs/user-guide/book/sql/datafusion-functions.html#to_timestamp_seconds) +- nested functions + - [x] Array of columns +- [x] Schema Queries + - [x] SHOW TABLES + - [x] SHOW COLUMNS + - [x] information_schema.{tables, columns} + - [ ] information_schema other views +- [x] Sorting +- [ ] Nested types +- [ ] Lists +- [x] Subqueries +- [x] Common table expressions +- [x] Set Operations + - [x] UNION ALL + - [x] UNION + - [x] INTERSECT + - [x] INTERSECT ALL + - [x] EXCEPT + - [x] EXCEPT ALL +- [x] Joins + - [x] INNER JOIN + - [x] LEFT JOIN + - [x] RIGHT JOIN + - [x] FULL JOIN + - [x] CROSS JOIN +- [ ] Window + - [x] Empty window + - [x] Common window functions + - [x] Window with PARTITION BY clause + - [x] Window with ORDER BY clause + - [ ] Window with FILTER clause + - [ ] [Window with custom WINDOW FRAME](https://github.com/apache/arrow-datafusion/issues/361) + - [ ] UDF and UDAF for window functions + +## Data Sources + +- [x] CSV +- [x] Parquet primitive types +- [ ] Parquet nested types + +## Extensibility + +DataFusion is designed to be extensible at all points. To that end, you can provide your own custom: + +- [x] User Defined Functions (UDFs) +- [x] User Defined Aggregate Functions (UDAFs) +- [x] User Defined Table Source (`TableProvider`) for tables +- [x] User Defined `Optimizer` passes (plan rewrites) +- [x] User Defined `LogicalPlan` nodes +- [x] User Defined `ExecutionPlan` nodes + +## Rust Version Compatbility + +This crate is tested with the latest stable version of Rust. We do not currently test against other, older versions of the Rust compiler. + +# Supported SQL + +This library currently supports many SQL constructs, including + +- `CREATE EXTERNAL TABLE X STORED AS PARQUET LOCATION '...';` to register a table's locations +- `SELECT ... FROM ...` together with any expression +- `ALIAS` to name an expression +- `CAST` to change types, including e.g. `Timestamp(Nanosecond, None)` +- Many mathematical unary and binary expressions such as `+`, `/`, `sqrt`, `tan`, `>=`. +- `WHERE` to filter +- `GROUP BY` together with one of the following aggregations: `MIN`, `MAX`, `COUNT`, `SUM`, `AVG`, `CORR`, `VAR`, `COVAR`, `STDDEV` (sample and population) +- `ORDER BY` together with an expression and optional `ASC` or `DESC` and also optional `NULLS FIRST` or `NULLS LAST` + +## Supported Functions + +DataFusion strives to implement a subset of the [PostgreSQL SQL dialect](https://www.postgresql.org/docs/current/functions.html) where possible. We explicitly choose a single dialect to maximize interoperability with other tools and allow reuse of the PostgreSQL documents and tutorials as much as possible. + +Currently, only a subset of the PostgreSQL dialect is implemented, and we will document any deviations. + +## Schema Metadata / Information Schema Support + +DataFusion supports the showing metadata about the tables available. This information can be accessed using the views of the ISO SQL `information_schema` schema or the DataFusion specific `SHOW TABLES` and `SHOW COLUMNS` commands. + +More information can be found in the [Postgres docs](https://www.postgresql.org/docs/13/infoschema-schema.html)). + +To show tables available for use in DataFusion, use the `SHOW TABLES` command or the `information_schema.tables` view: + +```sql +> show tables; ++---------------+--------------------+------------+------------+ +| table_catalog | table_schema | table_name | table_type | ++---------------+--------------------+------------+------------+ +| datafusion | public | t | BASE TABLE | +| datafusion | information_schema | tables | VIEW | ++---------------+--------------------+------------+------------+ + +> select * from information_schema.tables; + ++---------------+--------------------+------------+--------------+ +| table_catalog | table_schema | table_name | table_type | ++---------------+--------------------+------------+--------------+ +| datafusion | public | t | BASE TABLE | +| datafusion | information_schema | TABLES | SYSTEM TABLE | ++---------------+--------------------+------------+--------------+ +``` + +To show the schema of a table in DataFusion, use the `SHOW COLUMNS` command or the or `information_schema.columns` view: + +```sql +> show columns from t; ++---------------+--------------+------------+-------------+-----------+-------------+ +| table_catalog | table_schema | table_name | column_name | data_type | is_nullable | ++---------------+--------------+------------+-------------+-----------+-------------+ +| datafusion | public | t | a | Int32 | NO | +| datafusion | public | t | b | Utf8 | NO | +| datafusion | public | t | c | Float32 | NO | ++---------------+--------------+------------+-------------+-----------+-------------+ + +> select table_name, column_name, ordinal_position, is_nullable, data_type from information_schema.columns; ++------------+-------------+------------------+-------------+-----------+ +| table_name | column_name | ordinal_position | is_nullable | data_type | ++------------+-------------+------------------+-------------+-----------+ +| t | a | 0 | NO | Int32 | +| t | b | 1 | NO | Utf8 | +| t | c | 2 | NO | Float32 | ++------------+-------------+------------------+-------------+-----------+ +``` + +## Supported Data Types + +DataFusion uses Arrow, and thus the Arrow type system, for query +execution. The SQL types from +[sqlparser-rs](https://github.com/ballista-compute/sqlparser-rs/blob/main/src/ast/data_type.rs#L57) +are mapped to Arrow types according to the following table + +| SQL Data Type | Arrow DataType | +| ------------- | --------------------------------- | +| `CHAR` | `Utf8` | +| `VARCHAR` | `Utf8` | +| `UUID` | _Not yet supported_ | +| `CLOB` | _Not yet supported_ | +| `BINARY` | _Not yet supported_ | +| `VARBINARY` | _Not yet supported_ | +| `DECIMAL` | `Float64` | +| `FLOAT` | `Float32` | +| `SMALLINT` | `Int16` | +| `INT` | `Int32` | +| `BIGINT` | `Int64` | +| `REAL` | `Float32` | +| `DOUBLE` | `Float64` | +| `BOOLEAN` | `Boolean` | +| `DATE` | `Date32` | +| `TIME` | `Time64(TimeUnit::Millisecond)` | +| `TIMESTAMP` | `Timestamp(TimeUnit::Nanosecond)` | +| `INTERVAL` | _Not yet supported_ | +| `REGCLASS` | _Not yet supported_ | +| `TEXT` | _Not yet supported_ | +| `BYTEA` | _Not yet supported_ | +| `CUSTOM` | _Not yet supported_ | +| `ARRAY` | _Not yet supported_ | From 40c29e5cf205f4e6f5073a7e1daf0b37f6daa45d Mon Sep 17 00:00:00 2001 From: Remzi Yang <59198230+HaoYang670@users.noreply.github.com> Date: Mon, 7 Feb 2022 19:59:50 +0800 Subject: [PATCH 35/50] Create built-in scalar functions programmatically (#1734) * create build-in scalar functions programatically Signed-off-by: remzi <13716567376yh@gmail.com> * solve conflict Signed-off-by: remzi <13716567376yh@gmail.com> * fix spelling mistake Signed-off-by: remzi <13716567376yh@gmail.com> * rename to call_fn Signed-off-by: remzi <13716567376yh@gmail.com> --- datafusion/src/logical_plan/expr.rs | 14 ++++++ datafusion/src/logical_plan/mod.rs | 2 +- .../src/optimizer/simplify_expressions.rs | 45 +++++-------------- 3 files changed, 26 insertions(+), 35 deletions(-) diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 300c75137d8a..c2763d097e85 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -2235,6 +2235,20 @@ pub fn exprlist_to_fields<'a>( expr.into_iter().map(|e| e.to_field(input_schema)).collect() } +/// Calls a named built in function +/// ``` +/// use datafusion::logical_plan::*; +/// +/// // create the expression sin(x) < 0.2 +/// let expr = call_fn("sin", vec![col("x")]).unwrap().lt(lit(0.2)); +/// ``` +pub fn call_fn(name: impl AsRef, args: Vec) -> Result { + match name.as_ref().parse::() { + Ok(fun) => Ok(Expr::ScalarFunction { fun, args }), + Err(e) => Err(e), + } +} + #[cfg(test)] mod tests { use super::super::{col, lit, when}; diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index 22521a1bd1fb..ec1aea6a72a1 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -37,7 +37,7 @@ pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; pub use expr::{ abs, acos, and, approx_distinct, approx_percentile_cont, array, ascii, asin, atan, - avg, binary_expr, bit_length, btrim, case, ceil, character_length, chr, col, + avg, binary_expr, bit_length, btrim, call_fn, case, ceil, character_length, chr, col, columnize_expr, combine_filters, concat, concat_ws, cos, count, count_distinct, create_udaf, create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, floor, in_list, initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, diff --git a/datafusion/src/optimizer/simplify_expressions.rs b/datafusion/src/optimizer/simplify_expressions.rs index c000bdbc2bea..5f87542491d7 100644 --- a/datafusion/src/optimizer/simplify_expressions.rs +++ b/datafusion/src/optimizer/simplify_expressions.rs @@ -735,8 +735,8 @@ mod tests { use super::*; use crate::assert_contains; use crate::logical_plan::{ - and, binary_expr, col, create_udf, lit, lit_timestamp_nano, DFField, Expr, - LogicalPlanBuilder, + and, binary_expr, call_fn, col, create_udf, lit, lit_timestamp_nano, DFField, + Expr, LogicalPlanBuilder, }; use crate::physical_plan::functions::{make_scalar_function, BuiltinScalarFunction}; use crate::physical_plan::udf::ScalarUDF; @@ -1010,46 +1010,29 @@ mod tests { #[test] fn test_const_evaluator_scalar_functions() { // concat("foo", "bar") --> "foobar" - let expr = Expr::ScalarFunction { - args: vec![lit("foo"), lit("bar")], - fun: BuiltinScalarFunction::Concat, - }; + let expr = call_fn("concat", vec![lit("foo"), lit("bar")]).unwrap(); test_evaluate(expr, lit("foobar")); // ensure arguments are also constant folded // concat("foo", concat("bar", "baz")) --> "foobarbaz" - let concat1 = Expr::ScalarFunction { - args: vec![lit("bar"), lit("baz")], - fun: BuiltinScalarFunction::Concat, - }; - let expr = Expr::ScalarFunction { - args: vec![lit("foo"), concat1], - fun: BuiltinScalarFunction::Concat, - }; + let concat1 = call_fn("concat", vec![lit("bar"), lit("baz")]).unwrap(); + let expr = call_fn("concat", vec![lit("foo"), concat1]).unwrap(); test_evaluate(expr, lit("foobarbaz")); // Check non string arguments // to_timestamp("2020-09-08T12:00:00+00:00") --> timestamp(1599566400000000000i64) - let expr = Expr::ScalarFunction { - args: vec![lit("2020-09-08T12:00:00+00:00")], - fun: BuiltinScalarFunction::ToTimestamp, - }; + let expr = + call_fn("to_timestamp", vec![lit("2020-09-08T12:00:00+00:00")]).unwrap(); test_evaluate(expr, lit_timestamp_nano(1599566400000000000i64)); // check that non foldable arguments are folded // to_timestamp(a) --> to_timestamp(a) [no rewrite possible] - let expr = Expr::ScalarFunction { - args: vec![col("a")], - fun: BuiltinScalarFunction::ToTimestamp, - }; + let expr = call_fn("to_timestamp", vec![col("a")]).unwrap(); test_evaluate(expr.clone(), expr); // check that non foldable arguments are folded // to_timestamp(a) --> to_timestamp(a) [no rewrite possible] - let expr = Expr::ScalarFunction { - args: vec![col("a")], - fun: BuiltinScalarFunction::ToTimestamp, - }; + let expr = call_fn("to_timestamp", vec![col("a")]).unwrap(); test_evaluate(expr.clone(), expr); // volatile / stable functions should not be evaluated @@ -1090,10 +1073,7 @@ mod tests { } fn now_expr() -> Expr { - Expr::ScalarFunction { - args: vec![], - fun: BuiltinScalarFunction::Now, - } + call_fn("now", vec![]).unwrap() } fn cast_to_int64_expr(expr: Expr) -> Expr { @@ -1104,10 +1084,7 @@ mod tests { } fn to_timestamp_expr(arg: impl Into) -> Expr { - Expr::ScalarFunction { - args: vec![lit(arg.into())], - fun: BuiltinScalarFunction::ToTimestamp, - } + call_fn("to_timestamp", vec![lit(arg.into())]).unwrap() } #[test] From fe46a1ed9833f2f9ea4c4ccd4d77718e5c371ab1 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Mon, 7 Feb 2022 21:27:14 +0800 Subject: [PATCH 36/50] [split/1] split datafusion-common module (#1751) * split datafusion-common module * pyarrow * Update datafusion-common/README.md Co-authored-by: Andy Grove * Update datafusion/Cargo.toml * include publishing Co-authored-by: Andy Grove --- Cargo.toml | 1 + datafusion-common/Cargo.toml | 44 +++++++ datafusion-common/README.md | 24 ++++ datafusion-common/src/error.rs | 209 +++++++++++++++++++++++++++++++++ datafusion-common/src/lib.rs | 20 ++++ datafusion/Cargo.toml | 5 +- datafusion/src/error.rs | 182 +--------------------------- datafusion/src/pyarrow.rs | 8 -- 8 files changed, 302 insertions(+), 191 deletions(-) create mode 100644 datafusion-common/Cargo.toml create mode 100644 datafusion-common/README.md create mode 100644 datafusion-common/src/error.rs create mode 100644 datafusion-common/src/lib.rs diff --git a/Cargo.toml b/Cargo.toml index ea1acc04e687..81f6bb59f2d0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ [workspace] members = [ "datafusion", + "datafusion-common", "datafusion-cli", "datafusion-examples", "benchmarks", diff --git a/datafusion-common/Cargo.toml b/datafusion-common/Cargo.toml new file mode 100644 index 000000000000..9c05d8095caf --- /dev/null +++ b/datafusion-common/Cargo.toml @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-common" +description = "DataFusion is an in-memory query engine that uses Apache Arrow as the memory model" +version = "6.0.0" +homepage = "https://github.com/apache/arrow-datafusion" +repository = "https://github.com/apache/arrow-datafusion" +readme = "README.md" +authors = ["Apache Arrow "] +license = "Apache-2.0" +keywords = [ "arrow", "query", "sql" ] +edition = "2021" +rust-version = "1.58" + +[lib] +name = "datafusion_common" +path = "src/lib.rs" + +[features] +avro = ["avro-rs"] +pyarrow = ["pyo3"] + +[dependencies] +arrow = { version = "8.0.0", features = ["prettyprint"] } +parquet = { version = "8.0.0", features = ["arrow"] } +avro-rs = { version = "0.13", features = ["snappy"], optional = true } +pyo3 = { version = "0.15", optional = true } +sqlparser = "0.13" diff --git a/datafusion-common/README.md b/datafusion-common/README.md new file mode 100644 index 000000000000..8c44d78ef47f --- /dev/null +++ b/datafusion-common/README.md @@ -0,0 +1,24 @@ + + +# DataFusion Common + +This is an internal module for the most fundamental types of [DataFusion][df]. + +[df]: https://crates.io/crates/datafusion diff --git a/datafusion-common/src/error.rs b/datafusion-common/src/error.rs new file mode 100644 index 000000000000..ee2e61892fd4 --- /dev/null +++ b/datafusion-common/src/error.rs @@ -0,0 +1,209 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! DataFusion error types + +use std::error; +use std::fmt::{Display, Formatter}; +use std::io; +use std::result; + +use arrow::error::ArrowError; +#[cfg(feature = "avro")] +use avro_rs::Error as AvroError; +use parquet::errors::ParquetError; +#[cfg(feature = "pyarrow")] +use pyo3::exceptions::PyException; +#[cfg(feature = "pyarrow")] +use pyo3::prelude::PyErr; +use sqlparser::parser::ParserError; + +/// Result type for operations that could result in an [DataFusionError] +pub type Result = result::Result; + +/// Error type for generic operations that could result in DataFusionError::External +pub type GenericError = Box; + +/// DataFusion error +#[derive(Debug)] +pub enum DataFusionError { + /// Error returned by arrow. + ArrowError(ArrowError), + /// Wraps an error from the Parquet crate + ParquetError(ParquetError), + /// Wraps an error from the Avro crate + #[cfg(feature = "avro")] + AvroError(AvroError), + /// Error associated to I/O operations and associated traits. + IoError(io::Error), + /// Error returned when SQL is syntactically incorrect. + SQL(ParserError), + /// Error returned on a branch that we know it is possible + /// but to which we still have no implementation for. + /// Often, these errors are tracked in our issue tracker. + NotImplemented(String), + /// Error returned as a consequence of an error in DataFusion. + /// This error should not happen in normal usage of DataFusion. + // DataFusions has internal invariants that we are unable to ask the compiler to check for us. + // This error is raised when one of those invariants is not verified during execution. + Internal(String), + /// This error happens whenever a plan is not valid. Examples include + /// impossible casts, schema inference not possible and non-unique column names. + Plan(String), + /// Error returned during execution of the query. + /// Examples include files not found, errors in parsing certain types. + Execution(String), + /// This error is thrown when a consumer cannot acquire memory from the Memory Manager + /// we can just cancel the execution of the partition. + ResourcesExhausted(String), + /// Errors originating from outside DataFusion's core codebase. + /// For example, a custom S3Error from the crate datafusion-objectstore-s3 + External(GenericError), +} + +impl From for DataFusionError { + fn from(e: io::Error) -> Self { + DataFusionError::IoError(e) + } +} + +impl From for DataFusionError { + fn from(e: ArrowError) -> Self { + DataFusionError::ArrowError(e) + } +} + +#[cfg(feature = "pyarrow")] +impl From for PyErr { + fn from(err: DataFusionError) -> PyErr { + PyException::new_err(err.to_string()) + } +} + +impl From for ArrowError { + fn from(e: DataFusionError) -> Self { + match e { + DataFusionError::ArrowError(e) => e, + DataFusionError::External(e) => ArrowError::ExternalError(e), + other => ArrowError::ExternalError(Box::new(other)), + } + } +} + +impl From for DataFusionError { + fn from(e: ParquetError) -> Self { + DataFusionError::ParquetError(e) + } +} + +#[cfg(feature = "avro")] +impl From for DataFusionError { + fn from(e: AvroError) -> Self { + DataFusionError::AvroError(e) + } +} + +impl From for DataFusionError { + fn from(e: ParserError) -> Self { + DataFusionError::SQL(e) + } +} + +impl From for DataFusionError { + fn from(err: GenericError) -> Self { + DataFusionError::External(err) + } +} + +impl Display for DataFusionError { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + match *self { + DataFusionError::ArrowError(ref desc) => write!(f, "Arrow error: {}", desc), + DataFusionError::ParquetError(ref desc) => { + write!(f, "Parquet error: {}", desc) + } + #[cfg(feature = "avro")] + DataFusionError::AvroError(ref desc) => { + write!(f, "Avro error: {}", desc) + } + DataFusionError::IoError(ref desc) => write!(f, "IO error: {}", desc), + DataFusionError::SQL(ref desc) => { + write!(f, "SQL error: {:?}", desc) + } + DataFusionError::NotImplemented(ref desc) => { + write!(f, "This feature is not implemented: {}", desc) + } + DataFusionError::Internal(ref desc) => { + write!(f, "Internal error: {}. This was likely caused by a bug in DataFusion's \ + code and we would welcome that you file an bug report in our issue tracker", desc) + } + DataFusionError::Plan(ref desc) => { + write!(f, "Error during planning: {}", desc) + } + DataFusionError::Execution(ref desc) => { + write!(f, "Execution error: {}", desc) + } + DataFusionError::ResourcesExhausted(ref desc) => { + write!(f, "Resources exhausted: {}", desc) + } + DataFusionError::External(ref desc) => { + write!(f, "External error: {}", desc) + } + } + } +} + +impl error::Error for DataFusionError {} + +#[cfg(test)] +mod test { + use crate::error::DataFusionError; + use arrow::error::ArrowError; + + #[test] + fn arrow_error_to_datafusion() { + let res = return_arrow_error().unwrap_err(); + assert_eq!( + res.to_string(), + "External error: Error during planning: foo" + ); + } + + #[test] + fn datafusion_error_to_arrow() { + let res = return_datafusion_error().unwrap_err(); + assert_eq!(res.to_string(), "Arrow error: Schema error: bar"); + } + + /// Model what happens when implementing SendableRecrordBatchStream: + /// DataFusion code needs to return an ArrowError + #[allow(clippy::try_err)] + fn return_arrow_error() -> arrow::error::Result<()> { + // Expect the '?' to work + let _foo = Err(DataFusionError::Plan("foo".to_string()))?; + Ok(()) + } + + /// Model what happens when using arrow kernels in DataFusion + /// code: need to turn an ArrowError into a DataFusionError + #[allow(clippy::try_err)] + fn return_datafusion_error() -> crate::error::Result<()> { + // Expect the '?' to work + let _bar = Err(ArrowError::SchemaError("bar".to_string()))?; + Ok(()) + } +} diff --git a/datafusion-common/src/lib.rs b/datafusion-common/src/lib.rs new file mode 100644 index 000000000000..ac8ef623b0cd --- /dev/null +++ b/datafusion-common/src/lib.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod error; + +pub use error::{DataFusionError, Result}; diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index 81e2bb14877b..6df852257e16 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -43,13 +43,14 @@ simd = ["arrow/simd"] crypto_expressions = ["md-5", "sha2", "blake2", "blake3"] regex_expressions = ["regex"] unicode_expressions = ["unicode-segmentation"] -pyarrow = ["pyo3", "arrow/pyarrow"] +pyarrow = ["pyo3", "arrow/pyarrow", "datafusion-common/pyarrow"] # Used for testing ONLY: causes all values to hash to the same value (test for collisions) force_hash_collisions = [] # Used to enable the avro format -avro = ["avro-rs", "num-traits"] +avro = ["avro-rs", "num-traits", "datafusion-common/avro"] [dependencies] +datafusion-common = { path = "../datafusion-common", version = "6.0.0" } ahash = { version = "0.7", default-features = false } hashbrown = { version = "0.12", features = ["raw"] } arrow = { version = "8.0.0", features = ["prettyprint"] } diff --git a/datafusion/src/error.rs b/datafusion/src/error.rs index 248f24350356..c2c80b48781e 100644 --- a/datafusion/src/error.rs +++ b/datafusion/src/error.rs @@ -16,184 +16,4 @@ // under the License. //! DataFusion error types - -use std::error; -use std::fmt::{Display, Formatter}; -use std::io; -use std::result; - -use arrow::error::ArrowError; -#[cfg(feature = "avro")] -use avro_rs::Error as AvroError; -use parquet::errors::ParquetError; -use sqlparser::parser::ParserError; - -/// Result type for operations that could result in an [DataFusionError] -pub type Result = result::Result; - -/// Error type for generic operations that could result in DataFusionError::External -pub type GenericError = Box; - -/// DataFusion error -#[derive(Debug)] -#[allow(missing_docs)] -pub enum DataFusionError { - /// Error returned by arrow. - ArrowError(ArrowError), - /// Wraps an error from the Parquet crate - ParquetError(ParquetError), - /// Wraps an error from the Avro crate - #[cfg(feature = "avro")] - AvroError(AvroError), - /// Error associated to I/O operations and associated traits. - IoError(io::Error), - /// Error returned when SQL is syntactically incorrect. - SQL(ParserError), - /// Error returned on a branch that we know it is possible - /// but to which we still have no implementation for. - /// Often, these errors are tracked in our issue tracker. - NotImplemented(String), - /// Error returned as a consequence of an error in DataFusion. - /// This error should not happen in normal usage of DataFusion. - // DataFusions has internal invariants that we are unable to ask the compiler to check for us. - // This error is raised when one of those invariants is not verified during execution. - Internal(String), - /// This error happens whenever a plan is not valid. Examples include - /// impossible casts, schema inference not possible and non-unique column names. - Plan(String), - /// Error returned during execution of the query. - /// Examples include files not found, errors in parsing certain types. - Execution(String), - /// This error is thrown when a consumer cannot acquire memory from the Memory Manager - /// we can just cancel the execution of the partition. - ResourcesExhausted(String), - /// Errors originating from outside DataFusion's core codebase. - /// For example, a custom S3Error from the crate datafusion-objectstore-s3 - External(GenericError), -} - -impl From for DataFusionError { - fn from(e: io::Error) -> Self { - DataFusionError::IoError(e) - } -} - -impl From for DataFusionError { - fn from(e: ArrowError) -> Self { - DataFusionError::ArrowError(e) - } -} - -impl From for ArrowError { - fn from(e: DataFusionError) -> Self { - match e { - DataFusionError::ArrowError(e) => e, - DataFusionError::External(e) => ArrowError::ExternalError(e), - other => ArrowError::ExternalError(Box::new(other)), - } - } -} - -impl From for DataFusionError { - fn from(e: ParquetError) -> Self { - DataFusionError::ParquetError(e) - } -} - -#[cfg(feature = "avro")] -impl From for DataFusionError { - fn from(e: AvroError) -> Self { - DataFusionError::AvroError(e) - } -} - -impl From for DataFusionError { - fn from(e: ParserError) -> Self { - DataFusionError::SQL(e) - } -} - -impl From for DataFusionError { - fn from(err: GenericError) -> Self { - DataFusionError::External(err) - } -} - -impl Display for DataFusionError { - fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { - match *self { - DataFusionError::ArrowError(ref desc) => write!(f, "Arrow error: {}", desc), - DataFusionError::ParquetError(ref desc) => { - write!(f, "Parquet error: {}", desc) - } - #[cfg(feature = "avro")] - DataFusionError::AvroError(ref desc) => { - write!(f, "Avro error: {}", desc) - } - DataFusionError::IoError(ref desc) => write!(f, "IO error: {}", desc), - DataFusionError::SQL(ref desc) => { - write!(f, "SQL error: {:?}", desc) - } - DataFusionError::NotImplemented(ref desc) => { - write!(f, "This feature is not implemented: {}", desc) - } - DataFusionError::Internal(ref desc) => { - write!(f, "Internal error: {}. This was likely caused by a bug in DataFusion's \ - code and we would welcome that you file an bug report in our issue tracker", desc) - } - DataFusionError::Plan(ref desc) => { - write!(f, "Error during planning: {}", desc) - } - DataFusionError::Execution(ref desc) => { - write!(f, "Execution error: {}", desc) - } - DataFusionError::ResourcesExhausted(ref desc) => { - write!(f, "Resources exhausted: {}", desc) - } - DataFusionError::External(ref desc) => { - write!(f, "External error: {}", desc) - } - } - } -} - -impl error::Error for DataFusionError {} - -#[cfg(test)] -mod test { - use crate::error::DataFusionError; - use arrow::error::ArrowError; - - #[test] - fn arrow_error_to_datafusion() { - let res = return_arrow_error().unwrap_err(); - assert_eq!( - res.to_string(), - "External error: Error during planning: foo" - ); - } - - #[test] - fn datafusion_error_to_arrow() { - let res = return_datafusion_error().unwrap_err(); - assert_eq!(res.to_string(), "Arrow error: Schema error: bar"); - } - - /// Model what happens when implementing SendableRecrordBatchStream: - /// DataFusion code needs to return an ArrowError - #[allow(clippy::try_err)] - fn return_arrow_error() -> arrow::error::Result<()> { - // Expect the '?' to work - let _foo = Err(DataFusionError::Plan("foo".to_string()))?; - Ok(()) - } - - /// Model what happens when using arrow kernels in DataFusion - /// code: need to turn an ArrowError into a DataFusionError - #[allow(clippy::try_err)] - fn return_datafusion_error() -> crate::error::Result<()> { - // Expect the '?' to work - let _bar = Err(ArrowError::SchemaError("bar".to_string()))?; - Ok(()) - } -} +pub use datafusion_common::{DataFusionError, Result}; diff --git a/datafusion/src/pyarrow.rs b/datafusion/src/pyarrow.rs index d819b2b41154..46eb6b4437b5 100644 --- a/datafusion/src/pyarrow.rs +++ b/datafusion/src/pyarrow.rs @@ -15,21 +15,13 @@ // specific language governing permissions and limitations // under the License. -use pyo3::exceptions::PyException; use pyo3::prelude::*; use pyo3::types::PyList; use crate::arrow::array::ArrayData; use crate::arrow::pyarrow::PyArrowConvert; -use crate::error::DataFusionError; use crate::scalar::ScalarValue; -impl From for PyErr { - fn from(err: DataFusionError) -> PyErr { - PyException::new_err(err.to_string()) - } -} - impl PyArrowConvert for ScalarValue { fn from_pyarrow(value: &PyAny) -> PyResult { let py = value.py(); From d014ff218012fb019e9f625416d4cf35051db107 Mon Sep 17 00:00:00 2001 From: Marko Mikulicic Date: Mon, 7 Feb 2022 14:48:17 +0100 Subject: [PATCH 37/50] fix: Case insensitive unquoted identifiers (#1747) --- datafusion/src/execution/context.rs | 167 ++++++++++++++++++++++++++++ datafusion/src/sql/planner.rs | 18 +-- datafusion/src/sql/utils.rs | 9 ++ 3 files changed, 182 insertions(+), 12 deletions(-) diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 96e49c800f48..2f8663ecae01 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -3359,6 +3359,173 @@ mod tests { assert_eq!(result[0].schema().metadata(), result[1].schema().metadata()); } + #[tokio::test] + async fn normalized_column_identifiers() { + // create local execution context + let mut ctx = ExecutionContext::new(); + + // register csv file with the execution context + ctx.register_csv( + "case_insensitive_test", + "tests/example.csv", + CsvReadOptions::new(), + ) + .await + .unwrap(); + + let sql = "SELECT A, b FROM case_insensitive_test"; + let result = plan_and_collect(&mut ctx, sql) + .await + .expect("ran plan correctly"); + let expected = vec![ + "+---+---+", + "| a | b |", + "+---+---+", + "| 1 | 2 |", + "+---+---+", + ]; + assert_batches_sorted_eq!(expected, &result); + + let sql = "SELECT t.A, b FROM case_insensitive_test AS t"; + let result = plan_and_collect(&mut ctx, sql) + .await + .expect("ran plan correctly"); + let expected = vec![ + "+---+---+", + "| a | b |", + "+---+---+", + "| 1 | 2 |", + "+---+---+", + ]; + assert_batches_sorted_eq!(expected, &result); + + // Aliases + + let sql = "SELECT t.A as x, b FROM case_insensitive_test AS t"; + let result = plan_and_collect(&mut ctx, sql) + .await + .expect("ran plan correctly"); + let expected = vec![ + "+---+---+", + "| x | b |", + "+---+---+", + "| 1 | 2 |", + "+---+---+", + ]; + assert_batches_sorted_eq!(expected, &result); + + let sql = "SELECT t.A AS X, b FROM case_insensitive_test AS t"; + let result = plan_and_collect(&mut ctx, sql) + .await + .expect("ran plan correctly"); + let expected = vec![ + "+---+---+", + "| x | b |", + "+---+---+", + "| 1 | 2 |", + "+---+---+", + ]; + assert_batches_sorted_eq!(expected, &result); + + let sql = r#"SELECT t.A AS "X", b FROM case_insensitive_test AS t"#; + let result = plan_and_collect(&mut ctx, sql) + .await + .expect("ran plan correctly"); + let expected = vec![ + "+---+---+", + "| X | b |", + "+---+---+", + "| 1 | 2 |", + "+---+---+", + ]; + assert_batches_sorted_eq!(expected, &result); + + // Order by + + let sql = "SELECT t.A AS x, b FROM case_insensitive_test AS t ORDER BY x"; + let result = plan_and_collect(&mut ctx, sql) + .await + .expect("ran plan correctly"); + let expected = vec![ + "+---+---+", + "| x | b |", + "+---+---+", + "| 1 | 2 |", + "+---+---+", + ]; + assert_batches_sorted_eq!(expected, &result); + + let sql = "SELECT t.A AS x, b FROM case_insensitive_test AS t ORDER BY X"; + let result = plan_and_collect(&mut ctx, sql) + .await + .expect("ran plan correctly"); + let expected = vec![ + "+---+---+", + "| x | b |", + "+---+---+", + "| 1 | 2 |", + "+---+---+", + ]; + assert_batches_sorted_eq!(expected, &result); + + let sql = r#"SELECT t.A AS "X", b FROM case_insensitive_test AS t ORDER BY "X""#; + let result = plan_and_collect(&mut ctx, sql) + .await + .expect("ran plan correctly"); + let expected = vec![ + "+---+---+", + "| X | b |", + "+---+---+", + "| 1 | 2 |", + "+---+---+", + ]; + assert_batches_sorted_eq!(expected, &result); + + // Where + + let sql = "SELECT a, b FROM case_insensitive_test where A IS NOT null"; + let result = plan_and_collect(&mut ctx, sql) + .await + .expect("ran plan correctly"); + let expected = vec![ + "+---+---+", + "| a | b |", + "+---+---+", + "| 1 | 2 |", + "+---+---+", + ]; + assert_batches_sorted_eq!(expected, &result); + + // Group by + + let sql = "SELECT a as x, count(*) as c FROM case_insensitive_test GROUP BY X"; + let result = plan_and_collect(&mut ctx, sql) + .await + .expect("ran plan correctly"); + let expected = vec![ + "+---+---+", + "| x | c |", + "+---+---+", + "| 1 | 1 |", + "+---+---+", + ]; + assert_batches_sorted_eq!(expected, &result); + + let sql = + r#"SELECT a as "X", count(*) as c FROM case_insensitive_test GROUP BY "X""#; + let result = plan_and_collect(&mut ctx, sql) + .await + .expect("ran plan correctly"); + let expected = vec![ + "+---+---+", + "| X | c |", + "+---+---+", + "| 1 | 1 |", + "+---+---+", + ]; + assert_batches_sorted_eq!(expected, &result); + } + struct MyPhysicalPlanner {} #[async_trait] diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 462977274ecb..682b92ba661f 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -36,7 +36,7 @@ use crate::logical_plan::{ use crate::optimizer::utils::exprlist_to_columns; use crate::prelude::JoinType; use crate::scalar::ScalarValue; -use crate::sql::utils::make_decimal_type; +use crate::sql::utils::{make_decimal_type, normalize_ident}; use crate::{ error::{DataFusionError, Result}, physical_plan::udaf::AggregateUDF, @@ -1191,7 +1191,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SelectItem::UnnamedExpr(expr) => self.sql_to_rex(expr, schema), SelectItem::ExprWithAlias { expr, alias } => Ok(Alias( Box::new(self.sql_to_rex(expr, schema)?), - alias.value.clone(), + normalize_ident(alias), )), SelectItem::Wildcard => Ok(Expr::Wildcard), SelectItem::QualifiedWildcard(_) => Err(DataFusionError::NotImplemented( @@ -1392,6 +1392,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Identifier(ref id) => { if id.value.starts_with('@') { + // TODO: figure out if ScalarVariables should be insensitive. let var_names = vec![id.value.clone()]; Ok(Expr::ScalarVariable(var_names)) } else { @@ -1401,7 +1402,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // identifier. (e.g. it is "foo.bar" not foo.bar) Ok(Expr::Column(Column { relation: None, - name: id.value.clone(), + name: normalize_ident(id), })) } } @@ -1418,8 +1419,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } SQLExpr::CompoundIdentifier(ids) => { - let mut var_names: Vec<_> = - ids.iter().map(|id| id.value.clone()).collect(); + let mut var_names: Vec<_> = ids.iter().map(normalize_ident).collect(); if &var_names[0][0..1] == "@" { Ok(Expr::ScalarVariable(var_names)) @@ -1639,13 +1639,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // (e.g. "foo.bar") for function names yet function.name.to_string() } else { - // if there is a quote style, then don't normalize - // the name, otherwise normalize to lowercase - let ident = &function.name.0[0]; - match ident.quote_style { - Some(_) => ident.value.clone(), - None => ident.value.to_ascii_lowercase(), - } + normalize_ident(&function.name.0[0]) }; // first, scalar built-in diff --git a/datafusion/src/sql/utils.rs b/datafusion/src/sql/utils.rs index 0ede5ad8559e..d0cef0f3d376 100644 --- a/datafusion/src/sql/utils.rs +++ b/datafusion/src/sql/utils.rs @@ -18,6 +18,7 @@ //! SQL Utility Functions use arrow::datatypes::DataType; +use sqlparser::ast::Ident; use crate::logical_plan::{Expr, LogicalPlan}; use crate::scalar::{ScalarValue, MAX_PRECISION_FOR_DECIMAL128}; @@ -532,6 +533,14 @@ pub(crate) fn make_decimal_type( } } +// Normalize an identifer to a lowercase string unless the identifier is quoted. +pub(crate) fn normalize_ident(id: &Ident) -> String { + match id.quote_style { + Some(_) => id.value.clone(), + None => id.value.to_ascii_lowercase(), + } +} + #[cfg(test)] mod tests { use super::*; From 2e535f91624f4f10cf4edc24b5c9b86f5a539960 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Mon, 7 Feb 2022 22:02:38 +0800 Subject: [PATCH 38/50] move dfschema and column (#1758) --- datafusion-common/src/column.rs | 150 +++++ datafusion-common/src/dfschema.rs | 722 ++++++++++++++++++++++++ datafusion-common/src/lib.rs | 4 + datafusion/src/logical_plan/builder.rs | 48 +- datafusion/src/logical_plan/dfschema.rs | 669 +--------------------- datafusion/src/logical_plan/expr.rs | 174 +----- 6 files changed, 912 insertions(+), 855 deletions(-) create mode 100644 datafusion-common/src/column.rs create mode 100644 datafusion-common/src/dfschema.rs diff --git a/datafusion-common/src/column.rs b/datafusion-common/src/column.rs new file mode 100644 index 000000000000..02faa24b0346 --- /dev/null +++ b/datafusion-common/src/column.rs @@ -0,0 +1,150 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Column + +use crate::{DFSchema, DataFusionError, Result}; +use std::collections::HashSet; +use std::convert::Infallible; +use std::fmt; +use std::str::FromStr; +use std::sync::Arc; + +/// A named reference to a qualified field in a schema. +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct Column { + /// relation/table name. + pub relation: Option, + /// field/column name. + pub name: String, +} + +impl Column { + /// Create Column from unqualified name. + pub fn from_name(name: impl Into) -> Self { + Self { + relation: None, + name: name.into(), + } + } + + /// Deserialize a fully qualified name string into a column + pub fn from_qualified_name(flat_name: &str) -> Self { + use sqlparser::tokenizer::Token; + + let dialect = sqlparser::dialect::GenericDialect {}; + let mut tokenizer = sqlparser::tokenizer::Tokenizer::new(&dialect, flat_name); + if let Ok(tokens) = tokenizer.tokenize() { + if let [Token::Word(relation), Token::Period, Token::Word(name)] = + tokens.as_slice() + { + return Column { + relation: Some(relation.value.clone()), + name: name.value.clone(), + }; + } + } + // any expression that's not in the form of `foo.bar` will be treated as unqualified column + // name + Column { + relation: None, + name: String::from(flat_name), + } + } + + /// Serialize column into a flat name string + pub fn flat_name(&self) -> String { + match &self.relation { + Some(r) => format!("{}.{}", r, self.name), + None => self.name.clone(), + } + } + + // Internal implementation of normalize + pub fn normalize_with_schemas( + self, + schemas: &[&Arc], + using_columns: &[HashSet], + ) -> Result { + if self.relation.is_some() { + return Ok(self); + } + + for schema in schemas { + let fields = schema.fields_with_unqualified_name(&self.name); + match fields.len() { + 0 => continue, + 1 => { + return Ok(fields[0].qualified_column()); + } + _ => { + // More than 1 fields in this schema have their names set to self.name. + // + // This should only happen when a JOIN query with USING constraint references + // join columns using unqualified column name. For example: + // + // ```sql + // SELECT id FROM t1 JOIN t2 USING(id) + // ``` + // + // In this case, both `t1.id` and `t2.id` will match unqualified column `id`. + // We will use the relation from the first matched field to normalize self. + + // Compare matched fields with one USING JOIN clause at a time + for using_col in using_columns { + let all_matched = fields + .iter() + .all(|f| using_col.contains(&f.qualified_column())); + // All matched fields belong to the same using column set, in orther words + // the same join clause. We simply pick the qualifer from the first match. + if all_matched { + return Ok(fields[0].qualified_column()); + } + } + } + } + } + + Err(DataFusionError::Plan(format!( + "Column {} not found in provided schemas", + self + ))) + } +} + +impl From<&str> for Column { + fn from(c: &str) -> Self { + Self::from_qualified_name(c) + } +} + +impl FromStr for Column { + type Err = Infallible; + + fn from_str(s: &str) -> std::result::Result { + Ok(s.into()) + } +} + +impl fmt::Display for Column { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match &self.relation { + Some(r) => write!(f, "#{}.{}", r, self.name), + None => write!(f, "#{}", self.name), + } + } +} diff --git a/datafusion-common/src/dfschema.rs b/datafusion-common/src/dfschema.rs new file mode 100644 index 000000000000..46321c313127 --- /dev/null +++ b/datafusion-common/src/dfschema.rs @@ -0,0 +1,722 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! DFSchema is an extended schema struct that DataFusion uses to provide support for +//! fields with optional relation names. + +use std::collections::HashSet; +use std::convert::TryFrom; +use std::sync::Arc; + +use crate::error::{DataFusionError, Result}; +use crate::Column; + +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use std::fmt::{Display, Formatter}; + +/// A reference-counted reference to a `DFSchema`. +pub type DFSchemaRef = Arc; + +/// DFSchema wraps an Arrow schema and adds relation names +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DFSchema { + /// Fields + fields: Vec, +} + +impl DFSchema { + /// Creates an empty `DFSchema` + pub fn empty() -> Self { + Self { fields: vec![] } + } + + /// Create a new `DFSchema` + pub fn new(fields: Vec) -> Result { + let mut qualified_names = HashSet::new(); + let mut unqualified_names = HashSet::new(); + + for field in &fields { + if let Some(qualifier) = field.qualifier() { + if !qualified_names.insert((qualifier, field.name())) { + return Err(DataFusionError::Plan(format!( + "Schema contains duplicate qualified field name '{}'", + field.qualified_name() + ))); + } + } else if !unqualified_names.insert(field.name()) { + return Err(DataFusionError::Plan(format!( + "Schema contains duplicate unqualified field name '{}'", + field.name() + ))); + } + } + + // check for mix of qualified and unqualified field with same unqualified name + // note that we need to sort the contents of the HashSet first so that errors are + // deterministic + let mut qualified_names = qualified_names + .iter() + .map(|(l, r)| (l.to_owned(), r.to_owned())) + .collect::>(); + qualified_names.sort_by(|a, b| { + let a = format!("{}.{}", a.0, a.1); + let b = format!("{}.{}", b.0, b.1); + a.cmp(&b) + }); + for (qualifier, name) in &qualified_names { + if unqualified_names.contains(name) { + return Err(DataFusionError::Plan(format!( + "Schema contains qualified field name '{}.{}' \ + and unqualified field name '{}' which would be ambiguous", + qualifier, name, name + ))); + } + } + Ok(Self { fields }) + } + + /// Create a `DFSchema` from an Arrow schema + pub fn try_from_qualified_schema(qualifier: &str, schema: &Schema) -> Result { + Self::new( + schema + .fields() + .iter() + .map(|f| DFField::from_qualified(qualifier, f.clone())) + .collect(), + ) + } + + /// Combine two schemas + pub fn join(&self, schema: &DFSchema) -> Result { + let mut fields = self.fields.clone(); + fields.extend_from_slice(schema.fields().as_slice()); + Self::new(fields) + } + + /// Merge a schema into self + pub fn merge(&mut self, other_schema: &DFSchema) { + for field in other_schema.fields() { + // skip duplicate columns + let duplicated_field = match field.qualifier() { + Some(q) => self.field_with_name(Some(q.as_str()), field.name()).is_ok(), + // for unqualifed columns, check as unqualified name + None => self.field_with_unqualified_name(field.name()).is_ok(), + }; + if !duplicated_field { + self.fields.push(field.clone()); + } + } + } + + /// Get a list of fields + pub fn fields(&self) -> &Vec { + &self.fields + } + + /// Returns an immutable reference of a specific `Field` instance selected using an + /// offset within the internal `fields` vector + pub fn field(&self, i: usize) -> &DFField { + &self.fields[i] + } + + /// Find the index of the column with the given unqualified name + pub fn index_of(&self, name: &str) -> Result { + for i in 0..self.fields.len() { + if self.fields[i].name() == name { + return Ok(i); + } + } + Err(DataFusionError::Plan(format!( + "No field named '{}'. Valid fields are {}.", + name, + self.get_field_names() + ))) + } + + fn index_of_column_by_name( + &self, + qualifier: Option<&str>, + name: &str, + ) -> Result { + let mut matches = self + .fields + .iter() + .enumerate() + .filter(|(_, field)| match (qualifier, &field.qualifier) { + // field to lookup is qualified. + // current field is qualified and not shared between relations, compare both + // qualifier and name. + (Some(q), Some(field_q)) => q == field_q && field.name() == name, + // field to lookup is qualified but current field is unqualified. + (Some(_), None) => false, + // field to lookup is unqualified, no need to compare qualifier + (None, Some(_)) | (None, None) => field.name() == name, + }) + .map(|(idx, _)| idx); + match matches.next() { + None => Err(DataFusionError::Plan(format!( + "No field named '{}.{}'. Valid fields are {}.", + qualifier.unwrap_or(""), + name, + self.get_field_names() + ))), + Some(idx) => match matches.next() { + None => Ok(idx), + // found more than one matches + Some(_) => Err(DataFusionError::Internal(format!( + "Ambiguous reference to qualified field named '{}.{}'", + qualifier.unwrap_or(""), + name + ))), + }, + } + } + + /// Find the index of the column with the given qualifier and name + pub fn index_of_column(&self, col: &Column) -> Result { + self.index_of_column_by_name(col.relation.as_deref(), &col.name) + } + + /// Find the field with the given name + pub fn field_with_name( + &self, + qualifier: Option<&str>, + name: &str, + ) -> Result<&DFField> { + if let Some(qualifier) = qualifier { + self.field_with_qualified_name(qualifier, name) + } else { + self.field_with_unqualified_name(name) + } + } + + /// Find all fields match the given name + pub fn fields_with_unqualified_name(&self, name: &str) -> Vec<&DFField> { + self.fields + .iter() + .filter(|field| field.name() == name) + .collect() + } + + /// Find the field with the given name + pub fn field_with_unqualified_name(&self, name: &str) -> Result<&DFField> { + let matches = self.fields_with_unqualified_name(name); + match matches.len() { + 0 => Err(DataFusionError::Plan(format!( + "No field with unqualified name '{}'. Valid fields are {}.", + name, + self.get_field_names() + ))), + 1 => Ok(matches[0]), + _ => Err(DataFusionError::Plan(format!( + "Ambiguous reference to field named '{}'", + name + ))), + } + } + + /// Find the field with the given qualified name + pub fn field_with_qualified_name( + &self, + qualifier: &str, + name: &str, + ) -> Result<&DFField> { + let idx = self.index_of_column_by_name(Some(qualifier), name)?; + Ok(self.field(idx)) + } + + /// Find the field with the given qualified column + pub fn field_from_column(&self, column: &Column) -> Result<&DFField> { + match &column.relation { + Some(r) => self.field_with_qualified_name(r, &column.name), + None => self.field_with_unqualified_name(&column.name), + } + } + + /// Check to see if unqualified field names matches field names in Arrow schema + pub fn matches_arrow_schema(&self, arrow_schema: &Schema) -> bool { + self.fields + .iter() + .zip(arrow_schema.fields().iter()) + .all(|(dffield, arrowfield)| dffield.name() == arrowfield.name()) + } + + /// Strip all field qualifier in schema + pub fn strip_qualifiers(self) -> Self { + DFSchema { + fields: self + .fields + .into_iter() + .map(|f| f.strip_qualifier()) + .collect(), + } + } + + /// Replace all field qualifier with new value in schema + pub fn replace_qualifier(self, qualifier: &str) -> Self { + DFSchema { + fields: self + .fields + .into_iter() + .map(|f| { + DFField::new( + Some(qualifier), + f.name(), + f.data_type().to_owned(), + f.is_nullable(), + ) + }) + .collect(), + } + } + + /// Get comma-seperated list of field names for use in error messages + fn get_field_names(&self) -> String { + self.fields + .iter() + .map(|f| match f.qualifier() { + Some(qualifier) => format!("'{}.{}'", qualifier, f.name()), + None => format!("'{}'", f.name()), + }) + .collect::>() + .join(", ") + } +} + +impl From for Schema { + /// Convert DFSchema into a Schema + fn from(df_schema: DFSchema) -> Self { + Schema::new( + df_schema + .fields + .into_iter() + .map(|f| { + if f.qualifier().is_some() { + Field::new( + f.name().as_str(), + f.data_type().to_owned(), + f.is_nullable(), + ) + } else { + f.field + } + }) + .collect(), + ) + } +} + +impl From<&DFSchema> for Schema { + /// Convert DFSchema reference into a Schema + fn from(df_schema: &DFSchema) -> Self { + Schema::new(df_schema.fields.iter().map(|f| f.field.clone()).collect()) + } +} + +/// Create a `DFSchema` from an Arrow schema +impl TryFrom for DFSchema { + type Error = DataFusionError; + fn try_from(schema: Schema) -> std::result::Result { + Self::new( + schema + .fields() + .iter() + .map(|f| DFField::from(f.clone())) + .collect(), + ) + } +} + +impl From for SchemaRef { + fn from(df_schema: DFSchema) -> Self { + SchemaRef::new(df_schema.into()) + } +} + +/// Convenience trait to convert Schema like things to DFSchema and DFSchemaRef with fewer keystrokes +pub trait ToDFSchema +where + Self: Sized, +{ + /// Attempt to create a DSSchema + #[allow(clippy::wrong_self_convention)] + fn to_dfschema(self) -> Result; + + /// Attempt to create a DSSchemaRef + #[allow(clippy::wrong_self_convention)] + fn to_dfschema_ref(self) -> Result { + Ok(Arc::new(self.to_dfschema()?)) + } +} + +impl ToDFSchema for Schema { + #[allow(clippy::wrong_self_convention)] + fn to_dfschema(self) -> Result { + DFSchema::try_from(self) + } +} + +impl ToDFSchema for SchemaRef { + #[allow(clippy::wrong_self_convention)] + fn to_dfschema(self) -> Result { + // Attempt to use the Schema directly if there are no other + // references, otherwise clone + match Self::try_unwrap(self) { + Ok(schema) => DFSchema::try_from(schema), + Err(schemaref) => DFSchema::try_from(schemaref.as_ref().clone()), + } + } +} + +impl ToDFSchema for Vec { + fn to_dfschema(self) -> Result { + DFSchema::new(self) + } +} + +impl Display for DFSchema { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + write!( + f, + "{}", + self.fields + .iter() + .map(|field| field.qualified_name()) + .collect::>() + .join(", ") + ) + } +} + +/// Provides schema information needed by [Expr] methods such as +/// [Expr::nullable] and [Expr::data_type]. +/// +/// Note that this trait is implemented for &[DFSchema] which is +/// widely used in the DataFusion codebase. +pub trait ExprSchema { + /// Is this column reference nullable? + fn nullable(&self, col: &Column) -> Result; + + /// What is the datatype of this column? + fn data_type(&self, col: &Column) -> Result<&DataType>; +} + +// Implement `ExprSchema` for `Arc` +impl> ExprSchema for P { + fn nullable(&self, col: &Column) -> Result { + self.as_ref().nullable(col) + } + + fn data_type(&self, col: &Column) -> Result<&DataType> { + self.as_ref().data_type(col) + } +} + +impl ExprSchema for DFSchema { + fn nullable(&self, col: &Column) -> Result { + Ok(self.field_from_column(col)?.is_nullable()) + } + + fn data_type(&self, col: &Column) -> Result<&DataType> { + Ok(self.field_from_column(col)?.data_type()) + } +} + +/// DFField wraps an Arrow field and adds an optional qualifier +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DFField { + /// Optional qualifier (usually a table or relation name) + qualifier: Option, + /// Arrow field definition + field: Field, +} + +impl DFField { + /// Creates a new `DFField` + pub fn new( + qualifier: Option<&str>, + name: &str, + data_type: DataType, + nullable: bool, + ) -> Self { + DFField { + qualifier: qualifier.map(|s| s.to_owned()), + field: Field::new(name, data_type, nullable), + } + } + + /// Create an unqualified field from an existing Arrow field + pub fn from(field: Field) -> Self { + Self { + qualifier: None, + field, + } + } + + /// Create a qualified field from an existing Arrow field + pub fn from_qualified(qualifier: &str, field: Field) -> Self { + Self { + qualifier: Some(qualifier.to_owned()), + field, + } + } + + /// Returns an immutable reference to the `DFField`'s unqualified name + pub fn name(&self) -> &String { + self.field.name() + } + + /// Returns an immutable reference to the `DFField`'s data-type + pub fn data_type(&self) -> &DataType { + self.field.data_type() + } + + /// Indicates whether this `DFField` supports null values + pub fn is_nullable(&self) -> bool { + self.field.is_nullable() + } + + /// Returns a string to the `DFField`'s qualified name + pub fn qualified_name(&self) -> String { + if let Some(qualifier) = &self.qualifier { + format!("{}.{}", qualifier, self.field.name()) + } else { + self.field.name().to_owned() + } + } + + /// Builds a qualified column based on self + pub fn qualified_column(&self) -> Column { + Column { + relation: self.qualifier.clone(), + name: self.field.name().to_string(), + } + } + + /// Builds an unqualified column based on self + pub fn unqualified_column(&self) -> Column { + Column { + relation: None, + name: self.field.name().to_string(), + } + } + + /// Get the optional qualifier + pub fn qualifier(&self) -> Option<&String> { + self.qualifier.as_ref() + } + + /// Get the arrow field + pub fn field(&self) -> &Field { + &self.field + } + + /// Return field with qualifier stripped + pub fn strip_qualifier(mut self) -> Self { + self.qualifier = None; + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::DataType; + + #[test] + fn from_unqualified_field() { + let field = Field::new("c0", DataType::Boolean, true); + let field = DFField::from(field); + assert_eq!("c0", field.name()); + assert_eq!("c0", field.qualified_name()); + } + + #[test] + fn from_qualified_field() { + let field = Field::new("c0", DataType::Boolean, true); + let field = DFField::from_qualified("t1", field); + assert_eq!("c0", field.name()); + assert_eq!("t1.c0", field.qualified_name()); + } + + #[test] + fn from_unqualified_schema() -> Result<()> { + let schema = DFSchema::try_from(test_schema_1())?; + assert_eq!("c0, c1", schema.to_string()); + Ok(()) + } + + #[test] + fn from_qualified_schema() -> Result<()> { + let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; + assert_eq!("t1.c0, t1.c1", schema.to_string()); + Ok(()) + } + + #[test] + fn from_qualified_schema_into_arrow_schema() -> Result<()> { + let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; + let arrow_schema: Schema = schema.into(); + let expected = "Field { name: \"c0\", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }, \ + Field { name: \"c1\", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }"; + assert_eq!(expected, arrow_schema.to_string()); + Ok(()) + } + + #[test] + fn join_qualified() -> Result<()> { + let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; + let right = DFSchema::try_from_qualified_schema("t2", &test_schema_1())?; + let join = left.join(&right)?; + assert_eq!("t1.c0, t1.c1, t2.c0, t2.c1", join.to_string()); + // test valid access + assert!(join.field_with_qualified_name("t1", "c0").is_ok()); + assert!(join.field_with_qualified_name("t2", "c0").is_ok()); + // test invalid access + assert!(join.field_with_unqualified_name("c0").is_err()); + assert!(join.field_with_unqualified_name("t1.c0").is_err()); + assert!(join.field_with_unqualified_name("t2.c0").is_err()); + Ok(()) + } + + #[test] + fn join_qualified_duplicate() -> Result<()> { + let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; + let right = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; + let join = left.join(&right); + assert!(join.is_err()); + assert_eq!( + "Error during planning: Schema contains duplicate \ + qualified field name \'t1.c0\'", + &format!("{}", join.err().unwrap()) + ); + Ok(()) + } + + #[test] + fn join_unqualified_duplicate() -> Result<()> { + let left = DFSchema::try_from(test_schema_1())?; + let right = DFSchema::try_from(test_schema_1())?; + let join = left.join(&right); + assert!(join.is_err()); + assert_eq!( + "Error during planning: Schema contains duplicate \ + unqualified field name \'c0\'", + &format!("{}", join.err().unwrap()) + ); + Ok(()) + } + + #[test] + fn join_mixed() -> Result<()> { + let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; + let right = DFSchema::try_from(test_schema_2())?; + let join = left.join(&right)?; + assert_eq!("t1.c0, t1.c1, c100, c101", join.to_string()); + // test valid access + assert!(join.field_with_qualified_name("t1", "c0").is_ok()); + assert!(join.field_with_unqualified_name("c0").is_ok()); + assert!(join.field_with_unqualified_name("c100").is_ok()); + assert!(join.field_with_name(None, "c100").is_ok()); + // test invalid access + assert!(join.field_with_unqualified_name("t1.c0").is_err()); + assert!(join.field_with_unqualified_name("t1.c100").is_err()); + assert!(join.field_with_qualified_name("", "c100").is_err()); + Ok(()) + } + + #[test] + fn join_mixed_duplicate() -> Result<()> { + let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; + let right = DFSchema::try_from(test_schema_1())?; + let join = left.join(&right); + assert!(join.is_err()); + assert_eq!( + "Error during planning: Schema contains qualified \ + field name \'t1.c0\' and unqualified field name \'c0\' which would be ambiguous", + &format!("{}", join.err().unwrap()) + ); + Ok(()) + } + + #[test] + fn helpful_error_messages() -> Result<()> { + let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; + let expected_help = "Valid fields are \'t1.c0\', \'t1.c1\'."; + assert!(schema + .field_with_qualified_name("x", "y") + .unwrap_err() + .to_string() + .contains(expected_help)); + assert!(schema + .field_with_unqualified_name("y") + .unwrap_err() + .to_string() + .contains(expected_help)); + assert!(schema + .index_of("y") + .unwrap_err() + .to_string() + .contains(expected_help)); + Ok(()) + } + + #[test] + fn into() { + // Demonstrate how to convert back and forth between Schema, SchemaRef, DFSchema, and DFSchemaRef + let arrow_schema = Schema::new(vec![Field::new("c0", DataType::Int64, true)]); + let arrow_schema_ref = Arc::new(arrow_schema.clone()); + + let df_schema = + DFSchema::new(vec![DFField::new(None, "c0", DataType::Int64, true)]).unwrap(); + let df_schema_ref = Arc::new(df_schema.clone()); + + { + let arrow_schema = arrow_schema.clone(); + let arrow_schema_ref = arrow_schema_ref.clone(); + + assert_eq!(df_schema, arrow_schema.to_dfschema().unwrap()); + assert_eq!(df_schema, arrow_schema_ref.to_dfschema().unwrap()); + } + + { + let arrow_schema = arrow_schema.clone(); + let arrow_schema_ref = arrow_schema_ref.clone(); + + assert_eq!(df_schema_ref, arrow_schema.to_dfschema_ref().unwrap()); + assert_eq!(df_schema_ref, arrow_schema_ref.to_dfschema_ref().unwrap()); + } + + // Now, consume the refs + assert_eq!(df_schema_ref, arrow_schema.to_dfschema_ref().unwrap()); + assert_eq!(df_schema_ref, arrow_schema_ref.to_dfschema_ref().unwrap()); + } + + fn test_schema_1() -> Schema { + Schema::new(vec![ + Field::new("c0", DataType::Boolean, true), + Field::new("c1", DataType::Boolean, true), + ]) + } + + fn test_schema_2() -> Schema { + Schema::new(vec![ + Field::new("c100", DataType::Boolean, true), + Field::new("c101", DataType::Boolean, true), + ]) + } +} diff --git a/datafusion-common/src/lib.rs b/datafusion-common/src/lib.rs index ac8ef623b0cd..11f9bbbb7e82 100644 --- a/datafusion-common/src/lib.rs +++ b/datafusion-common/src/lib.rs @@ -15,6 +15,10 @@ // specific language governing permissions and limitations // under the License. +mod column; +mod dfschema; mod error; +pub use column::Column; +pub use dfschema::{DFField, DFSchema, DFSchemaRef, ExprSchema, ToDFSchema}; pub use error::{DataFusionError, Result}; diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index 613c8e950c93..d81fa9d2afa6 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -595,6 +595,17 @@ impl LogicalPlanBuilder { self.join_detailed(right, join_type, join_keys, false) } + fn normalize( + plan: &LogicalPlan, + column: impl Into + Clone, + ) -> Result { + let schemas = plan.all_schemas(); + let using_columns = plan.using_columns()?; + column + .into() + .normalize_with_schemas(&schemas, &using_columns) + } + /// Apply a join with on constraint and specified null equality /// If null_equals_null is true then null == null, else null != null pub fn join_detailed( @@ -633,7 +644,10 @@ impl LogicalPlanBuilder { match (l_is_left, l_is_right, r_is_left, r_is_right) { (_, Ok(_), Ok(_), _) => (Ok(r), Ok(l)), (Ok(_), _, _, Ok(_)) => (Ok(l), Ok(r)), - _ => (l.normalize(&self.plan), r.normalize(right)), + _ => ( + Self::normalize(&self.plan, l), + Self::normalize(right, r), + ), } } (Some(lr), None) => { @@ -643,9 +657,12 @@ impl LogicalPlanBuilder { right.schema().field_with_qualified_name(lr, &l.name); match (l_is_left, l_is_right) { - (Ok(_), _) => (Ok(l), r.normalize(right)), - (_, Ok(_)) => (r.normalize(&self.plan), Ok(l)), - _ => (l.normalize(&self.plan), r.normalize(right)), + (Ok(_), _) => (Ok(l), Self::normalize(right, r)), + (_, Ok(_)) => (Self::normalize(&self.plan, r), Ok(l)), + _ => ( + Self::normalize(&self.plan, l), + Self::normalize(right, r), + ), } } (None, Some(rr)) => { @@ -655,22 +672,25 @@ impl LogicalPlanBuilder { right.schema().field_with_qualified_name(rr, &r.name); match (r_is_left, r_is_right) { - (Ok(_), _) => (Ok(r), l.normalize(right)), - (_, Ok(_)) => (l.normalize(&self.plan), Ok(r)), - _ => (l.normalize(&self.plan), r.normalize(right)), + (Ok(_), _) => (Ok(r), Self::normalize(right, l)), + (_, Ok(_)) => (Self::normalize(&self.plan, l), Ok(r)), + _ => ( + Self::normalize(&self.plan, l), + Self::normalize(right, r), + ), } } (None, None) => { let mut swap = false; - let left_key = - l.clone().normalize(&self.plan).or_else(|_| { + let left_key = Self::normalize(&self.plan, l.clone()) + .or_else(|_| { swap = true; - l.normalize(right) + Self::normalize(right, l) }); if swap { - (r.normalize(&self.plan), left_key) + (Self::normalize(&self.plan, r), left_key) } else { - (left_key, r.normalize(right)) + (left_key, Self::normalize(right, r)) } } } @@ -705,11 +725,11 @@ impl LogicalPlanBuilder { let left_keys: Vec = using_keys .clone() .into_iter() - .map(|c| c.into().normalize(&self.plan)) + .map(|c| Self::normalize(&self.plan, c)) .collect::>()?; let right_keys: Vec = using_keys .into_iter() - .map(|c| c.into().normalize(right)) + .map(|c| Self::normalize(right, c)) .collect::>()?; let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect(); diff --git a/datafusion/src/logical_plan/dfschema.rs b/datafusion/src/logical_plan/dfschema.rs index 7b6471f64dd7..eb624283ea4f 100644 --- a/datafusion/src/logical_plan/dfschema.rs +++ b/datafusion/src/logical_plan/dfschema.rs @@ -18,671 +18,4 @@ //! DFSchema is an extended schema struct that DataFusion uses to provide support for //! fields with optional relation names. -use std::collections::HashSet; -use std::convert::TryFrom; -use std::sync::Arc; - -use crate::error::{DataFusionError, Result}; -use crate::logical_plan::Column; - -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use std::fmt::{Display, Formatter}; - -/// A reference-counted reference to a `DFSchema`. -pub type DFSchemaRef = Arc; - -/// DFSchema wraps an Arrow schema and adds relation names -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct DFSchema { - /// Fields - fields: Vec, -} - -impl DFSchema { - /// Creates an empty `DFSchema` - pub fn empty() -> Self { - Self { fields: vec![] } - } - - /// Create a new `DFSchema` - pub fn new(fields: Vec) -> Result { - let mut qualified_names = HashSet::new(); - let mut unqualified_names = HashSet::new(); - - for field in &fields { - if let Some(qualifier) = field.qualifier() { - if !qualified_names.insert((qualifier, field.name())) { - return Err(DataFusionError::Plan(format!( - "Schema contains duplicate qualified field name '{}'", - field.qualified_name() - ))); - } - } else if !unqualified_names.insert(field.name()) { - return Err(DataFusionError::Plan(format!( - "Schema contains duplicate unqualified field name '{}'", - field.name() - ))); - } - } - - // check for mix of qualified and unqualified field with same unqualified name - // note that we need to sort the contents of the HashSet first so that errors are - // deterministic - let mut qualified_names = qualified_names - .iter() - .map(|(l, r)| (l.to_owned(), r.to_owned())) - .collect::>(); - qualified_names.sort_by(|a, b| { - let a = format!("{}.{}", a.0, a.1); - let b = format!("{}.{}", b.0, b.1); - a.cmp(&b) - }); - for (qualifier, name) in &qualified_names { - if unqualified_names.contains(name) { - return Err(DataFusionError::Plan(format!( - "Schema contains qualified field name '{}.{}' \ - and unqualified field name '{}' which would be ambiguous", - qualifier, name, name - ))); - } - } - Ok(Self { fields }) - } - - /// Create a `DFSchema` from an Arrow schema - pub fn try_from_qualified_schema(qualifier: &str, schema: &Schema) -> Result { - Self::new( - schema - .fields() - .iter() - .map(|f| DFField::from_qualified(qualifier, f.clone())) - .collect(), - ) - } - - /// Combine two schemas - pub fn join(&self, schema: &DFSchema) -> Result { - let mut fields = self.fields.clone(); - fields.extend_from_slice(schema.fields().as_slice()); - Self::new(fields) - } - - /// Merge a schema into self - pub fn merge(&mut self, other_schema: &DFSchema) { - for field in other_schema.fields() { - // skip duplicate columns - let duplicated_field = match field.qualifier() { - Some(q) => self.field_with_name(Some(q.as_str()), field.name()).is_ok(), - // for unqualifed columns, check as unqualified name - None => self.field_with_unqualified_name(field.name()).is_ok(), - }; - if !duplicated_field { - self.fields.push(field.clone()); - } - } - } - - /// Get a list of fields - pub fn fields(&self) -> &Vec { - &self.fields - } - - /// Returns an immutable reference of a specific `Field` instance selected using an - /// offset within the internal `fields` vector - pub fn field(&self, i: usize) -> &DFField { - &self.fields[i] - } - - /// Find the index of the column with the given unqualified name - pub fn index_of(&self, name: &str) -> Result { - for i in 0..self.fields.len() { - if self.fields[i].name() == name { - return Ok(i); - } - } - Err(DataFusionError::Plan(format!( - "No field named '{}'. Valid fields are {}.", - name, - self.get_field_names() - ))) - } - - fn index_of_column_by_name( - &self, - qualifier: Option<&str>, - name: &str, - ) -> Result { - let mut matches = self - .fields - .iter() - .enumerate() - .filter(|(_, field)| match (qualifier, &field.qualifier) { - // field to lookup is qualified. - // current field is qualified and not shared between relations, compare both - // qualifier and name. - (Some(q), Some(field_q)) => q == field_q && field.name() == name, - // field to lookup is qualified but current field is unqualified. - (Some(_), None) => false, - // field to lookup is unqualified, no need to compare qualifier - (None, Some(_)) | (None, None) => field.name() == name, - }) - .map(|(idx, _)| idx); - match matches.next() { - None => Err(DataFusionError::Plan(format!( - "No field named '{}.{}'. Valid fields are {}.", - qualifier.unwrap_or(""), - name, - self.get_field_names() - ))), - Some(idx) => match matches.next() { - None => Ok(idx), - // found more than one matches - Some(_) => Err(DataFusionError::Internal(format!( - "Ambiguous reference to qualified field named '{}.{}'", - qualifier.unwrap_or(""), - name - ))), - }, - } - } - - /// Find the index of the column with the given qualifier and name - pub fn index_of_column(&self, col: &Column) -> Result { - self.index_of_column_by_name(col.relation.as_deref(), &col.name) - } - - /// Find the field with the given name - pub fn field_with_name( - &self, - qualifier: Option<&str>, - name: &str, - ) -> Result<&DFField> { - if let Some(qualifier) = qualifier { - self.field_with_qualified_name(qualifier, name) - } else { - self.field_with_unqualified_name(name) - } - } - - /// Find all fields match the given name - pub fn fields_with_unqualified_name(&self, name: &str) -> Vec<&DFField> { - self.fields - .iter() - .filter(|field| field.name() == name) - .collect() - } - - /// Find the field with the given name - pub fn field_with_unqualified_name(&self, name: &str) -> Result<&DFField> { - let matches = self.fields_with_unqualified_name(name); - match matches.len() { - 0 => Err(DataFusionError::Plan(format!( - "No field with unqualified name '{}'. Valid fields are {}.", - name, - self.get_field_names() - ))), - 1 => Ok(matches[0]), - _ => Err(DataFusionError::Plan(format!( - "Ambiguous reference to field named '{}'", - name - ))), - } - } - - /// Find the field with the given qualified name - pub fn field_with_qualified_name( - &self, - qualifier: &str, - name: &str, - ) -> Result<&DFField> { - let idx = self.index_of_column_by_name(Some(qualifier), name)?; - Ok(self.field(idx)) - } - - /// Find the field with the given qualified column - pub fn field_from_column(&self, column: &Column) -> Result<&DFField> { - match &column.relation { - Some(r) => self.field_with_qualified_name(r, &column.name), - None => self.field_with_unqualified_name(&column.name), - } - } - - /// Check to see if unqualified field names matches field names in Arrow schema - pub fn matches_arrow_schema(&self, arrow_schema: &Schema) -> bool { - self.fields - .iter() - .zip(arrow_schema.fields().iter()) - .all(|(dffield, arrowfield)| dffield.name() == arrowfield.name()) - } - - /// Strip all field qualifier in schema - pub fn strip_qualifiers(self) -> Self { - DFSchema { - fields: self - .fields - .into_iter() - .map(|f| f.strip_qualifier()) - .collect(), - } - } - - /// Replace all field qualifier with new value in schema - pub fn replace_qualifier(self, qualifier: &str) -> Self { - DFSchema { - fields: self - .fields - .into_iter() - .map(|f| { - DFField::new( - Some(qualifier), - f.name(), - f.data_type().to_owned(), - f.is_nullable(), - ) - }) - .collect(), - } - } - - /// Get comma-seperated list of field names for use in error messages - fn get_field_names(&self) -> String { - self.fields - .iter() - .map(|f| match f.qualifier() { - Some(qualifier) => format!("'{}.{}'", qualifier, f.name()), - None => format!("'{}'", f.name()), - }) - .collect::>() - .join(", ") - } -} - -impl From for Schema { - /// Convert DFSchema into a Schema - fn from(df_schema: DFSchema) -> Self { - Schema::new( - df_schema - .fields - .into_iter() - .map(|f| { - if f.qualifier().is_some() { - Field::new( - f.name().as_str(), - f.data_type().to_owned(), - f.is_nullable(), - ) - } else { - f.field - } - }) - .collect(), - ) - } -} - -impl From<&DFSchema> for Schema { - /// Convert DFSchema reference into a Schema - fn from(df_schema: &DFSchema) -> Self { - Schema::new(df_schema.fields.iter().map(|f| f.field.clone()).collect()) - } -} - -/// Create a `DFSchema` from an Arrow schema -impl TryFrom for DFSchema { - type Error = DataFusionError; - fn try_from(schema: Schema) -> std::result::Result { - Self::new( - schema - .fields() - .iter() - .map(|f| DFField::from(f.clone())) - .collect(), - ) - } -} - -impl From for SchemaRef { - fn from(df_schema: DFSchema) -> Self { - SchemaRef::new(df_schema.into()) - } -} - -/// Convenience trait to convert Schema like things to DFSchema and DFSchemaRef with fewer keystrokes -pub trait ToDFSchema -where - Self: Sized, -{ - /// Attempt to create a DSSchema - #[allow(clippy::wrong_self_convention)] - fn to_dfschema(self) -> Result; - - /// Attempt to create a DSSchemaRef - #[allow(clippy::wrong_self_convention)] - fn to_dfschema_ref(self) -> Result { - Ok(Arc::new(self.to_dfschema()?)) - } -} - -impl ToDFSchema for Schema { - #[allow(clippy::wrong_self_convention)] - fn to_dfschema(self) -> Result { - DFSchema::try_from(self) - } -} - -impl ToDFSchema for SchemaRef { - #[allow(clippy::wrong_self_convention)] - fn to_dfschema(self) -> Result { - // Attempt to use the Schema directly if there are no other - // references, otherwise clone - match Self::try_unwrap(self) { - Ok(schema) => DFSchema::try_from(schema), - Err(schemaref) => DFSchema::try_from(schemaref.as_ref().clone()), - } - } -} - -impl ToDFSchema for Vec { - fn to_dfschema(self) -> Result { - DFSchema::new(self) - } -} - -impl Display for DFSchema { - fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { - write!( - f, - "{}", - self.fields - .iter() - .map(|field| field.qualified_name()) - .collect::>() - .join(", ") - ) - } -} - -/// DFField wraps an Arrow field and adds an optional qualifier -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct DFField { - /// Optional qualifier (usually a table or relation name) - qualifier: Option, - /// Arrow field definition - field: Field, -} - -impl DFField { - /// Creates a new `DFField` - pub fn new( - qualifier: Option<&str>, - name: &str, - data_type: DataType, - nullable: bool, - ) -> Self { - DFField { - qualifier: qualifier.map(|s| s.to_owned()), - field: Field::new(name, data_type, nullable), - } - } - - /// Create an unqualified field from an existing Arrow field - pub fn from(field: Field) -> Self { - Self { - qualifier: None, - field, - } - } - - /// Create a qualified field from an existing Arrow field - pub fn from_qualified(qualifier: &str, field: Field) -> Self { - Self { - qualifier: Some(qualifier.to_owned()), - field, - } - } - - /// Returns an immutable reference to the `DFField`'s unqualified name - pub fn name(&self) -> &String { - self.field.name() - } - - /// Returns an immutable reference to the `DFField`'s data-type - pub fn data_type(&self) -> &DataType { - self.field.data_type() - } - - /// Indicates whether this `DFField` supports null values - pub fn is_nullable(&self) -> bool { - self.field.is_nullable() - } - - /// Returns a string to the `DFField`'s qualified name - pub fn qualified_name(&self) -> String { - if let Some(qualifier) = &self.qualifier { - format!("{}.{}", qualifier, self.field.name()) - } else { - self.field.name().to_owned() - } - } - - /// Builds a qualified column based on self - pub fn qualified_column(&self) -> Column { - Column { - relation: self.qualifier.clone(), - name: self.field.name().to_string(), - } - } - - /// Builds an unqualified column based on self - pub fn unqualified_column(&self) -> Column { - Column { - relation: None, - name: self.field.name().to_string(), - } - } - - /// Get the optional qualifier - pub fn qualifier(&self) -> Option<&String> { - self.qualifier.as_ref() - } - - /// Get the arrow field - pub fn field(&self) -> &Field { - &self.field - } - - /// Return field with qualifier stripped - pub fn strip_qualifier(mut self) -> Self { - self.qualifier = None; - self - } -} - -#[cfg(test)] -mod tests { - use super::*; - use arrow::datatypes::DataType; - - #[test] - fn from_unqualified_field() { - let field = Field::new("c0", DataType::Boolean, true); - let field = DFField::from(field); - assert_eq!("c0", field.name()); - assert_eq!("c0", field.qualified_name()); - } - - #[test] - fn from_qualified_field() { - let field = Field::new("c0", DataType::Boolean, true); - let field = DFField::from_qualified("t1", field); - assert_eq!("c0", field.name()); - assert_eq!("t1.c0", field.qualified_name()); - } - - #[test] - fn from_unqualified_schema() -> Result<()> { - let schema = DFSchema::try_from(test_schema_1())?; - assert_eq!("c0, c1", schema.to_string()); - Ok(()) - } - - #[test] - fn from_qualified_schema() -> Result<()> { - let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - assert_eq!("t1.c0, t1.c1", schema.to_string()); - Ok(()) - } - - #[test] - fn from_qualified_schema_into_arrow_schema() -> Result<()> { - let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - let arrow_schema: Schema = schema.into(); - let expected = "Field { name: \"c0\", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }, \ - Field { name: \"c1\", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }"; - assert_eq!(expected, arrow_schema.to_string()); - Ok(()) - } - - #[test] - fn join_qualified() -> Result<()> { - let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - let right = DFSchema::try_from_qualified_schema("t2", &test_schema_1())?; - let join = left.join(&right)?; - assert_eq!("t1.c0, t1.c1, t2.c0, t2.c1", join.to_string()); - // test valid access - assert!(join.field_with_qualified_name("t1", "c0").is_ok()); - assert!(join.field_with_qualified_name("t2", "c0").is_ok()); - // test invalid access - assert!(join.field_with_unqualified_name("c0").is_err()); - assert!(join.field_with_unqualified_name("t1.c0").is_err()); - assert!(join.field_with_unqualified_name("t2.c0").is_err()); - Ok(()) - } - - #[test] - fn join_qualified_duplicate() -> Result<()> { - let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - let right = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - let join = left.join(&right); - assert!(join.is_err()); - assert_eq!( - "Error during planning: Schema contains duplicate \ - qualified field name \'t1.c0\'", - &format!("{}", join.err().unwrap()) - ); - Ok(()) - } - - #[test] - fn join_unqualified_duplicate() -> Result<()> { - let left = DFSchema::try_from(test_schema_1())?; - let right = DFSchema::try_from(test_schema_1())?; - let join = left.join(&right); - assert!(join.is_err()); - assert_eq!( - "Error during planning: Schema contains duplicate \ - unqualified field name \'c0\'", - &format!("{}", join.err().unwrap()) - ); - Ok(()) - } - - #[test] - fn join_mixed() -> Result<()> { - let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - let right = DFSchema::try_from(test_schema_2())?; - let join = left.join(&right)?; - assert_eq!("t1.c0, t1.c1, c100, c101", join.to_string()); - // test valid access - assert!(join.field_with_qualified_name("t1", "c0").is_ok()); - assert!(join.field_with_unqualified_name("c0").is_ok()); - assert!(join.field_with_unqualified_name("c100").is_ok()); - assert!(join.field_with_name(None, "c100").is_ok()); - // test invalid access - assert!(join.field_with_unqualified_name("t1.c0").is_err()); - assert!(join.field_with_unqualified_name("t1.c100").is_err()); - assert!(join.field_with_qualified_name("", "c100").is_err()); - Ok(()) - } - - #[test] - fn join_mixed_duplicate() -> Result<()> { - let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - let right = DFSchema::try_from(test_schema_1())?; - let join = left.join(&right); - assert!(join.is_err()); - assert_eq!( - "Error during planning: Schema contains qualified \ - field name \'t1.c0\' and unqualified field name \'c0\' which would be ambiguous", - &format!("{}", join.err().unwrap()) - ); - Ok(()) - } - - #[test] - fn helpful_error_messages() -> Result<()> { - let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - let expected_help = "Valid fields are \'t1.c0\', \'t1.c1\'."; - assert!(schema - .field_with_qualified_name("x", "y") - .unwrap_err() - .to_string() - .contains(expected_help)); - assert!(schema - .field_with_unqualified_name("y") - .unwrap_err() - .to_string() - .contains(expected_help)); - assert!(schema - .index_of("y") - .unwrap_err() - .to_string() - .contains(expected_help)); - Ok(()) - } - - #[test] - fn into() { - // Demonstrate how to convert back and forth between Schema, SchemaRef, DFSchema, and DFSchemaRef - let arrow_schema = Schema::new(vec![Field::new("c0", DataType::Int64, true)]); - let arrow_schema_ref = Arc::new(arrow_schema.clone()); - - let df_schema = - DFSchema::new(vec![DFField::new(None, "c0", DataType::Int64, true)]).unwrap(); - let df_schema_ref = Arc::new(df_schema.clone()); - - { - let arrow_schema = arrow_schema.clone(); - let arrow_schema_ref = arrow_schema_ref.clone(); - - assert_eq!(df_schema, arrow_schema.to_dfschema().unwrap()); - assert_eq!(df_schema, arrow_schema_ref.to_dfschema().unwrap()); - } - - { - let arrow_schema = arrow_schema.clone(); - let arrow_schema_ref = arrow_schema_ref.clone(); - - assert_eq!(df_schema_ref, arrow_schema.to_dfschema_ref().unwrap()); - assert_eq!(df_schema_ref, arrow_schema_ref.to_dfschema_ref().unwrap()); - } - - // Now, consume the refs - assert_eq!(df_schema_ref, arrow_schema.to_dfschema_ref().unwrap()); - assert_eq!(df_schema_ref, arrow_schema_ref.to_dfschema_ref().unwrap()); - } - - fn test_schema_1() -> Schema { - Schema::new(vec![ - Field::new("c0", DataType::Boolean, true), - Field::new("c1", DataType::Boolean, true), - ]) - } - - fn test_schema_2() -> Schema { - Schema::new(vec![ - Field::new("c100", DataType::Boolean, true), - Field::new("c101", DataType::Boolean, true), - ]) - } -} +pub use datafusion_common::{DFField, DFSchema, DFSchemaRef, ExprSchema, ToDFSchema}; diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index c2763d097e85..4b539a814551 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -34,152 +34,14 @@ use crate::physical_plan::{ use crate::{physical_plan::udaf::AggregateUDF, scalar::ScalarValue}; use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction}; use arrow::{compute::can_cast_types, datatypes::DataType}; +pub use datafusion_common::{Column, ExprSchema}; use functions::{ReturnTypeFunction, ScalarFunctionImplementation, Signature}; use std::collections::{HashMap, HashSet}; -use std::convert::Infallible; use std::fmt; use std::hash::{BuildHasher, Hash, Hasher}; use std::ops::Not; -use std::str::FromStr; use std::sync::Arc; -/// A named reference to a qualified field in a schema. -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct Column { - /// relation/table name. - pub relation: Option, - /// field/column name. - pub name: String, -} - -impl Column { - /// Create Column from unqualified name. - pub fn from_name(name: impl Into) -> Self { - Self { - relation: None, - name: name.into(), - } - } - - /// Deserialize a fully qualified name string into a column - pub fn from_qualified_name(flat_name: &str) -> Self { - use sqlparser::tokenizer::Token; - - let dialect = sqlparser::dialect::GenericDialect {}; - let mut tokenizer = sqlparser::tokenizer::Tokenizer::new(&dialect, flat_name); - if let Ok(tokens) = tokenizer.tokenize() { - if let [Token::Word(relation), Token::Period, Token::Word(name)] = - tokens.as_slice() - { - return Column { - relation: Some(relation.value.clone()), - name: name.value.clone(), - }; - } - } - // any expression that's not in the form of `foo.bar` will be treated as unqualified column - // name - Column { - relation: None, - name: String::from(flat_name), - } - } - - /// Serialize column into a flat name string - pub fn flat_name(&self) -> String { - match &self.relation { - Some(r) => format!("{}.{}", r, self.name), - None => self.name.clone(), - } - } - - /// Normalizes `self` if is unqualified (has no relation name) - /// with an explicit qualifier from the first matching input - /// schemas. - /// - /// For example, `foo` will be normalized to `t.foo` if there is a - /// column named `foo` in a relation named `t` found in `schemas` - pub fn normalize(self, plan: &LogicalPlan) -> Result { - let schemas = plan.all_schemas(); - let using_columns = plan.using_columns()?; - self.normalize_with_schemas(&schemas, &using_columns) - } - - // Internal implementation of normalize - fn normalize_with_schemas( - self, - schemas: &[&Arc], - using_columns: &[HashSet], - ) -> Result { - if self.relation.is_some() { - return Ok(self); - } - - for schema in schemas { - let fields = schema.fields_with_unqualified_name(&self.name); - match fields.len() { - 0 => continue, - 1 => { - return Ok(fields[0].qualified_column()); - } - _ => { - // More than 1 fields in this schema have their names set to self.name. - // - // This should only happen when a JOIN query with USING constraint references - // join columns using unqualified column name. For example: - // - // ```sql - // SELECT id FROM t1 JOIN t2 USING(id) - // ``` - // - // In this case, both `t1.id` and `t2.id` will match unqualified column `id`. - // We will use the relation from the first matched field to normalize self. - - // Compare matched fields with one USING JOIN clause at a time - for using_col in using_columns { - let all_matched = fields - .iter() - .all(|f| using_col.contains(&f.qualified_column())); - // All matched fields belong to the same using column set, in orther words - // the same join clause. We simply pick the qualifer from the first match. - if all_matched { - return Ok(fields[0].qualified_column()); - } - } - } - } - } - - Err(DataFusionError::Plan(format!( - "Column {} not found in provided schemas", - self - ))) - } -} - -impl From<&str> for Column { - fn from(c: &str) -> Self { - Self::from_qualified_name(c) - } -} - -impl FromStr for Column { - type Err = Infallible; - - fn from_str(s: &str) -> std::result::Result { - Ok(s.into()) - } -} - -impl fmt::Display for Column { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match &self.relation { - Some(r) => write!(f, "#{}.{}", r, self.name), - None => write!(f, "#{}", self.name), - } - } -} - /// `Expr` is a central struct of DataFusion's query API, and /// represent logical expressions such as `A + 1`, or `CAST(c1 AS /// int)`. @@ -392,40 +254,6 @@ impl PartialOrd for Expr { } } -/// Provides schema information needed by [Expr] methods such as -/// [Expr::nullable] and [Expr::data_type]. -/// -/// Note that this trait is implemented for &[DFSchema] which is -/// widely used in the DataFusion codebase. -pub trait ExprSchema { - /// Is this column reference nullable? - fn nullable(&self, col: &Column) -> Result; - - /// What is the datatype of this column? - fn data_type(&self, col: &Column) -> Result<&DataType>; -} - -// Implement `ExprSchema` for `Arc` -impl> ExprSchema for P { - fn nullable(&self, col: &Column) -> Result { - self.as_ref().nullable(col) - } - - fn data_type(&self, col: &Column) -> Result<&DataType> { - self.as_ref().data_type(col) - } -} - -impl ExprSchema for DFSchema { - fn nullable(&self, col: &Column) -> Result { - Ok(self.field_from_column(col)?.is_nullable()) - } - - fn data_type(&self, col: &Column) -> Result<&DataType> { - Ok(self.field_from_column(col)?.data_type()) - } -} - impl Expr { /// Returns the [arrow::datatypes::DataType] of the expression /// based on [ExprSchema] From a39a2231efe7d2ac52e80853df13394026759c98 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Mon, 7 Feb 2022 23:20:54 +0800 Subject: [PATCH 39/50] add datafusion-expr module (#1759) --- Cargo.toml | 1 + datafusion-expr/Cargo.toml | 39 ++++ datafusion-expr/README.md | 24 +++ datafusion-expr/src/aggregate_function.rs | 93 ++++++++ datafusion-expr/src/lib.rs | 22 ++ datafusion-expr/src/window_function.rs | 204 ++++++++++++++++++ datafusion/Cargo.toml | 1 + datafusion/src/physical_plan/aggregates.rs | 76 +------ .../src/physical_plan/window_functions.rs | 186 +--------------- 9 files changed, 390 insertions(+), 256 deletions(-) create mode 100644 datafusion-expr/Cargo.toml create mode 100644 datafusion-expr/README.md create mode 100644 datafusion-expr/src/aggregate_function.rs create mode 100644 datafusion-expr/src/lib.rs create mode 100644 datafusion-expr/src/window_function.rs diff --git a/Cargo.toml b/Cargo.toml index 81f6bb59f2d0..f74f53ced323 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ members = [ "datafusion", "datafusion-common", + "datafusion-expr", "datafusion-cli", "datafusion-examples", "benchmarks", diff --git a/datafusion-expr/Cargo.toml b/datafusion-expr/Cargo.toml new file mode 100644 index 000000000000..c3be893ae87e --- /dev/null +++ b/datafusion-expr/Cargo.toml @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-expr" +description = "DataFusion is an in-memory query engine that uses Apache Arrow as the memory model" +version = "6.0.0" +homepage = "https://github.com/apache/arrow-datafusion" +repository = "https://github.com/apache/arrow-datafusion" +readme = "../README.md" +authors = ["Apache Arrow "] +license = "Apache-2.0" +keywords = [ "arrow", "query", "sql" ] +edition = "2021" +rust-version = "1.58" + +[lib] +name = "datafusion_expr" +path = "src/lib.rs" + +[features] + +[dependencies] +datafusion-common = { path = "../datafusion-common", version = "6.0.0" } +arrow = { version = "8.0.0", features = ["prettyprint"] } diff --git a/datafusion-expr/README.md b/datafusion-expr/README.md new file mode 100644 index 000000000000..25ac79c223c1 --- /dev/null +++ b/datafusion-expr/README.md @@ -0,0 +1,24 @@ + + +# DataFusion Expr + +This is an internal module for fundamental expression types of [DataFusion][df]. + +[df]: https://crates.io/crates/datafusion diff --git a/datafusion-expr/src/aggregate_function.rs b/datafusion-expr/src/aggregate_function.rs new file mode 100644 index 000000000000..8f12e88bf1a2 --- /dev/null +++ b/datafusion-expr/src/aggregate_function.rs @@ -0,0 +1,93 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::{DataFusionError, Result}; +use std::{fmt, str::FromStr}; + +/// Enum of all built-in aggregate functions +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +pub enum AggregateFunction { + /// count + Count, + /// sum + Sum, + /// min + Min, + /// max + Max, + /// avg + Avg, + /// Approximate aggregate function + ApproxDistinct, + /// array_agg + ArrayAgg, + /// Variance (Sample) + Variance, + /// Variance (Population) + VariancePop, + /// Standard Deviation (Sample) + Stddev, + /// Standard Deviation (Population) + StddevPop, + /// Covariance (Sample) + Covariance, + /// Covariance (Population) + CovariancePop, + /// Correlation + Correlation, + /// Approximate continuous percentile function + ApproxPercentileCont, +} + +impl fmt::Display for AggregateFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // uppercase of the debug. + write!(f, "{}", format!("{:?}", self).to_uppercase()) + } +} + +impl FromStr for AggregateFunction { + type Err = DataFusionError; + fn from_str(name: &str) -> Result { + Ok(match name { + "min" => AggregateFunction::Min, + "max" => AggregateFunction::Max, + "count" => AggregateFunction::Count, + "avg" => AggregateFunction::Avg, + "sum" => AggregateFunction::Sum, + "approx_distinct" => AggregateFunction::ApproxDistinct, + "array_agg" => AggregateFunction::ArrayAgg, + "var" => AggregateFunction::Variance, + "var_samp" => AggregateFunction::Variance, + "var_pop" => AggregateFunction::VariancePop, + "stddev" => AggregateFunction::Stddev, + "stddev_samp" => AggregateFunction::Stddev, + "stddev_pop" => AggregateFunction::StddevPop, + "covar" => AggregateFunction::Covariance, + "covar_samp" => AggregateFunction::Covariance, + "covar_pop" => AggregateFunction::CovariancePop, + "corr" => AggregateFunction::Correlation, + "approx_percentile_cont" => AggregateFunction::ApproxPercentileCont, + _ => { + return Err(DataFusionError::Plan(format!( + "There is no built-in function named {}", + name + ))); + } + }) + } +} diff --git a/datafusion-expr/src/lib.rs b/datafusion-expr/src/lib.rs new file mode 100644 index 000000000000..b6eaaf7c6659 --- /dev/null +++ b/datafusion-expr/src/lib.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod aggregate_function; +mod window_function; + +pub use aggregate_function::AggregateFunction; +pub use window_function::{BuiltInWindowFunction, WindowFunction}; diff --git a/datafusion-expr/src/window_function.rs b/datafusion-expr/src/window_function.rs new file mode 100644 index 000000000000..59523d6540b2 --- /dev/null +++ b/datafusion-expr/src/window_function.rs @@ -0,0 +1,204 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aggregate_function::AggregateFunction; +use datafusion_common::{DataFusionError, Result}; +use std::{fmt, str::FromStr}; + +/// WindowFunction +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum WindowFunction { + /// window function that leverages an aggregate function + AggregateFunction(AggregateFunction), + /// window function that leverages a built-in window function + BuiltInWindowFunction(BuiltInWindowFunction), +} + +impl FromStr for WindowFunction { + type Err = DataFusionError; + fn from_str(name: &str) -> Result { + let name = name.to_lowercase(); + if let Ok(aggregate) = AggregateFunction::from_str(name.as_str()) { + Ok(WindowFunction::AggregateFunction(aggregate)) + } else if let Ok(built_in_function) = + BuiltInWindowFunction::from_str(name.as_str()) + { + Ok(WindowFunction::BuiltInWindowFunction(built_in_function)) + } else { + Err(DataFusionError::Plan(format!( + "There is no window function named {}", + name + ))) + } + } +} + +impl fmt::Display for BuiltInWindowFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + BuiltInWindowFunction::RowNumber => write!(f, "ROW_NUMBER"), + BuiltInWindowFunction::Rank => write!(f, "RANK"), + BuiltInWindowFunction::DenseRank => write!(f, "DENSE_RANK"), + BuiltInWindowFunction::PercentRank => write!(f, "PERCENT_RANK"), + BuiltInWindowFunction::CumeDist => write!(f, "CUME_DIST"), + BuiltInWindowFunction::Ntile => write!(f, "NTILE"), + BuiltInWindowFunction::Lag => write!(f, "LAG"), + BuiltInWindowFunction::Lead => write!(f, "LEAD"), + BuiltInWindowFunction::FirstValue => write!(f, "FIRST_VALUE"), + BuiltInWindowFunction::LastValue => write!(f, "LAST_VALUE"), + BuiltInWindowFunction::NthValue => write!(f, "NTH_VALUE"), + } + } +} + +impl fmt::Display for WindowFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + WindowFunction::AggregateFunction(fun) => fun.fmt(f), + WindowFunction::BuiltInWindowFunction(fun) => fun.fmt(f), + } + } +} + +/// An aggregate function that is part of a built-in window function +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum BuiltInWindowFunction { + /// number of the current row within its partition, counting from 1 + RowNumber, + /// rank of the current row with gaps; same as row_number of its first peer + Rank, + /// ank of the current row without gaps; this function counts peer groups + DenseRank, + /// relative rank of the current row: (rank - 1) / (total rows - 1) + PercentRank, + /// relative rank of the current row: (number of rows preceding or peer with current row) / (total rows) + CumeDist, + /// integer ranging from 1 to the argument value, dividing the partition as equally as possible + Ntile, + /// returns value evaluated at the row that is offset rows before the current row within the partition; + /// if there is no such row, instead return default (which must be of the same type as value). + /// Both offset and default are evaluated with respect to the current row. + /// If omitted, offset defaults to 1 and default to null + Lag, + /// returns value evaluated at the row that is offset rows after the current row within the partition; + /// if there is no such row, instead return default (which must be of the same type as value). + /// Both offset and default are evaluated with respect to the current row. + /// If omitted, offset defaults to 1 and default to null + Lead, + /// returns value evaluated at the row that is the first row of the window frame + FirstValue, + /// returns value evaluated at the row that is the last row of the window frame + LastValue, + /// returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row + NthValue, +} + +impl FromStr for BuiltInWindowFunction { + type Err = DataFusionError; + fn from_str(name: &str) -> Result { + Ok(match name.to_uppercase().as_str() { + "ROW_NUMBER" => BuiltInWindowFunction::RowNumber, + "RANK" => BuiltInWindowFunction::Rank, + "DENSE_RANK" => BuiltInWindowFunction::DenseRank, + "PERCENT_RANK" => BuiltInWindowFunction::PercentRank, + "CUME_DIST" => BuiltInWindowFunction::CumeDist, + "NTILE" => BuiltInWindowFunction::Ntile, + "LAG" => BuiltInWindowFunction::Lag, + "LEAD" => BuiltInWindowFunction::Lead, + "FIRST_VALUE" => BuiltInWindowFunction::FirstValue, + "LAST_VALUE" => BuiltInWindowFunction::LastValue, + "NTH_VALUE" => BuiltInWindowFunction::NthValue, + _ => { + return Err(DataFusionError::Plan(format!( + "There is no built-in window function named {}", + name + ))) + } + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_window_function_case_insensitive() -> Result<()> { + let names = vec![ + "row_number", + "rank", + "dense_rank", + "percent_rank", + "cume_dist", + "ntile", + "lag", + "lead", + "first_value", + "last_value", + "nth_value", + "min", + "max", + "count", + "avg", + "sum", + ]; + for name in names { + let fun = WindowFunction::from_str(name)?; + let fun2 = WindowFunction::from_str(name.to_uppercase().as_str())?; + assert_eq!(fun, fun2); + assert_eq!(fun.to_string(), name.to_uppercase()); + } + Ok(()) + } + + #[test] + fn test_window_function_from_str() -> Result<()> { + assert_eq!( + WindowFunction::from_str("max")?, + WindowFunction::AggregateFunction(AggregateFunction::Max) + ); + assert_eq!( + WindowFunction::from_str("min")?, + WindowFunction::AggregateFunction(AggregateFunction::Min) + ); + assert_eq!( + WindowFunction::from_str("avg")?, + WindowFunction::AggregateFunction(AggregateFunction::Avg) + ); + assert_eq!( + WindowFunction::from_str("cume_dist")?, + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::CumeDist) + ); + assert_eq!( + WindowFunction::from_str("first_value")?, + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue) + ); + assert_eq!( + WindowFunction::from_str("LAST_value")?, + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::LastValue) + ); + assert_eq!( + WindowFunction::from_str("LAG")?, + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lag) + ); + assert_eq!( + WindowFunction::from_str("LEAD")?, + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lead) + ); + Ok(()) + } +} diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index 6df852257e16..6092c06141ee 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -51,6 +51,7 @@ avro = ["avro-rs", "num-traits", "datafusion-common/avro"] [dependencies] datafusion-common = { path = "../datafusion-common", version = "6.0.0" } +datafusion-expr = { path = "../datafusion-expr", version = "6.0.0" } ahash = { version = "0.7", default-features = false } hashbrown = { version = "0.12", features = ["raw"] } arrow = { version = "8.0.0", features = ["prettyprint"] } diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index 8fc94d386014..a1531d4a7b83 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -38,7 +38,7 @@ use expressions::{ avg_return_type, correlation_return_type, covariance_return_type, stddev_return_type, sum_return_type, variance_return_type, }; -use std::{fmt, str::FromStr, sync::Arc}; +use std::sync::Arc; /// the implementation of an aggregate function pub type AccumulatorFunctionImplementation = @@ -49,79 +49,7 @@ pub type AccumulatorFunctionImplementation = pub type StateTypeFunction = Arc Result>> + Send + Sync>; -/// Enum of all built-in aggregate functions -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] -pub enum AggregateFunction { - /// count - Count, - /// sum - Sum, - /// min - Min, - /// max - Max, - /// avg - Avg, - /// Approximate aggregate function - ApproxDistinct, - /// array_agg - ArrayAgg, - /// Variance (Sample) - Variance, - /// Variance (Population) - VariancePop, - /// Standard Deviation (Sample) - Stddev, - /// Standard Deviation (Population) - StddevPop, - /// Covariance (Sample) - Covariance, - /// Covariance (Population) - CovariancePop, - /// Correlation - Correlation, - /// Approximate continuous percentile function - ApproxPercentileCont, -} - -impl fmt::Display for AggregateFunction { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - // uppercase of the debug. - write!(f, "{}", format!("{:?}", self).to_uppercase()) - } -} - -impl FromStr for AggregateFunction { - type Err = DataFusionError; - fn from_str(name: &str) -> Result { - Ok(match name { - "min" => AggregateFunction::Min, - "max" => AggregateFunction::Max, - "count" => AggregateFunction::Count, - "avg" => AggregateFunction::Avg, - "sum" => AggregateFunction::Sum, - "approx_distinct" => AggregateFunction::ApproxDistinct, - "array_agg" => AggregateFunction::ArrayAgg, - "var" => AggregateFunction::Variance, - "var_samp" => AggregateFunction::Variance, - "var_pop" => AggregateFunction::VariancePop, - "stddev" => AggregateFunction::Stddev, - "stddev_samp" => AggregateFunction::Stddev, - "stddev_pop" => AggregateFunction::StddevPop, - "covar" => AggregateFunction::Covariance, - "covar_samp" => AggregateFunction::Covariance, - "covar_pop" => AggregateFunction::CovariancePop, - "corr" => AggregateFunction::Correlation, - "approx_percentile_cont" => AggregateFunction::ApproxPercentileCont, - _ => { - return Err(DataFusionError::Plan(format!( - "There is no built-in function named {}", - name - ))); - } - }) - } -} +pub use datafusion_expr::AggregateFunction; /// Returns the datatype of the aggregate function. /// This is used to get the returned data type for aggregate expr. diff --git a/datafusion/src/physical_plan/window_functions.rs b/datafusion/src/physical_plan/window_functions.rs index 178a55aa05ec..b8cc96a50490 100644 --- a/datafusion/src/physical_plan/window_functions.rs +++ b/datafusion/src/physical_plan/window_functions.rs @@ -23,130 +23,17 @@ use crate::error::{DataFusionError, Result}; use crate::physical_plan::functions::{TypeSignature, Volatility}; use crate::physical_plan::{ - aggregates, aggregates::AggregateFunction, functions::Signature, - type_coercion::data_types, windows::find_ranges_in_range, PhysicalExpr, + aggregates, functions::Signature, type_coercion::data_types, + windows::find_ranges_in_range, PhysicalExpr, }; use arrow::array::ArrayRef; use arrow::datatypes::DataType; use arrow::datatypes::Field; use arrow::record_batch::RecordBatch; +pub use datafusion_expr::{BuiltInWindowFunction, WindowFunction}; use std::any::Any; use std::ops::Range; use std::sync::Arc; -use std::{fmt, str::FromStr}; - -/// WindowFunction -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum WindowFunction { - /// window function that leverages an aggregate function - AggregateFunction(AggregateFunction), - /// window function that leverages a built-in window function - BuiltInWindowFunction(BuiltInWindowFunction), -} - -impl FromStr for WindowFunction { - type Err = DataFusionError; - fn from_str(name: &str) -> Result { - let name = name.to_lowercase(); - if let Ok(aggregate) = AggregateFunction::from_str(name.as_str()) { - Ok(WindowFunction::AggregateFunction(aggregate)) - } else if let Ok(built_in_function) = - BuiltInWindowFunction::from_str(name.as_str()) - { - Ok(WindowFunction::BuiltInWindowFunction(built_in_function)) - } else { - Err(DataFusionError::Plan(format!( - "There is no window function named {}", - name - ))) - } - } -} - -impl fmt::Display for BuiltInWindowFunction { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - BuiltInWindowFunction::RowNumber => write!(f, "ROW_NUMBER"), - BuiltInWindowFunction::Rank => write!(f, "RANK"), - BuiltInWindowFunction::DenseRank => write!(f, "DENSE_RANK"), - BuiltInWindowFunction::PercentRank => write!(f, "PERCENT_RANK"), - BuiltInWindowFunction::CumeDist => write!(f, "CUME_DIST"), - BuiltInWindowFunction::Ntile => write!(f, "NTILE"), - BuiltInWindowFunction::Lag => write!(f, "LAG"), - BuiltInWindowFunction::Lead => write!(f, "LEAD"), - BuiltInWindowFunction::FirstValue => write!(f, "FIRST_VALUE"), - BuiltInWindowFunction::LastValue => write!(f, "LAST_VALUE"), - BuiltInWindowFunction::NthValue => write!(f, "NTH_VALUE"), - } - } -} - -impl fmt::Display for WindowFunction { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - WindowFunction::AggregateFunction(fun) => fun.fmt(f), - WindowFunction::BuiltInWindowFunction(fun) => fun.fmt(f), - } - } -} - -/// An aggregate function that is part of a built-in window function -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum BuiltInWindowFunction { - /// number of the current row within its partition, counting from 1 - RowNumber, - /// rank of the current row with gaps; same as row_number of its first peer - Rank, - /// ank of the current row without gaps; this function counts peer groups - DenseRank, - /// relative rank of the current row: (rank - 1) / (total rows - 1) - PercentRank, - /// relative rank of the current row: (number of rows preceding or peer with current row) / (total rows) - CumeDist, - /// integer ranging from 1 to the argument value, dividing the partition as equally as possible - Ntile, - /// returns value evaluated at the row that is offset rows before the current row within the partition; - /// if there is no such row, instead return default (which must be of the same type as value). - /// Both offset and default are evaluated with respect to the current row. - /// If omitted, offset defaults to 1 and default to null - Lag, - /// returns value evaluated at the row that is offset rows after the current row within the partition; - /// if there is no such row, instead return default (which must be of the same type as value). - /// Both offset and default are evaluated with respect to the current row. - /// If omitted, offset defaults to 1 and default to null - Lead, - /// returns value evaluated at the row that is the first row of the window frame - FirstValue, - /// returns value evaluated at the row that is the last row of the window frame - LastValue, - /// returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row - NthValue, -} - -impl FromStr for BuiltInWindowFunction { - type Err = DataFusionError; - fn from_str(name: &str) -> Result { - Ok(match name.to_uppercase().as_str() { - "ROW_NUMBER" => BuiltInWindowFunction::RowNumber, - "RANK" => BuiltInWindowFunction::Rank, - "DENSE_RANK" => BuiltInWindowFunction::DenseRank, - "PERCENT_RANK" => BuiltInWindowFunction::PercentRank, - "CUME_DIST" => BuiltInWindowFunction::CumeDist, - "NTILE" => BuiltInWindowFunction::Ntile, - "LAG" => BuiltInWindowFunction::Lag, - "LEAD" => BuiltInWindowFunction::Lead, - "FIRST_VALUE" => BuiltInWindowFunction::FirstValue, - "LAST_VALUE" => BuiltInWindowFunction::LastValue, - "NTH_VALUE" => BuiltInWindowFunction::NthValue, - _ => { - return Err(DataFusionError::Plan(format!( - "There is no built-in window function named {}", - name - ))) - } - }) - } -} /// Returns the datatype of the window function pub fn return_type( @@ -303,72 +190,7 @@ pub(crate) trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { #[cfg(test)] mod tests { use super::*; - - #[test] - fn test_window_function_case_insensitive() -> Result<()> { - let names = vec![ - "row_number", - "rank", - "dense_rank", - "percent_rank", - "cume_dist", - "ntile", - "lag", - "lead", - "first_value", - "last_value", - "nth_value", - "min", - "max", - "count", - "avg", - "sum", - ]; - for name in names { - let fun = WindowFunction::from_str(name)?; - let fun2 = WindowFunction::from_str(name.to_uppercase().as_str())?; - assert_eq!(fun, fun2); - assert_eq!(fun.to_string(), name.to_uppercase()); - } - Ok(()) - } - - #[test] - fn test_window_function_from_str() -> Result<()> { - assert_eq!( - WindowFunction::from_str("max")?, - WindowFunction::AggregateFunction(AggregateFunction::Max) - ); - assert_eq!( - WindowFunction::from_str("min")?, - WindowFunction::AggregateFunction(AggregateFunction::Min) - ); - assert_eq!( - WindowFunction::from_str("avg")?, - WindowFunction::AggregateFunction(AggregateFunction::Avg) - ); - assert_eq!( - WindowFunction::from_str("cume_dist")?, - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::CumeDist) - ); - assert_eq!( - WindowFunction::from_str("first_value")?, - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue) - ); - assert_eq!( - WindowFunction::from_str("LAST_value")?, - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::LastValue) - ); - assert_eq!( - WindowFunction::from_str("LAG")?, - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lag) - ); - assert_eq!( - WindowFunction::from_str("LEAD")?, - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lead) - ); - Ok(()) - } + use std::str::FromStr; #[test] fn test_count_return_type() -> Result<()> { From 2ec34cfbc747d6f40bfc1cb3befa21dc080edbf2 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Tue, 8 Feb 2022 00:09:03 +0800 Subject: [PATCH 40/50] move column, dfschema, etc. to common module (#1760) --- datafusion-common/Cargo.toml | 1 + datafusion-common/src/error.rs | 11 - datafusion-common/src/lib.rs | 6 + .../src/pyarrow.rs | 18 +- datafusion-common/src/scalar.rs | 1959 +++++++++++++++++ datafusion/src/lib.rs | 3 - datafusion/src/scalar.rs | 1940 +--------------- 7 files changed, 1986 insertions(+), 1952 deletions(-) rename {datafusion => datafusion-common}/src/pyarrow.rs (92%) create mode 100644 datafusion-common/src/scalar.rs diff --git a/datafusion-common/Cargo.toml b/datafusion-common/Cargo.toml index 9c05d8095caf..350d548f9366 100644 --- a/datafusion-common/Cargo.toml +++ b/datafusion-common/Cargo.toml @@ -42,3 +42,4 @@ parquet = { version = "8.0.0", features = ["arrow"] } avro-rs = { version = "0.13", features = ["snappy"], optional = true } pyo3 = { version = "0.15", optional = true } sqlparser = "0.13" +ordered-float = "2.10" diff --git a/datafusion-common/src/error.rs b/datafusion-common/src/error.rs index ee2e61892fd4..93978db1a1e3 100644 --- a/datafusion-common/src/error.rs +++ b/datafusion-common/src/error.rs @@ -26,10 +26,6 @@ use arrow::error::ArrowError; #[cfg(feature = "avro")] use avro_rs::Error as AvroError; use parquet::errors::ParquetError; -#[cfg(feature = "pyarrow")] -use pyo3::exceptions::PyException; -#[cfg(feature = "pyarrow")] -use pyo3::prelude::PyErr; use sqlparser::parser::ParserError; /// Result type for operations that could result in an [DataFusionError] @@ -87,13 +83,6 @@ impl From for DataFusionError { } } -#[cfg(feature = "pyarrow")] -impl From for PyErr { - fn from(err: DataFusionError) -> PyErr { - PyException::new_err(err.to_string()) - } -} - impl From for ArrowError { fn from(e: DataFusionError) -> Self { match e { diff --git a/datafusion-common/src/lib.rs b/datafusion-common/src/lib.rs index 11f9bbbb7e82..fdcb7d4b5f74 100644 --- a/datafusion-common/src/lib.rs +++ b/datafusion-common/src/lib.rs @@ -18,7 +18,13 @@ mod column; mod dfschema; mod error; +#[cfg(feature = "pyarrow")] +mod pyarrow; +mod scalar; pub use column::Column; pub use dfschema::{DFField, DFSchema, DFSchemaRef, ExprSchema, ToDFSchema}; pub use error::{DataFusionError, Result}; +pub use scalar::{ + ScalarType, ScalarValue, MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128, +}; diff --git a/datafusion/src/pyarrow.rs b/datafusion-common/src/pyarrow.rs similarity index 92% rename from datafusion/src/pyarrow.rs rename to datafusion-common/src/pyarrow.rs index 46eb6b4437b5..bf10b4551775 100644 --- a/datafusion/src/pyarrow.rs +++ b/datafusion-common/src/pyarrow.rs @@ -15,12 +15,21 @@ // specific language governing permissions and limitations // under the License. -use pyo3::prelude::*; +//! PyArrow + +use crate::{DataFusionError, ScalarValue}; +use arrow::array::ArrayData; +use arrow::pyarrow::PyArrowConvert; +use pyo3::exceptions::PyException; +use pyo3::prelude::PyErr; use pyo3::types::PyList; +use pyo3::{FromPyObject, IntoPy, PyAny, PyObject, PyResult, Python}; -use crate::arrow::array::ArrayData; -use crate::arrow::pyarrow::PyArrowConvert; -use crate::scalar::ScalarValue; +impl From for PyErr { + fn from(err: DataFusionError) -> PyErr { + PyException::new_err(err.to_string()) + } +} impl PyArrowConvert for ScalarValue { fn from_pyarrow(value: &PyAny) -> PyResult { @@ -68,7 +77,6 @@ mod tests { use pyo3::prepare_freethreaded_python; use pyo3::py_run; use pyo3::types::PyDict; - use pyo3::Python; fn init_python() { prepare_freethreaded_python(); diff --git a/datafusion-common/src/scalar.rs b/datafusion-common/src/scalar.rs new file mode 100644 index 000000000000..d7c6c6bc710f --- /dev/null +++ b/datafusion-common/src/scalar.rs @@ -0,0 +1,1959 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module provides ScalarValue, an enum that can be used for storage of single elements + +use crate::error::{DataFusionError, Result}; +use arrow::{ + array::*, + compute::kernels::cast::cast, + datatypes::{ + ArrowDictionaryKeyType, ArrowNativeType, DataType, Field, Float32Type, + Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalUnit, TimeUnit, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, + }, +}; +use ordered_float::OrderedFloat; +use std::cmp::Ordering; +use std::convert::{Infallible, TryInto}; +use std::str::FromStr; +use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; + +// TODO may need to be moved to arrow-rs +/// The max precision and scale for decimal128 +pub const MAX_PRECISION_FOR_DECIMAL128: usize = 38; +pub const MAX_SCALE_FOR_DECIMAL128: usize = 38; + +/// Represents a dynamically typed, nullable single value. +/// This is the single-valued counter-part of arrow’s `Array`. +#[derive(Clone)] +pub enum ScalarValue { + /// true or false value + Boolean(Option), + /// 32bit float + Float32(Option), + /// 64bit float + Float64(Option), + /// 128bit decimal, using the i128 to represent the decimal + Decimal128(Option, usize, usize), + /// signed 8bit int + Int8(Option), + /// signed 16bit int + Int16(Option), + /// signed 32bit int + Int32(Option), + /// signed 64bit int + Int64(Option), + /// unsigned 8bit int + UInt8(Option), + /// unsigned 16bit int + UInt16(Option), + /// unsigned 32bit int + UInt32(Option), + /// unsigned 64bit int + UInt64(Option), + /// utf-8 encoded string. + Utf8(Option), + /// utf-8 encoded string representing a LargeString's arrow type. + LargeUtf8(Option), + /// binary + Binary(Option>), + /// large binary + LargeBinary(Option>), + /// list of nested ScalarValue (boxed to reduce size_of(ScalarValue)) + #[allow(clippy::box_collection)] + List(Option>>, Box), + /// Date stored as a signed 32bit int + Date32(Option), + /// Date stored as a signed 64bit int + Date64(Option), + /// Timestamp Second + TimestampSecond(Option, Option), + /// Timestamp Milliseconds + TimestampMillisecond(Option, Option), + /// Timestamp Microseconds + TimestampMicrosecond(Option, Option), + /// Timestamp Nanoseconds + TimestampNanosecond(Option, Option), + /// Interval with YearMonth unit + IntervalYearMonth(Option), + /// Interval with DayTime unit + IntervalDayTime(Option), + /// Interval with MonthDayNano unit + IntervalMonthDayNano(Option), + /// struct of nested ScalarValue (boxed to reduce size_of(ScalarValue)) + #[allow(clippy::box_collection)] + Struct(Option>>, Box>), +} + +// manual implementation of `PartialEq` that uses OrderedFloat to +// get defined behavior for floating point +impl PartialEq for ScalarValue { + fn eq(&self, other: &Self) -> bool { + use ScalarValue::*; + // This purposely doesn't have a catch-all "(_, _)" so that + // any newly added enum variant will require editing this list + // or else face a compile error + match (self, other) { + (Decimal128(v1, p1, s1), Decimal128(v2, p2, s2)) => { + v1.eq(v2) && p1.eq(p2) && s1.eq(s2) + } + (Decimal128(_, _, _), _) => false, + (Boolean(v1), Boolean(v2)) => v1.eq(v2), + (Boolean(_), _) => false, + (Float32(v1), Float32(v2)) => { + let v1 = v1.map(OrderedFloat); + let v2 = v2.map(OrderedFloat); + v1.eq(&v2) + } + (Float32(_), _) => false, + (Float64(v1), Float64(v2)) => { + let v1 = v1.map(OrderedFloat); + let v2 = v2.map(OrderedFloat); + v1.eq(&v2) + } + (Float64(_), _) => false, + (Int8(v1), Int8(v2)) => v1.eq(v2), + (Int8(_), _) => false, + (Int16(v1), Int16(v2)) => v1.eq(v2), + (Int16(_), _) => false, + (Int32(v1), Int32(v2)) => v1.eq(v2), + (Int32(_), _) => false, + (Int64(v1), Int64(v2)) => v1.eq(v2), + (Int64(_), _) => false, + (UInt8(v1), UInt8(v2)) => v1.eq(v2), + (UInt8(_), _) => false, + (UInt16(v1), UInt16(v2)) => v1.eq(v2), + (UInt16(_), _) => false, + (UInt32(v1), UInt32(v2)) => v1.eq(v2), + (UInt32(_), _) => false, + (UInt64(v1), UInt64(v2)) => v1.eq(v2), + (UInt64(_), _) => false, + (Utf8(v1), Utf8(v2)) => v1.eq(v2), + (Utf8(_), _) => false, + (LargeUtf8(v1), LargeUtf8(v2)) => v1.eq(v2), + (LargeUtf8(_), _) => false, + (Binary(v1), Binary(v2)) => v1.eq(v2), + (Binary(_), _) => false, + (LargeBinary(v1), LargeBinary(v2)) => v1.eq(v2), + (LargeBinary(_), _) => false, + (List(v1, t1), List(v2, t2)) => v1.eq(v2) && t1.eq(t2), + (List(_, _), _) => false, + (Date32(v1), Date32(v2)) => v1.eq(v2), + (Date32(_), _) => false, + (Date64(v1), Date64(v2)) => v1.eq(v2), + (Date64(_), _) => false, + (TimestampSecond(v1, _), TimestampSecond(v2, _)) => v1.eq(v2), + (TimestampSecond(_, _), _) => false, + (TimestampMillisecond(v1, _), TimestampMillisecond(v2, _)) => v1.eq(v2), + (TimestampMillisecond(_, _), _) => false, + (TimestampMicrosecond(v1, _), TimestampMicrosecond(v2, _)) => v1.eq(v2), + (TimestampMicrosecond(_, _), _) => false, + (TimestampNanosecond(v1, _), TimestampNanosecond(v2, _)) => v1.eq(v2), + (TimestampNanosecond(_, _), _) => false, + (IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.eq(v2), + (IntervalYearMonth(_), _) => false, + (IntervalDayTime(v1), IntervalDayTime(v2)) => v1.eq(v2), + (IntervalDayTime(_), _) => false, + (IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.eq(v2), + (IntervalMonthDayNano(_), _) => false, + (Struct(v1, t1), Struct(v2, t2)) => v1.eq(v2) && t1.eq(t2), + (Struct(_, _), _) => false, + } + } +} + +// manual implementation of `PartialOrd` that uses OrderedFloat to +// get defined behavior for floating point +impl PartialOrd for ScalarValue { + fn partial_cmp(&self, other: &Self) -> Option { + use ScalarValue::*; + // This purposely doesn't have a catch-all "(_, _)" so that + // any newly added enum variant will require editing this list + // or else face a compile error + match (self, other) { + (Decimal128(v1, p1, s1), Decimal128(v2, p2, s2)) => { + if p1.eq(p2) && s1.eq(s2) { + v1.partial_cmp(v2) + } else { + // Two decimal values can be compared if they have the same precision and scale. + None + } + } + (Decimal128(_, _, _), _) => None, + (Boolean(v1), Boolean(v2)) => v1.partial_cmp(v2), + (Boolean(_), _) => None, + (Float32(v1), Float32(v2)) => { + let v1 = v1.map(OrderedFloat); + let v2 = v2.map(OrderedFloat); + v1.partial_cmp(&v2) + } + (Float32(_), _) => None, + (Float64(v1), Float64(v2)) => { + let v1 = v1.map(OrderedFloat); + let v2 = v2.map(OrderedFloat); + v1.partial_cmp(&v2) + } + (Float64(_), _) => None, + (Int8(v1), Int8(v2)) => v1.partial_cmp(v2), + (Int8(_), _) => None, + (Int16(v1), Int16(v2)) => v1.partial_cmp(v2), + (Int16(_), _) => None, + (Int32(v1), Int32(v2)) => v1.partial_cmp(v2), + (Int32(_), _) => None, + (Int64(v1), Int64(v2)) => v1.partial_cmp(v2), + (Int64(_), _) => None, + (UInt8(v1), UInt8(v2)) => v1.partial_cmp(v2), + (UInt8(_), _) => None, + (UInt16(v1), UInt16(v2)) => v1.partial_cmp(v2), + (UInt16(_), _) => None, + (UInt32(v1), UInt32(v2)) => v1.partial_cmp(v2), + (UInt32(_), _) => None, + (UInt64(v1), UInt64(v2)) => v1.partial_cmp(v2), + (UInt64(_), _) => None, + (Utf8(v1), Utf8(v2)) => v1.partial_cmp(v2), + (Utf8(_), _) => None, + (LargeUtf8(v1), LargeUtf8(v2)) => v1.partial_cmp(v2), + (LargeUtf8(_), _) => None, + (Binary(v1), Binary(v2)) => v1.partial_cmp(v2), + (Binary(_), _) => None, + (LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2), + (LargeBinary(_), _) => None, + (List(v1, t1), List(v2, t2)) => { + if t1.eq(t2) { + v1.partial_cmp(v2) + } else { + None + } + } + (List(_, _), _) => None, + (Date32(v1), Date32(v2)) => v1.partial_cmp(v2), + (Date32(_), _) => None, + (Date64(v1), Date64(v2)) => v1.partial_cmp(v2), + (Date64(_), _) => None, + (TimestampSecond(v1, _), TimestampSecond(v2, _)) => v1.partial_cmp(v2), + (TimestampSecond(_, _), _) => None, + (TimestampMillisecond(v1, _), TimestampMillisecond(v2, _)) => { + v1.partial_cmp(v2) + } + (TimestampMillisecond(_, _), _) => None, + (TimestampMicrosecond(v1, _), TimestampMicrosecond(v2, _)) => { + v1.partial_cmp(v2) + } + (TimestampMicrosecond(_, _), _) => None, + (TimestampNanosecond(v1, _), TimestampNanosecond(v2, _)) => { + v1.partial_cmp(v2) + } + (TimestampNanosecond(_, _), _) => None, + (IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.partial_cmp(v2), + (IntervalYearMonth(_), _) => None, + (IntervalDayTime(v1), IntervalDayTime(v2)) => v1.partial_cmp(v2), + (IntervalDayTime(_), _) => None, + (IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.partial_cmp(v2), + (IntervalMonthDayNano(_), _) => None, + (Struct(v1, t1), Struct(v2, t2)) => { + if t1.eq(t2) { + v1.partial_cmp(v2) + } else { + None + } + } + (Struct(_, _), _) => None, + } + } +} + +impl Eq for ScalarValue {} + +// manual implementation of `Hash` that uses OrderedFloat to +// get defined behavior for floating point +impl std::hash::Hash for ScalarValue { + fn hash(&self, state: &mut H) { + use ScalarValue::*; + match self { + Decimal128(v, p, s) => { + v.hash(state); + p.hash(state); + s.hash(state) + } + Boolean(v) => v.hash(state), + Float32(v) => { + let v = v.map(OrderedFloat); + v.hash(state) + } + Float64(v) => { + let v = v.map(OrderedFloat); + v.hash(state) + } + Int8(v) => v.hash(state), + Int16(v) => v.hash(state), + Int32(v) => v.hash(state), + Int64(v) => v.hash(state), + UInt8(v) => v.hash(state), + UInt16(v) => v.hash(state), + UInt32(v) => v.hash(state), + UInt64(v) => v.hash(state), + Utf8(v) => v.hash(state), + LargeUtf8(v) => v.hash(state), + Binary(v) => v.hash(state), + LargeBinary(v) => v.hash(state), + List(v, t) => { + v.hash(state); + t.hash(state); + } + Date32(v) => v.hash(state), + Date64(v) => v.hash(state), + TimestampSecond(v, _) => v.hash(state), + TimestampMillisecond(v, _) => v.hash(state), + TimestampMicrosecond(v, _) => v.hash(state), + TimestampNanosecond(v, _) => v.hash(state), + IntervalYearMonth(v) => v.hash(state), + IntervalDayTime(v) => v.hash(state), + IntervalMonthDayNano(v) => v.hash(state), + Struct(v, t) => { + v.hash(state); + t.hash(state); + } + } + } +} + +// return the index into the dictionary values for array@index as well +// as a reference to the dictionary values array. Returns None for the +// index if the array is NULL at index +#[inline] +fn get_dict_value( + array: &ArrayRef, + index: usize, +) -> Result<(&ArrayRef, Option)> { + let dict_array = array.as_any().downcast_ref::>().unwrap(); + + // look up the index in the values dictionary + let keys_col = dict_array.keys(); + if !keys_col.is_valid(index) { + return Ok((dict_array.values(), None)); + } + let values_index = keys_col.value(index).to_usize().ok_or_else(|| { + DataFusionError::Internal(format!( + "Can not convert index to usize in dictionary of type creating group by value {:?}", + keys_col.data_type() + )) + })?; + + Ok((dict_array.values(), Some(values_index))) +} + +macro_rules! typed_cast_tz { + ($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident, $TZ:expr) => {{ + let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); + ScalarValue::$SCALAR( + match array.is_null($index) { + true => None, + false => Some(array.value($index).into()), + }, + $TZ.clone(), + ) + }}; +} + +macro_rules! typed_cast { + ($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident) => {{ + let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); + ScalarValue::$SCALAR(match array.is_null($index) { + true => None, + false => Some(array.value($index).into()), + }) + }}; +} + +macro_rules! build_list { + ($VALUE_BUILDER_TY:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr) => {{ + match $VALUES { + // the return on the macro is necessary, to short-circuit and return ArrayRef + None => { + return new_null_array( + &DataType::List(Box::new(Field::new( + "item", + DataType::$SCALAR_TY, + true, + ))), + $SIZE, + ) + } + Some(values) => { + build_values_list!($VALUE_BUILDER_TY, $SCALAR_TY, values.as_ref(), $SIZE) + } + } + }}; +} + +macro_rules! build_timestamp_list { + ($TIME_UNIT:expr, $TIME_ZONE:expr, $VALUES:expr, $SIZE:expr) => {{ + match $VALUES { + // the return on the macro is necessary, to short-circuit and return ArrayRef + None => { + return new_null_array( + &DataType::List(Box::new(Field::new( + "item", + DataType::Timestamp($TIME_UNIT, $TIME_ZONE), + true, + ))), + $SIZE, + ) + } + Some(values) => { + let values = values.as_ref(); + match $TIME_UNIT { + TimeUnit::Second => { + build_values_list_tz!( + TimestampSecondBuilder, + TimestampSecond, + values, + $SIZE + ) + } + TimeUnit::Microsecond => build_values_list_tz!( + TimestampMillisecondBuilder, + TimestampMillisecond, + values, + $SIZE + ), + TimeUnit::Millisecond => build_values_list_tz!( + TimestampMicrosecondBuilder, + TimestampMicrosecond, + values, + $SIZE + ), + TimeUnit::Nanosecond => build_values_list_tz!( + TimestampNanosecondBuilder, + TimestampNanosecond, + values, + $SIZE + ), + } + } + } + }}; +} + +macro_rules! build_values_list { + ($VALUE_BUILDER_TY:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr) => {{ + let mut builder = ListBuilder::new($VALUE_BUILDER_TY::new($VALUES.len())); + + for _ in 0..$SIZE { + for scalar_value in $VALUES { + match scalar_value { + ScalarValue::$SCALAR_TY(Some(v)) => { + builder.values().append_value(v.clone()).unwrap() + } + ScalarValue::$SCALAR_TY(None) => { + builder.values().append_null().unwrap(); + } + _ => panic!("Incompatible ScalarValue for list"), + }; + } + builder.append(true).unwrap(); + } + + builder.finish() + }}; +} + +macro_rules! build_values_list_tz { + ($VALUE_BUILDER_TY:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr) => {{ + let mut builder = ListBuilder::new($VALUE_BUILDER_TY::new($VALUES.len())); + + for _ in 0..$SIZE { + for scalar_value in $VALUES { + match scalar_value { + ScalarValue::$SCALAR_TY(Some(v), _) => { + builder.values().append_value(v.clone()).unwrap() + } + ScalarValue::$SCALAR_TY(None, _) => { + builder.values().append_null().unwrap(); + } + _ => panic!("Incompatible ScalarValue for list"), + }; + } + builder.append(true).unwrap(); + } + + builder.finish() + }}; +} + +macro_rules! build_array_from_option { + ($DATA_TYPE:ident, $ARRAY_TYPE:ident, $EXPR:expr, $SIZE:expr) => {{ + match $EXPR { + Some(value) => Arc::new($ARRAY_TYPE::from_value(*value, $SIZE)), + None => new_null_array(&DataType::$DATA_TYPE, $SIZE), + } + }}; + ($DATA_TYPE:ident, $ENUM:expr, $ARRAY_TYPE:ident, $EXPR:expr, $SIZE:expr) => {{ + match $EXPR { + Some(value) => Arc::new($ARRAY_TYPE::from_value(*value, $SIZE)), + None => new_null_array(&DataType::$DATA_TYPE($ENUM), $SIZE), + } + }}; + ($DATA_TYPE:ident, $ENUM:expr, $ENUM2:expr, $ARRAY_TYPE:ident, $EXPR:expr, $SIZE:expr) => {{ + match $EXPR { + Some(value) => { + let array: ArrayRef = Arc::new($ARRAY_TYPE::from_value(*value, $SIZE)); + // Need to call cast to cast to final data type with timezone/extra param + cast(&array, &DataType::$DATA_TYPE($ENUM, $ENUM2)) + .expect("cannot do temporal cast") + } + None => new_null_array(&DataType::$DATA_TYPE($ENUM, $ENUM2), $SIZE), + } + }}; +} + +macro_rules! eq_array_primitive { + ($array:expr, $index:expr, $ARRAYTYPE:ident, $VALUE:expr) => {{ + let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); + let is_valid = array.is_valid($index); + match $VALUE { + Some(val) => is_valid && &array.value($index) == val, + None => !is_valid, + } + }}; +} + +impl ScalarValue { + /// Create a decimal Scalar from value/precision and scale. + pub fn try_new_decimal128( + value: i128, + precision: usize, + scale: usize, + ) -> Result { + // make sure the precision and scale is valid + if precision <= MAX_PRECISION_FOR_DECIMAL128 && scale <= precision { + return Ok(ScalarValue::Decimal128(Some(value), precision, scale)); + } + return Err(DataFusionError::Internal(format!( + "Can not new a decimal type ScalarValue for precision {} and scale {}", + precision, scale + ))); + } + /// Getter for the `DataType` of the value + pub fn get_datatype(&self) -> DataType { + match self { + ScalarValue::Boolean(_) => DataType::Boolean, + ScalarValue::UInt8(_) => DataType::UInt8, + ScalarValue::UInt16(_) => DataType::UInt16, + ScalarValue::UInt32(_) => DataType::UInt32, + ScalarValue::UInt64(_) => DataType::UInt64, + ScalarValue::Int8(_) => DataType::Int8, + ScalarValue::Int16(_) => DataType::Int16, + ScalarValue::Int32(_) => DataType::Int32, + ScalarValue::Int64(_) => DataType::Int64, + ScalarValue::Decimal128(_, precision, scale) => { + DataType::Decimal(*precision, *scale) + } + ScalarValue::TimestampSecond(_, tz_opt) => { + DataType::Timestamp(TimeUnit::Second, tz_opt.clone()) + } + ScalarValue::TimestampMillisecond(_, tz_opt) => { + DataType::Timestamp(TimeUnit::Millisecond, tz_opt.clone()) + } + ScalarValue::TimestampMicrosecond(_, tz_opt) => { + DataType::Timestamp(TimeUnit::Microsecond, tz_opt.clone()) + } + ScalarValue::TimestampNanosecond(_, tz_opt) => { + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()) + } + ScalarValue::Float32(_) => DataType::Float32, + ScalarValue::Float64(_) => DataType::Float64, + ScalarValue::Utf8(_) => DataType::Utf8, + ScalarValue::LargeUtf8(_) => DataType::LargeUtf8, + ScalarValue::Binary(_) => DataType::Binary, + ScalarValue::LargeBinary(_) => DataType::LargeBinary, + ScalarValue::List(_, data_type) => DataType::List(Box::new(Field::new( + "item", + data_type.as_ref().clone(), + true, + ))), + ScalarValue::Date32(_) => DataType::Date32, + ScalarValue::Date64(_) => DataType::Date64, + ScalarValue::IntervalYearMonth(_) => { + DataType::Interval(IntervalUnit::YearMonth) + } + ScalarValue::IntervalDayTime(_) => DataType::Interval(IntervalUnit::DayTime), + ScalarValue::IntervalMonthDayNano(_) => { + DataType::Interval(IntervalUnit::MonthDayNano) + } + ScalarValue::Struct(_, fields) => DataType::Struct(fields.as_ref().clone()), + } + } + + /// Calculate arithmetic negation for a scalar value + pub fn arithmetic_negate(&self) -> Self { + match self { + ScalarValue::Boolean(None) + | ScalarValue::Int8(None) + | ScalarValue::Int16(None) + | ScalarValue::Int32(None) + | ScalarValue::Int64(None) + | ScalarValue::Float32(None) => self.clone(), + ScalarValue::Float64(Some(v)) => ScalarValue::Float64(Some(-v)), + ScalarValue::Float32(Some(v)) => ScalarValue::Float32(Some(-v)), + ScalarValue::Int8(Some(v)) => ScalarValue::Int8(Some(-v)), + ScalarValue::Int16(Some(v)) => ScalarValue::Int16(Some(-v)), + ScalarValue::Int32(Some(v)) => ScalarValue::Int32(Some(-v)), + ScalarValue::Int64(Some(v)) => ScalarValue::Int64(Some(-v)), + ScalarValue::Decimal128(Some(v), precision, scale) => { + ScalarValue::Decimal128(Some(-v), *precision, *scale) + } + _ => panic!("Cannot run arithmetic negate on scalar value: {:?}", self), + } + } + + /// whether this value is null or not. + pub fn is_null(&self) -> bool { + matches!( + *self, + ScalarValue::Boolean(None) + | ScalarValue::UInt8(None) + | ScalarValue::UInt16(None) + | ScalarValue::UInt32(None) + | ScalarValue::UInt64(None) + | ScalarValue::Int8(None) + | ScalarValue::Int16(None) + | ScalarValue::Int32(None) + | ScalarValue::Int64(None) + | ScalarValue::Float32(None) + | ScalarValue::Float64(None) + | ScalarValue::Date32(None) + | ScalarValue::Date64(None) + | ScalarValue::Utf8(None) + | ScalarValue::LargeUtf8(None) + | ScalarValue::List(None, _) + | ScalarValue::TimestampSecond(None, _) + | ScalarValue::TimestampMillisecond(None, _) + | ScalarValue::TimestampMicrosecond(None, _) + | ScalarValue::TimestampNanosecond(None, _) + | ScalarValue::Struct(None, _) + | ScalarValue::Decimal128(None, _, _) // For decimal type, the value is null means ScalarValue::Decimal128 is null. + ) + } + + /// Converts a scalar value into an 1-row array. + pub fn to_array(&self) -> ArrayRef { + self.to_array_of_size(1) + } + + /// Converts an iterator of references [`ScalarValue`] into an [`ArrayRef`] + /// corresponding to those values. For example, + /// + /// Returns an error if the iterator is empty or if the + /// [`ScalarValue`]s are not all the same type + /// + /// Example + /// ``` + /// use datafusion_common::ScalarValue; + /// use arrow::array::{ArrayRef, BooleanArray}; + /// + /// let scalars = vec![ + /// ScalarValue::Boolean(Some(true)), + /// ScalarValue::Boolean(None), + /// ScalarValue::Boolean(Some(false)), + /// ]; + /// + /// // Build an Array from the list of ScalarValues + /// let array = ScalarValue::iter_to_array(scalars.into_iter()) + /// .unwrap(); + /// + /// let expected: ArrayRef = std::sync::Arc::new( + /// BooleanArray::from(vec![ + /// Some(true), + /// None, + /// Some(false) + /// ] + /// )); + /// + /// assert_eq!(&array, &expected); + /// ``` + pub fn iter_to_array( + scalars: impl IntoIterator, + ) -> Result { + let mut scalars = scalars.into_iter().peekable(); + + // figure out the type based on the first element + let data_type = match scalars.peek() { + None => { + return Err(DataFusionError::Internal( + "Empty iterator passed to ScalarValue::iter_to_array".to_string(), + )); + } + Some(sv) => sv.get_datatype(), + }; + + /// Creates an array of $ARRAY_TY by unpacking values of + /// SCALAR_TY for primitive types + macro_rules! build_array_primitive { + ($ARRAY_TY:ident, $SCALAR_TY:ident) => {{ + { + let array = scalars + .map(|sv| { + if let ScalarValue::$SCALAR_TY(v) = sv { + Ok(v) + } else { + Err(DataFusionError::Internal(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected {:?}, got {:?}", + data_type, sv + ))) + } + }) + .collect::>()?; + + Arc::new(array) + } + }}; + } + + macro_rules! build_array_primitive_tz { + ($ARRAY_TY:ident, $SCALAR_TY:ident) => {{ + { + let array = scalars + .map(|sv| { + if let ScalarValue::$SCALAR_TY(v, _) = sv { + Ok(v) + } else { + Err(DataFusionError::Internal(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected {:?}, got {:?}", + data_type, sv + ))) + } + }) + .collect::>()?; + + Arc::new(array) + } + }}; + } + + /// Creates an array of $ARRAY_TY by unpacking values of + /// SCALAR_TY for "string-like" types. + macro_rules! build_array_string { + ($ARRAY_TY:ident, $SCALAR_TY:ident) => {{ + { + let array = scalars + .map(|sv| { + if let ScalarValue::$SCALAR_TY(v) = sv { + Ok(v) + } else { + Err(DataFusionError::Internal(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected {:?}, got {:?}", + data_type, sv + ))) + } + }) + .collect::>()?; + Arc::new(array) + } + }}; + } + + macro_rules! build_array_list_primitive { + ($ARRAY_TY:ident, $SCALAR_TY:ident, $NATIVE_TYPE:ident) => {{ + Arc::new(ListArray::from_iter_primitive::<$ARRAY_TY, _, _>( + scalars.into_iter().map(|x| match x { + ScalarValue::List(xs, _) => xs.map(|x| { + x.iter() + .map(|x| match x { + ScalarValue::$SCALAR_TY(i) => *i, + sv => panic!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected {:?}, got {:?}", + data_type, sv + ), + }) + .collect::>>() + }), + sv => panic!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected {:?}, got {:?}", + data_type, sv + ), + }), + )) + }}; + } + + macro_rules! build_array_list_string { + ($BUILDER:ident, $SCALAR_TY:ident) => {{ + let mut builder = ListBuilder::new($BUILDER::new(0)); + + for scalar in scalars.into_iter() { + match scalar { + ScalarValue::List(Some(xs), _) => { + let xs = *xs; + for s in xs { + match s { + ScalarValue::$SCALAR_TY(Some(val)) => { + builder.values().append_value(val)?; + } + ScalarValue::$SCALAR_TY(None) => { + builder.values().append_null()?; + } + sv => { + return Err(DataFusionError::Internal(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected Utf8, got {:?}", + sv + ))) + } + } + } + builder.append(true)?; + } + ScalarValue::List(None, _) => { + builder.append(false)?; + } + sv => { + return Err(DataFusionError::Internal(format!( + "Inconsistent types in ScalarValue::iter_to_array. \ + Expected List, got {:?}", + sv + ))) + } + } + } + + Arc::new(builder.finish()) + }}; + } + + let array: ArrayRef = match &data_type { + DataType::Decimal(precision, scale) => { + let decimal_array = + ScalarValue::iter_to_decimal_array(scalars, precision, scale)?; + Arc::new(decimal_array) + } + DataType::Boolean => build_array_primitive!(BooleanArray, Boolean), + DataType::Float32 => build_array_primitive!(Float32Array, Float32), + DataType::Float64 => build_array_primitive!(Float64Array, Float64), + DataType::Int8 => build_array_primitive!(Int8Array, Int8), + DataType::Int16 => build_array_primitive!(Int16Array, Int16), + DataType::Int32 => build_array_primitive!(Int32Array, Int32), + DataType::Int64 => build_array_primitive!(Int64Array, Int64), + DataType::UInt8 => build_array_primitive!(UInt8Array, UInt8), + DataType::UInt16 => build_array_primitive!(UInt16Array, UInt16), + DataType::UInt32 => build_array_primitive!(UInt32Array, UInt32), + DataType::UInt64 => build_array_primitive!(UInt64Array, UInt64), + DataType::Utf8 => build_array_string!(StringArray, Utf8), + DataType::LargeUtf8 => build_array_string!(LargeStringArray, LargeUtf8), + DataType::Binary => build_array_string!(BinaryArray, Binary), + DataType::LargeBinary => build_array_string!(LargeBinaryArray, LargeBinary), + DataType::Date32 => build_array_primitive!(Date32Array, Date32), + DataType::Date64 => build_array_primitive!(Date64Array, Date64), + DataType::Timestamp(TimeUnit::Second, _) => { + build_array_primitive_tz!(TimestampSecondArray, TimestampSecond) + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + build_array_primitive_tz!(TimestampMillisecondArray, TimestampMillisecond) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + build_array_primitive_tz!(TimestampMicrosecondArray, TimestampMicrosecond) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + build_array_primitive_tz!(TimestampNanosecondArray, TimestampNanosecond) + } + DataType::Interval(IntervalUnit::DayTime) => { + build_array_primitive!(IntervalDayTimeArray, IntervalDayTime) + } + DataType::Interval(IntervalUnit::YearMonth) => { + build_array_primitive!(IntervalYearMonthArray, IntervalYearMonth) + } + DataType::List(fields) if fields.data_type() == &DataType::Int8 => { + build_array_list_primitive!(Int8Type, Int8, i8) + } + DataType::List(fields) if fields.data_type() == &DataType::Int16 => { + build_array_list_primitive!(Int16Type, Int16, i16) + } + DataType::List(fields) if fields.data_type() == &DataType::Int32 => { + build_array_list_primitive!(Int32Type, Int32, i32) + } + DataType::List(fields) if fields.data_type() == &DataType::Int64 => { + build_array_list_primitive!(Int64Type, Int64, i64) + } + DataType::List(fields) if fields.data_type() == &DataType::UInt8 => { + build_array_list_primitive!(UInt8Type, UInt8, u8) + } + DataType::List(fields) if fields.data_type() == &DataType::UInt16 => { + build_array_list_primitive!(UInt16Type, UInt16, u16) + } + DataType::List(fields) if fields.data_type() == &DataType::UInt32 => { + build_array_list_primitive!(UInt32Type, UInt32, u32) + } + DataType::List(fields) if fields.data_type() == &DataType::UInt64 => { + build_array_list_primitive!(UInt64Type, UInt64, u64) + } + DataType::List(fields) if fields.data_type() == &DataType::Float32 => { + build_array_list_primitive!(Float32Type, Float32, f32) + } + DataType::List(fields) if fields.data_type() == &DataType::Float64 => { + build_array_list_primitive!(Float64Type, Float64, f64) + } + DataType::List(fields) if fields.data_type() == &DataType::Utf8 => { + build_array_list_string!(StringBuilder, Utf8) + } + DataType::List(fields) if fields.data_type() == &DataType::LargeUtf8 => { + build_array_list_string!(LargeStringBuilder, LargeUtf8) + } + DataType::List(_) => { + // Fallback case handling homogeneous lists with any ScalarValue element type + let list_array = ScalarValue::iter_to_array_list(scalars, &data_type)?; + Arc::new(list_array) + } + DataType::Struct(fields) => { + // Initialize a Vector to store the ScalarValues for each column + let mut columns: Vec> = + (0..fields.len()).map(|_| Vec::new()).collect(); + + // Iterate over scalars to populate the column scalars for each row + for scalar in scalars { + if let ScalarValue::Struct(values, fields) = scalar { + match values { + Some(values) => { + // Push value for each field + for c in 0..columns.len() { + let column = columns.get_mut(c).unwrap(); + column.push(values[c].clone()); + } + } + None => { + // Push NULL of the appropriate type for each field + for c in 0..columns.len() { + let dtype = fields[c].data_type(); + let column = columns.get_mut(c).unwrap(); + column.push(ScalarValue::try_from(dtype)?); + } + } + }; + } else { + return Err(DataFusionError::Internal(format!( + "Expected Struct but found: {}", + scalar + ))); + }; + } + + // Call iter_to_array recursively to convert the scalars for each column into Arrow arrays + let field_values = fields + .iter() + .zip(columns) + .map(|(field, column)| -> Result<(Field, ArrayRef)> { + Ok((field.clone(), Self::iter_to_array(column)?)) + }) + .collect::>>()?; + + Arc::new(StructArray::from(field_values)) + } + _ => { + return Err(DataFusionError::Internal(format!( + "Unsupported creation of {:?} array from ScalarValue {:?}", + data_type, + scalars.peek() + ))); + } + }; + + Ok(array) + } + + fn iter_to_decimal_array( + scalars: impl IntoIterator, + precision: &usize, + scale: &usize, + ) -> Result { + // collect the value as Option + let array = scalars + .into_iter() + .map(|element: ScalarValue| match element { + ScalarValue::Decimal128(v1, _, _) => v1, + _ => unreachable!(), + }) + .collect::>>(); + + // build the decimal array using the Decimal Builder + let mut builder = DecimalBuilder::new(array.len(), *precision, *scale); + array.iter().for_each(|element| match element { + None => { + builder.append_null().unwrap(); + } + Some(v) => { + builder.append_value(*v).unwrap(); + } + }); + Ok(builder.finish()) + } + + fn iter_to_array_list( + scalars: impl IntoIterator, + data_type: &DataType, + ) -> Result> { + let mut offsets = Int32Array::builder(0); + if let Err(err) = offsets.append_value(0) { + return Err(DataFusionError::ArrowError(err)); + } + + let mut elements: Vec = Vec::new(); + let mut valid = BooleanBufferBuilder::new(0); + let mut flat_len = 0i32; + for scalar in scalars { + if let ScalarValue::List(values, _) = scalar { + match values { + Some(values) => { + let element_array = ScalarValue::iter_to_array(*values)?; + + // Add new offset index + flat_len += element_array.len() as i32; + if let Err(err) = offsets.append_value(flat_len) { + return Err(DataFusionError::ArrowError(err)); + } + + elements.push(element_array); + + // Element is valid + valid.append(true); + } + None => { + // Repeat previous offset index + if let Err(err) = offsets.append_value(flat_len) { + return Err(DataFusionError::ArrowError(err)); + } + + // Element is null + valid.append(false); + } + } + } else { + return Err(DataFusionError::Internal(format!( + "Expected ScalarValue::List element. Received {:?}", + scalar + ))); + } + } + + // Concatenate element arrays to create single flat array + let element_arrays: Vec<&dyn Array> = + elements.iter().map(|a| a.as_ref()).collect(); + let flat_array = match arrow::compute::concat(&element_arrays) { + Ok(flat_array) => flat_array, + Err(err) => return Err(DataFusionError::ArrowError(err)), + }; + + // Build ListArray using ArrayData so we can specify a flat inner array, and offset indices + let offsets_array = offsets.finish(); + let array_data = ArrayDataBuilder::new(data_type.clone()) + .len(offsets_array.len() - 1) + .null_bit_buffer(valid.finish()) + .add_buffer(offsets_array.data().buffers()[0].clone()) + .add_child_data(flat_array.data().clone()); + + let list_array = ListArray::from(array_data.build()?); + Ok(list_array) + } + + fn build_decimal_array( + value: &Option, + precision: &usize, + scale: &usize, + size: usize, + ) -> DecimalArray { + let mut builder = DecimalBuilder::new(size, *precision, *scale); + match value { + None => { + for _i in 0..size { + builder.append_null().unwrap(); + } + } + Some(v) => { + let v = *v; + for _i in 0..size { + builder.append_value(v).unwrap(); + } + } + }; + builder.finish() + } + + /// Converts a scalar value into an array of `size` rows. + pub fn to_array_of_size(&self, size: usize) -> ArrayRef { + match self { + ScalarValue::Decimal128(e, precision, scale) => { + Arc::new(ScalarValue::build_decimal_array(e, precision, scale, size)) + } + ScalarValue::Boolean(e) => { + Arc::new(BooleanArray::from(vec![*e; size])) as ArrayRef + } + ScalarValue::Float64(e) => { + build_array_from_option!(Float64, Float64Array, e, size) + } + ScalarValue::Float32(e) => { + build_array_from_option!(Float32, Float32Array, e, size) + } + ScalarValue::Int8(e) => build_array_from_option!(Int8, Int8Array, e, size), + ScalarValue::Int16(e) => build_array_from_option!(Int16, Int16Array, e, size), + ScalarValue::Int32(e) => build_array_from_option!(Int32, Int32Array, e, size), + ScalarValue::Int64(e) => build_array_from_option!(Int64, Int64Array, e, size), + ScalarValue::UInt8(e) => build_array_from_option!(UInt8, UInt8Array, e, size), + ScalarValue::UInt16(e) => { + build_array_from_option!(UInt16, UInt16Array, e, size) + } + ScalarValue::UInt32(e) => { + build_array_from_option!(UInt32, UInt32Array, e, size) + } + ScalarValue::UInt64(e) => { + build_array_from_option!(UInt64, UInt64Array, e, size) + } + ScalarValue::TimestampSecond(e, tz_opt) => build_array_from_option!( + Timestamp, + TimeUnit::Second, + tz_opt.clone(), + TimestampSecondArray, + e, + size + ), + ScalarValue::TimestampMillisecond(e, tz_opt) => build_array_from_option!( + Timestamp, + TimeUnit::Millisecond, + tz_opt.clone(), + TimestampMillisecondArray, + e, + size + ), + + ScalarValue::TimestampMicrosecond(e, tz_opt) => build_array_from_option!( + Timestamp, + TimeUnit::Microsecond, + tz_opt.clone(), + TimestampMicrosecondArray, + e, + size + ), + ScalarValue::TimestampNanosecond(e, tz_opt) => build_array_from_option!( + Timestamp, + TimeUnit::Nanosecond, + tz_opt.clone(), + TimestampNanosecondArray, + e, + size + ), + ScalarValue::Utf8(e) => match e { + Some(value) => { + Arc::new(StringArray::from_iter_values(repeat(value).take(size))) + } + None => new_null_array(&DataType::Utf8, size), + }, + ScalarValue::LargeUtf8(e) => match e { + Some(value) => { + Arc::new(LargeStringArray::from_iter_values(repeat(value).take(size))) + } + None => new_null_array(&DataType::LargeUtf8, size), + }, + ScalarValue::Binary(e) => match e { + Some(value) => Arc::new( + repeat(Some(value.as_slice())) + .take(size) + .collect::(), + ), + None => { + Arc::new(repeat(None::<&str>).take(size).collect::()) + } + }, + ScalarValue::LargeBinary(e) => match e { + Some(value) => Arc::new( + repeat(Some(value.as_slice())) + .take(size) + .collect::(), + ), + None => Arc::new( + repeat(None::<&str>) + .take(size) + .collect::(), + ), + }, + ScalarValue::List(values, data_type) => Arc::new(match data_type.as_ref() { + DataType::Boolean => build_list!(BooleanBuilder, Boolean, values, size), + DataType::Int8 => build_list!(Int8Builder, Int8, values, size), + DataType::Int16 => build_list!(Int16Builder, Int16, values, size), + DataType::Int32 => build_list!(Int32Builder, Int32, values, size), + DataType::Int64 => build_list!(Int64Builder, Int64, values, size), + DataType::UInt8 => build_list!(UInt8Builder, UInt8, values, size), + DataType::UInt16 => build_list!(UInt16Builder, UInt16, values, size), + DataType::UInt32 => build_list!(UInt32Builder, UInt32, values, size), + DataType::UInt64 => build_list!(UInt64Builder, UInt64, values, size), + DataType::Utf8 => build_list!(StringBuilder, Utf8, values, size), + DataType::Float32 => build_list!(Float32Builder, Float32, values, size), + DataType::Float64 => build_list!(Float64Builder, Float64, values, size), + DataType::Timestamp(unit, tz) => { + build_timestamp_list!(unit.clone(), tz.clone(), values, size) + } + &DataType::LargeUtf8 => { + build_list!(LargeStringBuilder, LargeUtf8, values, size) + } + _ => ScalarValue::iter_to_array_list( + repeat(self.clone()).take(size), + &DataType::List(Box::new(Field::new( + "item", + data_type.as_ref().clone(), + true, + ))), + ) + .unwrap(), + }), + ScalarValue::Date32(e) => { + build_array_from_option!(Date32, Date32Array, e, size) + } + ScalarValue::Date64(e) => { + build_array_from_option!(Date64, Date64Array, e, size) + } + ScalarValue::IntervalDayTime(e) => build_array_from_option!( + Interval, + IntervalUnit::DayTime, + IntervalDayTimeArray, + e, + size + ), + ScalarValue::IntervalYearMonth(e) => build_array_from_option!( + Interval, + IntervalUnit::YearMonth, + IntervalYearMonthArray, + e, + size + ), + ScalarValue::IntervalMonthDayNano(e) => build_array_from_option!( + Interval, + IntervalUnit::MonthDayNano, + IntervalMonthDayNanoArray, + e, + size + ), + ScalarValue::Struct(values, fields) => match values { + Some(values) => { + let field_values: Vec<_> = fields + .iter() + .zip(values.iter()) + .map(|(field, value)| { + (field.clone(), value.to_array_of_size(size)) + }) + .collect(); + + Arc::new(StructArray::from(field_values)) + } + None => { + let field_values: Vec<_> = fields + .iter() + .map(|field| { + let none_field = Self::try_from(field.data_type()) + .expect("Failed to construct null ScalarValue from Struct field type"); + (field.clone(), none_field.to_array_of_size(size)) + }) + .collect(); + + Arc::new(StructArray::from(field_values)) + } + }, + } + } + + fn get_decimal_value_from_array( + array: &ArrayRef, + index: usize, + precision: &usize, + scale: &usize, + ) -> ScalarValue { + let array = array.as_any().downcast_ref::().unwrap(); + if array.is_null(index) { + ScalarValue::Decimal128(None, *precision, *scale) + } else { + ScalarValue::Decimal128(Some(array.value(index)), *precision, *scale) + } + } + + /// Converts a value in `array` at `index` into a ScalarValue + pub fn try_from_array(array: &ArrayRef, index: usize) -> Result { + // handle NULL value + if !array.is_valid(index) { + return array.data_type().try_into(); + } + + Ok(match array.data_type() { + DataType::Decimal(precision, scale) => { + ScalarValue::get_decimal_value_from_array(array, index, precision, scale) + } + DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean), + DataType::Float64 => typed_cast!(array, index, Float64Array, Float64), + DataType::Float32 => typed_cast!(array, index, Float32Array, Float32), + DataType::UInt64 => typed_cast!(array, index, UInt64Array, UInt64), + DataType::UInt32 => typed_cast!(array, index, UInt32Array, UInt32), + DataType::UInt16 => typed_cast!(array, index, UInt16Array, UInt16), + DataType::UInt8 => typed_cast!(array, index, UInt8Array, UInt8), + DataType::Int64 => typed_cast!(array, index, Int64Array, Int64), + DataType::Int32 => typed_cast!(array, index, Int32Array, Int32), + DataType::Int16 => typed_cast!(array, index, Int16Array, Int16), + DataType::Int8 => typed_cast!(array, index, Int8Array, Int8), + DataType::Binary => typed_cast!(array, index, BinaryArray, Binary), + DataType::LargeBinary => { + typed_cast!(array, index, LargeBinaryArray, LargeBinary) + } + DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8), + DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray, LargeUtf8), + DataType::List(nested_type) => { + let list_array = + array.as_any().downcast_ref::().ok_or_else(|| { + DataFusionError::Internal( + "Failed to downcast ListArray".to_string(), + ) + })?; + let value = match list_array.is_null(index) { + true => None, + false => { + let nested_array = list_array.value(index); + let scalar_vec = (0..nested_array.len()) + .map(|i| ScalarValue::try_from_array(&nested_array, i)) + .collect::>>()?; + Some(scalar_vec) + } + }; + let value = value.map(Box::new); + let data_type = Box::new(nested_type.data_type().clone()); + ScalarValue::List(value, data_type) + } + DataType::Date32 => { + typed_cast!(array, index, Date32Array, Date32) + } + DataType::Date64 => { + typed_cast!(array, index, Date64Array, Date64) + } + DataType::Timestamp(TimeUnit::Second, tz_opt) => { + typed_cast_tz!( + array, + index, + TimestampSecondArray, + TimestampSecond, + tz_opt + ) + } + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { + typed_cast_tz!( + array, + index, + TimestampMillisecondArray, + TimestampMillisecond, + tz_opt + ) + } + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { + typed_cast_tz!( + array, + index, + TimestampMicrosecondArray, + TimestampMicrosecond, + tz_opt + ) + } + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { + typed_cast_tz!( + array, + index, + TimestampNanosecondArray, + TimestampNanosecond, + tz_opt + ) + } + DataType::Dictionary(index_type, _) => { + let (values, values_index) = match **index_type { + DataType::Int8 => get_dict_value::(array, index)?, + DataType::Int16 => get_dict_value::(array, index)?, + DataType::Int32 => get_dict_value::(array, index)?, + DataType::Int64 => get_dict_value::(array, index)?, + DataType::UInt8 => get_dict_value::(array, index)?, + DataType::UInt16 => get_dict_value::(array, index)?, + DataType::UInt32 => get_dict_value::(array, index)?, + DataType::UInt64 => get_dict_value::(array, index)?, + _ => { + return Err(DataFusionError::Internal(format!( + "Index type not supported while creating scalar from dictionary: {}", + array.data_type(), + ))); + } + }; + + match values_index { + Some(values_index) => Self::try_from_array(values, values_index)?, + // was null + None => values.data_type().try_into()?, + } + } + DataType::Struct(fields) => { + let array = + array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "Failed to downcast ArrayRef to StructArray".to_string(), + ) + })?; + let mut field_values: Vec = Vec::new(); + for col_index in 0..array.num_columns() { + let col_array = array.column(col_index); + let col_scalar = ScalarValue::try_from_array(col_array, index)?; + field_values.push(col_scalar); + } + Self::Struct(Some(Box::new(field_values)), Box::new(fields.clone())) + } + other => { + return Err(DataFusionError::NotImplemented(format!( + "Can't create a scalar from array of type \"{:?}\"", + other + ))); + } + }) + } + + fn eq_array_decimal( + array: &ArrayRef, + index: usize, + value: &Option, + precision: usize, + scale: usize, + ) -> bool { + let array = array.as_any().downcast_ref::().unwrap(); + if array.precision() != precision || array.scale() != scale { + return false; + } + match value { + None => array.is_null(index), + Some(v) => !array.is_null(index) && array.value(index) == *v, + } + } + + /// Compares a single row of array @ index for equality with self, + /// in an optimized fashion. + /// + /// This method implements an optimized version of: + /// + /// ```text + /// let arr_scalar = Self::try_from_array(array, index).unwrap(); + /// arr_scalar.eq(self) + /// ``` + /// + /// *Performance note*: the arrow compute kernels should be + /// preferred over this function if at all possible as they can be + /// vectorized and are generally much faster. + /// + /// This function has a few narrow usescases such as hash table key + /// comparisons where comparing a single row at a time is necessary. + #[inline] + pub fn eq_array(&self, array: &ArrayRef, index: usize) -> bool { + if let DataType::Dictionary(key_type, _) = array.data_type() { + return self.eq_array_dictionary(array, index, key_type); + } + + match self { + ScalarValue::Decimal128(v, precision, scale) => { + ScalarValue::eq_array_decimal(array, index, v, *precision, *scale) + } + ScalarValue::Boolean(val) => { + eq_array_primitive!(array, index, BooleanArray, val) + } + ScalarValue::Float32(val) => { + eq_array_primitive!(array, index, Float32Array, val) + } + ScalarValue::Float64(val) => { + eq_array_primitive!(array, index, Float64Array, val) + } + ScalarValue::Int8(val) => eq_array_primitive!(array, index, Int8Array, val), + ScalarValue::Int16(val) => eq_array_primitive!(array, index, Int16Array, val), + ScalarValue::Int32(val) => eq_array_primitive!(array, index, Int32Array, val), + ScalarValue::Int64(val) => eq_array_primitive!(array, index, Int64Array, val), + ScalarValue::UInt8(val) => eq_array_primitive!(array, index, UInt8Array, val), + ScalarValue::UInt16(val) => { + eq_array_primitive!(array, index, UInt16Array, val) + } + ScalarValue::UInt32(val) => { + eq_array_primitive!(array, index, UInt32Array, val) + } + ScalarValue::UInt64(val) => { + eq_array_primitive!(array, index, UInt64Array, val) + } + ScalarValue::Utf8(val) => eq_array_primitive!(array, index, StringArray, val), + ScalarValue::LargeUtf8(val) => { + eq_array_primitive!(array, index, LargeStringArray, val) + } + ScalarValue::Binary(val) => { + eq_array_primitive!(array, index, BinaryArray, val) + } + ScalarValue::LargeBinary(val) => { + eq_array_primitive!(array, index, LargeBinaryArray, val) + } + ScalarValue::List(_, _) => unimplemented!(), + ScalarValue::Date32(val) => { + eq_array_primitive!(array, index, Date32Array, val) + } + ScalarValue::Date64(val) => { + eq_array_primitive!(array, index, Date64Array, val) + } + ScalarValue::TimestampSecond(val, _) => { + eq_array_primitive!(array, index, TimestampSecondArray, val) + } + ScalarValue::TimestampMillisecond(val, _) => { + eq_array_primitive!(array, index, TimestampMillisecondArray, val) + } + ScalarValue::TimestampMicrosecond(val, _) => { + eq_array_primitive!(array, index, TimestampMicrosecondArray, val) + } + ScalarValue::TimestampNanosecond(val, _) => { + eq_array_primitive!(array, index, TimestampNanosecondArray, val) + } + ScalarValue::IntervalYearMonth(val) => { + eq_array_primitive!(array, index, IntervalYearMonthArray, val) + } + ScalarValue::IntervalDayTime(val) => { + eq_array_primitive!(array, index, IntervalDayTimeArray, val) + } + ScalarValue::IntervalMonthDayNano(val) => { + eq_array_primitive!(array, index, IntervalMonthDayNanoArray, val) + } + ScalarValue::Struct(_, _) => unimplemented!(), + } + } + + /// Compares a dictionary array with indexes of type `key_type` + /// with the array @ index for equality with self + fn eq_array_dictionary( + &self, + array: &ArrayRef, + index: usize, + key_type: &DataType, + ) -> bool { + let (values, values_index) = match key_type { + DataType::Int8 => get_dict_value::(array, index).unwrap(), + DataType::Int16 => get_dict_value::(array, index).unwrap(), + DataType::Int32 => get_dict_value::(array, index).unwrap(), + DataType::Int64 => get_dict_value::(array, index).unwrap(), + DataType::UInt8 => get_dict_value::(array, index).unwrap(), + DataType::UInt16 => get_dict_value::(array, index).unwrap(), + DataType::UInt32 => get_dict_value::(array, index).unwrap(), + DataType::UInt64 => get_dict_value::(array, index).unwrap(), + _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), + }; + + match values_index { + Some(values_index) => self.eq_array(values, values_index), + None => self.is_null(), + } + } +} + +macro_rules! impl_scalar { + ($ty:ty, $scalar:tt) => { + impl From<$ty> for ScalarValue { + fn from(value: $ty) -> Self { + ScalarValue::$scalar(Some(value)) + } + } + + impl From> for ScalarValue { + fn from(value: Option<$ty>) -> Self { + ScalarValue::$scalar(value) + } + } + }; +} + +impl_scalar!(f64, Float64); +impl_scalar!(f32, Float32); +impl_scalar!(i8, Int8); +impl_scalar!(i16, Int16); +impl_scalar!(i32, Int32); +impl_scalar!(i64, Int64); +impl_scalar!(bool, Boolean); +impl_scalar!(u8, UInt8); +impl_scalar!(u16, UInt16); +impl_scalar!(u32, UInt32); +impl_scalar!(u64, UInt64); + +impl From<&str> for ScalarValue { + fn from(value: &str) -> Self { + Some(value).into() + } +} + +impl From> for ScalarValue { + fn from(value: Option<&str>) -> Self { + let value = value.map(|s| s.to_string()); + ScalarValue::Utf8(value) + } +} + +impl FromStr for ScalarValue { + type Err = Infallible; + + fn from_str(s: &str) -> std::result::Result { + Ok(s.into()) + } +} + +impl From> for ScalarValue { + fn from(value: Vec<(&str, ScalarValue)>) -> Self { + let (fields, scalars): (Vec<_>, Vec<_>) = value + .into_iter() + .map(|(name, scalar)| { + (Field::new(name, scalar.get_datatype(), false), scalar) + }) + .unzip(); + + Self::Struct(Some(Box::new(scalars)), Box::new(fields)) + } +} + +macro_rules! impl_try_from { + ($SCALAR:ident, $NATIVE:ident) => { + impl TryFrom for $NATIVE { + type Error = DataFusionError; + + fn try_from(value: ScalarValue) -> Result { + match value { + ScalarValue::$SCALAR(Some(inner_value)) => Ok(inner_value), + _ => Err(DataFusionError::Internal(format!( + "Cannot convert {:?} to {}", + value, + std::any::type_name::() + ))), + } + } + } + }; +} + +impl_try_from!(Int8, i8); +impl_try_from!(Int16, i16); + +// special implementation for i32 because of Date32 +impl TryFrom for i32 { + type Error = DataFusionError; + + fn try_from(value: ScalarValue) -> Result { + match value { + ScalarValue::Int32(Some(inner_value)) + | ScalarValue::Date32(Some(inner_value)) => Ok(inner_value), + _ => Err(DataFusionError::Internal(format!( + "Cannot convert {:?} to {}", + value, + std::any::type_name::() + ))), + } + } +} + +// special implementation for i64 because of TimeNanosecond +impl TryFrom for i64 { + type Error = DataFusionError; + + fn try_from(value: ScalarValue) -> Result { + match value { + ScalarValue::Int64(Some(inner_value)) + | ScalarValue::Date64(Some(inner_value)) + | ScalarValue::TimestampNanosecond(Some(inner_value), _) + | ScalarValue::TimestampMicrosecond(Some(inner_value), _) + | ScalarValue::TimestampMillisecond(Some(inner_value), _) + | ScalarValue::TimestampSecond(Some(inner_value), _) => Ok(inner_value), + _ => Err(DataFusionError::Internal(format!( + "Cannot convert {:?} to {}", + value, + std::any::type_name::() + ))), + } + } +} + +// special implementation for i128 because of Decimal128 +impl TryFrom for i128 { + type Error = DataFusionError; + + fn try_from(value: ScalarValue) -> Result { + match value { + ScalarValue::Decimal128(Some(inner_value), _, _) => Ok(inner_value), + _ => Err(DataFusionError::Internal(format!( + "Cannot convert {:?} to {}", + value, + std::any::type_name::() + ))), + } + } +} + +impl_try_from!(UInt8, u8); +impl_try_from!(UInt16, u16); +impl_try_from!(UInt32, u32); +impl_try_from!(UInt64, u64); +impl_try_from!(Float32, f32); +impl_try_from!(Float64, f64); +impl_try_from!(Boolean, bool); + +impl TryFrom<&DataType> for ScalarValue { + type Error = DataFusionError; + + /// Create a Null instance of ScalarValue for this datatype + fn try_from(datatype: &DataType) -> Result { + Ok(match datatype { + DataType::Boolean => ScalarValue::Boolean(None), + DataType::Float64 => ScalarValue::Float64(None), + DataType::Float32 => ScalarValue::Float32(None), + DataType::Int8 => ScalarValue::Int8(None), + DataType::Int16 => ScalarValue::Int16(None), + DataType::Int32 => ScalarValue::Int32(None), + DataType::Int64 => ScalarValue::Int64(None), + DataType::UInt8 => ScalarValue::UInt8(None), + DataType::UInt16 => ScalarValue::UInt16(None), + DataType::UInt32 => ScalarValue::UInt32(None), + DataType::UInt64 => ScalarValue::UInt64(None), + DataType::Decimal(precision, scale) => { + ScalarValue::Decimal128(None, *precision, *scale) + } + DataType::Utf8 => ScalarValue::Utf8(None), + DataType::LargeUtf8 => ScalarValue::LargeUtf8(None), + DataType::Date32 => ScalarValue::Date32(None), + DataType::Date64 => ScalarValue::Date64(None), + DataType::Timestamp(TimeUnit::Second, tz_opt) => { + ScalarValue::TimestampSecond(None, tz_opt.clone()) + } + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { + ScalarValue::TimestampMillisecond(None, tz_opt.clone()) + } + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { + ScalarValue::TimestampMicrosecond(None, tz_opt.clone()) + } + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { + ScalarValue::TimestampNanosecond(None, tz_opt.clone()) + } + DataType::Dictionary(_index_type, value_type) => { + value_type.as_ref().try_into()? + } + DataType::List(ref nested_type) => { + ScalarValue::List(None, Box::new(nested_type.data_type().clone())) + } + DataType::Struct(fields) => { + ScalarValue::Struct(None, Box::new(fields.clone())) + } + _ => { + return Err(DataFusionError::NotImplemented(format!( + "Can't create a scalar from data_type \"{:?}\"", + datatype + ))); + } + }) + } +} + +macro_rules! format_option { + ($F:expr, $EXPR:expr) => {{ + match $EXPR { + Some(e) => write!($F, "{}", e), + None => write!($F, "NULL"), + } + }}; +} + +impl fmt::Display for ScalarValue { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ScalarValue::Decimal128(v, p, s) => { + write!(f, "{:?},{:?},{:?}", v, p, s)?; + } + ScalarValue::Boolean(e) => format_option!(f, e)?, + ScalarValue::Float32(e) => format_option!(f, e)?, + ScalarValue::Float64(e) => format_option!(f, e)?, + ScalarValue::Int8(e) => format_option!(f, e)?, + ScalarValue::Int16(e) => format_option!(f, e)?, + ScalarValue::Int32(e) => format_option!(f, e)?, + ScalarValue::Int64(e) => format_option!(f, e)?, + ScalarValue::UInt8(e) => format_option!(f, e)?, + ScalarValue::UInt16(e) => format_option!(f, e)?, + ScalarValue::UInt32(e) => format_option!(f, e)?, + ScalarValue::UInt64(e) => format_option!(f, e)?, + ScalarValue::TimestampSecond(e, _) => format_option!(f, e)?, + ScalarValue::TimestampMillisecond(e, _) => format_option!(f, e)?, + ScalarValue::TimestampMicrosecond(e, _) => format_option!(f, e)?, + ScalarValue::TimestampNanosecond(e, _) => format_option!(f, e)?, + ScalarValue::Utf8(e) => format_option!(f, e)?, + ScalarValue::LargeUtf8(e) => format_option!(f, e)?, + ScalarValue::Binary(e) => match e { + Some(l) => write!( + f, + "{}", + l.iter() + .map(|v| format!("{}", v)) + .collect::>() + .join(",") + )?, + None => write!(f, "NULL")?, + }, + ScalarValue::LargeBinary(e) => match e { + Some(l) => write!( + f, + "{}", + l.iter() + .map(|v| format!("{}", v)) + .collect::>() + .join(",") + )?, + None => write!(f, "NULL")?, + }, + ScalarValue::List(e, _) => match e { + Some(l) => write!( + f, + "{}", + l.iter() + .map(|v| format!("{}", v)) + .collect::>() + .join(",") + )?, + None => write!(f, "NULL")?, + }, + ScalarValue::Date32(e) => format_option!(f, e)?, + ScalarValue::Date64(e) => format_option!(f, e)?, + ScalarValue::IntervalDayTime(e) => format_option!(f, e)?, + ScalarValue::IntervalYearMonth(e) => format_option!(f, e)?, + ScalarValue::IntervalMonthDayNano(e) => format_option!(f, e)?, + ScalarValue::Struct(e, fields) => match e { + Some(l) => write!( + f, + "{{{}}}", + l.iter() + .zip(fields.iter()) + .map(|(value, field)| format!("{}:{}", field.name(), value)) + .collect::>() + .join(",") + )?, + None => write!(f, "NULL")?, + }, + }; + Ok(()) + } +} + +impl fmt::Debug for ScalarValue { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ScalarValue::Decimal128(_, _, _) => write!(f, "Decimal128({})", self), + ScalarValue::Boolean(_) => write!(f, "Boolean({})", self), + ScalarValue::Float32(_) => write!(f, "Float32({})", self), + ScalarValue::Float64(_) => write!(f, "Float64({})", self), + ScalarValue::Int8(_) => write!(f, "Int8({})", self), + ScalarValue::Int16(_) => write!(f, "Int16({})", self), + ScalarValue::Int32(_) => write!(f, "Int32({})", self), + ScalarValue::Int64(_) => write!(f, "Int64({})", self), + ScalarValue::UInt8(_) => write!(f, "UInt8({})", self), + ScalarValue::UInt16(_) => write!(f, "UInt16({})", self), + ScalarValue::UInt32(_) => write!(f, "UInt32({})", self), + ScalarValue::UInt64(_) => write!(f, "UInt64({})", self), + ScalarValue::TimestampSecond(_, tz_opt) => { + write!(f, "TimestampSecond({}, {:?})", self, tz_opt) + } + ScalarValue::TimestampMillisecond(_, tz_opt) => { + write!(f, "TimestampMillisecond({}, {:?})", self, tz_opt) + } + ScalarValue::TimestampMicrosecond(_, tz_opt) => { + write!(f, "TimestampMicrosecond({}, {:?})", self, tz_opt) + } + ScalarValue::TimestampNanosecond(_, tz_opt) => { + write!(f, "TimestampNanosecond({}, {:?})", self, tz_opt) + } + ScalarValue::Utf8(None) => write!(f, "Utf8({})", self), + ScalarValue::Utf8(Some(_)) => write!(f, "Utf8(\"{}\")", self), + ScalarValue::LargeUtf8(None) => write!(f, "LargeUtf8({})", self), + ScalarValue::LargeUtf8(Some(_)) => write!(f, "LargeUtf8(\"{}\")", self), + ScalarValue::Binary(None) => write!(f, "Binary({})", self), + ScalarValue::Binary(Some(_)) => write!(f, "Binary(\"{}\")", self), + ScalarValue::LargeBinary(None) => write!(f, "LargeBinary({})", self), + ScalarValue::LargeBinary(Some(_)) => write!(f, "LargeBinary(\"{}\")", self), + ScalarValue::List(_, _) => write!(f, "List([{}])", self), + ScalarValue::Date32(_) => write!(f, "Date32(\"{}\")", self), + ScalarValue::Date64(_) => write!(f, "Date64(\"{}\")", self), + ScalarValue::IntervalDayTime(_) => { + write!(f, "IntervalDayTime(\"{}\")", self) + } + ScalarValue::IntervalYearMonth(_) => { + write!(f, "IntervalYearMonth(\"{}\")", self) + } + ScalarValue::IntervalMonthDayNano(_) => { + write!(f, "IntervalMonthDayNano(\"{}\")", self) + } + ScalarValue::Struct(e, fields) => { + // Use Debug representation of field values + match e { + Some(l) => write!( + f, + "Struct({{{}}})", + l.iter() + .zip(fields.iter()) + .map(|(value, field)| format!("{}:{:?}", field.name(), value)) + .collect::>() + .join(",") + ), + None => write!(f, "Struct(NULL)"), + } + } + } + } +} + +/// Trait used to map a NativeTime to a ScalarType. +pub trait ScalarType { + /// returns a scalar from an optional T + fn scalar(r: Option) -> ScalarValue; +} + +impl ScalarType for Float32Type { + fn scalar(r: Option) -> ScalarValue { + ScalarValue::Float32(r) + } +} + +impl ScalarType for TimestampSecondType { + fn scalar(r: Option) -> ScalarValue { + ScalarValue::TimestampSecond(r, None) + } +} + +impl ScalarType for TimestampMillisecondType { + fn scalar(r: Option) -> ScalarValue { + ScalarValue::TimestampMillisecond(r, None) + } +} + +impl ScalarType for TimestampMicrosecondType { + fn scalar(r: Option) -> ScalarValue { + ScalarValue::TimestampMicrosecond(r, None) + } +} + +impl ScalarType for TimestampNanosecondType { + fn scalar(r: Option) -> ScalarValue { + ScalarValue::TimestampNanosecond(r, None) + } +} diff --git a/datafusion/src/lib.rs b/datafusion/src/lib.rs index 9442f7e5fe9f..872a1c0a8678 100644 --- a/datafusion/src/lib.rs +++ b/datafusion/src/lib.rs @@ -224,9 +224,6 @@ pub use parquet; pub(crate) mod field_util; -#[cfg(feature = "pyarrow")] -mod pyarrow; - pub mod from_slice; #[cfg(test)] diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index 0f6a5cef63bf..7dc947565728 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -15,1945 +15,19 @@ // specific language governing permissions and limitations // under the License. -//! This module provides ScalarValue, an enum that can be used for storage of single elements - -use crate::error::{DataFusionError, Result}; -use arrow::{ - array::*, - compute::kernels::cast::cast, - datatypes::{ - ArrowDictionaryKeyType, ArrowNativeType, DataType, Field, Float32Type, - Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalUnit, TimeUnit, - TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, - TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, - }, -}; -use ordered_float::OrderedFloat; -use std::cmp::Ordering; -use std::convert::{Infallible, TryInto}; -use std::str::FromStr; -use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; - -// TODO may need to be moved to arrow-rs -/// The max precision and scale for decimal128 -pub(crate) const MAX_PRECISION_FOR_DECIMAL128: usize = 38; -pub(crate) const MAX_SCALE_FOR_DECIMAL128: usize = 38; - -/// Represents a dynamically typed, nullable single value. -/// This is the single-valued counter-part of arrow’s `Array`. -#[derive(Clone)] -pub enum ScalarValue { - /// true or false value - Boolean(Option), - /// 32bit float - Float32(Option), - /// 64bit float - Float64(Option), - /// 128bit decimal, using the i128 to represent the decimal - Decimal128(Option, usize, usize), - /// signed 8bit int - Int8(Option), - /// signed 16bit int - Int16(Option), - /// signed 32bit int - Int32(Option), - /// signed 64bit int - Int64(Option), - /// unsigned 8bit int - UInt8(Option), - /// unsigned 16bit int - UInt16(Option), - /// unsigned 32bit int - UInt32(Option), - /// unsigned 64bit int - UInt64(Option), - /// utf-8 encoded string. - Utf8(Option), - /// utf-8 encoded string representing a LargeString's arrow type. - LargeUtf8(Option), - /// binary - Binary(Option>), - /// large binary - LargeBinary(Option>), - /// list of nested ScalarValue (boxed to reduce size_of(ScalarValue)) - #[allow(clippy::box_collection)] - List(Option>>, Box), - /// Date stored as a signed 32bit int - Date32(Option), - /// Date stored as a signed 64bit int - Date64(Option), - /// Timestamp Second - TimestampSecond(Option, Option), - /// Timestamp Milliseconds - TimestampMillisecond(Option, Option), - /// Timestamp Microseconds - TimestampMicrosecond(Option, Option), - /// Timestamp Nanoseconds - TimestampNanosecond(Option, Option), - /// Interval with YearMonth unit - IntervalYearMonth(Option), - /// Interval with DayTime unit - IntervalDayTime(Option), - /// Interval with MonthDayNano unit - IntervalMonthDayNano(Option), - /// struct of nested ScalarValue (boxed to reduce size_of(ScalarValue)) - #[allow(clippy::box_collection)] - Struct(Option>>, Box>), -} - -// manual implementation of `PartialEq` that uses OrderedFloat to -// get defined behavior for floating point -impl PartialEq for ScalarValue { - fn eq(&self, other: &Self) -> bool { - use ScalarValue::*; - // This purposely doesn't have a catch-all "(_, _)" so that - // any newly added enum variant will require editing this list - // or else face a compile error - match (self, other) { - (Decimal128(v1, p1, s1), Decimal128(v2, p2, s2)) => { - v1.eq(v2) && p1.eq(p2) && s1.eq(s2) - } - (Decimal128(_, _, _), _) => false, - (Boolean(v1), Boolean(v2)) => v1.eq(v2), - (Boolean(_), _) => false, - (Float32(v1), Float32(v2)) => { - let v1 = v1.map(OrderedFloat); - let v2 = v2.map(OrderedFloat); - v1.eq(&v2) - } - (Float32(_), _) => false, - (Float64(v1), Float64(v2)) => { - let v1 = v1.map(OrderedFloat); - let v2 = v2.map(OrderedFloat); - v1.eq(&v2) - } - (Float64(_), _) => false, - (Int8(v1), Int8(v2)) => v1.eq(v2), - (Int8(_), _) => false, - (Int16(v1), Int16(v2)) => v1.eq(v2), - (Int16(_), _) => false, - (Int32(v1), Int32(v2)) => v1.eq(v2), - (Int32(_), _) => false, - (Int64(v1), Int64(v2)) => v1.eq(v2), - (Int64(_), _) => false, - (UInt8(v1), UInt8(v2)) => v1.eq(v2), - (UInt8(_), _) => false, - (UInt16(v1), UInt16(v2)) => v1.eq(v2), - (UInt16(_), _) => false, - (UInt32(v1), UInt32(v2)) => v1.eq(v2), - (UInt32(_), _) => false, - (UInt64(v1), UInt64(v2)) => v1.eq(v2), - (UInt64(_), _) => false, - (Utf8(v1), Utf8(v2)) => v1.eq(v2), - (Utf8(_), _) => false, - (LargeUtf8(v1), LargeUtf8(v2)) => v1.eq(v2), - (LargeUtf8(_), _) => false, - (Binary(v1), Binary(v2)) => v1.eq(v2), - (Binary(_), _) => false, - (LargeBinary(v1), LargeBinary(v2)) => v1.eq(v2), - (LargeBinary(_), _) => false, - (List(v1, t1), List(v2, t2)) => v1.eq(v2) && t1.eq(t2), - (List(_, _), _) => false, - (Date32(v1), Date32(v2)) => v1.eq(v2), - (Date32(_), _) => false, - (Date64(v1), Date64(v2)) => v1.eq(v2), - (Date64(_), _) => false, - (TimestampSecond(v1, _), TimestampSecond(v2, _)) => v1.eq(v2), - (TimestampSecond(_, _), _) => false, - (TimestampMillisecond(v1, _), TimestampMillisecond(v2, _)) => v1.eq(v2), - (TimestampMillisecond(_, _), _) => false, - (TimestampMicrosecond(v1, _), TimestampMicrosecond(v2, _)) => v1.eq(v2), - (TimestampMicrosecond(_, _), _) => false, - (TimestampNanosecond(v1, _), TimestampNanosecond(v2, _)) => v1.eq(v2), - (TimestampNanosecond(_, _), _) => false, - (IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.eq(v2), - (IntervalYearMonth(_), _) => false, - (IntervalDayTime(v1), IntervalDayTime(v2)) => v1.eq(v2), - (IntervalDayTime(_), _) => false, - (IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.eq(v2), - (IntervalMonthDayNano(_), _) => false, - (Struct(v1, t1), Struct(v2, t2)) => v1.eq(v2) && t1.eq(t2), - (Struct(_, _), _) => false, - } - } -} - -// manual implementation of `PartialOrd` that uses OrderedFloat to -// get defined behavior for floating point -impl PartialOrd for ScalarValue { - fn partial_cmp(&self, other: &Self) -> Option { - use ScalarValue::*; - // This purposely doesn't have a catch-all "(_, _)" so that - // any newly added enum variant will require editing this list - // or else face a compile error - match (self, other) { - (Decimal128(v1, p1, s1), Decimal128(v2, p2, s2)) => { - if p1.eq(p2) && s1.eq(s2) { - v1.partial_cmp(v2) - } else { - // Two decimal values can be compared if they have the same precision and scale. - None - } - } - (Decimal128(_, _, _), _) => None, - (Boolean(v1), Boolean(v2)) => v1.partial_cmp(v2), - (Boolean(_), _) => None, - (Float32(v1), Float32(v2)) => { - let v1 = v1.map(OrderedFloat); - let v2 = v2.map(OrderedFloat); - v1.partial_cmp(&v2) - } - (Float32(_), _) => None, - (Float64(v1), Float64(v2)) => { - let v1 = v1.map(OrderedFloat); - let v2 = v2.map(OrderedFloat); - v1.partial_cmp(&v2) - } - (Float64(_), _) => None, - (Int8(v1), Int8(v2)) => v1.partial_cmp(v2), - (Int8(_), _) => None, - (Int16(v1), Int16(v2)) => v1.partial_cmp(v2), - (Int16(_), _) => None, - (Int32(v1), Int32(v2)) => v1.partial_cmp(v2), - (Int32(_), _) => None, - (Int64(v1), Int64(v2)) => v1.partial_cmp(v2), - (Int64(_), _) => None, - (UInt8(v1), UInt8(v2)) => v1.partial_cmp(v2), - (UInt8(_), _) => None, - (UInt16(v1), UInt16(v2)) => v1.partial_cmp(v2), - (UInt16(_), _) => None, - (UInt32(v1), UInt32(v2)) => v1.partial_cmp(v2), - (UInt32(_), _) => None, - (UInt64(v1), UInt64(v2)) => v1.partial_cmp(v2), - (UInt64(_), _) => None, - (Utf8(v1), Utf8(v2)) => v1.partial_cmp(v2), - (Utf8(_), _) => None, - (LargeUtf8(v1), LargeUtf8(v2)) => v1.partial_cmp(v2), - (LargeUtf8(_), _) => None, - (Binary(v1), Binary(v2)) => v1.partial_cmp(v2), - (Binary(_), _) => None, - (LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2), - (LargeBinary(_), _) => None, - (List(v1, t1), List(v2, t2)) => { - if t1.eq(t2) { - v1.partial_cmp(v2) - } else { - None - } - } - (List(_, _), _) => None, - (Date32(v1), Date32(v2)) => v1.partial_cmp(v2), - (Date32(_), _) => None, - (Date64(v1), Date64(v2)) => v1.partial_cmp(v2), - (Date64(_), _) => None, - (TimestampSecond(v1, _), TimestampSecond(v2, _)) => v1.partial_cmp(v2), - (TimestampSecond(_, _), _) => None, - (TimestampMillisecond(v1, _), TimestampMillisecond(v2, _)) => { - v1.partial_cmp(v2) - } - (TimestampMillisecond(_, _), _) => None, - (TimestampMicrosecond(v1, _), TimestampMicrosecond(v2, _)) => { - v1.partial_cmp(v2) - } - (TimestampMicrosecond(_, _), _) => None, - (TimestampNanosecond(v1, _), TimestampNanosecond(v2, _)) => { - v1.partial_cmp(v2) - } - (TimestampNanosecond(_, _), _) => None, - (IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.partial_cmp(v2), - (IntervalYearMonth(_), _) => None, - (IntervalDayTime(v1), IntervalDayTime(v2)) => v1.partial_cmp(v2), - (IntervalDayTime(_), _) => None, - (IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.partial_cmp(v2), - (IntervalMonthDayNano(_), _) => None, - (Struct(v1, t1), Struct(v2, t2)) => { - if t1.eq(t2) { - v1.partial_cmp(v2) - } else { - None - } - } - (Struct(_, _), _) => None, - } - } -} - -impl Eq for ScalarValue {} - -// manual implementation of `Hash` that uses OrderedFloat to -// get defined behavior for floating point -impl std::hash::Hash for ScalarValue { - fn hash(&self, state: &mut H) { - use ScalarValue::*; - match self { - Decimal128(v, p, s) => { - v.hash(state); - p.hash(state); - s.hash(state) - } - Boolean(v) => v.hash(state), - Float32(v) => { - let v = v.map(OrderedFloat); - v.hash(state) - } - Float64(v) => { - let v = v.map(OrderedFloat); - v.hash(state) - } - Int8(v) => v.hash(state), - Int16(v) => v.hash(state), - Int32(v) => v.hash(state), - Int64(v) => v.hash(state), - UInt8(v) => v.hash(state), - UInt16(v) => v.hash(state), - UInt32(v) => v.hash(state), - UInt64(v) => v.hash(state), - Utf8(v) => v.hash(state), - LargeUtf8(v) => v.hash(state), - Binary(v) => v.hash(state), - LargeBinary(v) => v.hash(state), - List(v, t) => { - v.hash(state); - t.hash(state); - } - Date32(v) => v.hash(state), - Date64(v) => v.hash(state), - TimestampSecond(v, _) => v.hash(state), - TimestampMillisecond(v, _) => v.hash(state), - TimestampMicrosecond(v, _) => v.hash(state), - TimestampNanosecond(v, _) => v.hash(state), - IntervalYearMonth(v) => v.hash(state), - IntervalDayTime(v) => v.hash(state), - IntervalMonthDayNano(v) => v.hash(state), - Struct(v, t) => { - v.hash(state); - t.hash(state); - } - } - } -} - -// return the index into the dictionary values for array@index as well -// as a reference to the dictionary values array. Returns None for the -// index if the array is NULL at index -#[inline] -fn get_dict_value( - array: &ArrayRef, - index: usize, -) -> Result<(&ArrayRef, Option)> { - let dict_array = array.as_any().downcast_ref::>().unwrap(); - - // look up the index in the values dictionary - let keys_col = dict_array.keys(); - if !keys_col.is_valid(index) { - return Ok((dict_array.values(), None)); - } - let values_index = keys_col.value(index).to_usize().ok_or_else(|| { - DataFusionError::Internal(format!( - "Can not convert index to usize in dictionary of type creating group by value {:?}", - keys_col.data_type() - )) - })?; - - Ok((dict_array.values(), Some(values_index))) -} - -macro_rules! typed_cast_tz { - ($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident, $TZ:expr) => {{ - let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); - ScalarValue::$SCALAR( - match array.is_null($index) { - true => None, - false => Some(array.value($index).into()), - }, - $TZ.clone(), - ) - }}; -} - -macro_rules! typed_cast { - ($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident) => {{ - let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); - ScalarValue::$SCALAR(match array.is_null($index) { - true => None, - false => Some(array.value($index).into()), - }) - }}; -} - -macro_rules! build_list { - ($VALUE_BUILDER_TY:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr) => {{ - match $VALUES { - // the return on the macro is necessary, to short-circuit and return ArrayRef - None => { - return new_null_array( - &DataType::List(Box::new(Field::new( - "item", - DataType::$SCALAR_TY, - true, - ))), - $SIZE, - ) - } - Some(values) => { - build_values_list!($VALUE_BUILDER_TY, $SCALAR_TY, values.as_ref(), $SIZE) - } - } - }}; -} - -macro_rules! build_timestamp_list { - ($TIME_UNIT:expr, $TIME_ZONE:expr, $VALUES:expr, $SIZE:expr) => {{ - match $VALUES { - // the return on the macro is necessary, to short-circuit and return ArrayRef - None => { - return new_null_array( - &DataType::List(Box::new(Field::new( - "item", - DataType::Timestamp($TIME_UNIT, $TIME_ZONE), - true, - ))), - $SIZE, - ) - } - Some(values) => { - let values = values.as_ref(); - match $TIME_UNIT { - TimeUnit::Second => build_values_list_tz!( - TimestampSecondBuilder, - TimestampSecond, - values, - $SIZE - ), - TimeUnit::Microsecond => build_values_list_tz!( - TimestampMillisecondBuilder, - TimestampMillisecond, - values, - $SIZE - ), - TimeUnit::Millisecond => build_values_list_tz!( - TimestampMicrosecondBuilder, - TimestampMicrosecond, - values, - $SIZE - ), - TimeUnit::Nanosecond => build_values_list_tz!( - TimestampNanosecondBuilder, - TimestampNanosecond, - values, - $SIZE - ), - } - } - } - }}; -} - -macro_rules! build_values_list { - ($VALUE_BUILDER_TY:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr) => {{ - let mut builder = ListBuilder::new($VALUE_BUILDER_TY::new($VALUES.len())); - - for _ in 0..$SIZE { - for scalar_value in $VALUES { - match scalar_value { - ScalarValue::$SCALAR_TY(Some(v)) => { - builder.values().append_value(v.clone()).unwrap() - } - ScalarValue::$SCALAR_TY(None) => { - builder.values().append_null().unwrap(); - } - _ => panic!("Incompatible ScalarValue for list"), - }; - } - builder.append(true).unwrap(); - } - - builder.finish() - }}; -} - -macro_rules! build_values_list_tz { - ($VALUE_BUILDER_TY:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr) => {{ - let mut builder = ListBuilder::new($VALUE_BUILDER_TY::new($VALUES.len())); - - for _ in 0..$SIZE { - for scalar_value in $VALUES { - match scalar_value { - ScalarValue::$SCALAR_TY(Some(v), _) => { - builder.values().append_value(v.clone()).unwrap() - } - ScalarValue::$SCALAR_TY(None, _) => { - builder.values().append_null().unwrap(); - } - _ => panic!("Incompatible ScalarValue for list"), - }; - } - builder.append(true).unwrap(); - } - - builder.finish() - }}; -} - -macro_rules! build_array_from_option { - ($DATA_TYPE:ident, $ARRAY_TYPE:ident, $EXPR:expr, $SIZE:expr) => {{ - match $EXPR { - Some(value) => Arc::new($ARRAY_TYPE::from_value(*value, $SIZE)), - None => new_null_array(&DataType::$DATA_TYPE, $SIZE), - } - }}; - ($DATA_TYPE:ident, $ENUM:expr, $ARRAY_TYPE:ident, $EXPR:expr, $SIZE:expr) => {{ - match $EXPR { - Some(value) => Arc::new($ARRAY_TYPE::from_value(*value, $SIZE)), - None => new_null_array(&DataType::$DATA_TYPE($ENUM), $SIZE), - } - }}; - ($DATA_TYPE:ident, $ENUM:expr, $ENUM2:expr, $ARRAY_TYPE:ident, $EXPR:expr, $SIZE:expr) => {{ - match $EXPR { - Some(value) => { - let array: ArrayRef = Arc::new($ARRAY_TYPE::from_value(*value, $SIZE)); - // Need to call cast to cast to final data type with timezone/extra param - cast(&array, &DataType::$DATA_TYPE($ENUM, $ENUM2)) - .expect("cannot do temporal cast") - } - None => new_null_array(&DataType::$DATA_TYPE($ENUM, $ENUM2), $SIZE), - } - }}; -} - -macro_rules! eq_array_primitive { - ($array:expr, $index:expr, $ARRAYTYPE:ident, $VALUE:expr) => {{ - let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); - let is_valid = array.is_valid($index); - match $VALUE { - Some(val) => is_valid && &array.value($index) == val, - None => !is_valid, - } - }}; -} - -impl ScalarValue { - /// Create a decimal Scalar from value/precision and scale. - pub fn try_new_decimal128( - value: i128, - precision: usize, - scale: usize, - ) -> Result { - // make sure the precision and scale is valid - if precision <= MAX_PRECISION_FOR_DECIMAL128 && scale <= precision { - return Ok(ScalarValue::Decimal128(Some(value), precision, scale)); - } - return Err(DataFusionError::Internal(format!( - "Can not new a decimal type ScalarValue for precision {} and scale {}", - precision, scale - ))); - } - /// Getter for the `DataType` of the value - pub fn get_datatype(&self) -> DataType { - match self { - ScalarValue::Boolean(_) => DataType::Boolean, - ScalarValue::UInt8(_) => DataType::UInt8, - ScalarValue::UInt16(_) => DataType::UInt16, - ScalarValue::UInt32(_) => DataType::UInt32, - ScalarValue::UInt64(_) => DataType::UInt64, - ScalarValue::Int8(_) => DataType::Int8, - ScalarValue::Int16(_) => DataType::Int16, - ScalarValue::Int32(_) => DataType::Int32, - ScalarValue::Int64(_) => DataType::Int64, - ScalarValue::Decimal128(_, precision, scale) => { - DataType::Decimal(*precision, *scale) - } - ScalarValue::TimestampSecond(_, tz_opt) => { - DataType::Timestamp(TimeUnit::Second, tz_opt.clone()) - } - ScalarValue::TimestampMillisecond(_, tz_opt) => { - DataType::Timestamp(TimeUnit::Millisecond, tz_opt.clone()) - } - ScalarValue::TimestampMicrosecond(_, tz_opt) => { - DataType::Timestamp(TimeUnit::Microsecond, tz_opt.clone()) - } - ScalarValue::TimestampNanosecond(_, tz_opt) => { - DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()) - } - ScalarValue::Float32(_) => DataType::Float32, - ScalarValue::Float64(_) => DataType::Float64, - ScalarValue::Utf8(_) => DataType::Utf8, - ScalarValue::LargeUtf8(_) => DataType::LargeUtf8, - ScalarValue::Binary(_) => DataType::Binary, - ScalarValue::LargeBinary(_) => DataType::LargeBinary, - ScalarValue::List(_, data_type) => DataType::List(Box::new(Field::new( - "item", - data_type.as_ref().clone(), - true, - ))), - ScalarValue::Date32(_) => DataType::Date32, - ScalarValue::Date64(_) => DataType::Date64, - ScalarValue::IntervalYearMonth(_) => { - DataType::Interval(IntervalUnit::YearMonth) - } - ScalarValue::IntervalDayTime(_) => DataType::Interval(IntervalUnit::DayTime), - ScalarValue::IntervalMonthDayNano(_) => { - DataType::Interval(IntervalUnit::MonthDayNano) - } - ScalarValue::Struct(_, fields) => DataType::Struct(fields.as_ref().clone()), - } - } - - /// Calculate arithmetic negation for a scalar value - pub fn arithmetic_negate(&self) -> Self { - match self { - ScalarValue::Boolean(None) - | ScalarValue::Int8(None) - | ScalarValue::Int16(None) - | ScalarValue::Int32(None) - | ScalarValue::Int64(None) - | ScalarValue::Float32(None) => self.clone(), - ScalarValue::Float64(Some(v)) => ScalarValue::Float64(Some(-v)), - ScalarValue::Float32(Some(v)) => ScalarValue::Float32(Some(-v)), - ScalarValue::Int8(Some(v)) => ScalarValue::Int8(Some(-v)), - ScalarValue::Int16(Some(v)) => ScalarValue::Int16(Some(-v)), - ScalarValue::Int32(Some(v)) => ScalarValue::Int32(Some(-v)), - ScalarValue::Int64(Some(v)) => ScalarValue::Int64(Some(-v)), - ScalarValue::Decimal128(Some(v), precision, scale) => { - ScalarValue::Decimal128(Some(-v), *precision, *scale) - } - _ => panic!("Cannot run arithmetic negate on scalar value: {:?}", self), - } - } - - /// whether this value is null or not. - pub fn is_null(&self) -> bool { - matches!( - *self, - ScalarValue::Boolean(None) - | ScalarValue::UInt8(None) - | ScalarValue::UInt16(None) - | ScalarValue::UInt32(None) - | ScalarValue::UInt64(None) - | ScalarValue::Int8(None) - | ScalarValue::Int16(None) - | ScalarValue::Int32(None) - | ScalarValue::Int64(None) - | ScalarValue::Float32(None) - | ScalarValue::Float64(None) - | ScalarValue::Date32(None) - | ScalarValue::Date64(None) - | ScalarValue::Utf8(None) - | ScalarValue::LargeUtf8(None) - | ScalarValue::List(None, _) - | ScalarValue::TimestampSecond(None, _) - | ScalarValue::TimestampMillisecond(None, _) - | ScalarValue::TimestampMicrosecond(None, _) - | ScalarValue::TimestampNanosecond(None, _) - | ScalarValue::Struct(None, _) - | ScalarValue::Decimal128(None, _, _) // For decimal type, the value is null means ScalarValue::Decimal128 is null. - ) - } - - /// Converts a scalar value into an 1-row array. - pub fn to_array(&self) -> ArrayRef { - self.to_array_of_size(1) - } - - /// Converts an iterator of references [`ScalarValue`] into an [`ArrayRef`] - /// corresponding to those values. For example, - /// - /// Returns an error if the iterator is empty or if the - /// [`ScalarValue`]s are not all the same type - /// - /// Example - /// ``` - /// use datafusion::scalar::ScalarValue; - /// use arrow::array::{ArrayRef, BooleanArray}; - /// - /// let scalars = vec![ - /// ScalarValue::Boolean(Some(true)), - /// ScalarValue::Boolean(None), - /// ScalarValue::Boolean(Some(false)), - /// ]; - /// - /// // Build an Array from the list of ScalarValues - /// let array = ScalarValue::iter_to_array(scalars.into_iter()) - /// .unwrap(); - /// - /// let expected: ArrayRef = std::sync::Arc::new( - /// BooleanArray::from(vec![ - /// Some(true), - /// None, - /// Some(false) - /// ] - /// )); - /// - /// assert_eq!(&array, &expected); - /// ``` - pub fn iter_to_array( - scalars: impl IntoIterator, - ) -> Result { - let mut scalars = scalars.into_iter().peekable(); - - // figure out the type based on the first element - let data_type = match scalars.peek() { - None => { - return Err(DataFusionError::Internal( - "Empty iterator passed to ScalarValue::iter_to_array".to_string(), - )); - } - Some(sv) => sv.get_datatype(), - }; - - /// Creates an array of $ARRAY_TY by unpacking values of - /// SCALAR_TY for primitive types - macro_rules! build_array_primitive { - ($ARRAY_TY:ident, $SCALAR_TY:ident) => {{ - { - let array = scalars - .map(|sv| { - if let ScalarValue::$SCALAR_TY(v) = sv { - Ok(v) - } else { - Err(DataFusionError::Internal(format!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected {:?}, got {:?}", - data_type, sv - ))) - } - }) - .collect::>()?; - - Arc::new(array) - } - }}; - } - - macro_rules! build_array_primitive_tz { - ($ARRAY_TY:ident, $SCALAR_TY:ident) => {{ - { - let array = scalars - .map(|sv| { - if let ScalarValue::$SCALAR_TY(v, _) = sv { - Ok(v) - } else { - Err(DataFusionError::Internal(format!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected {:?}, got {:?}", - data_type, sv - ))) - } - }) - .collect::>()?; - - Arc::new(array) - } - }}; - } - - /// Creates an array of $ARRAY_TY by unpacking values of - /// SCALAR_TY for "string-like" types. - macro_rules! build_array_string { - ($ARRAY_TY:ident, $SCALAR_TY:ident) => {{ - { - let array = scalars - .map(|sv| { - if let ScalarValue::$SCALAR_TY(v) = sv { - Ok(v) - } else { - Err(DataFusionError::Internal(format!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected {:?}, got {:?}", - data_type, sv - ))) - } - }) - .collect::>()?; - Arc::new(array) - } - }}; - } - - macro_rules! build_array_list_primitive { - ($ARRAY_TY:ident, $SCALAR_TY:ident, $NATIVE_TYPE:ident) => {{ - Arc::new(ListArray::from_iter_primitive::<$ARRAY_TY, _, _>( - scalars.into_iter().map(|x| match x { - ScalarValue::List(xs, _) => xs.map(|x| { - x.iter() - .map(|x| match x { - ScalarValue::$SCALAR_TY(i) => *i, - sv => panic!("Inconsistent types in ScalarValue::iter_to_array. \ - Expected {:?}, got {:?}", data_type, sv), - }) - .collect::>>() - }), - sv => panic!("Inconsistent types in ScalarValue::iter_to_array. \ - Expected {:?}, got {:?}", data_type, sv), - }), - )) - }}; - } - - macro_rules! build_array_list_string { - ($BUILDER:ident, $SCALAR_TY:ident) => {{ - let mut builder = ListBuilder::new($BUILDER::new(0)); - - for scalar in scalars.into_iter() { - match scalar { - ScalarValue::List(Some(xs), _) => { - let xs = *xs; - for s in xs { - match s { - ScalarValue::$SCALAR_TY(Some(val)) => { - builder.values().append_value(val)?; - } - ScalarValue::$SCALAR_TY(None) => { - builder.values().append_null()?; - } - sv => return Err(DataFusionError::Internal(format!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected Utf8, got {:?}", - sv - ))), - } - } - builder.append(true)?; - } - ScalarValue::List(None, _) => { - builder.append(false)?; - } - sv => { - return Err(DataFusionError::Internal(format!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected List, got {:?}", - sv - ))) - } - } - } - - Arc::new(builder.finish()) - - }} - } - - let array: ArrayRef = match &data_type { - DataType::Decimal(precision, scale) => { - let decimal_array = - ScalarValue::iter_to_decimal_array(scalars, precision, scale)?; - Arc::new(decimal_array) - } - DataType::Boolean => build_array_primitive!(BooleanArray, Boolean), - DataType::Float32 => build_array_primitive!(Float32Array, Float32), - DataType::Float64 => build_array_primitive!(Float64Array, Float64), - DataType::Int8 => build_array_primitive!(Int8Array, Int8), - DataType::Int16 => build_array_primitive!(Int16Array, Int16), - DataType::Int32 => build_array_primitive!(Int32Array, Int32), - DataType::Int64 => build_array_primitive!(Int64Array, Int64), - DataType::UInt8 => build_array_primitive!(UInt8Array, UInt8), - DataType::UInt16 => build_array_primitive!(UInt16Array, UInt16), - DataType::UInt32 => build_array_primitive!(UInt32Array, UInt32), - DataType::UInt64 => build_array_primitive!(UInt64Array, UInt64), - DataType::Utf8 => build_array_string!(StringArray, Utf8), - DataType::LargeUtf8 => build_array_string!(LargeStringArray, LargeUtf8), - DataType::Binary => build_array_string!(BinaryArray, Binary), - DataType::LargeBinary => build_array_string!(LargeBinaryArray, LargeBinary), - DataType::Date32 => build_array_primitive!(Date32Array, Date32), - DataType::Date64 => build_array_primitive!(Date64Array, Date64), - DataType::Timestamp(TimeUnit::Second, _) => { - build_array_primitive_tz!(TimestampSecondArray, TimestampSecond) - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - build_array_primitive_tz!(TimestampMillisecondArray, TimestampMillisecond) - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - build_array_primitive_tz!(TimestampMicrosecondArray, TimestampMicrosecond) - } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - build_array_primitive_tz!(TimestampNanosecondArray, TimestampNanosecond) - } - DataType::Interval(IntervalUnit::DayTime) => { - build_array_primitive!(IntervalDayTimeArray, IntervalDayTime) - } - DataType::Interval(IntervalUnit::YearMonth) => { - build_array_primitive!(IntervalYearMonthArray, IntervalYearMonth) - } - DataType::List(fields) if fields.data_type() == &DataType::Int8 => { - build_array_list_primitive!(Int8Type, Int8, i8) - } - DataType::List(fields) if fields.data_type() == &DataType::Int16 => { - build_array_list_primitive!(Int16Type, Int16, i16) - } - DataType::List(fields) if fields.data_type() == &DataType::Int32 => { - build_array_list_primitive!(Int32Type, Int32, i32) - } - DataType::List(fields) if fields.data_type() == &DataType::Int64 => { - build_array_list_primitive!(Int64Type, Int64, i64) - } - DataType::List(fields) if fields.data_type() == &DataType::UInt8 => { - build_array_list_primitive!(UInt8Type, UInt8, u8) - } - DataType::List(fields) if fields.data_type() == &DataType::UInt16 => { - build_array_list_primitive!(UInt16Type, UInt16, u16) - } - DataType::List(fields) if fields.data_type() == &DataType::UInt32 => { - build_array_list_primitive!(UInt32Type, UInt32, u32) - } - DataType::List(fields) if fields.data_type() == &DataType::UInt64 => { - build_array_list_primitive!(UInt64Type, UInt64, u64) - } - DataType::List(fields) if fields.data_type() == &DataType::Float32 => { - build_array_list_primitive!(Float32Type, Float32, f32) - } - DataType::List(fields) if fields.data_type() == &DataType::Float64 => { - build_array_list_primitive!(Float64Type, Float64, f64) - } - DataType::List(fields) if fields.data_type() == &DataType::Utf8 => { - build_array_list_string!(StringBuilder, Utf8) - } - DataType::List(fields) if fields.data_type() == &DataType::LargeUtf8 => { - build_array_list_string!(LargeStringBuilder, LargeUtf8) - } - DataType::List(_) => { - // Fallback case handling homogeneous lists with any ScalarValue element type - let list_array = ScalarValue::iter_to_array_list(scalars, &data_type)?; - Arc::new(list_array) - } - DataType::Struct(fields) => { - // Initialize a Vector to store the ScalarValues for each column - let mut columns: Vec> = - (0..fields.len()).map(|_| Vec::new()).collect(); - - // Iterate over scalars to populate the column scalars for each row - for scalar in scalars { - if let ScalarValue::Struct(values, fields) = scalar { - match values { - Some(values) => { - // Push value for each field - for c in 0..columns.len() { - let column = columns.get_mut(c).unwrap(); - column.push(values[c].clone()); - } - } - None => { - // Push NULL of the appropriate type for each field - for c in 0..columns.len() { - let dtype = fields[c].data_type(); - let column = columns.get_mut(c).unwrap(); - column.push(ScalarValue::try_from(dtype)?); - } - } - }; - } else { - return Err(DataFusionError::Internal(format!( - "Expected Struct but found: {}", - scalar - ))); - }; - } - - // Call iter_to_array recursively to convert the scalars for each column into Arrow arrays - let field_values = fields - .iter() - .zip(columns) - .map(|(field, column)| -> Result<(Field, ArrayRef)> { - Ok((field.clone(), Self::iter_to_array(column)?)) - }) - .collect::>>()?; - - Arc::new(StructArray::from(field_values)) - } - _ => { - return Err(DataFusionError::Internal(format!( - "Unsupported creation of {:?} array from ScalarValue {:?}", - data_type, - scalars.peek() - ))); - } - }; - - Ok(array) - } - - fn iter_to_decimal_array( - scalars: impl IntoIterator, - precision: &usize, - scale: &usize, - ) -> Result { - // collect the value as Option - let array = scalars - .into_iter() - .map(|element: ScalarValue| match element { - ScalarValue::Decimal128(v1, _, _) => v1, - _ => unreachable!(), - }) - .collect::>>(); - - // build the decimal array using the Decimal Builder - let mut builder = DecimalBuilder::new(array.len(), *precision, *scale); - array.iter().for_each(|element| match element { - None => { - builder.append_null().unwrap(); - } - Some(v) => { - builder.append_value(*v).unwrap(); - } - }); - Ok(builder.finish()) - } - - fn iter_to_array_list( - scalars: impl IntoIterator, - data_type: &DataType, - ) -> Result> { - let mut offsets = Int32Array::builder(0); - if let Err(err) = offsets.append_value(0) { - return Err(DataFusionError::ArrowError(err)); - } - - let mut elements: Vec = Vec::new(); - let mut valid = BooleanBufferBuilder::new(0); - let mut flat_len = 0i32; - for scalar in scalars { - if let ScalarValue::List(values, _) = scalar { - match values { - Some(values) => { - let element_array = ScalarValue::iter_to_array(*values)?; - - // Add new offset index - flat_len += element_array.len() as i32; - if let Err(err) = offsets.append_value(flat_len) { - return Err(DataFusionError::ArrowError(err)); - } - - elements.push(element_array); - - // Element is valid - valid.append(true); - } - None => { - // Repeat previous offset index - if let Err(err) = offsets.append_value(flat_len) { - return Err(DataFusionError::ArrowError(err)); - } - - // Element is null - valid.append(false); - } - } - } else { - return Err(DataFusionError::Internal(format!( - "Expected ScalarValue::List element. Received {:?}", - scalar - ))); - } - } - - // Concatenate element arrays to create single flat array - let element_arrays: Vec<&dyn Array> = - elements.iter().map(|a| a.as_ref()).collect(); - let flat_array = match arrow::compute::concat(&element_arrays) { - Ok(flat_array) => flat_array, - Err(err) => return Err(DataFusionError::ArrowError(err)), - }; - - // Build ListArray using ArrayData so we can specify a flat inner array, and offset indices - let offsets_array = offsets.finish(); - let array_data = ArrayDataBuilder::new(data_type.clone()) - .len(offsets_array.len() - 1) - .null_bit_buffer(valid.finish()) - .add_buffer(offsets_array.data().buffers()[0].clone()) - .add_child_data(flat_array.data().clone()); - - let list_array = ListArray::from(array_data.build()?); - Ok(list_array) - } - - fn build_decimal_array( - value: &Option, - precision: &usize, - scale: &usize, - size: usize, - ) -> DecimalArray { - let mut builder = DecimalBuilder::new(size, *precision, *scale); - match value { - None => { - for _i in 0..size { - builder.append_null().unwrap(); - } - } - Some(v) => { - let v = *v; - for _i in 0..size { - builder.append_value(v).unwrap(); - } - } - }; - builder.finish() - } - - /// Converts a scalar value into an array of `size` rows. - pub fn to_array_of_size(&self, size: usize) -> ArrayRef { - match self { - ScalarValue::Decimal128(e, precision, scale) => { - Arc::new(ScalarValue::build_decimal_array(e, precision, scale, size)) - } - ScalarValue::Boolean(e) => { - Arc::new(BooleanArray::from(vec![*e; size])) as ArrayRef - } - ScalarValue::Float64(e) => { - build_array_from_option!(Float64, Float64Array, e, size) - } - ScalarValue::Float32(e) => { - build_array_from_option!(Float32, Float32Array, e, size) - } - ScalarValue::Int8(e) => build_array_from_option!(Int8, Int8Array, e, size), - ScalarValue::Int16(e) => build_array_from_option!(Int16, Int16Array, e, size), - ScalarValue::Int32(e) => build_array_from_option!(Int32, Int32Array, e, size), - ScalarValue::Int64(e) => build_array_from_option!(Int64, Int64Array, e, size), - ScalarValue::UInt8(e) => build_array_from_option!(UInt8, UInt8Array, e, size), - ScalarValue::UInt16(e) => { - build_array_from_option!(UInt16, UInt16Array, e, size) - } - ScalarValue::UInt32(e) => { - build_array_from_option!(UInt32, UInt32Array, e, size) - } - ScalarValue::UInt64(e) => { - build_array_from_option!(UInt64, UInt64Array, e, size) - } - ScalarValue::TimestampSecond(e, tz_opt) => build_array_from_option!( - Timestamp, - TimeUnit::Second, - tz_opt.clone(), - TimestampSecondArray, - e, - size - ), - ScalarValue::TimestampMillisecond(e, tz_opt) => build_array_from_option!( - Timestamp, - TimeUnit::Millisecond, - tz_opt.clone(), - TimestampMillisecondArray, - e, - size - ), - - ScalarValue::TimestampMicrosecond(e, tz_opt) => build_array_from_option!( - Timestamp, - TimeUnit::Microsecond, - tz_opt.clone(), - TimestampMicrosecondArray, - e, - size - ), - ScalarValue::TimestampNanosecond(e, tz_opt) => build_array_from_option!( - Timestamp, - TimeUnit::Nanosecond, - tz_opt.clone(), - TimestampNanosecondArray, - e, - size - ), - ScalarValue::Utf8(e) => match e { - Some(value) => { - Arc::new(StringArray::from_iter_values(repeat(value).take(size))) - } - None => new_null_array(&DataType::Utf8, size), - }, - ScalarValue::LargeUtf8(e) => match e { - Some(value) => { - Arc::new(LargeStringArray::from_iter_values(repeat(value).take(size))) - } - None => new_null_array(&DataType::LargeUtf8, size), - }, - ScalarValue::Binary(e) => match e { - Some(value) => Arc::new( - repeat(Some(value.as_slice())) - .take(size) - .collect::(), - ), - None => { - Arc::new(repeat(None::<&str>).take(size).collect::()) - } - }, - ScalarValue::LargeBinary(e) => match e { - Some(value) => Arc::new( - repeat(Some(value.as_slice())) - .take(size) - .collect::(), - ), - None => Arc::new( - repeat(None::<&str>) - .take(size) - .collect::(), - ), - }, - ScalarValue::List(values, data_type) => Arc::new(match data_type.as_ref() { - DataType::Boolean => build_list!(BooleanBuilder, Boolean, values, size), - DataType::Int8 => build_list!(Int8Builder, Int8, values, size), - DataType::Int16 => build_list!(Int16Builder, Int16, values, size), - DataType::Int32 => build_list!(Int32Builder, Int32, values, size), - DataType::Int64 => build_list!(Int64Builder, Int64, values, size), - DataType::UInt8 => build_list!(UInt8Builder, UInt8, values, size), - DataType::UInt16 => build_list!(UInt16Builder, UInt16, values, size), - DataType::UInt32 => build_list!(UInt32Builder, UInt32, values, size), - DataType::UInt64 => build_list!(UInt64Builder, UInt64, values, size), - DataType::Utf8 => build_list!(StringBuilder, Utf8, values, size), - DataType::Float32 => build_list!(Float32Builder, Float32, values, size), - DataType::Float64 => build_list!(Float64Builder, Float64, values, size), - DataType::Timestamp(unit, tz) => { - build_timestamp_list!(unit.clone(), tz.clone(), values, size) - } - &DataType::LargeUtf8 => { - build_list!(LargeStringBuilder, LargeUtf8, values, size) - } - _ => ScalarValue::iter_to_array_list( - repeat(self.clone()).take(size), - &DataType::List(Box::new(Field::new( - "item", - data_type.as_ref().clone(), - true, - ))), - ) - .unwrap(), - }), - ScalarValue::Date32(e) => { - build_array_from_option!(Date32, Date32Array, e, size) - } - ScalarValue::Date64(e) => { - build_array_from_option!(Date64, Date64Array, e, size) - } - ScalarValue::IntervalDayTime(e) => build_array_from_option!( - Interval, - IntervalUnit::DayTime, - IntervalDayTimeArray, - e, - size - ), - ScalarValue::IntervalYearMonth(e) => build_array_from_option!( - Interval, - IntervalUnit::YearMonth, - IntervalYearMonthArray, - e, - size - ), - ScalarValue::IntervalMonthDayNano(e) => build_array_from_option!( - Interval, - IntervalUnit::MonthDayNano, - IntervalMonthDayNanoArray, - e, - size - ), - ScalarValue::Struct(values, fields) => match values { - Some(values) => { - let field_values: Vec<_> = fields - .iter() - .zip(values.iter()) - .map(|(field, value)| { - (field.clone(), value.to_array_of_size(size)) - }) - .collect(); - - Arc::new(StructArray::from(field_values)) - } - None => { - let field_values: Vec<_> = fields - .iter() - .map(|field| { - let none_field = Self::try_from(field.data_type()).expect( - "Failed to construct null ScalarValue from Struct field type" - ); - (field.clone(), none_field.to_array_of_size(size)) - }) - .collect(); - - Arc::new(StructArray::from(field_values)) - } - }, - } - } - - fn get_decimal_value_from_array( - array: &ArrayRef, - index: usize, - precision: &usize, - scale: &usize, - ) -> ScalarValue { - let array = array.as_any().downcast_ref::().unwrap(); - if array.is_null(index) { - ScalarValue::Decimal128(None, *precision, *scale) - } else { - ScalarValue::Decimal128(Some(array.value(index)), *precision, *scale) - } - } - - /// Converts a value in `array` at `index` into a ScalarValue - pub fn try_from_array(array: &ArrayRef, index: usize) -> Result { - // handle NULL value - if !array.is_valid(index) { - return array.data_type().try_into(); - } - - Ok(match array.data_type() { - DataType::Decimal(precision, scale) => { - ScalarValue::get_decimal_value_from_array(array, index, precision, scale) - } - DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean), - DataType::Float64 => typed_cast!(array, index, Float64Array, Float64), - DataType::Float32 => typed_cast!(array, index, Float32Array, Float32), - DataType::UInt64 => typed_cast!(array, index, UInt64Array, UInt64), - DataType::UInt32 => typed_cast!(array, index, UInt32Array, UInt32), - DataType::UInt16 => typed_cast!(array, index, UInt16Array, UInt16), - DataType::UInt8 => typed_cast!(array, index, UInt8Array, UInt8), - DataType::Int64 => typed_cast!(array, index, Int64Array, Int64), - DataType::Int32 => typed_cast!(array, index, Int32Array, Int32), - DataType::Int16 => typed_cast!(array, index, Int16Array, Int16), - DataType::Int8 => typed_cast!(array, index, Int8Array, Int8), - DataType::Binary => typed_cast!(array, index, BinaryArray, Binary), - DataType::LargeBinary => { - typed_cast!(array, index, LargeBinaryArray, LargeBinary) - } - DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8), - DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray, LargeUtf8), - DataType::List(nested_type) => { - let list_array = - array.as_any().downcast_ref::().ok_or_else(|| { - DataFusionError::Internal( - "Failed to downcast ListArray".to_string(), - ) - })?; - let value = match list_array.is_null(index) { - true => None, - false => { - let nested_array = list_array.value(index); - let scalar_vec = (0..nested_array.len()) - .map(|i| ScalarValue::try_from_array(&nested_array, i)) - .collect::>>()?; - Some(scalar_vec) - } - }; - let value = value.map(Box::new); - let data_type = Box::new(nested_type.data_type().clone()); - ScalarValue::List(value, data_type) - } - DataType::Date32 => { - typed_cast!(array, index, Date32Array, Date32) - } - DataType::Date64 => { - typed_cast!(array, index, Date64Array, Date64) - } - DataType::Timestamp(TimeUnit::Second, tz_opt) => { - typed_cast_tz!( - array, - index, - TimestampSecondArray, - TimestampSecond, - tz_opt - ) - } - DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { - typed_cast_tz!( - array, - index, - TimestampMillisecondArray, - TimestampMillisecond, - tz_opt - ) - } - DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { - typed_cast_tz!( - array, - index, - TimestampMicrosecondArray, - TimestampMicrosecond, - tz_opt - ) - } - DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { - typed_cast_tz!( - array, - index, - TimestampNanosecondArray, - TimestampNanosecond, - tz_opt - ) - } - DataType::Dictionary(index_type, _) => { - let (values, values_index) = match **index_type { - DataType::Int8 => get_dict_value::(array, index)?, - DataType::Int16 => get_dict_value::(array, index)?, - DataType::Int32 => get_dict_value::(array, index)?, - DataType::Int64 => get_dict_value::(array, index)?, - DataType::UInt8 => get_dict_value::(array, index)?, - DataType::UInt16 => get_dict_value::(array, index)?, - DataType::UInt32 => get_dict_value::(array, index)?, - DataType::UInt64 => get_dict_value::(array, index)?, - _ => { - return Err(DataFusionError::Internal(format!( - "Index type not supported while creating scalar from dictionary: {}", - array.data_type(), - ))); - } - }; - - match values_index { - Some(values_index) => Self::try_from_array(values, values_index)?, - // was null - None => values.data_type().try_into()?, - } - } - DataType::Struct(fields) => { - let array = - array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal( - "Failed to downcast ArrayRef to StructArray".to_string(), - ) - })?; - let mut field_values: Vec = Vec::new(); - for col_index in 0..array.num_columns() { - let col_array = array.column(col_index); - let col_scalar = ScalarValue::try_from_array(col_array, index)?; - field_values.push(col_scalar); - } - Self::Struct(Some(Box::new(field_values)), Box::new(fields.clone())) - } - other => { - return Err(DataFusionError::NotImplemented(format!( - "Can't create a scalar from array of type \"{:?}\"", - other - ))); - } - }) - } - - fn eq_array_decimal( - array: &ArrayRef, - index: usize, - value: &Option, - precision: usize, - scale: usize, - ) -> bool { - let array = array.as_any().downcast_ref::().unwrap(); - if array.precision() != precision || array.scale() != scale { - return false; - } - match value { - None => array.is_null(index), - Some(v) => !array.is_null(index) && array.value(index) == *v, - } - } - - /// Compares a single row of array @ index for equality with self, - /// in an optimized fashion. - /// - /// This method implements an optimized version of: - /// - /// ```text - /// let arr_scalar = Self::try_from_array(array, index).unwrap(); - /// arr_scalar.eq(self) - /// ``` - /// - /// *Performance note*: the arrow compute kernels should be - /// preferred over this function if at all possible as they can be - /// vectorized and are generally much faster. - /// - /// This function has a few narrow usescases such as hash table key - /// comparisons where comparing a single row at a time is necessary. - #[inline] - pub fn eq_array(&self, array: &ArrayRef, index: usize) -> bool { - if let DataType::Dictionary(key_type, _) = array.data_type() { - return self.eq_array_dictionary(array, index, key_type); - } - - match self { - ScalarValue::Decimal128(v, precision, scale) => { - ScalarValue::eq_array_decimal(array, index, v, *precision, *scale) - } - ScalarValue::Boolean(val) => { - eq_array_primitive!(array, index, BooleanArray, val) - } - ScalarValue::Float32(val) => { - eq_array_primitive!(array, index, Float32Array, val) - } - ScalarValue::Float64(val) => { - eq_array_primitive!(array, index, Float64Array, val) - } - ScalarValue::Int8(val) => eq_array_primitive!(array, index, Int8Array, val), - ScalarValue::Int16(val) => eq_array_primitive!(array, index, Int16Array, val), - ScalarValue::Int32(val) => eq_array_primitive!(array, index, Int32Array, val), - ScalarValue::Int64(val) => eq_array_primitive!(array, index, Int64Array, val), - ScalarValue::UInt8(val) => eq_array_primitive!(array, index, UInt8Array, val), - ScalarValue::UInt16(val) => { - eq_array_primitive!(array, index, UInt16Array, val) - } - ScalarValue::UInt32(val) => { - eq_array_primitive!(array, index, UInt32Array, val) - } - ScalarValue::UInt64(val) => { - eq_array_primitive!(array, index, UInt64Array, val) - } - ScalarValue::Utf8(val) => eq_array_primitive!(array, index, StringArray, val), - ScalarValue::LargeUtf8(val) => { - eq_array_primitive!(array, index, LargeStringArray, val) - } - ScalarValue::Binary(val) => { - eq_array_primitive!(array, index, BinaryArray, val) - } - ScalarValue::LargeBinary(val) => { - eq_array_primitive!(array, index, LargeBinaryArray, val) - } - ScalarValue::List(_, _) => unimplemented!(), - ScalarValue::Date32(val) => { - eq_array_primitive!(array, index, Date32Array, val) - } - ScalarValue::Date64(val) => { - eq_array_primitive!(array, index, Date64Array, val) - } - ScalarValue::TimestampSecond(val, _) => { - eq_array_primitive!(array, index, TimestampSecondArray, val) - } - ScalarValue::TimestampMillisecond(val, _) => { - eq_array_primitive!(array, index, TimestampMillisecondArray, val) - } - ScalarValue::TimestampMicrosecond(val, _) => { - eq_array_primitive!(array, index, TimestampMicrosecondArray, val) - } - ScalarValue::TimestampNanosecond(val, _) => { - eq_array_primitive!(array, index, TimestampNanosecondArray, val) - } - ScalarValue::IntervalYearMonth(val) => { - eq_array_primitive!(array, index, IntervalYearMonthArray, val) - } - ScalarValue::IntervalDayTime(val) => { - eq_array_primitive!(array, index, IntervalDayTimeArray, val) - } - ScalarValue::IntervalMonthDayNano(val) => { - eq_array_primitive!(array, index, IntervalMonthDayNanoArray, val) - } - ScalarValue::Struct(_, _) => unimplemented!(), - } - } - - /// Compares a dictionary array with indexes of type `key_type` - /// with the array @ index for equality with self - fn eq_array_dictionary( - &self, - array: &ArrayRef, - index: usize, - key_type: &DataType, - ) -> bool { - let (values, values_index) = match key_type { - DataType::Int8 => get_dict_value::(array, index).unwrap(), - DataType::Int16 => get_dict_value::(array, index).unwrap(), - DataType::Int32 => get_dict_value::(array, index).unwrap(), - DataType::Int64 => get_dict_value::(array, index).unwrap(), - DataType::UInt8 => get_dict_value::(array, index).unwrap(), - DataType::UInt16 => get_dict_value::(array, index).unwrap(), - DataType::UInt32 => get_dict_value::(array, index).unwrap(), - DataType::UInt64 => get_dict_value::(array, index).unwrap(), - _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), - }; - - match values_index { - Some(values_index) => self.eq_array(values, values_index), - None => self.is_null(), - } - } -} - -macro_rules! impl_scalar { - ($ty:ty, $scalar:tt) => { - impl From<$ty> for ScalarValue { - fn from(value: $ty) -> Self { - ScalarValue::$scalar(Some(value)) - } - } - - impl From> for ScalarValue { - fn from(value: Option<$ty>) -> Self { - ScalarValue::$scalar(value) - } - } - }; -} - -impl_scalar!(f64, Float64); -impl_scalar!(f32, Float32); -impl_scalar!(i8, Int8); -impl_scalar!(i16, Int16); -impl_scalar!(i32, Int32); -impl_scalar!(i64, Int64); -impl_scalar!(bool, Boolean); -impl_scalar!(u8, UInt8); -impl_scalar!(u16, UInt16); -impl_scalar!(u32, UInt32); -impl_scalar!(u64, UInt64); - -impl From<&str> for ScalarValue { - fn from(value: &str) -> Self { - Some(value).into() - } -} - -impl From> for ScalarValue { - fn from(value: Option<&str>) -> Self { - let value = value.map(|s| s.to_string()); - ScalarValue::Utf8(value) - } -} - -impl FromStr for ScalarValue { - type Err = Infallible; - - fn from_str(s: &str) -> std::result::Result { - Ok(s.into()) - } -} - -impl From> for ScalarValue { - fn from(value: Vec<(&str, ScalarValue)>) -> Self { - let (fields, scalars): (Vec<_>, Vec<_>) = value - .into_iter() - .map(|(name, scalar)| { - (Field::new(name, scalar.get_datatype(), false), scalar) - }) - .unzip(); - - Self::Struct(Some(Box::new(scalars)), Box::new(fields)) - } -} - -macro_rules! impl_try_from { - ($SCALAR:ident, $NATIVE:ident) => { - impl TryFrom for $NATIVE { - type Error = DataFusionError; - - fn try_from(value: ScalarValue) -> Result { - match value { - ScalarValue::$SCALAR(Some(inner_value)) => Ok(inner_value), - _ => Err(DataFusionError::Internal(format!( - "Cannot convert {:?} to {}", - value, - std::any::type_name::() - ))), - } - } - } - }; -} - -impl_try_from!(Int8, i8); -impl_try_from!(Int16, i16); - -// special implementation for i32 because of Date32 -impl TryFrom for i32 { - type Error = DataFusionError; - - fn try_from(value: ScalarValue) -> Result { - match value { - ScalarValue::Int32(Some(inner_value)) - | ScalarValue::Date32(Some(inner_value)) => Ok(inner_value), - _ => Err(DataFusionError::Internal(format!( - "Cannot convert {:?} to {}", - value, - std::any::type_name::() - ))), - } - } -} - -// special implementation for i64 because of TimeNanosecond -impl TryFrom for i64 { - type Error = DataFusionError; - - fn try_from(value: ScalarValue) -> Result { - match value { - ScalarValue::Int64(Some(inner_value)) - | ScalarValue::Date64(Some(inner_value)) - | ScalarValue::TimestampNanosecond(Some(inner_value), _) - | ScalarValue::TimestampMicrosecond(Some(inner_value), _) - | ScalarValue::TimestampMillisecond(Some(inner_value), _) - | ScalarValue::TimestampSecond(Some(inner_value), _) => Ok(inner_value), - _ => Err(DataFusionError::Internal(format!( - "Cannot convert {:?} to {}", - value, - std::any::type_name::() - ))), - } - } -} - -// special implementation for i128 because of Decimal128 -impl TryFrom for i128 { - type Error = DataFusionError; - - fn try_from(value: ScalarValue) -> Result { - match value { - ScalarValue::Decimal128(Some(inner_value), _, _) => Ok(inner_value), - _ => Err(DataFusionError::Internal(format!( - "Cannot convert {:?} to {}", - value, - std::any::type_name::() - ))), - } - } -} - -impl_try_from!(UInt8, u8); -impl_try_from!(UInt16, u16); -impl_try_from!(UInt32, u32); -impl_try_from!(UInt64, u64); -impl_try_from!(Float32, f32); -impl_try_from!(Float64, f64); -impl_try_from!(Boolean, bool); - -impl TryFrom<&DataType> for ScalarValue { - type Error = DataFusionError; - - /// Create a Null instance of ScalarValue for this datatype - fn try_from(datatype: &DataType) -> Result { - Ok(match datatype { - DataType::Boolean => ScalarValue::Boolean(None), - DataType::Float64 => ScalarValue::Float64(None), - DataType::Float32 => ScalarValue::Float32(None), - DataType::Int8 => ScalarValue::Int8(None), - DataType::Int16 => ScalarValue::Int16(None), - DataType::Int32 => ScalarValue::Int32(None), - DataType::Int64 => ScalarValue::Int64(None), - DataType::UInt8 => ScalarValue::UInt8(None), - DataType::UInt16 => ScalarValue::UInt16(None), - DataType::UInt32 => ScalarValue::UInt32(None), - DataType::UInt64 => ScalarValue::UInt64(None), - DataType::Decimal(precision, scale) => { - ScalarValue::Decimal128(None, *precision, *scale) - } - DataType::Utf8 => ScalarValue::Utf8(None), - DataType::LargeUtf8 => ScalarValue::LargeUtf8(None), - DataType::Date32 => ScalarValue::Date32(None), - DataType::Date64 => ScalarValue::Date64(None), - DataType::Timestamp(TimeUnit::Second, tz_opt) => { - ScalarValue::TimestampSecond(None, tz_opt.clone()) - } - DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { - ScalarValue::TimestampMillisecond(None, tz_opt.clone()) - } - DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { - ScalarValue::TimestampMicrosecond(None, tz_opt.clone()) - } - DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { - ScalarValue::TimestampNanosecond(None, tz_opt.clone()) - } - DataType::Dictionary(_index_type, value_type) => { - value_type.as_ref().try_into()? - } - DataType::List(ref nested_type) => { - ScalarValue::List(None, Box::new(nested_type.data_type().clone())) - } - DataType::Struct(fields) => { - ScalarValue::Struct(None, Box::new(fields.clone())) - } - _ => { - return Err(DataFusionError::NotImplemented(format!( - "Can't create a scalar from data_type \"{:?}\"", - datatype - ))); - } - }) - } -} - -macro_rules! format_option { - ($F:expr, $EXPR:expr) => {{ - match $EXPR { - Some(e) => write!($F, "{}", e), - None => write!($F, "NULL"), - } - }}; -} +//! ScalarValue reimported from datafusion-common -impl fmt::Display for ScalarValue { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - ScalarValue::Decimal128(v, p, s) => { - write!(f, "{:?},{:?},{:?}", v, p, s)?; - } - ScalarValue::Boolean(e) => format_option!(f, e)?, - ScalarValue::Float32(e) => format_option!(f, e)?, - ScalarValue::Float64(e) => format_option!(f, e)?, - ScalarValue::Int8(e) => format_option!(f, e)?, - ScalarValue::Int16(e) => format_option!(f, e)?, - ScalarValue::Int32(e) => format_option!(f, e)?, - ScalarValue::Int64(e) => format_option!(f, e)?, - ScalarValue::UInt8(e) => format_option!(f, e)?, - ScalarValue::UInt16(e) => format_option!(f, e)?, - ScalarValue::UInt32(e) => format_option!(f, e)?, - ScalarValue::UInt64(e) => format_option!(f, e)?, - ScalarValue::TimestampSecond(e, _) => format_option!(f, e)?, - ScalarValue::TimestampMillisecond(e, _) => format_option!(f, e)?, - ScalarValue::TimestampMicrosecond(e, _) => format_option!(f, e)?, - ScalarValue::TimestampNanosecond(e, _) => format_option!(f, e)?, - ScalarValue::Utf8(e) => format_option!(f, e)?, - ScalarValue::LargeUtf8(e) => format_option!(f, e)?, - ScalarValue::Binary(e) => match e { - Some(l) => write!( - f, - "{}", - l.iter() - .map(|v| format!("{}", v)) - .collect::>() - .join(",") - )?, - None => write!(f, "NULL")?, - }, - ScalarValue::LargeBinary(e) => match e { - Some(l) => write!( - f, - "{}", - l.iter() - .map(|v| format!("{}", v)) - .collect::>() - .join(",") - )?, - None => write!(f, "NULL")?, - }, - ScalarValue::List(e, _) => match e { - Some(l) => write!( - f, - "{}", - l.iter() - .map(|v| format!("{}", v)) - .collect::>() - .join(",") - )?, - None => write!(f, "NULL")?, - }, - ScalarValue::Date32(e) => format_option!(f, e)?, - ScalarValue::Date64(e) => format_option!(f, e)?, - ScalarValue::IntervalDayTime(e) => format_option!(f, e)?, - ScalarValue::IntervalYearMonth(e) => format_option!(f, e)?, - ScalarValue::IntervalMonthDayNano(e) => format_option!(f, e)?, - ScalarValue::Struct(e, fields) => match e { - Some(l) => write!( - f, - "{{{}}}", - l.iter() - .zip(fields.iter()) - .map(|(value, field)| format!("{}:{}", field.name(), value)) - .collect::>() - .join(",") - )?, - None => write!(f, "NULL")?, - }, - }; - Ok(()) - } -} - -impl fmt::Debug for ScalarValue { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - ScalarValue::Decimal128(_, _, _) => write!(f, "Decimal128({})", self), - ScalarValue::Boolean(_) => write!(f, "Boolean({})", self), - ScalarValue::Float32(_) => write!(f, "Float32({})", self), - ScalarValue::Float64(_) => write!(f, "Float64({})", self), - ScalarValue::Int8(_) => write!(f, "Int8({})", self), - ScalarValue::Int16(_) => write!(f, "Int16({})", self), - ScalarValue::Int32(_) => write!(f, "Int32({})", self), - ScalarValue::Int64(_) => write!(f, "Int64({})", self), - ScalarValue::UInt8(_) => write!(f, "UInt8({})", self), - ScalarValue::UInt16(_) => write!(f, "UInt16({})", self), - ScalarValue::UInt32(_) => write!(f, "UInt32({})", self), - ScalarValue::UInt64(_) => write!(f, "UInt64({})", self), - ScalarValue::TimestampSecond(_, tz_opt) => { - write!(f, "TimestampSecond({}, {:?})", self, tz_opt) - } - ScalarValue::TimestampMillisecond(_, tz_opt) => { - write!(f, "TimestampMillisecond({}, {:?})", self, tz_opt) - } - ScalarValue::TimestampMicrosecond(_, tz_opt) => { - write!(f, "TimestampMicrosecond({}, {:?})", self, tz_opt) - } - ScalarValue::TimestampNanosecond(_, tz_opt) => { - write!(f, "TimestampNanosecond({}, {:?})", self, tz_opt) - } - ScalarValue::Utf8(None) => write!(f, "Utf8({})", self), - ScalarValue::Utf8(Some(_)) => write!(f, "Utf8(\"{}\")", self), - ScalarValue::LargeUtf8(None) => write!(f, "LargeUtf8({})", self), - ScalarValue::LargeUtf8(Some(_)) => write!(f, "LargeUtf8(\"{}\")", self), - ScalarValue::Binary(None) => write!(f, "Binary({})", self), - ScalarValue::Binary(Some(_)) => write!(f, "Binary(\"{}\")", self), - ScalarValue::LargeBinary(None) => write!(f, "LargeBinary({})", self), - ScalarValue::LargeBinary(Some(_)) => write!(f, "LargeBinary(\"{}\")", self), - ScalarValue::List(_, _) => write!(f, "List([{}])", self), - ScalarValue::Date32(_) => write!(f, "Date32(\"{}\")", self), - ScalarValue::Date64(_) => write!(f, "Date64(\"{}\")", self), - ScalarValue::IntervalDayTime(_) => { - write!(f, "IntervalDayTime(\"{}\")", self) - } - ScalarValue::IntervalYearMonth(_) => { - write!(f, "IntervalYearMonth(\"{}\")", self) - } - ScalarValue::IntervalMonthDayNano(_) => { - write!(f, "IntervalMonthDayNano(\"{}\")", self) - } - ScalarValue::Struct(e, fields) => { - // Use Debug representation of field values - match e { - Some(l) => write!( - f, - "Struct({{{}}})", - l.iter() - .zip(fields.iter()) - .map(|(value, field)| format!("{}:{:?}", field.name(), value)) - .collect::>() - .join(",") - ), - None => write!(f, "Struct(NULL)"), - } - } - } - } -} - -/// Trait used to map a NativeTime to a ScalarType. -pub trait ScalarType { - /// returns a scalar from an optional T - fn scalar(r: Option) -> ScalarValue; -} - -impl ScalarType for Float32Type { - fn scalar(r: Option) -> ScalarValue { - ScalarValue::Float32(r) - } -} - -impl ScalarType for TimestampSecondType { - fn scalar(r: Option) -> ScalarValue { - ScalarValue::TimestampSecond(r, None) - } -} - -impl ScalarType for TimestampMillisecondType { - fn scalar(r: Option) -> ScalarValue { - ScalarValue::TimestampMillisecond(r, None) - } -} - -impl ScalarType for TimestampMicrosecondType { - fn scalar(r: Option) -> ScalarValue { - ScalarValue::TimestampMicrosecond(r, None) - } -} - -impl ScalarType for TimestampNanosecondType { - fn scalar(r: Option) -> ScalarValue { - ScalarValue::TimestampNanosecond(r, None) - } -} +pub use datafusion_common::{ + ScalarType, ScalarValue, MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128, +}; #[cfg(test)] mod tests { use super::*; use crate::from_slice::FromSlice; + use arrow::{array::*, datatypes::*}; + use std::cmp::Ordering; + use std::sync::Arc; #[test] fn scalar_decimal_test() { From 09c67d5af32aee107e87b9ddb93226707ccaa4fb Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Tue, 8 Feb 2022 01:36:12 +0800 Subject: [PATCH 41/50] include window frames and operator into datafusion-expr (#1761) --- datafusion-expr/Cargo.toml | 1 + datafusion-expr/src/lib.rs | 4 + datafusion-expr/src/operator.rs | 97 +++++ datafusion-expr/src/window_frame.rs | 381 +++++++++++++++++++ datafusion/src/logical_plan/operators.rs | 83 +--- datafusion/src/logical_plan/window_frames.rs | 363 +----------------- 6 files changed, 487 insertions(+), 442 deletions(-) create mode 100644 datafusion-expr/src/operator.rs create mode 100644 datafusion-expr/src/window_frame.rs diff --git a/datafusion-expr/Cargo.toml b/datafusion-expr/Cargo.toml index c3be893ae87e..73a5fcd36152 100644 --- a/datafusion-expr/Cargo.toml +++ b/datafusion-expr/Cargo.toml @@ -37,3 +37,4 @@ path = "src/lib.rs" [dependencies] datafusion-common = { path = "../datafusion-common", version = "6.0.0" } arrow = { version = "8.0.0", features = ["prettyprint"] } +sqlparser = "0.13" diff --git a/datafusion-expr/src/lib.rs b/datafusion-expr/src/lib.rs index b6eaaf7c6659..13fa93ed6a2e 100644 --- a/datafusion-expr/src/lib.rs +++ b/datafusion-expr/src/lib.rs @@ -16,7 +16,11 @@ // under the License. mod aggregate_function; +mod operator; +mod window_frame; mod window_function; pub use aggregate_function::AggregateFunction; +pub use operator::Operator; +pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; pub use window_function::{BuiltInWindowFunction, WindowFunction}; diff --git a/datafusion-expr/src/operator.rs b/datafusion-expr/src/operator.rs new file mode 100644 index 000000000000..e6b7e35a0a5e --- /dev/null +++ b/datafusion-expr/src/operator.rs @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fmt; + +/// Operators applied to expressions +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash)] +pub enum Operator { + /// Expressions are equal + Eq, + /// Expressions are not equal + NotEq, + /// Left side is smaller than right side + Lt, + /// Left side is smaller or equal to right side + LtEq, + /// Left side is greater than right side + Gt, + /// Left side is greater or equal to right side + GtEq, + /// Addition + Plus, + /// Subtraction + Minus, + /// Multiplication operator, like `*` + Multiply, + /// Division operator, like `/` + Divide, + /// Remainder operator, like `%` + Modulo, + /// Logical AND, like `&&` + And, + /// Logical OR, like `||` + Or, + /// Matches a wildcard pattern + Like, + /// Does not match a wildcard pattern + NotLike, + /// IS DISTINCT FROM + IsDistinctFrom, + /// IS NOT DISTINCT FROM + IsNotDistinctFrom, + /// Case sensitive regex match + RegexMatch, + /// Case insensitive regex match + RegexIMatch, + /// Case sensitive regex not match + RegexNotMatch, + /// Case insensitive regex not match + RegexNotIMatch, + /// Bitwise and, like `&` + BitwiseAnd, +} + +impl fmt::Display for Operator { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let display = match &self { + Operator::Eq => "=", + Operator::NotEq => "!=", + Operator::Lt => "<", + Operator::LtEq => "<=", + Operator::Gt => ">", + Operator::GtEq => ">=", + Operator::Plus => "+", + Operator::Minus => "-", + Operator::Multiply => "*", + Operator::Divide => "/", + Operator::Modulo => "%", + Operator::And => "AND", + Operator::Or => "OR", + Operator::Like => "LIKE", + Operator::NotLike => "NOT LIKE", + Operator::RegexMatch => "~", + Operator::RegexIMatch => "~*", + Operator::RegexNotMatch => "!~", + Operator::RegexNotIMatch => "!~*", + Operator::IsDistinctFrom => "IS DISTINCT FROM", + Operator::IsNotDistinctFrom => "IS NOT DISTINCT FROM", + Operator::BitwiseAnd => "&", + }; + write!(f, "{}", display) + } +} diff --git a/datafusion-expr/src/window_frame.rs b/datafusion-expr/src/window_frame.rs new file mode 100644 index 000000000000..ba65a5088b61 --- /dev/null +++ b/datafusion-expr/src/window_frame.rs @@ -0,0 +1,381 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Window frame +//! +//! The frame-spec determines which output rows are read by an aggregate window function. The frame-spec consists of four parts: +//! - A frame type - either ROWS, RANGE or GROUPS, +//! - A starting frame boundary, +//! - An ending frame boundary, +//! - An EXCLUDE clause. + +use datafusion_common::{DataFusionError, Result}; +use sqlparser::ast; +use std::cmp::Ordering; +use std::convert::{From, TryFrom}; +use std::fmt; +use std::hash::{Hash, Hasher}; + +/// The frame-spec determines which output rows are read by an aggregate window function. +/// +/// The ending frame boundary can be omitted (if the BETWEEN and AND keywords that surround the +/// starting frame boundary are also omitted), in which case the ending frame boundary defaults to +/// CURRENT ROW. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)] +pub struct WindowFrame { + /// A frame type - either ROWS, RANGE or GROUPS + pub units: WindowFrameUnits, + /// A starting frame boundary + pub start_bound: WindowFrameBound, + /// An ending frame boundary + pub end_bound: WindowFrameBound, +} + +impl fmt::Display for WindowFrame { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{} BETWEEN {} AND {}", + self.units, self.start_bound, self.end_bound + )?; + Ok(()) + } +} + +impl TryFrom for WindowFrame { + type Error = DataFusionError; + + fn try_from(value: ast::WindowFrame) -> Result { + let start_bound = value.start_bound.into(); + let end_bound = value + .end_bound + .map(WindowFrameBound::from) + .unwrap_or(WindowFrameBound::CurrentRow); + + if let WindowFrameBound::Following(None) = start_bound { + Err(DataFusionError::Execution( + "Invalid window frame: start bound cannot be unbounded following" + .to_owned(), + )) + } else if let WindowFrameBound::Preceding(None) = end_bound { + Err(DataFusionError::Execution( + "Invalid window frame: end bound cannot be unbounded preceding" + .to_owned(), + )) + } else if start_bound > end_bound { + Err(DataFusionError::Execution(format!( + "Invalid window frame: start bound ({}) cannot be larger than end bound ({})", + start_bound, end_bound + ))) + } else { + let units = value.units.into(); + if units == WindowFrameUnits::Range { + for bound in &[start_bound, end_bound] { + match bound { + WindowFrameBound::Preceding(Some(v)) + | WindowFrameBound::Following(Some(v)) + if *v > 0 => + { + Err(DataFusionError::NotImplemented(format!( + "With WindowFrameUnits={}, the bound cannot be {} PRECEDING or FOLLOWING at the moment", + units, v + ))) + } + _ => Ok(()), + }?; + } + } + Ok(Self { + units, + start_bound, + end_bound, + }) + } + } +} + +impl Default for WindowFrame { + fn default() -> Self { + WindowFrame { + units: WindowFrameUnits::Range, + start_bound: WindowFrameBound::Preceding(None), + end_bound: WindowFrameBound::CurrentRow, + } + } +} + +/// There are five ways to describe starting and ending frame boundaries: +/// +/// 1. UNBOUNDED PRECEDING +/// 2. PRECEDING +/// 3. CURRENT ROW +/// 4. FOLLOWING +/// 5. UNBOUNDED FOLLOWING +/// +/// in this implementation we'll only allow to be u64 (i.e. no dynamic boundary) +#[derive(Debug, Clone, Copy, Eq)] +pub enum WindowFrameBound { + /// 1. UNBOUNDED PRECEDING + /// The frame boundary is the first row in the partition. + /// + /// 2. PRECEDING + /// must be a non-negative constant numeric expression. The boundary is a row that + /// is "units" prior to the current row. + Preceding(Option), + /// 3. The current row. + /// + /// For RANGE and GROUPS frame types, peers of the current row are also + /// included in the frame, unless specifically excluded by the EXCLUDE clause. + /// This is true regardless of whether CURRENT ROW is used as the starting or ending frame + /// boundary. + CurrentRow, + /// 4. This is the same as " PRECEDING" except that the boundary is units after the + /// current rather than before the current row. + /// + /// 5. UNBOUNDED FOLLOWING + /// The frame boundary is the last row in the partition. + Following(Option), +} + +impl From for WindowFrameBound { + fn from(value: ast::WindowFrameBound) -> Self { + match value { + ast::WindowFrameBound::Preceding(v) => Self::Preceding(v), + ast::WindowFrameBound::Following(v) => Self::Following(v), + ast::WindowFrameBound::CurrentRow => Self::CurrentRow, + } + } +} + +impl fmt::Display for WindowFrameBound { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + WindowFrameBound::CurrentRow => f.write_str("CURRENT ROW"), + WindowFrameBound::Preceding(None) => f.write_str("UNBOUNDED PRECEDING"), + WindowFrameBound::Following(None) => f.write_str("UNBOUNDED FOLLOWING"), + WindowFrameBound::Preceding(Some(n)) => write!(f, "{} PRECEDING", n), + WindowFrameBound::Following(Some(n)) => write!(f, "{} FOLLOWING", n), + } + } +} + +impl PartialEq for WindowFrameBound { + fn eq(&self, other: &Self) -> bool { + self.cmp(other) == Ordering::Equal + } +} + +impl PartialOrd for WindowFrameBound { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for WindowFrameBound { + fn cmp(&self, other: &Self) -> Ordering { + self.get_rank().cmp(&other.get_rank()) + } +} + +impl Hash for WindowFrameBound { + fn hash(&self, state: &mut H) { + self.get_rank().hash(state) + } +} + +impl WindowFrameBound { + /// get the rank of this window frame bound. + /// + /// the rank is a tuple of (u8, u64) because we'll firstly compare the kind and then the value + /// which requires special handling e.g. with preceding the larger the value the smaller the + /// rank and also for 0 preceding / following it is the same as current row + fn get_rank(&self) -> (u8, u64) { + match self { + WindowFrameBound::Preceding(None) => (0, 0), + WindowFrameBound::Following(None) => (4, 0), + WindowFrameBound::Preceding(Some(0)) + | WindowFrameBound::CurrentRow + | WindowFrameBound::Following(Some(0)) => (2, 0), + WindowFrameBound::Preceding(Some(v)) => (1, u64::MAX - *v), + WindowFrameBound::Following(Some(v)) => (3, *v), + } + } +} + +/// There are three frame types: ROWS, GROUPS, and RANGE. The frame type determines how the +/// starting and ending boundaries of the frame are measured. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)] +pub enum WindowFrameUnits { + /// The ROWS frame type means that the starting and ending boundaries for the frame are + /// determined by counting individual rows relative to the current row. + Rows, + /// The RANGE frame type requires that the ORDER BY clause of the window have exactly one + /// term. Call that term "X". With the RANGE frame type, the elements of the frame are + /// determined by computing the value of expression X for all rows in the partition and framing + /// those rows for which the value of X is within a certain range of the value of X for the + /// current row. + Range, + /// The GROUPS frame type means that the starting and ending boundaries are determine + /// by counting "groups" relative to the current group. A "group" is a set of rows that all have + /// equivalent values for all all terms of the window ORDER BY clause. + Groups, +} + +impl fmt::Display for WindowFrameUnits { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(match self { + WindowFrameUnits::Rows => "ROWS", + WindowFrameUnits::Range => "RANGE", + WindowFrameUnits::Groups => "GROUPS", + }) + } +} + +impl From for WindowFrameUnits { + fn from(value: ast::WindowFrameUnits) -> Self { + match value { + ast::WindowFrameUnits::Range => Self::Range, + ast::WindowFrameUnits::Groups => Self::Groups, + ast::WindowFrameUnits::Rows => Self::Rows, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_window_frame_creation() -> Result<()> { + let window_frame = ast::WindowFrame { + units: ast::WindowFrameUnits::Range, + start_bound: ast::WindowFrameBound::Following(None), + end_bound: None, + }; + let result = WindowFrame::try_from(window_frame); + assert_eq!( + result.err().unwrap().to_string(), + "Execution error: Invalid window frame: start bound cannot be unbounded following" + .to_owned() + ); + + let window_frame = ast::WindowFrame { + units: ast::WindowFrameUnits::Range, + start_bound: ast::WindowFrameBound::Preceding(None), + end_bound: Some(ast::WindowFrameBound::Preceding(None)), + }; + let result = WindowFrame::try_from(window_frame); + assert_eq!( + result.err().unwrap().to_string(), + "Execution error: Invalid window frame: end bound cannot be unbounded preceding" + .to_owned() + ); + + let window_frame = ast::WindowFrame { + units: ast::WindowFrameUnits::Range, + start_bound: ast::WindowFrameBound::Preceding(Some(1)), + end_bound: Some(ast::WindowFrameBound::Preceding(Some(2))), + }; + let result = WindowFrame::try_from(window_frame); + assert_eq!( + result.err().unwrap().to_string(), + "Execution error: Invalid window frame: start bound (1 PRECEDING) cannot be larger than end bound (2 PRECEDING)".to_owned() + ); + + let window_frame = ast::WindowFrame { + units: ast::WindowFrameUnits::Range, + start_bound: ast::WindowFrameBound::Preceding(Some(2)), + end_bound: Some(ast::WindowFrameBound::Preceding(Some(1))), + }; + let result = WindowFrame::try_from(window_frame); + assert_eq!( + result.err().unwrap().to_string(), + "This feature is not implemented: With WindowFrameUnits=RANGE, the bound cannot be 2 PRECEDING or FOLLOWING at the moment".to_owned() + ); + + let window_frame = ast::WindowFrame { + units: ast::WindowFrameUnits::Rows, + start_bound: ast::WindowFrameBound::Preceding(Some(2)), + end_bound: Some(ast::WindowFrameBound::Preceding(Some(1))), + }; + let result = WindowFrame::try_from(window_frame); + assert!(result.is_ok()); + Ok(()) + } + + #[test] + fn test_eq() { + assert_eq!( + WindowFrameBound::Preceding(Some(0)), + WindowFrameBound::CurrentRow + ); + assert_eq!( + WindowFrameBound::CurrentRow, + WindowFrameBound::Following(Some(0)) + ); + assert_eq!( + WindowFrameBound::Following(Some(2)), + WindowFrameBound::Following(Some(2)) + ); + assert_eq!( + WindowFrameBound::Following(None), + WindowFrameBound::Following(None) + ); + assert_eq!( + WindowFrameBound::Preceding(Some(2)), + WindowFrameBound::Preceding(Some(2)) + ); + assert_eq!( + WindowFrameBound::Preceding(None), + WindowFrameBound::Preceding(None) + ); + } + + #[test] + fn test_ord() { + assert!(WindowFrameBound::Preceding(Some(1)) < WindowFrameBound::CurrentRow); + // ! yes this is correct! + assert!( + WindowFrameBound::Preceding(Some(2)) < WindowFrameBound::Preceding(Some(1)) + ); + assert!( + WindowFrameBound::Preceding(Some(u64::MAX)) + < WindowFrameBound::Preceding(Some(u64::MAX - 1)) + ); + assert!( + WindowFrameBound::Preceding(None) + < WindowFrameBound::Preceding(Some(1000000)) + ); + assert!( + WindowFrameBound::Preceding(None) + < WindowFrameBound::Preceding(Some(u64::MAX)) + ); + assert!(WindowFrameBound::Preceding(None) < WindowFrameBound::Following(Some(0))); + assert!( + WindowFrameBound::Preceding(Some(1)) < WindowFrameBound::Following(Some(1)) + ); + assert!(WindowFrameBound::CurrentRow < WindowFrameBound::Following(Some(1))); + assert!( + WindowFrameBound::Following(Some(1)) < WindowFrameBound::Following(Some(2)) + ); + assert!(WindowFrameBound::Following(Some(2)) < WindowFrameBound::Following(None)); + assert!( + WindowFrameBound::Following(Some(u64::MAX)) + < WindowFrameBound::Following(None) + ); + } +} diff --git a/datafusion/src/logical_plan/operators.rs b/datafusion/src/logical_plan/operators.rs index 14ccab0537bd..813f7e0aac70 100644 --- a/datafusion/src/logical_plan/operators.rs +++ b/datafusion/src/logical_plan/operators.rs @@ -15,88 +15,9 @@ // specific language governing permissions and limitations // under the License. -use std::{fmt, ops}; - use super::{binary_expr, Expr}; - -/// Operators applied to expressions -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash)] -pub enum Operator { - /// Expressions are equal - Eq, - /// Expressions are not equal - NotEq, - /// Left side is smaller than right side - Lt, - /// Left side is smaller or equal to right side - LtEq, - /// Left side is greater than right side - Gt, - /// Left side is greater or equal to right side - GtEq, - /// Addition - Plus, - /// Subtraction - Minus, - /// Multiplication operator, like `*` - Multiply, - /// Division operator, like `/` - Divide, - /// Remainder operator, like `%` - Modulo, - /// Logical AND, like `&&` - And, - /// Logical OR, like `||` - Or, - /// Matches a wildcard pattern - Like, - /// Does not match a wildcard pattern - NotLike, - /// IS DISTINCT FROM - IsDistinctFrom, - /// IS NOT DISTINCT FROM - IsNotDistinctFrom, - /// Case sensitive regex match - RegexMatch, - /// Case insensitive regex match - RegexIMatch, - /// Case sensitive regex not match - RegexNotMatch, - /// Case insensitive regex not match - RegexNotIMatch, - /// Bitwise and, like `&` - BitwiseAnd, -} - -impl fmt::Display for Operator { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let display = match &self { - Operator::Eq => "=", - Operator::NotEq => "!=", - Operator::Lt => "<", - Operator::LtEq => "<=", - Operator::Gt => ">", - Operator::GtEq => ">=", - Operator::Plus => "+", - Operator::Minus => "-", - Operator::Multiply => "*", - Operator::Divide => "/", - Operator::Modulo => "%", - Operator::And => "AND", - Operator::Or => "OR", - Operator::Like => "LIKE", - Operator::NotLike => "NOT LIKE", - Operator::RegexMatch => "~", - Operator::RegexIMatch => "~*", - Operator::RegexNotMatch => "!~", - Operator::RegexNotIMatch => "!~*", - Operator::IsDistinctFrom => "IS DISTINCT FROM", - Operator::IsNotDistinctFrom => "IS NOT DISTINCT FROM", - Operator::BitwiseAnd => "&", - }; - write!(f, "{}", display) - } -} +pub use datafusion_expr::Operator; +use std::ops; impl ops::Add for Expr { type Output = Self; diff --git a/datafusion/src/logical_plan/window_frames.rs b/datafusion/src/logical_plan/window_frames.rs index 42e0a7e87c05..519582089db4 100644 --- a/datafusion/src/logical_plan/window_frames.rs +++ b/datafusion/src/logical_plan/window_frames.rs @@ -15,365 +15,6 @@ // specific language governing permissions and limitations // under the License. -//! Window frame -//! -//! The frame-spec determines which output rows are read by an aggregate window function. The frame-spec consists of four parts: -//! - A frame type - either ROWS, RANGE or GROUPS, -//! - A starting frame boundary, -//! - An ending frame boundary, -//! - An EXCLUDE clause. +//! Window frame types, reimported from datafusion_expr -use crate::error::{DataFusionError, Result}; -use sqlparser::ast; -use std::cmp::Ordering; -use std::convert::{From, TryFrom}; -use std::fmt; -use std::hash::{Hash, Hasher}; - -/// The frame-spec determines which output rows are read by an aggregate window function. -/// -/// The ending frame boundary can be omitted (if the BETWEEN and AND keywords that surround the -/// starting frame boundary are also omitted), in which case the ending frame boundary defaults to -/// CURRENT ROW. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)] -pub struct WindowFrame { - /// A frame type - either ROWS, RANGE or GROUPS - pub units: WindowFrameUnits, - /// A starting frame boundary - pub start_bound: WindowFrameBound, - /// An ending frame boundary - pub end_bound: WindowFrameBound, -} - -impl fmt::Display for WindowFrame { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "{} BETWEEN {} AND {}", - self.units, self.start_bound, self.end_bound - )?; - Ok(()) - } -} - -impl TryFrom for WindowFrame { - type Error = DataFusionError; - - fn try_from(value: ast::WindowFrame) -> Result { - let start_bound = value.start_bound.into(); - let end_bound = value - .end_bound - .map(WindowFrameBound::from) - .unwrap_or(WindowFrameBound::CurrentRow); - - if let WindowFrameBound::Following(None) = start_bound { - Err(DataFusionError::Execution( - "Invalid window frame: start bound cannot be unbounded following" - .to_owned(), - )) - } else if let WindowFrameBound::Preceding(None) = end_bound { - Err(DataFusionError::Execution( - "Invalid window frame: end bound cannot be unbounded preceding" - .to_owned(), - )) - } else if start_bound > end_bound { - Err(DataFusionError::Execution(format!( - "Invalid window frame: start bound ({}) cannot be larger than end bound ({})", - start_bound, end_bound - ))) - } else { - let units = value.units.into(); - if units == WindowFrameUnits::Range { - for bound in &[start_bound, end_bound] { - match bound { - WindowFrameBound::Preceding(Some(v)) - | WindowFrameBound::Following(Some(v)) - if *v > 0 => - { - Err(DataFusionError::NotImplemented(format!( - "With WindowFrameUnits={}, the bound cannot be {} PRECEDING or FOLLOWING at the moment", - units, v - ))) - } - _ => Ok(()), - }?; - } - } - Ok(Self { - units, - start_bound, - end_bound, - }) - } - } -} - -impl Default for WindowFrame { - fn default() -> Self { - WindowFrame { - units: WindowFrameUnits::Range, - start_bound: WindowFrameBound::Preceding(None), - end_bound: WindowFrameBound::CurrentRow, - } - } -} - -/// There are five ways to describe starting and ending frame boundaries: -/// -/// 1. UNBOUNDED PRECEDING -/// 2. PRECEDING -/// 3. CURRENT ROW -/// 4. FOLLOWING -/// 5. UNBOUNDED FOLLOWING -/// -/// in this implementation we'll only allow to be u64 (i.e. no dynamic boundary) -#[derive(Debug, Clone, Copy, Eq)] -pub enum WindowFrameBound { - /// 1. UNBOUNDED PRECEDING - /// The frame boundary is the first row in the partition. - /// - /// 2. PRECEDING - /// must be a non-negative constant numeric expression. The boundary is a row that - /// is "units" prior to the current row. - Preceding(Option), - /// 3. The current row. - /// - /// For RANGE and GROUPS frame types, peers of the current row are also - /// included in the frame, unless specifically excluded by the EXCLUDE clause. - /// This is true regardless of whether CURRENT ROW is used as the starting or ending frame - /// boundary. - CurrentRow, - /// 4. This is the same as " PRECEDING" except that the boundary is units after the - /// current rather than before the current row. - /// - /// 5. UNBOUNDED FOLLOWING - /// The frame boundary is the last row in the partition. - Following(Option), -} - -impl From for WindowFrameBound { - fn from(value: ast::WindowFrameBound) -> Self { - match value { - ast::WindowFrameBound::Preceding(v) => Self::Preceding(v), - ast::WindowFrameBound::Following(v) => Self::Following(v), - ast::WindowFrameBound::CurrentRow => Self::CurrentRow, - } - } -} - -impl fmt::Display for WindowFrameBound { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - WindowFrameBound::CurrentRow => f.write_str("CURRENT ROW"), - WindowFrameBound::Preceding(None) => f.write_str("UNBOUNDED PRECEDING"), - WindowFrameBound::Following(None) => f.write_str("UNBOUNDED FOLLOWING"), - WindowFrameBound::Preceding(Some(n)) => write!(f, "{} PRECEDING", n), - WindowFrameBound::Following(Some(n)) => write!(f, "{} FOLLOWING", n), - } - } -} - -impl PartialEq for WindowFrameBound { - fn eq(&self, other: &Self) -> bool { - self.cmp(other) == Ordering::Equal - } -} - -impl PartialOrd for WindowFrameBound { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for WindowFrameBound { - fn cmp(&self, other: &Self) -> Ordering { - self.get_rank().cmp(&other.get_rank()) - } -} - -impl Hash for WindowFrameBound { - fn hash(&self, state: &mut H) { - self.get_rank().hash(state) - } -} - -impl WindowFrameBound { - /// get the rank of this window frame bound. - /// - /// the rank is a tuple of (u8, u64) because we'll firstly compare the kind and then the value - /// which requires special handling e.g. with preceding the larger the value the smaller the - /// rank and also for 0 preceding / following it is the same as current row - fn get_rank(&self) -> (u8, u64) { - match self { - WindowFrameBound::Preceding(None) => (0, 0), - WindowFrameBound::Following(None) => (4, 0), - WindowFrameBound::Preceding(Some(0)) - | WindowFrameBound::CurrentRow - | WindowFrameBound::Following(Some(0)) => (2, 0), - WindowFrameBound::Preceding(Some(v)) => (1, u64::MAX - *v), - WindowFrameBound::Following(Some(v)) => (3, *v), - } - } -} - -/// There are three frame types: ROWS, GROUPS, and RANGE. The frame type determines how the -/// starting and ending boundaries of the frame are measured. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)] -pub enum WindowFrameUnits { - /// The ROWS frame type means that the starting and ending boundaries for the frame are - /// determined by counting individual rows relative to the current row. - Rows, - /// The RANGE frame type requires that the ORDER BY clause of the window have exactly one - /// term. Call that term "X". With the RANGE frame type, the elements of the frame are - /// determined by computing the value of expression X for all rows in the partition and framing - /// those rows for which the value of X is within a certain range of the value of X for the - /// current row. - Range, - /// The GROUPS frame type means that the starting and ending boundaries are determine - /// by counting "groups" relative to the current group. A "group" is a set of rows that all have - /// equivalent values for all all terms of the window ORDER BY clause. - Groups, -} - -impl fmt::Display for WindowFrameUnits { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.write_str(match self { - WindowFrameUnits::Rows => "ROWS", - WindowFrameUnits::Range => "RANGE", - WindowFrameUnits::Groups => "GROUPS", - }) - } -} - -impl From for WindowFrameUnits { - fn from(value: ast::WindowFrameUnits) -> Self { - match value { - ast::WindowFrameUnits::Range => Self::Range, - ast::WindowFrameUnits::Groups => Self::Groups, - ast::WindowFrameUnits::Rows => Self::Rows, - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_window_frame_creation() -> Result<()> { - let window_frame = ast::WindowFrame { - units: ast::WindowFrameUnits::Range, - start_bound: ast::WindowFrameBound::Following(None), - end_bound: None, - }; - let result = WindowFrame::try_from(window_frame); - assert_eq!( - result.err().unwrap().to_string(), - "Execution error: Invalid window frame: start bound cannot be unbounded following".to_owned() - ); - - let window_frame = ast::WindowFrame { - units: ast::WindowFrameUnits::Range, - start_bound: ast::WindowFrameBound::Preceding(None), - end_bound: Some(ast::WindowFrameBound::Preceding(None)), - }; - let result = WindowFrame::try_from(window_frame); - assert_eq!( - result.err().unwrap().to_string(), - "Execution error: Invalid window frame: end bound cannot be unbounded preceding".to_owned() - ); - - let window_frame = ast::WindowFrame { - units: ast::WindowFrameUnits::Range, - start_bound: ast::WindowFrameBound::Preceding(Some(1)), - end_bound: Some(ast::WindowFrameBound::Preceding(Some(2))), - }; - let result = WindowFrame::try_from(window_frame); - assert_eq!( - result.err().unwrap().to_string(), - "Execution error: Invalid window frame: start bound (1 PRECEDING) cannot be larger than end bound (2 PRECEDING)".to_owned() - ); - - let window_frame = ast::WindowFrame { - units: ast::WindowFrameUnits::Range, - start_bound: ast::WindowFrameBound::Preceding(Some(2)), - end_bound: Some(ast::WindowFrameBound::Preceding(Some(1))), - }; - let result = WindowFrame::try_from(window_frame); - assert_eq!( - result.err().unwrap().to_string(), - "This feature is not implemented: With WindowFrameUnits=RANGE, the bound cannot be 2 PRECEDING or FOLLOWING at the moment".to_owned() - ); - - let window_frame = ast::WindowFrame { - units: ast::WindowFrameUnits::Rows, - start_bound: ast::WindowFrameBound::Preceding(Some(2)), - end_bound: Some(ast::WindowFrameBound::Preceding(Some(1))), - }; - let result = WindowFrame::try_from(window_frame); - assert!(result.is_ok()); - Ok(()) - } - - #[test] - fn test_eq() { - assert_eq!( - WindowFrameBound::Preceding(Some(0)), - WindowFrameBound::CurrentRow - ); - assert_eq!( - WindowFrameBound::CurrentRow, - WindowFrameBound::Following(Some(0)) - ); - assert_eq!( - WindowFrameBound::Following(Some(2)), - WindowFrameBound::Following(Some(2)) - ); - assert_eq!( - WindowFrameBound::Following(None), - WindowFrameBound::Following(None) - ); - assert_eq!( - WindowFrameBound::Preceding(Some(2)), - WindowFrameBound::Preceding(Some(2)) - ); - assert_eq!( - WindowFrameBound::Preceding(None), - WindowFrameBound::Preceding(None) - ); - } - - #[test] - fn test_ord() { - assert!(WindowFrameBound::Preceding(Some(1)) < WindowFrameBound::CurrentRow); - // ! yes this is correct! - assert!( - WindowFrameBound::Preceding(Some(2)) < WindowFrameBound::Preceding(Some(1)) - ); - assert!( - WindowFrameBound::Preceding(Some(u64::MAX)) - < WindowFrameBound::Preceding(Some(u64::MAX - 1)) - ); - assert!( - WindowFrameBound::Preceding(None) - < WindowFrameBound::Preceding(Some(1000000)) - ); - assert!( - WindowFrameBound::Preceding(None) - < WindowFrameBound::Preceding(Some(u64::MAX)) - ); - assert!(WindowFrameBound::Preceding(None) < WindowFrameBound::Following(Some(0))); - assert!( - WindowFrameBound::Preceding(Some(1)) < WindowFrameBound::Following(Some(1)) - ); - assert!(WindowFrameBound::CurrentRow < WindowFrameBound::Following(Some(1))); - assert!( - WindowFrameBound::Following(Some(1)) < WindowFrameBound::Following(Some(2)) - ); - assert!(WindowFrameBound::Following(Some(2)) < WindowFrameBound::Following(None)); - assert!( - WindowFrameBound::Following(Some(u64::MAX)) - < WindowFrameBound::Following(None) - ); - } -} +pub use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; From 3c39c72c4e2c801fd8b6bbea0536392505168b34 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Tue, 8 Feb 2022 19:36:30 +0800 Subject: [PATCH 42/50] move signature, type signature, and volatility to split module (#1763) --- datafusion-expr/src/lib.rs | 2 + datafusion-expr/src/signature.rs | 116 ++++++++++++++++++++++ datafusion/src/physical_plan/functions.rs | 98 +----------------- 3 files changed, 119 insertions(+), 97 deletions(-) create mode 100644 datafusion-expr/src/signature.rs diff --git a/datafusion-expr/src/lib.rs b/datafusion-expr/src/lib.rs index 13fa93ed6a2e..d2b10b404a8b 100644 --- a/datafusion-expr/src/lib.rs +++ b/datafusion-expr/src/lib.rs @@ -17,10 +17,12 @@ mod aggregate_function; mod operator; +mod signature; mod window_frame; mod window_function; pub use aggregate_function::AggregateFunction; pub use operator::Operator; +pub use signature::{Signature, TypeSignature, Volatility}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; pub use window_function::{BuiltInWindowFunction, WindowFunction}; diff --git a/datafusion-expr/src/signature.rs b/datafusion-expr/src/signature.rs new file mode 100644 index 000000000000..5c27f422c105 --- /dev/null +++ b/datafusion-expr/src/signature.rs @@ -0,0 +1,116 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::DataType; + +///A function's volatility, which defines the functions eligibility for certain optimizations +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] +pub enum Volatility { + /// Immutable - An immutable function will always return the same output when given the same input. An example of this is [BuiltinScalarFunction::Cos]. + Immutable, + /// Stable - A stable function may return different values given the same input accross different queries but must return the same value for a given input within a query. An example of this is [BuiltinScalarFunction::Now]. + Stable, + /// Volatile - A volatile function may change the return value from evaluation to evaluation. Mutiple invocations of a volatile function may return different results when used in the same query. An example of this is [BuiltinScalarFunction::Random]. + Volatile, +} + +/// A function's type signature, which defines the function's supported argument types. +#[derive(Debug, Clone, PartialEq, Hash)] +pub enum TypeSignature { + /// arbitrary number of arguments of an common type out of a list of valid types + // A function such as `concat` is `Variadic(vec![DataType::Utf8, DataType::LargeUtf8])` + Variadic(Vec), + /// arbitrary number of arguments of an arbitrary but equal type + // A function such as `array` is `VariadicEqual` + // The first argument decides the type used for coercion + VariadicEqual, + /// fixed number of arguments of an arbitrary but equal type out of a list of valid types + // A function of one argument of f64 is `Uniform(1, vec![DataType::Float64])` + // A function of one argument of f64 or f32 is `Uniform(1, vec![DataType::Float32, DataType::Float64])` + Uniform(usize, Vec), + /// exact number of arguments of an exact type + Exact(Vec), + /// fixed number of arguments of arbitrary types + Any(usize), + /// One of a list of signatures + OneOf(Vec), +} + +///The Signature of a function defines its supported input types as well as its volatility. +#[derive(Debug, Clone, PartialEq, Hash)] +pub struct Signature { + /// type_signature - The types that the function accepts. See [TypeSignature] for more information. + pub type_signature: TypeSignature, + /// volatility - The volatility of the function. See [Volatility] for more information. + pub volatility: Volatility, +} + +impl Signature { + /// new - Creates a new Signature from any type signature and the volatility. + pub fn new(type_signature: TypeSignature, volatility: Volatility) -> Self { + Signature { + type_signature, + volatility, + } + } + /// variadic - Creates a variadic signature that represents an arbitrary number of arguments all from a type in common_types. + pub fn variadic(common_types: Vec, volatility: Volatility) -> Self { + Self { + type_signature: TypeSignature::Variadic(common_types), + volatility, + } + } + /// variadic_equal - Creates a variadic signature that represents an arbitrary number of arguments of the same type. + pub fn variadic_equal(volatility: Volatility) -> Self { + Self { + type_signature: TypeSignature::VariadicEqual, + volatility, + } + } + /// uniform - Creates a function with a fixed number of arguments of the same type, which must be from valid_types. + pub fn uniform( + arg_count: usize, + valid_types: Vec, + volatility: Volatility, + ) -> Self { + Self { + type_signature: TypeSignature::Uniform(arg_count, valid_types), + volatility, + } + } + /// exact - Creates a signture which must match the types in exact_types in order. + pub fn exact(exact_types: Vec, volatility: Volatility) -> Self { + Signature { + type_signature: TypeSignature::Exact(exact_types), + volatility, + } + } + /// any - Creates a signature which can a be made of any type but of a specified number + pub fn any(arg_count: usize, volatility: Volatility) -> Self { + Signature { + type_signature: TypeSignature::Any(arg_count), + volatility, + } + } + /// one_of Creates a signature which can match any of the [TypeSignature]s which are passed in. + pub fn one_of(type_signatures: Vec, volatility: Volatility) -> Self { + Signature { + type_signature: TypeSignature::OneOf(type_signatures), + volatility, + } + } +} diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs index 7d7cda75e867..af157c02fc38 100644 --- a/datafusion/src/physical_plan/functions.rs +++ b/datafusion/src/physical_plan/functions.rs @@ -56,103 +56,7 @@ use fmt::{Debug, Formatter}; use std::convert::From; use std::{any::Any, fmt, str::FromStr, sync::Arc}; -/// A function's type signature, which defines the function's supported argument types. -#[derive(Debug, Clone, PartialEq, Hash)] -pub enum TypeSignature { - /// arbitrary number of arguments of an common type out of a list of valid types - // A function such as `concat` is `Variadic(vec![DataType::Utf8, DataType::LargeUtf8])` - Variadic(Vec), - /// arbitrary number of arguments of an arbitrary but equal type - // A function such as `array` is `VariadicEqual` - // The first argument decides the type used for coercion - VariadicEqual, - /// fixed number of arguments of an arbitrary but equal type out of a list of valid types - // A function of one argument of f64 is `Uniform(1, vec![DataType::Float64])` - // A function of one argument of f64 or f32 is `Uniform(1, vec![DataType::Float32, DataType::Float64])` - Uniform(usize, Vec), - /// exact number of arguments of an exact type - Exact(Vec), - /// fixed number of arguments of arbitrary types - Any(usize), - /// One of a list of signatures - OneOf(Vec), -} - -///The Signature of a function defines its supported input types as well as its volatility. -#[derive(Debug, Clone, PartialEq, Hash)] -pub struct Signature { - /// type_signature - The types that the function accepts. See [TypeSignature] for more information. - pub type_signature: TypeSignature, - /// volatility - The volatility of the function. See [Volatility] for more information. - pub volatility: Volatility, -} - -impl Signature { - /// new - Creates a new Signature from any type signature and the volatility. - pub fn new(type_signature: TypeSignature, volatility: Volatility) -> Self { - Signature { - type_signature, - volatility, - } - } - /// variadic - Creates a variadic signature that represents an arbitrary number of arguments all from a type in common_types. - pub fn variadic(common_types: Vec, volatility: Volatility) -> Self { - Self { - type_signature: TypeSignature::Variadic(common_types), - volatility, - } - } - /// variadic_equal - Creates a variadic signature that represents an arbitrary number of arguments of the same type. - pub fn variadic_equal(volatility: Volatility) -> Self { - Self { - type_signature: TypeSignature::VariadicEqual, - volatility, - } - } - /// uniform - Creates a function with a fixed number of arguments of the same type, which must be from valid_types. - pub fn uniform( - arg_count: usize, - valid_types: Vec, - volatility: Volatility, - ) -> Self { - Self { - type_signature: TypeSignature::Uniform(arg_count, valid_types), - volatility, - } - } - /// exact - Creates a signture which must match the types in exact_types in order. - pub fn exact(exact_types: Vec, volatility: Volatility) -> Self { - Signature { - type_signature: TypeSignature::Exact(exact_types), - volatility, - } - } - /// any - Creates a signature which can a be made of any type but of a specified number - pub fn any(arg_count: usize, volatility: Volatility) -> Self { - Signature { - type_signature: TypeSignature::Any(arg_count), - volatility, - } - } - /// one_of Creates a signature which can match any of the [TypeSignature]s which are passed in. - pub fn one_of(type_signatures: Vec, volatility: Volatility) -> Self { - Signature { - type_signature: TypeSignature::OneOf(type_signatures), - volatility, - } - } -} - -///A function's volatility, which defines the functions eligibility for certain optimizations -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] -pub enum Volatility { - /// Immutable - An immutable function will always return the same output when given the same input. An example of this is [BuiltinScalarFunction::Cos]. - Immutable, - /// Stable - A stable function may return different values given the same input accross different queries but must return the same value for a given input within a query. An example of this is [BuiltinScalarFunction::Now]. - Stable, - /// Volatile - A volatile function may change the return value from evaluation to evaluation. Mutiple invocations of a volatile function may return different results when used in the same query. An example of this is [BuiltinScalarFunction::Random]. - Volatile, -} +pub use datafusion_expr::{Signature, TypeSignature, Volatility}; /// Scalar function /// From 86dcb0992fcfb3b88ff4d94e94cd998c1ef9786e Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Tue, 8 Feb 2022 20:48:25 +0800 Subject: [PATCH 43/50] [split/10] split up expr for rewriting, visiting, and simplification traits (#1774) * split up expr for rewriting, visiting, and simplification * add docs --- datafusion/src/datasource/listing/helpers.rs | 2 +- datafusion/src/logical_plan/expr.rs | 764 +----------------- datafusion/src/logical_plan/expr_rewriter.rs | 591 ++++++++++++++ datafusion/src/logical_plan/expr_simplier.rs | 97 +++ datafusion/src/logical_plan/expr_visitor.rs | 176 ++++ datafusion/src/logical_plan/mod.rs | 21 +- .../src/optimizer/common_subexpr_eliminate.rs | 4 +- .../src/optimizer/simplify_expressions.rs | 5 +- datafusion/src/optimizer/utils.rs | 6 +- datafusion/src/sql/utils.rs | 1 + datafusion/tests/simplification.rs | 1 + 11 files changed, 892 insertions(+), 776 deletions(-) create mode 100644 datafusion/src/logical_plan/expr_rewriter.rs create mode 100644 datafusion/src/logical_plan/expr_simplier.rs create mode 100644 datafusion/src/logical_plan/expr_visitor.rs diff --git a/datafusion/src/datasource/listing/helpers.rs b/datafusion/src/datasource/listing/helpers.rs index 912179c36f06..8ff821082906 100644 --- a/datafusion/src/datasource/listing/helpers.rs +++ b/datafusion/src/datasource/listing/helpers.rs @@ -37,7 +37,7 @@ use log::debug; use crate::{ error::Result, execution::context::ExecutionContext, - logical_plan::{self, Expr, ExpressionVisitor, Recursion}, + logical_plan::{self, Expr, ExprVisitable, ExpressionVisitor, Recursion}, physical_plan::functions::Volatility, scalar::ScalarValue, }; diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 4b539a814551..69da346aee8d 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -20,12 +20,8 @@ pub use super::Operator; use crate::error::{DataFusionError, Result}; -use crate::execution::context::ExecutionProps; use crate::field_util::get_indexed_field; -use crate::logical_plan::{ - plan::Aggregate, window_frames, DFField, DFSchema, LogicalPlan, -}; -use crate::optimizer::simplify_expressions::{ConstEvaluator, Simplifier}; +use crate::logical_plan::{window_frames, DFField, DFSchema}; use crate::physical_plan::functions::Volatility; use crate::physical_plan::{ aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF, @@ -36,7 +32,7 @@ use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction}; use arrow::{compute::can_cast_types, datatypes::DataType}; pub use datafusion_common::{Column, ExprSchema}; use functions::{ReturnTypeFunction, ScalarFunctionImplementation, Signature}; -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use std::fmt; use std::hash::{BuildHasher, Hash, Hasher}; use std::ops::Not; @@ -557,348 +553,6 @@ impl Expr { nulls_first, } } - - /// Performs a depth first walk of an expression and - /// its children, calling [`ExpressionVisitor::pre_visit`] and - /// `visitor.post_visit`. - /// - /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) to - /// separate expression algorithms from the structure of the - /// `Expr` tree and make it easier to add new types of expressions - /// and algorithms that walk the tree. - /// - /// For an expression tree such as - /// ```text - /// BinaryExpr (GT) - /// left: Column("foo") - /// right: Column("bar") - /// ``` - /// - /// The nodes are visited using the following order - /// ```text - /// pre_visit(BinaryExpr(GT)) - /// pre_visit(Column("foo")) - /// pre_visit(Column("bar")) - /// post_visit(Column("bar")) - /// post_visit(Column("bar")) - /// post_visit(BinaryExpr(GT)) - /// ``` - /// - /// If an Err result is returned, recursion is stopped immediately - /// - /// If `Recursion::Stop` is returned on a call to pre_visit, no - /// children of that expression are visited, nor is post_visit - /// called on that expression - /// - pub fn accept(&self, visitor: V) -> Result { - let visitor = match visitor.pre_visit(self)? { - Recursion::Continue(visitor) => visitor, - // If the recursion should stop, do not visit children - Recursion::Stop(visitor) => return Ok(visitor), - }; - - // recurse (and cover all expression types) - let visitor = match self { - Expr::Alias(expr, _) - | Expr::Not(expr) - | Expr::IsNotNull(expr) - | Expr::IsNull(expr) - | Expr::Negative(expr) - | Expr::Cast { expr, .. } - | Expr::TryCast { expr, .. } - | Expr::Sort { expr, .. } - | Expr::GetIndexedField { expr, .. } => expr.accept(visitor), - Expr::Column(_) - | Expr::ScalarVariable(_) - | Expr::Literal(_) - | Expr::Wildcard => Ok(visitor), - Expr::BinaryExpr { left, right, .. } => { - let visitor = left.accept(visitor)?; - right.accept(visitor) - } - Expr::Between { - expr, low, high, .. - } => { - let visitor = expr.accept(visitor)?; - let visitor = low.accept(visitor)?; - high.accept(visitor) - } - Expr::Case { - expr, - when_then_expr, - else_expr, - } => { - let visitor = if let Some(expr) = expr.as_ref() { - expr.accept(visitor) - } else { - Ok(visitor) - }?; - let visitor = when_then_expr.iter().try_fold( - visitor, - |visitor, (when, then)| { - let visitor = when.accept(visitor)?; - then.accept(visitor) - }, - )?; - if let Some(else_expr) = else_expr.as_ref() { - else_expr.accept(visitor) - } else { - Ok(visitor) - } - } - Expr::ScalarFunction { args, .. } - | Expr::ScalarUDF { args, .. } - | Expr::AggregateFunction { args, .. } - | Expr::AggregateUDF { args, .. } => args - .iter() - .try_fold(visitor, |visitor, arg| arg.accept(visitor)), - Expr::WindowFunction { - args, - partition_by, - order_by, - .. - } => { - let visitor = args - .iter() - .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; - let visitor = partition_by - .iter() - .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; - let visitor = order_by - .iter() - .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; - Ok(visitor) - } - Expr::InList { expr, list, .. } => { - let visitor = expr.accept(visitor)?; - list.iter() - .try_fold(visitor, |visitor, arg| arg.accept(visitor)) - } - }?; - - visitor.post_visit(self) - } - - /// Performs a depth first walk of an expression and its children - /// to rewrite an expression, consuming `self` producing a new - /// [`Expr`]. - /// - /// Implements a modified version of the [visitor - /// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) to - /// separate algorithms from the structure of the `Expr` tree and - /// make it easier to write new, efficient expression - /// transformation algorithms. - /// - /// For an expression tree such as - /// ```text - /// BinaryExpr (GT) - /// left: Column("foo") - /// right: Column("bar") - /// ``` - /// - /// The nodes are visited using the following order - /// ```text - /// pre_visit(BinaryExpr(GT)) - /// pre_visit(Column("foo")) - /// mutatate(Column("foo")) - /// pre_visit(Column("bar")) - /// mutate(Column("bar")) - /// mutate(BinaryExpr(GT)) - /// ``` - /// - /// If an Err result is returned, recursion is stopped immediately - /// - /// If [`false`] is returned on a call to pre_visit, no - /// children of that expression are visited, nor is mutate - /// called on that expression - /// - pub fn rewrite(self, rewriter: &mut R) -> Result - where - R: ExprRewriter, - { - let need_mutate = match rewriter.pre_visit(&self)? { - RewriteRecursion::Mutate => return rewriter.mutate(self), - RewriteRecursion::Stop => return Ok(self), - RewriteRecursion::Continue => true, - RewriteRecursion::Skip => false, - }; - - // recurse into all sub expressions(and cover all expression types) - let expr = match self { - Expr::Alias(expr, name) => Expr::Alias(rewrite_boxed(expr, rewriter)?, name), - Expr::Column(_) => self.clone(), - Expr::ScalarVariable(names) => Expr::ScalarVariable(names), - Expr::Literal(value) => Expr::Literal(value), - Expr::BinaryExpr { left, op, right } => Expr::BinaryExpr { - left: rewrite_boxed(left, rewriter)?, - op, - right: rewrite_boxed(right, rewriter)?, - }, - Expr::Not(expr) => Expr::Not(rewrite_boxed(expr, rewriter)?), - Expr::IsNotNull(expr) => Expr::IsNotNull(rewrite_boxed(expr, rewriter)?), - Expr::IsNull(expr) => Expr::IsNull(rewrite_boxed(expr, rewriter)?), - Expr::Negative(expr) => Expr::Negative(rewrite_boxed(expr, rewriter)?), - Expr::Between { - expr, - low, - high, - negated, - } => Expr::Between { - expr: rewrite_boxed(expr, rewriter)?, - low: rewrite_boxed(low, rewriter)?, - high: rewrite_boxed(high, rewriter)?, - negated, - }, - Expr::Case { - expr, - when_then_expr, - else_expr, - } => { - let expr = rewrite_option_box(expr, rewriter)?; - let when_then_expr = when_then_expr - .into_iter() - .map(|(when, then)| { - Ok(( - rewrite_boxed(when, rewriter)?, - rewrite_boxed(then, rewriter)?, - )) - }) - .collect::>>()?; - - let else_expr = rewrite_option_box(else_expr, rewriter)?; - - Expr::Case { - expr, - when_then_expr, - else_expr, - } - } - Expr::Cast { expr, data_type } => Expr::Cast { - expr: rewrite_boxed(expr, rewriter)?, - data_type, - }, - Expr::TryCast { expr, data_type } => Expr::TryCast { - expr: rewrite_boxed(expr, rewriter)?, - data_type, - }, - Expr::Sort { - expr, - asc, - nulls_first, - } => Expr::Sort { - expr: rewrite_boxed(expr, rewriter)?, - asc, - nulls_first, - }, - Expr::ScalarFunction { args, fun } => Expr::ScalarFunction { - args: rewrite_vec(args, rewriter)?, - fun, - }, - Expr::ScalarUDF { args, fun } => Expr::ScalarUDF { - args: rewrite_vec(args, rewriter)?, - fun, - }, - Expr::WindowFunction { - args, - fun, - partition_by, - order_by, - window_frame, - } => Expr::WindowFunction { - args: rewrite_vec(args, rewriter)?, - fun, - partition_by: rewrite_vec(partition_by, rewriter)?, - order_by: rewrite_vec(order_by, rewriter)?, - window_frame, - }, - Expr::AggregateFunction { - args, - fun, - distinct, - } => Expr::AggregateFunction { - args: rewrite_vec(args, rewriter)?, - fun, - distinct, - }, - Expr::AggregateUDF { args, fun } => Expr::AggregateUDF { - args: rewrite_vec(args, rewriter)?, - fun, - }, - Expr::InList { - expr, - list, - negated, - } => Expr::InList { - expr: rewrite_boxed(expr, rewriter)?, - list: rewrite_vec(list, rewriter)?, - negated, - }, - Expr::Wildcard => Expr::Wildcard, - Expr::GetIndexedField { expr, key } => Expr::GetIndexedField { - expr: rewrite_boxed(expr, rewriter)?, - key, - }, - }; - - // now rewrite this expression itself - if need_mutate { - rewriter.mutate(expr) - } else { - Ok(expr) - } - } - - /// Simplifies this [`Expr`]`s as much as possible, evaluating - /// constants and applying algebraic simplifications - /// - /// # Example: - /// `b > 2 AND b > 2` - /// can be written to - /// `b > 2` - /// - /// ``` - /// use datafusion::logical_plan::*; - /// use datafusion::error::Result; - /// use datafusion::execution::context::ExecutionProps; - /// - /// /// Simple implementation that provides `Simplifier` the information it needs - /// #[derive(Default)] - /// struct Info { - /// execution_props: ExecutionProps, - /// }; - /// - /// impl SimplifyInfo for Info { - /// fn is_boolean_type(&self, expr: &Expr) -> Result { - /// Ok(false) - /// } - /// fn nullable(&self, expr: &Expr) -> Result { - /// Ok(true) - /// } - /// fn execution_props(&self) -> &ExecutionProps { - /// &self.execution_props - /// } - /// } - /// - /// // b < 2 - /// let b_lt_2 = col("b").gt(lit(2)); - /// - /// // (b < 2) OR (b < 2) - /// let expr = b_lt_2.clone().or(b_lt_2.clone()); - /// - /// // (b < 2) OR (b < 2) --> (b < 2) - /// let expr = expr.simplify(&Info::default()).unwrap(); - /// assert_eq!(expr, b_lt_2); - /// ``` - pub fn simplify(self, info: &S) -> Result { - let mut rewriter = Simplifier::new(info); - let mut const_evaluator = ConstEvaluator::new(info.execution_props()); - - // TODO iterate until no changes are made during rewrite - // (evaluating constants can enable new simplifications and - // simplifications can enable new constant evaluation) - // https://github.com/apache/arrow-datafusion/issues/1160 - self.rewrite(&mut const_evaluator)?.rewrite(&mut rewriter) - } } impl Not for Expr { @@ -936,103 +590,6 @@ impl std::fmt::Display for Expr { } } -#[allow(clippy::boxed_local)] -fn rewrite_boxed(boxed_expr: Box, rewriter: &mut R) -> Result> -where - R: ExprRewriter, -{ - // TODO: It might be possible to avoid an allocation (the - // Box::new) below by reusing the box. - let expr: Expr = *boxed_expr; - let rewritten_expr = expr.rewrite(rewriter)?; - Ok(Box::new(rewritten_expr)) -} - -fn rewrite_option_box( - option_box: Option>, - rewriter: &mut R, -) -> Result>> -where - R: ExprRewriter, -{ - option_box - .map(|expr| rewrite_boxed(expr, rewriter)) - .transpose() -} - -/// rewrite a `Vec` of `Expr`s with the rewriter -fn rewrite_vec(v: Vec, rewriter: &mut R) -> Result> -where - R: ExprRewriter, -{ - v.into_iter().map(|expr| expr.rewrite(rewriter)).collect() -} - -/// Controls how the visitor recursion should proceed. -pub enum Recursion { - /// Attempt to visit all the children, recursively, of this expression. - Continue(V), - /// Do not visit the children of this expression, though the walk - /// of parents of this expression will not be affected - Stop(V), -} - -/// Encode the traversal of an expression tree. When passed to -/// `Expr::accept`, `ExpressionVisitor::visit` is invoked -/// recursively on all nodes of an expression tree. See the comments -/// on `Expr::accept` for details on its use -pub trait ExpressionVisitor: Sized { - /// Invoked before any children of `expr` are visisted. - fn pre_visit(self, expr: &Expr) -> Result>; - - /// Invoked after all children of `expr` are visited. Default - /// implementation does nothing. - fn post_visit(self, _expr: &Expr) -> Result { - Ok(self) - } -} - -/// Controls how the [ExprRewriter] recursion should proceed. -pub enum RewriteRecursion { - /// Continue rewrite / visit this expression. - Continue, - /// Call [mutate()] immediately and return. - Mutate, - /// Do not rewrite / visit the children of this expression. - Stop, - /// Keep recursive but skip mutate on this expression - Skip, -} - -/// Trait for potentially recursively rewriting an [`Expr`] expression -/// tree. When passed to `Expr::rewrite`, `ExpressionVisitor::mutate` is -/// invoked recursively on all nodes of an expression tree. See the -/// comments on `Expr::rewrite` for details on its use -pub trait ExprRewriter: Sized { - /// Invoked before any children of `expr` are rewritten / - /// visited. Default implementation returns `Ok(RewriteRecursion::Continue)` - fn pre_visit(&mut self, _expr: &Expr) -> Result { - Ok(RewriteRecursion::Continue) - } - - /// Invoked after all children of `expr` have been mutated and - /// returns a potentially modified expr. - fn mutate(&mut self, expr: Expr) -> Result; -} - -/// The information necessary to apply algebraic simplification to an -/// [Expr]. See [SimplifyContext] for one implementation -pub trait SimplifyInfo { - /// returns true if this Expr has boolean type - fn is_boolean_type(&self, expr: &Expr) -> Result; - - /// returns true of this expr is nullable (could possibly be NULL) - fn nullable(&self, expr: &Expr) -> Result; - - /// Returns details needed for partial expression evaluation - fn execution_props(&self) -> &ExecutionProps; -} - /// Helper struct for building [Expr::Case] pub struct CaseBuilder { expr: Option>, @@ -1201,183 +758,6 @@ pub fn columnize_expr(e: Expr, input_schema: &DFSchema) -> Expr { } } -/// Recursively replace all Column expressions in a given expression tree with Column expressions -/// provided by the hash map argument. -pub fn replace_col(e: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { - struct ColumnReplacer<'a> { - replace_map: &'a HashMap<&'a Column, &'a Column>, - } - - impl<'a> ExprRewriter for ColumnReplacer<'a> { - fn mutate(&mut self, expr: Expr) -> Result { - if let Expr::Column(c) = &expr { - match self.replace_map.get(c) { - Some(new_c) => Ok(Expr::Column((*new_c).to_owned())), - None => Ok(expr), - } - } else { - Ok(expr) - } - } - } - - e.rewrite(&mut ColumnReplacer { replace_map }) -} - -/// Recursively call [`Column::normalize`] on all Column expressions -/// in the `expr` expression tree. -pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { - normalize_col_with_schemas(expr, &plan.all_schemas(), &plan.using_columns()?) -} - -/// Recursively call [`Column::normalize`] on all Column expressions -/// in the `expr` expression tree. -fn normalize_col_with_schemas( - expr: Expr, - schemas: &[&Arc], - using_columns: &[HashSet], -) -> Result { - struct ColumnNormalizer<'a> { - schemas: &'a [&'a Arc], - using_columns: &'a [HashSet], - } - - impl<'a> ExprRewriter for ColumnNormalizer<'a> { - fn mutate(&mut self, expr: Expr) -> Result { - if let Expr::Column(c) = expr { - Ok(Expr::Column(c.normalize_with_schemas( - self.schemas, - self.using_columns, - )?)) - } else { - Ok(expr) - } - } - } - - expr.rewrite(&mut ColumnNormalizer { - schemas, - using_columns, - }) -} - -/// Recursively normalize all Column expressions in a list of expression trees -pub fn normalize_cols( - exprs: impl IntoIterator>, - plan: &LogicalPlan, -) -> Result> { - exprs - .into_iter() - .map(|e| normalize_col(e.into(), plan)) - .collect() -} - -/// Rewrite sort on aggregate expressions to sort on the column of aggregate output -/// For example, `max(x)` is written to `col("MAX(x)")` -pub fn rewrite_sort_cols_by_aggs( - exprs: impl IntoIterator>, - plan: &LogicalPlan, -) -> Result> { - exprs - .into_iter() - .map(|e| { - let expr = e.into(); - match expr { - Expr::Sort { - expr, - asc, - nulls_first, - } => { - let sort = Expr::Sort { - expr: Box::new(rewrite_sort_col_by_aggs(*expr, plan)?), - asc, - nulls_first, - }; - Ok(sort) - } - expr => Ok(expr), - } - }) - .collect() -} - -fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result { - match plan { - LogicalPlan::Aggregate(Aggregate { - input, aggr_expr, .. - }) => { - struct Rewriter<'a> { - plan: &'a LogicalPlan, - input: &'a LogicalPlan, - aggr_expr: &'a Vec, - } - - impl<'a> ExprRewriter for Rewriter<'a> { - fn mutate(&mut self, expr: Expr) -> Result { - let normalized_expr = normalize_col(expr.clone(), self.plan); - if normalized_expr.is_err() { - // The expr is not based on Aggregate plan output. Skip it. - return Ok(expr); - } - let normalized_expr = normalized_expr.unwrap(); - if let Some(found_agg) = - self.aggr_expr.iter().find(|a| (**a) == normalized_expr) - { - let agg = normalize_col(found_agg.clone(), self.plan)?; - let col = Expr::Column( - agg.to_field(self.input.schema()) - .map(|f| f.qualified_column())?, - ); - Ok(col) - } else { - Ok(expr) - } - } - } - - expr.rewrite(&mut Rewriter { - plan, - input, - aggr_expr, - }) - } - LogicalPlan::Projection(_) => rewrite_sort_col_by_aggs(expr, plan.inputs()[0]), - _ => Ok(expr), - } -} - -/// Recursively 'unnormalize' (remove all qualifiers) from an -/// expression tree. -/// -/// For example, if there were expressions like `foo.bar` this would -/// rewrite it to just `bar`. -pub fn unnormalize_col(expr: Expr) -> Expr { - struct RemoveQualifier {} - - impl ExprRewriter for RemoveQualifier { - fn mutate(&mut self, expr: Expr) -> Result { - if let Expr::Column(col) = expr { - //let Column { relation: _, name } = col; - Ok(Expr::Column(Column { - relation: None, - name: col.name, - })) - } else { - Ok(expr) - } - } - } - - expr.rewrite(&mut RemoveQualifier {}) - .expect("Unnormalize is infallable") -} - -/// Recursively un-normalize all Column expressions in a list of expression trees -#[inline] -pub fn unnormalize_cols(exprs: impl IntoIterator) -> Vec { - exprs.into_iter().map(unnormalize_col).collect() -} - /// Recursively un-alias an expressions #[inline] pub fn unalias(expr: Expr) -> Expr { @@ -2114,24 +1494,6 @@ mod tests { assert_eq!(expr, expected); } - #[test] - fn rewriter_visit() { - let mut rewriter = RecordingRewriter::default(); - col("state").eq(lit("CO")).rewrite(&mut rewriter).unwrap(); - - assert_eq!( - rewriter.v, - vec![ - "Previsited #state = Utf8(\"CO\")", - "Previsited #state", - "Mutated #state", - "Previsited Utf8(\"CO\")", - "Mutated Utf8(\"CO\")", - "Mutated #state = Utf8(\"CO\")" - ] - ) - } - #[test] fn filter_is_null_and_is_not_null() { let col_null = col("col1"); @@ -2143,128 +1505,6 @@ mod tests { ); } - #[derive(Default)] - struct RecordingRewriter { - v: Vec, - } - impl ExprRewriter for RecordingRewriter { - fn mutate(&mut self, expr: Expr) -> Result { - self.v.push(format!("Mutated {:?}", expr)); - Ok(expr) - } - - fn pre_visit(&mut self, expr: &Expr) -> Result { - self.v.push(format!("Previsited {:?}", expr)); - Ok(RewriteRecursion::Continue) - } - } - - #[test] - fn rewriter_rewrite() { - let mut rewriter = FooBarRewriter {}; - - // rewrites "foo" --> "bar" - let rewritten = col("state").eq(lit("foo")).rewrite(&mut rewriter).unwrap(); - assert_eq!(rewritten, col("state").eq(lit("bar"))); - - // doesn't wrewrite - let rewritten = col("state").eq(lit("baz")).rewrite(&mut rewriter).unwrap(); - assert_eq!(rewritten, col("state").eq(lit("baz"))); - } - - /// rewrites all "foo" string literals to "bar" - struct FooBarRewriter {} - impl ExprRewriter for FooBarRewriter { - fn mutate(&mut self, expr: Expr) -> Result { - match expr { - Expr::Literal(ScalarValue::Utf8(Some(utf8_val))) => { - let utf8_val = if utf8_val == "foo" { - "bar".to_string() - } else { - utf8_val - }; - Ok(lit(utf8_val)) - } - // otherwise, return the expression unchanged - expr => Ok(expr), - } - } - } - - #[test] - fn normalize_cols() { - let expr = col("a") + col("b") + col("c"); - - // Schemas with some matching and some non matching cols - let schema_a = - DFSchema::new(vec![make_field("tableA", "a"), make_field("tableA", "aa")]) - .unwrap(); - let schema_c = - DFSchema::new(vec![make_field("tableC", "cc"), make_field("tableC", "c")]) - .unwrap(); - let schema_b = DFSchema::new(vec![make_field("tableB", "b")]).unwrap(); - // non matching - let schema_f = - DFSchema::new(vec![make_field("tableC", "f"), make_field("tableC", "ff")]) - .unwrap(); - let schemas = vec![schema_c, schema_f, schema_b, schema_a] - .into_iter() - .map(Arc::new) - .collect::>(); - let schemas = schemas.iter().collect::>(); - - let normalized_expr = normalize_col_with_schemas(expr, &schemas, &[]).unwrap(); - assert_eq!( - normalized_expr, - col("tableA.a") + col("tableB.b") + col("tableC.c") - ); - } - - #[test] - fn normalize_cols_priority() { - let expr = col("a") + col("b"); - // Schemas with multiple matches for column a, first takes priority - let schema_a = DFSchema::new(vec![make_field("tableA", "a")]).unwrap(); - let schema_b = DFSchema::new(vec![make_field("tableB", "b")]).unwrap(); - let schema_a2 = DFSchema::new(vec![make_field("tableA2", "a")]).unwrap(); - let schemas = vec![schema_a2, schema_b, schema_a] - .into_iter() - .map(Arc::new) - .collect::>(); - let schemas = schemas.iter().collect::>(); - - let normalized_expr = normalize_col_with_schemas(expr, &schemas, &[]).unwrap(); - assert_eq!(normalized_expr, col("tableA2.a") + col("tableB.b")); - } - - #[test] - fn normalize_cols_non_exist() { - // test normalizing columns when the name doesn't exist - let expr = col("a") + col("b"); - let schema_a = DFSchema::new(vec![make_field("tableA", "a")]).unwrap(); - let schemas = vec![schema_a].into_iter().map(Arc::new).collect::>(); - let schemas = schemas.iter().collect::>(); - - let error = normalize_col_with_schemas(expr, &schemas, &[]) - .unwrap_err() - .to_string(); - assert_eq!( - error, - "Error during planning: Column #b not found in provided schemas" - ); - } - - #[test] - fn unnormalize_cols() { - let expr = col("tableA.a") + col("tableB.b"); - let unnormalized_expr = unnormalize_col(expr); - assert_eq!(unnormalized_expr, col("a") + col("b")); - } - - fn make_field(relation: &str, column: &str) -> DFField { - DFField::new(Some(relation), column, DataType::Int8, false) - } - #[test] fn test_not() { assert_eq!(lit(1).not(), !lit(1)); diff --git a/datafusion/src/logical_plan/expr_rewriter.rs b/datafusion/src/logical_plan/expr_rewriter.rs new file mode 100644 index 000000000000..d452dcd4c426 --- /dev/null +++ b/datafusion/src/logical_plan/expr_rewriter.rs @@ -0,0 +1,591 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Expression rewriter + +use super::Expr; +use crate::logical_plan::plan::Aggregate; +use crate::logical_plan::DFSchema; +use crate::logical_plan::LogicalPlan; +use datafusion_common::Column; +use datafusion_common::Result; +use std::collections::HashMap; +use std::collections::HashSet; +use std::sync::Arc; + +/// Controls how the [ExprRewriter] recursion should proceed. +pub enum RewriteRecursion { + /// Continue rewrite / visit this expression. + Continue, + /// Call [mutate()] immediately and return. + Mutate, + /// Do not rewrite / visit the children of this expression. + Stop, + /// Keep recursive but skip mutate on this expression + Skip, +} + +/// Trait for potentially recursively rewriting an [`Expr`] expression +/// tree. When passed to `Expr::rewrite`, `ExpressionVisitor::mutate` is +/// invoked recursively on all nodes of an expression tree. See the +/// comments on `Expr::rewrite` for details on its use +pub trait ExprRewriter: Sized { + /// Invoked before any children of `expr` are rewritten / + /// visited. Default implementation returns `Ok(RewriteRecursion::Continue)` + fn pre_visit(&mut self, _expr: &E) -> Result { + Ok(RewriteRecursion::Continue) + } + + /// Invoked after all children of `expr` have been mutated and + /// returns a potentially modified expr. + fn mutate(&mut self, expr: E) -> Result; +} + +/// a trait for marking types that are rewritable by [ExprRewriter] +pub trait ExprRewritable: Sized { + /// rewrite the expression tree using the given [ExprRewriter] + fn rewrite>(self, rewriter: &mut R) -> Result; +} + +impl ExprRewritable for Expr { + /// Performs a depth first walk of an expression and its children + /// to rewrite an expression, consuming `self` producing a new + /// [`Expr`]. + /// + /// Implements a modified version of the [visitor + /// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) to + /// separate algorithms from the structure of the `Expr` tree and + /// make it easier to write new, efficient expression + /// transformation algorithms. + /// + /// For an expression tree such as + /// ```text + /// BinaryExpr (GT) + /// left: Column("foo") + /// right: Column("bar") + /// ``` + /// + /// The nodes are visited using the following order + /// ```text + /// pre_visit(BinaryExpr(GT)) + /// pre_visit(Column("foo")) + /// mutatate(Column("foo")) + /// pre_visit(Column("bar")) + /// mutate(Column("bar")) + /// mutate(BinaryExpr(GT)) + /// ``` + /// + /// If an Err result is returned, recursion is stopped immediately + /// + /// If [`false`] is returned on a call to pre_visit, no + /// children of that expression are visited, nor is mutate + /// called on that expression + /// + fn rewrite(self, rewriter: &mut R) -> Result + where + R: ExprRewriter, + { + let need_mutate = match rewriter.pre_visit(&self)? { + RewriteRecursion::Mutate => return rewriter.mutate(self), + RewriteRecursion::Stop => return Ok(self), + RewriteRecursion::Continue => true, + RewriteRecursion::Skip => false, + }; + + // recurse into all sub expressions(and cover all expression types) + let expr = match self { + Expr::Alias(expr, name) => Expr::Alias(rewrite_boxed(expr, rewriter)?, name), + Expr::Column(_) => self.clone(), + Expr::ScalarVariable(names) => Expr::ScalarVariable(names), + Expr::Literal(value) => Expr::Literal(value), + Expr::BinaryExpr { left, op, right } => Expr::BinaryExpr { + left: rewrite_boxed(left, rewriter)?, + op, + right: rewrite_boxed(right, rewriter)?, + }, + Expr::Not(expr) => Expr::Not(rewrite_boxed(expr, rewriter)?), + Expr::IsNotNull(expr) => Expr::IsNotNull(rewrite_boxed(expr, rewriter)?), + Expr::IsNull(expr) => Expr::IsNull(rewrite_boxed(expr, rewriter)?), + Expr::Negative(expr) => Expr::Negative(rewrite_boxed(expr, rewriter)?), + Expr::Between { + expr, + low, + high, + negated, + } => Expr::Between { + expr: rewrite_boxed(expr, rewriter)?, + low: rewrite_boxed(low, rewriter)?, + high: rewrite_boxed(high, rewriter)?, + negated, + }, + Expr::Case { + expr, + when_then_expr, + else_expr, + } => { + let expr = rewrite_option_box(expr, rewriter)?; + let when_then_expr = when_then_expr + .into_iter() + .map(|(when, then)| { + Ok(( + rewrite_boxed(when, rewriter)?, + rewrite_boxed(then, rewriter)?, + )) + }) + .collect::>>()?; + + let else_expr = rewrite_option_box(else_expr, rewriter)?; + + Expr::Case { + expr, + when_then_expr, + else_expr, + } + } + Expr::Cast { expr, data_type } => Expr::Cast { + expr: rewrite_boxed(expr, rewriter)?, + data_type, + }, + Expr::TryCast { expr, data_type } => Expr::TryCast { + expr: rewrite_boxed(expr, rewriter)?, + data_type, + }, + Expr::Sort { + expr, + asc, + nulls_first, + } => Expr::Sort { + expr: rewrite_boxed(expr, rewriter)?, + asc, + nulls_first, + }, + Expr::ScalarFunction { args, fun } => Expr::ScalarFunction { + args: rewrite_vec(args, rewriter)?, + fun, + }, + Expr::ScalarUDF { args, fun } => Expr::ScalarUDF { + args: rewrite_vec(args, rewriter)?, + fun, + }, + Expr::WindowFunction { + args, + fun, + partition_by, + order_by, + window_frame, + } => Expr::WindowFunction { + args: rewrite_vec(args, rewriter)?, + fun, + partition_by: rewrite_vec(partition_by, rewriter)?, + order_by: rewrite_vec(order_by, rewriter)?, + window_frame, + }, + Expr::AggregateFunction { + args, + fun, + distinct, + } => Expr::AggregateFunction { + args: rewrite_vec(args, rewriter)?, + fun, + distinct, + }, + Expr::AggregateUDF { args, fun } => Expr::AggregateUDF { + args: rewrite_vec(args, rewriter)?, + fun, + }, + Expr::InList { + expr, + list, + negated, + } => Expr::InList { + expr: rewrite_boxed(expr, rewriter)?, + list: rewrite_vec(list, rewriter)?, + negated, + }, + Expr::Wildcard => Expr::Wildcard, + Expr::GetIndexedField { expr, key } => Expr::GetIndexedField { + expr: rewrite_boxed(expr, rewriter)?, + key, + }, + }; + + // now rewrite this expression itself + if need_mutate { + rewriter.mutate(expr) + } else { + Ok(expr) + } + } +} + +#[allow(clippy::boxed_local)] +fn rewrite_boxed(boxed_expr: Box, rewriter: &mut R) -> Result> +where + R: ExprRewriter, +{ + // TODO: It might be possible to avoid an allocation (the + // Box::new) below by reusing the box. + let expr: Expr = *boxed_expr; + let rewritten_expr = expr.rewrite(rewriter)?; + Ok(Box::new(rewritten_expr)) +} + +fn rewrite_option_box( + option_box: Option>, + rewriter: &mut R, +) -> Result>> +where + R: ExprRewriter, +{ + option_box + .map(|expr| rewrite_boxed(expr, rewriter)) + .transpose() +} + +/// rewrite a `Vec` of `Expr`s with the rewriter +fn rewrite_vec(v: Vec, rewriter: &mut R) -> Result> +where + R: ExprRewriter, +{ + v.into_iter().map(|expr| expr.rewrite(rewriter)).collect() +} + +/// Rewrite sort on aggregate expressions to sort on the column of aggregate output +/// For example, `max(x)` is written to `col("MAX(x)")` +pub fn rewrite_sort_cols_by_aggs( + exprs: impl IntoIterator>, + plan: &LogicalPlan, +) -> Result> { + exprs + .into_iter() + .map(|e| { + let expr = e.into(); + match expr { + Expr::Sort { + expr, + asc, + nulls_first, + } => { + let sort = Expr::Sort { + expr: Box::new(rewrite_sort_col_by_aggs(*expr, plan)?), + asc, + nulls_first, + }; + Ok(sort) + } + expr => Ok(expr), + } + }) + .collect() +} + +fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result { + match plan { + LogicalPlan::Aggregate(Aggregate { + input, aggr_expr, .. + }) => { + struct Rewriter<'a> { + plan: &'a LogicalPlan, + input: &'a LogicalPlan, + aggr_expr: &'a Vec, + } + + impl<'a> ExprRewriter for Rewriter<'a> { + fn mutate(&mut self, expr: Expr) -> Result { + let normalized_expr = normalize_col(expr.clone(), self.plan); + if normalized_expr.is_err() { + // The expr is not based on Aggregate plan output. Skip it. + return Ok(expr); + } + let normalized_expr = normalized_expr.unwrap(); + if let Some(found_agg) = + self.aggr_expr.iter().find(|a| (**a) == normalized_expr) + { + let agg = normalize_col(found_agg.clone(), self.plan)?; + let col = Expr::Column( + agg.to_field(self.input.schema()) + .map(|f| f.qualified_column())?, + ); + Ok(col) + } else { + Ok(expr) + } + } + } + + expr.rewrite(&mut Rewriter { + plan, + input, + aggr_expr, + }) + } + LogicalPlan::Projection(_) => rewrite_sort_col_by_aggs(expr, plan.inputs()[0]), + _ => Ok(expr), + } +} + +/// Recursively call [`Column::normalize`] on all Column expressions +/// in the `expr` expression tree. +pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { + normalize_col_with_schemas(expr, &plan.all_schemas(), &plan.using_columns()?) +} + +/// Recursively call [`Column::normalize`] on all Column expressions +/// in the `expr` expression tree. +fn normalize_col_with_schemas( + expr: Expr, + schemas: &[&Arc], + using_columns: &[HashSet], +) -> Result { + struct ColumnNormalizer<'a> { + schemas: &'a [&'a Arc], + using_columns: &'a [HashSet], + } + + impl<'a> ExprRewriter for ColumnNormalizer<'a> { + fn mutate(&mut self, expr: Expr) -> Result { + if let Expr::Column(c) = expr { + Ok(Expr::Column(c.normalize_with_schemas( + self.schemas, + self.using_columns, + )?)) + } else { + Ok(expr) + } + } + } + + expr.rewrite(&mut ColumnNormalizer { + schemas, + using_columns, + }) +} + +/// Recursively normalize all Column expressions in a list of expression trees +pub fn normalize_cols( + exprs: impl IntoIterator>, + plan: &LogicalPlan, +) -> Result> { + exprs + .into_iter() + .map(|e| normalize_col(e.into(), plan)) + .collect() +} + +/// Recursively replace all Column expressions in a given expression tree with Column expressions +/// provided by the hash map argument. +pub fn replace_col(e: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { + struct ColumnReplacer<'a> { + replace_map: &'a HashMap<&'a Column, &'a Column>, + } + + impl<'a> ExprRewriter for ColumnReplacer<'a> { + fn mutate(&mut self, expr: Expr) -> Result { + if let Expr::Column(c) = &expr { + match self.replace_map.get(c) { + Some(new_c) => Ok(Expr::Column((*new_c).to_owned())), + None => Ok(expr), + } + } else { + Ok(expr) + } + } + } + + e.rewrite(&mut ColumnReplacer { replace_map }) +} + +/// Recursively 'unnormalize' (remove all qualifiers) from an +/// expression tree. +/// +/// For example, if there were expressions like `foo.bar` this would +/// rewrite it to just `bar`. +pub fn unnormalize_col(expr: Expr) -> Expr { + struct RemoveQualifier {} + + impl ExprRewriter for RemoveQualifier { + fn mutate(&mut self, expr: Expr) -> Result { + if let Expr::Column(col) = expr { + //let Column { relation: _, name } = col; + Ok(Expr::Column(Column { + relation: None, + name: col.name, + })) + } else { + Ok(expr) + } + } + } + + expr.rewrite(&mut RemoveQualifier {}) + .expect("Unnormalize is infallable") +} + +/// Recursively un-normalize all Column expressions in a list of expression trees +#[inline] +pub fn unnormalize_cols(exprs: impl IntoIterator) -> Vec { + exprs.into_iter().map(unnormalize_col).collect() +} + +#[cfg(test)] +mod test { + use super::*; + use crate::logical_plan::DFField; + use crate::prelude::{col, lit}; + use arrow::datatypes::DataType; + use datafusion_common::ScalarValue; + + #[derive(Default)] + struct RecordingRewriter { + v: Vec, + } + impl ExprRewriter for RecordingRewriter { + fn mutate(&mut self, expr: Expr) -> Result { + self.v.push(format!("Mutated {:?}", expr)); + Ok(expr) + } + + fn pre_visit(&mut self, expr: &Expr) -> Result { + self.v.push(format!("Previsited {:?}", expr)); + Ok(RewriteRecursion::Continue) + } + } + + #[test] + fn rewriter_rewrite() { + let mut rewriter = FooBarRewriter {}; + + // rewrites "foo" --> "bar" + let rewritten = col("state").eq(lit("foo")).rewrite(&mut rewriter).unwrap(); + assert_eq!(rewritten, col("state").eq(lit("bar"))); + + // doesn't wrewrite + let rewritten = col("state").eq(lit("baz")).rewrite(&mut rewriter).unwrap(); + assert_eq!(rewritten, col("state").eq(lit("baz"))); + } + + /// rewrites all "foo" string literals to "bar" + struct FooBarRewriter {} + impl ExprRewriter for FooBarRewriter { + fn mutate(&mut self, expr: Expr) -> Result { + match expr { + Expr::Literal(ScalarValue::Utf8(Some(utf8_val))) => { + let utf8_val = if utf8_val == "foo" { + "bar".to_string() + } else { + utf8_val + }; + Ok(lit(utf8_val)) + } + // otherwise, return the expression unchanged + expr => Ok(expr), + } + } + } + + #[test] + fn normalize_cols() { + let expr = col("a") + col("b") + col("c"); + + // Schemas with some matching and some non matching cols + let schema_a = + DFSchema::new(vec![make_field("tableA", "a"), make_field("tableA", "aa")]) + .unwrap(); + let schema_c = + DFSchema::new(vec![make_field("tableC", "cc"), make_field("tableC", "c")]) + .unwrap(); + let schema_b = DFSchema::new(vec![make_field("tableB", "b")]).unwrap(); + // non matching + let schema_f = + DFSchema::new(vec![make_field("tableC", "f"), make_field("tableC", "ff")]) + .unwrap(); + let schemas = vec![schema_c, schema_f, schema_b, schema_a] + .into_iter() + .map(Arc::new) + .collect::>(); + let schemas = schemas.iter().collect::>(); + + let normalized_expr = normalize_col_with_schemas(expr, &schemas, &[]).unwrap(); + assert_eq!( + normalized_expr, + col("tableA.a") + col("tableB.b") + col("tableC.c") + ); + } + + #[test] + fn normalize_cols_priority() { + let expr = col("a") + col("b"); + // Schemas with multiple matches for column a, first takes priority + let schema_a = DFSchema::new(vec![make_field("tableA", "a")]).unwrap(); + let schema_b = DFSchema::new(vec![make_field("tableB", "b")]).unwrap(); + let schema_a2 = DFSchema::new(vec![make_field("tableA2", "a")]).unwrap(); + let schemas = vec![schema_a2, schema_b, schema_a] + .into_iter() + .map(Arc::new) + .collect::>(); + let schemas = schemas.iter().collect::>(); + + let normalized_expr = normalize_col_with_schemas(expr, &schemas, &[]).unwrap(); + assert_eq!(normalized_expr, col("tableA2.a") + col("tableB.b")); + } + + #[test] + fn normalize_cols_non_exist() { + // test normalizing columns when the name doesn't exist + let expr = col("a") + col("b"); + let schema_a = DFSchema::new(vec![make_field("tableA", "a")]).unwrap(); + let schemas = vec![schema_a].into_iter().map(Arc::new).collect::>(); + let schemas = schemas.iter().collect::>(); + + let error = normalize_col_with_schemas(expr, &schemas, &[]) + .unwrap_err() + .to_string(); + assert_eq!( + error, + "Error during planning: Column #b not found in provided schemas" + ); + } + + #[test] + fn unnormalize_cols() { + let expr = col("tableA.a") + col("tableB.b"); + let unnormalized_expr = unnormalize_col(expr); + assert_eq!(unnormalized_expr, col("a") + col("b")); + } + + fn make_field(relation: &str, column: &str) -> DFField { + DFField::new(Some(relation), column, DataType::Int8, false) + } + + #[test] + fn rewriter_visit() { + let mut rewriter = RecordingRewriter::default(); + col("state").eq(lit("CO")).rewrite(&mut rewriter).unwrap(); + + assert_eq!( + rewriter.v, + vec![ + "Previsited #state = Utf8(\"CO\")", + "Previsited #state", + "Mutated #state", + "Previsited Utf8(\"CO\")", + "Mutated Utf8(\"CO\")", + "Mutated #state = Utf8(\"CO\")" + ] + ) + } +} diff --git a/datafusion/src/logical_plan/expr_simplier.rs b/datafusion/src/logical_plan/expr_simplier.rs new file mode 100644 index 000000000000..06e58566f8a2 --- /dev/null +++ b/datafusion/src/logical_plan/expr_simplier.rs @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Expression simplifier + +use super::Expr; +use super::ExprRewritable; +use crate::execution::context::ExecutionProps; +use crate::optimizer::simplify_expressions::{ConstEvaluator, Simplifier}; +use datafusion_common::Result; + +/// The information necessary to apply algebraic simplification to an +/// [Expr]. See [SimplifyContext] for one implementation +pub trait SimplifyInfo { + /// returns true if this Expr has boolean type + fn is_boolean_type(&self, expr: &Expr) -> Result; + + /// returns true of this expr is nullable (could possibly be NULL) + fn nullable(&self, expr: &Expr) -> Result; + + /// Returns details needed for partial expression evaluation + fn execution_props(&self) -> &ExecutionProps; +} + +/// trait for types that can be simplified +pub trait ExprSimplifiable: Sized { + /// simplify this trait object using the given SimplifyInfo + fn simplify(self, info: &S) -> Result; +} + +impl ExprSimplifiable for Expr { + /// Simplifies this [`Expr`]`s as much as possible, evaluating + /// constants and applying algebraic simplifications + /// + /// # Example: + /// `b > 2 AND b > 2` + /// can be written to + /// `b > 2` + /// + /// ``` + /// use datafusion::logical_plan::*; + /// use datafusion::error::Result; + /// use datafusion::execution::context::ExecutionProps; + /// + /// /// Simple implementation that provides `Simplifier` the information it needs + /// #[derive(Default)] + /// struct Info { + /// execution_props: ExecutionProps, + /// }; + /// + /// impl SimplifyInfo for Info { + /// fn is_boolean_type(&self, expr: &Expr) -> Result { + /// Ok(false) + /// } + /// fn nullable(&self, expr: &Expr) -> Result { + /// Ok(true) + /// } + /// fn execution_props(&self) -> &ExecutionProps { + /// &self.execution_props + /// } + /// } + /// + /// // b < 2 + /// let b_lt_2 = col("b").gt(lit(2)); + /// + /// // (b < 2) OR (b < 2) + /// let expr = b_lt_2.clone().or(b_lt_2.clone()); + /// + /// // (b < 2) OR (b < 2) --> (b < 2) + /// let expr = expr.simplify(&Info::default()).unwrap(); + /// assert_eq!(expr, b_lt_2); + /// ``` + fn simplify(self, info: &S) -> Result { + let mut rewriter = Simplifier::new(info); + let mut const_evaluator = ConstEvaluator::new(info.execution_props()); + + // TODO iterate until no changes are made during rewrite + // (evaluating constants can enable new simplifications and + // simplifications can enable new constant evaluation) + // https://github.com/apache/arrow-datafusion/issues/1160 + self.rewrite(&mut const_evaluator)?.rewrite(&mut rewriter) + } +} diff --git a/datafusion/src/logical_plan/expr_visitor.rs b/datafusion/src/logical_plan/expr_visitor.rs new file mode 100644 index 000000000000..26084fb95f0b --- /dev/null +++ b/datafusion/src/logical_plan/expr_visitor.rs @@ -0,0 +1,176 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Expression visitor + +use super::Expr; +use datafusion_common::Result; + +/// Controls how the visitor recursion should proceed. +pub enum Recursion { + /// Attempt to visit all the children, recursively, of this expression. + Continue(V), + /// Do not visit the children of this expression, though the walk + /// of parents of this expression will not be affected + Stop(V), +} + +/// Encode the traversal of an expression tree. When passed to +/// `Expr::accept`, `ExpressionVisitor::visit` is invoked +/// recursively on all nodes of an expression tree. See the comments +/// on `Expr::accept` for details on its use +pub trait ExpressionVisitor: Sized { + /// Invoked before any children of `expr` are visisted. + fn pre_visit(self, expr: &E) -> Result> + where + Self: ExpressionVisitor; + + /// Invoked after all children of `expr` are visited. Default + /// implementation does nothing. + fn post_visit(self, _expr: &E) -> Result { + Ok(self) + } +} + +/// trait for types that can be visited by [`ExpressionVisitor`] +pub trait ExprVisitable: Sized { + /// accept a visitor, calling `visit` on all children of this + fn accept>(&self, visitor: V) -> Result; +} + +impl ExprVisitable for Expr { + /// Performs a depth first walk of an expression and + /// its children, calling [`ExpressionVisitor::pre_visit`] and + /// `visitor.post_visit`. + /// + /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) to + /// separate expression algorithms from the structure of the + /// `Expr` tree and make it easier to add new types of expressions + /// and algorithms that walk the tree. + /// + /// For an expression tree such as + /// ```text + /// BinaryExpr (GT) + /// left: Column("foo") + /// right: Column("bar") + /// ``` + /// + /// The nodes are visited using the following order + /// ```text + /// pre_visit(BinaryExpr(GT)) + /// pre_visit(Column("foo")) + /// pre_visit(Column("bar")) + /// post_visit(Column("bar")) + /// post_visit(Column("bar")) + /// post_visit(BinaryExpr(GT)) + /// ``` + /// + /// If an Err result is returned, recursion is stopped immediately + /// + /// If `Recursion::Stop` is returned on a call to pre_visit, no + /// children of that expression are visited, nor is post_visit + /// called on that expression + /// + fn accept(&self, visitor: V) -> Result { + let visitor = match visitor.pre_visit(self)? { + Recursion::Continue(visitor) => visitor, + // If the recursion should stop, do not visit children + Recursion::Stop(visitor) => return Ok(visitor), + }; + + // recurse (and cover all expression types) + let visitor = match self { + Expr::Alias(expr, _) + | Expr::Not(expr) + | Expr::IsNotNull(expr) + | Expr::IsNull(expr) + | Expr::Negative(expr) + | Expr::Cast { expr, .. } + | Expr::TryCast { expr, .. } + | Expr::Sort { expr, .. } + | Expr::GetIndexedField { expr, .. } => expr.accept(visitor), + Expr::Column(_) + | Expr::ScalarVariable(_) + | Expr::Literal(_) + | Expr::Wildcard => Ok(visitor), + Expr::BinaryExpr { left, right, .. } => { + let visitor = left.accept(visitor)?; + right.accept(visitor) + } + Expr::Between { + expr, low, high, .. + } => { + let visitor = expr.accept(visitor)?; + let visitor = low.accept(visitor)?; + high.accept(visitor) + } + Expr::Case { + expr, + when_then_expr, + else_expr, + } => { + let visitor = if let Some(expr) = expr.as_ref() { + expr.accept(visitor) + } else { + Ok(visitor) + }?; + let visitor = when_then_expr.iter().try_fold( + visitor, + |visitor, (when, then)| { + let visitor = when.accept(visitor)?; + then.accept(visitor) + }, + )?; + if let Some(else_expr) = else_expr.as_ref() { + else_expr.accept(visitor) + } else { + Ok(visitor) + } + } + Expr::ScalarFunction { args, .. } + | Expr::ScalarUDF { args, .. } + | Expr::AggregateFunction { args, .. } + | Expr::AggregateUDF { args, .. } => args + .iter() + .try_fold(visitor, |visitor, arg| arg.accept(visitor)), + Expr::WindowFunction { + args, + partition_by, + order_by, + .. + } => { + let visitor = args + .iter() + .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; + let visitor = partition_by + .iter() + .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; + let visitor = order_by + .iter() + .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; + Ok(visitor) + } + Expr::InList { expr, list, .. } => { + let visitor = expr.accept(visitor)?; + list.iter() + .try_fold(visitor, |visitor, arg| arg.accept(visitor)) + } + }?; + + visitor.post_visit(self) + } +} diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index ec1aea6a72a1..085775a2eb8c 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -25,6 +25,9 @@ pub(crate) mod builder; mod dfschema; mod display; mod expr; +mod expr_rewriter; +mod expr_simplier; +mod expr_visitor; mod extension; mod operators; pub mod plan; @@ -41,14 +44,18 @@ pub use expr::{ columnize_expr, combine_filters, concat, concat_ws, cos, count, count_distinct, create_udaf, create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, floor, in_list, initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, - lower, lpad, ltrim, max, md5, min, normalize_col, normalize_cols, now, octet_length, - or, random, regexp_match, regexp_replace, repeat, replace, replace_col, reverse, - rewrite_sort_cols_by_aggs, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, - signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, to_hex, - translate, trim, trunc, unalias, unnormalize_col, unnormalize_cols, upper, when, - Column, Expr, ExprRewriter, ExprSchema, ExpressionVisitor, Literal, Recursion, - RewriteRecursion, SimplifyInfo, + lower, lpad, ltrim, max, md5, min, now, octet_length, or, random, regexp_match, + regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256, + sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, + to_hex, translate, trim, trunc, unalias, upper, when, Column, Expr, ExprSchema, + Literal, }; +pub use expr_rewriter::{ + normalize_col, normalize_cols, replace_col, rewrite_sort_cols_by_aggs, + unnormalize_col, unnormalize_cols, ExprRewritable, ExprRewriter, RewriteRecursion, +}; +pub use expr_simplier::{ExprSimplifiable, SimplifyInfo}; +pub use expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion}; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; pub use plan::{ diff --git a/datafusion/src/optimizer/common_subexpr_eliminate.rs b/datafusion/src/optimizer/common_subexpr_eliminate.rs index 947073409d05..5c2219b3d99a 100644 --- a/datafusion/src/optimizer/common_subexpr_eliminate.rs +++ b/datafusion/src/optimizer/common_subexpr_eliminate.rs @@ -23,8 +23,8 @@ use crate::logical_plan::plan::{Filter, Projection, Window}; use crate::logical_plan::{ col, plan::{Aggregate, Sort}, - DFField, DFSchema, Expr, ExprRewriter, ExpressionVisitor, LogicalPlan, Recursion, - RewriteRecursion, + DFField, DFSchema, Expr, ExprRewritable, ExprRewriter, ExprVisitable, + ExpressionVisitor, LogicalPlan, Recursion, RewriteRecursion, }; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; diff --git a/datafusion/src/optimizer/simplify_expressions.rs b/datafusion/src/optimizer/simplify_expressions.rs index 5f87542491d7..f8f3df44b673 100644 --- a/datafusion/src/optimizer/simplify_expressions.rs +++ b/datafusion/src/optimizer/simplify_expressions.rs @@ -24,8 +24,8 @@ use arrow::record_batch::RecordBatch; use crate::error::DataFusionError; use crate::execution::context::ExecutionProps; use crate::logical_plan::{ - lit, DFSchema, DFSchemaRef, Expr, ExprRewriter, LogicalPlan, RewriteRecursion, - SimplifyInfo, + lit, DFSchema, DFSchemaRef, Expr, ExprRewritable, ExprRewriter, ExprSimplifiable, + LogicalPlan, RewriteRecursion, SimplifyInfo, }; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; @@ -252,6 +252,7 @@ impl SimplifyExpressions { /// /// ``` /// # use datafusion::prelude::*; +/// # use datafusion::logical_plan::ExprRewritable; /// # use datafusion::optimizer::simplify_expressions::ConstEvaluator; /// # use datafusion::execution::context::ExecutionProps; /// diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index f7ab836b398c..41d1e4bca03b 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -22,9 +22,11 @@ use crate::execution::context::ExecutionProps; use crate::logical_plan::plan::{ Aggregate, Analyze, Extension, Filter, Join, Projection, Sort, Window, }; + use crate::logical_plan::{ - build_join_schema, Column, CreateMemoryTable, DFSchemaRef, Expr, Limit, LogicalPlan, - LogicalPlanBuilder, Operator, Partitioning, Recursion, Repartition, Union, Values, + build_join_schema, Column, CreateMemoryTable, DFSchemaRef, Expr, ExprVisitable, + Limit, LogicalPlan, LogicalPlanBuilder, Operator, Partitioning, Recursion, + Repartition, Union, Values, }; use crate::prelude::lit; use crate::scalar::ScalarValue; diff --git a/datafusion/src/sql/utils.rs b/datafusion/src/sql/utils.rs index d0cef0f3d376..cbe40d6dc51d 100644 --- a/datafusion/src/sql/utils.rs +++ b/datafusion/src/sql/utils.rs @@ -20,6 +20,7 @@ use arrow::datatypes::DataType; use sqlparser::ast::Ident; +use crate::logical_plan::ExprVisitable; use crate::logical_plan::{Expr, LogicalPlan}; use crate::scalar::{ScalarValue, MAX_PRECISION_FOR_DECIMAL128}; use crate::{ diff --git a/datafusion/tests/simplification.rs b/datafusion/tests/simplification.rs index 5edf43f5ccb2..0ce8e7685b83 100644 --- a/datafusion/tests/simplification.rs +++ b/datafusion/tests/simplification.rs @@ -18,6 +18,7 @@ //! This program demonstrates the DataFusion expression simplification API. use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::logical_plan::ExprSimplifiable; use datafusion::{ error::Result, execution::context::ExecutionProps, From 4b682732ba49c66e3cb010db2b3fddd34a299c87 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Tue, 8 Feb 2022 21:47:23 +0800 Subject: [PATCH 44/50] move built-in scalar functions (#1764) --- datafusion-expr/src/built_in_function.rs | 330 ++++++++++++++++++++++ datafusion-expr/src/lib.rs | 2 + datafusion/src/physical_plan/functions.rs | 311 +------------------- 3 files changed, 334 insertions(+), 309 deletions(-) create mode 100644 datafusion-expr/src/built_in_function.rs diff --git a/datafusion-expr/src/built_in_function.rs b/datafusion-expr/src/built_in_function.rs new file mode 100644 index 000000000000..0d5ee9792ecb --- /dev/null +++ b/datafusion-expr/src/built_in_function.rs @@ -0,0 +1,330 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Built-in functions + +use crate::Volatility; +use datafusion_common::{DataFusionError, Result}; +use std::fmt; +use std::str::FromStr; + +/// Enum of all built-in scalar functions +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum BuiltinScalarFunction { + // math functions + /// abs + Abs, + /// acos + Acos, + /// asin + Asin, + /// atan + Atan, + /// ceil + Ceil, + /// cos + Cos, + /// Digest + Digest, + /// exp + Exp, + /// floor + Floor, + /// ln, Natural logarithm + Ln, + /// log, same as log10 + Log, + /// log10 + Log10, + /// log2 + Log2, + /// round + Round, + /// signum + Signum, + /// sin + Sin, + /// sqrt + Sqrt, + /// tan + Tan, + /// trunc + Trunc, + + // string functions + /// construct an array from columns + Array, + /// ascii + Ascii, + /// bit_length + BitLength, + /// btrim + Btrim, + /// character_length + CharacterLength, + /// chr + Chr, + /// concat + Concat, + /// concat_ws + ConcatWithSeparator, + /// date_part + DatePart, + /// date_trunc + DateTrunc, + /// initcap + InitCap, + /// left + Left, + /// lpad + Lpad, + /// lower + Lower, + /// ltrim + Ltrim, + /// md5 + MD5, + /// nullif + NullIf, + /// octet_length + OctetLength, + /// random + Random, + /// regexp_replace + RegexpReplace, + /// repeat + Repeat, + /// replace + Replace, + /// reverse + Reverse, + /// right + Right, + /// rpad + Rpad, + /// rtrim + Rtrim, + /// sha224 + SHA224, + /// sha256 + SHA256, + /// sha384 + SHA384, + /// Sha512 + SHA512, + /// split_part + SplitPart, + /// starts_with + StartsWith, + /// strpos + Strpos, + /// substr + Substr, + /// to_hex + ToHex, + /// to_timestamp + ToTimestamp, + /// to_timestamp_millis + ToTimestampMillis, + /// to_timestamp_micros + ToTimestampMicros, + /// to_timestamp_seconds + ToTimestampSeconds, + ///now + Now, + /// translate + Translate, + /// trim + Trim, + /// upper + Upper, + /// regexp_match + RegexpMatch, +} + +impl BuiltinScalarFunction { + /// an allowlist of functions to take zero arguments, so that they will get special treatment + /// while executing. + pub fn supports_zero_argument(&self) -> bool { + matches!( + self, + BuiltinScalarFunction::Random | BuiltinScalarFunction::Now + ) + } + /// Returns the [Volatility] of the builtin function. + pub fn volatility(&self) -> Volatility { + match self { + //Immutable scalar builtins + BuiltinScalarFunction::Abs => Volatility::Immutable, + BuiltinScalarFunction::Acos => Volatility::Immutable, + BuiltinScalarFunction::Asin => Volatility::Immutable, + BuiltinScalarFunction::Atan => Volatility::Immutable, + BuiltinScalarFunction::Ceil => Volatility::Immutable, + BuiltinScalarFunction::Cos => Volatility::Immutable, + BuiltinScalarFunction::Exp => Volatility::Immutable, + BuiltinScalarFunction::Floor => Volatility::Immutable, + BuiltinScalarFunction::Ln => Volatility::Immutable, + BuiltinScalarFunction::Log => Volatility::Immutable, + BuiltinScalarFunction::Log10 => Volatility::Immutable, + BuiltinScalarFunction::Log2 => Volatility::Immutable, + BuiltinScalarFunction::Round => Volatility::Immutable, + BuiltinScalarFunction::Signum => Volatility::Immutable, + BuiltinScalarFunction::Sin => Volatility::Immutable, + BuiltinScalarFunction::Sqrt => Volatility::Immutable, + BuiltinScalarFunction::Tan => Volatility::Immutable, + BuiltinScalarFunction::Trunc => Volatility::Immutable, + BuiltinScalarFunction::Array => Volatility::Immutable, + BuiltinScalarFunction::Ascii => Volatility::Immutable, + BuiltinScalarFunction::BitLength => Volatility::Immutable, + BuiltinScalarFunction::Btrim => Volatility::Immutable, + BuiltinScalarFunction::CharacterLength => Volatility::Immutable, + BuiltinScalarFunction::Chr => Volatility::Immutable, + BuiltinScalarFunction::Concat => Volatility::Immutable, + BuiltinScalarFunction::ConcatWithSeparator => Volatility::Immutable, + BuiltinScalarFunction::DatePart => Volatility::Immutable, + BuiltinScalarFunction::DateTrunc => Volatility::Immutable, + BuiltinScalarFunction::InitCap => Volatility::Immutable, + BuiltinScalarFunction::Left => Volatility::Immutable, + BuiltinScalarFunction::Lpad => Volatility::Immutable, + BuiltinScalarFunction::Lower => Volatility::Immutable, + BuiltinScalarFunction::Ltrim => Volatility::Immutable, + BuiltinScalarFunction::MD5 => Volatility::Immutable, + BuiltinScalarFunction::NullIf => Volatility::Immutable, + BuiltinScalarFunction::OctetLength => Volatility::Immutable, + BuiltinScalarFunction::RegexpReplace => Volatility::Immutable, + BuiltinScalarFunction::Repeat => Volatility::Immutable, + BuiltinScalarFunction::Replace => Volatility::Immutable, + BuiltinScalarFunction::Reverse => Volatility::Immutable, + BuiltinScalarFunction::Right => Volatility::Immutable, + BuiltinScalarFunction::Rpad => Volatility::Immutable, + BuiltinScalarFunction::Rtrim => Volatility::Immutable, + BuiltinScalarFunction::SHA224 => Volatility::Immutable, + BuiltinScalarFunction::SHA256 => Volatility::Immutable, + BuiltinScalarFunction::SHA384 => Volatility::Immutable, + BuiltinScalarFunction::SHA512 => Volatility::Immutable, + BuiltinScalarFunction::Digest => Volatility::Immutable, + BuiltinScalarFunction::SplitPart => Volatility::Immutable, + BuiltinScalarFunction::StartsWith => Volatility::Immutable, + BuiltinScalarFunction::Strpos => Volatility::Immutable, + BuiltinScalarFunction::Substr => Volatility::Immutable, + BuiltinScalarFunction::ToHex => Volatility::Immutable, + BuiltinScalarFunction::ToTimestamp => Volatility::Immutable, + BuiltinScalarFunction::ToTimestampMillis => Volatility::Immutable, + BuiltinScalarFunction::ToTimestampMicros => Volatility::Immutable, + BuiltinScalarFunction::ToTimestampSeconds => Volatility::Immutable, + BuiltinScalarFunction::Translate => Volatility::Immutable, + BuiltinScalarFunction::Trim => Volatility::Immutable, + BuiltinScalarFunction::Upper => Volatility::Immutable, + BuiltinScalarFunction::RegexpMatch => Volatility::Immutable, + + //Stable builtin functions + BuiltinScalarFunction::Now => Volatility::Stable, + + //Volatile builtin functions + BuiltinScalarFunction::Random => Volatility::Volatile, + } + } +} + +impl fmt::Display for BuiltinScalarFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // lowercase of the debug. + write!(f, "{}", format!("{:?}", self).to_lowercase()) + } +} + +impl FromStr for BuiltinScalarFunction { + type Err = DataFusionError; + fn from_str(name: &str) -> Result { + Ok(match name { + // math functions + "abs" => BuiltinScalarFunction::Abs, + "acos" => BuiltinScalarFunction::Acos, + "asin" => BuiltinScalarFunction::Asin, + "atan" => BuiltinScalarFunction::Atan, + "ceil" => BuiltinScalarFunction::Ceil, + "cos" => BuiltinScalarFunction::Cos, + "exp" => BuiltinScalarFunction::Exp, + "floor" => BuiltinScalarFunction::Floor, + "ln" => BuiltinScalarFunction::Ln, + "log" => BuiltinScalarFunction::Log, + "log10" => BuiltinScalarFunction::Log10, + "log2" => BuiltinScalarFunction::Log2, + "round" => BuiltinScalarFunction::Round, + "signum" => BuiltinScalarFunction::Signum, + "sin" => BuiltinScalarFunction::Sin, + "sqrt" => BuiltinScalarFunction::Sqrt, + "tan" => BuiltinScalarFunction::Tan, + "trunc" => BuiltinScalarFunction::Trunc, + + // string functions + "array" => BuiltinScalarFunction::Array, + "ascii" => BuiltinScalarFunction::Ascii, + "bit_length" => BuiltinScalarFunction::BitLength, + "btrim" => BuiltinScalarFunction::Btrim, + "char_length" => BuiltinScalarFunction::CharacterLength, + "character_length" => BuiltinScalarFunction::CharacterLength, + "concat" => BuiltinScalarFunction::Concat, + "concat_ws" => BuiltinScalarFunction::ConcatWithSeparator, + "chr" => BuiltinScalarFunction::Chr, + "date_part" | "datepart" => BuiltinScalarFunction::DatePart, + "date_trunc" | "datetrunc" => BuiltinScalarFunction::DateTrunc, + "initcap" => BuiltinScalarFunction::InitCap, + "left" => BuiltinScalarFunction::Left, + "length" => BuiltinScalarFunction::CharacterLength, + "lower" => BuiltinScalarFunction::Lower, + "lpad" => BuiltinScalarFunction::Lpad, + "ltrim" => BuiltinScalarFunction::Ltrim, + "md5" => BuiltinScalarFunction::MD5, + "nullif" => BuiltinScalarFunction::NullIf, + "octet_length" => BuiltinScalarFunction::OctetLength, + "random" => BuiltinScalarFunction::Random, + "regexp_replace" => BuiltinScalarFunction::RegexpReplace, + "repeat" => BuiltinScalarFunction::Repeat, + "replace" => BuiltinScalarFunction::Replace, + "reverse" => BuiltinScalarFunction::Reverse, + "right" => BuiltinScalarFunction::Right, + "rpad" => BuiltinScalarFunction::Rpad, + "rtrim" => BuiltinScalarFunction::Rtrim, + "sha224" => BuiltinScalarFunction::SHA224, + "sha256" => BuiltinScalarFunction::SHA256, + "sha384" => BuiltinScalarFunction::SHA384, + "sha512" => BuiltinScalarFunction::SHA512, + "digest" => BuiltinScalarFunction::Digest, + "split_part" => BuiltinScalarFunction::SplitPart, + "starts_with" => BuiltinScalarFunction::StartsWith, + "strpos" => BuiltinScalarFunction::Strpos, + "substr" => BuiltinScalarFunction::Substr, + "to_hex" => BuiltinScalarFunction::ToHex, + "to_timestamp" => BuiltinScalarFunction::ToTimestamp, + "to_timestamp_millis" => BuiltinScalarFunction::ToTimestampMillis, + "to_timestamp_micros" => BuiltinScalarFunction::ToTimestampMicros, + "to_timestamp_seconds" => BuiltinScalarFunction::ToTimestampSeconds, + "now" => BuiltinScalarFunction::Now, + "translate" => BuiltinScalarFunction::Translate, + "trim" => BuiltinScalarFunction::Trim, + "upper" => BuiltinScalarFunction::Upper, + "regexp_match" => BuiltinScalarFunction::RegexpMatch, + _ => { + return Err(DataFusionError::Plan(format!( + "There is no built-in function named {}", + name + ))) + } + }) + } +} diff --git a/datafusion-expr/src/lib.rs b/datafusion-expr/src/lib.rs index d2b10b404a8b..7dcddc39c4dc 100644 --- a/datafusion-expr/src/lib.rs +++ b/datafusion-expr/src/lib.rs @@ -16,12 +16,14 @@ // under the License. mod aggregate_function; +mod built_in_function; mod operator; mod signature; mod window_frame; mod window_function; pub use aggregate_function::AggregateFunction; +pub use built_in_function::BuiltinScalarFunction; pub use operator::Operator; pub use signature::{Signature, TypeSignature, Volatility}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs index af157c02fc38..9582eecce33e 100644 --- a/datafusion/src/physical_plan/functions.rs +++ b/datafusion/src/physical_plan/functions.rs @@ -54,9 +54,9 @@ use arrow::{ }; use fmt::{Debug, Formatter}; use std::convert::From; -use std::{any::Any, fmt, str::FromStr, sync::Arc}; +use std::{any::Any, fmt, sync::Arc}; -pub use datafusion_expr::{Signature, TypeSignature, Volatility}; +pub use datafusion_expr::{BuiltinScalarFunction, Signature, TypeSignature, Volatility}; /// Scalar function /// @@ -73,313 +73,6 @@ pub type ScalarFunctionImplementation = pub type ReturnTypeFunction = Arc Result> + Send + Sync>; -/// Enum of all built-in scalar functions -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum BuiltinScalarFunction { - // math functions - /// abs - Abs, - /// acos - Acos, - /// asin - Asin, - /// atan - Atan, - /// ceil - Ceil, - /// cos - Cos, - /// Digest - Digest, - /// exp - Exp, - /// floor - Floor, - /// ln, Natural logarithm - Ln, - /// log, same as log10 - Log, - /// log10 - Log10, - /// log2 - Log2, - /// round - Round, - /// signum - Signum, - /// sin - Sin, - /// sqrt - Sqrt, - /// tan - Tan, - /// trunc - Trunc, - - // string functions - /// construct an array from columns - Array, - /// ascii - Ascii, - /// bit_length - BitLength, - /// btrim - Btrim, - /// character_length - CharacterLength, - /// chr - Chr, - /// concat - Concat, - /// concat_ws - ConcatWithSeparator, - /// date_part - DatePart, - /// date_trunc - DateTrunc, - /// initcap - InitCap, - /// left - Left, - /// lpad - Lpad, - /// lower - Lower, - /// ltrim - Ltrim, - /// md5 - MD5, - /// nullif - NullIf, - /// octet_length - OctetLength, - /// random - Random, - /// regexp_replace - RegexpReplace, - /// repeat - Repeat, - /// replace - Replace, - /// reverse - Reverse, - /// right - Right, - /// rpad - Rpad, - /// rtrim - Rtrim, - /// sha224 - SHA224, - /// sha256 - SHA256, - /// sha384 - SHA384, - /// Sha512 - SHA512, - /// split_part - SplitPart, - /// starts_with - StartsWith, - /// strpos - Strpos, - /// substr - Substr, - /// to_hex - ToHex, - /// to_timestamp - ToTimestamp, - /// to_timestamp_millis - ToTimestampMillis, - /// to_timestamp_micros - ToTimestampMicros, - /// to_timestamp_seconds - ToTimestampSeconds, - ///now - Now, - /// translate - Translate, - /// trim - Trim, - /// upper - Upper, - /// regexp_match - RegexpMatch, -} - -impl BuiltinScalarFunction { - /// an allowlist of functions to take zero arguments, so that they will get special treatment - /// while executing. - fn supports_zero_argument(&self) -> bool { - matches!( - self, - BuiltinScalarFunction::Random | BuiltinScalarFunction::Now - ) - } - /// Returns the [Volatility] of the builtin function. - pub fn volatility(&self) -> Volatility { - match self { - //Immutable scalar builtins - BuiltinScalarFunction::Abs => Volatility::Immutable, - BuiltinScalarFunction::Acos => Volatility::Immutable, - BuiltinScalarFunction::Asin => Volatility::Immutable, - BuiltinScalarFunction::Atan => Volatility::Immutable, - BuiltinScalarFunction::Ceil => Volatility::Immutable, - BuiltinScalarFunction::Cos => Volatility::Immutable, - BuiltinScalarFunction::Exp => Volatility::Immutable, - BuiltinScalarFunction::Floor => Volatility::Immutable, - BuiltinScalarFunction::Ln => Volatility::Immutable, - BuiltinScalarFunction::Log => Volatility::Immutable, - BuiltinScalarFunction::Log10 => Volatility::Immutable, - BuiltinScalarFunction::Log2 => Volatility::Immutable, - BuiltinScalarFunction::Round => Volatility::Immutable, - BuiltinScalarFunction::Signum => Volatility::Immutable, - BuiltinScalarFunction::Sin => Volatility::Immutable, - BuiltinScalarFunction::Sqrt => Volatility::Immutable, - BuiltinScalarFunction::Tan => Volatility::Immutable, - BuiltinScalarFunction::Trunc => Volatility::Immutable, - BuiltinScalarFunction::Array => Volatility::Immutable, - BuiltinScalarFunction::Ascii => Volatility::Immutable, - BuiltinScalarFunction::BitLength => Volatility::Immutable, - BuiltinScalarFunction::Btrim => Volatility::Immutable, - BuiltinScalarFunction::CharacterLength => Volatility::Immutable, - BuiltinScalarFunction::Chr => Volatility::Immutable, - BuiltinScalarFunction::Concat => Volatility::Immutable, - BuiltinScalarFunction::ConcatWithSeparator => Volatility::Immutable, - BuiltinScalarFunction::DatePart => Volatility::Immutable, - BuiltinScalarFunction::DateTrunc => Volatility::Immutable, - BuiltinScalarFunction::InitCap => Volatility::Immutable, - BuiltinScalarFunction::Left => Volatility::Immutable, - BuiltinScalarFunction::Lpad => Volatility::Immutable, - BuiltinScalarFunction::Lower => Volatility::Immutable, - BuiltinScalarFunction::Ltrim => Volatility::Immutable, - BuiltinScalarFunction::MD5 => Volatility::Immutable, - BuiltinScalarFunction::NullIf => Volatility::Immutable, - BuiltinScalarFunction::OctetLength => Volatility::Immutable, - BuiltinScalarFunction::RegexpReplace => Volatility::Immutable, - BuiltinScalarFunction::Repeat => Volatility::Immutable, - BuiltinScalarFunction::Replace => Volatility::Immutable, - BuiltinScalarFunction::Reverse => Volatility::Immutable, - BuiltinScalarFunction::Right => Volatility::Immutable, - BuiltinScalarFunction::Rpad => Volatility::Immutable, - BuiltinScalarFunction::Rtrim => Volatility::Immutable, - BuiltinScalarFunction::SHA224 => Volatility::Immutable, - BuiltinScalarFunction::SHA256 => Volatility::Immutable, - BuiltinScalarFunction::SHA384 => Volatility::Immutable, - BuiltinScalarFunction::SHA512 => Volatility::Immutable, - BuiltinScalarFunction::Digest => Volatility::Immutable, - BuiltinScalarFunction::SplitPart => Volatility::Immutable, - BuiltinScalarFunction::StartsWith => Volatility::Immutable, - BuiltinScalarFunction::Strpos => Volatility::Immutable, - BuiltinScalarFunction::Substr => Volatility::Immutable, - BuiltinScalarFunction::ToHex => Volatility::Immutable, - BuiltinScalarFunction::ToTimestamp => Volatility::Immutable, - BuiltinScalarFunction::ToTimestampMillis => Volatility::Immutable, - BuiltinScalarFunction::ToTimestampMicros => Volatility::Immutable, - BuiltinScalarFunction::ToTimestampSeconds => Volatility::Immutable, - BuiltinScalarFunction::Translate => Volatility::Immutable, - BuiltinScalarFunction::Trim => Volatility::Immutable, - BuiltinScalarFunction::Upper => Volatility::Immutable, - BuiltinScalarFunction::RegexpMatch => Volatility::Immutable, - - //Stable builtin functions - BuiltinScalarFunction::Now => Volatility::Stable, - - //Volatile builtin functions - BuiltinScalarFunction::Random => Volatility::Volatile, - } - } -} - -impl fmt::Display for BuiltinScalarFunction { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - // lowercase of the debug. - write!(f, "{}", format!("{:?}", self).to_lowercase()) - } -} - -impl FromStr for BuiltinScalarFunction { - type Err = DataFusionError; - fn from_str(name: &str) -> Result { - Ok(match name { - // math functions - "abs" => BuiltinScalarFunction::Abs, - "acos" => BuiltinScalarFunction::Acos, - "asin" => BuiltinScalarFunction::Asin, - "atan" => BuiltinScalarFunction::Atan, - "ceil" => BuiltinScalarFunction::Ceil, - "cos" => BuiltinScalarFunction::Cos, - "exp" => BuiltinScalarFunction::Exp, - "floor" => BuiltinScalarFunction::Floor, - "ln" => BuiltinScalarFunction::Ln, - "log" => BuiltinScalarFunction::Log, - "log10" => BuiltinScalarFunction::Log10, - "log2" => BuiltinScalarFunction::Log2, - "round" => BuiltinScalarFunction::Round, - "signum" => BuiltinScalarFunction::Signum, - "sin" => BuiltinScalarFunction::Sin, - "sqrt" => BuiltinScalarFunction::Sqrt, - "tan" => BuiltinScalarFunction::Tan, - "trunc" => BuiltinScalarFunction::Trunc, - - // string functions - "array" => BuiltinScalarFunction::Array, - "ascii" => BuiltinScalarFunction::Ascii, - "bit_length" => BuiltinScalarFunction::BitLength, - "btrim" => BuiltinScalarFunction::Btrim, - "char_length" => BuiltinScalarFunction::CharacterLength, - "character_length" => BuiltinScalarFunction::CharacterLength, - "concat" => BuiltinScalarFunction::Concat, - "concat_ws" => BuiltinScalarFunction::ConcatWithSeparator, - "chr" => BuiltinScalarFunction::Chr, - "date_part" | "datepart" => BuiltinScalarFunction::DatePart, - "date_trunc" | "datetrunc" => BuiltinScalarFunction::DateTrunc, - "initcap" => BuiltinScalarFunction::InitCap, - "left" => BuiltinScalarFunction::Left, - "length" => BuiltinScalarFunction::CharacterLength, - "lower" => BuiltinScalarFunction::Lower, - "lpad" => BuiltinScalarFunction::Lpad, - "ltrim" => BuiltinScalarFunction::Ltrim, - "md5" => BuiltinScalarFunction::MD5, - "nullif" => BuiltinScalarFunction::NullIf, - "octet_length" => BuiltinScalarFunction::OctetLength, - "random" => BuiltinScalarFunction::Random, - "regexp_replace" => BuiltinScalarFunction::RegexpReplace, - "repeat" => BuiltinScalarFunction::Repeat, - "replace" => BuiltinScalarFunction::Replace, - "reverse" => BuiltinScalarFunction::Reverse, - "right" => BuiltinScalarFunction::Right, - "rpad" => BuiltinScalarFunction::Rpad, - "rtrim" => BuiltinScalarFunction::Rtrim, - "sha224" => BuiltinScalarFunction::SHA224, - "sha256" => BuiltinScalarFunction::SHA256, - "sha384" => BuiltinScalarFunction::SHA384, - "sha512" => BuiltinScalarFunction::SHA512, - "digest" => BuiltinScalarFunction::Digest, - "split_part" => BuiltinScalarFunction::SplitPart, - "starts_with" => BuiltinScalarFunction::StartsWith, - "strpos" => BuiltinScalarFunction::Strpos, - "substr" => BuiltinScalarFunction::Substr, - "to_hex" => BuiltinScalarFunction::ToHex, - "to_timestamp" => BuiltinScalarFunction::ToTimestamp, - "to_timestamp_millis" => BuiltinScalarFunction::ToTimestampMillis, - "to_timestamp_micros" => BuiltinScalarFunction::ToTimestampMicros, - "to_timestamp_seconds" => BuiltinScalarFunction::ToTimestampSeconds, - "now" => BuiltinScalarFunction::Now, - "translate" => BuiltinScalarFunction::Translate, - "trim" => BuiltinScalarFunction::Trim, - "upper" => BuiltinScalarFunction::Upper, - "regexp_match" => BuiltinScalarFunction::RegexpMatch, - _ => { - return Err(DataFusionError::Plan(format!( - "There is no built-in function named {}", - name - ))) - } - }) - } -} - macro_rules! make_utf8_to_return_type { ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => { fn $FUNC(arg_type: &DataType, name: &str) -> Result { From f2615afe4e42e67a98705fe5a774122c2751f710 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Tue, 8 Feb 2022 23:06:03 +0800 Subject: [PATCH 45/50] split expr type and null info to be expr-schemable (#1784) --- datafusion/src/logical_plan/builder.rs | 1 + datafusion/src/logical_plan/expr.rs | 202 +-------------- datafusion/src/logical_plan/expr_rewriter.rs | 1 + datafusion/src/logical_plan/expr_schema.rs | 231 ++++++++++++++++++ datafusion/src/logical_plan/mod.rs | 2 + .../src/optimizer/common_subexpr_eliminate.rs | 2 +- .../src/optimizer/simplify_expressions.rs | 8 +- .../optimizer/single_distinct_to_groupby.rs | 1 + datafusion/tests/simplification.rs | 1 + 9 files changed, 245 insertions(+), 204 deletions(-) create mode 100644 datafusion/src/logical_plan/expr_schema.rs diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index d81fa9d2afa6..a722238059f5 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -25,6 +25,7 @@ use crate::datasource::{ MemTable, TableProvider, }; use crate::error::{DataFusionError, Result}; +use crate::logical_plan::expr_schema::ExprSchemable; use crate::logical_plan::plan::{ Aggregate, Analyze, EmptyRelation, Explain, Filter, Join, Projection, Sort, TableScan, ToStringifiedPlan, Union, Window, diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 69da346aee8d..f19e9d8d6a35 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -20,16 +20,13 @@ pub use super::Operator; use crate::error::{DataFusionError, Result}; -use crate::field_util::get_indexed_field; +use crate::logical_plan::ExprSchemable; use crate::logical_plan::{window_frames, DFField, DFSchema}; use crate::physical_plan::functions::Volatility; -use crate::physical_plan::{ - aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF, - window_functions, -}; +use crate::physical_plan::{aggregates, functions, udf::ScalarUDF, window_functions}; use crate::{physical_plan::udaf::AggregateUDF, scalar::ScalarValue}; use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction}; -use arrow::{compute::can_cast_types, datatypes::DataType}; +use arrow::datatypes::DataType; pub use datafusion_common::{Column, ExprSchema}; use functions::{ReturnTypeFunction, ScalarFunctionImplementation, Signature}; use std::collections::HashSet; @@ -251,151 +248,6 @@ impl PartialOrd for Expr { } impl Expr { - /// Returns the [arrow::datatypes::DataType] of the expression - /// based on [ExprSchema] - /// - /// Note: [DFSchema] implements [ExprSchema]. - /// - /// # Errors - /// - /// This function errors when it is not possible to compute its - /// [arrow::datatypes::DataType]. This happens when e.g. the - /// expression refers to a column that does not exist in the - /// schema, or when the expression is incorrectly typed - /// (e.g. `[utf8] + [bool]`). - pub fn get_type(&self, schema: &S) -> Result { - match self { - Expr::Alias(expr, _) | Expr::Sort { expr, .. } | Expr::Negative(expr) => { - expr.get_type(schema) - } - Expr::Column(c) => Ok(schema.data_type(c)?.clone()), - Expr::ScalarVariable(_) => Ok(DataType::Utf8), - Expr::Literal(l) => Ok(l.get_datatype()), - Expr::Case { when_then_expr, .. } => when_then_expr[0].1.get_type(schema), - Expr::Cast { data_type, .. } | Expr::TryCast { data_type, .. } => { - Ok(data_type.clone()) - } - Expr::ScalarUDF { fun, args } => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - Ok((fun.return_type)(&data_types)?.as_ref().clone()) - } - Expr::ScalarFunction { fun, args } => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - functions::return_type(fun, &data_types) - } - Expr::WindowFunction { fun, args, .. } => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - window_functions::return_type(fun, &data_types) - } - Expr::AggregateFunction { fun, args, .. } => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - aggregates::return_type(fun, &data_types) - } - Expr::AggregateUDF { fun, args, .. } => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - Ok((fun.return_type)(&data_types)?.as_ref().clone()) - } - Expr::Not(_) - | Expr::IsNull(_) - | Expr::Between { .. } - | Expr::InList { .. } - | Expr::IsNotNull(_) => Ok(DataType::Boolean), - Expr::BinaryExpr { - ref left, - ref right, - ref op, - } => binary_operator_data_type( - &left.get_type(schema)?, - op, - &right.get_type(schema)?, - ), - Expr::Wildcard => Err(DataFusionError::Internal( - "Wildcard expressions are not valid in a logical query plan".to_owned(), - )), - Expr::GetIndexedField { ref expr, key } => { - let data_type = expr.get_type(schema)?; - - get_indexed_field(&data_type, key).map(|x| x.data_type().clone()) - } - } - } - - /// Returns the nullability of the expression based on [ExprSchema]. - /// - /// Note: [DFSchema] implements [ExprSchema]. - /// - /// # Errors - /// - /// This function errors when it is not possible to compute its - /// nullability. This happens when the expression refers to a - /// column that does not exist in the schema. - pub fn nullable(&self, input_schema: &S) -> Result { - match self { - Expr::Alias(expr, _) - | Expr::Not(expr) - | Expr::Negative(expr) - | Expr::Sort { expr, .. } - | Expr::Between { expr, .. } - | Expr::InList { expr, .. } => expr.nullable(input_schema), - Expr::Column(c) => input_schema.nullable(c), - Expr::Literal(value) => Ok(value.is_null()), - Expr::Case { - when_then_expr, - else_expr, - .. - } => { - // this expression is nullable if any of the input expressions are nullable - let then_nullable = when_then_expr - .iter() - .map(|(_, t)| t.nullable(input_schema)) - .collect::>>()?; - if then_nullable.contains(&true) { - Ok(true) - } else if let Some(e) = else_expr { - e.nullable(input_schema) - } else { - Ok(false) - } - } - Expr::Cast { expr, .. } => expr.nullable(input_schema), - Expr::ScalarVariable(_) - | Expr::TryCast { .. } - | Expr::ScalarFunction { .. } - | Expr::ScalarUDF { .. } - | Expr::WindowFunction { .. } - | Expr::AggregateFunction { .. } - | Expr::AggregateUDF { .. } => Ok(true), - Expr::IsNull(_) | Expr::IsNotNull(_) => Ok(false), - Expr::BinaryExpr { - ref left, - ref right, - .. - } => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?), - Expr::Wildcard => Err(DataFusionError::Internal( - "Wildcard expressions are not valid in a logical query plan".to_owned(), - )), - Expr::GetIndexedField { ref expr, key } => { - let data_type = expr.get_type(input_schema)?; - get_indexed_field(&data_type, key).map(|x| x.is_nullable()) - } - } - } - /// Returns the name of this expression based on [crate::logical_plan::DFSchema]. /// /// This represents how a column with this expression is named when no alias is chosen @@ -403,54 +255,6 @@ impl Expr { create_name(self, input_schema) } - /// Returns a [arrow::datatypes::Field] compatible with this expression. - pub fn to_field(&self, input_schema: &DFSchema) -> Result { - match self { - Expr::Column(c) => Ok(DFField::new( - c.relation.as_deref(), - &c.name, - self.get_type(input_schema)?, - self.nullable(input_schema)?, - )), - _ => Ok(DFField::new( - None, - &self.name(input_schema)?, - self.get_type(input_schema)?, - self.nullable(input_schema)?, - )), - } - } - - /// Wraps this expression in a cast to a target [arrow::datatypes::DataType]. - /// - /// # Errors - /// - /// This function errors when it is impossible to cast the - /// expression to the target [arrow::datatypes::DataType]. - pub fn cast_to( - self, - cast_to_type: &DataType, - schema: &S, - ) -> Result { - // TODO(kszucs): most of the operations do not validate the type correctness - // like all of the binary expressions below. Perhaps Expr should track the - // type of the expression? - let this_type = self.get_type(schema)?; - if this_type == *cast_to_type { - Ok(self) - } else if can_cast_types(&this_type, cast_to_type) { - Ok(Expr::Cast { - expr: Box::new(self), - data_type: cast_to_type.clone(), - }) - } else { - Err(DataFusionError::Plan(format!( - "Cannot automatically convert {:?} to {:?}", - this_type, cast_to_type - ))) - } - } - /// Return `self == other` pub fn eq(self, other: Expr) -> Expr { binary_expr(self, Operator::Eq, other) diff --git a/datafusion/src/logical_plan/expr_rewriter.rs b/datafusion/src/logical_plan/expr_rewriter.rs index d452dcd4c426..5062d5fce7ad 100644 --- a/datafusion/src/logical_plan/expr_rewriter.rs +++ b/datafusion/src/logical_plan/expr_rewriter.rs @@ -20,6 +20,7 @@ use super::Expr; use crate::logical_plan::plan::Aggregate; use crate::logical_plan::DFSchema; +use crate::logical_plan::ExprSchemable; use crate::logical_plan::LogicalPlan; use datafusion_common::Column; use datafusion_common::Result; diff --git a/datafusion/src/logical_plan/expr_schema.rs b/datafusion/src/logical_plan/expr_schema.rs new file mode 100644 index 000000000000..2e44c72415c9 --- /dev/null +++ b/datafusion/src/logical_plan/expr_schema.rs @@ -0,0 +1,231 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::Expr; +use crate::field_util::get_indexed_field; +use crate::physical_plan::{ + aggregates, expressions::binary_operator_data_type, functions, window_functions, +}; +use arrow::compute::can_cast_types; +use arrow::datatypes::DataType; +use datafusion_common::{DFField, DFSchema, DataFusionError, ExprSchema, Result}; + +/// trait to allow expr to typable with respect to a schema +pub trait ExprSchemable { + /// given a schema, return the type of the expr + fn get_type(&self, schema: &S) -> Result; + + /// given a schema, return the nullability of the expr + fn nullable(&self, input_schema: &S) -> Result; + + /// convert to a field with respect to a schema + fn to_field(&self, input_schema: &DFSchema) -> Result; + + /// cast to a type with respect to a schema + fn cast_to(self, cast_to_type: &DataType, schema: &S) -> Result; +} + +impl ExprSchemable for Expr { + /// Returns the [arrow::datatypes::DataType] of the expression + /// based on [ExprSchema] + /// + /// Note: [DFSchema] implements [ExprSchema]. + /// + /// # Errors + /// + /// This function errors when it is not possible to compute its + /// [arrow::datatypes::DataType]. This happens when e.g. the + /// expression refers to a column that does not exist in the + /// schema, or when the expression is incorrectly typed + /// (e.g. `[utf8] + [bool]`). + fn get_type(&self, schema: &S) -> Result { + match self { + Expr::Alias(expr, _) | Expr::Sort { expr, .. } | Expr::Negative(expr) => { + expr.get_type(schema) + } + Expr::Column(c) => Ok(schema.data_type(c)?.clone()), + Expr::ScalarVariable(_) => Ok(DataType::Utf8), + Expr::Literal(l) => Ok(l.get_datatype()), + Expr::Case { when_then_expr, .. } => when_then_expr[0].1.get_type(schema), + Expr::Cast { data_type, .. } | Expr::TryCast { data_type, .. } => { + Ok(data_type.clone()) + } + Expr::ScalarUDF { fun, args } => { + let data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + Ok((fun.return_type)(&data_types)?.as_ref().clone()) + } + Expr::ScalarFunction { fun, args } => { + let data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + functions::return_type(fun, &data_types) + } + Expr::WindowFunction { fun, args, .. } => { + let data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + window_functions::return_type(fun, &data_types) + } + Expr::AggregateFunction { fun, args, .. } => { + let data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + aggregates::return_type(fun, &data_types) + } + Expr::AggregateUDF { fun, args, .. } => { + let data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + Ok((fun.return_type)(&data_types)?.as_ref().clone()) + } + Expr::Not(_) + | Expr::IsNull(_) + | Expr::Between { .. } + | Expr::InList { .. } + | Expr::IsNotNull(_) => Ok(DataType::Boolean), + Expr::BinaryExpr { + ref left, + ref right, + ref op, + } => binary_operator_data_type( + &left.get_type(schema)?, + op, + &right.get_type(schema)?, + ), + Expr::Wildcard => Err(DataFusionError::Internal( + "Wildcard expressions are not valid in a logical query plan".to_owned(), + )), + Expr::GetIndexedField { ref expr, key } => { + let data_type = expr.get_type(schema)?; + + get_indexed_field(&data_type, key).map(|x| x.data_type().clone()) + } + } + } + + /// Returns the nullability of the expression based on [ExprSchema]. + /// + /// Note: [DFSchema] implements [ExprSchema]. + /// + /// # Errors + /// + /// This function errors when it is not possible to compute its + /// nullability. This happens when the expression refers to a + /// column that does not exist in the schema. + fn nullable(&self, input_schema: &S) -> Result { + match self { + Expr::Alias(expr, _) + | Expr::Not(expr) + | Expr::Negative(expr) + | Expr::Sort { expr, .. } + | Expr::Between { expr, .. } + | Expr::InList { expr, .. } => expr.nullable(input_schema), + Expr::Column(c) => input_schema.nullable(c), + Expr::Literal(value) => Ok(value.is_null()), + Expr::Case { + when_then_expr, + else_expr, + .. + } => { + // this expression is nullable if any of the input expressions are nullable + let then_nullable = when_then_expr + .iter() + .map(|(_, t)| t.nullable(input_schema)) + .collect::>>()?; + if then_nullable.contains(&true) { + Ok(true) + } else if let Some(e) = else_expr { + e.nullable(input_schema) + } else { + Ok(false) + } + } + Expr::Cast { expr, .. } => expr.nullable(input_schema), + Expr::ScalarVariable(_) + | Expr::TryCast { .. } + | Expr::ScalarFunction { .. } + | Expr::ScalarUDF { .. } + | Expr::WindowFunction { .. } + | Expr::AggregateFunction { .. } + | Expr::AggregateUDF { .. } => Ok(true), + Expr::IsNull(_) | Expr::IsNotNull(_) => Ok(false), + Expr::BinaryExpr { + ref left, + ref right, + .. + } => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?), + Expr::Wildcard => Err(DataFusionError::Internal( + "Wildcard expressions are not valid in a logical query plan".to_owned(), + )), + Expr::GetIndexedField { ref expr, key } => { + let data_type = expr.get_type(input_schema)?; + get_indexed_field(&data_type, key).map(|x| x.is_nullable()) + } + } + } + + /// Returns a [arrow::datatypes::Field] compatible with this expression. + fn to_field(&self, input_schema: &DFSchema) -> Result { + match self { + Expr::Column(c) => Ok(DFField::new( + c.relation.as_deref(), + &c.name, + self.get_type(input_schema)?, + self.nullable(input_schema)?, + )), + _ => Ok(DFField::new( + None, + &self.name(input_schema)?, + self.get_type(input_schema)?, + self.nullable(input_schema)?, + )), + } + } + + /// Wraps this expression in a cast to a target [arrow::datatypes::DataType]. + /// + /// # Errors + /// + /// This function errors when it is impossible to cast the + /// expression to the target [arrow::datatypes::DataType]. + fn cast_to(self, cast_to_type: &DataType, schema: &S) -> Result { + // TODO(kszucs): most of the operations do not validate the type correctness + // like all of the binary expressions below. Perhaps Expr should track the + // type of the expression? + let this_type = self.get_type(schema)?; + if this_type == *cast_to_type { + Ok(self) + } else if can_cast_types(&this_type, cast_to_type) { + Ok(Expr::Cast { + expr: Box::new(self), + data_type: cast_to_type.clone(), + }) + } else { + Err(DataFusionError::Plan(format!( + "Cannot automatically convert {:?} to {:?}", + this_type, cast_to_type + ))) + } + } +} diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index 085775a2eb8c..f2ecb0f76278 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -26,6 +26,7 @@ mod dfschema; mod display; mod expr; mod expr_rewriter; +mod expr_schema; mod expr_simplier; mod expr_visitor; mod extension; @@ -54,6 +55,7 @@ pub use expr_rewriter::{ normalize_col, normalize_cols, replace_col, rewrite_sort_cols_by_aggs, unnormalize_col, unnormalize_cols, ExprRewritable, ExprRewriter, RewriteRecursion, }; +pub use expr_schema::ExprSchemable; pub use expr_simplier::{ExprSimplifiable, SimplifyInfo}; pub use expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion}; pub use extension::UserDefinedLogicalNode; diff --git a/datafusion/src/optimizer/common_subexpr_eliminate.rs b/datafusion/src/optimizer/common_subexpr_eliminate.rs index 5c2219b3d99a..2ed45be25bc1 100644 --- a/datafusion/src/optimizer/common_subexpr_eliminate.rs +++ b/datafusion/src/optimizer/common_subexpr_eliminate.rs @@ -23,7 +23,7 @@ use crate::logical_plan::plan::{Filter, Projection, Window}; use crate::logical_plan::{ col, plan::{Aggregate, Sort}, - DFField, DFSchema, Expr, ExprRewritable, ExprRewriter, ExprVisitable, + DFField, DFSchema, Expr, ExprRewritable, ExprRewriter, ExprSchemable, ExprVisitable, ExpressionVisitor, LogicalPlan, Recursion, RewriteRecursion, }; use crate::optimizer::optimizer::OptimizerRule; diff --git a/datafusion/src/optimizer/simplify_expressions.rs b/datafusion/src/optimizer/simplify_expressions.rs index f8f3df44b673..4e9709bd9b5f 100644 --- a/datafusion/src/optimizer/simplify_expressions.rs +++ b/datafusion/src/optimizer/simplify_expressions.rs @@ -17,12 +17,9 @@ //! Simplify expressions optimizer rule -use arrow::array::new_null_array; -use arrow::datatypes::{DataType, Field, Schema}; -use arrow::record_batch::RecordBatch; - use crate::error::DataFusionError; use crate::execution::context::ExecutionProps; +use crate::logical_plan::ExprSchemable; use crate::logical_plan::{ lit, DFSchema, DFSchemaRef, Expr, ExprRewritable, ExprRewriter, ExprSimplifiable, LogicalPlan, RewriteRecursion, SimplifyInfo, @@ -33,6 +30,9 @@ use crate::physical_plan::functions::Volatility; use crate::physical_plan::planner::create_physical_expr; use crate::scalar::ScalarValue; use crate::{error::Result, logical_plan::Operator}; +use arrow::array::new_null_array; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::record_batch::RecordBatch; /// Provides simplification information based on schema and properties struct SimplifyContext<'a, 'b> { diff --git a/datafusion/src/optimizer/single_distinct_to_groupby.rs b/datafusion/src/optimizer/single_distinct_to_groupby.rs index 02a24e214495..2e0bd5ff0549 100644 --- a/datafusion/src/optimizer/single_distinct_to_groupby.rs +++ b/datafusion/src/optimizer/single_distinct_to_groupby.rs @@ -20,6 +20,7 @@ use crate::error::Result; use crate::execution::context::ExecutionProps; use crate::logical_plan::plan::{Aggregate, Projection}; +use crate::logical_plan::ExprSchemable; use crate::logical_plan::{col, columnize_expr, DFSchema, Expr, LogicalPlan}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; diff --git a/datafusion/tests/simplification.rs b/datafusion/tests/simplification.rs index 0ce8e7685b83..fe5f5e254b52 100644 --- a/datafusion/tests/simplification.rs +++ b/datafusion/tests/simplification.rs @@ -18,6 +18,7 @@ //! This program demonstrates the DataFusion expression simplification API. use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::logical_plan::ExprSchemable; use datafusion::logical_plan::ExprSimplifiable; use datafusion::{ error::Result, From e8c198b9fac6cd8822b950b9f71898e47965488d Mon Sep 17 00:00:00 2001 From: Eduard Karacharov <13005055+korowa@users.noreply.github.com> Date: Wed, 9 Feb 2022 00:16:08 +0300 Subject: [PATCH 46/50] rewrite predicates before pushing to union inputs (#1781) --- datafusion/src/optimizer/filter_push_down.rs | 54 ++++++++++++++++++-- 1 file changed, 49 insertions(+), 5 deletions(-) diff --git a/datafusion/src/optimizer/filter_push_down.rs b/datafusion/src/optimizer/filter_push_down.rs index ababb52020d7..78911313efaf 100644 --- a/datafusion/src/optimizer/filter_push_down.rs +++ b/datafusion/src/optimizer/filter_push_down.rs @@ -16,9 +16,9 @@ use crate::datasource::datasource::TableProviderFilterPushDown; use crate::execution::context::ExecutionProps; -use crate::logical_plan::plan::{Aggregate, Filter, Join, Projection}; +use crate::logical_plan::plan::{Aggregate, Filter, Join, Projection, Union}; use crate::logical_plan::{ - and, replace_col, Column, CrossJoin, JoinType, Limit, LogicalPlan, TableScan, + and, col, replace_col, Column, CrossJoin, JoinType, Limit, LogicalPlan, TableScan, }; use crate::logical_plan::{DFSchema, Expr}; use crate::optimizer::optimizer::OptimizerRule; @@ -394,8 +394,29 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { // sort is filter-commutable push_down(&state, plan) } - LogicalPlan::Union(_) => { - // union all is filter-commutable + LogicalPlan::Union(Union { + inputs: _, + schema, + alias: _, + }) => { + // union changing all qualifiers while building logical plan so we need + // to rewrite filters to push unqualified columns to inputs + let projection = schema + .fields() + .iter() + .map(|field| (field.qualified_name(), col(field.name()))) + .collect::>(); + + // rewriting predicate expressions using unqualified names as replacements + if !projection.is_empty() { + for (predicate, columns) in state.filters.iter_mut() { + *predicate = rewrite(predicate, &projection)?; + + columns.clear(); + utils::expr_to_columns(predicate, columns)?; + } + } + push_down(&state, plan) } LogicalPlan::Limit(Limit { input, .. }) => { @@ -574,7 +595,9 @@ fn rewrite(expr: &Expr, projection: &HashMap) -> Result { mod tests { use super::*; use crate::datasource::TableProvider; - use crate::logical_plan::{lit, sum, DFSchema, Expr, LogicalPlanBuilder, Operator}; + use crate::logical_plan::{ + lit, sum, union_with_alias, DFSchema, Expr, LogicalPlanBuilder, Operator, + }; use crate::physical_plan::ExecutionPlan; use crate::test::*; use crate::{logical_plan::col, prelude::JoinType}; @@ -901,6 +924,27 @@ mod tests { Ok(()) } + #[test] + fn union_all_with_alias() -> Result<()> { + let table_scan = test_table_scan()?; + let union = + union_with_alias(table_scan.clone(), table_scan, Some("t".to_string()))?; + + let plan = LogicalPlanBuilder::from(union) + .filter(col("t.a").eq(lit(1i64)))? + .build()?; + + // filter appears below Union without relation qualifier + let expected = "\ + Union\ + \n Filter: #a = Int64(1)\ + \n TableScan: test projection=None\ + \n Filter: #a = Int64(1)\ + \n TableScan: test projection=None"; + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } + /// verifies that filters with the same columns are correctly placed #[test] fn filter_2_breaks_limits() -> Result<()> { From ed9b04995906aba95f21837cae973d931d346d83 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Wed, 9 Feb 2022 08:16:02 +0800 Subject: [PATCH 47/50] move accumulator and columnar value (#1765) --- datafusion-expr/src/accumulator.rs | 44 +++++++++++++++++ datafusion-expr/src/columnar_value.rs | 60 +++++++++++++++++++++++ datafusion-expr/src/lib.rs | 4 ++ datafusion/src/physical_plan/functions.rs | 18 ++----- datafusion/src/physical_plan/mod.rs | 52 +------------------- 5 files changed, 113 insertions(+), 65 deletions(-) create mode 100644 datafusion-expr/src/accumulator.rs create mode 100644 datafusion-expr/src/columnar_value.rs diff --git a/datafusion-expr/src/accumulator.rs b/datafusion-expr/src/accumulator.rs new file mode 100644 index 000000000000..599bd363fb61 --- /dev/null +++ b/datafusion-expr/src/accumulator.rs @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::ArrayRef; +use datafusion_common::{Result, ScalarValue}; +use std::fmt::Debug; + +/// An accumulator represents a stateful object that lives throughout the evaluation of multiple rows and +/// generically accumulates values. +/// +/// An accumulator knows how to: +/// * update its state from inputs via `update_batch` +/// * convert its internal state to a vector of scalar values +/// * update its state from multiple accumulators' states via `merge_batch` +/// * compute the final value from its internal state via `evaluate` +pub trait Accumulator: Send + Sync + Debug { + /// Returns the state of the accumulator at the end of the accumulation. + // in the case of an average on which we track `sum` and `n`, this function should return a vector + // of two values, sum and n. + fn state(&self) -> Result>; + + /// updates the accumulator's state from a vector of arrays. + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>; + + /// updates the accumulator's state from a vector of states. + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()>; + + /// returns its value based on its current state. + fn evaluate(&self) -> Result; +} diff --git a/datafusion-expr/src/columnar_value.rs b/datafusion-expr/src/columnar_value.rs new file mode 100644 index 000000000000..5e6959d751f8 --- /dev/null +++ b/datafusion-expr/src/columnar_value.rs @@ -0,0 +1,60 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::ArrayRef; +use arrow::array::NullArray; +use arrow::datatypes::DataType; +use arrow::record_batch::RecordBatch; +use datafusion_common::ScalarValue; +use std::sync::Arc; + +/// Represents the result from an expression +#[derive(Clone)] +pub enum ColumnarValue { + /// Array of values + Array(ArrayRef), + /// A single value + Scalar(ScalarValue), +} + +impl ColumnarValue { + pub fn data_type(&self) -> DataType { + match self { + ColumnarValue::Array(array_value) => array_value.data_type().clone(), + ColumnarValue::Scalar(scalar_value) => scalar_value.get_datatype(), + } + } + + /// Convert a columnar value into an ArrayRef + pub fn into_array(self, num_rows: usize) -> ArrayRef { + match self { + ColumnarValue::Array(array) => array, + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(num_rows), + } + } +} + +/// null columnar values are implemented as a null array in order to pass batch +/// num_rows +pub type NullColumnarValue = ColumnarValue; + +impl From<&RecordBatch> for NullColumnarValue { + fn from(batch: &RecordBatch) -> Self { + let num_rows = batch.num_rows(); + ColumnarValue::Array(Arc::new(NullArray::new(num_rows))) + } +} diff --git a/datafusion-expr/src/lib.rs b/datafusion-expr/src/lib.rs index 7dcddc39c4dc..2491fcf73ca9 100644 --- a/datafusion-expr/src/lib.rs +++ b/datafusion-expr/src/lib.rs @@ -15,15 +15,19 @@ // specific language governing permissions and limitations // under the License. +mod accumulator; mod aggregate_function; mod built_in_function; +mod columnar_value; mod operator; mod signature; mod window_frame; mod window_function; +pub use accumulator::Accumulator; pub use aggregate_function::AggregateFunction; pub use built_in_function::BuiltinScalarFunction; +pub use columnar_value::{ColumnarValue, NullColumnarValue}; pub use operator::Operator; pub use signature::{Signature, TypeSignature, Volatility}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs index 9582eecce33e..bf0aee9e6aa0 100644 --- a/datafusion/src/physical_plan/functions.rs +++ b/datafusion/src/physical_plan/functions.rs @@ -46,18 +46,17 @@ use crate::{ scalar::ScalarValue, }; use arrow::{ - array::{ArrayRef, NullArray}, + array::ArrayRef, compute::kernels::length::{bit_length, length}, datatypes::TimeUnit, datatypes::{DataType, Field, Int32Type, Int64Type, Schema}, record_batch::RecordBatch, }; +pub use datafusion_expr::NullColumnarValue; +pub use datafusion_expr::{BuiltinScalarFunction, Signature, TypeSignature, Volatility}; use fmt::{Debug, Formatter}; -use std::convert::From; use std::{any::Any, fmt, sync::Arc}; -pub use datafusion_expr::{BuiltinScalarFunction, Signature, TypeSignature, Volatility}; - /// Scalar function /// /// The Fn param is the wrapped function but be aware that the function will @@ -1206,17 +1205,6 @@ impl fmt::Display for ScalarFunctionExpr { } } -/// null columnar values are implemented as a null array in order to pass batch -/// num_rows -type NullColumnarValue = ColumnarValue; - -impl From<&RecordBatch> for NullColumnarValue { - fn from(batch: &RecordBatch) -> Self { - let num_rows = batch.num_rows(); - ColumnarValue::Array(Arc::new(NullArray::new(num_rows))) - } -} - impl PhysicalExpr for ScalarFunctionExpr { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index ac70f2f90ae2..38a19db1347a 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -35,6 +35,8 @@ use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; use async_trait::async_trait; +pub use datafusion_expr::Accumulator; +pub use datafusion_expr::ColumnarValue; pub use display::DisplayFormatType; use futures::stream::Stream; use std::fmt; @@ -419,32 +421,6 @@ pub enum Distribution { HashPartitioned(Vec>), } -/// Represents the result from an expression -#[derive(Clone)] -pub enum ColumnarValue { - /// Array of values - Array(ArrayRef), - /// A single value - Scalar(ScalarValue), -} - -impl ColumnarValue { - fn data_type(&self) -> DataType { - match self { - ColumnarValue::Array(array_value) => array_value.data_type().clone(), - ColumnarValue::Scalar(scalar_value) => scalar_value.get_datatype(), - } - } - - /// Convert a columnar value into an ArrayRef - pub fn into_array(self, num_rows: usize) -> ArrayRef { - match self { - ColumnarValue::Array(array) => array, - ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(num_rows), - } - } -} - /// Expression that can be evaluated against a RecordBatch /// A Physical expression knows its type, nullability and how to evaluate itself. pub trait PhysicalExpr: Send + Sync + Display + Debug { @@ -578,30 +554,6 @@ pub trait WindowExpr: Send + Sync + Debug { } } -/// An accumulator represents a stateful object that lives throughout the evaluation of multiple rows and -/// generically accumulates values. -/// -/// An accumulator knows how to: -/// * update its state from inputs via `update_batch` -/// * convert its internal state to a vector of scalar values -/// * update its state from multiple accumulators' states via `merge_batch` -/// * compute the final value from its internal state via `evaluate` -pub trait Accumulator: Send + Sync + Debug { - /// Returns the state of the accumulator at the end of the accumulation. - // in the case of an average on which we track `sum` and `n`, this function should return a vector - // of two values, sum and n. - fn state(&self) -> Result>; - - /// updates the accumulator's state from a vector of arrays. - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>; - - /// updates the accumulator's state from a vector of states. - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()>; - - /// returns its value based on its current state. - fn evaluate(&self) -> Result; -} - /// Applies an optional projection to a [`SchemaRef`], returning the /// projected schema /// From 014e5e90d623befd9f3e179b02864a6b8bcab568 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Wed, 9 Feb 2022 12:50:38 +0800 Subject: [PATCH 48/50] move accumulator and columnar value (#1762) --- datafusion-expr/Cargo.toml | 1 + datafusion-expr/src/expr.rs | 698 ++++++++++++++++++ datafusion-expr/src/expr_fn.rs | 32 + datafusion-expr/src/function.rs | 46 ++ datafusion-expr/src/lib.rs | 15 + datafusion-expr/src/literal.rs | 138 ++++ datafusion-expr/src/operator.rs | 43 ++ datafusion-expr/src/udaf.rs | 92 +++ datafusion-expr/src/udf.rs | 93 +++ datafusion/src/execution/dataframe_impl.rs | 4 +- datafusion/src/logical_plan/expr.rs | 811 +-------------------- datafusion/src/logical_plan/mod.rs | 3 +- datafusion/src/logical_plan/operators.rs | 42 -- datafusion/src/physical_plan/aggregates.rs | 11 +- datafusion/src/physical_plan/udaf.rs | 83 +-- datafusion/src/physical_plan/udf.rs | 85 +-- datafusion/src/sql/planner.rs | 3 +- 17 files changed, 1187 insertions(+), 1013 deletions(-) create mode 100644 datafusion-expr/src/expr.rs create mode 100644 datafusion-expr/src/expr_fn.rs create mode 100644 datafusion-expr/src/function.rs create mode 100644 datafusion-expr/src/literal.rs create mode 100644 datafusion-expr/src/udaf.rs create mode 100644 datafusion-expr/src/udf.rs diff --git a/datafusion-expr/Cargo.toml b/datafusion-expr/Cargo.toml index 73a5fcd36152..a6dad528b6b7 100644 --- a/datafusion-expr/Cargo.toml +++ b/datafusion-expr/Cargo.toml @@ -38,3 +38,4 @@ path = "src/lib.rs" datafusion-common = { path = "../datafusion-common", version = "6.0.0" } arrow = { version = "8.0.0", features = ["prettyprint"] } sqlparser = "0.13" +ahash = { version = "0.7", default-features = false } diff --git a/datafusion-expr/src/expr.rs b/datafusion-expr/src/expr.rs new file mode 100644 index 000000000000..f26f1dfa9746 --- /dev/null +++ b/datafusion-expr/src/expr.rs @@ -0,0 +1,698 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aggregate_function; +use crate::built_in_function; +use crate::expr_fn::binary_expr; +use crate::window_frame; +use crate::window_function; +use crate::AggregateUDF; +use crate::Operator; +use crate::ScalarUDF; +use arrow::datatypes::DataType; +use datafusion_common::Column; +use datafusion_common::{DFSchema, Result}; +use datafusion_common::{DataFusionError, ScalarValue}; +use std::fmt; +use std::hash::{BuildHasher, Hash, Hasher}; +use std::ops::Not; +use std::sync::Arc; + +/// `Expr` is a central struct of DataFusion's query API, and +/// represent logical expressions such as `A + 1`, or `CAST(c1 AS +/// int)`. +/// +/// An `Expr` can compute its [DataType](arrow::datatypes::DataType) +/// and nullability, and has functions for building up complex +/// expressions. +/// +/// # Examples +/// +/// ## Create an expression `c1` referring to column named "c1" +/// ``` +/// # use datafusion_common::Column; +/// # use datafusion_expr::{lit, col, Expr}; +/// let expr = col("c1"); +/// assert_eq!(expr, Expr::Column(Column::from_name("c1"))); +/// ``` +/// +/// ## Create the expression `c1 + c2` to add columns "c1" and "c2" together +/// ``` +/// # use datafusion_expr::{lit, col, Operator, Expr}; +/// let expr = col("c1") + col("c2"); +/// +/// assert!(matches!(expr, Expr::BinaryExpr { ..} )); +/// if let Expr::BinaryExpr { left, right, op } = expr { +/// assert_eq!(*left, col("c1")); +/// assert_eq!(*right, col("c2")); +/// assert_eq!(op, Operator::Plus); +/// } +/// ``` +/// +/// ## Create expression `c1 = 42` to compare the value in column "c1" to the literal value `42` +/// ``` +/// # use datafusion_common::ScalarValue; +/// # use datafusion_expr::{lit, col, Operator, Expr}; +/// let expr = col("c1").eq(lit(42_i32)); +/// +/// assert!(matches!(expr, Expr::BinaryExpr { .. } )); +/// if let Expr::BinaryExpr { left, right, op } = expr { +/// assert_eq!(*left, col("c1")); +/// let scalar = ScalarValue::Int32(Some(42)); +/// assert_eq!(*right, Expr::Literal(scalar)); +/// assert_eq!(op, Operator::Eq); +/// } +/// ``` +#[derive(Clone, PartialEq, Hash)] +pub enum Expr { + /// An expression with a specific name. + Alias(Box, String), + /// A named reference to a qualified filed in a schema. + Column(Column), + /// A named reference to a variable in a registry. + ScalarVariable(Vec), + /// A constant value. + Literal(ScalarValue), + /// A binary expression such as "age > 21" + BinaryExpr { + /// Left-hand side of the expression + left: Box, + /// The comparison operator + op: Operator, + /// Right-hand side of the expression + right: Box, + }, + /// Negation of an expression. The expression's type must be a boolean to make sense. + Not(Box), + /// Whether an expression is not Null. This expression is never null. + IsNotNull(Box), + /// Whether an expression is Null. This expression is never null. + IsNull(Box), + /// arithmetic negation of an expression, the operand must be of a signed numeric data type + Negative(Box), + /// Returns the field of a [`ListArray`] or [`StructArray`] by key + GetIndexedField { + /// the expression to take the field from + expr: Box, + /// The name of the field to take + key: ScalarValue, + }, + /// Whether an expression is between a given range. + Between { + /// The value to compare + expr: Box, + /// Whether the expression is negated + negated: bool, + /// The low end of the range + low: Box, + /// The high end of the range + high: Box, + }, + /// The CASE expression is similar to a series of nested if/else and there are two forms that + /// can be used. The first form consists of a series of boolean "when" expressions with + /// corresponding "then" expressions, and an optional "else" expression. + /// + /// CASE WHEN condition THEN result + /// [WHEN ...] + /// [ELSE result] + /// END + /// + /// The second form uses a base expression and then a series of "when" clauses that match on a + /// literal value. + /// + /// CASE expression + /// WHEN value THEN result + /// [WHEN ...] + /// [ELSE result] + /// END + Case { + /// Optional base expression that can be compared to literal values in the "when" expressions + expr: Option>, + /// One or more when/then expressions + when_then_expr: Vec<(Box, Box)>, + /// Optional "else" expression + else_expr: Option>, + }, + /// Casts the expression to a given type and will return a runtime error if the expression cannot be cast. + /// This expression is guaranteed to have a fixed type. + Cast { + /// The expression being cast + expr: Box, + /// The `DataType` the expression will yield + data_type: DataType, + }, + /// Casts the expression to a given type and will return a null value if the expression cannot be cast. + /// This expression is guaranteed to have a fixed type. + TryCast { + /// The expression being cast + expr: Box, + /// The `DataType` the expression will yield + data_type: DataType, + }, + /// A sort expression, that can be used to sort values. + Sort { + /// The expression to sort on + expr: Box, + /// The direction of the sort + asc: bool, + /// Whether to put Nulls before all other data values + nulls_first: bool, + }, + /// Represents the call of a built-in scalar function with a set of arguments. + ScalarFunction { + /// The function + fun: built_in_function::BuiltinScalarFunction, + /// List of expressions to feed to the functions as arguments + args: Vec, + }, + /// Represents the call of a user-defined scalar function with arguments. + ScalarUDF { + /// The function + fun: Arc, + /// List of expressions to feed to the functions as arguments + args: Vec, + }, + /// Represents the call of an aggregate built-in function with arguments. + AggregateFunction { + /// Name of the function + fun: aggregate_function::AggregateFunction, + /// List of expressions to feed to the functions as arguments + args: Vec, + /// Whether this is a DISTINCT aggregation or not + distinct: bool, + }, + /// Represents the call of a window function with arguments. + WindowFunction { + /// Name of the function + fun: window_function::WindowFunction, + /// List of expressions to feed to the functions as arguments + args: Vec, + /// List of partition by expressions + partition_by: Vec, + /// List of order by expressions + order_by: Vec, + /// Window frame + window_frame: Option, + }, + /// aggregate function + AggregateUDF { + /// The function + fun: Arc, + /// List of expressions to feed to the functions as arguments + args: Vec, + }, + /// Returns whether the list contains the expr value. + InList { + /// The expression to compare + expr: Box, + /// A list of values to compare against + list: Vec, + /// Whether the expression is negated + negated: bool, + }, + /// Represents a reference to all fields in a schema. + Wildcard, +} + +/// Fixed seed for the hashing so that Ords are consistent across runs +const SEED: ahash::RandomState = ahash::RandomState::with_seeds(0, 0, 0, 0); + +impl PartialOrd for Expr { + fn partial_cmp(&self, other: &Self) -> Option { + let mut hasher = SEED.build_hasher(); + self.hash(&mut hasher); + let s = hasher.finish(); + + let mut hasher = SEED.build_hasher(); + other.hash(&mut hasher); + let o = hasher.finish(); + + Some(s.cmp(&o)) + } +} + +impl Expr { + /// Returns the name of this expression based on [crate::logical_plan::DFSchema]. + /// + /// This represents how a column with this expression is named when no alias is chosen + pub fn name(&self, input_schema: &DFSchema) -> Result { + create_name(self, input_schema) + } + + /// Return `self == other` + pub fn eq(self, other: Expr) -> Expr { + binary_expr(self, Operator::Eq, other) + } + + /// Return `self != other` + pub fn not_eq(self, other: Expr) -> Expr { + binary_expr(self, Operator::NotEq, other) + } + + /// Return `self > other` + pub fn gt(self, other: Expr) -> Expr { + binary_expr(self, Operator::Gt, other) + } + + /// Return `self >= other` + pub fn gt_eq(self, other: Expr) -> Expr { + binary_expr(self, Operator::GtEq, other) + } + + /// Return `self < other` + pub fn lt(self, other: Expr) -> Expr { + binary_expr(self, Operator::Lt, other) + } + + /// Return `self <= other` + pub fn lt_eq(self, other: Expr) -> Expr { + binary_expr(self, Operator::LtEq, other) + } + + /// Return `self && other` + pub fn and(self, other: Expr) -> Expr { + binary_expr(self, Operator::And, other) + } + + /// Return `self || other` + pub fn or(self, other: Expr) -> Expr { + binary_expr(self, Operator::Or, other) + } + + /// Return `!self` + #[allow(clippy::should_implement_trait)] + pub fn not(self) -> Expr { + !self + } + + /// Calculate the modulus of two expressions. + /// Return `self % other` + pub fn modulus(self, other: Expr) -> Expr { + binary_expr(self, Operator::Modulo, other) + } + + /// Return `self LIKE other` + pub fn like(self, other: Expr) -> Expr { + binary_expr(self, Operator::Like, other) + } + + /// Return `self NOT LIKE other` + pub fn not_like(self, other: Expr) -> Expr { + binary_expr(self, Operator::NotLike, other) + } + + /// Return `self AS name` alias expression + pub fn alias(self, name: &str) -> Expr { + Expr::Alias(Box::new(self), name.to_owned()) + } + + /// Return `self IN ` if `negated` is false, otherwise + /// return `self NOT IN `.a + pub fn in_list(self, list: Vec, negated: bool) -> Expr { + Expr::InList { + expr: Box::new(self), + list, + negated, + } + } + + /// Return `IsNull(Box(self)) + #[allow(clippy::wrong_self_convention)] + pub fn is_null(self) -> Expr { + Expr::IsNull(Box::new(self)) + } + + /// Return `IsNotNull(Box(self)) + #[allow(clippy::wrong_self_convention)] + pub fn is_not_null(self) -> Expr { + Expr::IsNotNull(Box::new(self)) + } + + /// Create a sort expression from an existing expression. + /// + /// ``` + /// # use datafusion_expr::col; + /// let sort_expr = col("foo").sort(true, true); // SORT ASC NULLS_FIRST + /// ``` + pub fn sort(self, asc: bool, nulls_first: bool) -> Expr { + Expr::Sort { + expr: Box::new(self), + asc, + nulls_first, + } + } +} + +impl Not for Expr { + type Output = Self; + + fn not(self) -> Self::Output { + Expr::Not(Box::new(self)) + } +} + +impl std::fmt::Display for Expr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Expr::BinaryExpr { + ref left, + ref right, + ref op, + } => write!(f, "{} {} {}", left, op, right), + Expr::AggregateFunction { + /// Name of the function + ref fun, + /// List of expressions to feed to the functions as arguments + ref args, + /// Whether this is a DISTINCT aggregation or not + ref distinct, + } => fmt_function(f, &fun.to_string(), *distinct, args, true), + Expr::ScalarFunction { + /// Name of the function + ref fun, + /// List of expressions to feed to the functions as arguments + ref args, + } => fmt_function(f, &fun.to_string(), false, args, true), + _ => write!(f, "{:?}", self), + } + } +} + +impl fmt::Debug for Expr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Expr::Alias(expr, alias) => write!(f, "{:?} AS {}", expr, alias), + Expr::Column(c) => write!(f, "{}", c), + Expr::ScalarVariable(var_names) => write!(f, "{}", var_names.join(".")), + Expr::Literal(v) => write!(f, "{:?}", v), + Expr::Case { + expr, + when_then_expr, + else_expr, + .. + } => { + write!(f, "CASE ")?; + if let Some(e) = expr { + write!(f, "{:?} ", e)?; + } + for (w, t) in when_then_expr { + write!(f, "WHEN {:?} THEN {:?} ", w, t)?; + } + if let Some(e) = else_expr { + write!(f, "ELSE {:?} ", e)?; + } + write!(f, "END") + } + Expr::Cast { expr, data_type } => { + write!(f, "CAST({:?} AS {:?})", expr, data_type) + } + Expr::TryCast { expr, data_type } => { + write!(f, "TRY_CAST({:?} AS {:?})", expr, data_type) + } + Expr::Not(expr) => write!(f, "NOT {:?}", expr), + Expr::Negative(expr) => write!(f, "(- {:?})", expr), + Expr::IsNull(expr) => write!(f, "{:?} IS NULL", expr), + Expr::IsNotNull(expr) => write!(f, "{:?} IS NOT NULL", expr), + Expr::BinaryExpr { left, op, right } => { + write!(f, "{:?} {} {:?}", left, op, right) + } + Expr::Sort { + expr, + asc, + nulls_first, + } => { + if *asc { + write!(f, "{:?} ASC", expr)?; + } else { + write!(f, "{:?} DESC", expr)?; + } + if *nulls_first { + write!(f, " NULLS FIRST") + } else { + write!(f, " NULLS LAST") + } + } + Expr::ScalarFunction { fun, args, .. } => { + fmt_function(f, &fun.to_string(), false, args, false) + } + Expr::ScalarUDF { fun, ref args, .. } => { + fmt_function(f, &fun.name, false, args, false) + } + Expr::WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + } => { + fmt_function(f, &fun.to_string(), false, args, false)?; + if !partition_by.is_empty() { + write!(f, " PARTITION BY {:?}", partition_by)?; + } + if !order_by.is_empty() { + write!(f, " ORDER BY {:?}", order_by)?; + } + if let Some(window_frame) = window_frame { + write!( + f, + " {} BETWEEN {} AND {}", + window_frame.units, + window_frame.start_bound, + window_frame.end_bound + )?; + } + Ok(()) + } + Expr::AggregateFunction { + fun, + distinct, + ref args, + .. + } => fmt_function(f, &fun.to_string(), *distinct, args, true), + Expr::AggregateUDF { fun, ref args, .. } => { + fmt_function(f, &fun.name, false, args, false) + } + Expr::Between { + expr, + negated, + low, + high, + } => { + if *negated { + write!(f, "{:?} NOT BETWEEN {:?} AND {:?}", expr, low, high) + } else { + write!(f, "{:?} BETWEEN {:?} AND {:?}", expr, low, high) + } + } + Expr::InList { + expr, + list, + negated, + } => { + if *negated { + write!(f, "{:?} NOT IN ({:?})", expr, list) + } else { + write!(f, "{:?} IN ({:?})", expr, list) + } + } + Expr::Wildcard => write!(f, "*"), + Expr::GetIndexedField { ref expr, key } => { + write!(f, "({:?})[{}]", expr, key) + } + } + } +} + +fn fmt_function( + f: &mut fmt::Formatter, + fun: &str, + distinct: bool, + args: &[Expr], + display: bool, +) -> fmt::Result { + let args: Vec = match display { + true => args.iter().map(|arg| format!("{}", arg)).collect(), + false => args.iter().map(|arg| format!("{:?}", arg)).collect(), + }; + + // let args: Vec = args.iter().map(|arg| format!("{:?}", arg)).collect(); + let distinct_str = match distinct { + true => "DISTINCT ", + false => "", + }; + write!(f, "{}({}{})", fun, distinct_str, args.join(", ")) +} + +fn create_function_name( + fun: &str, + distinct: bool, + args: &[Expr], + input_schema: &DFSchema, +) -> Result { + let names: Vec = args + .iter() + .map(|e| create_name(e, input_schema)) + .collect::>()?; + let distinct_str = match distinct { + true => "DISTINCT ", + false => "", + }; + Ok(format!("{}({}{})", fun, distinct_str, names.join(","))) +} + +/// Returns a readable name of an expression based on the input schema. +/// This function recursively transverses the expression for names such as "CAST(a > 2)". +fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { + match e { + Expr::Alias(_, name) => Ok(name.clone()), + Expr::Column(c) => Ok(c.flat_name()), + Expr::ScalarVariable(variable_names) => Ok(variable_names.join(".")), + Expr::Literal(value) => Ok(format!("{:?}", value)), + Expr::BinaryExpr { left, op, right } => { + let left = create_name(left, input_schema)?; + let right = create_name(right, input_schema)?; + Ok(format!("{} {} {}", left, op, right)) + } + Expr::Case { + expr, + when_then_expr, + else_expr, + } => { + let mut name = "CASE ".to_string(); + if let Some(e) = expr { + let e = create_name(e, input_schema)?; + name += &format!("{} ", e); + } + for (w, t) in when_then_expr { + let when = create_name(w, input_schema)?; + let then = create_name(t, input_schema)?; + name += &format!("WHEN {} THEN {} ", when, then); + } + if let Some(e) = else_expr { + let e = create_name(e, input_schema)?; + name += &format!("ELSE {} ", e); + } + name += "END"; + Ok(name) + } + Expr::Cast { expr, data_type } => { + let expr = create_name(expr, input_schema)?; + Ok(format!("CAST({} AS {:?})", expr, data_type)) + } + Expr::TryCast { expr, data_type } => { + let expr = create_name(expr, input_schema)?; + Ok(format!("TRY_CAST({} AS {:?})", expr, data_type)) + } + Expr::Not(expr) => { + let expr = create_name(expr, input_schema)?; + Ok(format!("NOT {}", expr)) + } + Expr::Negative(expr) => { + let expr = create_name(expr, input_schema)?; + Ok(format!("(- {})", expr)) + } + Expr::IsNull(expr) => { + let expr = create_name(expr, input_schema)?; + Ok(format!("{} IS NULL", expr)) + } + Expr::IsNotNull(expr) => { + let expr = create_name(expr, input_schema)?; + Ok(format!("{} IS NOT NULL", expr)) + } + Expr::GetIndexedField { expr, key } => { + let expr = create_name(expr, input_schema)?; + Ok(format!("{}[{}]", expr, key)) + } + Expr::ScalarFunction { fun, args, .. } => { + create_function_name(&fun.to_string(), false, args, input_schema) + } + Expr::ScalarUDF { fun, args, .. } => { + create_function_name(&fun.name, false, args, input_schema) + } + Expr::WindowFunction { + fun, + args, + window_frame, + partition_by, + order_by, + } => { + let mut parts: Vec = vec![create_function_name( + &fun.to_string(), + false, + args, + input_schema, + )?]; + if !partition_by.is_empty() { + parts.push(format!("PARTITION BY {:?}", partition_by)); + } + if !order_by.is_empty() { + parts.push(format!("ORDER BY {:?}", order_by)); + } + if let Some(window_frame) = window_frame { + parts.push(format!("{}", window_frame)); + } + Ok(parts.join(" ")) + } + Expr::AggregateFunction { + fun, + distinct, + args, + .. + } => create_function_name(&fun.to_string(), *distinct, args, input_schema), + Expr::AggregateUDF { fun, args } => { + let mut names = Vec::with_capacity(args.len()); + for e in args { + names.push(create_name(e, input_schema)?); + } + Ok(format!("{}({})", fun.name, names.join(","))) + } + Expr::InList { + expr, + list, + negated, + } => { + let expr = create_name(expr, input_schema)?; + let list = list.iter().map(|expr| create_name(expr, input_schema)); + if *negated { + Ok(format!("{} NOT IN ({:?})", expr, list)) + } else { + Ok(format!("{} IN ({:?})", expr, list)) + } + } + Expr::Between { + expr, + negated, + low, + high, + } => { + let expr = create_name(expr, input_schema)?; + let low = create_name(low, input_schema)?; + let high = create_name(high, input_schema)?; + if *negated { + Ok(format!("{} NOT BETWEEN {} AND {}", expr, low, high)) + } else { + Ok(format!("{} BETWEEN {} AND {}", expr, low, high)) + } + } + Expr::Sort { .. } => Err(DataFusionError::Internal( + "Create name does not support sort expression".to_string(), + )), + Expr::Wildcard => Err(DataFusionError::Internal( + "Create name does not support wildcard".to_string(), + )), + } +} diff --git a/datafusion-expr/src/expr_fn.rs b/datafusion-expr/src/expr_fn.rs new file mode 100644 index 000000000000..469a82d0ff24 --- /dev/null +++ b/datafusion-expr/src/expr_fn.rs @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{Expr, Operator}; + +/// Create a column expression based on a qualified or unqualified column name +pub fn col(ident: &str) -> Expr { + Expr::Column(ident.into()) +} + +/// return a new expression l r +pub fn binary_expr(l: Expr, op: Operator, r: Expr) -> Expr { + Expr::BinaryExpr { + left: Box::new(l), + op, + right: Box::new(r), + } +} diff --git a/datafusion-expr/src/function.rs b/datafusion-expr/src/function.rs new file mode 100644 index 000000000000..2bacd6ae6227 --- /dev/null +++ b/datafusion-expr/src/function.rs @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::Accumulator; +use crate::ColumnarValue; +use arrow::datatypes::DataType; +use datafusion_common::Result; +use std::sync::Arc; + +/// Scalar function +/// +/// The Fn param is the wrapped function but be aware that the function will +/// be passed with the slice / vec of columnar values (either scalar or array) +/// with the exception of zero param function, where a singular element vec +/// will be passed. In that case the single element is a null array to indicate +/// the batch's row count (so that the generative zero-argument function can know +/// the result array size). +pub type ScalarFunctionImplementation = + Arc Result + Send + Sync>; + +/// A function's return type +pub type ReturnTypeFunction = + Arc Result> + Send + Sync>; + +/// the implementation of an aggregate function +pub type AccumulatorFunctionImplementation = + Arc Result> + Send + Sync>; + +/// This signature corresponds to which types an aggregator serializes +/// its state, given its return datatype. +pub type StateTypeFunction = + Arc Result>> + Send + Sync>; diff --git a/datafusion-expr/src/lib.rs b/datafusion-expr/src/lib.rs index 2491fcf73ca9..709fa634d52d 100644 --- a/datafusion-expr/src/lib.rs +++ b/datafusion-expr/src/lib.rs @@ -19,8 +19,14 @@ mod accumulator; mod aggregate_function; mod built_in_function; mod columnar_value; +pub mod expr; +pub mod expr_fn; +mod function; +mod literal; mod operator; mod signature; +mod udaf; +mod udf; mod window_frame; mod window_function; @@ -28,7 +34,16 @@ pub use accumulator::Accumulator; pub use aggregate_function::AggregateFunction; pub use built_in_function::BuiltinScalarFunction; pub use columnar_value::{ColumnarValue, NullColumnarValue}; +pub use expr::Expr; +pub use expr_fn::col; +pub use function::{ + AccumulatorFunctionImplementation, ReturnTypeFunction, ScalarFunctionImplementation, + StateTypeFunction, +}; +pub use literal::{lit, lit_timestamp_nano, Literal, TimestampLiteral}; pub use operator::Operator; pub use signature::{Signature, TypeSignature, Volatility}; +pub use udaf::AggregateUDF; +pub use udf::ScalarUDF; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; pub use window_function::{BuiltInWindowFunction, WindowFunction}; diff --git a/datafusion-expr/src/literal.rs b/datafusion-expr/src/literal.rs new file mode 100644 index 000000000000..02c75af69573 --- /dev/null +++ b/datafusion-expr/src/literal.rs @@ -0,0 +1,138 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::Expr; +use datafusion_common::ScalarValue; + +/// Create a literal expression +pub fn lit(n: T) -> Expr { + n.lit() +} + +/// Create a literal timestamp expression +pub fn lit_timestamp_nano(n: T) -> Expr { + n.lit_timestamp_nano() +} + +/// Trait for converting a type to a [`Literal`] literal expression. +pub trait Literal { + /// convert the value to a Literal expression + fn lit(&self) -> Expr; +} + +/// Trait for converting a type to a literal timestamp +pub trait TimestampLiteral { + fn lit_timestamp_nano(&self) -> Expr; +} + +impl Literal for &str { + fn lit(&self) -> Expr { + Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned()))) + } +} + +impl Literal for String { + fn lit(&self) -> Expr { + Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned()))) + } +} + +impl Literal for Vec { + fn lit(&self) -> Expr { + Expr::Literal(ScalarValue::Binary(Some((*self).to_owned()))) + } +} + +impl Literal for &[u8] { + fn lit(&self) -> Expr { + Expr::Literal(ScalarValue::Binary(Some((*self).to_owned()))) + } +} + +impl Literal for ScalarValue { + fn lit(&self) -> Expr { + Expr::Literal(self.clone()) + } +} + +macro_rules! make_literal { + ($TYPE:ty, $SCALAR:ident, $DOC: expr) => { + #[doc = $DOC] + impl Literal for $TYPE { + fn lit(&self) -> Expr { + Expr::Literal(ScalarValue::$SCALAR(Some(self.clone()))) + } + } + }; +} + +macro_rules! make_timestamp_literal { + ($TYPE:ty, $SCALAR:ident, $DOC: expr) => { + #[doc = $DOC] + impl TimestampLiteral for $TYPE { + fn lit_timestamp_nano(&self) -> Expr { + Expr::Literal(ScalarValue::TimestampNanosecond( + Some((self.clone()).into()), + None, + )) + } + } + }; +} + +make_literal!(bool, Boolean, "literal expression containing a bool"); +make_literal!(f32, Float32, "literal expression containing an f32"); +make_literal!(f64, Float64, "literal expression containing an f64"); +make_literal!(i8, Int8, "literal expression containing an i8"); +make_literal!(i16, Int16, "literal expression containing an i16"); +make_literal!(i32, Int32, "literal expression containing an i32"); +make_literal!(i64, Int64, "literal expression containing an i64"); +make_literal!(u8, UInt8, "literal expression containing a u8"); +make_literal!(u16, UInt16, "literal expression containing a u16"); +make_literal!(u32, UInt32, "literal expression containing a u32"); +make_literal!(u64, UInt64, "literal expression containing a u64"); + +make_timestamp_literal!(i8, Int8, "literal expression containing an i8"); +make_timestamp_literal!(i16, Int16, "literal expression containing an i16"); +make_timestamp_literal!(i32, Int32, "literal expression containing an i32"); +make_timestamp_literal!(i64, Int64, "literal expression containing an i64"); +make_timestamp_literal!(u8, UInt8, "literal expression containing a u8"); +make_timestamp_literal!(u16, UInt16, "literal expression containing a u16"); +make_timestamp_literal!(u32, UInt32, "literal expression containing a u32"); + +#[cfg(test)] +mod test { + use super::*; + use crate::expr_fn::col; + use datafusion_common::ScalarValue; + + #[test] + fn test_lit_timestamp_nano() { + let expr = col("time").eq(lit_timestamp_nano(10)); // 10 is an implicit i32 + let expected = + col("time").eq(lit(ScalarValue::TimestampNanosecond(Some(10), None))); + assert_eq!(expr, expected); + + let i: i64 = 10; + let expr = col("time").eq(lit_timestamp_nano(i)); + assert_eq!(expr, expected); + + let i: u32 = 10; + let expr = col("time").eq(lit_timestamp_nano(i)); + assert_eq!(expr, expected); + } +} diff --git a/datafusion-expr/src/operator.rs b/datafusion-expr/src/operator.rs index e6b7e35a0a5e..a1cad76cdd97 100644 --- a/datafusion-expr/src/operator.rs +++ b/datafusion-expr/src/operator.rs @@ -15,7 +15,10 @@ // specific language governing permissions and limitations // under the License. +use crate::expr_fn::binary_expr; +use crate::Expr; use std::fmt; +use std::ops; /// Operators applied to expressions #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash)] @@ -95,3 +98,43 @@ impl fmt::Display for Operator { write!(f, "{}", display) } } + +impl ops::Add for Expr { + type Output = Self; + + fn add(self, rhs: Self) -> Self { + binary_expr(self, Operator::Plus, rhs) + } +} + +impl ops::Sub for Expr { + type Output = Self; + + fn sub(self, rhs: Self) -> Self { + binary_expr(self, Operator::Minus, rhs) + } +} + +impl ops::Mul for Expr { + type Output = Self; + + fn mul(self, rhs: Self) -> Self { + binary_expr(self, Operator::Multiply, rhs) + } +} + +impl ops::Div for Expr { + type Output = Self; + + fn div(self, rhs: Self) -> Self { + binary_expr(self, Operator::Divide, rhs) + } +} + +impl ops::Rem for Expr { + type Output = Self; + + fn rem(self, rhs: Self) -> Self { + binary_expr(self, Operator::Modulo, rhs) + } +} diff --git a/datafusion-expr/src/udaf.rs b/datafusion-expr/src/udaf.rs new file mode 100644 index 000000000000..a39d58b622f3 --- /dev/null +++ b/datafusion-expr/src/udaf.rs @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module contains functions and structs supporting user-defined aggregate functions. + +use crate::Expr; +use crate::{ + AccumulatorFunctionImplementation, ReturnTypeFunction, Signature, StateTypeFunction, +}; +use std::fmt::{self, Debug, Formatter}; +use std::sync::Arc; + +/// Logical representation of a user-defined aggregate function (UDAF) +/// A UDAF is different from a UDF in that it is stateful across batches. +#[derive(Clone)] +pub struct AggregateUDF { + /// name + pub name: String, + /// signature + pub signature: Signature, + /// Return type + pub return_type: ReturnTypeFunction, + /// actual implementation + pub accumulator: AccumulatorFunctionImplementation, + /// the accumulator's state's description as a function of the return type + pub state_type: StateTypeFunction, +} + +impl Debug for AggregateUDF { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("AggregateUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("fun", &"") + .finish() + } +} + +impl PartialEq for AggregateUDF { + fn eq(&self, other: &Self) -> bool { + self.name == other.name && self.signature == other.signature + } +} + +impl std::hash::Hash for AggregateUDF { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.signature.hash(state); + } +} + +impl AggregateUDF { + /// Create a new AggregateUDF + pub fn new( + name: &str, + signature: &Signature, + return_type: &ReturnTypeFunction, + accumulator: &AccumulatorFunctionImplementation, + state_type: &StateTypeFunction, + ) -> Self { + Self { + name: name.to_owned(), + signature: signature.clone(), + return_type: return_type.clone(), + accumulator: accumulator.clone(), + state_type: state_type.clone(), + } + } + + /// creates a logical expression with a call of the UDAF + /// This utility allows using the UDAF without requiring access to the registry. + pub fn call(&self, args: Vec) -> Expr { + Expr::AggregateUDF { + fun: Arc::new(self.clone()), + args, + } + } +} diff --git a/datafusion-expr/src/udf.rs b/datafusion-expr/src/udf.rs new file mode 100644 index 000000000000..79a17a4a2b4b --- /dev/null +++ b/datafusion-expr/src/udf.rs @@ -0,0 +1,93 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! UDF support + +use crate::{Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature}; +use std::fmt; +use std::fmt::Debug; +use std::fmt::Formatter; +use std::sync::Arc; + +/// Logical representation of a UDF. +#[derive(Clone)] +pub struct ScalarUDF { + /// name + pub name: String, + /// signature + pub signature: Signature, + /// Return type + pub return_type: ReturnTypeFunction, + /// actual implementation + /// + /// The fn param is the wrapped function but be aware that the function will + /// be passed with the slice / vec of columnar values (either scalar or array) + /// with the exception of zero param function, where a singular element vec + /// will be passed. In that case the single element is a null array to indicate + /// the batch's row count (so that the generative zero-argument function can know + /// the result array size). + pub fun: ScalarFunctionImplementation, +} + +impl Debug for ScalarUDF { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("ScalarUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("fun", &"") + .finish() + } +} + +impl PartialEq for ScalarUDF { + fn eq(&self, other: &Self) -> bool { + self.name == other.name && self.signature == other.signature + } +} + +impl std::hash::Hash for ScalarUDF { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.signature.hash(state); + } +} + +impl ScalarUDF { + /// Create a new ScalarUDF + pub fn new( + name: &str, + signature: &Signature, + return_type: &ReturnTypeFunction, + fun: &ScalarFunctionImplementation, + ) -> Self { + Self { + name: name.to_owned(), + signature: signature.clone(), + return_type: return_type.clone(), + fun: fun.clone(), + } + } + + /// creates a logical expression with a call of the UDF + /// This utility allows using the UDF without requiring access to the registry. + pub fn call(&self, args: Vec) -> Expr { + Expr::ScalarUDF { + fun: Arc::new(self.clone()), + args, + } + } +} diff --git a/datafusion/src/execution/dataframe_impl.rs b/datafusion/src/execution/dataframe_impl.rs index 3fcaa28af973..0e3cc61f3b5a 100644 --- a/datafusion/src/execution/dataframe_impl.rs +++ b/datafusion/src/execution/dataframe_impl.rs @@ -321,12 +321,12 @@ mod tests { use super::*; use crate::execution::options::CsvReadOptions; - use crate::physical_plan::functions::ScalarFunctionImplementation; - use crate::physical_plan::functions::Volatility; use crate::physical_plan::{window_functions, ColumnarValue}; use crate::{assert_batches_sorted_eq, execution::context::ExecutionContext}; use crate::{logical_plan::*, test_util}; use arrow::datatypes::DataType; + use datafusion_expr::ScalarFunctionImplementation; + use datafusion_expr::Volatility; #[tokio::test] async fn select_columns() -> Result<()> { diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index f19e9d8d6a35..de052983f770 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -21,379 +21,22 @@ pub use super::Operator; use crate::error::{DataFusionError, Result}; use crate::logical_plan::ExprSchemable; -use crate::logical_plan::{window_frames, DFField, DFSchema}; -use crate::physical_plan::functions::Volatility; -use crate::physical_plan::{aggregates, functions, udf::ScalarUDF, window_functions}; -use crate::{physical_plan::udaf::AggregateUDF, scalar::ScalarValue}; -use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction}; +use crate::logical_plan::{DFField, DFSchema}; +use crate::physical_plan::udaf::AggregateUDF; +use crate::physical_plan::{aggregates, functions, udf::ScalarUDF}; use arrow::datatypes::DataType; pub use datafusion_common::{Column, ExprSchema}; -use functions::{ReturnTypeFunction, ScalarFunctionImplementation, Signature}; +pub use datafusion_expr::expr_fn::col; +use datafusion_expr::AccumulatorFunctionImplementation; +pub use datafusion_expr::Expr; +use datafusion_expr::StateTypeFunction; +pub use datafusion_expr::{lit, lit_timestamp_nano, Literal}; +use datafusion_expr::{ + ReturnTypeFunction, ScalarFunctionImplementation, Signature, Volatility, +}; use std::collections::HashSet; -use std::fmt; -use std::hash::{BuildHasher, Hash, Hasher}; -use std::ops::Not; use std::sync::Arc; -/// `Expr` is a central struct of DataFusion's query API, and -/// represent logical expressions such as `A + 1`, or `CAST(c1 AS -/// int)`. -/// -/// An `Expr` can compute its [DataType](arrow::datatypes::DataType) -/// and nullability, and has functions for building up complex -/// expressions. -/// -/// # Examples -/// -/// ## Create an expression `c1` referring to column named "c1" -/// ``` -/// # use datafusion::logical_plan::*; -/// let expr = col("c1"); -/// assert_eq!(expr, Expr::Column(Column::from_name("c1"))); -/// ``` -/// -/// ## Create the expression `c1 + c2` to add columns "c1" and "c2" together -/// ``` -/// # use datafusion::logical_plan::*; -/// let expr = col("c1") + col("c2"); -/// -/// assert!(matches!(expr, Expr::BinaryExpr { ..} )); -/// if let Expr::BinaryExpr { left, right, op } = expr { -/// assert_eq!(*left, col("c1")); -/// assert_eq!(*right, col("c2")); -/// assert_eq!(op, Operator::Plus); -/// } -/// ``` -/// -/// ## Create expression `c1 = 42` to compare the value in coumn "c1" to the literal value `42` -/// ``` -/// # use datafusion::logical_plan::*; -/// # use datafusion::scalar::*; -/// let expr = col("c1").eq(lit(42)); -/// -/// assert!(matches!(expr, Expr::BinaryExpr { ..} )); -/// if let Expr::BinaryExpr { left, right, op } = expr { -/// assert_eq!(*left, col("c1")); -/// let scalar = ScalarValue::Int32(Some(42)); -/// assert_eq!(*right, Expr::Literal(scalar)); -/// assert_eq!(op, Operator::Eq); -/// } -/// ``` -#[derive(Clone, PartialEq, Hash)] -pub enum Expr { - /// An expression with a specific name. - Alias(Box, String), - /// A named reference to a qualified filed in a schema. - Column(Column), - /// A named reference to a variable in a registry. - ScalarVariable(Vec), - /// A constant value. - Literal(ScalarValue), - /// A binary expression such as "age > 21" - BinaryExpr { - /// Left-hand side of the expression - left: Box, - /// The comparison operator - op: Operator, - /// Right-hand side of the expression - right: Box, - }, - /// Negation of an expression. The expression's type must be a boolean to make sense. - Not(Box), - /// Whether an expression is not Null. This expression is never null. - IsNotNull(Box), - /// Whether an expression is Null. This expression is never null. - IsNull(Box), - /// arithmetic negation of an expression, the operand must be of a signed numeric data type - Negative(Box), - /// Returns the field of a [`ListArray`] or [`StructArray`] by key - GetIndexedField { - /// the expression to take the field from - expr: Box, - /// The name of the field to take - key: ScalarValue, - }, - /// Whether an expression is between a given range. - Between { - /// The value to compare - expr: Box, - /// Whether the expression is negated - negated: bool, - /// The low end of the range - low: Box, - /// The high end of the range - high: Box, - }, - /// The CASE expression is similar to a series of nested if/else and there are two forms that - /// can be used. The first form consists of a series of boolean "when" expressions with - /// corresponding "then" expressions, and an optional "else" expression. - /// - /// CASE WHEN condition THEN result - /// [WHEN ...] - /// [ELSE result] - /// END - /// - /// The second form uses a base expression and then a series of "when" clauses that match on a - /// literal value. - /// - /// CASE expression - /// WHEN value THEN result - /// [WHEN ...] - /// [ELSE result] - /// END - Case { - /// Optional base expression that can be compared to literal values in the "when" expressions - expr: Option>, - /// One or more when/then expressions - when_then_expr: Vec<(Box, Box)>, - /// Optional "else" expression - else_expr: Option>, - }, - /// Casts the expression to a given type and will return a runtime error if the expression cannot be cast. - /// This expression is guaranteed to have a fixed type. - Cast { - /// The expression being cast - expr: Box, - /// The `DataType` the expression will yield - data_type: DataType, - }, - /// Casts the expression to a given type and will return a null value if the expression cannot be cast. - /// This expression is guaranteed to have a fixed type. - TryCast { - /// The expression being cast - expr: Box, - /// The `DataType` the expression will yield - data_type: DataType, - }, - /// A sort expression, that can be used to sort values. - Sort { - /// The expression to sort on - expr: Box, - /// The direction of the sort - asc: bool, - /// Whether to put Nulls before all other data values - nulls_first: bool, - }, - /// Represents the call of a built-in scalar function with a set of arguments. - ScalarFunction { - /// The function - fun: functions::BuiltinScalarFunction, - /// List of expressions to feed to the functions as arguments - args: Vec, - }, - /// Represents the call of a user-defined scalar function with arguments. - ScalarUDF { - /// The function - fun: Arc, - /// List of expressions to feed to the functions as arguments - args: Vec, - }, - /// Represents the call of an aggregate built-in function with arguments. - AggregateFunction { - /// Name of the function - fun: aggregates::AggregateFunction, - /// List of expressions to feed to the functions as arguments - args: Vec, - /// Whether this is a DISTINCT aggregation or not - distinct: bool, - }, - /// Represents the call of a window function with arguments. - WindowFunction { - /// Name of the function - fun: window_functions::WindowFunction, - /// List of expressions to feed to the functions as arguments - args: Vec, - /// List of partition by expressions - partition_by: Vec, - /// List of order by expressions - order_by: Vec, - /// Window frame - window_frame: Option, - }, - /// aggregate function - AggregateUDF { - /// The function - fun: Arc, - /// List of expressions to feed to the functions as arguments - args: Vec, - }, - /// Returns whether the list contains the expr value. - InList { - /// The expression to compare - expr: Box, - /// A list of values to compare against - list: Vec, - /// Whether the expression is negated - negated: bool, - }, - /// Represents a reference to all fields in a schema. - Wildcard, -} - -/// Fixed seed for the hashing so that Ords are consistent across runs -const SEED: ahash::RandomState = ahash::RandomState::with_seeds(0, 0, 0, 0); - -impl PartialOrd for Expr { - fn partial_cmp(&self, other: &Self) -> Option { - let mut hasher = SEED.build_hasher(); - self.hash(&mut hasher); - let s = hasher.finish(); - - let mut hasher = SEED.build_hasher(); - other.hash(&mut hasher); - let o = hasher.finish(); - - Some(s.cmp(&o)) - } -} - -impl Expr { - /// Returns the name of this expression based on [crate::logical_plan::DFSchema]. - /// - /// This represents how a column with this expression is named when no alias is chosen - pub fn name(&self, input_schema: &DFSchema) -> Result { - create_name(self, input_schema) - } - - /// Return `self == other` - pub fn eq(self, other: Expr) -> Expr { - binary_expr(self, Operator::Eq, other) - } - - /// Return `self != other` - pub fn not_eq(self, other: Expr) -> Expr { - binary_expr(self, Operator::NotEq, other) - } - - /// Return `self > other` - pub fn gt(self, other: Expr) -> Expr { - binary_expr(self, Operator::Gt, other) - } - - /// Return `self >= other` - pub fn gt_eq(self, other: Expr) -> Expr { - binary_expr(self, Operator::GtEq, other) - } - - /// Return `self < other` - pub fn lt(self, other: Expr) -> Expr { - binary_expr(self, Operator::Lt, other) - } - - /// Return `self <= other` - pub fn lt_eq(self, other: Expr) -> Expr { - binary_expr(self, Operator::LtEq, other) - } - - /// Return `self && other` - pub fn and(self, other: Expr) -> Expr { - binary_expr(self, Operator::And, other) - } - - /// Return `self || other` - pub fn or(self, other: Expr) -> Expr { - binary_expr(self, Operator::Or, other) - } - - /// Return `!self` - #[allow(clippy::should_implement_trait)] - pub fn not(self) -> Expr { - !self - } - - /// Calculate the modulus of two expressions. - /// Return `self % other` - pub fn modulus(self, other: Expr) -> Expr { - binary_expr(self, Operator::Modulo, other) - } - - /// Return `self LIKE other` - pub fn like(self, other: Expr) -> Expr { - binary_expr(self, Operator::Like, other) - } - - /// Return `self NOT LIKE other` - pub fn not_like(self, other: Expr) -> Expr { - binary_expr(self, Operator::NotLike, other) - } - - /// Return `self AS name` alias expression - pub fn alias(self, name: &str) -> Expr { - Expr::Alias(Box::new(self), name.to_owned()) - } - - /// Return `self IN ` if `negated` is false, otherwise - /// return `self NOT IN `.a - pub fn in_list(self, list: Vec, negated: bool) -> Expr { - Expr::InList { - expr: Box::new(self), - list, - negated, - } - } - - /// Return `IsNull(Box(self)) - #[allow(clippy::wrong_self_convention)] - pub fn is_null(self) -> Expr { - Expr::IsNull(Box::new(self)) - } - - /// Return `IsNotNull(Box(self)) - #[allow(clippy::wrong_self_convention)] - pub fn is_not_null(self) -> Expr { - Expr::IsNotNull(Box::new(self)) - } - - /// Create a sort expression from an existing expression. - /// - /// ``` - /// # use datafusion::logical_plan::col; - /// let sort_expr = col("foo").sort(true, true); // SORT ASC NULLS_FIRST - /// ``` - pub fn sort(self, asc: bool, nulls_first: bool) -> Expr { - Expr::Sort { - expr: Box::new(self), - asc, - nulls_first, - } - } -} - -impl Not for Expr { - type Output = Self; - - fn not(self) -> Self::Output { - Expr::Not(Box::new(self)) - } -} - -impl std::fmt::Display for Expr { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - Expr::BinaryExpr { - ref left, - ref right, - ref op, - } => write!(f, "{} {} {}", left, op, right), - Expr::AggregateFunction { - /// Name of the function - ref fun, - /// List of expressions to feed to the functions as arguments - ref args, - /// Whether this is a DISTINCT aggregation or not - ref distinct, - } => fmt_function(f, &fun.to_string(), *distinct, args, true), - Expr::ScalarFunction { - /// Name of the function - ref fun, - /// List of expressions to feed to the functions as arguments - ref args, - } => fmt_function(f, &fun.to_string(), false, args, true), - _ => write!(f, "{:?}", self), - } - } -} - /// Helper struct for building [Expr::Case] pub struct CaseBuilder { expr: Option>, @@ -484,15 +127,6 @@ pub fn when(when: Expr, then: Expr) -> CaseBuilder { } } -/// return a new expression l r -pub fn binary_expr(l: Expr, op: Operator, r: Expr) -> Expr { - Expr::BinaryExpr { - left: Box::new(l), - op, - right: Box::new(r), - } -} - /// return a new expression with a logical AND pub fn and(left: Expr, right: Expr) -> Expr { Expr::BinaryExpr { @@ -525,11 +159,6 @@ pub fn or(left: Expr, right: Expr) -> Expr { } } -/// Create a column expression based on a qualified or unqualified column name -pub fn col(ident: &str) -> Expr { - Expr::Column(ident.into()) -} - /// Convert an expression into Column expression if it's already provided as input plan. /// /// For example, it rewrites: @@ -634,102 +263,6 @@ pub fn in_list(expr: Expr, list: Vec, negated: bool) -> Expr { } } -/// Trait for converting a type to a [`Literal`] literal expression. -pub trait Literal { - /// convert the value to a Literal expression - fn lit(&self) -> Expr; -} - -/// Trait for converting a type to a literal timestamp -pub trait TimestampLiteral { - fn lit_timestamp_nano(&self) -> Expr; -} - -impl Literal for &str { - fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned()))) - } -} - -impl Literal for String { - fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned()))) - } -} - -impl Literal for Vec { - fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Binary(Some((*self).to_owned()))) - } -} - -impl Literal for &[u8] { - fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Binary(Some((*self).to_owned()))) - } -} - -impl Literal for ScalarValue { - fn lit(&self) -> Expr { - Expr::Literal(self.clone()) - } -} - -macro_rules! make_literal { - ($TYPE:ty, $SCALAR:ident, $DOC: expr) => { - #[doc = $DOC] - impl Literal for $TYPE { - fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::$SCALAR(Some(self.clone()))) - } - } - }; -} - -macro_rules! make_timestamp_literal { - ($TYPE:ty, $SCALAR:ident, $DOC: expr) => { - #[doc = $DOC] - impl TimestampLiteral for $TYPE { - fn lit_timestamp_nano(&self) -> Expr { - Expr::Literal(ScalarValue::TimestampNanosecond( - Some((self.clone()).into()), - None, - )) - } - } - }; -} - -make_literal!(bool, Boolean, "literal expression containing a bool"); -make_literal!(f32, Float32, "literal expression containing an f32"); -make_literal!(f64, Float64, "literal expression containing an f64"); -make_literal!(i8, Int8, "literal expression containing an i8"); -make_literal!(i16, Int16, "literal expression containing an i16"); -make_literal!(i32, Int32, "literal expression containing an i32"); -make_literal!(i64, Int64, "literal expression containing an i64"); -make_literal!(u8, UInt8, "literal expression containing a u8"); -make_literal!(u16, UInt16, "literal expression containing a u16"); -make_literal!(u32, UInt32, "literal expression containing a u32"); -make_literal!(u64, UInt64, "literal expression containing a u64"); - -make_timestamp_literal!(i8, Int8, "literal expression containing an i8"); -make_timestamp_literal!(i16, Int16, "literal expression containing an i16"); -make_timestamp_literal!(i32, Int32, "literal expression containing an i32"); -make_timestamp_literal!(i64, Int64, "literal expression containing an i64"); -make_timestamp_literal!(u8, UInt8, "literal expression containing a u8"); -make_timestamp_literal!(u16, UInt16, "literal expression containing a u16"); -make_timestamp_literal!(u32, UInt32, "literal expression containing a u32"); - -/// Create a literal expression -pub fn lit(n: T) -> Expr { - n.lit() -} - -/// Create a literal timestamp expression -pub fn lit_timestamp_nano(n: T) -> Expr { - n.lit_timestamp_nano() -} - /// Concatenates the text representations of all the arguments. NULL arguments are ignored. pub fn concat(args: &[Expr]) -> Expr { Expr::ScalarFunction { @@ -934,311 +467,6 @@ pub fn create_udaf( ) } -fn fmt_function( - f: &mut fmt::Formatter, - fun: &str, - distinct: bool, - args: &[Expr], - display: bool, -) -> fmt::Result { - let args: Vec = match display { - true => args.iter().map(|arg| format!("{}", arg)).collect(), - false => args.iter().map(|arg| format!("{:?}", arg)).collect(), - }; - - // let args: Vec = args.iter().map(|arg| format!("{:?}", arg)).collect(); - let distinct_str = match distinct { - true => "DISTINCT ", - false => "", - }; - write!(f, "{}({}{})", fun, distinct_str, args.join(", ")) -} - -impl fmt::Debug for Expr { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Expr::Alias(expr, alias) => write!(f, "{:?} AS {}", expr, alias), - Expr::Column(c) => write!(f, "{}", c), - Expr::ScalarVariable(var_names) => write!(f, "{}", var_names.join(".")), - Expr::Literal(v) => write!(f, "{:?}", v), - Expr::Case { - expr, - when_then_expr, - else_expr, - .. - } => { - write!(f, "CASE ")?; - if let Some(e) = expr { - write!(f, "{:?} ", e)?; - } - for (w, t) in when_then_expr { - write!(f, "WHEN {:?} THEN {:?} ", w, t)?; - } - if let Some(e) = else_expr { - write!(f, "ELSE {:?} ", e)?; - } - write!(f, "END") - } - Expr::Cast { expr, data_type } => { - write!(f, "CAST({:?} AS {:?})", expr, data_type) - } - Expr::TryCast { expr, data_type } => { - write!(f, "TRY_CAST({:?} AS {:?})", expr, data_type) - } - Expr::Not(expr) => write!(f, "NOT {:?}", expr), - Expr::Negative(expr) => write!(f, "(- {:?})", expr), - Expr::IsNull(expr) => write!(f, "{:?} IS NULL", expr), - Expr::IsNotNull(expr) => write!(f, "{:?} IS NOT NULL", expr), - Expr::BinaryExpr { left, op, right } => { - write!(f, "{:?} {} {:?}", left, op, right) - } - Expr::Sort { - expr, - asc, - nulls_first, - } => { - if *asc { - write!(f, "{:?} ASC", expr)?; - } else { - write!(f, "{:?} DESC", expr)?; - } - if *nulls_first { - write!(f, " NULLS FIRST") - } else { - write!(f, " NULLS LAST") - } - } - Expr::ScalarFunction { fun, args, .. } => { - fmt_function(f, &fun.to_string(), false, args, false) - } - Expr::ScalarUDF { fun, ref args, .. } => { - fmt_function(f, &fun.name, false, args, false) - } - Expr::WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - } => { - fmt_function(f, &fun.to_string(), false, args, false)?; - if !partition_by.is_empty() { - write!(f, " PARTITION BY {:?}", partition_by)?; - } - if !order_by.is_empty() { - write!(f, " ORDER BY {:?}", order_by)?; - } - if let Some(window_frame) = window_frame { - write!( - f, - " {} BETWEEN {} AND {}", - window_frame.units, - window_frame.start_bound, - window_frame.end_bound - )?; - } - Ok(()) - } - Expr::AggregateFunction { - fun, - distinct, - ref args, - .. - } => fmt_function(f, &fun.to_string(), *distinct, args, true), - Expr::AggregateUDF { fun, ref args, .. } => { - fmt_function(f, &fun.name, false, args, false) - } - Expr::Between { - expr, - negated, - low, - high, - } => { - if *negated { - write!(f, "{:?} NOT BETWEEN {:?} AND {:?}", expr, low, high) - } else { - write!(f, "{:?} BETWEEN {:?} AND {:?}", expr, low, high) - } - } - Expr::InList { - expr, - list, - negated, - } => { - if *negated { - write!(f, "{:?} NOT IN ({:?})", expr, list) - } else { - write!(f, "{:?} IN ({:?})", expr, list) - } - } - Expr::Wildcard => write!(f, "*"), - Expr::GetIndexedField { ref expr, key } => { - write!(f, "({:?})[{}]", expr, key) - } - } - } -} - -fn create_function_name( - fun: &str, - distinct: bool, - args: &[Expr], - input_schema: &DFSchema, -) -> Result { - let names: Vec = args - .iter() - .map(|e| create_name(e, input_schema)) - .collect::>()?; - let distinct_str = match distinct { - true => "DISTINCT ", - false => "", - }; - Ok(format!("{}({}{})", fun, distinct_str, names.join(","))) -} - -/// Returns a readable name of an expression based on the input schema. -/// This function recursively transverses the expression for names such as "CAST(a > 2)". -fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { - match e { - Expr::Alias(_, name) => Ok(name.clone()), - Expr::Column(c) => Ok(c.flat_name()), - Expr::ScalarVariable(variable_names) => Ok(variable_names.join(".")), - Expr::Literal(value) => Ok(format!("{:?}", value)), - Expr::BinaryExpr { left, op, right } => { - let left = create_name(left, input_schema)?; - let right = create_name(right, input_schema)?; - Ok(format!("{} {} {}", left, op, right)) - } - Expr::Case { - expr, - when_then_expr, - else_expr, - } => { - let mut name = "CASE ".to_string(); - if let Some(e) = expr { - let e = create_name(e, input_schema)?; - name += &format!("{} ", e); - } - for (w, t) in when_then_expr { - let when = create_name(w, input_schema)?; - let then = create_name(t, input_schema)?; - name += &format!("WHEN {} THEN {} ", when, then); - } - if let Some(e) = else_expr { - let e = create_name(e, input_schema)?; - name += &format!("ELSE {} ", e); - } - name += "END"; - Ok(name) - } - Expr::Cast { expr, data_type } => { - let expr = create_name(expr, input_schema)?; - Ok(format!("CAST({} AS {:?})", expr, data_type)) - } - Expr::TryCast { expr, data_type } => { - let expr = create_name(expr, input_schema)?; - Ok(format!("TRY_CAST({} AS {:?})", expr, data_type)) - } - Expr::Not(expr) => { - let expr = create_name(expr, input_schema)?; - Ok(format!("NOT {}", expr)) - } - Expr::Negative(expr) => { - let expr = create_name(expr, input_schema)?; - Ok(format!("(- {})", expr)) - } - Expr::IsNull(expr) => { - let expr = create_name(expr, input_schema)?; - Ok(format!("{} IS NULL", expr)) - } - Expr::IsNotNull(expr) => { - let expr = create_name(expr, input_schema)?; - Ok(format!("{} IS NOT NULL", expr)) - } - Expr::GetIndexedField { expr, key } => { - let expr = create_name(expr, input_schema)?; - Ok(format!("{}[{}]", expr, key)) - } - Expr::ScalarFunction { fun, args, .. } => { - create_function_name(&fun.to_string(), false, args, input_schema) - } - Expr::ScalarUDF { fun, args, .. } => { - create_function_name(&fun.name, false, args, input_schema) - } - Expr::WindowFunction { - fun, - args, - window_frame, - partition_by, - order_by, - } => { - let mut parts: Vec = vec![create_function_name( - &fun.to_string(), - false, - args, - input_schema, - )?]; - if !partition_by.is_empty() { - parts.push(format!("PARTITION BY {:?}", partition_by)); - } - if !order_by.is_empty() { - parts.push(format!("ORDER BY {:?}", order_by)); - } - if let Some(window_frame) = window_frame { - parts.push(format!("{}", window_frame)); - } - Ok(parts.join(" ")) - } - Expr::AggregateFunction { - fun, - distinct, - args, - .. - } => create_function_name(&fun.to_string(), *distinct, args, input_schema), - Expr::AggregateUDF { fun, args } => { - let mut names = Vec::with_capacity(args.len()); - for e in args { - names.push(create_name(e, input_schema)?); - } - Ok(format!("{}({})", fun.name, names.join(","))) - } - Expr::InList { - expr, - list, - negated, - } => { - let expr = create_name(expr, input_schema)?; - let list = list.iter().map(|expr| create_name(expr, input_schema)); - if *negated { - Ok(format!("{} NOT IN ({:?})", expr, list)) - } else { - Ok(format!("{} IN ({:?})", expr, list)) - } - } - Expr::Between { - expr, - negated, - low, - high, - } => { - let expr = create_name(expr, input_schema)?; - let low = create_name(low, input_schema)?; - let high = create_name(high, input_schema)?; - if *negated { - Ok(format!("{} NOT BETWEEN {} AND {}", expr, low, high)) - } else { - Ok(format!("{} BETWEEN {} AND {}", expr, low, high)) - } - } - Expr::Sort { .. } => Err(DataFusionError::Internal( - "Create name does not support sort expression".to_string(), - )), - Expr::Wildcard => Err(DataFusionError::Internal( - "Create name does not support wildcard".to_string(), - )), - } -} - /// Create field meta-data from an expression, for use in a result set schema pub fn exprlist_to_fields<'a>( expr: impl IntoIterator, @@ -1265,6 +493,7 @@ pub fn call_fn(name: impl AsRef, args: Vec) -> Result { mod tests { use super::super::{col, lit, when}; use super::*; + use datafusion_expr::expr_fn::binary_expr; #[test] fn case_when_same_literal_then_types() -> Result<()> { @@ -1282,22 +511,6 @@ mod tests { assert!(maybe_expr.is_err()); } - #[test] - fn test_lit_timestamp_nano() { - let expr = col("time").eq(lit_timestamp_nano(10)); // 10 is an implicit i32 - let expected = - col("time").eq(lit(ScalarValue::TimestampNanosecond(Some(10), None))); - assert_eq!(expr, expected); - - let i: i64 = 10; - let expr = col("time").eq(lit_timestamp_nano(i)); - assert_eq!(expr, expected); - - let i: u32 = 10; - let expr = col("time").eq(lit_timestamp_nano(i)); - assert_eq!(expr, expected); - } - #[test] fn filter_is_null_and_is_not_null() { let col_null = col("col1"); diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index f2ecb0f76278..24d6723210c7 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -37,11 +37,12 @@ pub mod window_frames; pub use builder::{ build_join_schema, union_with_alias, LogicalPlanBuilder, UNNAMED_TABLE, }; +pub use datafusion_expr::expr_fn::binary_expr; pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; pub use expr::{ abs, acos, and, approx_distinct, approx_percentile_cont, array, ascii, asin, atan, - avg, binary_expr, bit_length, btrim, call_fn, case, ceil, character_length, chr, col, + avg, bit_length, btrim, call_fn, case, ceil, character_length, chr, col, columnize_expr, combine_filters, concat, concat_ws, cos, count, count_distinct, create_udaf, create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, floor, in_list, initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, diff --git a/datafusion/src/logical_plan/operators.rs b/datafusion/src/logical_plan/operators.rs index 813f7e0aac70..2f129284fa71 100644 --- a/datafusion/src/logical_plan/operators.rs +++ b/datafusion/src/logical_plan/operators.rs @@ -15,49 +15,7 @@ // specific language governing permissions and limitations // under the License. -use super::{binary_expr, Expr}; pub use datafusion_expr::Operator; -use std::ops; - -impl ops::Add for Expr { - type Output = Self; - - fn add(self, rhs: Self) -> Self { - binary_expr(self, Operator::Plus, rhs) - } -} - -impl ops::Sub for Expr { - type Output = Self; - - fn sub(self, rhs: Self) -> Self { - binary_expr(self, Operator::Minus, rhs) - } -} - -impl ops::Mul for Expr { - type Output = Self; - - fn mul(self, rhs: Self) -> Self { - binary_expr(self, Operator::Multiply, rhs) - } -} - -impl ops::Div for Expr { - type Output = Self; - - fn div(self, rhs: Self) -> Self { - binary_expr(self, Operator::Divide, rhs) - } -} - -impl ops::Rem for Expr { - type Output = Self; - - fn rem(self, rhs: Self) -> Self { - binary_expr(self, Operator::Modulo, rhs) - } -} #[cfg(test)] mod tests { diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index a1531d4a7b83..10096504bcb4 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -28,7 +28,7 @@ use super::{ functions::{Signature, TypeSignature, Volatility}, - Accumulator, AggregateExpr, PhysicalExpr, + AggregateExpr, PhysicalExpr, }; use crate::error::{DataFusionError, Result}; use crate::physical_plan::coercion_rule::aggregate_rule::{coerce_exprs, coerce_types}; @@ -40,15 +40,6 @@ use expressions::{ }; use std::sync::Arc; -/// the implementation of an aggregate function -pub type AccumulatorFunctionImplementation = - Arc Result> + Send + Sync>; - -/// This signature corresponds to which types an aggregator serializes -/// its state, given its return datatype. -pub type StateTypeFunction = - Arc Result>> + Send + Sync>; - pub use datafusion_expr::AggregateFunction; /// Returns the datatype of the aggregate function. diff --git a/datafusion/src/physical_plan/udaf.rs b/datafusion/src/physical_plan/udaf.rs index 0de696d61172..71e7e0657596 100644 --- a/datafusion/src/physical_plan/udaf.rs +++ b/datafusion/src/physical_plan/udaf.rs @@ -17,7 +17,7 @@ //! This module contains functions and structs supporting user-defined aggregate functions. -use fmt::{Debug, Formatter}; +use fmt::Debug; use std::any::Any; use std::fmt; @@ -26,85 +26,14 @@ use arrow::{ datatypes::{DataType, Schema}, }; -use crate::physical_plan::PhysicalExpr; -use crate::{error::Result, logical_plan::Expr}; - use super::{ - aggregates::AccumulatorFunctionImplementation, - aggregates::StateTypeFunction, - expressions::format_state_name, - functions::{ReturnTypeFunction, Signature}, - type_coercion::coerce, - Accumulator, AggregateExpr, + expressions::format_state_name, type_coercion::coerce, Accumulator, AggregateExpr, }; -use std::sync::Arc; - -/// Logical representation of a user-defined aggregate function (UDAF) -/// A UDAF is different from a UDF in that it is stateful across batches. -#[derive(Clone)] -pub struct AggregateUDF { - /// name - pub name: String, - /// signature - pub signature: Signature, - /// Return type - pub return_type: ReturnTypeFunction, - /// actual implementation - pub accumulator: AccumulatorFunctionImplementation, - /// the accumulator's state's description as a function of the return type - pub state_type: StateTypeFunction, -} - -impl Debug for AggregateUDF { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("AggregateUDF") - .field("name", &self.name) - .field("signature", &self.signature) - .field("fun", &"") - .finish() - } -} - -impl PartialEq for AggregateUDF { - fn eq(&self, other: &Self) -> bool { - self.name == other.name && self.signature == other.signature - } -} - -impl std::hash::Hash for AggregateUDF { - fn hash(&self, state: &mut H) { - self.name.hash(state); - self.signature.hash(state); - } -} - -impl AggregateUDF { - /// Create a new AggregateUDF - pub fn new( - name: &str, - signature: &Signature, - return_type: &ReturnTypeFunction, - accumulator: &AccumulatorFunctionImplementation, - state_type: &StateTypeFunction, - ) -> Self { - Self { - name: name.to_owned(), - signature: signature.clone(), - return_type: return_type.clone(), - accumulator: accumulator.clone(), - state_type: state_type.clone(), - } - } +use crate::error::Result; +use crate::physical_plan::PhysicalExpr; +pub use datafusion_expr::AggregateUDF; - /// creates a logical expression with a call of the UDAF - /// This utility allows using the UDAF without requiring access to the registry. - pub fn call(&self, args: Vec) -> Expr { - Expr::AggregateUDF { - fun: Arc::new(self.clone()), - args, - } - } -} +use std::sync::Arc; /// Creates a physical expression of the UDAF, that includes all necessary type coercion. /// This function errors when `args`' can't be coerced to a valid argument type of the UDAF. diff --git a/datafusion/src/physical_plan/udf.rs b/datafusion/src/physical_plan/udf.rs index 7355746a368b..58e66da48a7d 100644 --- a/datafusion/src/physical_plan/udf.rs +++ b/datafusion/src/physical_plan/udf.rs @@ -17,91 +17,16 @@ //! UDF support -use fmt::{Debug, Formatter}; -use std::fmt; - +use super::type_coercion::coerce; +use crate::error::Result; +use crate::physical_plan::functions::ScalarFunctionExpr; +use crate::physical_plan::PhysicalExpr; use arrow::datatypes::Schema; -use crate::error::Result; -use crate::{logical_plan::Expr, physical_plan::PhysicalExpr}; +pub use datafusion_expr::ScalarUDF; -use super::{ - functions::{ - ReturnTypeFunction, ScalarFunctionExpr, ScalarFunctionImplementation, Signature, - }, - type_coercion::coerce, -}; use std::sync::Arc; -/// Logical representation of a UDF. -#[derive(Clone)] -pub struct ScalarUDF { - /// name - pub name: String, - /// signature - pub signature: Signature, - /// Return type - pub return_type: ReturnTypeFunction, - /// actual implementation - /// - /// The fn param is the wrapped function but be aware that the function will - /// be passed with the slice / vec of columnar values (either scalar or array) - /// with the exception of zero param function, where a singular element vec - /// will be passed. In that case the single element is a null array to indicate - /// the batch's row count (so that the generative zero-argument function can know - /// the result array size). - pub fun: ScalarFunctionImplementation, -} - -impl Debug for ScalarUDF { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("ScalarUDF") - .field("name", &self.name) - .field("signature", &self.signature) - .field("fun", &"") - .finish() - } -} - -impl PartialEq for ScalarUDF { - fn eq(&self, other: &Self) -> bool { - self.name == other.name && self.signature == other.signature - } -} - -impl std::hash::Hash for ScalarUDF { - fn hash(&self, state: &mut H) { - self.name.hash(state); - self.signature.hash(state); - } -} - -impl ScalarUDF { - /// Create a new ScalarUDF - pub fn new( - name: &str, - signature: &Signature, - return_type: &ReturnTypeFunction, - fun: &ScalarFunctionImplementation, - ) -> Self { - Self { - name: name.to_owned(), - signature: signature.clone(), - return_type: return_type.clone(), - fun: fun.clone(), - } - } - - /// creates a logical expression with a call of the UDF - /// This utility allows using the UDF without requiring access to the registry. - pub fn call(&self, args: Vec) -> Expr { - Expr::ScalarUDF { - fun: Arc::new(self.clone()), - args, - } - } -} - /// Create a physical expression of the UDF. /// This function errors when `args`' can't be coerced to a valid argument type of the UDF. pub fn create_physical_expr( diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 682b92ba661f..2e417c75f3f0 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -2172,11 +2172,10 @@ pub fn convert_data_type(sql_type: &SQLDataType) -> Result { #[cfg(test)] mod tests { - use functions::ScalarFunctionImplementation; - use crate::datasource::empty::EmptyTable; use crate::physical_plan::functions::Volatility; use crate::{logical_plan::create_udf, sql::parser::DFParser}; + use datafusion_expr::ScalarFunctionImplementation; use super::*; From b2cfe2bae284a7cda51a9c9fb987c420aa1d9f39 Mon Sep 17 00:00:00 2001 From: Guillaume Balaine Date: Wed, 9 Feb 2022 12:40:19 +0100 Subject: [PATCH 49/50] fix bad data type in test_try_cast_decimal_to_decimal --- datafusion/src/physical_plan/expressions/try_cast.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/src/physical_plan/expressions/try_cast.rs b/datafusion/src/physical_plan/expressions/try_cast.rs index a2e74bbac798..0e5c5e81ea94 100644 --- a/datafusion/src/physical_plan/expressions/try_cast.rs +++ b/datafusion/src/physical_plan/expressions/try_cast.rs @@ -279,7 +279,7 @@ mod tests { // decimal to i8 generic_decimal_to_other_test_cast!( decimal_array, - DataType::Decimal(10, 0), + DataType::Decimal(10, 3), Int8Array, DataType::Int8, vec![ From 79107650f5d04165aef6009a46d0289e615b34aa Mon Sep 17 00:00:00 2001 From: Guillaume Balaine Date: Sat, 12 Feb 2022 09:27:08 +0100 Subject: [PATCH 50/50] added projections for avro columns --- .../src/avro_to_arrow/arrow_array_reader.rs | 2 ++ datafusion/src/avro_to_arrow/reader.rs | 23 ++++++++----------- datafusion/src/avro_to_arrow/schema.rs | 1 - 3 files changed, 12 insertions(+), 14 deletions(-) delete mode 100644 datafusion/src/avro_to_arrow/schema.rs diff --git a/datafusion/src/avro_to_arrow/arrow_array_reader.rs b/datafusion/src/avro_to_arrow/arrow_array_reader.rs index 0fd50e9b2c1f..8667c77fc9a8 100644 --- a/datafusion/src/avro_to_arrow/arrow_array_reader.rs +++ b/datafusion/src/avro_to_arrow/arrow_array_reader.rs @@ -38,6 +38,7 @@ impl<'a, R: Read> AvroBatchReader { avro_schemas: Vec, codec: Option, file_marker: [u8; 16], + projection: Option>, ) -> Result { let reader = AvroReader::new( read::Decompressor::new( @@ -46,6 +47,7 @@ impl<'a, R: Read> AvroBatchReader { ), avro_schemas, schema.fields.clone(), + projection, ); Ok(Self { reader, schema }) } diff --git a/datafusion/src/avro_to_arrow/reader.rs b/datafusion/src/avro_to_arrow/reader.rs index 7cb640e60560..a7a8e9549dfb 100644 --- a/datafusion/src/avro_to_arrow/reader.rs +++ b/datafusion/src/avro_to_arrow/reader.rs @@ -108,22 +108,16 @@ impl ReaderBuilder { // check if schema should be inferred source.seek(SeekFrom::Start(0))?; - let (mut avro_schemas, mut schema, codec, file_marker) = + let (avro_schemas, schema, codec, file_marker) = read::read_metadata(&mut source)?; - if let Some(proj) = self.projection { - let mut indices: Vec = schema + + let projection = self.projection.map(|proj| { + schema .fields .iter() - .filter(|f| !proj.contains(&f.name)) - .enumerate() - .map(|(i, _)| i) - .collect(); - indices.sort_by(|i1, i2| i2.cmp(i1)); - for i in indices { - avro_schemas.remove(i); - schema.fields.remove(i); - } - } + .map(|f| proj.contains(&f.name)) + .collect::>() + }); Reader::try_new( source, @@ -132,6 +126,7 @@ impl ReaderBuilder { avro_schemas, codec, file_marker, + projection, ) } } @@ -155,6 +150,7 @@ impl<'a, R: Read> Reader { avro_schemas: Vec, codec: Option, file_marker: [u8; 16], + projection: Option>, ) -> Result { Ok(Self { array_reader: AvroBatchReader::try_new( @@ -163,6 +159,7 @@ impl<'a, R: Read> Reader { avro_schemas, codec, file_marker, + projection, )?, schema, batch_size, diff --git a/datafusion/src/avro_to_arrow/schema.rs b/datafusion/src/avro_to_arrow/schema.rs deleted file mode 100644 index 8b137891791f..000000000000 --- a/datafusion/src/avro_to_arrow/schema.rs +++ /dev/null @@ -1 +0,0 @@ -