refactor: postprocessing move to unit test (#18779)

This commit is contained in:
Yongjie Zhao 2022-02-17 20:05:41 +08:00 committed by GitHub
parent cd381879c0
commit 30a9d14639
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 1324 additions and 1098 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,40 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.utils.pandas_postprocessing import aggregate
from tests.unit_tests.fixtures.dataframes import categories_df
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
def test_aggregate():
aggregates = {
"asc sum": {"column": "asc_idx", "operator": "sum"},
"asc q2": {
"column": "asc_idx",
"operator": "percentile",
"options": {"q": 75},
},
"desc q1": {
"column": "desc_idx",
"operator": "percentile",
"options": {"q": 25},
},
}
df = aggregate(df=categories_df, groupby=["constant"], aggregates=aggregates)
assert df.columns.tolist() == ["constant", "asc sum", "asc q2", "desc q1"]
assert series_to_list(df["asc sum"])[0] == 5050
assert series_to_list(df["asc q2"])[0] == 75
assert series_to_list(df["desc q1"])[0] == 25

View File

@ -0,0 +1,126 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from superset.exceptions import QueryObjectValidationError
from superset.utils.core import PostProcessingBoxplotWhiskerType
from superset.utils.pandas_postprocessing import boxplot
from tests.unit_tests.fixtures.dataframes import names_df
def test_boxplot_tukey():
df = boxplot(
df=names_df,
groupby=["region"],
whisker_type=PostProcessingBoxplotWhiskerType.TUKEY,
metrics=["cars"],
)
columns = {column for column in df.columns}
assert columns == {
"cars__mean",
"cars__median",
"cars__q1",
"cars__q3",
"cars__max",
"cars__min",
"cars__count",
"cars__outliers",
"region",
}
assert len(df) == 4
def test_boxplot_min_max():
df = boxplot(
df=names_df,
groupby=["region"],
whisker_type=PostProcessingBoxplotWhiskerType.MINMAX,
metrics=["cars"],
)
columns = {column for column in df.columns}
assert columns == {
"cars__mean",
"cars__median",
"cars__q1",
"cars__q3",
"cars__max",
"cars__min",
"cars__count",
"cars__outliers",
"region",
}
assert len(df) == 4
def test_boxplot_percentile():
df = boxplot(
df=names_df,
groupby=["region"],
whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE,
metrics=["cars"],
percentiles=[1, 99],
)
columns = {column for column in df.columns}
assert columns == {
"cars__mean",
"cars__median",
"cars__q1",
"cars__q3",
"cars__max",
"cars__min",
"cars__count",
"cars__outliers",
"region",
}
assert len(df) == 4
def test_boxplot_percentile_incorrect_params():
with pytest.raises(QueryObjectValidationError):
boxplot(
df=names_df,
groupby=["region"],
whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE,
metrics=["cars"],
)
with pytest.raises(QueryObjectValidationError):
boxplot(
df=names_df,
groupby=["region"],
whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE,
metrics=["cars"],
percentiles=[10],
)
with pytest.raises(QueryObjectValidationError):
boxplot(
df=names_df,
groupby=["region"],
whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE,
metrics=["cars"],
percentiles=[90, 10],
)
with pytest.raises(QueryObjectValidationError):
boxplot(
df=names_df,
groupby=["region"],
whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE,
metrics=["cars"],
percentiles=[10, 90, 10],
)

View File

@ -0,0 +1,62 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.utils.pandas_postprocessing import compare
from tests.unit_tests.fixtures.dataframes import timeseries_df2
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
def test_compare():
# `difference` comparison
post_df = compare(
df=timeseries_df2,
source_columns=["y"],
compare_columns=["z"],
compare_type="difference",
)
assert post_df.columns.tolist() == ["label", "y", "z", "difference__y__z"]
assert series_to_list(post_df["difference__y__z"]) == [0.0, -2.0, -8.0, -6.0]
# drop original columns
post_df = compare(
df=timeseries_df2,
source_columns=["y"],
compare_columns=["z"],
compare_type="difference",
drop_original_columns=True,
)
assert post_df.columns.tolist() == ["label", "difference__y__z"]
# `percentage` comparison
post_df = compare(
df=timeseries_df2,
source_columns=["y"],
compare_columns=["z"],
compare_type="percentage",
)
assert post_df.columns.tolist() == ["label", "y", "z", "percentage__y__z"]
assert series_to_list(post_df["percentage__y__z"]) == [0.0, -0.5, -0.8, -0.75]
# `ratio` comparison
post_df = compare(
df=timeseries_df2,
source_columns=["y"],
compare_columns=["z"],
compare_type="ratio",
)
assert post_df.columns.tolist() == ["label", "y", "z", "ratio__y__z"]
assert series_to_list(post_df["ratio__y__z"]) == [1.0, 0.5, 0.2, 0.25]

View File

@ -0,0 +1,69 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
import pytest
from pandas import DataFrame
from superset.exceptions import QueryObjectValidationError
from superset.utils.core import DTTM_ALIAS, PostProcessingContributionOrientation
from superset.utils.pandas_postprocessing import contribution
def test_contribution():
df = DataFrame(
{
DTTM_ALIAS: [datetime(2020, 7, 16, 14, 49), datetime(2020, 7, 16, 14, 50),],
"a": [1, 3],
"b": [1, 9],
}
)
with pytest.raises(QueryObjectValidationError, match="not numeric"):
contribution(df, columns=[DTTM_ALIAS])
with pytest.raises(QueryObjectValidationError, match="same length"):
contribution(df, columns=["a"], rename_columns=["aa", "bb"])
# cell contribution across row
processed_df = contribution(
df, orientation=PostProcessingContributionOrientation.ROW,
)
assert processed_df.columns.tolist() == [DTTM_ALIAS, "a", "b"]
assert processed_df["a"].tolist() == [0.5, 0.25]
assert processed_df["b"].tolist() == [0.5, 0.75]
# cell contribution across column without temporal column
df.pop(DTTM_ALIAS)
processed_df = contribution(
df, orientation=PostProcessingContributionOrientation.COLUMN
)
assert processed_df.columns.tolist() == ["a", "b"]
assert processed_df["a"].tolist() == [0.25, 0.75]
assert processed_df["b"].tolist() == [0.1, 0.9]
# contribution only on selected columns
processed_df = contribution(
df,
orientation=PostProcessingContributionOrientation.COLUMN,
columns=["a"],
rename_columns=["pct_a"],
)
assert processed_df.columns.tolist() == ["a", "b", "pct_a"]
assert processed_df["a"].tolist() == [1, 3]
assert processed_df["b"].tolist() == [1, 9]
assert processed_df["pct_a"].tolist() == [0.25, 0.75]

View File

@ -0,0 +1,97 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from pandas import to_datetime
from superset.exceptions import QueryObjectValidationError
from superset.utils.pandas_postprocessing import cum, pivot
from tests.unit_tests.fixtures.dataframes import (
multiple_metrics_df,
single_metric_df,
timeseries_df,
)
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
def test_cum():
# create new column (cumsum)
post_df = cum(df=timeseries_df, columns={"y": "y2"}, operator="sum",)
assert post_df.columns.tolist() == ["label", "y", "y2"]
assert series_to_list(post_df["label"]) == ["x", "y", "z", "q"]
assert series_to_list(post_df["y"]) == [1.0, 2.0, 3.0, 4.0]
assert series_to_list(post_df["y2"]) == [1.0, 3.0, 6.0, 10.0]
# overwrite column (cumprod)
post_df = cum(df=timeseries_df, columns={"y": "y"}, operator="prod",)
assert post_df.columns.tolist() == ["label", "y"]
assert series_to_list(post_df["y"]) == [1.0, 2.0, 6.0, 24.0]
# overwrite column (cummin)
post_df = cum(df=timeseries_df, columns={"y": "y"}, operator="min",)
assert post_df.columns.tolist() == ["label", "y"]
assert series_to_list(post_df["y"]) == [1.0, 1.0, 1.0, 1.0]
# invalid operator
with pytest.raises(QueryObjectValidationError):
cum(
df=timeseries_df, columns={"y": "y"}, operator="abc",
)
def test_cum_with_pivot_df_and_single_metric():
pivot_df = pivot(
df=single_metric_df,
index=["dttm"],
columns=["country"],
aggregates={"sum_metric": {"operator": "sum"}},
flatten_columns=False,
reset_index=False,
)
cum_df = cum(df=pivot_df, operator="sum", is_pivot_df=True,)
# dttm UK US
# 0 2019-01-01 5 6
# 1 2019-01-02 12 14
assert cum_df["UK"].to_list() == [5.0, 12.0]
assert cum_df["US"].to_list() == [6.0, 14.0]
assert (
cum_df["dttm"].to_list() == to_datetime(["2019-01-01", "2019-01-02"]).to_list()
)
def test_cum_with_pivot_df_and_multiple_metrics():
pivot_df = pivot(
df=multiple_metrics_df,
index=["dttm"],
columns=["country"],
aggregates={
"sum_metric": {"operator": "sum"},
"count_metric": {"operator": "sum"},
},
flatten_columns=False,
reset_index=False,
)
cum_df = cum(df=pivot_df, operator="sum", is_pivot_df=True,)
# dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
# 0 2019-01-01 1 2 5 6
# 1 2019-01-02 4 6 12 14
assert cum_df["count_metric, UK"].to_list() == [1.0, 4.0]
assert cum_df["count_metric, US"].to_list() == [2.0, 6.0]
assert cum_df["sum_metric, UK"].to_list() == [5.0, 12.0]
assert cum_df["sum_metric, US"].to_list() == [6.0, 14.0]
assert (
cum_df["dttm"].to_list() == to_datetime(["2019-01-01", "2019-01-02"]).to_list()
)

View File

@ -0,0 +1,50 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from superset.exceptions import QueryObjectValidationError
from superset.utils.pandas_postprocessing import diff
from tests.unit_tests.fixtures.dataframes import timeseries_df, timeseries_df2
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
def test_diff():
# overwrite column
post_df = diff(df=timeseries_df, columns={"y": "y"})
assert post_df.columns.tolist() == ["label", "y"]
assert series_to_list(post_df["y"]) == [None, 1.0, 1.0, 1.0]
# add column
post_df = diff(df=timeseries_df, columns={"y": "y1"})
assert post_df.columns.tolist() == ["label", "y", "y1"]
assert series_to_list(post_df["y"]) == [1.0, 2.0, 3.0, 4.0]
assert series_to_list(post_df["y1"]) == [None, 1.0, 1.0, 1.0]
# look ahead
post_df = diff(df=timeseries_df, columns={"y": "y1"}, periods=-1)
assert series_to_list(post_df["y1"]) == [-1.0, -1.0, -1.0, None]
# invalid column reference
with pytest.raises(QueryObjectValidationError):
diff(
df=timeseries_df, columns={"abc": "abc"},
)
# diff by columns
post_df = diff(df=timeseries_df2, columns={"y": "y", "z": "z"}, axis=1)
assert post_df.columns.tolist() == ["label", "y", "z"]
assert series_to_list(post_df["z"]) == [0.0, 2.0, 8.0, 6.0]

View File

@ -0,0 +1,90 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.utils.pandas_postprocessing import (
geodetic_parse,
geohash_decode,
geohash_encode,
)
from tests.unit_tests.fixtures.dataframes import lonlat_df
from tests.unit_tests.pandas_postprocessing.utils import round_floats, series_to_list
def test_geohash_decode():
# decode lon/lat from geohash
post_df = geohash_decode(
df=lonlat_df[["city", "geohash"]],
geohash="geohash",
latitude="latitude",
longitude="longitude",
)
assert sorted(post_df.columns.tolist()) == sorted(
["city", "geohash", "latitude", "longitude"]
)
assert round_floats(series_to_list(post_df["longitude"]), 6) == round_floats(
series_to_list(lonlat_df["longitude"]), 6
)
assert round_floats(series_to_list(post_df["latitude"]), 6) == round_floats(
series_to_list(lonlat_df["latitude"]), 6
)
def test_geohash_encode():
# encode lon/lat into geohash
post_df = geohash_encode(
df=lonlat_df[["city", "latitude", "longitude"]],
latitude="latitude",
longitude="longitude",
geohash="geohash",
)
assert sorted(post_df.columns.tolist()) == sorted(
["city", "geohash", "latitude", "longitude"]
)
assert series_to_list(post_df["geohash"]) == series_to_list(lonlat_df["geohash"])
def test_geodetic_parse():
# parse geodetic string with altitude into lon/lat/altitude
post_df = geodetic_parse(
df=lonlat_df[["city", "geodetic"]],
geodetic="geodetic",
latitude="latitude",
longitude="longitude",
altitude="altitude",
)
assert sorted(post_df.columns.tolist()) == sorted(
["city", "geodetic", "latitude", "longitude", "altitude"]
)
assert series_to_list(post_df["longitude"]) == series_to_list(
lonlat_df["longitude"]
)
assert series_to_list(post_df["latitude"]) == series_to_list(lonlat_df["latitude"])
assert series_to_list(post_df["altitude"]) == series_to_list(lonlat_df["altitude"])
# parse geodetic string into lon/lat
post_df = geodetic_parse(
df=lonlat_df[["city", "geodetic"]],
geodetic="geodetic",
latitude="latitude",
longitude="longitude",
)
assert sorted(post_df.columns.tolist()) == sorted(
["city", "geodetic", "latitude", "longitude"]
)
assert series_to_list(post_df["longitude"]) == series_to_list(
lonlat_df["longitude"]
)
assert series_to_list(post_df["latitude"]), series_to_list(lonlat_df["latitude"])

View File

@ -0,0 +1,266 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import pytest
from pandas import DataFrame, Timestamp, to_datetime
from superset.exceptions import QueryObjectValidationError
from superset.utils.pandas_postprocessing import _flatten_column_after_pivot, pivot
from tests.unit_tests.fixtures.dataframes import categories_df, single_metric_df
from tests.unit_tests.pandas_postprocessing.utils import (
AGGREGATES_MULTIPLE,
AGGREGATES_SINGLE,
)
def test_flatten_column_after_pivot():
"""
Test pivot column flattening function
"""
# single aggregate cases
assert (
_flatten_column_after_pivot(aggregates=AGGREGATES_SINGLE, column="idx_nulls",)
== "idx_nulls"
)
assert (
_flatten_column_after_pivot(aggregates=AGGREGATES_SINGLE, column=1234,)
== "1234"
)
assert (
_flatten_column_after_pivot(
aggregates=AGGREGATES_SINGLE, column=Timestamp("2020-09-29T00:00:00"),
)
== "2020-09-29 00:00:00"
)
assert (
_flatten_column_after_pivot(aggregates=AGGREGATES_SINGLE, column="idx_nulls",)
== "idx_nulls"
)
assert (
_flatten_column_after_pivot(
aggregates=AGGREGATES_SINGLE, column=("idx_nulls", "col1"),
)
== "col1"
)
assert (
_flatten_column_after_pivot(
aggregates=AGGREGATES_SINGLE, column=("idx_nulls", "col1", 1234),
)
== "col1, 1234"
)
# Multiple aggregate cases
assert (
_flatten_column_after_pivot(
aggregates=AGGREGATES_MULTIPLE, column=("idx_nulls", "asc_idx", "col1"),
)
== "idx_nulls, asc_idx, col1"
)
assert (
_flatten_column_after_pivot(
aggregates=AGGREGATES_MULTIPLE,
column=("idx_nulls", "asc_idx", "col1", 1234),
)
== "idx_nulls, asc_idx, col1, 1234"
)
def test_pivot_without_columns():
"""
Make sure pivot without columns returns correct DataFrame
"""
df = pivot(df=categories_df, index=["name"], aggregates=AGGREGATES_SINGLE,)
assert df.columns.tolist() == ["name", "idx_nulls"]
assert len(df) == 101
assert df.sum()[1] == 1050
def test_pivot_with_single_column():
"""
Make sure pivot with single column returns correct DataFrame
"""
df = pivot(
df=categories_df,
index=["name"],
columns=["category"],
aggregates=AGGREGATES_SINGLE,
)
assert df.columns.tolist() == ["name", "cat0", "cat1", "cat2"]
assert len(df) == 101
assert df.sum()[1] == 315
df = pivot(
df=categories_df,
index=["dept"],
columns=["category"],
aggregates=AGGREGATES_SINGLE,
)
assert df.columns.tolist() == ["dept", "cat0", "cat1", "cat2"]
assert len(df) == 5
def test_pivot_with_multiple_columns():
"""
Make sure pivot with multiple columns returns correct DataFrame
"""
df = pivot(
df=categories_df,
index=["name"],
columns=["category", "dept"],
aggregates=AGGREGATES_SINGLE,
)
assert len(df.columns) == 1 + 3 * 5 # index + possible permutations
def test_pivot_fill_values():
"""
Make sure pivot with fill values returns correct DataFrame
"""
df = pivot(
df=categories_df,
index=["name"],
columns=["category"],
metric_fill_value=1,
aggregates={"idx_nulls": {"operator": "sum"}},
)
assert df.sum()[1] == 382
def test_pivot_fill_column_values():
"""
Make sure pivot witn null column names returns correct DataFrame
"""
df_copy = categories_df.copy()
df_copy["category"] = None
df = pivot(
df=df_copy,
index=["name"],
columns=["category"],
aggregates={"idx_nulls": {"operator": "sum"}},
)
assert len(df) == 101
assert df.columns.tolist() == ["name", "<NULL>"]
def test_pivot_exceptions():
"""
Make sure pivot raises correct Exceptions
"""
# Missing index
with pytest.raises(TypeError):
pivot(df=categories_df, columns=["dept"], aggregates=AGGREGATES_SINGLE)
# invalid index reference
with pytest.raises(QueryObjectValidationError):
pivot(
df=categories_df,
index=["abc"],
columns=["dept"],
aggregates=AGGREGATES_SINGLE,
)
# invalid column reference
with pytest.raises(QueryObjectValidationError):
pivot(
df=categories_df,
index=["dept"],
columns=["abc"],
aggregates=AGGREGATES_SINGLE,
)
# invalid aggregate options
with pytest.raises(QueryObjectValidationError):
pivot(
df=categories_df,
index=["name"],
columns=["category"],
aggregates={"idx_nulls": {}},
)
def test_pivot_eliminate_cartesian_product_columns():
# single metric
mock_df = DataFrame(
{
"dttm": to_datetime(["2019-01-01", "2019-01-01"]),
"a": [0, 1],
"b": [0, 1],
"metric": [9, np.NAN],
}
)
df = pivot(
df=mock_df,
index=["dttm"],
columns=["a", "b"],
aggregates={"metric": {"operator": "mean"}},
drop_missing_columns=False,
)
assert list(df.columns) == ["dttm", "0, 0", "1, 1"]
assert np.isnan(df["1, 1"][0])
# multiple metrics
mock_df = DataFrame(
{
"dttm": to_datetime(["2019-01-01", "2019-01-01"]),
"a": [0, 1],
"b": [0, 1],
"metric": [9, np.NAN],
"metric2": [10, 11],
}
)
df = pivot(
df=mock_df,
index=["dttm"],
columns=["a", "b"],
aggregates={"metric": {"operator": "mean"}, "metric2": {"operator": "mean"},},
drop_missing_columns=False,
)
assert list(df.columns) == [
"dttm",
"metric, 0, 0",
"metric, 1, 1",
"metric2, 0, 0",
"metric2, 1, 1",
]
assert np.isnan(df["metric, 1, 1"][0])
def test_pivot_without_flatten_columns_and_reset_index():
df = pivot(
df=single_metric_df,
index=["dttm"],
columns=["country"],
aggregates={"sum_metric": {"operator": "sum"}},
flatten_columns=False,
reset_index=False,
)
# metric
# country UK US
# dttm
# 2019-01-01 5 6
# 2019-01-02 7 8
assert df.columns.to_list() == [("sum_metric", "UK"), ("sum_metric", "US")]
assert df.index.to_list() == to_datetime(["2019-01-01", "2019-01-02"]).to_list()

View File

@ -0,0 +1,114 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from importlib.util import find_spec
import pytest
from superset.exceptions import QueryObjectValidationError
from superset.utils.core import DTTM_ALIAS
from superset.utils.pandas_postprocessing import prophet
from tests.unit_tests.fixtures.dataframes import prophet_df
def test_prophet_valid():
pytest.importorskip("prophet")
df = prophet(df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.9)
columns = {column for column in df.columns}
assert columns == {
DTTM_ALIAS,
"a__yhat",
"a__yhat_upper",
"a__yhat_lower",
"a",
"b__yhat",
"b__yhat_upper",
"b__yhat_lower",
"b",
}
assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31)
assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 3, 31)
assert len(df) == 7
df = prophet(df=prophet_df, time_grain="P1M", periods=5, confidence_interval=0.9)
assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31)
assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 5, 31)
assert len(df) == 9
def test_prophet_valid_zero_periods():
pytest.importorskip("prophet")
df = prophet(df=prophet_df, time_grain="P1M", periods=0, confidence_interval=0.9)
columns = {column for column in df.columns}
assert columns == {
DTTM_ALIAS,
"a__yhat",
"a__yhat_upper",
"a__yhat_lower",
"a",
"b__yhat",
"b__yhat_upper",
"b__yhat_lower",
"b",
}
assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31)
assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2021, 12, 31)
assert len(df) == 4
def test_prophet_import():
dynamic_module = find_spec("prophet")
if dynamic_module is None:
with pytest.raises(QueryObjectValidationError):
prophet(df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.9)
def test_prophet_missing_temporal_column():
df = prophet_df.drop(DTTM_ALIAS, axis=1)
with pytest.raises(QueryObjectValidationError):
prophet(
df=df, time_grain="P1M", periods=3, confidence_interval=0.9,
)
def test_prophet_incorrect_confidence_interval():
with pytest.raises(QueryObjectValidationError):
prophet(
df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.0,
)
with pytest.raises(QueryObjectValidationError):
prophet(
df=prophet_df, time_grain="P1M", periods=3, confidence_interval=1.0,
)
def test_prophet_incorrect_periods():
with pytest.raises(QueryObjectValidationError):
prophet(
df=prophet_df, time_grain="P1M", periods=-1, confidence_interval=0.8,
)
def test_prophet_incorrect_time_grain():
with pytest.raises(QueryObjectValidationError):
prophet(
df=prophet_df, time_grain="yearly", periods=10, confidence_interval=0.8,
)

