fix(chart-data-api): assert referenced columns are present in datasource (#10451)

* fix(chart-data-api): assert requested columns are present in datasource

* add filter tests

* add column_names to AnnotationDatasource

* add assertion for simple metrics

* lint
This commit is contained in:
Ville Brofeldt 2020-08-14 20:58:24 +03:00 committed by GitHub
parent 6c09b938fe
commit acb00f509c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 196 additions and 24 deletions

View File

@ -797,7 +797,7 @@ class ChartDataQueryObjectSchema(Schema):
deprecated=True,
)
having_filters = fields.List(
fields.Dict(),
fields.Nested(ChartDataFilterSchema),
description="HAVING filters to be added to legacy Druid datasource queries. "
"This field is deprecated and should be passed to `extras` "
"as `having_druid`.",

View File

@ -22,6 +22,7 @@ from typing import Any, ClassVar, Dict, List, Optional, Union
import numpy as np
import pandas as pd
from flask_babel import gettext as _
from superset import app, cache, db, security_manager
from superset.common.query_object import QueryObject
@ -235,6 +236,21 @@ class QueryContext:
if query_obj and not is_loaded:
try:
invalid_columns = [
col
for col in query_obj.columns
+ query_obj.groupby
+ [flt["col"] for flt in query_obj.filter]
+ utils.get_column_names_from_metrics(query_obj.metrics)
if col not in self.datasource.column_names
]
if invalid_columns:
raise QueryObjectValidationError(
_(
"Columns missing in datasource: %(invalid_columns)s",
invalid_columns=invalid_columns,
)
)
query_result = self.get_query_result(query_obj)
status = query_result["status"]
query = query_result["query"]

View File

@ -90,6 +90,19 @@ class AnnotationDatasource(BaseDatasource):
cache_timeout = 0
changed_on = None
type = "annotation"
column_names = [
"created_on",
"changed_on",
"id",
"start_dttm",
"end_dttm",
"layer_id",
"short_descr",
"long_descr",
"json_metadata",
"created_by_fk",
"changed_by_fk",
]
def query(self, query_obj: QueryObjectDict) -> QueryResult:
error_message = None
@ -721,7 +734,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
expression_type = metric.get("expressionType")
label = utils.get_metric_name(metric)
if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SIMPLE"]:
if expression_type == utils.AdhocMetricExpressionType.SIMPLE:
column_name = metric["column"].get("column_name")
table_column = columns_by_name.get(column_name)
if table_column:
@ -729,7 +742,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
else:
sqla_column = column(column_name)
sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column)
elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SQL"]:
elif expression_type == utils.AdhocMetricExpressionType.SQL:
sqla_metric = literal_column(metric.get("sqlExpression"))
else:
return None

View File

