fix: Time Offset in SQLite and refine logic in Date Type conversion (#21378)

This commit is contained in:
Yongjie Zhao 2022-09-16 12:02:22 +08:00 committed by GitHub
parent 1c0bff3dfb
commit 2dfcba04b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 259 additions and 49 deletions

View File

@ -49,8 +49,10 @@ from superset.utils import csv
from superset.utils.cache import generate_cache_key, set_and_log_cache
from superset.utils.core import (
DatasourceType,
DateColumn,
DTTM_ALIAS,
error_msg_from_exception,
get_base_axis_labels,
get_column_names_from_columns,
get_column_names_from_metrics,
get_metric_names,
@ -238,18 +240,57 @@ class QueryContextProcessor:
return result
def normalize_df(self, df: pd.DataFrame, query_object: QueryObject) -> pd.DataFrame:
datasource = self._qc_datasource
timestamp_format = None
if datasource.type == "table":
dttm_col = datasource.get_column(query_object.granularity)
if dttm_col:
timestamp_format = dttm_col.python_date_format
# todo: should support "python_date_format" and "get_column" in each datasource
def _get_timestamp_format(
source: BaseDatasource, column: Optional[str]
) -> Optional[str]:
column_obj = source.get_column(column)
if (
column_obj
# only sqla column was supported
and hasattr(column_obj, "python_date_format")
and (formatter := column_obj.python_date_format)
):
return str(formatter)
return None
datasource = self._qc_datasource
labels = tuple(
label
for label in [
*get_base_axis_labels(query_object.columns),
query_object.granularity,
]
if datasource
# Query datasource didn't support `get_column`
and hasattr(datasource, "get_column")
and (col := datasource.get_column(label))
and col.is_dttm
)
dttm_cols = [
DateColumn(
timestamp_format=_get_timestamp_format(datasource, label),
offset=datasource.offset,
time_shift=query_object.time_shift,
col_label=label,
)
for label in labels
if label
]
if DTTM_ALIAS in df:
dttm_cols.append(
DateColumn.get_legacy_time_column(
timestamp_format=_get_timestamp_format(
datasource, query_object.granularity
),
offset=datasource.offset,
time_shift=query_object.time_shift,
)
)
normalize_dttm_col(
df=df,
timestamp_format=timestamp_format,
offset=datasource.offset,
time_shift=query_object.time_shift,
dttm_cols=tuple(dttm_cols),
)
if self.enforce_numerical_metrics:
@ -344,10 +385,7 @@ class QueryContextProcessor:
offset_metrics_df = offset_metrics_df.rename(columns=metrics_mapping)
# 3. set time offset for index
# TODO: add x-axis to QueryObject, potentially as an array for
# multi-dimensional charts
granularity = query_object.granularity
index = granularity if granularity in df.columns else DTTM_ALIAS
index = (get_base_axis_labels(query_object.columns) or [DTTM_ALIAS])[0]
if not dataframe_utils.is_datetime_series(offset_metrics_df.get(index)):
raise QueryObjectValidationError(
_(

View File

@ -16,6 +16,8 @@
# under the License.
"""Utility functions used across Superset"""
# pylint: disable=too-many-lines
from __future__ import annotations
import _thread
import collections
import decimal
@ -34,6 +36,7 @@ import traceback
import uuid
import zlib
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import date, datetime, time, timedelta
from distutils.util import strtobool
from email.mime.application import MIMEApplication
@ -1271,15 +1274,13 @@ def is_adhoc_column(column: Column) -> TypeGuard[AdhocColumn]:
return isinstance(column, dict)
def get_base_axis_column(columns: Optional[List[Column]]) -> Optional[AdhocColumn]:
if columns is None:
return None
def get_base_axis_labels(columns: Optional[List[Column]]) -> Tuple[str, ...]:
axis_cols = [
col
for col in columns
for col in columns or []
if is_adhoc_column(col) and col.get("columnType") == "BASE_AXIS"
]
return axis_cols[0] if axis_cols else None
return tuple(get_column_name(col) for col in axis_cols)
def get_column_name(
@ -1301,9 +1302,12 @@ def get_column_name(
expr = column.get("sqlExpression")
if expr:
return expr
raise ValueError("Missing label")
verbose_map = verbose_map or {}
return verbose_map.get(column, column)
if isinstance(column, str):
verbose_map = verbose_map or {}
return verbose_map.get(column, column)
raise ValueError("Missing label")
def get_metric_name(
@ -1845,33 +1849,64 @@ def remove_duplicates(
return result
@dataclass
class DateColumn:
col_label: str
timestamp_format: Optional[str] = None
offset: Optional[int] = None
time_shift: Optional[timedelta] = None
def __hash__(self) -> int:
return hash(self.col_label)
def __eq__(self, other: object) -> bool:
return isinstance(other, DateColumn) and hash(self) == hash(other)
@classmethod
def get_legacy_time_column(
cls,
timestamp_format: Optional[str],
offset: Optional[int],
time_shift: Optional[timedelta],
) -> DateColumn:
return cls(
timestamp_format=timestamp_format,
offset=offset,
time_shift=time_shift,
col_label=DTTM_ALIAS,
)
def normalize_dttm_col(
df: pd.DataFrame,
timestamp_format: Optional[str],
offset: int,
time_shift: Optional[timedelta],
dttm_cols: Tuple[DateColumn, ...] = tuple(),
) -> None:
if DTTM_ALIAS not in df.columns:
return
if timestamp_format in ("epoch_s", "epoch_ms"):
dttm_col = df[DTTM_ALIAS]
if is_numeric_dtype(dttm_col):
# Column is formatted as a numeric value
unit = timestamp_format.replace("epoch_", "")
df[DTTM_ALIAS] = pd.to_datetime(
dttm_col, utc=False, unit=unit, origin="unix", errors="coerce"
)
for _col in dttm_cols:
if _col.col_label not in df.columns:
continue
if _col.timestamp_format in ("epoch_s", "epoch_ms"):
dttm_series = df[_col.col_label]
if is_numeric_dtype(dttm_series):
# Column is formatted as a numeric value
unit = _col.timestamp_format.replace("epoch_", "")
df[_col.col_label] = pd.to_datetime(
dttm_series, utc=False, unit=unit, origin="unix", errors="coerce"
)
else:
# Column has already been formatted as a timestamp.
df[_col.col_label] = dttm_series.apply(pd.Timestamp)
else:
# Column has already been formatted as a timestamp.
df[DTTM_ALIAS] = dttm_col.apply(pd.Timestamp)
else:
df[DTTM_ALIAS] = pd.to_datetime(
df[DTTM_ALIAS], utc=False, format=timestamp_format, errors="coerce"
)
if offset:
df[DTTM_ALIAS] += timedelta(hours=offset)
if time_shift is not None:
df[DTTM_ALIAS] += time_shift
df[_col.col_label] = pd.to_datetime(
df[_col.col_label],
utc=False,
format=_col.timestamp_format,
errors="coerce",
)
if _col.offset:
df[_col.col_label] += timedelta(hours=_col.offset)
if _col.time_shift is not None:
df[_col.col_label] += _col.time_shift
def parse_boolean_string(bool_str: Optional[str]) -> bool:

View File

@ -80,6 +80,7 @@ from superset.utils import core as utils, csv
from superset.utils.cache import set_and_log_cache
from superset.utils.core import (
apply_max_row_limit,
DateColumn,
DTTM_ALIAS,
ExtraFiltersReasonType,
get_column_name,
@ -301,9 +302,15 @@ class BaseViz: # pylint: disable=too-many-public-methods
if not df.empty:
utils.normalize_dttm_col(
df=df,
timestamp_format=timestamp_format,
offset=self.datasource.offset,
time_shift=self.time_shift,
dttm_cols=tuple(
[
DateColumn.get_legacy_time_column(
timestamp_format=timestamp_format,
offset=self.datasource.offset,
time_shift=self.time_shift,
)
]
),
)
if self.enforce_numerical_metrics:

View File

@ -402,3 +402,13 @@ only_postgresql = pytest.mark.skipif(
"postgresql" not in os.environ.get("SUPERSET__SQLALCHEMY_DATABASE_URI", ""),
reason="Only run test case in Postgresql",
)
only_sqlite = pytest.mark.skipif(
"sqlite" not in os.environ.get("SUPERSET__SQLALCHEMY_DATABASE_URI", ""),
reason="Only run test case in SQLite",
)
only_mysql = pytest.mark.skipif(
"mysql" not in os.environ.get("SUPERSET__SQLALCHEMY_DATABASE_URI", ""),
reason="Only run test case in MySQL",
)

View File

@ -18,6 +18,8 @@ import re
import time
from typing import Any, Dict
import numpy as np
import pandas as pd
import pytest
from pandas import DateOffset
@ -39,7 +41,7 @@ from superset.utils.core import (
)
from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.conftest import only_postgresql
from tests.integration_tests.conftest import only_postgresql, only_sqlite
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
load_birth_names_data,
@ -910,3 +912,109 @@ def test_non_date_adhoc_column(app_context, physical_dataset):
df = qc.get_df_payload(query_object)["df"]
assert df["ADHOC COLUMN"][0] == 0
assert df["ADHOC COLUMN"][1] == 10
@only_sqlite
def test_time_grain_and_time_offset_with_base_axis(app_context, physical_dataset):
column_on_axis: AdhocColumn = {
"label": "col6",
"sqlExpression": "col6",
"columnType": "BASE_AXIS",
"timeGrain": "P3M",
}
qc = QueryContextFactory().create(
datasource={
"type": physical_dataset.type,
"id": physical_dataset.id,
},
queries=[
{
"columns": [column_on_axis],
"metrics": [
{
"label": "SUM(col1)",
"expressionType": "SQL",
"sqlExpression": "SUM(col1)",
}
],
"time_offsets": ["3 month ago"],
"granularity": "col6",
"time_range": "2002-01 : 2003-01",
}
],
result_type=ChartDataResultType.FULL,
force=True,
)
query_object = qc.queries[0]
df = qc.get_df_payload(query_object)["df"]
# todo: MySQL returns integer and float column as object type
"""
col6 SUM(col1) SUM(col1)__3 month ago
0 2002-01-01 3 NaN
1 2002-04-01 12 3.0
2 2002-07-01 21 12.0
3 2002-10-01 9 21.0
"""
assert df.equals(
pd.DataFrame(
data={
"col6": pd.to_datetime(
["2002-01-01", "2002-04-01", "2002-07-01", "2002-10-01"]
),
"SUM(col1)": [3, 12, 21, 9],
"SUM(col1)__3 month ago": [np.nan, 3, 12, 21],
}
)
)
@only_sqlite
def test_time_grain_and_time_offset_on_legacy_query(app_context, physical_dataset):
qc = QueryContextFactory().create(
datasource={
"type": physical_dataset.type,
"id": physical_dataset.id,
},
queries=[
{
"columns": [],
"extras": {
"time_grain_sqla": "P3M",
},
"metrics": [
{
"label": "SUM(col1)",
"expressionType": "SQL",
"sqlExpression": "SUM(col1)",
}
],
"time_offsets": ["3 month ago"],
"granularity": "col6",
"time_range": "2002-01 : 2003-01",
"is_timeseries": True,
}
],
result_type=ChartDataResultType.FULL,
force=True,
)
query_object = qc.queries[0]
df = qc.get_df_payload(query_object)["df"]
# todo: MySQL returns integer and float column as object type
"""
__timestamp SUM(col1) SUM(col1)__3 month ago
0 2002-01-01 3 NaN
1 2002-04-01 12 3.0
2 2002-07-01 21 12.0
3 2002-10-01 9 21.0
"""
assert df.equals(
pd.DataFrame(
data={
"__timestamp": pd.to_datetime(
["2002-01-01", "2002-04-01", "2002-07-01", "2002-10-01"]
),
"SUM(col1)": [3, 12, 21, 9],
"SUM(col1)__3 month ago": [np.nan, 3, 12, 21],
}
)
)

View File

@ -70,6 +70,7 @@ from superset.utils.core import (
validate_json,
zlib_compress,
zlib_decompress,
DateColumn,
)
from superset.utils.database import get_or_create_db
from superset.utils import schema
@ -1062,7 +1063,18 @@ class TestUtils(SupersetTestCase):
time_shift: Optional[timedelta],
) -> pd.DataFrame:
df = df.copy()
normalize_dttm_col(df, timestamp_format, offset, time_shift)
normalize_dttm_col(
df,
tuple(
[
DateColumn.get_legacy_time_column(
timestamp_format=timestamp_format,
offset=offset,
time_shift=time_shift,
)
]
),
)
return df
ts = pd.Timestamp(2021, 2, 15, 19, 0, 0, 0)