From fb627ba3769dfeb8f79718790a17a91873239383 Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Wed, 1 May 2019 22:07:01 -0700 Subject: [PATCH] [fix] Fixing SQL parsing issue (#7374) --- superset/sql_parse.py | 39 ++++++++++++++++++++------------------- tests/sql_parse_tests.py | 9 +++++++++ 2 files changed, 29 insertions(+), 19 deletions(-) diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 2f65392558f9e..662f6c326229b 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -18,7 +18,7 @@ import logging import sqlparse -from sqlparse.sql import Identifier, IdentifierList +from sqlparse.sql import Identifier, IdentifierList, Token, TokenList from sqlparse.tokens import Keyword, Name RESULT_OPERATIONS = {'UNION', 'INTERSECT', 'EXCEPT', 'SELECT'} @@ -75,32 +75,32 @@ def get_statements(self): return statements @staticmethod - def __get_full_name(identifier): - if len(identifier.tokens) > 2 and identifier.tokens[1].value == '.': - return '{}.{}'.format(identifier.tokens[0].value, - identifier.tokens[2].value) - return identifier.get_real_name() + def __get_full_name(tlist: TokenList): + if len(tlist.tokens) > 2 and tlist.tokens[1].value == '.': + return '{}.{}'.format(tlist.tokens[0].value, + tlist.tokens[2].value) + return tlist.get_real_name() @staticmethod - def __is_identifier(token): + def __is_identifier(token: Token): return isinstance(token, (IdentifierList, Identifier)) - def __process_identifier(self, identifier): + def __process_tokenlist(self, tlist: TokenList): # exclude subselects - if '(' not in str(identifier): - table_name = self.__get_full_name(identifier) + if '(' not in str(tlist): + table_name = self.__get_full_name(tlist) if table_name and not table_name.startswith(CTE_PREFIX): self._table_names.add(table_name) return # store aliases - if hasattr(identifier, 'get_alias'): - self._alias_names.add(identifier.get_alias()) - if hasattr(identifier, 'tokens'): - # some aliases are not parsed properly - if identifier.tokens[0].ttype == Name: - self._alias_names.add(identifier.tokens[0].value) - self.__extract_from_token(identifier) + if tlist.has_alias(): + self._alias_names.add(tlist.get_alias()) + + # some aliases are not parsed properly + if tlist.tokens[0].ttype == Name: + self._alias_names.add(tlist.tokens[0].value) + self.__extract_from_token(tlist) def as_create_table(self, table_name, overwrite=False): """Reformats the query into the create table as query. @@ -144,10 +144,11 @@ def __extract_from_token(self, token, depth=0): if table_name_preceding_token: if isinstance(item, Identifier): - self.__process_identifier(item) + self.__process_tokenlist(item) elif isinstance(item, IdentifierList): for token in item.get_identifiers(): - self.__process_identifier(token) + if isinstance(token, TokenList): + self.__process_tokenlist(token) elif isinstance(item, IdentifierList): for token in item.tokens: if not self.__is_identifier(token): diff --git a/tests/sql_parse_tests.py b/tests/sql_parse_tests.py index 56959397fa9c5..7096147b5615f 100644 --- a/tests/sql_parse_tests.py +++ b/tests/sql_parse_tests.py @@ -462,3 +462,12 @@ def test_messy_breakdown_statements(self): 'SELECT * FROM ab_user LIMIT 1', ] self.assertEquals(statements, expected) + + def test_identifier_list_with_keyword_as_alias(self): + query = """ + WITH + f AS (SELECT * FROM foo), + match AS (SELECT * FROM f) + SELECT * FROM match + """ + self.assertEquals({'foo'}, self.extract_tables(query))