feat(chart-data-api): ignore unknown fields on QueryObject (#12118)

This commit is contained in:
Ville Brofeldt 2020-12-18 14:32:55 +02:00 committed by GitHub
parent 6621557af2
commit 1a5f61b133
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 3 deletions

View File

@ -17,7 +17,7 @@
from typing import Any, Dict
from flask_babel import gettext as _
from marshmallow import fields, post_load, Schema, validate
from marshmallow import EXCLUDE, fields, post_load, Schema, validate
from marshmallow.validate import Length, Range
from superset.common.query_context import QueryContext
@ -857,6 +857,9 @@ class AnnotationLayerSchema(Schema):
class ChartDataQueryObjectSchema(Schema):
class Meta: # pylint: disable=too-few-public-methods
unknown = EXCLUDE
annotation_layers = fields.List(
fields.Nested(AnnotationLayerSchema),
description="Annotation layers to apply to chart",

View File

@ -198,7 +198,7 @@ class TestQueryContext(SupersetTestCase):
def test_sql_injection_via_columns(self):
"""
Ensure that calling invalid columns names in columns are caught
Ensure that calling invalid column names in columns are caught
"""
self.login(username="admin")
table_name = "birth_names"
@ -213,7 +213,7 @@ class TestQueryContext(SupersetTestCase):
def test_sql_injection_via_metrics(self):
"""
Ensure that calling invalid columns names in filters are caught
Ensure that calling invalid column names in filters are caught
"""
self.login(username="admin")
table_name = "birth_names"
@ -266,3 +266,22 @@ class TestQueryContext(SupersetTestCase):
self.assertEqual(len(response), 2)
self.assertEqual(response["language"], "sql")
self.assertIn("SELECT", response["query"])
def test_query_object_unknown_fields(self):
"""
Ensure that query objects with unknown fields don't raise an Exception and
have an identical cache key as one without the unknown field
"""
self.maxDiff = None
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
orig_cache_key = responses["queries"][0]["cache_key"]
payload["queries"][0]["foo"] = "bar"
query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
new_cache_key = responses["queries"][0]["cache_key"]
self.assertEqual(orig_cache_key, new_cache_key)