Skip to content

Commit

Permalink
Small improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed Jan 16, 2024
1 parent c831099 commit 01fe20e
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 22 deletions.
108 changes: 95 additions & 13 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
# under the License.
import logging
import re
from collections.abc import Iterator
import urllib.parse
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from typing import Any, cast, Optional
from urllib import parse

import sqlparse
from sqlalchemy import and_
from sqlglot import exp, parse_one
from sqlglot.optimizer.scope import traverse_scope
from sqlglot import exp, parse, parse_one
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
from sqlparse import keywords
from sqlparse.lexer import Lexer
from sqlparse.sql import (
Expand Down Expand Up @@ -181,7 +181,7 @@ def __str__(self) -> str:
"""

return ".".join(
parse.quote(part, safe="").replace(".", "%2E")
urllib.parse.quote(part, safe="").replace(".", "%2E")
for part in [self.catalog, self.schema, self.table]
if part
)
Expand All @@ -191,11 +191,17 @@ def __eq__(self, __o: object) -> bool:


class ParsedQuery:
def __init__(self, sql_statement: str, strip_comments: bool = False):
def __init__(
self,
sql_statement: str,
strip_comments: bool = False,
dialect: Optional[str] = None,
):
if strip_comments:
sql_statement = sqlparse.format(sql_statement, strip_comments=True)

self.sql: str = sql_statement
self.dialect = dialect
self._tables: set[Table] = set()
self._alias_names: set[str] = set()
self._limit: Optional[int] = None
Expand All @@ -208,15 +214,91 @@ def __init__(self, sql_statement: str, strip_comments: bool = False):
@property
def tables(self) -> set[Table]:
if not self._tables:
self._tables = {
Table(source.name, source.db if source.db != "" else None)
for scope in traverse_scope(parse_one(self.sql))
for source in scope.sources.values()
if isinstance(source, exp.Table)
}

self._tables = self._extract_tables_from_sql()
return self._tables

def _extract_tables_from_sql(self) -> set[Table]:
"""
Extract all table references in a query.
Note: this uses sqlglot, since it's better at catching more edge cases.
"""
try:
statements = parse(self.sql, dialect=self.dialect)
except Exception: # pylint: disable=broad-exception-caught
logger.warning("Unable to parse SQL (%s): %s", self.dialect, self.sql)
return set()

return {
table
for statement in statements
for table in self._extract_tables_from_statement(statement)
if statement
}

def _extract_tables_from_statement(self, statement: exp.Expression) -> set[Table]:
"""
Extract all table references in a single statement.
Please not that this is not trivial; consider the following queries:
DESCRIBE some_table;
SHOW PARTITIONS FROM some_table;
WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name;
See the unit tests for other tricky cases.
"""
sources: Iterable[exp.Table]

if isinstance(statement, exp.Describe):
# A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly
# query for all tables.
sources = statement.find_all(exp.Table)
elif isinstance(statement, exp.Command):
# Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a
# `SELECT` statetement in order to extract tables.
literal = statement.find(exp.Literal)
if not literal:
return set()

Check warning on line 262 in superset/sql_parse.py

View check run for this annotation

Codecov / codecov/patch

superset/sql_parse.py#L262

Added line #L262 was not covered by tests

pseudo_query = parse_one(f"SELECT {literal.this}", dialect=self.dialect)
sources = pseudo_query.find_all(exp.Table)
elif statement:
sources = []
for scope in traverse_scope(statement):
for source in scope.sources.values():
if not isinstance(source, exp.Table):
continue

# CTEs in the parent scope look like tables (and are represented by
# exp.Table objects), but should not be considered as such;
# otherwise a user with access to table `foo` could access any table
# with a query like this:
#
# WITH foo AS (SELECT * FROM bar) SELECT * FROM foo
#
parent_sources = scope.parent.sources if scope.parent else {}
ctes_in_scope = {
name
for name, parent_scope in parent_sources.items()
if isinstance(parent_scope, Scope)
and parent_scope.scope_type == ScopeType.CTE
}
if source.name not in ctes_in_scope:
sources.append(source)

else:
return set()

Check warning on line 291 in superset/sql_parse.py

View check run for this annotation

Codecov / codecov/patch

superset/sql_parse.py#L291

Added line #L291 was not covered by tests

return {
Table(
source.name,
source.db if source.db != "" else None,
source.catalog if source.catalog != "" else None,
)
for source in sources
}

@property
def limit(self) -> Optional[int]:
return self._limit
Expand Down
38 changes: 29 additions & 9 deletions tests/unit_tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@
)


def extract_tables(query: str) -> set[Table]:
def extract_tables(query: str, dialect: Optional[str] = None) -> set[Table]:
"""
Helper function to extract tables referenced in a query.
"""
return ParsedQuery(query).tables
return ParsedQuery(query, dialect=dialect).tables


def test_table() -> None:
Expand Down Expand Up @@ -268,14 +268,16 @@ def test_extract_tables_illdefined() -> None:
assert extract_tables("SELECT * FROM schemaname.") == set()
assert extract_tables("SELECT * FROM catalogname.schemaname.") == set()
assert extract_tables("SELECT * FROM catalogname..") == set()
assert extract_tables("SELECT * FROM catalogname..tbname") == set()
assert extract_tables("SELECT * FROM catalogname..tbname") == {
Table(table="tbname", schema=None, catalog="catalogname")
}


def test_extract_tables_show_tables_from() -> None:
"""
Test ``SHOW TABLES FROM``.
"""
assert extract_tables("SHOW TABLES FROM s1 like '%order%'") == set()
assert extract_tables("SHOW TABLES FROM s1 like '%order%'", "mysql") == set()


def test_extract_tables_show_columns_from() -> None:
Expand Down Expand Up @@ -316,7 +318,7 @@ def test_extract_tables_where_subquery() -> None:
"""
SELECT name
FROM t1
WHERE regionkey EXISTS (SELECT regionkey FROM t2)
WHERE EXISTS (SELECT 1 FROM t2 WHERE t1.regionkey = t2.regionkey);
"""
)
== {Table("t1"), Table("t2")}
Expand Down Expand Up @@ -531,6 +533,18 @@ def test_extract_tables_reusing_aliases() -> None:
== {Table("src")}
)

# weird query with circular dependency
assert (
extract_tables(
"""
with src as ( select key from q2 where key = '5'),
q2 as ( select key from src where key = '5')
select * from (select key from src) a
"""
)
== set()
)


def test_extract_tables_multistatement() -> None:
"""
Expand Down Expand Up @@ -670,7 +684,8 @@ def test_extract_tables_nested_select() -> None:
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
"""
""",
"mysql",
)
== {Table("COLUMNS", "INFORMATION_SCHEMA")}
)
Expand All @@ -681,7 +696,8 @@ def test_extract_tables_nested_select() -> None:
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME="bi_achievement_daily"),0x7e)));
"""
""",
"mysql",
)
== {Table("COLUMNS", "INFORMATION_SCHEMA")}
)
Expand Down Expand Up @@ -1803,13 +1819,17 @@ def test_extract_table_references(mocker: MockerFixture) -> None:
assert extract_table_references(
sql,
"trino",
) == {Table(table="other_table", schema=None, catalog=None)}
) == {
Table(table="table", schema=None, catalog=None),
Table(table="other_table", schema=None, catalog=None),
}
logger.warning.assert_called_once()

logger = mocker.patch("superset.migrations.shared.utils.logger")
sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table"
assert extract_table_references(sql, "trino", show_warning=False) == {
Table(table="other_table", schema=None, catalog=None)
Table(table="table", schema=None, catalog=None),
Table(table="other_table", schema=None, catalog=None),
}
logger.warning.assert_not_called()

Expand Down

0 comments on commit 01fe20e

Please sign in to comment.