From 51ad63426c3ad3e227b3393cbc377621f98f8f89 Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Wed, 27 Mar 2024 08:12:25 +1300 Subject: [PATCH] fix: Leverage actual database for rendering Jinjarized SQL (#27646) --- superset/models/sql_lab.py | 2 +- superset/security/manager.py | 4 +--- superset/sql_parse.py | 15 ++++++++------- tests/unit_tests/sql_parse_tests.py | 6 +++++- 4 files changed, 15 insertions(+), 12 deletions(-) diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index 2d7384a74e471..f22d774e884a8 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -73,7 +73,7 @@ def sql_tables(self) -> list[Table]: return list( extract_tables_from_jinja_sql( self.sql, # type: ignore - self.database.db_engine_spec.engine, # type: ignore + self.database, # type: ignore ) ) except SupersetSecurityException: diff --git a/superset/security/manager.py b/superset/security/manager.py index 2833e886456f5..e5a32e97a711c 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -1963,9 +1963,7 @@ def raise_for_access( default_schema = database.get_default_schema_for_query(query) tables = { Table(table_.table, table_.schema or default_schema) - for table_ in extract_tables_from_jinja_sql( - query.sql, database.db_engine_spec.engine - ) + for table_ in extract_tables_from_jinja_sql(query.sql, database) } elif table: tables = {table} diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 58bca48a6e2df..6df5dbc089e4c 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -23,8 +23,7 @@ import urllib.parse from collections.abc import Iterable, Iterator from dataclasses import dataclass -from typing import Any, cast -from unittest.mock import Mock +from typing import Any, cast, TYPE_CHECKING import sqlparse from flask_babel import gettext as __ @@ -71,6 +70,9 @@ except (ImportError, ModuleNotFoundError): sqloxide_parse = None +if TYPE_CHECKING: + from superset.models.core import Database + RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"} ON_KEYWORD = "ON" PRECEDES_TABLE_NAME = {"FROM", "JOIN", "DESCRIBE", "WITH", "LEFT JOIN", "RIGHT JOIN"} @@ -1054,7 +1056,7 @@ def find_nodes_by_key(element: Any, target: str) -> Iterator[Any]: } -def extract_tables_from_jinja_sql(sql: str, engine: str | None = None) -> set[Table]: +def extract_tables_from_jinja_sql(sql: str, database: Database) -> set[Table]: """ Extract all table references in the Jinjafied SQL statement. @@ -1067,7 +1069,7 @@ def extract_tables_from_jinja_sql(sql: str, engine: str | None = None) -> set[Ta SQLGlot. :param sql: The Jinjafied SQL statement - :param engine: The associated database engine + :param database: The database associated with the SQL statement :returns: The set of tables referenced in the SQL statement :raises SupersetSecurityException: If SQLGlot is unable to parse the SQL statement """ @@ -1076,8 +1078,7 @@ def extract_tables_from_jinja_sql(sql: str, engine: str | None = None) -> set[Ta get_template_processor, ) - # Mock the required database as the processor signature is exposed publically. - processor = get_template_processor(database=Mock(backend=engine)) + processor = get_template_processor(database) template = processor.env.parse(sql) tables = set() @@ -1107,6 +1108,6 @@ def extract_tables_from_jinja_sql(sql: str, engine: str | None = None) -> set[Ta tables | ParsedQuery( sql_statement=processor.process_template(template), - engine=engine, + engine=database.db_engine_spec.engine, ).tables ) diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 81ea0e5a7a3f7..dab5dbf9c709c 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -17,6 +17,7 @@ # pylint: disable=invalid-name, redefined-outer-name, too-many-lines from typing import Optional +from unittest.mock import Mock import pytest import sqlparse @@ -1912,6 +1913,9 @@ def test_extract_tables_from_jinja_sql( expected: set[Table], ) -> None: assert ( - extract_tables_from_jinja_sql(sql.format(engine=engine, macro=macro), engine) + extract_tables_from_jinja_sql( + sql=sql.format(engine=engine, macro=macro), + database=Mock(), + ) == expected )