fix: Normalize prequery result type (#17312)
Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
parent
cb34a22684
commit
36f489eea0
|
|
@ -36,6 +36,7 @@ from typing import (
|
|||
)
|
||||
|
||||
import dateutil.parser
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import sqlalchemy as sa
|
||||
import sqlparse
|
||||
|
|
@ -1455,6 +1456,39 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
)
|
||||
return ob
|
||||
|
||||
def _normalize_prequery_result_type(
|
||||
self, row: pd.Series, dimension: str, columns_by_name: Dict[str, TableColumn],
|
||||
) -> Union[str, int, float, bool, Text]:
|
||||
"""
|
||||
Convert a prequery result type to its equivalent Python type.
|
||||
|
||||
Some databases like Druid will return timestamps as strings, but do not perform
|
||||
automatic casting when comparing these strings to a timestamp. For cases like
|
||||
this we convert the value via the appropriate SQL transform.
|
||||
|
||||
:param row: A prequery record
|
||||
:param dimension: The dimension name
|
||||
:param columns_by_name: The mapping of columns by name
|
||||
:return: equivalent primitive python type
|
||||
"""
|
||||
|
||||
value = row[dimension]
|
||||
|
||||
if isinstance(value, np.generic):
|
||||
value = value.item()
|
||||
|
||||
column_ = columns_by_name[dimension]
|
||||
|
||||
if column_.type and column_.is_temporal and isinstance(value, str):
|
||||
sql = self.db_engine_spec.convert_dttm(
|
||||
column_.type, dateutil.parser.parse(value),
|
||||
)
|
||||
|
||||
if sql:
|
||||
value = text(sql)
|
||||
|
||||
return value
|
||||
|
||||
def _get_top_groups(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
|
|
@ -1466,15 +1500,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
for _unused, row in df.iterrows():
|
||||
group = []
|
||||
for dimension in dimensions:
|
||||
value = utils.normalize_prequery_result_type(row[dimension])
|
||||
|
||||
# Some databases like Druid will return timestamps as strings, but
|
||||
# do not perform automatic casting when comparing these strings to
|
||||
# a timestamp. For cases like this we convert the value from a
|
||||
# string into a timestamp.
|
||||
if columns_by_name[dimension].is_temporal and isinstance(value, str):
|
||||
dttm = dateutil.parser.parse(value)
|
||||
value = text(self.db_engine_spec.convert_dttm("TIMESTAMP", dttm))
|
||||
value = self._normalize_prequery_result_type(
|
||||
row, dimension, columns_by_name,
|
||||
)
|
||||
|
||||
group.append(groupby_exprs[dimension] == value)
|
||||
groups.append(and_(*group))
|
||||
|
|
|
|||
|
|
@ -1813,35 +1813,3 @@ def escape_sqla_query_binds(sql: str) -> str:
|
|||
sql = sql.replace(bind, bind.replace(":", "\\:"))
|
||||
processed_binds.add(bind)
|
||||
return sql
|
||||
|
||||
|
||||
def normalize_prequery_result_type(
|
||||
value: Union[str, int, float, bool, np.generic]
|
||||
) -> Union[str, int, float, bool]:
|
||||
"""
|
||||
Convert a value that is potentially a numpy type into its equivalent Python type.
|
||||
|
||||
:param value: primitive datatype in either numpy or python format
|
||||
:return: equivalent primitive python type
|
||||
>>> normalize_prequery_result_type('abc')
|
||||
'abc'
|
||||
>>> normalize_prequery_result_type(True)
|
||||
True
|
||||
>>> normalize_prequery_result_type(123)
|
||||
123
|
||||
>>> normalize_prequery_result_type(np.int16(123))
|
||||
123
|
||||
>>> normalize_prequery_result_type(np.uint32(123))
|
||||
123
|
||||
>>> normalize_prequery_result_type(np.int64(123))
|
||||
123
|
||||
>>> normalize_prequery_result_type(123.456)
|
||||
123.456
|
||||
>>> normalize_prequery_result_type(np.float32(123.456))
|
||||
123.45600128173828
|
||||
>>> normalize_prequery_result_type(np.float64(123.456))
|
||||
123.456
|
||||
"""
|
||||
if isinstance(value, np.generic):
|
||||
return value.item()
|
||||
return value
|
||||
|
|
|
|||
|
|
@ -16,11 +16,18 @@
|
|||
# under the License.
|
||||
# isort:skip_file
|
||||
import re
|
||||
from typing import Any, Dict, NamedTuple, List, Pattern, Tuple, Union
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Pattern, Tuple, Union
|
||||
from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import sqlalchemy as sa
|
||||
from flask import Flask
|
||||
from pytest_mock import MockFixture
|
||||
from sqlalchemy.sql import text
|
||||
from sqlalchemy.sql.elements import TextClause
|
||||
|
||||
from superset import db
|
||||
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
||||
|
|
@ -33,6 +40,7 @@ from superset.utils.core import (
|
|||
FilterOperator,
|
||||
GenericDataType,
|
||||
get_example_database,
|
||||
TemporalType,
|
||||
)
|
||||
from tests.integration_tests.fixtures.birth_names_dashboard import (
|
||||
load_birth_names_dashboard_with_slices,
|
||||
|
|
@ -484,3 +492,70 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
)
|
||||
assert None not in without_null
|
||||
assert len(without_null) == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"row,dimension,result",
|
||||
[
|
||||
(pd.Series({"foo": "abc"}), "foo", "abc"),
|
||||
(pd.Series({"bar": True}), "bar", True),
|
||||
(pd.Series({"baz": 123}), "baz", 123),
|
||||
(pd.Series({"baz": np.int16(123)}), "baz", 123),
|
||||
(pd.Series({"baz": np.uint32(123)}), "baz", 123),
|
||||
(pd.Series({"baz": np.int64(123)}), "baz", 123),
|
||||
(pd.Series({"qux": 123.456}), "qux", 123.456),
|
||||
(pd.Series({"qux": np.float32(123.456)}), "qux", 123.45600128173828),
|
||||
(pd.Series({"qux": np.float64(123.456)}), "qux", 123.456),
|
||||
(pd.Series({"quux": "2021-01-01"}), "quux", "2021-01-01"),
|
||||
(
|
||||
pd.Series({"quuz": "2021-01-01T00:00:00"}),
|
||||
"quuz",
|
||||
text("TIME_PARSE('2021-01-01T00:00:00')"),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test__normalize_prequery_result_type(
|
||||
app_context: Flask,
|
||||
mocker: MockFixture,
|
||||
row: pd.Series,
|
||||
dimension: str,
|
||||
result: Any,
|
||||
) -> None:
|
||||
def _convert_dttm(target_type: str, dttm: datetime) -> Optional[str]:
|
||||
if target_type.upper() == TemporalType.TIMESTAMP:
|
||||
return f"""TIME_PARSE('{dttm.isoformat(timespec="seconds")}')"""
|
||||
|
||||
return None
|
||||
|
||||
table = SqlaTable(table_name="foobar", database=get_example_database())
|
||||
mocker.patch.object(table.db_engine_spec, "convert_dttm", new=_convert_dttm)
|
||||
|
||||
columns_by_name = {
|
||||
"foo": TableColumn(
|
||||
column_name="foo", is_dttm=False, table=table, type="STRING",
|
||||
),
|
||||
"bar": TableColumn(
|
||||
column_name="bar", is_dttm=False, table=table, type="BOOLEAN",
|
||||
),
|
||||
"baz": TableColumn(
|
||||
column_name="baz", is_dttm=False, table=table, type="INTEGER",
|
||||
),
|
||||
"qux": TableColumn(
|
||||
column_name="qux", is_dttm=False, table=table, type="FLOAT",
|
||||
),
|
||||
"quux": TableColumn(
|
||||
column_name="quuz", is_dttm=True, table=table, type="STRING",
|
||||
),
|
||||
"quuz": TableColumn(
|
||||
column_name="quux", is_dttm=True, table=table, type="TIMESTAMP",
|
||||
),
|
||||
}
|
||||
|
||||
normalized = table._normalize_prequery_result_type(row, dimension, columns_by_name,)
|
||||
|
||||
assert type(normalized) == type(result)
|
||||
|
||||
if isinstance(normalized, TextClause):
|
||||
assert str(normalized) == str(result)
|
||||
else:
|
||||
assert normalized == result
|
||||
|
|
|
|||
Loading…
Reference in New Issue