View File

@ -0,0 +1,107 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from pandas import DataFrame, to_datetime
from superset.exceptions import QueryObjectValidationError
from superset.utils.pandas_postprocessing import resample
from tests.unit_tests.fixtures.dataframes import timeseries_df
def test_resample():
df = timeseries_df.copy()
df.index.name = "time_column"
df.reset_index(inplace=True)
post_df = resample(df=df, rule="1D", method="ffill", time_column="time_column",)
assert post_df["label"].tolist() == ["x", "y", "y", "y", "z", "z", "q"]
assert post_df["y"].tolist() == [1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 4.0]
post_df = resample(
df=df, rule="1D", method="asfreq", time_column="time_column", fill_value=0,
)
assert post_df["label"].tolist() == ["x", "y", 0, 0, "z", 0, "q"]
assert post_df["y"].tolist() == [1.0, 2.0, 0, 0, 3.0, 0, 4.0]
def test_resample_with_groupby():
"""
The Dataframe contains a timestamp column, a string column and a numeric column.
__timestamp city val
0 2022-01-13 Chicago 6.0
1 2022-01-13 LA 5.0
2 2022-01-13 NY 4.0
3 2022-01-11 Chicago 3.0
4 2022-01-11 LA 2.0
5 2022-01-11 NY 1.0
"""
df = DataFrame(
{
"__timestamp": to_datetime(
[
"2022-01-13",
"2022-01-13",
"2022-01-13",
"2022-01-11",
"2022-01-11",
"2022-01-11",
]
),
"city": ["Chicago", "LA", "NY", "Chicago", "LA", "NY"],
"val": [6.0, 5.0, 4.0, 3.0, 2.0, 1.0],
}
)
post_df = resample(
df=df,
rule="1D",
method="asfreq",
fill_value=0,
time_column="__timestamp",
groupby_columns=("city",),
)
assert list(post_df.columns) == [
"__timestamp",
"city",
"val",
]
assert [str(dt.date()) for dt in post_df["__timestamp"]] == (
["2022-01-11"] * 3 + ["2022-01-12"] * 3 + ["2022-01-13"] * 3
)
assert list(post_df["val"]) == [3.0, 2.0, 1.0, 0, 0, 0, 6.0, 5.0, 4.0]
# should raise error when get a non-existent column
with pytest.raises(QueryObjectValidationError):
resample(
df=df,
rule="1D",
method="asfreq",
fill_value=0,
time_column="__timestamp",
groupby_columns=("city", "unkonw_column",),
)
# should raise error when get a None value in groupby list
with pytest.raises(QueryObjectValidationError):
resample(
df=df,
rule="1D",
method="asfreq",
fill_value=0,
time_column="__timestamp",
groupby_columns=("city", None,),
)

