fix(dashboard): Charts crashing when cross filter on adhoc column is applied (#23238)

Co-authored-by: Ville Brofeldt <33317356+villebro@users.noreply.github.com>
This commit is contained in:
Kamil Gabryjelski 2023-03-04 07:57:35 +01:00 committed by GitHub
parent 006f3dd88c
commit 42980a69a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 143 additions and 44 deletions

View File

@ -23,6 +23,7 @@ import {
FeatureFlag,
Filters,
FilterState,
getColumnLabel,
isFeatureEnabled,
NativeFilterType,
NO_TIME_RANGE,
@ -146,8 +147,8 @@ const getAppliedColumns = (chart: any): Set<string> =>
const getRejectedColumns = (chart: any): Set<string> =>
new Set(
(chart?.queriesResponse?.[0]?.rejected_filters || []).map(
(filter: any) => filter.column,
(chart?.queriesResponse?.[0]?.rejected_filters || []).map((filter: any) =>
getColumnLabel(filter.column),
),
);

View File

@ -17,7 +17,7 @@
from __future__ import annotations
import copy
from typing import Any, Callable, cast, Dict, List, Optional, TYPE_CHECKING
from typing import Any, Callable, Dict, Optional, TYPE_CHECKING
from flask_babel import _
@ -32,7 +32,6 @@ from superset.utils.core import (
ExtraFiltersReasonType,
get_column_name,
get_time_filter_status,
is_adhoc_column,
)
if TYPE_CHECKING:
@ -102,7 +101,6 @@ def _get_full(
datasource = _get_datasource(query_context, query_obj)
result_type = query_obj.result_type or query_context.result_type
payload = query_context.get_df_payload(query_obj, force_cached=force_cached)
applied_template_filters = payload.get("applied_template_filters", [])
df = payload["df"]
status = payload["status"]
if status != QueryStatus.FAILED:
@ -113,23 +111,23 @@ def _get_full(
payload["result_format"] = query_context.result_format
del payload["df"]
filters = query_obj.filter
filter_columns = cast(List[str], [flt.get("col") for flt in filters])
columns = set(datasource.column_names)
applied_time_columns, rejected_time_columns = get_time_filter_status(
datasource, query_obj.applied_time_extras
)
applied_filter_columns = payload.get("applied_filter_columns", [])
rejected_filter_columns = payload.get("rejected_filter_columns", [])
del payload["applied_filter_columns"]
del payload["rejected_filter_columns"]
payload["applied_filters"] = [
{"column": get_column_name(col)}
for col in filter_columns
if is_adhoc_column(col) or col in columns or col in applied_template_filters
{"column": get_column_name(col)} for col in applied_filter_columns
] + applied_time_columns
payload["rejected_filters"] = [
{"reason": ExtraFiltersReasonType.COL_NOT_IN_DATASOURCE, "column": col}
for col in filter_columns
if not is_adhoc_column(col)
and col not in columns
and col not in applied_template_filters
{
"reason": ExtraFiltersReasonType.COL_NOT_IN_DATASOURCE,
"column": get_column_name(col),
}
for col in rejected_filter_columns
] + rejected_time_columns
if result_type == ChartDataResultType.RESULTS and status != QueryStatus.FAILED:

View File

@ -165,6 +165,8 @@ class QueryContextProcessor:
"cache_timeout": self.get_cache_timeout(),
"df": cache.df,
"applied_template_filters": cache.applied_template_filters,
"applied_filter_columns": cache.applied_filter_columns,
"rejected_filter_columns": cache.rejected_filter_columns,
"annotation_data": cache.annotation_data,
"error": cache.error_message,
"is_cached": cache.is_cached,

View File

@ -29,6 +29,7 @@ from superset.exceptions import CacheLoadError
from superset.extensions import cache_manager
from superset.models.helpers import QueryResult
from superset.stats_logger import BaseStatsLogger
from superset.superset_typing import Column
from superset.utils.cache import set_and_log_cache
from superset.utils.core import error_msg_from_exception, get_stacktrace
@ -54,6 +55,8 @@ class QueryCacheManager:
query: str = "",
annotation_data: Optional[Dict[str, Any]] = None,
applied_template_filters: Optional[List[str]] = None,
applied_filter_columns: Optional[List[Column]] = None,
rejected_filter_columns: Optional[List[Column]] = None,
status: Optional[str] = None,
error_message: Optional[str] = None,
is_loaded: bool = False,
@ -66,6 +69,8 @@ class QueryCacheManager:
self.query = query
self.annotation_data = {} if annotation_data is None else annotation_data
self.applied_template_filters = applied_template_filters or []
self.applied_filter_columns = applied_filter_columns or []
self.rejected_filter_columns = rejected_filter_columns or []
self.status = status
self.error_message = error_message
@ -93,6 +98,8 @@ class QueryCacheManager:
self.status = query_result.status
self.query = query_result.query
self.applied_template_filters = query_result.applied_template_filters
self.applied_filter_columns = query_result.applied_filter_columns
self.rejected_filter_columns = query_result.rejected_filter_columns
self.error_message = query_result.error_message
self.df = query_result.df
self.annotation_data = {} if annotation_data is None else annotation_data
@ -107,6 +114,8 @@ class QueryCacheManager:
"df": self.df,
"query": self.query,
"applied_template_filters": self.applied_template_filters,
"applied_filter_columns": self.applied_filter_columns,
"rejected_filter_columns": self.rejected_filter_columns,
"annotation_data": self.annotation_data,
}
if self.is_loaded and key and self.status != QueryStatus.FAILED:
@ -150,6 +159,12 @@ class QueryCacheManager:
query_cache.applied_template_filters = cache_value.get(
"applied_template_filters", []
)
query_cache.applied_filter_columns = cache_value.get(
"applied_filter_columns", []
)
query_cache.rejected_filter_columns = cache_value.get(
"rejected_filter_columns", []
)
query_cache.status = QueryStatus.SUCCESS
query_cache.is_loaded = True
query_cache.is_cached = cache_value is not None

View File

@ -99,9 +99,11 @@ from superset.datasets.models import Dataset as NewDataset
from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression
from superset.exceptions import (
AdvancedDataTypeResponseError,
ColumnNotFoundException,
DatasetInvalidPermissionEvaluationException,
QueryClauseValidationException,
QueryObjectValidationError,
SupersetGenericDBErrorException,
SupersetSecurityException,
)
from superset.extensions import feature_flag_manager
@ -150,6 +152,8 @@ ADDITIVE_METRIC_TYPES_LOWER = {op.lower() for op in ADDITIVE_METRIC_TYPES}
class SqlaQuery(NamedTuple):
applied_template_filters: List[str]
applied_filter_columns: List[ColumnTyping]
rejected_filter_columns: List[ColumnTyping]
cte: Optional[str]
extra_cache_keys: List[Any]
labels_expected: List[str]
@ -159,6 +163,8 @@ class SqlaQuery(NamedTuple):
class QueryStringExtended(NamedTuple):
applied_template_filters: Optional[List[str]]
applied_filter_columns: List[ColumnTyping]
rejected_filter_columns: List[ColumnTyping]
labels_expected: List[str]
prequeries: List[str]
sql: str
@ -878,6 +884,8 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
sql = self.mutate_query_from_config(sql)
return QueryStringExtended(
applied_template_filters=sqlaq.applied_template_filters,
applied_filter_columns=sqlaq.applied_filter_columns,
rejected_filter_columns=sqlaq.rejected_filter_columns,
labels_expected=sqlaq.labels_expected,
prequeries=sqlaq.prequeries,
sql=sql,
@ -1020,13 +1028,16 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
)
is_dttm = col_in_metadata.is_temporal
else:
sqla_column = literal_column(expression)
# probe adhoc column type
tbl, _ = self.get_from_clause(template_processor)
qry = sa.select([sqla_column]).limit(1).select_from(tbl)
sql = self.database.compile_sqla_query(qry)
col_desc = get_columns_description(self.database, sql)
is_dttm = col_desc[0]["is_dttm"]
try:
sqla_column = literal_column(expression)
# probe adhoc column type
tbl, _ = self.get_from_clause(template_processor)
qry = sa.select([sqla_column]).limit(1).select_from(tbl)
sql = self.database.compile_sqla_query(qry)
col_desc = get_columns_description(self.database, sql)
is_dttm = col_desc[0]["is_dttm"]
except SupersetGenericDBErrorException as ex:
raise ColumnNotFoundException(message=str(ex)) from ex
if (
is_dttm
@ -1181,6 +1192,8 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
}
columns = columns or []
groupby = groupby or []
rejected_adhoc_filters_columns: List[Union[str, ColumnTyping]] = []
applied_adhoc_filters_columns: List[Union[str, ColumnTyping]] = []
series_column_names = utils.get_column_names(series_columns or [])
# deprecated, to be removed in 2.0
if is_timeseries and timeseries_limit:
@ -1439,9 +1452,14 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
if flt_col == utils.DTTM_ALIAS and is_timeseries and dttm_col:
col_obj = dttm_col
elif is_adhoc_column(flt_col):
sqla_col = self.adhoc_column_to_sqla(flt_col)
try:
sqla_col = self.adhoc_column_to_sqla(flt_col)
applied_adhoc_filters_columns.append(flt_col)
except ColumnNotFoundException:
rejected_adhoc_filters_columns.append(flt_col)
continue
else:
col_obj = columns_by_name.get(flt_col)
col_obj = columns_by_name.get(cast(str, flt_col))
filter_grain = flt.get("grain")
if is_feature_enabled("ENABLE_TEMPLATE_REMOVE_FILTERS"):
@ -1766,8 +1784,27 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
qry = select([col]).select_from(qry.alias("rowcount_qry"))
labels_expected = [label]
filter_columns = [flt.get("col") for flt in filter] if filter else []
rejected_filter_columns = [
col
for col in filter_columns
if col
and not is_adhoc_column(col)
and col not in self.column_names
and col not in applied_template_filters
] + rejected_adhoc_filters_columns
applied_filter_columns = [
col
for col in filter_columns
if col
and not is_adhoc_column(col)
and (col in self.column_names or col in applied_template_filters)
] + applied_adhoc_filters_columns
return SqlaQuery(
applied_template_filters=applied_template_filters,
rejected_filter_columns=rejected_filter_columns,
applied_filter_columns=applied_filter_columns,
cte=cte,
extra_cache_keys=extra_cache_keys,
labels_expected=labels_expected,
@ -1906,6 +1943,8 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
return QueryResult(
applied_template_filters=query_str_ext.applied_template_filters,
applied_filter_columns=query_str_ext.applied_filter_columns,
rejected_filter_columns=query_str_ext.rejected_filter_columns,
status=status,
df=df,
duration=datetime.now() - qry_start_dttm,

View File

@ -270,3 +270,7 @@ class SupersetCancelQueryException(SupersetException):
class QueryNotFoundException(SupersetException):
status = 404
class ColumnNotFoundException(SupersetException):
status = 404

View File

@ -80,6 +80,7 @@ from superset.jinja_context import BaseTemplateProcessor
from superset.sql_parse import has_table_query, insert_rls, ParsedQuery, sanitize_clause
from superset.superset_typing import (
AdhocMetric,
Column as ColumnTyping,
FilterValue,
FilterValues,
Metric,
@ -545,6 +546,8 @@ class QueryResult: # pylint: disable=too-few-public-methods
query: str,
duration: timedelta,
applied_template_filters: Optional[List[str]] = None,
applied_filter_columns: Optional[List[ColumnTyping]] = None,
rejected_filter_columns: Optional[List[ColumnTyping]] = None,
status: str = QueryStatus.SUCCESS,
error_message: Optional[str] = None,
errors: Optional[List[Dict[str, Any]]] = None,
@ -555,6 +558,8 @@ class QueryResult: # pylint: disable=too-few-public-methods
self.query = query
self.duration = duration
self.applied_template_filters = applied_template_filters or []
self.applied_filter_columns = applied_filter_columns or []
self.rejected_filter_columns = rejected_filter_columns or []
self.status = status
self.error_message = error_message
self.errors = errors or []
@ -1646,7 +1651,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
elif utils.is_adhoc_column(flt_col):
sqla_col = self.adhoc_column_to_sqla(flt_col) # type: ignore
else:
col_obj = columns_by_name.get(flt_col)
col_obj = columns_by_name.get(cast(str, flt_col))
filter_grain = flt.get("grain")
if is_feature_enabled("ENABLE_TEMPLATE_REMOVE_FILTERS"):

View File

@ -221,7 +221,7 @@ class AdhocFilterClause(TypedDict, total=False):
class QueryObjectFilterClause(TypedDict, total=False):
col: str
col: Column
op: str # pylint: disable=invalid-name
val: Optional[FilterValues]
grain: Optional[str]
@ -1089,7 +1089,7 @@ def simple_filter_to_adhoc(
"expressionType": "SIMPLE",
"comparator": filter_clause.get("val"),
"operator": filter_clause["op"],
"subject": filter_clause["col"],
"subject": cast(str, filter_clause["col"]),
}
if filter_clause.get("isExtra"):
result["isExtra"] = True

View File

@ -154,7 +154,8 @@ class BaseViz: # pylint: disable=too-many-public-methods
self.status: Optional[str] = None
self.error_msg = ""
self.results: Optional[QueryResult] = None
self.applied_template_filters: List[str] = []
self.applied_filter_columns: List[Column] = []
self.rejected_filter_columns: List[Column] = []
self.errors: List[Dict[str, Any]] = []
self.force = force
self._force_cached = force_cached
@ -288,7 +289,8 @@ class BaseViz: # pylint: disable=too-many-public-methods
# The datasource here can be different backend but the interface is common
self.results = self.datasource.query(query_obj)
self.applied_template_filters = self.results.applied_template_filters or []
self.applied_filter_columns = self.results.applied_filter_columns or []
self.rejected_filter_columns = self.results.rejected_filter_columns or []
self.query = self.results.query
self.status = self.results.status
self.errors = self.results.errors
@ -492,25 +494,21 @@ class BaseViz: # pylint: disable=too-many-public-methods
if "df" in payload:
del payload["df"]
filters = self.form_data.get("filters", [])
filter_columns = [flt.get("col") for flt in filters]
columns = set(self.datasource.column_names)
applied_template_filters = self.applied_template_filters or []
applied_filter_columns = self.applied_filter_columns or []
rejected_filter_columns = self.rejected_filter_columns or []
applied_time_extras = self.form_data.get("applied_time_extras", {})
applied_time_columns, rejected_time_columns = utils.get_time_filter_status(
self.datasource, applied_time_extras
)
payload["applied_filters"] = [
{"column": get_column_name(col)}
for col in filter_columns
if is_adhoc_column(col) or col in columns or col in applied_template_filters
{"column": get_column_name(col)} for col in applied_filter_columns
] + applied_time_columns
payload["rejected_filters"] = [
{"reason": ExtraFiltersReasonType.COL_NOT_IN_DATASOURCE, "column": col}
for col in filter_columns
if not is_adhoc_column(col)
and col not in columns
and col not in applied_template_filters
{
"reason": ExtraFiltersReasonType.COL_NOT_IN_DATASOURCE,
"column": get_column_name(col),
}
for col in rejected_filter_columns
] + rejected_time_columns
if df is not None:
payload["colnames"] = list(df.columns)
@ -535,8 +533,11 @@ class BaseViz: # pylint: disable=too-many-public-methods
try:
df = cache_value["df"]
self.query = cache_value["query"]
self.applied_template_filters = cache_value.get(
"applied_template_filters", []
self.applied_filter_columns = cache_value.get(
"applied_filter_columns", []
)
self.rejected_filter_columns = cache_value.get(
"rejected_filter_columns", []
)
self.status = QueryStatus.SUCCESS
is_loaded = True

View File

@ -56,6 +56,7 @@ from superset.utils.core import (
AnnotationType,
get_example_default_schema,
AdhocMetricExpressionType,
ExtraFiltersReasonType,
)
from superset.utils.database import get_example_database, get_main_database
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
@ -73,6 +74,12 @@ ADHOC_COLUMN_FIXTURE: AdhocColumn = {
"when gender = 'girl' then 'female' else 'other' end",
}
INCOMPATIBLE_ADHOC_COLUMN_FIXTURE: AdhocColumn = {
"hasCustomLabel": True,
"label": "exciting_or_boring",
"sqlExpression": "case when genre = 'Action' then 'Exciting' else 'Boring' end",
}
class BaseTestChartDataApi(SupersetTestCase):
query_context_payload_template = None
@ -1059,6 +1066,33 @@ class TestGetChartDataApi(BaseTestChartDataApi):
assert unique_genders == {"male", "female"}
assert result["applied_filters"] == [{"column": "male_or_female"}]
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_with_incompatible_adhoc_column(self):
"""
Chart data API: Test query with adhoc column that fails to run on this dataset
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["queries"][0]["columns"] = [ADHOC_COLUMN_FIXTURE]
request_payload["queries"][0]["filters"] = [
{"col": INCOMPATIBLE_ADHOC_COLUMN_FIXTURE, "op": "IN", "val": ["Exciting"]},
{"col": ADHOC_COLUMN_FIXTURE, "op": "IN", "val": ["male", "female"]},
]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
data = result["data"]
assert {column for column in data[0].keys()} == {"male_or_female", "sum__num"}
unique_genders = {row["male_or_female"] for row in data}
assert unique_genders == {"male", "female"}
assert result["applied_filters"] == [{"column": "male_or_female"}]
assert result["rejected_filters"] == [
{
"column": "exciting_or_boring",
"reason": ExtraFiltersReasonType.COL_NOT_IN_DATASOURCE,
}
]
@pytest.fixture()
def physical_query_context(physical_dataset) -> Dict[str, Any]: