diff --git a/crates/pgsrv/src/codec.rs b/crates/pgsrv/src/codec.rs index 938458ab8..ad167db5f 100644 --- a/crates/pgsrv/src/codec.rs +++ b/crates/pgsrv/src/codec.rs @@ -148,6 +148,80 @@ impl PgCodec { }) } + fn decode_parse(buf: &mut Cursor<'_>) -> Result { + let name = buf.read_cstring()?.to_string(); + let sql = buf.read_cstring()?.to_string(); + let num_params = buf.get_i16() as usize; + let mut param_types = Vec::with_capacity(num_params); + for _ in 0..num_params { + param_types.push(buf.get_i32()); + } + Ok(FrontendMessage::Parse { + name, + sql, + param_types, + }) + } + + fn decode_bind(buf: &mut Cursor<'_>) -> Result { + let portal = buf.read_cstring()?.to_string(); + let statement = buf.read_cstring()?.to_string(); + + let num_params = buf.get_i16() as usize; + let mut param_formats = Vec::with_capacity(num_params); + for _ in 0..num_params { + param_formats.push(buf.get_i16()); + } + + let num_values = buf.get_i16() as usize; // must match num_params + let mut param_values = Vec::with_capacity(num_values); + for _ in 0..num_values { + let len = buf.get_i32(); + if len == -1 { + param_values.push(None); + } else { + let mut val = vec![0; len as usize]; + buf.copy_to_slice(&mut val); + param_values.push(Some(val)); + } + } + + let num_params = buf.get_i16() as usize; + let mut result_formats = Vec::with_capacity(num_params); + for _ in 0..num_params { + result_formats.push(buf.get_i16()); + } + + Ok(FrontendMessage::Bind { + portal, + statement, + param_formats, + param_values, + result_formats, + }) + } + + fn decode_describe(buf: &mut Cursor<'_>) -> Result { + let object_type = buf.get_u8().try_into()?; + let name = buf.read_cstring()?.to_string(); + + Ok(FrontendMessage::Describe { object_type, name }) + } + + fn decode_execute(buf: &mut Cursor<'_>) -> Result { + let portal = buf.read_cstring()?.to_string(); + let max_rows = buf.get_i32(); + Ok(FrontendMessage::Execute { portal, max_rows }) + } + + fn decode_sync(_buf: &mut Cursor<'_>) -> Result { + Ok(FrontendMessage::Sync) + } + + fn decode_terminate(_buf: &mut Cursor<'_>) -> Result { + Ok(FrontendMessage::Terminate) + } + fn encode_scalar_as_text(scalar: ScalarValue, buf: &mut BytesMut) -> Result<()> { if scalar.is_null() { buf.put_i32(-1); @@ -187,6 +261,9 @@ impl Encoder for PgCodec { BackendMessage::DataRow(_, _) => b'D', BackendMessage::ErrorResponse(_) => b'E', BackendMessage::NoticeResponse(_) => b'N', + BackendMessage::ParseComplete => b'1', + BackendMessage::BindComplete => b'2', + BackendMessage::NoData => b'n', }; dst.put_u8(byte); @@ -198,6 +275,9 @@ impl Encoder for PgCodec { BackendMessage::AuthenticationOk => dst.put_i32(0), BackendMessage::AuthenticationCleartextPassword => dst.put_i32(3), BackendMessage::EmptyQueryResponse => (), + BackendMessage::ParseComplete => (), + BackendMessage::BindComplete => (), + BackendMessage::NoData => (), BackendMessage::ParameterStatus { key, val } => { dst.put_cstring(&key); dst.put_cstring(&val); @@ -284,8 +364,8 @@ impl Decoder for PgCodec { let msg_len = i32::from_be_bytes(src[1..5].try_into().unwrap()) as usize; // Not enough bytes to read the full message yet. - if src.len() < msg_len { - src.reserve(msg_len - src.len()); + if src.len() < msg_len + 1 { + src.reserve(msg_len + 1 - src.len()); return Ok(None); } @@ -296,6 +376,13 @@ impl Decoder for PgCodec { let msg = match msg_type { b'Q' => Self::decode_query(&mut buf)?, b'p' => Self::decode_password(&mut buf)?, + b'P' => Self::decode_parse(&mut buf)?, + b'B' => Self::decode_bind(&mut buf)?, + b'D' => Self::decode_describe(&mut buf)?, + b'E' => Self::decode_execute(&mut buf)?, + b'S' => Self::decode_sync(&mut buf)?, + // X - Terminate + b'X' => return Ok(None), other => return Err(PgSrvError::InvalidMsgType(other)), }; diff --git a/crates/pgsrv/src/errors.rs b/crates/pgsrv/src/errors.rs index 1b02a88e2..4df4f13a9 100644 --- a/crates/pgsrv/src/errors.rs +++ b/crates/pgsrv/src/errors.rs @@ -21,6 +21,9 @@ pub enum PgSrvError { #[error("missing null byte")] MissingNullByte, + #[error("unexpected describe object type: {0}")] + UnexpectedDescribeObjectType(u8), + /// We've received an unexpected message identifier from the frontend. /// Includes the char representation to allow for easy cross referencing /// with the Postgres message format documentation. diff --git a/crates/pgsrv/src/handler.rs b/crates/pgsrv/src/handler.rs index a59e22148..bbb023d9a 100644 --- a/crates/pgsrv/src/handler.rs +++ b/crates/pgsrv/src/handler.rs @@ -1,11 +1,12 @@ use crate::codec::{FramedConn, PgCodec}; use crate::errors::{PgSrvError, Result}; use crate::messages::{ - BackendMessage, ErrorResponse, FieldDescription, FrontendMessage, StartupMessage, - TransactionStatus, + BackendMessage, DescribeObjectType, ErrorResponse, FieldDescription, FrontendMessage, + StartupMessage, TransactionStatus, }; use datafusion::physical_plan::SendableRecordBatchStream; use futures::StreamExt; +use sqlexec::logical_plan::LogicalPlan; use sqlexec::{ engine::Engine, executor::{ExecutionResult, Executor}, @@ -13,7 +14,7 @@ use sqlexec::{ }; use std::collections::HashMap; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; -use tracing::trace; +use tracing::{trace, warn}; /// Default parameters to send to the frontend on startup. Existing postgres /// drivers may expect these in the server response on startup. @@ -130,6 +131,8 @@ impl Handler { struct ClientSession { conn: FramedConn, session: Session, // TODO: Make this a trait for stubbability? + + error_state: bool, } impl ClientSession @@ -137,31 +140,84 @@ where C: AsyncRead + AsyncWrite + Unpin, { fn new(session: Session, conn: FramedConn) -> Self { - ClientSession { session, conn } + ClientSession { + session, + conn, + error_state: false, + } } async fn run(mut self) -> Result<()> { self.ready_for_query().await?; loop { let msg = self.conn.read().await?; - - match msg { - Some(FrontendMessage::Query { sql }) => self.query(sql).await?, - Some(other) => { - self.conn - .send( - ErrorResponse::feature_not_supported(format!( - "unsupported frontend message: {:?}", - other - )) - .into(), - ) - .await?; - self.ready_for_query().await?; + tracing::debug!(?msg, "received message"); + // If we're in an error state, we should only process Sync messages. + // Until this is received, we should discard all incoming messages + if self.error_state { + match msg { + Some(FrontendMessage::Sync) => { + self.clear_error(); + self.ready_for_query().await?; + continue; + } + Some(other) => { + tracing::warn!(?other, "discarding message"); + } + None => { + tracing::debug!("connection closed"); + return Ok(()); + } } - None => { - trace!("connection closed"); - return Ok(()); + } else { + // Execute messages as normal if not in an error state. + match msg { + Some(FrontendMessage::Query { sql }) => self.query(sql).await?, + Some(FrontendMessage::Parse { + name, + sql, + param_types, + }) => self.parse(name, sql, param_types).await?, + Some(FrontendMessage::Bind { + portal, + statement, + param_formats, + param_values, + result_formats, + }) => { + self.bind( + portal, + statement, + param_formats, + param_values, + result_formats, + ) + .await? + } + Some(FrontendMessage::Describe { object_type, name }) => { + self.describe(object_type, name).await? + } + Some(FrontendMessage::Execute { portal, max_rows }) => { + self.execute(portal, max_rows).await? + } + Some(FrontendMessage::Sync) => self.sync().await?, + Some(other) => { + warn!(?other, "unsupported frontend message"); + self.conn + .send( + ErrorResponse::feature_not_supported(format!( + "unsupported frontend message: {:?}", + other + )) + .into(), + ) + .await?; + self.ready_for_query().await?; + } + None => { + trace!("connection closed"); + return Ok(()); + } } } } @@ -199,20 +255,7 @@ where } }; - match result { - ExecutionResult::Query { stream } => { - Self::stream_batch(conn, stream).await?; - Self::command_complete(conn, "SELECT").await? - } - ExecutionResult::Begin => Self::command_complete(conn, "BEGIN").await?, - ExecutionResult::Commit => Self::command_complete(conn, "COMMIT").await?, - ExecutionResult::Rollback => Self::command_complete(conn, "ROLLBACK").await?, - ExecutionResult::WriteSuccess => Self::command_complete(conn, "INSERT").await?, - ExecutionResult::CreateTable => { - Self::command_complete(conn, "CREATE_TABLE").await? - } - ExecutionResult::SetLocal => Self::command_complete(conn, "SET").await?, - } + Self::send_result(conn, result).await?; } if num_statements == 0 { @@ -222,6 +265,191 @@ where self.ready_for_query().await } + /// Parse the provided SQL statement and store it in the session. + async fn parse(&mut self, name: String, sql: String, param_types: Vec) -> Result<()> { + let session = &mut self.session; + let conn = &mut self.conn; + + // an empty name selectss the unnamed prepared statement + let name = if name.is_empty() { None } else { Some(name) }; + + trace!(?name, %sql, ?param_types, "received parse"); + + session.create_prepared_statement(name, sql, param_types)?; + + conn.send(BackendMessage::ParseComplete).await?; + + Ok(()) + } + + async fn bind( + &mut self, + portal: String, + statement: String, + param_formats: Vec, + param_values: Vec>>, + result_formats: Vec, + ) -> Result<()> { + let portal_name = if portal.is_empty() { + None + } else { + Some(portal) + }; + let statement_name = if statement.is_empty() { + None + } else { + Some(statement) + }; + + // param_formats can be empty, in which case all parameters (if any) are assumed to be text + // or it may have one entry, in which case all parameters are assumed to be of that format + // or it may have one entry per parameter, in which case each parameter is assumed to be of that format + // each code must be 0 (text) or 1 (binary) + let param_formats = if param_formats.is_empty() { + if param_values.is_empty() { + vec![] + } else { + vec![0] + } + } else if param_formats.len() == 1 { + vec![param_formats[0]; param_values.len()] + } else { + param_formats + }; + + trace!(?portal_name, ?statement_name, ?param_formats, ?param_values, ?result_formats, "received bind"); + + let session = &mut self.session; + let conn = &mut self.conn; + + session.bind_prepared_statement( + portal_name, + statement_name, + param_formats, + param_values, + result_formats, + )?; + + conn.send(BackendMessage::BindComplete).await?; + + Ok(()) + } + + async fn describe(&mut self, object_type: DescribeObjectType, name: String) -> Result<()> { + let session = &mut self.session; + let conn = &mut self.conn; + + let name = if name.is_empty() { None } else { Some(name) }; + + trace!(?name, ?object_type, "received describe"); + + match object_type { + DescribeObjectType::Statement => match session.get_prepared_statement(&name) { + Some(statement) => { + statement.describe(); + todo!("return statement describe response"); + } + None => { + self.conn + .send( + ErrorResponse::error_internal(format!( + "unknown prepared statement: {:?}", + name + )) + .into(), + ) + .await?; + } + }, + DescribeObjectType::Portal => { + // Describe (portal variant) returns a RowDescription message describing the rows + // that will be returned. If the portal contains a query that returns no rows, then + // a NoData message is returned instead. + match session.get_portal(&name) { + Some(portal) => { + match &portal.plan { + LogicalPlan::Ddl(_) => { + self.conn.send(BackendMessage::NoData).await?; + } + LogicalPlan::Write(_) => { + todo!("return portal describe response for Write"); + } + LogicalPlan::Query(df_plan) => { + let schema = df_plan.schema(); + let fields: Vec<_> = schema + .fields() + .iter() + .map(|field| FieldDescription::new_named(field.name())) + .collect(); + conn.send(BackendMessage::RowDescription(fields)).await?; + } + LogicalPlan::Transaction(_) => { + todo!("return portal describe response for Transaction"); + } + LogicalPlan::Runtime => { + todo!("return portal describe response for Runtime"); + } + } + } + None => { + self.conn + .send( + ErrorResponse::error_internal(format!( + "unknown portal: {:?}", + name + )) + .into(), + ) + .await?; + } + } + } + } + + Ok(()) + } + + async fn execute(&mut self, portal: String, max_rows: i32) -> Result<()> { + let portal_name = if portal.is_empty() { + None + } else { + Some(portal) + }; + + let session = &mut self.session; + let conn = &mut self.conn; + + trace!(?portal_name, ?max_rows, "received execute"); + + let result = session.execute_portal(&portal_name, max_rows).await?; + Self::send_result(conn, result).await?; + + Ok(()) + } + + async fn sync(&mut self) -> Result<()> { + trace!("received sync"); + + self.ready_for_query().await + } + + async fn send_result(conn: &mut FramedConn, result: ExecutionResult) -> Result<()> { + match result { + ExecutionResult::Query { stream } => { + Self::stream_batch(conn, stream).await?; + Self::command_complete(conn, "SELECT").await? + } + ExecutionResult::Begin => Self::command_complete(conn, "BEGIN").await?, + ExecutionResult::Commit => Self::command_complete(conn, "COMMIT").await?, + ExecutionResult::Rollback => Self::command_complete(conn, "ROLLBACK").await?, + ExecutionResult::WriteSuccess => Self::command_complete(conn, "INSERT").await?, + ExecutionResult::CreateTable => Self::command_complete(conn, "CREATE TABLE").await?, + ExecutionResult::SetLocal => Self::command_complete(conn, "SET").await?, + ExecutionResult::DropTables => Self::command_complete(conn, "DROP TABLE").await?, + } + Ok(()) + } + async fn stream_batch( conn: &mut FramedConn, mut stream: SendableRecordBatchStream, @@ -249,4 +477,12 @@ where conn.send(BackendMessage::CommandComplete { tag: tag.into() }) .await } + + fn set_error(&mut self) { + self.error_state = true; + } + + fn clear_error(&mut self) { + self.error_state = false; + } } diff --git a/crates/pgsrv/src/messages.rs b/crates/pgsrv/src/messages.rs index ad11d7e23..eea5e5a54 100644 --- a/crates/pgsrv/src/messages.rs +++ b/crates/pgsrv/src/messages.rs @@ -1,6 +1,8 @@ use datafusion::arrow::record_batch::RecordBatch; use std::collections::HashMap; +use crate::errors::PgSrvError; + /// Version number (v3.0) used during normal frontend startup. pub const VERSION_V3: i32 = 0x30000; /// Version number used to request a cancellation. @@ -31,6 +33,41 @@ pub enum FrontendMessage { Query { sql: String }, /// An encrypted or unencrypted password. PasswordMessage { password: String }, + /// An extended query parse message. + Parse { + /// The name of the prepared statement. An empty string denotes the unnamed prepared statement. + name: String, + /// The query string to be parsed. + sql: String, + /// The object IDs of the parameter data types. Placing a zero here is equivalent to leaving the type unspecified. + param_types: Vec, + }, + Bind { + /// The name of the destination portal (an empty string selects the unnamed portal). + portal: String, + /// The name of the source prepared statement (an empty string selects the unnamed prepared statement). + statement: String, + /// The parameter format codes. Each must presently be zero (text) or one (binary). + param_formats: Vec, + /// The parameter values, in the format indicated by the associated format code. n is the above length. + param_values: Vec>>, + /// The result-column format codes. Each must presently be zero (text) or one (binary). + result_formats: Vec, + }, + Describe { + /// The kind of item to describe: 'S' to describe a prepared statement; or 'P' to describe a portal. + object_type: DescribeObjectType, + /// The name of the item to describe (an empty string selects the unnamed prepared statement or portal). + name: String, + }, + Execute { + /// The name of the portal to execute (an empty string selects the unnamed portal). + portal: String, + /// The maximum number of rows to return, if portal contains a query that returns rows (ignored otherwise). Zero denotes "no limit". + max_rows: i32, + }, + Sync, + Terminate, } #[derive(Debug)] @@ -52,6 +89,9 @@ pub enum BackendMessage { CommandComplete { tag: String }, RowDescription(Vec), DataRow(RecordBatch, usize), + ParseComplete, + BindComplete, + NoData, } impl From for BackendMessage { @@ -200,3 +240,31 @@ impl FieldDescription { } } } + +#[derive(Debug)] +#[repr(u8)] +pub enum DescribeObjectType { + Statement = b'S', + Portal = b'P', +} + +impl std::fmt::Display for DescribeObjectType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + DescribeObjectType::Statement => write!(f, "Statement"), + DescribeObjectType::Portal => write!(f, "Portal"), + } + } +} + +impl TryFrom for DescribeObjectType { + type Error = PgSrvError; + + fn try_from(value: u8) -> Result { + match value { + b'S' => Ok(DescribeObjectType::Statement), + b'P' => Ok(DescribeObjectType::Portal), + _ => Err(PgSrvError::UnexpectedDescribeObjectType(value)), + } + } +} diff --git a/crates/sqlexec/src/catalog.rs b/crates/sqlexec/src/catalog.rs index 2cc544856..6fe28fe4a 100644 --- a/crates/sqlexec/src/catalog.rs +++ b/crates/sqlexec/src/catalog.rs @@ -175,4 +175,9 @@ impl SchemaProvider for SchemaCatalog { let tables = self.tables.read(); tables.contains_key(name) } + + fn deregister_table(&self, name: &str) -> DfResult>> { + let mut tables = self.tables.write(); + Ok(tables.remove(name)) + } } diff --git a/crates/sqlexec/src/executor.rs b/crates/sqlexec/src/executor.rs index 4900a7f5c..da987e652 100644 --- a/crates/sqlexec/src/executor.rs +++ b/crates/sqlexec/src/executor.rs @@ -24,6 +24,8 @@ pub enum ExecutionResult { CreateTable, /// A client local variable was set. SetLocal, + /// Tables dropped. + DropTables, } impl fmt::Debug for ExecutionResult { @@ -36,6 +38,7 @@ impl fmt::Debug for ExecutionResult { ExecutionResult::WriteSuccess => write!(f, "write success"), ExecutionResult::CreateTable => write!(f, "create table"), ExecutionResult::SetLocal => write!(f, "set local"), + ExecutionResult::DropTables => write!(f, "drop tables"), } } } @@ -107,6 +110,12 @@ impl<'a> Executor<'a> { let stream = self.session.execute_physical(physical)?; Ok(ExecutionResult::Query { stream }) } + LogicalPlan::Runtime => { + // TODO: We'll want to: + // 1. Actually do something here. + // 2. Probably return a different variant for global SET statements. + Ok(ExecutionResult::SetLocal) + } other => Err(internal!("unimplemented logical plan: {:?}", other)), } } diff --git a/crates/sqlexec/src/extended.rs b/crates/sqlexec/src/extended.rs new file mode 100644 index 000000000..60273251f --- /dev/null +++ b/crates/sqlexec/src/extended.rs @@ -0,0 +1,56 @@ +use crate::{ + errors::Result, + logical_plan::LogicalPlan, +}; + +// A prepared statement. +// This is contains the SQL statements that will later be turned into a +// portal when a Bind message is received. +#[derive(Debug)] +pub struct PreparedStatement { + pub sql: String, + pub param_types: Vec, +} + +impl PreparedStatement { + pub fn new(sql: String, param_types: Vec) -> Self { + // TODO: parse the SQL for placeholders + Self { sql, param_types } + } + + /// The Describe message statement variant returns a ParameterDescription message describing + /// the parameters needed by the statement, followed by a RowDescription message describing the + /// rows that will be returned when the statement is eventually executed. + /// If the statement will not return rows, then a NoData message is returned. + pub fn describe(&self) { + // since bind has not been issued, the formats to be used for returned columns are not yet + // known. In this case, the backend will assume the default format (text) for all columns. + todo!("describe statement") + } +} + +/// A Portal is the result of a prepared statement that has been bound with the Bind message. +/// The portal is a readied execution plan that can be executed using an Execute message. +#[derive(Debug)] +pub struct Portal { + pub plan: LogicalPlan, + pub param_formats: Vec, + pub param_values: Vec>>, + pub result_formats: Vec, +} + +impl Portal { + pub fn new( + plan: LogicalPlan, + param_formats: Vec, + param_values: Vec>>, + result_formats: Vec, + ) -> Result { + Ok(Self { + plan, + param_formats, + param_values, + result_formats, + }) + } +} diff --git a/crates/sqlexec/src/lib.rs b/crates/sqlexec/src/lib.rs index 63c3a30ff..b168810fc 100644 --- a/crates/sqlexec/src/lib.rs +++ b/crates/sqlexec/src/lib.rs @@ -3,6 +3,7 @@ pub mod datasource; pub mod engine; pub mod errors; pub mod executor; +pub mod extended; pub mod logical_plan; pub mod session; diff --git a/crates/sqlexec/src/logical_plan.rs b/crates/sqlexec/src/logical_plan.rs index b55665046..70c5279bb 100644 --- a/crates/sqlexec/src/logical_plan.rs +++ b/crates/sqlexec/src/logical_plan.rs @@ -1,9 +1,9 @@ use crate::errors::{internal, ExecError}; use datafusion::arrow::datatypes::Field; use datafusion::logical_plan::LogicalPlan as DfLogicalPlan; -use datafusion::sql::sqlparser::ast; +use datafusion::sql::sqlparser::ast::{self}; -#[derive(Debug)] +#[derive(Clone, Debug)] pub enum LogicalPlan { /// DDL plans. Ddl(DdlPlan), @@ -14,6 +14,9 @@ pub enum LogicalPlan { Query(DfLogicalPlan), /// Plans related to transaction management. Transaction(TransactionPlan), + /// Plans related to altering the state or runtime of the session. + // TODO: Actually implement this. This would correspond to "SET ..." and "SET SESSION ..." statements. + Runtime, } impl From for LogicalPlan { @@ -22,7 +25,7 @@ impl From for LogicalPlan { } } -#[derive(Debug)] +#[derive(Clone, Debug)] pub enum WritePlan { Insert(Insert), } @@ -33,7 +36,7 @@ impl From for LogicalPlan { } } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Insert { pub table_name: String, pub columns: Vec, @@ -44,12 +47,13 @@ pub struct Insert { /// /// Note that while datafusion has some support for DDL, it's very much focused /// on working with "external" data that won't be modified like parquet files. -#[derive(Debug)] +#[derive(Clone, Debug)] pub enum DdlPlan { CreateSchema(CreateSchema), CreateTable(CreateTable), CreateExternalTable(CreateExternalTable), CreateTableAs(CreateTableAs), + DropTable(DropTable), } impl From for LogicalPlan { @@ -58,20 +62,20 @@ impl From for LogicalPlan { } } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct CreateSchema { pub schema_name: String, pub if_not_exists: bool, } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct CreateTable { pub table_name: String, pub if_not_exists: bool, pub columns: Vec, } -#[derive(Debug)] +#[derive(Clone, Debug)] pub enum FileType { Parquet, } @@ -86,20 +90,26 @@ impl TryFrom for FileType { } } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct CreateExternalTable { pub table_name: String, pub location: String, pub file_type: FileType, } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct CreateTableAs { pub table_name: String, pub source: DfLogicalPlan, } -#[derive(Debug)] +#[derive(Clone, Debug)] +pub struct DropTable { + pub names: Vec, + pub if_exists: bool, +} + +#[derive(Clone, Debug)] pub enum TransactionPlan { Begin, Commit, diff --git a/crates/sqlexec/src/session.rs b/crates/sqlexec/src/session.rs index d5faa5131..8aa582430 100644 --- a/crates/sqlexec/src/session.rs +++ b/crates/sqlexec/src/session.rs @@ -1,6 +1,8 @@ use crate::catalog::{DatabaseCatalog, DEFAULT_SCHEMA}; use crate::datasource::MemTable; use crate::errors::{internal, Result}; +use crate::executor::ExecutionResult; +use crate::extended::{Portal, PreparedStatement}; use crate::logical_plan::*; use datafusion::arrow::datatypes::{Field, Schema}; use datafusion::catalog::catalog::CatalogList; @@ -15,9 +17,12 @@ use datafusion::physical_plan::{ SendableRecordBatchStream, }; use datafusion::sql::planner::{convert_data_type, SqlToRel}; -use datafusion::sql::sqlparser::ast; +use datafusion::sql::sqlparser::ast::{self, ObjectType}; +use datafusion::sql::sqlparser::dialect::PostgreSqlDialect; +use datafusion::sql::sqlparser::parser::Parser; use datafusion::sql::{ResolvedTableReference, TableReference}; use futures::StreamExt; +use hashbrown::hash_map::Entry; use std::sync::Arc; use tracing::debug; @@ -32,6 +37,12 @@ pub struct Session { /// The concretely typed "GlareDB" catalog. catalog: Arc, // TODO: Transaction context goes here. + + // prepared statements + unnamed_statement: Option, + named_statements: hashbrown::HashMap, + unnamed_portal: Option, + named_portals: hashbrown::HashMap, } impl Session { @@ -52,7 +63,14 @@ impl Session { let mut state = SessionState::with_config_rt(config, runtime); state.catalog_list = catalog.clone(); - Session { state, catalog } + Session { + state, + catalog, + unnamed_statement: None, + named_statements: hashbrown::HashMap::new(), + unnamed_portal: None, + named_portals: hashbrown::HashMap::new(), + } } pub(crate) fn plan_sql(&self, statement: ast::Statement) -> Result { @@ -167,7 +185,28 @@ impl Session { .into()) } - stmt => Err(internal!("unsupported sql statement: {}", stmt)), + // Drop tables + ast::Statement::Drop { + object_type: ObjectType::Table, + if_exists, + names, + .. + } => { + let names = names + .into_iter() + .map(|name| name.to_string()) + .collect::>(); + + Ok(DdlPlan::DropTable(DropTable { if_exists, names }).into()) + } + + // "SET ...", "SET SESSION ...", "SET LOCAL ..." + // TODO: Actually plan this. + ast::Statement::SetVariable { .. } | ast::Statement::SetRole { .. } => { + Ok(LogicalPlan::Runtime) + } + + stmt => Err(internal!("unsupported sql statement: {:?}", stmt)), } } @@ -320,4 +359,172 @@ impl Session { .ok_or_else(|| internal!("missing schema: {}", resolved.schema))?; Ok(schema) } + + /// Store the prepared statement in the current session. + /// It will later be readied for execution by using `bind_prepared_statement`. + pub fn create_prepared_statement( + &mut self, + name: Option, + sql: String, + params: Vec, + ) -> Result<()> { + match name { + None => { + // Store the unnamed prepared statement. + // This will persist until the session is dropped or another unnamed prepared statement is created + self.unnamed_statement = Some(PreparedStatement::new(sql, params)); + } + Some(name) => { + // Named prepared statements must be explicitly closed before being redefined + match self.named_statements.entry(name) { + Entry::Occupied(ent) => { + return Err(internal!( + "prepared statement already exists: {}", + ent.key() + )) + } + Entry::Vacant(ent) => { + ent.insert(PreparedStatement::new(sql, params)); + } + } + } + } + + Ok(()) + } + + pub fn get_prepared_statement(&self, name: &Option) -> Option<&PreparedStatement> { + match name { + None => self.unnamed_statement.as_ref(), + Some(name) => self.named_statements.get(name), + } + } + + pub fn get_portal(&self, portal_name: &Option) -> Option<&Portal> { + match portal_name { + None => self.unnamed_portal.as_ref(), + Some(name) => self.named_portals.get(name), + } + } + + /// Bind the parameters of a prepared statement to the given values. + /// If successful, the bound statement will create a portal which can be used to execute the statement. + pub fn bind_prepared_statement( + &mut self, + portal_name: Option, + statement_name: Option, + param_formats: Vec, + param_values: Vec>>, + result_formats: Vec, + ) -> Result<()> { + let statement = match statement_name { + None => self + .unnamed_statement + .as_mut() + .ok_or_else(|| internal!("no unnamed prepared statement"))?, + Some(name) => self + .named_statements + .get_mut(&name) + .ok_or_else(|| internal!("no prepared statement named: {}", name))?, + }; + + let statements = Parser::parse_sql(&PostgreSqlDialect {}, &statement.sql)? + .into_iter() + .collect::>(); + + // a portal can only be bound to a single statement + if statements.len() != 1 { + return Err(internal!("portal can only be bound to a single statement")); + } + + let statement = statements.into_iter().next().unwrap(); + + match portal_name { + None => { + // Store the unnamed portal. + // This will persist until the session is dropped or another unnamed portal is created + let plan = self.plan_sql(statement)?; + self.unnamed_portal = Some(Portal::new( + plan, + param_formats, + param_values, + result_formats, + )?); + } + Some(name) => { + // Named portals must be explicitly closed before being redefined + match self.named_portals.entry(name) { + Entry::Occupied(ent) => { + return Err(internal!("portal already exists: {}", ent.key())) + } + Entry::Vacant(_ent) => { + todo!("plan named portal"); + // let plan = self.plan_sql(statement)?; + // ent.insert(Portal::new(plan, param_formats, param_values, result_formats)?); + } + } + } + } + + Ok(()) + } + + pub async fn drop_table(&self, plan: DropTable) -> Result<()> { + debug!(?plan, "drop table"); + for name in plan.names { + let resolved = self.resolve_table_name(&name); + let schema = self.get_schema_for_reference(&resolved)?; + + schema.deregister_table(resolved.table)?; + } + + Ok(()) + } + + pub async fn execute_portal( + &mut self, + portal_name: &Option, + _max_rows: i32, + ) -> Result { + // TODO: respect max_rows + let portal = match portal_name { + None => self + .unnamed_portal + .as_mut() + .ok_or_else(|| internal!("no unnamed portal"))?, + Some(name) => self + .named_portals + .get_mut(name) + .ok_or_else(|| internal!("no portal named: {}", name))?, + }; + + match portal.plan.clone() { + LogicalPlan::Ddl(DdlPlan::CreateTable(plan)) => { + self.create_table(plan)?; + Ok(ExecutionResult::CreateTable) + } + LogicalPlan::Ddl(DdlPlan::CreateExternalTable(plan)) => { + self.create_external_table(plan).await?; + Ok(ExecutionResult::CreateTable) + } + LogicalPlan::Ddl(DdlPlan::CreateTableAs(plan)) => { + self.create_table_as(plan).await?; + Ok(ExecutionResult::CreateTable) + } + LogicalPlan::Ddl(DdlPlan::DropTable(plan)) => { + self.drop_table(plan).await?; + Ok(ExecutionResult::DropTables) + } + LogicalPlan::Write(WritePlan::Insert(plan)) => { + self.insert(plan).await?; + Ok(ExecutionResult::WriteSuccess) + } + LogicalPlan::Query(plan) => { + let physical = self.create_physical_plan(plan).await?; + let stream = self.execute_physical(physical)?; + Ok(ExecutionResult::Query { stream }) + } + other => Err(internal!("unimplemented logical plan: {:?}", other)), + } + } }