fix: Normalize prequery result type (#17312)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley 2021-11-03 13:58:40 -07:00 committed by GitHub
parent cb34a22684
commit 36f489eea0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 113 additions and 42 deletions

View File

@ -36,6 +36,7 @@ from typing import (
)
import dateutil.parser
import numpy as np
import pandas as pd
import sqlalchemy as sa
import sqlparse
@ -1455,6 +1456,39 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
)
return ob
def _normalize_prequery_result_type(
self, row: pd.Series, dimension: str, columns_by_name: Dict[str, TableColumn],
) -> Union[str, int, float, bool, Text]:
"""
Convert a prequery result type to its equivalent Python type.
Some databases like Druid will return timestamps as strings, but do not perform
automatic casting when comparing these strings to a timestamp. For cases like
this we convert the value via the appropriate SQL transform.
:param row: A prequery record
:param dimension: The dimension name
:param columns_by_name: The mapping of columns by name
:return: equivalent primitive python type
"""
value = row[dimension]
if isinstance(value, np.generic):
value = value.item()
column_ = columns_by_name[dimension]
if column_.type and column_.is_temporal and isinstance(value, str):
sql = self.db_engine_spec.convert_dttm(
column_.type, dateutil.parser.parse(value),
)
if sql:
value = text(sql)
return value
def _get_top_groups(
self,
df: pd.DataFrame,
@ -1466,15 +1500,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
for _unused, row in df.iterrows():
group = []
for dimension in dimensions:
value = utils.normalize_prequery_result_type(row[dimension])
# Some databases like Druid will return timestamps as strings, but
# do not perform automatic casting when comparing these strings to
# a timestamp. For cases like this we convert the value from a
# string into a timestamp.
if columns_by_name[dimension].is_temporal and isinstance(value, str):
dttm = dateutil.parser.parse(value)
value = text(self.db_engine_spec.convert_dttm("TIMESTAMP", dttm))
value = self._normalize_prequery_result_type(
row, dimension, columns_by_name,
)
group.append(groupby_exprs[dimension] == value)
groups.append(and_(*group))

View File

@ -1813,35 +1813,3 @@ def escape_sqla_query_binds(sql: str) -> str:
sql = sql.replace(bind, bind.replace(":", "\\:"))
processed_binds.add(bind)
return sql
def normalize_prequery_result_type(
value: Union[str, int, float, bool, np.generic]
) -> Union[str, int, float, bool]:
"""
Convert a value that is potentially a numpy type into its equivalent Python type.
:param value: primitive datatype in either numpy or python format
:return: equivalent primitive python type
>>> normalize_prequery_result_type('abc')
'abc'
>>> normalize_prequery_result_type(True)
True
>>> normalize_prequery_result_type(123)
123
>>> normalize_prequery_result_type(np.int16(123))
123
>>> normalize_prequery_result_type(np.uint32(123))
123
>>> normalize_prequery_result_type(np.int64(123))
123
>>> normalize_prequery_result_type(123.456)
123.456
>>> normalize_prequery_result_type(np.float32(123.456))
123.45600128173828
>>> normalize_prequery_result_type(np.float64(123.456))
123.456
"""
if isinstance(value, np.generic):
return value.item()
return value

View File

@ -16,11 +16,18 @@
# under the License.
# isort:skip_file
import re
from typing import Any, Dict, NamedTuple, List, Pattern, Tuple, Union
from datetime import datetime
from typing import Any, Dict, List, NamedTuple, Optional, Pattern, Tuple, Union
from unittest.mock import patch
import pytest
import numpy as np
import pandas as pd
import sqlalchemy as sa
from flask import Flask
from pytest_mock import MockFixture
from sqlalchemy.sql import text
from sqlalchemy.sql.elements import TextClause
from superset import db
from superset.connectors.sqla.models import SqlaTable, TableColumn
@ -33,6 +40,7 @@ from superset.utils.core import (
FilterOperator,
GenericDataType,
get_example_database,
TemporalType,
)
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
@ -484,3 +492,70 @@ class TestDatabaseModel(SupersetTestCase):
)
assert None not in without_null
assert len(without_null) == 2
@pytest.mark.parametrize(
"row,dimension,result",
[
(pd.Series({"foo": "abc"}), "foo", "abc"),
(pd.Series({"bar": True}), "bar", True),
(pd.Series({"baz": 123}), "baz", 123),
(pd.Series({"baz": np.int16(123)}), "baz", 123),
(pd.Series({"baz": np.uint32(123)}), "baz", 123),
(pd.Series({"baz": np.int64(123)}), "baz", 123),
(pd.Series({"qux": 123.456}), "qux", 123.456),
(pd.Series({"qux": np.float32(123.456)}), "qux", 123.45600128173828),
(pd.Series({"qux": np.float64(123.456)}), "qux", 123.456),
(pd.Series({"quux": "2021-01-01"}), "quux", "2021-01-01"),
(
pd.Series({"quuz": "2021-01-01T00:00:00"}),
"quuz",
text("TIME_PARSE('2021-01-01T00:00:00')"),
),
],
)
def test__normalize_prequery_result_type(
app_context: Flask,
mocker: MockFixture,
row: pd.Series,
dimension: str,
result: Any,
) -> None:
def _convert_dttm(target_type: str, dttm: datetime) -> Optional[str]:
if target_type.upper() == TemporalType.TIMESTAMP:
return f"""TIME_PARSE('{dttm.isoformat(timespec="seconds")}')"""
return None
table = SqlaTable(table_name="foobar", database=get_example_database())
mocker.patch.object(table.db_engine_spec, "convert_dttm", new=_convert_dttm)
columns_by_name = {
"foo": TableColumn(
column_name="foo", is_dttm=False, table=table, type="STRING",
),
"bar": TableColumn(
column_name="bar", is_dttm=False, table=table, type="BOOLEAN",
),
"baz": TableColumn(
column_name="baz", is_dttm=False, table=table, type="INTEGER",
),
"qux": TableColumn(
column_name="qux", is_dttm=False, table=table, type="FLOAT",
),
"quux": TableColumn(
column_name="quuz", is_dttm=True, table=table, type="STRING",
),
"quuz": TableColumn(
column_name="quux", is_dttm=True, table=table, type="TIMESTAMP",
),
}
normalized = table._normalize_prequery_result_type(row, dimension, columns_by_name,)
assert type(normalized) == type(result)
if isinstance(normalized, TextClause):
assert str(normalized) == str(result)
else:
assert normalized == result