Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: basic support for executing prepared statements #13242

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 72 additions & 4 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ use crate::{
logical_expr::{
CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction,
CreateMemoryTable, CreateView, DropCatalogSchema, DropFunction, DropTable,
DropView, LogicalPlan, LogicalPlanBuilder, SetVariable, TableType, UNNAMED_TABLE,
DropView, Execute, LogicalPlan, LogicalPlanBuilder, Prepare, SetVariable,
TableType, UNNAMED_TABLE,
},
physical_expr::PhysicalExpr,
physical_plan::ExecutionPlan,
Expand All @@ -54,9 +55,9 @@ use arrow::record_batch::RecordBatch;
use arrow_schema::Schema;
use datafusion_common::{
config::{ConfigExtension, TableOptions},
exec_err, not_impl_err, plan_datafusion_err, plan_err,
exec_datafusion_err, exec_err, not_impl_err, plan_datafusion_err, plan_err,
tree_node::{TreeNodeRecursion, TreeNodeVisitor},
DFSchema, SchemaReference, TableReference,
DFSchema, ParamValues, ScalarValue, SchemaReference, TableReference,
};
use datafusion_execution::registry::SerializerRegistry;
use datafusion_expr::{
Expand Down Expand Up @@ -687,7 +688,31 @@ impl SessionContext {
LogicalPlan::Statement(Statement::SetVariable(stmt)) => {
self.set_variable(stmt).await
}

LogicalPlan::Prepare(Prepare {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am worried that LogicalPlan::Prepare seems to be runnable even if we run sql_with_options and disable statements.

pub async fn sql_with_options(
&self,
sql: &str,
options: SQLOptions,
) -> Result<DataFrame> {
let plan = self.state().create_logical_plan(sql).await?;
options.verify_plan(&plan)?;
self.execute_logical_plan(plan).await
}

I realize that this PR does not change if Prepare is a statement or not, but it is the first that actually makes Prepare do something

Could you please add a test like this (but for PREPARE statements) and make sure the statement can't be executed if statements are disabled?

async fn unsupported_statement_returns_error() {
let ctx = SessionContext::new();
ctx.sql("CREATE TABLE test (x int)").await.unwrap();
let options = SQLOptions::new().with_allow_statements(false);
let sql = "set datafusion.execution.batch_size = 5";
let df = ctx.sql_with_options(sql, options).await;
assert_eq!(
df.unwrap_err().strip_backtrace(),
"Error during planning: Statement not supported: SetVariable"
);
let options = options.with_allow_statements(true);
ctx.sql_with_options(sql, options).await.unwrap();
}

Perhaps (as a follow on PR) we can make LogicalPlan::Prepare a statement https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.Statement.html

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Temporarily added PREPARE to BadPlanVisitor and added a test for this in 4729766
I will make LogicalPlan::Prepare a statement in a follow-up PR.

name,
input,
data_types,
}) => {
// The number of parameters must match the specified data types length.
if !data_types.is_empty() {
let param_names = input.get_parameter_names()?;
if param_names.len() != data_types.len() {
return plan_err!(
"Prepare specifies {} data types but query has {} parameters",
data_types.len(),
param_names.len()
);
}
}
// Store the unoptimized plan into the session state. Although storing the
// optimized plan or the physical plan would be more efficient, doing so is
// not currently feasible. This is because `now()` would be optimized to a
// constant value, causing each EXECUTE to yield the same result, which is
// incorrect behavior.
self.state.write().store_prepared(name, data_types, input)?;
self.return_empty_dataframe()
}
LogicalPlan::Execute(execute) => self.execute_prepared(execute),
plan => Ok(DataFrame::new(self.state(), plan)),
}
}
Expand Down Expand Up @@ -1088,6 +1113,49 @@ impl SessionContext {
}
}

