fix(chart-data-api): assert referenced columns are present in datasource (#10451)
* fix(chart-data-api): assert requested columns are present in datasource * add filter tests * add column_names to AnnotationDatasource * add assertion for simple metrics * lint
This commit is contained in:
parent
6c09b938fe
commit
acb00f509c
|
|
@ -797,7 +797,7 @@ class ChartDataQueryObjectSchema(Schema):
|
|||
deprecated=True,
|
||||
)
|
||||
having_filters = fields.List(
|
||||
fields.Dict(),
|
||||
fields.Nested(ChartDataFilterSchema),
|
||||
description="HAVING filters to be added to legacy Druid datasource queries. "
|
||||
"This field is deprecated and should be passed to `extras` "
|
||||
"as `having_druid`.",
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from typing import Any, ClassVar, Dict, List, Optional, Union
|
|||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from flask_babel import gettext as _
|
||||
|
||||
from superset import app, cache, db, security_manager
|
||||
from superset.common.query_object import QueryObject
|
||||
|
|
@ -235,6 +236,21 @@ class QueryContext:
|
|||
|
||||
if query_obj and not is_loaded:
|
||||
try:
|
||||
invalid_columns = [
|
||||
col
|
||||
for col in query_obj.columns
|
||||
+ query_obj.groupby
|
||||
+ [flt["col"] for flt in query_obj.filter]
|
||||
+ utils.get_column_names_from_metrics(query_obj.metrics)
|
||||
if col not in self.datasource.column_names
|
||||
]
|
||||
if invalid_columns:
|
||||
raise QueryObjectValidationError(
|
||||
_(
|
||||
"Columns missing in datasource: %(invalid_columns)s",
|
||||
invalid_columns=invalid_columns,
|
||||
)
|
||||
)
|
||||
query_result = self.get_query_result(query_obj)
|
||||
status = query_result["status"]
|
||||
query = query_result["query"]
|
||||
|
|
|
|||
|
|
@ -90,6 +90,19 @@ class AnnotationDatasource(BaseDatasource):
|
|||
cache_timeout = 0
|
||||
changed_on = None
|
||||
type = "annotation"
|
||||
column_names = [
|
||||
"created_on",
|
||||
"changed_on",
|
||||
"id",
|
||||
"start_dttm",
|
||||
"end_dttm",
|
||||
"layer_id",
|
||||
"short_descr",
|
||||
"long_descr",
|
||||
"json_metadata",
|
||||
"created_by_fk",
|
||||
"changed_by_fk",
|
||||
]
|
||||
|
||||
def query(self, query_obj: QueryObjectDict) -> QueryResult:
|
||||
error_message = None
|
||||
|
|
@ -721,7 +734,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
|
|||
expression_type = metric.get("expressionType")
|
||||
label = utils.get_metric_name(metric)
|
||||
|
||||
if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SIMPLE"]:
|
||||
if expression_type == utils.AdhocMetricExpressionType.SIMPLE:
|
||||
column_name = metric["column"].get("column_name")
|
||||
table_column = columns_by_name.get(column_name)
|
||||
if table_column:
|
||||
|
|
@ -729,7 +742,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
|
|||
else:
|
||||
sqla_column = column(column_name)
|
||||
sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column)
|
||||
elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SQL"]:
|
||||
elif expression_type == utils.AdhocMetricExpressionType.SQL:
|
||||
sqla_metric = literal_column(metric.get("sqlExpression"))
|
||||
else:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ from types import TracebackType
|
|||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
|
|
@ -102,7 +103,6 @@ logging.getLogger("MARKDOWN").setLevel(logging.INFO)
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
DTTM_ALIAS = "__timestamp"
|
||||
ADHOC_METRIC_EXPRESSION_TYPES = {"SIMPLE": "SIMPLE", "SQL": "SQL"}
|
||||
|
||||
JS_MAX_INTEGER = 9007199254740991 # Largest int Java Script can handle 2^53-1
|
||||
|
||||
|
|
@ -1038,20 +1038,23 @@ def backend() -> str:
|
|||
|
||||
|
||||
def is_adhoc_metric(metric: Metric) -> bool:
|
||||
if not isinstance(metric, dict):
|
||||
return False
|
||||
metric = cast(Dict[str, Any], metric)
|
||||
return bool(
|
||||
isinstance(metric, dict)
|
||||
and (
|
||||
(
|
||||
(
|
||||
metric["expressionType"] == ADHOC_METRIC_EXPRESSION_TYPES["SIMPLE"]
|
||||
and metric["column"]
|
||||
and metric["aggregate"]
|
||||
metric.get("expressionType") == AdhocMetricExpressionType.SIMPLE
|
||||
and metric.get("column")
|
||||
and cast(Dict[str, Any], metric["column"]).get("column_name")
|
||||
and metric.get("aggregate")
|
||||
)
|
||||
or (
|
||||
metric["expressionType"] == ADHOC_METRIC_EXPRESSION_TYPES["SQL"]
|
||||
and metric["sqlExpression"]
|
||||
metric.get("expressionType") == AdhocMetricExpressionType.SQL
|
||||
and metric.get("sqlExpression")
|
||||
)
|
||||
)
|
||||
and metric["label"]
|
||||
and metric.get("label")
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1398,6 +1401,37 @@ def get_form_data_token(form_data: Dict[str, Any]) -> str:
|
|||
return form_data.get("token") or "token_" + uuid.uuid4().hex[:8]
|
||||
|
||||
|
||||
def get_column_name_from_metric(metric: Metric) -> Optional[str]:
|
||||
"""
|
||||
Extract the column that a metric is referencing. If the metric isn't
|
||||
a simple metric, always returns `None`.
|
||||
|
||||
:param metric: Ad-hoc metric
|
||||
:return: column name if simple metric, otherwise None
|
||||
"""
|
||||
if is_adhoc_metric(metric):
|
||||
metric = cast(Dict[str, Any], metric)
|
||||
if metric["expressionType"] == AdhocMetricExpressionType.SIMPLE:
|
||||
return cast(Dict[str, Any], metric["column"])["column_name"]
|
||||
return None
|
||||
|
||||
|
||||
def get_column_names_from_metrics(metrics: List[Metric]) -> List[str]:
|
||||
"""
|
||||
Extract the columns that a list of metrics are referencing. Expcludes all
|
||||
SQL metrics.
|
||||
|
||||
:param metrics: Ad-hoc metric
|
||||
:return: column name if simple metric, otherwise None
|
||||
"""
|
||||
columns: List[str] = []
|
||||
for metric in metrics:
|
||||
column_name = get_column_name_from_metric(metric)
|
||||
if column_name:
|
||||
columns.append(column_name)
|
||||
return columns
|
||||
|
||||
|
||||
class LenientEnum(Enum):
|
||||
"""Enums that do not raise ValueError when value is invalid"""
|
||||
|
||||
|
|
@ -1523,3 +1557,8 @@ class PostProcessingContributionOrientation(str, Enum):
|
|||
|
||||
ROW = "row"
|
||||
COLUMN = "column"
|
||||
|
||||
|
||||
class AdhocMetricExpressionType(str, Enum):
|
||||
SIMPLE = "SIMPLE"
|
||||
SQL = "SQL"
|
||||
|
|
|
|||
|
|
@ -481,6 +481,24 @@ class BaseViz:
|
|||
|
||||
if query_obj and not is_loaded:
|
||||
try:
|
||||
invalid_columns = [
|
||||
col
|
||||
for col in (query_obj.get("columns") or [])
|
||||
+ (query_obj.get("groupby") or [])
|
||||
+ utils.get_column_names_from_metrics(
|
||||
cast(
|
||||
List[Union[str, Dict[str, Any]]], query_obj.get("metrics"),
|
||||
)
|
||||
)
|
||||
if col not in self.datasource.column_names
|
||||
]
|
||||
if invalid_columns:
|
||||
raise QueryObjectValidationError(
|
||||
_(
|
||||
"Columns missing in datasource: %(invalid_columns)s",
|
||||
invalid_columns=invalid_columns,
|
||||
)
|
||||
)
|
||||
df = self.get_df(query_obj)
|
||||
if self.status != utils.QueryStatus.FAILED:
|
||||
stats_logger.incr("loaded_from_source")
|
||||
|
|
|
|||
|
|
@ -1202,6 +1202,25 @@ class TestCore(SupersetTestCase):
|
|||
database.extra = json.dumps(extra)
|
||||
self.assertEqual(database.explore_database_id, explore_database.id)
|
||||
|
||||
def test_get_column_names_from_metric(self):
|
||||
simple_metric = {
|
||||
"expressionType": utils.AdhocMetricExpressionType.SIMPLE.value,
|
||||
"column": {"column_name": "my_col"},
|
||||
"aggregate": "SUM",
|
||||
"label": "My Simple Label",
|
||||
}
|
||||
assert utils.get_column_name_from_metric(simple_metric) == "my_col"
|
||||
|
||||
sql_metric = {
|
||||
"expressionType": utils.AdhocMetricExpressionType.SQL.value,
|
||||
"sqlExpression": "SUM(my_label)",
|
||||
"label": "My SQL Label",
|
||||
}
|
||||
assert utils.get_column_name_from_metric(sql_metric) is None
|
||||
assert utils.get_column_names_from_metrics([simple_metric, sql_metric]) == [
|
||||
"my_col"
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -17,11 +17,12 @@
|
|||
import tests.test_app
|
||||
from superset import db
|
||||
from superset.charts.schemas import ChartDataQueryContextSchema
|
||||
from superset.common.query_context import QueryContext
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.utils.core import (
|
||||
AdhocMetricExpressionType,
|
||||
ChartDataResultFormat,
|
||||
ChartDataResultType,
|
||||
FilterOperator,
|
||||
TimeRangeEndpoint,
|
||||
)
|
||||
from tests.base_tests import SupersetTestCase
|
||||
|
|
@ -75,7 +76,7 @@ class TestQueryContext(SupersetTestCase):
|
|||
payload = get_query_context(table.name, table.id, table.type)
|
||||
|
||||
# construct baseline cache_key
|
||||
query_context = QueryContext(**payload)
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_object = query_context.queries[0]
|
||||
cache_key_original = query_context.cache_key(query_object)
|
||||
|
||||
|
|
@ -92,7 +93,7 @@ class TestQueryContext(SupersetTestCase):
|
|||
db.session.commit()
|
||||
|
||||
# create new QueryContext with unchanged attributes and extract new cache_key
|
||||
query_context = QueryContext(**payload)
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_object = query_context.queries[0]
|
||||
cache_key_new = query_context.cache_key(query_object)
|
||||
|
||||
|
|
@ -108,20 +109,20 @@ class TestQueryContext(SupersetTestCase):
|
|||
)
|
||||
|
||||
# construct baseline cache_key from query_context with post processing operation
|
||||
query_context = QueryContext(**payload)
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_object = query_context.queries[0]
|
||||
cache_key_original = query_context.cache_key(query_object)
|
||||
|
||||
# ensure added None post_processing operation doesn't change cache_key
|
||||
payload["queries"][0]["post_processing"].append(None)
|
||||
query_context = QueryContext(**payload)
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_object = query_context.queries[0]
|
||||
cache_key_with_null = query_context.cache_key(query_object)
|
||||
self.assertEqual(cache_key_original, cache_key_with_null)
|
||||
|
||||
# ensure query without post processing operation is different
|
||||
payload["queries"][0].pop("post_processing")
|
||||
query_context = QueryContext(**payload)
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_object = query_context.queries[0]
|
||||
cache_key_without_post_processing = query_context.cache_key(query_object)
|
||||
self.assertNotEqual(cache_key_original, cache_key_without_post_processing)
|
||||
|
|
@ -136,7 +137,7 @@ class TestQueryContext(SupersetTestCase):
|
|||
table = self.get_table_by_name(table_name)
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
del payload["queries"][0]["extras"]["time_range_endpoints"]
|
||||
query_context = QueryContext(**payload)
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_object = query_context.queries[0]
|
||||
extras = query_object.to_dict()["extras"]
|
||||
self.assertTrue("time_range_endpoints" in extras)
|
||||
|
|
@ -155,8 +156,8 @@ class TestQueryContext(SupersetTestCase):
|
|||
table = self.get_table_by_name(table_name)
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
payload["queries"][0]["granularity_sqla"] = "timecol"
|
||||
payload["queries"][0]["having_filters"] = {"col": "a", "op": "==", "val": "b"}
|
||||
query_context = QueryContext(**payload)
|
||||
payload["queries"][0]["having_filters"] = [{"col": "a", "op": "==", "val": "b"}]
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
self.assertEqual(len(query_context.queries), 1)
|
||||
query_object = query_context.queries[0]
|
||||
self.assertEqual(query_object.granularity, "timecol")
|
||||
|
|
@ -172,13 +173,79 @@ class TestQueryContext(SupersetTestCase):
|
|||
payload = get_query_context(table.name, table.id, table.type)
|
||||
payload["result_format"] = ChartDataResultFormat.CSV.value
|
||||
payload["queries"][0]["row_limit"] = 10
|
||||
query_context = QueryContext(**payload)
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
responses = query_context.get_payload()
|
||||
self.assertEqual(len(responses), 1)
|
||||
data = responses[0]["data"]
|
||||
self.assertIn("name,sum__num\n", data)
|
||||
self.assertEqual(len(data.split("\n")), 12)
|
||||
|
||||
def test_sql_injection_via_groupby(self):
|
||||
"""
|
||||
Ensure that calling invalid columns names in groupby are caught
|
||||
"""
|
||||
self.login(username="admin")
|
||||
table_name = "birth_names"
|
||||
table = self.get_table_by_name(table_name)
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
payload["queries"][0]["groupby"] = ["currentDatabase()"]
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_payload = query_context.get_payload()
|
||||
assert query_payload[0].get("error") is not None
|
||||
|
||||
def test_sql_injection_via_columns(self):
|
||||
"""
|
||||
Ensure that calling invalid columns names in columns are caught
|
||||
"""
|
||||
self.login(username="admin")
|
||||
table_name = "birth_names"
|
||||
table = self.get_table_by_name(table_name)
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
payload["queries"][0]["groupby"] = []
|
||||
payload["queries"][0]["metrics"] = []
|
||||
payload["queries"][0]["columns"] = ["*, 'extra'"]
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_payload = query_context.get_payload()
|
||||
assert query_payload[0].get("error") is not None
|
||||
|
||||
def test_sql_injection_via_filters(self):
|
||||
"""
|
||||
Ensure that calling invalid columns names in filters are caught
|
||||
"""
|
||||
self.login(username="admin")
|
||||
table_name = "birth_names"
|
||||
table = self.get_table_by_name(table_name)
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
payload["queries"][0]["groupby"] = ["name"]
|
||||
payload["queries"][0]["metrics"] = []
|
||||
payload["queries"][0]["filters"] = [
|
||||
{"col": "*", "op": FilterOperator.EQUALS.value, "val": ";"}
|
||||
]
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_payload = query_context.get_payload()
|
||||
assert query_payload[0].get("error") is not None
|
||||
|
||||
def test_sql_injection_via_metrics(self):
|
||||
"""
|
||||
Ensure that calling invalid columns names in filters are caught
|
||||
"""
|
||||
self.login(username="admin")
|
||||
table_name = "birth_names"
|
||||
table = self.get_table_by_name(table_name)
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
payload["queries"][0]["groupby"] = ["name"]
|
||||
payload["queries"][0]["metrics"] = [
|
||||
{
|
||||
"expressionType": AdhocMetricExpressionType.SIMPLE.value,
|
||||
"column": {"column_name": "invalid_col"},
|
||||
"aggregate": "SUM",
|
||||
"label": "My Simple Label",
|
||||
}
|
||||
]
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_payload = query_context.get_payload()
|
||||
assert query_payload[0].get("error") is not None
|
||||
|
||||
def test_samples_response_type(self):
|
||||
"""
|
||||
Ensure that samples result type works
|
||||
|
|
@ -189,7 +256,7 @@ class TestQueryContext(SupersetTestCase):
|
|||
payload = get_query_context(table.name, table.id, table.type)
|
||||
payload["result_type"] = ChartDataResultType.SAMPLES.value
|
||||
payload["queries"][0]["row_limit"] = 5
|
||||
query_context = QueryContext(**payload)
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
responses = query_context.get_payload()
|
||||
self.assertEqual(len(responses), 1)
|
||||
data = responses[0]["data"]
|
||||
|
|
@ -206,7 +273,7 @@ class TestQueryContext(SupersetTestCase):
|
|||
table = self.get_table_by_name(table_name)
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
payload["result_type"] = ChartDataResultType.QUERY.value
|
||||
query_context = QueryContext(**payload)
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
responses = query_context.get_payload()
|
||||
self.assertEqual(len(responses), 1)
|
||||
response = responses[0]
|
||||
|
|
|
|||
Loading…
Reference in New Issue