Skip to content

Commit

Permalink
lint + cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
villebro committed Feb 5, 2022
1 parent 16c3a3f commit 6a47c7d
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 27 deletions.
26 changes: 16 additions & 10 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
get_physical_table_metadata,
get_virtual_table_metadata,
)
from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression
from superset.db_engine_specs.base import BaseEngineSpec, CTE_ALIAS, TimestampExpression
from superset.exceptions import QueryObjectValidationError
from superset.jinja_context import (
BaseTemplateProcessor,
Expand All @@ -103,7 +103,6 @@
logger = logging.getLogger(__name__)

VIRTUAL_TABLE_ALIAS = "virtual_table"
CTE_ALIAS = "__cte"


class SqlaQuery(NamedTuple):
Expand All @@ -129,12 +128,6 @@ class MetadataResult:
modified: List[str] = field(default_factory=list)


def _apply_cte(sql: str, cte: Optional[str]) -> str:
if cte:
sql = f"{cte}{sql}"
return sql


class AnnotationDatasource(BaseDatasource):
"""Dummy object so we can query annotations using 'Viz' objects just like
regular datasources.
Expand Down Expand Up @@ -570,6 +563,19 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
def __repr__(self) -> str:
return self.name

@staticmethod
def _apply_cte(sql: str, cte: Optional[str]) -> str:
"""
Append a CTE before the SELECT statement if defined
:param sql: SELECT statement
:param cte: CTE statement
:return:
"""
if cte:
sql = f"{cte}\n{sql}"
return sql

@property
def db_engine_spec(self) -> Type[BaseEngineSpec]:
return self.database.db_engine_spec
Expand Down Expand Up @@ -762,7 +768,7 @@ def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:

engine = self.database.get_sqla_engine()
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
sql = _apply_cte(sql, cte)
sql = self._apply_cte(sql, cte)
sql = self.mutate_query_from_config(sql)

df = pd.read_sql_query(sql=sql, con=engine)
Expand All @@ -784,7 +790,7 @@ def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor:
def get_query_str_extended(self, query_obj: QueryObjectDict) -> QueryStringExtended:
sqlaq = self.get_sqla_query(**query_obj)
sql = self.database.compile_sqla_query(sqlaq.sqla_query)
sql = _apply_cte(sql, sqlaq.cte)
sql = self._apply_cte(sql, sqlaq.cte)
sql = sqlparse.format(sql, reindent=True)
sql = self.mutate_query_from_config(sql)
return QueryStringExtended(
Expand Down
20 changes: 11 additions & 9 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from sqlalchemy.sql import quoted_name, text
from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom, TextClause
from sqlalchemy.types import TypeEngine
from sqlparse.tokens import CTE, Keyword
from sqlparse.tokens import CTE
from typing_extensions import TypedDict

from superset import security_manager, sql_parse
Expand All @@ -81,6 +81,9 @@
logger = logging.getLogger()


CTE_ALIAS = "__cte"


class TimeGrain(NamedTuple):
name: str # TODO: redundant field, remove
label: str
Expand Down Expand Up @@ -679,19 +682,18 @@ def get_cte_query(cls, sql: str) -> Optional[str]:
"""
if not cls.allows_cte_in_subquery:
p = sqlparse.parse(sql)[0]
stmt = sqlparse.parse(sql)[0]

# The first meaningful token for CTE will be with WITH
idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True)
if not (tok and tok.ttype == CTE):
idx, token = stmt.token_next(-1, skip_ws=True, skip_cm=True)
if not (token and token.ttype == CTE):
return None
idx, tok = p.token_next(idx)
idx = p.token_index(tok) + 1
idx, token = stmt.token_next(idx)
idx = stmt.token_index(token) + 1

# extract rest of the SQLs after CTE
remainder = "".join(str(tok) for tok in p.tokens[idx:]).strip()
__query = "WITH " + tok.value + ",\n__cte AS (\n" + remainder + "\n)"
return __query
remainder = "".join(str(token) for token in stmt.tokens[idx:]).strip()
return f"WITH {token.value},\n{CTE_ALIAS} AS (\n{remainder}\n)"

return None

Expand Down
2 changes: 0 additions & 2 deletions superset/db_engine_specs/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
from superset.errors import SupersetErrorType
from superset.utils import core as utils
import sqlparse
from sqlparse.tokens import Keyword, CTE

logger = logging.getLogger(__name__)

Expand Down
43 changes: 43 additions & 0 deletions tests/unit_tests/db_engine_specs/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, protected-access

from textwrap import dedent

import pytest
from flask.ctx import AppContext
from sqlalchemy.types import TypeEngine


def test_get_text_clause_with_colon(app_context: AppContext) -> None:
Expand Down Expand Up @@ -56,3 +60,42 @@ def test_parse_sql_multi_statement(app_context: AppContext) -> None:
"SELECT foo FROM tbl1",
"SELECT bar FROM tbl2",
]


@pytest.mark.parametrize(
"original,expected",
[
(
dedent(
"""
with currency as
(
select 'INR' as cur
)
select * from currency
"""
),
None,
),
("SELECT 1 as cnt", None,),
(
dedent(
"""
select 'INR' as cur
union
select 'AUD' as cur
union
select 'USD' as cur
"""
),
None,
),
],
)
def test_cte_query_parsing(
app_context: AppContext, original: TypeEngine, expected: str
) -> None:
from superset.db_engine_specs.base import BaseEngineSpec

actual = BaseEngineSpec.get_cte_query(original)
assert actual == expected
16 changes: 10 additions & 6 deletions tests/unit_tests/db_engine_specs/test_mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,20 +186,24 @@ def test_column_datatype_to_string(
(
dedent(
"""
with currency as
(
with currency as (
select 'INR' as cur
),
currency_2 as (
select 'EUR' as cur
)
select * from currency
select * from currency union all select * from currency_2
"""
),
dedent(
"""WITH currency as
(
"""WITH currency as (
select 'INR' as cur
),
currency_2 as (
select 'EUR' as cur
),
__cte AS (
select * from currency
select * from currency union all select * from currency_2
)"""
),
),
Expand Down

0 comments on commit 6a47c7d

Please sign in to comment.