diff --git a/superset/db_engine_specs/athena.py b/superset/db_engine_specs/athena.py index f4a6efca3..047952402 100644 --- a/superset/db_engine_specs/athena.py +++ b/superset/db_engine_specs/athena.py @@ -19,10 +19,10 @@ from datetime import datetime from typing import Any, Dict, Optional, Pattern, Tuple from flask_babel import gettext as __ +from sqlalchemy import types from superset.db_engine_specs.base import BaseEngineSpec from superset.errors import SupersetErrorType -from superset.utils import core as utils SYNTAX_ERROR_REGEX = re.compile( ": mismatched input '(?P.*?)'. Expecting: " @@ -66,10 +66,11 @@ class AthenaEngineSpec(BaseEngineSpec): def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None ) -> Optional[str]: - tt = target_type.upper() - if tt == utils.TemporalType.DATE: + sqla_type = cls.get_sqla_column_type(target_type) + + if isinstance(sqla_type, types.Date): return f"DATE '{dttm.date().isoformat()}'" - if tt == utils.TemporalType.TIMESTAMP: + if isinstance(sqla_type, types.TIMESTAMP): datetime_formatted = dttm.isoformat(sep=" ", timespec="milliseconds") return f"""TIMESTAMP '{datetime_formatted}'""" return None diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 32f184622..631efb461 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -197,7 +197,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods _date_trunc_functions: Dict[str, str] = {} _time_grain_expressions: Dict[Optional[str], str] = {} - column_type_mappings: Tuple[ColumnTypeMapping, ...] = ( + _default_column_type_mappings: Tuple[ColumnTypeMapping, ...] = ( ( re.compile(r"^string", re.IGNORECASE), types.String(), @@ -314,6 +314,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods GenericDataType.BOOLEAN, ), ) + # engine-specific type mappings to check prior to the defaults + column_type_mappings: Tuple[ColumnTypeMapping, ...] = () # Does database support join-free timeslot grouping time_groupby_inline = False @@ -1389,24 +1391,25 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return label_mutated @classmethod - def get_sqla_column_type( + def get_column_types( cls, column_type: Optional[str], - column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings, ) -> Optional[Tuple[TypeEngine, GenericDataType]]: """ - Return a sqlalchemy native column type that corresponds to the column type - defined in the data source (return None to use default type inferred by - SQLAlchemy). Override `column_type_mappings` for specific needs + Return a sqlalchemy native column type and generic data type that corresponds + to the column type defined in the data source (return None to use default type + inferred by SQLAlchemy). Override `column_type_mappings` for specific needs (see MSSQL for example of NCHAR/NVARCHAR handling). :param column_type: Column type returned by inspector - :param column_type_mappings: Maps from string to SqlAlchemy TypeEngine - :return: SqlAlchemy column type + :return: SQLAlchemy and generic Superset column types """ if not column_type: return None - for regex, sqla_type, generic_type in column_type_mappings: + + for regex, sqla_type, generic_type in ( + cls.column_type_mappings + cls._default_column_type_mappings + ): match = regex.match(column_type) if not match: continue @@ -1569,19 +1572,16 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods native_type: Optional[str], db_extra: Optional[Dict[str, Any]] = None, source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE, - column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings, ) -> Optional[ColumnSpec]: """ - Converts native database type to sqlalchemy column type. + Get generic type related specs regarding a native column type. + :param native_type: Native database type :param db_extra: The database extra object :param source: Type coming from the database table or cursor description - :param column_type_mappings: Maps from string to SqlAlchemy TypeEngine :return: ColumnSpec object """ - col_types = cls.get_sqla_column_type( - native_type, column_type_mappings=column_type_mappings - ) + col_types = cls.get_column_types(native_type) if col_types: column_type, generic_type = col_types is_dttm = generic_type == GenericDataType.TEMPORAL @@ -1590,6 +1590,28 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods ) return None + @classmethod + def get_sqla_column_type( + cls, + native_type: Optional[str], + db_extra: Optional[Dict[str, Any]] = None, + source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE, + ) -> Optional[TypeEngine]: + """ + Converts native database type to sqlalchemy column type. + + :param native_type: Native database type + :param db_extra: The database extra object + :param source: Type coming from the database table or cursor description + :return: ColumnSpec object + """ + column_spec = cls.get_column_spec( + native_type=native_type, + db_extra=db_extra, + source=source, + ) + return column_spec.sqla_type if column_spec else None + # pylint: disable=unused-argument @classmethod def prepare_cancel_query(cls, query: Query, session: Session) -> None: diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index 6672b0b47..171dad473 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -26,7 +26,7 @@ from apispec.ext.marshmallow import MarshmallowPlugin from flask_babel import gettext as __ from marshmallow import fields, Schema from marshmallow.exceptions import ValidationError -from sqlalchemy import column +from sqlalchemy import column, types from sqlalchemy.engine.base import Engine from sqlalchemy.sql import sqltypes from typing_extensions import TypedDict @@ -201,15 +201,15 @@ class BigQueryEngineSpec(BaseEngineSpec): def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None ) -> Optional[str]: - tt = target_type.upper() - if tt == utils.TemporalType.DATE: + sqla_type = cls.get_sqla_column_type(target_type) + if isinstance(sqla_type, types.Date): return f"CAST('{dttm.date().isoformat()}' AS DATE)" - if tt == utils.TemporalType.DATETIME: - return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS DATETIME)""" - if tt == utils.TemporalType.TIME: - return f"""CAST('{dttm.strftime("%H:%M:%S.%f")}' AS TIME)""" - if tt == utils.TemporalType.TIMESTAMP: + if isinstance(sqla_type, types.TIMESTAMP): return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS TIMESTAMP)""" + if isinstance(sqla_type, types.DateTime): + return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS DATETIME)""" + if isinstance(sqla_type, types.Time): + return f"""CAST('{dttm.strftime("%H:%M:%S.%f")}' AS TIME)""" return None @classmethod diff --git a/superset/db_engine_specs/clickhouse.py b/superset/db_engine_specs/clickhouse.py index 4531dca69..930aeee52 100644 --- a/superset/db_engine_specs/clickhouse.py +++ b/superset/db_engine_specs/clickhouse.py @@ -18,12 +18,12 @@ import logging from datetime import datetime from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING +from sqlalchemy import types from urllib3.exceptions import NewConnectionError from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.exceptions import SupersetDBAPIDatabaseError from superset.extensions import cache_manager -from superset.utils import core as utils if TYPE_CHECKING: # prevent circular imports @@ -77,10 +77,11 @@ class ClickHouseEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None ) -> Optional[str]: - tt = target_type.upper() - if tt == utils.TemporalType.DATE: + sqla_type = cls.get_sqla_column_type(target_type) + + if isinstance(sqla_type, types.Date): return f"toDate('{dttm.date().isoformat()}')" - if tt == utils.TemporalType.DATETIME: + if isinstance(sqla_type, types.DateTime): return f"""toDateTime('{dttm.isoformat(sep=" ", timespec="seconds")}')""" return None diff --git a/superset/db_engine_specs/crate.py b/superset/db_engine_specs/crate.py index 4d934c448..7cf7bed15 100644 --- a/superset/db_engine_specs/crate.py +++ b/superset/db_engine_specs/crate.py @@ -14,11 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from datetime import datetime from typing import Any, Dict, Optional, TYPE_CHECKING +from sqlalchemy import types + from superset.db_engine_specs.base import BaseEngineSpec -from superset.utils import core as utils if TYPE_CHECKING: from superset.connectors.sqla.models import TableColumn @@ -53,12 +56,13 @@ class CrateEngineSpec(BaseEngineSpec): def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None ) -> Optional[str]: - tt = target_type.upper() - if tt == utils.TemporalType.TIMESTAMP: + sqla_type = cls.get_sqla_column_type(target_type) + + if isinstance(sqla_type, types.TIMESTAMP): return f"{dttm.timestamp() * 1000}" return None @classmethod - def alter_new_orm_column(cls, orm_col: "TableColumn") -> None: + def alter_new_orm_column(cls, orm_col: TableColumn) -> None: if orm_col.type == "TIMESTAMP": orm_col.python_date_format = "epoch_ms" diff --git a/superset/db_engine_specs/dremio.py b/superset/db_engine_specs/dremio.py index fddba00b5..0c773e709 100644 --- a/superset/db_engine_specs/dremio.py +++ b/superset/db_engine_specs/dremio.py @@ -17,8 +17,9 @@ from datetime import datetime from typing import Any, Dict, Optional +from sqlalchemy import types + from superset.db_engine_specs.base import BaseEngineSpec -from superset.utils import core as utils class DremioEngineSpec(BaseEngineSpec): @@ -46,10 +47,11 @@ class DremioEngineSpec(BaseEngineSpec): def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None ) -> Optional[str]: - tt = target_type.upper() - if tt == utils.TemporalType.DATE: + sqla_type = cls.get_sqla_column_type(target_type) + + if isinstance(sqla_type, types.Date): return f"TO_DATE('{dttm.date().isoformat()}', 'YYYY-MM-DD')" - if tt == utils.TemporalType.TIMESTAMP: + if isinstance(sqla_type, types.TIMESTAMP): dttm_formatted = dttm.isoformat(sep=" ", timespec="milliseconds") return f"""TO_TIMESTAMP('{dttm_formatted}', 'YYYY-MM-DD HH24:MI:SS.FFF')""" return None diff --git a/superset/db_engine_specs/drill.py b/superset/db_engine_specs/drill.py index b1a928122..756f74e82 100644 --- a/superset/db_engine_specs/drill.py +++ b/superset/db_engine_specs/drill.py @@ -18,11 +18,11 @@ from datetime import datetime from typing import Any, Dict, Optional from urllib import parse +from sqlalchemy import types from sqlalchemy.engine.url import URL from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.exceptions import SupersetDBAPIProgrammingError -from superset.utils import core as utils class DrillEngineSpec(BaseEngineSpec): @@ -59,10 +59,11 @@ class DrillEngineSpec(BaseEngineSpec): def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None ) -> Optional[str]: - tt = target_type.upper() - if tt == utils.TemporalType.DATE: + sqla_type = cls.get_sqla_column_type(target_type) + + if isinstance(sqla_type, types.Date): return f"TO_DATE('{dttm.date().isoformat()}', 'yyyy-MM-dd')" - if tt == utils.TemporalType.TIMESTAMP: + if isinstance(sqla_type, types.TIMESTAMP): datetime_formatted = dttm.isoformat(sep=" ", timespec="seconds") return f"""TO_TIMESTAMP('{datetime_formatted}', 'yyyy-MM-dd HH:mm:ss')""" return None diff --git a/superset/db_engine_specs/druid.py b/superset/db_engine_specs/druid.py index 6cdc9f85e..83829ec22 100644 --- a/superset/db_engine_specs/druid.py +++ b/superset/db_engine_specs/druid.py @@ -14,11 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +from __future__ import annotations + import json import logging from datetime import datetime from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING +from sqlalchemy import types from sqlalchemy.engine.reflection import Inspector from superset import is_feature_enabled @@ -70,12 +74,12 @@ class DruidEngineSpec(BaseEngineSpec): } @classmethod - def alter_new_orm_column(cls, orm_col: "TableColumn") -> None: + def alter_new_orm_column(cls, orm_col: TableColumn) -> None: if orm_col.column_name == "__time": orm_col.is_dttm = True @staticmethod - def get_extra_params(database: "Database") -> Dict[str, Any]: + def get_extra_params(database: Database) -> Dict[str, Any]: """ For Druid, the path to a SSL certificate is placed in `connect_args`. @@ -102,10 +106,11 @@ class DruidEngineSpec(BaseEngineSpec): def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None ) -> Optional[str]: - tt = target_type.upper() - if tt == utils.TemporalType.DATE: + sqla_type = cls.get_sqla_column_type(target_type) + + if isinstance(sqla_type, types.Date): return f"CAST(TIME_PARSE('{dttm.date().isoformat()}') AS DATE)" - if tt in (utils.TemporalType.DATETIME, utils.TemporalType.TIMESTAMP): + if isinstance(sqla_type, (types.DateTime, types.TIMESTAMP)): return f"""TIME_PARSE('{dttm.isoformat(timespec="seconds")}')""" return None diff --git a/superset/db_engine_specs/duckdb.py b/superset/db_engine_specs/duckdb.py index c9eb287c9..1248287b8 100644 --- a/superset/db_engine_specs/duckdb.py +++ b/superset/db_engine_specs/duckdb.py @@ -21,11 +21,11 @@ from datetime import datetime from typing import Any, Dict, Optional, Pattern, Set, Tuple, TYPE_CHECKING from flask_babel import gettext as __ +from sqlalchemy import types from sqlalchemy.engine.reflection import Inspector from superset.db_engine_specs.base import BaseEngineSpec from superset.errors import SupersetErrorType -from superset.utils import core as utils if TYPE_CHECKING: # prevent circular imports @@ -67,8 +67,9 @@ class DuckDBEngineSpec(BaseEngineSpec): def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None ) -> Optional[str]: - tt = target_type.upper() - if tt in (utils.TemporalType.TEXT, utils.TemporalType.DATETIME): + sqla_type = cls.get_sqla_column_type(target_type) + + if isinstance(sqla_type, (types.String, types.DateTime)): return f"""'{dttm.isoformat(sep=" ", timespec="microseconds")}'""" return None diff --git a/superset/db_engine_specs/dynamodb.py b/superset/db_engine_specs/dynamodb.py index 06dcafbb5..c398a9c1d 100644 --- a/superset/db_engine_specs/dynamodb.py +++ b/superset/db_engine_specs/dynamodb.py @@ -17,8 +17,9 @@ from datetime import datetime from typing import Any, Dict, Optional +from sqlalchemy import types + from superset.db_engine_specs.base import BaseEngineSpec -from superset.utils import core as utils class DynamoDBEngineSpec(BaseEngineSpec): @@ -56,7 +57,9 @@ class DynamoDBEngineSpec(BaseEngineSpec): def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None ) -> Optional[str]: - tt = target_type.upper() - if tt in (utils.TemporalType.TEXT, utils.TemporalType.DATETIME): + sqla_type = cls.get_sqla_column_type(target_type) + + if isinstance(sqla_type, (types.String, types.DateTime)): return f"""'{dttm.isoformat(sep=" ", timespec="seconds")}'""" + return None diff --git a/superset/db_engine_specs/elasticsearch.py b/superset/db_engine_specs/elasticsearch.py index 12a5e21e2..b47a61d0c 100644 --- a/superset/db_engine_specs/elasticsearch.py +++ b/superset/db_engine_specs/elasticsearch.py @@ -19,13 +19,14 @@ from datetime import datetime from distutils.version import StrictVersion from typing import Any, Dict, Optional, Type +from sqlalchemy import types + from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.exceptions import ( SupersetDBAPIDatabaseError, SupersetDBAPIOperationalError, SupersetDBAPIProgrammingError, ) -from superset.utils import core as utils logger = logging.getLogger() @@ -68,7 +69,10 @@ class ElasticSearchEngineSpec(BaseEngineSpec): # pylint: disable=abstract-metho ) -> Optional[str]: db_extra = db_extra or {} - if target_type.upper() == utils.TemporalType.DATETIME: + + sqla_type = cls.get_sqla_column_type(target_type) + + if isinstance(sqla_type, types.DateTime): es_version = db_extra.get("version") # The elasticsearch CAST function does not take effect for the time zone # setting. In elasticsearch7.8 and above, we can use the DATETIME_PARSE @@ -119,7 +123,9 @@ class OpenDistroEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None ) -> Optional[str]: - if target_type.upper() == utils.TemporalType.DATETIME: + sqla_type = cls.get_sqla_column_type(target_type) + + if isinstance(sqla_type, types.DateTime): return f"""'{dttm.isoformat(timespec="seconds")}'""" return None diff --git a/superset/db_engine_specs/firebird.py b/superset/db_engine_specs/firebird.py index 9254a3f2a..306a642dc 100644 --- a/superset/db_engine_specs/firebird.py +++ b/superset/db_engine_specs/firebird.py @@ -17,8 +17,9 @@ from datetime import datetime from typing import Any, Dict, Optional +from sqlalchemy import types + from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod -from superset.utils import core as utils class FirebirdEngineSpec(BaseEngineSpec): @@ -73,13 +74,14 @@ class FirebirdEngineSpec(BaseEngineSpec): def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None ) -> Optional[str]: - tt = target_type.upper() - if tt == utils.TemporalType.TIMESTAMP: + sqla_type = cls.get_sqla_column_type(target_type) + + if isinstance(sqla_type, types.Date): + return f"CAST('{dttm.date().isoformat()}' AS DATE)" + if isinstance(sqla_type, types.DateTime): dttm_formatted = dttm.isoformat(sep=" ") dttm_valid_precision = dttm_formatted[: len("YYYY-MM-DD HH:MM:SS.MMMM")] return f"CAST('{dttm_valid_precision}' AS TIMESTAMP)" - if tt == utils.TemporalType.DATE: - return f"CAST('{dttm.date().isoformat()}' AS DATE)" - if tt == utils.TemporalType.TIME: + if isinstance(sqla_type, types.Time): return f"CAST('{dttm.time().isoformat()}' AS TIME)" return None diff --git a/superset/db_engine_specs/firebolt.py b/superset/db_engine_specs/firebolt.py index 04f48b612..65cd71435 100644 --- a/superset/db_engine_specs/firebolt.py +++ b/superset/db_engine_specs/firebolt.py @@ -17,8 +17,9 @@ from datetime import datetime from typing import Any, Dict, Optional +from sqlalchemy import types + from superset.db_engine_specs.base import BaseEngineSpec -from superset.utils import core as utils class FireboltEngineSpec(BaseEngineSpec): @@ -44,13 +45,14 @@ class FireboltEngineSpec(BaseEngineSpec): def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None ) -> Optional[str]: - tt = target_type.upper() - if tt == utils.TemporalType.DATE: + sqla_type = cls.get_sqla_column_type(target_type) + + if isinstance(sqla_type, types.Date): return f"CAST('{dttm.date().isoformat()}' AS DATE)" - if tt == utils.TemporalType.DATETIME: - return f"""CAST('{dttm.isoformat(timespec="seconds")}' AS DATETIME)""" - if tt == utils.TemporalType.TIMESTAMP: + if isinstance(sqla_type, types.TIMESTAMP): return f"""CAST('{dttm.isoformat(timespec="seconds")}' AS TIMESTAMP)""" + if isinstance(sqla_type, types.DateTime): + return f"""CAST('{dttm.isoformat(timespec="seconds")}' AS DATETIME)""" return None @classmethod diff --git a/superset/db_engine_specs/hana.py b/superset/db_engine_specs/hana.py index 0cc55d08d..e579550b2 100644 --- a/superset/db_engine_specs/hana.py +++ b/superset/db_engine_specs/hana.py @@ -17,9 +17,10 @@ from datetime import datetime from typing import Any, Dict, Optional +from sqlalchemy import types + from superset.db_engine_specs.base import LimitMethod from superset.db_engine_specs.postgres import PostgresBaseEngineSpec -from superset.utils import core as utils class HanaEngineSpec(PostgresBaseEngineSpec): @@ -46,10 +47,11 @@ class HanaEngineSpec(PostgresBaseEngineSpec): def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None ) -> Optional[str]: - tt = target_type.upper() - if tt == utils.TemporalType.DATE: + sqla_type = cls.get_sqla_column_type(target_type) + + if isinstance(sqla_type, types.Date): return f"TO_DATE('{dttm.date().isoformat()}', 'YYYY-MM-DD')" - if tt == utils.TemporalType.TIMESTAMP: + if isinstance(sqla_type, types.TIMESTAMP): return f"""TO_TIMESTAMP('{dttm .isoformat(timespec="microseconds")}', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')""" return None diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index 3c90975fa..c36e9ccba 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -30,7 +30,7 @@ import pandas as pd import pyarrow as pa import pyarrow.parquet as pq from flask import current_app, g -from sqlalchemy import Column, text +from sqlalchemy import Column, text, types from sqlalchemy.engine.base import Engine from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL @@ -45,7 +45,6 @@ from superset.exceptions import SupersetException from superset.extensions import cache_manager from superset.models.sql_lab import Query from superset.sql_parse import ParsedQuery, Table -from superset.utils import core as utils if TYPE_CHECKING: # prevent circular imports @@ -249,10 +248,11 @@ class HiveEngineSpec(PrestoEngineSpec): def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None ) -> Optional[str]: - tt = target_type.upper() - if tt == utils.TemporalType.DATE: + sqla_type = cls.get_sqla_column_type(target_type) + + if isinstance(sqla_type, types.Date): return f"CAST('{dttm.date().isoformat()}' AS DATE)" - if tt == utils.TemporalType.TIMESTAMP: + if isinstance(sqla_type, types.TIMESTAMP): return f"""CAST('{dttm .isoformat(sep=" ", timespec="microseconds")}' AS TIMESTAMP)""" return None diff --git a/superset/db_engine_specs/impala.py b/superset/db_engine_specs/impala.py index 84720d56c..5de1e690c 100644 --- a/superset/db_engine_specs/impala.py +++ b/superset/db_engine_specs/impala.py @@ -21,13 +21,13 @@ from datetime import datetime from typing import Any, Dict, List, Optional from flask import current_app +from sqlalchemy import types from sqlalchemy.engine.reflection import Inspector from sqlalchemy.orm import Session from superset.constants import QUERY_EARLY_CANCEL_KEY from superset.db_engine_specs.base import BaseEngineSpec from superset.models.sql_lab import Query -from superset.utils import core as utils logger = logging.getLogger(__name__) # Query 5543ffdf692b7d02:f78a944000000000: 3% Complete (17 out of 547) @@ -59,10 +59,11 @@ class ImpalaEngineSpec(BaseEngineSpec): def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None ) -> Optional[str]: - tt = target_type.upper() - if tt == utils.TemporalType.DATE: + sqla_type = cls.get_sqla_column_type(target_type) + + if isinstance(sqla_type, types.Date): return f"CAST('{dttm.date().isoformat()}' AS DATE)" - if tt == utils.TemporalType.TIMESTAMP: + if isinstance(sqla_type, types.TIMESTAMP): return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS TIMESTAMP)""" return None diff --git a/superset/db_engine_specs/kusto.py b/superset/db_engine_specs/kusto.py index 77485310e..9fddb23d2 100644 --- a/superset/db_engine_specs/kusto.py +++ b/superset/db_engine_specs/kusto.py @@ -14,9 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import re from datetime import datetime from typing import Any, Dict, List, Optional, Type +from sqlalchemy import types +from sqlalchemy.dialects.mssql.base import SMALLDATETIME + from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod from superset.db_engine_specs.exceptions import ( SupersetDBAPIDatabaseError, @@ -24,7 +28,7 @@ from superset.db_engine_specs.exceptions import ( SupersetDBAPIProgrammingError, ) from superset.sql_parse import ParsedQuery -from superset.utils import core as utils +from superset.utils.core import GenericDataType class KustoSqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method @@ -59,6 +63,14 @@ class KustoSqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method type_code_map: Dict[int, str] = {} # loaded from get_datatype only if needed + column_type_mappings = ( + ( + re.compile(r"^smalldatetime.*", re.IGNORECASE), + SMALLDATETIME(), + GenericDataType.TEMPORAL, + ), + ) + @classmethod def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: # pylint: disable=import-outside-toplevel,import-error @@ -74,18 +86,19 @@ class KustoSqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None ) -> Optional[str]: - tt = target_type.upper() - if tt == utils.TemporalType.DATE: + sqla_type = cls.get_sqla_column_type(target_type) + + if isinstance(sqla_type, types.Date): return f"CONVERT(DATE, '{dttm.date().isoformat()}', 23)" - if tt == utils.TemporalType.DATETIME: - datetime_formatted = dttm.isoformat(timespec="milliseconds") - return f"""CONVERT(DATETIME, '{datetime_formatted}', 126)""" - if tt == utils.TemporalType.SMALLDATETIME: - datetime_formatted = dttm.isoformat(sep=" ", timespec="seconds") - return f"""CONVERT(SMALLDATETIME, '{datetime_formatted}', 20)""" - if tt == utils.TemporalType.TIMESTAMP: + if isinstance(sqla_type, types.TIMESTAMP): datetime_formatted = dttm.isoformat(sep=" ", timespec="seconds") return f"""CONVERT(TIMESTAMP, '{datetime_formatted}', 20)""" + if isinstance(sqla_type, SMALLDATETIME): + datetime_formatted = dttm.isoformat(sep=" ", timespec="seconds") + return f"""CONVERT(SMALLDATETIME, '{datetime_formatted}', 20)""" + if isinstance(sqla_type, types.DateTime): + datetime_formatted = dttm.isoformat(timespec="milliseconds") + return f"""CONVERT(DATETIME, '{datetime_formatted}', 126)""" return None @classmethod @@ -132,13 +145,12 @@ class KustoKqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None ) -> Optional[str]: - if target_type.upper() in [ - utils.TemporalType.DATETIME, - utils.TemporalType.TIMESTAMP, - ]: - return f"""datetime({dttm.isoformat(timespec="microseconds")})""" - if target_type.upper() == utils.TemporalType.DATE: + sqla_type = cls.get_sqla_column_type(target_type) + + if isinstance(sqla_type, types.Date): return f"""datetime({dttm.date().isoformat()})""" + if isinstance(sqla_type, types.DateTime): + return f"""datetime({dttm.isoformat(timespec="microseconds")})""" return None diff --git a/superset/db_engine_specs/kylin.py b/superset/db_engine_specs/kylin.py index dc3836c73..d76811e86 100644 --- a/superset/db_engine_specs/kylin.py +++ b/superset/db_engine_specs/kylin.py @@ -17,8 +17,9 @@ from datetime import datetime from typing import Any, Dict, Optional +from sqlalchemy import types + from superset.db_engine_specs.base import BaseEngineSpec -from superset.utils import core as utils class KylinEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method @@ -43,10 +44,11 @@ class KylinEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None ) -> Optional[str]: - tt = target_type.upper() - if tt == utils.TemporalType.DATE: + sqla_type = cls.get_sqla_column_type(target_type) + + if isinstance(sqla_type, types.Date): return f"CAST('{dttm.date().isoformat()}' AS DATE)" - if tt == utils.TemporalType.TIMESTAMP: + if isinstance(sqla_type, types.TIMESTAMP): datetime_fomatted = dttm.isoformat(sep=" ", timespec="seconds") return f"""CAST('{datetime_fomatted}' AS TIMESTAMP)""" return None diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py index 158e73ade..8b38ec742 100644 --- a/superset/db_engine_specs/mssql.py +++ b/superset/db_engine_specs/mssql.py @@ -20,10 +20,12 @@ from datetime import datetime from typing import Any, Dict, List, Optional, Pattern, Tuple from flask_babel import gettext as __ +from sqlalchemy import types +from sqlalchemy.dialects.mssql.base import SMALLDATETIME from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod from superset.errors import SupersetErrorType -from superset.utils import core as utils +from superset.utils.core import GenericDataType logger = logging.getLogger(__name__) @@ -70,6 +72,13 @@ class MssqlEngineSpec(BaseEngineSpec): "1969-12-29T00:00:00Z/P1W": "DATEADD(WEEK," " DATEDIFF(WEEK, 0, DATEADD(DAY, -1, {col})), 0)", } + column_type_mappings = ( + ( + re.compile(r"^smalldatetime.*", re.IGNORECASE), + SMALLDATETIME(), + GenericDataType.TEMPORAL, + ), + ) custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = { CONNECTION_ACCESS_DENIED_REGEX: ( @@ -108,15 +117,16 @@ class MssqlEngineSpec(BaseEngineSpec): def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None ) -> Optional[str]: - tt = target_type.upper() - if tt == utils.TemporalType.DATE: + sqla_type = cls.get_sqla_column_type(target_type) + + if isinstance(sqla_type, types.Date): return f"CONVERT(DATE, '{dttm.date().isoformat()}', 23)" - if tt == utils.TemporalType.DATETIME: - datetime_formatted = dttm.isoformat(timespec="milliseconds") - return f"""CONVERT(DATETIME, '{datetime_formatted}', 126)""" - if tt == utils.TemporalType.SMALLDATETIME: + if isinstance(sqla_type, SMALLDATETIME): datetime_formatted = dttm.isoformat(sep=" ", timespec="seconds") return f"""CONVERT(SMALLDATETIME, '{datetime_formatted}', 20)""" + if isinstance(sqla_type, types.DateTime): + datetime_formatted = dttm.isoformat(timespec="milliseconds") + return f"""CONVERT(DATETIME, '{datetime_formatted}', 126)""" return None @classmethod diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py index 1701d1e25..b873daff7 100644 --- a/superset/db_engine_specs/mysql.py +++ b/superset/db_engine_specs/mysql.py @@ -20,6 +20,7 @@ from typing import Any, Dict, Optional, Pattern, Tuple from urllib import parse from flask_babel import gettext as __ +from sqlalchemy import types from sqlalchemy.dialects.mysql import ( BIT, DECIMAL, @@ -34,15 +35,10 @@ from sqlalchemy.dialects.mysql import ( ) from sqlalchemy.engine.url import URL -from superset.db_engine_specs.base import ( - BaseEngineSpec, - BasicParametersMixin, - ColumnTypeMapping, -) +from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin from superset.errors import SupersetErrorType from superset.models.sql_lab import Query -from superset.utils import core as utils -from superset.utils.core import ColumnSpec, GenericDataType +from superset.utils.core import GenericDataType # Regular expressions to catch custom errors CONNECTION_ACCESS_DENIED_REGEX = re.compile( @@ -182,10 +178,11 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin): def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None ) -> Optional[str]: - tt = target_type.upper() - if tt == utils.TemporalType.DATE: + sqla_type = cls.get_sqla_column_type(target_type) + + if isinstance(sqla_type, types.Date): return f"STR_TO_DATE('{dttm.date().isoformat()}', '%Y-%m-%d')" - if tt == utils.TemporalType.DATETIME: + if isinstance(sqla_type, types.DateTime): datetime_formatted = dttm.isoformat(sep=" ", timespec="microseconds") return f"""STR_TO_DATE('{datetime_formatted}', '%Y-%m-%d %H:%i:%s.%f')""" return None @@ -232,23 +229,6 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin): pass return message - @classmethod - def get_column_spec( - cls, - native_type: Optional[str], - db_extra: Optional[Dict[str, Any]] = None, - source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE, - column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings, - ) -> Optional[ColumnSpec]: - - column_spec = super().get_column_spec(native_type) - if column_spec: - return column_spec - - return super().get_column_spec( - native_type, column_type_mappings=column_type_mappings - ) - @classmethod def get_cancel_query_id(cls, cursor: Any, query: Query) -> Optional[str]: """ diff --git a/superset/db_engine_specs/oracle.py b/superset/db_engine_specs/oracle.py index ee04e49ff..4a219919b 100644 --- a/superset/db_engine_specs/oracle.py +++ b/superset/db_engine_specs/oracle.py @@ -17,8 +17,9 @@ from datetime import datetime from typing import Any, Dict, List, Optional, Tuple +from sqlalchemy import types + from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod -from superset.utils import core as utils class OracleEngineSpec(BaseEngineSpec): @@ -44,15 +45,16 @@ class OracleEngineSpec(BaseEngineSpec): def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None ) -> Optional[str]: - tt = target_type.upper() - if tt == utils.TemporalType.DATE: + sqla_type = cls.get_sqla_column_type(target_type) + + if isinstance(sqla_type, types.Date): return f"TO_DATE('{dttm.date().isoformat()}', 'YYYY-MM-DD')" - if tt == utils.TemporalType.DATETIME: - datetime_formatted = dttm.isoformat(timespec="seconds") - return f"""TO_DATE('{datetime_formatted}', 'YYYY-MM-DD"T"HH24:MI:SS')""" - if tt == utils.TemporalType.TIMESTAMP: + if isinstance(sqla_type, types.TIMESTAMP): return f"""TO_TIMESTAMP('{dttm .isoformat(timespec="microseconds")}', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')""" + if isinstance(sqla_type, types.DateTime): + datetime_formatted = dttm.isoformat(timespec="seconds") + return f"""TO_DATE('{datetime_formatted}', 'YYYY-MM-DD"T"HH24:MI:SS')""" return None @classmethod diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index 286b6e80a..cbe00ea58 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -21,20 +21,16 @@ from datetime import datetime from typing import Any, Dict, List, Optional, Pattern, Set, Tuple, TYPE_CHECKING from flask_babel import gettext as __ -from sqlalchemy.dialects.postgresql import ARRAY, DOUBLE_PRECISION, ENUM, JSON +from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON from sqlalchemy.dialects.postgresql.base import PGInspector -from sqlalchemy.types import String +from sqlalchemy.types import Date, DateTime, String -from superset.db_engine_specs.base import ( - BaseEngineSpec, - BasicParametersMixin, - ColumnTypeMapping, -) +from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin from superset.errors import SupersetErrorType from superset.exceptions import SupersetException from superset.models.sql_lab import Query from superset.utils import core as utils -from superset.utils.core import ColumnSpec, GenericDataType +from superset.utils.core import GenericDataType if TYPE_CHECKING: from superset.models.core import Database # pragma: no cover @@ -185,7 +181,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin): ), ( re.compile(r"^array.*", re.IGNORECASE), - lambda match: ARRAY(int(match[2])) if match[2] else String(), + String(), GenericDataType.STRING, ), ( @@ -238,10 +234,11 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin): def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None ) -> Optional[str]: - tt = target_type.upper() - if tt == utils.TemporalType.DATE: + sqla_type = cls.get_sqla_column_type(target_type) + + if isinstance(sqla_type, Date): return f"TO_DATE('{dttm.date().isoformat()}', 'YYYY-MM-DD')" - if "TIMESTAMP" in tt or "DATETIME" in tt: + if isinstance(sqla_type, DateTime): dttm_formatted = dttm.isoformat(sep=" ", timespec="microseconds") return f"""TO_TIMESTAMP('{dttm_formatted}', 'YYYY-MM-DD HH24:MI:SS.US')""" return None @@ -270,23 +267,6 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin): extra["engine_params"] = engine_params return extra - @classmethod - def get_column_spec( - cls, - native_type: Optional[str], - db_extra: Optional[Dict[str, Any]] = None, - source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE, - column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings, - ) -> Optional[ColumnSpec]: - - column_spec = super().get_column_spec(native_type) - if column_spec: - return column_spec - - return super().get_column_spec( - native_type, column_type_mappings=column_type_mappings - ) - @classmethod def get_datatype(cls, type_code: Any) -> Optional[str]: # pylint: disable=import-outside-toplevel diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 2e8fc09fd..72931a85b 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -54,7 +54,7 @@ from sqlalchemy.sql.expression import ColumnClause, Select from superset import cache_manager, is_feature_enabled from superset.common.db_query_status import QueryStatus from superset.databases.utils import make_url_safe -from superset.db_engine_specs.base import BaseEngineSpec, ColumnTypeMapping +from superset.db_engine_specs.base import BaseEngineSpec from superset.errors import SupersetErrorType from superset.exceptions import SupersetTemplateException from superset.models.sql_lab import Query @@ -70,7 +70,7 @@ from superset.models.sql_types.presto_sql_types import ( from superset.result_set import destringify from superset.superset_typing import ResultSetColumnType from superset.utils import core as utils -from superset.utils.core import ColumnSpec, GenericDataType +from superset.utils.core import GenericDataType if TYPE_CHECKING: # prevent circular imports @@ -165,6 +165,92 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta): A base class that share common functions between Presto and Trino """ + column_type_mappings = ( + ( + re.compile(r"^boolean.*", re.IGNORECASE), + types.BOOLEAN(), + GenericDataType.BOOLEAN, + ), + ( + re.compile(r"^tinyint.*", re.IGNORECASE), + TinyInteger(), + GenericDataType.NUMERIC, + ), + ( + re.compile(r"^smallint.*", re.IGNORECASE), + types.SmallInteger(), + GenericDataType.NUMERIC, + ), + ( + re.compile(r"^integer.*", re.IGNORECASE), + types.INTEGER(), + GenericDataType.NUMERIC, + ), + ( + re.compile(r"^bigint.*", re.IGNORECASE), + types.BigInteger(), + GenericDataType.NUMERIC, + ), + ( + re.compile(r"^real.*", re.IGNORECASE), + types.FLOAT(), + GenericDataType.NUMERIC, + ), + ( + re.compile(r"^double.*", re.IGNORECASE), + types.FLOAT(), + GenericDataType.NUMERIC, + ), + ( + re.compile(r"^decimal.*", re.IGNORECASE), + types.DECIMAL(), + GenericDataType.NUMERIC, + ), + ( + re.compile(r"^varchar(\((\d+)\))*$", re.IGNORECASE), + lambda match: types.VARCHAR(int(match[2])) if match[2] else types.String(), + GenericDataType.STRING, + ), + ( + re.compile(r"^char(\((\d+)\))*$", re.IGNORECASE), + lambda match: types.CHAR(int(match[2])) if match[2] else types.String(), + GenericDataType.STRING, + ), + ( + re.compile(r"^varbinary.*", re.IGNORECASE), + types.VARBINARY(), + GenericDataType.STRING, + ), + ( + re.compile(r"^json.*", re.IGNORECASE), + types.JSON(), + GenericDataType.STRING, + ), + ( + re.compile(r"^date.*", re.IGNORECASE), + types.Date(), + GenericDataType.TEMPORAL, + ), + ( + re.compile(r"^timestamp.*", re.IGNORECASE), + types.TIMESTAMP(), + GenericDataType.TEMPORAL, + ), + ( + re.compile(r"^interval.*", re.IGNORECASE), + Interval(), + GenericDataType.TEMPORAL, + ), + ( + re.compile(r"^time.*", re.IGNORECASE), + types.Time(), + GenericDataType.TEMPORAL, + ), + (re.compile(r"^array.*", re.IGNORECASE), Array(), GenericDataType.STRING), + (re.compile(r"^map.*", re.IGNORECASE), Map(), GenericDataType.STRING), + (re.compile(r"^row.*", re.IGNORECASE), Row(), GenericDataType.STRING), + ) + # pylint: disable=line-too-long _time_grain_expressions = { None: "{col}", @@ -199,14 +285,13 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta): Superset only defines time zone naive `datetime` objects, though this method handles both time zone naive and aware conversions. """ - tt = target_type.upper() - if tt == utils.TemporalType.DATE: + sqla_type = cls.get_sqla_column_type(target_type) + + if isinstance(sqla_type, types.Date): return f"DATE '{dttm.date().isoformat()}'" - if tt in ( - utils.TemporalType.TIMESTAMP, - utils.TemporalType.TIMESTAMP_WITH_TIME_ZONE, - ): + if isinstance(sqla_type, types.TIMESTAMP): return f"""TIMESTAMP '{dttm.isoformat(timespec="microseconds", sep=" ")}'""" + return None @classmethod @@ -827,92 +912,6 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): full_table = "{}.{}".format(quote(schema), full_table) return inspector.bind.execute(f"SHOW COLUMNS FROM {full_table}").fetchall() - column_type_mappings = ( - ( - re.compile(r"^boolean.*", re.IGNORECASE), - types.BOOLEAN, - GenericDataType.BOOLEAN, - ), - ( - re.compile(r"^tinyint.*", re.IGNORECASE), - TinyInteger(), - GenericDataType.NUMERIC, - ), - ( - re.compile(r"^smallint.*", re.IGNORECASE), - types.SMALLINT(), - GenericDataType.NUMERIC, - ), - ( - re.compile(r"^integer.*", re.IGNORECASE), - types.INTEGER(), - GenericDataType.NUMERIC, - ), - ( - re.compile(r"^bigint.*", re.IGNORECASE), - types.BIGINT(), - GenericDataType.NUMERIC, - ), - ( - re.compile(r"^real.*", re.IGNORECASE), - types.FLOAT(), - GenericDataType.NUMERIC, - ), - ( - re.compile(r"^double.*", re.IGNORECASE), - types.FLOAT(), - GenericDataType.NUMERIC, - ), - ( - re.compile(r"^decimal.*", re.IGNORECASE), - types.DECIMAL(), - GenericDataType.NUMERIC, - ), - ( - re.compile(r"^varchar(\((\d+)\))*$", re.IGNORECASE), - lambda match: types.VARCHAR(int(match[2])) if match[2] else types.String(), - GenericDataType.STRING, - ), - ( - re.compile(r"^char(\((\d+)\))*$", re.IGNORECASE), - lambda match: types.CHAR(int(match[2])) if match[2] else types.CHAR(), - GenericDataType.STRING, - ), - ( - re.compile(r"^varbinary.*", re.IGNORECASE), - types.VARBINARY(), - GenericDataType.STRING, - ), - ( - re.compile(r"^json.*", re.IGNORECASE), - types.JSON(), - GenericDataType.STRING, - ), - ( - re.compile(r"^date.*", re.IGNORECASE), - types.DATE(), - GenericDataType.TEMPORAL, - ), - ( - re.compile(r"^timestamp.*", re.IGNORECASE), - types.TIMESTAMP(), - GenericDataType.TEMPORAL, - ), - ( - re.compile(r"^interval.*", re.IGNORECASE), - Interval(), - GenericDataType.TEMPORAL, - ), - ( - re.compile(r"^time.*", re.IGNORECASE), - types.Time(), - GenericDataType.TEMPORAL, - ), - (re.compile(r"^array.*", re.IGNORECASE), Array(), GenericDataType.STRING), - (re.compile(r"^map.*", re.IGNORECASE), Map(), GenericDataType.STRING), - (re.compile(r"^row.*", re.IGNORECASE), Row(), GenericDataType.STRING), - ) - @classmethod def get_columns( cls, inspector: Inspector, table_name: str, schema: Optional[str] @@ -1282,24 +1281,6 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): return error_dict.get("message", _("Unknown Presto Error")) return utils.error_msg_from_exception(ex) - @classmethod - def get_column_spec( - cls, - native_type: Optional[str], - db_extra: Optional[Dict[str, Any]] = None, - source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE, - column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings, - ) -> Optional[ColumnSpec]: - - column_spec = super().get_column_spec( - native_type, column_type_mappings=column_type_mappings - ) - - if column_spec: - return column_spec - - return super().get_column_spec(native_type) - @classmethod def has_implicit_cancel(cls) -> bool: """ diff --git a/superset/db_engine_specs/rockset.py b/superset/db_engine_specs/rockset.py index 606b860a5..3778c5275 100644 --- a/superset/db_engine_specs/rockset.py +++ b/superset/db_engine_specs/rockset.py @@ -17,8 +17,9 @@ from datetime import datetime from typing import Any, Dict, Optional, TYPE_CHECKING +from sqlalchemy import types + from superset.db_engine_specs.base import BaseEngineSpec -from superset.utils import core as utils if TYPE_CHECKING: from superset.connectors.sqla.models import TableColumn @@ -53,15 +54,16 @@ class RocksetEngineSpec(BaseEngineSpec): def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None ) -> Optional[str]: - tt = target_type.upper() - if tt == utils.TemporalType.DATE: + sqla_type = cls.get_sqla_column_type(target_type) + + if isinstance(sqla_type, types.Date): return f"DATE '{dttm.date().isoformat()}'" - if tt == utils.TemporalType.DATETIME: - dttm_formatted = dttm.isoformat(sep=" ", timespec="microseconds") - return f"""DATETIME '{dttm_formatted}'""" - if tt == utils.TemporalType.TIMESTAMP: + if isinstance(sqla_type, types.TIMESTAMP): dttm_formatted = dttm.isoformat(timespec="microseconds") return f"""TIMESTAMP '{dttm_formatted}'""" + if isinstance(sqla_type, types.DateTime): + dttm_formatted = dttm.isoformat(sep=" ", timespec="microseconds") + return f"""DATETIME '{dttm_formatted}'""" return None @classmethod diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py index cd083a76b..419e0a065 100644 --- a/superset/db_engine_specs/snowflake.py +++ b/superset/db_engine_specs/snowflake.py @@ -28,6 +28,7 @@ from cryptography.hazmat.primitives import serialization from flask import current_app from flask_babel import gettext as __ from marshmallow import fields, Schema +from sqlalchemy import types from sqlalchemy.engine.url import URL from typing_extensions import TypedDict @@ -37,7 +38,6 @@ from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType from superset.db_engine_specs.postgres import PostgresBaseEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.models.sql_lab import Query -from superset.utils import core as utils if TYPE_CHECKING: from superset.models.core import Database @@ -157,13 +157,14 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None ) -> Optional[str]: - tt = target_type.upper() - if tt == utils.TemporalType.DATE: + sqla_type = cls.get_sqla_column_type(target_type) + + if isinstance(sqla_type, types.Date): return f"TO_DATE('{dttm.date().isoformat()}')" - if tt == utils.TemporalType.DATETIME: - return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS DATETIME)""" - if utils.TemporalType.TIMESTAMP in tt: + if isinstance(sqla_type, types.TIMESTAMP): return f"""TO_TIMESTAMP('{dttm.isoformat(timespec="microseconds")}')""" + if isinstance(sqla_type, types.DateTime): + return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS DATETIME)""" return None @staticmethod diff --git a/superset/db_engine_specs/sqlite.py b/superset/db_engine_specs/sqlite.py index 8bd2d081e..a41414329 100644 --- a/superset/db_engine_specs/sqlite.py +++ b/superset/db_engine_specs/sqlite.py @@ -19,11 +19,11 @@ from datetime import datetime from typing import Any, Dict, Optional, Pattern, Set, Tuple, TYPE_CHECKING from flask_babel import gettext as __ +from sqlalchemy import types from sqlalchemy.engine.reflection import Inspector from superset.db_engine_specs.base import BaseEngineSpec from superset.errors import SupersetErrorType -from superset.utils import core as utils if TYPE_CHECKING: # prevent circular imports @@ -76,12 +76,8 @@ class SqliteEngineSpec(BaseEngineSpec): def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None ) -> Optional[str]: - tt = target_type.upper() - if tt in ( - utils.TemporalType.TEXT, - utils.TemporalType.DATETIME, - utils.TemporalType.TIMESTAMP, - ): + sqla_type = cls.get_sqla_column_type(target_type) + if isinstance(sqla_type, (types.String, types.DateTime)): return f"""'{dttm.isoformat(sep=" ", timespec="seconds")}'""" return None diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 3b23f7987..82c7566f3 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -17,8 +17,6 @@ from __future__ import annotations import logging -import re -from datetime import datetime from typing import Any, Dict, Optional, Type, TYPE_CHECKING import simplejson as json @@ -49,29 +47,6 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): engine = "trino" engine_name = "Trino" - @classmethod - def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None - ) -> Optional[str]: - """ - Convert a Python `datetime` object to a SQL expression. - :param target_type: The target type of expression - :param dttm: The datetime object - :param db_extra: The database extra object - :return: The SQL expression - Superset only defines time zone naive `datetime` objects, though this method - handles both time zone naive and aware conversions. - """ - tt = target_type.upper() - if tt == utils.TemporalType.DATE: - return f"DATE '{dttm.date().isoformat()}'" - if re.sub(r"\(\d\)", "", tt) in ( - utils.TemporalType.TIMESTAMP, - utils.TemporalType.TIMESTAMP_WITH_TIME_ZONE, - ): - return f"""TIMESTAMP '{dttm.isoformat(timespec="microseconds", sep=" ")}'""" - return None - @classmethod def extra_table_metadata( cls, diff --git a/superset/utils/core.py b/superset/utils/core.py index 86486ad22..af6a2519e 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -364,21 +364,6 @@ class RowLevelSecurityFilterType(str, Enum): BASE = "Base" -class TemporalType(str, Enum): - """ - Supported temporal types - """ - - DATE = "DATE" - DATETIME = "DATETIME" - SMALLDATETIME = "SMALLDATETIME" - TEXT = "TEXT" - TIME = "TIME" - TIME_WITH_TIME_ZONE = "TIME WITH TIME ZONE" - TIMESTAMP = "TIMESTAMP" - TIMESTAMP_WITH_TIME_ZONE = "TIMESTAMP WITH TIME ZONE" - - class ColumnTypeSource(Enum): GET_TABLE = 1 CURSOR_DESCRIPION = 2 diff --git a/tests/integration_tests/db_engine_specs/base_tests.py b/tests/integration_tests/db_engine_specs/base_tests.py index 6496d4609..e20ea35ae 100644 --- a/tests/integration_tests/db_engine_specs/base_tests.py +++ b/tests/integration_tests/db_engine_specs/base_tests.py @@ -22,7 +22,6 @@ from tests.integration_tests.test_app import app from tests.integration_tests.base_tests import SupersetTestCase from superset.db_engine_specs.base import BaseEngineSpec from superset.models.core import Database -from superset.utils.core import GenericDataType class TestDbEngineSpec(SupersetTestCase): @@ -37,16 +36,3 @@ class TestDbEngineSpec(SupersetTestCase): main = Database(database_name="test_database", sqlalchemy_uri="sqlite://") limited = engine_spec_class.apply_limit_to_sql(sql, limit, main, force) self.assertEqual(expected_sql, limited) - - -def assert_generic_types( - spec: Type[BaseEngineSpec], - type_expectations: Tuple[Tuple[str, GenericDataType], ...], -) -> None: - for type_str, expected_type in type_expectations: - column_spec = spec.get_column_spec(type_str) - assert column_spec is not None - actual_type = column_spec.generic_type - assert ( - actual_type == expected_type - ), f"{type_str} should be {expected_type.name} but is {actual_type.name}" diff --git a/tests/integration_tests/db_engine_specs/bigquery_tests.py b/tests/integration_tests/db_engine_specs/bigquery_tests.py index 5f0819258..574a2b75e 100644 --- a/tests/integration_tests/db_engine_specs/bigquery_tests.py +++ b/tests/integration_tests/db_engine_specs/bigquery_tests.py @@ -48,23 +48,6 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec): actual = BigQueryEngineSpec.make_label_compatible(column(original).name) self.assertEqual(actual, expected) - def test_convert_dttm(self): - """ - DB Eng Specs (bigquery): Test conversion to date time - """ - dttm = self.get_dttm() - test_cases = { - "DATE": "CAST('2019-01-02' AS DATE)", - "DATETIME": "CAST('2019-01-02T03:04:05.678900' AS DATETIME)", - "TIMESTAMP": "CAST('2019-01-02T03:04:05.678900' AS TIMESTAMP)", - "TIME": "CAST('03:04:05.678900' AS TIME)", - "UNKNOWNTYPE": None, - } - - for target_type, expected in test_cases.items(): - actual = BigQueryEngineSpec.convert_dttm(target_type, dttm) - self.assertEqual(actual, expected) - def test_timegrain_expressions(self): """ DB Eng Specs (bigquery): Test time grain expressions diff --git a/tests/integration_tests/db_engine_specs/crate_tests.py b/tests/integration_tests/db_engine_specs/crate_tests.py deleted file mode 100644 index 7c86b34b3..000000000 --- a/tests/integration_tests/db_engine_specs/crate_tests.py +++ /dev/null @@ -1,53 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from superset.connectors.sqla.models import SqlaTable, TableColumn -from superset.db_engine_specs.crate import CrateEngineSpec -from superset.models.core import Database -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec - - -class TestCrateDbEngineSpec(TestDbEngineSpec): - def test_convert_dttm(self): - """ - DB Eng Specs (crate): Test conversion to date time - """ - dttm = self.get_dttm() - assert CrateEngineSpec.convert_dttm("TIMESTAMP", dttm) == str( - dttm.timestamp() * 1000 - ) - - def test_epoch_to_dttm(self): - """ - DB Eng Specs (crate): Test epoch to dttm - """ - assert CrateEngineSpec.epoch_to_dttm() == "{col} * 1000" - - def test_epoch_ms_to_dttm(self): - """ - DB Eng Specs (crate): Test epoch ms to dttm - """ - assert CrateEngineSpec.epoch_ms_to_dttm() == "{col}" - - def test_alter_new_orm_column(self): - """ - DB Eng Specs (crate): Test alter orm column - """ - database = Database(database_name="crate", sqlalchemy_uri="crate://db") - tbl = SqlaTable(table_name="druid_tbl", database=database) - col = TableColumn(column_name="ts", type="TIMESTAMP", table=tbl) - CrateEngineSpec.alter_new_orm_column(col) - assert col.python_date_format == "epoch_ms" diff --git a/tests/integration_tests/db_engine_specs/databricks_tests.py b/tests/integration_tests/db_engine_specs/databricks_tests.py index c2d57831a..5ff20b734 100644 --- a/tests/integration_tests/db_engine_specs/databricks_tests.py +++ b/tests/integration_tests/db_engine_specs/databricks_tests.py @@ -14,18 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from textwrap import dedent from unittest import mock -from sqlalchemy import column, literal_column - -from superset.constants import USER_AGENT from superset.db_engine_specs import get_engine_spec from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec -from tests.integration_tests.db_engine_specs.base_tests import ( - assert_generic_types, - TestDbEngineSpec, -) +from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec from tests.integration_tests.fixtures.certificates import ssl_certificate from tests.integration_tests.fixtures.database import default_db_extra diff --git a/tests/integration_tests/db_engine_specs/druid_tests.py b/tests/integration_tests/db_engine_specs/druid_tests.py deleted file mode 100644 index 232787ba6..000000000 --- a/tests/integration_tests/db_engine_specs/druid_tests.py +++ /dev/null @@ -1,78 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from unittest import mock - -from sqlalchemy import column - -from superset.db_engine_specs.druid import DruidEngineSpec -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec -from tests.integration_tests.fixtures.certificates import ssl_certificate -from tests.integration_tests.fixtures.database import default_db_extra - - -class TestDruidDbEngineSpec(TestDbEngineSpec): - def test_convert_dttm(self): - dttm = self.get_dttm() - - self.assertEqual( - DruidEngineSpec.convert_dttm("DATETIME", dttm), - "TIME_PARSE('2019-01-02T03:04:05')", - ) - - self.assertEqual( - DruidEngineSpec.convert_dttm("TIMESTAMP", dttm), - "TIME_PARSE('2019-01-02T03:04:05')", - ) - - self.assertEqual( - DruidEngineSpec.convert_dttm("DATE", dttm), - "CAST(TIME_PARSE('2019-01-02') AS DATE)", - ) - - def test_timegrain_expressions(self): - """ - DB Eng Specs (druid): Test time grain expressions - """ - col = "__time" - sqla_col = column(col) - test_cases = { - "PT1S": f"TIME_FLOOR(CAST({col} AS TIMESTAMP), 'PT1S')", - "PT5M": f"TIME_FLOOR(CAST({col} AS TIMESTAMP), 'PT5M')", - "P1W/1970-01-03T00:00:00Z": f"TIME_SHIFT(TIME_FLOOR(TIME_SHIFT(CAST({col} AS TIMESTAMP), 'P1D', 1), 'P1W'), 'P1D', 5)", - "1969-12-28T00:00:00Z/P1W": f"TIME_SHIFT(TIME_FLOOR(TIME_SHIFT(CAST({col} AS TIMESTAMP), 'P1D', 1), 'P1W'), 'P1D', -1)", - } - for grain, expected in test_cases.items(): - actual = DruidEngineSpec.get_timestamp_expr( - col=sqla_col, pdf=None, time_grain=grain - ) - self.assertEqual(str(actual), expected) - - def test_extras_without_ssl(self): - db = mock.Mock() - db.extra = default_db_extra - db.server_cert = None - extras = DruidEngineSpec.get_extra_params(db) - assert "connect_args" not in extras["engine_params"] - - def test_extras_with_ssl(self): - db = mock.Mock() - db.extra = default_db_extra - db.server_cert = ssl_certificate - extras = DruidEngineSpec.get_extra_params(db) - connect_args = extras["engine_params"]["connect_args"] - assert connect_args["scheme"] == "https" - assert "ssl_verify_cert" in connect_args diff --git a/tests/integration_tests/db_engine_specs/elasticsearch_tests.py b/tests/integration_tests/db_engine_specs/elasticsearch_tests.py deleted file mode 100644 index 7dd515779..000000000 --- a/tests/integration_tests/db_engine_specs/elasticsearch_tests.py +++ /dev/null @@ -1,104 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from unittest.mock import MagicMock - -import pytest -from sqlalchemy import column - -from superset.db_engine_specs.elasticsearch import ( - ElasticSearchEngineSpec, - OpenDistroEngineSpec, -) -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec - - -class TestElasticSearchDbEngineSpec(TestDbEngineSpec): - @pytest.fixture(autouse=True) - def inject_fixtures(self, caplog): - self._caplog = caplog - - def test_convert_dttm(self): - dttm = self.get_dttm() - - self.assertEqual( - ElasticSearchEngineSpec.convert_dttm("DATETIME", dttm, db_extra=None), - "CAST('2019-01-02T03:04:05' AS DATETIME)", - ) - - def test_convert_dttm2(self): - """ - ES 7.8 and above versions need to use the DATETIME_PARSE function to - solve the time zone problem - """ - dttm = self.get_dttm() - db_extra = {"version": "7.8"} - - self.assertEqual( - ElasticSearchEngineSpec.convert_dttm("DATETIME", dttm, db_extra=db_extra), - "DATETIME_PARSE('2019-01-02 03:04:05', 'yyyy-MM-dd HH:mm:ss')", - ) - - def test_convert_dttm3(self): - dttm = self.get_dttm() - db_extra = {"version": 7.8} - - self.assertEqual( - ElasticSearchEngineSpec.convert_dttm("DATETIME", dttm, db_extra=db_extra), - "CAST('2019-01-02T03:04:05' AS DATETIME)", - ) - - self.assertNotEqual( - ElasticSearchEngineSpec.convert_dttm("DATETIME", dttm, db_extra=db_extra), - "DATETIME_PARSE('2019-01-02 03:04:05', 'yyyy-MM-dd HH:mm:ss')", - ) - - self.assertIn("Unexpected error while convert es_version", self._caplog.text) - - def test_opendistro_convert_dttm(self): - """ - DB Eng Specs (opendistro): Test convert_dttm - """ - dttm = self.get_dttm() - - self.assertEqual( - OpenDistroEngineSpec.convert_dttm("DATETIME", dttm, db_extra=None), - "'2019-01-02T03:04:05'", - ) - - def test_opendistro_sqla_column_label(self): - """ - DB Eng Specs (opendistro): Test column label - """ - test_cases = { - "Col": "Col", - "Col.keyword": "Col_keyword", - } - for original, expected in test_cases.items(): - actual = OpenDistroEngineSpec.make_label_compatible(column(original).name) - self.assertEqual(actual, expected) - - def test_opendistro_strip_comments(self): - """ - DB Eng Specs (opendistro): Test execute sql strip comments - """ - mock_cursor = MagicMock() - mock_cursor.execute.return_value = [] - - OpenDistroEngineSpec.execute( - mock_cursor, "-- some comment \nSELECT 1\n --other comment" - ) - mock_cursor.execute.assert_called_once_with("SELECT 1\n") diff --git a/tests/integration_tests/db_engine_specs/firebird_tests.py b/tests/integration_tests/db_engine_specs/firebird_tests.py deleted file mode 100644 index 5e00e2ed4..000000000 --- a/tests/integration_tests/db_engine_specs/firebird_tests.py +++ /dev/null @@ -1,81 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from datetime import datetime -from unittest import mock - -import pytest - -from superset.db_engine_specs.firebird import FirebirdEngineSpec - -grain_expressions = { - None: "timestamp_column", - "PT1S": ( - "CAST(CAST(timestamp_column AS DATE) " - "|| ' ' " - "|| EXTRACT(HOUR FROM timestamp_column) " - "|| ':' " - "|| EXTRACT(MINUTE FROM timestamp_column) " - "|| ':' " - "|| FLOOR(EXTRACT(SECOND FROM timestamp_column)) AS TIMESTAMP)" - ), - "PT1M": ( - "CAST(CAST(timestamp_column AS DATE) " - "|| ' ' " - "|| EXTRACT(HOUR FROM timestamp_column) " - "|| ':' " - "|| EXTRACT(MINUTE FROM timestamp_column) " - "|| ':00' AS TIMESTAMP)" - ), - "P1D": "CAST(timestamp_column AS DATE)", - "P1M": ( - "CAST(EXTRACT(YEAR FROM timestamp_column) " - "|| '-' " - "|| EXTRACT(MONTH FROM timestamp_column) " - "|| '-01' AS DATE)" - ), - "P1Y": "CAST(EXTRACT(YEAR FROM timestamp_column) || '-01-01' AS DATE)", -} - - -@pytest.mark.parametrize("grain,expected", grain_expressions.items()) -def test_time_grain_expressions(grain, expected): - assert ( - FirebirdEngineSpec._time_grain_expressions[grain].format(col="timestamp_column") - == expected - ) - - -def test_epoch_to_dttm(): - assert ( - FirebirdEngineSpec.epoch_to_dttm().format(col="timestamp_column") - == "DATEADD(second, timestamp_column, CAST('00:00:00' AS TIMESTAMP))" - ) - - -def test_convert_dttm(): - dttm = datetime(2021, 1, 1) - assert ( - FirebirdEngineSpec.convert_dttm("timestamp", dttm) - == "CAST('2021-01-01 00:00:00' AS TIMESTAMP)" - ) - assert ( - FirebirdEngineSpec.convert_dttm("TIMESTAMP", dttm) - == "CAST('2021-01-01 00:00:00' AS TIMESTAMP)" - ) - assert FirebirdEngineSpec.convert_dttm("TIME", dttm) == "CAST('00:00:00' AS TIME)" - assert FirebirdEngineSpec.convert_dttm("DATE", dttm) == "CAST('2021-01-01' AS DATE)" - assert FirebirdEngineSpec.convert_dttm("STRING", dttm) is None diff --git a/tests/integration_tests/db_engine_specs/firebolt_tests.py b/tests/integration_tests/db_engine_specs/firebolt_tests.py deleted file mode 100644 index 793b32970..000000000 --- a/tests/integration_tests/db_engine_specs/firebolt_tests.py +++ /dev/null @@ -1,39 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from superset.db_engine_specs.firebolt import FireboltEngineSpec -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec - - -class TestFireboltDbEngineSpec(TestDbEngineSpec): - def test_convert_dttm(self): - dttm = self.get_dttm() - test_cases = { - "DATE": "CAST('2019-01-02' AS DATE)", - "DATETIME": "CAST('2019-01-02T03:04:05' AS DATETIME)", - "TIMESTAMP": "CAST('2019-01-02T03:04:05' AS TIMESTAMP)", - "UNKNOWNTYPE": None, - } - - for target_type, expected in test_cases.items(): - actual = FireboltEngineSpec.convert_dttm(target_type, dttm) - self.assertEqual(actual, expected) - - def test_epoch_to_dttm(self): - assert ( - FireboltEngineSpec.epoch_to_dttm().format(col="timestamp_column") - == "from_unixtime(timestamp_column)" - ) diff --git a/tests/integration_tests/db_engine_specs/hive_tests.py b/tests/integration_tests/db_engine_specs/hive_tests.py index b39f26589..b63f64ab0 100644 --- a/tests/integration_tests/db_engine_specs/hive_tests.py +++ b/tests/integration_tests/db_engine_specs/hive_tests.py @@ -150,15 +150,6 @@ def test_hive_error_msg(): ) -def test_convert_dttm(): - dttm = datetime.strptime("2019-01-02 03:04:05.678900", "%Y-%m-%d %H:%M:%S.%f") - assert HiveEngineSpec.convert_dttm("DATE", dttm) == "CAST('2019-01-02' AS DATE)" - assert ( - HiveEngineSpec.convert_dttm("TIMESTAMP", dttm) - == "CAST('2019-01-02 03:04:05.678900' AS TIMESTAMP)" - ) - - def test_df_to_csv() -> None: with pytest.raises(SupersetException): HiveEngineSpec.df_to_sql( diff --git a/tests/integration_tests/db_engine_specs/kylin_tests.py b/tests/integration_tests/db_engine_specs/kylin_tests.py deleted file mode 100644 index a607565d5..000000000 --- a/tests/integration_tests/db_engine_specs/kylin_tests.py +++ /dev/null @@ -1,32 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from superset.db_engine_specs.kylin import KylinEngineSpec -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec - - -class TestKylinDbEngineSpec(TestDbEngineSpec): - def test_convert_dttm(self): - dttm = self.get_dttm() - - self.assertEqual( - KylinEngineSpec.convert_dttm("DATE", dttm), "CAST('2019-01-02' AS DATE)" - ) - - self.assertEqual( - KylinEngineSpec.convert_dttm("TIMESTAMP", dttm), - "CAST('2019-01-02 03:04:05' AS TIMESTAMP)", - ) diff --git a/tests/integration_tests/db_engine_specs/mysql_tests.py b/tests/integration_tests/db_engine_specs/mysql_tests.py index b069bba69..36b41222b 100644 --- a/tests/integration_tests/db_engine_specs/mysql_tests.py +++ b/tests/integration_tests/db_engine_specs/mysql_tests.py @@ -21,12 +21,7 @@ from sqlalchemy.dialects.mysql import DATE, NVARCHAR, TEXT, VARCHAR from superset.db_engine_specs.mysql import MySQLEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType -from superset.models.sql_lab import Query -from superset.utils.core import GenericDataType -from tests.integration_tests.db_engine_specs.base_tests import ( - assert_generic_types, - TestDbEngineSpec, -) +from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec): @@ -38,19 +33,6 @@ class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec): self.assertEqual("TINY", MySQLEngineSpec.get_datatype(1)) self.assertEqual("VARCHAR", MySQLEngineSpec.get_datatype(15)) - def test_convert_dttm(self): - dttm = self.get_dttm() - - self.assertEqual( - MySQLEngineSpec.convert_dttm("DATE", dttm), - "STR_TO_DATE('2019-01-02', '%Y-%m-%d')", - ) - - self.assertEqual( - MySQLEngineSpec.convert_dttm("DATETIME", dttm), - "STR_TO_DATE('2019-01-02 03:04:05.678900', '%Y-%m-%d %H:%i:%s.%f')", - ) - def test_column_datatype_to_string(self): test_cases = ( (DATE(), "DATE"), @@ -69,32 +51,6 @@ class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec): ) self.assertEqual(actual, expected) - def test_generic_type(self): - type_expectations = ( - # Numeric - ("TINYINT", GenericDataType.NUMERIC), - ("SMALLINT", GenericDataType.NUMERIC), - ("MEDIUMINT", GenericDataType.NUMERIC), - ("INT", GenericDataType.NUMERIC), - ("BIGINT", GenericDataType.NUMERIC), - ("DECIMAL", GenericDataType.NUMERIC), - ("FLOAT", GenericDataType.NUMERIC), - ("DOUBLE", GenericDataType.NUMERIC), - ("BIT", GenericDataType.NUMERIC), - # String - ("CHAR", GenericDataType.STRING), - ("VARCHAR", GenericDataType.STRING), - ("TINYTEXT", GenericDataType.STRING), - ("MEDIUMTEXT", GenericDataType.STRING), - ("LONGTEXT", GenericDataType.STRING), - # Temporal - ("DATE", GenericDataType.TEMPORAL), - ("DATETIME", GenericDataType.TEMPORAL), - ("TIMESTAMP", GenericDataType.TEMPORAL), - ("TIME", GenericDataType.TEMPORAL), - ) - assert_generic_types(MySQLEngineSpec, type_expectations) - def test_extract_error_message(self): from MySQLdb._exceptions import OperationalError @@ -239,22 +195,3 @@ class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec): }, ) ] - - @unittest.mock.patch("sqlalchemy.engine.Engine.connect") - def test_get_cancel_query_id(self, engine_mock): - query = Query() - cursor_mock = engine_mock.return_value.__enter__.return_value - cursor_mock.fetchone.return_value = [123] - assert MySQLEngineSpec.get_cancel_query_id(cursor_mock, query) == 123 - - @unittest.mock.patch("sqlalchemy.engine.Engine.connect") - def test_cancel_query(self, engine_mock): - query = Query() - cursor_mock = engine_mock.return_value.__enter__.return_value - assert MySQLEngineSpec.cancel_query(cursor_mock, query, 123) is True - - @unittest.mock.patch("sqlalchemy.engine.Engine.connect") - def test_cancel_query_failed(self, engine_mock): - query = Query() - cursor_mock = engine_mock.raiseError.side_effect = Exception() - assert MySQLEngineSpec.cancel_query(cursor_mock, query, 123) is False diff --git a/tests/integration_tests/db_engine_specs/oracle_tests.py b/tests/integration_tests/db_engine_specs/oracle_tests.py deleted file mode 100644 index b2f4f9f23..000000000 --- a/tests/integration_tests/db_engine_specs/oracle_tests.py +++ /dev/null @@ -1,87 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from unittest import mock - -import pytest -from sqlalchemy import column -from sqlalchemy.dialects import oracle -from sqlalchemy.dialects.oracle import DATE, NVARCHAR, VARCHAR - -from superset.db_engine_specs.oracle import OracleEngineSpec -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec - - -class TestOracleDbEngineSpec(TestDbEngineSpec): - def test_oracle_sqla_column_name_length_exceeded(self): - col = column("This_Is_32_Character_Column_Name") - label = OracleEngineSpec.make_label_compatible(col.name) - self.assertEqual(label.quote, True) - label_expected = "3b26974078683be078219674eeb8f5" - self.assertEqual(label, label_expected) - - def test_oracle_time_expression_reserved_keyword_1m_grain(self): - col = column("decimal") - expr = OracleEngineSpec.get_timestamp_expr(col, None, "P1M") - result = str(expr.compile(dialect=oracle.dialect())) - self.assertEqual(result, "TRUNC(CAST(\"decimal\" as DATE), 'MONTH')") - dttm = self.get_dttm() - - def test_column_datatype_to_string(self): - test_cases = ( - (DATE(), "DATE"), - (VARCHAR(length=255), "VARCHAR(255 CHAR)"), - (VARCHAR(length=255, collation="utf8"), "VARCHAR(255 CHAR)"), - (NVARCHAR(length=128), "NVARCHAR2(128)"), - ) - - for original, expected in test_cases: - actual = OracleEngineSpec.column_datatype_to_string( - original, oracle.dialect() - ) - self.assertEqual(actual, expected) - - def test_fetch_data_no_description(self): - cursor = mock.MagicMock() - cursor.description = [] - assert OracleEngineSpec.fetch_data(cursor) == [] - - def test_fetch_data(self): - cursor = mock.MagicMock() - result = ["a", "b"] - cursor.fetchall.return_value = result - assert OracleEngineSpec.fetch_data(cursor) == result - - -@pytest.mark.parametrize( - "date_format,expected", - [ - ("DATE", "TO_DATE('2019-01-02', 'YYYY-MM-DD')"), - ("DATETIME", """TO_DATE('2019-01-02T03:04:05', 'YYYY-MM-DD"T"HH24:MI:SS')"""), - ( - "TIMESTAMP", - """TO_TIMESTAMP('2019-01-02T03:04:05.678900', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')""", - ), - ( - "timestamp", - """TO_TIMESTAMP('2019-01-02T03:04:05.678900', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')""", - ), - ("Other", None), - ], -) -def test_convert_dttm(date_format, expected): - dttm = TestOracleDbEngineSpec.get_dttm() - assert OracleEngineSpec.convert_dttm(date_format, dttm) == expected diff --git a/tests/integration_tests/db_engine_specs/postgres_tests.py b/tests/integration_tests/db_engine_specs/postgres_tests.py index 03b3e5763..a6145432c 100644 --- a/tests/integration_tests/db_engine_specs/postgres_tests.py +++ b/tests/integration_tests/db_engine_specs/postgres_tests.py @@ -24,11 +24,7 @@ from superset.db_engine_specs import load_engine_specs from superset.db_engine_specs.postgres import PostgresEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.models.sql_lab import Query -from superset.utils.core import GenericDataType -from tests.integration_tests.db_engine_specs.base_tests import ( - assert_generic_types, - TestDbEngineSpec, -) +from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec from tests.integration_tests.fixtures.certificates import ssl_certificate from tests.integration_tests.fixtures.database import default_db_extra @@ -100,29 +96,6 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec): result = str(expr.compile(None, dialect=postgresql.dialect())) self.assertEqual(result, "DATE_TRUNC('year', \"MixedCase\")") - def test_convert_dttm(self): - """ - DB Eng Specs (postgres): Test conversion to date time - """ - dttm = self.get_dttm() - - self.assertEqual( - PostgresEngineSpec.convert_dttm("DATE", dttm), - "TO_DATE('2019-01-02', 'YYYY-MM-DD')", - ) - - self.assertEqual( - PostgresEngineSpec.convert_dttm("TIMESTAMP", dttm), - "TO_TIMESTAMP('2019-01-02 03:04:05.678900', 'YYYY-MM-DD HH24:MI:SS.US')", - ) - - self.assertEqual( - PostgresEngineSpec.convert_dttm("DATETIME", dttm), - "TO_TIMESTAMP('2019-01-02 03:04:05.678900', 'YYYY-MM-DD HH24:MI:SS.US')", - ) - - self.assertEqual(PostgresEngineSpec.convert_dttm("TIME", dttm), None) - def test_empty_dbapi_cursor_description(self): """ DB Eng Specs (postgres): Test empty cursor description (no columns) @@ -541,28 +514,3 @@ def test_base_parameters_mixin(): }, "required": ["database", "host", "port", "username"], } - - -def test_generic_type(): - type_expectations = ( - # Numeric - ("SMALLINT", GenericDataType.NUMERIC), - ("INTEGER", GenericDataType.NUMERIC), - ("BIGINT", GenericDataType.NUMERIC), - ("DECIMAL", GenericDataType.NUMERIC), - ("NUMERIC", GenericDataType.NUMERIC), - ("REAL", GenericDataType.NUMERIC), - ("DOUBLE PRECISION", GenericDataType.NUMERIC), - ("MONEY", GenericDataType.NUMERIC), - # String - ("CHAR", GenericDataType.STRING), - ("VARCHAR", GenericDataType.STRING), - ("TEXT", GenericDataType.STRING), - # Temporal - ("DATE", GenericDataType.TEMPORAL), - ("TIMESTAMP", GenericDataType.TEMPORAL), - ("TIME", GenericDataType.TEMPORAL), - # Boolean - ("BOOLEAN", GenericDataType.BOOLEAN), - ) - assert_generic_types(PostgresEngineSpec, type_expectations) diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py b/tests/integration_tests/db_engine_specs/presto_tests.py index 9099dbb7d..78b552ecb 100644 --- a/tests/integration_tests/db_engine_specs/presto_tests.py +++ b/tests/integration_tests/db_engine_specs/presto_tests.py @@ -25,7 +25,6 @@ from sqlalchemy.sql import select from superset.db_engine_specs.presto import PrestoEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.sql_parse import ParsedQuery -from superset.utils.core import DatasourceName, GenericDataType from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec @@ -624,42 +623,6 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec): self.assertEqual(actual_data, expected_data) self.assertEqual(actual_expanded_cols, expected_expanded_cols) - def test_get_sqla_column_type(self): - column_spec = PrestoEngineSpec.get_column_spec("varchar(255)") - assert isinstance(column_spec.sqla_type, types.VARCHAR) - assert column_spec.sqla_type.length == 255 - self.assertEqual(column_spec.generic_type, GenericDataType.STRING) - - column_spec = PrestoEngineSpec.get_column_spec("varchar") - assert isinstance(column_spec.sqla_type, types.String) - assert column_spec.sqla_type.length is None - self.assertEqual(column_spec.generic_type, GenericDataType.STRING) - - column_spec = PrestoEngineSpec.get_column_spec("char(10)") - assert isinstance(column_spec.sqla_type, types.CHAR) - assert column_spec.sqla_type.length == 10 - self.assertEqual(column_spec.generic_type, GenericDataType.STRING) - - column_spec = PrestoEngineSpec.get_column_spec("char") - assert isinstance(column_spec.sqla_type, types.CHAR) - assert column_spec.sqla_type.length is None - self.assertEqual(column_spec.generic_type, GenericDataType.STRING) - - column_spec = PrestoEngineSpec.get_column_spec("integer") - assert isinstance(column_spec.sqla_type, types.Integer) - self.assertEqual(column_spec.generic_type, GenericDataType.NUMERIC) - - column_spec = PrestoEngineSpec.get_column_spec("time") - assert isinstance(column_spec.sqla_type, types.Time) - self.assertEqual(column_spec.generic_type, GenericDataType.TEMPORAL) - - column_spec = PrestoEngineSpec.get_column_spec("timestamp") - assert isinstance(column_spec.sqla_type, types.TIMESTAMP) - self.assertEqual(column_spec.generic_type, GenericDataType.TEMPORAL) - - sqla_type = PrestoEngineSpec.get_sqla_column_type(None) - assert sqla_type is None - @mock.patch("superset.db_engine_specs.base.BaseEngineSpec.get_table_names") @mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.get_view_names") def test_get_table_names( diff --git a/tests/integration_tests/db_engine_specs/trino_tests.py b/tests/integration_tests/db_engine_specs/trino_tests.py deleted file mode 100644 index 6379d013b..000000000 --- a/tests/integration_tests/db_engine_specs/trino_tests.py +++ /dev/null @@ -1,214 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import json -from typing import Any, Dict -from unittest import mock -from unittest.mock import Mock, patch - -import pandas as pd -import pytest -from sqlalchemy import types - -import superset.config -from superset.constants import USER_AGENT -from superset.db_engine_specs.trino import TrinoEngineSpec -from superset.utils.core import GenericDataType -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec - - -class TestTrinoDbEngineSpec(TestDbEngineSpec): - def test_get_extra_params(self): - database = Mock() - - database.extra = json.dumps({}) - database.server_cert = None - extra = TrinoEngineSpec.get_extra_params(database) - expected = {"engine_params": {"connect_args": {"source": USER_AGENT}}} - self.assertEqual(extra, expected) - - expected = { - "first": 1, - "engine_params": { - "second": "two", - "connect_args": {"source": "foobar", "third": "three"}, - }, - } - database.extra = json.dumps(expected) - database.server_cert = None - extra = TrinoEngineSpec.get_extra_params(database) - self.assertEqual(extra, expected) - - @patch("superset.utils.core.create_ssl_cert_file") - def test_get_extra_params_with_server_cert(self, create_ssl_cert_file_func: Mock): - database = Mock() - - database.extra = json.dumps({}) - database.server_cert = "TEST_CERT" - create_ssl_cert_file_func.return_value = "/path/to/tls.crt" - extra = TrinoEngineSpec.get_extra_params(database) - - connect_args = extra.get("engine_params", {}).get("connect_args", {}) - self.assertEqual(connect_args.get("http_scheme"), "https") - self.assertEqual(connect_args.get("verify"), "/path/to/tls.crt") - create_ssl_cert_file_func.assert_called_once_with(database.server_cert) - - @patch("trino.auth.BasicAuthentication") - def test_auth_basic(self, auth: Mock): - database = Mock() - - auth_params = {"username": "username", "password": "password"} - database.encrypted_extra = json.dumps( - {"auth_method": "basic", "auth_params": auth_params} - ) - - params: Dict[str, Any] = {} - TrinoEngineSpec.update_params_from_encrypted_extra(database, params) - connect_args = params.setdefault("connect_args", {}) - self.assertEqual(connect_args.get("http_scheme"), "https") - auth.assert_called_once_with(**auth_params) - - @patch("trino.auth.KerberosAuthentication") - def test_auth_kerberos(self, auth: Mock): - database = Mock() - - auth_params = { - "service_name": "superset", - "mutual_authentication": False, - "delegate": True, - } - database.encrypted_extra = json.dumps( - {"auth_method": "kerberos", "auth_params": auth_params} - ) - - params: Dict[str, Any] = {} - TrinoEngineSpec.update_params_from_encrypted_extra(database, params) - connect_args = params.setdefault("connect_args", {}) - self.assertEqual(connect_args.get("http_scheme"), "https") - auth.assert_called_once_with(**auth_params) - - @patch("trino.auth.CertificateAuthentication") - def test_auth_certificate(self, auth: Mock): - database = Mock() - - auth_params = {"cert": "/path/to/cert.pem", "key": "/path/to/key.pem"} - database.encrypted_extra = json.dumps( - {"auth_method": "certificate", "auth_params": auth_params} - ) - - params: Dict[str, Any] = {} - TrinoEngineSpec.update_params_from_encrypted_extra(database, params) - connect_args = params.setdefault("connect_args", {}) - self.assertEqual(connect_args.get("http_scheme"), "https") - auth.assert_called_once_with(**auth_params) - - @patch("trino.auth.JWTAuthentication") - def test_auth_jwt(self, auth: Mock): - database = Mock() - - auth_params = {"token": "jwt-token-string"} - database.encrypted_extra = json.dumps( - {"auth_method": "jwt", "auth_params": auth_params} - ) - - params: Dict[str, Any] = {} - TrinoEngineSpec.update_params_from_encrypted_extra(database, params) - connect_args = params.setdefault("connect_args", {}) - self.assertEqual(connect_args.get("http_scheme"), "https") - auth.assert_called_once_with(**auth_params) - - def test_auth_custom_auth(self): - database = Mock() - auth_class = Mock() - - auth_method = "custom_auth" - auth_params = {"params1": "params1", "params2": "params2"} - database.encrypted_extra = json.dumps( - {"auth_method": auth_method, "auth_params": auth_params} - ) - - with patch.dict( - "superset.config.ALLOWED_EXTRA_AUTHENTICATIONS", - {"trino": {"custom_auth": auth_class}}, - clear=True, - ): - params: Dict[str, Any] = {} - TrinoEngineSpec.update_params_from_encrypted_extra(database, params) - - connect_args = params.setdefault("connect_args", {}) - self.assertEqual(connect_args.get("http_scheme"), "https") - - auth_class.assert_called_once_with(**auth_params) - - def test_auth_custom_auth_denied(self): - database = Mock() - auth_method = "my.module:TrinoAuthClass" - auth_params = {"params1": "params1", "params2": "params2"} - database.encrypted_extra = json.dumps( - {"auth_method": auth_method, "auth_params": auth_params} - ) - - superset.config.ALLOWED_EXTRA_AUTHENTICATIONS = {} - - with pytest.raises(ValueError) as excinfo: - TrinoEngineSpec.update_params_from_encrypted_extra(database, {}) - - assert str(excinfo.value) == ( - f"For security reason, custom authentication '{auth_method}' " - f"must be listed in 'ALLOWED_EXTRA_AUTHENTICATIONS' config" - ) - - def test_convert_dttm(self): - dttm = self.get_dttm() - - self.assertEqual( - TrinoEngineSpec.convert_dttm("TIMESTAMP", dttm), - "TIMESTAMP '2019-01-02 03:04:05.678900'", - ) - - self.assertEqual( - TrinoEngineSpec.convert_dttm("TIMESTAMP(3)", dttm), - "TIMESTAMP '2019-01-02 03:04:05.678900'", - ) - - self.assertEqual( - TrinoEngineSpec.convert_dttm("TIMESTAMP WITH TIME ZONE", dttm), - "TIMESTAMP '2019-01-02 03:04:05.678900'", - ) - - self.assertEqual( - TrinoEngineSpec.convert_dttm("TIMESTAMP(3) WITH TIME ZONE", dttm), - "TIMESTAMP '2019-01-02 03:04:05.678900'", - ) - - self.assertEqual( - TrinoEngineSpec.convert_dttm("DATE", dttm), - "DATE '2019-01-02'", - ) - - def test_extra_table_metadata(self): - db = mock.Mock() - db.get_indexes = mock.Mock( - return_value=[{"column_names": ["ds", "hour"], "name": "partition"}] - ) - db.get_extra = mock.Mock(return_value={}) - db.has_view_by_name = mock.Mock(return_value=None) - db.get_df = mock.Mock( - return_value=pd.DataFrame({"ds": ["01-01-19"], "hour": [1]}) - ) - result = TrinoEngineSpec.extra_table_metadata(db, "test_table", "test_schema") - assert result["partitions"]["cols"] == ["ds", "hour"] - assert result["partitions"]["latest"] == {"ds": "01-01-19", "hour": 1} diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index dfba16179..400391351 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -39,7 +39,6 @@ from superset.utils.core import ( AdhocMetricExpressionType, FilterOperator, GenericDataType, - TemporalType, ) from superset.utils.database import get_example_database from tests.integration_tests.fixtures.birth_names_dashboard import ( @@ -805,7 +804,7 @@ def test__normalize_prequery_result_type( def _convert_dttm( target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None ) -> Optional[str]: - if target_type.upper() == TemporalType.TIMESTAMP: + if target_type.upper() == "TIMESTAMP": return f"""TIME_PARSE('{dttm.isoformat(timespec="seconds")}')""" return None diff --git a/tests/unit_tests/db_engine_specs/test_athena.py b/tests/unit_tests/db_engine_specs/test_athena.py index a1243ac09..51ec6656a 100644 --- a/tests/unit_tests/db_engine_specs/test_athena.py +++ b/tests/unit_tests/db_engine_specs/test_athena.py @@ -17,8 +17,12 @@ # pylint: disable=unused-argument, import-outside-toplevel, protected-access import re from datetime import datetime +from typing import Optional + +import pytest from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm from tests.unit_tests.fixtures.common import dttm SYNTAX_ERROR_REGEX = re.compile( @@ -26,19 +30,20 @@ SYNTAX_ERROR_REGEX = re.compile( ) -def test_convert_dttm(dttm: datetime) -> None: - """ - Test that date objects are converted correctly. - """ +@pytest.mark.parametrize( + "target_type,expected_result", + [ + ("Date", "DATE '2019-01-02'"), + ("TimeStamp", "TIMESTAMP '2019-01-02 03:04:05.678'"), + ("UnknownType", None), + ], +) +def test_convert_dttm( + target_type: str, expected_result: Optional[str], dttm: datetime +) -> None: + from superset.db_engine_specs.athena import AthenaEngineSpec as spec - from superset.db_engine_specs.athena import AthenaEngineSpec - - assert AthenaEngineSpec.convert_dttm("DATE", dttm) == "DATE '2019-01-02'" - - assert ( - AthenaEngineSpec.convert_dttm("TIMESTAMP", dttm) - == "TIMESTAMP '2019-01-02 03:04:05.678'" - ) + assert_convert_dttm(spec, target_type, expected_result, dttm) def test_extract_errors() -> None: diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index 79a83c6b0..868a6bbdc 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -17,9 +17,13 @@ # pylint: disable=unused-argument, import-outside-toplevel, protected-access from textwrap import dedent +from typing import Any, Dict, Optional, Type import pytest -from sqlalchemy.types import TypeEngine +from sqlalchemy import types + +from superset.utils.core import GenericDataType +from tests.unit_tests.db_engine_specs.utils import assert_column_spec def test_get_text_clause_with_colon() -> None: @@ -94,8 +98,43 @@ select 'USD' as cur ), ], ) -def test_cte_query_parsing(original: TypeEngine, expected: str) -> None: +def test_cte_query_parsing(original: types.TypeEngine, expected: str) -> None: from superset.db_engine_specs.base import BaseEngineSpec actual = BaseEngineSpec.get_cte_query(original) assert actual == expected + + +@pytest.mark.parametrize( + "native_type,sqla_type,attrs,generic_type,is_dttm", + [ + ("SMALLINT", types.SmallInteger, None, GenericDataType.NUMERIC, False), + ("INTEGER", types.Integer, None, GenericDataType.NUMERIC, False), + ("BIGINT", types.BigInteger, None, GenericDataType.NUMERIC, False), + ("DECIMAL", types.Numeric, None, GenericDataType.NUMERIC, False), + ("NUMERIC", types.Numeric, None, GenericDataType.NUMERIC, False), + ("REAL", types.REAL, None, GenericDataType.NUMERIC, False), + ("DOUBLE PRECISION", types.Float, None, GenericDataType.NUMERIC, False), + ("MONEY", types.Numeric, None, GenericDataType.NUMERIC, False), + # String + ("CHAR", types.String, None, GenericDataType.STRING, False), + ("VARCHAR", types.String, None, GenericDataType.STRING, False), + ("TEXT", types.String, None, GenericDataType.STRING, False), + # Temporal + ("DATE", types.Date, None, GenericDataType.TEMPORAL, True), + ("TIMESTAMP", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True), + ("TIME", types.Time, None, GenericDataType.TEMPORAL, True), + # Boolean + ("BOOLEAN", types.Boolean, None, GenericDataType.BOOLEAN, False), + ], +) +def test_get_column_spec( + native_type: str, + sqla_type: Type[types.TypeEngine], + attrs: Optional[Dict[str, Any]], + generic_type: GenericDataType, + is_dttm: bool, +) -> None: + from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec as spec + + assert_column_spec(spec, native_type, sqla_type, attrs, generic_type, is_dttm) diff --git a/tests/unit_tests/db_engine_specs/test_bigquery.py b/tests/unit_tests/db_engine_specs/test_bigquery.py index 362e88804..5b9c6a956 100644 --- a/tests/unit_tests/db_engine_specs/test_bigquery.py +++ b/tests/unit_tests/db_engine_specs/test_bigquery.py @@ -18,12 +18,18 @@ # pylint: disable=line-too-long, import-outside-toplevel, protected-access, invalid-name import json +from datetime import datetime +from typing import Optional +import pytest from pytest_mock import MockFixture from sqlalchemy import select from sqlalchemy.sql import sqltypes from sqlalchemy_bigquery import BigQueryDialect +from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm +from tests.unit_tests.fixtures.common import dttm + def test_get_fields() -> None: """ @@ -285,3 +291,24 @@ def test_parse_error_raises_exception() -> None: == expected_result ) assert str(BigQueryEngineSpec.parse_error_exception(Exception(message_2))) == "6" + + +@pytest.mark.parametrize( + "target_type,expected_result", + [ + ("Date", "CAST('2019-01-02' AS DATE)"), + ("DateTime", "CAST('2019-01-02T03:04:05.678900' AS DATETIME)"), + ("TimeStamp", "CAST('2019-01-02T03:04:05.678900' AS TIMESTAMP)"), + ("Time", "CAST('03:04:05.678900' AS TIME)"), + ("UnknownType", None), + ], +) +def test_convert_dttm( + target_type: str, expected_result: Optional[str], dttm: datetime +) -> None: + """ + DB Eng Specs (bigquery): Test conversion to date time + """ + from superset.db_engine_specs.bigquery import BigQueryEngineSpec as spec + + assert_convert_dttm(spec, target_type, expected_result, dttm) diff --git a/tests/unit_tests/db_engine_specs/test_clickhouse.py b/tests/unit_tests/db_engine_specs/test_clickhouse.py index ca01c304f..9a52b0461 100644 --- a/tests/unit_tests/db_engine_specs/test_clickhouse.py +++ b/tests/unit_tests/db_engine_specs/test_clickhouse.py @@ -14,22 +14,31 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + from datetime import datetime -from unittest import mock +from typing import Optional +from unittest.mock import Mock import pytest +from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm from tests.unit_tests.fixtures.common import dttm -def test_convert_dttm(dttm: datetime) -> None: - from superset.db_engine_specs.clickhouse import ClickHouseEngineSpec +@pytest.mark.parametrize( + "target_type,expected_result", + [ + ("Date", "toDate('2019-01-02')"), + ("DateTime", "toDateTime('2019-01-02 03:04:05')"), + ("UnknownType", None), + ], +) +def test_convert_dttm( + target_type: str, expected_result: Optional[str], dttm: datetime +) -> None: + from superset.db_engine_specs.clickhouse import ClickHouseEngineSpec as spec - assert ClickHouseEngineSpec.convert_dttm("DATE", dttm) == "toDate('2019-01-02')" - assert ( - ClickHouseEngineSpec.convert_dttm("DATETIME", dttm) - == "toDateTime('2019-01-02 03:04:05')" - ) + assert_convert_dttm(spec, target_type, expected_result, dttm) def test_execute_connection_error() -> None: @@ -38,7 +47,7 @@ def test_execute_connection_error() -> None: from superset.db_engine_specs.clickhouse import ClickHouseEngineSpec from superset.db_engine_specs.exceptions import SupersetDBAPIDatabaseError - cursor = mock.Mock() + cursor = Mock() cursor.execute.side_effect = NewConnectionError( "Dummypool", "Exception with sensitive data" ) diff --git a/tests/unit_tests/db_engine_specs/test_crate.py b/tests/unit_tests/db_engine_specs/test_crate.py new file mode 100644 index 000000000..2cb1cd789 --- /dev/null +++ b/tests/unit_tests/db_engine_specs/test_crate.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from datetime import datetime +from typing import Optional + +import pytest + +from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm +from tests.unit_tests.fixtures.common import dttm + + +def test_epoch_to_dttm() -> None: + """ + DB Eng Specs (crate): Test epoch to dttm + """ + from superset.db_engine_specs.crate import CrateEngineSpec + + assert CrateEngineSpec.epoch_to_dttm() == "{col} * 1000" + + +def test_epoch_ms_to_dttm() -> None: + """ + DB Eng Specs (crate): Test epoch ms to dttm + """ + from superset.db_engine_specs.crate import CrateEngineSpec + + assert CrateEngineSpec.epoch_ms_to_dttm() == "{col}" + + +def test_alter_new_orm_column() -> None: + """ + DB Eng Specs (crate): Test alter orm column + """ + from superset.connectors.sqla.models import SqlaTable, TableColumn + from superset.db_engine_specs.crate import CrateEngineSpec + from superset.models.core import Database + + database = Database(database_name="crate", sqlalchemy_uri="crate://db") + tbl = SqlaTable(table_name="tbl", database=database) + col = TableColumn(column_name="ts", type="TIMESTAMP", table=tbl) + CrateEngineSpec.alter_new_orm_column(col) + assert col.python_date_format == "epoch_ms" + + +@pytest.mark.parametrize( + "target_type,expected_result", + [ + ("TimeStamp", "1546398245678.9"), + ("UnknownType", None), + ], +) +def test_convert_dttm( + target_type: str, expected_result: Optional[str], dttm: datetime +) -> None: + from superset.db_engine_specs.crate import CrateEngineSpec as spec + + assert_convert_dttm(spec, target_type, expected_result, dttm) diff --git a/tests/unit_tests/db_engine_specs/test_databricks.py b/tests/unit_tests/db_engine_specs/test_databricks.py index 86ffbc613..49d65b324 100644 --- a/tests/unit_tests/db_engine_specs/test_databricks.py +++ b/tests/unit_tests/db_engine_specs/test_databricks.py @@ -17,13 +17,16 @@ # pylint: disable=unused-argument, import-outside-toplevel, protected-access import json +from datetime import datetime +from typing import Optional import pytest from pytest_mock import MockerFixture from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType -from superset.utils.core import GenericDataType +from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm +from tests.unit_tests.fixtures.common import dttm def test_get_parameters_from_uri() -> None: @@ -109,37 +112,6 @@ def test_parameters_json_schema() -> None: } -def test_generic_type() -> None: - """ - assert that generic types match - """ - from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec - from tests.integration_tests.db_engine_specs.base_tests import assert_generic_types - - type_expectations = ( - # Numeric - ("SMALLINT", GenericDataType.NUMERIC), - ("INTEGER", GenericDataType.NUMERIC), - ("BIGINT", GenericDataType.NUMERIC), - ("DECIMAL", GenericDataType.NUMERIC), - ("NUMERIC", GenericDataType.NUMERIC), - ("REAL", GenericDataType.NUMERIC), - ("DOUBLE PRECISION", GenericDataType.NUMERIC), - ("MONEY", GenericDataType.NUMERIC), - # String - ("CHAR", GenericDataType.STRING), - ("VARCHAR", GenericDataType.STRING), - ("TEXT", GenericDataType.STRING), - # Temporal - ("DATE", GenericDataType.TEMPORAL), - ("TIMESTAMP", GenericDataType.TEMPORAL), - ("TIME", GenericDataType.TEMPORAL), - # Boolean - ("BOOLEAN", GenericDataType.BOOLEAN), - ) - assert_generic_types(DatabricksNativeEngineSpec, type_expectations) - - def test_get_extra_params(mocker: MockerFixture) -> None: """ Test the ``get_extra_params`` method. @@ -253,3 +225,22 @@ def test_extract_errors_with_context() -> None: }, ) ] + + +@pytest.mark.parametrize( + "target_type,expected_result", + [ + ("Date", "CAST('2019-01-02' AS DATE)"), + ( + "TimeStamp", + "CAST('2019-01-02 03:04:05.678900' AS TIMESTAMP)", + ), + ("UnknownType", None), + ], +) +def test_convert_dttm( + target_type: str, expected_result: Optional[str], dttm: datetime +) -> None: + from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec as spec + + assert_convert_dttm(spec, target_type, expected_result, dttm) diff --git a/tests/integration_tests/db_engine_specs/dremio_tests.py b/tests/unit_tests/db_engine_specs/test_dremio.py similarity index 57% rename from tests/integration_tests/db_engine_specs/dremio_tests.py rename to tests/unit_tests/db_engine_specs/test_dremio.py index 5d678c947..6b1e8203b 100644 --- a/tests/integration_tests/db_engine_specs/dremio_tests.py +++ b/tests/unit_tests/db_engine_specs/test_dremio.py @@ -14,20 +14,29 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from superset.db_engine_specs.dremio import DremioEngineSpec -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec +from datetime import datetime +from typing import Optional + +import pytest + +from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm +from tests.unit_tests.fixtures.common import dttm -class TestDremioDbEngineSpec(TestDbEngineSpec): - def test_convert_dttm(self): - dttm = self.get_dttm() - - self.assertEqual( - DremioEngineSpec.convert_dttm("DATE", dttm), - "TO_DATE('2019-01-02', 'YYYY-MM-DD')", - ) - - self.assertEqual( - DremioEngineSpec.convert_dttm("TIMESTAMP", dttm), +@pytest.mark.parametrize( + "target_type,expected_result", + [ + ("Date", "TO_DATE('2019-01-02', 'YYYY-MM-DD')"), + ( + "TimeStamp", "TO_TIMESTAMP('2019-01-02 03:04:05.678', 'YYYY-MM-DD HH24:MI:SS.FFF')", - ) + ), + ("UnknownType", None), + ], +) +def test_convert_dttm( + target_type: str, expected_result: Optional[str], dttm: datetime +) -> None: + from superset.db_engine_specs.dremio import DremioEngineSpec as spec + + assert_convert_dttm(spec, target_type, expected_result, dttm) diff --git a/tests/unit_tests/db_engine_specs/test_drill.py b/tests/unit_tests/db_engine_specs/test_drill.py index 195ad8aca..e56df5d47 100644 --- a/tests/unit_tests/db_engine_specs/test_drill.py +++ b/tests/unit_tests/db_engine_specs/test_drill.py @@ -16,7 +16,13 @@ # under the License. # pylint: disable=unused-argument, import-outside-toplevel, protected-access -from pytest import raises +from datetime import datetime +from typing import Optional + +import pytest + +from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm +from tests.unit_tests.fixtures.common import dttm def test_odbc_impersonation() -> None: @@ -82,5 +88,21 @@ def test_invalid_impersonation() -> None: url = URL("drill+foobar") username = "DoAsUser" - with raises(SupersetDBAPIProgrammingError): + with pytest.raises(SupersetDBAPIProgrammingError): DrillEngineSpec.get_url_for_impersonation(url, True, username) + + +@pytest.mark.parametrize( + "target_type,expected_result", + [ + ("Date", "TO_DATE('2019-01-02', 'yyyy-MM-dd')"), + ("TimeStamp", "TO_TIMESTAMP('2019-01-02 03:04:05', 'yyyy-MM-dd HH:mm:ss')"), + ("UnknownType", None), + ], +) +def test_convert_dttm( + target_type: str, expected_result: Optional[str], dttm: datetime +) -> None: + from superset.db_engine_specs.drill import DrillEngineSpec as spec + + assert_convert_dttm(spec, target_type, expected_result, dttm) diff --git a/tests/unit_tests/db_engine_specs/test_druid.py b/tests/unit_tests/db_engine_specs/test_druid.py new file mode 100644 index 000000000..d090dffcd --- /dev/null +++ b/tests/unit_tests/db_engine_specs/test_druid.py @@ -0,0 +1,95 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from datetime import datetime +from typing import Optional +from unittest import mock + +import pytest +from sqlalchemy import column + +from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm +from tests.unit_tests.fixtures.common import dttm + + +@pytest.mark.parametrize( + "target_type,expected_result", + [ + ("Date", "CAST(TIME_PARSE('2019-01-02') AS DATE)"), + ("DateTime", "TIME_PARSE('2019-01-02T03:04:05')"), + ("TimeStamp", "TIME_PARSE('2019-01-02T03:04:05')"), + ("UnknownType", None), + ], +) +def test_convert_dttm( + target_type: str, expected_result: Optional[str], dttm: datetime +) -> None: + from superset.db_engine_specs.druid import DruidEngineSpec as spec + + assert_convert_dttm(spec, target_type, expected_result, dttm) + + +@pytest.mark.parametrize( + "time_grain,expected_result", + [ + ("PT1S", "TIME_FLOOR(CAST(col AS TIMESTAMP), 'PT1S')"), + ("PT5M", "TIME_FLOOR(CAST({col} AS TIMESTAMP), 'PT5M')"), + ( + "P1W/1970-01-03T00:00:00Z", + "TIME_SHIFT(TIME_FLOOR(TIME_SHIFT(CAST(col AS TIMESTAMP), 'P1D', 1), 'P1W'), 'P1D', 5)", + ), + ( + "1969-12-28T00:00:00Z/P1W", + "TIME_SHIFT(TIME_FLOOR(TIME_SHIFT(CAST(col AS TIMESTAMP), 'P1D', 1), 'P1W'), 'P1D', -1)", + ), + ], +) +def test_timegrain_expressions(time_grain: str, expected_result: str) -> None: + """ + DB Eng Specs (druid): Test time grain expressions + """ + from superset.db_engine_specs.druid import DruidEngineSpec + + assert str( + DruidEngineSpec.get_timestamp_expr( + col=column("col"), pdf=None, time_grain=time_grain + ) + ) + + +def test_extras_without_ssl() -> None: + from superset.db_engine_specs.druid import DruidEngineSpec + from tests.integration_tests.fixtures.database import default_db_extra + + db = mock.Mock() + db.extra = default_db_extra + db.server_cert = None + extras = DruidEngineSpec.get_extra_params(db) + assert "connect_args" not in extras["engine_params"] + + +def test_extras_with_ssl() -> None: + from superset.db_engine_specs.druid import DruidEngineSpec + from tests.integration_tests.fixtures.certificates import ssl_certificate + from tests.integration_tests.fixtures.database import default_db_extra + + db = mock.Mock() + db.extra = default_db_extra + db.server_cert = ssl_certificate + extras = DruidEngineSpec.get_extra_params(db) + connect_args = extras["engine_params"]["connect_args"] + assert connect_args["scheme"] == "https" + assert "ssl_verify_cert" in connect_args diff --git a/tests/integration_tests/db_engine_specs/drill_tests.py b/tests/unit_tests/db_engine_specs/test_duckdb.py similarity index 54% rename from tests/integration_tests/db_engine_specs/drill_tests.py rename to tests/unit_tests/db_engine_specs/test_duckdb.py index e89462ee5..72d018f4f 100644 --- a/tests/integration_tests/db_engine_specs/drill_tests.py +++ b/tests/unit_tests/db_engine_specs/test_duckdb.py @@ -14,20 +14,27 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from superset.db_engine_specs.drill import DrillEngineSpec -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec + +from datetime import datetime +from typing import Optional + +import pytest + +from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm +from tests.unit_tests.fixtures.common import dttm -class TestDrillDbEngineSpec(TestDbEngineSpec): - def test_convert_dttm(self): - dttm = self.get_dttm() +@pytest.mark.parametrize( + "target_type,expected_result", + [ + ("Text", "'2019-01-02 03:04:05.678900'"), + ("DateTime", "'2019-01-02 03:04:05.678900'"), + ("UnknownType", None), + ], +) +def test_convert_dttm( + target_type: str, expected_result: Optional[str], dttm: datetime +) -> None: + from superset.db_engine_specs.duckdb import DuckDBEngineSpec as spec - self.assertEqual( - DrillEngineSpec.convert_dttm("DATE", dttm), - "TO_DATE('2019-01-02', 'yyyy-MM-dd')", - ) - - self.assertEqual( - DrillEngineSpec.convert_dttm("TIMESTAMP", dttm), - "TO_TIMESTAMP('2019-01-02 03:04:05', 'yyyy-MM-dd HH:mm:ss')", - ) + assert_convert_dttm(spec, target_type, expected_result, dttm) diff --git a/tests/unit_tests/db_engine_specs/test_dynamodb.py b/tests/unit_tests/db_engine_specs/test_dynamodb.py index bf2b555ce..26196f5b4 100644 --- a/tests/unit_tests/db_engine_specs/test_dynamodb.py +++ b/tests/unit_tests/db_engine_specs/test_dynamodb.py @@ -14,24 +14,27 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from datetime import datetime +from datetime import datetime +from typing import Optional + +import pytest + +from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm from tests.unit_tests.fixtures.common import dttm -def test_convert_dttm(dttm: datetime) -> None: - from superset.db_engine_specs.dynamodb import DynamoDBEngineSpec +@pytest.mark.parametrize( + "target_type,expected_result", + [ + ("text", "'2019-01-02 03:04:05'"), + ("dateTime", "'2019-01-02 03:04:05'"), + ("unknowntype", None), + ], +) +def test_convert_dttm( + target_type: str, expected_result: Optional[str], dttm: datetime +) -> None: + from superset.db_engine_specs.dynamodb import DynamoDBEngineSpec as spec - assert DynamoDBEngineSpec.convert_dttm("TEXT", dttm) == "'2019-01-02 03:04:05'" - - -def test_convert_dttm_lower(dttm: datetime) -> None: - from superset.db_engine_specs.dynamodb import DynamoDBEngineSpec - - assert DynamoDBEngineSpec.convert_dttm("text", dttm) == "'2019-01-02 03:04:05'" - - -def test_convert_dttm_invalid_type(dttm: datetime) -> None: - from superset.db_engine_specs.dynamodb import DynamoDBEngineSpec - - assert DynamoDBEngineSpec.convert_dttm("other", dttm) is None + assert_convert_dttm(spec, target_type, expected_result, dttm) diff --git a/tests/unit_tests/db_engine_specs/test_elasticsearch.py b/tests/unit_tests/db_engine_specs/test_elasticsearch.py new file mode 100644 index 000000000..de55c6342 --- /dev/null +++ b/tests/unit_tests/db_engine_specs/test_elasticsearch.py @@ -0,0 +1,106 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from datetime import datetime +from typing import Any, Dict, Optional +from unittest.mock import MagicMock + +import pytest +from sqlalchemy import column + +from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm +from tests.unit_tests.fixtures.common import dttm + + +@pytest.mark.parametrize( + "target_type,db_extra,expected_result", + [ + ("DateTime", None, "CAST('2019-01-02T03:04:05' AS DATETIME)"), + ( + "DateTime", + {"version": "7.7"}, + "CAST('2019-01-02T03:04:05' AS DATETIME)", + ), + ( + "DateTime", + {"version": "7.8"}, + "DATETIME_PARSE('2019-01-02 03:04:05', 'yyyy-MM-dd HH:mm:ss')", + ), + ( + "DateTime", + {"version": "unparseable semver version"}, + "CAST('2019-01-02T03:04:05' AS DATETIME)", + ), + ("Unknown", None, None), + ], +) +def test_elasticsearch_convert_dttm( + target_type: str, + db_extra: Optional[Dict[str, Any]], + expected_result: Optional[str], + dttm: datetime, +) -> None: + from superset.db_engine_specs.elasticsearch import ElasticSearchEngineSpec as spec + + assert_convert_dttm(spec, target_type, expected_result, dttm, db_extra) + + +@pytest.mark.parametrize( + "target_type,expected_result", + [ + ("DateTime", "'2019-01-02T03:04:05'"), + ("Unknown", None), + ], +) +def test_opendistro_convert_dttm( + target_type: str, + expected_result: Optional[str], + dttm: datetime, +) -> None: + from superset.db_engine_specs.elasticsearch import OpenDistroEngineSpec as spec + + assert_convert_dttm(spec, target_type, expected_result, dttm) + + +@pytest.mark.parametrize( + "original,expected", + [ + ("Col", "Col"), + ("Col.keyword", "Col_keyword"), + ], +) +def test_opendistro_sqla_column_label(original: str, expected: str) -> None: + """ + DB Eng Specs (opendistro): Test column label + """ + from superset.db_engine_specs.elasticsearch import OpenDistroEngineSpec + + assert OpenDistroEngineSpec.make_label_compatible(original) == expected + + +def test_opendistro_strip_comments() -> None: + """ + DB Eng Specs (opendistro): Test execute sql strip comments + """ + from superset.db_engine_specs.elasticsearch import OpenDistroEngineSpec + + mock_cursor = MagicMock() + mock_cursor.execute.return_value = [] + + OpenDistroEngineSpec.execute( + mock_cursor, "-- some comment \nSELECT 1\n --other comment" + ) + mock_cursor.execute.assert_called_once_with("SELECT 1\n") diff --git a/tests/unit_tests/db_engine_specs/test_firebird.py b/tests/unit_tests/db_engine_specs/test_firebird.py new file mode 100644 index 000000000..c1add9167 --- /dev/null +++ b/tests/unit_tests/db_engine_specs/test_firebird.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime +from typing import Optional + +import pytest + +from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm +from tests.unit_tests.fixtures.common import dttm + + +@pytest.mark.parametrize( + "time_grain,expected", + [ + (None, "timestamp_column"), + ( + "PT1S", + ( + "CAST(CAST(timestamp_column AS DATE) " + "|| ' ' " + "|| EXTRACT(HOUR FROM timestamp_column) " + "|| ':' " + "|| EXTRACT(MINUTE FROM timestamp_column) " + "|| ':' " + "|| FLOOR(EXTRACT(SECOND FROM timestamp_column)) AS TIMESTAMP)" + ), + ), + ( + "PT1M", + ( + "CAST(CAST(timestamp_column AS DATE) " + "|| ' ' " + "|| EXTRACT(HOUR FROM timestamp_column) " + "|| ':' " + "|| EXTRACT(MINUTE FROM timestamp_column) " + "|| ':00' AS TIMESTAMP)" + ), + ), + ("P1D", "CAST(timestamp_column AS DATE)"), + ( + "P1M", + ( + "CAST(EXTRACT(YEAR FROM timestamp_column) " + "|| '-' " + "|| EXTRACT(MONTH FROM timestamp_column) " + "|| '-01' AS DATE)" + ), + ), + ("P1Y", "CAST(EXTRACT(YEAR FROM timestamp_column) || '-01-01' AS DATE)"), + ], +) +def test_time_grain_expressions(time_grain: Optional[str], expected: str) -> None: + from superset.db_engine_specs.firebird import FirebirdEngineSpec + + assert ( + FirebirdEngineSpec._time_grain_expressions[time_grain].format( + col="timestamp_column", + ) + == expected + ) + + +def test_epoch_to_dttm() -> None: + from superset.db_engine_specs.firebird import FirebirdEngineSpec + + assert ( + FirebirdEngineSpec.epoch_to_dttm().format(col="timestamp_column") + == "DATEADD(second, timestamp_column, CAST('00:00:00' AS TIMESTAMP))" + ) + + +@pytest.mark.parametrize( + "target_type,expected_result", + [ + ("Date", "CAST('2019-01-02' AS DATE)"), + ("DateTime", "CAST('2019-01-02 03:04:05.6789' AS TIMESTAMP)"), + ("TimeStamp", "CAST('2019-01-02 03:04:05.6789' AS TIMESTAMP)"), + ("Time", "CAST('03:04:05.678900' AS TIME)"), + ("UnknownType", None), + ], +) +def test_convert_dttm( + target_type: str, expected_result: Optional[str], dttm: datetime +) -> None: + from superset.db_engine_specs.firebird import FirebirdEngineSpec as spec + + assert_convert_dttm(spec, target_type, expected_result, dttm) diff --git a/tests/unit_tests/db_engine_specs/test_firebolt.py b/tests/unit_tests/db_engine_specs/test_firebolt.py new file mode 100644 index 000000000..eb84bb14b --- /dev/null +++ b/tests/unit_tests/db_engine_specs/test_firebolt.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from datetime import datetime +from typing import Optional + +import pytest + +from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm +from tests.unit_tests.fixtures.common import dttm + + +@pytest.mark.parametrize( + "target_type,expected_result", + [ + ("Date", "CAST('2019-01-02' AS DATE)"), + ( + "DateTime", + "CAST('2019-01-02T03:04:05' AS DATETIME)", + ), + ( + "TimeStamp", + "CAST('2019-01-02T03:04:05' AS TIMESTAMP)", + ), + ("UnknownType", None), + ], +) +def test_convert_dttm( + target_type: str, expected_result: Optional[str], dttm: datetime +) -> None: + from superset.db_engine_specs.firebolt import FireboltEngineSpec as spec + + assert_convert_dttm(spec, target_type, expected_result, dttm) + + +def test_epoch_to_dttm() -> None: + from superset.db_engine_specs.firebolt import FireboltEngineSpec + + assert ( + FireboltEngineSpec.epoch_to_dttm().format(col="timestamp_column") + == "from_unixtime(timestamp_column)" + ) diff --git a/tests/integration_tests/db_engine_specs/hana_tests.py b/tests/unit_tests/db_engine_specs/test_hana.py similarity index 57% rename from tests/integration_tests/db_engine_specs/hana_tests.py rename to tests/unit_tests/db_engine_specs/test_hana.py index 06eee032e..1d1ac6390 100644 --- a/tests/integration_tests/db_engine_specs/hana_tests.py +++ b/tests/unit_tests/db_engine_specs/test_hana.py @@ -14,20 +14,30 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from superset.db_engine_specs.hana import HanaEngineSpec -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec + +from datetime import datetime +from typing import Optional + +import pytest + +from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm +from tests.unit_tests.fixtures.common import dttm -class TestHanaDbEngineSpec(TestDbEngineSpec): - def test_convert_dttm(self): - dttm = self.get_dttm() - - self.assertEqual( - HanaEngineSpec.convert_dttm("DATE", dttm), - "TO_DATE('2019-01-02', 'YYYY-MM-DD')", - ) - - self.assertEqual( - HanaEngineSpec.convert_dttm("TIMESTAMP", dttm), +@pytest.mark.parametrize( + "target_type,expected_result", + [ + ("Date", "TO_DATE('2019-01-02', 'YYYY-MM-DD')"), + ( + "TimeStamp", "TO_TIMESTAMP('2019-01-02T03:04:05.678900', 'YYYY-MM-DD\"T\"HH24:MI:SS.ff6')", - ) + ), + ("UnknownType", None), + ], +) +def test_convert_dttm( + target_type: str, expected_result: Optional[str], dttm: datetime +) -> None: + from superset.db_engine_specs.hana import HanaEngineSpec as spec + + assert_convert_dttm(spec, target_type, expected_result, dttm) diff --git a/tests/unit_tests/db_engine_specs/test_hive.py b/tests/unit_tests/db_engine_specs/test_hive.py new file mode 100644 index 000000000..3a5cb9140 --- /dev/null +++ b/tests/unit_tests/db_engine_specs/test_hive.py @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from datetime import datetime +from typing import Optional + +import pytest + +from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm +from tests.unit_tests.fixtures.common import dttm + + +@pytest.mark.parametrize( + "target_type,expected_result", + [ + ("Date", "CAST('2019-01-02' AS DATE)"), + ( + "TimeStamp", + "CAST('2019-01-02 03:04:05.678900' AS TIMESTAMP)", + ), + ("UnknownType", None), + ], +) +def test_convert_dttm( + target_type: str, expected_result: Optional[str], dttm: datetime +) -> None: + from superset.db_engine_specs.hive import HiveEngineSpec as spec + + assert_convert_dttm(spec, target_type, expected_result, dttm) diff --git a/tests/unit_tests/db_engine_specs/test_impala.py b/tests/unit_tests/db_engine_specs/test_impala.py new file mode 100644 index 000000000..8a4244052 --- /dev/null +++ b/tests/unit_tests/db_engine_specs/test_impala.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime +from typing import Optional + +import pytest + +from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm +from tests.unit_tests.fixtures.common import dttm + + +@pytest.mark.parametrize( + "target_type,expected_result", + [ + ("Date", "CAST('2019-01-02' AS DATE)"), + ("TimeStamp", "CAST('2019-01-02T03:04:05.678900' AS TIMESTAMP)"), + ("UnknownType", None), + ], +) +def test_convert_dttm( + target_type: str, expected_result: Optional[str], dttm: datetime +) -> None: + from superset.db_engine_specs.impala import ImpalaEngineSpec as spec + + assert_convert_dttm(spec, target_type, expected_result, dttm) diff --git a/tests/unit_tests/db_engine_specs/test_kusto.py b/tests/unit_tests/db_engine_specs/test_kusto.py index e556418a8..538eafc6b 100644 --- a/tests/unit_tests/db_engine_specs/test_kusto.py +++ b/tests/unit_tests/db_engine_specs/test_kusto.py @@ -16,9 +16,11 @@ # under the License. # pylint: disable=unused-argument, import-outside-toplevel, protected-access from datetime import datetime +from typing import Optional import pytest +from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm from tests.unit_tests.fixtures.common import dttm @@ -108,45 +110,35 @@ def test_kql_parse_sql() -> None: @pytest.mark.parametrize( - "target_type,expected_dttm", + "target_type,expected_result", [ - ("DATETIME", "datetime(2019-01-02T03:04:05.678900)"), - ("TIMESTAMP", "datetime(2019-01-02T03:04:05.678900)"), - ("DATE", "datetime(2019-01-02)"), + ("DateTime", "datetime(2019-01-02T03:04:05.678900)"), + ("TimeStamp", "datetime(2019-01-02T03:04:05.678900)"), + ("Date", "datetime(2019-01-02)"), + ("UnknownType", None), ], ) def test_kql_convert_dttm( - target_type: str, - expected_dttm: str, - dttm: datetime, + target_type: str, expected_result: Optional[str], dttm: datetime ) -> None: - """ - Test that date objects are converted correctly. - """ + from superset.db_engine_specs.kusto import KustoKqlEngineSpec as spec - from superset.db_engine_specs.kusto import KustoKqlEngineSpec - - assert expected_dttm == KustoKqlEngineSpec.convert_dttm(target_type, dttm) + assert_convert_dttm(spec, target_type, expected_result, dttm) @pytest.mark.parametrize( - "target_type,expected_dttm", + "target_type,expected_result", [ - ("DATETIME", "CONVERT(DATETIME, '2019-01-02T03:04:05.678', 126)"), - ("DATE", "CONVERT(DATE, '2019-01-02', 23)"), - ("SMALLDATETIME", "CONVERT(SMALLDATETIME, '2019-01-02 03:04:05', 20)"), - ("TIMESTAMP", "CONVERT(TIMESTAMP, '2019-01-02 03:04:05', 20)"), + ("Date", "CONVERT(DATE, '2019-01-02', 23)"), + ("DateTime", "CONVERT(DATETIME, '2019-01-02T03:04:05.678', 126)"), + ("SmallDateTime", "CONVERT(SMALLDATETIME, '2019-01-02 03:04:05', 20)"), + ("TimeStamp", "CONVERT(TIMESTAMP, '2019-01-02 03:04:05', 20)"), + ("UnknownType", None), ], ) def test_sql_convert_dttm( - target_type: str, - expected_dttm: str, - dttm: datetime, + target_type: str, expected_result: Optional[str], dttm: datetime ) -> None: - """ - Test that date objects are converted correctly. - """ + from superset.db_engine_specs.kusto import KustoSqlEngineSpec as spec - from superset.db_engine_specs.kusto import KustoSqlEngineSpec - - assert expected_dttm == KustoSqlEngineSpec.convert_dttm(target_type, dttm) + assert_convert_dttm(spec, target_type, expected_result, dttm) diff --git a/tests/integration_tests/db_engine_specs/impala_tests.py b/tests/unit_tests/db_engine_specs/test_kylin.py similarity index 54% rename from tests/integration_tests/db_engine_specs/impala_tests.py rename to tests/unit_tests/db_engine_specs/test_kylin.py index 936ace98b..cbc8c9133 100644 --- a/tests/integration_tests/db_engine_specs/impala_tests.py +++ b/tests/unit_tests/db_engine_specs/test_kylin.py @@ -14,19 +14,27 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from superset.db_engine_specs.impala import ImpalaEngineSpec -from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec + +from datetime import datetime +from typing import Optional + +import pytest + +from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm +from tests.unit_tests.fixtures.common import dttm -class TestImpalaDbEngineSpec(TestDbEngineSpec): - def test_convert_dttm(self): - dttm = self.get_dttm() +@pytest.mark.parametrize( + "target_type,expected_result", + [ + ("Date", "CAST('2019-01-02' AS DATE)"), + ("TimeStamp", "CAST('2019-01-02 03:04:05' AS TIMESTAMP)"), + ("UnknownType", None), + ], +) +def test_convert_dttm( + target_type: str, expected_result: Optional[str], dttm: datetime +) -> None: + from superset.db_engine_specs.kylin import KylinEngineSpec as spec - self.assertEqual( - ImpalaEngineSpec.convert_dttm("DATE", dttm), "CAST('2019-01-02' AS DATE)" - ) - - self.assertEqual( - ImpalaEngineSpec.convert_dttm("TIMESTAMP", dttm), - "CAST('2019-01-02T03:04:05.678900' AS TIMESTAMP)", - ) + assert_convert_dttm(spec, target_type, expected_result, dttm) diff --git a/tests/unit_tests/db_engine_specs/test_mssql.py b/tests/unit_tests/db_engine_specs/test_mssql.py index 0ceee0adf..63a315c14 100644 --- a/tests/unit_tests/db_engine_specs/test_mssql.py +++ b/tests/unit_tests/db_engine_specs/test_mssql.py @@ -17,9 +17,10 @@ import unittest.mock as mock from datetime import datetime from textwrap import dedent +from typing import Any, Dict, Optional, Type import pytest -from sqlalchemy import column, table +from sqlalchemy import column, table, types from sqlalchemy.dialects import mssql from sqlalchemy.dialects.mssql import DATE, NTEXT, NVARCHAR, TEXT, VARCHAR from sqlalchemy.sql import select @@ -27,36 +28,36 @@ from sqlalchemy.types import String, TypeEngine, UnicodeText from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.utils.core import GenericDataType +from tests.unit_tests.db_engine_specs.utils import ( + assert_column_spec, + assert_convert_dttm, +) from tests.unit_tests.fixtures.common import dttm @pytest.mark.parametrize( - "type_string,type_expected,generic_type_expected", + "native_type,sqla_type,attrs,generic_type,is_dttm", [ - ("STRING", String, GenericDataType.STRING), - ("CHAR(10)", String, GenericDataType.STRING), - ("VARCHAR(10)", String, GenericDataType.STRING), - ("TEXT", String, GenericDataType.STRING), - ("NCHAR(10)", UnicodeText, GenericDataType.STRING), - ("NVARCHAR(10)", UnicodeText, GenericDataType.STRING), - ("NTEXT", UnicodeText, GenericDataType.STRING), + ("CHAR", String, None, GenericDataType.STRING, False), + ("CHAR(10)", String, None, GenericDataType.STRING, False), + ("VARCHAR", String, None, GenericDataType.STRING, False), + ("VARCHAR(10)", String, None, GenericDataType.STRING, False), + ("TEXT", String, None, GenericDataType.STRING, False), + ("NCHAR(10)", UnicodeText, None, GenericDataType.STRING, False), + ("NVARCHAR(10)", UnicodeText, None, GenericDataType.STRING, False), + ("NTEXT", UnicodeText, None, GenericDataType.STRING, False), ], ) -def test_mssql_column_types( - type_string: str, - type_expected: TypeEngine, - generic_type_expected: GenericDataType, +def test_get_column_spec( + native_type: str, + sqla_type: Type[types.TypeEngine], + attrs: Optional[Dict[str, Any]], + generic_type: GenericDataType, + is_dttm: bool, ) -> None: - from superset.db_engine_specs.mssql import MssqlEngineSpec + from superset.db_engine_specs.mssql import MssqlEngineSpec as spec - if type_expected is None: - type_assigned = MssqlEngineSpec.get_sqla_column_type(type_string) - assert type_assigned is None - else: - column_spec = MssqlEngineSpec.get_column_spec(type_string) - if column_spec is not None: - assert isinstance(column_spec.sqla_type, type_expected) - assert column_spec.generic_type == generic_type_expected + assert_column_spec(spec, native_type, sqla_type, attrs, generic_type, is_dttm) def test_where_clause_n_prefix() -> None: @@ -65,13 +66,13 @@ def test_where_clause_n_prefix() -> None: dialect = mssql.dialect() # non-unicode col - sqla_column_type = MssqlEngineSpec.get_sqla_column_type("VARCHAR(10)") + sqla_column_type = MssqlEngineSpec.get_column_types("VARCHAR(10)") assert sqla_column_type is not None type_, _ = sqla_column_type str_col = column("col", type_=type_) # unicode col - sqla_column_type = MssqlEngineSpec.get_sqla_column_type("NTEXT") + sqla_column_type = MssqlEngineSpec.get_column_types("NTEXT") assert sqla_column_type is not None type_, _ = sqla_column_type unicode_col = column("unicode_col", type_=type_) @@ -103,30 +104,31 @@ def test_time_exp_mixd_case_col_1y() -> None: @pytest.mark.parametrize( - "actual,expected", + "target_type,expected_result", [ ( - "DATE", + "date", "CONVERT(DATE, '2019-01-02', 23)", ), ( - "DATETIME", + "datetime", "CONVERT(DATETIME, '2019-01-02T03:04:05.678', 126)", ), ( - "SMALLDATETIME", + "smalldatetime", "CONVERT(SMALLDATETIME, '2019-01-02 03:04:05', 20)", ), + ("Other", None), ], ) def test_convert_dttm( - actual: str, - expected: str, + target_type: str, + expected_result: Optional[str], dttm: datetime, ) -> None: - from superset.db_engine_specs.mssql import MssqlEngineSpec + from superset.db_engine_specs.mssql import MssqlEngineSpec as spec - assert MssqlEngineSpec.convert_dttm(actual, dttm) == expected + assert_convert_dttm(spec, target_type, expected_result, dttm) def test_extract_error_message() -> None: diff --git a/tests/unit_tests/db_engine_specs/test_mysql.py b/tests/unit_tests/db_engine_specs/test_mysql.py new file mode 100644 index 000000000..4562e497c --- /dev/null +++ b/tests/unit_tests/db_engine_specs/test_mysql.py @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime +from typing import Any, Dict, Optional, Type +from unittest.mock import Mock, patch + +import pytest +from sqlalchemy import types +from sqlalchemy.dialects.mysql import ( + BIT, + DECIMAL, + DOUBLE, + FLOAT, + INTEGER, + LONGTEXT, + MEDIUMINT, + MEDIUMTEXT, + TINYINT, + TINYTEXT, +) + +from superset.utils.core import GenericDataType +from tests.unit_tests.db_engine_specs.utils import ( + assert_column_spec, + assert_convert_dttm, +) +from tests.unit_tests.fixtures.common import dttm + + +@pytest.mark.parametrize( + "native_type,sqla_type,attrs,generic_type,is_dttm", + [ + # Numeric + ("TINYINT", TINYINT, None, GenericDataType.NUMERIC, False), + ("SMALLINT", types.SmallInteger, None, GenericDataType.NUMERIC, False), + ("MEDIUMINT", MEDIUMINT, None, GenericDataType.NUMERIC, False), + ("INT", INTEGER, None, GenericDataType.NUMERIC, False), + ("BIGINT", types.BigInteger, None, GenericDataType.NUMERIC, False), + ("DECIMAL", DECIMAL, None, GenericDataType.NUMERIC, False), + ("FLOAT", FLOAT, None, GenericDataType.NUMERIC, False), + ("DOUBLE", DOUBLE, None, GenericDataType.NUMERIC, False), + ("BIT", BIT, None, GenericDataType.NUMERIC, False), + # String + ("CHAR", types.String, None, GenericDataType.STRING, False), + ("VARCHAR", types.String, None, GenericDataType.STRING, False), + ("TINYTEXT", TINYTEXT, None, GenericDataType.STRING, False), + ("MEDIUMTEXT", MEDIUMTEXT, None, GenericDataType.STRING, False), + ("LONGTEXT", LONGTEXT, None, GenericDataType.STRING, False), + # Temporal + ("DATE", types.Date, None, GenericDataType.TEMPORAL, True), + ("DATETIME", types.DateTime, None, GenericDataType.TEMPORAL, True), + ("TIMESTAMP", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True), + ("TIME", types.Time, None, GenericDataType.TEMPORAL, True), + ], +) +def test_get_column_spec( + native_type: str, + sqla_type: Type[types.TypeEngine], + attrs: Optional[Dict[str, Any]], + generic_type: GenericDataType, + is_dttm: bool, +) -> None: + from superset.db_engine_specs.mysql import MySQLEngineSpec as spec + + assert_column_spec(spec, native_type, sqla_type, attrs, generic_type, is_dttm) + + +@pytest.mark.parametrize( + "target_type,expected_result", + [ + ("Date", "STR_TO_DATE('2019-01-02', '%Y-%m-%d')"), + ( + "DateTime", + "STR_TO_DATE('2019-01-02 03:04:05.678900', '%Y-%m-%d %H:%i:%s.%f')", + ), + ("UnknownType", None), + ], +) +def test_convert_dttm( + target_type: str, expected_result: Optional[str], dttm: datetime +) -> None: + from superset.db_engine_specs.mysql import MySQLEngineSpec as spec + + assert_convert_dttm(spec, target_type, expected_result, dttm) + + +@patch("sqlalchemy.engine.Engine.connect") +def test_get_cancel_query_id(engine_mock: Mock) -> None: + from superset.db_engine_specs.mysql import MySQLEngineSpec + from superset.models.sql_lab import Query + + query = Query() + cursor_mock = engine_mock.return_value.__enter__.return_value + cursor_mock.fetchone.return_value = ["123"] + assert MySQLEngineSpec.get_cancel_query_id(cursor_mock, query) == "123" + + +@patch("sqlalchemy.engine.Engine.connect") +def test_cancel_query(engine_mock: Mock) -> None: + from superset.db_engine_specs.mysql import MySQLEngineSpec + from superset.models.sql_lab import Query + + query = Query() + cursor_mock = engine_mock.return_value.__enter__.return_value + assert MySQLEngineSpec.cancel_query(cursor_mock, query, "123") is True + + +@patch("sqlalchemy.engine.Engine.connect") +def test_cancel_query_failed(engine_mock: Mock) -> None: + from superset.db_engine_specs.mysql import MySQLEngineSpec + from superset.models.sql_lab import Query + + query = Query() + cursor_mock = engine_mock.raiseError.side_effect = Exception() + assert MySQLEngineSpec.cancel_query(cursor_mock, query, "123") is False diff --git a/tests/unit_tests/db_engine_specs/test_oracle.py b/tests/unit_tests/db_engine_specs/test_oracle.py new file mode 100644 index 000000000..0dce95697 --- /dev/null +++ b/tests/unit_tests/db_engine_specs/test_oracle.py @@ -0,0 +1,113 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from datetime import datetime +from typing import Optional, Union +from unittest import mock + +import pytest +from sqlalchemy import column, types +from sqlalchemy.dialects import oracle +from sqlalchemy.dialects.oracle import DATE, NVARCHAR, VARCHAR +from sqlalchemy.sql import quoted_name + +from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm +from tests.unit_tests.fixtures.common import dttm + + +@pytest.mark.parametrize( + "column_name,expected_result", + [ + ("This_Is_32_Character_Column_Name", "3b26974078683be078219674eeb8f5"), + ("snake_label", "snake_label"), + ("camelLabel", "camelLabel"), + ], +) +def test_oracle_sqla_column_name_length_exceeded( + column_name: str, expected_result: Union[str, quoted_name] +) -> None: + from superset.db_engine_specs.oracle import OracleEngineSpec + + label = OracleEngineSpec.make_label_compatible(column_name) + assert isinstance(label, quoted_name) + assert label.quote is True + assert label == expected_result + + +def test_oracle_time_expression_reserved_keyword_1m_grain() -> None: + from superset.db_engine_specs.oracle import OracleEngineSpec + + col = column("decimal") + expr = OracleEngineSpec.get_timestamp_expr(col, None, "P1M") + result = str(expr.compile(dialect=oracle.dialect())) + assert result == "TRUNC(CAST(\"decimal\" as DATE), 'MONTH')" + + +@pytest.mark.parametrize( + "sqla_type,expected_result", + [ + (DATE(), "DATE"), + (VARCHAR(length=255), "VARCHAR(255 CHAR)"), + (VARCHAR(length=255, collation="utf8"), "VARCHAR(255 CHAR)"), + (NVARCHAR(length=128), "NVARCHAR2(128)"), + ], +) +def test_column_datatype_to_string( + sqla_type: types.TypeEngine, expected_result: str +) -> None: + from superset.db_engine_specs.oracle import OracleEngineSpec + + assert ( + OracleEngineSpec.column_datatype_to_string(sqla_type, oracle.dialect()) + == expected_result + ) + + +def test_fetch_data_no_description() -> None: + from superset.db_engine_specs.oracle import OracleEngineSpec + + cursor = mock.MagicMock() + cursor.description = [] + assert OracleEngineSpec.fetch_data(cursor) == [] + + +def test_fetch_data() -> None: + from superset.db_engine_specs.oracle import OracleEngineSpec + + cursor = mock.MagicMock() + result = ["a", "b"] + cursor.fetchall.return_value = result + assert OracleEngineSpec.fetch_data(cursor) == result + + +@pytest.mark.parametrize( + "target_type,expected_result", + [ + ("Date", "TO_DATE('2019-01-02', 'YYYY-MM-DD')"), + ("DateTime", """TO_DATE('2019-01-02T03:04:05', 'YYYY-MM-DD"T"HH24:MI:SS')"""), + ( + "TimeStamp", + """TO_TIMESTAMP('2019-01-02T03:04:05.678900', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')""", + ), + ("Other", None), + ], +) +def test_convert_dttm( + target_type: str, expected_result: Optional[str], dttm: datetime +) -> None: + from superset.db_engine_specs.oracle import OracleEngineSpec as spec + + assert_convert_dttm(spec, target_type, expected_result, dttm) diff --git a/tests/unit_tests/db_engine_specs/test_postgres.py b/tests/unit_tests/db_engine_specs/test_postgres.py new file mode 100644 index 000000000..088ce2747 --- /dev/null +++ b/tests/unit_tests/db_engine_specs/test_postgres.py @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime +from typing import Any, Dict, Optional, Type + +import pytest +from sqlalchemy import types +from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON + +from superset.utils.core import GenericDataType +from tests.unit_tests.db_engine_specs.utils import ( + assert_column_spec, + assert_convert_dttm, +) +from tests.unit_tests.fixtures.common import dttm + + +@pytest.mark.parametrize( + "target_type,expected_result", + [ + ("Date", "TO_DATE('2019-01-02', 'YYYY-MM-DD')"), + ( + "DateTime", + "TO_TIMESTAMP('2019-01-02 03:04:05.678900', 'YYYY-MM-DD HH24:MI:SS.US')", + ), + ( + "TimeStamp", + "TO_TIMESTAMP('2019-01-02 03:04:05.678900', 'YYYY-MM-DD HH24:MI:SS.US')", + ), + ("UnknownType", None), + ], +) +def test_convert_dttm( + target_type: str, expected_result: Optional[str], dttm: datetime +) -> None: + from superset.db_engine_specs.postgres import PostgresEngineSpec as spec + + assert_convert_dttm(spec, target_type, expected_result, dttm) + + +@pytest.mark.parametrize( + "native_type,sqla_type,attrs,generic_type,is_dttm", + [ + ("SMALLINT", types.SmallInteger, None, GenericDataType.NUMERIC, False), + ("INTEGER", types.Integer, None, GenericDataType.NUMERIC, False), + ("BIGINT", types.BigInteger, None, GenericDataType.NUMERIC, False), + ("DECIMAL", types.Numeric, None, GenericDataType.NUMERIC, False), + ("NUMERIC", types.Numeric, None, GenericDataType.NUMERIC, False), + ("REAL", types.REAL, None, GenericDataType.NUMERIC, False), + ("DOUBLE PRECISION", DOUBLE_PRECISION, None, GenericDataType.NUMERIC, False), + ("MONEY", types.Numeric, None, GenericDataType.NUMERIC, False), + # String + ("CHAR", types.String, None, GenericDataType.STRING, False), + ("VARCHAR", types.String, None, GenericDataType.STRING, False), + ("TEXT", types.String, None, GenericDataType.STRING, False), + ("ARRAY", types.String, None, GenericDataType.STRING, False), + ("ENUM", ENUM, None, GenericDataType.STRING, False), + ("JSON", JSON, None, GenericDataType.STRING, False), + # Temporal + ("DATE", types.Date, None, GenericDataType.TEMPORAL, True), + ("TIMESTAMP", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True), + ("TIME", types.Time, None, GenericDataType.TEMPORAL, True), + # Boolean + ("BOOLEAN", types.Boolean, None, GenericDataType.BOOLEAN, False), + ], +) +def test_get_column_spec( + native_type: str, + sqla_type: Type[types.TypeEngine], + attrs: Optional[Dict[str, Any]], + generic_type: GenericDataType, + is_dttm: bool, +) -> None: + from superset.db_engine_specs.postgres import PostgresEngineSpec as spec + + assert_column_spec(spec, native_type, sqla_type, attrs, generic_type, is_dttm) diff --git a/tests/unit_tests/db_engine_specs/test_presto.py b/tests/unit_tests/db_engine_specs/test_presto.py index 0f0777d0c..a30fab94c 100644 --- a/tests/unit_tests/db_engine_specs/test_presto.py +++ b/tests/unit_tests/db_engine_specs/test_presto.py @@ -15,14 +15,21 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import Optional +from typing import Any, Dict, Optional, Type import pytest import pytz +from sqlalchemy import types + +from superset.utils.core import GenericDataType +from tests.unit_tests.db_engine_specs.utils import ( + assert_column_spec, + assert_convert_dttm, +) @pytest.mark.parametrize( - "target_type,dttm,result", + "target_type,dttm,expected_result", [ ("VARCHAR", datetime(2022, 1, 1), None), ("DATE", datetime(2022, 1, 1), "DATE '2022-01-01'"), @@ -46,9 +53,32 @@ import pytz def test_convert_dttm( target_type: str, dttm: datetime, - result: Optional[str], + expected_result: Optional[str], ) -> None: - from superset.db_engine_specs.presto import PrestoEngineSpec + from superset.db_engine_specs.presto import PrestoEngineSpec as spec - for case in (str.lower, str.upper): - assert PrestoEngineSpec.convert_dttm(case(target_type), dttm) == result + assert_convert_dttm(spec, target_type, expected_result, dttm) + + +@pytest.mark.parametrize( + "native_type,sqla_type,attrs,generic_type,is_dttm", + [ + ("varchar(255)", types.VARCHAR, {"length": 255}, GenericDataType.STRING, False), + ("varchar", types.String, None, GenericDataType.STRING, False), + ("char(255)", types.CHAR, {"length": 255}, GenericDataType.STRING, False), + ("char", types.String, None, GenericDataType.STRING, False), + ("integer", types.Integer, None, GenericDataType.NUMERIC, False), + ("time", types.Time, None, GenericDataType.TEMPORAL, True), + ("timestamp", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True), + ], +) +def test_get_column_spec( + native_type: str, + sqla_type: Type[types.TypeEngine], + attrs: Optional[Dict[str, Any]], + generic_type: GenericDataType, + is_dttm: bool, +) -> None: + from superset.db_engine_specs.presto import PrestoEngineSpec as spec + + assert_column_spec(spec, native_type, sqla_type, attrs, generic_type, is_dttm) diff --git a/tests/unit_tests/db_engine_specs/test_rockset.py b/tests/unit_tests/db_engine_specs/test_rockset.py new file mode 100644 index 000000000..c501dccf2 --- /dev/null +++ b/tests/unit_tests/db_engine_specs/test_rockset.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime +from typing import Optional + +import pytest + +from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm +from tests.unit_tests.fixtures.common import dttm + + +@pytest.mark.parametrize( + "target_type,expected_result", + [ + ("Date", "DATE '2019-01-02'"), + ("DateTime", "DATETIME '2019-01-02 03:04:05.678900'"), + ("Timestamp", "TIMESTAMP '2019-01-02T03:04:05.678900'"), + ("UnknownType", None), + ], +) +def test_convert_dttm( + target_type: str, expected_result: Optional[str], dttm: datetime +) -> None: + from superset.db_engine_specs.rockset import RocksetEngineSpec as spec + + assert_convert_dttm(spec, target_type, expected_result, dttm) diff --git a/tests/unit_tests/db_engine_specs/test_snowflake.py b/tests/unit_tests/db_engine_specs/test_snowflake.py index 2f1171c33..9689428d2 100644 --- a/tests/unit_tests/db_engine_specs/test_snowflake.py +++ b/tests/unit_tests/db_engine_specs/test_snowflake.py @@ -19,28 +19,43 @@ import json from datetime import datetime +from typing import Optional from unittest import mock import pytest from pytest_mock import MockerFixture from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm from tests.unit_tests.fixtures.common import dttm @pytest.mark.parametrize( - "actual,expected", + "target_type,expected_result", [ - ("DATE", "TO_DATE('2019-01-02')"), - ("DATETIME", "CAST('2019-01-02T03:04:05.678900' AS DATETIME)"), - ("TIMESTAMP", "TO_TIMESTAMP('2019-01-02T03:04:05.678900')"), + ("Date", "TO_DATE('2019-01-02')"), + ("DateTime", "CAST('2019-01-02T03:04:05.678900' AS DATETIME)"), + ("TimeStamp", "TO_TIMESTAMP('2019-01-02T03:04:05.678900')"), ("TIMESTAMP_NTZ", "TO_TIMESTAMP('2019-01-02T03:04:05.678900')"), + ("TIMESTAMP_LTZ", "TO_TIMESTAMP('2019-01-02T03:04:05.678900')"), + ("TIMESTAMP_TZ", "TO_TIMESTAMP('2019-01-02T03:04:05.678900')"), + ("TIMESTAMPLTZ", "TO_TIMESTAMP('2019-01-02T03:04:05.678900')"), + ("TIMESTAMPNTZ", "TO_TIMESTAMP('2019-01-02T03:04:05.678900')"), + ("TIMESTAMPTZ", "TO_TIMESTAMP('2019-01-02T03:04:05.678900')"), + ( + "TIMESTAMP WITH LOCAL TIME ZONE", + "TO_TIMESTAMP('2019-01-02T03:04:05.678900')", + ), + ("TIMESTAMP WITHOUT TIME ZONE", "TO_TIMESTAMP('2019-01-02T03:04:05.678900')"), + ("UnknownType", None), ], ) -def test_convert_dttm(actual: str, expected: str, dttm: datetime) -> None: - from superset.db_engine_specs.snowflake import SnowflakeEngineSpec +def test_convert_dttm( + target_type: str, expected_result: Optional[str], dttm: datetime +) -> None: + from superset.db_engine_specs.snowflake import SnowflakeEngineSpec as spec - assert SnowflakeEngineSpec.convert_dttm(actual, dttm) == expected + assert_convert_dttm(spec, target_type, expected_result, dttm) def test_database_connection_test_mutator() -> None: diff --git a/tests/unit_tests/db_engine_specs/test_sqlite.py b/tests/unit_tests/db_engine_specs/test_sqlite.py index 76ea4fdff..11ce174c0 100644 --- a/tests/unit_tests/db_engine_specs/test_sqlite.py +++ b/tests/unit_tests/db_engine_specs/test_sqlite.py @@ -16,30 +16,32 @@ # under the License. # pylint: disable=invalid-name, unused-argument, import-outside-toplevel, redefined-outer-name from datetime import datetime -from unittest import mock +from typing import Optional import pytest from sqlalchemy.engine import create_engine +from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm from tests.unit_tests.fixtures.common import dttm -def test_convert_dttm(dttm: datetime) -> None: - from superset.db_engine_specs.sqlite import SqliteEngineSpec +@pytest.mark.parametrize( + "target_type,expected_result", + [ + ("Text", "'2019-01-02 03:04:05'"), + ("DateTime", "'2019-01-02 03:04:05'"), + ("TimeStamp", "'2019-01-02 03:04:05'"), + ("Other", None), + ], +) +def test_convert_dttm( + target_type: str, + expected_result: Optional[str], + dttm: datetime, +) -> None: + from superset.db_engine_specs.sqlite import SqliteEngineSpec as spec - assert SqliteEngineSpec.convert_dttm("TEXT", dttm) == "'2019-01-02 03:04:05'" - - -def test_convert_dttm_lower(dttm: datetime) -> None: - from superset.db_engine_specs.sqlite import SqliteEngineSpec - - assert SqliteEngineSpec.convert_dttm("text", dttm) == "'2019-01-02 03:04:05'" - - -def test_convert_dttm_invalid_type(dttm: datetime) -> None: - from superset.db_engine_specs.sqlite import SqliteEngineSpec - - assert SqliteEngineSpec.convert_dttm("other", dttm) is None + assert_convert_dttm(spec, target_type, expected_result, dttm) @pytest.mark.parametrize( diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py index 382b65ce5..0ea296a07 100644 --- a/tests/unit_tests/db_engine_specs/test_trino.py +++ b/tests/unit_tests/db_engine_specs/test_trino.py @@ -16,17 +16,288 @@ # under the License. # pylint: disable=unused-argument, import-outside-toplevel, protected-access import json -from typing import Any, Dict -from unittest import mock +from datetime import datetime +from typing import Any, Dict, Optional, Type +from unittest.mock import Mock, patch +import pandas as pd import pytest from pytest_mock import MockerFixture +from sqlalchemy import types -from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY +import superset.config +from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT +from superset.utils.core import GenericDataType +from tests.unit_tests.db_engine_specs.utils import ( + assert_column_spec, + assert_convert_dttm, +) +from tests.unit_tests.fixtures.common import dttm -@mock.patch("sqlalchemy.engine.Engine.connect") -def test_cancel_query_success(engine_mock: mock.Mock) -> None: +@pytest.mark.parametrize( + "extra,expected", + [ + ({}, {"engine_params": {"connect_args": {"source": USER_AGENT}}}), + ( + { + "first": 1, + "engine_params": { + "second": "two", + "connect_args": {"source": "foobar", "third": "three"}, + }, + }, + { + "first": 1, + "engine_params": { + "second": "two", + "connect_args": {"source": "foobar", "third": "three"}, + }, + }, + ), + ], +) +def test_get_extra_params(extra: Dict[str, Any], expected: Dict[str, Any]) -> None: + from superset.db_engine_specs.trino import TrinoEngineSpec + + database = Mock() + + database.extra = json.dumps(extra) + database.server_cert = None + assert TrinoEngineSpec.get_extra_params(database) == expected + + +@patch("superset.utils.core.create_ssl_cert_file") +def test_get_extra_params_with_server_cert(mock_create_ssl_cert_file: Mock) -> None: + from superset.db_engine_specs.trino import TrinoEngineSpec + + database = Mock() + + database.extra = json.dumps({}) + database.server_cert = "TEST_CERT" + mock_create_ssl_cert_file.return_value = "/path/to/tls.crt" + extra = TrinoEngineSpec.get_extra_params(database) + + connect_args = extra.get("engine_params", {}).get("connect_args", {}) + assert connect_args.get("http_scheme") == "https" + assert connect_args.get("verify") == "/path/to/tls.crt" + mock_create_ssl_cert_file.assert_called_once_with(database.server_cert) + + +@patch("trino.auth.BasicAuthentication") +def test_auth_basic(mock_auth: Mock) -> None: + from superset.db_engine_specs.trino import TrinoEngineSpec + + database = Mock() + + auth_params = {"username": "username", "password": "password"} + database.encrypted_extra = json.dumps( + {"auth_method": "basic", "auth_params": auth_params} + ) + + params: Dict[str, Any] = {} + TrinoEngineSpec.update_params_from_encrypted_extra(database, params) + connect_args = params.setdefault("connect_args", {}) + assert connect_args.get("http_scheme") == "https" + mock_auth.assert_called_once_with(**auth_params) + + +@patch("trino.auth.KerberosAuthentication") +def test_auth_kerberos(mock_auth: Mock) -> None: + from superset.db_engine_specs.trino import TrinoEngineSpec + + database = Mock() + + auth_params = { + "service_name": "superset", + "mutual_authentication": False, + "delegate": True, + } + database.encrypted_extra = json.dumps( + {"auth_method": "kerberos", "auth_params": auth_params} + ) + + params: Dict[str, Any] = {} + TrinoEngineSpec.update_params_from_encrypted_extra(database, params) + connect_args = params.setdefault("connect_args", {}) + assert connect_args.get("http_scheme") == "https" + mock_auth.assert_called_once_with(**auth_params) + + +@patch("trino.auth.CertificateAuthentication") +def test_auth_certificate(mock_auth: Mock) -> None: + from superset.db_engine_specs.trino import TrinoEngineSpec + + database = Mock() + auth_params = {"cert": "/path/to/cert.pem", "key": "/path/to/key.pem"} + database.encrypted_extra = json.dumps( + {"auth_method": "certificate", "auth_params": auth_params} + ) + + params: Dict[str, Any] = {} + TrinoEngineSpec.update_params_from_encrypted_extra(database, params) + connect_args = params.setdefault("connect_args", {}) + assert connect_args.get("http_scheme") == "https" + mock_auth.assert_called_once_with(**auth_params) + + +@patch("trino.auth.JWTAuthentication") +def test_auth_jwt(mock_auth: Mock) -> None: + from superset.db_engine_specs.trino import TrinoEngineSpec + + database = Mock() + + auth_params = {"token": "jwt-token-string"} + database.encrypted_extra = json.dumps( + {"auth_method": "jwt", "auth_params": auth_params} + ) + + params: Dict[str, Any] = {} + TrinoEngineSpec.update_params_from_encrypted_extra(database, params) + connect_args = params.setdefault("connect_args", {}) + assert connect_args.get("http_scheme") == "https" + mock_auth.assert_called_once_with(**auth_params) + + +def test_auth_custom_auth() -> None: + from superset.db_engine_specs.trino import TrinoEngineSpec + + database = Mock() + auth_class = Mock() + + auth_method = "custom_auth" + auth_params = {"params1": "params1", "params2": "params2"} + database.encrypted_extra = json.dumps( + {"auth_method": auth_method, "auth_params": auth_params} + ) + + with patch.dict( + "superset.config.ALLOWED_EXTRA_AUTHENTICATIONS", + {"trino": {"custom_auth": auth_class}}, + clear=True, + ): + params: Dict[str, Any] = {} + TrinoEngineSpec.update_params_from_encrypted_extra(database, params) + + connect_args = params.setdefault("connect_args", {}) + assert connect_args.get("http_scheme") == "https" + + auth_class.assert_called_once_with(**auth_params) + + +def test_auth_custom_auth_denied() -> None: + from superset.db_engine_specs.trino import TrinoEngineSpec + + database = Mock() + auth_method = "my.module:TrinoAuthClass" + auth_params = {"params1": "params1", "params2": "params2"} + database.encrypted_extra = json.dumps( + {"auth_method": auth_method, "auth_params": auth_params} + ) + + superset.config.ALLOWED_EXTRA_AUTHENTICATIONS = {} + + with pytest.raises(ValueError) as excinfo: + TrinoEngineSpec.update_params_from_encrypted_extra(database, {}) + + assert str(excinfo.value) == ( + f"For security reason, custom authentication '{auth_method}' " + f"must be listed in 'ALLOWED_EXTRA_AUTHENTICATIONS' config" + ) + + +@pytest.mark.parametrize( + "native_type,sqla_type,attrs,generic_type,is_dttm", + [ + ("BOOLEAN", types.Boolean, None, GenericDataType.BOOLEAN, False), + ("TINYINT", types.Integer, None, GenericDataType.NUMERIC, False), + ("SMALLINT", types.SmallInteger, None, GenericDataType.NUMERIC, False), + ("INTEGER", types.Integer, None, GenericDataType.NUMERIC, False), + ("BIGINT", types.BigInteger, None, GenericDataType.NUMERIC, False), + ("REAL", types.FLOAT, None, GenericDataType.NUMERIC, False), + ("DOUBLE", types.FLOAT, None, GenericDataType.NUMERIC, False), + ("DECIMAL", types.DECIMAL, None, GenericDataType.NUMERIC, False), + ("VARCHAR", types.String, None, GenericDataType.STRING, False), + ("VARCHAR(20)", types.VARCHAR, {"length": 20}, GenericDataType.STRING, False), + ("CHAR", types.String, None, GenericDataType.STRING, False), + ("CHAR(2)", types.CHAR, {"length": 2}, GenericDataType.STRING, False), + ("JSON", types.JSON, None, GenericDataType.STRING, False), + ("TIMESTAMP", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True), + ("TIMESTAMP(3)", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True), + ( + "TIMESTAMP WITH TIME ZONE", + types.TIMESTAMP, + None, + GenericDataType.TEMPORAL, + True, + ), + ( + "TIMESTAMP(3) WITH TIME ZONE", + types.TIMESTAMP, + None, + GenericDataType.TEMPORAL, + True, + ), + ("DATE", types.Date, None, GenericDataType.TEMPORAL, True), + ], +) +def test_get_column_spec( + native_type: str, + sqla_type: Type[types.TypeEngine], + attrs: Optional[Dict[str, Any]], + generic_type: GenericDataType, + is_dttm: bool, +) -> None: + from superset.db_engine_specs.trino import TrinoEngineSpec as spec + + assert_column_spec( + spec, + native_type, + sqla_type, + attrs, + generic_type, + is_dttm, + ) + + +@pytest.mark.parametrize( + "target_type,expected_result", + [ + ("TimeStamp", "TIMESTAMP '2019-01-02 03:04:05.678900'"), + ("TimeStamp(3)", "TIMESTAMP '2019-01-02 03:04:05.678900'"), + ("TimeStamp With Time Zone", "TIMESTAMP '2019-01-02 03:04:05.678900'"), + ("TimeStamp(3) With Time Zone", "TIMESTAMP '2019-01-02 03:04:05.678900'"), + ("Date", "DATE '2019-01-02'"), + ("Other", None), + ], +) +def test_convert_dttm( + target_type: str, + expected_result: Optional[str], + dttm: datetime, +) -> None: + from superset.db_engine_specs.trino import TrinoEngineSpec + + assert_convert_dttm(TrinoEngineSpec, target_type, expected_result, dttm) + + +def test_extra_table_metadata() -> None: + from superset.db_engine_specs.trino import TrinoEngineSpec + + db_mock = Mock() + db_mock.get_indexes = Mock( + return_value=[{"column_names": ["ds", "hour"], "name": "partition"}] + ) + db_mock.get_extra = Mock(return_value={}) + db_mock.has_view_by_name = Mock(return_value=None) + db_mock.get_df = Mock(return_value=pd.DataFrame({"ds": ["01-01-19"], "hour": [1]})) + result = TrinoEngineSpec.extra_table_metadata(db_mock, "test_table", "test_schema") + assert result["partitions"]["cols"] == ["ds", "hour"] + assert result["partitions"]["latest"] == {"ds": "01-01-19", "hour": 1} + + +@patch("sqlalchemy.engine.Engine.connect") +def test_cancel_query_success(engine_mock: Mock) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec from superset.models.sql_lab import Query @@ -35,8 +306,8 @@ def test_cancel_query_success(engine_mock: mock.Mock) -> None: assert TrinoEngineSpec.cancel_query(cursor_mock, query, "123") is True -@mock.patch("sqlalchemy.engine.Engine.connect") -def test_cancel_query_failed(engine_mock: mock.Mock) -> None: +@patch("sqlalchemy.engine.Engine.connect") +def test_cancel_query_failed(engine_mock: Mock) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec from superset.models.sql_lab import Query @@ -67,11 +338,11 @@ def test_prepare_cancel_query( @pytest.mark.parametrize("cancel_early", [True, False]) -@mock.patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query") -@mock.patch("sqlalchemy.engine.Engine.connect") +@patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query") +@patch("sqlalchemy.engine.Engine.connect") def test_handle_cursor_early_cancel( - engine_mock: mock.Mock, - cancel_query_mock: mock.Mock, + engine_mock: Mock, + cancel_query_mock: Mock, cancel_early: bool, mocker: MockerFixture, ) -> None: diff --git a/tests/unit_tests/db_engine_specs/utils.py b/tests/unit_tests/db_engine_specs/utils.py new file mode 100644 index 000000000..13ae7a34d --- /dev/null +++ b/tests/unit_tests/db_engine_specs/utils.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime +from typing import Any, Dict, Optional, Type, TYPE_CHECKING + +from sqlalchemy import types + +from superset.utils.core import GenericDataType + +if TYPE_CHECKING: + from superset.db_engine_specs.base import BaseEngineSpec + + +def assert_convert_dttm( + db_engine_spec: Type[BaseEngineSpec], + target_type: str, + expected_result: Optional[str], + dttm: datetime, + db_extra: Optional[Dict[str, Any]] = None, +) -> None: + for target in ( + target_type, + target_type.upper(), + target_type.lower(), + target_type.capitalize(), + ): + assert ( + result := db_engine_spec.convert_dttm( + target_type=target, + dttm=dttm, + db_extra=db_extra, + ) + ) == expected_result, result + + +def assert_column_spec( + db_engine_spec: Type[BaseEngineSpec], + native_type: str, + sqla_type: Type[types.TypeEngine], + attrs: Optional[Dict[str, Any]], + generic_type: GenericDataType, + is_dttm: bool, +) -> None: + assert (column_spec := db_engine_spec.get_column_spec(native_type)) is not None + assert isinstance(column_spec.sqla_type, sqla_type) + + for key, value in (attrs or {}).items(): + assert getattr(column_spec.sqla_type, key) == value + + assert column_spec.generic_type == generic_type + assert column_spec.is_dttm == is_dttm