feat: add support for comments in adhoc clauses (#19248)

* feat: add support for comments in adhoc clauses

* sanitize remaining freeform clauses

* sanitize adhoc having in frontend

* address review comment
This commit is contained in:
Ville Brofeldt 2022-03-19 00:08:06 +02:00 committed by GitHub
parent 97abc28a1f
commit f341025d80
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 109 additions and 60 deletions

View File

@ -23,6 +23,14 @@ import { QueryObjectFilterClause } from './types/Query';
import { isSimpleAdhocFilter } from './types/Filter';
import convertFilter from './convertFilter';
function sanitizeClause(clause: string): string {
let sanitizedClause = clause;
if (clause.includes('--')) {
sanitizedClause = `${clause}\n`;
}
return `(${sanitizedClause})`;
}
/** Logic formerly in viz.py's process_query_filters */
export default function processFilters(
formData: Partial<QueryFormData>,
@ -60,9 +68,9 @@ export default function processFilters(
});
// some filter-related fields need to go in `extras`
extras.having = freeformHaving.map(exp => `(${exp})`).join(' AND ');
extras.having = freeformHaving.map(sanitizeClause).join(' AND ');
extras.having_druid = simpleHaving;
extras.where = freeformWhere.map(exp => `(${exp})`).join(' AND ');
extras.where = freeformWhere.map(sanitizeClause).join(' AND ');
return {
filters: simpleWhere,

View File

@ -132,12 +132,12 @@ describe('processFilters', () => {
{
expressionType: 'SQL',
clause: 'WHERE',
sqlExpression: 'tea = "jasmine"',
sqlExpression: "tea = 'jasmine'",
},
{
expressionType: 'SQL',
clause: 'WHERE',
sqlExpression: 'cup = "large"',
sqlExpression: "cup = 'large' -- comment",
},
{
expressionType: 'SQL',
@ -147,13 +147,13 @@ describe('processFilters', () => {
{
expressionType: 'SQL',
clause: 'HAVING',
sqlExpression: 'waitTime <= 180',
sqlExpression: 'waitTime <= 180 -- comment',
},
],
}),
).toEqual({
extras: {
having: '(ice = 25 OR ice = 50) AND (waitTime <= 180)',
having: '(ice = 25 OR ice = 50) AND (waitTime <= 180 -- comment\n)',
having_druid: [
{
col: 'sweetness',
@ -166,7 +166,7 @@ describe('processFilters', () => {
val: '50',
},
],
where: '(tea = "jasmine") AND (cup = "large")',
where: "(tea = 'jasmine') AND (cup = 'large' -- comment\n)",
},
filters: [
{

View File

@ -30,7 +30,7 @@ from superset.exceptions import (
QueryClauseValidationException,
QueryObjectValidationError,
)
from superset.sql_parse import validate_filter_clause
from superset.sql_parse import sanitize_clause
from superset.superset_typing import Column, Metric, OrderBy
from superset.utils import pandas_postprocessing
from superset.utils.core import (
@ -272,7 +272,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
try:
self._validate_there_are_no_missing_series()
self._validate_no_have_duplicate_labels()
self._validate_filters()
self._sanitize_filters()
return None
except QueryObjectValidationError as ex:
if raise_exceptions:
@ -291,12 +291,14 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
)
)
def _validate_filters(self) -> None:
def _sanitize_filters(self) -> None:
for param in ("where", "having"):
clause = self.extras.get(param)
if clause:
try:
validate_filter_clause(clause)
sanitized_clause = sanitize_clause(clause)
if sanitized_clause != clause:
self.extras[param] = sanitized_clause
except QueryClauseValidationException as ex:
raise QueryObjectValidationError(ex.message) from ex

View File

@ -82,7 +82,10 @@ from superset.connectors.sqla.utils import (
)
from superset.datasets.models import Dataset as NewDataset
from superset.db_engine_specs.base import BaseEngineSpec, CTE_ALIAS, TimestampExpression
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import (
QueryClauseValidationException,
QueryObjectValidationError,
)
from superset.jinja_context import (
BaseTemplateProcessor,
ExtraCache,
@ -96,7 +99,7 @@ from superset.models.helpers import (
clone_model,
QueryResult,
)
from superset.sql_parse import ParsedQuery
from superset.sql_parse import ParsedQuery, sanitize_clause
from superset.superset_typing import (
AdhocColumn,
AdhocMetric,
@ -887,6 +890,10 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
tp = self.get_template_processor()
expression = tp.process_template(cast(str, metric["sqlExpression"]))
validate_adhoc_subquery(expression)
try:
expression = sanitize_clause(expression)
except QueryClauseValidationException as ex:
raise QueryObjectValidationError(ex.message) from ex
sqla_metric = literal_column(expression)
else:
raise QueryObjectValidationError("Adhoc metric expressionType is invalid")
@ -912,6 +919,10 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
expression = template_processor.process_template(expression)
if expression:
validate_adhoc_subquery(expression)
try:
expression = sanitize_clause(expression)
except QueryClauseValidationException as ex:
raise QueryObjectValidationError(ex.message) from ex
sqla_metric = literal_column(expression)
return self.make_sqla_column_compatible(sqla_metric, label)
@ -1353,7 +1364,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
where = extras.get("where")
if where:
try:
where = template_processor.process_template(where)
where = template_processor.process_template(f"({where})")
except TemplateError as ex:
raise QueryObjectValidationError(
_(
@ -1361,11 +1372,11 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
msg=ex.message,
)
) from ex
where_clause_and += [self.text(f"({where})")]
where_clause_and += [self.text(where)]
having = extras.get("having")
if having:
try:
having = template_processor.process_template(having)
having = template_processor.process_template(f"({having})")
except TemplateError as ex:
raise QueryObjectValidationError(
_(
@ -1373,7 +1384,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
msg=ex.message,
)
) from ex
having_clause_and += [self.text(f"({having})")]
having_clause_and += [self.text(having)]
if apply_fetch_values_predicate and self.fetch_values_predicate:
qry = qry.where(self.get_fetch_values_predicate())
if granularity:

View File

@ -32,6 +32,7 @@ from sqlparse.sql import (
Where,
)
from sqlparse.tokens import (
Comment,
CTE,
DDL,
DML,
@ -441,25 +442,35 @@ class ParsedQuery:
return str_res
def validate_filter_clause(clause: str) -> None:
if sqlparse.format(clause, strip_comments=True) != sqlparse.format(clause):
raise QueryClauseValidationException("Filter clause contains comment")
def sanitize_clause(clause: str) -> str:
# clause = sqlparse.format(clause, strip_comments=True)
statements = sqlparse.parse(clause)
if len(statements) != 1:
raise QueryClauseValidationException("Filter clause contains multiple queries")
raise QueryClauseValidationException("Clause contains multiple statements")
open_parens = 0
previous_token = None
for token in statements[0]:
if token.value == "/" and previous_token and previous_token.value == "*":
raise QueryClauseValidationException("Closing unopened multiline comment")
if token.value == "*" and previous_token and previous_token.value == "/":
raise QueryClauseValidationException("Unclosed multiline comment")
if token.value in (")", "("):
open_parens += 1 if token.value == "(" else -1
if open_parens < 0:
raise QueryClauseValidationException(
"Closing unclosed parenthesis in filter clause"
)
previous_token = token
if open_parens > 0:
raise QueryClauseValidationException("Unclosed parenthesis in filter clause")
if previous_token and previous_token.ttype in Comment:
if previous_token.value[-1] != "\n":
clause = f"{clause}\n"
return clause
class InsertRLSState(str, Enum):
"""

View File

@ -98,6 +98,7 @@ from superset.exceptions import (
SupersetException,
SupersetTimeoutException,
)
from superset.sql_parse import sanitize_clause
from superset.superset_typing import (
AdhocColumn,
AdhocMetric,
@ -1366,10 +1367,12 @@ def split_adhoc_filters_into_base_filters( # pylint: disable=invalid-name
}
)
elif expression_type == "SQL":
sql_expression = adhoc_filter.get("sqlExpression")
sql_expression = sanitize_clause(sql_expression)
if clause == "WHERE":
sql_where_filters.append(adhoc_filter.get("sqlExpression"))
sql_where_filters.append(sql_expression)
elif clause == "HAVING":
sql_having_filters.append(adhoc_filter.get("sqlExpression"))
sql_having_filters.append(sql_expression)
form_data["where"] = " AND ".join(
["({})".format(sql) for sql in sql_where_filters]
)

View File

@ -62,14 +62,13 @@ from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import (
CacheLoadError,
NullValueException,
QueryClauseValidationException,
QueryObjectValidationError,
SpatialException,
SupersetSecurityException,
)
from superset.extensions import cache_manager, security_manager
from superset.models.helpers import QueryResult
from superset.sql_parse import validate_filter_clause
from superset.sql_parse import sanitize_clause
from superset.superset_typing import (
Column,
Metric,
@ -391,10 +390,9 @@ class BaseViz: # pylint: disable=too-many-public-methods
for param in ("where", "having"):
clause = self.form_data.get(param)
if clause:
try:
validate_filter_clause(clause)
except QueryClauseValidationException as ex:
raise QueryObjectValidationError(ex.message) from ex
sanitized_clause = sanitize_clause(clause)
if sanitized_clause != clause:
self.form_data[param] = sanitized_clause
# extras are used to query elements specific to a datasource type
# for instance the extra where clause that applies only to Tables

View File

@ -465,6 +465,28 @@ class TestPostChartDataApi(BaseTestChartDataApi):
assert rv.status_code == 400
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_with_where_parameter_including_comment___200(self):
self.query_context_payload["queries"][0]["filters"] = []
self.query_context_payload["queries"][0]["extras"]["where"] = "1 = 1 -- abc"
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
assert rv.status_code == 200
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_with_orderby_parameter_with_second_query__400(self):
self.query_context_payload["queries"][0]["filters"] = []
self.query_context_payload["queries"][0]["orderby"] = [
[
{"expressionType": "SQL", "sqlExpression": "sum__num; select 1, 1",},
True,
],
]
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
assert rv.status_code == 400
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_with_invalid_having_parameter_closing_and_comment__400(self):
self.query_context_payload["queries"][0]["filters"] = []

View File

@ -30,9 +30,9 @@ from superset.sql_parse import (
insert_rls,
matches_table_name,
ParsedQuery,
sanitize_clause,
strip_comments_from_sql,
Table,
validate_filter_clause,
)
@ -1142,52 +1142,46 @@ def test_strip_comments_from_sql() -> None:
)
def test_validate_filter_clause_valid():
def test_sanitize_clause_valid():
# regular clauses
assert validate_filter_clause("col = 1") is None
assert validate_filter_clause("1=\t\n1") is None
assert validate_filter_clause("(col = 1)") is None
assert validate_filter_clause("(col1 = 1) AND (col2 = 2)") is None
assert sanitize_clause("col = 1") == "col = 1"
assert sanitize_clause("1=\t\n1") == "1=\t\n1"
assert sanitize_clause("(col = 1)") == "(col = 1)"
assert sanitize_clause("(col1 = 1) AND (col2 = 2)") == "(col1 = 1) AND (col2 = 2)"
assert sanitize_clause("col = 'abc' -- comment") == "col = 'abc' -- comment\n"
# Valid literal values that appear to be invalid
assert validate_filter_clause("col = 'col1 = 1) AND (col2 = 2'") is None
assert validate_filter_clause("col = 'select 1; select 2'") is None
assert validate_filter_clause("col = 'abc -- comment'") is None
# Valid literal values that at could be flagged as invalid by a naive query parser
assert (
sanitize_clause("col = 'col1 = 1) AND (col2 = 2'")
== "col = 'col1 = 1) AND (col2 = 2'"
)
assert sanitize_clause("col = 'select 1; select 2'") == "col = 'select 1; select 2'"
assert sanitize_clause("col = 'abc -- comment'") == "col = 'abc -- comment'"
def test_validate_filter_clause_closing_unclosed():
def test_sanitize_clause_closing_unclosed():
with pytest.raises(QueryClauseValidationException):
validate_filter_clause("col1 = 1) AND (col2 = 2)")
sanitize_clause("col1 = 1) AND (col2 = 2)")
def test_validate_filter_clause_unclosed():
def test_sanitize_clause_unclosed():
with pytest.raises(QueryClauseValidationException):
validate_filter_clause("(col1 = 1) AND (col2 = 2")
sanitize_clause("(col1 = 1) AND (col2 = 2")
def test_validate_filter_clause_closing_and_unclosed():
def test_sanitize_clause_closing_and_unclosed():
with pytest.raises(QueryClauseValidationException):
validate_filter_clause("col1 = 1) AND (col2 = 2")
sanitize_clause("col1 = 1) AND (col2 = 2")
def test_validate_filter_clause_closing_and_unclosed_nested():
def test_sanitize_clause_closing_and_unclosed_nested():
with pytest.raises(QueryClauseValidationException):
validate_filter_clause("(col1 = 1)) AND ((col2 = 2)")
sanitize_clause("(col1 = 1)) AND ((col2 = 2)")
def test_validate_filter_clause_multiple():
def test_sanitize_clause_multiple():
with pytest.raises(QueryClauseValidationException):
validate_filter_clause("TRUE; SELECT 1")
def test_validate_filter_clause_comment():
with pytest.raises(QueryClauseValidationException):
validate_filter_clause("1 = 1 -- comment")
def test_validate_filter_clause_subquery_comment():
with pytest.raises(QueryClauseValidationException):
validate_filter_clause("(1 = 1 -- comment\n)")
sanitize_clause("TRUE; SELECT 1")
def test_sqlparse_issue_652():