diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 7b89ab8f0e2cb..c85afc9460f12 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -28,7 +28,7 @@ from sqlalchemy import and_ from sqlglot import exp, parse, parse_one from sqlglot.dialects import Dialects -from sqlglot.errors import ParseError +from sqlglot.errors import SqlglotError from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope from sqlparse import keywords from sqlparse.lexer import Lexer @@ -287,7 +287,7 @@ def _extract_tables_from_sql(self) -> set[Table]: """ try: statements = parse(self.stripped(), dialect=self._dialect) - except ParseError: + except SqlglotError: logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql) return set() @@ -319,12 +319,17 @@ def _extract_tables_from_statement(self, statement: exp.Expression) -> set[Table elif isinstance(statement, exp.Command): # Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a # `SELECT` statetement in order to extract tables. - literal = statement.find(exp.Literal) - if not literal: + if not (literal := statement.find(exp.Literal)): return set() - pseudo_query = parse_one(f"SELECT {literal.this}", dialect=self._dialect) - sources = pseudo_query.find_all(exp.Table) + try: + pseudo_query = parse_one( + f"SELECT {literal.this}", + dialect=self._dialect, + ) + sources = pseudo_query.find_all(exp.Table) + except SqlglotError: + return set() else: sources = [ source diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index f05e16ae85fd0..2fd23f7e8e4f2 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -271,6 +271,7 @@ def test_extract_tables_illdefined() -> None: assert extract_tables("SELECT * FROM catalogname..tbname") == { Table(table="tbname", schema=None, catalog="catalogname") } + assert extract_tables('SELECT * FROM "tbname') == set() def test_extract_tables_show_tables_from() -> None: @@ -558,6 +559,10 @@ def test_extract_tables_multistatement() -> None: Table("t1"), Table("t2"), } + assert extract_tables( + "ADD JAR file:///hive.jar; SELECT * FROM t1;", + engine="hive", + ) == {Table("t1")} def test_extract_tables_complex() -> None: @@ -1815,10 +1820,7 @@ def test_extract_table_references(mocker: MockerFixture) -> None: # test falling back to sqlparse logger = mocker.patch("superset.sql_parse.logger") sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table" - assert extract_table_references( - sql, - "trino", - ) == { + assert extract_table_references(sql, "trino") == { Table(table="table", schema=None, catalog=None), Table(table="other_table", schema=None, catalog=None), }