chore(sql): clean up invalid filter clause exception types (#17702)
* chore(sql): clean up invalid filter clause exception types * fix lint * rename exception
This commit is contained in:
parent
1af99eabc2
commit
3a42071e0f
|
|
@ -25,7 +25,11 @@ from flask_babel import gettext as _
|
||||||
from pandas import DataFrame
|
from pandas import DataFrame
|
||||||
|
|
||||||
from superset.common.chart_data import ChartDataResultType
|
from superset.common.chart_data import ChartDataResultType
|
||||||
from superset.exceptions import QueryObjectValidationError
|
from superset.exceptions import (
|
||||||
|
QueryClauseValidationException,
|
||||||
|
QueryObjectValidationError,
|
||||||
|
)
|
||||||
|
from superset.sql_parse import validate_filter_clause
|
||||||
from superset.typing import Column, Metric, OrderBy
|
from superset.typing import Column, Metric, OrderBy
|
||||||
from superset.utils import pandas_postprocessing
|
from superset.utils import pandas_postprocessing
|
||||||
from superset.utils.core import (
|
from superset.utils.core import (
|
||||||
|
|
@ -267,6 +271,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
||||||
try:
|
try:
|
||||||
self._validate_there_are_no_missing_series()
|
self._validate_there_are_no_missing_series()
|
||||||
self._validate_no_have_duplicate_labels()
|
self._validate_no_have_duplicate_labels()
|
||||||
|
self._validate_filters()
|
||||||
return None
|
return None
|
||||||
except QueryObjectValidationError as ex:
|
except QueryObjectValidationError as ex:
|
||||||
if raise_exceptions:
|
if raise_exceptions:
|
||||||
|
|
@ -285,6 +290,15 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _validate_filters(self) -> None:
|
||||||
|
for param in ("where", "having"):
|
||||||
|
clause = self.extras.get(param)
|
||||||
|
if clause:
|
||||||
|
try:
|
||||||
|
validate_filter_clause(clause)
|
||||||
|
except QueryClauseValidationException as ex:
|
||||||
|
raise QueryObjectValidationError(ex.message) from ex
|
||||||
|
|
||||||
def _validate_there_are_no_missing_series(self) -> None:
|
def _validate_there_are_no_missing_series(self) -> None:
|
||||||
missing_series = [col for col in self.series_columns if col not in self.columns]
|
missing_series = [col for col in self.series_columns if col not in self.columns]
|
||||||
if missing_series:
|
if missing_series:
|
||||||
|
|
|
||||||
|
|
@ -194,6 +194,10 @@ class CacheLoadError(SupersetException):
|
||||||
status = 404
|
status = 404
|
||||||
|
|
||||||
|
|
||||||
|
class QueryClauseValidationException(SupersetException):
|
||||||
|
status = 400
|
||||||
|
|
||||||
|
|
||||||
class DashboardImportException(SupersetException):
|
class DashboardImportException(SupersetException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,8 @@ from sqlparse.sql import (
|
||||||
from sqlparse.tokens import DDL, DML, Keyword, Name, Punctuation, String, Whitespace
|
from sqlparse.tokens import DDL, DML, Keyword, Name, Punctuation, String, Whitespace
|
||||||
from sqlparse.utils import imt
|
from sqlparse.utils import imt
|
||||||
|
|
||||||
|
from superset.exceptions import QueryClauseValidationException
|
||||||
|
|
||||||
RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"}
|
RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"}
|
||||||
ON_KEYWORD = "ON"
|
ON_KEYWORD = "ON"
|
||||||
PRECEDES_TABLE_NAME = {"FROM", "JOIN", "DESCRIBE", "WITH", "LEFT JOIN", "RIGHT JOIN"}
|
PRECEDES_TABLE_NAME = {"FROM", "JOIN", "DESCRIBE", "WITH", "LEFT JOIN", "RIGHT JOIN"}
|
||||||
|
|
@ -378,3 +380,23 @@ class ParsedQuery:
|
||||||
for i in statement.tokens:
|
for i in statement.tokens:
|
||||||
str_res += str(i.value)
|
str_res += str(i.value)
|
||||||
return str_res
|
return str_res
|
||||||
|
|
||||||
|
|
||||||
|
def validate_filter_clause(clause: str) -> None:
|
||||||
|
if sqlparse.format(clause, strip_comments=True) != sqlparse.format(clause):
|
||||||
|
raise QueryClauseValidationException("Filter clause contains comment")
|
||||||
|
|
||||||
|
statements = sqlparse.parse(clause)
|
||||||
|
if len(statements) != 1:
|
||||||
|
raise QueryClauseValidationException("Filter clause contains multiple queries")
|
||||||
|
open_parens = 0
|
||||||
|
|
||||||
|
for token in statements[0]:
|
||||||
|
if token.value in (")", "("):
|
||||||
|
open_parens += 1 if token.value == "(" else -1
|
||||||
|
if open_parens < 0:
|
||||||
|
raise QueryClauseValidationException(
|
||||||
|
"Closing unclosed parenthesis in filter clause"
|
||||||
|
)
|
||||||
|
if open_parens > 0:
|
||||||
|
raise QueryClauseValidationException("Unclosed parenthesis in filter clause")
|
||||||
|
|
|
||||||
|
|
@ -62,12 +62,14 @@ from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||||
from superset.exceptions import (
|
from superset.exceptions import (
|
||||||
CacheLoadError,
|
CacheLoadError,
|
||||||
NullValueException,
|
NullValueException,
|
||||||
|
QueryClauseValidationException,
|
||||||
QueryObjectValidationError,
|
QueryObjectValidationError,
|
||||||
SpatialException,
|
SpatialException,
|
||||||
SupersetSecurityException,
|
SupersetSecurityException,
|
||||||
)
|
)
|
||||||
from superset.extensions import cache_manager, security_manager
|
from superset.extensions import cache_manager, security_manager
|
||||||
from superset.models.helpers import QueryResult
|
from superset.models.helpers import QueryResult
|
||||||
|
from superset.sql_parse import validate_filter_clause
|
||||||
from superset.typing import Column, Metric, QueryObjectDict, VizData, VizPayload
|
from superset.typing import Column, Metric, QueryObjectDict, VizData, VizPayload
|
||||||
from superset.utils import core as utils, csv
|
from superset.utils import core as utils, csv
|
||||||
from superset.utils.cache import set_and_log_cache
|
from superset.utils.cache import set_and_log_cache
|
||||||
|
|
@ -373,6 +375,15 @@ class BaseViz: # pylint: disable=too-many-public-methods
|
||||||
self.from_dttm = from_dttm
|
self.from_dttm = from_dttm
|
||||||
self.to_dttm = to_dttm
|
self.to_dttm = to_dttm
|
||||||
|
|
||||||
|
# validate sql filters
|
||||||
|
for param in ("where", "having"):
|
||||||
|
clause = self.form_data.get(param)
|
||||||
|
if clause:
|
||||||
|
try:
|
||||||
|
validate_filter_clause(clause)
|
||||||
|
except QueryClauseValidationException as ex:
|
||||||
|
raise QueryObjectValidationError(ex.message) from ex
|
||||||
|
|
||||||
# extras are used to query elements specific to a datasource type
|
# extras are used to query elements specific to a datasource type
|
||||||
# for instance the extra where clause that applies only to Tables
|
# for instance the extra where clause that applies only to Tables
|
||||||
extras = {
|
extras = {
|
||||||
|
|
|
||||||
|
|
@ -425,6 +425,28 @@ class TestPostChartDataApi(BaseTestChartDataApi):
|
||||||
|
|
||||||
assert rv.status_code == 400
|
assert rv.status_code == 400
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||||
|
def test_with_invalid_where_parameter_closing_unclosed__400(self):
|
||||||
|
self.query_context_payload["queries"][0]["filters"] = []
|
||||||
|
self.query_context_payload["queries"][0]["extras"][
|
||||||
|
"where"
|
||||||
|
] = "state = 'CA') OR (state = 'NY'"
|
||||||
|
|
||||||
|
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
|
||||||
|
|
||||||
|
assert rv.status_code == 400
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||||
|
def test_with_invalid_having_parameter_closing_and_comment__400(self):
|
||||||
|
self.query_context_payload["queries"][0]["filters"] = []
|
||||||
|
self.query_context_payload["queries"][0]["extras"][
|
||||||
|
"having"
|
||||||
|
] = "COUNT(1) = 0) UNION ALL SELECT 'abc', 1--comment"
|
||||||
|
|
||||||
|
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
|
||||||
|
|
||||||
|
assert rv.status_code == 400
|
||||||
|
|
||||||
def test_with_invalid_datasource__400(self):
|
def test_with_invalid_datasource__400(self):
|
||||||
self.query_context_payload["datasource"] = "abc"
|
self.query_context_payload["datasource"] = "abc"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,9 +20,16 @@
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Set
|
from typing import Set
|
||||||
|
|
||||||
|
import pytest
|
||||||
import sqlparse
|
import sqlparse
|
||||||
|
|
||||||
from superset.sql_parse import ParsedQuery, strip_comments_from_sql, Table
|
from superset.exceptions import QueryClauseValidationException
|
||||||
|
from superset.sql_parse import (
|
||||||
|
ParsedQuery,
|
||||||
|
strip_comments_from_sql,
|
||||||
|
Table,
|
||||||
|
validate_filter_clause,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def extract_tables(query: str) -> Set[Table]:
|
def extract_tables(query: str) -> Set[Table]:
|
||||||
|
|
@ -1144,3 +1151,51 @@ def test_strip_comments_from_sql() -> None:
|
||||||
strip_comments_from_sql("SELECT '--abc' as abc, col2 FROM table1\n")
|
strip_comments_from_sql("SELECT '--abc' as abc, col2 FROM table1\n")
|
||||||
== "SELECT '--abc' as abc, col2 FROM table1"
|
== "SELECT '--abc' as abc, col2 FROM table1"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_filter_clause_valid():
|
||||||
|
# regular clauses
|
||||||
|
assert validate_filter_clause("col = 1") is None
|
||||||
|
assert validate_filter_clause("1=\t\n1") is None
|
||||||
|
assert validate_filter_clause("(col = 1)") is None
|
||||||
|
assert validate_filter_clause("(col1 = 1) AND (col2 = 2)") is None
|
||||||
|
|
||||||
|
# Valid literal values that appear to be invalid
|
||||||
|
assert validate_filter_clause("col = 'col1 = 1) AND (col2 = 2'") is None
|
||||||
|
assert validate_filter_clause("col = 'select 1; select 2'") is None
|
||||||
|
assert validate_filter_clause("col = 'abc -- comment'") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_filter_clause_closing_unclosed():
|
||||||
|
with pytest.raises(QueryClauseValidationException):
|
||||||
|
validate_filter_clause("col1 = 1) AND (col2 = 2)")
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_filter_clause_unclosed():
|
||||||
|
with pytest.raises(QueryClauseValidationException):
|
||||||
|
validate_filter_clause("(col1 = 1) AND (col2 = 2")
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_filter_clause_closing_and_unclosed():
|
||||||
|
with pytest.raises(QueryClauseValidationException):
|
||||||
|
validate_filter_clause("col1 = 1) AND (col2 = 2")
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_filter_clause_closing_and_unclosed_nested():
|
||||||
|
with pytest.raises(QueryClauseValidationException):
|
||||||
|
validate_filter_clause("(col1 = 1)) AND ((col2 = 2)")
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_filter_clause_multiple():
|
||||||
|
with pytest.raises(QueryClauseValidationException):
|
||||||
|
validate_filter_clause("TRUE; SELECT 1")
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_filter_clause_comment():
|
||||||
|
with pytest.raises(QueryClauseValidationException):
|
||||||
|
validate_filter_clause("1 = 1 -- comment")
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_filter_clause_subquery_comment():
|
||||||
|
with pytest.raises(QueryClauseValidationException):
|
||||||
|
validate_filter_clause("(1 = 1 -- comment\n)")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue