From 73947a5f021128cfccd47293ca65aa5c4e83f598 Mon Sep 17 00:00:00 2001 From: wugeer <1284057728@qq.com> Date: Wed, 20 Nov 2024 05:14:28 +0800 Subject: [PATCH] Add support for PostgreSQL `UNLISTEN` syntax and Add support for Postgres `LOAD extension` expr (#1531) Co-authored-by: Ifeanyi Ubah --- src/ast/mod.rs | 11 ++++++ src/dialect/mod.rs | 9 +---- src/dialect/postgresql.rs | 12 +++--- src/keywords.rs | 1 + src/parser/mod.rs | 22 +++++++++-- tests/sqlparser_common.rs | 77 +++++++++++++++++++++++++++++++++++++-- tests/sqlparser_duckdb.rs | 14 ------- 7 files changed, 113 insertions(+), 33 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 89e70bdd4..9185c9df4 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -3340,6 +3340,13 @@ pub enum Statement { /// See Postgres LISTEN { channel: Ident }, /// ```sql + /// UNLISTEN + /// ``` + /// stop listening for a notification + /// + /// See Postgres + UNLISTEN { channel: Ident }, + /// ```sql /// NOTIFY channel [ , payload ] /// ``` /// send a notification event together with an optional “payload” string to channel @@ -4948,6 +4955,10 @@ impl fmt::Display for Statement { write!(f, "LISTEN {channel}")?; Ok(()) } + Statement::UNLISTEN { channel } => { + write!(f, "UNLISTEN {channel}")?; + Ok(()) + } Statement::NOTIFY { channel, payload } => { write!(f, "NOTIFY {channel}")?; if let Some(payload) = payload { diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 39ea98c69..985cad749 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -633,13 +633,8 @@ pub trait Dialect: Debug + Any { false } - /// Returns true if the dialect supports the `LISTEN` statement - fn supports_listen(&self) -> bool { - false - } - - /// Returns true if the dialect supports the `NOTIFY` statement - fn supports_notify(&self) -> bool { + /// Returns true if the dialect supports the `LISTEN`, `UNLISTEN` and `NOTIFY` statements + fn supports_listen_notify(&self) -> bool { false } diff --git a/src/dialect/postgresql.rs b/src/dialect/postgresql.rs index 5af1ab853..559586e3f 100644 --- a/src/dialect/postgresql.rs +++ b/src/dialect/postgresql.rs @@ -191,12 +191,9 @@ impl Dialect for PostgreSqlDialect { } /// see - fn supports_listen(&self) -> bool { - true - } - + /// see /// see - fn supports_notify(&self) -> bool { + fn supports_listen_notify(&self) -> bool { true } @@ -209,6 +206,11 @@ impl Dialect for PostgreSqlDialect { fn supports_comment_on(&self) -> bool { true } + + /// See + fn supports_load_extension(&self) -> bool { + true + } } pub fn parse_create(parser: &mut Parser) -> Option> { diff --git a/src/keywords.rs b/src/keywords.rs index 29115a0d2..fc2a2927c 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -799,6 +799,7 @@ define_keywords!( UNION, UNIQUE, UNKNOWN, + UNLISTEN, UNLOAD, UNLOCK, UNLOGGED, diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 35ad95803..35c763e93 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -532,10 +532,11 @@ impl<'a> Parser<'a> { Keyword::EXECUTE | Keyword::EXEC => self.parse_execute(), Keyword::PREPARE => self.parse_prepare(), Keyword::MERGE => self.parse_merge(), - // `LISTEN` and `NOTIFY` are Postgres-specific + // `LISTEN`, `UNLISTEN` and `NOTIFY` are Postgres-specific // syntaxes. They are used for Postgres statement. - Keyword::LISTEN if self.dialect.supports_listen() => self.parse_listen(), - Keyword::NOTIFY if self.dialect.supports_notify() => self.parse_notify(), + Keyword::LISTEN if self.dialect.supports_listen_notify() => self.parse_listen(), + Keyword::UNLISTEN if self.dialect.supports_listen_notify() => self.parse_unlisten(), + Keyword::NOTIFY if self.dialect.supports_listen_notify() => self.parse_notify(), // `PRAGMA` is sqlite specific https://www.sqlite.org/pragma.html Keyword::PRAGMA => self.parse_pragma(), Keyword::UNLOAD => self.parse_unload(), @@ -999,6 +1000,21 @@ impl<'a> Parser<'a> { Ok(Statement::LISTEN { channel }) } + pub fn parse_unlisten(&mut self) -> Result { + let channel = if self.consume_token(&Token::Mul) { + Ident::new(Expr::Wildcard.to_string()) + } else { + match self.parse_identifier(false) { + Ok(expr) => expr, + _ => { + self.prev_token(); + return self.expected("wildcard or identifier", self.peek_token()); + } + } + }; + Ok(Statement::UNLISTEN { channel }) + } + pub fn parse_notify(&mut self) -> Result { let channel = self.parse_identifier(false)?; let payload = if self.consume_token(&Token::Comma) { diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index ecdca6b1b..3d9ba5da2 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -11595,7 +11595,7 @@ fn test_show_dbs_schemas_tables_views() { #[test] fn parse_listen_channel() { - let dialects = all_dialects_where(|d| d.supports_listen()); + let dialects = all_dialects_where(|d| d.supports_listen_notify()); match dialects.verified_stmt("LISTEN test1") { Statement::LISTEN { channel } => { @@ -11609,7 +11609,7 @@ fn parse_listen_channel() { ParserError::ParserError("Expected: identifier, found: *".to_string()) ); - let dialects = all_dialects_where(|d| !d.supports_listen()); + let dialects = all_dialects_where(|d| !d.supports_listen_notify()); assert_eq!( dialects.parse_sql_statements("LISTEN test1").unwrap_err(), @@ -11617,9 +11617,40 @@ fn parse_listen_channel() { ); } +#[test] +fn parse_unlisten_channel() { + let dialects = all_dialects_where(|d| d.supports_listen_notify()); + + match dialects.verified_stmt("UNLISTEN test1") { + Statement::UNLISTEN { channel } => { + assert_eq!(Ident::new("test1"), channel); + } + _ => unreachable!(), + }; + + match dialects.verified_stmt("UNLISTEN *") { + Statement::UNLISTEN { channel } => { + assert_eq!(Ident::new("*"), channel); + } + _ => unreachable!(), + }; + + assert_eq!( + dialects.parse_sql_statements("UNLISTEN +").unwrap_err(), + ParserError::ParserError("Expected: wildcard or identifier, found: +".to_string()) + ); + + let dialects = all_dialects_where(|d| !d.supports_listen_notify()); + + assert_eq!( + dialects.parse_sql_statements("UNLISTEN test1").unwrap_err(), + ParserError::ParserError("Expected: an SQL statement, found: UNLISTEN".to_string()) + ); +} + #[test] fn parse_notify_channel() { - let dialects = all_dialects_where(|d| d.supports_notify()); + let dialects = all_dialects_where(|d| d.supports_listen_notify()); match dialects.verified_stmt("NOTIFY test1") { Statement::NOTIFY { channel, payload } => { @@ -11655,7 +11686,7 @@ fn parse_notify_channel() { "NOTIFY test1", "NOTIFY test1, 'this is a test notification'", ]; - let dialects = all_dialects_where(|d| !d.supports_notify()); + let dialects = all_dialects_where(|d| !d.supports_listen_notify()); for &sql in &sql_statements { assert_eq!( @@ -11864,6 +11895,44 @@ fn parse_load_data() { ); } +#[test] +fn test_load_extension() { + let dialects = all_dialects_where(|d| d.supports_load_extension()); + let not_supports_load_extension_dialects = all_dialects_where(|d| !d.supports_load_extension()); + let sql = "LOAD my_extension"; + + match dialects.verified_stmt(sql) { + Statement::Load { extension_name } => { + assert_eq!(Ident::new("my_extension"), extension_name); + } + _ => unreachable!(), + }; + + assert_eq!( + not_supports_load_extension_dialects + .parse_sql_statements(sql) + .unwrap_err(), + ParserError::ParserError( + "Expected: `DATA` or an extension name after `LOAD`, found: my_extension".to_string() + ) + ); + + let sql = "LOAD 'filename'"; + + match dialects.verified_stmt(sql) { + Statement::Load { extension_name } => { + assert_eq!( + Ident { + value: "filename".to_string(), + quote_style: Some('\'') + }, + extension_name + ); + } + _ => unreachable!(), + }; +} + #[test] fn test_select_top() { let dialects = all_dialects_where(|d| d.supports_top_before_distinct()); diff --git a/tests/sqlparser_duckdb.rs b/tests/sqlparser_duckdb.rs index d68f37713..a2db5c282 100644 --- a/tests/sqlparser_duckdb.rs +++ b/tests/sqlparser_duckdb.rs @@ -359,20 +359,6 @@ fn test_duckdb_install() { ); } -#[test] -fn test_duckdb_load_extension() { - let stmt = duckdb().verified_stmt("LOAD my_extension"); - assert_eq!( - Statement::Load { - extension_name: Ident { - value: "my_extension".to_string(), - quote_style: None - } - }, - stmt - ); -} - #[test] fn test_duckdb_struct_literal() { //struct literal syntax https://duckdb.org/docs/sql/data_types/struct#creating-structs