diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py index 18e5de48c10d6..84a6753f22861 100644 --- a/superset/connectors/sqla/utils.py +++ b/superset/connectors/sqla/utils.py @@ -38,7 +38,8 @@ ) from superset.models.core import Database from superset.result_set import SupersetResultSet -from superset.sql_parse import ParsedQuery, Table +from superset.sql.parse import SQLScript +from superset.sql_parse import Table from superset.superset_typing import ResultSetColumnType if TYPE_CHECKING: @@ -105,8 +106,8 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]: sql = dataset.get_template_processor().process_template( dataset.sql, **dataset.template_params_dict ) - parsed_query = ParsedQuery(sql, engine=db_engine_spec.engine) - if not db_engine_spec.is_readonly_query(parsed_query): + parsed_script = SQLScript(sql, engine=db_engine_spec.engine) + if parsed_script.has_mutation(): raise SupersetSecurityException( SupersetError( error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR, @@ -114,8 +115,7 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]: level=ErrorLevel.ERROR, ) ) - statements = parsed_query.get_statements() - if len(statements) > 1: + if len(parsed_script.statements) > 1: raise SupersetSecurityException( SupersetError( error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR, @@ -127,7 +127,7 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]: dataset.database, dataset.catalog, dataset.schema, - statements[0], + sql, ) diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index a8bae9b64d4a7..b9745cf8f7c62 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -1981,6 +1981,7 @@ def test_gets_owned_created_favorited_by_me_filter(self): self.assertEqual(rv.status_code, 200) data = json.loads(rv.data.decode("utf-8")) + data["result"].sort(key=lambda x: x["datasource_id"]) assert data["result"][0]["slice_name"] == "name0" assert data["result"][0]["datasource_id"] == 1 diff --git a/tests/unit_tests/connectors/sqla/utils_test.py b/tests/unit_tests/connectors/sqla/utils_test.py index 75d5a1fe32914..0da3ab7e95a9d 100644 --- a/tests/unit_tests/connectors/sqla/utils_test.py +++ b/tests/unit_tests/connectors/sqla/utils_test.py @@ -15,9 +15,14 @@ # specific language governing permissions and limitations # under the License. +import pytest from pytest_mock import MockerFixture -from superset.connectors.sqla.utils import get_columns_description +from superset.connectors.sqla.utils import ( + get_columns_description, + get_virtual_table_metadata, +) +from superset.exceptions import SupersetSecurityException # Returns column descriptions when given valid database, catalog, schema, and query @@ -89,3 +94,46 @@ def test_returns_column_descriptions(mocker: MockerFixture) -> None: "is_dttm": False, }, ] + + +def test_get_virtual_table_metadata(mocker: MockerFixture) -> None: + """ + Test the `get_virtual_table_metadata` function. + """ + mocker.patch( + "superset.connectors.sqla.utils.get_columns_description", + return_value=[{"name": "one", "type": "INTEGER"}], + ) + dataset = mocker.MagicMock( + sql="with source as ( select 1 as one ) select * from source", + ) + dataset.database.db_engine_spec.engine = "postgresql" + dataset.get_template_processor().process_template.return_value = dataset.sql + + assert get_virtual_table_metadata(dataset) == [{"name": "one", "type": "INTEGER"}] + + +def test_get_virtual_table_metadata_mutating(mocker: MockerFixture) -> None: + """ + Test the `get_virtual_table_metadata` function with mutating SQL. + """ + dataset = mocker.MagicMock(sql="DROP TABLE sample_data") + dataset.database.db_engine_spec.engine = "postgresql" + dataset.get_template_processor().process_template.return_value = dataset.sql + + with pytest.raises(SupersetSecurityException) as excinfo: + get_virtual_table_metadata(dataset) + assert str(excinfo.value) == "Only `SELECT` statements are allowed" + + +def test_get_virtual_table_metadata_multiple(mocker: MockerFixture) -> None: + """ + Test the `get_virtual_table_metadata` function with multiple statements. + """ + dataset = mocker.MagicMock(sql="SELECT 1; SELECT 2") + dataset.database.db_engine_spec.engine = "postgresql" + dataset.get_template_processor().process_template.return_value = dataset.sql + + with pytest.raises(SupersetSecurityException) as excinfo: + get_virtual_table_metadata(dataset) + assert str(excinfo.value) == "Only single queries supported"