fix(chart): Supporting custom SQL as temporal x-axis column with filter (#25126)

Co-authored-by: Kamil Gabryjelski <kamil.gabryjelski@gmail.com>
This commit is contained in:
Zef Lin 2023-09-18 11:30:52 -07:00 committed by GitHub
parent e11012d426
commit c8c94825ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 90 additions and 16 deletions

View File

@ -22,20 +22,19 @@ import exploreReducer from 'src/explore/reducers/exploreReducer';
import * as actions from 'src/explore/actions/exploreActions';
describe('reducers', () => {
it('sets correct control value given an arbitrary key and value', () => {
it('Does not set a control value if control does not exist', () => {
const newState = exploreReducer(
defaultState,
actions.setControlValue('NEW_FIELD', 'x', []),
);
expect(newState.controls.NEW_FIELD.value).toBe('x');
expect(newState.form_data.NEW_FIELD).toBe('x');
expect(newState.controls.NEW_FIELD).toBeUndefined();
});
it('setControlValue works as expected with a checkbox', () => {
it('setControlValue works as expected with a Select control', () => {
const newState = exploreReducer(
defaultState,
actions.setControlValue('show_legend', true, []),
actions.setControlValue('y_axis_format', '$,.2f', []),
);
expect(newState.controls.show_legend.value).toBe(true);
expect(newState.form_data.show_legend).toBe(true);
expect(newState.controls.y_axis_format.value).toBe('$,.2f');
expect(newState.form_data.y_axis_format).toBe('$,.2f');
});
});

View File

@ -112,7 +112,7 @@ export default function exploreReducer(state = {}, action) {
const vizType = new_form_data.viz_type;
// if the controlName is metrics, and the metric column name is updated,
// need to update column config as well to keep the previou config.
// need to update column config as well to keep the previous config.
if (controlName === 'metrics' && old_metrics_data && new_column_config) {
value.forEach((item, index) => {
if (
@ -129,11 +129,11 @@ export default function exploreReducer(state = {}, action) {
}
// Use the processed control config (with overrides and everything)
// if `controlName` does not existing in current controls,
// if `controlName` does not exist in current controls,
const controlConfig =
state.controls[action.controlName] ||
getControlConfig(action.controlName, vizType) ||
{};
null;
// will call validators again
const control = {
@ -149,7 +149,7 @@ export default function exploreReducer(state = {}, action) {
...state,
controls: {
...state.controls,
[controlName]: control,
...(controlConfig && { [controlName]: control }),
...(controlName === 'metrics' && { column_config }),
},
};
@ -196,10 +196,12 @@ export default function exploreReducer(state = {}, action) {
triggerRender: control.renderTrigger && !hasErrors,
controls: {
...currentControlsState,
[action.controlName]: {
...control,
validationErrors: errors,
},
...(controlConfig && {
[action.controlName]: {
...control,
validationErrors: errors,
},
}),
...rerenderedControls,
},
};

View File

@ -26,7 +26,7 @@ from superset.common.query_object_factory import QueryObjectFactory
from superset.daos.chart import ChartDAO
from superset.daos.datasource import DatasourceDAO
from superset.models.slice import Slice
from superset.utils.core import DatasourceDict, DatasourceType
from superset.utils.core import DatasourceDict, DatasourceType, is_adhoc_column
if TYPE_CHECKING:
from superset.connectors.base.models import BaseDatasource
@ -128,6 +128,8 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
if granularity := query_object.granularity:
filter_to_remove = None
if is_adhoc_column(x_axis): # type: ignore
x_axis = x_axis.get("sqlExpression")
if x_axis and x_axis in temporal_columns:
filter_to_remove = x_axis
x_axis_column = next(
@ -175,6 +177,9 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
# another temporal filter. A new filter based on the value of
# the granularity will be added later in the code.
# In practice, this is replacing the previous default temporal filter.
if is_adhoc_column(filter_to_remove): # type: ignore
filter_to_remove = filter_to_remove.get("sqlExpression")
if filter_to_remove:
query_object.filter = [
filter

View File

@ -1011,6 +1011,8 @@ class SqlaTable(
qry = sa.select([sqla_column]).limit(1).select_from(tbl)
sql = self.database.compile_sqla_query(qry)
col_desc = get_columns_description(self.database, self.schema, sql)
if not col_desc:
raise SupersetGenericDBErrorException("Column not found")
is_dttm = col_desc[0]["is_dttm"] # type: ignore
except SupersetGenericDBErrorException as ex:
raise ColumnNotFoundException(message=str(ex)) from ex

View File

@ -51,6 +51,7 @@ from superset.models.slice import Slice
from superset.superset_typing import AdhocColumn
from superset.utils.core import (
AnnotationType,
backend,
get_example_default_schema,
AdhocMetricExpressionType,
ExtraFiltersReasonType,
@ -943,6 +944,71 @@ class TestGetChartDataApi(BaseTestChartDataApi):
assert data["result"][0]["status"] == "success"
assert data["result"][0]["rowcount"] == 2
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_get_with_x_axis_using_custom_sql(self):
"""
Chart data API: Test GET endpoint
"""
chart = db.session.query(Slice).filter_by(slice_name="Genders").one()
chart.query_context = json.dumps(
{
"datasource": {"id": chart.table.id, "type": "table"},
"force": False,
"queries": [
{
"time_range": "1900-01-01T00:00:00 : 2000-01-01T00:00:00",
"granularity": "ds",
"filters": [
{"col": "ds", "op": "TEMPORAL_RANGE", "val": "No filter"}
],
"extras": {
"having": "",
"where": "",
},
"applied_time_extras": {},
"columns": [
{
"columnType": "BASE_AXIS",
"datasourceWarning": False,
"expressionType": "SQL",
"label": "My column",
"sqlExpression": "ds",
"timeGrain": "P1W",
}
],
"metrics": ["sum__num"],
"orderby": [["sum__num", False]],
"annotation_layers": [],
"row_limit": 50000,
"timeseries_limit": 0,
"order_desc": True,
"url_params": {},
"custom_params": {},
"custom_form_data": {},
}
],
"form_data": {
"x_axis": {
"datasourceWarning": False,
"expressionType": "SQL",
"label": "My column",
"sqlExpression": "ds",
}
},
"result_format": "json",
"result_type": "full",
}
)
rv = self.get_assert_metric(f"api/v1/chart/{chart.id}/data/", "get_data")
assert rv.mimetype == "application/json"
data = json.loads(rv.data.decode("utf-8"))
assert data["result"][0]["status"] == "success"
if backend() == "presto":
assert data["result"][0]["rowcount"] == 41
else:
assert data["result"][0]["rowcount"] == 40
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_get_forced(self):
"""