Skip to content

Commit

Permalink
Move logical expression type-coercion code from physical-expr crate…
Browse files Browse the repository at this point in the history
… to `expr` crate (#2257)
  • Loading branch information
andygrove authored Apr 18, 2022
1 parent c91efc2 commit 5f0b61b
Show file tree
Hide file tree
Showing 48 changed files with 1,923 additions and 1,867 deletions.
2 changes: 1 addition & 1 deletion ballista/rust/core/src/serde/physical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use datafusion::logical_plan::FunctionRegistry;

use datafusion::physical_plan::file_format::FileScanConfig;

use datafusion::physical_plan::window_functions::WindowFunction;
use datafusion::logical_expr::window_function::WindowFunction;

use datafusion::physical_plan::{
expressions::{
Expand Down
3 changes: 2 additions & 1 deletion ballista/rust/core/src/serde/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1064,10 +1064,11 @@ mod roundtrip_tests {

use datafusion::arrow::array::ArrayRef;
use datafusion::execution::context::ExecutionProps;
use datafusion::logical_expr::{BuiltinScalarFunction, Volatility};
use datafusion::logical_plan::create_udf;
use datafusion::physical_plan::functions;
use datafusion::physical_plan::functions::{
make_scalar_function, BuiltinScalarFunction, ScalarFunctionExpr, Volatility,
make_scalar_function, ScalarFunctionExpr,
};
use datafusion::physical_plan::projection::ProjectionExec;
use datafusion::{
Expand Down
3 changes: 2 additions & 1 deletion ballista/rust/core/src/serde/physical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ use datafusion::physical_plan::{AggregateExpr, PhysicalExpr};

use crate::serde::{protobuf, BallistaError};

use datafusion::physical_plan::functions::{BuiltinScalarFunction, ScalarFunctionExpr};
use datafusion::logical_expr::BuiltinScalarFunction;
use datafusion::physical_plan::functions::ScalarFunctionExpr;

impl TryInto<protobuf::PhysicalExprNode> for Arc<dyn AggregateExpr> {
type Error = BallistaError;
Expand Down
3 changes: 1 addition & 2 deletions datafusion-examples/examples/simple_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@ use datafusion::arrow::{
};

use datafusion::from_slice::FromSlice;
use datafusion::physical_plan::functions::Volatility;
use datafusion::{error::Result, logical_plan::create_udaf, physical_plan::Accumulator};
use datafusion::{prelude::*, scalar::ScalarValue};
use datafusion::{logical_expr::Volatility, prelude::*, scalar::ScalarValue};
use std::sync::Arc;

// create local session context with an in-memory table
Expand Down
2 changes: 1 addition & 1 deletion datafusion-examples/examples/simple_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use datafusion::{
datatypes::DataType,
record_batch::RecordBatch,
},
physical_plan::functions::Volatility,
logical_expr::Volatility,
};

use datafusion::from_slice::FromSlice;
Expand Down
10 changes: 5 additions & 5 deletions datafusion/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -648,12 +648,14 @@ mod tests {

use super::*;
use crate::execution::options::CsvReadOptions;
use crate::physical_plan::{window_functions, ColumnarValue};
use crate::physical_plan::ColumnarValue;
use crate::{assert_batches_sorted_eq, execution::context::SessionContext};
use crate::{logical_plan::*, test_util};
use arrow::datatypes::DataType;
use datafusion_expr::ScalarFunctionImplementation;
use datafusion_expr::Volatility;
use datafusion_expr::{
BuiltInWindowFunction, ScalarFunctionImplementation, WindowFunction,
};

#[tokio::test]
async fn select_columns() -> Result<()> {
Expand Down Expand Up @@ -693,9 +695,7 @@ mod tests {
// build plan using Table API
let t = test_table().await?;
let first_row = Expr::WindowFunction {
fun: window_functions::WindowFunction::BuiltInWindowFunction(
window_functions::BuiltInWindowFunction::FirstValue,
),
fun: WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue),
args: vec![col("aggregate_test_100.c1")],
partition_by: vec![col("aggregate_test_100.c2")],
order_by: vec![],
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/listing/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ use crate::{
error::Result,
execution::context::SessionContext,
logical_plan::{self, Expr, ExprVisitable, ExpressionVisitor, Recursion},
physical_plan::functions::Volatility,
scalar::ScalarValue,
};

use super::{PartitionedFile, PartitionedFileStream};
use datafusion_data_access::{object_store::ObjectStore, FileMeta, SizedFile};
use datafusion_expr::Volatility;

const FILE_SIZE_COLUMN_NAME: &str = "_df_part_file_size_";
const FILE_PATH_COLUMN_NAME: &str = "_df_part_file_path_";
Expand Down
3 changes: 2 additions & 1 deletion datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1603,7 +1603,7 @@ mod tests {
use crate::execution::context::QueryPlanner;
use crate::from_slice::FromSlice;
use crate::logical_plan::{binary_expr, lit, Operator};
use crate::physical_plan::functions::{make_scalar_function, Volatility};
use crate::physical_plan::functions::make_scalar_function;
use crate::test;
use crate::variable::VarType;
use crate::{
Expand All @@ -1622,6 +1622,7 @@ mod tests {
use arrow::datatypes::*;
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use datafusion_expr::Volatility;
use std::fs::File;
use std::sync::Weak;
use std::thread::{self, JoinHandle};
Expand Down
3 changes: 2 additions & 1 deletion datafusion/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,9 @@ pub mod variable;
pub use arrow;
pub use parquet;

// re-export object store dependencies
// re-export DataFusion crates
pub use datafusion_data_access;
pub use datafusion_expr as logical_expr;

#[cfg(feature = "row")]
pub mod row;
Expand Down
11 changes: 5 additions & 6 deletions datafusion/core/src/logical_plan/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@
// under the License.

use super::Expr;
use crate::physical_plan::{
aggregates, expressions::binary_operator_data_type, functions, window_functions,
};
use crate::logical_expr::{aggregate_function, function, window_function};
use arrow::compute::can_cast_types;
use arrow::datatypes::DataType;
use datafusion_common::{DFField, DFSchema, DataFusionError, ExprSchema, Result};
use datafusion_expr::binary_rule::binary_operator_data_type;
use datafusion_expr::field_util::get_indexed_field;

/// trait to allow expr to typable with respect to a schema
Expand Down Expand Up @@ -76,21 +75,21 @@ impl ExprSchemable for Expr {
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
functions::return_type(fun, &data_types)
function::return_type(fun, &data_types)
}
Expr::WindowFunction { fun, args, .. } => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
window_functions::return_type(fun, &data_types)
window_function::return_type(fun, &data_types)
}
Expr::AggregateFunction { fun, args, .. } => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
aggregates::return_type(fun, &data_types)
aggregate_function::return_type(fun, &data_types)
}
Expr::AggregateUDF { fun, args, .. } => {
let data_types = args
Expand Down
5 changes: 3 additions & 2 deletions datafusion/core/src/optimizer/simplify_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ use crate::logical_plan::{
};
use crate::optimizer::optimizer::OptimizerRule;
use crate::optimizer::utils;
use crate::physical_plan::functions::Volatility;
use crate::physical_plan::planner::create_physical_expr;
use crate::scalar::ScalarValue;
use crate::{error::Result, logical_plan::Operator};
use arrow::array::new_null_array;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use datafusion_expr::Volatility;

/// Provides simplification information based on schema and properties
pub(crate) struct SimplifyContext<'a, 'b> {
Expand Down Expand Up @@ -735,14 +735,15 @@ mod tests {

use arrow::array::{ArrayRef, Int32Array};
use chrono::{DateTime, TimeZone, Utc};
use datafusion_expr::BuiltinScalarFunction;

use super::*;
use crate::assert_contains;
use crate::logical_plan::{
and, binary_expr, call_fn, col, create_udf, lit, lit_timestamp_nano, DFField,
Expr, LogicalPlanBuilder,
};
use crate::physical_plan::functions::{make_scalar_function, BuiltinScalarFunction};
use crate::physical_plan::functions::make_scalar_function;
use crate::physical_plan::udf::ScalarUDF;

#[test]
Expand Down
99 changes: 1 addition & 98 deletions datafusion/core/src/physical_plan/aggregate_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,101 +17,4 @@

//! Support the coercion rule for aggregate function.

pub use datafusion_physical_expr::coercion_rule::aggregate_rule::{
coerce_exprs, coerce_types,
};

#[cfg(test)]
mod tests {
use super::*;
use crate::physical_plan::aggregates;
use arrow::datatypes::DataType;
use datafusion_expr::AggregateFunction;

#[test]
fn test_aggregate_coerce_types() {
// test input args with error number input types
let fun = AggregateFunction::Min;
let input_types = vec![DataType::Int64, DataType::Int32];
let signature = aggregates::signature(&fun);
let result = coerce_types(&fun, &input_types, &signature);
assert_eq!("Error during planning: The function Min expects 1 arguments, but 2 were provided", result.unwrap_err().to_string());

// test input args is invalid data type for sum or avg
let fun = AggregateFunction::Sum;
let input_types = vec![DataType::Utf8];
let signature = aggregates::signature(&fun);
let result = coerce_types(&fun, &input_types, &signature);
assert_eq!(
"Error during planning: The function Sum does not support inputs of type Utf8.",
result.unwrap_err().to_string()
);
let fun = AggregateFunction::Avg;
let signature = aggregates::signature(&fun);
let result = coerce_types(&fun, &input_types, &signature);
assert_eq!(
"Error during planning: The function Avg does not support inputs of type Utf8.",
result.unwrap_err().to_string()
);

// test count, array_agg, approx_distinct, min, max.
// the coerced types is same with input types
let funs = vec![
AggregateFunction::Count,
AggregateFunction::ArrayAgg,
AggregateFunction::ApproxDistinct,
AggregateFunction::Min,
AggregateFunction::Max,
];
let input_types = vec![
vec![DataType::Int32],
vec![DataType::Decimal(10, 2)],
vec![DataType::Utf8],
];
for fun in funs {
for input_type in &input_types {
let signature = aggregates::signature(&fun);
let result = coerce_types(&fun, input_type, &signature);
assert_eq!(*input_type, result.unwrap());
}
}
// test sum, avg
let funs = vec![AggregateFunction::Sum, AggregateFunction::Avg];
let input_types = vec![
vec![DataType::Int32],
vec![DataType::Float32],
vec![DataType::Decimal(20, 3)],
];
for fun in funs {
for input_type in &input_types {
let signature = aggregates::signature(&fun);
let result = coerce_types(&fun, input_type, &signature);
assert_eq!(*input_type, result.unwrap());
}
}

// ApproxPercentileCont input types
let input_types = vec![
vec![DataType::Int8, DataType::Float64],
vec![DataType::Int16, DataType::Float64],
vec![DataType::Int32, DataType::Float64],
vec![DataType::Int64, DataType::Float64],
vec![DataType::UInt8, DataType::Float64],
vec![DataType::UInt16, DataType::Float64],
vec![DataType::UInt32, DataType::Float64],
vec![DataType::UInt64, DataType::Float64],
vec![DataType::Float32, DataType::Float64],
vec![DataType::Float64, DataType::Float64],
];
for input_type in &input_types {
let signature =
aggregates::signature(&AggregateFunction::ApproxPercentileCont);
let result = coerce_types(
&AggregateFunction::ApproxPercentileCont,
input_type,
&signature,
);
assert_eq!(*input_type, result.unwrap());
}
}
}
pub use datafusion_physical_expr::coercion_rule::aggregate_rule::coerce_exprs;
Loading

0 comments on commit 5f0b61b

Please sign in to comment.