Skip to content

Commit

Permalink
chore(sqla): refactor query utils (#21811)
Browse files Browse the repository at this point in the history
Co-authored-by: Ville Brofeldt <[email protected]>
  • Loading branch information
villebro and villebro authored Oct 17, 2022
1 parent 7a7181a commit 52d33b0
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 8 deletions.
26 changes: 19 additions & 7 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
DatasetInvalidPermissionEvaluationException,
QueryClauseValidationException,
QueryObjectValidationError,
SupersetSecurityException,
)
from superset.extensions import feature_flag_manager
from superset.jinja_context import (
Expand Down Expand Up @@ -655,19 +656,19 @@ def _process_sql_expression(
expression: Optional[str],
database_id: int,
schema: str,
template_processor: Optional[BaseTemplateProcessor],
template_processor: Optional[BaseTemplateProcessor] = None,
) -> Optional[str]:
if template_processor and expression:
expression = template_processor.process_template(expression)
if expression:
expression = validate_adhoc_subquery(
expression,
database_id,
schema,
)
try:
expression = validate_adhoc_subquery(
expression,
database_id,
schema,
)
expression = sanitize_clause(expression)
except QueryClauseValidationException as ex:
except (QueryClauseValidationException, SupersetSecurityException) as ex:
raise QueryObjectValidationError(ex.message) from ex
return expression

Expand Down Expand Up @@ -1672,6 +1673,11 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
msg=ex.message,
)
) from ex
where = _process_sql_expression(
expression=where,
database_id=self.database_id,
schema=self.schema,
)
where_clause_and += [self.text(where)]
having = extras.get("having")
if having:
Expand All @@ -1684,7 +1690,13 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
msg=ex.message,
)
) from ex
having = _process_sql_expression(
expression=having,
database_id=self.database_id,
schema=self.schema,
)
having_clause_and += [self.text(having)]

if apply_fetch_values_predicate and self.fetch_values_predicate:
qry = qry.where(self.get_fetch_values_predicate())
if granularity:
Expand Down
50 changes: 50 additions & 0 deletions tests/integration_tests/charts/data/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,3 +1098,53 @@ def test_chart_cache_timeout_chart_not_found(

rv = test_client.post(CHART_DATA_URI, json=physical_query_context)
assert rv.json["result"][0]["cache_timeout"] == 1010


@pytest.mark.parametrize(
"status_code,extras",
[
(200, {"where": "1 = 1"}),
(200, {"having": "count(*) > 0"}),
(400, {"where": "col1 in (select distinct col1 from physical_dataset)"}),
(400, {"having": "count(*) > (select count(*) from physical_dataset)"}),
],
)
@with_feature_flags(ALLOW_ADHOC_SUBQUERY=False)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_subquery_not_allowed(
test_client,
login_as_admin,
physical_dataset,
physical_query_context,
status_code,
extras,
):
physical_query_context["queries"][0]["extras"] = extras
rv = test_client.post(CHART_DATA_URI, json=physical_query_context)

assert rv.status_code == status_code


@pytest.mark.parametrize(
"status_code,extras",
[
(200, {"where": "1 = 1"}),
(200, {"having": "count(*) > 0"}),
(200, {"where": "col1 in (select distinct col1 from physical_dataset)"}),
(200, {"having": "count(*) > (select count(*) from physical_dataset)"}),
],
)
@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_subquery_allowed(
test_client,
login_as_admin,
physical_dataset,
physical_query_context,
status_code,
extras,
):
physical_query_context["queries"][0]["extras"] = extras
rv = test_client.post(CHART_DATA_URI, json=physical_query_context)

assert rv.status_code == status_code
2 changes: 1 addition & 1 deletion tests/integration_tests/sqla_models_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def test_adhoc_metrics_and_calc_columns(self):
)
db.session.commit()

with pytest.raises(SupersetSecurityException):
with pytest.raises(QueryObjectValidationError):
table.get_sqla_query(**base_query_obj)
# Cleanup
db.session.delete(table)
Expand Down

0 comments on commit 52d33b0

Please sign in to comment.