Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for CALL statements #1063

Merged
merged 2 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 39 additions & 11 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1437,6 +1437,7 @@ pub enum Statement {
file_format: Option<FileFormat>,
source: Box<Query>,
},
Call(Function),
Copy {
/// The source of 'COPY TO', or the target of 'COPY FROM'
source: CopySource,
Expand Down Expand Up @@ -1715,7 +1716,9 @@ pub enum Statement {
///
/// Note: this is a PostgreSQL-specific statement,
/// but may also compatible with other SQL.
Discard { object_type: DiscardObject },
Discard {
object_type: DiscardObject,
},
lovasoa marked this conversation as resolved.
Show resolved Hide resolved
/// SET `[ SESSION | LOCAL ]` ROLE role_name. Examples: [ANSI][1], [Postgresql][2], [MySQL][3], and [Oracle][4].
///
/// [1]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#set-role-statement
Expand Down Expand Up @@ -1747,7 +1750,10 @@ pub enum Statement {
///
/// Note: this is a PostgreSQL-specific statements
/// `SET TIME ZONE <value>` is an alias for `SET timezone TO <value>` in PostgreSQL
SetTimeZone { local: bool, value: Expr },
SetTimeZone {
local: bool,
value: Expr,
},
/// SET NAMES 'charset_name' [COLLATE 'collation_name']
///
/// Note: this is a MySQL-specific statement.
Expand All @@ -1762,13 +1768,17 @@ pub enum Statement {
/// SHOW FUNCTIONS
///
/// Note: this is a Presto-specific statement.
ShowFunctions { filter: Option<ShowStatementFilter> },
ShowFunctions {
filter: Option<ShowStatementFilter>,
},
/// ```sql
/// SHOW <variable>
/// ```
///
/// Note: this is a PostgreSQL-specific statement.
ShowVariable { variable: Vec<Ident> },
ShowVariable {
variable: Vec<Ident>,
},
/// SHOW VARIABLES
///
/// Note: this is a MySQL-specific statement.
Expand Down Expand Up @@ -1806,11 +1816,15 @@ pub enum Statement {
/// SHOW COLLATION
///
/// Note: this is a MySQL-specific statement.
ShowCollation { filter: Option<ShowStatementFilter> },
ShowCollation {
filter: Option<ShowStatementFilter>,
},
/// USE
///
/// Note: This is a MySQL-specific statement.
Use { db_name: Ident },
Use {
db_name: Ident,
},
/// `START [ TRANSACTION | WORK ] | START TRANSACTION } ...`
/// If `begin` is false.
///
Expand Down Expand Up @@ -1838,7 +1852,9 @@ pub enum Statement {
if_exists: bool,
},
/// `COMMIT [ TRANSACTION | WORK ] [ AND [ NO ] CHAIN ]`
Commit { chain: bool },
Commit {
chain: bool,
},
/// `ROLLBACK [ TRANSACTION | WORK ] [ AND [ NO ] CHAIN ] [ TO [ SAVEPOINT ] savepoint_name ]`
Rollback {
chain: bool,
Expand Down Expand Up @@ -1934,11 +1950,17 @@ pub enum Statement {
/// `DEALLOCATE [ PREPARE ] { name | ALL }`
///
/// Note: this is a PostgreSQL-specific statement.
Deallocate { name: Ident, prepare: bool },
Deallocate {
name: Ident,
prepare: bool,
},
/// `EXECUTE name [ ( parameter [, ...] ) ]`
///
/// Note: this is a PostgreSQL-specific statement.
Execute { name: Ident, parameters: Vec<Expr> },
Execute {
name: Ident,
parameters: Vec<Expr>,
},
/// `PREPARE name [ ( data_type [, ...] ) ] AS statement`
///
/// Note: this is a PostgreSQL-specific statement.
Expand Down Expand Up @@ -1979,9 +2001,13 @@ pub enum Statement {
format: Option<AnalyzeFormat>,
},
/// SAVEPOINT -- define a new savepoint within the current transaction
Savepoint { name: Ident },
Savepoint {
name: Ident,
},
/// RELEASE \[ SAVEPOINT \] savepoint_name
ReleaseSavepoint { name: Ident },
ReleaseSavepoint {
name: Ident,
},
// MERGE INTO statement, based on Snowflake. See <https://docs.snowflake.com/en/sql-reference/sql/merge.html>
Merge {
// optional INTO keyword
Expand Down Expand Up @@ -2303,6 +2329,8 @@ impl fmt::Display for Statement {
Ok(())
}

Statement::Call(function) => write!(f, "CALL {function}"),

Statement::Copy {
source,
to,
Expand Down
27 changes: 27 additions & 0 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ impl<'a> Parser<'a> {
Keyword::UNCACHE => Ok(self.parse_uncache_table()?),
Keyword::UPDATE => Ok(self.parse_update()?),
Keyword::ALTER => Ok(self.parse_alter()?),
Keyword::CALL => Ok(self.parse_call()?),
Keyword::COPY => Ok(self.parse_copy()?),
Keyword::CLOSE => Ok(self.parse_close()?),
Keyword::SET => Ok(self.parse_set()?),
Expand Down Expand Up @@ -4773,6 +4774,32 @@ impl<'a> Parser<'a> {
})
}

/// Parse a `CALL procedure_name(arg1, arg2, ...)`
/// or `CALL procedure_name` statement
pub fn parse_call(&mut self) -> Result<Statement, ParserError> {
let object_name = self.parse_object_name()?;
if self.peek_token().token == Token::LParen {
match self.parse_function(object_name)? {
Expr::Function(f) => Ok(Statement::Call(f)),
other => parser_err!(
format!("Expected a simple procedure call but found: {other}"),
self.peek_token().location
),
}
} else {
Ok(Statement::Call(Function {
name: object_name,
args: vec![],
over: None,
distinct: false,
filter: None,
null_treatment: None,
special: true,
order_by: vec![],
}))
}
}

/// Parse a copy statement
pub fn parse_copy(&mut self) -> Result<Statement, ParserError> {
let source;
Expand Down
8 changes: 8 additions & 0 deletions tests/sqlparser_mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1817,6 +1817,14 @@ fn parse_hex_string_introducer() {
)
}

#[test]
fn parse_call() {
mysql().verified_stmt("CALL my_procedure()");
mysql().verified_stmt("CALL my_procedure(1, 'a')");
mysql().verified_stmt("CALL my_procedure(1, 'a', @my_var)");
mysql().verified_stmt("CALL my_procedure");
}

lovasoa marked this conversation as resolved.
Show resolved Hide resolved
#[test]
fn parse_string_introducers() {
mysql().verified_stmt("SELECT _binary 'abc'");
Expand Down
Loading