chore: Improve chart data API + schemas + tests (#9599)
* Make all fields optional in QueryObject and fix having_druid schema * fix: datasource type sql to table * lint * Add missing fields * Refactor tests * Linting * Refactor query context fixtures * Add typing to test func
This commit is contained in:
parent
76764acfc1
commit
a6cedaaa87
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from marshmallow import fields, post_load, Schema, ValidationError
|
||||
from marshmallow import fields, post_load, Schema, validate, ValidationError
|
||||
from marshmallow.validate import Length
|
||||
|
||||
from superset.common.query_context import QueryContext
|
||||
|
|
@ -77,13 +77,15 @@ class ChartDataAdhocMetricSchema(Schema):
|
|||
expressionType = fields.String(
|
||||
description="Simple or SQL metric",
|
||||
required=True,
|
||||
enum=["SIMPLE", "SQL"],
|
||||
validate=validate.OneOf(choices=("SIMPLE", "SQL")),
|
||||
example="SQL",
|
||||
)
|
||||
aggregate = fields.String(
|
||||
description="Aggregation operator. Only required for simple expression types.",
|
||||
required=False,
|
||||
enum=["AVG", "COUNT", "COUNT_DISTINCT", "MAX", "MIN", "SUM"],
|
||||
validate=validate.OneOf(
|
||||
choices=("AVG", "COUNT", "COUNT_DISTINCT", "MAX", "MIN", "SUM")
|
||||
),
|
||||
)
|
||||
column = fields.Nested(ChartDataColumnSchema)
|
||||
sqlExpression = fields.String(
|
||||
|
|
@ -178,28 +180,30 @@ class ChartDataRollingOptionsSchema(ChartDataPostProcessingOperationOptionsSchem
|
|||
)
|
||||
rolling_type = fields.String(
|
||||
description="Type of rolling window. Any numpy function will work.",
|
||||
enum=[
|
||||
"average",
|
||||
"argmin",
|
||||
"argmax",
|
||||
"cumsum",
|
||||
"cumprod",
|
||||
"max",
|
||||
"mean",
|
||||
"median",
|
||||
"nansum",
|
||||
"nanmin",
|
||||
"nanmax",
|
||||
"nanmean",
|
||||
"nanmedian",
|
||||
"min",
|
||||
"percentile",
|
||||
"prod",
|
||||
"product",
|
||||
"std",
|
||||
"sum",
|
||||
"var",
|
||||
],
|
||||
validate=validate.OneOf(
|
||||
choices=(
|
||||
"average",
|
||||
"argmin",
|
||||
"argmax",
|
||||
"cumsum",
|
||||
"cumprod",
|
||||
"max",
|
||||
"mean",
|
||||
"median",
|
||||
"nansum",
|
||||
"nanmin",
|
||||
"nanmax",
|
||||
"nanmean",
|
||||
"nanmedian",
|
||||
"min",
|
||||
"percentile",
|
||||
"prod",
|
||||
"product",
|
||||
"std",
|
||||
"sum",
|
||||
"var",
|
||||
)
|
||||
),
|
||||
required=True,
|
||||
example="percentile",
|
||||
)
|
||||
|
|
@ -225,23 +229,25 @@ class ChartDataRollingOptionsSchema(ChartDataPostProcessingOperationOptionsSchem
|
|||
"additional parameters to `rolling_type_options`. For instance, "
|
||||
"to use `gaussian`, the parameter `std` needs to be provided.",
|
||||
required=False,
|
||||
enum=[
|
||||
"boxcar",
|
||||
"triang",
|
||||
"blackman",
|
||||
"hamming",
|
||||
"bartlett",
|
||||
"parzen",
|
||||
"bohman",
|
||||
"blackmanharris",
|
||||
"nuttall",
|
||||
"barthann",
|
||||
"kaiser",
|
||||
"gaussian",
|
||||
"general_gaussian",
|
||||
"slepian",
|
||||
"exponential",
|
||||
],
|
||||
validate=validate.OneOf(
|
||||
choices=(
|
||||
"boxcar",
|
||||
"triang",
|
||||
"blackman",
|
||||
"hamming",
|
||||
"bartlett",
|
||||
"parzen",
|
||||
"bohman",
|
||||
"blackmanharris",
|
||||
"nuttall",
|
||||
"barthann",
|
||||
"kaiser",
|
||||
"gaussian",
|
||||
"general_gaussian",
|
||||
"slepian",
|
||||
"exponential",
|
||||
)
|
||||
),
|
||||
)
|
||||
min_periods = fields.Integer(
|
||||
description="The minimum amount of periods required for a row to be included "
|
||||
|
|
@ -333,7 +339,9 @@ class ChartDataPostProcessingOperationSchema(Schema):
|
|||
operation = fields.String(
|
||||
description="Post processing operation type",
|
||||
required=True,
|
||||
enum=["aggregate", "pivot", "rolling", "select", "sort"],
|
||||
validate=validate.OneOf(
|
||||
choices=("aggregate", "pivot", "rolling", "select", "sort")
|
||||
),
|
||||
example="aggregate",
|
||||
)
|
||||
options = fields.Nested(
|
||||
|
|
@ -362,7 +370,9 @@ class ChartDataFilterSchema(Schema):
|
|||
)
|
||||
op = fields.String( # pylint: disable=invalid-name
|
||||
description="The comparison operator.",
|
||||
enum=[filter_op.value for filter_op in utils.FilterOperator],
|
||||
validate=validate.OneOf(
|
||||
choices=[filter_op.value for filter_op in utils.FilterOperator]
|
||||
),
|
||||
required=True,
|
||||
example="IN",
|
||||
)
|
||||
|
|
@ -376,21 +386,23 @@ class ChartDataFilterSchema(Schema):
|
|||
class ChartDataExtrasSchema(Schema):
|
||||
|
||||
time_range_endpoints = fields.List(
|
||||
fields.String(enum=["INCLUSIVE", "EXCLUSIVE"]),
|
||||
description="A list with two values, stating if start/end should be "
|
||||
"inclusive/exclusive.",
|
||||
required=False,
|
||||
fields.String(
|
||||
validate=validate.OneOf(choices=("INCLUSIVE", "EXCLUSIVE")),
|
||||
description="A list with two values, stating if start/end should be "
|
||||
"inclusive/exclusive.",
|
||||
required=False,
|
||||
)
|
||||
)
|
||||
relative_start = fields.String(
|
||||
description="Start time for relative time deltas. "
|
||||
'Default: `config["DEFAULT_RELATIVE_START_TIME"]`',
|
||||
enum=["today", "now"],
|
||||
validate=validate.OneOf(choices=("today", "now")),
|
||||
required=False,
|
||||
)
|
||||
relative_end = fields.String(
|
||||
description="End time for relative time deltas. "
|
||||
'Default: `config["DEFAULT_RELATIVE_START_TIME"]`',
|
||||
enum=["today", "now"],
|
||||
validate=validate.OneOf(choices=("today", "now")),
|
||||
required=False,
|
||||
)
|
||||
where = fields.String(
|
||||
|
|
@ -402,35 +414,54 @@ class ChartDataExtrasSchema(Schema):
|
|||
"AND operator.",
|
||||
required=False,
|
||||
)
|
||||
having_druid = fields.String(
|
||||
having_druid = fields.List(
|
||||
fields.Nested(ChartDataFilterSchema),
|
||||
description="HAVING filters to be added to legacy Druid datasource queries.",
|
||||
required=False,
|
||||
)
|
||||
time_grain_sqla = fields.String(
|
||||
description="To what level of granularity should the temporal column be "
|
||||
"aggregated. Supports "
|
||||
"[ISO 8601](https://en.wikipedia.org/wiki/ISO_8601#Durations) durations.",
|
||||
validate=validate.OneOf(
|
||||
choices=(
|
||||
"PT1S",
|
||||
"PT1M",
|
||||
"PT5M",
|
||||
"PT10M",
|
||||
"PT15M",
|
||||
"PT0.5H",
|
||||
"PT1H",
|
||||
"P1D",
|
||||
"P1W",
|
||||
"P1M",
|
||||
"P0.25Y",
|
||||
"P1Y",
|
||||
),
|
||||
),
|
||||
required=False,
|
||||
example="P1D",
|
||||
)
|
||||
druid_time_origin = fields.String(
|
||||
description="Starting point for time grain counting on legacy Druid "
|
||||
"datasources. Used to change e.g. Monday/Sunday first-day-of-week.",
|
||||
required=False,
|
||||
)
|
||||
|
||||
|
||||
class ChartDataQueryObjectSchema(Schema):
|
||||
filters = fields.List(fields.Nested(ChartDataFilterSchema), required=False)
|
||||
granularity = fields.String(
|
||||
description="To what level of granularity should the temporal column be "
|
||||
"aggregated. Supports "
|
||||
"[ISO 8601](https://en.wikipedia.org/wiki/ISO_8601#Durations) "
|
||||
"durations.",
|
||||
enum=[
|
||||
"PT1S",
|
||||
"PT1M",
|
||||
"PT5M",
|
||||
"PT10M",
|
||||
"PT15M",
|
||||
"PT0.5H",
|
||||
"PT1H",
|
||||
"P1D",
|
||||
"P1W",
|
||||
"P1M",
|
||||
"P0.25Y",
|
||||
"P1Y",
|
||||
],
|
||||
description="Name of temporal column used for time filtering. For legacy Druid "
|
||||
"datasources this defines the time grain.",
|
||||
required=False,
|
||||
example="P1D",
|
||||
)
|
||||
granularity_sqla = fields.String(
|
||||
description="Name of temporal column used for time filtering for SQL "
|
||||
"datasources. This field is deprecated, use `granularity` "
|
||||
"instead.",
|
||||
required=False,
|
||||
deprecated=True,
|
||||
)
|
||||
groupby = fields.List(
|
||||
fields.String(description="Columns by which to group the query.",),
|
||||
|
|
@ -441,6 +472,7 @@ class ChartDataQueryObjectSchema(Schema):
|
|||
"references to datasource metrics (strings), or ad-hoc metrics"
|
||||
"which are defined only within the query object. See "
|
||||
"`ChartDataAdhocMetricSchema` for the structure of ad-hoc metrics.",
|
||||
required=False,
|
||||
)
|
||||
post_processing = fields.List(
|
||||
fields.Nested(ChartDataPostProcessingOperationSchema),
|
||||
|
|
@ -450,7 +482,8 @@ class ChartDataQueryObjectSchema(Schema):
|
|||
)
|
||||
time_range = fields.String(
|
||||
description="A time rage, either expressed as a colon separated string "
|
||||
"`since : until`. Valid formats for `since` and `until` are: \n"
|
||||
"`since : until` or human readable freeform. Valid formats for "
|
||||
"`since` and `until` are: \n"
|
||||
"- ISO 8601\n"
|
||||
"- X days/years/hours/day/year/weeks\n"
|
||||
"- X days/years/hours/day/year/weeks ago\n"
|
||||
|
|
@ -488,7 +521,7 @@ class ChartDataQueryObjectSchema(Schema):
|
|||
order_desc = fields.Boolean(
|
||||
description="Reverse order. Default: `false`", required=False
|
||||
)
|
||||
extras = fields.Dict(description=" Default: `{}`", required=False)
|
||||
extras = fields.Nested(ChartDataExtrasSchema, required=False)
|
||||
columns = fields.List(fields.String(), description="", required=False,)
|
||||
orderby = fields.List(
|
||||
fields.List(fields.Raw()),
|
||||
|
|
@ -499,13 +532,13 @@ class ChartDataQueryObjectSchema(Schema):
|
|||
)
|
||||
where = fields.String(
|
||||
description="WHERE clause to be added to queries using AND operator."
|
||||
"This field is deprecated, and should be passed to `extras`.",
|
||||
"This field is deprecated and should be passed to `extras`.",
|
||||
required=False,
|
||||
deprecated=True,
|
||||
)
|
||||
having = fields.String(
|
||||
description="HAVING clause to be added to aggregate queries using "
|
||||
"AND operator. This field is deprecated, and should be passed "
|
||||
"AND operator. This field is deprecated and should be passed "
|
||||
"to `extras`.",
|
||||
required=False,
|
||||
deprecated=True,
|
||||
|
|
@ -513,7 +546,7 @@ class ChartDataQueryObjectSchema(Schema):
|
|||
having_filters = fields.List(
|
||||
fields.Dict(),
|
||||
description="HAVING filters to be added to legacy Druid datasource queries. "
|
||||
"This field is deprecated, and should be passed to `extras` "
|
||||
"This field is deprecated and should be passed to `extras` "
|
||||
"as `filters_druid`.",
|
||||
required=False,
|
||||
deprecated=True,
|
||||
|
|
@ -523,7 +556,10 @@ class ChartDataQueryObjectSchema(Schema):
|
|||
class ChartDataDatasourceSchema(Schema):
|
||||
description = "Chart datasource"
|
||||
id = fields.Integer(description="Datasource id", required=True,)
|
||||
type = fields.String(description="Datasource type", enum=["druid", "sql"])
|
||||
type = fields.String(
|
||||
description="Datasource type",
|
||||
validate=validate.OneOf(choices=("druid", "table")),
|
||||
)
|
||||
|
||||
|
||||
class ChartDataQueryContextSchema(Schema):
|
||||
|
|
@ -561,15 +597,17 @@ class ChartDataResponseResult(Schema):
|
|||
)
|
||||
status = fields.String(
|
||||
description="Status of the query",
|
||||
enum=[
|
||||
"stopped",
|
||||
"failed",
|
||||
"pending",
|
||||
"running",
|
||||
"scheduled",
|
||||
"success",
|
||||
"timed_out",
|
||||
],
|
||||
validate=validate.OneOf(
|
||||
choices=(
|
||||
"stopped",
|
||||
"failed",
|
||||
"pending",
|
||||
"running",
|
||||
"scheduled",
|
||||
"success",
|
||||
"timed_out",
|
||||
)
|
||||
),
|
||||
allow_none=False,
|
||||
)
|
||||
stacktrace = fields.String(
|
||||
|
|
|
|||
|
|
@ -35,15 +35,19 @@ logger = logging.getLogger(__name__)
|
|||
# https://github.com/python/mypy/issues/5288
|
||||
|
||||
|
||||
class DeprecatedExtrasField(NamedTuple):
|
||||
name: str
|
||||
extras_name: str
|
||||
class DeprecatedField(NamedTuple):
|
||||
old_name: str
|
||||
new_name: str
|
||||
|
||||
|
||||
DEPRECATED_FIELDS = (
|
||||
DeprecatedField(old_name="granularity_sqla", new_name="granularity"),
|
||||
)
|
||||
|
||||
DEPRECATED_EXTRAS_FIELDS = (
|
||||
DeprecatedExtrasField(name="where", extras_name="where"),
|
||||
DeprecatedExtrasField(name="having", extras_name="having"),
|
||||
DeprecatedExtrasField(name="having_filters", extras_name="having_druid"),
|
||||
DeprecatedField(old_name="where", new_name="where"),
|
||||
DeprecatedField(old_name="having", new_name="having"),
|
||||
DeprecatedField(old_name="having_filters", new_name="having_druid"),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -53,7 +57,7 @@ class QueryObject:
|
|||
and druid. The query objects are constructed on the client.
|
||||
"""
|
||||
|
||||
granularity: str
|
||||
granularity: Optional[str]
|
||||
from_dttm: datetime
|
||||
to_dttm: datetime
|
||||
is_timeseries: bool
|
||||
|
|
@ -72,8 +76,8 @@ class QueryObject:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
granularity: str,
|
||||
metrics: List[Union[Dict[str, Any], str]],
|
||||
granularity: Optional[str] = None,
|
||||
metrics: Optional[List[Union[Dict[str, Any], str]]] = None,
|
||||
groupby: Optional[List[str]] = None,
|
||||
filters: Optional[List[Dict[str, Any]]] = None,
|
||||
time_range: Optional[str] = None,
|
||||
|
|
@ -89,6 +93,7 @@ class QueryObject:
|
|||
post_processing: Optional[List[Dict[str, Any]]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
metrics = metrics or []
|
||||
extras = extras or {}
|
||||
is_sip_38 = is_feature_enabled("SIP_38_VIZ_REARCHITECTURE")
|
||||
self.granularity = granularity
|
||||
|
|
@ -131,22 +136,44 @@ class QueryObject:
|
|||
if is_sip_38 and groupby:
|
||||
self.columns += groupby
|
||||
logger.warning(
|
||||
f"The field groupby is deprecated. Viz plugins should "
|
||||
f"pass all selectables via the columns field"
|
||||
f"The field `groupby` is deprecated. Viz plugins should "
|
||||
f"pass all selectables via the `columns` field"
|
||||
)
|
||||
|
||||
self.orderby = orderby or []
|
||||
|
||||
# move deprecated fields to extras
|
||||
for field in DEPRECATED_EXTRAS_FIELDS:
|
||||
if field.name in kwargs:
|
||||
# rename deprecated fields
|
||||
for field in DEPRECATED_FIELDS:
|
||||
if field.old_name in kwargs:
|
||||
logger.warning(
|
||||
f"The field `{field.name} is deprecated, and should be "
|
||||
f"passed to `extras` via the `{field.extras_name}` property"
|
||||
f"The field `{field.old_name}` is deprecated, please use "
|
||||
f"`{field.new_name}` instead."
|
||||
)
|
||||
value = kwargs[field.name]
|
||||
value = kwargs[field.old_name]
|
||||
if value:
|
||||
self.extras[field.extras_name] = value
|
||||
if hasattr(self, field.new_name):
|
||||
logger.warning(
|
||||
f"The field `{field.new_name}` is already populated, "
|
||||
f"replacing value with contents from `{field.old_name}`."
|
||||
)
|
||||
setattr(self, field.new_name, value)
|
||||
|
||||
# move deprecated extras fields to extras
|
||||
for field in DEPRECATED_EXTRAS_FIELDS:
|
||||
if field.old_name in kwargs:
|
||||
logger.warning(
|
||||
f"The field `{field.old_name}` is deprecated and should be "
|
||||
f"passed to `extras` via the `{field.new_name}` property."
|
||||
)
|
||||
value = kwargs[field.old_name]
|
||||
if value:
|
||||
if hasattr(self.extras, field.new_name):
|
||||
logger.warning(
|
||||
f"The field `{field.new_name}` is already populated in "
|
||||
f"`extras`, replacing value with contents "
|
||||
f"from `{field.old_name}`."
|
||||
)
|
||||
self.extras[field.new_name] = value
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
query_object_dict = {
|
||||
|
|
|
|||
|
|
@ -366,7 +366,9 @@ class BaseDatasource(
|
|||
def default_query(qry) -> Query:
|
||||
return qry
|
||||
|
||||
def get_column(self, column_name: str) -> Optional["BaseColumn"]:
|
||||
def get_column(self, column_name: Optional[str]) -> Optional["BaseColumn"]:
|
||||
if not column_name:
|
||||
return None
|
||||
for col in self.columns:
|
||||
if col.column_name == column_name:
|
||||
return col
|
||||
|
|
|
|||
|
|
@ -385,7 +385,7 @@ class RequestAccessTests(SupersetTestCase):
|
|||
)
|
||||
self.assertEqual(
|
||||
"[Superset] Access to the datasource {} was granted".format(
|
||||
self.get_table(ds_1_id).full_name
|
||||
self.get_table_by_id(ds_1_id).full_name
|
||||
),
|
||||
call_args[2]["Subject"],
|
||||
)
|
||||
|
|
@ -426,7 +426,7 @@ class RequestAccessTests(SupersetTestCase):
|
|||
)
|
||||
self.assertEqual(
|
||||
"[Superset] Access to the datasource {} was granted".format(
|
||||
self.get_table(ds_2_id).full_name
|
||||
self.get_table_by_id(ds_2_id).full_name
|
||||
),
|
||||
call_args[2]["Subject"],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,16 +18,18 @@
|
|||
"""Unit tests for Superset"""
|
||||
import imp
|
||||
import json
|
||||
from typing import Union, Dict
|
||||
from typing import Dict, Union
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pandas as pd
|
||||
from flask import Response
|
||||
from flask_appbuilder.security.sqla import models as ab_models
|
||||
from flask_testing import TestCase
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from tests.test_app import app # isort:skip
|
||||
from superset import db, security_manager
|
||||
from superset.connectors.base.models import BaseDatasource
|
||||
from superset.connectors.druid.models import DruidCluster, DruidDatasource
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.models import core as models
|
||||
|
|
@ -103,7 +105,8 @@ class SupersetTestCase(TestCase):
|
|||
session.add(druid_datasource2)
|
||||
session.commit()
|
||||
|
||||
def get_table(self, table_id):
|
||||
@staticmethod
|
||||
def get_table_by_id(table_id: int) -> SqlaTable:
|
||||
return db.session.query(SqlaTable).filter_by(id=table_id).one()
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -127,21 +130,25 @@ class SupersetTestCase(TestCase):
|
|||
resp = self.get_resp("/login/", data=dict(username=username, password=password))
|
||||
self.assertNotIn("User confirmation needed", resp)
|
||||
|
||||
def get_slice(self, slice_name, session):
|
||||
def get_slice(self, slice_name: str, session: Session) -> Slice:
|
||||
slc = session.query(Slice).filter_by(slice_name=slice_name).one()
|
||||
session.expunge_all()
|
||||
return slc
|
||||
|
||||
def get_table_by_name(self, name):
|
||||
@staticmethod
|
||||
def get_table_by_name(name: str) -> SqlaTable:
|
||||
return db.session.query(SqlaTable).filter_by(table_name=name).one()
|
||||
|
||||
def get_database_by_id(self, db_id):
|
||||
@staticmethod
|
||||
def get_database_by_id(db_id: int) -> Database:
|
||||
return db.session.query(Database).filter_by(id=db_id).one()
|
||||
|
||||
def get_druid_ds_by_name(self, name):
|
||||
@staticmethod
|
||||
def get_druid_ds_by_name(name: str) -> DruidDatasource:
|
||||
return db.session.query(DruidDatasource).filter_by(datasource_name=name).first()
|
||||
|
||||
def get_datasource_mock(self):
|
||||
@staticmethod
|
||||
def get_datasource_mock() -> BaseDatasource:
|
||||
datasource = Mock()
|
||||
results = Mock()
|
||||
results.query = Mock()
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
"""Unit tests for Superset"""
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import prison
|
||||
from sqlalchemy.sql import func
|
||||
|
|
@ -28,6 +28,7 @@ from superset.models.dashboard import Dashboard
|
|||
from superset.models.slice import Slice
|
||||
from tests.base_api_tests import ApiOwnersTestCaseMixin
|
||||
from tests.base_tests import SupersetTestCase
|
||||
from tests.fixtures.query_context import get_query_context
|
||||
|
||||
|
||||
class ChartApiTests(SupersetTestCase, ApiOwnersTestCaseMixin):
|
||||
|
|
@ -69,32 +70,6 @@ class ChartApiTests(SupersetTestCase, ApiOwnersTestCaseMixin):
|
|||
db.session.commit()
|
||||
return slice
|
||||
|
||||
def _get_query_context(self) -> Dict[str, Any]:
|
||||
self.login(username="admin")
|
||||
slc = self.get_slice("Girl Name Cloud", db.session)
|
||||
return {
|
||||
"datasource": {"id": slc.datasource_id, "type": slc.datasource_type},
|
||||
"queries": [
|
||||
{
|
||||
"extras": {"where": ""},
|
||||
"granularity": "ds",
|
||||
"groupby": ["name"],
|
||||
"is_timeseries": False,
|
||||
"metrics": [{"label": "sum__num"}],
|
||||
"order_desc": True,
|
||||
"orderby": [],
|
||||
"row_limit": 100,
|
||||
"time_range": "100 years ago : now",
|
||||
"timeseries_limit": 0,
|
||||
"timeseries_limit_metric": None,
|
||||
"filters": [{"col": "gender", "op": "==", "val": "boy"}],
|
||||
"having": "",
|
||||
"having_filters": [],
|
||||
"where": "",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
def test_delete_chart(self):
|
||||
"""
|
||||
Chart API: Test delete
|
||||
|
|
@ -662,22 +637,37 @@ class ChartApiTests(SupersetTestCase, ApiOwnersTestCaseMixin):
|
|||
Query API: Test chart data query
|
||||
"""
|
||||
self.login(username="admin")
|
||||
query_context = self._get_query_context()
|
||||
table = self.get_table_by_name("birth_names")
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
uri = "api/v1/chart/data"
|
||||
rv = self.post_assert_metric(uri, query_context, "data")
|
||||
rv = self.post_assert_metric(uri, payload, "data")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(data["result"][0]["rowcount"], 100)
|
||||
|
||||
def test_invalid_chart_data(self):
|
||||
"""
|
||||
Query API: Test chart data query with invalid schema
|
||||
def test_chart_data_with_invalid_datasource(self):
|
||||
"""Query API: Test chart data query with invalid schema
|
||||
"""
|
||||
self.login(username="admin")
|
||||
query_context = self._get_query_context()
|
||||
query_context["datasource"] = "abc"
|
||||
table = self.get_table_by_name("birth_names")
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
payload["datasource"] = "abc"
|
||||
uri = "api/v1/chart/data"
|
||||
rv = self.client.post(uri, json=query_context)
|
||||
rv = self.post_assert_metric(uri, payload, "data")
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
|
||||
def test_chart_data_with_invalid_enum_value(self):
|
||||
"""Query API: Test chart data query with invalid enum value
|
||||
"""
|
||||
self.login(username="admin")
|
||||
table = self.get_table_by_name("birth_names")
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
payload["queries"][0]["extras"]["time_range_endpoints"] = [
|
||||
"abc",
|
||||
"EXCLUSIVE",
|
||||
]
|
||||
uri = "api/v1/chart/data"
|
||||
rv = self.client.post(uri, json=payload)
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
|
||||
def test_query_exec_not_allowed(self):
|
||||
|
|
@ -685,9 +675,10 @@ class ChartApiTests(SupersetTestCase, ApiOwnersTestCaseMixin):
|
|||
Query API: Test chart data query not allowed
|
||||
"""
|
||||
self.login(username="gamma")
|
||||
query_context = self._get_query_context()
|
||||
table = self.get_table_by_name("birth_names")
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
uri = "api/v1/chart/data"
|
||||
rv = self.post_assert_metric(uri, query_context, "data")
|
||||
rv = self.post_assert_metric(uri, payload, "data")
|
||||
self.assertEqual(rv.status_code, 401)
|
||||
|
||||
def test_datasources(self):
|
||||
|
|
|
|||
|
|
@ -28,7 +28,6 @@ import pytz
|
|||
import random
|
||||
import re
|
||||
import string
|
||||
from typing import Any, Dict
|
||||
import unittest
|
||||
from unittest import mock, skipUnless
|
||||
|
||||
|
|
@ -44,8 +43,6 @@ from superset import (
|
|||
sql_lab,
|
||||
is_feature_enabled,
|
||||
)
|
||||
from superset.common.query_context import QueryContext
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.db_engine_specs.mssql import MssqlEngineSpec
|
||||
|
|
@ -111,61 +108,6 @@ class CoreTests(SupersetTestCase):
|
|||
resp = self.client.get("/superset/slice/-1/")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def _get_query_context(self) -> Dict[str, Any]:
|
||||
self.login(username="admin")
|
||||
slc = self.get_slice("Girl Name Cloud", db.session)
|
||||
return {
|
||||
"datasource": {"id": slc.datasource_id, "type": slc.datasource_type},
|
||||
"queries": [
|
||||
{
|
||||
"granularity": "ds",
|
||||
"groupby": ["name"],
|
||||
"metrics": [{"label": "sum__num"}],
|
||||
"filters": [],
|
||||
"row_limit": 100,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
def _get_query_context_with_post_processing(self) -> Dict[str, Any]:
|
||||
self.login(username="admin")
|
||||
slc = self.get_slice("Girl Name Cloud", db.session)
|
||||
return {
|
||||
"datasource": {"id": slc.datasource_id, "type": slc.datasource_type},
|
||||
"queries": [
|
||||
{
|
||||
"granularity": "ds",
|
||||
"groupby": ["name", "state"],
|
||||
"metrics": [{"label": "sum__num"}],
|
||||
"filters": [],
|
||||
"row_limit": 100,
|
||||
"post_processing": [
|
||||
{
|
||||
"operation": "aggregate",
|
||||
"options": {
|
||||
"groupby": ["state"],
|
||||
"aggregates": {
|
||||
"q1": {
|
||||
"operator": "percentile",
|
||||
"column": "sum__num",
|
||||
"options": {"q": 25},
|
||||
},
|
||||
"median": {
|
||||
"operator": "median",
|
||||
"column": "sum__num",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"operation": "sort",
|
||||
"options": {"columns": {"q1": False, "state": True},},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
def test_viz_cache_key(self):
|
||||
self.login(username="admin")
|
||||
slc = self.get_slice("Girls", db.session)
|
||||
|
|
@ -178,45 +120,6 @@ class CoreTests(SupersetTestCase):
|
|||
qobj["groupby"] = []
|
||||
self.assertNotEqual(cache_key, viz.cache_key(qobj))
|
||||
|
||||
def test_cache_key_changes_when_datasource_is_updated(self):
|
||||
qc_dict = self._get_query_context()
|
||||
|
||||
# construct baseline cache_key
|
||||
query_context = QueryContext(**qc_dict)
|
||||
query_object = query_context.queries[0]
|
||||
cache_key_original = query_context.cache_key(query_object)
|
||||
|
||||
# make temporary change and revert it to refresh the changed_on property
|
||||
datasource = ConnectorRegistry.get_datasource(
|
||||
datasource_type=qc_dict["datasource"]["type"],
|
||||
datasource_id=qc_dict["datasource"]["id"],
|
||||
session=db.session,
|
||||
)
|
||||
description_original = datasource.description
|
||||
datasource.description = "temporary description"
|
||||
db.session.commit()
|
||||
datasource.description = description_original
|
||||
db.session.commit()
|
||||
|
||||
# create new QueryContext with unchanged attributes and extract new cache_key
|
||||
query_context = QueryContext(**qc_dict)
|
||||
query_object = query_context.queries[0]
|
||||
cache_key_new = query_context.cache_key(query_object)
|
||||
|
||||
# the new cache_key should be different due to updated datasource
|
||||
self.assertNotEqual(cache_key_original, cache_key_new)
|
||||
|
||||
def test_query_context_time_range_endpoints(self):
|
||||
query_context = QueryContext(**self._get_query_context())
|
||||
query_object = query_context.queries[0]
|
||||
extras = query_object.to_dict()["extras"]
|
||||
self.assertTrue("time_range_endpoints" in extras)
|
||||
|
||||
self.assertEquals(
|
||||
extras["time_range_endpoints"],
|
||||
(utils.TimeRangeEndpoint.INCLUSIVE, utils.TimeRangeEndpoint.EXCLUSIVE),
|
||||
)
|
||||
|
||||
def test_get_superset_tables_not_allowed(self):
|
||||
example_db = utils.get_example_database()
|
||||
schema_name = self.default_schema_backend_map[example_db.backend]
|
||||
|
|
@ -254,20 +157,6 @@ class CoreTests(SupersetTestCase):
|
|||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
|
||||
def test_api_v1_query_endpoint(self):
|
||||
self.login(username="admin")
|
||||
qc_dict = self._get_query_context()
|
||||
data = json.dumps(qc_dict)
|
||||
resp = json.loads(self.get_resp("/api/v1/query/", {"query_context": data}))
|
||||
self.assertEqual(resp[0]["rowcount"], 100)
|
||||
|
||||
def test_api_v1_query_endpoint_with_post_processing(self):
|
||||
self.login(username="admin")
|
||||
qc_dict = self._get_query_context_with_post_processing()
|
||||
data = json.dumps(qc_dict)
|
||||
resp = json.loads(self.get_resp("/api/v1/query/", {"query_context": data}))
|
||||
self.assertEqual(resp[0]["rowcount"], 6)
|
||||
|
||||
def test_old_slice_json_endpoint(self):
|
||||
self.login(username="admin")
|
||||
slc = self.get_slice("Girls", db.session)
|
||||
|
|
|
|||
|
|
@ -165,7 +165,7 @@ class DictImportExportTests(SupersetTestCase):
|
|||
new_table = SqlaTable.import_from_dict(db.session, dict_table)
|
||||
db.session.commit()
|
||||
imported_id = new_table.id
|
||||
imported = self.get_table(imported_id)
|
||||
imported = self.get_table_by_id(imported_id)
|
||||
self.assert_table_equals(table, imported)
|
||||
self.yaml_compare(table.export_to_dict(), imported.export_to_dict())
|
||||
|
||||
|
|
@ -178,7 +178,7 @@ class DictImportExportTests(SupersetTestCase):
|
|||
)
|
||||
imported_table = SqlaTable.import_from_dict(db.session, dict_table)
|
||||
db.session.commit()
|
||||
imported = self.get_table(imported_table.id)
|
||||
imported = self.get_table_by_id(imported_table.id)
|
||||
self.assert_table_equals(table, imported)
|
||||
self.assertEqual(
|
||||
{DBREF: ID_PREFIX + 2, "database_name": "main"}, json.loads(imported.params)
|
||||
|
|
@ -194,7 +194,7 @@ class DictImportExportTests(SupersetTestCase):
|
|||
)
|
||||
imported_table = SqlaTable.import_from_dict(db.session, dict_table)
|
||||
db.session.commit()
|
||||
imported = self.get_table(imported_table.id)
|
||||
imported = self.get_table_by_id(imported_table.id)
|
||||
self.assert_table_equals(table, imported)
|
||||
self.yaml_compare(table.export_to_dict(), imported.export_to_dict())
|
||||
|
||||
|
|
@ -213,7 +213,7 @@ class DictImportExportTests(SupersetTestCase):
|
|||
imported_over_table = SqlaTable.import_from_dict(db.session, dict_table_over)
|
||||
db.session.commit()
|
||||
|
||||
imported_over = self.get_table(imported_over_table.id)
|
||||
imported_over = self.get_table_by_id(imported_over_table.id)
|
||||
self.assertEqual(imported_table.id, imported_over.id)
|
||||
expected_table, _ = self.create_table(
|
||||
"table_override",
|
||||
|
|
@ -243,7 +243,7 @@ class DictImportExportTests(SupersetTestCase):
|
|||
)
|
||||
db.session.commit()
|
||||
|
||||
imported_over = self.get_table(imported_over_table.id)
|
||||
imported_over = self.get_table_by_id(imported_over_table.id)
|
||||
self.assertEqual(imported_table.id, imported_over.id)
|
||||
expected_table, _ = self.create_table(
|
||||
"table_override",
|
||||
|
|
@ -274,7 +274,7 @@ class DictImportExportTests(SupersetTestCase):
|
|||
imported_copy_table = SqlaTable.import_from_dict(db.session, dict_copy_table)
|
||||
db.session.commit()
|
||||
self.assertEqual(imported_table.id, imported_copy_table.id)
|
||||
self.assert_table_equals(copy_table, self.get_table(imported_table.id))
|
||||
self.assert_table_equals(copy_table, self.get_table_by_id(imported_table.id))
|
||||
self.yaml_compare(
|
||||
imported_copy_table.export_to_dict(), imported_table.export_to_dict()
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,103 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import copy
|
||||
from typing import Any, Dict, List
|
||||
|
||||
QUERY_OBJECTS = {
|
||||
"birth_names": {
|
||||
"extras": {"where": "", "time_range_endpoints": ["INCLUSIVE", "EXCLUSIVE"],},
|
||||
"granularity": "ds",
|
||||
"groupby": ["name"],
|
||||
"is_timeseries": False,
|
||||
"metrics": [{"label": "sum__num"}],
|
||||
"order_desc": True,
|
||||
"orderby": [],
|
||||
"row_limit": 100,
|
||||
"time_range": "100 years ago : now",
|
||||
"timeseries_limit": 0,
|
||||
"timeseries_limit_metric": None,
|
||||
"filters": [{"col": "gender", "op": "==", "val": "boy"}],
|
||||
"having": "",
|
||||
"having_filters": [],
|
||||
"where": "",
|
||||
}
|
||||
}
|
||||
|
||||
POSTPROCESSING_OPERATIONS = {
|
||||
"birth_names": [
|
||||
{
|
||||
"operation": "aggregate",
|
||||
"options": {
|
||||
"groupby": ["gender"],
|
||||
"aggregates": {
|
||||
"q1": {
|
||||
"operator": "percentile",
|
||||
"column": "sum__num",
|
||||
"options": {"q": 25},
|
||||
},
|
||||
"median": {"operator": "median", "column": "sum__num",},
|
||||
},
|
||||
},
|
||||
},
|
||||
{"operation": "sort", "options": {"columns": {"q1": False, "gender": True},},},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def _get_query_object(
|
||||
datasource_name: str, add_postprocessing_operations: bool
|
||||
) -> Dict[str, Any]:
|
||||
if datasource_name not in QUERY_OBJECTS:
|
||||
raise Exception(
|
||||
f"QueryObject fixture not defined for datasource: {datasource_name}"
|
||||
)
|
||||
query_object = copy.deepcopy(QUERY_OBJECTS[datasource_name])
|
||||
if add_postprocessing_operations:
|
||||
query_object["post_processing"] = _get_postprocessing_operation(datasource_name)
|
||||
return query_object
|
||||
|
||||
|
||||
def _get_postprocessing_operation(datasource_name: str) -> List[Dict[str, Any]]:
|
||||
if datasource_name not in QUERY_OBJECTS:
|
||||
raise Exception(
|
||||
f"Post-processing fixture not defined for datasource: {datasource_name}"
|
||||
)
|
||||
return copy.deepcopy(POSTPROCESSING_OPERATIONS[datasource_name])
|
||||
|
||||
|
||||
def get_query_context(
|
||||
datasource_name: str = "birth_names",
|
||||
datasource_id: int = 0,
|
||||
datasource_type: str = "table",
|
||||
add_postprocessing_operations: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a request payload for retrieving a QueryContext object via the
|
||||
`api/v1/chart/data` endpoint. By default returns a payload corresponding to one
|
||||
generated by the "Boy Name Cloud" chart in the examples.
|
||||
|
||||
:param datasource_name: name of datasource to query. Different datasources require
|
||||
different parameters in the QueryContext.
|
||||
:param datasource_id: id of datasource to query.
|
||||
:param datasource_type: type of datasource to query.
|
||||
:param add_postprocessing_operations: Add post-processing operations to QueryObject
|
||||
:return: Request payload
|
||||
"""
|
||||
return {
|
||||
"datasource": {"id": datasource_id, "type": datasource_type},
|
||||
"queries": [_get_query_object(datasource_name, add_postprocessing_operations)],
|
||||
}
|
||||
|
|
@ -558,7 +558,7 @@ class ImportExportTests(SupersetTestCase):
|
|||
def test_import_table_no_metadata(self):
|
||||
table = self.create_table("pure_table", id=10001)
|
||||
imported_id = SqlaTable.import_obj(table, import_time=1989)
|
||||
imported = self.get_table(imported_id)
|
||||
imported = self.get_table_by_id(imported_id)
|
||||
self.assert_table_equals(table, imported)
|
||||
|
||||
def test_import_table_1_col_1_met(self):
|
||||
|
|
@ -566,7 +566,7 @@ class ImportExportTests(SupersetTestCase):
|
|||
"table_1_col_1_met", id=10002, cols_names=["col1"], metric_names=["metric1"]
|
||||
)
|
||||
imported_id = SqlaTable.import_obj(table, import_time=1990)
|
||||
imported = self.get_table(imported_id)
|
||||
imported = self.get_table_by_id(imported_id)
|
||||
self.assert_table_equals(table, imported)
|
||||
self.assertEqual(
|
||||
{"remote_id": 10002, "import_time": 1990, "database_name": "examples"},
|
||||
|
|
@ -582,7 +582,7 @@ class ImportExportTests(SupersetTestCase):
|
|||
)
|
||||
imported_id = SqlaTable.import_obj(table, import_time=1991)
|
||||
|
||||
imported = self.get_table(imported_id)
|
||||
imported = self.get_table_by_id(imported_id)
|
||||
self.assert_table_equals(table, imported)
|
||||
|
||||
def test_import_table_override(self):
|
||||
|
|
@ -599,7 +599,7 @@ class ImportExportTests(SupersetTestCase):
|
|||
)
|
||||
imported_over_id = SqlaTable.import_obj(table_over, import_time=1992)
|
||||
|
||||
imported_over = self.get_table(imported_over_id)
|
||||
imported_over = self.get_table_by_id(imported_over_id)
|
||||
self.assertEqual(imported_id, imported_over.id)
|
||||
expected_table = self.create_table(
|
||||
"table_override",
|
||||
|
|
@ -627,7 +627,7 @@ class ImportExportTests(SupersetTestCase):
|
|||
imported_id_copy = SqlaTable.import_obj(copy_table, import_time=1994)
|
||||
|
||||
self.assertEqual(imported_id, imported_id_copy)
|
||||
self.assert_table_equals(copy_table, self.get_table(imported_id))
|
||||
self.assert_table_equals(copy_table, self.get_table_by_id(imported_id))
|
||||
|
||||
def test_import_druid_no_metadata(self):
|
||||
datasource = self.create_druid_datasource("pure_druid", id=10001)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,94 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from superset import db
|
||||
from superset.common.query_context import QueryContext
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.utils.core import TimeRangeEndpoint
|
||||
from tests.base_tests import SupersetTestCase
|
||||
from tests.fixtures.query_context import get_query_context
|
||||
from tests.test_app import app
|
||||
|
||||
|
||||
class QueryContextTests(SupersetTestCase):
|
||||
def test_cache_key_changes_when_datasource_is_updated(self):
|
||||
self.login(username="admin")
|
||||
table_name = "birth_names"
|
||||
table = self.get_table_by_name(table_name)
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
|
||||
# construct baseline cache_key
|
||||
query_context = QueryContext(**payload)
|
||||
query_object = query_context.queries[0]
|
||||
cache_key_original = query_context.cache_key(query_object)
|
||||
|
||||
# make temporary change and revert it to refresh the changed_on property
|
||||
datasource = ConnectorRegistry.get_datasource(
|
||||
datasource_type=payload["datasource"]["type"],
|
||||
datasource_id=payload["datasource"]["id"],
|
||||
session=db.session,
|
||||
)
|
||||
description_original = datasource.description
|
||||
datasource.description = "temporary description"
|
||||
db.session.commit()
|
||||
datasource.description = description_original
|
||||
db.session.commit()
|
||||
|
||||
# create new QueryContext with unchanged attributes and extract new cache_key
|
||||
query_context = QueryContext(**payload)
|
||||
query_object = query_context.queries[0]
|
||||
cache_key_new = query_context.cache_key(query_object)
|
||||
|
||||
# the new cache_key should be different due to updated datasource
|
||||
self.assertNotEqual(cache_key_original, cache_key_new)
|
||||
|
||||
def test_query_context_time_range_endpoints(self):
|
||||
"""
|
||||
Ensure that time_range_endpoints are populated automatically when missing
|
||||
from the payload
|
||||
"""
|
||||
self.login(username="admin")
|
||||
table_name = "birth_names"
|
||||
table = self.get_table_by_name(table_name)
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
del payload["queries"][0]["extras"]["time_range_endpoints"]
|
||||
query_context = QueryContext(**payload)
|
||||
query_object = query_context.queries[0]
|
||||
extras = query_object.to_dict()["extras"]
|
||||
self.assertTrue("time_range_endpoints" in extras)
|
||||
|
||||
self.assertEquals(
|
||||
extras["time_range_endpoints"],
|
||||
(TimeRangeEndpoint.INCLUSIVE, TimeRangeEndpoint.EXCLUSIVE),
|
||||
)
|
||||
|
||||
def test_convert_deprecated_fields(self):
|
||||
"""
|
||||
Ensure that deprecated fields are converted correctly
|
||||
"""
|
||||
self.login(username="admin")
|
||||
table_name = "birth_names"
|
||||
table = self.get_table_by_name(table_name)
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
payload["queries"][0]["granularity_sqla"] = "timecol"
|
||||
payload["queries"][0]["having_filters"] = {"col": "a", "op": "==", "val": "b"}
|
||||
query_context = QueryContext(**payload)
|
||||
self.assertEqual(len(query_context.queries), 1)
|
||||
query_object = query_context.queries[0]
|
||||
self.assertEqual(query_object.granularity, "timecol")
|
||||
self.assertIn("having_druid", query_object.extras)
|
||||
Loading…
Reference in New Issue