From 9518509f1dc34f95d22e4f6f9b9e38eddf29b287 Mon Sep 17 00:00:00 2001 From: Daniel Vaz Gaspar Date: Tue, 5 Mar 2024 17:44:51 +0000 Subject: [PATCH] fix: improve explore REST api validations (#27395) --- superset/commands/explore/get.py | 11 +++++++++-- tests/integration_tests/explore/api_tests.py | 20 +++++++++++++++++++- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/superset/commands/explore/get.py b/superset/commands/explore/get.py index 9d715bd63dc5f..a0ff176109266 100644 --- a/superset/commands/explore/get.py +++ b/superset/commands/explore/get.py @@ -37,6 +37,7 @@ from superset.exceptions import SupersetException from superset.explore.exceptions import WrongEndpointError from superset.explore.permalink.exceptions import ExplorePermalinkGetFailedError +from superset.extensions import security_manager from superset.utils import core as utils from superset.views.utils import ( get_datasource_info, @@ -61,7 +62,6 @@ def __init__( # pylint: disable=too-many-locals,too-many-branches,too-many-statements def run(self) -> Optional[dict[str, Any]]: initial_form_data = {} - if self._permalink_key is not None: command = GetExplorePermalinkCommand(self._permalink_key) permalink_value = command.run() @@ -110,12 +110,19 @@ def run(self) -> Optional[dict[str, Any]]: self._datasource_type = SqlaTable.type datasource: Optional[BaseDatasource] = None + if self._datasource_id is not None: with contextlib.suppress(DatasourceNotFound): datasource = DatasourceDAO.get_datasource( cast(str, self._datasource_type), self._datasource_id ) - datasource_name = datasource.name if datasource else _("[Missing Dataset]") + + datasource_name = _("[Missing Dataset]") + + if datasource: + datasource_name = datasource.name + security_manager.can_access_datasource(datasource) + viz_type = form_data.get("viz_type") if not viz_type and datasource and datasource.default_endpoint: raise WrongEndpointError(redirect=datasource.default_endpoint) diff --git a/tests/integration_tests/explore/api_tests.py b/tests/integration_tests/explore/api_tests.py index c0b7f5fcd41d7..6d33f1c91676e 100644 --- a/tests/integration_tests/explore/api_tests.py +++ b/tests/integration_tests/explore/api_tests.py @@ -197,7 +197,7 @@ def test_get_from_permalink_unknown_key(test_client, login_as_admin): @patch("superset.security.SupersetSecurityManager.can_access_datasource") -def test_get_dataset_access_denied( +def test_get_dataset_access_denied_with_form_data_key( mock_can_access_datasource, test_client, login_as_admin, dataset ): message = "Dataset access denied" @@ -214,6 +214,24 @@ def test_get_dataset_access_denied( assert data["message"] == message +@patch("superset.security.SupersetSecurityManager.can_access_datasource") +def test_get_dataset_access_denied( + mock_can_access_datasource, test_client, login_as_admin, dataset +): + message = "Dataset access denied" + mock_can_access_datasource.side_effect = DatasetAccessDeniedError( + message=message, datasource_id=dataset.id, datasource_type=dataset.type + ) + resp = test_client.get( + f"api/v1/explore/?datasource_id={dataset.id}&datasource_type={dataset.type}" + ) + data = json.loads(resp.data.decode("utf-8")) + assert resp.status_code == 403 + assert data["datasource_id"] == dataset.id + assert data["datasource_type"] == dataset.type + assert data["message"] == message + + @patch("superset.daos.datasource.DatasourceDAO.get_datasource") def test_wrong_endpoint(mock_get_datasource, test_client, login_as_admin, dataset): dataset.default_endpoint = "another_endpoint"