diff --git a/src/ast/value.rs b/src/ast/value.rs index 2afdfaeae..583866849 100644 --- a/src/ast/value.rs +++ b/src/ast/value.rs @@ -56,6 +56,8 @@ pub enum Value { }, /// `NULL` value Null, + /// `?` or `$` Prepared statement arg placeholder + Placeholder(String), } impl fmt::Display for Value { @@ -108,6 +110,7 @@ impl fmt::Display for Value { Ok(()) } Value::Null => write!(f, "NULL"), + Value::Placeholder(v) => write!(f, "{}", v), } } } diff --git a/src/parser.rs b/src/parser.rs index 5fc3846dc..9b2cf0d50 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -445,6 +445,10 @@ impl<'a> Parser<'a> { self.expect_token(&Token::RParen)?; Ok(expr) } + Token::Placeholder(_) => { + self.prev_token(); + Ok(Expr::Value(self.parse_value()?)) + } unexpected => self.expected("an expression:", unexpected), }?; @@ -1966,6 +1970,7 @@ impl<'a> Parser<'a> { Token::SingleQuotedString(ref s) => Ok(Value::SingleQuotedString(s.to_string())), Token::NationalStringLiteral(ref s) => Ok(Value::NationalStringLiteral(s.to_string())), Token::HexStringLiteral(ref s) => Ok(Value::HexStringLiteral(s.to_string())), + Token::Placeholder(ref s) => Ok(Value::Placeholder(s.to_string())), unexpected => self.expected("a value", unexpected), } } diff --git a/src/tokenizer.rs b/src/tokenizer.rs index d04e1d8f7..6b5b91f43 100644 --- a/src/tokenizer.rs +++ b/src/tokenizer.rs @@ -124,6 +124,8 @@ pub enum Token { PGSquareRoot, /// `||/` , a cube root math operator in PostgreSQL PGCubeRoot, + /// `?` or `$` , a prepared statement arg placeholder + Placeholder(String), } impl fmt::Display for Token { @@ -176,6 +178,7 @@ impl fmt::Display for Token { Token::ShiftRight => f.write_str(">>"), Token::PGSquareRoot => f.write_str("|/"), Token::PGCubeRoot => f.write_str("||/"), + Token::Placeholder(ref s) => write!(f, "{}", s), } } } @@ -304,6 +307,7 @@ impl<'a> Tokenizer<'a> { Token::Word(w) if w.quote_style != None => self.col += w.value.len() as u64 + 2, Token::Number(s, _) => self.col += s.len() as u64, Token::SingleQuotedString(s) => self.col += s.len() as u64, + Token::Placeholder(s) => self.col += s.len() as u64, _ => self.col += 1, } @@ -550,6 +554,15 @@ impl<'a> Tokenizer<'a> { '~' => self.consume_and_return(chars, Token::Tilde), '#' => self.consume_and_return(chars, Token::Sharp), '@' => self.consume_and_return(chars, Token::AtSign), + '?' => self.consume_and_return(chars, Token::Placeholder(String::from("?"))), + '$' => { + chars.next(); + let s = peeking_take_while( + chars, + |ch| matches!(ch, '0'..='9' | 'A'..='Z' | 'a'..='z'), + ); + Ok(Some(Token::Placeholder(String::from("$") + &s))) + } other => self.consume_and_return(chars, Token::Char(other)), }, None => Ok(None), @@ -616,7 +629,7 @@ impl<'a> Tokenizer<'a> { 'r' => s.push('\r'), 't' => s.push('\t'), 'Z' => s.push('\x1a'), - x => s.push(x) + x => s.push(x), } } } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 4739f146d..91c97c5c0 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -24,8 +24,12 @@ use test_utils::{all_dialects, expr_from_projection, join, number, only, table, use matches::assert_matches; use sqlparser::ast::*; -use sqlparser::dialect::{keywords::ALL_KEYWORDS, SQLiteDialect}; +use sqlparser::dialect::{ + keywords::ALL_KEYWORDS, AnsiDialect, GenericDialect, MsSqlDialect, PostgreSqlDialect, + SQLiteDialect, SnowflakeDialect, +}; use sqlparser::parser::{Parser, ParserError}; +use sqlparser::test_utils::TestedDialects; #[test] fn parse_insert_values() { @@ -2715,10 +2719,10 @@ fn parse_scalar_subqueries() { assert_matches!( verified_expr(sql), Expr::BinaryOp { - op: BinaryOperator::Plus, .. - //left: box Subquery { .. }, - //right: box Subquery { .. }, - } + op: BinaryOperator::Plus, + .. //left: box Subquery { .. }, + //right: box Subquery { .. }, + } ); } @@ -3557,6 +3561,42 @@ fn parse_rolling_window() { ); } +#[test] +fn test_placeholder() { + let sql = "SELECT * FROM student WHERE id = ?"; + let ast = verified_only_select(sql); + assert_eq!( + ast.selection, + Some(Expr::BinaryOp { + left: Box::new(Expr::Identifier(Ident::new("id"))), + op: BinaryOperator::Eq, + right: Box::new(Expr::Value(Value::Placeholder("?".into()))) + }) + ); + + let dialects = TestedDialects { + dialects: vec![ + Box::new(GenericDialect {}), + Box::new(PostgreSqlDialect {}), + Box::new(MsSqlDialect {}), + Box::new(AnsiDialect {}), + Box::new(SnowflakeDialect {}), + // Note: `$` is the starting word for the HiveDialect identifier + // Box::new(sqlparser::dialect::HiveDialect {}), + ], + }; + let sql = "SELECT * FROM student WHERE id = $Id1"; + let ast = dialects.verified_only_select(sql); + assert_eq!( + ast.selection, + Some(Expr::BinaryOp { + left: Box::new(Expr::Identifier(Ident::new("id"))), + op: BinaryOperator::Eq, + right: Box::new(Expr::Value(Value::Placeholder("$Id1".into()))) + }) + ); +} + fn parse_sql_statements(sql: &str) -> Result, ParserError> { all_dialects().parse_sql_statements(sql) }