Add QueryContext deserialization test (#9778)

* Add QueryContext deserialization test

* deserialize using marshmallow and assert error dict
This commit is contained in:
Ville Brofeldt 2020-05-11 14:10:14 +03:00 committed by GitHub
parent 24db9ab088
commit 42b10aecae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 43 additions and 2 deletions

View File

@ -490,6 +490,7 @@ class ChartDataPostProcessingOperationSchema(Schema):
validate=validate.OneOf(
choices=(
"aggregate",
"cum",
"geodetic_parse",
"geohash_decode",
"geohash_encode",
@ -501,8 +502,7 @@ class ChartDataPostProcessingOperationSchema(Schema):
),
example="aggregate",
)
options = fields.Nested(
ChartDataPostProcessingOperationOptionsSchema,
options = fields.Dict(
description="Options specifying how to perform the operation. Please refer "
"to the respective post processing operation option schemas. "
"For example, `ChartDataPostProcessingOperationOptions` specifies "

View File

@ -17,6 +17,7 @@
from typing import Any, Dict, List, Optional
from superset import db
from superset.charts.schemas import ChartDataQueryContextSchema
from superset.common.query_context import QueryContext
from superset.connectors.connector_registry import ConnectorRegistry
from superset.utils.core import TimeRangeEndpoint
@ -26,6 +27,46 @@ from tests.test_app import app
class QueryContextTests(SupersetTestCase):
def test_schema_deserialization(self):
"""
Ensure that the deserialized QueryContext contains all required fields.
"""
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(
table.name, table.id, table.type, add_postprocessing_operations=True
)
query_context, errors = ChartDataQueryContextSchema().load(payload)
self.assertDictEqual(errors, {})
self.assertEqual(len(query_context.queries), len(payload["queries"]))
for query_idx, query in enumerate(query_context.queries):
payload_query = payload["queries"][query_idx]
# check basic properies
self.assertEqual(query.extras, payload_query["extras"])
self.assertEqual(query.filter, payload_query["filters"])
self.assertEqual(query.groupby, payload_query["groupby"])
# metrics are mutated during creation
for metric_idx, metric in enumerate(query.metrics):
payload_metric = payload_query["metrics"][metric_idx]
payload_metric = (
payload_metric
if "expressionType" in payload_metric
else payload_metric["label"]
)
self.assertEqual(metric, payload_metric)
self.assertEqual(query.orderby, payload_query["orderby"])
self.assertEqual(query.time_range, payload_query["time_range"])
# check post processing operation properties
for post_proc_idx, post_proc in enumerate(query.post_processing):
payload_post_proc = payload_query["post_processing"][post_proc_idx]
self.assertEqual(post_proc["operation"], payload_post_proc["operation"])
self.assertEqual(post_proc["options"], payload_post_proc["options"])
def test_cache_key_changes_when_datasource_is_updated(self):
self.login(username="admin")
table_name = "birth_names"