fix: Handling of column types for Presto, Trino, et al. (#28653)

This commit is contained in:
John Bodley 2024-05-28 08:56:38 -07:00 committed by GitHub
parent a59bad83d4
commit 4ff17409ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 85 additions and 39 deletions

View File

@ -14,9 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-lines
from __future__ import annotations
import contextlib
@ -118,7 +116,7 @@ def get_children(column: ResultSetColumnType) -> list[ResultSetColumnType]:
pattern = re.compile(r"(?P<type>\w+)\((?P<children>.*)\)")
if not column["type"]:
raise ValueError
match = pattern.match(column["type"])
match = pattern.match(cast(str, column["type"]))
if not match:
raise Exception( # pylint: disable=broad-exception-raised
f"Unable to parse column type {column['type']}"
@ -538,6 +536,10 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
for col_name, value in zip(col_names, values):
col_type = column_type_by_name.get(col_name)
if isinstance(col_type, str):
col_type_class = getattr(types, col_type, None)
col_type = col_type_class() if col_type_class else None
if isinstance(col_type, types.DATE):
col_type = Date()
elif isinstance(col_type, types.TIMESTAMP):

View File

@ -18,12 +18,15 @@ from collections.abc import Sequence
from datetime import datetime
from typing import Any, Literal, Optional, TYPE_CHECKING, TypedDict, Union
from sqlalchemy.sql.type_api import TypeEngine
from typing_extensions import NotRequired
from werkzeug.wrappers import Response
if TYPE_CHECKING:
from superset.utils.core import GenericDataType
SQLType = Union[TypeEngine, type[TypeEngine]]
class LegacyMetric(TypedDict):
label: Optional[str]
@ -73,7 +76,7 @@ class ResultSetColumnType(TypedDict):
name: str # legacy naming convention keeping this for backwards compatibility
column_name: str
type: Optional[str]
type: Optional[Union[SQLType, str]]
is_dttm: Optional[bool]
type_generic: NotRequired[Optional["GenericDataType"]]

View File

@ -25,7 +25,6 @@ from sqlalchemy import sql, text, types
from sqlalchemy.engine.url import make_url
from superset.sql_parse import Table
from superset.superset_typing import ResultSetColumnType
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
assert_column_spec,
@ -116,45 +115,43 @@ def test_get_schema_from_engine_params() -> None:
@pytest.mark.parametrize(
["column_type", "column_value", "expected_value"],
[
(types.DATE(), "2023-05-01", "DATE '2023-05-01'"),
(types.TIMESTAMP(), "2023-05-01", "TIMESTAMP '2023-05-01'"),
(types.VARCHAR(), "2023-05-01", "'2023-05-01'"),
(types.INT(), 1234, "1234"),
("DATE", "2023-05-01", "DATE '2023-05-01'"),
("TIMESTAMP", "2023-05-01", "TIMESTAMP '2023-05-01'"),
("VARCHAR", "2023-05-01", "'2023-05-01'"),
("INT", 1234, "1234"),
],
)
def test_where_latest_partition(
mock_latest_partition, column_type, column_value: Any, expected_value: str
mock_latest_partition,
column_type: str,
column_value: Any,
expected_value: str,
) -> None:
"""
Test the ``where_latest_partition`` method
"""
from superset.db_engine_specs.presto import PrestoEngineSpec as spec
from superset.db_engine_specs.presto import PrestoEngineSpec
mock_latest_partition.return_value = (["partition_key"], [column_value])
query = sql.select(text("* FROM table"))
columns: list[ResultSetColumnType] = [
{
"column_name": "partition_key",
"name": "partition_key",
"type": column_type,
"is_dttm": False,
}
]
expected = f"""SELECT * FROM table \nWHERE "partition_key" = {expected_value}"""
result = spec.where_latest_partition(
mock.MagicMock(),
Table("table"),
query,
columns,
assert (
str(
PrestoEngineSpec.where_latest_partition( # type: ignore
database=mock.MagicMock(),
table=Table("table"),
query=sql.select(text("* FROM table")),
columns=[
{
"column_name": "partition_key",
"name": "partition_key",
"type": column_type,
"is_dttm": False,
}
],
).compile(
dialect=PrestoDialect(),
compile_kwargs={"literal_binds": True},
)
)
== f"""SELECT * FROM table \nWHERE "partition_key" = {expected_value}"""
)
assert result is not None
actual = result.compile(
dialect=PrestoDialect(), compile_kwargs={"literal_binds": True}
)
assert str(actual) == expected
def test_adjust_engine_params_fully_qualified() -> None:

View File

@ -19,16 +19,17 @@ import copy
import json
from datetime import datetime
from typing import Any, Optional
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch
import pandas as pd
import pytest
from pytest_mock import MockerFixture
from requests.exceptions import ConnectionError as RequestsConnectionError
from sqlalchemy import types
from sqlalchemy import sql, text, types
from sqlalchemy.engine.url import make_url
from trino.exceptions import TrinoExternalError, TrinoInternalError, TrinoUserError
from trino.sqlalchemy import datatype
from trino.sqlalchemy.dialect import TrinoDialect
import superset.config
from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT
@ -39,7 +40,7 @@ from superset.db_engine_specs.exceptions import (
SupersetDBAPIProgrammingError,
)
from superset.sql_parse import Table
from superset.superset_typing import ResultSetColumnType, SQLAColumnType
from superset.superset_typing import ResultSetColumnType, SQLAColumnType, SQLType
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
assert_column_spec,
@ -645,3 +646,46 @@ def test_get_default_catalog() -> None:
sqlalchemy_uri="trino://user:pass@localhost:8080/system/default",
)
assert TrinoEngineSpec.get_default_catalog(database) == "system"
@patch("superset.db_engine_specs.trino.TrinoEngineSpec.latest_partition")
@pytest.mark.parametrize(
["column_type", "column_value", "expected_value"],
[
(types.DATE(), "2023-05-01", "DATE '2023-05-01'"),
(types.TIMESTAMP(), "2023-05-01", "TIMESTAMP '2023-05-01'"),
(types.VARCHAR(), "2023-05-01", "'2023-05-01'"),
(types.INT(), 1234, "1234"),
],
)
def test_where_latest_partition(
mock_latest_partition,
column_type: SQLType,
column_value: Any,
expected_value: str,
) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
mock_latest_partition.return_value = (["partition_key"], [column_value])
assert (
str(
TrinoEngineSpec.where_latest_partition( # type: ignore
database=MagicMock(),
table=Table("table"),
query=sql.select(text("* FROM table")),
columns=[
{
"column_name": "partition_key",
"name": "partition_key",
"type": column_type,
"is_dttm": False,
}
],
).compile(
dialect=TrinoDialect(),
compile_kwargs={"literal_binds": True},
)
)
== f"""SELECT * FROM table \nWHERE partition_key = {expected_value}"""
)