chore: Improve chart data API + schemas + tests (#9599)

* Make all fields optional in QueryObject and fix having_druid schema

* fix: datasource type sql to table

* lint

* Add missing fields

* Refactor tests

* Linting

* Refactor query context fixtures

* Add typing to test func
This commit is contained in:
Ville Brofeldt 2020-04-23 14:30:48 +03:00 committed by GitHub
parent 76764acfc1
commit a6cedaaa87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 423 additions and 272 deletions

View File

@ -16,7 +16,7 @@
# under the License.
from typing import Any, Dict, Union
from marshmallow import fields, post_load, Schema, ValidationError
from marshmallow import fields, post_load, Schema, validate, ValidationError
from marshmallow.validate import Length
from superset.common.query_context import QueryContext
@ -77,13 +77,15 @@ class ChartDataAdhocMetricSchema(Schema):
expressionType = fields.String(
description="Simple or SQL metric",
required=True,
enum=["SIMPLE", "SQL"],
validate=validate.OneOf(choices=("SIMPLE", "SQL")),
example="SQL",
)
aggregate = fields.String(
description="Aggregation operator. Only required for simple expression types.",
required=False,
enum=["AVG", "COUNT", "COUNT_DISTINCT", "MAX", "MIN", "SUM"],
validate=validate.OneOf(
choices=("AVG", "COUNT", "COUNT_DISTINCT", "MAX", "MIN", "SUM")
),
)
column = fields.Nested(ChartDataColumnSchema)
sqlExpression = fields.String(
@ -178,28 +180,30 @@ class ChartDataRollingOptionsSchema(ChartDataPostProcessingOperationOptionsSchem
)
rolling_type = fields.String(
description="Type of rolling window. Any numpy function will work.",
enum=[
"average",
"argmin",
"argmax",
"cumsum",
"cumprod",
"max",
"mean",
"median",
"nansum",
"nanmin",
"nanmax",
"nanmean",
"nanmedian",
"min",
"percentile",
"prod",
"product",
"std",
"sum",
"var",
],
validate=validate.OneOf(
choices=(
"average",
"argmin",
"argmax",
"cumsum",
"cumprod",
"max",
"mean",
"median",
"nansum",
"nanmin",
"nanmax",
"nanmean",
"nanmedian",
"min",
"percentile",
"prod",
"product",
"std",
"sum",
"var",
)
),
required=True,
example="percentile",
)
@ -225,23 +229,25 @@ class ChartDataRollingOptionsSchema(ChartDataPostProcessingOperationOptionsSchem
"additional parameters to `rolling_type_options`. For instance, "
"to use `gaussian`, the parameter `std` needs to be provided.",
required=False,
enum=[
"boxcar",
"triang",
"blackman",
"hamming",
"bartlett",
"parzen",
"bohman",
"blackmanharris",
"nuttall",
"barthann",
"kaiser",
"gaussian",
"general_gaussian",
"slepian",
"exponential",
],
validate=validate.OneOf(
choices=(
"boxcar",
"triang",
"blackman",
"hamming",
"bartlett",
"parzen",
"bohman",
"blackmanharris",
"nuttall",
"barthann",
"kaiser",
"gaussian",
"general_gaussian",
"slepian",
"exponential",
)
),
)
min_periods = fields.Integer(
description="The minimum amount of periods required for a row to be included "
@ -333,7 +339,9 @@ class ChartDataPostProcessingOperationSchema(Schema):
operation = fields.String(
description="Post processing operation type",
required=True,
enum=["aggregate", "pivot", "rolling", "select", "sort"],
validate=validate.OneOf(
choices=("aggregate", "pivot", "rolling", "select", "sort")
),
example="aggregate",
)
options = fields.Nested(
@ -362,7 +370,9 @@ class ChartDataFilterSchema(Schema):
)
op = fields.String( # pylint: disable=invalid-name
description="The comparison operator.",
enum=[filter_op.value for filter_op in utils.FilterOperator],
validate=validate.OneOf(
choices=[filter_op.value for filter_op in utils.FilterOperator]
),
required=True,
example="IN",
)
@ -376,21 +386,23 @@ class ChartDataFilterSchema(Schema):
class ChartDataExtrasSchema(Schema):
time_range_endpoints = fields.List(
fields.String(enum=["INCLUSIVE", "EXCLUSIVE"]),
description="A list with two values, stating if start/end should be "
"inclusive/exclusive.",
required=False,
fields.String(
validate=validate.OneOf(choices=("INCLUSIVE", "EXCLUSIVE")),
description="A list with two values, stating if start/end should be "
"inclusive/exclusive.",
required=False,
)
)
relative_start = fields.String(
description="Start time for relative time deltas. "
'Default: `config["DEFAULT_RELATIVE_START_TIME"]`',
enum=["today", "now"],
validate=validate.OneOf(choices=("today", "now")),
required=False,
)
relative_end = fields.String(
description="End time for relative time deltas. "
'Default: `config["DEFAULT_RELATIVE_START_TIME"]`',
enum=["today", "now"],
validate=validate.OneOf(choices=("today", "now")),
required=False,
)
where = fields.String(
@ -402,35 +414,54 @@ class ChartDataExtrasSchema(Schema):
"AND operator.",
required=False,
)
having_druid = fields.String(
having_druid = fields.List(
fields.Nested(ChartDataFilterSchema),
description="HAVING filters to be added to legacy Druid datasource queries.",
required=False,
)
time_grain_sqla = fields.String(
description="To what level of granularity should the temporal column be "
"aggregated. Supports "
"[ISO 8601](https://en.wikipedia.org/wiki/ISO_8601#Durations) durations.",
validate=validate.OneOf(
choices=(
"PT1S",
"PT1M",
"PT5M",
"PT10M",
"PT15M",
"PT0.5H",
"PT1H",
"P1D",
"P1W",
"P1M",
"P0.25Y",
"P1Y",
),
),
required=False,
example="P1D",
)
druid_time_origin = fields.String(
description="Starting point for time grain counting on legacy Druid "
"datasources. Used to change e.g. Monday/Sunday first-day-of-week.",
required=False,
)
class ChartDataQueryObjectSchema(Schema):
filters = fields.List(fields.Nested(ChartDataFilterSchema), required=False)
granularity = fields.String(
description="To what level of granularity should the temporal column be "
"aggregated. Supports "
"[ISO 8601](https://en.wikipedia.org/wiki/ISO_8601#Durations) "
"durations.",
enum=[
"PT1S",
"PT1M",
"PT5M",
"PT10M",
"PT15M",
"PT0.5H",
"PT1H",
"P1D",
"P1W",
"P1M",
"P0.25Y",
"P1Y",
],
description="Name of temporal column used for time filtering. For legacy Druid "
"datasources this defines the time grain.",
required=False,
example="P1D",
)
granularity_sqla = fields.String(
description="Name of temporal column used for time filtering for SQL "
"datasources. This field is deprecated, use `granularity` "
"instead.",
required=False,
deprecated=True,
)
groupby = fields.List(
fields.String(description="Columns by which to group the query.",),
@ -441,6 +472,7 @@ class ChartDataQueryObjectSchema(Schema):
"references to datasource metrics (strings), or ad-hoc metrics"
"which are defined only within the query object. See "
"`ChartDataAdhocMetricSchema` for the structure of ad-hoc metrics.",
required=False,
)
post_processing = fields.List(
fields.Nested(ChartDataPostProcessingOperationSchema),
@ -450,7 +482,8 @@ class ChartDataQueryObjectSchema(Schema):
)
time_range = fields.String(
description="A time rage, either expressed as a colon separated string "
"`since : until`. Valid formats for `since` and `until` are: \n"
"`since : until` or human readable freeform. Valid formats for "
"`since` and `until` are: \n"
"- ISO 8601\n"
"- X days/years/hours/day/year/weeks\n"
"- X days/years/hours/day/year/weeks ago\n"
@ -488,7 +521,7 @@ class ChartDataQueryObjectSchema(Schema):
order_desc = fields.Boolean(
description="Reverse order. Default: `false`", required=False
)
extras = fields.Dict(description=" Default: `{}`", required=False)
extras = fields.Nested(ChartDataExtrasSchema, required=False)
columns = fields.List(fields.String(), description="", required=False,)
orderby = fields.List(
fields.List(fields.Raw()),
@ -499,13 +532,13 @@ class ChartDataQueryObjectSchema(Schema):
)
where = fields.String(
description="WHERE clause to be added to queries using AND operator."
"This field is deprecated, and should be passed to `extras`.",
"This field is deprecated and should be passed to `extras`.",
required=False,
deprecated=True,
)
having = fields.String(
description="HAVING clause to be added to aggregate queries using "
"AND operator. This field is deprecated, and should be passed "
"AND operator. This field is deprecated and should be passed "
"to `extras`.",
required=False,
deprecated=True,
@ -513,7 +546,7 @@ class ChartDataQueryObjectSchema(Schema):
having_filters = fields.List(
fields.Dict(),
description="HAVING filters to be added to legacy Druid datasource queries. "
"This field is deprecated, and should be passed to `extras` "
"This field is deprecated and should be passed to `extras` "
"as `filters_druid`.",
required=False,
deprecated=True,
@ -523,7 +556,10 @@ class ChartDataQueryObjectSchema(Schema):
class ChartDataDatasourceSchema(Schema):
description = "Chart datasource"
id = fields.Integer(description="Datasource id", required=True,)
type = fields.String(description="Datasource type", enum=["druid", "sql"])
type = fields.String(
description="Datasource type",
validate=validate.OneOf(choices=("druid", "table")),
)
class ChartDataQueryContextSchema(Schema):
@ -561,15 +597,17 @@ class ChartDataResponseResult(Schema):
)
status = fields.String(
description="Status of the query",
enum=[
"stopped",
"failed",
"pending",
"running",
"scheduled",
"success",
"timed_out",
],
validate=validate.OneOf(
choices=(
"stopped",
"failed",
"pending",
"running",
"scheduled",
"success",
"timed_out",
)
),
allow_none=False,
)
stacktrace = fields.String(

View File

@ -35,15 +35,19 @@ logger = logging.getLogger(__name__)
# https://github.com/python/mypy/issues/5288
class DeprecatedExtrasField(NamedTuple):
name: str
extras_name: str
class DeprecatedField(NamedTuple):
old_name: str
new_name: str
DEPRECATED_FIELDS = (
DeprecatedField(old_name="granularity_sqla", new_name="granularity"),
)
DEPRECATED_EXTRAS_FIELDS = (
DeprecatedExtrasField(name="where", extras_name="where"),
DeprecatedExtrasField(name="having", extras_name="having"),
DeprecatedExtrasField(name="having_filters", extras_name="having_druid"),
DeprecatedField(old_name="where", new_name="where"),
DeprecatedField(old_name="having", new_name="having"),
DeprecatedField(old_name="having_filters", new_name="having_druid"),
)
@ -53,7 +57,7 @@ class QueryObject:
and druid. The query objects are constructed on the client.
"""
granularity: str
granularity: Optional[str]
from_dttm: datetime
to_dttm: datetime
is_timeseries: bool
@ -72,8 +76,8 @@ class QueryObject:
def __init__(
self,
granularity: str,
metrics: List[Union[Dict[str, Any], str]],
granularity: Optional[str] = None,
metrics: Optional[List[Union[Dict[str, Any], str]]] = None,
groupby: Optional[List[str]] = None,
filters: Optional[List[Dict[str, Any]]] = None,
time_range: Optional[str] = None,
@ -89,6 +93,7 @@ class QueryObject:
post_processing: Optional[List[Dict[str, Any]]] = None,
**kwargs: Any,
):
metrics = metrics or []
extras = extras or {}
is_sip_38 = is_feature_enabled("SIP_38_VIZ_REARCHITECTURE")
self.granularity = granularity
@ -131,22 +136,44 @@ class QueryObject:
if is_sip_38 and groupby:
self.columns += groupby
logger.warning(
f"The field groupby is deprecated. Viz plugins should "
f"pass all selectables via the columns field"
f"The field `groupby` is deprecated. Viz plugins should "
f"pass all selectables via the `columns` field"
)
self.orderby = orderby or []
# move deprecated fields to extras
for field in DEPRECATED_EXTRAS_FIELDS:
if field.name in kwargs:
# rename deprecated fields
for field in DEPRECATED_FIELDS:
if field.old_name in kwargs:
logger.warning(
f"The field `{field.name} is deprecated, and should be "
f"passed to `extras` via the `{field.extras_name}` property"
f"The field `{field.old_name}` is deprecated, please use "
f"`{field.new_name}` instead."
)
value = kwargs[field.name]
value = kwargs[field.old_name]
if value:
self.extras[field.extras_name] = value
if hasattr(self, field.new_name):
logger.warning(
f"The field `{field.new_name}` is already populated, "
f"replacing value with contents from `{field.old_name}`."
)
setattr(self, field.new_name, value)
# move deprecated extras fields to extras
for field in DEPRECATED_EXTRAS_FIELDS:
if field.old_name in kwargs:
logger.warning(
f"The field `{field.old_name}` is deprecated and should be "
f"passed to `extras` via the `{field.new_name}` property."
)
value = kwargs[field.old_name]
if value:
if hasattr(self.extras, field.new_name):
logger.warning(
f"The field `{field.new_name}` is already populated in "
f"`extras`, replacing value with contents "
f"from `{field.old_name}`."
)
self.extras[field.new_name] = value
def to_dict(self) -> Dict[str, Any]:
query_object_dict = {

View File

@ -366,7 +366,9 @@ class BaseDatasource(
def default_query(qry) -> Query:
return qry
def get_column(self, column_name: str) -> Optional["BaseColumn"]:
def get_column(self, column_name: Optional[str]) -> Optional["BaseColumn"]:
if not column_name:
return None
for col in self.columns:
if col.column_name == column_name:
return col

View File

@ -385,7 +385,7 @@ class RequestAccessTests(SupersetTestCase):
)
self.assertEqual(
"[Superset] Access to the datasource {} was granted".format(
self.get_table(ds_1_id).full_name
self.get_table_by_id(ds_1_id).full_name
),
call_args[2]["Subject"],
)
@ -426,7 +426,7 @@ class RequestAccessTests(SupersetTestCase):
)
self.assertEqual(
"[Superset] Access to the datasource {} was granted".format(
self.get_table(ds_2_id).full_name
self.get_table_by_id(ds_2_id).full_name
),
call_args[2]["Subject"],
)

View File

@ -18,16 +18,18 @@
"""Unit tests for Superset"""
import imp
import json
from typing import Union, Dict
from typing import Dict, Union
from unittest.mock import Mock, patch
import pandas as pd
from flask import Response
from flask_appbuilder.security.sqla import models as ab_models
from flask_testing import TestCase
from sqlalchemy.orm import Session
from tests.test_app import app # isort:skip
from superset import db, security_manager
from superset.connectors.base.models import BaseDatasource
from superset.connectors.druid.models import DruidCluster, DruidDatasource
from superset.connectors.sqla.models import SqlaTable
from superset.models import core as models
@ -103,7 +105,8 @@ class SupersetTestCase(TestCase):
session.add(druid_datasource2)
session.commit()
def get_table(self, table_id):
@staticmethod
def get_table_by_id(table_id: int) -> SqlaTable:
return db.session.query(SqlaTable).filter_by(id=table_id).one()
@staticmethod
@ -127,21 +130,25 @@ class SupersetTestCase(TestCase):
resp = self.get_resp("/login/", data=dict(username=username, password=password))
self.assertNotIn("User confirmation needed", resp)
def get_slice(self, slice_name, session):
def get_slice(self, slice_name: str, session: Session) -> Slice:
slc = session.query(Slice).filter_by(slice_name=slice_name).one()
session.expunge_all()
return slc
def get_table_by_name(self, name):
@staticmethod
def get_table_by_name(name: str) -> SqlaTable:
return db.session.query(SqlaTable).filter_by(table_name=name).one()
def get_database_by_id(self, db_id):
@staticmethod
def get_database_by_id(db_id: int) -> Database:
return db.session.query(Database).filter_by(id=db_id).one()
def get_druid_ds_by_name(self, name):
@staticmethod
def get_druid_ds_by_name(name: str) -> DruidDatasource:
return db.session.query(DruidDatasource).filter_by(datasource_name=name).first()
def get_datasource_mock(self):
@staticmethod
def get_datasource_mock() -> BaseDatasource:
datasource = Mock()
results = Mock()
results.query = Mock()

View File

@ -16,7 +16,7 @@
# under the License.
"""Unit tests for Superset"""
import json
from typing import Any, Dict, List, Optional
from typing import List, Optional
import prison
from sqlalchemy.sql import func
@ -28,6 +28,7 @@ from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from tests.base_api_tests import ApiOwnersTestCaseMixin
from tests.base_tests import SupersetTestCase
from tests.fixtures.query_context import get_query_context
class ChartApiTests(SupersetTestCase, ApiOwnersTestCaseMixin):
@ -69,32 +70,6 @@ class ChartApiTests(SupersetTestCase, ApiOwnersTestCaseMixin):
db.session.commit()
return slice
def _get_query_context(self) -> Dict[str, Any]:
self.login(username="admin")
slc = self.get_slice("Girl Name Cloud", db.session)
return {
"datasource": {"id": slc.datasource_id, "type": slc.datasource_type},
"queries": [
{
"extras": {"where": ""},
"granularity": "ds",
"groupby": ["name"],
"is_timeseries": False,
"metrics": [{"label": "sum__num"}],
"order_desc": True,
"orderby": [],
"row_limit": 100,
"time_range": "100 years ago : now",
"timeseries_limit": 0,
"timeseries_limit_metric": None,
"filters": [{"col": "gender", "op": "==", "val": "boy"}],
"having": "",
"having_filters": [],
"where": "",
}
],
}
def test_delete_chart(self):
"""
Chart API: Test delete
@ -662,22 +637,37 @@ class ChartApiTests(SupersetTestCase, ApiOwnersTestCaseMixin):
Query API: Test chart data query
"""
self.login(username="admin")
query_context = self._get_query_context()
table = self.get_table_by_name("birth_names")
payload = get_query_context(table.name, table.id, table.type)
uri = "api/v1/chart/data"
rv = self.post_assert_metric(uri, query_context, "data")
rv = self.post_assert_metric(uri, payload, "data")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["result"][0]["rowcount"], 100)
def test_invalid_chart_data(self):
"""
Query API: Test chart data query with invalid schema
def test_chart_data_with_invalid_datasource(self):
"""Query API: Test chart data query with invalid schema
"""
self.login(username="admin")
query_context = self._get_query_context()
query_context["datasource"] = "abc"
table = self.get_table_by_name("birth_names")
payload = get_query_context(table.name, table.id, table.type)
payload["datasource"] = "abc"
uri = "api/v1/chart/data"
rv = self.client.post(uri, json=query_context)
rv = self.post_assert_metric(uri, payload, "data")
self.assertEqual(rv.status_code, 400)
def test_chart_data_with_invalid_enum_value(self):
"""Query API: Test chart data query with invalid enum value
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
payload = get_query_context(table.name, table.id, table.type)
payload["queries"][0]["extras"]["time_range_endpoints"] = [
"abc",
"EXCLUSIVE",
]
uri = "api/v1/chart/data"
rv = self.client.post(uri, json=payload)
self.assertEqual(rv.status_code, 400)
def test_query_exec_not_allowed(self):
@ -685,9 +675,10 @@ class ChartApiTests(SupersetTestCase, ApiOwnersTestCaseMixin):
Query API: Test chart data query not allowed
"""
self.login(username="gamma")
query_context = self._get_query_context()
table = self.get_table_by_name("birth_names")
payload = get_query_context(table.name, table.id, table.type)
uri = "api/v1/chart/data"
rv = self.post_assert_metric(uri, query_context, "data")
rv = self.post_assert_metric(uri, payload, "data")
self.assertEqual(rv.status_code, 401)
def test_datasources(self):

View File

@ -28,7 +28,6 @@ import pytz
import random
import re
import string
from typing import Any, Dict
import unittest
from unittest import mock, skipUnless
@ -44,8 +43,6 @@ from superset import (
sql_lab,
is_feature_enabled,
)
from superset.common.query_context import QueryContext
from superset.connectors.connector_registry import ConnectorRegistry
from superset.connectors.sqla.models import SqlaTable
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.mssql import MssqlEngineSpec
@ -111,61 +108,6 @@ class CoreTests(SupersetTestCase):
resp = self.client.get("/superset/slice/-1/")
assert resp.status_code == 404
def _get_query_context(self) -> Dict[str, Any]:
self.login(username="admin")
slc = self.get_slice("Girl Name Cloud", db.session)
return {
"datasource": {"id": slc.datasource_id, "type": slc.datasource_type},
"queries": [
{
"granularity": "ds",
"groupby": ["name"],
"metrics": [{"label": "sum__num"}],
"filters": [],
"row_limit": 100,
}
],
}
def _get_query_context_with_post_processing(self) -> Dict[str, Any]:
self.login(username="admin")
slc = self.get_slice("Girl Name Cloud", db.session)
return {
"datasource": {"id": slc.datasource_id, "type": slc.datasource_type},
"queries": [
{
"granularity": "ds",
"groupby": ["name", "state"],
"metrics": [{"label": "sum__num"}],
"filters": [],
"row_limit": 100,
"post_processing": [
{
"operation": "aggregate",
"options": {
"groupby": ["state"],
"aggregates": {
"q1": {
"operator": "percentile",
"column": "sum__num",
"options": {"q": 25},
},
"median": {
"operator": "median",
"column": "sum__num",
},
},
},
},
{
"operation": "sort",
"options": {"columns": {"q1": False, "state": True},},
},
],
}
],
}
def test_viz_cache_key(self):
self.login(username="admin")
slc = self.get_slice("Girls", db.session)
@ -178,45 +120,6 @@ class CoreTests(SupersetTestCase):
qobj["groupby"] = []
self.assertNotEqual(cache_key, viz.cache_key(qobj))
def test_cache_key_changes_when_datasource_is_updated(self):
qc_dict = self._get_query_context()
# construct baseline cache_key
query_context = QueryContext(**qc_dict)
query_object = query_context.queries[0]
cache_key_original = query_context.cache_key(query_object)
# make temporary change and revert it to refresh the changed_on property
datasource = ConnectorRegistry.get_datasource(
datasource_type=qc_dict["datasource"]["type"],
datasource_id=qc_dict["datasource"]["id"],
session=db.session,
)
description_original = datasource.description
datasource.description = "temporary description"
db.session.commit()
datasource.description = description_original
db.session.commit()
# create new QueryContext with unchanged attributes and extract new cache_key
query_context = QueryContext(**qc_dict)
query_object = query_context.queries[0]
cache_key_new = query_context.cache_key(query_object)
# the new cache_key should be different due to updated datasource
self.assertNotEqual(cache_key_original, cache_key_new)
def test_query_context_time_range_endpoints(self):
query_context = QueryContext(**self._get_query_context())
query_object = query_context.queries[0]
extras = query_object.to_dict()["extras"]
self.assertTrue("time_range_endpoints" in extras)
self.assertEquals(
extras["time_range_endpoints"],
(utils.TimeRangeEndpoint.INCLUSIVE, utils.TimeRangeEndpoint.EXCLUSIVE),
)
def test_get_superset_tables_not_allowed(self):
example_db = utils.get_example_database()
schema_name = self.default_schema_backend_map[example_db.backend]
@ -254,20 +157,6 @@ class CoreTests(SupersetTestCase):
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_api_v1_query_endpoint(self):
self.login(username="admin")
qc_dict = self._get_query_context()
data = json.dumps(qc_dict)
resp = json.loads(self.get_resp("/api/v1/query/", {"query_context": data}))
self.assertEqual(resp[0]["rowcount"], 100)
def test_api_v1_query_endpoint_with_post_processing(self):
self.login(username="admin")
qc_dict = self._get_query_context_with_post_processing()
data = json.dumps(qc_dict)
resp = json.loads(self.get_resp("/api/v1/query/", {"query_context": data}))
self.assertEqual(resp[0]["rowcount"], 6)
def test_old_slice_json_endpoint(self):
self.login(username="admin")
slc = self.get_slice("Girls", db.session)

View File

@ -165,7 +165,7 @@ class DictImportExportTests(SupersetTestCase):
new_table = SqlaTable.import_from_dict(db.session, dict_table)
db.session.commit()
imported_id = new_table.id
imported = self.get_table(imported_id)
imported = self.get_table_by_id(imported_id)
self.assert_table_equals(table, imported)
self.yaml_compare(table.export_to_dict(), imported.export_to_dict())
@ -178,7 +178,7 @@ class DictImportExportTests(SupersetTestCase):
)
imported_table = SqlaTable.import_from_dict(db.session, dict_table)
db.session.commit()
imported = self.get_table(imported_table.id)
imported = self.get_table_by_id(imported_table.id)
self.assert_table_equals(table, imported)
self.assertEqual(
{DBREF: ID_PREFIX + 2, "database_name": "main"}, json.loads(imported.params)
@ -194,7 +194,7 @@ class DictImportExportTests(SupersetTestCase):
)
imported_table = SqlaTable.import_from_dict(db.session, dict_table)
db.session.commit()
imported = self.get_table(imported_table.id)
imported = self.get_table_by_id(imported_table.id)
self.assert_table_equals(table, imported)
self.yaml_compare(table.export_to_dict(), imported.export_to_dict())
@ -213,7 +213,7 @@ class DictImportExportTests(SupersetTestCase):
imported_over_table = SqlaTable.import_from_dict(db.session, dict_table_over)
db.session.commit()
imported_over = self.get_table(imported_over_table.id)
imported_over = self.get_table_by_id(imported_over_table.id)
self.assertEqual(imported_table.id, imported_over.id)
expected_table, _ = self.create_table(
"table_override",
@ -243,7 +243,7 @@ class DictImportExportTests(SupersetTestCase):
)
db.session.commit()
imported_over = self.get_table(imported_over_table.id)
imported_over = self.get_table_by_id(imported_over_table.id)
self.assertEqual(imported_table.id, imported_over.id)
expected_table, _ = self.create_table(
"table_override",
@ -274,7 +274,7 @@ class DictImportExportTests(SupersetTestCase):
imported_copy_table = SqlaTable.import_from_dict(db.session, dict_copy_table)
db.session.commit()
self.assertEqual(imported_table.id, imported_copy_table.id)
self.assert_table_equals(copy_table, self.get_table(imported_table.id))
self.assert_table_equals(copy_table, self.get_table_by_id(imported_table.id))
self.yaml_compare(
imported_copy_table.export_to_dict(), imported_table.export_to_dict()
)

103
tests/fixtures/query_context.py vendored Normal file
View File

@ -0,0 +1,103 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import copy
from typing import Any, Dict, List
QUERY_OBJECTS = {
"birth_names": {
"extras": {"where": "", "time_range_endpoints": ["INCLUSIVE", "EXCLUSIVE"],},
"granularity": "ds",
"groupby": ["name"],
"is_timeseries": False,
"metrics": [{"label": "sum__num"}],
"order_desc": True,
"orderby": [],
"row_limit": 100,
"time_range": "100 years ago : now",
"timeseries_limit": 0,
"timeseries_limit_metric": None,
"filters": [{"col": "gender", "op": "==", "val": "boy"}],
"having": "",
"having_filters": [],
"where": "",
}
}
POSTPROCESSING_OPERATIONS = {
"birth_names": [
{
"operation": "aggregate",
"options": {
"groupby": ["gender"],
"aggregates": {
"q1": {
"operator": "percentile",
"column": "sum__num",
"options": {"q": 25},
},
"median": {"operator": "median", "column": "sum__num",},
},
},
},
{"operation": "sort", "options": {"columns": {"q1": False, "gender": True},},},
]
}
def _get_query_object(
datasource_name: str, add_postprocessing_operations: bool
) -> Dict[str, Any]:
if datasource_name not in QUERY_OBJECTS:
raise Exception(
f"QueryObject fixture not defined for datasource: {datasource_name}"
)
query_object = copy.deepcopy(QUERY_OBJECTS[datasource_name])
if add_postprocessing_operations:
query_object["post_processing"] = _get_postprocessing_operation(datasource_name)
return query_object
def _get_postprocessing_operation(datasource_name: str) -> List[Dict[str, Any]]:
if datasource_name not in QUERY_OBJECTS:
raise Exception(
f"Post-processing fixture not defined for datasource: {datasource_name}"
)
return copy.deepcopy(POSTPROCESSING_OPERATIONS[datasource_name])
def get_query_context(
datasource_name: str = "birth_names",
datasource_id: int = 0,
datasource_type: str = "table",
add_postprocessing_operations: bool = False,
) -> Dict[str, Any]:
"""
Create a request payload for retrieving a QueryContext object via the
`api/v1/chart/data` endpoint. By default returns a payload corresponding to one
generated by the "Boy Name Cloud" chart in the examples.
:param datasource_name: name of datasource to query. Different datasources require
different parameters in the QueryContext.
:param datasource_id: id of datasource to query.
:param datasource_type: type of datasource to query.
:param add_postprocessing_operations: Add post-processing operations to QueryObject
:return: Request payload
"""
return {
"datasource": {"id": datasource_id, "type": datasource_type},
"queries": [_get_query_object(datasource_name, add_postprocessing_operations)],
}

View File

@ -558,7 +558,7 @@ class ImportExportTests(SupersetTestCase):
def test_import_table_no_metadata(self):
table = self.create_table("pure_table", id=10001)
imported_id = SqlaTable.import_obj(table, import_time=1989)
imported = self.get_table(imported_id)
imported = self.get_table_by_id(imported_id)
self.assert_table_equals(table, imported)
def test_import_table_1_col_1_met(self):
@ -566,7 +566,7 @@ class ImportExportTests(SupersetTestCase):
"table_1_col_1_met", id=10002, cols_names=["col1"], metric_names=["metric1"]
)
imported_id = SqlaTable.import_obj(table, import_time=1990)
imported = self.get_table(imported_id)
imported = self.get_table_by_id(imported_id)
self.assert_table_equals(table, imported)
self.assertEqual(
{"remote_id": 10002, "import_time": 1990, "database_name": "examples"},
@ -582,7 +582,7 @@ class ImportExportTests(SupersetTestCase):
)
imported_id = SqlaTable.import_obj(table, import_time=1991)
imported = self.get_table(imported_id)
imported = self.get_table_by_id(imported_id)
self.assert_table_equals(table, imported)
def test_import_table_override(self):
@ -599,7 +599,7 @@ class ImportExportTests(SupersetTestCase):
)
imported_over_id = SqlaTable.import_obj(table_over, import_time=1992)
imported_over = self.get_table(imported_over_id)
imported_over = self.get_table_by_id(imported_over_id)
self.assertEqual(imported_id, imported_over.id)
expected_table = self.create_table(
"table_override",
@ -627,7 +627,7 @@ class ImportExportTests(SupersetTestCase):
imported_id_copy = SqlaTable.import_obj(copy_table, import_time=1994)
self.assertEqual(imported_id, imported_id_copy)
self.assert_table_equals(copy_table, self.get_table(imported_id))
self.assert_table_equals(copy_table, self.get_table_by_id(imported_id))
def test_import_druid_no_metadata(self):
datasource = self.create_druid_datasource("pure_druid", id=10001)

View File

@ -0,0 +1,94 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, List, Optional
from superset import db
from superset.common.query_context import QueryContext
from superset.connectors.connector_registry import ConnectorRegistry
from superset.utils.core import TimeRangeEndpoint
from tests.base_tests import SupersetTestCase
from tests.fixtures.query_context import get_query_context
from tests.test_app import app
class QueryContextTests(SupersetTestCase):
def test_cache_key_changes_when_datasource_is_updated(self):
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
# construct baseline cache_key
query_context = QueryContext(**payload)
query_object = query_context.queries[0]
cache_key_original = query_context.cache_key(query_object)
# make temporary change and revert it to refresh the changed_on property
datasource = ConnectorRegistry.get_datasource(
datasource_type=payload["datasource"]["type"],
datasource_id=payload["datasource"]["id"],
session=db.session,
)
description_original = datasource.description
datasource.description = "temporary description"
db.session.commit()
datasource.description = description_original
db.session.commit()
# create new QueryContext with unchanged attributes and extract new cache_key
query_context = QueryContext(**payload)
query_object = query_context.queries[0]
cache_key_new = query_context.cache_key(query_object)
# the new cache_key should be different due to updated datasource
self.assertNotEqual(cache_key_original, cache_key_new)
def test_query_context_time_range_endpoints(self):
"""
Ensure that time_range_endpoints are populated automatically when missing
from the payload
"""
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
del payload["queries"][0]["extras"]["time_range_endpoints"]
query_context = QueryContext(**payload)
query_object = query_context.queries[0]
extras = query_object.to_dict()["extras"]
self.assertTrue("time_range_endpoints" in extras)
self.assertEquals(
extras["time_range_endpoints"],
(TimeRangeEndpoint.INCLUSIVE, TimeRangeEndpoint.EXCLUSIVE),
)
def test_convert_deprecated_fields(self):
"""
Ensure that deprecated fields are converted correctly
"""
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload["queries"][0]["granularity_sqla"] = "timecol"
payload["queries"][0]["having_filters"] = {"col": "a", "op": "==", "val": "b"}
query_context = QueryContext(**payload)
self.assertEqual(len(query_context.queries), 1)
query_object = query_context.queries[0]
self.assertEqual(query_object.granularity, "timecol")
self.assertIn("having_druid", query_object.extras)