From acb00f509c193ea90aecc7486eee7c6e9fe1a8b3 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Fri, 14 Aug 2020 20:58:24 +0300 Subject: [PATCH] fix(chart-data-api): assert referenced columns are present in datasource (#10451) * fix(chart-data-api): assert requested columns are present in datasource * add filter tests * add column_names to AnnotationDatasource * add assertion for simple metrics * lint --- superset/charts/schemas.py | 2 +- superset/common/query_context.py | 16 ++++++ superset/connectors/sqla/models.py | 17 +++++- superset/utils/core.py | 57 ++++++++++++++++--- superset/viz.py | 18 ++++++ tests/core_tests.py | 19 +++++++ tests/query_context_tests.py | 91 ++++++++++++++++++++++++++---- 7 files changed, 196 insertions(+), 24 deletions(-) diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 34fa4d8c5..1fba09c9a 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -797,7 +797,7 @@ class ChartDataQueryObjectSchema(Schema): deprecated=True, ) having_filters = fields.List( - fields.Dict(), + fields.Nested(ChartDataFilterSchema), description="HAVING filters to be added to legacy Druid datasource queries. " "This field is deprecated and should be passed to `extras` " "as `having_druid`.", diff --git a/superset/common/query_context.py b/superset/common/query_context.py index 0d33f9c4a..d2cecaecd 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -22,6 +22,7 @@ from typing import Any, ClassVar, Dict, List, Optional, Union import numpy as np import pandas as pd +from flask_babel import gettext as _ from superset import app, cache, db, security_manager from superset.common.query_object import QueryObject @@ -235,6 +236,21 @@ class QueryContext: if query_obj and not is_loaded: try: + invalid_columns = [ + col + for col in query_obj.columns + + query_obj.groupby + + [flt["col"] for flt in query_obj.filter] + + utils.get_column_names_from_metrics(query_obj.metrics) + if col not in self.datasource.column_names + ] + if invalid_columns: + raise QueryObjectValidationError( + _( + "Columns missing in datasource: %(invalid_columns)s", + invalid_columns=invalid_columns, + ) + ) query_result = self.get_query_result(query_obj) status = query_result["status"] query = query_result["query"] diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index cfc807d1f..97336d4b1 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -90,6 +90,19 @@ class AnnotationDatasource(BaseDatasource): cache_timeout = 0 changed_on = None type = "annotation" + column_names = [ + "created_on", + "changed_on", + "id", + "start_dttm", + "end_dttm", + "layer_id", + "short_descr", + "long_descr", + "json_metadata", + "created_by_fk", + "changed_by_fk", + ] def query(self, query_obj: QueryObjectDict) -> QueryResult: error_message = None @@ -721,7 +734,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at expression_type = metric.get("expressionType") label = utils.get_metric_name(metric) - if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SIMPLE"]: + if expression_type == utils.AdhocMetricExpressionType.SIMPLE: column_name = metric["column"].get("column_name") table_column = columns_by_name.get(column_name) if table_column: @@ -729,7 +742,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at else: sqla_column = column(column_name) sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column) - elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SQL"]: + elif expression_type == utils.AdhocMetricExpressionType.SQL: sqla_metric = literal_column(metric.get("sqlExpression")) else: return None diff --git a/superset/utils/core.py b/superset/utils/core.py index 3f998ca15..aa3a10a7b 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -42,6 +42,7 @@ from types import TracebackType from typing import ( Any, Callable, + cast, Dict, Iterable, Iterator, @@ -102,7 +103,6 @@ logging.getLogger("MARKDOWN").setLevel(logging.INFO) logger = logging.getLogger(__name__) DTTM_ALIAS = "__timestamp" -ADHOC_METRIC_EXPRESSION_TYPES = {"SIMPLE": "SIMPLE", "SQL": "SQL"} JS_MAX_INTEGER = 9007199254740991 # Largest int Java Script can handle 2^53-1 @@ -1038,20 +1038,23 @@ def backend() -> str: def is_adhoc_metric(metric: Metric) -> bool: + if not isinstance(metric, dict): + return False + metric = cast(Dict[str, Any], metric) return bool( - isinstance(metric, dict) - and ( + ( ( - metric["expressionType"] == ADHOC_METRIC_EXPRESSION_TYPES["SIMPLE"] - and metric["column"] - and metric["aggregate"] + metric.get("expressionType") == AdhocMetricExpressionType.SIMPLE + and metric.get("column") + and cast(Dict[str, Any], metric["column"]).get("column_name") + and metric.get("aggregate") ) or ( - metric["expressionType"] == ADHOC_METRIC_EXPRESSION_TYPES["SQL"] - and metric["sqlExpression"] + metric.get("expressionType") == AdhocMetricExpressionType.SQL + and metric.get("sqlExpression") ) ) - and metric["label"] + and metric.get("label") ) @@ -1398,6 +1401,37 @@ def get_form_data_token(form_data: Dict[str, Any]) -> str: return form_data.get("token") or "token_" + uuid.uuid4().hex[:8] +def get_column_name_from_metric(metric: Metric) -> Optional[str]: + """ + Extract the column that a metric is referencing. If the metric isn't + a simple metric, always returns `None`. + + :param metric: Ad-hoc metric + :return: column name if simple metric, otherwise None + """ + if is_adhoc_metric(metric): + metric = cast(Dict[str, Any], metric) + if metric["expressionType"] == AdhocMetricExpressionType.SIMPLE: + return cast(Dict[str, Any], metric["column"])["column_name"] + return None + + +def get_column_names_from_metrics(metrics: List[Metric]) -> List[str]: + """ + Extract the columns that a list of metrics are referencing. Expcludes all + SQL metrics. + + :param metrics: Ad-hoc metric + :return: column name if simple metric, otherwise None + """ + columns: List[str] = [] + for metric in metrics: + column_name = get_column_name_from_metric(metric) + if column_name: + columns.append(column_name) + return columns + + class LenientEnum(Enum): """Enums that do not raise ValueError when value is invalid""" @@ -1523,3 +1557,8 @@ class PostProcessingContributionOrientation(str, Enum): ROW = "row" COLUMN = "column" + + +class AdhocMetricExpressionType(str, Enum): + SIMPLE = "SIMPLE" + SQL = "SQL" diff --git a/superset/viz.py b/superset/viz.py index 34054cb40..14eedf0ba 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -481,6 +481,24 @@ class BaseViz: if query_obj and not is_loaded: try: + invalid_columns = [ + col + for col in (query_obj.get("columns") or []) + + (query_obj.get("groupby") or []) + + utils.get_column_names_from_metrics( + cast( + List[Union[str, Dict[str, Any]]], query_obj.get("metrics"), + ) + ) + if col not in self.datasource.column_names + ] + if invalid_columns: + raise QueryObjectValidationError( + _( + "Columns missing in datasource: %(invalid_columns)s", + invalid_columns=invalid_columns, + ) + ) df = self.get_df(query_obj) if self.status != utils.QueryStatus.FAILED: stats_logger.incr("loaded_from_source") diff --git a/tests/core_tests.py b/tests/core_tests.py index 4f2d1bfbe..d625860f2 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -1202,6 +1202,25 @@ class TestCore(SupersetTestCase): database.extra = json.dumps(extra) self.assertEqual(database.explore_database_id, explore_database.id) + def test_get_column_names_from_metric(self): + simple_metric = { + "expressionType": utils.AdhocMetricExpressionType.SIMPLE.value, + "column": {"column_name": "my_col"}, + "aggregate": "SUM", + "label": "My Simple Label", + } + assert utils.get_column_name_from_metric(simple_metric) == "my_col" + + sql_metric = { + "expressionType": utils.AdhocMetricExpressionType.SQL.value, + "sqlExpression": "SUM(my_label)", + "label": "My SQL Label", + } + assert utils.get_column_name_from_metric(sql_metric) is None + assert utils.get_column_names_from_metrics([simple_metric, sql_metric]) == [ + "my_col" + ] + if __name__ == "__main__": unittest.main() diff --git a/tests/query_context_tests.py b/tests/query_context_tests.py index f816bcd59..0b0230f14 100644 --- a/tests/query_context_tests.py +++ b/tests/query_context_tests.py @@ -17,11 +17,12 @@ import tests.test_app from superset import db from superset.charts.schemas import ChartDataQueryContextSchema -from superset.common.query_context import QueryContext from superset.connectors.connector_registry import ConnectorRegistry from superset.utils.core import ( + AdhocMetricExpressionType, ChartDataResultFormat, ChartDataResultType, + FilterOperator, TimeRangeEndpoint, ) from tests.base_tests import SupersetTestCase @@ -75,7 +76,7 @@ class TestQueryContext(SupersetTestCase): payload = get_query_context(table.name, table.id, table.type) # construct baseline cache_key - query_context = QueryContext(**payload) + query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] cache_key_original = query_context.cache_key(query_object) @@ -92,7 +93,7 @@ class TestQueryContext(SupersetTestCase): db.session.commit() # create new QueryContext with unchanged attributes and extract new cache_key - query_context = QueryContext(**payload) + query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] cache_key_new = query_context.cache_key(query_object) @@ -108,20 +109,20 @@ class TestQueryContext(SupersetTestCase): ) # construct baseline cache_key from query_context with post processing operation - query_context = QueryContext(**payload) + query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] cache_key_original = query_context.cache_key(query_object) # ensure added None post_processing operation doesn't change cache_key payload["queries"][0]["post_processing"].append(None) - query_context = QueryContext(**payload) + query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] cache_key_with_null = query_context.cache_key(query_object) self.assertEqual(cache_key_original, cache_key_with_null) # ensure query without post processing operation is different payload["queries"][0].pop("post_processing") - query_context = QueryContext(**payload) + query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] cache_key_without_post_processing = query_context.cache_key(query_object) self.assertNotEqual(cache_key_original, cache_key_without_post_processing) @@ -136,7 +137,7 @@ class TestQueryContext(SupersetTestCase): table = self.get_table_by_name(table_name) payload = get_query_context(table.name, table.id, table.type) del payload["queries"][0]["extras"]["time_range_endpoints"] - query_context = QueryContext(**payload) + query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] extras = query_object.to_dict()["extras"] self.assertTrue("time_range_endpoints" in extras) @@ -155,8 +156,8 @@ class TestQueryContext(SupersetTestCase): table = self.get_table_by_name(table_name) payload = get_query_context(table.name, table.id, table.type) payload["queries"][0]["granularity_sqla"] = "timecol" - payload["queries"][0]["having_filters"] = {"col": "a", "op": "==", "val": "b"} - query_context = QueryContext(**payload) + payload["queries"][0]["having_filters"] = [{"col": "a", "op": "==", "val": "b"}] + query_context = ChartDataQueryContextSchema().load(payload) self.assertEqual(len(query_context.queries), 1) query_object = query_context.queries[0] self.assertEqual(query_object.granularity, "timecol") @@ -172,13 +173,79 @@ class TestQueryContext(SupersetTestCase): payload = get_query_context(table.name, table.id, table.type) payload["result_format"] = ChartDataResultFormat.CSV.value payload["queries"][0]["row_limit"] = 10 - query_context = QueryContext(**payload) + query_context = ChartDataQueryContextSchema().load(payload) responses = query_context.get_payload() self.assertEqual(len(responses), 1) data = responses[0]["data"] self.assertIn("name,sum__num\n", data) self.assertEqual(len(data.split("\n")), 12) + def test_sql_injection_via_groupby(self): + """ + Ensure that calling invalid columns names in groupby are caught + """ + self.login(username="admin") + table_name = "birth_names" + table = self.get_table_by_name(table_name) + payload = get_query_context(table.name, table.id, table.type) + payload["queries"][0]["groupby"] = ["currentDatabase()"] + query_context = ChartDataQueryContextSchema().load(payload) + query_payload = query_context.get_payload() + assert query_payload[0].get("error") is not None + + def test_sql_injection_via_columns(self): + """ + Ensure that calling invalid columns names in columns are caught + """ + self.login(username="admin") + table_name = "birth_names" + table = self.get_table_by_name(table_name) + payload = get_query_context(table.name, table.id, table.type) + payload["queries"][0]["groupby"] = [] + payload["queries"][0]["metrics"] = [] + payload["queries"][0]["columns"] = ["*, 'extra'"] + query_context = ChartDataQueryContextSchema().load(payload) + query_payload = query_context.get_payload() + assert query_payload[0].get("error") is not None + + def test_sql_injection_via_filters(self): + """ + Ensure that calling invalid columns names in filters are caught + """ + self.login(username="admin") + table_name = "birth_names" + table = self.get_table_by_name(table_name) + payload = get_query_context(table.name, table.id, table.type) + payload["queries"][0]["groupby"] = ["name"] + payload["queries"][0]["metrics"] = [] + payload["queries"][0]["filters"] = [ + {"col": "*", "op": FilterOperator.EQUALS.value, "val": ";"} + ] + query_context = ChartDataQueryContextSchema().load(payload) + query_payload = query_context.get_payload() + assert query_payload[0].get("error") is not None + + def test_sql_injection_via_metrics(self): + """ + Ensure that calling invalid columns names in filters are caught + """ + self.login(username="admin") + table_name = "birth_names" + table = self.get_table_by_name(table_name) + payload = get_query_context(table.name, table.id, table.type) + payload["queries"][0]["groupby"] = ["name"] + payload["queries"][0]["metrics"] = [ + { + "expressionType": AdhocMetricExpressionType.SIMPLE.value, + "column": {"column_name": "invalid_col"}, + "aggregate": "SUM", + "label": "My Simple Label", + } + ] + query_context = ChartDataQueryContextSchema().load(payload) + query_payload = query_context.get_payload() + assert query_payload[0].get("error") is not None + def test_samples_response_type(self): """ Ensure that samples result type works @@ -189,7 +256,7 @@ class TestQueryContext(SupersetTestCase): payload = get_query_context(table.name, table.id, table.type) payload["result_type"] = ChartDataResultType.SAMPLES.value payload["queries"][0]["row_limit"] = 5 - query_context = QueryContext(**payload) + query_context = ChartDataQueryContextSchema().load(payload) responses = query_context.get_payload() self.assertEqual(len(responses), 1) data = responses[0]["data"] @@ -206,7 +273,7 @@ class TestQueryContext(SupersetTestCase): table = self.get_table_by_name(table_name) payload = get_query_context(table.name, table.id, table.type) payload["result_type"] = ChartDataResultType.QUERY.value - query_context = QueryContext(**payload) + query_context = ChartDataQueryContextSchema().load(payload) responses = query_context.get_payload() self.assertEqual(len(responses), 1) response = responses[0]