From fdbcbb5c84f998666fd325ac14bc10d2cbdb2288 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Wed, 9 Feb 2022 14:01:57 +0200 Subject: [PATCH] fix(csv-export): pivot v2 with verbose names (#18633) * fix(csv-export): pivot v2 with verbose names * refine logic + add test * add missing verbose_map --- superset/charts/data/api.py | 4 ++- superset/charts/post_processing.py | 41 ++++++++++++++++------- superset/utils/core.py | 47 +++++++++++++++++++------- tests/unit_tests/core_tests.py | 53 ++++++++++++++++++++++++++++++ 4 files changed, 122 insertions(+), 23 deletions(-) diff --git a/superset/charts/data/api.py b/superset/charts/data/api.py index d6490421c..dc92d9745 100644 --- a/superset/charts/data/api.py +++ b/superset/charts/data/api.py @@ -241,7 +241,9 @@ class ChartDataRestApi(ChartRestApi): return self._run_async(json_body, command) form_data = json_body.get("form_data") - return self._get_data_response(command, form_data=form_data) + return self._get_data_response( + command, form_data=form_data, datasource=query_context.datasource + ) @expose("/data/", methods=["GET"]) @protect() diff --git a/superset/charts/post_processing.py b/superset/charts/post_processing.py index d3b8d47d3..7b2129039 100644 --- a/superset/charts/post_processing.py +++ b/superset/charts/post_processing.py @@ -32,7 +32,12 @@ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING import pandas as pd from superset.common.chart_data import ChartDataResultFormat -from superset.utils.core import DTTM_ALIAS, extract_dataframe_dtypes, get_metric_name +from superset.utils.core import ( + DTTM_ALIAS, + extract_dataframe_dtypes, + get_column_names, + get_metric_names, +) if TYPE_CHECKING: from superset.connectors.base.models import BaseDatasource @@ -214,18 +219,23 @@ pivot_v2_aggfunc_map = { } -def pivot_table_v2(df: pd.DataFrame, form_data: Dict[str, Any]) -> pd.DataFrame: +def pivot_table_v2( + df: pd.DataFrame, + form_data: Dict[str, Any], + datasource: Optional["BaseDatasource"] = None, +) -> pd.DataFrame: """ Pivot table v2. """ + verbose_map = datasource.data["verbose_map"] if datasource else None if form_data.get("granularity_sqla") == "all" and DTTM_ALIAS in df: del df[DTTM_ALIAS] return pivot_df( df, - rows=form_data.get("groupbyRows") or [], - columns=form_data.get("groupbyColumns") or [], - metrics=[get_metric_name(m) for m in form_data["metrics"]], + rows=get_column_names(form_data.get("groupbyRows"), verbose_map), + columns=get_column_names(form_data.get("groupbyColumns"), verbose_map), + metrics=get_metric_names(form_data["metrics"], verbose_map), aggfunc=form_data.get("aggregateFunction", "Sum"), transpose_pivot=bool(form_data.get("transposePivot")), combine_metrics=bool(form_data.get("combineMetric")), @@ -235,10 +245,15 @@ def pivot_table_v2(df: pd.DataFrame, form_data: Dict[str, Any]) -> pd.DataFrame: ) -def pivot_table(df: pd.DataFrame, form_data: Dict[str, Any]) -> pd.DataFrame: +def pivot_table( + df: pd.DataFrame, + form_data: Dict[str, Any], + datasource: Optional["BaseDatasource"] = None, +) -> pd.DataFrame: """ Pivot table (v1). """ + verbose_map = datasource.data["verbose_map"] if datasource else None if form_data.get("granularity") == "all" and DTTM_ALIAS in df: del df[DTTM_ALIAS] @@ -254,9 +269,9 @@ def pivot_table(df: pd.DataFrame, form_data: Dict[str, Any]) -> pd.DataFrame: return pivot_df( df, - rows=form_data.get("groupby") or [], - columns=form_data.get("columns") or [], - metrics=[get_metric_name(m) for m in form_data["metrics"]], + rows=get_column_names(form_data.get("groupby"), verbose_map), + columns=get_column_names(form_data.get("columns"), verbose_map), + metrics=get_metric_names(form_data["metrics"], verbose_map), aggfunc=func_map.get(form_data.get("pandas_aggfunc", "sum"), "Sum"), transpose_pivot=bool(form_data.get("transpose_pivot")), combine_metrics=bool(form_data.get("combine_metric")), @@ -266,7 +281,11 @@ def pivot_table(df: pd.DataFrame, form_data: Dict[str, Any]) -> pd.DataFrame: ) -def table(df: pd.DataFrame, form_data: Dict[str, Any]) -> pd.DataFrame: +def table( + df: pd.DataFrame, + form_data: Dict[str, Any], + datasource: Optional["BaseDatasource"] = None, # pylint: disable=unused-argument +) -> pd.DataFrame: """ Table. """ @@ -312,7 +331,7 @@ def apply_post_process( else: raise Exception(f"Result format {query['result_format']} not supported") - processed_df = post_processor(df, form_data) + processed_df = post_processor(df, form_data, datasource) query["colnames"] = list(processed_df.columns) query["indexnames"] = list(processed_df.index) diff --git a/superset/utils/core.py b/superset/utils/core.py index da69a89a8..ddb725623 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1228,11 +1228,15 @@ def is_adhoc_column(column: Column) -> TypeGuard[AdhocColumn]: return isinstance(column, dict) -def get_column_name(column: Column) -> str: +def get_column_name( + column: Column, verbose_map: Optional[Dict[str, Any]] = None +) -> str: """ Extract label from column :param column: object to extract label from + :param verbose_map: verbose_map from dataset for optional mapping from + raw name to verbose name :return: String representation of column :raises ValueError: if metric object is invalid """ @@ -1243,15 +1247,20 @@ def get_column_name(column: Column) -> str: expr = column.get("sqlExpression") if expr: return expr - raise Exception("Missing label") - return column + raise ValueError("Missing label") + verbose_map = verbose_map or {} + return verbose_map.get(column, column) -def get_metric_name(metric: Metric) -> str: +def get_metric_name( + metric: Metric, verbose_map: Optional[Dict[str, Any]] = None +) -> str: """ Extract label from metric :param metric: object to extract label from + :param verbose_map: verbose_map from dataset for optional mapping from + raw name to verbose name :return: String representation of metric :raises ValueError: if metric object is invalid """ @@ -1273,19 +1282,35 @@ def get_metric_name(metric: Metric) -> str: if column_name: return column_name raise ValueError(__("Invalid metric object")) - return metric # type: ignore + + verbose_map = verbose_map or {} + return verbose_map.get(metric, metric) # type: ignore -def get_column_names(columns: Optional[Sequence[Column]]) -> List[str]: - return [column for column in map(get_column_name, columns or []) if column] +def get_column_names( + columns: Optional[Sequence[Column]], verbose_map: Optional[Dict[str, Any]] = None, +) -> List[str]: + return [ + column + for column in [get_column_name(column, verbose_map) for column in columns or []] + if column + ] -def get_metric_names(metrics: Optional[Sequence[Metric]]) -> List[str]: - return [metric for metric in map(get_metric_name, metrics or []) if metric] +def get_metric_names( + metrics: Optional[Sequence[Metric]], verbose_map: Optional[Dict[str, Any]] = None, +) -> List[str]: + return [ + metric + for metric in [get_metric_name(metric, verbose_map) for metric in metrics or []] + if metric + ] -def get_first_metric_name(metrics: Optional[Sequence[Metric]]) -> Optional[str]: - metric_labels = get_metric_names(metrics) +def get_first_metric_name( + metrics: Optional[Sequence[Metric]], verbose_map: Optional[Dict[str, Any]] = None, +) -> Optional[str]: + metric_labels = get_metric_names(metrics, verbose_map) return metric_labels[0] if metric_labels else None diff --git a/tests/unit_tests/core_tests.py b/tests/unit_tests/core_tests.py index 87e0abcef..c9f96204c 100644 --- a/tests/unit_tests/core_tests.py +++ b/tests/unit_tests/core_tests.py @@ -19,10 +19,13 @@ from copy import deepcopy import pytest from superset.utils.core import ( + AdhocColumn, AdhocMetric, ExtraFiltersReasonType, ExtraFiltersTimeColumnType, GenericDataType, + get_column_name, + get_column_names, get_metric_name, get_metric_names, get_time_filter_status, @@ -47,15 +50,23 @@ SQL_ADHOC_METRIC: AdhocMetric = { "label": "my_sql", "sqlExpression": "SUM(my_col)", } +STR_COLUMN = "my_column" +SQL_ADHOC_COLUMN: AdhocColumn = { + "hasCustomLabel": True, + "label": "My Adhoc Column", + "sqlExpression": "case when foo = 1 then 'foo' else 'bar' end", +} def test_get_metric_name_saved_metric(): assert get_metric_name(STR_METRIC) == "my_metric" + assert get_metric_name(STR_METRIC, {STR_METRIC: "My Metric"}) == "My Metric" def test_get_metric_name_adhoc(): metric = deepcopy(SIMPLE_SUM_ADHOC_METRIC) assert get_metric_name(metric) == "my SUM" + assert get_metric_name(metric, {"my SUM": "My Irrelevant Mapping"}) == "my SUM" del metric["label"] assert get_metric_name(metric) == "SUM(my_col)" metric["label"] = "" @@ -64,9 +75,11 @@ def test_get_metric_name_adhoc(): assert get_metric_name(metric) == "my_col" metric["aggregate"] = "" assert get_metric_name(metric) == "my_col" + assert get_metric_name(metric, {"my_col": "My Irrelevant Mapping"}) == "my_col" metric = deepcopy(SQL_ADHOC_METRIC) assert get_metric_name(metric) == "my_sql" + assert get_metric_name(metric, {"my_sql": "My Irrelevant Mapping"}) == "my_sql" del metric["label"] assert get_metric_name(metric) == "SUM(my_col)" metric["label"] = "" @@ -97,6 +110,46 @@ def test_get_metric_names(): assert get_metric_names( [STR_METRIC, SIMPLE_SUM_ADHOC_METRIC, SQL_ADHOC_METRIC] ) == ["my_metric", "my SUM", "my_sql"] + assert get_metric_names( + [STR_METRIC, SIMPLE_SUM_ADHOC_METRIC, SQL_ADHOC_METRIC], + {STR_METRIC: "My Metric"}, + ) == ["My Metric", "my SUM", "my_sql"] + + +def test_get_column_name_physical_column(): + assert get_column_name(STR_COLUMN) == "my_column" + assert get_metric_name(STR_COLUMN, {STR_COLUMN: "My Column"}) == "My Column" + + +def test_get_column_name_adhoc(): + column = deepcopy(SQL_ADHOC_COLUMN) + assert get_column_name(column) == "My Adhoc Column" + assert ( + get_column_name(column, {"My Adhoc Column": "My Irrelevant Mapping"}) + == "My Adhoc Column" + ) + del column["label"] + assert get_column_name(column) == "case when foo = 1 then 'foo' else 'bar' end" + column["label"] = "" + assert get_column_name(column) == "case when foo = 1 then 'foo' else 'bar' end" + + +def test_get_column_names(): + assert get_column_names([STR_COLUMN, SQL_ADHOC_COLUMN]) == [ + "my_column", + "My Adhoc Column", + ] + assert get_column_names( + [STR_COLUMN, SQL_ADHOC_COLUMN], {"my_column": "My Column"}, + ) == ["My Column", "My Adhoc Column"] + + +def test_get_column_name_invalid_metric(): + column = deepcopy(SQL_ADHOC_COLUMN) + del column["label"] + del column["sqlExpression"] + with pytest.raises(ValueError): + get_column_name(column) def test_is_adhoc_metric():