feat: post-processing for pivot table v2 (#15879)

* feat: add pivot v2 post-processing

* Fix lint
This commit is contained in:
Beto Dealmeida 2021-07-29 11:05:56 -07:00 committed by GitHub
parent 6afa840659
commit f4739f427e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 504 additions and 1 deletions

View File

@ -33,6 +33,13 @@ import pandas as pd
from superset.utils.core import DTTM_ALIAS, extract_dataframe_dtypes, get_metric_name from superset.utils.core import DTTM_ALIAS, extract_dataframe_dtypes, get_metric_name
def sql_like_sum(series: pd.Series) -> pd.Series:
"""
A SUM aggregation function that mimics the behavior from SQL.
"""
return series.sum(min_count=1)
def pivot_table( def pivot_table(
result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None
) -> Dict[Any, Any]: ) -> Dict[Any, Any]:
@ -53,7 +60,7 @@ def pivot_table(
aggfunc = form_data.get("pandas_aggfunc") or "sum" aggfunc = form_data.get("pandas_aggfunc") or "sum"
if pd.api.types.is_numeric_dtype(df[metric]): if pd.api.types.is_numeric_dtype(df[metric]):
if aggfunc == "sum": if aggfunc == "sum":
aggfunc = lambda x: x.sum(min_count=1) aggfunc = sql_like_sum
elif aggfunc not in {"min", "max"}: elif aggfunc not in {"min", "max"}:
aggfunc = "max" aggfunc = "max"
aggfuncs[metric] = aggfunc aggfuncs[metric] = aggfunc
@ -95,6 +102,120 @@ def pivot_table(
return result return result
def list_unique_values(series: pd.Series) -> str:
"""
List unique values in a series.
"""
return ", ".join(set(str(v) for v in pd.Series.unique(series)))
pivot_v2_aggfunc_map = {
"Count": pd.Series.count,
"Count Unique Values": pd.Series.nunique,
"List Unique Values": list_unique_values,
"Sum": pd.Series.sum,
"Average": pd.Series.mean,
"Median": pd.Series.median,
"Sample Variance": lambda series: pd.series.var(series) if len(series) > 1 else 0,
"Sample Standard Deviation": (
lambda series: pd.series.std(series) if len(series) > 1 else 0,
),
"Minimum": pd.Series.min,
"Maximum": pd.Series.max,
"First": lambda series: series[:1],
"Last": lambda series: series[-1:],
"Sum as Fraction of Total": pd.Series.sum,
"Sum as Fraction of Rows": pd.Series.sum,
"Sum as Fraction of Columns": pd.Series.sum,
"Count as Fraction of Total": pd.Series.count,
"Count as Fraction of Rows": pd.Series.count,
"Count as Fraction of Columns": pd.Series.count,
}
def pivot_table_v2( # pylint: disable=too-many-branches
result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None,
) -> Dict[Any, Any]:
"""
Pivot table v2.
"""
for query in result["queries"]:
data = query["data"]
df = pd.DataFrame(data)
form_data = form_data or {}
if form_data.get("granularity_sqla") == "all" and DTTM_ALIAS in df:
del df[DTTM_ALIAS]
# TODO (betodealmeida): implement metricsLayout
metrics = [get_metric_name(m) for m in form_data["metrics"]]
aggregate_function = form_data.get("aggregateFunction", "Sum")
groupby = form_data.get("groupbyRows") or []
columns = form_data.get("groupbyColumns") or []
if form_data.get("transposePivot"):
groupby, columns = columns, groupby
df = df.pivot_table(
index=groupby,
columns=columns,
values=metrics,
aggfunc=pivot_v2_aggfunc_map[aggregate_function],
margins=True,
)
# The pandas `pivot_table` method either brings both row/column
# totals, or none at all. We pass `margin=True` to get both, and
# remove any dimension that was not requests.
if not form_data.get("rowTotals"):
df.drop(df.columns[len(df.columns) - 1], axis=1, inplace=True)
if not form_data.get("colTotals"):
df = df[:-1]
# Compute fractions, if needed. If `colTotals` or `rowTotals` are
# present we need to adjust for including them in the sum
if aggregate_function.endswith(" as Fraction of Total"):
total = df.sum().sum()
df = df.astype(total.dtypes) / total
if form_data.get("colTotals"):
df *= 2
if form_data.get("rowTotals"):
df *= 2
elif aggregate_function.endswith(" as Fraction of Columns"):
total = df.sum(axis=0)
df = df.astype(total.dtypes).div(total, axis=1)
if form_data.get("colTotals"):
df *= 2
elif aggregate_function.endswith(" as Fraction of Rows"):
total = df.sum(axis=1)
df = df.astype(total.dtypes).div(total, axis=0)
if form_data.get("rowTotals"):
df *= 2
# Re-order the columns adhering to the metric ordering.
df = df[metrics]
# Display metrics side by side with each column
if form_data.get("combineMetric"):
df = df.stack(0).unstack().reindex(level=-1, columns=metrics)
# flatten column names
df.columns = [" ".join(column) for column in df.columns]
# re-arrange data into a list of dicts
data = []
for i in df.index:
row = {col: df[col][i] for col in df.columns}
row[df.index.name] = i
data.append(row)
query["data"] = data
query["colnames"] = list(df.columns)
query["coltypes"] = extract_dataframe_dtypes(df)
query["rowcount"] = len(df.index)
return result
post_processors = { post_processors = {
"pivot_table": pivot_table, "pivot_table": pivot_table,
"pivot_table_v2": pivot_table_v2,
} }

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,366 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import copy
from typing import Any, Dict
from superset.charts.post_processing import pivot_table, pivot_table_v2
from superset.utils.core import GenericDataType, QueryStatus
RESULT: Dict[str, Any] = {
"query_context": None,
"queries": [
{
"cache_key": "1bd3ab8c01e98a0e349fb61bc76d9b90",
"cached_dttm": None,
"cache_timeout": 86400,
"annotation_data": {},
"error": None,
"is_cached": None,
"query": """SELECT state AS state,
gender AS gender,
sum(num) AS \"Births\"
FROM birth_names
WHERE ds >= TO_TIMESTAMP('1921-07-28 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US')
AND ds < TO_TIMESTAMP('2021-07-28 10:39:44.000000', 'YYYY-MM-DD HH24:MI:SS.US')
GROUP BY state,
gender
LIMIT 50000;
""",
"status": QueryStatus.SUCCESS,
"stacktrace": None,
"rowcount": 22,
"colnames": ["state", "gender", "Births"],
"coltypes": [
GenericDataType.STRING,
GenericDataType.STRING,
GenericDataType.NUMERIC,
],
"data": [
{"state": "OH", "gender": "boy", "Births": int("2376385")},
{"state": "TX", "gender": "girl", "Births": int("2313186")},
{"state": "MA", "gender": "boy", "Births": int("1285126")},
{"state": "MA", "gender": "girl", "Births": int("842146")},
{"state": "PA", "gender": "boy", "Births": int("2390275")},
{"state": "NY", "gender": "boy", "Births": int("3543961")},
{"state": "FL", "gender": "boy", "Births": int("1968060")},
{"state": "TX", "gender": "boy", "Births": int("3311985")},
{"state": "NJ", "gender": "boy", "Births": int("1486126")},
{"state": "CA", "gender": "girl", "Births": int("3567754")},
{"state": "CA", "gender": "boy", "Births": int("5430796")},
{"state": "IL", "gender": "girl", "Births": int("1614427")},
{"state": "FL", "gender": "girl", "Births": int("1312593")},
{"state": "NY", "gender": "girl", "Births": int("2280733")},
{"state": "NJ", "gender": "girl", "Births": int("992702")},
{"state": "MI", "gender": "girl", "Births": int("1326229")},
{"state": "other", "gender": "girl", "Births": int("15058341")},
{"state": "other", "gender": "boy", "Births": int("22044909")},
{"state": "MI", "gender": "boy", "Births": int("1938321")},
{"state": "IL", "gender": "boy", "Births": int("2357411")},
{"state": "PA", "gender": "girl", "Births": int("1615383")},
{"state": "OH", "gender": "girl", "Births": int("1622814")},
],
"applied_filters": [],
"rejected_filters": [],
}
],
}
def test_pivot_table():
form_data = {
"adhoc_filters": [],
"columns": ["state"],
"datasource": "3__table",
"date_format": "smart_date",
"extra_form_data": {},
"granularity_sqla": "ds",
"groupby": ["gender"],
"metrics": [
{
"aggregate": "SUM",
"column": {"column_name": "num", "type": "BIGINT"},
"expressionType": "SIMPLE",
"label": "Births",
"optionName": "metric_11",
}
],
"number_format": "SMART_NUMBER",
"order_desc": True,
"pandas_aggfunc": "sum",
"pivot_margins": True,
"row_limit": 50000,
"slice_id": 143,
"time_grain_sqla": "P1D",
"time_range": "100 years ago : now",
"time_range_endpoints": ["inclusive", "exclusive"],
"url_params": {},
"viz_type": "pivot_table",
}
result = copy.deepcopy(RESULT)
assert pivot_table(result, form_data) == {
"query_context": None,
"queries": [
{
"cache_key": "1bd3ab8c01e98a0e349fb61bc76d9b90",
"cached_dttm": None,
"cache_timeout": 86400,
"annotation_data": {},
"error": None,
"is_cached": None,
"query": """SELECT state AS state,
gender AS gender,
sum(num) AS \"Births\"
FROM birth_names
WHERE ds >= TO_TIMESTAMP('1921-07-28 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US')
AND ds < TO_TIMESTAMP('2021-07-28 10:39:44.000000', 'YYYY-MM-DD HH24:MI:SS.US')
GROUP BY state,
gender
LIMIT 50000;
""",
"status": QueryStatus.SUCCESS,
"stacktrace": None,
"rowcount": 3,
"colnames": [
"Births CA",
"Births FL",
"Births IL",
"Births MA",
"Births MI",
"Births NJ",
"Births NY",
"Births OH",
"Births PA",
"Births TX",
"Births other",
"Births All",
],
"coltypes": [
GenericDataType.NUMERIC,
GenericDataType.NUMERIC,
GenericDataType.NUMERIC,
GenericDataType.NUMERIC,
GenericDataType.NUMERIC,
GenericDataType.NUMERIC,
GenericDataType.NUMERIC,
GenericDataType.NUMERIC,
GenericDataType.NUMERIC,
GenericDataType.NUMERIC,
GenericDataType.NUMERIC,
GenericDataType.NUMERIC,
],
"data": [
{
"Births CA": 5430796,
"Births FL": 1968060,
"Births IL": 2357411,
"Births MA": 1285126,
"Births MI": 1938321,
"Births NJ": 1486126,
"Births NY": 3543961,
"Births OH": 2376385,
"Births PA": 2390275,
"Births TX": 3311985,
"Births other": 22044909,
"Births All": 48133355,
"gender": "boy",
},
{
"Births CA": 3567754,
"Births FL": 1312593,
"Births IL": 1614427,
"Births MA": 842146,
"Births MI": 1326229,
"Births NJ": 992702,
"Births NY": 2280733,
"Births OH": 1622814,
"Births PA": 1615383,
"Births TX": 2313186,
"Births other": 15058341,
"Births All": 32546308,
"gender": "girl",
},
{
"Births CA": 8998550,
"Births FL": 3280653,
"Births IL": 3971838,
"Births MA": 2127272,
"Births MI": 3264550,
"Births NJ": 2478828,
"Births NY": 5824694,
"Births OH": 3999199,
"Births PA": 4005658,
"Births TX": 5625171,
"Births other": 37103250,
"Births All": 80679663,
"gender": "All",
},
],
"applied_filters": [],
"rejected_filters": [],
}
],
}
def test_pivot_table_v2():
form_data = {
"adhoc_filters": [],
"aggregateFunction": "Sum as Fraction of Rows",
"colOrder": "key_a_to_z",
"colTotals": True,
"combineMetric": True,
"datasource": "3__table",
"date_format": "smart_date",
"extra_form_data": {},
"granularity_sqla": "ds",
"groupbyColumns": ["state"],
"groupbyRows": ["gender"],
"metrics": [
{
"aggregate": "SUM",
"column": {"column_name": "num", "type": "BIGINT"},
"expressionType": "SIMPLE",
"label": "Births",
"optionName": "metric_11",
}
],
"metricsLayout": "ROWS",
"rowOrder": "key_a_to_z",
"rowTotals": True,
"row_limit": 50000,
"slice_id": 72,
"time_grain_sqla": None,
"time_range": "100 years ago : now",
"time_range_endpoints": ["inclusive", "exclusive"],
"transposePivot": True,
"url_params": {},
"valueFormat": "SMART_NUMBER",
"viz_type": "pivot_table_v2",
}
result = copy.deepcopy(RESULT)
assert pivot_table_v2(result, form_data) == {
"query_context": None,
"queries": [
{
"cache_key": "1bd3ab8c01e98a0e349fb61bc76d9b90",
"cached_dttm": None,
"cache_timeout": 86400,
"annotation_data": {},
"error": None,
"is_cached": None,
"query": """SELECT state AS state,
gender AS gender,
sum(num) AS \"Births\"
FROM birth_names
WHERE ds >= TO_TIMESTAMP('1921-07-28 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US')
AND ds < TO_TIMESTAMP('2021-07-28 10:39:44.000000', 'YYYY-MM-DD HH24:MI:SS.US')
GROUP BY state,
gender
LIMIT 50000;
""",
"status": QueryStatus.SUCCESS,
"stacktrace": None,
"rowcount": 12,
"colnames": ["All Births", "boy Births", "girl Births"],
"coltypes": [
GenericDataType.NUMERIC,
GenericDataType.NUMERIC,
GenericDataType.NUMERIC,
],
"data": [
{
"All Births": 1.0,
"boy Births": 0.5965983645717509,
"girl Births": 0.40340163542824914,
"state": "All",
},
{
"All Births": 1.0,
"boy Births": 0.6035190113962805,
"girl Births": 0.3964809886037195,
"state": "CA",
},
{
"All Births": 1.0,
"boy Births": 0.5998988615985903,
"girl Births": 0.4001011384014097,
"state": "FL",
},
{
"All Births": 1.0,
"boy Births": 0.5935315085862012,
"girl Births": 0.40646849141379887,
"state": "IL",
},
{
"All Births": 1.0,
"boy Births": 0.6041192663655611,
"girl Births": 0.3958807336344389,
"state": "MA",
},
{
"All Births": 1.0,
"boy Births": 0.5937482960898133,
"girl Births": 0.4062517039101867,
"state": "MI",
},
{
"All Births": 1.0,
"boy Births": 0.5995276800165239,
"girl Births": 0.40047231998347604,
"state": "NJ",
},
{
"All Births": 1.0,
"boy Births": 0.6084372844307357,
"girl Births": 0.39156271556926425,
"state": "NY",
},
{
"All Births": 1.0,
"boy Births": 0.5942152416021308,
"girl Births": 0.40578475839786915,
"state": "OH",
},
{
"All Births": 1.0,
"boy Births": 0.596724682935987,
"girl Births": 0.40327531706401293,
"state": "PA",
},
{
"All Births": 1.0,
"boy Births": 0.5887794344385264,
"girl Births": 0.41122056556147357,
"state": "TX",
},
{
"All Births": 1.0,
"boy Births": 0.5941503507105172,
"girl Births": 0.40584964928948275,
"state": "other",
},
],
"applied_filters": [],
"rejected_filters": [],
}
],
}