chore(sql): clean up invalid filter clause exception types (#17702)

* chore(sql): clean up invalid filter clause exception types

* fix lint

* rename exception
This commit is contained in:
Ville Brofeldt 2021-12-09 17:49:32 +02:00 committed by GitHub
parent 1af99eabc2
commit 3a42071e0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 130 additions and 2 deletions

View File

@ -25,7 +25,11 @@ from flask_babel import gettext as _
from pandas import DataFrame from pandas import DataFrame
from superset.common.chart_data import ChartDataResultType from superset.common.chart_data import ChartDataResultType
from superset.exceptions import QueryObjectValidationError from superset.exceptions import (
QueryClauseValidationException,
QueryObjectValidationError,
)
from superset.sql_parse import validate_filter_clause
from superset.typing import Column, Metric, OrderBy from superset.typing import Column, Metric, OrderBy
from superset.utils import pandas_postprocessing from superset.utils import pandas_postprocessing
from superset.utils.core import ( from superset.utils.core import (
@ -267,6 +271,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
try: try:
self._validate_there_are_no_missing_series() self._validate_there_are_no_missing_series()
self._validate_no_have_duplicate_labels() self._validate_no_have_duplicate_labels()
self._validate_filters()
return None return None
except QueryObjectValidationError as ex: except QueryObjectValidationError as ex:
if raise_exceptions: if raise_exceptions:
@ -285,6 +290,15 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
) )
) )
def _validate_filters(self) -> None:
for param in ("where", "having"):
clause = self.extras.get(param)
if clause:
try:
validate_filter_clause(clause)
except QueryClauseValidationException as ex:
raise QueryObjectValidationError(ex.message) from ex
def _validate_there_are_no_missing_series(self) -> None: def _validate_there_are_no_missing_series(self) -> None:
missing_series = [col for col in self.series_columns if col not in self.columns] missing_series = [col for col in self.series_columns if col not in self.columns]
if missing_series: if missing_series:

View File

@ -194,6 +194,10 @@ class CacheLoadError(SupersetException):
status = 404 status = 404
class QueryClauseValidationException(SupersetException):
status = 400
class DashboardImportException(SupersetException): class DashboardImportException(SupersetException):
pass pass

View File

@ -32,6 +32,8 @@ from sqlparse.sql import (
from sqlparse.tokens import DDL, DML, Keyword, Name, Punctuation, String, Whitespace from sqlparse.tokens import DDL, DML, Keyword, Name, Punctuation, String, Whitespace
from sqlparse.utils import imt from sqlparse.utils import imt
from superset.exceptions import QueryClauseValidationException
RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"} RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"}
ON_KEYWORD = "ON" ON_KEYWORD = "ON"
PRECEDES_TABLE_NAME = {"FROM", "JOIN", "DESCRIBE", "WITH", "LEFT JOIN", "RIGHT JOIN"} PRECEDES_TABLE_NAME = {"FROM", "JOIN", "DESCRIBE", "WITH", "LEFT JOIN", "RIGHT JOIN"}
@ -378,3 +380,23 @@ class ParsedQuery:
for i in statement.tokens: for i in statement.tokens:
str_res += str(i.value) str_res += str(i.value)
return str_res return str_res
def validate_filter_clause(clause: str) -> None:
if sqlparse.format(clause, strip_comments=True) != sqlparse.format(clause):
raise QueryClauseValidationException("Filter clause contains comment")
statements = sqlparse.parse(clause)
if len(statements) != 1:
raise QueryClauseValidationException("Filter clause contains multiple queries")
open_parens = 0
for token in statements[0]:
if token.value in (")", "("):
open_parens += 1 if token.value == "(" else -1
if open_parens < 0:
raise QueryClauseValidationException(
"Closing unclosed parenthesis in filter clause"
)
if open_parens > 0:
raise QueryClauseValidationException("Unclosed parenthesis in filter clause")

View File

