fix: rolling and cum operator on multiple series (#16945)
* fix: rolling and cum operator on multiple series * add UT * updates
This commit is contained in:
parent
6dc00b3e3f
commit
fd8461406d
|
|
@ -131,6 +131,9 @@ def _flatten_column_after_pivot(
|
|||
def validate_column_args(*argnames: str) -> Callable[..., Any]:
|
||||
def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def wrapped(df: DataFrame, **options: Any) -> Any:
|
||||
if options.get("is_pivot_df"):
|
||||
# skip validation when pivot Dataframe
|
||||
return func(df, **options)
|
||||
columns = df.columns.tolist()
|
||||
for name in argnames:
|
||||
if name in options and not all(
|
||||
|
|
@ -223,6 +226,7 @@ def pivot( # pylint: disable=too-many-arguments,too-many-locals
|
|||
marginal_distributions: Optional[bool] = None,
|
||||
marginal_distribution_name: Optional[str] = None,
|
||||
flatten_columns: bool = True,
|
||||
reset_index: bool = True,
|
||||
) -> DataFrame:
|
||||
"""
|
||||
Perform a pivot operation on a DataFrame.
|
||||
|
|
@ -243,6 +247,7 @@ def pivot( # pylint: disable=too-many-arguments,too-many-locals
|
|||
:param marginal_distribution_name: Name of row/column with marginal distribution.
|
||||
Default to 'All'.
|
||||
:param flatten_columns: Convert column names to strings
|
||||
:param reset_index: Convert index to column
|
||||
:return: A pivot table
|
||||
:raises QueryObjectValidationError: If the request in incorrect
|
||||
"""
|
||||
|
|
@ -300,7 +305,8 @@ def pivot( # pylint: disable=too-many-arguments,too-many-locals
|
|||
_flatten_column_after_pivot(col, aggregates) for col in df.columns
|
||||
]
|
||||
# return index as regular column
|
||||
df.reset_index(level=0, inplace=True)
|
||||
if reset_index:
|
||||
df.reset_index(level=0, inplace=True)
|
||||
return df
|
||||
|
||||
|
||||
|
|
@ -343,13 +349,14 @@ def sort(df: DataFrame, columns: Dict[str, bool]) -> DataFrame:
|
|||
@validate_column_args("columns")
|
||||
def rolling( # pylint: disable=too-many-arguments
|
||||
df: DataFrame,
|
||||
columns: Dict[str, str],
|
||||
rolling_type: str,
|
||||
columns: Optional[Dict[str, str]] = None,
|
||||
window: Optional[int] = None,
|
||||
rolling_type_options: Optional[Dict[str, Any]] = None,
|
||||
center: bool = False,
|
||||
win_type: Optional[str] = None,
|
||||
min_periods: Optional[int] = None,
|
||||
is_pivot_df: bool = False,
|
||||
) -> DataFrame:
|
||||
"""
|
||||
Apply a rolling window on the dataset. See the Pandas docs for further details:
|
||||
|
|
@ -369,11 +376,16 @@ def rolling( # pylint: disable=too-many-arguments
|
|||
:param win_type: Type of window function.
|
||||
:param min_periods: The minimum amount of periods required for a row to be included
|
||||
in the result set.
|
||||
:param is_pivot_df: Dataframe is pivoted or not
|
||||
:return: DataFrame with the rolling columns
|
||||
:raises QueryObjectValidationError: If the request in incorrect
|
||||
"""
|
||||
rolling_type_options = rolling_type_options or {}
|
||||
df_rolling = df[columns.keys()]
|
||||
columns = columns or {}
|
||||
if is_pivot_df:
|
||||
df_rolling = df
|
||||
else:
|
||||
df_rolling = df[columns.keys()]
|
||||
kwargs: Dict[str, Union[str, int]] = {}
|
||||
if window is None:
|
||||
raise QueryObjectValidationError(_("Undefined window for rolling operation"))
|
||||
|
|
@ -405,10 +417,20 @@ def rolling( # pylint: disable=too-many-arguments
|
|||
options=rolling_type_options,
|
||||
)
|
||||
) from ex
|
||||
df = _append_columns(df, df_rolling, columns)
|
||||
|
||||
if is_pivot_df:
|
||||
agg_in_pivot_df = df.columns.get_level_values(0).drop_duplicates().to_list()
|
||||
agg: Dict[str, Dict[str, Any]] = {col: {} for col in agg_in_pivot_df}
|
||||
df_rolling.columns = [
|
||||
_flatten_column_after_pivot(col, agg) for col in df_rolling.columns
|
||||
]
|
||||
df_rolling.reset_index(level=0, inplace=True)
|
||||
else:
|
||||
df_rolling = _append_columns(df, df_rolling, columns)
|
||||
|
||||
if min_periods:
|
||||
df = df[min_periods:]
|
||||
return df
|
||||
df_rolling = df_rolling[min_periods:]
|
||||
return df_rolling
|
||||
|
||||
|
||||
@validate_column_args("columns", "drop", "rename")
|
||||
|
|
@ -524,7 +546,12 @@ def compare( # pylint: disable=too-many-arguments
|
|||
|
||||
|
||||
@validate_column_args("columns")
|
||||
def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame:
|
||||
def cum(
|
||||
df: DataFrame,
|
||||
operator: str,
|
||||
columns: Optional[Dict[str, str]] = None,
|
||||
is_pivot_df: bool = False,
|
||||
) -> DataFrame:
|
||||
"""
|
||||
Calculate cumulative sum/product/min/max for select columns.
|
||||
|
||||
|
|
@ -535,9 +562,14 @@ def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame:
|
|||
`y2` based on cumulative values calculated from `y`, leaving the original
|
||||
column `y` unchanged.
|
||||
:param operator: cumulative operator, e.g. `sum`, `prod`, `min`, `max`
|
||||
:param is_pivot_df: Dataframe is pivoted or not
|
||||
:return: DataFrame with cumulated columns
|
||||
"""
|
||||
df_cum = df[columns.keys()]
|
||||
columns = columns or {}
|
||||
if is_pivot_df:
|
||||
df_cum = df
|
||||
else:
|
||||
df_cum = df[columns.keys()]
|
||||
operation = "cum" + operator
|
||||
if operation not in ALLOWLIST_CUMULATIVE_FUNCTIONS or not hasattr(
|
||||
df_cum, operation
|
||||
|
|
@ -545,7 +577,17 @@ def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame:
|
|||
raise QueryObjectValidationError(
|
||||
_("Invalid cumulative operator: %(operator)s", operator=operator)
|
||||
)
|
||||
return _append_columns(df, getattr(df_cum, operation)(), columns)
|
||||
if is_pivot_df:
|
||||
df_cum = getattr(df_cum, operation)()
|
||||
agg_in_pivot_df = df.columns.get_level_values(0).drop_duplicates().to_list()
|
||||
agg: Dict[str, Dict[str, Any]] = {col: {} for col in agg_in_pivot_df}
|
||||
df_cum.columns = [
|
||||
_flatten_column_after_pivot(col, agg) for col in df_cum.columns
|
||||
]
|
||||
df_cum.reset_index(level=0, inplace=True)
|
||||
else:
|
||||
df_cum = _append_columns(df, getattr(df_cum, operation)(), columns)
|
||||
return df_cum
|
||||
|
||||
|
||||
def geohash_decode(
|
||||
|
|
|
|||
|
|
@ -165,3 +165,19 @@ prophet_df = DataFrame(
|
|||
"b": [4, 3, 4.1, 3.95],
|
||||
}
|
||||
)
|
||||
|
||||
single_metric_df = DataFrame(
|
||||
{
|
||||
"dttm": to_datetime(["2019-01-01", "2019-01-01", "2019-01-02", "2019-01-02",]),
|
||||
"country": ["UK", "US", "UK", "US"],
|
||||
"sum_metric": [5, 6, 7, 8],
|
||||
}
|
||||
)
|
||||
multiple_metrics_df = DataFrame(
|
||||
{
|
||||
"dttm": to_datetime(["2019-01-01", "2019-01-01", "2019-01-02", "2019-01-02",]),
|
||||
"country": ["UK", "US", "UK", "US"],
|
||||
"sum_metric": [5, 6, 7, 8],
|
||||
"count_metric": [1, 2, 3, 4],
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -35,6 +35,8 @@ from superset.utils.core import (
|
|||
from .base_tests import SupersetTestCase
|
||||
from .fixtures.dataframes import (
|
||||
categories_df,
|
||||
single_metric_df,
|
||||
multiple_metrics_df,
|
||||
lonlat_df,
|
||||
names_df,
|
||||
timeseries_df,
|
||||
|
|
@ -305,6 +307,23 @@ class TestPostProcessing(SupersetTestCase):
|
|||
)
|
||||
self.assertTrue(np.isnan(df["metric, 1, 1"][0]))
|
||||
|
||||
def test_pivot_without_flatten_columns_and_reset_index(self):
|
||||
df = proc.pivot(
|
||||
df=single_metric_df,
|
||||
index=["dttm"],
|
||||
columns=["country"],
|
||||
aggregates={"sum_metric": {"operator": "sum"}},
|
||||
flatten_columns=False,
|
||||
reset_index=False,
|
||||
)
|
||||
# metric
|
||||
# country UK US
|
||||
# dttm
|
||||
# 2019-01-01 5 6
|
||||
# 2019-01-02 7 8
|
||||
assert df.columns.to_list() == [("sum_metric", "UK"), ("sum_metric", "US")]
|
||||
assert df.index.to_list() == to_datetime(["2019-01-01", "2019-01-02"]).to_list()
|
||||
|
||||
def test_aggregate(self):
|
||||
aggregates = {
|
||||
"asc sum": {"column": "asc_idx", "operator": "sum"},
|
||||
|
|
@ -405,6 +424,60 @@ class TestPostProcessing(SupersetTestCase):
|
|||
window=2,
|
||||
)
|
||||
|
||||
def test_rolling_with_pivot_df_and_single_metric(self):
|
||||
pivot_df = proc.pivot(
|
||||
df=single_metric_df,
|
||||
index=["dttm"],
|
||||
columns=["country"],
|
||||
aggregates={"sum_metric": {"operator": "sum"}},
|
||||
flatten_columns=False,
|
||||
reset_index=False,
|
||||
)
|
||||
rolling_df = proc.rolling(
|
||||
df=pivot_df, rolling_type="sum", window=2, min_periods=0, is_pivot_df=True,
|
||||
)
|
||||
# dttm UK US
|
||||
# 0 2019-01-01 5 6
|
||||
# 1 2019-01-02 12 14
|
||||
assert rolling_df["UK"].to_list() == [5.0, 12.0]
|
||||
assert rolling_df["US"].to_list() == [6.0, 14.0]
|
||||
assert (
|
||||
rolling_df["dttm"].to_list()
|
||||
== to_datetime(["2019-01-01", "2019-01-02",]).to_list()
|
||||
)
|
||||
|
||||
rolling_df = proc.rolling(
|
||||
df=pivot_df, rolling_type="sum", window=2, min_periods=2, is_pivot_df=True,
|
||||
)
|
||||
assert rolling_df.empty is True
|
||||
|
||||
def test_rolling_with_pivot_df_and_multiple_metrics(self):
|
||||
pivot_df = proc.pivot(
|
||||
df=multiple_metrics_df,
|
||||
index=["dttm"],
|
||||
columns=["country"],
|
||||
aggregates={
|
||||
"sum_metric": {"operator": "sum"},
|
||||
"count_metric": {"operator": "sum"},
|
||||
},
|
||||
flatten_columns=False,
|
||||
reset_index=False,
|
||||
)
|
||||
rolling_df = proc.rolling(
|
||||
df=pivot_df, rolling_type="sum", window=2, min_periods=0, is_pivot_df=True,
|
||||
)
|
||||
# dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
|
||||
# 0 2019-01-01 1.0 2.0 5.0 6.0
|
||||
# 1 2019-01-02 4.0 6.0 12.0 14.0
|
||||
assert rolling_df["count_metric, UK"].to_list() == [1.0, 4.0]
|
||||
assert rolling_df["count_metric, US"].to_list() == [2.0, 6.0]
|
||||
assert rolling_df["sum_metric, UK"].to_list() == [5.0, 12.0]
|
||||
assert rolling_df["sum_metric, US"].to_list() == [6.0, 14.0]
|
||||
assert (
|
||||
rolling_df["dttm"].to_list()
|
||||
== to_datetime(["2019-01-01", "2019-01-02",]).to_list()
|
||||
)
|
||||
|
||||
def test_select(self):
|
||||
# reorder columns
|
||||
post_df = proc.select(df=timeseries_df, columns=["y", "label"])
|
||||
|
|
@ -557,6 +630,51 @@ class TestPostProcessing(SupersetTestCase):
|
|||
operator="abc",
|
||||
)
|
||||
|
||||
def test_cum_with_pivot_df_and_single_metric(self):
|
||||
pivot_df = proc.pivot(
|
||||
df=single_metric_df,
|
||||
index=["dttm"],
|
||||
columns=["country"],
|
||||
aggregates={"sum_metric": {"operator": "sum"}},
|
||||
flatten_columns=False,
|
||||
reset_index=False,
|
||||
)
|
||||
cum_df = proc.cum(df=pivot_df, operator="sum", is_pivot_df=True,)
|
||||
# dttm UK US
|
||||
# 0 2019-01-01 5 6
|
||||
# 1 2019-01-02 12 14
|
||||
assert cum_df["UK"].to_list() == [5.0, 12.0]
|
||||
assert cum_df["US"].to_list() == [6.0, 14.0]
|
||||
assert (
|
||||
cum_df["dttm"].to_list()
|
||||
== to_datetime(["2019-01-01", "2019-01-02",]).to_list()
|
||||
)
|
||||
|
||||
def test_cum_with_pivot_df_and_multiple_metrics(self):
|
||||
pivot_df = proc.pivot(
|
||||
df=multiple_metrics_df,
|
||||
index=["dttm"],
|
||||
columns=["country"],
|
||||
aggregates={
|
||||
"sum_metric": {"operator": "sum"},
|
||||
"count_metric": {"operator": "sum"},
|
||||
},
|
||||
flatten_columns=False,
|
||||
reset_index=False,
|
||||
)
|
||||
cum_df = proc.cum(df=pivot_df, operator="sum", is_pivot_df=True,)
|
||||
# dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
|
||||
# 0 2019-01-01 1 2 5 6
|
||||
# 1 2019-01-02 4 6 12 14
|
||||
assert cum_df["count_metric, UK"].to_list() == [1.0, 4.0]
|
||||
assert cum_df["count_metric, US"].to_list() == [2.0, 6.0]
|
||||
assert cum_df["sum_metric, UK"].to_list() == [5.0, 12.0]
|
||||
assert cum_df["sum_metric, US"].to_list() == [6.0, 14.0]
|
||||
assert (
|
||||
cum_df["dttm"].to_list()
|
||||
== to_datetime(["2019-01-01", "2019-01-02",]).to_list()
|
||||
)
|
||||
|
||||
def test_geohash_decode(self):
|
||||
# decode lon/lat from geohash
|
||||
post_df = proc.geohash_decode(
|
||||
|
|
|
|||
Loading…
Reference in New Issue