fix: Apply normalization to all dttm columns (#25147)

This commit is contained in:
Kamil Gabryjelski 2023-10-06 18:47:00 +02:00 committed by GitHub
parent 17792a507c
commit 58fcd292a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 160 additions and 10 deletions

View File

@ -185,6 +185,7 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
filter
for filter in query_object.filter
if filter["col"] != filter_to_remove
or filter["op"] != "TEMPORAL_RANGE"
]
def _apply_filters(self, query_object: QueryObject) -> None:

View File

@ -282,10 +282,11 @@ class QueryContextProcessor:
datasource = self._qc_datasource
labels = tuple(
label
for label in [
for label in {
*get_base_axis_labels(query_object.columns),
*[col for col in query_object.columns or [] if isinstance(col, str)],
query_object.granularity,
]
}
if datasource
# Query datasource didn't support `get_column`
and hasattr(datasource, "get_column")

View File

@ -16,17 +16,24 @@
# under the License.
from __future__ import annotations
from datetime import datetime
from typing import Any, TYPE_CHECKING
from superset.common.chart_data import ChartDataResultType
from superset.common.query_object import QueryObject
from superset.common.utils.time_range_utils import get_since_until_from_time_range
from superset.utils.core import apply_max_row_limit, DatasourceDict, DatasourceType
from superset.utils.core import (
apply_max_row_limit,
DatasourceDict,
DatasourceType,
FilterOperator,
QueryObjectFilterClause,
)
if TYPE_CHECKING:
from sqlalchemy.orm import sessionmaker
from superset.connectors.base.models import BaseDatasource
from superset.connectors.base.models import BaseColumn, BaseDatasource
from superset.daos.datasource import DatasourceDAO
@ -66,6 +73,10 @@ class QueryObjectFactory: # pylint: disable=too-few-public-methods
)
kwargs["from_dttm"] = from_dttm
kwargs["to_dttm"] = to_dttm
if datasource_model_instance and kwargs.get("filters", []):
kwargs["filters"] = self._process_filters(
datasource_model_instance, kwargs["filters"]
)
return QueryObject(
datasource=datasource_model_instance,
extras=extras,
@ -102,3 +113,54 @@ class QueryObjectFactory: # pylint: disable=too-few-public-methods
# light version of the view.utils.core
# import view.utils require application context
# Todo: move it and the view.utils.core to utils package
def _process_filters(
self, datasource: BaseDatasource, query_filters: list[QueryObjectFilterClause]
) -> list[QueryObjectFilterClause]:
def get_dttm_filter_value(
value: Any, col: BaseColumn, date_format: str
) -> int | str:
if not isinstance(value, int):
return value
if date_format in {"epoch_ms", "epoch_s"}:
if date_format == "epoch_s":
value = str(value)
else:
value = str(value * 1000)
else:
dttm = datetime.utcfromtimestamp(value / 1000)
value = dttm.strftime(date_format)
if col.type in col.num_types:
value = int(value)
return value
for query_filter in query_filters:
if query_filter.get("op") == FilterOperator.TEMPORAL_RANGE:
continue
filter_col = query_filter.get("col")
if not isinstance(filter_col, str):
continue
column = datasource.get_column(filter_col)
if not column:
continue
filter_value = query_filter.get("val")
date_format = column.python_date_format
if not date_format and datasource.db_extra:
date_format = datasource.db_extra.get(
"python_date_format_by_column_name", {}
).get(column.column_name)
if column.is_dttm and date_format:
if isinstance(filter_value, list):
query_filter["val"] = [
get_dttm_filter_value(value, column, date_format)
for value in filter_value
]
else:
query_filter["val"] = get_dttm_filter_value(
filter_value, column, date_format
)
return query_filters

View File

@ -836,11 +836,9 @@ def test_special_chars_in_column_name(app_context, physical_dataset):
query_object = qc.queries[0]
df = qc.get_df_payload(query_object)["df"]
if query_object.datasource.database.backend == "sqlite":
# sqlite returns string as timestamp column
assert df["time column with spaces"][0] == "2002-01-03 00:00:00"
assert df["I_AM_A_TRUNC_COLUMN"][0] == "2002-01-01 00:00:00"
else:
# sqlite doesn't have timestamp columns
if query_object.datasource.database.backend != "sqlite":
assert df["time column with spaces"][0].strftime("%Y-%m-%d") == "2002-01-03"
assert df["I_AM_A_TRUNC_COLUMN"][0].strftime("%Y-%m-%d") == "2002-01-01"

View File

@ -43,9 +43,45 @@ def session_factory() -> Mock:
return Mock()
class SimpleDatasetColumn:
def __init__(self, col_params: dict[str, Any]):
self.__dict__.update(col_params)
TEMPORAL_COLUMN_NAMES = ["temporal_column", "temporal_column_with_python_date_format"]
TEMPORAL_COLUMNS = {
TEMPORAL_COLUMN_NAMES[0]: SimpleDatasetColumn(
{
"column_name": TEMPORAL_COLUMN_NAMES[0],
"is_dttm": True,
"python_date_format": None,
"type": "string",
"num_types": ["BIGINT"],
}
),
TEMPORAL_COLUMN_NAMES[1]: SimpleDatasetColumn(
{
"column_name": TEMPORAL_COLUMN_NAMES[1],
"type": "BIGINT",
"is_dttm": True,
"python_date_format": "%Y",
"num_types": ["BIGINT"],
}
),
}
@fixture
def connector_registry() -> Mock:
return Mock(spec=["get_datasource"])
datasource_dao_mock = Mock(spec=["get_datasource"])
datasource_dao_mock.get_datasource.return_value = Mock()
datasource_dao_mock.get_datasource().get_column = Mock(
side_effect=lambda col_name: TEMPORAL_COLUMNS[col_name]
if col_name in TEMPORAL_COLUMN_NAMES
else Mock()
)
datasource_dao_mock.get_datasource().db_extra = None
return datasource_dao_mock
def apply_max_row_limit(limit: int, max_limit: Optional[int] = None) -> int:
@ -112,3 +148,55 @@ class TestQueryObjectFactory:
raw_query_context["result_type"], **raw_query_object
)
assert query_object.post_processing == []
def test_query_context_no_python_date_format_filters(
self,
query_object_factory: QueryObjectFactory,
raw_query_context: dict[str, Any],
):
raw_query_object = raw_query_context["queries"][0]
raw_query_object["filters"].append(
{"col": TEMPORAL_COLUMN_NAMES[0], "op": "==", "val": 315532800000}
)
query_object = query_object_factory.create(
raw_query_context["result_type"],
raw_query_context["datasource"],
**raw_query_object
)
assert query_object.filter[3]["val"] == 315532800000
def test_query_context_python_date_format_filters(
self,
query_object_factory: QueryObjectFactory,
raw_query_context: dict[str, Any],
):
raw_query_object = raw_query_context["queries"][0]
raw_query_object["filters"].append(
{"col": TEMPORAL_COLUMN_NAMES[1], "op": "==", "val": 315532800000}
)
query_object = query_object_factory.create(
raw_query_context["result_type"],
raw_query_context["datasource"],
**raw_query_object
)
assert query_object.filter[3]["val"] == 1980
def test_query_context_python_date_format_filters_list_of_values(
self,
query_object_factory: QueryObjectFactory,
raw_query_context: dict[str, Any],
):
raw_query_object = raw_query_context["queries"][0]
raw_query_object["filters"].append(
{
"col": TEMPORAL_COLUMN_NAMES[1],
"op": "==",
"val": [315532800000, 631152000000],
}
)
query_object = query_object_factory.create(
raw_query_context["result_type"],
raw_query_context["datasource"],
**raw_query_object
)
assert query_object.filter[3]["val"] == [1980, 1990]