refactor: postprocessing move to unit test (#18779)
This commit is contained in:
parent
cd381879c0
commit
30a9d14639
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,16 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from superset.utils.pandas_postprocessing import aggregate
|
||||
from tests.unit_tests.fixtures.dataframes import categories_df
|
||||
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
|
||||
|
||||
|
||||
def test_aggregate():
|
||||
aggregates = {
|
||||
"asc sum": {"column": "asc_idx", "operator": "sum"},
|
||||
"asc q2": {
|
||||
"column": "asc_idx",
|
||||
"operator": "percentile",
|
||||
"options": {"q": 75},
|
||||
},
|
||||
"desc q1": {
|
||||
"column": "desc_idx",
|
||||
"operator": "percentile",
|
||||
"options": {"q": 25},
|
||||
},
|
||||
}
|
||||
df = aggregate(df=categories_df, groupby=["constant"], aggregates=aggregates)
|
||||
assert df.columns.tolist() == ["constant", "asc sum", "asc q2", "desc q1"]
|
||||
assert series_to_list(df["asc sum"])[0] == 5050
|
||||
assert series_to_list(df["asc q2"])[0] == 75
|
||||
assert series_to_list(df["desc q1"])[0] == 25
|
||||
|
|
@ -0,0 +1,126 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import pytest
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.utils.core import PostProcessingBoxplotWhiskerType
|
||||
from superset.utils.pandas_postprocessing import boxplot
|
||||
from tests.unit_tests.fixtures.dataframes import names_df
|
||||
|
||||
|
||||
def test_boxplot_tukey():
|
||||
df = boxplot(
|
||||
df=names_df,
|
||||
groupby=["region"],
|
||||
whisker_type=PostProcessingBoxplotWhiskerType.TUKEY,
|
||||
metrics=["cars"],
|
||||
)
|
||||
columns = {column for column in df.columns}
|
||||
assert columns == {
|
||||
"cars__mean",
|
||||
"cars__median",
|
||||
"cars__q1",
|
||||
"cars__q3",
|
||||
"cars__max",
|
||||
"cars__min",
|
||||
"cars__count",
|
||||
"cars__outliers",
|
||||
"region",
|
||||
}
|
||||
assert len(df) == 4
|
||||
|
||||
|
||||
def test_boxplot_min_max():
|
||||
df = boxplot(
|
||||
df=names_df,
|
||||
groupby=["region"],
|
||||
whisker_type=PostProcessingBoxplotWhiskerType.MINMAX,
|
||||
metrics=["cars"],
|
||||
)
|
||||
columns = {column for column in df.columns}
|
||||
assert columns == {
|
||||
"cars__mean",
|
||||
"cars__median",
|
||||
"cars__q1",
|
||||
"cars__q3",
|
||||
"cars__max",
|
||||
"cars__min",
|
||||
"cars__count",
|
||||
"cars__outliers",
|
||||
"region",
|
||||
}
|
||||
assert len(df) == 4
|
||||
|
||||
|
||||
def test_boxplot_percentile():
|
||||
df = boxplot(
|
||||
df=names_df,
|
||||
groupby=["region"],
|
||||
whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE,
|
||||
metrics=["cars"],
|
||||
percentiles=[1, 99],
|
||||
)
|
||||
columns = {column for column in df.columns}
|
||||
assert columns == {
|
||||
"cars__mean",
|
||||
"cars__median",
|
||||
"cars__q1",
|
||||
"cars__q3",
|
||||
"cars__max",
|
||||
"cars__min",
|
||||
"cars__count",
|
||||
"cars__outliers",
|
||||
"region",
|
||||
}
|
||||
assert len(df) == 4
|
||||
|
||||
|
||||
def test_boxplot_percentile_incorrect_params():
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
boxplot(
|
||||
df=names_df,
|
||||
groupby=["region"],
|
||||
whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE,
|
||||
metrics=["cars"],
|
||||
)
|
||||
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
boxplot(
|
||||
df=names_df,
|
||||
groupby=["region"],
|
||||
whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE,
|
||||
metrics=["cars"],
|
||||
percentiles=[10],
|
||||
)
|
||||
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
boxplot(
|
||||
df=names_df,
|
||||
groupby=["region"],
|
||||
whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE,
|
||||
metrics=["cars"],
|
||||
percentiles=[90, 10],
|
||||
)
|
||||
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
boxplot(
|
||||
df=names_df,
|
||||
groupby=["region"],
|
||||
whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE,
|
||||
metrics=["cars"],
|
||||
percentiles=[10, 90, 10],
|
||||
)
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from superset.utils.pandas_postprocessing import compare
|
||||
from tests.unit_tests.fixtures.dataframes import timeseries_df2
|
||||
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
|
||||
|
||||
|
||||
def test_compare():
|
||||
# `difference` comparison
|
||||
post_df = compare(
|
||||
df=timeseries_df2,
|
||||
source_columns=["y"],
|
||||
compare_columns=["z"],
|
||||
compare_type="difference",
|
||||
)
|
||||
assert post_df.columns.tolist() == ["label", "y", "z", "difference__y__z"]
|
||||
assert series_to_list(post_df["difference__y__z"]) == [0.0, -2.0, -8.0, -6.0]
|
||||
|
||||
# drop original columns
|
||||
post_df = compare(
|
||||
df=timeseries_df2,
|
||||
source_columns=["y"],
|
||||
compare_columns=["z"],
|
||||
compare_type="difference",
|
||||
drop_original_columns=True,
|
||||
)
|
||||
assert post_df.columns.tolist() == ["label", "difference__y__z"]
|
||||
|
||||
# `percentage` comparison
|
||||
post_df = compare(
|
||||
df=timeseries_df2,
|
||||
source_columns=["y"],
|
||||
compare_columns=["z"],
|
||||
compare_type="percentage",
|
||||
)
|
||||
assert post_df.columns.tolist() == ["label", "y", "z", "percentage__y__z"]
|
||||
assert series_to_list(post_df["percentage__y__z"]) == [0.0, -0.5, -0.8, -0.75]
|
||||
|
||||
# `ratio` comparison
|
||||
post_df = compare(
|
||||
df=timeseries_df2,
|
||||
source_columns=["y"],
|
||||
compare_columns=["z"],
|
||||
compare_type="ratio",
|
||||
)
|
||||
assert post_df.columns.tolist() == ["label", "y", "z", "ratio__y__z"]
|
||||
assert series_to_list(post_df["ratio__y__z"]) == [1.0, 0.5, 0.2, 0.25]
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from pandas import DataFrame
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.utils.core import DTTM_ALIAS, PostProcessingContributionOrientation
|
||||
from superset.utils.pandas_postprocessing import contribution
|
||||
|
||||
|
||||
def test_contribution():
|
||||
df = DataFrame(
|
||||
{
|
||||
DTTM_ALIAS: [datetime(2020, 7, 16, 14, 49), datetime(2020, 7, 16, 14, 50),],
|
||||
"a": [1, 3],
|
||||
"b": [1, 9],
|
||||
}
|
||||
)
|
||||
with pytest.raises(QueryObjectValidationError, match="not numeric"):
|
||||
contribution(df, columns=[DTTM_ALIAS])
|
||||
|
||||
with pytest.raises(QueryObjectValidationError, match="same length"):
|
||||
contribution(df, columns=["a"], rename_columns=["aa", "bb"])
|
||||
|
||||
# cell contribution across row
|
||||
processed_df = contribution(
|
||||
df, orientation=PostProcessingContributionOrientation.ROW,
|
||||
)
|
||||
assert processed_df.columns.tolist() == [DTTM_ALIAS, "a", "b"]
|
||||
assert processed_df["a"].tolist() == [0.5, 0.25]
|
||||
assert processed_df["b"].tolist() == [0.5, 0.75]
|
||||
|
||||
# cell contribution across column without temporal column
|
||||
df.pop(DTTM_ALIAS)
|
||||
processed_df = contribution(
|
||||
df, orientation=PostProcessingContributionOrientation.COLUMN
|
||||
)
|
||||
assert processed_df.columns.tolist() == ["a", "b"]
|
||||
assert processed_df["a"].tolist() == [0.25, 0.75]
|
||||
assert processed_df["b"].tolist() == [0.1, 0.9]
|
||||
|
||||
# contribution only on selected columns
|
||||
processed_df = contribution(
|
||||
df,
|
||||
orientation=PostProcessingContributionOrientation.COLUMN,
|
||||
columns=["a"],
|
||||
rename_columns=["pct_a"],
|
||||
)
|
||||
assert processed_df.columns.tolist() == ["a", "b", "pct_a"]
|
||||
assert processed_df["a"].tolist() == [1, 3]
|
||||
assert processed_df["b"].tolist() == [1, 9]
|
||||
assert processed_df["pct_a"].tolist() == [0.25, 0.75]
|
||||
|
|
@ -0,0 +1,97 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import pytest
|
||||
from pandas import to_datetime
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.utils.pandas_postprocessing import cum, pivot
|
||||
from tests.unit_tests.fixtures.dataframes import (
|
||||
multiple_metrics_df,
|
||||
single_metric_df,
|
||||
timeseries_df,
|
||||
)
|
||||
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
|
||||
|
||||
|
||||
def test_cum():
|
||||
# create new column (cumsum)
|
||||
post_df = cum(df=timeseries_df, columns={"y": "y2"}, operator="sum",)
|
||||
assert post_df.columns.tolist() == ["label", "y", "y2"]
|
||||
assert series_to_list(post_df["label"]) == ["x", "y", "z", "q"]
|
||||
assert series_to_list(post_df["y"]) == [1.0, 2.0, 3.0, 4.0]
|
||||
assert series_to_list(post_df["y2"]) == [1.0, 3.0, 6.0, 10.0]
|
||||
|
||||
# overwrite column (cumprod)
|
||||
post_df = cum(df=timeseries_df, columns={"y": "y"}, operator="prod",)
|
||||
assert post_df.columns.tolist() == ["label", "y"]
|
||||
assert series_to_list(post_df["y"]) == [1.0, 2.0, 6.0, 24.0]
|
||||
|
||||
# overwrite column (cummin)
|
||||
post_df = cum(df=timeseries_df, columns={"y": "y"}, operator="min",)
|
||||
assert post_df.columns.tolist() == ["label", "y"]
|
||||
assert series_to_list(post_df["y"]) == [1.0, 1.0, 1.0, 1.0]
|
||||
|
||||
# invalid operator
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
cum(
|
||||
df=timeseries_df, columns={"y": "y"}, operator="abc",
|
||||
)
|
||||
|
||||
|
||||
def test_cum_with_pivot_df_and_single_metric():
|
||||
pivot_df = pivot(
|
||||
df=single_metric_df,
|
||||
index=["dttm"],
|
||||
columns=["country"],
|
||||
aggregates={"sum_metric": {"operator": "sum"}},
|
||||
flatten_columns=False,
|
||||
reset_index=False,
|
||||
)
|
||||
cum_df = cum(df=pivot_df, operator="sum", is_pivot_df=True,)
|
||||
# dttm UK US
|
||||
# 0 2019-01-01 5 6
|
||||
# 1 2019-01-02 12 14
|
||||
assert cum_df["UK"].to_list() == [5.0, 12.0]
|
||||
assert cum_df["US"].to_list() == [6.0, 14.0]
|
||||
assert (
|
||||
cum_df["dttm"].to_list() == to_datetime(["2019-01-01", "2019-01-02"]).to_list()
|
||||
)
|
||||
|
||||
|
||||
def test_cum_with_pivot_df_and_multiple_metrics():
|
||||
pivot_df = pivot(
|
||||
df=multiple_metrics_df,
|
||||
index=["dttm"],
|
||||
columns=["country"],
|
||||
aggregates={
|
||||
"sum_metric": {"operator": "sum"},
|
||||
"count_metric": {"operator": "sum"},
|
||||
},
|
||||
flatten_columns=False,
|
||||
reset_index=False,
|
||||
)
|
||||
cum_df = cum(df=pivot_df, operator="sum", is_pivot_df=True,)
|
||||
# dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
|
||||
# 0 2019-01-01 1 2 5 6
|
||||
# 1 2019-01-02 4 6 12 14
|
||||
assert cum_df["count_metric, UK"].to_list() == [1.0, 4.0]
|
||||
assert cum_df["count_metric, US"].to_list() == [2.0, 6.0]
|
||||
assert cum_df["sum_metric, UK"].to_list() == [5.0, 12.0]
|
||||
assert cum_df["sum_metric, US"].to_list() == [6.0, 14.0]
|
||||
assert (
|
||||
cum_df["dttm"].to_list() == to_datetime(["2019-01-01", "2019-01-02"]).to_list()
|
||||
)
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import pytest
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.utils.pandas_postprocessing import diff
|
||||
from tests.unit_tests.fixtures.dataframes import timeseries_df, timeseries_df2
|
||||
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
|
||||
|
||||
|
||||
def test_diff():
|
||||
# overwrite column
|
||||
post_df = diff(df=timeseries_df, columns={"y": "y"})
|
||||
assert post_df.columns.tolist() == ["label", "y"]
|
||||
assert series_to_list(post_df["y"]) == [None, 1.0, 1.0, 1.0]
|
||||
|
||||
# add column
|
||||
post_df = diff(df=timeseries_df, columns={"y": "y1"})
|
||||
assert post_df.columns.tolist() == ["label", "y", "y1"]
|
||||
assert series_to_list(post_df["y"]) == [1.0, 2.0, 3.0, 4.0]
|
||||
assert series_to_list(post_df["y1"]) == [None, 1.0, 1.0, 1.0]
|
||||
|
||||
# look ahead
|
||||
post_df = diff(df=timeseries_df, columns={"y": "y1"}, periods=-1)
|
||||
assert series_to_list(post_df["y1"]) == [-1.0, -1.0, -1.0, None]
|
||||
|
||||
# invalid column reference
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
diff(
|
||||
df=timeseries_df, columns={"abc": "abc"},
|
||||
)
|
||||
|
||||
# diff by columns
|
||||
post_df = diff(df=timeseries_df2, columns={"y": "y", "z": "z"}, axis=1)
|
||||
assert post_df.columns.tolist() == ["label", "y", "z"]
|
||||
assert series_to_list(post_df["z"]) == [0.0, 2.0, 8.0, 6.0]
|
||||
|
|
@ -0,0 +1,90 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from superset.utils.pandas_postprocessing import (
|
||||
geodetic_parse,
|
||||
geohash_decode,
|
||||
geohash_encode,
|
||||
)
|
||||
from tests.unit_tests.fixtures.dataframes import lonlat_df
|
||||
from tests.unit_tests.pandas_postprocessing.utils import round_floats, series_to_list
|
||||
|
||||
|
||||
def test_geohash_decode():
|
||||
# decode lon/lat from geohash
|
||||
post_df = geohash_decode(
|
||||
df=lonlat_df[["city", "geohash"]],
|
||||
geohash="geohash",
|
||||
latitude="latitude",
|
||||
longitude="longitude",
|
||||
)
|
||||
assert sorted(post_df.columns.tolist()) == sorted(
|
||||
["city", "geohash", "latitude", "longitude"]
|
||||
)
|
||||
assert round_floats(series_to_list(post_df["longitude"]), 6) == round_floats(
|
||||
series_to_list(lonlat_df["longitude"]), 6
|
||||
)
|
||||
assert round_floats(series_to_list(post_df["latitude"]), 6) == round_floats(
|
||||
series_to_list(lonlat_df["latitude"]), 6
|
||||
)
|
||||
|
||||
|
||||
def test_geohash_encode():
|
||||
# encode lon/lat into geohash
|
||||
post_df = geohash_encode(
|
||||
df=lonlat_df[["city", "latitude", "longitude"]],
|
||||
latitude="latitude",
|
||||
longitude="longitude",
|
||||
geohash="geohash",
|
||||
)
|
||||
assert sorted(post_df.columns.tolist()) == sorted(
|
||||
["city", "geohash", "latitude", "longitude"]
|
||||
)
|
||||
assert series_to_list(post_df["geohash"]) == series_to_list(lonlat_df["geohash"])
|
||||
|
||||
|
||||
def test_geodetic_parse():
|
||||
# parse geodetic string with altitude into lon/lat/altitude
|
||||
post_df = geodetic_parse(
|
||||
df=lonlat_df[["city", "geodetic"]],
|
||||
geodetic="geodetic",
|
||||
latitude="latitude",
|
||||
longitude="longitude",
|
||||
altitude="altitude",
|
||||
)
|
||||
assert sorted(post_df.columns.tolist()) == sorted(
|
||||
["city", "geodetic", "latitude", "longitude", "altitude"]
|
||||
)
|
||||
assert series_to_list(post_df["longitude"]) == series_to_list(
|
||||
lonlat_df["longitude"]
|
||||
)
|
||||
assert series_to_list(post_df["latitude"]) == series_to_list(lonlat_df["latitude"])
|
||||
assert series_to_list(post_df["altitude"]) == series_to_list(lonlat_df["altitude"])
|
||||
|
||||
# parse geodetic string into lon/lat
|
||||
post_df = geodetic_parse(
|
||||
df=lonlat_df[["city", "geodetic"]],
|
||||
geodetic="geodetic",
|
||||
latitude="latitude",
|
||||
longitude="longitude",
|
||||
)
|
||||
assert sorted(post_df.columns.tolist()) == sorted(
|
||||
["city", "geodetic", "latitude", "longitude"]
|
||||
)
|
||||
assert series_to_list(post_df["longitude"]) == series_to_list(
|
||||
lonlat_df["longitude"]
|
||||
)
|
||||
assert series_to_list(post_df["latitude"]), series_to_list(lonlat_df["latitude"])
|
||||
|
|
@ -0,0 +1,266 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from pandas import DataFrame, Timestamp, to_datetime
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.utils.pandas_postprocessing import _flatten_column_after_pivot, pivot
|
||||
from tests.unit_tests.fixtures.dataframes import categories_df, single_metric_df
|
||||
from tests.unit_tests.pandas_postprocessing.utils import (
|
||||
AGGREGATES_MULTIPLE,
|
||||
AGGREGATES_SINGLE,
|
||||
)
|
||||
|
||||
|
||||
def test_flatten_column_after_pivot():
|
||||
"""
|
||||
Test pivot column flattening function
|
||||
"""
|
||||
# single aggregate cases
|
||||
assert (
|
||||
_flatten_column_after_pivot(aggregates=AGGREGATES_SINGLE, column="idx_nulls",)
|
||||
== "idx_nulls"
|
||||
)
|
||||
|
||||
assert (
|
||||
_flatten_column_after_pivot(aggregates=AGGREGATES_SINGLE, column=1234,)
|
||||
== "1234"
|
||||
)
|
||||
|
||||
assert (
|
||||
_flatten_column_after_pivot(
|
||||
aggregates=AGGREGATES_SINGLE, column=Timestamp("2020-09-29T00:00:00"),
|
||||
)
|
||||
== "2020-09-29 00:00:00"
|
||||
)
|
||||
|
||||
assert (
|
||||
_flatten_column_after_pivot(aggregates=AGGREGATES_SINGLE, column="idx_nulls",)
|
||||
== "idx_nulls"
|
||||
)
|
||||
|
||||
assert (
|
||||
_flatten_column_after_pivot(
|
||||
aggregates=AGGREGATES_SINGLE, column=("idx_nulls", "col1"),
|
||||
)
|
||||
== "col1"
|
||||
)
|
||||
|
||||
assert (
|
||||
_flatten_column_after_pivot(
|
||||
aggregates=AGGREGATES_SINGLE, column=("idx_nulls", "col1", 1234),
|
||||
)
|
||||
== "col1, 1234"
|
||||
)
|
||||
|
||||
# Multiple aggregate cases
|
||||
assert (
|
||||
_flatten_column_after_pivot(
|
||||
aggregates=AGGREGATES_MULTIPLE, column=("idx_nulls", "asc_idx", "col1"),
|
||||
)
|
||||
== "idx_nulls, asc_idx, col1"
|
||||
)
|
||||
|
||||
assert (
|
||||
_flatten_column_after_pivot(
|
||||
aggregates=AGGREGATES_MULTIPLE,
|
||||
column=("idx_nulls", "asc_idx", "col1", 1234),
|
||||
)
|
||||
== "idx_nulls, asc_idx, col1, 1234"
|
||||
)
|
||||
|
||||
|
||||
def test_pivot_without_columns():
|
||||
"""
|
||||
Make sure pivot without columns returns correct DataFrame
|
||||
"""
|
||||
df = pivot(df=categories_df, index=["name"], aggregates=AGGREGATES_SINGLE,)
|
||||
assert df.columns.tolist() == ["name", "idx_nulls"]
|
||||
assert len(df) == 101
|
||||
assert df.sum()[1] == 1050
|
||||
|
||||
|
||||
def test_pivot_with_single_column():
|
||||
"""
|
||||
Make sure pivot with single column returns correct DataFrame
|
||||
"""
|
||||
df = pivot(
|
||||
df=categories_df,
|
||||
index=["name"],
|
||||
columns=["category"],
|
||||
aggregates=AGGREGATES_SINGLE,
|
||||
)
|
||||
assert df.columns.tolist() == ["name", "cat0", "cat1", "cat2"]
|
||||
assert len(df) == 101
|
||||
assert df.sum()[1] == 315
|
||||
|
||||
df = pivot(
|
||||
df=categories_df,
|
||||
index=["dept"],
|
||||
columns=["category"],
|
||||
aggregates=AGGREGATES_SINGLE,
|
||||
)
|
||||
assert df.columns.tolist() == ["dept", "cat0", "cat1", "cat2"]
|
||||
assert len(df) == 5
|
||||
|
||||
|
||||
def test_pivot_with_multiple_columns():
|
||||
"""
|
||||
Make sure pivot with multiple columns returns correct DataFrame
|
||||
"""
|
||||
df = pivot(
|
||||
df=categories_df,
|
||||
index=["name"],
|
||||
columns=["category", "dept"],
|
||||
aggregates=AGGREGATES_SINGLE,
|
||||
)
|
||||
assert len(df.columns) == 1 + 3 * 5 # index + possible permutations
|
||||
|
||||
|
||||
def test_pivot_fill_values():
|
||||
"""
|
||||
Make sure pivot with fill values returns correct DataFrame
|
||||
"""
|
||||
df = pivot(
|
||||
df=categories_df,
|
||||
index=["name"],
|
||||
columns=["category"],
|
||||
metric_fill_value=1,
|
||||
aggregates={"idx_nulls": {"operator": "sum"}},
|
||||
)
|
||||
assert df.sum()[1] == 382
|
||||
|
||||
|
||||
def test_pivot_fill_column_values():
|
||||
"""
|
||||
Make sure pivot witn null column names returns correct DataFrame
|
||||
"""
|
||||
df_copy = categories_df.copy()
|
||||
df_copy["category"] = None
|
||||
df = pivot(
|
||||
df=df_copy,
|
||||
index=["name"],
|
||||
columns=["category"],
|
||||
aggregates={"idx_nulls": {"operator": "sum"}},
|
||||
)
|
||||
assert len(df) == 101
|
||||
assert df.columns.tolist() == ["name", "<NULL>"]
|
||||
|
||||
|
||||
def test_pivot_exceptions():
|
||||
"""
|
||||
Make sure pivot raises correct Exceptions
|
||||
"""
|
||||
# Missing index
|
||||
with pytest.raises(TypeError):
|
||||
pivot(df=categories_df, columns=["dept"], aggregates=AGGREGATES_SINGLE)
|
||||
|
||||
# invalid index reference
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
pivot(
|
||||
df=categories_df,
|
||||
index=["abc"],
|
||||
columns=["dept"],
|
||||
aggregates=AGGREGATES_SINGLE,
|
||||
)
|
||||
|
||||
# invalid column reference
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
pivot(
|
||||
df=categories_df,
|
||||
index=["dept"],
|
||||
columns=["abc"],
|
||||
aggregates=AGGREGATES_SINGLE,
|
||||
)
|
||||
|
||||
# invalid aggregate options
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
pivot(
|
||||
df=categories_df,
|
||||
index=["name"],
|
||||
columns=["category"],
|
||||
aggregates={"idx_nulls": {}},
|
||||
)
|
||||
|
||||
|
||||
def test_pivot_eliminate_cartesian_product_columns():
|
||||
# single metric
|
||||
mock_df = DataFrame(
|
||||
{
|
||||
"dttm": to_datetime(["2019-01-01", "2019-01-01"]),
|
||||
"a": [0, 1],
|
||||
"b": [0, 1],
|
||||
"metric": [9, np.NAN],
|
||||
}
|
||||
)
|
||||
|
||||
df = pivot(
|
||||
df=mock_df,
|
||||
index=["dttm"],
|
||||
columns=["a", "b"],
|
||||
aggregates={"metric": {"operator": "mean"}},
|
||||
drop_missing_columns=False,
|
||||
)
|
||||
assert list(df.columns) == ["dttm", "0, 0", "1, 1"]
|
||||
assert np.isnan(df["1, 1"][0])
|
||||
|
||||
# multiple metrics
|
||||
mock_df = DataFrame(
|
||||
{
|
||||
"dttm": to_datetime(["2019-01-01", "2019-01-01"]),
|
||||
"a": [0, 1],
|
||||
"b": [0, 1],
|
||||
"metric": [9, np.NAN],
|
||||
"metric2": [10, 11],
|
||||
}
|
||||
)
|
||||
|
||||
df = pivot(
|
||||
df=mock_df,
|
||||
index=["dttm"],
|
||||
columns=["a", "b"],
|
||||
aggregates={"metric": {"operator": "mean"}, "metric2": {"operator": "mean"},},
|
||||
drop_missing_columns=False,
|
||||
)
|
||||
assert list(df.columns) == [
|
||||
"dttm",
|
||||
"metric, 0, 0",
|
||||
"metric, 1, 1",
|
||||
"metric2, 0, 0",
|
||||
"metric2, 1, 1",
|
||||
]
|
||||
assert np.isnan(df["metric, 1, 1"][0])
|
||||
|
||||
|
||||
def test_pivot_without_flatten_columns_and_reset_index():
|
||||
df = pivot(
|
||||
df=single_metric_df,
|
||||
index=["dttm"],
|
||||
columns=["country"],
|
||||
aggregates={"sum_metric": {"operator": "sum"}},
|
||||
flatten_columns=False,
|
||||
reset_index=False,
|
||||
)
|
||||
# metric
|
||||
# country UK US
|
||||
# dttm
|
||||
# 2019-01-01 5 6
|
||||
# 2019-01-02 7 8
|
||||
assert df.columns.to_list() == [("sum_metric", "UK"), ("sum_metric", "US")]
|
||||
assert df.index.to_list() == to_datetime(["2019-01-01", "2019-01-02"]).to_list()
|
||||
|
|
@ -0,0 +1,114 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from datetime import datetime
|
||||
from importlib.util import find_spec
|
||||
|
||||
import pytest
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.utils.core import DTTM_ALIAS
|
||||
from superset.utils.pandas_postprocessing import prophet
|
||||
from tests.unit_tests.fixtures.dataframes import prophet_df
|
||||
|
||||
|
||||
def test_prophet_valid():
|
||||
pytest.importorskip("prophet")
|
||||
|
||||
df = prophet(df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.9)
|
||||
columns = {column for column in df.columns}
|
||||
assert columns == {
|
||||
DTTM_ALIAS,
|
||||
"a__yhat",
|
||||
"a__yhat_upper",
|
||||
"a__yhat_lower",
|
||||
"a",
|
||||
"b__yhat",
|
||||
"b__yhat_upper",
|
||||
"b__yhat_lower",
|
||||
"b",
|
||||
}
|
||||
assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31)
|
||||
assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 3, 31)
|
||||
assert len(df) == 7
|
||||
|
||||
df = prophet(df=prophet_df, time_grain="P1M", periods=5, confidence_interval=0.9)
|
||||
assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31)
|
||||
assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 5, 31)
|
||||
assert len(df) == 9
|
||||
|
||||
|
||||
def test_prophet_valid_zero_periods():
|
||||
pytest.importorskip("prophet")
|
||||
|
||||
df = prophet(df=prophet_df, time_grain="P1M", periods=0, confidence_interval=0.9)
|
||||
columns = {column for column in df.columns}
|
||||
assert columns == {
|
||||
DTTM_ALIAS,
|
||||
"a__yhat",
|
||||
"a__yhat_upper",
|
||||
"a__yhat_lower",
|
||||
"a",
|
||||
"b__yhat",
|
||||
"b__yhat_upper",
|
||||
"b__yhat_lower",
|
||||
"b",
|
||||
}
|
||||
assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31)
|
||||
assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2021, 12, 31)
|
||||
assert len(df) == 4
|
||||
|
||||
|
||||
def test_prophet_import():
|
||||
dynamic_module = find_spec("prophet")
|
||||
if dynamic_module is None:
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
prophet(df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.9)
|
||||
|
||||
|
||||
def test_prophet_missing_temporal_column():
|
||||
df = prophet_df.drop(DTTM_ALIAS, axis=1)
|
||||
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
prophet(
|
||||
df=df, time_grain="P1M", periods=3, confidence_interval=0.9,
|
||||
)
|
||||
|
||||
|
||||
def test_prophet_incorrect_confidence_interval():
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
prophet(
|
||||
df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.0,
|
||||
)
|
||||
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
prophet(
|
||||
df=prophet_df, time_grain="P1M", periods=3, confidence_interval=1.0,
|
||||
)
|
||||
|
||||
|
||||
def test_prophet_incorrect_periods():
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
prophet(
|
||||
df=prophet_df, time_grain="P1M", periods=-1, confidence_interval=0.8,
|
||||
)
|
||||
|
||||
|
||||
def test_prophet_incorrect_time_grain():
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
prophet(
|
||||
df=prophet_df, time_grain="yearly", periods=10, confidence_interval=0.8,
|
||||
)
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import pytest
|
||||
from pandas import DataFrame, to_datetime
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.utils.pandas_postprocessing import resample
|
||||
from tests.unit_tests.fixtures.dataframes import timeseries_df
|
||||
|
||||
|
||||
def test_resample():
|
||||
df = timeseries_df.copy()
|
||||
df.index.name = "time_column"
|
||||
df.reset_index(inplace=True)
|
||||
|
||||
post_df = resample(df=df, rule="1D", method="ffill", time_column="time_column",)
|
||||
assert post_df["label"].tolist() == ["x", "y", "y", "y", "z", "z", "q"]
|
||||
|
||||
assert post_df["y"].tolist() == [1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 4.0]
|
||||
|
||||
post_df = resample(
|
||||
df=df, rule="1D", method="asfreq", time_column="time_column", fill_value=0,
|
||||
)
|
||||
assert post_df["label"].tolist() == ["x", "y", 0, 0, "z", 0, "q"]
|
||||
assert post_df["y"].tolist() == [1.0, 2.0, 0, 0, 3.0, 0, 4.0]
|
||||
|
||||
|
||||
def test_resample_with_groupby():
|
||||
"""
|
||||
The Dataframe contains a timestamp column, a string column and a numeric column.
|
||||
__timestamp city val
|
||||
0 2022-01-13 Chicago 6.0
|
||||
1 2022-01-13 LA 5.0
|
||||
2 2022-01-13 NY 4.0
|
||||
3 2022-01-11 Chicago 3.0
|
||||
4 2022-01-11 LA 2.0
|
||||
5 2022-01-11 NY 1.0
|
||||
"""
|
||||
df = DataFrame(
|
||||
{
|
||||
"__timestamp": to_datetime(
|
||||
[
|
||||
"2022-01-13",
|
||||
"2022-01-13",
|
||||
"2022-01-13",
|
||||
"2022-01-11",
|
||||
"2022-01-11",
|
||||
"2022-01-11",
|
||||
]
|
||||
),
|
||||
"city": ["Chicago", "LA", "NY", "Chicago", "LA", "NY"],
|
||||
"val": [6.0, 5.0, 4.0, 3.0, 2.0, 1.0],
|
||||
}
|
||||
)
|
||||
post_df = resample(
|
||||
df=df,
|
||||
rule="1D",
|
||||
method="asfreq",
|
||||
fill_value=0,
|
||||
time_column="__timestamp",
|
||||
groupby_columns=("city",),
|
||||
)
|
||||
assert list(post_df.columns) == [
|
||||
"__timestamp",
|
||||
"city",
|
||||
"val",
|
||||
]
|
||||
assert [str(dt.date()) for dt in post_df["__timestamp"]] == (
|
||||
["2022-01-11"] * 3 + ["2022-01-12"] * 3 + ["2022-01-13"] * 3
|
||||
)
|
||||
assert list(post_df["val"]) == [3.0, 2.0, 1.0, 0, 0, 0, 6.0, 5.0, 4.0]
|
||||
|
||||
# should raise error when get a non-existent column
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
resample(
|
||||
df=df,
|
||||
rule="1D",
|
||||
method="asfreq",
|
||||
fill_value=0,
|
||||
time_column="__timestamp",
|
||||
groupby_columns=("city", "unkonw_column",),
|
||||
)
|
||||
|
||||
# should raise error when get a None value in groupby list
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
resample(
|
||||
df=df,
|
||||
rule="1D",
|
||||
method="asfreq",
|
||||
fill_value=0,
|
||||
time_column="__timestamp",
|
||||
groupby_columns=("city", None,),
|
||||
)
|
||||
|
|
@ -0,0 +1,147 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import pytest
|
||||
from pandas import to_datetime
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.utils.pandas_postprocessing import pivot, rolling
|
||||
from tests.unit_tests.fixtures.dataframes import (
|
||||
multiple_metrics_df,
|
||||
single_metric_df,
|
||||
timeseries_df,
|
||||
)
|
||||
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
|
||||
|
||||
|
||||
def test_rolling():
|
||||
# sum rolling type
|
||||
post_df = rolling(
|
||||
df=timeseries_df,
|
||||
columns={"y": "y"},
|
||||
rolling_type="sum",
|
||||
window=2,
|
||||
min_periods=0,
|
||||
)
|
||||
|
||||
assert post_df.columns.tolist() == ["label", "y"]
|
||||
assert series_to_list(post_df["y"]) == [1.0, 3.0, 5.0, 7.0]
|
||||
|
||||
# mean rolling type with alias
|
||||
post_df = rolling(
|
||||
df=timeseries_df,
|
||||
rolling_type="mean",
|
||||
columns={"y": "y_mean"},
|
||||
window=10,
|
||||
min_periods=0,
|
||||
)
|
||||
assert post_df.columns.tolist() == ["label", "y", "y_mean"]
|
||||
assert series_to_list(post_df["y_mean"]) == [1.0, 1.5, 2.0, 2.5]
|
||||
|
||||
# count rolling type
|
||||
post_df = rolling(
|
||||
df=timeseries_df,
|
||||
rolling_type="count",
|
||||
columns={"y": "y"},
|
||||
window=10,
|
||||
min_periods=0,
|
||||
)
|
||||
assert post_df.columns.tolist() == ["label", "y"]
|
||||
assert series_to_list(post_df["y"]) == [1.0, 2.0, 3.0, 4.0]
|
||||
|
||||
# quantile rolling type
|
||||
post_df = rolling(
|
||||
df=timeseries_df,
|
||||
columns={"y": "q1"},
|
||||
rolling_type="quantile",
|
||||
rolling_type_options={"quantile": 0.25},
|
||||
window=10,
|
||||
min_periods=0,
|
||||
)
|
||||
assert post_df.columns.tolist() == ["label", "y", "q1"]
|
||||
assert series_to_list(post_df["q1"]) == [1.0, 1.25, 1.5, 1.75]
|
||||
|
||||
# incorrect rolling type
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
rolling(
|
||||
df=timeseries_df, columns={"y": "y"}, rolling_type="abc", window=2,
|
||||
)
|
||||
|
||||
# incorrect rolling type options
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
rolling(
|
||||
df=timeseries_df,
|
||||
columns={"y": "y"},
|
||||
rolling_type="quantile",
|
||||
rolling_type_options={"abc": 123},
|
||||
window=2,
|
||||
)
|
||||
|
||||
|
||||
def test_rolling_with_pivot_df_and_single_metric():
|
||||
pivot_df = pivot(
|
||||
df=single_metric_df,
|
||||
index=["dttm"],
|
||||
columns=["country"],
|
||||
aggregates={"sum_metric": {"operator": "sum"}},
|
||||
flatten_columns=False,
|
||||
reset_index=False,
|
||||
)
|
||||
rolling_df = rolling(
|
||||
df=pivot_df, rolling_type="sum", window=2, min_periods=0, is_pivot_df=True,
|
||||
)
|
||||
# dttm UK US
|
||||
# 0 2019-01-01 5 6
|
||||
# 1 2019-01-02 12 14
|
||||
assert rolling_df["UK"].to_list() == [5.0, 12.0]
|
||||
assert rolling_df["US"].to_list() == [6.0, 14.0]
|
||||
assert (
|
||||
rolling_df["dttm"].to_list()
|
||||
== to_datetime(["2019-01-01", "2019-01-02"]).to_list()
|
||||
)
|
||||
|
||||
rolling_df = rolling(
|
||||
df=pivot_df, rolling_type="sum", window=2, min_periods=2, is_pivot_df=True,
|
||||
)
|
||||
assert rolling_df.empty is True
|
||||
|
||||
|
||||
def test_rolling_with_pivot_df_and_multiple_metrics():
|
||||
pivot_df = pivot(
|
||||
df=multiple_metrics_df,
|
||||
index=["dttm"],
|
||||
columns=["country"],
|
||||
aggregates={
|
||||
"sum_metric": {"operator": "sum"},
|
||||
"count_metric": {"operator": "sum"},
|
||||
},
|
||||
flatten_columns=False,
|
||||
reset_index=False,
|
||||
)
|
||||
rolling_df = rolling(
|
||||
df=pivot_df, rolling_type="sum", window=2, min_periods=0, is_pivot_df=True,
|
||||
)
|
||||
# dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
|
||||
# 0 2019-01-01 1.0 2.0 5.0 6.0
|
||||
# 1 2019-01-02 4.0 6.0 12.0 14.0
|
||||
assert rolling_df["count_metric, UK"].to_list() == [1.0, 4.0]
|
||||
assert rolling_df["count_metric, US"].to_list() == [2.0, 6.0]
|
||||
assert rolling_df["sum_metric, UK"].to_list() == [5.0, 12.0]
|
||||
assert rolling_df["sum_metric, US"].to_list() == [6.0, 14.0]
|
||||
assert (
|
||||
rolling_df["dttm"].to_list()
|
||||
== to_datetime(["2019-01-01", "2019-01-02",]).to_list()
|
||||
)
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import pytest
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.utils.pandas_postprocessing.select import select
|
||||
from tests.unit_tests.fixtures.dataframes import timeseries_df
|
||||
|
||||
|
||||
def test_select():
|
||||
# reorder columns
|
||||
post_df = select(df=timeseries_df, columns=["y", "label"])
|
||||
assert post_df.columns.tolist() == ["y", "label"]
|
||||
|
||||
# one column
|
||||
post_df = select(df=timeseries_df, columns=["label"])
|
||||
assert post_df.columns.tolist() == ["label"]
|
||||
|
||||
# rename and select one column
|
||||
post_df = select(df=timeseries_df, columns=["y"], rename={"y": "y1"})
|
||||
assert post_df.columns.tolist() == ["y1"]
|
||||
|
||||
# rename one and leave one unchanged
|
||||
post_df = select(df=timeseries_df, rename={"y": "y1"})
|
||||
assert post_df.columns.tolist() == ["label", "y1"]
|
||||
|
||||
# drop one column
|
||||
post_df = select(df=timeseries_df, exclude=["label"])
|
||||
assert post_df.columns.tolist() == ["y"]
|
||||
|
||||
# rename and drop one column
|
||||
post_df = select(df=timeseries_df, rename={"y": "y1"}, exclude=["label"])
|
||||
assert post_df.columns.tolist() == ["y1"]
|
||||
|
||||
# invalid columns
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
select(df=timeseries_df, columns=["abc"], rename={"abc": "qwerty"})
|
||||
|
||||
# select renamed column by new name
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
select(df=timeseries_df, columns=["label_new"], rename={"label": "label_new"})
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import pytest
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.utils.pandas_postprocessing import sort
|
||||
from tests.unit_tests.fixtures.dataframes import categories_df
|
||||
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
|
||||
|
||||
|
||||
def test_sort():
|
||||
df = sort(df=categories_df, columns={"category": True, "asc_idx": False})
|
||||
assert series_to_list(df["asc_idx"])[1] == 96
|
||||
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
sort(df=df, columns={"abc": True})
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import math
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pandas import Series
|
||||
|
||||
AGGREGATES_SINGLE = {"idx_nulls": {"operator": "sum"}}
|
||||
AGGREGATES_MULTIPLE = {
|
||||
"idx_nulls": {"operator": "sum"},
|
||||
"asc_idx": {"operator": "mean"},
|
||||
}
|
||||
|
||||
|
||||
def series_to_list(series: Series) -> List[Any]:
|
||||
"""
|
||||
Converts a `Series` to a regular list, and replaces non-numeric values to
|
||||
Nones.
|
||||
|
||||
:param series: Series to convert
|
||||
:return: list without nan or inf
|
||||
"""
|
||||
return [
|
||||
None
|
||||
if not isinstance(val, str) and (math.isnan(val) or math.isinf(val))
|
||||
else val
|
||||
for val in series.tolist()
|
||||
]
|
||||
|
||||
|
||||
def round_floats(
|
||||
floats: List[Optional[float]], precision: int
|
||||
) -> List[Optional[float]]:
|
||||
"""
|
||||
Round list of floats to certain precision
|
||||
|
||||
:param floats: floats to round
|
||||
:param precision: intended decimal precision
|
||||
:return: rounded floats
|
||||
"""
|
||||
return [round(val, precision) if val else None for val in floats]
|
||||
Loading…
Reference in New Issue