Skip to content

Commit

Permalink
feat: basic support for executing prepared statements (#13242)
Browse files Browse the repository at this point in the history
* feat: basic support for executing prepared statements

* Improve execute_prepared

* Fix tests

* Update doc

* Add test

* Add issue test

* Respect allow_statements option
  • Loading branch information
jonahgao authored Nov 7, 2024
1 parent ba094e7 commit 34d9d3a
Show file tree
Hide file tree
Showing 6 changed files with 304 additions and 59 deletions.
84 changes: 80 additions & 4 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ use crate::{
logical_expr::{
CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction,
CreateMemoryTable, CreateView, DropCatalogSchema, DropFunction, DropTable,
DropView, LogicalPlan, LogicalPlanBuilder, SetVariable, TableType, UNNAMED_TABLE,
DropView, Execute, LogicalPlan, LogicalPlanBuilder, Prepare, SetVariable,
TableType, UNNAMED_TABLE,
},
physical_expr::PhysicalExpr,
physical_plan::ExecutionPlan,
Expand All @@ -54,9 +55,9 @@ use arrow::record_batch::RecordBatch;
use arrow_schema::Schema;
use datafusion_common::{
config::{ConfigExtension, TableOptions},
exec_err, not_impl_err, plan_datafusion_err, plan_err,
exec_datafusion_err, exec_err, not_impl_err, plan_datafusion_err, plan_err,
tree_node::{TreeNodeRecursion, TreeNodeVisitor},
DFSchema, SchemaReference, TableReference,
DFSchema, ParamValues, ScalarValue, SchemaReference, TableReference,
};
use datafusion_execution::registry::SerializerRegistry;
use datafusion_expr::{
Expand Down Expand Up @@ -687,7 +688,31 @@ impl SessionContext {
LogicalPlan::Statement(Statement::SetVariable(stmt)) => {
self.set_variable(stmt).await
}

LogicalPlan::Prepare(Prepare {
name,
input,
data_types,
}) => {
// The number of parameters must match the specified data types length.
if !data_types.is_empty() {
let param_names = input.get_parameter_names()?;
if param_names.len() != data_types.len() {
return plan_err!(
"Prepare specifies {} data types but query has {} parameters",
data_types.len(),
param_names.len()
);
}
}
// Store the unoptimized plan into the session state. Although storing the
// optimized plan or the physical plan would be more efficient, doing so is
// not currently feasible. This is because `now()` would be optimized to a
// constant value, causing each EXECUTE to yield the same result, which is
// incorrect behavior.
self.state.write().store_prepared(name, data_types, input)?;
self.return_empty_dataframe()
}
LogicalPlan::Execute(execute) => self.execute_prepared(execute),
plan => Ok(DataFrame::new(self.state(), plan)),
}
}
Expand Down Expand Up @@ -1088,6 +1113,49 @@ impl SessionContext {
}
}

fn execute_prepared(&self, execute: Execute) -> Result<DataFrame> {
let Execute {
name, parameters, ..
} = execute;
let prepared = self.state.read().get_prepared(&name).ok_or_else(|| {
exec_datafusion_err!("Prepared statement '{}' does not exist", name)
})?;

// Only allow literals as parameters for now.
let mut params: Vec<ScalarValue> = parameters
.into_iter()
.map(|e| match e {
Expr::Literal(scalar) => Ok(scalar),
_ => not_impl_err!("Unsupported parameter type: {}", e),
})
.collect::<Result<_>>()?;

// If the prepared statement provides data types, cast the params to those types.
if !prepared.data_types.is_empty() {
if params.len() != prepared.data_types.len() {
return exec_err!(
"Prepared statement '{}' expects {} parameters, but {} provided",
name,
prepared.data_types.len(),
params.len()
);
}
params = params
.into_iter()
.zip(prepared.data_types.iter())
.map(|(e, dt)| e.cast_to(dt))
.collect::<Result<_>>()?;
}

let params = ParamValues::List(params);
let plan = prepared
.plan
.as_ref()
.clone()
.replace_params_with_values(&params)?;
Ok(DataFrame::new(self.state(), plan))
}

/// Registers a variable provider within this context.
pub fn register_variable(
&self,
Expand Down Expand Up @@ -1705,6 +1773,14 @@ impl<'n, 'a> TreeNodeVisitor<'n> for BadPlanVisitor<'a> {
LogicalPlan::Statement(stmt) if !self.options.allow_statements => {
plan_err!("Statement not supported: {}", stmt.name())
}
// TODO: Implement PREPARE as a LogicalPlan::Statement
LogicalPlan::Prepare(_) if !self.options.allow_statements => {
plan_err!("Statement not supported: PREPARE")
}
// TODO: Implement EXECUTE as a LogicalPlan::Statement
LogicalPlan::Execute(_) if !self.options.allow_statements => {
plan_err!("Statement not supported: EXECUTE")
}
_ => Ok(TreeNodeRecursion::Continue),
}
}
Expand Down
38 changes: 37 additions & 1 deletion datafusion/core/src/execution/session_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan};
use datafusion_common::file_options::file_type::FileType;
use datafusion_common::tree_node::TreeNode;
use datafusion_common::{
config_err, not_impl_err, plan_datafusion_err, DFSchema, DataFusionError,
config_err, exec_err, not_impl_err, plan_datafusion_err, DFSchema, DataFusionError,
ResolvedTableReference, TableReference,
};
use datafusion_execution::config::SessionConfig;
Expand Down Expand Up @@ -171,6 +171,9 @@ pub struct SessionState {
/// It will be invoked on `CREATE FUNCTION` statements.
/// thus, changing dialect o PostgreSql is required
function_factory: Option<Arc<dyn FunctionFactory>>,
/// Cache logical plans of prepared statements for later execution.
/// Key is the prepared statement name.
prepared_plans: HashMap<String, Arc<PreparedPlan>>,
}