View File

@ -0,0 +1,147 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from pandas import to_datetime
from superset.exceptions import QueryObjectValidationError
from superset.utils.pandas_postprocessing import pivot, rolling
from tests.unit_tests.fixtures.dataframes import (
multiple_metrics_df,
single_metric_df,
timeseries_df,
)
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
def test_rolling():
# sum rolling type
post_df = rolling(
df=timeseries_df,
columns={"y": "y"},
rolling_type="sum",
window=2,
min_periods=0,
)
assert post_df.columns.tolist() == ["label", "y"]
assert series_to_list(post_df["y"]) == [1.0, 3.0, 5.0, 7.0]
# mean rolling type with alias
post_df = rolling(
df=timeseries_df,
rolling_type="mean",
columns={"y": "y_mean"},
window=10,
min_periods=0,
)
assert post_df.columns.tolist() == ["label", "y", "y_mean"]
assert series_to_list(post_df["y_mean"]) == [1.0, 1.5, 2.0, 2.5]
# count rolling type
post_df = rolling(
df=timeseries_df,
rolling_type="count",
columns={"y": "y"},
window=10,
min_periods=0,
)
assert post_df.columns.tolist() == ["label", "y"]
assert series_to_list(post_df["y"]) == [1.0, 2.0, 3.0, 4.0]
# quantile rolling type
post_df = rolling(
df=timeseries_df,
columns={"y": "q1"},
rolling_type="quantile",
rolling_type_options={"quantile": 0.25},
window=10,
min_periods=0,
)
assert post_df.columns.tolist() == ["label", "y", "q1"]
assert series_to_list(post_df["q1"]) == [1.0, 1.25, 1.5, 1.75]
# incorrect rolling type
with pytest.raises(QueryObjectValidationError):
rolling(
df=timeseries_df, columns={"y": "y"}, rolling_type="abc", window=2,
)
# incorrect rolling type options
with pytest.raises(QueryObjectValidationError):
rolling(
df=timeseries_df,
columns={"y": "y"},
rolling_type="quantile",
rolling_type_options={"abc": 123},
window=2,
)
def test_rolling_with_pivot_df_and_single_metric():
pivot_df = pivot(
df=single_metric_df,
index=["dttm"],
columns=["country"],
aggregates={"sum_metric": {"operator": "sum"}},
flatten_columns=False,
reset_index=False,
)
rolling_df = rolling(
df=pivot_df, rolling_type="sum", window=2, min_periods=0, is_pivot_df=True,
)
# dttm UK US
# 0 2019-01-01 5 6
# 1 2019-01-02 12 14
assert rolling_df["UK"].to_list() == [5.0, 12.0]
assert rolling_df["US"].to_list() == [6.0, 14.0]
assert (
rolling_df["dttm"].to_list()
== to_datetime(["2019-01-01", "2019-01-02"]).to_list()
)
rolling_df = rolling(
df=pivot_df, rolling_type="sum", window=2, min_periods=2, is_pivot_df=True,
)
assert rolling_df.empty is True
def test_rolling_with_pivot_df_and_multiple_metrics():
pivot_df = pivot(
df=multiple_metrics_df,
index=["dttm"],
columns=["country"],
aggregates={
"sum_metric": {"operator": "sum"},
"count_metric": {"operator": "sum"},
},
flatten_columns=False,
reset_index=False,
)
rolling_df = rolling(
df=pivot_df, rolling_type="sum", window=2, min_periods=0, is_pivot_df=True,
)
# dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
# 0 2019-01-01 1.0 2.0 5.0 6.0
# 1 2019-01-02 4.0 6.0 12.0 14.0
assert rolling_df["count_metric, UK"].to_list() == [1.0, 4.0]
assert rolling_df["count_metric, US"].to_list() == [2.0, 6.0]
assert rolling_df["sum_metric, UK"].to_list() == [5.0, 12.0]
assert rolling_df["sum_metric, US"].to_list() == [6.0, 14.0]
assert (
rolling_df["dttm"].to_list()
== to_datetime(["2019-01-01", "2019-01-02",]).to_list()
)

