Skip to content

Commit

Permalink
feat(sqlparse): improve table parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed Jan 12, 2024
1 parent e36c014 commit c831099
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 79 deletions.
15 changes: 12 additions & 3 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ geographiclib==1.52
geopy==2.2.0
# via apache-superset
greenlet==2.0.2
# via shillelagh
# via
# shillelagh
# sqlalchemy
gunicorn==21.2.0
# via apache-superset
hashids==1.3.1
Expand All @@ -155,7 +157,10 @@ idna==3.2
# email-validator
# requests
importlib-metadata==6.6.0
# via apache-superset
# via
# apache-superset
# flask
# shillelagh
importlib-resources==5.12.0
# via limits
isodate==0.6.0
Expand Down Expand Up @@ -327,6 +332,8 @@ sqlalchemy-utils==0.38.3
# via
# apache-superset
# flask-appbuilder
sqlglot==20.8.0
# via apache-superset
sqlparse==0.4.4
# via apache-superset
sshtunnel==0.4.0
Expand Down Expand Up @@ -376,7 +383,9 @@ wtforms-json==0.3.5
xlsxwriter==3.0.7
# via apache-superset
zipp==3.15.0
# via importlib-metadata
# via
# importlib-metadata
# importlib-resources

# The following packages are considered to be unsafe in a requirements file:
# setuptools
6 changes: 2 additions & 4 deletions requirements/testing.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ db-dtypes==1.1.1
# via pandas-gbq
docker==6.1.1
# via -r requirements/testing.in
exceptiongroup==1.1.1
# via pytest
ephem==4.1.4
# via lunarcalendar
flask-testing==0.8.1
# via -r requirements/testing.in
fonttools==4.39.4
Expand Down Expand Up @@ -121,6 +117,8 @@ pyee==9.0.4
# via playwright
pyfakefs==5.2.2
# via -r requirements/testing.in
pyhive[presto]==0.7.0
# via apache-superset
pytest==7.3.1
# via
# -r requirements/testing.in
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def get_git_sha() -> str:
"slack_sdk>=3.19.0, <4",
"sqlalchemy>=1.4, <2",
"sqlalchemy-utils>=0.38.3, <0.39",
"sqlglot>=20,<21",
"sqlparse>=0.4.4, <0.5",
"tabulate>=0.8.9, <0.9",
"typing-extensions>=4, <5",
Expand Down
77 changes: 7 additions & 70 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

import sqlparse
from sqlalchemy import and_
from sqlglot import exp, parse_one
from sqlglot.optimizer.scope import traverse_scope
from sqlparse import keywords
from sqlparse.lexer import Lexer
from sqlparse.sql import (
Expand Down Expand Up @@ -206,12 +208,13 @@ def __init__(self, sql_statement: str, strip_comments: bool = False):
@property
def tables(self) -> set[Table]:
if not self._tables:
for statement in self._parsed:
self._extract_from_token(statement)

self._tables = {
table for table in self._tables if str(table) not in self._alias_names
Table(source.name, source.db if source.db != "" else None)
for scope in traverse_scope(parse_one(self.sql))
for source in scope.sources.values()
if isinstance(source, exp.Table)
}

return self._tables

@property
Expand Down Expand Up @@ -393,28 +396,6 @@ def get_table(tlist: TokenList) -> Optional[Table]:
def _is_identifier(token: Token) -> bool:
return isinstance(token, (IdentifierList, Identifier))

def _process_tokenlist(self, token_list: TokenList) -> None:
"""
Add table names to table set
:param token_list: TokenList to be processed
"""
# exclude subselects
if "(" not in str(token_list):
table = self.get_table(token_list)
if table and not table.table.startswith(CTE_PREFIX):
self._tables.add(table)
return

# store aliases
if token_list.has_alias():
self._alias_names.add(token_list.get_alias())

# some aliases are not parsed properly
if token_list.tokens[0].ttype == Name:
self._alias_names.add(token_list.tokens[0].value)
self._extract_from_token(token_list)

def as_create_table(
self,
table_name: str,
Expand All @@ -441,50 +422,6 @@ def as_create_table(
exec_sql += f"CREATE {method} {full_table_name} AS \n{sql}"
return exec_sql

def _extract_from_token(self, token: Token) -> None:
"""
<Identifier> store a list of subtokens and <IdentifierList> store lists of
subtoken list.
It extracts <IdentifierList> and <Identifier> from :param token: and loops
through all subtokens recursively. It finds table_name_preceding_token and
passes <IdentifierList> and <Identifier> to self._process_tokenlist to populate
self._tables.
:param token: instance of Token or child class, e.g. TokenList, to be processed
"""
if not hasattr(token, "tokens"):
return

table_name_preceding_token = False

for item in token.tokens:
if item.is_group and (
not self._is_identifier(item) or isinstance(item.tokens[0], Parenthesis)
):
self._extract_from_token(item)

if item.ttype in Keyword and (
item.normalized in PRECEDES_TABLE_NAME
or item.normalized.endswith(" JOIN")
):
table_name_preceding_token = True
continue

if item.ttype in Keyword:
table_name_preceding_token = False
continue
if table_name_preceding_token:
if isinstance(item, Identifier):
self._process_tokenlist(item)
elif isinstance(item, IdentifierList):
for token2 in item.get_identifiers():
if isinstance(token2, TokenList):
self._process_tokenlist(token2)
elif isinstance(item, IdentifierList):
if any(not self._is_identifier(token2) for token2 in item.tokens):
self._extract_from_token(item)

def set_or_update_query_limit(self, new_limit: int, force: bool = False) -> str:
"""Returns the query with the specified limit.
Expand Down
17 changes: 15 additions & 2 deletions tests/unit_tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,13 @@ def test_extract_tables() -> None:
Table("left_table")
}

# reverse select
assert extract_tables("FROM t1 SELECT field") == {Table("t1")}
assert extract_tables(
"SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;"
) == {Table("forbidden_table")}

assert extract_tables(
"select * from (select * from forbidden_table) forbidden_table"
) == {Table("forbidden_table")}


def test_extract_tables_subselect() -> None:
Expand Down Expand Up @@ -1306,6 +1311,14 @@ def test_sqlparse_issue_652():
"(SELECT table_name FROM /**/ information_schema.tables WHERE table_name LIKE '%user%' LIMIT 1)",
True,
),
(
"SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;",
True,
),
(
"SELECT * FROM (SELECT * FROM forbidden_table) forbidden_table",
True,
),
],
)
def test_has_table_query(sql: str, expected: bool) -> None:
Expand Down

0 comments on commit c831099

Please sign in to comment.