impl Debug for SessionState {
Expand All @@ -197,6 +200,7 @@ impl Debug for SessionState {
.field("scalar_functions", &self.scalar_functions)
.field("aggregate_functions", &self.aggregate_functions)
.field("window_functions", &self.window_functions)
.field("prepared_plans", &self.prepared_plans)
.finish()
}
}
Expand Down Expand Up @@ -906,6 +910,29 @@ impl SessionState {
let udtf = self.table_functions.remove(name);
Ok(udtf.map(|x| x.function().clone()))
}

/// Store the logical plan and the parameter types of a prepared statement.
pub(crate) fn store_prepared(
&mut self,
name: String,
data_types: Vec<DataType>,
plan: Arc<LogicalPlan>,
) -> datafusion_common::Result<()> {
match self.prepared_plans.entry(name) {
Entry::Vacant(e) => {
e.insert(Arc::new(PreparedPlan { data_types, plan }));
Ok(())
}
Entry::Occupied(e) => {
exec_err!("Prepared statement '{}' already exists", e.key())
}
}
}

/// Get the prepared plan with the given name.
pub(crate) fn get_prepared(&self, name: &str) -> Option<Arc<PreparedPlan>> {
self.prepared_plans.get(name).map(Arc::clone)
}
}

/// A builder to be used for building [`SessionState`]'s. Defaults will
Expand Down Expand Up @@ -1327,6 +1354,7 @@ impl SessionStateBuilder {
table_factories: table_factories.unwrap_or_default(),
runtime_env,
function_factory,
prepared_plans: HashMap::new(),
};

if let Some(file_formats) = file_formats {
Expand Down Expand Up @@ -1876,6 +1904,14 @@ impl<'a> SimplifyInfo for SessionSimplifyProvider<'a> {
}
}

#[derive(Debug)]
pub(crate) struct PreparedPlan {
/// Data types of the parameters
pub(crate) data_types: Vec<DataType>,
/// The prepared logical plan
pub(crate) plan: Arc<LogicalPlan>,
}