View File

@ -0,0 +1,55 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from superset.exceptions import QueryObjectValidationError
from superset.utils.pandas_postprocessing.select import select
from tests.unit_tests.fixtures.dataframes import timeseries_df
def test_select():
# reorder columns
post_df = select(df=timeseries_df, columns=["y", "label"])
assert post_df.columns.tolist() == ["y", "label"]
# one column
post_df = select(df=timeseries_df, columns=["label"])
assert post_df.columns.tolist() == ["label"]
# rename and select one column
post_df = select(df=timeseries_df, columns=["y"], rename={"y": "y1"})
assert post_df.columns.tolist() == ["y1"]
# rename one and leave one unchanged
post_df = select(df=timeseries_df, rename={"y": "y1"})
assert post_df.columns.tolist() == ["label", "y1"]
# drop one column
post_df = select(df=timeseries_df, exclude=["label"])
assert post_df.columns.tolist() == ["y"]
# rename and drop one column
post_df = select(df=timeseries_df, rename={"y": "y1"}, exclude=["label"])
assert post_df.columns.tolist() == ["y1"]
# invalid columns
with pytest.raises(QueryObjectValidationError):
select(df=timeseries_df, columns=["abc"], rename={"abc": "qwerty"})
# select renamed column by new name
with pytest.raises(QueryObjectValidationError):
select(df=timeseries_df, columns=["label_new"], rename={"label": "label_new"})

