diff --git a/requirements/base.txt b/requirements/base.txt index 98d2a8094eee3..b198a5c9cee6d 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -141,7 +141,9 @@ geographiclib==1.52 geopy==2.2.0 # via apache-superset greenlet==2.0.2 - # via shillelagh + # via + # shillelagh + # sqlalchemy gunicorn==21.2.0 # via apache-superset hashids==1.3.1 @@ -155,7 +157,10 @@ idna==3.2 # email-validator # requests importlib-metadata==6.6.0 - # via apache-superset + # via + # apache-superset + # flask + # shillelagh importlib-resources==5.12.0 # via limits isodate==0.6.0 @@ -327,6 +332,8 @@ sqlalchemy-utils==0.38.3 # via # apache-superset # flask-appbuilder +sqlglot==20.8.0 + # via apache-superset sqlparse==0.4.4 # via apache-superset sshtunnel==0.4.0 @@ -376,7 +383,9 @@ wtforms-json==0.3.5 xlsxwriter==3.0.7 # via apache-superset zipp==3.15.0 - # via importlib-metadata + # via + # importlib-metadata + # importlib-resources # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/requirements/testing.txt b/requirements/testing.txt index 3bf3c78d03728..b40497c8fc130 100644 --- a/requirements/testing.txt +++ b/requirements/testing.txt @@ -24,10 +24,6 @@ db-dtypes==1.1.1 # via pandas-gbq docker==6.1.1 # via -r requirements/testing.in -exceptiongroup==1.1.1 - # via pytest -ephem==4.1.4 - # via lunarcalendar flask-testing==0.8.1 # via -r requirements/testing.in fonttools==4.39.4 @@ -121,6 +117,8 @@ pyee==9.0.4 # via playwright pyfakefs==5.2.2 # via -r requirements/testing.in +pyhive[presto]==0.7.0 + # via apache-superset pytest==7.3.1 # via # -r requirements/testing.in diff --git a/setup.py b/setup.py index fd2a9f8c80cd6..cb02a7f49095a 100644 --- a/setup.py +++ b/setup.py @@ -125,6 +125,7 @@ def get_git_sha() -> str: "slack_sdk>=3.19.0, <4", "sqlalchemy>=1.4, <2", "sqlalchemy-utils>=0.38.3, <0.39", + "sqlglot>=20,<21", "sqlparse>=0.4.4, <0.5", "tabulate>=0.8.9, <0.9", "typing-extensions>=4, <5", diff --git a/superset/commands/dataset/duplicate.py b/superset/commands/dataset/duplicate.py index 0ae47c35bca4d..850290422e1c5 100644 --- a/superset/commands/dataset/duplicate.py +++ b/superset/commands/dataset/duplicate.py @@ -70,7 +70,10 @@ def run(self) -> Model: table.normalize_columns = self._base_model.normalize_columns table.always_filter_main_dttm = self._base_model.always_filter_main_dttm table.is_sqllab_view = True - table.sql = ParsedQuery(self._base_model.sql).stripped() + table.sql = ParsedQuery( + self._base_model.sql, + engine=database.db_engine_spec.engine, + ).stripped() db.session.add(table) cols = [] for config_ in self._base_model.columns: diff --git a/superset/commands/sql_lab/export.py b/superset/commands/sql_lab/export.py index 1b9b0e03442fa..aa6050f27f9ae 100644 --- a/superset/commands/sql_lab/export.py +++ b/superset/commands/sql_lab/export.py @@ -115,7 +115,10 @@ def run( limit = None else: sql = self._query.executed_sql - limit = ParsedQuery(sql).limit + limit = ParsedQuery( + sql, + engine=self._query.database.db_engine_spec.engine, + ).limit if limit is not None and self._query.limiting_factor in { LimitingFactor.QUERY, LimitingFactor.DROPDOWN, diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 624eb2ce5a530..08dc923c21b27 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -1457,7 +1457,7 @@ def get_from_clause( return self.get_sqla_table(), None from_sql = self.get_rendered_sql(template_processor) - parsed_query = ParsedQuery(from_sql) + parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine) if not ( parsed_query.is_unknown() or self.db_engine_spec.is_readonly_query(parsed_query) diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py index 66594084c82d5..688be53515040 100644 --- a/superset/connectors/sqla/utils.py +++ b/superset/connectors/sqla/utils.py @@ -111,7 +111,7 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]: sql = dataset.get_template_processor().process_template( dataset.sql, **dataset.template_params_dict ) - parsed_query = ParsedQuery(sql) + parsed_query = ParsedQuery(sql, engine=db_engine_spec.engine) if not db_engine_spec.is_readonly_query(parsed_query): raise SupersetSecurityException( SupersetError( diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 48e44064acfdf..3b8bb2bd33292 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -899,7 +899,7 @@ def apply_limit_to_sql( return database.compile_sqla_query(qry) if cls.limit_method == LimitMethod.FORCE_LIMIT: - parsed_query = sql_parse.ParsedQuery(sql) + parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine) sql = parsed_query.set_or_update_query_limit(limit, force=force) return sql @@ -980,7 +980,7 @@ def get_limit_from_sql(cls, sql: str) -> int | None: :param sql: SQL query :return: Value of limit clause in query """ - parsed_query = sql_parse.ParsedQuery(sql) + parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine) return parsed_query.limit @classmethod @@ -992,7 +992,7 @@ def set_or_update_query_limit(cls, sql: str, limit: int) -> str: :param limit: New limit to insert/replace into query :return: Query with new limit """ - parsed_query = sql_parse.ParsedQuery(sql) + parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine) return parsed_query.set_or_update_query_limit(limit) @classmethod @@ -1487,7 +1487,7 @@ def process_statement(cls, statement: str, database: Database) -> str: :param database: Database instance :return: Dictionary with different costs """ - parsed_query = ParsedQuery(statement) + parsed_query = ParsedQuery(statement, engine=cls.engine) sql = parsed_query.stripped() sql_query_mutator = current_app.config["SQL_QUERY_MUTATOR"] mutate_after_split = current_app.config["MUTATE_AFTER_SPLIT"] @@ -1522,7 +1522,7 @@ def estimate_query_cost( "Database does not support cost estimation" ) - parsed_query = sql_parse.ParsedQuery(sql) + parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine) statements = parsed_query.get_statements() costs = [] @@ -1583,7 +1583,7 @@ def execute( # pylint: disable=unused-argument :return: """ if not cls.allows_sql_comments: - query = sql_parse.strip_comments_from_sql(query) + query = sql_parse.strip_comments_from_sql(query, engine=cls.engine) if cls.arraysize: cursor.arraysize = cls.arraysize diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index 8e7ed0bf7d061..a8d834276e60c 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -435,7 +435,7 @@ def estimate_query_cost( if not cls.get_allow_cost_estimate(extra): raise SupersetException("Database does not support cost estimation") - parsed_query = sql_parse.ParsedQuery(sql) + parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine) statements = parsed_query.get_statements() costs = [] for statement in statements: diff --git a/superset/models/helpers.py b/superset/models/helpers.py index a6d879b785e48..9c8e83147e9ed 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -1093,7 +1093,7 @@ def get_from_clause( """ from_sql = self.get_rendered_sql(template_processor) - parsed_query = ParsedQuery(from_sql) + parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine) if not ( parsed_query.is_unknown() or self.db_engine_spec.is_readonly_query(parsed_query) diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index ca530ff8b97f5..a0e9fa6b6eb0a 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -183,7 +183,7 @@ def username(self) -> str: @property def sql_tables(self) -> list[Table]: - return list(ParsedQuery(self.sql).tables) + return list(ParsedQuery(self.sql, engine=self.db_engine_spec.engine).tables) @property def columns(self) -> list["TableColumn"]: @@ -427,7 +427,9 @@ def url(self) -> str: @property def sql_tables(self) -> list[Table]: - return list(ParsedQuery(self.sql).tables) + return list( + ParsedQuery(self.sql, engine=self.database.db_engine_spec.engine).tables + ) @property def last_run_humanized(self) -> str: diff --git a/superset/security/manager.py b/superset/security/manager.py index 618a7e2808a46..7e1b697840ac0 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -1876,7 +1876,10 @@ def raise_for_access( default_schema = database.get_default_schema_for_query(query) tables = { Table(table_.table, table_.schema or default_schema) - for table_ in sql_parse.ParsedQuery(query.sql).tables + for table_ in sql_parse.ParsedQuery( + query.sql, + engine=database.db_engine_spec.engine, + ).tables } elif table: tables = {table} diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 1029ff402ca3c..1b883a77cfbbc 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -199,7 +199,7 @@ def execute_sql_statement( database: Database = query.database db_engine_spec = database.db_engine_spec - parsed_query = ParsedQuery(sql_statement) + parsed_query = ParsedQuery(sql_statement, engine=db_engine_spec.engine) if is_feature_enabled("RLS_IN_SQLLAB"): # There are two ways to insert RLS: either replacing the table with a subquery # that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is @@ -219,7 +219,8 @@ def execute_sql_statement( database.id, query.schema, ) - ) + ), + engine=db_engine_spec.engine, ) sql = parsed_query.stripped() @@ -409,7 +410,11 @@ def execute_sql_statements( ) # Breaking down into multiple statements - parsed_query = ParsedQuery(rendered_query, strip_comments=True) + parsed_query = ParsedQuery( + rendered_query, + strip_comments=True, + engine=db_engine_spec.engine, + ) if not db_engine_spec.run_multiple_statements_as_one: statements = parsed_query.get_statements() logger.info( diff --git a/superset/sql_parse.py b/superset/sql_parse.py index cecd673276976..07704171dee3d 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -14,15 +14,22 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +# pylint: disable=too-many-lines + import logging import re -from collections.abc import Iterator +import urllib.parse +from collections.abc import Iterable, Iterator from dataclasses import dataclass from typing import Any, cast, Optional -from urllib import parse import sqlparse from sqlalchemy import and_ +from sqlglot import exp, parse, parse_one +from sqlglot.dialects import Dialects +from sqlglot.errors import ParseError +from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope from sqlparse import keywords from sqlparse.lexer import Lexer from sqlparse.sql import ( @@ -53,7 +60,7 @@ try: from sqloxide import parse_sql as sqloxide_parse -except: # pylint: disable=bare-except +except (ImportError, ModuleNotFoundError): sqloxide_parse = None RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"} @@ -72,6 +79,59 @@ lex.set_SQL_REGEX(sqlparser_sql_regex) +# mapping between DB engine specs and sqlglot dialects +SQLGLOT_DIALECTS = { + "ascend": Dialects.HIVE, + "awsathena": Dialects.PRESTO, + "bigquery": Dialects.BIGQUERY, + "clickhouse": Dialects.CLICKHOUSE, + "clickhousedb": Dialects.CLICKHOUSE, + "cockroachdb": Dialects.POSTGRES, + # "crate": ??? + # "databend": ??? + "databricks": Dialects.DATABRICKS, + # "db2": ??? + # "dremio": ??? + "drill": Dialects.DRILL, + # "druid": ??? + "duckdb": Dialects.DUCKDB, + # "dynamodb": ??? + # "elasticsearch": ??? + # "exa": ??? + # "firebird": ??? + # "firebolt": ??? + "gsheets": Dialects.SQLITE, + "hana": Dialects.POSTGRES, + "hive": Dialects.HIVE, + # "ibmi": ??? + # "impala": ??? + # "kustokql": ??? + # "kylin": ??? + # "mssql": ??? + "mysql": Dialects.MYSQL, + "netezza": Dialects.POSTGRES, + # "ocient": ??? + # "odelasticsearch": ??? + "oracle": Dialects.ORACLE, + # "pinot": ??? + "postgresql": Dialects.POSTGRES, + "presto": Dialects.PRESTO, + "pydoris": Dialects.DORIS, + "redshift": Dialects.REDSHIFT, + # "risingwave": ??? + # "rockset": ??? + "shillelagh": Dialects.SQLITE, + "snowflake": Dialects.SNOWFLAKE, + # "solr": ??? + "sqlite": Dialects.SQLITE, + "starrocks": Dialects.STARROCKS, + "superset": Dialects.SQLITE, + "teradatasql": Dialects.TERADATA, + "trino": Dialects.TRINO, + "vertica": Dialects.POSTGRES, +} + + class CtasMethod(StrEnum): TABLE = "TABLE" VIEW = "VIEW" @@ -150,7 +210,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]: return cte, remainder -def strip_comments_from_sql(statement: str) -> str: +def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> str: """ Strips comments from a SQL statement, does a simple test first to avoid always instantiating the expensive ParsedQuery constructor @@ -160,7 +220,11 @@ def strip_comments_from_sql(statement: str) -> str: :param statement: A string with the SQL statement :return: SQL statement without comments """ - return ParsedQuery(statement).strip_comments() if "--" in statement else statement + return ( + ParsedQuery(statement, engine=engine).strip_comments() + if "--" in statement + else statement + ) @dataclass(eq=True, frozen=True) @@ -179,7 +243,7 @@ def __str__(self) -> str: """ return ".".join( - parse.quote(part, safe="").replace(".", "%2E") + urllib.parse.quote(part, safe="").replace(".", "%2E") for part in [self.catalog, self.schema, self.table] if part ) @@ -189,11 +253,17 @@ def __eq__(self, __o: object) -> bool: class ParsedQuery: - def __init__(self, sql_statement: str, strip_comments: bool = False): + def __init__( + self, + sql_statement: str, + strip_comments: bool = False, + engine: Optional[str] = None, + ): if strip_comments: sql_statement = sqlparse.format(sql_statement, strip_comments=True) self.sql: str = sql_statement + self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None self._tables: set[Table] = set() self._alias_names: set[str] = set() self._limit: Optional[int] = None @@ -206,14 +276,94 @@ def __init__(self, sql_statement: str, strip_comments: bool = False): @property def tables(self) -> set[Table]: if not self._tables: - for statement in self._parsed: - self._extract_from_token(statement) - - self._tables = { - table for table in self._tables if str(table) not in self._alias_names - } + self._tables = self._extract_tables_from_sql() return self._tables + def _extract_tables_from_sql(self) -> set[Table]: + """ + Extract all table references in a query. + + Note: this uses sqlglot, since it's better at catching more edge cases. + """ + try: + statements = parse(self.sql, dialect=self._dialect) + except ParseError: + logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql) + return set() + + return { + table + for statement in statements + for table in self._extract_tables_from_statement(statement) + if statement + } + + def _extract_tables_from_statement(self, statement: exp.Expression) -> set[Table]: + """ + Extract all table references in a single statement. + + Please not that this is not trivial; consider the following queries: + + DESCRIBE some_table; + SHOW PARTITIONS FROM some_table; + WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name; + + See the unit tests for other tricky cases. + """ + sources: Iterable[exp.Table] + + if isinstance(statement, exp.Describe): + # A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly + # query for all tables. + sources = statement.find_all(exp.Table) + elif isinstance(statement, exp.Command): + # Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a + # `SELECT` statetement in order to extract tables. + literal = statement.find(exp.Literal) + if not literal: + return set() + + pseudo_query = parse_one(f"SELECT {literal.this}", dialect=self._dialect) + sources = pseudo_query.find_all(exp.Table) + else: + sources = [ + source + for scope in traverse_scope(statement) + for source in scope.sources.values() + if isinstance(source, exp.Table) and not self._is_cte(source, scope) + ] + + return { + Table( + source.name, + source.db if source.db != "" else None, + source.catalog if source.catalog != "" else None, + ) + for source in sources + } + + def _is_cte(self, source: exp.Table, scope: Scope) -> bool: + """ + Is the source a CTE? + + CTEs in the parent scope look like tables (and are represented by + exp.Table objects), but should not be considered as such; + otherwise a user with access to table `foo` could access any table + with a query like this: + + WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo + + """ + parent_sources = scope.parent.sources if scope.parent else {} + ctes_in_scope = { + name + for name, parent_scope in parent_sources.items() + if isinstance(parent_scope, Scope) + and parent_scope.scope_type == ScopeType.CTE + } + + return source.name in ctes_in_scope + @property def limit(self) -> Optional[int]: return self._limit @@ -393,28 +543,6 @@ def get_table(tlist: TokenList) -> Optional[Table]: def _is_identifier(token: Token) -> bool: return isinstance(token, (IdentifierList, Identifier)) - def _process_tokenlist(self, token_list: TokenList) -> None: - """ - Add table names to table set - - :param token_list: TokenList to be processed - """ - # exclude subselects - if "(" not in str(token_list): - table = self.get_table(token_list) - if table and not table.table.startswith(CTE_PREFIX): - self._tables.add(table) - return - - # store aliases - if token_list.has_alias(): - self._alias_names.add(token_list.get_alias()) - - # some aliases are not parsed properly - if token_list.tokens[0].ttype == Name: - self._alias_names.add(token_list.tokens[0].value) - self._extract_from_token(token_list) - def as_create_table( self, table_name: str, @@ -441,50 +569,6 @@ def as_create_table( exec_sql += f"CREATE {method} {full_table_name} AS \n{sql}" return exec_sql - def _extract_from_token(self, token: Token) -> None: - """ - store a list of subtokens and store lists of - subtoken list. - - It extracts and from :param token: and loops - through all subtokens recursively. It finds table_name_preceding_token and - passes and to self._process_tokenlist to populate - self._tables. - - :param token: instance of Token or child class, e.g. TokenList, to be processed - """ - if not hasattr(token, "tokens"): - return - - table_name_preceding_token = False - - for item in token.tokens: - if item.is_group and ( - not self._is_identifier(item) or isinstance(item.tokens[0], Parenthesis) - ): - self._extract_from_token(item) - - if item.ttype in Keyword and ( - item.normalized in PRECEDES_TABLE_NAME - or item.normalized.endswith(" JOIN") - ): - table_name_preceding_token = True - continue - - if item.ttype in Keyword: - table_name_preceding_token = False - continue - if table_name_preceding_token: - if isinstance(item, Identifier): - self._process_tokenlist(item) - elif isinstance(item, IdentifierList): - for token2 in item.get_identifiers(): - if isinstance(token2, TokenList): - self._process_tokenlist(token2) - elif isinstance(item, IdentifierList): - if any(not self._is_identifier(token2) for token2 in item.tokens): - self._extract_from_token(item) - def set_or_update_query_limit(self, new_limit: int, force: bool = False) -> str: """Returns the query with the specified limit. @@ -881,7 +965,7 @@ def insert_rls_in_predicate( # mapping between sqloxide and SQLAlchemy dialects -SQLOXITE_DIALECTS = { +SQLOXIDE_DIALECTS = { "ansi": {"trino", "trinonative", "presto"}, "hive": {"hive", "databricks"}, "ms": {"mssql"}, @@ -914,7 +998,7 @@ def extract_table_references( tree = None if sqloxide_parse: - for dialect, sqla_dialects in SQLOXITE_DIALECTS.items(): + for dialect, sqla_dialects in SQLOXIDE_DIALECTS.items(): if sqla_dialect in sqla_dialects: break sql_text = RE_JINJA_BLOCK.sub(" ", sql_text) diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py index c01b9386718ca..fed1ff3bfae62 100644 --- a/superset/sql_validators/presto_db.py +++ b/superset/sql_validators/presto_db.py @@ -50,7 +50,7 @@ def validate_statement( ) -> Optional[SQLValidationAnnotation]: # pylint: disable=too-many-locals db_engine_spec = database.db_engine_spec - parsed_query = ParsedQuery(statement) + parsed_query = ParsedQuery(statement, engine=db_engine_spec.engine) sql = parsed_query.stripped() # Hook to allow environment-specific mutation (usually comments) to the SQL @@ -154,7 +154,7 @@ def validate( For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE VALIDATE) SELECT 1 FROM default.mytable. """ - parsed_query = ParsedQuery(sql) + parsed_query = ParsedQuery(sql, engine=database.db_engine_spec.engine) statements = parsed_query.get_statements() logger.info("Validating %i statement(s)", len(statements)) diff --git a/superset/sqllab/query_render.py b/superset/sqllab/query_render.py index f4c1c26c6eb4e..5597bcb086d7c 100644 --- a/superset/sqllab/query_render.py +++ b/superset/sqllab/query_render.py @@ -58,7 +58,11 @@ def render(self, execution_context: SqlJsonExecutionContext) -> str: database=query_model.database, query=query_model ) - parsed_query = ParsedQuery(query_model.sql, strip_comments=True) + parsed_query = ParsedQuery( + query_model.sql, + strip_comments=True, + engine=query_model.database.db_engine_spec.engine, + ) rendered_query = sql_template_processor.process_template( parsed_query.stripped(), **execution_context.template_params ) diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index efd883810147e..f650b77734f36 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -40,11 +40,11 @@ ) -def extract_tables(query: str) -> set[Table]: +def extract_tables(query: str, engine: Optional[str] = None) -> set[Table]: """ Helper function to extract tables referenced in a query. """ - return ParsedQuery(query).tables + return ParsedQuery(query, engine=engine).tables def test_table() -> None: @@ -96,8 +96,13 @@ def test_extract_tables() -> None: Table("left_table") } - # reverse select - assert extract_tables("FROM t1 SELECT field") == {Table("t1")} + assert extract_tables( + "SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;" + ) == {Table("forbidden_table")} + + assert extract_tables( + "select * from (select * from forbidden_table) forbidden_table" + ) == {Table("forbidden_table")} def test_extract_tables_subselect() -> None: @@ -263,14 +268,16 @@ def test_extract_tables_illdefined() -> None: assert extract_tables("SELECT * FROM schemaname.") == set() assert extract_tables("SELECT * FROM catalogname.schemaname.") == set() assert extract_tables("SELECT * FROM catalogname..") == set() - assert extract_tables("SELECT * FROM catalogname..tbname") == set() + assert extract_tables("SELECT * FROM catalogname..tbname") == { + Table(table="tbname", schema=None, catalog="catalogname") + } def test_extract_tables_show_tables_from() -> None: """ Test ``SHOW TABLES FROM``. """ - assert extract_tables("SHOW TABLES FROM s1 like '%order%'") == set() + assert extract_tables("SHOW TABLES FROM s1 like '%order%'", "mysql") == set() def test_extract_tables_show_columns_from() -> None: @@ -311,7 +318,7 @@ def test_extract_tables_where_subquery() -> None: """ SELECT name FROM t1 -WHERE regionkey EXISTS (SELECT regionkey FROM t2) +WHERE EXISTS (SELECT 1 FROM t2 WHERE t1.regionkey = t2.regionkey); """ ) == {Table("t1"), Table("t2")} @@ -526,6 +533,18 @@ def test_extract_tables_reusing_aliases() -> None: == {Table("src")} ) + # weird query with circular dependency + assert ( + extract_tables( + """ +with src as ( select key from q2 where key = '5'), +q2 as ( select key from src where key = '5') +select * from (select key from src) a +""" + ) + == set() + ) + def test_extract_tables_multistatement() -> None: """ @@ -665,7 +684,8 @@ def test_extract_tables_nested_select() -> None: select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME) from INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA like "%bi%"),0x7e))); -""" +""", + "mysql", ) == {Table("COLUMNS", "INFORMATION_SCHEMA")} ) @@ -676,7 +696,8 @@ def test_extract_tables_nested_select() -> None: select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME) from INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME="bi_achievement_daily"),0x7e))); -""" +""", + "mysql", ) == {Table("COLUMNS", "INFORMATION_SCHEMA")} ) @@ -1306,6 +1327,14 @@ def test_sqlparse_issue_652(): "(SELECT table_name FROM /**/ information_schema.tables WHERE table_name LIKE '%user%' LIMIT 1)", True, ), + ( + "SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;", + True, + ), + ( + "SELECT * FROM (SELECT * FROM forbidden_table) forbidden_table", + True, + ), ], ) def test_has_table_query(sql: str, expected: bool) -> None: @@ -1790,13 +1819,17 @@ def test_extract_table_references(mocker: MockerFixture) -> None: assert extract_table_references( sql, "trino", - ) == {Table(table="other_table", schema=None, catalog=None)} + ) == { + Table(table="table", schema=None, catalog=None), + Table(table="other_table", schema=None, catalog=None), + } logger.warning.assert_called_once() logger = mocker.patch("superset.migrations.shared.utils.logger") sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table" assert extract_table_references(sql, "trino", show_warning=False) == { - Table(table="other_table", schema=None, catalog=None) + Table(table="table", schema=None, catalog=None), + Table(table="other_table", schema=None, catalog=None), } logger.warning.assert_not_called()