feat: Add post processing to QueryObject (#9427)

* Add post processing to QueryObject

* Simplify sort signature and require explicit sort order

* Add new operations and unit tests

* linting

* Address comments

* Simplify test method names

* Address comments

* Linting

* remove unnecessary logic

* Apply strict whitelisting to all getattr calls

* Add checking of rolling_type_options and add/improve docs
This commit is contained in:
Ville Brofeldt 2020-04-10 20:50:11 +03:00 committed by GitHub
parent 5ec0192bcc
commit a8ce3bccdf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 899 additions and 12 deletions

View File

@ -52,7 +52,7 @@ marshmallow==2.19.5 # via flask-appbuilder, marshmallow-enum, marshmallow-
more-itertools==8.1.0 # via zipp
msgpack==0.6.2 # via apache-superset (setup.py)
numpy==1.18.1 # via pandas, pyarrow
pandas==0.25.3 # via apache-superset (setup.py)
pandas==1.0.3 # via apache-superset (setup.py)
parsedatetime==2.5 # via apache-superset (setup.py)
pathlib2==2.3.5 # via apache-superset (setup.py)
polyline==1.4.0 # via apache-superset (setup.py)

View File

@ -88,7 +88,7 @@ setup(
"isodate",
"markdown>=3.0",
"msgpack>=0.6.1, <0.7.0",
"pandas>=0.25.3, <1.0",
"pandas>=1.0.3, <1.1",
"parsedatetime",
"pathlib2",
"polyline",

View File

@ -51,7 +51,7 @@ class QueryContext:
custom_cache_timeout: Optional[int]
# TODO: Type datasource and query_object dictionary with TypedDict when it becomes
# a vanilla python type https://github.com/python/mypy/issues/5288
# a vanilla python type https://github.com/python/mypy/issues/5288
def __init__(
self,
datasource: Dict[str, Any],
@ -70,8 +70,8 @@ class QueryContext:
"""Returns a pandas dataframe based on the query object"""
# Here, we assume that all the queries will use the same datasource, which is
# is a valid assumption for current setting. In a long term, we may or maynot
# support multiple queries from different data source.
# a valid assumption for current setting. In the long term, we may
# support multiple queries from different data sources.
timestamp_format = None
if self.datasource.type == "table":
@ -105,6 +105,9 @@ class QueryContext:
self.df_metrics_to_num(df, query_object)
df.replace([np.inf, -np.inf], np.nan)
df = query_object.exec_post_processing(df)
return {
"query": result.query,
"status": result.status,

View File

@ -20,13 +20,16 @@ from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Union
import simplejson as json
from flask_babel import gettext as _
from pandas import DataFrame
from superset import app
from superset.utils import core as utils
from superset.exceptions import QueryObjectValidationError
from superset.utils import core as utils, pandas_postprocessing
from superset.views.utils import get_time_range_endpoints
# TODO: Type Metrics dictionary with TypedDict when it becomes a vanilla python type
# https://github.com/python/mypy/issues/5288
# https://github.com/python/mypy/issues/5288
class QueryObject:
@ -50,6 +53,7 @@ class QueryObject:
extras: Dict
columns: List[str]
orderby: List[List]
post_processing: List[Dict[str, Any]]
def __init__(
self,
@ -67,6 +71,7 @@ class QueryObject:
extras: Optional[Dict] = None,
columns: Optional[List[str]] = None,
orderby: Optional[List[List]] = None,
post_processing: Optional[List[Dict[str, Any]]] = None,
relative_start: str = app.config["DEFAULT_RELATIVE_START_TIME"],
relative_end: str = app.config["DEFAULT_RELATIVE_END_TIME"],
):
@ -81,8 +86,9 @@ class QueryObject:
self.time_range = time_range
self.time_shift = utils.parse_human_timedelta(time_shift)
self.groupby = groupby or []
self.post_processing = post_processing or []
# Temporal solution for backward compatability issue due the new format of
# Temporary solution for backward compatibility issue due the new format of
# non-ad-hoc metric which needs to adhere to superset-ui per
# https://git.io/Jvm7P.
self.metrics = [
@ -138,9 +144,37 @@ class QueryObject:
if self.time_range:
cache_dict["time_range"] = self.time_range
json_data = self.json_dumps(cache_dict, sort_keys=True)
if self.post_processing:
cache_dict["post_processing"] = self.post_processing
return hashlib.md5(json_data.encode("utf-8")).hexdigest()
def json_dumps(self, obj: Any, sort_keys: bool = False) -> str:
return json.dumps(
obj, default=utils.json_int_dttm_ser, ignore_nan=True, sort_keys=sort_keys
)
def exec_post_processing(self, df: DataFrame) -> DataFrame:
"""
Perform post processing operations on DataFrame.
:param df: DataFrame returned from database model.
:return: new DataFrame to which all post processing operations have been
applied
:raises ChartDataValidationError: If the post processing operation in incorrect
"""
for post_process in self.post_processing:
operation = post_process.get("operation")
if not operation:
raise QueryObjectValidationError(
_("`operation` property of post processing object undefined")
)
if not hasattr(pandas_postprocessing, operation):
raise QueryObjectValidationError(
_(
"Unsupported post processing operation: %(operation)s",
type=operation,
)
)
options = post_process.get("options", {})
df = getattr(pandas_postprocessing, operation)(df, **options)
return df

View File

@ -68,3 +68,7 @@ class CertificateException(SupersetException):
class DatabaseNotFound(SupersetException):
status = 400
class QueryObjectValidationError(SupersetException):
status = 400

View File

@ -0,0 +1,389 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
from flask_babel import gettext as _
from pandas import DataFrame, NamedAgg
from superset.exceptions import QueryObjectValidationError
WHITELIST_NUMPY_FUNCTIONS = (
"average",
"argmin",
"argmax",
"cumsum",
"cumprod",
"max",
"mean",
"median",
"nansum",
"nanmin",
"nanmax",
"nanmean",
"nanmedian",
"min",
"percentile",
"prod",
"product",
"std",
"sum",
"var",
)
WHITELIST_ROLLING_FUNCTIONS = (
"count",
"corr",
"cov",
"kurt",
"max",
"mean",
"median",
"min",
"std",
"skew",
"sum",
"var",
"quantile",
)
WHITELIST_CUMULATIVE_FUNCTIONS = (
"cummax",
"cummin",
"cumprod",
"cumsum",
)
def validate_column_args(*argnames: str) -> Callable:
def wrapper(func):
def wrapped(df, **options):
columns = df.columns.tolist()
for name in argnames:
if name in options and not all(
elem in columns for elem in options[name]
):
raise QueryObjectValidationError(
_("Referenced columns not available in DataFrame.")
)
return func(df, **options)
return wrapped
return wrapper
def _get_aggregate_funcs(
df: DataFrame, aggregates: Dict[str, Dict[str, Any]],
) -> Dict[str, NamedAgg]:
"""
Converts a set of aggregate config objects into functions that pandas can use as
aggregators. Currently only numpy aggregators are supported.
:param df: DataFrame on which to perform aggregate operation.
:param aggregates: Mapping from column name to aggregat config.
:return: Mapping from metric name to function that takes a single input argument.
"""
agg_funcs: Dict[str, NamedAgg] = {}
for name, agg_obj in aggregates.items():
column = agg_obj.get("column", name)
if column not in df:
raise QueryObjectValidationError(
_(
"Column referenced by aggregate is undefined: %(column)s",
column=column,
)
)
if "operator" not in agg_obj:
raise QueryObjectValidationError(
_("Operator undefined for aggregator: %(name)s", name=name,)
)
operator = agg_obj["operator"]
if operator not in WHITELIST_NUMPY_FUNCTIONS or not hasattr(np, operator):
raise QueryObjectValidationError(
_("Invalid numpy function: %(operator)s", operator=operator,)
)
func = getattr(np, operator)
options = agg_obj.get("options", {})
agg_funcs[name] = NamedAgg(column=column, aggfunc=partial(func, **options))
return agg_funcs
def _append_columns(
base_df: DataFrame, append_df: DataFrame, columns: Dict[str, str]
) -> DataFrame:
"""
Function for adding columns from one DataFrame to another DataFrame. Calls the
assign method, which overwrites the original column in `base_df` if the column
already exists, and appends the column if the name is not defined.
:param base_df: DataFrame which to use as the base
:param append_df: DataFrame from which to select data.
:param columns: columns on which to append, mapping source column to
target column. For instance, `{'y': 'y'}` will replace the values in
column `y` in `base_df` with the values in `y` in `append_df`,
while `{'y': 'y2'}` will add a column `y2` to `base_df` based
on values in column `y` in `append_df`, leaving the original column `y`
in `base_df` unchanged.
:return: new DataFrame with combined data from `base_df` and `append_df`
"""
return base_df.assign(
**{
target: append_df[append_df.columns[idx]]
for idx, target in enumerate(columns.values())
}
)
@validate_column_args("index", "columns")
def pivot( # pylint: disable=too-many-arguments
df: DataFrame,
index: List[str],
columns: List[str],
aggregates: Dict[str, Dict[str, Any]],
metric_fill_value: Optional[Any] = None,
column_fill_value: Optional[str] = None,
drop_missing_columns: Optional[bool] = True,
combine_value_with_metric=False,
marginal_distributions: Optional[bool] = None,
marginal_distribution_name: Optional[str] = None,
) -> DataFrame:
"""
Perform a pivot operation on a DataFrame.
:param df: Object on which pivot operation will be performed
:param index: Columns to group by on the table index (=rows)
:param columns: Columns to group by on the table columns
:param metric_fill_value: Value to replace missing values with
:param column_fill_value: Value to replace missing pivot columns with
:param drop_missing_columns: Do not include columns whose entries are all missing
:param combine_value_with_metric: Display metrics side by side within each column,
as opposed to each column being displayed side by side for each metric.
:param aggregates: A mapping from aggregate column name to the the aggregate
config.
:param marginal_distributions: Add totals for row/column. Default to False
:param marginal_distribution_name: Name of row/column with marginal distribution.
Default to 'All'.
:return: A pivot table
:raises ChartDataValidationError: If the request in incorrect
"""
if not index:
raise QueryObjectValidationError(
_("Pivot operation requires at least one index")
)
if not columns:
raise QueryObjectValidationError(
_("Pivot operation requires at least one column")
)
if not aggregates:
raise QueryObjectValidationError(
_("Pivot operation must include at least one aggregate")
)
if column_fill_value:
df[columns] = df[columns].fillna(value=column_fill_value)
aggregate_funcs = _get_aggregate_funcs(df, aggregates)
# TODO (villebro): Pandas 1.0.3 doesn't yet support NamedAgg in pivot_table.
# Remove once/if support is added.
aggfunc = {na.column: na.aggfunc for na in aggregate_funcs.values()}
df = df.pivot_table(
values=aggfunc.keys(),
index=index,
columns=columns,
aggfunc=aggfunc,
fill_value=metric_fill_value,
dropna=drop_missing_columns,
margins=marginal_distributions,
margins_name=marginal_distribution_name,
)
if combine_value_with_metric:
df = df.stack(0).unstack()
return df
@validate_column_args("groupby")
def aggregate(
df: DataFrame, groupby: List[str], aggregates: Dict[str, Dict[str, Any]]
) -> DataFrame:
"""
Apply aggregations to a DataFrame.
:param df: Object to aggregate.
:param groupby: columns to aggregate
:param aggregates: A mapping from metric column to the function used to
aggregate values.
:raises ChartDataValidationError: If the request in incorrect
"""
aggregates = aggregates or {}
aggregate_funcs = _get_aggregate_funcs(df, aggregates)
return df.groupby(by=groupby).agg(**aggregate_funcs).reset_index()
@validate_column_args("columns")
def sort(df: DataFrame, columns: Dict[str, bool]) -> DataFrame:
"""
Sort a DataFrame.
:param df: DataFrame to sort.
:param columns: columns by by which to sort. The key specifies the column name,
value specifies if sorting in ascending order.
:return: Sorted DataFrame
:raises ChartDataValidationError: If the request in incorrect
"""
return df.sort_values(by=list(columns.keys()), ascending=list(columns.values()))
@validate_column_args("columns")
def rolling( # pylint: disable=too-many-arguments
df: DataFrame,
columns: Dict[str, str],
rolling_type: str,
window: int,
rolling_type_options: Optional[Dict[str, Any]] = None,
center: bool = False,
win_type: Optional[str] = None,
min_periods: Optional[int] = None,
) -> DataFrame:
"""
Apply a rolling window on the dataset. See the Pandas docs for further details:
https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.rolling.html
:param df: DataFrame on which the rolling period will be based.
:param columns: columns on which to perform rolling, mapping source column to
target column. For instance, `{'y': 'y'}` will replace the column `y` with
the rolling value in `y`, while `{'y': 'y2'}` will add a column `y2` based
on rolling values calculated from `y`, leaving the original column `y`
unchanged.
:param rolling_type: Type of rolling window. Any numpy function will work.
:param rolling_type_options: Optional options to pass to rolling method. Needed
for e.g. quantile operation.
:param center: Should the label be at the center of the window.
:param win_type: Type of window function.
:param window: Size of the window.
:param min_periods:
:return: DataFrame with the rolling columns
:raises ChartDataValidationError: If the request in incorrect
"""
rolling_type_options = rolling_type_options or {}
df_rolling = df[columns.keys()]
kwargs: Dict[str, Union[str, int]] = {}
if not window:
raise QueryObjectValidationError(_("Undefined window for rolling operation"))
kwargs["window"] = window
if min_periods is not None:
kwargs["min_periods"] = min_periods
if center is not None:
kwargs["center"] = center
if win_type is not None:
kwargs["win_type"] = win_type
df_rolling = df_rolling.rolling(**kwargs)
if rolling_type not in WHITELIST_ROLLING_FUNCTIONS or not hasattr(
df_rolling, rolling_type
):
raise QueryObjectValidationError(
_("Invalid rolling_type: %(type)s", type=rolling_type)
)
try:
df_rolling = getattr(df_rolling, rolling_type)(**rolling_type_options)
except TypeError:
raise QueryObjectValidationError(
_(
"Invalid options for %(rolling_type)s: %(options)s",
rolling_type=rolling_type,
options=rolling_type_options,
)
)
df = _append_columns(df, df_rolling, columns)
if min_periods:
df = df[min_periods:]
return df
@validate_column_args("columns", "rename")
def select(
df: DataFrame, columns: List[str], rename: Optional[Dict[str, str]] = None
) -> DataFrame:
"""
Only select a subset of columns in the original dataset. Can be useful for
removing unnecessary intermediate results, renaming and reordering columns.
:param df: DataFrame on which the rolling period will be based.
:param columns: Columns which to select from the DataFrame, in the desired order.
If columns are renamed, the new column name should be referenced
here.
:param rename: columns which to rename, mapping source column to target column.
For instance, `{'y': 'y2'}` will rename the column `y` to
`y2`.
:return: Subset of columns in original DataFrame
:raises ChartDataValidationError: If the request in incorrect
"""
df_select = df[columns]
if rename is not None:
df_select = df_select.rename(columns=rename)
return df_select
@validate_column_args("columns")
def diff(df: DataFrame, columns: Dict[str, str], periods: int = 1,) -> DataFrame:
"""
:param df: DataFrame on which the diff will be based.
:param columns: columns on which to perform diff, mapping source column to
target column. For instance, `{'y': 'y'}` will replace the column `y` with
the diff value in `y`, while `{'y': 'y2'}` will add a column `y2` based
on diff values calculated from `y`, leaving the original column `y`
unchanged.
:param periods: periods to shift for calculating difference.
:return: DataFrame with diffed columns
:raises ChartDataValidationError: If the request in incorrect
"""
df_diff = df[columns.keys()]
df_diff = df_diff.diff(periods=periods)
return _append_columns(df, df_diff, columns)
@validate_column_args("columns")
def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame:
"""
:param df: DataFrame on which the cumulative operation will be based.
:param columns: columns on which to perform a cumulative operation, mapping source
column to target column. For instance, `{'y': 'y'}` will replace the column
`y` with the cumulative value in `y`, while `{'y': 'y2'}` will add a column
`y2` based on cumulative values calculated from `y`, leaving the original
column `y` unchanged.
:param operator: cumulative operator, e.g. `sum`, `prod`, `min`, `max`
:return:
"""
df_cum = df[columns.keys()]
operation = "cum" + operator
if operation not in WHITELIST_CUMULATIVE_FUNCTIONS or not hasattr(
df_cum, operation
):
raise QueryObjectValidationError(
_("Invalid cumulative operator: %(operator)s", operator=operator)
)
return _append_columns(df, getattr(df_cum, operation)(), columns)

View File

@ -111,7 +111,7 @@ class CoreTests(SupersetTestCase):
resp = self.client.get("/superset/slice/-1/")
assert resp.status_code == 404
def _get_query_context_dict(self) -> Dict[str, Any]:
def _get_query_context(self) -> Dict[str, Any]:
self.login(username="admin")
slc = self.get_slice("Girl Name Cloud", db.session)
return {
@ -127,6 +127,45 @@ class CoreTests(SupersetTestCase):
],
}
def _get_query_context_with_post_processing(self) -> Dict[str, Any]:
self.login(username="admin")
slc = self.get_slice("Girl Name Cloud", db.session)
return {
"datasource": {"id": slc.datasource_id, "type": slc.datasource_type},
"queries": [
{
"granularity": "ds",
"groupby": ["name", "state"],
"metrics": [{"label": "sum__num"}],
"filters": [],
"row_limit": 100,
"post_processing": [
{
"operation": "aggregate",
"options": {
"groupby": ["state"],
"aggregates": {
"q1": {
"operator": "percentile",
"column": "sum__num",
"options": {"q": 25},
},
"median": {
"operator": "median",
"column": "sum__num",
},
},
},
},
{
"operation": "sort",
"options": {"columns": {"q1": False, "state": True},},
},
],
}
],
}
def test_viz_cache_key(self):
self.login(username="admin")
slc = self.get_slice("Girls", db.session)
@ -140,7 +179,7 @@ class CoreTests(SupersetTestCase):
self.assertNotEqual(cache_key, viz.cache_key(qobj))
def test_cache_key_changes_when_datasource_is_updated(self):
qc_dict = self._get_query_context_dict()
qc_dict = self._get_query_context()
# construct baseline cache_key
query_context = QueryContext(**qc_dict)
@ -168,7 +207,7 @@ class CoreTests(SupersetTestCase):
self.assertNotEqual(cache_key_original, cache_key_new)
def test_query_context_time_range_endpoints(self):
query_context = QueryContext(**self._get_query_context_dict())
query_context = QueryContext(**self._get_query_context())
query_object = query_context.queries[0]
extras = query_object.to_dict()["extras"]
self.assertTrue("time_range_endpoints" in extras)
@ -217,11 +256,18 @@ class CoreTests(SupersetTestCase):
def test_api_v1_query_endpoint(self):
self.login(username="admin")
qc_dict = self._get_query_context_dict()
qc_dict = self._get_query_context()
data = json.dumps(qc_dict)
resp = json.loads(self.get_resp("/api/v1/query/", {"query_context": data}))
self.assertEqual(resp[0]["rowcount"], 100)
def test_api_v1_query_endpoint_with_post_processing(self):
self.login(username="admin")
qc_dict = self._get_query_context_with_post_processing()
data = json.dumps(qc_dict)
resp = json.loads(self.get_resp("/api/v1/query/", {"query_context": data}))
self.assertEqual(resp[0]["rowcount"], 6)
def test_old_slice_json_endpoint(self):
self.login(username="admin")
slc = self.get_slice("Girls", db.session)

121
tests/fixtures/dataframes.py vendored Normal file
View File

@ -0,0 +1,121 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import date
from pandas import DataFrame, to_datetime
names_df = DataFrame(
[
{
"dt": date(2020, 1, 2),
"name": "John",
"country": "United Kingdom",
"cars": 3,
"bikes": 1,
"seconds": 30,
},
{
"dt": date(2020, 1, 2),
"name": "Peter",
"country": "Sweden",
"cars": 4,
"bikes": 2,
"seconds": 1,
},
{
"dt": date(2020, 1, 3),
"name": "Mary",
"country": "Finland",
"cars": 5,
"bikes": 3,
"seconds": None,
},
{
"dt": date(2020, 1, 3),
"name": "Peter",
"country": "India",
"cars": 6,
"bikes": 4,
"seconds": 12,
},
{
"dt": date(2020, 1, 4),
"name": "John",
"country": "Portugal",
"cars": 7,
"bikes": None,
"seconds": 75,
},
{
"dt": date(2020, 1, 4),
"name": "Peter",
"country": "Italy",
"cars": None,
"bikes": 5,
"seconds": 600,
},
{
"dt": date(2020, 1, 4),
"name": "Mary",
"country": None,
"cars": 9,
"bikes": 6,
"seconds": 2,
},
{
"dt": date(2020, 1, 4),
"name": None,
"country": "Australia",
"cars": 10,
"bikes": 7,
"seconds": 99,
},
{
"dt": date(2020, 1, 1),
"name": "John",
"country": "USA",
"cars": 1,
"bikes": 8,
"seconds": None,
},
{
"dt": date(2020, 1, 1),
"name": "Mary",
"country": "Fiji",
"cars": 2,
"bikes": 9,
"seconds": 50,
},
]
)
categories_df = DataFrame(
{
"constant": ["dummy" for _ in range(0, 101)],
"category": [f"cat{i%3}" for i in range(0, 101)],
"dept": [f"dept{i%5}" for i in range(0, 101)],
"name": [f"person{i}" for i in range(0, 101)],
"asc_idx": [i for i in range(0, 101)],
"desc_idx": [i for i in range(100, -1, -1)],
"idx_nulls": [i if i % 5 == 0 else None for i in range(0, 101)],
}
)
timeseries_df = DataFrame(
index=to_datetime(["2019-01-01", "2019-01-02", "2019-01-05", "2019-01-07"]),
data={"label": ["x", "y", "z", "q"], "y": [1.0, 2.0, 3.0, 4.0]},
)

View File

@ -0,0 +1,290 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
import math
from typing import Any, List
from pandas import Series
from superset.exceptions import QueryObjectValidationError
from superset.utils import pandas_postprocessing as proc
from .base_tests import SupersetTestCase
from .fixtures.dataframes import categories_df, timeseries_df
def series_to_list(series: Series) -> List[Any]:
"""
Converts a `Series` to a regular list, and replaces non-numeric values to
Nones.
:param series: Series to convert
:return: list without nan or inf
"""
return [
None
if not isinstance(val, str) and (math.isnan(val) or math.isinf(val))
else val
for val in series.tolist()
]
class PostProcessingTestCase(SupersetTestCase):
def test_pivot(self):
aggregates = {"idx_nulls": {"operator": "sum"}}
# regular pivot
df = proc.pivot(
df=categories_df,
index=["name"],
columns=["category"],
aggregates=aggregates,
)
self.assertListEqual(
df.columns.tolist(),
[("idx_nulls", "cat0"), ("idx_nulls", "cat1"), ("idx_nulls", "cat2")],
)
self.assertEqual(len(df), 101)
self.assertEqual(df.sum()[0], 315)
# regular pivot
df = proc.pivot(
df=categories_df,
index=["dept"],
columns=["category"],
aggregates=aggregates,
)
self.assertEqual(len(df), 5)
# fill value
df = proc.pivot(
df=categories_df,
index=["name"],
columns=["category"],
metric_fill_value=1,
aggregates={"idx_nulls": {"operator": "sum"}},
)
self.assertEqual(df.sum()[0], 382)
# invalid index reference
self.assertRaises(
QueryObjectValidationError,
proc.pivot,
df=categories_df,
index=["abc"],
columns=["dept"],
aggregates=aggregates,
)
# invalid column reference
self.assertRaises(
QueryObjectValidationError,
proc.pivot,
df=categories_df,
index=["dept"],
columns=["abc"],
aggregates=aggregates,
)
# invalid aggregate options
self.assertRaises(
QueryObjectValidationError,
proc.pivot,
df=categories_df,
index=["name"],
columns=["category"],
aggregates={"idx_nulls": {}},
)
def test_aggregate(self):
aggregates = {
"asc sum": {"column": "asc_idx", "operator": "sum"},
"asc q2": {
"column": "asc_idx",
"operator": "percentile",
"options": {"q": 75},
},
"desc q1": {
"column": "desc_idx",
"operator": "percentile",
"options": {"q": 25},
},
}
df = proc.aggregate(
df=categories_df, groupby=["constant"], aggregates=aggregates
)
self.assertListEqual(
df.columns.tolist(), ["constant", "asc sum", "asc q2", "desc q1"]
)
self.assertEqual(series_to_list(df["asc sum"])[0], 5050)
self.assertEqual(series_to_list(df["asc q2"])[0], 75)
self.assertEqual(series_to_list(df["desc q1"])[0], 25)
def test_sort(self):
df = proc.sort(df=categories_df, columns={"category": True, "asc_idx": False})
self.assertEqual(96, series_to_list(df["asc_idx"])[1])
self.assertRaises(
QueryObjectValidationError, proc.sort, df=df, columns={"abc": True}
)
def test_rolling(self):
# sum rolling type
post_df = proc.rolling(
df=timeseries_df,
columns={"y": "y"},
rolling_type="sum",
window=2,
min_periods=0,
)
self.assertListEqual(post_df.columns.tolist(), ["label", "y"])
self.assertListEqual(series_to_list(post_df["y"]), [1.0, 3.0, 5.0, 7.0])
# mean rolling type with alias
post_df = proc.rolling(
df=timeseries_df,
rolling_type="mean",
columns={"y": "y_mean"},
window=10,
min_periods=0,
)
self.assertListEqual(post_df.columns.tolist(), ["label", "y", "y_mean"])
self.assertListEqual(series_to_list(post_df["y_mean"]), [1.0, 1.5, 2.0, 2.5])
# count rolling type
post_df = proc.rolling(
df=timeseries_df,
rolling_type="count",
columns={"y": "y"},
window=10,
min_periods=0,
)
self.assertListEqual(post_df.columns.tolist(), ["label", "y"])
self.assertListEqual(series_to_list(post_df["y"]), [1.0, 2.0, 3.0, 4.0])
# quantile rolling type
post_df = proc.rolling(
df=timeseries_df,
columns={"y": "q1"},
rolling_type="quantile",
rolling_type_options={"quantile": 0.25},
window=10,
min_periods=0,
)
self.assertListEqual(post_df.columns.tolist(), ["label", "y", "q1"])
self.assertListEqual(series_to_list(post_df["q1"]), [1.0, 1.25, 1.5, 1.75])
# incorrect rolling type
self.assertRaises(
QueryObjectValidationError,
proc.rolling,
df=timeseries_df,
columns={"y": "y"},
rolling_type="abc",
window=2,
)
# incorrect rolling type options
self.assertRaises(
QueryObjectValidationError,
proc.rolling,
df=timeseries_df,
columns={"y": "y"},
rolling_type="quantile",
rolling_type_options={"abc": 123},
window=2,
)
def test_select(self):
# reorder columns
post_df = proc.select(df=timeseries_df, columns=["y", "label"])
self.assertListEqual(post_df.columns.tolist(), ["y", "label"])
# one column
post_df = proc.select(df=timeseries_df, columns=["label"])
self.assertListEqual(post_df.columns.tolist(), ["label"])
# rename one column
post_df = proc.select(df=timeseries_df, columns=["y"], rename={"y": "y1"})
self.assertListEqual(post_df.columns.tolist(), ["y1"])
# rename one and leave one unchanged
post_df = proc.select(
df=timeseries_df, columns=["label", "y"], rename={"y": "y1"}
)
self.assertListEqual(post_df.columns.tolist(), ["label", "y1"])
# invalid columns
self.assertRaises(
QueryObjectValidationError,
proc.select,
df=timeseries_df,
columns=["qwerty"],
rename={"abc": "qwerty"},
)
def test_diff(self):
# overwrite column
post_df = proc.diff(df=timeseries_df, columns={"y": "y"})
self.assertListEqual(post_df.columns.tolist(), ["label", "y"])
self.assertListEqual(series_to_list(post_df["y"]), [None, 1.0, 1.0, 1.0])
# add column
post_df = proc.diff(df=timeseries_df, columns={"y": "y1"})
self.assertListEqual(post_df.columns.tolist(), ["label", "y", "y1"])
self.assertListEqual(series_to_list(post_df["y"]), [1.0, 2.0, 3.0, 4.0])
self.assertListEqual(series_to_list(post_df["y1"]), [None, 1.0, 1.0, 1.0])
# look ahead
post_df = proc.diff(df=timeseries_df, columns={"y": "y1"}, periods=-1)
self.assertListEqual(series_to_list(post_df["y1"]), [-1.0, -1.0, -1.0, None])
# invalid column reference
self.assertRaises(
QueryObjectValidationError,
proc.diff,
df=timeseries_df,
columns={"abc": "abc"},
)
def test_cum(self):
# create new column (cumsum)
post_df = proc.cum(df=timeseries_df, columns={"y": "y2"}, operator="sum",)
self.assertListEqual(post_df.columns.tolist(), ["label", "y", "y2"])
self.assertListEqual(series_to_list(post_df["label"]), ["x", "y", "z", "q"])
self.assertListEqual(series_to_list(post_df["y"]), [1.0, 2.0, 3.0, 4.0])
self.assertListEqual(series_to_list(post_df["y2"]), [1.0, 3.0, 6.0, 10.0])
# overwrite column (cumprod)
post_df = proc.cum(df=timeseries_df, columns={"y": "y"}, operator="prod",)
self.assertListEqual(post_df.columns.tolist(), ["label", "y"])
self.assertListEqual(series_to_list(post_df["y"]), [1.0, 2.0, 6.0, 24.0])
# overwrite column (cummin)
post_df = proc.cum(df=timeseries_df, columns={"y": "y"}, operator="min",)
self.assertListEqual(post_df.columns.tolist(), ["label", "y"])
self.assertListEqual(series_to_list(post_df["y"]), [1.0, 1.0, 1.0, 1.0])
# invalid operator
self.assertRaises(
QueryObjectValidationError,
proc.cum,
df=timeseries_df,
columns={"y": "y"},
operator="abc",
)