From 535ddf99f51c4643714f87ddb91a96b3c10d74f8 Mon Sep 17 00:00:00 2001
From: gamife <gamife9886@gmail.com>
Date: Fri, 11 Feb 2022 10:36:37 +0800
Subject: [PATCH] feat: add arg placeholder

---
 src/ast/value.rs          |  3 +++
 src/parser.rs             |  5 +++++
 src/tokenizer.rs          | 13 +++++++++++++
 tests/sqlparser_common.rs | 40 ++++++++++++++++++++++++++++++++++++++-
 4 files changed, 60 insertions(+), 1 deletion(-)

diff --git a/src/ast/value.rs b/src/ast/value.rs
index 0742fbd05..a4df8e8fe 100644
--- a/src/ast/value.rs
+++ b/src/ast/value.rs
@@ -59,6 +59,8 @@ pub enum Value {
     },
     /// `NULL` value
     Null,
+    /// `?` or `$` Prepared statement arg placeholder
+    Placeholder(String),
 }
 
 impl fmt::Display for Value {
@@ -111,6 +113,7 @@ impl fmt::Display for Value {
                 Ok(())
             }
             Value::Null => write!(f, "NULL"),
+            Value::Placeholder(v) => write!(f, "{}", v),
         }
     }
 }
diff --git a/src/parser.rs b/src/parser.rs
index 17051327d..37775e4e9 100644
--- a/src/parser.rs
+++ b/src/parser.rs
@@ -504,6 +504,10 @@ impl<'a> Parser<'a> {
                 self.expect_token(&Token::RParen)?;
                 Ok(expr)
             }
+            Token::Placeholder(_) => {
+                self.prev_token();
+                Ok(Expr::Value(self.parse_value()?))
+            }
             unexpected => self.expected("an expression:", unexpected),
         }?;
 
@@ -2234,6 +2238,7 @@ impl<'a> Parser<'a> {
             Token::SingleQuotedString(ref s) => Ok(Value::SingleQuotedString(s.to_string())),
             Token::NationalStringLiteral(ref s) => Ok(Value::NationalStringLiteral(s.to_string())),
             Token::HexStringLiteral(ref s) => Ok(Value::HexStringLiteral(s.to_string())),
+            Token::Placeholder(ref s) => Ok(Value::Placeholder(s.to_string())),
             unexpected => self.expected("a value", unexpected),
         }
     }
diff --git a/src/tokenizer.rs b/src/tokenizer.rs
index 05b530216..4fb962555 100644
--- a/src/tokenizer.rs
+++ b/src/tokenizer.rs
@@ -139,6 +139,8 @@ pub enum Token {
     PGSquareRoot,
     /// `||/` , a cube root math operator in PostgreSQL
     PGCubeRoot,
+    /// `?` or `$` , a prepared statement arg placeholder
+    Placeholder(String),
 }
 
 impl fmt::Display for Token {
@@ -194,6 +196,7 @@ impl fmt::Display for Token {
             Token::ShiftRight => f.write_str(">>"),
             Token::PGSquareRoot => f.write_str("|/"),
             Token::PGCubeRoot => f.write_str("||/"),
+            Token::Placeholder(ref s) => write!(f, "{}", s),
         }
     }
 }
@@ -337,6 +340,7 @@ impl<'a> Tokenizer<'a> {
                 Token::Word(w) if w.quote_style != None => self.col += w.value.len() as u64 + 2,
                 Token::Number(s, _) => self.col += s.len() as u64,
                 Token::SingleQuotedString(s) => self.col += s.len() as u64,
+                Token::Placeholder(s) => self.col += s.len() as u64,
                 _ => self.col += 1,
             }
 
@@ -598,6 +602,15 @@ impl<'a> Tokenizer<'a> {
                 }
                 '#' => self.consume_and_return(chars, Token::Sharp),
                 '@' => self.consume_and_return(chars, Token::AtSign),
+                '?' => self.consume_and_return(chars, Token::Placeholder(String::from("?"))),
+                '$' => {
+                    chars.next();
+                    let s = peeking_take_while(
+                        chars,
+                        |ch| matches!(ch, '0'..='9' | 'A'..='Z' | 'a'..='z'),
+                    );
+                    Ok(Some(Token::Placeholder(String::from("$") + &s)))
+                }
                 other => self.consume_and_return(chars, Token::Char(other)),
             },
             None => Ok(None),
diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs
index a6ba19509..e61024de3 100644
--- a/tests/sqlparser_common.rs
+++ b/tests/sqlparser_common.rs
@@ -22,7 +22,9 @@
 mod test_utils;
 use matches::assert_matches;
 use sqlparser::ast::*;
-use sqlparser::dialect::{GenericDialect, PostgreSqlDialect, SQLiteDialect};
+use sqlparser::dialect::{
+    AnsiDialect, GenericDialect, MsSqlDialect, PostgreSqlDialect, SQLiteDialect, SnowflakeDialect,
+};
 use sqlparser::keywords::ALL_KEYWORDS;
 use sqlparser::parser::{Parser, ParserError};
 use test_utils::{
@@ -4160,6 +4162,42 @@ fn test_revoke() {
     }
 }
 
+#[test]
+fn test_placeholder() {
+    let sql = "SELECT * FROM student WHERE id = ?";
+    let ast = verified_only_select(sql);
+    assert_eq!(
+        ast.selection,
+        Some(Expr::BinaryOp {
+            left: Box::new(Expr::Identifier(Ident::new("id"))),
+            op: BinaryOperator::Eq,
+            right: Box::new(Expr::Value(Value::Placeholder("?".into())))
+        })
+    );
+
+    let dialects = TestedDialects {
+        dialects: vec![
+            Box::new(GenericDialect {}),
+            Box::new(PostgreSqlDialect {}),
+            Box::new(MsSqlDialect {}),
+            Box::new(AnsiDialect {}),
+            Box::new(SnowflakeDialect {}),
+            // Note: `$` is the starting word for the HiveDialect identifier
+            // Box::new(sqlparser::dialect::HiveDialect {}),
+        ],
+    };
+    let sql = "SELECT * FROM student WHERE id = $Id1";
+    let ast = dialects.verified_only_select(sql);
+    assert_eq!(
+        ast.selection,
+        Some(Expr::BinaryOp {
+            left: Box::new(Expr::Identifier(Ident::new("id"))),
+            op: BinaryOperator::Eq,
+            right: Box::new(Expr::Value(Value::Placeholder("$Id1".into())))
+        })
+    );
+}
+
 #[test]
 fn all_keywords_sorted() {
     // assert!(ALL_KEYWORDS.is_sorted())