diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 5de166c86..f85ad00a4 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -17,7 +17,7 @@ from typing import Any, Dict from flask_babel import gettext as _ -from marshmallow import fields, post_load, Schema, validate +from marshmallow import EXCLUDE, fields, post_load, Schema, validate from marshmallow.validate import Length, Range from superset.common.query_context import QueryContext @@ -857,6 +857,9 @@ class AnnotationLayerSchema(Schema): class ChartDataQueryObjectSchema(Schema): + class Meta: # pylint: disable=too-few-public-methods + unknown = EXCLUDE + annotation_layers = fields.List( fields.Nested(AnnotationLayerSchema), description="Annotation layers to apply to chart", diff --git a/tests/query_context_tests.py b/tests/query_context_tests.py index 5bccf0750..a376fa280 100644 --- a/tests/query_context_tests.py +++ b/tests/query_context_tests.py @@ -198,7 +198,7 @@ class TestQueryContext(SupersetTestCase): def test_sql_injection_via_columns(self): """ - Ensure that calling invalid columns names in columns are caught + Ensure that calling invalid column names in columns are caught """ self.login(username="admin") table_name = "birth_names" @@ -213,7 +213,7 @@ class TestQueryContext(SupersetTestCase): def test_sql_injection_via_metrics(self): """ - Ensure that calling invalid columns names in filters are caught + Ensure that calling invalid column names in filters are caught """ self.login(username="admin") table_name = "birth_names" @@ -266,3 +266,22 @@ class TestQueryContext(SupersetTestCase): self.assertEqual(len(response), 2) self.assertEqual(response["language"], "sql") self.assertIn("SELECT", response["query"]) + + def test_query_object_unknown_fields(self): + """ + Ensure that query objects with unknown fields don't raise an Exception and + have an identical cache key as one without the unknown field + """ + self.maxDiff = None + self.login(username="admin") + table_name = "birth_names" + table = self.get_table_by_name(table_name) + payload = get_query_context(table.name, table.id, table.type) + query_context = ChartDataQueryContextSchema().load(payload) + responses = query_context.get_payload() + orig_cache_key = responses["queries"][0]["cache_key"] + payload["queries"][0]["foo"] = "bar" + query_context = ChartDataQueryContextSchema().load(payload) + responses = query_context.get_payload() + new_cache_key = responses["queries"][0]["cache_key"] + self.assertEqual(orig_cache_key, new_cache_key)