diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index d97264e00..4f3b82af6 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -14,9 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - # pylint: disable=too-many-lines - from __future__ import annotations import contextlib @@ -118,7 +116,7 @@ def get_children(column: ResultSetColumnType) -> list[ResultSetColumnType]: pattern = re.compile(r"(?P\w+)\((?P.*)\)") if not column["type"]: raise ValueError - match = pattern.match(column["type"]) + match = pattern.match(cast(str, column["type"])) if not match: raise Exception( # pylint: disable=broad-exception-raised f"Unable to parse column type {column['type']}" @@ -538,6 +536,10 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta): for col_name, value in zip(col_names, values): col_type = column_type_by_name.get(col_name) + if isinstance(col_type, str): + col_type_class = getattr(types, col_type, None) + col_type = col_type_class() if col_type_class else None + if isinstance(col_type, types.DATE): col_type = Date() elif isinstance(col_type, types.TIMESTAMP): diff --git a/superset/superset_typing.py b/superset/superset_typing.py index ba623f581..3a850e0ac 100644 --- a/superset/superset_typing.py +++ b/superset/superset_typing.py @@ -18,12 +18,15 @@ from collections.abc import Sequence from datetime import datetime from typing import Any, Literal, Optional, TYPE_CHECKING, TypedDict, Union +from sqlalchemy.sql.type_api import TypeEngine from typing_extensions import NotRequired from werkzeug.wrappers import Response if TYPE_CHECKING: from superset.utils.core import GenericDataType +SQLType = Union[TypeEngine, type[TypeEngine]] + class LegacyMetric(TypedDict): label: Optional[str] @@ -73,7 +76,7 @@ class ResultSetColumnType(TypedDict): name: str # legacy naming convention keeping this for backwards compatibility column_name: str - type: Optional[str] + type: Optional[Union[SQLType, str]] is_dttm: Optional[bool] type_generic: NotRequired[Optional["GenericDataType"]] diff --git a/tests/unit_tests/db_engine_specs/test_presto.py b/tests/unit_tests/db_engine_specs/test_presto.py index e00210222..7631ed9ad 100644 --- a/tests/unit_tests/db_engine_specs/test_presto.py +++ b/tests/unit_tests/db_engine_specs/test_presto.py @@ -25,7 +25,6 @@ from sqlalchemy import sql, text, types from sqlalchemy.engine.url import make_url from superset.sql_parse import Table -from superset.superset_typing import ResultSetColumnType from superset.utils.core import GenericDataType from tests.unit_tests.db_engine_specs.utils import ( assert_column_spec, @@ -116,45 +115,43 @@ def test_get_schema_from_engine_params() -> None: @pytest.mark.parametrize( ["column_type", "column_value", "expected_value"], [ - (types.DATE(), "2023-05-01", "DATE '2023-05-01'"), - (types.TIMESTAMP(), "2023-05-01", "TIMESTAMP '2023-05-01'"), - (types.VARCHAR(), "2023-05-01", "'2023-05-01'"), - (types.INT(), 1234, "1234"), + ("DATE", "2023-05-01", "DATE '2023-05-01'"), + ("TIMESTAMP", "2023-05-01", "TIMESTAMP '2023-05-01'"), + ("VARCHAR", "2023-05-01", "'2023-05-01'"), + ("INT", 1234, "1234"), ], ) def test_where_latest_partition( - mock_latest_partition, column_type, column_value: Any, expected_value: str + mock_latest_partition, + column_type: str, + column_value: Any, + expected_value: str, ) -> None: - """ - Test the ``where_latest_partition`` method - """ - from superset.db_engine_specs.presto import PrestoEngineSpec as spec + from superset.db_engine_specs.presto import PrestoEngineSpec mock_latest_partition.return_value = (["partition_key"], [column_value]) - query = sql.select(text("* FROM table")) - columns: list[ResultSetColumnType] = [ - { - "column_name": "partition_key", - "name": "partition_key", - "type": column_type, - "is_dttm": False, - } - ] - - expected = f"""SELECT * FROM table \nWHERE "partition_key" = {expected_value}""" - result = spec.where_latest_partition( - mock.MagicMock(), - Table("table"), - query, - columns, + assert ( + str( + PrestoEngineSpec.where_latest_partition( # type: ignore + database=mock.MagicMock(), + table=Table("table"), + query=sql.select(text("* FROM table")), + columns=[ + { + "column_name": "partition_key", + "name": "partition_key", + "type": column_type, + "is_dttm": False, + } + ], + ).compile( + dialect=PrestoDialect(), + compile_kwargs={"literal_binds": True}, + ) + ) + == f"""SELECT * FROM table \nWHERE "partition_key" = {expected_value}""" ) - assert result is not None - actual = result.compile( - dialect=PrestoDialect(), compile_kwargs={"literal_binds": True} - ) - - assert str(actual) == expected def test_adjust_engine_params_fully_qualified() -> None: diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py index e35615f57..c783780a4 100644 --- a/tests/unit_tests/db_engine_specs/test_trino.py +++ b/tests/unit_tests/db_engine_specs/test_trino.py @@ -19,16 +19,17 @@ import copy import json from datetime import datetime from typing import Any, Optional -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import pandas as pd import pytest from pytest_mock import MockerFixture from requests.exceptions import ConnectionError as RequestsConnectionError -from sqlalchemy import types +from sqlalchemy import sql, text, types from sqlalchemy.engine.url import make_url from trino.exceptions import TrinoExternalError, TrinoInternalError, TrinoUserError from trino.sqlalchemy import datatype +from trino.sqlalchemy.dialect import TrinoDialect import superset.config from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT @@ -39,7 +40,7 @@ from superset.db_engine_specs.exceptions import ( SupersetDBAPIProgrammingError, ) from superset.sql_parse import Table -from superset.superset_typing import ResultSetColumnType, SQLAColumnType +from superset.superset_typing import ResultSetColumnType, SQLAColumnType, SQLType from superset.utils.core import GenericDataType from tests.unit_tests.db_engine_specs.utils import ( assert_column_spec, @@ -645,3 +646,46 @@ def test_get_default_catalog() -> None: sqlalchemy_uri="trino://user:pass@localhost:8080/system/default", ) assert TrinoEngineSpec.get_default_catalog(database) == "system" + + +@patch("superset.db_engine_specs.trino.TrinoEngineSpec.latest_partition") +@pytest.mark.parametrize( + ["column_type", "column_value", "expected_value"], + [ + (types.DATE(), "2023-05-01", "DATE '2023-05-01'"), + (types.TIMESTAMP(), "2023-05-01", "TIMESTAMP '2023-05-01'"), + (types.VARCHAR(), "2023-05-01", "'2023-05-01'"), + (types.INT(), 1234, "1234"), + ], +) +def test_where_latest_partition( + mock_latest_partition, + column_type: SQLType, + column_value: Any, + expected_value: str, +) -> None: + from superset.db_engine_specs.trino import TrinoEngineSpec + + mock_latest_partition.return_value = (["partition_key"], [column_value]) + + assert ( + str( + TrinoEngineSpec.where_latest_partition( # type: ignore + database=MagicMock(), + table=Table("table"), + query=sql.select(text("* FROM table")), + columns=[ + { + "column_name": "partition_key", + "name": "partition_key", + "type": column_type, + "is_dttm": False, + } + ], + ).compile( + dialect=TrinoDialect(), + compile_kwargs={"literal_binds": True}, + ) + ) + == f"""SELECT * FROM table \nWHERE partition_key = {expected_value}""" + )