chore(sqla): refactor query utils (#21811)

Co-authored-by: Ville Brofeldt <ville.brofeldt@apple.com>
This commit is contained in:
Ville Brofeldt 2022-10-17 10:40:42 +01:00 committed by GitHub
parent 7a7181a244
commit 52d33b05fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 70 additions and 8 deletions

View File

@ -95,6 +95,7 @@ from superset.exceptions import (
DatasetInvalidPermissionEvaluationException,
QueryClauseValidationException,
QueryObjectValidationError,
SupersetSecurityException,
)
from superset.extensions import feature_flag_manager
from superset.jinja_context import (
@ -655,19 +656,19 @@ def _process_sql_expression(
expression: Optional[str],
database_id: int,
schema: str,
template_processor: Optional[BaseTemplateProcessor],
template_processor: Optional[BaseTemplateProcessor] = None,
) -> Optional[str]:
if template_processor and expression:
expression = template_processor.process_template(expression)
if expression:
expression = validate_adhoc_subquery(
expression,
database_id,
schema,
)
try:
expression = validate_adhoc_subquery(
expression,
database_id,
schema,
)
expression = sanitize_clause(expression)
except QueryClauseValidationException as ex:
except (QueryClauseValidationException, SupersetSecurityException) as ex:
raise QueryObjectValidationError(ex.message) from ex
return expression
@ -1672,6 +1673,11 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
msg=ex.message,
)
) from ex
where = _process_sql_expression(
expression=where,
database_id=self.database_id,
schema=self.schema,
)
where_clause_and += [self.text(where)]
having = extras.get("having")
if having:
@ -1684,7 +1690,13 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
msg=ex.message,
)
) from ex
having = _process_sql_expression(
expression=having,
database_id=self.database_id,
schema=self.schema,
)
having_clause_and += [self.text(having)]
if apply_fetch_values_predicate and self.fetch_values_predicate:
qry = qry.where(self.get_fetch_values_predicate())
if granularity:

View File

@ -1098,3 +1098,53 @@ def test_chart_cache_timeout_chart_not_found(
rv = test_client.post(CHART_DATA_URI, json=physical_query_context)
assert rv.json["result"][0]["cache_timeout"] == 1010
@pytest.mark.parametrize(
"status_code,extras",
[
(200, {"where": "1 = 1"}),
(200, {"having": "count(*) > 0"}),
(400, {"where": "col1 in (select distinct col1 from physical_dataset)"}),
(400, {"having": "count(*) > (select count(*) from physical_dataset)"}),
],
)
@with_feature_flags(ALLOW_ADHOC_SUBQUERY=False)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_subquery_not_allowed(
test_client,
login_as_admin,
physical_dataset,
physical_query_context,
status_code,
extras,
):
physical_query_context["queries"][0]["extras"] = extras
rv = test_client.post(CHART_DATA_URI, json=physical_query_context)
assert rv.status_code == status_code
@pytest.mark.parametrize(
"status_code,extras",
[
(200, {"where": "1 = 1"}),
(200, {"having": "count(*) > 0"}),
(200, {"where": "col1 in (select distinct col1 from physical_dataset)"}),
(200, {"having": "count(*) > (select count(*) from physical_dataset)"}),
],
)
@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_subquery_allowed(
test_client,
login_as_admin,
physical_dataset,
physical_query_context,
status_code,
extras,
):
physical_query_context["queries"][0]["extras"] = extras
rv = test_client.post(CHART_DATA_URI, json=physical_query_context)
assert rv.status_code == status_code

View File

@ -262,7 +262,7 @@ class TestDatabaseModel(SupersetTestCase):
)
db.session.commit()
with pytest.raises(SupersetSecurityException):
with pytest.raises(QueryObjectValidationError):
table.get_sqla_query(**base_query_obj)
# Cleanup
db.session.delete(table)