diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index ae3202981..98aaecebd 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -49,8 +49,10 @@ from superset.utils import csv from superset.utils.cache import generate_cache_key, set_and_log_cache from superset.utils.core import ( DatasourceType, + DateColumn, DTTM_ALIAS, error_msg_from_exception, + get_base_axis_labels, get_column_names_from_columns, get_column_names_from_metrics, get_metric_names, @@ -238,18 +240,57 @@ class QueryContextProcessor: return result def normalize_df(self, df: pd.DataFrame, query_object: QueryObject) -> pd.DataFrame: - datasource = self._qc_datasource - timestamp_format = None - if datasource.type == "table": - dttm_col = datasource.get_column(query_object.granularity) - if dttm_col: - timestamp_format = dttm_col.python_date_format + # todo: should support "python_date_format" and "get_column" in each datasource + def _get_timestamp_format( + source: BaseDatasource, column: Optional[str] + ) -> Optional[str]: + column_obj = source.get_column(column) + if ( + column_obj + # only sqla column was supported + and hasattr(column_obj, "python_date_format") + and (formatter := column_obj.python_date_format) + ): + return str(formatter) + return None + + datasource = self._qc_datasource + labels = tuple( + label + for label in [ + *get_base_axis_labels(query_object.columns), + query_object.granularity, + ] + if datasource + # Query datasource didn't support `get_column` + and hasattr(datasource, "get_column") + and (col := datasource.get_column(label)) + and col.is_dttm + ) + dttm_cols = [ + DateColumn( + timestamp_format=_get_timestamp_format(datasource, label), + offset=datasource.offset, + time_shift=query_object.time_shift, + col_label=label, + ) + for label in labels + if label + ] + if DTTM_ALIAS in df: + dttm_cols.append( + DateColumn.get_legacy_time_column( + timestamp_format=_get_timestamp_format( + datasource, query_object.granularity + ), + offset=datasource.offset, + time_shift=query_object.time_shift, + ) + ) normalize_dttm_col( df=df, - timestamp_format=timestamp_format, - offset=datasource.offset, - time_shift=query_object.time_shift, + dttm_cols=tuple(dttm_cols), ) if self.enforce_numerical_metrics: @@ -344,10 +385,7 @@ class QueryContextProcessor: offset_metrics_df = offset_metrics_df.rename(columns=metrics_mapping) # 3. set time offset for index - # TODO: add x-axis to QueryObject, potentially as an array for - # multi-dimensional charts - granularity = query_object.granularity - index = granularity if granularity in df.columns else DTTM_ALIAS + index = (get_base_axis_labels(query_object.columns) or [DTTM_ALIAS])[0] if not dataframe_utils.is_datetime_series(offset_metrics_df.get(index)): raise QueryObjectValidationError( _( diff --git a/superset/utils/core.py b/superset/utils/core.py index 4131ede78..2786cfa27 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -16,6 +16,8 @@ # under the License. """Utility functions used across Superset""" # pylint: disable=too-many-lines +from __future__ import annotations + import _thread import collections import decimal @@ -34,6 +36,7 @@ import traceback import uuid import zlib from contextlib import contextmanager +from dataclasses import dataclass from datetime import date, datetime, time, timedelta from distutils.util import strtobool from email.mime.application import MIMEApplication @@ -1271,15 +1274,13 @@ def is_adhoc_column(column: Column) -> TypeGuard[AdhocColumn]: return isinstance(column, dict) -def get_base_axis_column(columns: Optional[List[Column]]) -> Optional[AdhocColumn]: - if columns is None: - return None +def get_base_axis_labels(columns: Optional[List[Column]]) -> Tuple[str, ...]: axis_cols = [ col - for col in columns + for col in columns or [] if is_adhoc_column(col) and col.get("columnType") == "BASE_AXIS" ] - return axis_cols[0] if axis_cols else None + return tuple(get_column_name(col) for col in axis_cols) def get_column_name( @@ -1301,9 +1302,12 @@ def get_column_name( expr = column.get("sqlExpression") if expr: return expr - raise ValueError("Missing label") - verbose_map = verbose_map or {} - return verbose_map.get(column, column) + + if isinstance(column, str): + verbose_map = verbose_map or {} + return verbose_map.get(column, column) + + raise ValueError("Missing label") def get_metric_name( @@ -1845,33 +1849,64 @@ def remove_duplicates( return result +@dataclass +class DateColumn: + col_label: str + timestamp_format: Optional[str] = None + offset: Optional[int] = None + time_shift: Optional[timedelta] = None + + def __hash__(self) -> int: + return hash(self.col_label) + + def __eq__(self, other: object) -> bool: + return isinstance(other, DateColumn) and hash(self) == hash(other) + + @classmethod + def get_legacy_time_column( + cls, + timestamp_format: Optional[str], + offset: Optional[int], + time_shift: Optional[timedelta], + ) -> DateColumn: + return cls( + timestamp_format=timestamp_format, + offset=offset, + time_shift=time_shift, + col_label=DTTM_ALIAS, + ) + + def normalize_dttm_col( df: pd.DataFrame, - timestamp_format: Optional[str], - offset: int, - time_shift: Optional[timedelta], + dttm_cols: Tuple[DateColumn, ...] = tuple(), ) -> None: - if DTTM_ALIAS not in df.columns: - return - if timestamp_format in ("epoch_s", "epoch_ms"): - dttm_col = df[DTTM_ALIAS] - if is_numeric_dtype(dttm_col): - # Column is formatted as a numeric value - unit = timestamp_format.replace("epoch_", "") - df[DTTM_ALIAS] = pd.to_datetime( - dttm_col, utc=False, unit=unit, origin="unix", errors="coerce" - ) + for _col in dttm_cols: + if _col.col_label not in df.columns: + continue + + if _col.timestamp_format in ("epoch_s", "epoch_ms"): + dttm_series = df[_col.col_label] + if is_numeric_dtype(dttm_series): + # Column is formatted as a numeric value + unit = _col.timestamp_format.replace("epoch_", "") + df[_col.col_label] = pd.to_datetime( + dttm_series, utc=False, unit=unit, origin="unix", errors="coerce" + ) + else: + # Column has already been formatted as a timestamp. + df[_col.col_label] = dttm_series.apply(pd.Timestamp) else: - # Column has already been formatted as a timestamp. - df[DTTM_ALIAS] = dttm_col.apply(pd.Timestamp) - else: - df[DTTM_ALIAS] = pd.to_datetime( - df[DTTM_ALIAS], utc=False, format=timestamp_format, errors="coerce" - ) - if offset: - df[DTTM_ALIAS] += timedelta(hours=offset) - if time_shift is not None: - df[DTTM_ALIAS] += time_shift + df[_col.col_label] = pd.to_datetime( + df[_col.col_label], + utc=False, + format=_col.timestamp_format, + errors="coerce", + ) + if _col.offset: + df[_col.col_label] += timedelta(hours=_col.offset) + if _col.time_shift is not None: + df[_col.col_label] += _col.time_shift def parse_boolean_string(bool_str: Optional[str]) -> bool: diff --git a/superset/viz.py b/superset/viz.py index 42de86497..43e71b533 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -80,6 +80,7 @@ from superset.utils import core as utils, csv from superset.utils.cache import set_and_log_cache from superset.utils.core import ( apply_max_row_limit, + DateColumn, DTTM_ALIAS, ExtraFiltersReasonType, get_column_name, @@ -301,9 +302,15 @@ class BaseViz: # pylint: disable=too-many-public-methods if not df.empty: utils.normalize_dttm_col( df=df, - timestamp_format=timestamp_format, - offset=self.datasource.offset, - time_shift=self.time_shift, + dttm_cols=tuple( + [ + DateColumn.get_legacy_time_column( + timestamp_format=timestamp_format, + offset=self.datasource.offset, + time_shift=self.time_shift, + ) + ] + ), ) if self.enforce_numerical_metrics: diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index 606bfe437..c117422ad 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -402,3 +402,13 @@ only_postgresql = pytest.mark.skipif( "postgresql" not in os.environ.get("SUPERSET__SQLALCHEMY_DATABASE_URI", ""), reason="Only run test case in Postgresql", ) + +only_sqlite = pytest.mark.skipif( + "sqlite" not in os.environ.get("SUPERSET__SQLALCHEMY_DATABASE_URI", ""), + reason="Only run test case in SQLite", +) + +only_mysql = pytest.mark.skipif( + "mysql" not in os.environ.get("SUPERSET__SQLALCHEMY_DATABASE_URI", ""), + reason="Only run test case in MySQL", +) diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py index cb2b24e9c..2306b0a1e 100644 --- a/tests/integration_tests/query_context_tests.py +++ b/tests/integration_tests/query_context_tests.py @@ -18,6 +18,8 @@ import re import time from typing import Any, Dict +import numpy as np +import pandas as pd import pytest from pandas import DateOffset @@ -39,7 +41,7 @@ from superset.utils.core import ( ) from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR from tests.integration_tests.base_tests import SupersetTestCase -from tests.integration_tests.conftest import only_postgresql +from tests.integration_tests.conftest import only_postgresql, only_sqlite from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, load_birth_names_data, @@ -910,3 +912,109 @@ def test_non_date_adhoc_column(app_context, physical_dataset): df = qc.get_df_payload(query_object)["df"] assert df["ADHOC COLUMN"][0] == 0 assert df["ADHOC COLUMN"][1] == 10 + + +@only_sqlite +def test_time_grain_and_time_offset_with_base_axis(app_context, physical_dataset): + column_on_axis: AdhocColumn = { + "label": "col6", + "sqlExpression": "col6", + "columnType": "BASE_AXIS", + "timeGrain": "P3M", + } + qc = QueryContextFactory().create( + datasource={ + "type": physical_dataset.type, + "id": physical_dataset.id, + }, + queries=[ + { + "columns": [column_on_axis], + "metrics": [ + { + "label": "SUM(col1)", + "expressionType": "SQL", + "sqlExpression": "SUM(col1)", + } + ], + "time_offsets": ["3 month ago"], + "granularity": "col6", + "time_range": "2002-01 : 2003-01", + } + ], + result_type=ChartDataResultType.FULL, + force=True, + ) + query_object = qc.queries[0] + df = qc.get_df_payload(query_object)["df"] + # todo: MySQL returns integer and float column as object type + """ + col6 SUM(col1) SUM(col1)__3 month ago +0 2002-01-01 3 NaN +1 2002-04-01 12 3.0 +2 2002-07-01 21 12.0 +3 2002-10-01 9 21.0 + """ + assert df.equals( + pd.DataFrame( + data={ + "col6": pd.to_datetime( + ["2002-01-01", "2002-04-01", "2002-07-01", "2002-10-01"] + ), + "SUM(col1)": [3, 12, 21, 9], + "SUM(col1)__3 month ago": [np.nan, 3, 12, 21], + } + ) + ) + + +@only_sqlite +def test_time_grain_and_time_offset_on_legacy_query(app_context, physical_dataset): + qc = QueryContextFactory().create( + datasource={ + "type": physical_dataset.type, + "id": physical_dataset.id, + }, + queries=[ + { + "columns": [], + "extras": { + "time_grain_sqla": "P3M", + }, + "metrics": [ + { + "label": "SUM(col1)", + "expressionType": "SQL", + "sqlExpression": "SUM(col1)", + } + ], + "time_offsets": ["3 month ago"], + "granularity": "col6", + "time_range": "2002-01 : 2003-01", + "is_timeseries": True, + } + ], + result_type=ChartDataResultType.FULL, + force=True, + ) + query_object = qc.queries[0] + df = qc.get_df_payload(query_object)["df"] + # todo: MySQL returns integer and float column as object type + """ + __timestamp SUM(col1) SUM(col1)__3 month ago +0 2002-01-01 3 NaN +1 2002-04-01 12 3.0 +2 2002-07-01 21 12.0 +3 2002-10-01 9 21.0 + """ + assert df.equals( + pd.DataFrame( + data={ + "__timestamp": pd.to_datetime( + ["2002-01-01", "2002-04-01", "2002-07-01", "2002-10-01"] + ), + "SUM(col1)": [3, 12, 21, 9], + "SUM(col1)__3 month ago": [np.nan, 3, 12, 21], + } + ) + ) diff --git a/tests/integration_tests/utils_tests.py b/tests/integration_tests/utils_tests.py index da1567cab..7df9cd82f 100644 --- a/tests/integration_tests/utils_tests.py +++ b/tests/integration_tests/utils_tests.py @@ -70,6 +70,7 @@ from superset.utils.core import ( validate_json, zlib_compress, zlib_decompress, + DateColumn, ) from superset.utils.database import get_or_create_db from superset.utils import schema @@ -1062,7 +1063,18 @@ class TestUtils(SupersetTestCase): time_shift: Optional[timedelta], ) -> pd.DataFrame: df = df.copy() - normalize_dttm_col(df, timestamp_format, offset, time_shift) + normalize_dttm_col( + df, + tuple( + [ + DateColumn.get_legacy_time_column( + timestamp_format=timestamp_format, + offset=offset, + time_shift=time_shift, + ) + ] + ), + ) return df ts = pd.Timestamp(2021, 2, 15, 19, 0, 0, 0)