From 51a2627b7caee10565cb40a154bab19d2ed63e9c Mon Sep 17 00:00:00 2001 From: Eyal Leshem Date: Sun, 16 Aug 2020 19:12:40 +0300 Subject: [PATCH] Add support for sql variable inside query in snowflake and mysql dialect see : https://docs.snowflake.com/en/sql-reference/session-variables.html https://dev.mysql.com/doc/refman/8.0/en/user-variables.html --- src/ast/mod.rs | 5 +++++ src/parser.rs | 10 +++++++++ src/tokenizer.rs | 8 ++++++++ tests/sqlparser_mysql.rs | 33 ++++++++++++++++++++++++++++++ tests/sqlparser_snowflake.rs | 39 ++++++++++++++++++++++++++++++++++++ 5 files changed, 95 insertions(+) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 7b143349b..e3404637c 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -235,6 +235,10 @@ pub enum Expr { Subquery(Box), /// The `LISTAGG` function `SELECT LISTAGG(...) WITHIN GROUP (ORDER BY ...)` ListAgg(ListAgg), + /// Embed variable inside a query is supported by some databases: + /// - Mysql: https://dev.mysql.com/doc/refman/8.0/en/user-variables.html + /// - Snowflake: https://docs.snowflake.com/en/sql-reference/session-variables.html + SqlVariable { prefix: char, name: Ident }, } impl fmt::Display for Expr { @@ -315,6 +319,7 @@ impl fmt::Display for Expr { Expr::Exists(s) => write!(f, "EXISTS ({})", s), Expr::Subquery(s) => write!(f, "({})", s), Expr::ListAgg(listagg) => write!(f, "{}", listagg), + Expr::SqlVariable { prefix, name } => write!(f, "{}{}", prefix, name), } } } diff --git a/src/parser.rs b/src/parser.rs index 431984a19..c40a7e0aa 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -312,6 +312,16 @@ impl<'a> Parser<'a> { self.expect_token(&Token::RParen)?; Ok(expr) } + Token::Dollar if dialect_of!(self is SnowflakeDialect) => { + // Snowflake user defined variables starts with $ + let name = self.parse_identifier()?; + Ok(Expr::SqlVariable { prefix: '$', name }) + } + Token::At if dialect_of!(self is MySqlDialect) => { + // Mysql user defined variables starts with @ + let name = self.parse_identifier()?; + Ok(Expr::SqlVariable { prefix: '@', name }) + } unexpected => self.expected("an expression", unexpected), }?; diff --git a/src/tokenizer.rs b/src/tokenizer.rs index 644066989..b47d6b5b6 100644 --- a/src/tokenizer.rs +++ b/src/tokenizer.rs @@ -101,6 +101,10 @@ pub enum Token { RBrace, /// Right Arrow `=>` RArrow, + /// Dollar sign `$` + Dollar, + /// At sign `@` + At, } impl fmt::Display for Token { @@ -142,6 +146,8 @@ impl fmt::Display for Token { Token::LBrace => f.write_str("{"), Token::RBrace => f.write_str("}"), Token::RArrow => f.write_str("=>"), + Token::Dollar => f.write_str("$"), + Token::At => f.write_str("@"), } } } @@ -448,6 +454,8 @@ impl<'a> Tokenizer<'a> { '^' => self.consume_and_return(chars, Token::Caret), '{' => self.consume_and_return(chars, Token::LBrace), '}' => self.consume_and_return(chars, Token::RBrace), + '$' => self.consume_and_return(chars, Token::Dollar), + '@' => self.consume_and_return(chars, Token::At), other => self.consume_and_return(chars, Token::Char(other)), }, None => Ok(None), diff --git a/tests/sqlparser_mysql.rs b/tests/sqlparser_mysql.rs index c0fc8c8ba..7e7ab6d09 100644 --- a/tests/sqlparser_mysql.rs +++ b/tests/sqlparser_mysql.rs @@ -152,6 +152,39 @@ fn parse_quote_identifiers() { } } +#[test] +fn test_query_with_variable_name() { + let sql = "SELECT @var1"; + let select = mysql().verified_only_select(sql); + + assert_eq!( + only(select.projection), + SelectItem::UnnamedExpr(Expr::SqlVariable { + prefix: '@', + name: Ident::new("var1") + },) + ); + + let sql = "SELECT c1 FROM t1 WHERE num BETWEEN @min AND @max"; + let select = mysql().verified_only_select(sql); + + assert_eq!( + select.selection.unwrap(), + Expr::Between { + expr: Box::new(Expr::Identifier("num".into())), + low: Box::new(Expr::SqlVariable { + prefix: '@', + name: Ident::new("min") + }), + high: Box::new(Expr::SqlVariable { + prefix: '@', + name: Ident::new("max") + }), + negated: false, + } + ); +} + fn mysql() -> TestedDialects { TestedDialects { dialects: vec![Box::new(MySqlDialect {})], diff --git a/tests/sqlparser_snowflake.rs b/tests/sqlparser_snowflake.rs index 086d57264..f06342740 100644 --- a/tests/sqlparser_snowflake.rs +++ b/tests/sqlparser_snowflake.rs @@ -24,6 +24,45 @@ fn test_snowflake_create_table() { } } +#[test] +fn test_query_with_variable_name() { + let sql = "SELECT $var1"; + let select = snowflake().verified_only_select(sql); + + assert_eq!( + only(select.projection), + SelectItem::UnnamedExpr(Expr::SqlVariable { + prefix: '$', + name: Ident::new("var1") + },) + ); + + let sql = "SELECT c1 FROM t1 WHERE num BETWEEN $min AND $max"; + let select = snowflake().verified_only_select(sql); + + assert_eq!( + select.selection.unwrap(), + Expr::Between { + expr: Box::new(Expr::Identifier("num".into())), + low: Box::new(Expr::SqlVariable { + prefix: '$', + name: Ident::new("min") + }), + high: Box::new(Expr::SqlVariable { + prefix: '$', + name: Ident::new("max") + }), + negated: false, + } + ); +} + +fn snowflake() -> TestedDialects { + TestedDialects { + dialects: vec![Box::new(SnowflakeDialect {})], + } +} + fn snowflake_and_generic() -> TestedDialects { TestedDialects { dialects: vec![Box::new(SnowflakeDialect {}), Box::new(GenericDialect {})],