Skip to content

Commit

Permalink
[fix] Fixing SQL parsing issue (#7374)
Browse files Browse the repository at this point in the history
  • Loading branch information
john-bodley authored May 2, 2019
1 parent ee78fd7 commit fb627ba
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 19 deletions.
39 changes: 20 additions & 19 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import logging

import sqlparse
from sqlparse.sql import Identifier, IdentifierList
from sqlparse.sql import Identifier, IdentifierList, Token, TokenList
from sqlparse.tokens import Keyword, Name

RESULT_OPERATIONS = {'UNION', 'INTERSECT', 'EXCEPT', 'SELECT'}
Expand Down Expand Up @@ -75,32 +75,32 @@ def get_statements(self):
return statements

@staticmethod
def __get_full_name(identifier):
if len(identifier.tokens) > 2 and identifier.tokens[1].value == '.':
return '{}.{}'.format(identifier.tokens[0].value,
identifier.tokens[2].value)
return identifier.get_real_name()
def __get_full_name(tlist: TokenList):
if len(tlist.tokens) > 2 and tlist.tokens[1].value == '.':
return '{}.{}'.format(tlist.tokens[0].value,
tlist.tokens[2].value)
return tlist.get_real_name()

@staticmethod
def __is_identifier(token):
def __is_identifier(token: Token):
return isinstance(token, (IdentifierList, Identifier))

def __process_identifier(self, identifier):
def __process_tokenlist(self, tlist: TokenList):
# exclude subselects
if '(' not in str(identifier):
table_name = self.__get_full_name(identifier)
if '(' not in str(tlist):
table_name = self.__get_full_name(tlist)
if table_name and not table_name.startswith(CTE_PREFIX):
self._table_names.add(table_name)
return

# store aliases
if hasattr(identifier, 'get_alias'):
self._alias_names.add(identifier.get_alias())
if hasattr(identifier, 'tokens'):
# some aliases are not parsed properly
if identifier.tokens[0].ttype == Name:
self._alias_names.add(identifier.tokens[0].value)
self.__extract_from_token(identifier)
if tlist.has_alias():
self._alias_names.add(tlist.get_alias())

# some aliases are not parsed properly
if tlist.tokens[0].ttype == Name:
self._alias_names.add(tlist.tokens[0].value)
self.__extract_from_token(tlist)

def as_create_table(self, table_name, overwrite=False):
"""Reformats the query into the create table as query.
Expand Down Expand Up @@ -144,10 +144,11 @@ def __extract_from_token(self, token, depth=0):

if table_name_preceding_token:
if isinstance(item, Identifier):
self.__process_identifier(item)
self.__process_tokenlist(item)
elif isinstance(item, IdentifierList):
for token in item.get_identifiers():
self.__process_identifier(token)
if isinstance(token, TokenList):
self.__process_tokenlist(token)
elif isinstance(item, IdentifierList):
for token in item.tokens:
if not self.__is_identifier(token):
Expand Down
9 changes: 9 additions & 0 deletions tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,3 +462,12 @@ def test_messy_breakdown_statements(self):
'SELECT * FROM ab_user LIMIT 1',
]
self.assertEquals(statements, expected)

def test_identifier_list_with_keyword_as_alias(self):
query = """
WITH
f AS (SELECT * FROM foo),
match AS (SELECT * FROM f)
SELECT * FROM match
"""
self.assertEquals({'foo'}, self.extract_tables(query))

0 comments on commit fb627ba

Please sign in to comment.