From d466383df26bcfd7bad15fa4ae88ebbbde0aa94a Mon Sep 17 00:00:00 2001 From: Elizabeth Thompson Date: Fri, 1 Nov 2024 15:08:42 -0700 Subject: [PATCH] fix: warning emits an error (#28524) --- .../utils/pandas_postprocessing/compare.py | 7 +- .../pandas_postprocessing/test_compare.py | 67 +++++++++++++++++++ 2 files changed, 73 insertions(+), 1 deletion(-) diff --git a/superset/utils/pandas_postprocessing/compare.py b/superset/utils/pandas_postprocessing/compare.py index b20682027..64442280b 100644 --- a/superset/utils/pandas_postprocessing/compare.py +++ b/superset/utils/pandas_postprocessing/compare.py @@ -81,5 +81,10 @@ def compare( # pylint: disable=too-many-arguments df = pd.concat([df, diff_df], axis=1) if drop_original_columns: - df = df.drop(source_columns + compare_columns, axis=1) + level = ( + 0 + if isinstance(df.columns, pd.MultiIndex) and df.columns.nlevels > 1 + else None + ) + df = df.drop(source_columns + compare_columns, axis=1, level=level) return df diff --git a/tests/unit_tests/pandas_postprocessing/test_compare.py b/tests/unit_tests/pandas_postprocessing/test_compare.py index 9da8a3153..a26aa11d2 100644 --- a/tests/unit_tests/pandas_postprocessing/test_compare.py +++ b/tests/unit_tests/pandas_postprocessing/test_compare.py @@ -14,6 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import io +import sys + import pandas as pd from superset.constants import PandasPostprocessingCompare as PPC @@ -179,6 +182,70 @@ def test_compare_multi_index_column(): ) +def test_compare_multi_index_column_non_lex_sorted(): + index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"]) + index.name = "__timestamp" + + iterables = [["m1", "m2"], ["a", "b"], ["x", "y"]] + columns = pd.MultiIndex.from_product(iterables, names=[None, "level1", "level2"]) + + df = pd.DataFrame(index=index, columns=columns, data=1) + + # Define a non-lexicographical column order + # arrange them as m1, m2 instead of m2, m1 + new_columns_order = [ + ("m1", "a", "x"), + ("m1", "a", "y"), + ("m1", "b", "x"), + ("m1", "b", "y"), + ("m2", "a", "x"), + ("m2", "a", "y"), + ("m2", "b", "x"), + ("m2", "b", "y"), + ] + + df.columns = pd.MultiIndex.from_tuples( + new_columns_order, names=["level1", "level2", None] + ) + + # to capture stderr + stderr = sys.stderr + sys.stderr = io.StringIO() + + try: + post_df = pp.compare( + df, + source_columns=["m1"], + compare_columns=["m2"], + compare_type=PPC.DIFF, + drop_original_columns=True, + ) + assert sys.stderr.getvalue() == "" + finally: + sys.stderr = stderr + + flat_df = pp.flatten(post_df) + """ + __timestamp difference__m1__m2, a, x difference__m1__m2, a, y difference__m1__m2, b, x difference__m1__m2, b, y + 0 2021-01-01 0 0 0 0 + 1 2021-01-02 0 0 0 0 + 2 2021-01-03 0 0 0 0 + """ + assert flat_df.equals( + pd.DataFrame( + data={ + "__timestamp": pd.to_datetime( + ["2021-01-01", "2021-01-02", "2021-01-03"] + ), + "difference__m1__m2, a, x": [0, 0, 0], + "difference__m1__m2, a, y": [0, 0, 0], + "difference__m1__m2, b, x": [0, 0, 0], + "difference__m1__m2, b, y": [0, 0, 0], + } + ) + ) + + def test_compare_after_pivot(): pivot_df = pp.pivot( df=multiple_metrics_df,