Skip to content

Commit

Permalink
fix(dataset): use sqlglot for DML check (#31024)
Browse files Browse the repository at this point in the history
(cherry picked from commit 832fed1)
  • Loading branch information
betodealmeida authored and sadpandajoe committed Dec 4, 2024
1 parent 6e092dd commit 5e72994
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 7 deletions.
12 changes: 6 additions & 6 deletions superset/connectors/sqla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
)
from superset.models.core import Database
from superset.result_set import SupersetResultSet
from superset.sql_parse import ParsedQuery, Table
from superset.sql.parse import SQLScript
from superset.sql_parse import Table
from superset.superset_typing import ResultSetColumnType

if TYPE_CHECKING:
Expand Down Expand Up @@ -105,17 +106,16 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]:
sql = dataset.get_template_processor().process_template(
dataset.sql, **dataset.template_params_dict
)
parsed_query = ParsedQuery(sql, engine=db_engine_spec.engine)
if not db_engine_spec.is_readonly_query(parsed_query):
parsed_script = SQLScript(sql, engine=db_engine_spec.engine)
if parsed_script.has_mutation():
raise SupersetSecurityException(
SupersetError(
error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
message=_("Only `SELECT` statements are allowed"),
level=ErrorLevel.ERROR,
)
)
statements = parsed_query.get_statements()
if len(statements) > 1:
if len(parsed_script.statements) > 1:
raise SupersetSecurityException(
SupersetError(
error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
Expand All @@ -127,7 +127,7 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]:
dataset.database,
dataset.catalog,
dataset.schema,
statements[0],
sql,
)


Expand Down
1 change: 1 addition & 0 deletions tests/integration_tests/charts/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1981,6 +1981,7 @@ def test_gets_owned_created_favorited_by_me_filter(self):
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))

data["result"].sort(key=lambda x: x["datasource_id"])
assert data["result"][0]["slice_name"] == "name0"
assert data["result"][0]["datasource_id"] == 1

Expand Down
50 changes: 49 additions & 1 deletion tests/unit_tests/connectors/sqla/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,14 @@
# specific language governing permissions and limitations
# under the License.

import pytest
from pytest_mock import MockerFixture

from superset.connectors.sqla.utils import get_columns_description
from superset.connectors.sqla.utils import (
get_columns_description,
get_virtual_table_metadata,
)
from superset.exceptions import SupersetSecurityException


# Returns column descriptions when given valid database, catalog, schema, and query
Expand Down Expand Up @@ -89,3 +94,46 @@ def test_returns_column_descriptions(mocker: MockerFixture) -> None:
"is_dttm": False,
},
]


def test_get_virtual_table_metadata(mocker: MockerFixture) -> None:
"""
Test the `get_virtual_table_metadata` function.
"""
mocker.patch(
"superset.connectors.sqla.utils.get_columns_description",
return_value=[{"name": "one", "type": "INTEGER"}],
)
dataset = mocker.MagicMock(
sql="with source as ( select 1 as one ) select * from source",
)
dataset.database.db_engine_spec.engine = "postgresql"
dataset.get_template_processor().process_template.return_value = dataset.sql

assert get_virtual_table_metadata(dataset) == [{"name": "one", "type": "INTEGER"}]


def test_get_virtual_table_metadata_mutating(mocker: MockerFixture) -> None:
"""
Test the `get_virtual_table_metadata` function with mutating SQL.
"""
dataset = mocker.MagicMock(sql="DROP TABLE sample_data")
dataset.database.db_engine_spec.engine = "postgresql"
dataset.get_template_processor().process_template.return_value = dataset.sql

with pytest.raises(SupersetSecurityException) as excinfo:
get_virtual_table_metadata(dataset)
assert str(excinfo.value) == "Only `SELECT` statements are allowed"


def test_get_virtual_table_metadata_multiple(mocker: MockerFixture) -> None:
"""
Test the `get_virtual_table_metadata` function with multiple statements.
"""
dataset = mocker.MagicMock(sql="SELECT 1; SELECT 2")
dataset.database.db_engine_spec.engine = "postgresql"
dataset.get_template_processor().process_template.return_value = dataset.sql

with pytest.raises(SupersetSecurityException) as excinfo:
get_virtual_table_metadata(dataset)
assert str(excinfo.value) == "Only single queries supported"

0 comments on commit 5e72994

Please sign in to comment.