fn execute_prepared(&self, execute: Execute) -> Result<DataFrame> {
let Execute {
name, parameters, ..
} = execute;
let prepared = self.state.read().get_prepared(&name).ok_or_else(|| {
exec_datafusion_err!("Prepared statement '{}' does not exist", name)
})?;

// Only allow literals as parameters for now.
let mut params: Vec<ScalarValue> = parameters
.into_iter()
.map(|e| match e {
Expr::Literal(scalar) => Ok(scalar),
_ => not_impl_err!("Unsupported data type for parameter: {}", e),
})
.collect::<Result<_>>()?;

// If the prepared statement provides data types, cast the params to those types.
if !prepared.data_types.is_empty() {
if params.len() != prepared.data_types.len() {
return exec_err!(
"Prepared statement '{}' expects {} parameters, but {} provided",
name,
prepared.data_types.len(),
params.len()
);
}
params = params
.into_iter()
.zip(prepared.data_types.iter())
.map(|(e, dt)| e.cast_to(dt))
.collect::<Result<_>>()?;
}

let params = ParamValues::List(params);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

let plan = prepared
.plan
.as_ref()
.clone()
.replace_params_with_values(&params)?;
Ok(DataFrame::new(self.state(), plan))
}

/// Registers a variable provider within this context.
pub fn register_variable(
&self,
Expand Down
38 changes: 37 additions & 1 deletion datafusion/core/src/execution/session_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan};
use datafusion_common::file_options::file_type::FileType;
use datafusion_common::tree_node::TreeNode;
use datafusion_common::{
config_err, not_impl_err, plan_datafusion_err, DFSchema, DataFusionError,
config_err, exec_err, not_impl_err, plan_datafusion_err, DFSchema, DataFusionError,
ResolvedTableReference, TableReference,
};
use datafusion_execution::config::SessionConfig;
Expand Down Expand Up @@ -171,6 +171,9 @@ pub struct SessionState {
/// It will be invoked on `CREATE FUNCTION` statements.
/// thus, changing dialect o PostgreSql is required
function_factory: Option<Arc<dyn FunctionFactory>>,
/// Cache logical plans of prepared statements for later execution.
/// Key is the prepared statement name.
prepared_plans: HashMap<String, Arc<PreparedPlan>>,
}

impl Debug for SessionState {
Expand All @@ -197,6 +200,7 @@ impl Debug for SessionState {
.field("scalar_functions", &self.scalar_functions)
.field("aggregate_functions", &self.aggregate_functions)
.field("window_functions", &self.window_functions)
.field("prepared_plans", &self.prepared_plans)
.finish()
}
}
Expand Down Expand Up @@ -906,6 +910,29 @@ impl SessionState {
let udtf = self.table_functions.remove(name);
Ok(udtf.map(|x| x.function().clone()))
}

/// Store the logical plan and the parameter types of a prepared statement.
pub(crate) fn store_prepared(
&mut self,
name: String,
data_types: Vec<DataType>,
plan: Arc<LogicalPlan>,
) -> datafusion_common::Result<()> {
match self.prepared_plans.entry(name) {
Entry::Vacant(e) => {
e.insert(Arc::new(PreparedPlan { data_types, plan }));
Ok(())
}
Entry::Occupied(e) => {
exec_err!("Prepared statement '{}' already exists", e.key())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this errors if an existing prepared statement should we add some way to erase / clear prepared plans? Now there is no way to avoid accumulating them over time

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I plan to implement DEALLOCATE in next PR.

}
}
}

/// Get the prepared plan with the given name.
pub(crate) fn get_prepared(&self, name: &str) -> Option<Arc<PreparedPlan>> {
self.prepared_plans.get(name).map(Arc::clone)
}
}

/// A builder to be used for building [`SessionState`]'s. Defaults will
Expand Down Expand Up @@ -1327,6 +1354,7 @@ impl SessionStateBuilder {
table_factories: table_factories.unwrap_or_default(),
runtime_env,
function_factory,
prepared_plans: HashMap::new(),
};

if let Some(file_formats) = file_formats {
Expand Down Expand Up @@ -1876,6 +1904,14 @@ impl<'a> SimplifyInfo for SessionSimplifyProvider<'a> {
}
}

#[derive(Debug)]
pub(crate) struct PreparedPlan {
/// Data types of the parameters
pub(crate) data_types: Vec<DataType>,
/// The prepared logical plan
pub(crate) plan: Arc<LogicalPlan>,
}

