fix: Time Offset in SQLite and refine logic in Date Type conversion (#21378)
This commit is contained in:
parent
1c0bff3dfb
commit
2dfcba04b0
|
|
@ -49,8 +49,10 @@ from superset.utils import csv
|
|||
from superset.utils.cache import generate_cache_key, set_and_log_cache
|
||||
from superset.utils.core import (
|
||||
DatasourceType,
|
||||
DateColumn,
|
||||
DTTM_ALIAS,
|
||||
error_msg_from_exception,
|
||||
get_base_axis_labels,
|
||||
get_column_names_from_columns,
|
||||
get_column_names_from_metrics,
|
||||
get_metric_names,
|
||||
|
|
@ -238,18 +240,57 @@ class QueryContextProcessor:
|
|||
return result
|
||||
|
||||
def normalize_df(self, df: pd.DataFrame, query_object: QueryObject) -> pd.DataFrame:
|
||||
datasource = self._qc_datasource
|
||||
timestamp_format = None
|
||||
if datasource.type == "table":
|
||||
dttm_col = datasource.get_column(query_object.granularity)
|
||||
if dttm_col:
|
||||
timestamp_format = dttm_col.python_date_format
|
||||
# todo: should support "python_date_format" and "get_column" in each datasource
|
||||
def _get_timestamp_format(
|
||||
source: BaseDatasource, column: Optional[str]
|
||||
) -> Optional[str]:
|
||||
column_obj = source.get_column(column)
|
||||
if (
|
||||
column_obj
|
||||
# only sqla column was supported
|
||||
and hasattr(column_obj, "python_date_format")
|
||||
and (formatter := column_obj.python_date_format)
|
||||
):
|
||||
return str(formatter)
|
||||
|
||||
return None
|
||||
|
||||
datasource = self._qc_datasource
|
||||
labels = tuple(
|
||||
label
|
||||
for label in [
|
||||
*get_base_axis_labels(query_object.columns),
|
||||
query_object.granularity,
|
||||
]
|
||||
if datasource
|
||||
# Query datasource didn't support `get_column`
|
||||
and hasattr(datasource, "get_column")
|
||||
and (col := datasource.get_column(label))
|
||||
and col.is_dttm
|
||||
)
|
||||
dttm_cols = [
|
||||
DateColumn(
|
||||
timestamp_format=_get_timestamp_format(datasource, label),
|
||||
offset=datasource.offset,
|
||||
time_shift=query_object.time_shift,
|
||||
col_label=label,
|
||||
)
|
||||
for label in labels
|
||||
if label
|
||||
]
|
||||
if DTTM_ALIAS in df:
|
||||
dttm_cols.append(
|
||||
DateColumn.get_legacy_time_column(
|
||||
timestamp_format=_get_timestamp_format(
|
||||
datasource, query_object.granularity
|
||||
),
|
||||
offset=datasource.offset,
|
||||
time_shift=query_object.time_shift,
|
||||
)
|
||||
)
|
||||
normalize_dttm_col(
|
||||
df=df,
|
||||
timestamp_format=timestamp_format,
|
||||
offset=datasource.offset,
|
||||
time_shift=query_object.time_shift,
|
||||
dttm_cols=tuple(dttm_cols),
|
||||
)
|
||||
|
||||
if self.enforce_numerical_metrics:
|
||||
|
|
@ -344,10 +385,7 @@ class QueryContextProcessor:
|
|||
offset_metrics_df = offset_metrics_df.rename(columns=metrics_mapping)
|
||||
|
||||
# 3. set time offset for index
|
||||
# TODO: add x-axis to QueryObject, potentially as an array for
|
||||
# multi-dimensional charts
|
||||
granularity = query_object.granularity
|
||||
index = granularity if granularity in df.columns else DTTM_ALIAS
|
||||
index = (get_base_axis_labels(query_object.columns) or [DTTM_ALIAS])[0]
|
||||
if not dataframe_utils.is_datetime_series(offset_metrics_df.get(index)):
|
||||
raise QueryObjectValidationError(
|
||||
_(
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@
|
|||
# under the License.
|
||||
"""Utility functions used across Superset"""
|
||||
# pylint: disable=too-many-lines
|
||||
from __future__ import annotations
|
||||
|
||||
import _thread
|
||||
import collections
|
||||
import decimal
|
||||
|
|
@ -34,6 +36,7 @@ import traceback
|
|||
import uuid
|
||||
import zlib
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, datetime, time, timedelta
|
||||
from distutils.util import strtobool
|
||||
from email.mime.application import MIMEApplication
|
||||
|
|
@ -1271,15 +1274,13 @@ def is_adhoc_column(column: Column) -> TypeGuard[AdhocColumn]:
|
|||
return isinstance(column, dict)
|
||||
|
||||
|
||||
def get_base_axis_column(columns: Optional[List[Column]]) -> Optional[AdhocColumn]:
|
||||
if columns is None:
|
||||
return None
|
||||
def get_base_axis_labels(columns: Optional[List[Column]]) -> Tuple[str, ...]:
|
||||
axis_cols = [
|
||||
col
|
||||
for col in columns
|
||||
for col in columns or []
|
||||
if is_adhoc_column(col) and col.get("columnType") == "BASE_AXIS"
|
||||
]
|
||||
return axis_cols[0] if axis_cols else None
|
||||
return tuple(get_column_name(col) for col in axis_cols)
|
||||
|
||||
|
||||
def get_column_name(
|
||||
|
|
@ -1301,9 +1302,12 @@ def get_column_name(
|
|||
expr = column.get("sqlExpression")
|
||||
if expr:
|
||||
return expr
|
||||
raise ValueError("Missing label")
|
||||
verbose_map = verbose_map or {}
|
||||
return verbose_map.get(column, column)
|
||||
|
||||
if isinstance(column, str):
|
||||
verbose_map = verbose_map or {}
|
||||
return verbose_map.get(column, column)
|
||||
|
||||
raise ValueError("Missing label")
|
||||
|
||||
|
||||
def get_metric_name(
|
||||
|
|
@ -1845,33 +1849,64 @@ def remove_duplicates(
|
|||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class DateColumn:
|
||||
col_label: str
|
||||
timestamp_format: Optional[str] = None
|
||||
offset: Optional[int] = None
|
||||
time_shift: Optional[timedelta] = None
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.col_label)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, DateColumn) and hash(self) == hash(other)
|
||||
|
||||
@classmethod
|
||||
def get_legacy_time_column(
|
||||
cls,
|
||||
timestamp_format: Optional[str],
|
||||
offset: Optional[int],
|
||||
time_shift: Optional[timedelta],
|
||||
) -> DateColumn:
|
||||
return cls(
|
||||
timestamp_format=timestamp_format,
|
||||
offset=offset,
|
||||
time_shift=time_shift,
|
||||
col_label=DTTM_ALIAS,
|
||||
)
|
||||
|
||||
|
||||
def normalize_dttm_col(
|
||||
df: pd.DataFrame,
|
||||
timestamp_format: Optional[str],
|
||||
offset: int,
|
||||
time_shift: Optional[timedelta],
|
||||
dttm_cols: Tuple[DateColumn, ...] = tuple(),
|
||||
) -> None:
|
||||
if DTTM_ALIAS not in df.columns:
|
||||
return
|
||||
if timestamp_format in ("epoch_s", "epoch_ms"):
|
||||
dttm_col = df[DTTM_ALIAS]
|
||||
if is_numeric_dtype(dttm_col):
|
||||
# Column is formatted as a numeric value
|
||||
unit = timestamp_format.replace("epoch_", "")
|
||||
df[DTTM_ALIAS] = pd.to_datetime(
|
||||
dttm_col, utc=False, unit=unit, origin="unix", errors="coerce"
|
||||
)
|
||||
for _col in dttm_cols:
|
||||
if _col.col_label not in df.columns:
|
||||
continue
|
||||
|
||||
if _col.timestamp_format in ("epoch_s", "epoch_ms"):
|
||||
dttm_series = df[_col.col_label]
|
||||
if is_numeric_dtype(dttm_series):
|
||||
# Column is formatted as a numeric value
|
||||
unit = _col.timestamp_format.replace("epoch_", "")
|
||||
df[_col.col_label] = pd.to_datetime(
|
||||
dttm_series, utc=False, unit=unit, origin="unix", errors="coerce"
|
||||
)
|
||||
else:
|
||||
# Column has already been formatted as a timestamp.
|
||||
df[_col.col_label] = dttm_series.apply(pd.Timestamp)
|
||||
else:
|
||||
# Column has already been formatted as a timestamp.
|
||||
df[DTTM_ALIAS] = dttm_col.apply(pd.Timestamp)
|
||||
else:
|
||||
df[DTTM_ALIAS] = pd.to_datetime(
|
||||
df[DTTM_ALIAS], utc=False, format=timestamp_format, errors="coerce"
|
||||
)
|
||||
if offset:
|
||||
df[DTTM_ALIAS] += timedelta(hours=offset)
|
||||
if time_shift is not None:
|
||||
df[DTTM_ALIAS] += time_shift
|
||||
df[_col.col_label] = pd.to_datetime(
|
||||
df[_col.col_label],
|
||||
utc=False,
|
||||
format=_col.timestamp_format,
|
||||
errors="coerce",
|
||||
)
|
||||
if _col.offset:
|
||||
df[_col.col_label] += timedelta(hours=_col.offset)
|
||||
if _col.time_shift is not None:
|
||||
df[_col.col_label] += _col.time_shift
|
||||
|
||||
|
||||
def parse_boolean_string(bool_str: Optional[str]) -> bool:
|
||||
|
|
|
|||
|
|
@ -80,6 +80,7 @@ from superset.utils import core as utils, csv
|
|||
from superset.utils.cache import set_and_log_cache
|
||||
from superset.utils.core import (
|
||||
apply_max_row_limit,
|
||||
DateColumn,
|
||||
DTTM_ALIAS,
|
||||
ExtraFiltersReasonType,
|
||||
get_column_name,
|
||||
|
|
@ -301,9 +302,15 @@ class BaseViz: # pylint: disable=too-many-public-methods
|
|||
if not df.empty:
|
||||
utils.normalize_dttm_col(
|
||||
df=df,
|
||||
timestamp_format=timestamp_format,
|
||||
offset=self.datasource.offset,
|
||||
time_shift=self.time_shift,
|
||||
dttm_cols=tuple(
|
||||
[
|
||||
DateColumn.get_legacy_time_column(
|
||||
timestamp_format=timestamp_format,
|
||||
offset=self.datasource.offset,
|
||||
time_shift=self.time_shift,
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
if self.enforce_numerical_metrics:
|
||||
|
|
|
|||
|
|
@ -402,3 +402,13 @@ only_postgresql = pytest.mark.skipif(
|
|||
"postgresql" not in os.environ.get("SUPERSET__SQLALCHEMY_DATABASE_URI", ""),
|
||||
reason="Only run test case in Postgresql",
|
||||
)
|
||||
|
||||
only_sqlite = pytest.mark.skipif(
|
||||
"sqlite" not in os.environ.get("SUPERSET__SQLALCHEMY_DATABASE_URI", ""),
|
||||
reason="Only run test case in SQLite",
|
||||
)
|
||||
|
||||
only_mysql = pytest.mark.skipif(
|
||||
"mysql" not in os.environ.get("SUPERSET__SQLALCHEMY_DATABASE_URI", ""),
|
||||
reason="Only run test case in MySQL",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ import re
|
|||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from pandas import DateOffset
|
||||
|
||||
|
|
@ -39,7 +41,7 @@ from superset.utils.core import (
|
|||
)
|
||||
from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
|
||||
from tests.integration_tests.base_tests import SupersetTestCase
|
||||
from tests.integration_tests.conftest import only_postgresql
|
||||
from tests.integration_tests.conftest import only_postgresql, only_sqlite
|
||||
from tests.integration_tests.fixtures.birth_names_dashboard import (
|
||||
load_birth_names_dashboard_with_slices,
|
||||
load_birth_names_data,
|
||||
|
|
@ -910,3 +912,109 @@ def test_non_date_adhoc_column(app_context, physical_dataset):
|
|||
df = qc.get_df_payload(query_object)["df"]
|
||||
assert df["ADHOC COLUMN"][0] == 0
|
||||
assert df["ADHOC COLUMN"][1] == 10
|
||||
|
||||
|
||||
@only_sqlite
|
||||
def test_time_grain_and_time_offset_with_base_axis(app_context, physical_dataset):
|
||||
column_on_axis: AdhocColumn = {
|
||||
"label": "col6",
|
||||
"sqlExpression": "col6",
|
||||
"columnType": "BASE_AXIS",
|
||||
"timeGrain": "P3M",
|
||||
}
|
||||
qc = QueryContextFactory().create(
|
||||
datasource={
|
||||
"type": physical_dataset.type,
|
||||
"id": physical_dataset.id,
|
||||
},
|
||||
queries=[
|
||||
{
|
||||
"columns": [column_on_axis],
|
||||
"metrics": [
|
||||
{
|
||||
"label": "SUM(col1)",
|
||||
"expressionType": "SQL",
|
||||
"sqlExpression": "SUM(col1)",
|
||||
}
|
||||
],
|
||||
"time_offsets": ["3 month ago"],
|
||||
"granularity": "col6",
|
||||
"time_range": "2002-01 : 2003-01",
|
||||
}
|
||||
],
|
||||
result_type=ChartDataResultType.FULL,
|
||||
force=True,
|
||||
)
|
||||
query_object = qc.queries[0]
|
||||
df = qc.get_df_payload(query_object)["df"]
|
||||
# todo: MySQL returns integer and float column as object type
|
||||
"""
|
||||
col6 SUM(col1) SUM(col1)__3 month ago
|
||||
0 2002-01-01 3 NaN
|
||||
1 2002-04-01 12 3.0
|
||||
2 2002-07-01 21 12.0
|
||||
3 2002-10-01 9 21.0
|
||||
"""
|
||||
assert df.equals(
|
||||
pd.DataFrame(
|
||||
data={
|
||||
"col6": pd.to_datetime(
|
||||
["2002-01-01", "2002-04-01", "2002-07-01", "2002-10-01"]
|
||||
),
|
||||
"SUM(col1)": [3, 12, 21, 9],
|
||||
"SUM(col1)__3 month ago": [np.nan, 3, 12, 21],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@only_sqlite
|
||||
def test_time_grain_and_time_offset_on_legacy_query(app_context, physical_dataset):
|
||||
qc = QueryContextFactory().create(
|
||||
datasource={
|
||||
"type": physical_dataset.type,
|
||||
"id": physical_dataset.id,
|
||||
},
|
||||
queries=[
|
||||
{
|
||||
"columns": [],
|
||||
"extras": {
|
||||
"time_grain_sqla": "P3M",
|
||||
},
|
||||
"metrics": [
|
||||
{
|
||||
"label": "SUM(col1)",
|
||||
"expressionType": "SQL",
|
||||
"sqlExpression": "SUM(col1)",
|
||||
}
|
||||
],
|
||||
"time_offsets": ["3 month ago"],
|
||||
"granularity": "col6",
|
||||
"time_range": "2002-01 : 2003-01",
|
||||
"is_timeseries": True,
|
||||
}
|
||||
],
|
||||
result_type=ChartDataResultType.FULL,
|
||||
force=True,
|
||||
)
|
||||
query_object = qc.queries[0]
|
||||
df = qc.get_df_payload(query_object)["df"]
|
||||
# todo: MySQL returns integer and float column as object type
|
||||
"""
|
||||
__timestamp SUM(col1) SUM(col1)__3 month ago
|
||||
0 2002-01-01 3 NaN
|
||||
1 2002-04-01 12 3.0
|
||||
2 2002-07-01 21 12.0
|
||||
3 2002-10-01 9 21.0
|
||||
"""
|
||||
assert df.equals(
|
||||
pd.DataFrame(
|
||||
data={
|
||||
"__timestamp": pd.to_datetime(
|
||||
["2002-01-01", "2002-04-01", "2002-07-01", "2002-10-01"]
|
||||
),
|
||||
"SUM(col1)": [3, 12, 21, 9],
|
||||
"SUM(col1)__3 month ago": [np.nan, 3, 12, 21],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -70,6 +70,7 @@ from superset.utils.core import (
|
|||
validate_json,
|
||||
zlib_compress,
|
||||
zlib_decompress,
|
||||
DateColumn,
|
||||
)
|
||||
from superset.utils.database import get_or_create_db
|
||||
from superset.utils import schema
|
||||
|
|
@ -1062,7 +1063,18 @@ class TestUtils(SupersetTestCase):
|
|||
time_shift: Optional[timedelta],
|
||||
) -> pd.DataFrame:
|
||||
df = df.copy()
|
||||
normalize_dttm_col(df, timestamp_format, offset, time_shift)
|
||||
normalize_dttm_col(
|
||||
df,
|
||||
tuple(
|
||||
[
|
||||
DateColumn.get_legacy_time_column(
|
||||
timestamp_format=timestamp_format,
|
||||
offset=offset,
|
||||
time_shift=time_shift,
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
return df
|
||||
|
||||
ts = pd.Timestamp(2021, 2, 15, 19, 0, 0, 0)
|
||||
|
|
|
|||
Loading…
Reference in New Issue