feat: add support for comments in adhoc clauses (#19248)
* feat: add support for comments in adhoc clauses * sanitize remaining freeform clauses * sanitize adhoc having in frontend * address review comment
This commit is contained in:
parent
97abc28a1f
commit
f341025d80
|
|
@ -23,6 +23,14 @@ import { QueryObjectFilterClause } from './types/Query';
|
|||
import { isSimpleAdhocFilter } from './types/Filter';
|
||||
import convertFilter from './convertFilter';
|
||||
|
||||
function sanitizeClause(clause: string): string {
|
||||
let sanitizedClause = clause;
|
||||
if (clause.includes('--')) {
|
||||
sanitizedClause = `${clause}\n`;
|
||||
}
|
||||
return `(${sanitizedClause})`;
|
||||
}
|
||||
|
||||
/** Logic formerly in viz.py's process_query_filters */
|
||||
export default function processFilters(
|
||||
formData: Partial<QueryFormData>,
|
||||
|
|
@ -60,9 +68,9 @@ export default function processFilters(
|
|||
});
|
||||
|
||||
// some filter-related fields need to go in `extras`
|
||||
extras.having = freeformHaving.map(exp => `(${exp})`).join(' AND ');
|
||||
extras.having = freeformHaving.map(sanitizeClause).join(' AND ');
|
||||
extras.having_druid = simpleHaving;
|
||||
extras.where = freeformWhere.map(exp => `(${exp})`).join(' AND ');
|
||||
extras.where = freeformWhere.map(sanitizeClause).join(' AND ');
|
||||
|
||||
return {
|
||||
filters: simpleWhere,
|
||||
|
|
|
|||
|
|
@ -132,12 +132,12 @@ describe('processFilters', () => {
|
|||
{
|
||||
expressionType: 'SQL',
|
||||
clause: 'WHERE',
|
||||
sqlExpression: 'tea = "jasmine"',
|
||||
sqlExpression: "tea = 'jasmine'",
|
||||
},
|
||||
{
|
||||
expressionType: 'SQL',
|
||||
clause: 'WHERE',
|
||||
sqlExpression: 'cup = "large"',
|
||||
sqlExpression: "cup = 'large' -- comment",
|
||||
},
|
||||
{
|
||||
expressionType: 'SQL',
|
||||
|
|
@ -147,13 +147,13 @@ describe('processFilters', () => {
|
|||
{
|
||||
expressionType: 'SQL',
|
||||
clause: 'HAVING',
|
||||
sqlExpression: 'waitTime <= 180',
|
||||
sqlExpression: 'waitTime <= 180 -- comment',
|
||||
},
|
||||
],
|
||||
}),
|
||||
).toEqual({
|
||||
extras: {
|
||||
having: '(ice = 25 OR ice = 50) AND (waitTime <= 180)',
|
||||
having: '(ice = 25 OR ice = 50) AND (waitTime <= 180 -- comment\n)',
|
||||
having_druid: [
|
||||
{
|
||||
col: 'sweetness',
|
||||
|
|
@ -166,7 +166,7 @@ describe('processFilters', () => {
|
|||
val: '50',
|
||||
},
|
||||
],
|
||||
where: '(tea = "jasmine") AND (cup = "large")',
|
||||
where: "(tea = 'jasmine') AND (cup = 'large' -- comment\n)",
|
||||
},
|
||||
filters: [
|
||||
{
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ from superset.exceptions import (
|
|||
QueryClauseValidationException,
|
||||
QueryObjectValidationError,
|
||||
)
|
||||
from superset.sql_parse import validate_filter_clause
|
||||
from superset.sql_parse import sanitize_clause
|
||||
from superset.superset_typing import Column, Metric, OrderBy
|
||||
from superset.utils import pandas_postprocessing
|
||||
from superset.utils.core import (
|
||||
|
|
@ -272,7 +272,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
|||
try:
|
||||
self._validate_there_are_no_missing_series()
|
||||
self._validate_no_have_duplicate_labels()
|
||||
self._validate_filters()
|
||||
self._sanitize_filters()
|
||||
return None
|
||||
except QueryObjectValidationError as ex:
|
||||
if raise_exceptions:
|
||||
|
|
@ -291,12 +291,14 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
|||
)
|
||||
)
|
||||
|
||||
def _validate_filters(self) -> None:
|
||||
def _sanitize_filters(self) -> None:
|
||||
for param in ("where", "having"):
|
||||
clause = self.extras.get(param)
|
||||
if clause:
|
||||
try:
|
||||
validate_filter_clause(clause)
|
||||
sanitized_clause = sanitize_clause(clause)
|
||||
if sanitized_clause != clause:
|
||||
self.extras[param] = sanitized_clause
|
||||
except QueryClauseValidationException as ex:
|
||||
raise QueryObjectValidationError(ex.message) from ex
|
||||
|
||||
|
|
|
|||
|
|
@ -82,7 +82,10 @@ from superset.connectors.sqla.utils import (
|
|||
)
|
||||
from superset.datasets.models import Dataset as NewDataset
|
||||
from superset.db_engine_specs.base import BaseEngineSpec, CTE_ALIAS, TimestampExpression
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import (
|
||||
QueryClauseValidationException,
|
||||
QueryObjectValidationError,
|
||||
)
|
||||
from superset.jinja_context import (
|
||||
BaseTemplateProcessor,
|
||||
ExtraCache,
|
||||
|
|
@ -96,7 +99,7 @@ from superset.models.helpers import (
|
|||
clone_model,
|
||||
QueryResult,
|
||||
)
|
||||
from superset.sql_parse import ParsedQuery
|
||||
from superset.sql_parse import ParsedQuery, sanitize_clause
|
||||
from superset.superset_typing import (
|
||||
AdhocColumn,
|
||||
AdhocMetric,
|
||||
|
|
@ -887,6 +890,10 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
tp = self.get_template_processor()
|
||||
expression = tp.process_template(cast(str, metric["sqlExpression"]))
|
||||
validate_adhoc_subquery(expression)
|
||||
try:
|
||||
expression = sanitize_clause(expression)
|
||||
except QueryClauseValidationException as ex:
|
||||
raise QueryObjectValidationError(ex.message) from ex
|
||||
sqla_metric = literal_column(expression)
|
||||
else:
|
||||
raise QueryObjectValidationError("Adhoc metric expressionType is invalid")
|
||||
|
|
@ -912,6 +919,10 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
expression = template_processor.process_template(expression)
|
||||
if expression:
|
||||
validate_adhoc_subquery(expression)
|
||||
try:
|
||||
expression = sanitize_clause(expression)
|
||||
except QueryClauseValidationException as ex:
|
||||
raise QueryObjectValidationError(ex.message) from ex
|
||||
sqla_metric = literal_column(expression)
|
||||
return self.make_sqla_column_compatible(sqla_metric, label)
|
||||
|
||||
|
|
@ -1353,7 +1364,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
where = extras.get("where")
|
||||
if where:
|
||||
try:
|
||||
where = template_processor.process_template(where)
|
||||
where = template_processor.process_template(f"({where})")
|
||||
except TemplateError as ex:
|
||||
raise QueryObjectValidationError(
|
||||
_(
|
||||
|
|
@ -1361,11 +1372,11 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
msg=ex.message,
|
||||
)
|
||||
) from ex
|
||||
where_clause_and += [self.text(f"({where})")]
|
||||
where_clause_and += [self.text(where)]
|
||||
having = extras.get("having")
|
||||
if having:
|
||||
try:
|
||||
having = template_processor.process_template(having)
|
||||
having = template_processor.process_template(f"({having})")
|
||||
except TemplateError as ex:
|
||||
raise QueryObjectValidationError(
|
||||
_(
|
||||
|
|
@ -1373,7 +1384,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
msg=ex.message,
|
||||
)
|
||||
) from ex
|
||||
having_clause_and += [self.text(f"({having})")]
|
||||
having_clause_and += [self.text(having)]
|
||||
if apply_fetch_values_predicate and self.fetch_values_predicate:
|
||||
qry = qry.where(self.get_fetch_values_predicate())
|
||||
if granularity:
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ from sqlparse.sql import (
|
|||
Where,
|
||||
)
|
||||
from sqlparse.tokens import (
|
||||
Comment,
|
||||
CTE,
|
||||
DDL,
|
||||
DML,
|
||||
|
|
@ -441,25 +442,35 @@ class ParsedQuery:
|
|||
return str_res
|
||||
|
||||
|
||||
def validate_filter_clause(clause: str) -> None:
|
||||
if sqlparse.format(clause, strip_comments=True) != sqlparse.format(clause):
|
||||
raise QueryClauseValidationException("Filter clause contains comment")
|
||||
|
||||
def sanitize_clause(clause: str) -> str:
|
||||
# clause = sqlparse.format(clause, strip_comments=True)
|
||||
statements = sqlparse.parse(clause)
|
||||
if len(statements) != 1:
|
||||
raise QueryClauseValidationException("Filter clause contains multiple queries")
|
||||
raise QueryClauseValidationException("Clause contains multiple statements")
|
||||
open_parens = 0
|
||||
|
||||
previous_token = None
|
||||
for token in statements[0]:
|
||||
if token.value == "/" and previous_token and previous_token.value == "*":
|
||||
raise QueryClauseValidationException("Closing unopened multiline comment")
|
||||
if token.value == "*" and previous_token and previous_token.value == "/":
|
||||
raise QueryClauseValidationException("Unclosed multiline comment")
|
||||
if token.value in (")", "("):
|
||||
open_parens += 1 if token.value == "(" else -1
|
||||
if open_parens < 0:
|
||||
raise QueryClauseValidationException(
|
||||
"Closing unclosed parenthesis in filter clause"
|
||||
)
|
||||
previous_token = token
|
||||
if open_parens > 0:
|
||||
raise QueryClauseValidationException("Unclosed parenthesis in filter clause")
|
||||
|
||||
if previous_token and previous_token.ttype in Comment:
|
||||
if previous_token.value[-1] != "\n":
|
||||
clause = f"{clause}\n"
|
||||
|
||||
return clause
|
||||
|
||||
|
||||
class InsertRLSState(str, Enum):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -98,6 +98,7 @@ from superset.exceptions import (
|
|||
SupersetException,
|
||||
SupersetTimeoutException,
|
||||
)
|
||||
from superset.sql_parse import sanitize_clause
|
||||
from superset.superset_typing import (
|
||||
AdhocColumn,
|
||||
AdhocMetric,
|
||||
|
|
@ -1366,10 +1367,12 @@ def split_adhoc_filters_into_base_filters( # pylint: disable=invalid-name
|
|||
}
|
||||
)
|
||||
elif expression_type == "SQL":
|
||||
sql_expression = adhoc_filter.get("sqlExpression")
|
||||
sql_expression = sanitize_clause(sql_expression)
|
||||
if clause == "WHERE":
|
||||
sql_where_filters.append(adhoc_filter.get("sqlExpression"))
|
||||
sql_where_filters.append(sql_expression)
|
||||
elif clause == "HAVING":
|
||||
sql_having_filters.append(adhoc_filter.get("sqlExpression"))
|
||||
sql_having_filters.append(sql_expression)
|
||||
form_data["where"] = " AND ".join(
|
||||
["({})".format(sql) for sql in sql_where_filters]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -62,14 +62,13 @@ from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
|||
from superset.exceptions import (
|
||||
CacheLoadError,
|
||||
NullValueException,
|
||||
QueryClauseValidationException,
|
||||
QueryObjectValidationError,
|
||||
SpatialException,
|
||||
SupersetSecurityException,
|
||||
)
|
||||
from superset.extensions import cache_manager, security_manager
|
||||
from superset.models.helpers import QueryResult
|
||||
from superset.sql_parse import validate_filter_clause
|
||||
from superset.sql_parse import sanitize_clause
|
||||
from superset.superset_typing import (
|
||||
Column,
|
||||
Metric,
|
||||
|
|
@ -391,10 +390,9 @@ class BaseViz: # pylint: disable=too-many-public-methods
|
|||
for param in ("where", "having"):
|
||||
clause = self.form_data.get(param)
|
||||
if clause:
|
||||
try:
|
||||
validate_filter_clause(clause)
|
||||
except QueryClauseValidationException as ex:
|
||||
raise QueryObjectValidationError(ex.message) from ex
|
||||
sanitized_clause = sanitize_clause(clause)
|
||||
if sanitized_clause != clause:
|
||||
self.form_data[param] = sanitized_clause
|
||||
|
||||
# extras are used to query elements specific to a datasource type
|
||||
# for instance the extra where clause that applies only to Tables
|
||||
|
|
|
|||
|
|
@ -465,6 +465,28 @@ class TestPostChartDataApi(BaseTestChartDataApi):
|
|||
|
||||
assert rv.status_code == 400
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_with_where_parameter_including_comment___200(self):
|
||||
self.query_context_payload["queries"][0]["filters"] = []
|
||||
self.query_context_payload["queries"][0]["extras"]["where"] = "1 = 1 -- abc"
|
||||
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
|
||||
|
||||
assert rv.status_code == 200
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_with_orderby_parameter_with_second_query__400(self):
|
||||
self.query_context_payload["queries"][0]["filters"] = []
|
||||
self.query_context_payload["queries"][0]["orderby"] = [
|
||||
[
|
||||
{"expressionType": "SQL", "sqlExpression": "sum__num; select 1, 1",},
|
||||
True,
|
||||
],
|
||||
]
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
|
||||
|
||||
assert rv.status_code == 400
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_with_invalid_having_parameter_closing_and_comment__400(self):
|
||||
self.query_context_payload["queries"][0]["filters"] = []
|
||||
|
|
|
|||
|
|
@ -30,9 +30,9 @@ from superset.sql_parse import (
|
|||
insert_rls,
|
||||
matches_table_name,
|
||||
ParsedQuery,
|
||||
sanitize_clause,
|
||||
strip_comments_from_sql,
|
||||
Table,
|
||||
validate_filter_clause,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1142,52 +1142,46 @@ def test_strip_comments_from_sql() -> None:
|
|||
)
|
||||
|
||||
|
||||
def test_validate_filter_clause_valid():
|
||||
def test_sanitize_clause_valid():
|
||||
# regular clauses
|
||||
assert validate_filter_clause("col = 1") is None
|
||||
assert validate_filter_clause("1=\t\n1") is None
|
||||
assert validate_filter_clause("(col = 1)") is None
|
||||
assert validate_filter_clause("(col1 = 1) AND (col2 = 2)") is None
|
||||
assert sanitize_clause("col = 1") == "col = 1"
|
||||
assert sanitize_clause("1=\t\n1") == "1=\t\n1"
|
||||
assert sanitize_clause("(col = 1)") == "(col = 1)"
|
||||
assert sanitize_clause("(col1 = 1) AND (col2 = 2)") == "(col1 = 1) AND (col2 = 2)"
|
||||
assert sanitize_clause("col = 'abc' -- comment") == "col = 'abc' -- comment\n"
|
||||
|
||||
# Valid literal values that appear to be invalid
|
||||
assert validate_filter_clause("col = 'col1 = 1) AND (col2 = 2'") is None
|
||||
assert validate_filter_clause("col = 'select 1; select 2'") is None
|
||||
assert validate_filter_clause("col = 'abc -- comment'") is None
|
||||
# Valid literal values that at could be flagged as invalid by a naive query parser
|
||||
assert (
|
||||
sanitize_clause("col = 'col1 = 1) AND (col2 = 2'")
|
||||
== "col = 'col1 = 1) AND (col2 = 2'"
|
||||
)
|
||||
assert sanitize_clause("col = 'select 1; select 2'") == "col = 'select 1; select 2'"
|
||||
assert sanitize_clause("col = 'abc -- comment'") == "col = 'abc -- comment'"
|
||||
|
||||
|
||||
def test_validate_filter_clause_closing_unclosed():
|
||||
def test_sanitize_clause_closing_unclosed():
|
||||
with pytest.raises(QueryClauseValidationException):
|
||||
validate_filter_clause("col1 = 1) AND (col2 = 2)")
|
||||
sanitize_clause("col1 = 1) AND (col2 = 2)")
|
||||
|
||||
|
||||
def test_validate_filter_clause_unclosed():
|
||||
def test_sanitize_clause_unclosed():
|
||||
with pytest.raises(QueryClauseValidationException):
|
||||
validate_filter_clause("(col1 = 1) AND (col2 = 2")
|
||||
sanitize_clause("(col1 = 1) AND (col2 = 2")
|
||||
|
||||
|
||||
def test_validate_filter_clause_closing_and_unclosed():
|
||||
def test_sanitize_clause_closing_and_unclosed():
|
||||
with pytest.raises(QueryClauseValidationException):
|
||||
validate_filter_clause("col1 = 1) AND (col2 = 2")
|
||||
sanitize_clause("col1 = 1) AND (col2 = 2")
|
||||
|
||||
|
||||
def test_validate_filter_clause_closing_and_unclosed_nested():
|
||||
def test_sanitize_clause_closing_and_unclosed_nested():
|
||||
with pytest.raises(QueryClauseValidationException):
|
||||
validate_filter_clause("(col1 = 1)) AND ((col2 = 2)")
|
||||
sanitize_clause("(col1 = 1)) AND ((col2 = 2)")
|
||||
|
||||
|
||||
def test_validate_filter_clause_multiple():
|
||||
def test_sanitize_clause_multiple():
|
||||
with pytest.raises(QueryClauseValidationException):
|
||||
validate_filter_clause("TRUE; SELECT 1")
|
||||
|
||||
|
||||
def test_validate_filter_clause_comment():
|
||||
with pytest.raises(QueryClauseValidationException):
|
||||
validate_filter_clause("1 = 1 -- comment")
|
||||
|
||||
|
||||
def test_validate_filter_clause_subquery_comment():
|
||||
with pytest.raises(QueryClauseValidationException):
|
||||
validate_filter_clause("(1 = 1 -- comment\n)")
|
||||
sanitize_clause("TRUE; SELECT 1")
|
||||
|
||||
|
||||
def test_sqlparse_issue_652():
|
||||
|
|
|
|||
Loading…
Reference in New Issue