Skip to content

Commit

Permalink
Extend infer_placeholder_types to support BETWEEN predicates (#7703)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrelmartins authored Oct 1, 2023
1 parent 14cdf72 commit f959127
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
10 changes: 10 additions & 0 deletions datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,16 @@ fn infer_placeholder_types(expr: Expr, schema: &DFSchema) -> Result<Expr> {
rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?;
rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?;
};
if let Expr::Between(Between {
expr,
negated: _,
low,
high,
}) = &mut expr
{
rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?;
rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?;
}
Ok(Transformed::Yes(expr))
})
}
Expand Down
34 changes: 34 additions & 0 deletions datafusion/sql/tests/sql_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3894,6 +3894,40 @@ Projection: person.id, person.age
prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan);
}

#[test]
fn test_prepare_statement_infer_types_from_between_predicate() {
let sql = "SELECT id, age FROM person WHERE age BETWEEN $1 AND $2";

let expected_plan = r#"
Projection: person.id, person.age
Filter: person.age BETWEEN $1 AND $2
TableScan: person
"#
.trim();

let expected_dt = "[Int32]";
let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt);

let actual_types = plan.get_parameter_types().unwrap();
let expected_types = HashMap::from([
("$1".to_string(), Some(DataType::Int32)),
("$2".to_string(), Some(DataType::Int32)),
]);
assert_eq!(actual_types, expected_types);

// replace params with values
let param_values = vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))];
let expected_plan = r#"
Projection: person.id, person.age
Filter: person.age BETWEEN Int32(10) AND Int32(30)
TableScan: person
"#
.trim();
let plan = plan.replace_params_with_values(&param_values).unwrap();

prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan);
}

#[test]
fn test_prepare_statement_infer_types_subquery() {
let sql = "SELECT id, age FROM person WHERE age = (select max(age) from person where id = $1)";
Expand Down

0 comments on commit f959127

Please sign in to comment.