Skip to content

Commit

Permalink
fix(sql_parse): Ensure table extraction handles Jinja templating (apa…
Browse files Browse the repository at this point in the history
  • Loading branch information
john-bodley authored Mar 22, 2024
1 parent adad749 commit 2ade3db
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 30 deletions.
4 changes: 3 additions & 1 deletion superset/commands/sql_lab/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,13 @@ def _run_sql_json_exec_from_scratch(self) -> SqlJsonExecutionStatus:
try:
logger.info("Triggering query_id: %i", query.id)

# Necessary to check access before rendering the Jinjafied query as the
# some Jinja macros execute statements upon rendering.
self._validate_access(query)
self._execution_context.set_query(query)
rendered_query = self._sql_query_render.render(self._execution_context)
validate_rendered_query = copy.copy(query)
validate_rendered_query.sql = rendered_query
self._validate_access(validate_rendered_query)
self._set_query_limit_if_required(rendered_query)
self._query_dao.update(
query, {"limit": self._execution_context.query.limit}
Expand Down
10 changes: 5 additions & 5 deletions superset/jinja_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import dateutil
from flask import current_app, has_request_context, request
from flask_babel import gettext as _
from jinja2 import DebugUndefined
from jinja2 import DebugUndefined, Environment
from jinja2.sandbox import SandboxedEnvironment
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql.expression import bindparam
Expand Down Expand Up @@ -479,11 +479,11 @@ def __init__(
self._applied_filters = applied_filters
self._removed_filters = removed_filters
self._context: dict[str, Any] = {}
self._env = SandboxedEnvironment(undefined=DebugUndefined)
self.env: Environment = SandboxedEnvironment(undefined=DebugUndefined)
self.set_context(**kwargs)

# custom filters
self._env.filters["where_in"] = WhereInMacro(database.get_dialect())
self.env.filters["where_in"] = WhereInMacro(database.get_dialect())

def set_context(self, **kwargs: Any) -> None:
self._context.update(kwargs)
Expand All @@ -496,7 +496,7 @@ def process_template(self, sql: str, **kwargs: Any) -> str:
>>> process_template(sql)
"SELECT '2017-01-01T00:00:00'"
"""
template = self._env.from_string(sql)
template = self.env.from_string(sql)
kwargs.update(self._context)

context = validate_template_context(self.engine, kwargs)
Expand Down Expand Up @@ -643,7 +643,7 @@ class TrinoTemplateProcessor(PrestoTemplateProcessor):
engine = "trino"

def process_template(self, sql: str, **kwargs: Any) -> str:
template = self._env.from_string(sql)
template = self.env.from_string(sql)
kwargs.update(self._context)

# Backwards compatibility if migrating from Presto.
Expand Down
40 changes: 27 additions & 13 deletions superset/models/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,15 @@
from sqlalchemy.sql.elements import ColumnElement, literal_column

from superset import security_manager
from superset.exceptions import SupersetSecurityException
from superset.jinja_context import BaseTemplateProcessor, get_template_processor
from superset.models.helpers import (
AuditMixinNullable,
ExploreMixin,
ExtraJSONMixin,
ImportExportMixin,
)
from superset.sql_parse import CtasMethod, ParsedQuery, Table
from superset.sql_parse import CtasMethod, extract_tables_from_jinja_sql, Table
from superset.sqllab.limiting_factor import LimitingFactor
from superset.utils.core import get_column_name, MediumText, QueryStatus, user_label

Expand All @@ -65,8 +66,25 @@
logger = logging.getLogger(__name__)


class SqlTablesMixin: # pylint: disable=too-few-public-methods
@property
def sql_tables(self) -> list[Table]:
try:
return list(
extract_tables_from_jinja_sql(
self.sql, # type: ignore
self.database.db_engine_spec.engine, # type: ignore
)
)
except SupersetSecurityException:
return []


class Query(
ExtraJSONMixin, ExploreMixin, Model
SqlTablesMixin,
ExtraJSONMixin,
ExploreMixin,
Model,
): # pylint: disable=abstract-method,too-many-public-methods
"""ORM model for SQL query
Expand Down Expand Up @@ -181,10 +199,6 @@ def database_name(self) -> str:
def username(self) -> str:
return self.user.username

@property
def sql_tables(self) -> list[Table]:
return list(ParsedQuery(self.sql, engine=self.db_engine_spec.engine).tables)

@property
def columns(self) -> list["TableColumn"]:
from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel
Expand Down Expand Up @@ -355,7 +369,13 @@ def adhoc_column_to_sqla(
return self.make_sqla_column_compatible(sqla_column, label)


class SavedQuery(AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, Model):
class SavedQuery(
SqlTablesMixin,
AuditMixinNullable,
ExtraJSONMixin,
ImportExportMixin,
Model,
):
"""ORM model for SQL query"""

__tablename__ = "saved_query"
Expand Down Expand Up @@ -425,12 +445,6 @@ def sqlalchemy_uri(self) -> URL:
def url(self) -> str:
return f"/sqllab?savedQueryId={self.id}"

@property
def sql_tables(self) -> list[Table]:
return list(
ParsedQuery(self.sql, engine=self.database.db_engine_spec.engine).tables
)

@property
def last_run_humanized(self) -> str:
return naturaltime(datetime.now() - self.changed_on)
Expand Down
13 changes: 4 additions & 9 deletions superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,12 @@
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.query import Query as SqlaQuery

from superset import sql_parse
from superset.constants import RouteMethod
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import (
DatasetInvalidPermissionEvaluationException,
SupersetSecurityException,
)
from superset.jinja_context import get_template_processor
from superset.security.guest_token import (
GuestToken,
GuestTokenResources,
Expand All @@ -68,6 +66,7 @@
GuestTokenUser,
GuestUser,
)
from superset.sql_parse import extract_tables_from_jinja_sql
from superset.superset_typing import Metric
from superset.utils.core import (
DatasourceName,
Expand Down Expand Up @@ -1961,16 +1960,12 @@ def raise_for_access(
return

if query:
# make sure the quuery is valid SQL by rendering any Jinja
processor = get_template_processor(database=query.database)
rendered_sql = processor.process_template(query.sql)
default_schema = database.get_default_schema_for_query(query)
tables = {
Table(table_.table, table_.schema or default_schema)
for table_ in sql_parse.ParsedQuery(
rendered_sql,
engine=database.db_engine_spec.engine,
).tables
for table_ in extract_tables_from_jinja_sql(
query.sql, database.db_engine_spec.engine
)
}
elif table:
tables = {table}
Expand Down
60 changes: 60 additions & 0 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from typing import Any, cast
from unittest.mock import Mock

import sqlglot
import sqlparse
from flask_babel import gettext as __
from jinja2 import nodes
from sqlalchemy import and_
from sqlglot import exp, parse, parse_one
from sqlglot.dialects.dialect import Dialect, Dialects
Expand Down Expand Up @@ -1232,3 +1234,61 @@ def find_nodes_by_key(element: Any, target: str) -> Iterator[Any]:
Table(*[part["value"] for part in table["name"][::-1]])
for table in find_nodes_by_key(tree, "Table")
}


def extract_tables_from_jinja_sql(sql: str, engine: str | None = None) -> set[Table]:
"""
Extract all table references in the Jinjafied SQL statement.
Due to Jinja templating, a multiphase approach is necessary as the Jinjafied SQL
statement may represent invalid SQL which is non-parsable by SQLGlot.
Firstly, we extract any tables referenced within the confines of specific Jinja
macros. Secondly, we replace these non-SQL Jinja calls with a pseudo-benign SQL
expression to help ensure that the resulting SQL statements are parsable by
SQLGlot.
:param sql: The Jinjafied SQL statement
:param engine: The associated database engine
:returns: The set of tables referenced in the SQL statement
:raises SupersetSecurityException: If SQLGlot is unable to parse the SQL statement
"""

from superset.jinja_context import ( # pylint: disable=import-outside-toplevel
get_template_processor,
)

# Mock the required database as the processor signature is exposed publically.
processor = get_template_processor(database=Mock(backend=engine))
template = processor.env.parse(sql)

tables = set()

for node in template.find_all(nodes.Call):
if isinstance(node.node, nodes.Getattr) and node.node.attr in (
"latest_partition",
"latest_sub_partition",
):
# Extract the table referenced in the macro.
tables.add(
Table(
*[
remove_quotes(part)
for part in node.args[0].value.split(".")[::-1]
if len(node.args) == 1
]
)
)

# Replace the potentially problematic Jinja macro with some benign SQL.
node.__class__ = nodes.TemplateData
node.fields = nodes.TemplateData.fields
node.data = "NULL"

return (
tables
| ParsedQuery(
sql_statement=processor.process_template(template),
engine=engine,
).tables
)
3 changes: 1 addition & 2 deletions superset/sqllab/query_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ def _validate(
sql_template_processor: BaseTemplateProcessor,
) -> None:
if is_feature_enabled("ENABLE_TEMPLATE_PROCESSING"):
# pylint: disable=protected-access
syntax_tree = sql_template_processor._env.parse(rendered_query)
syntax_tree = sql_template_processor.env.parse(rendered_query)
undefined_parameters = find_undeclared_variables(syntax_tree)
if undefined_parameters:
self._raise_undefined_parameter_exception(
Expand Down
41 changes: 41 additions & 0 deletions tests/unit_tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from superset.sql_parse import (
add_table_name,
extract_table_references,
extract_tables_from_jinja_sql,
get_rls_for_table,
has_table_query,
insert_rls_as_subquery,
Expand Down Expand Up @@ -1909,3 +1910,43 @@ def test_sqlstatement() -> None:

statement = SQLStatement("SET a=1")
assert statement.get_settings() == {"a": "1"}


@pytest.mark.parametrize(
"engine",
[
"hive",
"presto",
"trino",
],
)
@pytest.mark.parametrize(
"macro",
[
"latest_partition('foo.bar')",
"latest_sub_partition('foo.bar', baz='qux')",
],
)
@pytest.mark.parametrize(
"sql,expected",
[
(
"SELECT '{{{{ {engine}.{macro} }}}}'",
{Table(table="bar", schema="foo")},
),
(
"SELECT * FROM foo.baz WHERE quux = '{{{{ {engine}.{macro} }}}}'",
{Table(table="bar", schema="foo"), Table(table="baz", schema="foo")},
),
],
)
def test_extract_tables_from_jinja_sql(
engine: str,
macro: str,
sql: str,
expected: set[Table],
) -> None:
assert (
extract_tables_from_jinja_sql(sql.format(engine=engine, macro=macro), engine)
== expected
)

0 comments on commit 2ade3db

Please sign in to comment.