diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 01ceaecbc..4d63b118d 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -177,7 +177,7 @@ class QueryObject: # 2. { label: 'label_name' } - legacy format for a predefined metric # 3. { expressionType: 'SIMPLE' | 'SQL', ... } - adhoc metric self.metrics = metrics and [ - x if isinstance(x, str) or is_adhoc_metric(x) else x["label"] + x if isinstance(x, str) or is_adhoc_metric(x) else x["label"] # type: ignore for x in metrics ] diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index af73d2d42..32edb6952 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -60,6 +60,7 @@ from superset.models.core import Database from superset.models.helpers import AuditMixinNullable, ImportExportMixin, QueryResult from superset.typing import ( AdhocMetric, + AdhocMetricColumn, FilterValues, Granularity, Metric, @@ -93,7 +94,13 @@ except ImportError: pass try: - from superset.utils.core import DimSelector, DTTM_ALIAS, FilterOperator, flasher + from superset.utils.core import ( + DimSelector, + DTTM_ALIAS, + FilterOperator, + flasher, + get_metric_name, + ) except ImportError: pass @@ -1021,7 +1028,7 @@ class DruidDatasource(Model, BaseDatasource): @staticmethod def druid_type_from_adhoc_metric(adhoc_metric: AdhocMetric) -> str: - column_type = adhoc_metric["column"]["type"].lower() + column_type = adhoc_metric["column"]["type"].lower() # type: ignore aggregate = adhoc_metric["aggregate"].lower() if aggregate == "count": @@ -1063,11 +1070,13 @@ class DruidDatasource(Model, BaseDatasource): _("Metric(s) {} must be aggregations.").format(invalid_metric_names) ) for adhoc_metric in adhoc_metrics: - aggregations[adhoc_metric["label"]] = { - "fieldName": adhoc_metric["column"]["column_name"], - "fieldNames": [adhoc_metric["column"]["column_name"]], + label = get_metric_name(adhoc_metric) + column = cast(AdhocMetricColumn, adhoc_metric["column"]) + aggregations[label] = { + "fieldName": column["column_name"], + "fieldNames": [column["column_name"]], "type": DruidDatasource.druid_type_from_adhoc_metric(adhoc_metric), - "name": adhoc_metric["label"], + "name": label, } return aggregations diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index a05c63e1b..7a91ac001 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -845,7 +845,8 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at label = utils.get_metric_name(metric) if expression_type == utils.AdhocMetricExpressionType.SIMPLE: - column_name = cast(str, metric["column"].get("column_name")) + metric_column = metric.get("column") or {} + column_name = cast(str, metric_column.get("column_name")) table_column: Optional[TableColumn] = columns_by_name.get(column_name) if table_column: sqla_column = table_column.get_sqla_col() diff --git a/superset/typing.py b/superset/typing.py index 009c9d946..4273444fe 100644 --- a/superset/typing.py +++ b/superset/typing.py @@ -15,24 +15,50 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Tuple, + TYPE_CHECKING, + Union, +) from flask import Flask from flask_caching import Cache from typing_extensions import TypedDict from werkzeug.wrappers import Response +if TYPE_CHECKING: + from superset.utils.core import GenericDataType -class AdhocMetricColumn(TypedDict): + +class LegacyMetric(TypedDict): + label: Optional[str] + + +class AdhocMetricColumn(TypedDict, total=False): column_name: Optional[str] + description: Optional[str] + expression: Optional[str] + filterable: bool + groupby: bool + id: int + is_dttm: bool + python_date_format: Optional[str] type: str + type_generic: "GenericDataType" + verbose_name: Optional[str] -class AdhocMetric(TypedDict): +class AdhocMetric(TypedDict, total=False): aggregate: str - column: AdhocMetricColumn + column: Optional[AdhocMetricColumn] expressionType: str - label: str + label: Optional[str] sqlExpression: Optional[str] diff --git a/superset/utils/core.py b/superset/utils/core.py index 427ebfaa7..d48606ec0 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -96,7 +96,14 @@ from superset.exceptions import ( SupersetException, SupersetTimeoutException, ) -from superset.typing import AdhocMetric, FilterValues, FlaskResponse, FormData, Metric +from superset.typing import ( + AdhocMetric, + AdhocMetricColumn, + FilterValues, + FlaskResponse, + FormData, + Metric, +) from superset.utils.dates import datetime_to_epoch, EPOCH from superset.utils.hashing import md5_sha_from_dict, md5_sha_from_str @@ -1273,7 +1280,33 @@ def is_adhoc_metric(metric: Metric) -> bool: def get_metric_name(metric: Metric) -> str: - return metric["label"] if is_adhoc_metric(metric) else metric # type: ignore + """ + Extract label from metric + + :param metric: object to extract label from + :return: String representation of metric + :raises ValueError: if metric object is invalid + """ + if is_adhoc_metric(metric): + metric = cast(AdhocMetric, metric) + label = metric.get("label") + if label: + return label + expression_type = metric.get("expressionType") + if expression_type == "SQL": + sql_expression = metric.get("sqlExpression") + if sql_expression: + return sql_expression + elif expression_type == "SIMPLE": + column: AdhocMetricColumn = metric.get("column") or {} + column_name = column.get("column_name") + aggregate = metric.get("aggregate") + if column and aggregate: + return f"{aggregate}({column_name})" + if column_name: + return column_name + raise ValueError(__("Invalid metric object")) + return cast(str, metric) def get_metric_names(metrics: Sequence[Metric]) -> List[str]: diff --git a/tests/unit_tests/core_tests.py b/tests/unit_tests/core_tests.py new file mode 100644 index 000000000..bb3e50f51 --- /dev/null +++ b/tests/unit_tests/core_tests.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from copy import deepcopy + +import pytest + +from superset.utils.core import ( + AdhocMetric, + GenericDataType, + get_metric_name, + get_metric_names, +) + +STR_METRIC = "my_metric" +SIMPLE_SUM_ADHOC_METRIC: AdhocMetric = { + "aggregate": "SUM", + "column": { + "column_name": "my_col", + "type": "INT", + "type_generic": GenericDataType.NUMERIC, + }, + "expressionType": "SIMPLE", + "label": "my SUM", +} +SQL_ADHOC_METRIC: AdhocMetric = { + "expressionType": "SQL", + "label": "my_sql", + "sqlExpression": "SUM(my_col)", +} + + +def test_get_metric_name_saved_metric(): + assert get_metric_name(STR_METRIC) == "my_metric" + + +def test_get_metric_name_adhoc(): + metric = deepcopy(SIMPLE_SUM_ADHOC_METRIC) + assert get_metric_name(metric) == "my SUM" + del metric["label"] + assert get_metric_name(metric) == "SUM(my_col)" + metric["label"] = "" + assert get_metric_name(metric) == "SUM(my_col)" + del metric["aggregate"] + assert get_metric_name(metric) == "my_col" + metric["aggregate"] = "" + assert get_metric_name(metric) == "my_col" + + metric = deepcopy(SQL_ADHOC_METRIC) + assert get_metric_name(metric) == "my_sql" + del metric["label"] + assert get_metric_name(metric) == "SUM(my_col)" + metric["label"] = "" + assert get_metric_name(metric) == "SUM(my_col)" + + +def test_get_metric_name_invalid_metric(): + metric = deepcopy(SIMPLE_SUM_ADHOC_METRIC) + del metric["label"] + del metric["column"] + with pytest.raises(ValueError): + get_metric_name(metric) + + metric = deepcopy(SIMPLE_SUM_ADHOC_METRIC) + del metric["label"] + metric["expressionType"] = "FOO" + with pytest.raises(ValueError): + get_metric_name(metric) + + metric = deepcopy(SQL_ADHOC_METRIC) + del metric["label"] + metric["expressionType"] = "FOO" + with pytest.raises(ValueError): + get_metric_name(metric) + + +def test_get_metric_names(): + assert get_metric_names( + [STR_METRIC, SIMPLE_SUM_ADHOC_METRIC, SQL_ADHOC_METRIC] + ) == ["my_metric", "my SUM", "my_sql"]