Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Fixing SQL parsing issue #7374

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 20 additions & 19 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import logging

import sqlparse
from sqlparse.sql import Identifier, IdentifierList
from sqlparse.sql import Identifier, IdentifierList, Token, TokenList
from sqlparse.tokens import Keyword, Name

RESULT_OPERATIONS = {'UNION', 'INTERSECT', 'EXCEPT', 'SELECT'}
Expand Down Expand Up @@ -75,32 +75,32 @@ def get_statements(self):
return statements

@staticmethod
def __get_full_name(identifier):
if len(identifier.tokens) > 2 and identifier.tokens[1].value == '.':
return '{}.{}'.format(identifier.tokens[0].value,
identifier.tokens[2].value)
return identifier.get_real_name()
def __get_full_name(tlist: TokenList):
if len(tlist.tokens) > 2 and tlist.tokens[1].value == '.':
return '{}.{}'.format(tlist.tokens[0].value,
tlist.tokens[2].value)
return tlist.get_real_name()

@staticmethod
def __is_identifier(token):
def __is_identifier(token: Token):
return isinstance(token, (IdentifierList, Identifier))

def __process_identifier(self, identifier):
def __process_tokenlist(self, tlist: TokenList):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling this identifiers is somewhat misleading as in reality it's processing TokenList objects.

# exclude subselects
if '(' not in str(identifier):
table_name = self.__get_full_name(identifier)
if '(' not in str(tlist):
table_name = self.__get_full_name(tlist)
if table_name and not table_name.startswith(CTE_PREFIX):
self._table_names.add(table_name)
return

# store aliases
if hasattr(identifier, 'get_alias'):
self._alias_names.add(identifier.get_alias())
if hasattr(identifier, 'tokens'):
# some aliases are not parsed properly
if identifier.tokens[0].ttype == Name:
self._alias_names.add(identifier.tokens[0].value)
self.__extract_from_token(identifier)
if tlist.has_alias():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By construction there's no need to check whether the get_alias attribute exists as this method is defined in the TokenList class. Note an alias can be None and thus if there is no alias we shouldn't add it.

self._alias_names.add(tlist.get_alias())

# some aliases are not parsed properly
if tlist.tokens[0].ttype == Name:
self._alias_names.add(tlist.tokens[0].value)
self.__extract_from_token(tlist)

def as_create_table(self, table_name, overwrite=False):
"""Reformats the query into the create table as query.
Expand Down Expand Up @@ -144,10 +144,11 @@ def __extract_from_token(self, token, depth=0):

if table_name_preceding_token:
if isinstance(item, Identifier):
self.__process_identifier(item)
self.__process_tokenlist(item)
elif isinstance(item, IdentifierList):
for token in item.get_identifiers():
self.__process_identifier(token)
if isinstance(token, TokenList):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no guarantee that the "identifiers" associated withe get_indentifiers (possibly poorly named) are of type Identifier and thus the additional check to verify that the token is an instance of TokenList is required.

self.__process_tokenlist(token)
elif isinstance(item, IdentifierList):
for token in item.tokens:
if not self.__is_identifier(token):
Expand Down
9 changes: 9 additions & 0 deletions tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,3 +462,12 @@ def test_messy_breakdown_statements(self):
'SELECT * FROM ab_user LIMIT 1',
]
self.assertEquals(statements, expected)

def test_identifier_list_with_keyword_as_alias(self):
query = """
WITH
f AS (SELECT * FROM foo),
match AS (SELECT * FROM f)
SELECT * FROM match
"""
self.assertEquals({'foo'}, self.extract_tables(query))