Skip to content

Commit

Permalink
Pass engine to ParsedQuery
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed Jan 19, 2024
1 parent 01fe20e commit b984b9c
Show file tree
Hide file tree
Showing 14 changed files with 116 additions and 34 deletions.
5 changes: 4 additions & 1 deletion superset/commands/dataset/duplicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ def run(self) -> Model:
table.normalize_columns = self._base_model.normalize_columns
table.always_filter_main_dttm = self._base_model.always_filter_main_dttm
table.is_sqllab_view = True
table.sql = ParsedQuery(self._base_model.sql).stripped()
table.sql = ParsedQuery(
self._base_model.sql,
engine=database.db_engine_spec.engine,
).stripped()
db.session.add(table)
cols = []
for config_ in self._base_model.columns:
Expand Down
5 changes: 4 additions & 1 deletion superset/commands/sql_lab/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,10 @@ def run(
limit = None
else:
sql = self._query.executed_sql
limit = ParsedQuery(sql).limit
limit = ParsedQuery(
sql,
engine=self._query.database.db_engine_spec.engine,
).limit
if limit is not None and self._query.limiting_factor in {
LimitingFactor.QUERY,
LimitingFactor.DROPDOWN,
Expand Down
2 changes: 1 addition & 1 deletion superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1458,7 +1458,7 @@ def get_from_clause(
return self.get_sqla_table(), None

from_sql = self.get_rendered_sql(template_processor)
parsed_query = ParsedQuery(from_sql)
parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
if not (
parsed_query.is_unknown()
or self.db_engine_spec.is_readonly_query(parsed_query)
Expand Down
2 changes: 1 addition & 1 deletion superset/connectors/sqla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]:
sql = dataset.get_template_processor().process_template(
dataset.sql, **dataset.template_params_dict
)
parsed_query = ParsedQuery(sql)
parsed_query = ParsedQuery(sql, engine=db_engine_spec.engine)
if not db_engine_spec.is_readonly_query(parsed_query):
raise SupersetSecurityException(
SupersetError(
Expand Down
12 changes: 6 additions & 6 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,7 @@ def apply_limit_to_sql(
return database.compile_sqla_query(qry)

if cls.limit_method == LimitMethod.FORCE_LIMIT:
parsed_query = sql_parse.ParsedQuery(sql)
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
sql = parsed_query.set_or_update_query_limit(limit, force=force)

return sql
Expand Down Expand Up @@ -981,7 +981,7 @@ def get_limit_from_sql(cls, sql: str) -> int | None:
:param sql: SQL query
:return: Value of limit clause in query
"""
parsed_query = sql_parse.ParsedQuery(sql)
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
return parsed_query.limit

@classmethod
Expand All @@ -993,7 +993,7 @@ def set_or_update_query_limit(cls, sql: str, limit: int) -> str:
:param limit: New limit to insert/replace into query
:return: Query with new limit
"""
parsed_query = sql_parse.ParsedQuery(sql)
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)

Check warning on line 996 in superset/db_engine_specs/base.py

View check run for this annotation

Codecov / codecov/patch

superset/db_engine_specs/base.py#L996

Added line #L996 was not covered by tests
return parsed_query.set_or_update_query_limit(limit)

@classmethod
Expand Down Expand Up @@ -1490,7 +1490,7 @@ def process_statement(cls, statement: str, database: Database) -> str:
:param database: Database instance
:return: Dictionary with different costs
"""
parsed_query = ParsedQuery(statement)
parsed_query = ParsedQuery(statement, engine=cls.engine)

Check warning on line 1493 in superset/db_engine_specs/base.py

View check run for this annotation

Codecov / codecov/patch

superset/db_engine_specs/base.py#L1493

Added line #L1493 was not covered by tests
sql = parsed_query.stripped()
sql_query_mutator = current_app.config["SQL_QUERY_MUTATOR"]
mutate_after_split = current_app.config["MUTATE_AFTER_SPLIT"]
Expand Down Expand Up @@ -1525,7 +1525,7 @@ def estimate_query_cost(
"Database does not support cost estimation"
)

parsed_query = sql_parse.ParsedQuery(sql)
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)

Check warning on line 1528 in superset/db_engine_specs/base.py

View check run for this annotation

Codecov / codecov/patch

superset/db_engine_specs/base.py#L1528

Added line #L1528 was not covered by tests
statements = parsed_query.get_statements()

costs = []
Expand Down Expand Up @@ -1586,7 +1586,7 @@ def execute( # pylint: disable=unused-argument
:return:
"""
if not cls.allows_sql_comments:
query = sql_parse.strip_comments_from_sql(query)
query = sql_parse.strip_comments_from_sql(query, engine=cls.engine)

if cls.arraysize:
cursor.arraysize = cls.arraysize
Expand Down
2 changes: 1 addition & 1 deletion superset/db_engine_specs/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def estimate_query_cost(
if not cls.get_allow_cost_estimate(extra):
raise SupersetException("Database does not support cost estimation")

parsed_query = sql_parse.ParsedQuery(sql)
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)

Check warning on line 438 in superset/db_engine_specs/bigquery.py

View check run for this annotation

Codecov / codecov/patch

superset/db_engine_specs/bigquery.py#L438

Added line #L438 was not covered by tests
statements = parsed_query.get_statements()
costs = []
for statement in statements:
Expand Down
2 changes: 1 addition & 1 deletion superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,7 @@ def get_from_clause(
"""

from_sql = self.get_rendered_sql(template_processor)
parsed_query = ParsedQuery(from_sql)
parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)

Check warning on line 1097 in superset/models/helpers.py

View check run for this annotation

Codecov / codecov/patch

superset/models/helpers.py#L1097

Added line #L1097 was not covered by tests
if not (
parsed_query.is_unknown()
or self.db_engine_spec.is_readonly_query(parsed_query)
Expand Down
6 changes: 4 additions & 2 deletions superset/models/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def username(self) -> str:

@property
def sql_tables(self) -> list[Table]:
return list(ParsedQuery(self.sql).tables)
return list(ParsedQuery(self.sql, engine=self.db_engine_spec.engine).tables)

@property
def columns(self) -> list["TableColumn"]:
Expand Down Expand Up @@ -427,7 +427,9 @@ def url(self) -> str:

@property
def sql_tables(self) -> list[Table]:
return list(ParsedQuery(self.sql).tables)
return list(
ParsedQuery(self.sql, engine=self.database.db_engine_spec.engine).tables
)

@property
def last_run_humanized(self) -> str:
Expand Down
5 changes: 4 additions & 1 deletion superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1861,7 +1861,10 @@ def raise_for_access(
default_schema = database.get_default_schema_for_query(query)
tables = {
Table(table_.table, table_.schema or default_schema)
for table_ in sql_parse.ParsedQuery(query.sql).tables
for table_ in sql_parse.ParsedQuery(
query.sql,
engine=database.db_engine_spec.engine,
).tables
}
elif table:
tables = {table}
Expand Down
11 changes: 8 additions & 3 deletions superset/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def execute_sql_statement( # pylint: disable=too-many-arguments, too-many-local
database: Database = query.database
db_engine_spec = database.db_engine_spec

parsed_query = ParsedQuery(sql_statement)
parsed_query = ParsedQuery(sql_statement, engine=db_engine_spec.engine)
if is_feature_enabled("RLS_IN_SQLLAB"):
# There are two ways to insert RLS: either replacing the table with a subquery
# that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is
Expand All @@ -228,7 +228,8 @@ def execute_sql_statement( # pylint: disable=too-many-arguments, too-many-local
database.id,
query.schema,
)
)
),
engine=db_engine_spec.engine,
)

sql = parsed_query.stripped()
Expand Down Expand Up @@ -419,7 +420,11 @@ def execute_sql_statements(
)

# Breaking down into multiple statements
parsed_query = ParsedQuery(rendered_query, strip_comments=True)
parsed_query = ParsedQuery(
rendered_query,
strip_comments=True,

Check warning on line 425 in superset/sql_lab.py

View check run for this annotation

Codecov / codecov/patch

superset/sql_lab.py#L424-L425

Added lines #L424 - L425 were not covered by tests
engine=db_engine_spec.engine,
)
if not db_engine_spec.run_multiple_statements_as_one:
statements = parsed_query.get_statements()
logger.info(
Expand Down
84 changes: 73 additions & 11 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=too-many-lines

import logging
import re
import urllib.parse
Expand All @@ -24,6 +27,8 @@
import sqlparse
from sqlalchemy import and_
from sqlglot import exp, parse, parse_one
from sqlglot.dialects import Dialects
from sqlglot.errors import ParseError
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
from sqlparse import keywords
from sqlparse.lexer import Lexer
Expand Down Expand Up @@ -55,7 +60,7 @@

try:
from sqloxide import parse_sql as sqloxide_parse
except: # pylint: disable=bare-except
except (ImportError, ModuleNotFoundError):

Check warning on line 63 in superset/sql_parse.py

View check run for this annotation

Codecov / codecov/patch

superset/sql_parse.py#L63

Added line #L63 was not covered by tests
sqloxide_parse = None

RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"}
Expand All @@ -74,6 +79,59 @@
lex.set_SQL_REGEX(sqlparser_sql_regex)


# mapping between DB engine specs and sqlglot dialects
SQLGLOT_DIALECTS = {
"ascend": Dialects.HIVE,
"awsathena": Dialects.PRESTO,
"bigquery": Dialects.BIGQUERY,
"clickhouse": Dialects.CLICKHOUSE,
"clickhousedb": Dialects.CLICKHOUSE,
"cockroachdb": Dialects.POSTGRES,
# "crate": ???
# "databend": ???
"databricks": Dialects.DATABRICKS,
# "db2": ???
# "dremio": ???
"drill": Dialects.DRILL,
# "druid": ???
"duckdb": Dialects.DUCKDB,
# "dynamodb": ???
# "elasticsearch": ???
# "exa": ???
# "firebird": ???
# "firebolt": ???
"gsheets": Dialects.SQLITE,
"hana": Dialects.POSTGRES,
"hive": Dialects.HIVE,
# "ibmi": ???
# "impala": ???
# "kustokql": ???
# "kylin": ???
# "mssql": ???
"mysql": Dialects.MYSQL,
"netezza": Dialects.POSTGRES,
# "ocient": ???
# "odelasticsearch": ???
"oracle": Dialects.ORACLE,
# "pinot": ???
"postgresql": Dialects.POSTGRES,
"presto": Dialects.PRESTO,
"pydoris": Dialects.DORIS,
"redshift": Dialects.REDSHIFT,
# "risingwave": ???
# "rockset": ???
"shillelagh": Dialects.SQLITE,
"snowflake": Dialects.SNOWFLAKE,
# "solr": ???
"sqlite": Dialects.SQLITE,
"starrocks": Dialects.STARROCKS,
"superset": Dialects.SQLITE,
"teradatasql": Dialects.TERADATA,
"trino": Dialects.TRINO,
"vertica": Dialects.POSTGRES,
}


class CtasMethod(StrEnum):
TABLE = "TABLE"
VIEW = "VIEW"
Expand Down Expand Up @@ -152,7 +210,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
return cte, remainder


def strip_comments_from_sql(statement: str) -> str:
def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> str:
"""
Strips comments from a SQL statement, does a simple test first
to avoid always instantiating the expensive ParsedQuery constructor
Expand All @@ -162,7 +220,11 @@ def strip_comments_from_sql(statement: str) -> str:
:param statement: A string with the SQL statement
:return: SQL statement without comments
"""
return ParsedQuery(statement).strip_comments() if "--" in statement else statement
return (
ParsedQuery(statement, engine=engine).strip_comments()
if "--" in statement
else statement
)


@dataclass(eq=True, frozen=True)
Expand Down Expand Up @@ -195,13 +257,13 @@ def __init__(
self,
sql_statement: str,
strip_comments: bool = False,
dialect: Optional[str] = None,
engine: Optional[str] = None,
):
if strip_comments:
sql_statement = sqlparse.format(sql_statement, strip_comments=True)

self.sql: str = sql_statement
self.dialect = dialect
self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
self._tables: set[Table] = set()
self._alias_names: set[str] = set()
self._limit: Optional[int] = None
Expand All @@ -224,9 +286,9 @@ def _extract_tables_from_sql(self) -> set[Table]:
Note: this uses sqlglot, since it's better at catching more edge cases.
"""
try:
statements = parse(self.sql, dialect=self.dialect)
except Exception: # pylint: disable=broad-exception-caught
logger.warning("Unable to parse SQL (%s): %s", self.dialect, self.sql)
statements = parse(self.sql, dialect=self._dialect)
except ParseError:
logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql)
return set()

return {
Expand Down Expand Up @@ -261,7 +323,7 @@ def _extract_tables_from_statement(self, statement: exp.Expression) -> set[Table
if not literal:
return set()

Check warning on line 324 in superset/sql_parse.py

View check run for this annotation

Codecov / codecov/patch

superset/sql_parse.py#L324

Added line #L324 was not covered by tests

pseudo_query = parse_one(f"SELECT {literal.this}", dialect=self.dialect)
pseudo_query = parse_one(f"SELECT {literal.this}", dialect=self._dialect)
sources = pseudo_query.find_all(exp.Table)
elif statement:
sources = []
Expand Down Expand Up @@ -900,7 +962,7 @@ def insert_rls_in_predicate(


# mapping between sqloxide and SQLAlchemy dialects
SQLOXITE_DIALECTS = {
SQLOXIDE_DIALECTS = {
"ansi": {"trino", "trinonative", "presto"},
"hive": {"hive", "databricks"},
"ms": {"mssql"},
Expand Down Expand Up @@ -933,7 +995,7 @@ def extract_table_references(
tree = None

if sqloxide_parse:
for dialect, sqla_dialects in SQLOXITE_DIALECTS.items():
for dialect, sqla_dialects in SQLOXIDE_DIALECTS.items():
if sqla_dialect in sqla_dialects:
break
sql_text = RE_JINJA_BLOCK.sub(" ", sql_text)
Expand Down
4 changes: 2 additions & 2 deletions superset/sql_validators/presto_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def validate_statement(
) -> Optional[SQLValidationAnnotation]:
# pylint: disable=too-many-locals
db_engine_spec = database.db_engine_spec
parsed_query = ParsedQuery(statement)
parsed_query = ParsedQuery(statement, engine=db_engine_spec.engine)
sql = parsed_query.stripped()

# Hook to allow environment-specific mutation (usually comments) to the SQL
Expand Down Expand Up @@ -154,7 +154,7 @@ def validate(
For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE
VALIDATE) SELECT 1 FROM default.mytable.
"""
parsed_query = ParsedQuery(sql)
parsed_query = ParsedQuery(sql, engine=database.db_engine_spec.engine)
statements = parsed_query.get_statements()

logger.info("Validating %i statement(s)", len(statements))
Expand Down
6 changes: 5 additions & 1 deletion superset/sqllab/query_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,11 @@ def render(self, execution_context: SqlJsonExecutionContext) -> str:
database=query_model.database, query=query_model
)

parsed_query = ParsedQuery(query_model.sql, strip_comments=True)
parsed_query = ParsedQuery(
query_model.sql,
strip_comments=True,
engine=query_model.database.db_engine_spec.engine,
)
rendered_query = sql_template_processor.process_template(
parsed_query.stripped(), **execution_context.template_params
)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@
)


def extract_tables(query: str, dialect: Optional[str] = None) -> set[Table]:
def extract_tables(query: str, engine: Optional[str] = None) -> set[Table]:
"""
Helper function to extract tables referenced in a query.
"""
return ParsedQuery(query, dialect=dialect).tables
return ParsedQuery(query, engine=engine).tables


def test_table() -> None:
Expand Down

0 comments on commit b984b9c

Please sign in to comment.