fix: Handle python_date_format in ExploreMixin (#24068)

This commit is contained in:
John Bodley 2023-05-16 06:54:12 -07:00 committed by GitHub
parent 78bc0693d4
commit 2f0caf8a0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 71 additions and 57 deletions

View File

@ -330,21 +330,6 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
def datasource(self) -> RelationshipProperty:
return self.table
def get_time_filter(
self,
start_dttm: Optional[DateTime] = None,
end_dttm: Optional[DateTime] = None,
label: Optional[str] = "__time",
template_processor: Optional[BaseTemplateProcessor] = None,
) -> ColumnElement:
col = self.get_sqla_col(label=label, template_processor=template_processor)
l = []
if start_dttm:
l.append(col >= self.table.text(self.dttm_sql_literal(start_dttm)))
if end_dttm:
l.append(col < self.table.text(self.dttm_sql_literal(end_dttm)))
return and_(*l)
def get_timestamp_expression(
self,
time_grain: Optional[str],
@ -379,36 +364,6 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
time_expr = self.db_engine_spec.get_timestamp_expr(col, pdf, time_grain)
return self.table.make_sqla_column_compatible(time_expr, label)
def dttm_sql_literal(self, dttm: DateTime) -> str:
"""Convert datetime object to a SQL expression string"""
sql = (
self.db_engine_spec.convert_dttm(self.type, dttm, db_extra=self.db_extra)
if self.type
else None
)
if sql:
return sql
tf = self.python_date_format
# Fallback to the default format (if defined).
if not tf:
tf = self.db_extra.get("python_date_format_by_column_name", {}).get(
self.column_name
)
if tf:
if tf in ["epoch_ms", "epoch_s"]:
seconds_since_epoch = int(dttm.timestamp())
if tf == "epoch_s":
return str(seconds_since_epoch)
return str(seconds_since_epoch * 1000)
return f"'{dttm.strftime(tf)}'"
# TODO(john-bodley): SIP-15 will explicitly require a type conversion.
return f"""'{dttm.strftime("%Y-%m-%d %H:%M:%S.%f")}'"""
@property
def data(self) -> Dict[str, Any]:
attrs = (

View File

@ -1269,17 +1269,12 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
return or_(*groups)
def dttm_sql_literal(
self,
col: "TableColumn",
dttm: sa.DateTime,
col_type: Optional[str],
) -> str:
def dttm_sql_literal(self, dttm: datetime, col: "TableColumn") -> str:
"""Convert datetime object to a SQL expression string"""
sql = (
self.db_engine_spec.convert_dttm(col_type, dttm, db_extra=None)
if col_type
self.db_engine_spec.convert_dttm(col.type, dttm, db_extra=self.db_extra)
if col.type
else None
)
@ -1330,14 +1325,14 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
l.append(
col
>= self.db_engine_spec.get_text_clause(
self.dttm_sql_literal(time_col, start_dttm, time_col.type)
self.dttm_sql_literal(start_dttm, time_col)
)
)
if end_dttm:
l.append(
col
< self.db_engine_spec.get_text_clause(
self.dttm_sql_literal(time_col, end_dttm, time_col.type)
self.dttm_sql_literal(end_dttm, time_col)
)
)
return and_(*l)

View File

@ -16,12 +16,17 @@
# under the License.
# pylint: disable=import-outside-toplevel
import json
from datetime import datetime
from typing import List, Optional
import pytest
from pytest_mock import MockFixture
from sqlalchemy.engine.reflection import Inspector
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.models.core import Database
def test_get_metrics(mocker: MockFixture) -> None:
"""
@ -143,3 +148,62 @@ def test_get_db_engine_spec(mocker: MockFixture) -> None:
).db_engine_spec
== OldDBEngineSpec
)
@pytest.mark.parametrize(
"dttm,col,database,result",
[
(
datetime(2023, 1, 1, 1, 23, 45, 600000),
TableColumn(python_date_format="epoch_s"),
Database(),
"1672536225",
),
(
datetime(2023, 1, 1, 1, 23, 45, 600000),
TableColumn(python_date_format="epoch_ms"),
Database(),
"1672536225000",
),
(
datetime(2023, 1, 1, 1, 23, 45, 600000),
TableColumn(python_date_format="%Y-%m-%d"),
Database(),
"'2023-01-01'",
),
(
datetime(2023, 1, 1, 1, 23, 45, 600000),
TableColumn(column_name="ds"),
Database(
extra=json.dumps(
{
"python_date_format_by_column_name": {
"ds": "%Y-%m-%d",
},
},
),
sqlalchemy_uri="foo://",
),
"'2023-01-01'",
),
(
datetime(2023, 1, 1, 1, 23, 45, 600000),
TableColumn(),
Database(sqlalchemy_uri="foo://"),
"'2023-01-01 01:23:45.600000'",
),
(
datetime(2023, 1, 1, 1, 23, 45, 600000),
TableColumn(type="TimeStamp"),
Database(sqlalchemy_uri="trino://"),
"TIMESTAMP '2023-01-01 01:23:45.600000'",
),
],
)
def test_dttm_sql_literal(
dttm: datetime,
col: TableColumn,
database: Database,
result: str,
) -> None:
assert SqlaTable(database=database).dttm_sql_literal(dttm, col) == result

View File

@ -18,7 +18,7 @@
# Remember to start celery workers to run celery tests, e.g.
# celery --app=superset.tasks.celery_app:app worker -Ofair -c 2
[testenv]
basepython = python3.8
basepython = python3.10
ignore_basepython_conflict = true
commands =
superset db upgrade