refactor: remove queryFields in QueryObject and update chart control configs (#12091)

* Clean up queryFields

* Clean up unused vars

* Bump chart plugins

* Bringing changes in #12147
This commit is contained in:
Jesse Yang 2020-12-22 17:10:19 -08:00 committed by GitHub
parent de61859e98
commit d2da25a621
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 328 additions and 2544 deletions

View File

@ -36,7 +36,7 @@ describe('Dashboard form data', () => {
});
});
it('should apply url params and queryFields to slice requests', () => {
it('should apply url params to slice requests', () => {
const aliases = getChartAliases(dashboard.slices);
// wait and verify one-by-one
cy.wait(aliases).then(requests => {
@ -48,7 +48,6 @@ describe('Dashboard form data', () => {
if (isLegacyResponse(responseBody)) {
const requestFormData = xhr.request.body;
const requestParams = JSON.parse(requestFormData.get('form_data'));
expect(requestParams).to.have.property('queryFields');
expect(requestParams.url_params).deep.eq(urlParams);
} else {
xhr.request.body.queries.forEach(query => {

File diff suppressed because it is too large Load Diff

View File

@ -65,35 +65,35 @@
"@babel/runtime-corejs3": "^7.8.4",
"@data-ui/sparkline": "^0.0.84",
"@emotion/core": "^10.0.35",
"@superset-ui/chart-controls": "^0.15.18",
"@superset-ui/core": "^0.15.19",
"@superset-ui/legacy-plugin-chart-calendar": "^0.15.18",
"@superset-ui/legacy-plugin-chart-chord": "^0.15.18",
"@superset-ui/legacy-plugin-chart-country-map": "^0.15.18",
"@superset-ui/legacy-plugin-chart-event-flow": "^0.15.18",
"@superset-ui/legacy-plugin-chart-force-directed": "^0.15.18",
"@superset-ui/legacy-plugin-chart-heatmap": "^0.15.18",
"@superset-ui/legacy-plugin-chart-histogram": "^0.15.18",
"@superset-ui/legacy-plugin-chart-horizon": "^0.15.18",
"@superset-ui/legacy-plugin-chart-map-box": "^0.15.18",
"@superset-ui/legacy-plugin-chart-paired-t-test": "^0.15.18",
"@superset-ui/legacy-plugin-chart-parallel-coordinates": "^0.15.18",
"@superset-ui/legacy-plugin-chart-partition": "^0.15.18",
"@superset-ui/legacy-plugin-chart-pivot-table": "^0.15.18",
"@superset-ui/legacy-plugin-chart-rose": "^0.15.18",
"@superset-ui/legacy-plugin-chart-sankey": "^0.15.18",
"@superset-ui/legacy-plugin-chart-sankey-loop": "^0.15.18",
"@superset-ui/legacy-plugin-chart-sunburst": "^0.15.18",
"@superset-ui/legacy-plugin-chart-treemap": "^0.15.18",
"@superset-ui/legacy-plugin-chart-world-map": "^0.15.18",
"@superset-ui/legacy-preset-chart-big-number": "^0.15.18",
"@superset-ui/chart-controls": "^0.16.1",
"@superset-ui/core": "^0.16.1",
"@superset-ui/legacy-plugin-chart-calendar": "^0.16.2",
"@superset-ui/legacy-plugin-chart-chord": "^0.16.2",
"@superset-ui/legacy-plugin-chart-country-map": "^0.16.1",
"@superset-ui/legacy-plugin-chart-event-flow": "^0.16.1",
"@superset-ui/legacy-plugin-chart-force-directed": "^0.16.2",
"@superset-ui/legacy-plugin-chart-heatmap": "^0.16.2",
"@superset-ui/legacy-plugin-chart-histogram": "^0.16.2",
"@superset-ui/legacy-plugin-chart-horizon": "^0.16.2",
"@superset-ui/legacy-plugin-chart-map-box": "^0.16.2",
"@superset-ui/legacy-plugin-chart-paired-t-test": "^0.16.1",
"@superset-ui/legacy-plugin-chart-parallel-coordinates": "^0.16.1",
"@superset-ui/legacy-plugin-chart-partition": "^0.16.2",
"@superset-ui/legacy-plugin-chart-pivot-table": "^0.16.2",
"@superset-ui/legacy-plugin-chart-rose": "^0.16.2",
"@superset-ui/legacy-plugin-chart-sankey": "^0.16.1",
"@superset-ui/legacy-plugin-chart-sankey-loop": "^0.16.2",
"@superset-ui/legacy-plugin-chart-sunburst": "^0.16.1",
"@superset-ui/legacy-plugin-chart-treemap": "^0.16.2",
"@superset-ui/legacy-plugin-chart-world-map": "^0.16.2",
"@superset-ui/legacy-preset-chart-big-number": "^0.16.1",
"@superset-ui/legacy-preset-chart-deckgl": "^0.3.2",
"@superset-ui/legacy-preset-chart-nvd3": "^0.15.18",
"@superset-ui/plugin-chart-echarts": "^0.15.18",
"@superset-ui/plugin-chart-table": "^0.15.18",
"@superset-ui/plugin-chart-word-cloud": "^0.15.18",
"@superset-ui/plugin-filter-antd": "^0.15.18",
"@superset-ui/preset-chart-xy": "^0.15.18",
"@superset-ui/legacy-preset-chart-nvd3": "^0.16.1",
"@superset-ui/plugin-chart-echarts": "^0.16.1",
"@superset-ui/plugin-chart-table": "^0.16.1",
"@superset-ui/plugin-chart-word-cloud": "^0.16.1",
"@superset-ui/plugin-filter-antd": "^0.16.1",
"@superset-ui/preset-chart-xy": "^0.16.1",
"@vx/responsive": "^0.0.195",
"abortcontroller-polyfill": "^1.1.9",
"antd": "^4.9.4",

View File

@ -91,6 +91,6 @@ describe('ControlPanelsContainer', () => {
it('renders ControlPanelSections', () => {
wrapper = shallow(<ControlPanelsContainer {...getDefaultProps()} />);
expect(wrapper.find(ControlPanelSection)).toHaveLength(6);
expect(wrapper.find(ControlPanelSection)).toHaveLength(5);
});
});

View File

@ -21,9 +21,7 @@ import { getChartControlPanelRegistry, t } from '@superset-ui/core';
import {
getControlConfig,
getControlState,
getFormDataFromControls,
applyMapStateToPropsToControl,
getAllControlsState,
findControlItem,
} from 'src/explore/controlUtils';
import {
@ -198,18 +196,6 @@ describe('controlUtils', () => {
});
});
describe('queryFields', () => {
it('in formData', () => {
const controlsState = getAllControlsState('table', 'table', {}, {});
const formData = getFormDataFromControls(controlsState);
expect(formData.queryFields).toEqual({
all_columns: 'columns',
metric: 'metrics',
metrics: 'metrics',
});
});
});
describe('findControlItem', () => {
it('find control as a string', () => {
const controlItem = findControlItem(

View File

@ -83,7 +83,6 @@ export const controlPanelSectionsChartOptionsTable = [
name: 'all_columns',
config: {
type: 'SelectControl',
queryField: 'columns',
multi: true,
label: t('Columns'),
default: [],

View File

@ -22,13 +22,10 @@ import { expandControlConfig } from '@superset-ui/chart-controls';
import * as SECTIONS from './controlPanels/sections';
export function getFormDataFromControls(controlsState) {
const formData = { queryFields: {} };
const formData = {};
Object.keys(controlsState).forEach(controlName => {
const control = controlsState[controlName];
formData[controlName] = control.value;
if (control.hasOwnProperty('queryField')) {
formData.queryFields[controlName] = control.queryField;
}
});
return formData;
}
@ -193,12 +190,15 @@ const getMemoizedSectionsToRender = memoizeOne(
}
});
const { datasourceAndVizType, sqlaTimeSeries, druidTimeSeries } = sections;
const timeSection =
datasourceType === 'table' ? sqlaTimeSeries : druidTimeSeries;
const { datasourceAndVizType } = sections;
// list of datasource-specific controls that should be removed
const invalidControls =
datasourceType === 'table'
? ['granularity', 'druid_time_origin']
: ['granularity_sqla', 'time_grain_sqla'];
return []
.concat(datasourceAndVizType, timeSection, controlPanelSections)
.concat(datasourceAndVizType, controlPanelSections)
.filter(section => !!section)
.map(section => {
const { controlSetRows } = section;
@ -206,7 +206,9 @@ const getMemoizedSectionsToRender = memoizeOne(
...section,
controlSetRows:
controlSetRows?.map(row =>
row.map(item => expandControlConfig(item, controlOverrides)),
row
.filter(control => !invalidControls.includes(control))
.map(item => expandControlConfig(item, controlOverrides)),
) || [],
};
});

View File

@ -118,7 +118,6 @@ const timeColumnOption = {
const groupByControl = {
type: 'SelectControl',
queryField: 'groupby',
multi: true,
freeForm: true,
label: t('Group by'),
@ -150,7 +149,6 @@ const groupByControl = {
const metrics = {
type: 'MetricsControl',
queryField: 'metrics',
multi: true,
label: t('Metrics'),
validators: [validateNonEmpty],

View File

@ -79,8 +79,6 @@ export function applyDefaultFormData(inputFormData) {
}
});
// always use dynamically generated queryFields
formData.queryFields = controlFormData.queryFields;
return formData;
}

View File

@ -109,6 +109,7 @@ const plugins = [
new ForkTsCheckerWebpackPlugin({
eslint: true,
checkSyntacticErrors: true,
memoryLimit: 4096,
}),
new CopyPlugin({

View File

@ -27,7 +27,13 @@ from pandas import DataFrame
from superset import app, is_feature_enabled
from superset.exceptions import QueryObjectValidationError
from superset.typing import Metric
from superset.utils import core as utils, pandas_postprocessing
from superset.utils import pandas_postprocessing
from superset.utils.core import (
DTTM_ALIAS,
get_since_until,
json_int_dttm_ser,
parse_human_timedelta,
)
from superset.views.utils import get_time_range_endpoints
config = app.config
@ -90,7 +96,7 @@ class QueryObject:
filters: Optional[List[Dict[str, Any]]] = None,
time_range: Optional[str] = None,
time_shift: Optional[str] = None,
is_timeseries: bool = False,
is_timeseries: Optional[bool] = None,
timeseries_limit: int = 0,
row_limit: Optional[int] = None,
row_offset: Optional[int] = None,
@ -114,7 +120,7 @@ class QueryObject:
]
self.applied_time_extras = applied_time_extras or {}
self.granularity = granularity
self.from_dttm, self.to_dttm = utils.get_since_until(
self.from_dttm, self.to_dttm = get_since_until(
relative_start=extras.get(
"relative_start", config["DEFAULT_RELATIVE_START_TIME"]
),
@ -124,20 +130,28 @@ class QueryObject:
time_range=time_range,
time_shift=time_shift,
)
self.is_timeseries = is_timeseries
# is_timeseries is True if time column is in groupby
self.is_timeseries = (
is_timeseries
if is_timeseries is not None
else (DTTM_ALIAS in groupby if groupby else False)
)
self.time_range = time_range
self.time_shift = utils.parse_human_timedelta(time_shift)
self.time_shift = parse_human_timedelta(time_shift)
self.post_processing = [
post_proc for post_proc in post_processing or [] if post_proc
]
if not is_sip_38:
self.groupby = groupby or []
# Temporary solution for backward compatibility issue due the new format of
# non-ad-hoc metric which needs to adhere to superset-ui per
# https://git.io/Jvm7P.
# Support metric reference/definition in the format of
# 1. 'metric_name' - name of predefined metric
# 2. { label: 'label_name' } - legacy format for a predefined metric
# 3. { expressionType: 'SIMPLE' | 'SQL', ... } - adhoc metric
self.metrics = [
metric if "expressionType" in metric else metric["label"] # type: ignore
metric
if isinstance(metric, str) or "expressionType" in metric
else metric["label"] # type: ignore
for metric in metrics
]
@ -267,7 +281,7 @@ class QueryObject:
@staticmethod
def json_dumps(obj: Any, sort_keys: bool = False) -> str:
return json.dumps(
obj, default=utils.json_int_dttm_ser, ignore_nan=True, sort_keys=sort_keys
obj, default=json_int_dttm_ser, ignore_nan=True, sort_keys=sort_keys
)
def exec_post_processing(self, df: DataFrame) -> DataFrame:

View File

@ -95,6 +95,10 @@ def post_assert_metric(
return rv
def get_table_by_name(name: str) -> SqlaTable:
return db.session.query(SqlaTable).filter_by(table_name=name).one()
@pytest.fixture
def logged_in_admin():
"""Fixture with app context and logged in admin user."""
@ -228,7 +232,7 @@ class SupersetTestCase(TestCase):
@staticmethod
def get_table_by_name(name: str) -> SqlaTable:
return db.session.query(SqlaTable).filter_by(table_name=name).one()
return get_table_by_name(name)
@staticmethod
def get_database_by_id(db_id: int) -> Database:

View File

@ -1033,8 +1033,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
Chart data API: Test chart data query
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload = get_query_context("birth_names")
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
@ -1045,8 +1044,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
Chart data API: Test chart data query with applied time extras
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload = get_query_context("birth_names")
request_payload["queries"][0]["applied_time_extras"] = {
"__time_range": "100 years ago : now",
"__time_origin": "now",
@ -1069,8 +1067,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
Chart data API: Test chart data query with limit and offset
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload = get_query_context("birth_names")
request_payload["queries"][0]["row_limit"] = 5
request_payload["queries"][0]["row_offset"] = 0
request_payload["queries"][0]["orderby"] = [["name", True]]
@ -1101,8 +1098,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
Chart data API: Ensure row count doesn't exceed default limit
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload = get_query_context("birth_names")
del request_payload["queries"][0]["row_limit"]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
@ -1117,8 +1113,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
Chart data API: Ensure sample response row count doesn't exceed default limit
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload = get_query_context("birth_names")
request_payload["result_type"] = utils.ChartDataResultType.SAMPLES
request_payload["queries"][0]["row_limit"] = 10
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
@ -1131,8 +1126,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
Chart data API: Test chart data with unsupported result type
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload = get_query_context("birth_names")
request_payload["result_type"] = "qwerty"
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 400)
@ -1142,8 +1136,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
Chart data API: Test chart data with unsupported result format
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload = get_query_context("birth_names")
request_payload["result_format"] = "qwerty"
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 400)
@ -1153,8 +1146,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
Chart data API: Test chart data with query result format
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload = get_query_context("birth_names")
request_payload["result_type"] = utils.ChartDataResultType.QUERY
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 200)
@ -1164,8 +1156,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
Chart data API: Test chart data with CSV result format
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload = get_query_context("birth_names")
request_payload["result_format"] = "csv"
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 200)
@ -1175,8 +1166,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
Chart data API: Ensure mixed case filter operator generates valid result
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload = get_query_context("birth_names")
request_payload["queries"][0]["filters"][0]["op"] = "In"
request_payload["queries"][0]["row_limit"] = 10
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
@ -1190,8 +1180,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
"""
pytest.importorskip("fbprophet")
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload = get_query_context("birth_names")
time_grain = "P1Y"
request_payload["queries"][0]["is_timeseries"] = True
request_payload["queries"][0]["groupby"] = []
@ -1224,8 +1213,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
Chart data API: Ensure filter referencing missing column is ignored
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload = get_query_context("birth_names")
request_payload["queries"][0]["filters"] = [
{"col": "non_existent_filter", "op": "==", "val": "foo"},
]
@ -1240,8 +1228,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
Chart data API: Test chart data with empty result
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload = get_query_context("birth_names")
request_payload["queries"][0]["filters"] = [
{"col": "gender", "op": "==", "val": "foo"}
]
@ -1257,8 +1244,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
Chart data API: Test chart data with invalid SQL
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload = get_query_context("birth_names")
request_payload["queries"][0]["filters"] = []
# erroneus WHERE-clause
request_payload["queries"][0]["extras"]["where"] = "(gender abc def)"
@ -1270,8 +1256,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
Chart data API: Test chart data query with invalid schema
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
payload = get_query_context(table.name, table.id, table.type)
payload = get_query_context("birth_names")
payload["datasource"] = "abc"
rv = self.post_assert_metric(CHART_DATA_URI, payload, "data")
self.assertEqual(rv.status_code, 400)
@ -1281,8 +1266,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
Chart data API: Test chart data query with invalid enum value
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
payload = get_query_context(table.name, table.id, table.type)
payload = get_query_context("birth_names")
payload["queries"][0]["extras"]["time_range_endpoints"] = [
"abc",
"EXCLUSIVE",
@ -1295,8 +1279,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
Chart data API: Test chart data query not allowed
"""
self.login(username="gamma")
table = self.get_table_by_name("birth_names")
payload = get_query_context(table.name, table.id, table.type)
payload = get_query_context("birth_names")
rv = self.post_assert_metric(CHART_DATA_URI, payload, "data")
self.assertEqual(rv.status_code, 401)
@ -1305,8 +1288,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
Chart data API: Ensure request referencing filters via jinja renders a correct query
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload = get_query_context("birth_names")
request_payload["result_type"] = utils.ChartDataResultType.QUERY
request_payload["queries"][0]["filters"] = [
{"col": "gender", "op": "==", "val": "boy"}
@ -1330,8 +1312,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
"""
async_query_manager.init_app(app)
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload = get_query_context("birth_names")
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 202)
data = json.loads(rv.data.decode("utf-8"))
@ -1350,8 +1331,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
"""
async_query_manager.init_app(app)
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload = get_query_context("birth_names")
request_payload["result_type"] = "results"
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 200)
@ -1366,8 +1346,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
"""
async_query_manager.init_app(app)
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload = get_query_context("birth_names")
test_client.set_cookie(
"localhost", app.config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME"], "foo"
)
@ -1385,8 +1364,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
"""
async_query_manager.init_app(app)
self.login(username="admin")
table = self.get_table_by_name("birth_names")
query_context = get_query_context(table.name, table.id, table.type)
query_context = get_query_context("birth_names")
load_qc_mock.return_value = query_context
orig_run = ChartDataCommand.run
@ -1415,8 +1393,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
"""
async_query_manager.init_app(app)
self.login(username="admin")
table = self.get_table_by_name("birth_names")
query_context = get_query_context(table.name, table.id, table.type)
query_context = get_query_context("birth_names")
load_qc_mock.return_value = query_context
rv = self.get_assert_metric(
f"{CHART_DATA_URI}/test-cache-key", "data_from_cache"
@ -1436,8 +1413,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
Chart data cache API: Test chart data async cache request (no login)
"""
async_query_manager.init_app(app)
table = self.get_table_by_name("birth_names")
query_context = get_query_context(table.name, table.id, table.type)
query_context = get_query_context("birth_names")
load_qc_mock.return_value = query_context
orig_run = ChartDataCommand.run
@ -1639,8 +1615,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
Chart data API: Test chart data query
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload = get_query_context("birth_names")
annotation_layers = []
request_payload["queries"][0]["annotation_layers"] = annotation_layers

View File

@ -29,9 +29,7 @@ from tests.fixtures.query_context import get_query_context
class TestSchema(SupersetTestCase):
def test_query_context_limit_and_offset(self):
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload = get_query_context("birth_names")
# Use defaults
payload["queries"][0].pop("row_limit", None)
@ -59,17 +57,13 @@ class TestSchema(SupersetTestCase):
def test_query_context_null_timegrain(self):
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload = get_query_context("birth_names")
payload["queries"][0]["extras"]["time_grain_sqla"] = None
_ = ChartDataQueryContextSchema().load(payload)
def test_query_context_series_limit(self):
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload = get_query_context("birth_names")
payload["queries"][0]["timeseries_limit"] = 2
payload["queries"][0]["timeseries_limit_metric"] = {
@ -90,9 +84,7 @@ class TestSchema(SupersetTestCase):
def test_query_context_null_post_processing_op(self):
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload = get_query_context("birth_names")
payload["queries"][0]["post_processing"] = [None]
query_context = ChartDataQueryContextSchema().load(payload)

View File

@ -848,11 +848,6 @@ class TestCore(SupersetTestCase):
def test_explore_json(self):
tbl_id = self.table_ids.get("birth_names")
form_data = {
"queryFields": {
"metrics": "metrics",
"groupby": "groupby",
"columns": "groupby",
},
"datasource": f"{tbl_id}__table",
"viz_type": "dist_bar",
"time_range_endpoints": ["inclusive", "exclusive"],
@ -879,11 +874,6 @@ class TestCore(SupersetTestCase):
def test_explore_json_async(self):
tbl_id = self.table_ids.get("birth_names")
form_data = {
"queryFields": {
"metrics": "metrics",
"groupby": "groupby",
"columns": "groupby",
},
"datasource": f"{tbl_id}__table",
"viz_type": "dist_bar",
"time_range_endpoints": ["inclusive", "exclusive"],
@ -914,11 +904,6 @@ class TestCore(SupersetTestCase):
def test_explore_json_async_results_format(self):
tbl_id = self.table_ids.get("birth_names")
form_data = {
"queryFields": {
"metrics": "metrics",
"groupby": "groupby",
"columns": "groupby",
},
"datasource": f"{tbl_id}__table",
"viz_type": "dist_bar",
"time_range_endpoints": ["inclusive", "exclusive"],
@ -947,11 +932,6 @@ class TestCore(SupersetTestCase):
form_data = dict(
{
"form_data": {
"queryFields": {
"metrics": "metrics",
"groupby": "groupby",
"columns": "groupby",
},
"datasource": f"{tbl_id}__table",
"viz_type": "dist_bar",
"time_range_endpoints": ["inclusive", "exclusive"],
@ -991,11 +971,6 @@ class TestCore(SupersetTestCase):
form_data = dict(
{
"form_data": {
"queryFields": {
"metrics": "metrics",
"groupby": "groupby",
"columns": "groupby",
},
"datasource": f"{tbl_id}__table",
"viz_type": "dist_bar",
"time_range_endpoints": ["inclusive", "exclusive"],

View File

@ -77,7 +77,7 @@ dashboard_export: Dict[str, Any] = {
"datasource_name": "birth_names_2",
"datasource_type": "table",
"id": 83,
"params": '{"adhoc_filters": [], "datasource": "3__table", "granularity_sqla": "ds", "header_font_size": 0.4, "metric": {"aggregate": "SUM", "column": {"column_name": "num_california", "expression": "CASE WHEN state = \'CA\' THEN num ELSE 0 END"}, "expressionType": "SIMPLE", "label": "SUM(num_california)"}, "queryFields": {"metric": "metrics"}, "slice_id": 83, "subheader_font_size": 0.15, "time_range": "100 years ago : now", "time_range_endpoints": ["unknown", "inclusive"], "url_params": {}, "viz_type": "big_number_total", "y_axis_format": "SMART_NUMBER", "remote_id": 83, "datasource_name": "birth_names_2", "schema": null, "database_name": "examples"}',
"params": '{"adhoc_filters": [], "datasource": "3__table", "granularity_sqla": "ds", "header_font_size": 0.4, "metric": {"aggregate": "SUM", "column": {"column_name": "num_california", "expression": "CASE WHEN state = \'CA\' THEN num ELSE 0 END"}, "expressionType": "SIMPLE", "label": "SUM(num_california)"}, "slice_id": 83, "subheader_font_size": 0.15, "time_range": "100 years ago : now", "time_range_endpoints": ["unknown", "inclusive"], "url_params": {}, "viz_type": "big_number_total", "y_axis_format": "SMART_NUMBER", "remote_id": 83, "datasource_name": "birth_names_2", "schema": null, "database_name": "examples"}',
"slice_name": "Number of California Births",
"viz_type": "big_number_total",
}

View File

@ -17,26 +17,29 @@
import copy
from typing import Any, Dict, List
from superset.utils.core import AnnotationType
from superset.utils.core import AnnotationType, DTTM_ALIAS
from tests.base_tests import get_table_by_name
QUERY_OBJECTS = {
"birth_names": {
"extras": {"where": "", "time_range_endpoints": ["inclusive", "exclusive"]},
"granularity": "ds",
"groupby": ["name"],
"is_timeseries": False,
"metrics": [{"label": "sum__num"}],
"order_desc": True,
"orderby": [["sum__num", False]],
"row_limit": 100,
"time_range": "100 years ago : now",
"timeseries_limit": 0,
"timeseries_limit_metric": None,
"filters": [{"col": "gender", "op": "==", "val": "boy"}],
"having": "",
"having_filters": [],
"where": "",
}
query_birth_names = {
"extras": {"where": "", "time_range_endpoints": ["inclusive", "exclusive"]},
"granularity": "ds",
"groupby": ["name"],
"metrics": [{"label": "sum__num"}],
"order_desc": True,
"orderby": [["sum__num", False]],
"row_limit": 100,
"time_range": "100 years ago : now",
"timeseries_limit": 0,
"timeseries_limit_metric": None,
"filters": [{"col": "gender", "op": "==", "val": "boy"}],
"having": "",
"having_filters": [],
"where": "",
}
QUERY_OBJECTS: Dict[str, Dict[str, object]] = {
"birth_names": {**query_birth_names, "is_timeseries": False,},
"birth_names:include_time": {**query_birth_names, "groupby": [DTTM_ALIAS, "name"],},
}
ANNOTATION_LAYERS = {
@ -131,46 +134,43 @@ POSTPROCESSING_OPERATIONS = {
}
def _get_query_object(
datasource_name: str, add_postprocessing_operations: bool
def get_query_object(
query_name: str, add_postprocessing_operations: bool
) -> Dict[str, Any]:
if datasource_name not in QUERY_OBJECTS:
raise Exception(
f"QueryObject fixture not defined for datasource: {datasource_name}"
)
query_object = copy.deepcopy(QUERY_OBJECTS[datasource_name])
if query_name not in QUERY_OBJECTS:
raise Exception(f"QueryObject fixture not defined for datasource: {query_name}")
query_object = copy.deepcopy(QUERY_OBJECTS[query_name])
if add_postprocessing_operations:
query_object["post_processing"] = _get_postprocessing_operation(datasource_name)
query_object["post_processing"] = _get_postprocessing_operation(query_name)
return query_object
def _get_postprocessing_operation(datasource_name: str) -> List[Dict[str, Any]]:
if datasource_name not in QUERY_OBJECTS:
def _get_postprocessing_operation(query_name: str) -> List[Dict[str, Any]]:
if query_name not in QUERY_OBJECTS:
raise Exception(
f"Post-processing fixture not defined for datasource: {datasource_name}"
f"Post-processing fixture not defined for datasource: {query_name}"
)
return copy.deepcopy(POSTPROCESSING_OPERATIONS[datasource_name])
return copy.deepcopy(POSTPROCESSING_OPERATIONS[query_name])
def get_query_context(
datasource_name: str = "birth_names",
datasource_id: int = 0,
datasource_type: str = "table",
add_postprocessing_operations: bool = False,
query_name: str, add_postprocessing_operations: bool = False,
) -> Dict[str, Any]:
"""
Create a request payload for retrieving a QueryContext object via the
`api/v1/chart/data` endpoint. By default returns a payload corresponding to one
generated by the "Boy Name Cloud" chart in the examples.
:param datasource_name: name of datasource to query. Different datasources require
different parameters in the QueryContext.
:param query_name: name of an example query, which is always in the format
of `datasource_name[:test_case_name]`, where `:test_case_name` is optional.
:param datasource_id: id of datasource to query.
:param datasource_type: type of datasource to query.
:param add_postprocessing_operations: Add post-processing operations to QueryObject
:return: Request payload
"""
table_name = query_name.split(":")[0]
table = get_table_by_name(table_name)
return {
"datasource": {"id": datasource_id, "type": datasource_type},
"queries": [_get_query_object(datasource_name, add_postprocessing_operations)],
"datasource": {"id": table.id, "type": table.type},
"queries": [get_query_object(query_name, add_postprocessing_operations)],
}

View File

@ -14,13 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import tests.test_app
from superset import db
from superset.charts.schemas import ChartDataQueryContextSchema
from superset.connectors.connector_registry import ConnectorRegistry
from superset.models.cache import CacheKey
from superset.utils.core import (
AdhocMetricExpressionType,
ChartDataResultFormat,
@ -29,7 +25,6 @@ from superset.utils.core import (
TimeRangeEndpoint,
)
from tests.base_tests import SupersetTestCase
from tests.fixtures.energy_dashboard import load_energy_table_with_slice
from tests.fixtures.query_context import get_query_context
@ -39,13 +34,10 @@ class TestQueryContext(SupersetTestCase):
Ensure that the deserialized QueryContext contains all required fields.
"""
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(
table.name, table.id, table.type, add_postprocessing_operations=True
)
payload = get_query_context("birth_names", add_postprocessing_operations=True)
query_context = ChartDataQueryContextSchema().load(payload)
self.assertEqual(len(query_context.queries), len(payload["queries"]))
for query_idx, query in enumerate(query_context.queries):
payload_query = payload["queries"][query_idx]
@ -75,9 +67,7 @@ class TestQueryContext(SupersetTestCase):
def test_cache_key_changes_when_datasource_is_updated(self):
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload = get_query_context("birth_names")
# construct baseline cache_key
query_context = ChartDataQueryContextSchema().load(payload)
@ -106,11 +96,7 @@ class TestQueryContext(SupersetTestCase):
def test_cache_key_changes_when_post_processing_is_updated(self):
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(
table.name, table.id, table.type, add_postprocessing_operations=True
)
payload = get_query_context("birth_names", add_postprocessing_operations=True)
# construct baseline cache_key from query_context with post processing operation
query_context = ChartDataQueryContextSchema().load(payload)
@ -121,43 +107,57 @@ class TestQueryContext(SupersetTestCase):
payload["queries"][0]["post_processing"].append(None)
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_with_null = query_context.query_cache_key(query_object)
self.assertEqual(cache_key_original, cache_key_with_null)
cache_key = query_context.query_cache_key(query_object)
self.assertEqual(cache_key_original, cache_key)
# ensure query without post processing operation is different
payload["queries"][0].pop("post_processing")
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_without_post_processing = query_context.query_cache_key(query_object)
self.assertNotEqual(cache_key_original, cache_key_without_post_processing)
cache_key = query_context.query_cache_key(query_object)
self.assertNotEqual(cache_key_original, cache_key)
def test_query_context_time_range_endpoints(self):
"""
Ensure that time_range_endpoints are populated automatically when missing
from the payload
from the payload.
"""
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload = get_query_context("birth_names")
del payload["queries"][0]["extras"]["time_range_endpoints"]
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
extras = query_object.to_dict()["extras"]
self.assertTrue("time_range_endpoints" in extras)
assert "time_range_endpoints" in extras
self.assertEqual(
extras["time_range_endpoints"],
(TimeRangeEndpoint.INCLUSIVE, TimeRangeEndpoint.EXCLUSIVE),
)
def test_handle_metrics_field(self):
"""
Should support both predefined and adhoc metrics.
"""
self.login(username="admin")
adhoc_metric = {
"expressionType": "SIMPLE",
"column": {"column_name": "sum_boys", "type": "BIGINT(20)"},
"aggregate": "SUM",
"label": "Boys",
"optionName": "metric_11",
}
payload = get_query_context("birth_names")
payload["queries"][0]["metrics"] = ["sum__num", {"label": "abc"}, adhoc_metric]
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
self.assertEqual(query_object.metrics, ["sum__num", "abc", adhoc_metric])
def test_convert_deprecated_fields(self):
"""
Ensure that deprecated fields are converted correctly
"""
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload = get_query_context("birth_names")
payload["queries"][0]["granularity_sqla"] = "timecol"
payload["queries"][0]["having_filters"] = [{"col": "a", "op": "==", "val": "b"}]
query_context = ChartDataQueryContextSchema().load(payload)
@ -171,9 +171,7 @@ class TestQueryContext(SupersetTestCase):
Ensure that CSV result format works
"""
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload = get_query_context("birth_names")
payload["result_format"] = ChartDataResultFormat.CSV.value
payload["queries"][0]["row_limit"] = 10
query_context = ChartDataQueryContextSchema().load(payload)
@ -188,9 +186,7 @@ class TestQueryContext(SupersetTestCase):
Ensure that calling invalid columns names in groupby are caught
"""
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload = get_query_context("birth_names")
payload["queries"][0]["groupby"] = ["currentDatabase()"]
query_context = ChartDataQueryContextSchema().load(payload)
query_payload = query_context.get_payload()
@ -201,9 +197,7 @@ class TestQueryContext(SupersetTestCase):
Ensure that calling invalid column names in columns are caught
"""
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload = get_query_context("birth_names")
payload["queries"][0]["groupby"] = []
payload["queries"][0]["metrics"] = []
payload["queries"][0]["columns"] = ["*, 'extra'"]
@ -216,9 +210,7 @@ class TestQueryContext(SupersetTestCase):
Ensure that calling invalid column names in filters are caught
"""
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload = get_query_context("birth_names")
payload["queries"][0]["groupby"] = ["name"]
payload["queries"][0]["metrics"] = [
{
@ -237,9 +229,7 @@ class TestQueryContext(SupersetTestCase):
Ensure that samples result type works
"""
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload = get_query_context("birth_names")
payload["result_type"] = ChartDataResultType.SAMPLES.value
payload["queries"][0]["row_limit"] = 5
query_context = ChartDataQueryContextSchema().load(payload)
@ -255,9 +245,7 @@ class TestQueryContext(SupersetTestCase):
Ensure that query result type works
"""
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload = get_query_context("birth_names")
payload["result_type"] = ChartDataResultType.QUERY.value
query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
@ -274,9 +262,7 @@ class TestQueryContext(SupersetTestCase):
"""
self.maxDiff = None
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload = get_query_context("birth_names")
query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
orig_cache_key = responses["queries"][0]["cache_key"]

View File

@ -45,8 +45,7 @@ class TestAsyncQueries(SupersetTestCase):
@mock.patch.object(async_query_manager, "update_job")
def test_load_chart_data_into_cache(self, mock_update_job):
async_query_manager.init_app(app)
table = get_table_by_name("birth_names")
form_data = get_query_context(table.name, table.id, table.type)
query_context = get_query_context("birth_names")
job_metadata = {
"channel_id": str(uuid4()),
"job_id": str(uuid4()),
@ -55,7 +54,7 @@ class TestAsyncQueries(SupersetTestCase):
"errors": [],
}
load_chart_data_into_cache(job_metadata, form_data)
load_chart_data_into_cache(job_metadata, query_context)
mock_update_job.assert_called_with(job_metadata, "done", result_url=mock.ANY)
@ -65,8 +64,7 @@ class TestAsyncQueries(SupersetTestCase):
@mock.patch.object(async_query_manager, "update_job")
def test_load_chart_data_into_cache_error(self, mock_update_job, mock_run_command):
async_query_manager.init_app(app)
table = get_table_by_name("birth_names")
form_data = get_query_context(table.name, table.id, table.type)
query_context = get_query_context("birth_names")
job_metadata = {
"channel_id": str(uuid4()),
"job_id": str(uuid4()),
@ -75,7 +73,7 @@ class TestAsyncQueries(SupersetTestCase):
"errors": [],
}
with pytest.raises(ChartDataQueryFailedError):
load_chart_data_into_cache(job_metadata, form_data)
load_chart_data_into_cache(job_metadata, query_context)
mock_run_command.assert_called_with(cache=True)
errors = [{"message": "Error: foo"}]
@ -86,11 +84,6 @@ class TestAsyncQueries(SupersetTestCase):
async_query_manager.init_app(app)
table = get_table_by_name("birth_names")
form_data = {
"queryFields": {
"metrics": "metrics",
"groupby": "groupby",
"columns": "groupby",
},
"datasource": f"{table.id}__table",
"viz_type": "dist_bar",
"time_range_endpoints": ["inclusive", "exclusive"],