#[cfg(test)]
mod tests {
use super::{SessionContextProvider, SessionStateBuilder};
Expand Down
30 changes: 4 additions & 26 deletions datafusion/core/tests/sql/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ async fn test_named_query_parameters() -> Result<()> {
let ctx = create_ctx_with_partition(&tmp_dir, partition_count).await?;

// sql to statement then to logical plan with parameters
// c1 defined as UINT32, c2 defined as UInt64
let results = ctx
.sql("SELECT c1, c2 FROM test WHERE c1 > $coo AND c1 < $foo")
.await?
Expand Down Expand Up @@ -106,9 +105,9 @@ async fn test_prepare_statement() -> Result<()> {
let ctx = create_ctx_with_partition(&tmp_dir, partition_count).await?;

// sql to statement then to prepare logical plan with parameters
// c1 defined as UINT32, c2 defined as UInt64 but the params are Int32 and Float64
let dataframe =
ctx.sql("PREPARE my_plan(INT, DOUBLE) AS SELECT c1, c2 FROM test WHERE c1 > $2 AND c1 < $1").await?;
Copy link
Member Author

@jonahgao jonahgao Nov 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PREPARE now returns an empty DataFrame and can't call with_param_values on it. I think this change is reasonable because the result of PREPARE is not a real relation, as it doesn't contain valid rows. Its result is more like DDL.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree -- PREPARE is a statement in my mind (and has no results)

let dataframe = ctx
.sql("SELECT c1, c2 FROM test WHERE c1 > $2 AND c1 < $1")
.await?;

// prepare logical plan to logical plan without parameters
let param_values = vec![ScalarValue::Int32(Some(3)), ScalarValue::Float64(Some(0.0))];
Expand Down Expand Up @@ -156,7 +155,7 @@ async fn prepared_statement_type_coercion() -> Result<()> {
("unsigned", Arc::new(unsigned_ints) as ArrayRef),
])?;
ctx.register_batch("test", batch)?;
let results = ctx.sql("PREPARE my_plan(BIGINT, INT, TEXT) AS SELECT signed, unsigned FROM test WHERE $1 >= signed AND signed <= $2 AND unsigned = $3")
let results = ctx.sql("SELECT signed, unsigned FROM test WHERE $1 >= signed AND signed <= $2 AND unsigned = $3")
.await?
.with_param_values(vec![
ScalarValue::from(1_i64),
Expand All @@ -176,27 +175,6 @@ async fn prepared_statement_type_coercion() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn prepared_statement_invalid_types() -> Result<()> {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this test because now parameters will be cast to the target type.
This is also the behavior of PostgreSQL.

psql=> prepare p(int) as select 100+$1;
PREPARE

psql=> execute p('100');
 ?column?
----------
      200
(1 row)

psql=> execute p(20.12);
 ?column?
----------
      120
(1 row)

let ctx = SessionContext::new();
let signed_ints: Int32Array = vec![-1, 0, 1].into();
let unsigned_ints: UInt64Array = vec![1, 2, 3].into();
let batch = RecordBatch::try_from_iter(vec![
("signed", Arc::new(signed_ints) as ArrayRef),
("unsigned", Arc::new(unsigned_ints) as ArrayRef),
])?;
ctx.register_batch("test", batch)?;
let results = ctx
.sql("PREPARE my_plan(INT) AS SELECT signed FROM test WHERE signed = $1")
.await?
.with_param_values(vec![ScalarValue::from("1")]);
assert_eq!(
results.unwrap_err().strip_backtrace(),
"Error during planning: Expected parameter of type Int32, got Utf8 at index 0"
);
Ok(())
}

#[tokio::test]
async fn test_parameter_type_coercion() -> Result<()> {
let ctx = SessionContext::new();
Expand Down
16 changes: 16 additions & 0 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1437,6 +1437,22 @@ impl LogicalPlan {
.map(|res| res.data)
}

/// Walk the logical plan, find any `Placeholder` tokens, and return a set of their names.
pub fn get_parameter_names(&self) -> Result<HashSet<String>> {
let mut param_names = HashSet::new();
self.apply_with_subqueries(|plan| {
plan.apply_expressions(|expr| {
expr.apply(|expr| {
if let Expr::Placeholder(Placeholder { id, .. }) = expr {
param_names.insert(id.clone());
}
Ok(TreeNodeRecursion::Continue)
})
})
})
.map(|_| param_names)
}

/// Walk the logical plan, find any `Placeholder` tokens, and return a map of their IDs and DataTypes
pub fn get_parameter_types(
&self,
Expand Down
Loading
Loading