View File

@ -0,0 +1,30 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from superset.exceptions import QueryObjectValidationError
from superset.utils.pandas_postprocessing import sort
from tests.unit_tests.fixtures.dataframes import categories_df
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
def test_sort():
df = sort(df=categories_df, columns={"category": True, "asc_idx": False})
assert series_to_list(df["asc_idx"])[1] == 96
with pytest.raises(QueryObjectValidationError):
sort(df=df, columns={"abc": True})

View File

@ -0,0 +1,55 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import math
from typing import Any, List, Optional
from pandas import Series
AGGREGATES_SINGLE = {"idx_nulls": {"operator": "sum"}}
AGGREGATES_MULTIPLE = {
"idx_nulls": {"operator": "sum"},
"asc_idx": {"operator": "mean"},
}
def series_to_list(series: Series) -> List[Any]:
"""
Converts a `Series` to a regular list, and replaces non-numeric values to
Nones.
:param series: Series to convert
:return: list without nan or inf
"""
return [
None
if not isinstance(val, str) and (math.isnan(val) or math.isinf(val))
else val
for val in series.tolist()
]
def round_floats(
floats: List[Optional[float]], precision: int
) -> List[Optional[float]]:
"""
Round list of floats to certain precision
:param floats: floats to round
:param precision: intended decimal precision
:return: rounded floats
"""
return [round(val, precision) if val else None for val in floats]