Skip to content

Commit

Permalink
[MINOR]: Update create_window_expr to refer only input schema (#8945)
Browse files Browse the repository at this point in the history
* create_window_expr now receives physical input schema

* Resolve linter errors

* Match argument signature for some window functions

* Remove physical input_schema
  • Loading branch information
mustafasrepo authored Jan 24, 2024
1 parent 19ca7d2 commit b5db718
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 64 deletions.
7 changes: 1 addition & 6 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ use datafusion_expr::expr::{
};
use datafusion_expr::expr_rewriter::unnormalize_cols;
use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
use datafusion_expr::utils::exprlist_to_fields;
use datafusion_expr::{
DescribeTable, DmlStatement, RecursiveQuery, ScalarFunctionDefinition,
StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp,
Expand Down Expand Up @@ -720,16 +719,12 @@ impl DefaultPhysicalPlanner {
}

let logical_input_schema = input.schema();
// Extend the schema to include window expression fields as builtin window functions derives its datatype from incoming schema
let mut window_fields = logical_input_schema.fields().clone();
window_fields.extend_from_slice(&exprlist_to_fields(window_expr.iter(), input)?);
let extended_schema = &DFSchema::new_with_metadata(window_fields, HashMap::new())?;
let window_expr = window_expr
.iter()
.map(|e| {
create_window_expr(
e,
extended_schema,
logical_input_schema,
session_state.execution_props(),
)
})
Expand Down
84 changes: 27 additions & 57 deletions datafusion/core/tests/fuzz_cases/window_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ use arrow::compute::{concat_batches, SortOptions};
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use arrow::util::pretty::pretty_format_batches;
use arrow_schema::{Field, Schema};
use datafusion::physical_plan::memory::MemoryExec;
use datafusion::physical_plan::sorts::sort::SortExec;
use datafusion::physical_plan::windows::{
Expand All @@ -38,7 +37,6 @@ use datafusion_expr::{
};
use datafusion_physical_expr::expressions::{cast, col, lit};
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
use itertools::Itertools;
use test_utils::add_empty_batches;

use hashbrown::HashMap;
Expand Down Expand Up @@ -229,14 +227,14 @@ fn get_random_function(
rng: &mut StdRng,
is_linear: bool,
) -> (WindowFunctionDefinition, Vec<Arc<dyn PhysicalExpr>>, String) {
let mut args = if is_linear {
let arg = if is_linear {
// In linear test for the test version with WindowAggExec we use insert SortExecs to the plan to be able to generate
// same result with BoundedWindowAggExec which doesn't use any SortExec. To make result
// non-dependent on table order. We should use column a in the window function
// (Given that we do not use ROWS for the window frame. ROWS also introduces dependency to the table order.).
vec![col("a", schema).unwrap()]
col("a", schema).unwrap()
} else {
vec![col("x", schema).unwrap()]
col("x", schema).unwrap()
};
let mut window_fn_map = HashMap::new();
// HashMap values consists of tuple first element is WindowFunction, second is additional argument
Expand All @@ -245,28 +243,28 @@ fn get_random_function(
"sum",
(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
vec![],
vec![arg.clone()],
),
);
window_fn_map.insert(
"count",
(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count),
vec![],
vec![arg.clone()],
),
);
window_fn_map.insert(
"min",
(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min),
vec![],
vec![arg.clone()],
),
);
window_fn_map.insert(
"max",
(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
vec![],
vec![arg.clone()],
),
);
if !is_linear {
Expand Down Expand Up @@ -307,6 +305,7 @@ fn get_random_function(
BuiltInWindowFunction::Lead,
),
vec![
arg.clone(),
lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))),
lit(ScalarValue::Int64(Some(rng.gen_range(1..1000)))),
],
Expand All @@ -319,6 +318,7 @@ fn get_random_function(
BuiltInWindowFunction::Lag,
),
vec![
arg.clone(),
lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))),
lit(ScalarValue::Int64(Some(rng.gen_range(1..1000)))),
],
Expand All @@ -331,7 +331,7 @@ fn get_random_function(
WindowFunctionDefinition::BuiltInWindowFunction(
BuiltInWindowFunction::FirstValue,
),
vec![],
vec![arg.clone()],
),
);
window_fn_map.insert(
Expand All @@ -340,7 +340,7 @@ fn get_random_function(
WindowFunctionDefinition::BuiltInWindowFunction(
BuiltInWindowFunction::LastValue,
),
vec![],
vec![arg.clone()],
),
);
window_fn_map.insert(
Expand All @@ -349,23 +349,26 @@ fn get_random_function(
WindowFunctionDefinition::BuiltInWindowFunction(
BuiltInWindowFunction::NthValue,
),
vec![lit(ScalarValue::Int64(Some(rng.gen_range(1..10))))],
vec![
arg.clone(),
lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))),
],
),
);