@ -42,6 +42,7 @@ from types import TracebackType
from typing import (
Any,
Callable,
cast,
Dict,
Iterable,
Iterator,
@ -102,7 +103,6 @@ logging.getLogger("MARKDOWN").setLevel(logging.INFO)
logger = logging.getLogger(__name__)
DTTM_ALIAS = "__timestamp"
ADHOC_METRIC_EXPRESSION_TYPES = {"SIMPLE": "SIMPLE", "SQL": "SQL"}
JS_MAX_INTEGER = 9007199254740991 # Largest int Java Script can handle 2^53-1
@ -1038,20 +1038,23 @@ def backend() -> str:
def is_adhoc_metric(metric: Metric) -> bool:
if not isinstance(metric, dict):
return False
metric = cast(Dict[str, Any], metric)
return bool(
isinstance(metric, dict)
and (
(
(
metric["expressionType"] == ADHOC_METRIC_EXPRESSION_TYPES["SIMPLE"]
and metric["column"]
and metric["aggregate"]
metric.get("expressionType") == AdhocMetricExpressionType.SIMPLE
and metric.get("column")
and cast(Dict[str, Any], metric["column"]).get("column_name")
and metric.get("aggregate")
)
or (
metric["expressionType"] == ADHOC_METRIC_EXPRESSION_TYPES["SQL"]
and metric["sqlExpression"]
metric.get("expressionType") == AdhocMetricExpressionType.SQL
and metric.get("sqlExpression")
)
)
and metric["label"]
and metric.get("label")
)
@ -1398,6 +1401,37 @@ def get_form_data_token(form_data: Dict[str, Any]) -> str:
return form_data.get("token") or "token_" + uuid.uuid4().hex[:8]
def get_column_name_from_metric(metric: Metric) -> Optional[str]:
"""
Extract the column that a metric is referencing. If the metric isn't
a simple metric, always returns `None`.
:param metric: Ad-hoc metric
:return: column name if simple metric, otherwise None
"""
if is_adhoc_metric(metric):
metric = cast(Dict[str, Any], metric)
if metric["expressionType"] == AdhocMetricExpressionType.SIMPLE:
return cast(Dict[str, Any], metric["column"])["column_name"]
return None
def get_column_names_from_metrics(metrics: List[Metric]) -> List[str]:
"""
Extract the columns that a list of metrics are referencing. Expcludes all
SQL metrics.
:param metrics: Ad-hoc metric
:return: column name if simple metric, otherwise None
"""
columns: List[str] = []
for metric in metrics:
column_name = get_column_name_from_metric(metric)
if column_name:
columns.append(column_name)
return columns
class LenientEnum(Enum):
"""Enums that do not raise ValueError when value is invalid"""
@ -1523,3 +1557,8 @@ class PostProcessingContributionOrientation(str, Enum):
ROW = "row"
COLUMN = "column"
class AdhocMetricExpressionType(str, Enum):
SIMPLE = "SIMPLE"
SQL = "SQL"

View File

@ -481,6 +481,24 @@ class BaseViz:
if query_obj and not is_loaded:
try:
invalid_columns = [
col
for col in (query_obj.get("columns") or [])
+ (query_obj.get("groupby") or [])
+ utils.get_column_names_from_metrics(
cast(
List[Union[str, Dict[str, Any]]], query_obj.get("metrics"),
)
)
if col not in self.datasource.column_names
]
if invalid_columns:
raise QueryObjectValidationError(
_(
"Columns missing in datasource: %(invalid_columns)s",
invalid_columns=invalid_columns,
)
)
df = self.get_df(query_obj)
if self.status != utils.QueryStatus.FAILED:
stats_logger.incr("loaded_from_source")

View File

@ -1202,6 +1202,25 @@ class TestCore(SupersetTestCase):
database.extra = json.dumps(extra)
self.assertEqual(database.explore_database_id, explore_database.id)
def test_get_column_names_from_metric(self):
simple_metric = {
"expressionType": utils.AdhocMetricExpressionType.SIMPLE.value,
"column": {"column_name": "my_col"},
"aggregate": "SUM",
"label": "My Simple Label",
}
assert utils.get_column_name_from_metric(simple_metric) == "my_col"
sql_metric = {
"expressionType": utils.AdhocMetricExpressionType.SQL.value,
"sqlExpression": "SUM(my_label)",
"label": "My SQL Label",
}
assert utils.get_column_name_from_metric(sql_metric) is None
assert utils.get_column_names_from_metrics([simple_metric, sql_metric]) == [
"my_col"
]
if __name__ == "__main__":
unittest.main()

View File

@ -17,11 +17,12 @@
import tests.test_app
from superset import db
from superset.charts.schemas import ChartDataQueryContextSchema
from superset.common.query_context import QueryContext
from superset.connectors.connector_registry import ConnectorRegistry
from superset.utils.core import (
AdhocMetricExpressionType,
ChartDataResultFormat,
ChartDataResultType,
FilterOperator,
TimeRangeEndpoint,
)
from tests.base_tests import SupersetTestCase
@ -75,7 +76,7 @@ class TestQueryContext(SupersetTestCase):
payload = get_query_context(table.name, table.id, table.type)
# construct baseline cache_key
query_context = QueryContext(**payload)
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_original = query_context.cache_key(query_object)
@ -92,7 +93,7 @@ class TestQueryContext(SupersetTestCase):
db.session.commit()
# create new QueryContext with unchanged attributes and extract new cache_key
query_context = QueryContext(**payload)
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_new = query_context.cache_key(query_object)
@ -108,20 +109,20 @@ class TestQueryContext(SupersetTestCase):
)
# construct baseline cache_key from query_context with post processing operation
query_context = QueryContext(**payload)
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_original = query_context.cache_key(query_object)
# ensure added None post_processing operation doesn't change cache_key
payload["queries"][0]["post_processing"].append(None)
query_context = QueryContext(**payload)
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_with_null = query_context.cache_key(query_object)
self.assertEqual(cache_key_original, cache_key_with_null)
# ensure query without post processing operation is different
payload["queries"][0].pop("post_processing")
query_context = QueryContext(**payload)
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_without_post_processing = query_context.cache_key(query_object)
self.assertNotEqual(cache_key_original, cache_key_without_post_processing)
@ -136,7 +137,7 @@ class TestQueryContext(SupersetTestCase):
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
del payload["queries"][0]["extras"]["time_range_endpoints"]
query_context = QueryContext(**payload)
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
extras = query_object.to_dict()["extras"]
self.assertTrue("time_range_endpoints" in extras)
@ -155,8 +156,8 @@ class TestQueryContext(SupersetTestCase):
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload["queries"][0]["granularity_sqla"] = "timecol"
payload["queries"][0]["having_filters"] = {"col": "a", "op": "==", "val": "b"}
query_context = QueryContext(**payload)
payload["queries"][0]["having_filters"] = [{"col": "a", "op": "==", "val": "b"}]
query_context = ChartDataQueryContextSchema().load(payload)
self.assertEqual(len(query_context.queries), 1)
query_object = query_context.queries[0]
self.assertEqual(query_object.granularity, "timecol")
@ -172,13 +173,79 @@ class TestQueryContext(SupersetTestCase):
payload = get_query_context(table.name, table.id, table.type)
payload["result_format"] = ChartDataResultFormat.CSV.value
payload["queries"][0]["row_limit"] = 10
query_context = QueryContext(**payload)
query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
self.assertEqual(len(responses), 1)
data = responses[0]["data"]
self.assertIn("name,sum__num\n", data)
self.assertEqual(len(data.split("\n")), 12)
def test_sql_injection_via_groupby(self):
"""
Ensure that calling invalid columns names in groupby are caught
"""
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload["queries"][0]["groupby"] = ["currentDatabase()"]
query_context = ChartDataQueryContextSchema().load(payload)
query_payload = query_context.get_payload()
assert query_payload[0].get("error") is not None
def test_sql_injection_via_columns(self):
"""
Ensure that calling invalid columns names in columns are caught
"""
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload["queries"][0]["groupby"] = []
payload["queries"][0]["metrics"] = []
payload["queries"][0]["columns"] = ["*, 'extra'"]
query_context = ChartDataQueryContextSchema().load(payload)
query_payload = query_context.get_payload()
assert query_payload[0].get("error") is not None
def test_sql_injection_via_filters(self):
"""
Ensure that calling invalid columns names in filters are caught
"""
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload["queries"][0]["groupby"] = ["name"]
payload["queries"][0]["metrics"] = []
payload["queries"][0]["filters"] = [
{"col": "*", "op": FilterOperator.EQUALS.value, "val": ";"}
]
query_context = ChartDataQueryContextSchema().load(payload)
query_payload = query_context.get_payload()
assert query_payload[0].get("error") is not None
def test_sql_injection_via_metrics(self):
"""
Ensure that calling invalid columns names in filters are caught
"""
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload["queries"][0]["groupby"] = ["name"]
payload["queries"][0]["metrics"] = [
{
"expressionType": AdhocMetricExpressionType.SIMPLE.value,
"column": {"column_name": "invalid_col"},
"aggregate": "SUM",
"label": "My Simple Label",
}
]
query_context = ChartDataQueryContextSchema().load(payload)
query_payload = query_context.get_payload()
assert query_payload[0].get("error") is not None
def test_samples_response_type(self):
"""
Ensure that samples result type works
@ -189,7 +256,7 @@ class TestQueryContext(SupersetTestCase):
payload = get_query_context(table.name, table.id, table.type)
payload["result_type"] = ChartDataResultType.SAMPLES.value
payload["queries"][0]["row_limit"] = 5
query_context = QueryContext(**payload)
query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
self.assertEqual(len(responses), 1)
data = responses[0]["data"]
@ -206,7 +273,7 @@ class TestQueryContext(SupersetTestCase):
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload["result_type"] = ChartDataResultType.QUERY.value
query_context = QueryContext(**payload)
query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
self.assertEqual(len(responses), 1)
response = responses[0]