diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index be6cee9885aa..2f6266f29d6a 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -741,6 +741,16 @@ fn infer_placeholder_types(expr: Expr, schema: &DFSchema) -> Result { rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?; rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?; }; + if let Expr::Between(Between { + expr, + negated: _, + low, + high, + }) = &mut expr + { + rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?; + rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?; + } Ok(Transformed::Yes(expr)) }) } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 88a041a66145..f4ffea06b7a1 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -3894,6 +3894,40 @@ Projection: person.id, person.age prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); } +#[test] +fn test_prepare_statement_infer_types_from_between_predicate() { + let sql = "SELECT id, age FROM person WHERE age BETWEEN $1 AND $2"; + + let expected_plan = r#" +Projection: person.id, person.age + Filter: person.age BETWEEN $1 AND $2 + TableScan: person + "# + .trim(); + + let expected_dt = "[Int32]"; + let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + + let actual_types = plan.get_parameter_types().unwrap(); + let expected_types = HashMap::from([ + ("$1".to_string(), Some(DataType::Int32)), + ("$2".to_string(), Some(DataType::Int32)), + ]); + assert_eq!(actual_types, expected_types); + + // replace params with values + let param_values = vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))]; + let expected_plan = r#" +Projection: person.id, person.age + Filter: person.age BETWEEN Int32(10) AND Int32(30) + TableScan: person + "# + .trim(); + let plan = plan.replace_params_with_values(¶m_values).unwrap(); + + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); +} + #[test] fn test_prepare_statement_infer_types_subquery() { let sql = "SELECT id, age FROM person WHERE age = (select max(age) from person where id = $1)";