feat: add rolling window support to 'Big Number with Trendline' viz (#9107)

* Rolling big number

* addressing comment
This commit is contained in:
Maxime Beauchemin 2020-03-10 10:19:12 -07:00 committed by GitHub
parent 753aeb4829
commit c04d6163e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 146 additions and 45 deletions

View File

@ -17,6 +17,7 @@
* under the License.
*/
import { t } from '@superset-ui/translation';
import React from 'react';
export default {
controlPanelSections: [
@ -43,6 +44,14 @@ export default {
['subheader_font_size'],
],
},
{
label: t('Advanced Analytics'),
expanded: false,
controlSetRows: [
[<h1 className="section-header">{t('Rolling Window')}</h1>],
['rolling_type', 'rolling_periods', 'min_periods'],
],
},
],
controlOverrides: {
y_axis_format: {

View File

@ -75,7 +75,7 @@ export const NVD3TimeSeries = [
'of query results',
),
controlSetRows: [
[<h1 className="section-header">{t('Moving Average')}</h1>],
[<h1 className="section-header">{t('Rolling Window')}</h1>],
['rolling_type', 'rolling_periods', 'min_periods'],
[<h1 className="section-header">{t('Time Comparison')}</h1>],
['time_compare', 'comparison_type'],

View File

@ -1126,7 +1126,7 @@ export const controls = {
rolling_type: {
type: 'SelectControl',
label: t('Rolling'),
label: t('Rolling Function'),
default: 'None',
choices: formatSelectOptions(['None', 'mean', 'sum', 'std', 'cumsum']),
description: t(

View File

@ -106,22 +106,23 @@ def load_birth_names(only_metadata=False, force=False):
obj.fetch_metadata()
tbl = obj
metrics = [
{
"expressionType": "SIMPLE",
"column": {"column_name": "num", "type": "BIGINT"},
"aggregate": "SUM",
"label": "Births",
"optionName": "metric_11",
}
]
metric = "sum__num"
defaults = {
"compare_lag": "10",
"compare_suffix": "o10Y",
"limit": "25",
"granularity_sqla": "ds",
"groupby": [],
"metric": "sum__num",
"metrics": [
{
"expressionType": "SIMPLE",
"column": {"column_name": "num", "type": "BIGINT"},
"aggregate": "SUM",
"label": "Births",
"optionName": "metric_11",
}
],
"row_limit": config["ROW_LIMIT"],
"since": "100 years ago",
"until": "now",
@ -144,6 +145,7 @@ def load_birth_names(only_metadata=False, force=False):
granularity_sqla="ds",
compare_lag="5",
compare_suffix="over 5Y",
metric=metric,
),
),
Slice(
@ -151,7 +153,9 @@ def load_birth_names(only_metadata=False, force=False):
viz_type="pie",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(defaults, viz_type="pie", groupby=["gender"]),
params=get_slice_json(
defaults, viz_type="pie", groupby=["gender"], metric=metric
),
),
Slice(
slice_name="Trends",
@ -165,6 +169,7 @@ def load_birth_names(only_metadata=False, force=False):
granularity_sqla="ds",
rich_tooltip=True,
show_legend=True,
metrics=metrics,
),
),
Slice(
@ -215,6 +220,7 @@ def load_birth_names(only_metadata=False, force=False):
adhoc_filters=[gen_filter("gender", "girl")],
row_limit=50,
timeseries_limit_metric="sum__num",
metrics=metrics,
),
),
Slice(
@ -231,6 +237,7 @@ def load_birth_names(only_metadata=False, force=False):
rotation="square",
limit="100",
adhoc_filters=[gen_filter("gender", "girl")],
metric=metric,
),
),
Slice(
@ -243,6 +250,7 @@ def load_birth_names(only_metadata=False, force=False):
groupby=["name"],
adhoc_filters=[gen_filter("gender", "boy")],
row_limit=50,
metrics=metrics,
),
),
Slice(
@ -259,6 +267,7 @@ def load_birth_names(only_metadata=False, force=False):
rotation="square",
limit="100",
adhoc_filters=[gen_filter("gender", "boy")],
metric=metric,
),
),
Slice(
@ -276,6 +285,7 @@ def load_birth_names(only_metadata=False, force=False):
time_grain_sqla="P1D",
viz_type="area",
x_axis_forma="smart_date",
metrics=metrics,
),
),
Slice(
@ -293,6 +303,7 @@ def load_birth_names(only_metadata=False, force=False):
time_grain_sqla="P1D",
viz_type="area",
x_axis_forma="smart_date",
metrics=metrics,
),
),
]
@ -314,6 +325,7 @@ def load_birth_names(only_metadata=False, force=False):
},
metric_2="sum__num",
granularity_sqla="ds",
metrics=metrics,
),
),
Slice(
@ -321,7 +333,7 @@ def load_birth_names(only_metadata=False, force=False):
viz_type="line",
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(defaults, viz_type="line"),
params=get_slice_json(defaults, viz_type="line", metrics=metrics),
),
Slice(
slice_name="Daily Totals",
@ -335,6 +347,7 @@ def load_birth_names(only_metadata=False, force=False):
since="40 years ago",
until="now",
viz_type="table",
metrics=metrics,
),
),
Slice(
@ -397,6 +410,7 @@ def load_birth_names(only_metadata=False, force=False):
datasource_id=tbl.id,
params=get_slice_json(
defaults,
metrics=metrics,
groupby=["name"],
row_limit=50,
timeseries_limit_metric={
@ -417,6 +431,7 @@ def load_birth_names(only_metadata=False, force=False):
datasource_id=tbl.id,
params=get_slice_json(
defaults,
metric=metric,
viz_type="big_number_total",
granularity_sqla="ds",
adhoc_filters=[gen_filter("gender", "girl")],
@ -429,7 +444,11 @@ def load_birth_names(only_metadata=False, force=False):
datasource_type="table",
datasource_id=tbl.id,
params=get_slice_json(
defaults, viz_type="pivot_table", groupby=["name"], columns=["state"]
defaults,
viz_type="pivot_table",
groupby=["name"],
columns=["state"],
metrics=metrics,
),
),
]

View File

@ -97,31 +97,32 @@ def load_world_bank_health_n_pop(
db.session.commit()
tbl.fetch_metadata()
metric = "sum__SP_POP_TOTL"
metrics = ["sum__SP_POP_TOTL"]
secondary_metric = {
"aggregate": "SUM",
"column": {
"column_name": "SP_RUR_TOTL",
"optionName": "_col_SP_RUR_TOTL",
"type": "DOUBLE",
},
"expressionType": "SIMPLE",
"hasCustomLabel": True,
"label": "Rural Population",
}
defaults = {
"compare_lag": "10",
"compare_suffix": "o10Y",
"limit": "25",
"granularity_sqla": "year",
"groupby": [],
"metric": "sum__SP_POP_TOTL",
"metrics": ["sum__SP_POP_TOTL"],
"row_limit": config["ROW_LIMIT"],
"since": "2014-01-01",
"until": "2014-01-02",
"time_range": "2014-01-01 : 2014-01-02",
"markup_type": "markdown",
"country_fieldtype": "cca3",
"secondary_metric": {
"aggregate": "SUM",
"column": {
"column_name": "SP_RUR_TOTL",
"optionName": "_col_SP_RUR_TOTL",
"type": "DOUBLE",
},
"expressionType": "SIMPLE",
"hasCustomLabel": True,
"label": "Rural Population",
},
"entity": "country_code",
"show_bubbles": True,
}
@ -207,6 +208,7 @@ def load_world_bank_health_n_pop(
viz_type="world_map",
metric="sum__SP_RUR_TOTL_ZS",
num_period_compare="10",
secondary_metric=secondary_metric,
),
),
Slice(
@ -264,6 +266,8 @@ def load_world_bank_health_n_pop(
groupby=["region", "country_name"],
since="2011-01-01",
until="2011-01-01",
metric=metric,
secondary_metric=secondary_metric,
),
),
Slice(
@ -277,6 +281,7 @@ def load_world_bank_health_n_pop(
until="now",
viz_type="area",
groupby=["region"],
metrics=metrics,
),
),
Slice(
@ -292,6 +297,7 @@ def load_world_bank_health_n_pop(
x_ticks_layout="staggered",
viz_type="box_plot",
groupby=["region"],
metrics=metrics,
),
),
Slice(

View File

@ -178,6 +178,26 @@ class BaseViz:
"""
pass
def apply_rolling(self, df):
fd = self.form_data
rolling_type = fd.get("rolling_type")
rolling_periods = int(fd.get("rolling_periods") or 0)
min_periods = int(fd.get("min_periods") or 0)
if rolling_type in ("mean", "std", "sum") and rolling_periods:
kwargs = dict(window=rolling_periods, min_periods=min_periods)
if rolling_type == "mean":
df = df.rolling(**kwargs).mean()
elif rolling_type == "std":
df = df.rolling(**kwargs).std()
elif rolling_type == "sum":
df = df.rolling(**kwargs).sum()
elif rolling_type == "cumsum":
df = df.cumsum()
if min_periods:
df = df[min_periods:]
return df
def get_samples(self):
query_obj = self.query_obj()
query_obj.update(
@ -1101,6 +1121,18 @@ class BigNumberViz(BaseViz):
self.form_data["metric"] = metric
return d
def get_data(self, df: pd.DataFrame) -> VizData:
df = df.pivot_table(
index=DTTM_ALIAS,
columns=[],
values=self.metric_labels,
fill_value=0,
aggfunc=sum,
)
df = self.apply_rolling(df)
df[DTTM_ALIAS] = df.index
return super().get_data(df)
class BigNumberTotalViz(BaseViz):
@ -1225,23 +1257,7 @@ class NVD3TimeSeriesViz(NVD3Viz):
dfs.sort_values(ascending=False, inplace=True)
df = df[dfs.index]
rolling_type = fd.get("rolling_type")
rolling_periods = int(fd.get("rolling_periods") or 0)
min_periods = int(fd.get("min_periods") or 0)
if rolling_type in ("mean", "std", "sum") and rolling_periods:
kwargs = dict(window=rolling_periods, min_periods=min_periods)
if rolling_type == "mean":
df = df.rolling(**kwargs).mean()
elif rolling_type == "std":
df = df.rolling(**kwargs).std()
elif rolling_type == "sum":
df = df.rolling(**kwargs).sum()
elif rolling_type == "cumsum":
df = df.cumsum()
if min_periods:
df = df[min_periods:]
df = self.apply_rolling(df)
if fd.get("contribution"):
dft = df.T
df = (dft / dft.sum()).T

View File

@ -1192,3 +1192,54 @@ class TimeSeriesVizTestCase(SupersetTestCase):
.tolist(),
[1.0, 2.0, np.nan, np.nan, 5.0, np.nan, 7.0],
)
def test_apply_rolling(self):
datasource = self.get_datasource_mock()
df = pd.DataFrame(
index=pd.to_datetime(
["2019-01-01", "2019-01-02", "2019-01-05", "2019-01-07"]
),
data={"y": [1.0, 2.0, 3.0, 4.0]},
)
self.assertEqual(
viz.BigNumberViz(
datasource,
{
"metrics": ["y"],
"rolling_type": "cumsum",
"rolling_periods": 0,
"min_periods": 0,
},
)
.apply_rolling(df)["y"]
.tolist(),
[1.0, 3.0, 6.0, 10.0],
)
self.assertEqual(
viz.BigNumberViz(
datasource,
{
"metrics": ["y"],
"rolling_type": "sum",
"rolling_periods": 2,
"min_periods": 0,
},
)
.apply_rolling(df)["y"]
.tolist(),
[1.0, 3.0, 5.0, 7.0],
)
self.assertEqual(
viz.BigNumberViz(
datasource,
{
"metrics": ["y"],
"rolling_type": "mean",
"rolling_periods": 10,
"min_periods": 0,
},
)
.apply_rolling(df)["y"]
.tolist(),
[1.0, 1.5, 2.0, 2.5],
)