feat: Add post processing to QueryObject (#9427)
* Add post processing to QueryObject * Simplify sort signature and require explicit sort order * Add new operations and unit tests * linting * Address comments * Simplify test method names * Address comments * Linting * remove unnecessary logic * Apply strict whitelisting to all getattr calls * Add checking of rolling_type_options and add/improve docs
This commit is contained in:
parent
5ec0192bcc
commit
a8ce3bccdf
|
|
@ -52,7 +52,7 @@ marshmallow==2.19.5 # via flask-appbuilder, marshmallow-enum, marshmallow-
|
|||
more-itertools==8.1.0 # via zipp
|
||||
msgpack==0.6.2 # via apache-superset (setup.py)
|
||||
numpy==1.18.1 # via pandas, pyarrow
|
||||
pandas==0.25.3 # via apache-superset (setup.py)
|
||||
pandas==1.0.3 # via apache-superset (setup.py)
|
||||
parsedatetime==2.5 # via apache-superset (setup.py)
|
||||
pathlib2==2.3.5 # via apache-superset (setup.py)
|
||||
polyline==1.4.0 # via apache-superset (setup.py)
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -88,7 +88,7 @@ setup(
|
|||
"isodate",
|
||||
"markdown>=3.0",
|
||||
"msgpack>=0.6.1, <0.7.0",
|
||||
"pandas>=0.25.3, <1.0",
|
||||
"pandas>=1.0.3, <1.1",
|
||||
"parsedatetime",
|
||||
"pathlib2",
|
||||
"polyline",
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ class QueryContext:
|
|||
custom_cache_timeout: Optional[int]
|
||||
|
||||
# TODO: Type datasource and query_object dictionary with TypedDict when it becomes
|
||||
# a vanilla python type https://github.com/python/mypy/issues/5288
|
||||
# a vanilla python type https://github.com/python/mypy/issues/5288
|
||||
def __init__(
|
||||
self,
|
||||
datasource: Dict[str, Any],
|
||||
|
|
@ -70,8 +70,8 @@ class QueryContext:
|
|||
"""Returns a pandas dataframe based on the query object"""
|
||||
|
||||
# Here, we assume that all the queries will use the same datasource, which is
|
||||
# is a valid assumption for current setting. In a long term, we may or maynot
|
||||
# support multiple queries from different data source.
|
||||
# a valid assumption for current setting. In the long term, we may
|
||||
# support multiple queries from different data sources.
|
||||
|
||||
timestamp_format = None
|
||||
if self.datasource.type == "table":
|
||||
|
|
@ -105,6 +105,9 @@ class QueryContext:
|
|||
self.df_metrics_to_num(df, query_object)
|
||||
|
||||
df.replace([np.inf, -np.inf], np.nan)
|
||||
|
||||
df = query_object.exec_post_processing(df)
|
||||
|
||||
return {
|
||||
"query": result.query,
|
||||
"status": result.status,
|
||||
|
|
|
|||
|
|
@ -20,13 +20,16 @@ from datetime import datetime, timedelta
|
|||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import simplejson as json
|
||||
from flask_babel import gettext as _
|
||||
from pandas import DataFrame
|
||||
|
||||
from superset import app
|
||||
from superset.utils import core as utils
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.utils import core as utils, pandas_postprocessing
|
||||
from superset.views.utils import get_time_range_endpoints
|
||||
|
||||
# TODO: Type Metrics dictionary with TypedDict when it becomes a vanilla python type
|
||||
# https://github.com/python/mypy/issues/5288
|
||||
# https://github.com/python/mypy/issues/5288
|
||||
|
||||
|
||||
class QueryObject:
|
||||
|
|
@ -50,6 +53,7 @@ class QueryObject:
|
|||
extras: Dict
|
||||
columns: List[str]
|
||||
orderby: List[List]
|
||||
post_processing: List[Dict[str, Any]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -67,6 +71,7 @@ class QueryObject:
|
|||
extras: Optional[Dict] = None,
|
||||
columns: Optional[List[str]] = None,
|
||||
orderby: Optional[List[List]] = None,
|
||||
post_processing: Optional[List[Dict[str, Any]]] = None,
|
||||
relative_start: str = app.config["DEFAULT_RELATIVE_START_TIME"],
|
||||
relative_end: str = app.config["DEFAULT_RELATIVE_END_TIME"],
|
||||
):
|
||||
|
|
@ -81,8 +86,9 @@ class QueryObject:
|
|||
self.time_range = time_range
|
||||
self.time_shift = utils.parse_human_timedelta(time_shift)
|
||||
self.groupby = groupby or []
|
||||
self.post_processing = post_processing or []
|
||||
|
||||
# Temporal solution for backward compatability issue due the new format of
|
||||
# Temporary solution for backward compatibility issue due the new format of
|
||||
# non-ad-hoc metric which needs to adhere to superset-ui per
|
||||
# https://git.io/Jvm7P.
|
||||
self.metrics = [
|
||||
|
|
@ -138,9 +144,37 @@ class QueryObject:
|
|||
if self.time_range:
|
||||
cache_dict["time_range"] = self.time_range
|
||||
json_data = self.json_dumps(cache_dict, sort_keys=True)
|
||||
if self.post_processing:
|
||||
cache_dict["post_processing"] = self.post_processing
|
||||
return hashlib.md5(json_data.encode("utf-8")).hexdigest()
|
||||
|
||||
def json_dumps(self, obj: Any, sort_keys: bool = False) -> str:
|
||||
return json.dumps(
|
||||
obj, default=utils.json_int_dttm_ser, ignore_nan=True, sort_keys=sort_keys
|
||||
)
|
||||
|
||||
def exec_post_processing(self, df: DataFrame) -> DataFrame:
|
||||
"""
|
||||
Perform post processing operations on DataFrame.
|
||||
|
||||
:param df: DataFrame returned from database model.
|
||||
:return: new DataFrame to which all post processing operations have been
|
||||
applied
|
||||
:raises ChartDataValidationError: If the post processing operation in incorrect
|
||||
"""
|
||||
for post_process in self.post_processing:
|
||||
operation = post_process.get("operation")
|
||||
if not operation:
|
||||
raise QueryObjectValidationError(
|
||||
_("`operation` property of post processing object undefined")
|
||||
)
|
||||
if not hasattr(pandas_postprocessing, operation):
|
||||
raise QueryObjectValidationError(
|
||||
_(
|
||||
"Unsupported post processing operation: %(operation)s",
|
||||
type=operation,
|
||||
)
|
||||
)
|
||||
options = post_process.get("options", {})
|
||||
df = getattr(pandas_postprocessing, operation)(df, **options)
|
||||
return df
|
||||
|
|
|
|||
|
|
@ -68,3 +68,7 @@ class CertificateException(SupersetException):
|
|||
|
||||
class DatabaseNotFound(SupersetException):
|
||||
status = 400
|
||||
|
||||
|
||||
class QueryObjectValidationError(SupersetException):
|
||||
status = 400
|
||||
|
|
|
|||
|
|
@ -0,0 +1,389 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from flask_babel import gettext as _
|
||||
from pandas import DataFrame, NamedAgg
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
|
||||
WHITELIST_NUMPY_FUNCTIONS = (
|
||||
"average",
|
||||
"argmin",
|
||||
"argmax",
|
||||
"cumsum",
|
||||
"cumprod",
|
||||
"max",
|
||||
"mean",
|
||||
"median",
|
||||
"nansum",
|
||||
"nanmin",
|
||||
"nanmax",
|
||||
"nanmean",
|
||||
"nanmedian",
|
||||
"min",
|
||||
"percentile",
|
||||
"prod",
|
||||
"product",
|
||||
"std",
|
||||
"sum",
|
||||
"var",
|
||||
)
|
||||
|
||||
WHITELIST_ROLLING_FUNCTIONS = (
|
||||
"count",
|
||||
"corr",
|
||||
"cov",
|
||||
"kurt",
|
||||
"max",
|
||||
"mean",
|
||||
"median",
|
||||
"min",
|
||||
"std",
|
||||
"skew",
|
||||
"sum",
|
||||
"var",
|
||||
"quantile",
|
||||
)
|
||||
|
||||
WHITELIST_CUMULATIVE_FUNCTIONS = (
|
||||
"cummax",
|
||||
"cummin",
|
||||
"cumprod",
|
||||
"cumsum",
|
||||
)
|
||||
|
||||
|
||||
def validate_column_args(*argnames: str) -> Callable:
|
||||
def wrapper(func):
|
||||
def wrapped(df, **options):
|
||||
columns = df.columns.tolist()
|
||||
for name in argnames:
|
||||
if name in options and not all(
|
||||
elem in columns for elem in options[name]
|
||||
):
|
||||
raise QueryObjectValidationError(
|
||||
_("Referenced columns not available in DataFrame.")
|
||||
)
|
||||
return func(df, **options)
|
||||
|
||||
return wrapped
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _get_aggregate_funcs(
|
||||
df: DataFrame, aggregates: Dict[str, Dict[str, Any]],
|
||||
) -> Dict[str, NamedAgg]:
|
||||
"""
|
||||
Converts a set of aggregate config objects into functions that pandas can use as
|
||||
aggregators. Currently only numpy aggregators are supported.
|
||||
|
||||
:param df: DataFrame on which to perform aggregate operation.
|
||||
:param aggregates: Mapping from column name to aggregat config.
|
||||
:return: Mapping from metric name to function that takes a single input argument.
|
||||
"""
|
||||
agg_funcs: Dict[str, NamedAgg] = {}
|
||||
for name, agg_obj in aggregates.items():
|
||||
column = agg_obj.get("column", name)
|
||||
if column not in df:
|
||||
raise QueryObjectValidationError(
|
||||
_(
|
||||
"Column referenced by aggregate is undefined: %(column)s",
|
||||
column=column,
|
||||
)
|
||||
)
|
||||
if "operator" not in agg_obj:
|
||||
raise QueryObjectValidationError(
|
||||
_("Operator undefined for aggregator: %(name)s", name=name,)
|
||||
)
|
||||
operator = agg_obj["operator"]
|
||||
if operator not in WHITELIST_NUMPY_FUNCTIONS or not hasattr(np, operator):
|
||||
raise QueryObjectValidationError(
|
||||
_("Invalid numpy function: %(operator)s", operator=operator,)
|
||||
)
|
||||
func = getattr(np, operator)
|
||||
options = agg_obj.get("options", {})
|
||||
agg_funcs[name] = NamedAgg(column=column, aggfunc=partial(func, **options))
|
||||
|
||||
return agg_funcs
|
||||
|
||||
|
||||
def _append_columns(
|
||||
base_df: DataFrame, append_df: DataFrame, columns: Dict[str, str]
|
||||
) -> DataFrame:
|
||||
"""
|
||||
Function for adding columns from one DataFrame to another DataFrame. Calls the
|
||||
assign method, which overwrites the original column in `base_df` if the column
|
||||
already exists, and appends the column if the name is not defined.
|
||||
|
||||
:param base_df: DataFrame which to use as the base
|
||||
:param append_df: DataFrame from which to select data.
|
||||
:param columns: columns on which to append, mapping source column to
|
||||
target column. For instance, `{'y': 'y'}` will replace the values in
|
||||
column `y` in `base_df` with the values in `y` in `append_df`,
|
||||
while `{'y': 'y2'}` will add a column `y2` to `base_df` based
|
||||
on values in column `y` in `append_df`, leaving the original column `y`
|
||||
in `base_df` unchanged.
|
||||
:return: new DataFrame with combined data from `base_df` and `append_df`
|
||||
"""
|
||||
return base_df.assign(
|
||||
**{
|
||||
target: append_df[append_df.columns[idx]]
|
||||
for idx, target in enumerate(columns.values())
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@validate_column_args("index", "columns")
|
||||
def pivot( # pylint: disable=too-many-arguments
|
||||
df: DataFrame,
|
||||
index: List[str],
|
||||
columns: List[str],
|
||||
aggregates: Dict[str, Dict[str, Any]],
|
||||
metric_fill_value: Optional[Any] = None,
|
||||
column_fill_value: Optional[str] = None,
|
||||
drop_missing_columns: Optional[bool] = True,
|
||||
combine_value_with_metric=False,
|
||||
marginal_distributions: Optional[bool] = None,
|
||||
marginal_distribution_name: Optional[str] = None,
|
||||
) -> DataFrame:
|
||||
"""
|
||||
Perform a pivot operation on a DataFrame.
|
||||
|
||||
:param df: Object on which pivot operation will be performed
|
||||
:param index: Columns to group by on the table index (=rows)
|
||||
:param columns: Columns to group by on the table columns
|
||||
:param metric_fill_value: Value to replace missing values with
|
||||
:param column_fill_value: Value to replace missing pivot columns with
|
||||
:param drop_missing_columns: Do not include columns whose entries are all missing
|
||||
:param combine_value_with_metric: Display metrics side by side within each column,
|
||||
as opposed to each column being displayed side by side for each metric.
|
||||
:param aggregates: A mapping from aggregate column name to the the aggregate
|
||||
config.
|
||||
:param marginal_distributions: Add totals for row/column. Default to False
|
||||
:param marginal_distribution_name: Name of row/column with marginal distribution.
|
||||
Default to 'All'.
|
||||
:return: A pivot table
|
||||
:raises ChartDataValidationError: If the request in incorrect
|
||||
"""
|
||||
if not index:
|
||||
raise QueryObjectValidationError(
|
||||
_("Pivot operation requires at least one index")
|
||||
)
|
||||
if not columns:
|
||||
raise QueryObjectValidationError(
|
||||
_("Pivot operation requires at least one column")
|
||||
)
|
||||
if not aggregates:
|
||||
raise QueryObjectValidationError(
|
||||
_("Pivot operation must include at least one aggregate")
|
||||
)
|
||||
|
||||
if column_fill_value:
|
||||
df[columns] = df[columns].fillna(value=column_fill_value)
|
||||
|
||||
aggregate_funcs = _get_aggregate_funcs(df, aggregates)
|
||||
|
||||
# TODO (villebro): Pandas 1.0.3 doesn't yet support NamedAgg in pivot_table.
|
||||
# Remove once/if support is added.
|
||||
aggfunc = {na.column: na.aggfunc for na in aggregate_funcs.values()}
|
||||
|
||||
df = df.pivot_table(
|
||||
values=aggfunc.keys(),
|
||||
index=index,
|
||||
columns=columns,
|
||||
aggfunc=aggfunc,
|
||||
fill_value=metric_fill_value,
|
||||
dropna=drop_missing_columns,
|
||||
margins=marginal_distributions,
|
||||
margins_name=marginal_distribution_name,
|
||||
)
|
||||
|
||||
if combine_value_with_metric:
|
||||
df = df.stack(0).unstack()
|
||||
|
||||
return df
|
||||
|
||||
|
||||
@validate_column_args("groupby")
|
||||
def aggregate(
|
||||
df: DataFrame, groupby: List[str], aggregates: Dict[str, Dict[str, Any]]
|
||||
) -> DataFrame:
|
||||
"""
|
||||
Apply aggregations to a DataFrame.
|
||||
|
||||
:param df: Object to aggregate.
|
||||
:param groupby: columns to aggregate
|
||||
:param aggregates: A mapping from metric column to the function used to
|
||||
aggregate values.
|
||||
:raises ChartDataValidationError: If the request in incorrect
|
||||
"""
|
||||
aggregates = aggregates or {}
|
||||
aggregate_funcs = _get_aggregate_funcs(df, aggregates)
|
||||
return df.groupby(by=groupby).agg(**aggregate_funcs).reset_index()
|
||||
|
||||
|
||||
@validate_column_args("columns")
|
||||
def sort(df: DataFrame, columns: Dict[str, bool]) -> DataFrame:
|
||||
"""
|
||||
Sort a DataFrame.
|
||||
|
||||
:param df: DataFrame to sort.
|
||||
:param columns: columns by by which to sort. The key specifies the column name,
|
||||
value specifies if sorting in ascending order.
|
||||
:return: Sorted DataFrame
|
||||
:raises ChartDataValidationError: If the request in incorrect
|
||||
"""
|
||||
return df.sort_values(by=list(columns.keys()), ascending=list(columns.values()))
|
||||
|
||||
|
||||
@validate_column_args("columns")
|
||||
def rolling( # pylint: disable=too-many-arguments
|
||||
df: DataFrame,
|
||||
columns: Dict[str, str],
|
||||
rolling_type: str,
|
||||
window: int,
|
||||
rolling_type_options: Optional[Dict[str, Any]] = None,
|
||||
center: bool = False,
|
||||
win_type: Optional[str] = None,
|
||||
min_periods: Optional[int] = None,
|
||||
) -> DataFrame:
|
||||
"""
|
||||
Apply a rolling window on the dataset. See the Pandas docs for further details:
|
||||
https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.rolling.html
|
||||
|
||||
:param df: DataFrame on which the rolling period will be based.
|
||||
:param columns: columns on which to perform rolling, mapping source column to
|
||||
target column. For instance, `{'y': 'y'}` will replace the column `y` with
|
||||
the rolling value in `y`, while `{'y': 'y2'}` will add a column `y2` based
|
||||
on rolling values calculated from `y`, leaving the original column `y`
|
||||
unchanged.
|
||||
:param rolling_type: Type of rolling window. Any numpy function will work.
|
||||
:param rolling_type_options: Optional options to pass to rolling method. Needed
|
||||
for e.g. quantile operation.
|
||||
:param center: Should the label be at the center of the window.
|
||||
:param win_type: Type of window function.
|
||||
:param window: Size of the window.
|
||||
:param min_periods:
|
||||
:return: DataFrame with the rolling columns
|
||||
:raises ChartDataValidationError: If the request in incorrect
|
||||
"""
|
||||
rolling_type_options = rolling_type_options or {}
|
||||
df_rolling = df[columns.keys()]
|
||||
kwargs: Dict[str, Union[str, int]] = {}
|
||||
if not window:
|
||||
raise QueryObjectValidationError(_("Undefined window for rolling operation"))
|
||||
|
||||
kwargs["window"] = window
|
||||
if min_periods is not None:
|
||||
kwargs["min_periods"] = min_periods
|
||||
if center is not None:
|
||||
kwargs["center"] = center
|
||||
if win_type is not None:
|
||||
kwargs["win_type"] = win_type
|
||||
|
||||
df_rolling = df_rolling.rolling(**kwargs)
|
||||
if rolling_type not in WHITELIST_ROLLING_FUNCTIONS or not hasattr(
|
||||
df_rolling, rolling_type
|
||||
):
|
||||
raise QueryObjectValidationError(
|
||||
_("Invalid rolling_type: %(type)s", type=rolling_type)
|
||||
)
|
||||
try:
|
||||
df_rolling = getattr(df_rolling, rolling_type)(**rolling_type_options)
|
||||
except TypeError:
|
||||
raise QueryObjectValidationError(
|
||||
_(
|
||||
"Invalid options for %(rolling_type)s: %(options)s",
|
||||
rolling_type=rolling_type,
|
||||
options=rolling_type_options,
|
||||
)
|
||||
)
|
||||
df = _append_columns(df, df_rolling, columns)
|
||||
if min_periods:
|
||||
df = df[min_periods:]
|
||||
return df
|
||||
|
||||
|
||||
@validate_column_args("columns", "rename")
|
||||
def select(
|
||||
df: DataFrame, columns: List[str], rename: Optional[Dict[str, str]] = None
|
||||
) -> DataFrame:
|
||||
"""
|
||||
Only select a subset of columns in the original dataset. Can be useful for
|
||||
removing unnecessary intermediate results, renaming and reordering columns.
|
||||
|
||||
:param df: DataFrame on which the rolling period will be based.
|
||||
:param columns: Columns which to select from the DataFrame, in the desired order.
|
||||
If columns are renamed, the new column name should be referenced
|
||||
here.
|
||||
:param rename: columns which to rename, mapping source column to target column.
|
||||
For instance, `{'y': 'y2'}` will rename the column `y` to
|
||||
`y2`.
|
||||
:return: Subset of columns in original DataFrame
|
||||
:raises ChartDataValidationError: If the request in incorrect
|
||||
"""
|
||||
df_select = df[columns]
|
||||
if rename is not None:
|
||||
df_select = df_select.rename(columns=rename)
|
||||
return df_select
|
||||
|
||||
|
||||
@validate_column_args("columns")
|
||||
def diff(df: DataFrame, columns: Dict[str, str], periods: int = 1,) -> DataFrame:
|
||||
"""
|
||||
|
||||
:param df: DataFrame on which the diff will be based.
|
||||
:param columns: columns on which to perform diff, mapping source column to
|
||||
target column. For instance, `{'y': 'y'}` will replace the column `y` with
|
||||
the diff value in `y`, while `{'y': 'y2'}` will add a column `y2` based
|
||||
on diff values calculated from `y`, leaving the original column `y`
|
||||
unchanged.
|
||||
:param periods: periods to shift for calculating difference.
|
||||
:return: DataFrame with diffed columns
|
||||
:raises ChartDataValidationError: If the request in incorrect
|
||||
"""
|
||||
df_diff = df[columns.keys()]
|
||||
df_diff = df_diff.diff(periods=periods)
|
||||
return _append_columns(df, df_diff, columns)
|
||||
|
||||
|
||||
@validate_column_args("columns")
|
||||
def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame:
|
||||
"""
|
||||
|
||||
:param df: DataFrame on which the cumulative operation will be based.
|
||||
:param columns: columns on which to perform a cumulative operation, mapping source
|
||||
column to target column. For instance, `{'y': 'y'}` will replace the column
|
||||
`y` with the cumulative value in `y`, while `{'y': 'y2'}` will add a column
|
||||
`y2` based on cumulative values calculated from `y`, leaving the original
|
||||
column `y` unchanged.
|
||||
:param operator: cumulative operator, e.g. `sum`, `prod`, `min`, `max`
|
||||
:return:
|
||||
"""
|
||||
df_cum = df[columns.keys()]
|
||||
operation = "cum" + operator
|
||||
if operation not in WHITELIST_CUMULATIVE_FUNCTIONS or not hasattr(
|
||||
df_cum, operation
|
||||
):
|
||||
raise QueryObjectValidationError(
|
||||
_("Invalid cumulative operator: %(operator)s", operator=operator)
|
||||
)
|
||||
return _append_columns(df, getattr(df_cum, operation)(), columns)
|
||||
|
|
@ -111,7 +111,7 @@ class CoreTests(SupersetTestCase):
|
|||
resp = self.client.get("/superset/slice/-1/")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def _get_query_context_dict(self) -> Dict[str, Any]:
|
||||
def _get_query_context(self) -> Dict[str, Any]:
|
||||
self.login(username="admin")
|
||||
slc = self.get_slice("Girl Name Cloud", db.session)
|
||||
return {
|
||||
|
|
@ -127,6 +127,45 @@ class CoreTests(SupersetTestCase):
|
|||
],
|
||||
}
|
||||
|
||||
def _get_query_context_with_post_processing(self) -> Dict[str, Any]:
|
||||
self.login(username="admin")
|
||||
slc = self.get_slice("Girl Name Cloud", db.session)
|
||||
return {
|
||||
"datasource": {"id": slc.datasource_id, "type": slc.datasource_type},
|
||||
"queries": [
|
||||
{
|
||||
"granularity": "ds",
|
||||
"groupby": ["name", "state"],
|
||||
"metrics": [{"label": "sum__num"}],
|
||||
"filters": [],
|
||||
"row_limit": 100,
|
||||
"post_processing": [
|
||||
{
|
||||
"operation": "aggregate",
|
||||
"options": {
|
||||
"groupby": ["state"],
|
||||
"aggregates": {
|
||||
"q1": {
|
||||
"operator": "percentile",
|
||||
"column": "sum__num",
|
||||
"options": {"q": 25},
|
||||
},
|
||||
"median": {
|
||||
"operator": "median",
|
||||
"column": "sum__num",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"operation": "sort",
|
||||
"options": {"columns": {"q1": False, "state": True},},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
def test_viz_cache_key(self):
|
||||
self.login(username="admin")
|
||||
slc = self.get_slice("Girls", db.session)
|
||||
|
|
@ -140,7 +179,7 @@ class CoreTests(SupersetTestCase):
|
|||
self.assertNotEqual(cache_key, viz.cache_key(qobj))
|
||||
|
||||
def test_cache_key_changes_when_datasource_is_updated(self):
|
||||
qc_dict = self._get_query_context_dict()
|
||||
qc_dict = self._get_query_context()
|
||||
|
||||
# construct baseline cache_key
|
||||
query_context = QueryContext(**qc_dict)
|
||||
|
|
@ -168,7 +207,7 @@ class CoreTests(SupersetTestCase):
|
|||
self.assertNotEqual(cache_key_original, cache_key_new)
|
||||
|
||||
def test_query_context_time_range_endpoints(self):
|
||||
query_context = QueryContext(**self._get_query_context_dict())
|
||||
query_context = QueryContext(**self._get_query_context())
|
||||
query_object = query_context.queries[0]
|
||||
extras = query_object.to_dict()["extras"]
|
||||
self.assertTrue("time_range_endpoints" in extras)
|
||||
|
|
@ -217,11 +256,18 @@ class CoreTests(SupersetTestCase):
|
|||
|
||||
def test_api_v1_query_endpoint(self):
|
||||
self.login(username="admin")
|
||||
qc_dict = self._get_query_context_dict()
|
||||
qc_dict = self._get_query_context()
|
||||
data = json.dumps(qc_dict)
|
||||
resp = json.loads(self.get_resp("/api/v1/query/", {"query_context": data}))
|
||||
self.assertEqual(resp[0]["rowcount"], 100)
|
||||
|
||||
def test_api_v1_query_endpoint_with_post_processing(self):
|
||||
self.login(username="admin")
|
||||
qc_dict = self._get_query_context_with_post_processing()
|
||||
data = json.dumps(qc_dict)
|
||||
resp = json.loads(self.get_resp("/api/v1/query/", {"query_context": data}))
|
||||
self.assertEqual(resp[0]["rowcount"], 6)
|
||||
|
||||
def test_old_slice_json_endpoint(self):
|
||||
self.login(username="admin")
|
||||
slc = self.get_slice("Girls", db.session)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,121 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from datetime import date
|
||||
|
||||
from pandas import DataFrame, to_datetime
|
||||
|
||||
names_df = DataFrame(
|
||||
[
|
||||
{
|
||||
"dt": date(2020, 1, 2),
|
||||
"name": "John",
|
||||
"country": "United Kingdom",
|
||||
"cars": 3,
|
||||
"bikes": 1,
|
||||
"seconds": 30,
|
||||
},
|
||||
{
|
||||
"dt": date(2020, 1, 2),
|
||||
"name": "Peter",
|
||||
"country": "Sweden",
|
||||
"cars": 4,
|
||||
"bikes": 2,
|
||||
"seconds": 1,
|
||||
},
|
||||
{
|
||||
"dt": date(2020, 1, 3),
|
||||
"name": "Mary",
|
||||
"country": "Finland",
|
||||
"cars": 5,
|
||||
"bikes": 3,
|
||||
"seconds": None,
|
||||
},
|
||||
{
|
||||
"dt": date(2020, 1, 3),
|
||||
"name": "Peter",
|
||||
"country": "India",
|
||||
"cars": 6,
|
||||
"bikes": 4,
|
||||
"seconds": 12,
|
||||
},
|
||||
{
|
||||
"dt": date(2020, 1, 4),
|
||||
"name": "John",
|
||||
"country": "Portugal",
|
||||
"cars": 7,
|
||||
"bikes": None,
|
||||
"seconds": 75,
|
||||
},
|
||||
{
|
||||
"dt": date(2020, 1, 4),
|
||||
"name": "Peter",
|
||||
"country": "Italy",
|
||||
"cars": None,
|
||||
"bikes": 5,
|
||||
"seconds": 600,
|
||||
},
|
||||
{
|
||||
"dt": date(2020, 1, 4),
|
||||
"name": "Mary",
|
||||
"country": None,
|
||||
"cars": 9,
|
||||
"bikes": 6,
|
||||
"seconds": 2,
|
||||
},
|
||||
{
|
||||
"dt": date(2020, 1, 4),
|
||||
"name": None,
|
||||
"country": "Australia",
|
||||
"cars": 10,
|
||||
"bikes": 7,
|
||||
"seconds": 99,
|
||||
},
|
||||
{
|
||||
"dt": date(2020, 1, 1),
|
||||
"name": "John",
|
||||
"country": "USA",
|
||||
"cars": 1,
|
||||
"bikes": 8,
|
||||
"seconds": None,
|
||||
},
|
||||
{
|
||||
"dt": date(2020, 1, 1),
|
||||
"name": "Mary",
|
||||
"country": "Fiji",
|
||||
"cars": 2,
|
||||
"bikes": 9,
|
||||
"seconds": 50,
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
categories_df = DataFrame(
|
||||
{
|
||||
"constant": ["dummy" for _ in range(0, 101)],
|
||||
"category": [f"cat{i%3}" for i in range(0, 101)],
|
||||
"dept": [f"dept{i%5}" for i in range(0, 101)],
|
||||
"name": [f"person{i}" for i in range(0, 101)],
|
||||
"asc_idx": [i for i in range(0, 101)],
|
||||
"desc_idx": [i for i in range(100, -1, -1)],
|
||||
"idx_nulls": [i if i % 5 == 0 else None for i in range(0, 101)],
|
||||
}
|
||||
)
|
||||
|
||||
timeseries_df = DataFrame(
|
||||
index=to_datetime(["2019-01-01", "2019-01-02", "2019-01-05", "2019-01-07"]),
|
||||
data={"label": ["x", "y", "z", "q"], "y": [1.0, 2.0, 3.0, 4.0]},
|
||||
)
|
||||
|
|
@ -0,0 +1,290 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# isort:skip_file
|
||||
import math
|
||||
from typing import Any, List
|
||||
|
||||
from pandas import Series
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.utils import pandas_postprocessing as proc
|
||||
|
||||
from .base_tests import SupersetTestCase
|
||||
from .fixtures.dataframes import categories_df, timeseries_df
|
||||
|
||||
|
||||
def series_to_list(series: Series) -> List[Any]:
|
||||
"""
|
||||
Converts a `Series` to a regular list, and replaces non-numeric values to
|
||||
Nones.
|
||||
|
||||
:param series: Series to convert
|
||||
:return: list without nan or inf
|
||||
"""
|
||||
return [
|
||||
None
|
||||
if not isinstance(val, str) and (math.isnan(val) or math.isinf(val))
|
||||
else val
|
||||
for val in series.tolist()
|
||||
]
|
||||
|
||||
|
||||
class PostProcessingTestCase(SupersetTestCase):
|
||||
def test_pivot(self):
|
||||
aggregates = {"idx_nulls": {"operator": "sum"}}
|
||||
|
||||
# regular pivot
|
||||
df = proc.pivot(
|
||||
df=categories_df,
|
||||
index=["name"],
|
||||
columns=["category"],
|
||||
aggregates=aggregates,
|
||||
)
|
||||
self.assertListEqual(
|
||||
df.columns.tolist(),
|
||||
[("idx_nulls", "cat0"), ("idx_nulls", "cat1"), ("idx_nulls", "cat2")],
|
||||
)
|
||||
self.assertEqual(len(df), 101)
|
||||
self.assertEqual(df.sum()[0], 315)
|
||||
|
||||
# regular pivot
|
||||
df = proc.pivot(
|
||||
df=categories_df,
|
||||
index=["dept"],
|
||||
columns=["category"],
|
||||
aggregates=aggregates,
|
||||
)
|
||||
self.assertEqual(len(df), 5)
|
||||
|
||||
# fill value
|
||||
df = proc.pivot(
|
||||
df=categories_df,
|
||||
index=["name"],
|
||||
columns=["category"],
|
||||
metric_fill_value=1,
|
||||
aggregates={"idx_nulls": {"operator": "sum"}},
|
||||
)
|
||||
self.assertEqual(df.sum()[0], 382)
|
||||
|
||||
# invalid index reference
|
||||
self.assertRaises(
|
||||
QueryObjectValidationError,
|
||||
proc.pivot,
|
||||
df=categories_df,
|
||||
index=["abc"],
|
||||
columns=["dept"],
|
||||
aggregates=aggregates,
|
||||
)
|
||||
|
||||
# invalid column reference
|
||||
self.assertRaises(
|
||||
QueryObjectValidationError,
|
||||
proc.pivot,
|
||||
df=categories_df,
|
||||
index=["dept"],
|
||||
columns=["abc"],
|
||||
aggregates=aggregates,
|
||||
)
|
||||
|
||||
# invalid aggregate options
|
||||
self.assertRaises(
|
||||
QueryObjectValidationError,
|
||||
proc.pivot,
|
||||
df=categories_df,
|
||||
index=["name"],
|
||||
columns=["category"],
|
||||
aggregates={"idx_nulls": {}},
|
||||
)
|
||||
|
||||
def test_aggregate(self):
|
||||
aggregates = {
|
||||
"asc sum": {"column": "asc_idx", "operator": "sum"},
|
||||
"asc q2": {
|
||||
"column": "asc_idx",
|
||||
"operator": "percentile",
|
||||
"options": {"q": 75},
|
||||
},
|
||||
"desc q1": {
|
||||
"column": "desc_idx",
|
||||
"operator": "percentile",
|
||||
"options": {"q": 25},
|
||||
},
|
||||
}
|
||||
df = proc.aggregate(
|
||||
df=categories_df, groupby=["constant"], aggregates=aggregates
|
||||
)
|
||||
self.assertListEqual(
|
||||
df.columns.tolist(), ["constant", "asc sum", "asc q2", "desc q1"]
|
||||
)
|
||||
self.assertEqual(series_to_list(df["asc sum"])[0], 5050)
|
||||
self.assertEqual(series_to_list(df["asc q2"])[0], 75)
|
||||
self.assertEqual(series_to_list(df["desc q1"])[0], 25)
|
||||
|
||||
def test_sort(self):
|
||||
df = proc.sort(df=categories_df, columns={"category": True, "asc_idx": False})
|
||||
self.assertEqual(96, series_to_list(df["asc_idx"])[1])
|
||||
|
||||
self.assertRaises(
|
||||
QueryObjectValidationError, proc.sort, df=df, columns={"abc": True}
|
||||
)
|
||||
|
||||
def test_rolling(self):
|
||||
# sum rolling type
|
||||
post_df = proc.rolling(
|
||||
df=timeseries_df,
|
||||
columns={"y": "y"},
|
||||
rolling_type="sum",
|
||||
window=2,
|
||||
min_periods=0,
|
||||
)
|
||||
|
||||
self.assertListEqual(post_df.columns.tolist(), ["label", "y"])
|
||||
self.assertListEqual(series_to_list(post_df["y"]), [1.0, 3.0, 5.0, 7.0])
|
||||
|
||||
# mean rolling type with alias
|
||||
post_df = proc.rolling(
|
||||
df=timeseries_df,
|
||||
rolling_type="mean",
|
||||
columns={"y": "y_mean"},
|
||||
window=10,
|
||||
min_periods=0,
|
||||
)
|
||||
self.assertListEqual(post_df.columns.tolist(), ["label", "y", "y_mean"])
|
||||
self.assertListEqual(series_to_list(post_df["y_mean"]), [1.0, 1.5, 2.0, 2.5])
|
||||
|
||||
# count rolling type
|
||||
post_df = proc.rolling(
|
||||
df=timeseries_df,
|
||||
rolling_type="count",
|
||||
columns={"y": "y"},
|
||||
window=10,
|
||||
min_periods=0,
|
||||
)
|
||||
self.assertListEqual(post_df.columns.tolist(), ["label", "y"])
|
||||
self.assertListEqual(series_to_list(post_df["y"]), [1.0, 2.0, 3.0, 4.0])
|
||||
|
||||
# quantile rolling type
|
||||
post_df = proc.rolling(
|
||||
df=timeseries_df,
|
||||
columns={"y": "q1"},
|
||||
rolling_type="quantile",
|
||||
rolling_type_options={"quantile": 0.25},
|
||||
window=10,
|
||||
min_periods=0,
|
||||
)
|
||||
self.assertListEqual(post_df.columns.tolist(), ["label", "y", "q1"])
|
||||
self.assertListEqual(series_to_list(post_df["q1"]), [1.0, 1.25, 1.5, 1.75])
|
||||
|
||||
# incorrect rolling type
|
||||
self.assertRaises(
|
||||
QueryObjectValidationError,
|
||||
proc.rolling,
|
||||
df=timeseries_df,
|
||||
columns={"y": "y"},
|
||||
rolling_type="abc",
|
||||
window=2,
|
||||
)
|
||||
|
||||
# incorrect rolling type options
|
||||
self.assertRaises(
|
||||
QueryObjectValidationError,
|
||||
proc.rolling,
|
||||
df=timeseries_df,
|
||||
columns={"y": "y"},
|
||||
rolling_type="quantile",
|
||||
rolling_type_options={"abc": 123},
|
||||
window=2,
|
||||
)
|
||||
|
||||
def test_select(self):
|
||||
# reorder columns
|
||||
post_df = proc.select(df=timeseries_df, columns=["y", "label"])
|
||||
self.assertListEqual(post_df.columns.tolist(), ["y", "label"])
|
||||
|
||||
# one column
|
||||
post_df = proc.select(df=timeseries_df, columns=["label"])
|
||||
self.assertListEqual(post_df.columns.tolist(), ["label"])
|
||||
|
||||
# rename one column
|
||||
post_df = proc.select(df=timeseries_df, columns=["y"], rename={"y": "y1"})
|
||||
self.assertListEqual(post_df.columns.tolist(), ["y1"])
|
||||
|
||||
# rename one and leave one unchanged
|
||||
post_df = proc.select(
|
||||
df=timeseries_df, columns=["label", "y"], rename={"y": "y1"}
|
||||
)
|
||||
self.assertListEqual(post_df.columns.tolist(), ["label", "y1"])
|
||||
|
||||
# invalid columns
|
||||
self.assertRaises(
|
||||
QueryObjectValidationError,
|
||||
proc.select,
|
||||
df=timeseries_df,
|
||||
columns=["qwerty"],
|
||||
rename={"abc": "qwerty"},
|
||||
)
|
||||
|
||||
def test_diff(self):
|
||||
# overwrite column
|
||||
post_df = proc.diff(df=timeseries_df, columns={"y": "y"})
|
||||
self.assertListEqual(post_df.columns.tolist(), ["label", "y"])
|
||||
self.assertListEqual(series_to_list(post_df["y"]), [None, 1.0, 1.0, 1.0])
|
||||
|
||||
# add column
|
||||
post_df = proc.diff(df=timeseries_df, columns={"y": "y1"})
|
||||
self.assertListEqual(post_df.columns.tolist(), ["label", "y", "y1"])
|
||||
self.assertListEqual(series_to_list(post_df["y"]), [1.0, 2.0, 3.0, 4.0])
|
||||
self.assertListEqual(series_to_list(post_df["y1"]), [None, 1.0, 1.0, 1.0])
|
||||
|
||||
# look ahead
|
||||
post_df = proc.diff(df=timeseries_df, columns={"y": "y1"}, periods=-1)
|
||||
self.assertListEqual(series_to_list(post_df["y1"]), [-1.0, -1.0, -1.0, None])
|
||||
|
||||
# invalid column reference
|
||||
self.assertRaises(
|
||||
QueryObjectValidationError,
|
||||
proc.diff,
|
||||
df=timeseries_df,
|
||||
columns={"abc": "abc"},
|
||||
)
|
||||
|
||||
def test_cum(self):
|
||||
# create new column (cumsum)
|
||||
post_df = proc.cum(df=timeseries_df, columns={"y": "y2"}, operator="sum",)
|
||||
self.assertListEqual(post_df.columns.tolist(), ["label", "y", "y2"])
|
||||
self.assertListEqual(series_to_list(post_df["label"]), ["x", "y", "z", "q"])
|
||||
self.assertListEqual(series_to_list(post_df["y"]), [1.0, 2.0, 3.0, 4.0])
|
||||
self.assertListEqual(series_to_list(post_df["y2"]), [1.0, 3.0, 6.0, 10.0])
|
||||
|
||||
# overwrite column (cumprod)
|
||||
post_df = proc.cum(df=timeseries_df, columns={"y": "y"}, operator="prod",)
|
||||
self.assertListEqual(post_df.columns.tolist(), ["label", "y"])
|
||||
self.assertListEqual(series_to_list(post_df["y"]), [1.0, 2.0, 6.0, 24.0])
|
||||
|
||||
# overwrite column (cummin)
|
||||
post_df = proc.cum(df=timeseries_df, columns={"y": "y"}, operator="min",)
|
||||
self.assertListEqual(post_df.columns.tolist(), ["label", "y"])
|
||||
self.assertListEqual(series_to_list(post_df["y"]), [1.0, 1.0, 1.0, 1.0])
|
||||
|
||||
# invalid operator
|
||||
self.assertRaises(
|
||||
QueryObjectValidationError,
|
||||
proc.cum,
|
||||
df=timeseries_df,
|
||||
columns={"y": "y"},
|
||||
operator="abc",
|
||||
)
|
||||
Loading…
Reference in New Issue