From 6a47c7d4d32ca928fd3aca7f4f875135bc891fbc Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Sat, 5 Feb 2022 10:22:14 +0200 Subject: [PATCH] lint + cleanup --- superset/connectors/sqla/models.py | 26 ++++++----- superset/db_engine_specs/base.py | 20 +++++---- superset/db_engine_specs/mssql.py | 2 - tests/unit_tests/db_engine_specs/test_base.py | 43 +++++++++++++++++++ .../unit_tests/db_engine_specs/test_mssql.py | 16 ++++--- 5 files changed, 80 insertions(+), 27 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 41f29f7ad74c0..ca1d4bc57a022 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -77,7 +77,7 @@ get_physical_table_metadata, get_virtual_table_metadata, ) -from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression +from superset.db_engine_specs.base import BaseEngineSpec, CTE_ALIAS, TimestampExpression from superset.exceptions import QueryObjectValidationError from superset.jinja_context import ( BaseTemplateProcessor, @@ -103,7 +103,6 @@ logger = logging.getLogger(__name__) VIRTUAL_TABLE_ALIAS = "virtual_table" -CTE_ALIAS = "__cte" class SqlaQuery(NamedTuple): @@ -129,12 +128,6 @@ class MetadataResult: modified: List[str] = field(default_factory=list) -def _apply_cte(sql: str, cte: Optional[str]) -> str: - if cte: - sql = f"{cte}{sql}" - return sql - - class AnnotationDatasource(BaseDatasource): """Dummy object so we can query annotations using 'Viz' objects just like regular datasources. @@ -570,6 +563,19 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho def __repr__(self) -> str: return self.name + @staticmethod + def _apply_cte(sql: str, cte: Optional[str]) -> str: + """ + Append a CTE before the SELECT statement if defined + + :param sql: SELECT statement + :param cte: CTE statement + :return: + """ + if cte: + sql = f"{cte}\n{sql}" + return sql + @property def db_engine_spec(self) -> Type[BaseEngineSpec]: return self.database.db_engine_spec @@ -762,7 +768,7 @@ def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]: engine = self.database.get_sqla_engine() sql = qry.compile(engine, compile_kwargs={"literal_binds": True}) - sql = _apply_cte(sql, cte) + sql = self._apply_cte(sql, cte) sql = self.mutate_query_from_config(sql) df = pd.read_sql_query(sql=sql, con=engine) @@ -784,7 +790,7 @@ def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor: def get_query_str_extended(self, query_obj: QueryObjectDict) -> QueryStringExtended: sqlaq = self.get_sqla_query(**query_obj) sql = self.database.compile_sqla_query(sqlaq.sqla_query) - sql = _apply_cte(sql, sqlaq.cte) + sql = self._apply_cte(sql, sqlaq.cte) sql = sqlparse.format(sql, reindent=True) sql = self.mutate_query_from_config(sql) return QueryStringExtended( diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 0c0d0662ad92f..764f3fde70580 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -54,7 +54,7 @@ from sqlalchemy.sql import quoted_name, text from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom, TextClause from sqlalchemy.types import TypeEngine -from sqlparse.tokens import CTE, Keyword +from sqlparse.tokens import CTE from typing_extensions import TypedDict from superset import security_manager, sql_parse @@ -81,6 +81,9 @@ logger = logging.getLogger() +CTE_ALIAS = "__cte" + + class TimeGrain(NamedTuple): name: str # TODO: redundant field, remove label: str @@ -679,19 +682,18 @@ def get_cte_query(cls, sql: str) -> Optional[str]: """ if not cls.allows_cte_in_subquery: - p = sqlparse.parse(sql)[0] + stmt = sqlparse.parse(sql)[0] # The first meaningful token for CTE will be with WITH - idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True) - if not (tok and tok.ttype == CTE): + idx, token = stmt.token_next(-1, skip_ws=True, skip_cm=True) + if not (token and token.ttype == CTE): return None - idx, tok = p.token_next(idx) - idx = p.token_index(tok) + 1 + idx, token = stmt.token_next(idx) + idx = stmt.token_index(token) + 1 # extract rest of the SQLs after CTE - remainder = "".join(str(tok) for tok in p.tokens[idx:]).strip() - __query = "WITH " + tok.value + ",\n__cte AS (\n" + remainder + "\n)" - return __query + remainder = "".join(str(token) for token in stmt.tokens[idx:]).strip() + return f"WITH {token.value},\n{CTE_ALIAS} AS (\n{remainder}\n)" return None diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py index 0d73a7b80d280..e5c66e046a082 100644 --- a/superset/db_engine_specs/mssql.py +++ b/superset/db_engine_specs/mssql.py @@ -24,8 +24,6 @@ from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod from superset.errors import SupersetErrorType from superset.utils import core as utils -import sqlparse -from sqlparse.tokens import Keyword, CTE logger = logging.getLogger(__name__) diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index d822f50de9d8a..4dc27c0928f99 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -16,7 +16,11 @@ # under the License. # pylint: disable=unused-argument, import-outside-toplevel, protected-access +from textwrap import dedent + +import pytest from flask.ctx import AppContext +from sqlalchemy.types import TypeEngine def test_get_text_clause_with_colon(app_context: AppContext) -> None: @@ -56,3 +60,42 @@ def test_parse_sql_multi_statement(app_context: AppContext) -> None: "SELECT foo FROM tbl1", "SELECT bar FROM tbl2", ] + + +@pytest.mark.parametrize( + "original,expected", + [ + ( + dedent( + """ +with currency as +( +select 'INR' as cur +) +select * from currency +""" + ), + None, + ), + ("SELECT 1 as cnt", None,), + ( + dedent( + """ +select 'INR' as cur +union +select 'AUD' as cur +union +select 'USD' as cur +""" + ), + None, + ), + ], +) +def test_cte_query_parsing( + app_context: AppContext, original: TypeEngine, expected: str +) -> None: + from superset.db_engine_specs.base import BaseEngineSpec + + actual = BaseEngineSpec.get_cte_query(original) + assert actual == expected diff --git a/tests/unit_tests/db_engine_specs/test_mssql.py b/tests/unit_tests/db_engine_specs/test_mssql.py index ebb8a2d332f94..250b8158fa320 100644 --- a/tests/unit_tests/db_engine_specs/test_mssql.py +++ b/tests/unit_tests/db_engine_specs/test_mssql.py @@ -186,20 +186,24 @@ def test_column_datatype_to_string( ( dedent( """ -with currency as -( +with currency as ( select 'INR' as cur +), +currency_2 as ( +select 'EUR' as cur ) -select * from currency +select * from currency union all select * from currency_2 """ ), dedent( - """WITH currency as -( + """WITH currency as ( select 'INR' as cur ), +currency_2 as ( +select 'EUR' as cur +), __cte AS ( -select * from currency +select * from currency union all select * from currency_2 )""" ), ),