fix: Apply normalization to all dttm columns (#25147)
This commit is contained in:
parent
17792a507c
commit
58fcd292a9
|
|
@ -185,6 +185,7 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
|
|||
filter
|
||||
for filter in query_object.filter
|
||||
if filter["col"] != filter_to_remove
|
||||
or filter["op"] != "TEMPORAL_RANGE"
|
||||
]
|
||||
|
||||
def _apply_filters(self, query_object: QueryObject) -> None:
|
||||
|
|
|
|||
|
|
@ -282,10 +282,11 @@ class QueryContextProcessor:
|
|||
datasource = self._qc_datasource
|
||||
labels = tuple(
|
||||
label
|
||||
for label in [
|
||||
for label in {
|
||||
*get_base_axis_labels(query_object.columns),
|
||||
*[col for col in query_object.columns or [] if isinstance(col, str)],
|
||||
query_object.granularity,
|
||||
]
|
||||
}
|
||||
if datasource
|
||||
# Query datasource didn't support `get_column`
|
||||
and hasattr(datasource, "get_column")
|
||||
|
|
|
|||
|
|
@ -16,17 +16,24 @@
|
|||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from superset.common.chart_data import ChartDataResultType
|
||||
from superset.common.query_object import QueryObject
|
||||
from superset.common.utils.time_range_utils import get_since_until_from_time_range
|
||||
from superset.utils.core import apply_max_row_limit, DatasourceDict, DatasourceType
|
||||
from superset.utils.core import (
|
||||
apply_max_row_limit,
|
||||
DatasourceDict,
|
||||
DatasourceType,
|
||||
FilterOperator,
|
||||
QueryObjectFilterClause,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from superset.connectors.base.models import BaseDatasource
|
||||
from superset.connectors.base.models import BaseColumn, BaseDatasource
|
||||
from superset.daos.datasource import DatasourceDAO
|
||||
|
||||
|
||||
|
|
@ -66,6 +73,10 @@ class QueryObjectFactory: # pylint: disable=too-few-public-methods
|
|||
)
|
||||
kwargs["from_dttm"] = from_dttm
|
||||
kwargs["to_dttm"] = to_dttm
|
||||
if datasource_model_instance and kwargs.get("filters", []):
|
||||
kwargs["filters"] = self._process_filters(
|
||||
datasource_model_instance, kwargs["filters"]
|
||||
)
|
||||
return QueryObject(
|
||||
datasource=datasource_model_instance,
|
||||
extras=extras,
|
||||
|
|
@ -102,3 +113,54 @@ class QueryObjectFactory: # pylint: disable=too-few-public-methods
|
|||
# light version of the view.utils.core
|
||||
# import view.utils require application context
|
||||
# Todo: move it and the view.utils.core to utils package
|
||||
|
||||
def _process_filters(
|
||||
self, datasource: BaseDatasource, query_filters: list[QueryObjectFilterClause]
|
||||
) -> list[QueryObjectFilterClause]:
|
||||
def get_dttm_filter_value(
|
||||
value: Any, col: BaseColumn, date_format: str
|
||||
) -> int | str:
|
||||
if not isinstance(value, int):
|
||||
return value
|
||||
if date_format in {"epoch_ms", "epoch_s"}:
|
||||
if date_format == "epoch_s":
|
||||
value = str(value)
|
||||
else:
|
||||
value = str(value * 1000)
|
||||
else:
|
||||
dttm = datetime.utcfromtimestamp(value / 1000)
|
||||
value = dttm.strftime(date_format)
|
||||
|
||||
if col.type in col.num_types:
|
||||
value = int(value)
|
||||
return value
|
||||
|
||||
for query_filter in query_filters:
|
||||
if query_filter.get("op") == FilterOperator.TEMPORAL_RANGE:
|
||||
continue
|
||||
filter_col = query_filter.get("col")
|
||||
if not isinstance(filter_col, str):
|
||||
continue
|
||||
column = datasource.get_column(filter_col)
|
||||
if not column:
|
||||
continue
|
||||
filter_value = query_filter.get("val")
|
||||
|
||||
date_format = column.python_date_format
|
||||
if not date_format and datasource.db_extra:
|
||||
date_format = datasource.db_extra.get(
|
||||
"python_date_format_by_column_name", {}
|
||||
).get(column.column_name)
|
||||
|
||||
if column.is_dttm and date_format:
|
||||
if isinstance(filter_value, list):
|
||||
query_filter["val"] = [
|
||||
get_dttm_filter_value(value, column, date_format)
|
||||
for value in filter_value
|
||||
]
|
||||
else:
|
||||
query_filter["val"] = get_dttm_filter_value(
|
||||
filter_value, column, date_format
|
||||
)
|
||||
|
||||
return query_filters
|
||||
|
|
|
|||
|
|
@ -836,11 +836,9 @@ def test_special_chars_in_column_name(app_context, physical_dataset):
|
|||
|
||||
query_object = qc.queries[0]
|
||||
df = qc.get_df_payload(query_object)["df"]
|
||||
if query_object.datasource.database.backend == "sqlite":
|
||||
# sqlite returns string as timestamp column
|
||||
assert df["time column with spaces"][0] == "2002-01-03 00:00:00"
|
||||
assert df["I_AM_A_TRUNC_COLUMN"][0] == "2002-01-01 00:00:00"
|
||||
else:
|
||||
|
||||
# sqlite doesn't have timestamp columns
|
||||
if query_object.datasource.database.backend != "sqlite":
|
||||
assert df["time column with spaces"][0].strftime("%Y-%m-%d") == "2002-01-03"
|
||||
assert df["I_AM_A_TRUNC_COLUMN"][0].strftime("%Y-%m-%d") == "2002-01-01"
|
||||
|
||||
|
|
|
|||
|
|
@ -43,9 +43,45 @@ def session_factory() -> Mock:
|
|||
return Mock()
|
||||
|
||||
|
||||
class SimpleDatasetColumn:
|
||||
def __init__(self, col_params: dict[str, Any]):
|
||||
self.__dict__.update(col_params)
|
||||
|
||||
|
||||
TEMPORAL_COLUMN_NAMES = ["temporal_column", "temporal_column_with_python_date_format"]
|
||||
TEMPORAL_COLUMNS = {
|
||||
TEMPORAL_COLUMN_NAMES[0]: SimpleDatasetColumn(
|
||||
{
|
||||
"column_name": TEMPORAL_COLUMN_NAMES[0],
|
||||
"is_dttm": True,
|
||||
"python_date_format": None,
|
||||
"type": "string",
|
||||
"num_types": ["BIGINT"],
|
||||
}
|
||||
),
|
||||
TEMPORAL_COLUMN_NAMES[1]: SimpleDatasetColumn(
|
||||
{
|
||||
"column_name": TEMPORAL_COLUMN_NAMES[1],
|
||||
"type": "BIGINT",
|
||||
"is_dttm": True,
|
||||
"python_date_format": "%Y",
|
||||
"num_types": ["BIGINT"],
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@fixture
|
||||
def connector_registry() -> Mock:
|
||||
return Mock(spec=["get_datasource"])
|
||||
datasource_dao_mock = Mock(spec=["get_datasource"])
|
||||
datasource_dao_mock.get_datasource.return_value = Mock()
|
||||
datasource_dao_mock.get_datasource().get_column = Mock(
|
||||
side_effect=lambda col_name: TEMPORAL_COLUMNS[col_name]
|
||||
if col_name in TEMPORAL_COLUMN_NAMES
|
||||
else Mock()
|
||||
)
|
||||
datasource_dao_mock.get_datasource().db_extra = None
|
||||
return datasource_dao_mock
|
||||
|
||||
|
||||
def apply_max_row_limit(limit: int, max_limit: Optional[int] = None) -> int:
|
||||
|
|
@ -112,3 +148,55 @@ class TestQueryObjectFactory:
|
|||
raw_query_context["result_type"], **raw_query_object
|
||||
)
|
||||
assert query_object.post_processing == []
|
||||
|
||||
def test_query_context_no_python_date_format_filters(
|
||||
self,
|
||||
query_object_factory: QueryObjectFactory,
|
||||
raw_query_context: dict[str, Any],
|
||||
):
|
||||
raw_query_object = raw_query_context["queries"][0]
|
||||
raw_query_object["filters"].append(
|
||||
{"col": TEMPORAL_COLUMN_NAMES[0], "op": "==", "val": 315532800000}
|
||||
)
|
||||
query_object = query_object_factory.create(
|
||||
raw_query_context["result_type"],
|
||||
raw_query_context["datasource"],
|
||||
**raw_query_object
|
||||
)
|
||||
assert query_object.filter[3]["val"] == 315532800000
|
||||
|
||||
def test_query_context_python_date_format_filters(
|
||||
self,
|
||||
query_object_factory: QueryObjectFactory,
|
||||
raw_query_context: dict[str, Any],
|
||||
):
|
||||
raw_query_object = raw_query_context["queries"][0]
|
||||
raw_query_object["filters"].append(
|
||||
{"col": TEMPORAL_COLUMN_NAMES[1], "op": "==", "val": 315532800000}
|
||||
)
|
||||
query_object = query_object_factory.create(
|
||||
raw_query_context["result_type"],
|
||||
raw_query_context["datasource"],
|
||||
**raw_query_object
|
||||
)
|
||||
assert query_object.filter[3]["val"] == 1980
|
||||
|
||||
def test_query_context_python_date_format_filters_list_of_values(
|
||||
self,
|
||||
query_object_factory: QueryObjectFactory,
|
||||
raw_query_context: dict[str, Any],
|
||||
):
|
||||
raw_query_object = raw_query_context["queries"][0]
|
||||
raw_query_object["filters"].append(
|
||||
{
|
||||
"col": TEMPORAL_COLUMN_NAMES[1],
|
||||
"op": "==",
|
||||
"val": [315532800000, 631152000000],
|
||||
}
|
||||
)
|
||||
query_object = query_object_factory.create(
|
||||
raw_query_context["result_type"],
|
||||
raw_query_context["datasource"],
|
||||
**raw_query_object
|
||||
)
|
||||
assert query_object.filter[3]["val"] == [1980, 1990]
|
||||
|
|
|
|||
Loading…
Reference in New Issue