let rand_fn_idx = rng.gen_range(0..window_fn_map.len());
let fn_name = window_fn_map.keys().collect::<Vec<_>>()[rand_fn_idx];
let (window_fn, new_args) = window_fn_map.values().collect::<Vec<_>>()[rand_fn_idx];
let (window_fn, args) = window_fn_map.values().collect::<Vec<_>>()[rand_fn_idx];
let mut args = args.clone();
if let WindowFunctionDefinition::AggregateFunction(f) = window_fn {
let a = args[0].clone();
let dt = a.data_type(schema.as_ref()).unwrap();
let sig = f.signature();
let coerced = coerce_types(f, &[dt], &sig).unwrap();
args[0] = cast(a, schema, coerced[0].clone()).unwrap();
}

for new_arg in new_args {
args.push(new_arg.clone());
if !args.is_empty() {
// Do type coercion first argument
let a = args[0].clone();
let dt = a.data_type(schema.as_ref()).unwrap();
let sig = f.signature();
let coerced = coerce_types(f, &[dt], &sig).unwrap();
args[0] = cast(a, schema, coerced[0].clone()).unwrap();
}
}

(window_fn.clone(), args, fn_name.to_string())
Expand Down Expand Up @@ -534,39 +537,6 @@ async fn run_window_test(
exec1 = Arc::new(SortExec::new(sort_keys.clone(), exec1)) as _;
}

// The schema needs to be enriched before the `create_window_expr`
// The reason for this is window expressions datatypes are derived from the schema
// The datafusion code enriches the schema on physical planner and this test copies the same behavior manually
// Also bunch of functions dont require input arguments thus just send an empty vec for such functions
let data_types = if [
"row_number",
"rank",
"dense_rank",
"percent_rank",
"ntile",
"cume_dist",
]
.contains(&fn_name.as_str())
{
vec![]
} else {
args.iter()
.map(|e| e.clone().as_ref().data_type(&schema))
.collect::<Result<Vec<_>>>()?
};
let window_expr_return_type = window_fn.return_type(&data_types)?;
let mut window_fields = schema
.fields()
.iter()
.map(|f| f.as_ref().clone())
.collect_vec();
window_fields.extend_from_slice(&[Field::new(
&fn_name,
window_expr_return_type,
true,
)]);
let extended_schema = Arc::new(Schema::new(window_fields));

let usual_window_exec = Arc::new(
WindowAggExec::try_new(
vec![create_window_expr(
Expand All @@ -576,7 +546,7 @@ async fn run_window_test(
&partitionby_exprs,
&orderby_exprs,
Arc::new(window_frame.clone()),
&extended_schema,
schema.as_ref(),
)
.unwrap()],
exec1,
Expand All @@ -598,7 +568,7 @@ async fn run_window_test(
&partitionby_exprs,
&orderby_exprs,
Arc::new(window_frame.clone()),
extended_schema.as_ref(),
schema.as_ref(),
)
.unwrap()],
exec2,
Expand Down
9 changes: 8 additions & 1 deletion datafusion/physical-plan/src/windows/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,14 @@ fn create_built_in_window_expr(
input_schema: &Schema,
name: String,
) -> Result<Arc<dyn BuiltInWindowFunctionExpr>> {
let data_type = input_schema.field_with_name(&name)?.data_type();
// need to get the types into an owned vec for some reason
let input_types: Vec<_> = args
.iter()
.map(|arg| arg.data_type(input_schema))
.collect::<Result<_>>()?;

// figure out the output type
let data_type = &fun.return_type(&input_types)?;
Ok(match fun {
BuiltInWindowFunction::RowNumber => Arc::new(RowNumber::new(name, data_type)),
BuiltInWindowFunction::Rank => Arc::new(rank(name, data_type)),
Expand Down

0 comments on commit b5db718

Please sign in to comment.