From 2ade3db3101679df2e0e678d5a4cbaf1f165a4b0 Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Fri, 22 Mar 2024 13:39:28 +1300 Subject: [PATCH] fix(sql_parse): Ensure table extraction handles Jinja templating (#27470) --- superset/commands/sql_lab/execute.py | 4 +- superset/jinja_context.py | 10 ++--- superset/models/sql_lab.py | 40 +++++++++++++------ superset/security/manager.py | 13 ++---- superset/sql_parse.py | 60 ++++++++++++++++++++++++++++ superset/sqllab/query_render.py | 3 +- tests/unit_tests/sql_parse_tests.py | 41 +++++++++++++++++++ 7 files changed, 141 insertions(+), 30 deletions(-) diff --git a/superset/commands/sql_lab/execute.py b/superset/commands/sql_lab/execute.py index 5d955571d8441..533264fb28a4b 100644 --- a/superset/commands/sql_lab/execute.py +++ b/superset/commands/sql_lab/execute.py @@ -144,11 +144,13 @@ def _run_sql_json_exec_from_scratch(self) -> SqlJsonExecutionStatus: try: logger.info("Triggering query_id: %i", query.id) + # Necessary to check access before rendering the Jinjafied query as the + # some Jinja macros execute statements upon rendering. + self._validate_access(query) self._execution_context.set_query(query) rendered_query = self._sql_query_render.render(self._execution_context) validate_rendered_query = copy.copy(query) validate_rendered_query.sql = rendered_query - self._validate_access(validate_rendered_query) self._set_query_limit_if_required(rendered_query) self._query_dao.update( query, {"limit": self._execution_context.query.limit} diff --git a/superset/jinja_context.py b/superset/jinja_context.py index 2990953bae52e..0ee7667811f0d 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -24,7 +24,7 @@ import dateutil from flask import current_app, has_request_context, request from flask_babel import gettext as _ -from jinja2 import DebugUndefined +from jinja2 import DebugUndefined, Environment from jinja2.sandbox import SandboxedEnvironment from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.sql.expression import bindparam @@ -479,11 +479,11 @@ def __init__( self._applied_filters = applied_filters self._removed_filters = removed_filters self._context: dict[str, Any] = {} - self._env = SandboxedEnvironment(undefined=DebugUndefined) + self.env: Environment = SandboxedEnvironment(undefined=DebugUndefined) self.set_context(**kwargs) # custom filters - self._env.filters["where_in"] = WhereInMacro(database.get_dialect()) + self.env.filters["where_in"] = WhereInMacro(database.get_dialect()) def set_context(self, **kwargs: Any) -> None: self._context.update(kwargs) @@ -496,7 +496,7 @@ def process_template(self, sql: str, **kwargs: Any) -> str: >>> process_template(sql) "SELECT '2017-01-01T00:00:00'" """ - template = self._env.from_string(sql) + template = self.env.from_string(sql) kwargs.update(self._context) context = validate_template_context(self.engine, kwargs) @@ -643,7 +643,7 @@ class TrinoTemplateProcessor(PrestoTemplateProcessor): engine = "trino" def process_template(self, sql: str, **kwargs: Any) -> str: - template = self._env.from_string(sql) + template = self.env.from_string(sql) kwargs.update(self._context) # Backwards compatibility if migrating from Presto. diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index f4724d6dabb69..2d7384a74e471 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -46,6 +46,7 @@ from sqlalchemy.sql.elements import ColumnElement, literal_column from superset import security_manager +from superset.exceptions import SupersetSecurityException from superset.jinja_context import BaseTemplateProcessor, get_template_processor from superset.models.helpers import ( AuditMixinNullable, @@ -53,7 +54,7 @@ ExtraJSONMixin, ImportExportMixin, ) -from superset.sql_parse import CtasMethod, ParsedQuery, Table +from superset.sql_parse import CtasMethod, extract_tables_from_jinja_sql, Table from superset.sqllab.limiting_factor import LimitingFactor from superset.utils.core import get_column_name, MediumText, QueryStatus, user_label @@ -65,8 +66,25 @@ logger = logging.getLogger(__name__) +class SqlTablesMixin: # pylint: disable=too-few-public-methods + @property + def sql_tables(self) -> list[Table]: + try: + return list( + extract_tables_from_jinja_sql( + self.sql, # type: ignore + self.database.db_engine_spec.engine, # type: ignore + ) + ) + except SupersetSecurityException: + return [] + + class Query( - ExtraJSONMixin, ExploreMixin, Model + SqlTablesMixin, + ExtraJSONMixin, + ExploreMixin, + Model, ): # pylint: disable=abstract-method,too-many-public-methods """ORM model for SQL query @@ -181,10 +199,6 @@ def database_name(self) -> str: def username(self) -> str: return self.user.username - @property - def sql_tables(self) -> list[Table]: - return list(ParsedQuery(self.sql, engine=self.db_engine_spec.engine).tables) - @property def columns(self) -> list["TableColumn"]: from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel @@ -355,7 +369,13 @@ def adhoc_column_to_sqla( return self.make_sqla_column_compatible(sqla_column, label) -class SavedQuery(AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, Model): +class SavedQuery( + SqlTablesMixin, + AuditMixinNullable, + ExtraJSONMixin, + ImportExportMixin, + Model, +): """ORM model for SQL query""" __tablename__ = "saved_query" @@ -425,12 +445,6 @@ def sqlalchemy_uri(self) -> URL: def url(self) -> str: return f"/sqllab?savedQueryId={self.id}" - @property - def sql_tables(self) -> list[Table]: - return list( - ParsedQuery(self.sql, engine=self.database.db_engine_spec.engine).tables - ) - @property def last_run_humanized(self) -> str: return naturaltime(datetime.now() - self.changed_on) diff --git a/superset/security/manager.py b/superset/security/manager.py index a5324314334ac..2833e886456f5 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -52,14 +52,12 @@ from sqlalchemy.orm.mapper import Mapper from sqlalchemy.orm.query import Query as SqlaQuery -from superset import sql_parse from superset.constants import RouteMethod from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import ( DatasetInvalidPermissionEvaluationException, SupersetSecurityException, ) -from superset.jinja_context import get_template_processor from superset.security.guest_token import ( GuestToken, GuestTokenResources, @@ -68,6 +66,7 @@ GuestTokenUser, GuestUser, ) +from superset.sql_parse import extract_tables_from_jinja_sql from superset.superset_typing import Metric from superset.utils.core import ( DatasourceName, @@ -1961,16 +1960,12 @@ def raise_for_access( return if query: - # make sure the quuery is valid SQL by rendering any Jinja - processor = get_template_processor(database=query.database) - rendered_sql = processor.process_template(query.sql) default_schema = database.get_default_schema_for_query(query) tables = { Table(table_.table, table_.schema or default_schema) - for table_ in sql_parse.ParsedQuery( - rendered_sql, - engine=database.db_engine_spec.engine, - ).tables + for table_ in extract_tables_from_jinja_sql( + query.sql, database.db_engine_spec.engine + ) } elif table: tables = {table} diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 9367d3c59f6d3..f721f456d0933 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -25,10 +25,12 @@ from collections.abc import Iterable, Iterator from dataclasses import dataclass from typing import Any, cast +from unittest.mock import Mock import sqlglot import sqlparse from flask_babel import gettext as __ +from jinja2 import nodes from sqlalchemy import and_ from sqlglot import exp, parse, parse_one from sqlglot.dialects.dialect import Dialect, Dialects @@ -1232,3 +1234,61 @@ def find_nodes_by_key(element: Any, target: str) -> Iterator[Any]: Table(*[part["value"] for part in table["name"][::-1]]) for table in find_nodes_by_key(tree, "Table") } + + +def extract_tables_from_jinja_sql(sql: str, engine: str | None = None) -> set[Table]: + """ + Extract all table references in the Jinjafied SQL statement. + + Due to Jinja templating, a multiphase approach is necessary as the Jinjafied SQL + statement may represent invalid SQL which is non-parsable by SQLGlot. + + Firstly, we extract any tables referenced within the confines of specific Jinja + macros. Secondly, we replace these non-SQL Jinja calls with a pseudo-benign SQL + expression to help ensure that the resulting SQL statements are parsable by + SQLGlot. + + :param sql: The Jinjafied SQL statement + :param engine: The associated database engine + :returns: The set of tables referenced in the SQL statement + :raises SupersetSecurityException: If SQLGlot is unable to parse the SQL statement + """ + + from superset.jinja_context import ( # pylint: disable=import-outside-toplevel + get_template_processor, + ) + + # Mock the required database as the processor signature is exposed publically. + processor = get_template_processor(database=Mock(backend=engine)) + template = processor.env.parse(sql) + + tables = set() + + for node in template.find_all(nodes.Call): + if isinstance(node.node, nodes.Getattr) and node.node.attr in ( + "latest_partition", + "latest_sub_partition", + ): + # Extract the table referenced in the macro. + tables.add( + Table( + *[ + remove_quotes(part) + for part in node.args[0].value.split(".")[::-1] + if len(node.args) == 1 + ] + ) + ) + + # Replace the potentially problematic Jinja macro with some benign SQL. + node.__class__ = nodes.TemplateData + node.fields = nodes.TemplateData.fields + node.data = "NULL" + + return ( + tables + | ParsedQuery( + sql_statement=processor.process_template(template), + engine=engine, + ).tables + ) diff --git a/superset/sqllab/query_render.py b/superset/sqllab/query_render.py index 5597bcb086d7c..caf9a3cb2b206 100644 --- a/superset/sqllab/query_render.py +++ b/superset/sqllab/query_render.py @@ -79,8 +79,7 @@ def _validate( sql_template_processor: BaseTemplateProcessor, ) -> None: if is_feature_enabled("ENABLE_TEMPLATE_PROCESSING"): - # pylint: disable=protected-access - syntax_tree = sql_template_processor._env.parse(rendered_query) + syntax_tree = sql_template_processor.env.parse(rendered_query) undefined_parameters = find_undeclared_variables(syntax_tree) if undefined_parameters: self._raise_undefined_parameter_exception( diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 973a4e379316c..aa4171e763fe8 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -32,6 +32,7 @@ from superset.sql_parse import ( add_table_name, extract_table_references, + extract_tables_from_jinja_sql, get_rls_for_table, has_table_query, insert_rls_as_subquery, @@ -1909,3 +1910,43 @@ def test_sqlstatement() -> None: statement = SQLStatement("SET a=1") assert statement.get_settings() == {"a": "1"} + + +@pytest.mark.parametrize( + "engine", + [ + "hive", + "presto", + "trino", + ], +) +@pytest.mark.parametrize( + "macro", + [ + "latest_partition('foo.bar')", + "latest_sub_partition('foo.bar', baz='qux')", + ], +) +@pytest.mark.parametrize( + "sql,expected", + [ + ( + "SELECT '{{{{ {engine}.{macro} }}}}'", + {Table(table="bar", schema="foo")}, + ), + ( + "SELECT * FROM foo.baz WHERE quux = '{{{{ {engine}.{macro} }}}}'", + {Table(table="bar", schema="foo"), Table(table="baz", schema="foo")}, + ), + ], +) +def test_extract_tables_from_jinja_sql( + engine: str, + macro: str, + sql: str, + expected: set[Table], +) -> None: + assert ( + extract_tables_from_jinja_sql(sql.format(engine=engine, macro=macro), engine) + == expected + )