diff --git a/crates/pgrepr/src/error.rs b/crates/pgrepr/src/error.rs index 0d3294486..49a03749c 100644 --- a/crates/pgrepr/src/error.rs +++ b/crates/pgrepr/src/error.rs @@ -12,6 +12,18 @@ pub enum PgReprError { #[error(transparent)] Io(#[from] std::io::Error), + #[error(transparent)] + Utf8Error(#[from] std::str::Utf8Error), + + #[error("Binary read unimplemented.")] + BinaryReadUnimplemented, + + #[error("Failed to parse: {0}")] + ParseError(Box), + + #[error("Unsuported pg type for decoding: {0}")] + UnsupportedPgTypeForDecode(tokio_postgres::types::Type), + #[error("arrow type '{0}' not supported")] UnsupportedArrowType(datafusion::arrow::datatypes::DataType), diff --git a/crates/pgrepr/src/lib.rs b/crates/pgrepr/src/lib.rs index 03eec2ac5..9de3b3b39 100644 --- a/crates/pgrepr/src/lib.rs +++ b/crates/pgrepr/src/lib.rs @@ -3,4 +3,5 @@ pub mod format; pub mod oid; pub mod types; +mod reader; mod writer; diff --git a/crates/pgrepr/src/reader.rs b/crates/pgrepr/src/reader.rs new file mode 100644 index 000000000..7c1ca3c02 --- /dev/null +++ b/crates/pgrepr/src/reader.rs @@ -0,0 +1,98 @@ +use crate::error::{PgReprError, Result}; +use std::str::FromStr; + +/// Reader defines the interface for the different kinds of values that can be +/// decoded as a postgres type. +pub(crate) trait Reader { + fn read_bool(buf: &[u8]) -> Result; + + fn read_int2(buf: &[u8]) -> Result; + fn read_int4(buf: &[u8]) -> Result; + fn read_int8(buf: &[u8]) -> Result; + fn read_float4(buf: &[u8]) -> Result; + fn read_float8(buf: &[u8]) -> Result; + + fn read_text(buf: &[u8]) -> Result; +} + +#[derive(Debug)] +pub(crate) struct TextReader; + +impl TextReader { + fn parse>( + buf: &[u8], + ) -> Result { + std::str::from_utf8(buf)? + .parse::() + .map_err(|e| PgReprError::ParseError(Box::new(e))) + } +} + +impl Reader for TextReader { + fn read_bool(buf: &[u8]) -> Result { + Self::parse::<_, SqlBool>(buf).map(|b| b.0) + } + + fn read_int2(buf: &[u8]) -> Result { + Self::parse(buf) + } + + fn read_int4(buf: &[u8]) -> Result { + Self::parse(buf) + } + + fn read_int8(buf: &[u8]) -> Result { + Self::parse(buf) + } + + fn read_float4(buf: &[u8]) -> Result { + Self::parse(buf) + } + + fn read_float8(buf: &[u8]) -> Result { + Self::parse(buf) + } + + fn read_text(buf: &[u8]) -> Result { + Self::parse(buf) + } +} + +#[derive(Debug, thiserror::Error)] +#[error("String was not 't', 'true', 'f', or 'false'")] +struct ParseSqlBoolError; + +struct SqlBool(bool); + +impl FromStr for SqlBool { + type Err = ParseSqlBoolError; + fn from_str(s: &str) -> Result { + match s { + "true" | "t" => Ok(SqlBool(true)), + "false" | "f" => Ok(SqlBool(false)), + _ => Err(ParseSqlBoolError), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn read_sql_bool() { + let v = TextReader::read_bool("t".as_bytes()).unwrap(); + assert!(v); + + let v = TextReader::read_bool("true".as_bytes()).unwrap(); + assert!(v); + + let v = TextReader::read_bool("f".as_bytes()).unwrap(); + assert!(!v); + + let v = TextReader::read_bool("false".as_bytes()).unwrap(); + assert!(!v); + + let _ = TextReader::read_bool("none".as_bytes()).unwrap_err(); + } +} diff --git a/crates/pgrepr/src/types.rs b/crates/pgrepr/src/types.rs index 248d83d6f..232f914e6 100644 --- a/crates/pgrepr/src/types.rs +++ b/crates/pgrepr/src/types.rs @@ -10,6 +10,7 @@ use tokio_postgres::types::Type as PgType; use crate::error::{PgReprError, Result}; use crate::format::Format; +use crate::reader::{Reader, TextReader}; use crate::writer::{BinaryWriter, TextWriter, Writer}; /// Returns a compatible postgres type for the arrow datatype. @@ -78,6 +79,42 @@ pub fn encode_array_value( Ok(()) } +/// Decodes a scalar value using the provided format and pg type. +pub fn decode_scalar_value( + buf: Option<&[u8]>, + format: Format, + pg_type: &PgType, +) -> Result { + match buf { + Some(buf) => match format { + Format::Text => decode_not_null_value::(buf, pg_type), + Format::Binary => Err(PgReprError::BinaryReadUnimplemented), + }, + None => Ok(ScalarValue::Null), + } +} + +/// Decodes a non-null value from the buffer of the given pg type. +// TODO: We currently get the pg type from the inferred arrow type during +// parameter type resolution. This can lead to a loss of type information or may +// result in an inexact conversion (e.g. arrow doesn't have a JSON type, but +// postgres does). +// +// We may be able to work around this by inferring pg types directly instead of +// needing to infer arrow types, then convert those into postgres types. +fn decode_not_null_value(buf: &[u8], pg_type: &PgType) -> Result { + Ok(match pg_type { + &PgType::BOOL => R::read_bool(buf)?.into(), + &PgType::INT2 => R::read_int2(buf)?.into(), + &PgType::INT4 => R::read_int4(buf)?.into(), + &PgType::INT8 => R::read_int8(buf)?.into(), + &PgType::FLOAT4 => R::read_float4(buf)?.into(), + &PgType::FLOAT8 => R::read_float8(buf)?.into(), + &PgType::TEXT => ScalarValue::Utf8(Some(R::read_text(buf)?)), + other => return Err(PgReprError::UnsupportedPgTypeForDecode(other.clone())), + }) +} + /// Per writer implementation for encoding non-null array values. fn encode_array_not_null_value( buf: &mut BytesMut, diff --git a/crates/pgsrv/src/handler.rs b/crates/pgsrv/src/handler.rs index d902e05b5..e5f084ccb 100644 --- a/crates/pgsrv/src/handler.rs +++ b/crates/pgsrv/src/handler.rs @@ -7,8 +7,10 @@ use crate::messages::{ use crate::proxy::{ProxyKey, GLAREDB_DATABASE_ID_KEY, GLAREDB_USER_ID_KEY}; use crate::ssl::{Connection, SslConfig}; use datafusion::physical_plan::SendableRecordBatchStream; +use datafusion::scalar::ScalarValue; use futures::StreamExt; use pgrepr::format::Format; +use pgrepr::types::decode_scalar_value; use sqlexec::context::{OutputFields, Portal, PreparedStatement}; use sqlexec::{ engine::Engine, @@ -369,13 +371,9 @@ where ); // Bind... - if let Err(e) = session.bind_statement( - UNNAMED, - &UNNAMED, - Vec::new(), - Vec::new(), - all_text_formats(num_fields), - ) { + if let Err(e) = + session.bind_statement(UNNAMED, &UNNAMED, Vec::new(), all_text_formats(num_fields)) + { self.send_error(e.into()).await?; return self.ready_for_query().await; } @@ -462,7 +460,14 @@ where Err(e) => return self.send_error(e.into()).await, }; - // TODO: Check and parse param formats and values. + // Read scalars for query parameters. + let scalars = match stmt.input_paramaters() { + Some(types) => match decode_param_scalars(param_formats, param_values, types) { + Ok(scalars) => scalars, + Err(e) => return self.send_error(e).await, + }, + None => Vec::new(), // Would only happen with an empty query. + }; // Extend out the result formats. let result_formats = match extend_formats( @@ -473,13 +478,10 @@ where Err(e) => return self.send_error(e).await, }; - match self.session.bind_statement( - portal, - &statement, - param_formats, - param_values, - result_formats, - ) { + match self + .session + .bind_statement(portal, &statement, scalars, result_formats) + { Ok(_) => self.conn.send(BackendMessage::BindComplete).await, Err(e) => self.send_error(e.into()).await, } @@ -648,6 +650,55 @@ where } } +/// Decodes inputs for a prepared query into the appropriate scalar values. +fn decode_param_scalars( + param_formats: Vec, + param_values: Vec>>, + types: &HashMap>, +) -> Result, ErrorResponse> { + let param_formats = extend_formats(param_formats, param_values.len())?; + + if param_values.len() != types.len() { + return Err(ErrorResponse::error_internal(format!( + "Invalid number of values provided. Expected: {}, got: {}", + types.len(), + param_values.len(), + ))); + } + + let mut scalars = Vec::with_capacity(param_values.len()); + for (idx, (val, format)) in param_values + .into_iter() + .zip(param_formats.into_iter()) + .enumerate() + { + // Parameter types keyed by '$n'. + let str_id = format!("${}", idx + 1); + + let typ = types.get(&str_id).ok_or_else(|| { + ErrorResponse::error_internal(format!( + "Missing type for param value at index {}, input types: {:?}", + idx, types + )) + })?; + + match typ { + Some(typ) => { + let scalar = decode_scalar_value(val.as_deref(), format, typ)?; + scalars.push(scalar); + } + None => { + return Err(ErrorResponse::error_internal(format!( + "Unknown type at index {}, input types: {:?}", + idx, types + ))) + } + } + } + + Ok(scalars) +} + /// Parse a sql string, returning an error response if failed to parse. fn parse_sql(sql: &str) -> Result, ErrorResponse> { parser::parse_sql(sql).map_err(|e| ErrorResponse::error(SqlState::SyntaxError, e.to_string())) @@ -684,3 +735,88 @@ fn get_encoding_state(portal: &Portal) -> Vec<(PgType, Format)> { .collect(), } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn decode_params_success() { + // Success test cases for decoding params. + + struct TestCase { + values: Vec>>, + types: Vec<(&'static str, Option)>, + expected: Vec, + } + + let test_cases = vec![ + // No params. + TestCase { + values: Vec::new(), + types: Vec::new(), + expected: Vec::new(), + }, + // One param of type int64. + TestCase { + values: vec![Some(vec![49])], + types: vec![("$1", Some(PgType::INT8))], + expected: vec![ScalarValue::Int64(Some(1))], + }, + // Two params param of type string. + TestCase { + values: vec![Some(vec![49, 48]), Some(vec![50, 48])], + types: vec![("$1", Some(PgType::TEXT)), ("$2", Some(PgType::TEXT))], + expected: vec![ + ScalarValue::Utf8(Some("10".to_string())), + ScalarValue::Utf8(Some("20".to_string())), + ], + }, + ]; + + for test_case in test_cases { + let types: HashMap<_, _> = test_case + .types + .into_iter() + .map(|(k, v)| (k.to_string(), v)) + .collect(); + + let scalars = decode_param_scalars(Vec::new(), test_case.values, &types).unwrap(); + assert_eq!(test_case.expected, scalars); + } + } + + #[test] + fn decode_params_fail() { + // Failure test cases for decoding params (all cases should result in an + // error). + + struct TestCase { + values: Vec>>, + types: Vec<(&'static str, Option)>, + } + + let test_cases = vec![ + // Params provided, none expected. + TestCase { + values: vec![Some(vec![49])], + types: Vec::new(), + }, + // No params provided, one expected. + TestCase { + values: Vec::new(), + types: vec![("$1", Some(PgType::INT8))], + }, + ]; + + for test_case in test_cases { + let types: HashMap<_, _> = test_case + .types + .into_iter() + .map(|(k, v)| (k.to_string(), v)) + .collect(); + + decode_param_scalars(Vec::new(), test_case.values, &types).unwrap_err(); + } + } +} diff --git a/crates/pgsrv/src/messages.rs b/crates/pgsrv/src/messages.rs index 378c84b9d..acc880477 100644 --- a/crates/pgsrv/src/messages.rs +++ b/crates/pgsrv/src/messages.rs @@ -1,4 +1,5 @@ use datafusion::arrow::record_batch::RecordBatch; +use pgrepr::error::PgReprError; use pgrepr::format::Format; use sqlexec::errors::ExecError; use std::collections::HashMap; @@ -256,6 +257,13 @@ impl From<&PgSrvError> for ErrorResponse { } } +impl From for ErrorResponse { + fn from(e: PgReprError) -> Self { + // TODO: Actually set appropriate codes. + ErrorResponse::error_internal(e.to_string()) + } +} + #[derive(Debug)] pub enum NoticeSeverity { Warning, diff --git a/crates/sqlexec/src/context.rs b/crates/sqlexec/src/context.rs index fed2fa673..0c7b21019 100644 --- a/crates/sqlexec/src/context.rs +++ b/crates/sqlexec/src/context.rs @@ -15,6 +15,7 @@ use datafusion::error::{DataFusionError, Result as DataFusionResult}; use datafusion::execution::context::{SessionConfig, SessionState, TaskContext}; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::logical_expr::{AggregateUDF, ScalarUDF, TableSource}; +use datafusion::scalar::ScalarValue; use datafusion::sql::planner::ContextProvider; use datafusion::sql::TableReference; use datasource_common::ssh::SshTunnelAccess; @@ -299,7 +300,7 @@ impl SessionContext { &mut self, name: String, stmt: Option, - params: Vec, + _params: Vec, // TODO: We can use these for providing types for parameters. ) -> Result<()> { // Refresh the cached catalog state if necessary if *self.get_session_vars().force_catalog_refresh.value() { @@ -317,12 +318,6 @@ impl SessionContext { self.metastore_catalog.swap_state(new_state); } - if !params.is_empty() { - return Err(ExecError::UnsupportedFeature( - "prepared statements with parameters", - )); - } - // Unnamed (empty string) prepared statements can be overwritten // whenever. Named prepared statements must be explicitly removed before // being used again. @@ -348,6 +343,7 @@ impl SessionContext { &mut self, portal_name: String, stmt_name: &str, + params: Vec, result_formats: Vec, ) -> Result<()> { // Unnamed portals can be overwritten, named portals need to be @@ -359,11 +355,16 @@ impl SessionContext { )); } - let stmt = match self.prepared.get(stmt_name) { + let mut stmt = match self.prepared.get(stmt_name) { Some(prepared) => prepared.clone(), None => return Err(ExecError::UnknownPreparedStatement(stmt_name.to_string())), }; + // Replace placeholders if necessary. + if let Some(plan) = &mut stmt.plan { + plan.replace_placeholders(params)?; + } + assert_eq!( result_formats.len(), stmt.output_fields().map(|f| f.len()).unwrap_or(0) @@ -545,6 +546,8 @@ pub struct PreparedStatement { /// The logical plan for the statement. Is `Some` if the statement is /// `Some`. pub(crate) plan: Option, + /// Parameter data types. + pub(crate) parameter_types: Option>>, /// The output schema of the statement if it produces an output. pub(crate) output_schema: Option, /// Output postgres types. @@ -571,9 +574,21 @@ impl PreparedStatement { None => Vec::new(), }; + // Convert inferred arrow types for parameters into their associated + // pg type. + let parameter_types: HashMap<_, _> = plan + .get_parameter_types()? + .into_iter() + .map(|(id, arrow_type)| { + let typ = arrow_type.map(|typ| arrow_to_pg_type(&typ, None)); + (id, typ) + }) + .collect(); + Ok(PreparedStatement { stmt: Some(inner), plan: Some(plan), + parameter_types: Some(parameter_types), output_schema: schema, output_pg_types: pg_types, }) @@ -582,6 +597,7 @@ impl PreparedStatement { Ok(PreparedStatement { stmt: None, plan: None, + parameter_types: None, output_schema: None, output_pg_types: Vec::new(), }) @@ -597,6 +613,12 @@ impl PreparedStatement { result_formats: None, }) } + + /// Returns the type of the input parameters. Input paramets are keyed as + /// "$n" starting at "$1". + pub fn input_paramaters(&self) -> Option<&HashMap>> { + self.parameter_types.as_ref() + } } #[derive(Debug, Clone)] diff --git a/crates/sqlexec/src/logical_plan.rs b/crates/sqlexec/src/logical_plan.rs index b6eaa5477..cd92ce482 100644 --- a/crates/sqlexec/src/logical_plan.rs +++ b/crates/sqlexec/src/logical_plan.rs @@ -1,8 +1,10 @@ use crate::errors::{internal, Result}; use datafusion::arrow::datatypes::{DataType, Field, Schema as ArrowSchema}; use datafusion::logical_expr::LogicalPlan as DfLogicalPlan; +use datafusion::scalar::ScalarValue; use datafusion::sql::sqlparser::ast; use metastore::types::catalog::{ConnectionOptions, TableOptions}; +use std::collections::HashMap; #[derive(Clone, Debug)] pub enum LogicalPlan { @@ -42,6 +44,29 @@ impl LogicalPlan { _ => None, } } + + /// Get parameter types for the logical plan. + /// + /// Note this will only try to get the parameters if the plan is a + /// datafusion logical plan. Possible support for other plans may come + /// later. + pub fn get_parameter_types(&self) -> Result>> { + Ok(match self { + LogicalPlan::Query(plan) => plan.get_parameter_types()?, + _ => HashMap::new(), + }) + } + + /// Replace placeholders in this plan with the provided scalars. + /// + /// Note this currently only replaces placeholders for datafusion plans. + pub fn replace_placeholders(&mut self, scalars: Vec) -> Result<()> { + if let LogicalPlan::Query(plan) = self { + *plan = plan.replace_params_with_values(&scalars)?; + } + + Ok(()) + } } impl From for LogicalPlan { diff --git a/crates/sqlexec/src/session.rs b/crates/sqlexec/src/session.rs index de26f54a7..82a42161c 100644 --- a/crates/sqlexec/src/session.rs +++ b/crates/sqlexec/src/session.rs @@ -10,6 +10,7 @@ use datafusion::logical_expr::LogicalPlan as DfLogicalPlan; use datafusion::physical_plan::{ execute_stream, memory::MemoryStream, ExecutionPlan, SendableRecordBatchStream, }; +use datafusion::scalar::ScalarValue; use metastore::session::SessionCatalog; use pgrepr::format::Format; use std::fmt; @@ -270,18 +271,11 @@ impl Session { &mut self, portal_name: String, stmt_name: &str, - param_formats: Vec, - param_values: Vec>>, + params: Vec, result_formats: Vec, ) -> Result<()> { - // We don't currently support parameters. We're already erroring on - // attempting to prepare statements with parameters, so this is just - // ensuring that we're not missing anything right now. - assert_eq!(0, param_formats.len()); - assert_eq!(0, param_values.len()); - self.ctx - .bind_statement(portal_name, stmt_name, result_formats) + .bind_statement(portal_name, stmt_name, params, result_formats) } async fn execute_inner(&mut self, plan: LogicalPlan) -> Result { diff --git a/testdata/pgprototest/extended_params.pt b/testdata/pgprototest/extended_params.pt new file mode 100644 index 000000000..5926c3003 --- /dev/null +++ b/testdata/pgprototest/extended_params.pt @@ -0,0 +1,176 @@ +# Extended query protocol with paramaters. + + +# No context for type. Note that this fails with datafusion planning. +# -> ErrorResponse {"fields":["ERROR","ERROR","XX000","Error during planning: Placeholder type could not be resolved"]} + +# send +# Parse {"query": "select $1"} +# Bind {"values": ["4"]} +# Execute +# Sync +# ---- + +# until +# ReadyForQuery +# ---- +# ParseComplete +# BindComplete +# DataRow {"fields":["4"]} +# CommandComplete {"tag":"SELECT 1"} +# ReadyForQuery {"status":"I"} + + +# Type provided. Fails for us, missing type. +# -> ErrorResponse {"fields":["ERROR","ERROR","XX000","missing type for param value at index 0, input types: {}"]} + +# send +# Parse {"query": "select $1::text"} +# Bind {"values": ["5"]} +# Execute +# Sync +# ---- +# +# until +# ReadyForQuery +# ---- +# ParseComplete +# BindComplete +# DataRow {"fields":["5"]} +# CommandComplete {"tag":"SELECT 1"} +# ReadyForQuery {"status":"I"} + + +# In binary expression (add). + +send +Parse {"query": "select $1 + 1"} +Bind {"values": ["1"]} +Execute +Sync +---- + +until +ReadyForQuery +---- +ParseComplete +BindComplete +DataRow {"fields":["2"]} +CommandComplete {"tag":"SELECT 1"} +ReadyForQuery {"status":"I"} + + +# In where clause. + +send +Parse {"query": "select * from (select * from (values (1, 2), (3, 4)) as _) as sub(a, b) where a > $1"} +Bind {"values": ["2"]} +Execute +Sync +---- + +until +ReadyForQuery +---- +ParseComplete +BindComplete +DataRow {"fields":["3","4"]} +CommandComplete {"tag":"SELECT 1"} +ReadyForQuery {"status":"I"} + + +# String in where clause. + +send +Parse {"query": "select * from (select * from (values ('10', '20'), ('30', '40')) as _) as sub(a, b) where a = $1"} +Bind {"values": ["10"]} +Execute +Sync +---- + +until +ReadyForQuery +---- +ParseComplete +BindComplete +DataRow {"fields":["10","20"]} +CommandComplete {"tag":"SELECT 1"} +ReadyForQuery {"status":"I"} + + +# String parameter in scalar function. Fails for us (datafusion). +# -> ErrorResponse {"fields":["ERROR","ERROR","XX000","Error during planning: Placeholder type could not be resolved"]} + +# send +# Parse {"query": "select reverse($1)"} +# Bind {"values": ["hello"]} +# Execute +# Sync +# ---- +# +# until +# ReadyForQuery +# ---- +# ParseComplete +# BindComplete +# DataRow {"fields":["olleh"]} +# CommandComplete {"tag":"SELECT 1"} +# ReadyForQuery {"status":"I"} + + +# Multiple parameters. + +send +Parse {"query": "select ($1 + 1) >= $2"} +Bind {"values": ["1", "2"]} +Execute +Sync +---- + +until +ReadyForQuery +---- +ParseComplete +BindComplete +DataRow {"fields":["t"]} +CommandComplete {"tag":"SELECT 1"} +ReadyForQuery {"status":"I"} + + +# Fails for us. +# -> ErrorResponse {"fields":["ERROR","ERROR","XX000","missing type for param value at index 0, input types: {}"]} + +# send +# Parse {"query": "select $1 is true, $2 is true, $3 is false, $4 is false"} +# Bind {"values": ["t", "true", "f", "false"]} +# Execute +# Sync +# ---- +# +# until +# ReadyForQuery +# ---- +# ParseComplete +# BindComplete +# DataRow {"fields":["t","t","t","t"]} +# CommandComplete {"tag":"SELECT 1"} +# ReadyForQuery {"status":"I"} + + +# Float sanity check. + +send +Parse {"query": "select $1 > 0.1"} +Bind {"values": ["0.2"]} +Execute +Sync +---- + +until +ReadyForQuery +---- +ParseComplete +BindComplete +DataRow {"fields":["t"]} +CommandComplete {"tag":"SELECT 1"} +ReadyForQuery {"status":"I"}