#[cfg(test)]
mod tests {
use super::{SessionContextProvider, SessionStateBuilder};
Expand Down
30 changes: 4 additions & 26 deletions datafusion/core/tests/sql/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ async fn test_named_query_parameters() -> Result<()> {
let ctx = create_ctx_with_partition(&tmp_dir, partition_count).await?;

// sql to statement then to logical plan with parameters
// c1 defined as UINT32, c2 defined as UInt64
let results = ctx
.sql("SELECT c1, c2 FROM test WHERE c1 > $coo AND c1 < $foo")
.await?
Expand Down Expand Up @@ -106,9 +105,9 @@ async fn test_prepare_statement() -> Result<()> {
let ctx = create_ctx_with_partition(&tmp_dir, partition_count).await?;

// sql to statement then to prepare logical plan with parameters
// c1 defined as UINT32, c2 defined as UInt64 but the params are Int32 and Float64
let dataframe =
ctx.sql("PREPARE my_plan(INT, DOUBLE) AS SELECT c1, c2 FROM test WHERE c1 > $2 AND c1 < $1").await?;
let dataframe = ctx
.sql("SELECT c1, c2 FROM test WHERE c1 > $2 AND c1 < $1")
.await?;

// prepare logical plan to logical plan without parameters
let param_values = vec![ScalarValue::Int32(Some(3)), ScalarValue::Float64(Some(0.0))];
Expand Down Expand Up @@ -156,7 +155,7 @@ async fn prepared_statement_type_coercion() -> Result<()> {
("unsigned", Arc::new(unsigned_ints) as ArrayRef),
])?;
ctx.register_batch("test", batch)?;
let results = ctx.sql("PREPARE my_plan(BIGINT, INT, TEXT) AS SELECT signed, unsigned FROM test WHERE $1 >= signed AND signed <= $2 AND unsigned = $3")
let results = ctx.sql("SELECT signed, unsigned FROM test WHERE $1 >= signed AND signed <= $2 AND unsigned = $3")
.await?
.with_param_values(vec![
ScalarValue::from(1_i64),
Expand All @@ -176,27 +175,6 @@ async fn prepared_statement_type_coercion() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn prepared_statement_invalid_types() -> Result<()> {
let ctx = SessionContext::new();
let signed_ints: Int32Array = vec![-1, 0, 1].into();
let unsigned_ints: UInt64Array = vec![1, 2, 3].into();
let batch = RecordBatch::try_from_iter(vec![
("signed", Arc::new(signed_ints) as ArrayRef),
("unsigned", Arc::new(unsigned_ints) as ArrayRef),
])?;
ctx.register_batch("test", batch)?;
let results = ctx
.sql("PREPARE my_plan(INT) AS SELECT signed FROM test WHERE signed = $1")
.await?
.with_param_values(vec![ScalarValue::from("1")]);
assert_eq!(
results.unwrap_err().strip_backtrace(),
"Error during planning: Expected parameter of type Int32, got Utf8 at index 0"
);
Ok(())
}

#[tokio::test]
async fn test_parameter_type_coercion() -> Result<()> {
let ctx = SessionContext::new();
Expand Down
24 changes: 24 additions & 0 deletions datafusion/core/tests/sql/sql_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,30 @@ async fn unsupported_statement_returns_error() {
ctx.sql_with_options(sql, options).await.unwrap();
}

// Disallow PREPARE and EXECUTE statements if `allow_statements` is false
#[tokio::test]
async fn disable_prepare_and_execute_statement() {
let ctx = SessionContext::new();

let prepare_sql = "PREPARE plan(INT) AS SELECT $1";
let execute_sql = "EXECUTE plan(1)";
let options = SQLOptions::new().with_allow_statements(false);
let df = ctx.sql_with_options(prepare_sql, options).await;
assert_eq!(
df.unwrap_err().strip_backtrace(),
"Error during planning: Statement not supported: PREPARE"
);
let df = ctx.sql_with_options(execute_sql, options).await;
assert_eq!(
df.unwrap_err().strip_backtrace(),
"Error during planning: Statement not supported: EXECUTE"
);

let options = options.with_allow_statements(true);
ctx.sql_with_options(prepare_sql, options).await.unwrap();
ctx.sql_with_options(execute_sql, options).await.unwrap();
}

#[tokio::test]
async fn empty_statement_returns_error() {
let ctx = SessionContext::new();
Expand Down
16 changes: 16 additions & 0 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1440,6 +1440,22 @@ impl LogicalPlan {
.map(|res| res.data)
}

/// Walk the logical plan, find any `Placeholder` tokens, and return a set of their names.
pub fn get_parameter_names(&self) -> Result<HashSet<String>> {
let mut param_names = HashSet::new();
self.apply_with_subqueries(|plan| {
plan.apply_expressions(|expr| {
expr.apply(|expr| {
if let Expr::Placeholder(Placeholder { id, .. }) = expr {
param_names.insert(id.clone());
}
Ok(TreeNodeRecursion::Continue)
})
})
})
.map(|_| param_names)
}

/// Walk the logical plan, find any `Placeholder` tokens, and return a map of their IDs and DataTypes
pub fn get_parameter_types(
&self,
Expand Down
Loading

0 comments on commit 34d9d3a

Please sign in to comment.