@ -62,12 +62,14 @@ from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import ( from superset.exceptions import (
CacheLoadError, CacheLoadError,
NullValueException, NullValueException,
QueryClauseValidationException,
QueryObjectValidationError, QueryObjectValidationError,
SpatialException, SpatialException,
SupersetSecurityException, SupersetSecurityException,
) )
from superset.extensions import cache_manager, security_manager from superset.extensions import cache_manager, security_manager
from superset.models.helpers import QueryResult from superset.models.helpers import QueryResult
from superset.sql_parse import validate_filter_clause
from superset.typing import Column, Metric, QueryObjectDict, VizData, VizPayload from superset.typing import Column, Metric, QueryObjectDict, VizData, VizPayload
from superset.utils import core as utils, csv from superset.utils import core as utils, csv
from superset.utils.cache import set_and_log_cache from superset.utils.cache import set_and_log_cache
@ -373,6 +375,15 @@ class BaseViz: # pylint: disable=too-many-public-methods
self.from_dttm = from_dttm self.from_dttm = from_dttm
self.to_dttm = to_dttm self.to_dttm = to_dttm
# validate sql filters
for param in ("where", "having"):
clause = self.form_data.get(param)
if clause:
try:
validate_filter_clause(clause)
except QueryClauseValidationException as ex:
raise QueryObjectValidationError(ex.message) from ex
# extras are used to query elements specific to a datasource type # extras are used to query elements specific to a datasource type
# for instance the extra where clause that applies only to Tables # for instance the extra where clause that applies only to Tables
extras = { extras = {

View File

@ -425,6 +425,28 @@ class TestPostChartDataApi(BaseTestChartDataApi):
assert rv.status_code == 400 assert rv.status_code == 400
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_with_invalid_where_parameter_closing_unclosed__400(self):
self.query_context_payload["queries"][0]["filters"] = []
self.query_context_payload["queries"][0]["extras"][
"where"
] = "state = 'CA') OR (state = 'NY'"
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
assert rv.status_code == 400
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_with_invalid_having_parameter_closing_and_comment__400(self):
self.query_context_payload["queries"][0]["filters"] = []
self.query_context_payload["queries"][0]["extras"][
"having"
] = "COUNT(1) = 0) UNION ALL SELECT 'abc', 1--comment"
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
assert rv.status_code == 400
def test_with_invalid_datasource__400(self): def test_with_invalid_datasource__400(self):
self.query_context_payload["datasource"] = "abc" self.query_context_payload["datasource"] = "abc"

View File

@ -20,9 +20,16 @@
import unittest import unittest
from typing import Set from typing import Set
import pytest
import sqlparse import sqlparse
from superset.sql_parse import ParsedQuery, strip_comments_from_sql, Table from superset.exceptions import QueryClauseValidationException
from superset.sql_parse import (
ParsedQuery,
strip_comments_from_sql,
Table,
validate_filter_clause,
)
def extract_tables(query: str) -> Set[Table]: def extract_tables(query: str) -> Set[Table]:
@ -1144,3 +1151,51 @@ def test_strip_comments_from_sql() -> None:
strip_comments_from_sql("SELECT '--abc' as abc, col2 FROM table1\n") strip_comments_from_sql("SELECT '--abc' as abc, col2 FROM table1\n")
== "SELECT '--abc' as abc, col2 FROM table1" == "SELECT '--abc' as abc, col2 FROM table1"
) )
def test_validate_filter_clause_valid():
# regular clauses
assert validate_filter_clause("col = 1") is None
assert validate_filter_clause("1=\t\n1") is None
assert validate_filter_clause("(col = 1)") is None
assert validate_filter_clause("(col1 = 1) AND (col2 = 2)") is None
# Valid literal values that appear to be invalid
assert validate_filter_clause("col = 'col1 = 1) AND (col2 = 2'") is None
assert validate_filter_clause("col = 'select 1; select 2'") is None
assert validate_filter_clause("col = 'abc -- comment'") is None
def test_validate_filter_clause_closing_unclosed():
with pytest.raises(QueryClauseValidationException):
validate_filter_clause("col1 = 1) AND (col2 = 2)")
def test_validate_filter_clause_unclosed():
with pytest.raises(QueryClauseValidationException):
validate_filter_clause("(col1 = 1) AND (col2 = 2")
def test_validate_filter_clause_closing_and_unclosed():
with pytest.raises(QueryClauseValidationException):
validate_filter_clause("col1 = 1) AND (col2 = 2")
def test_validate_filter_clause_closing_and_unclosed_nested():
with pytest.raises(QueryClauseValidationException):
validate_filter_clause("(col1 = 1)) AND ((col2 = 2)")
def test_validate_filter_clause_multiple():
with pytest.raises(QueryClauseValidationException):
validate_filter_clause("TRUE; SELECT 1")
def test_validate_filter_clause_comment():
with pytest.raises(QueryClauseValidationException):
validate_filter_clause("1 = 1 -- comment")
def test_validate_filter_clause_subquery_comment():
with pytest.raises(QueryClauseValidationException):
validate_filter_clause("(1 = 1 -- comment\n)")