From 535ddf99f51c4643714f87ddb91a96b3c10d74f8 Mon Sep 17 00:00:00 2001 From: gamife <gamife9886@gmail.com> Date: Fri, 11 Feb 2022 10:36:37 +0800 Subject: [PATCH] feat: add arg placeholder --- src/ast/value.rs | 3 +++ src/parser.rs | 5 +++++ src/tokenizer.rs | 13 +++++++++++++ tests/sqlparser_common.rs | 40 ++++++++++++++++++++++++++++++++++++++- 4 files changed, 60 insertions(+), 1 deletion(-) diff --git a/src/ast/value.rs b/src/ast/value.rs index 0742fbd05..a4df8e8fe 100644 --- a/src/ast/value.rs +++ b/src/ast/value.rs @@ -59,6 +59,8 @@ pub enum Value { }, /// `NULL` value Null, + /// `?` or `$` Prepared statement arg placeholder + Placeholder(String), } impl fmt::Display for Value { @@ -111,6 +113,7 @@ impl fmt::Display for Value { Ok(()) } Value::Null => write!(f, "NULL"), + Value::Placeholder(v) => write!(f, "{}", v), } } } diff --git a/src/parser.rs b/src/parser.rs index 17051327d..37775e4e9 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -504,6 +504,10 @@ impl<'a> Parser<'a> { self.expect_token(&Token::RParen)?; Ok(expr) } + Token::Placeholder(_) => { + self.prev_token(); + Ok(Expr::Value(self.parse_value()?)) + } unexpected => self.expected("an expression:", unexpected), }?; @@ -2234,6 +2238,7 @@ impl<'a> Parser<'a> { Token::SingleQuotedString(ref s) => Ok(Value::SingleQuotedString(s.to_string())), Token::NationalStringLiteral(ref s) => Ok(Value::NationalStringLiteral(s.to_string())), Token::HexStringLiteral(ref s) => Ok(Value::HexStringLiteral(s.to_string())), + Token::Placeholder(ref s) => Ok(Value::Placeholder(s.to_string())), unexpected => self.expected("a value", unexpected), } } diff --git a/src/tokenizer.rs b/src/tokenizer.rs index 05b530216..4fb962555 100644 --- a/src/tokenizer.rs +++ b/src/tokenizer.rs @@ -139,6 +139,8 @@ pub enum Token { PGSquareRoot, /// `||/` , a cube root math operator in PostgreSQL PGCubeRoot, + /// `?` or `$` , a prepared statement arg placeholder + Placeholder(String), } impl fmt::Display for Token { @@ -194,6 +196,7 @@ impl fmt::Display for Token { Token::ShiftRight => f.write_str(">>"), Token::PGSquareRoot => f.write_str("|/"), Token::PGCubeRoot => f.write_str("||/"), + Token::Placeholder(ref s) => write!(f, "{}", s), } } } @@ -337,6 +340,7 @@ impl<'a> Tokenizer<'a> { Token::Word(w) if w.quote_style != None => self.col += w.value.len() as u64 + 2, Token::Number(s, _) => self.col += s.len() as u64, Token::SingleQuotedString(s) => self.col += s.len() as u64, + Token::Placeholder(s) => self.col += s.len() as u64, _ => self.col += 1, } @@ -598,6 +602,15 @@ impl<'a> Tokenizer<'a> { } '#' => self.consume_and_return(chars, Token::Sharp), '@' => self.consume_and_return(chars, Token::AtSign), + '?' => self.consume_and_return(chars, Token::Placeholder(String::from("?"))), + '$' => { + chars.next(); + let s = peeking_take_while( + chars, + |ch| matches!(ch, '0'..='9' | 'A'..='Z' | 'a'..='z'), + ); + Ok(Some(Token::Placeholder(String::from("$") + &s))) + } other => self.consume_and_return(chars, Token::Char(other)), }, None => Ok(None), diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index a6ba19509..e61024de3 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -22,7 +22,9 @@ mod test_utils; use matches::assert_matches; use sqlparser::ast::*; -use sqlparser::dialect::{GenericDialect, PostgreSqlDialect, SQLiteDialect}; +use sqlparser::dialect::{ + AnsiDialect, GenericDialect, MsSqlDialect, PostgreSqlDialect, SQLiteDialect, SnowflakeDialect, +}; use sqlparser::keywords::ALL_KEYWORDS; use sqlparser::parser::{Parser, ParserError}; use test_utils::{ @@ -4160,6 +4162,42 @@ fn test_revoke() { } } +#[test] +fn test_placeholder() { + let sql = "SELECT * FROM student WHERE id = ?"; + let ast = verified_only_select(sql); + assert_eq!( + ast.selection, + Some(Expr::BinaryOp { + left: Box::new(Expr::Identifier(Ident::new("id"))), + op: BinaryOperator::Eq, + right: Box::new(Expr::Value(Value::Placeholder("?".into()))) + }) + ); + + let dialects = TestedDialects { + dialects: vec![ + Box::new(GenericDialect {}), + Box::new(PostgreSqlDialect {}), + Box::new(MsSqlDialect {}), + Box::new(AnsiDialect {}), + Box::new(SnowflakeDialect {}), + // Note: `$` is the starting word for the HiveDialect identifier + // Box::new(sqlparser::dialect::HiveDialect {}), + ], + }; + let sql = "SELECT * FROM student WHERE id = $Id1"; + let ast = dialects.verified_only_select(sql); + assert_eq!( + ast.selection, + Some(Expr::BinaryOp { + left: Box::new(Expr::Identifier(Ident::new("id"))), + op: BinaryOperator::Eq, + right: Box::new(Expr::Value(Value::Placeholder("$Id1".into()))) + }) + ); +} + #[test] fn all_keywords_sorted() { // assert!(ALL_KEYWORDS.is_sorted())