feat: support non-numeric columns in pivot table (#10389)
* fix: support non-numeric columns in pivot table * bump package and add unit tests * mypy
This commit is contained in:
parent
5e93f00a53
commit
fc28c92f57
|
|
@ -29,7 +29,18 @@ import uuid
|
|||
from collections import defaultdict, OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
from itertools import product
|
||||
from typing import Any, cast, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
|
||||
import dataclasses
|
||||
import geohash
|
||||
|
|
@ -734,6 +745,7 @@ class PivotTableViz(BaseViz):
|
|||
verbose_name = _("Pivot Table")
|
||||
credits = 'a <a href="https://github.com/airbnb/superset">Superset</a> original'
|
||||
is_timeseries = False
|
||||
enforce_numerical_metrics = False
|
||||
|
||||
def query_obj(self) -> QueryObjectDict:
|
||||
d = super().query_obj()
|
||||
|
|
@ -764,6 +776,18 @@ class PivotTableViz(BaseViz):
|
|||
raise QueryObjectValidationError(_("Group By' and 'Columns' can't overlap"))
|
||||
return d
|
||||
|
||||
@staticmethod
|
||||
def get_aggfunc(
|
||||
metric: str, df: pd.DataFrame, form_data: Dict[str, Any]
|
||||
) -> Union[str, Callable[[Any], Any]]:
|
||||
aggfunc = form_data.get("pandas_aggfunc") or "sum"
|
||||
if pd.api.types.is_numeric_dtype(df[metric]):
|
||||
# Ensure that Pandas's sum function mimics that of SQL.
|
||||
if aggfunc == "sum":
|
||||
return lambda x: x.sum(min_count=1)
|
||||
# only min and max work properly for non-numerics
|
||||
return aggfunc if aggfunc in ("min", "max") else "max"
|
||||
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
if df.empty:
|
||||
return None
|
||||
|
|
@ -771,22 +795,21 @@ class PivotTableViz(BaseViz):
|
|||
if self.form_data.get("granularity") == "all" and DTTM_ALIAS in df:
|
||||
del df[DTTM_ALIAS]
|
||||
|
||||
aggfunc = self.form_data.get("pandas_aggfunc") or "sum"
|
||||
|
||||
# Ensure that Pandas's sum function mimics that of SQL.
|
||||
if aggfunc == "sum":
|
||||
aggfunc = lambda x: x.sum(min_count=1)
|
||||
metrics = [utils.get_metric_name(m) for m in self.form_data["metrics"]]
|
||||
aggfuncs: Dict[str, Union[str, Callable[[Any], Any]]] = {}
|
||||
for metric in metrics:
|
||||
aggfuncs[metric] = self.get_aggfunc(metric, df, self.form_data)
|
||||
|
||||
groupby = self.form_data.get("groupby")
|
||||
columns = self.form_data.get("columns")
|
||||
if self.form_data.get("transpose_pivot"):
|
||||
groupby, columns = columns, groupby
|
||||
metrics = [utils.get_metric_name(m) for m in self.form_data["metrics"]]
|
||||
|
||||
df = df.pivot_table(
|
||||
index=groupby,
|
||||
columns=columns,
|
||||
values=metrics,
|
||||
aggfunc=aggfunc,
|
||||
aggfunc=aggfuncs,
|
||||
margins=self.form_data.get("pivot_margins"),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1292,3 +1292,41 @@ class TestBigNumberViz(SupersetTestCase):
|
|||
)
|
||||
data = viz.BigNumberViz(datasource, {"metrics": ["y"]}).get_data(df)
|
||||
assert np.isnan(data[2]["y"])
|
||||
|
||||
|
||||
class TestPivotTableViz(SupersetTestCase):
|
||||
df = pd.DataFrame(
|
||||
data={
|
||||
"intcol": [1, 2, 3, None],
|
||||
"floatcol": [0.1, 0.2, 0.3, None],
|
||||
"strcol": ["a", "b", "c", None],
|
||||
}
|
||||
)
|
||||
|
||||
def test_get_aggfunc_numeric(self):
|
||||
# is a sum function
|
||||
func = viz.PivotTableViz.get_aggfunc("intcol", self.df, {})
|
||||
assert hasattr(func, "__call__")
|
||||
assert func(self.df["intcol"]) == 6
|
||||
|
||||
assert (
|
||||
viz.PivotTableViz.get_aggfunc("intcol", self.df, {"pandas_aggfunc": "min"})
|
||||
== "min"
|
||||
)
|
||||
assert (
|
||||
viz.PivotTableViz.get_aggfunc(
|
||||
"floatcol", self.df, {"pandas_aggfunc": "max"}
|
||||
)
|
||||
== "max"
|
||||
)
|
||||
|
||||
def test_get_aggfunc_non_numeric(self):
|
||||
assert viz.PivotTableViz.get_aggfunc("strcol", self.df, {}) == "max"
|
||||
assert (
|
||||
viz.PivotTableViz.get_aggfunc("strcol", self.df, {"pandas_aggfunc": "sum"})
|
||||
== "max"
|
||||
)
|
||||
assert (
|
||||
viz.PivotTableViz.get_aggfunc("strcol", self.df, {"pandas_aggfunc": "min"})
|
||||
== "min"
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue