fix: Handling of column types for Presto, Trino, et al. (#28653)
This commit is contained in:
parent
a59bad83d4
commit
4ff17409ab
|
|
@ -14,9 +14,7 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
|
|
@ -118,7 +116,7 @@ def get_children(column: ResultSetColumnType) -> list[ResultSetColumnType]:
|
|||
pattern = re.compile(r"(?P<type>\w+)\((?P<children>.*)\)")
|
||||
if not column["type"]:
|
||||
raise ValueError
|
||||
match = pattern.match(column["type"])
|
||||
match = pattern.match(cast(str, column["type"]))
|
||||
if not match:
|
||||
raise Exception( # pylint: disable=broad-exception-raised
|
||||
f"Unable to parse column type {column['type']}"
|
||||
|
|
@ -538,6 +536,10 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
|
|||
for col_name, value in zip(col_names, values):
|
||||
col_type = column_type_by_name.get(col_name)
|
||||
|
||||
if isinstance(col_type, str):
|
||||
col_type_class = getattr(types, col_type, None)
|
||||
col_type = col_type_class() if col_type_class else None
|
||||
|
||||
if isinstance(col_type, types.DATE):
|
||||
col_type = Date()
|
||||
elif isinstance(col_type, types.TIMESTAMP):
|
||||
|
|
|
|||
|
|
@ -18,12 +18,15 @@ from collections.abc import Sequence
|
|||
from datetime import datetime
|
||||
from typing import Any, Literal, Optional, TYPE_CHECKING, TypedDict, Union
|
||||
|
||||
from sqlalchemy.sql.type_api import TypeEngine
|
||||
from typing_extensions import NotRequired
|
||||
from werkzeug.wrappers import Response
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.utils.core import GenericDataType
|
||||
|
||||
SQLType = Union[TypeEngine, type[TypeEngine]]
|
||||
|
||||
|
||||
class LegacyMetric(TypedDict):
|
||||
label: Optional[str]
|
||||
|
|
@ -73,7 +76,7 @@ class ResultSetColumnType(TypedDict):
|
|||
|
||||
name: str # legacy naming convention keeping this for backwards compatibility
|
||||
column_name: str
|
||||
type: Optional[str]
|
||||
type: Optional[Union[SQLType, str]]
|
||||
is_dttm: Optional[bool]
|
||||
type_generic: NotRequired[Optional["GenericDataType"]]
|
||||
|
||||
|
|
|
|||
|
|
@ -25,7 +25,6 @@ from sqlalchemy import sql, text, types
|
|||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
from superset.sql_parse import Table
|
||||
from superset.superset_typing import ResultSetColumnType
|
||||
from superset.utils.core import GenericDataType
|
||||
from tests.unit_tests.db_engine_specs.utils import (
|
||||
assert_column_spec,
|
||||
|
|
@ -116,45 +115,43 @@ def test_get_schema_from_engine_params() -> None:
|
|||
@pytest.mark.parametrize(
|
||||
["column_type", "column_value", "expected_value"],
|
||||
[
|
||||
(types.DATE(), "2023-05-01", "DATE '2023-05-01'"),
|
||||
(types.TIMESTAMP(), "2023-05-01", "TIMESTAMP '2023-05-01'"),
|
||||
(types.VARCHAR(), "2023-05-01", "'2023-05-01'"),
|
||||
(types.INT(), 1234, "1234"),
|
||||
("DATE", "2023-05-01", "DATE '2023-05-01'"),
|
||||
("TIMESTAMP", "2023-05-01", "TIMESTAMP '2023-05-01'"),
|
||||
("VARCHAR", "2023-05-01", "'2023-05-01'"),
|
||||
("INT", 1234, "1234"),
|
||||
],
|
||||
)
|
||||
def test_where_latest_partition(
|
||||
mock_latest_partition, column_type, column_value: Any, expected_value: str
|
||||
mock_latest_partition,
|
||||
column_type: str,
|
||||
column_value: Any,
|
||||
expected_value: str,
|
||||
) -> None:
|
||||
"""
|
||||
Test the ``where_latest_partition`` method
|
||||
"""
|
||||
from superset.db_engine_specs.presto import PrestoEngineSpec as spec
|
||||
from superset.db_engine_specs.presto import PrestoEngineSpec
|
||||
|
||||
mock_latest_partition.return_value = (["partition_key"], [column_value])
|
||||
|
||||
query = sql.select(text("* FROM table"))
|
||||
columns: list[ResultSetColumnType] = [
|
||||
{
|
||||
"column_name": "partition_key",
|
||||
"name": "partition_key",
|
||||
"type": column_type,
|
||||
"is_dttm": False,
|
||||
}
|
||||
]
|
||||
|
||||
expected = f"""SELECT * FROM table \nWHERE "partition_key" = {expected_value}"""
|
||||
result = spec.where_latest_partition(
|
||||
mock.MagicMock(),
|
||||
Table("table"),
|
||||
query,
|
||||
columns,
|
||||
assert (
|
||||
str(
|
||||
PrestoEngineSpec.where_latest_partition( # type: ignore
|
||||
database=mock.MagicMock(),
|
||||
table=Table("table"),
|
||||
query=sql.select(text("* FROM table")),
|
||||
columns=[
|
||||
{
|
||||
"column_name": "partition_key",
|
||||
"name": "partition_key",
|
||||
"type": column_type,
|
||||
"is_dttm": False,
|
||||
}
|
||||
],
|
||||
).compile(
|
||||
dialect=PrestoDialect(),
|
||||
compile_kwargs={"literal_binds": True},
|
||||
)
|
||||
)
|
||||
== f"""SELECT * FROM table \nWHERE "partition_key" = {expected_value}"""
|
||||
)
|
||||
assert result is not None
|
||||
actual = result.compile(
|
||||
dialect=PrestoDialect(), compile_kwargs={"literal_binds": True}
|
||||
)
|
||||
|
||||
assert str(actual) == expected
|
||||
|
||||
|
||||
def test_adjust_engine_params_fully_qualified() -> None:
|
||||
|
|
|
|||
|
|
@ -19,16 +19,17 @@ import copy
|
|||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from requests.exceptions import ConnectionError as RequestsConnectionError
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy import sql, text, types
|
||||
from sqlalchemy.engine.url import make_url
|
||||
from trino.exceptions import TrinoExternalError, TrinoInternalError, TrinoUserError
|
||||
from trino.sqlalchemy import datatype
|
||||
from trino.sqlalchemy.dialect import TrinoDialect
|
||||
|
||||
import superset.config
|
||||
from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT
|
||||
|
|
@ -39,7 +40,7 @@ from superset.db_engine_specs.exceptions import (
|
|||
SupersetDBAPIProgrammingError,
|
||||
)
|
||||
from superset.sql_parse import Table
|
||||
from superset.superset_typing import ResultSetColumnType, SQLAColumnType
|
||||
from superset.superset_typing import ResultSetColumnType, SQLAColumnType, SQLType
|
||||
from superset.utils.core import GenericDataType
|
||||
from tests.unit_tests.db_engine_specs.utils import (
|
||||
assert_column_spec,
|
||||
|
|
@ -645,3 +646,46 @@ def test_get_default_catalog() -> None:
|
|||
sqlalchemy_uri="trino://user:pass@localhost:8080/system/default",
|
||||
)
|
||||
assert TrinoEngineSpec.get_default_catalog(database) == "system"
|
||||
|
||||
|
||||
@patch("superset.db_engine_specs.trino.TrinoEngineSpec.latest_partition")
|
||||
@pytest.mark.parametrize(
|
||||
["column_type", "column_value", "expected_value"],
|
||||
[
|
||||
(types.DATE(), "2023-05-01", "DATE '2023-05-01'"),
|
||||
(types.TIMESTAMP(), "2023-05-01", "TIMESTAMP '2023-05-01'"),
|
||||
(types.VARCHAR(), "2023-05-01", "'2023-05-01'"),
|
||||
(types.INT(), 1234, "1234"),
|
||||
],
|
||||
)
|
||||
def test_where_latest_partition(
|
||||
mock_latest_partition,
|
||||
column_type: SQLType,
|
||||
column_value: Any,
|
||||
expected_value: str,
|
||||
) -> None:
|
||||
from superset.db_engine_specs.trino import TrinoEngineSpec
|
||||
|
||||
mock_latest_partition.return_value = (["partition_key"], [column_value])
|
||||
|
||||
assert (
|
||||
str(
|
||||
TrinoEngineSpec.where_latest_partition( # type: ignore
|
||||
database=MagicMock(),
|
||||
table=Table("table"),
|
||||
query=sql.select(text("* FROM table")),
|
||||
columns=[
|
||||
{
|
||||
"column_name": "partition_key",
|
||||
"name": "partition_key",
|
||||
"type": column_type,
|
||||
"is_dttm": False,
|
||||
}
|
||||
],
|
||||
).compile(
|
||||
dialect=TrinoDialect(),
|
||||
compile_kwargs={"literal_binds": True},
|
||||
)
|
||||
)
|
||||
== f"""SELECT * FROM table \nWHERE partition_key = {expected_value}"""
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue