chore(db_engine_specs): clean up column spec logic and add tests (#22871)

This commit is contained in:
Ville Brofeldt 2023-01-31 15:54:07 +02:00 committed by GitHub
parent 8466eec228
commit cd6fc35f60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
73 changed files with 1953 additions and 1463 deletions

View File

@ -19,10 +19,10 @@ from datetime import datetime
from typing import Any, Dict, Optional, Pattern, Tuple
from flask_babel import gettext as __
from sqlalchemy import types
from superset.db_engine_specs.base import BaseEngineSpec
from superset.errors import SupersetErrorType
from superset.utils import core as utils
SYNTAX_ERROR_REGEX = re.compile(
": mismatched input '(?P<syntax_error>.*?)'. Expecting: "
@ -66,10 +66,11 @@ class AthenaEngineSpec(BaseEngineSpec):
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
return f"DATE '{dttm.date().isoformat()}'"
if tt == utils.TemporalType.TIMESTAMP:
if isinstance(sqla_type, types.TIMESTAMP):
datetime_formatted = dttm.isoformat(sep=" ", timespec="milliseconds")
return f"""TIMESTAMP '{datetime_formatted}'"""
return None

View File

@ -197,7 +197,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
_date_trunc_functions: Dict[str, str] = {}
_time_grain_expressions: Dict[Optional[str], str] = {}
column_type_mappings: Tuple[ColumnTypeMapping, ...] = (
_default_column_type_mappings: Tuple[ColumnTypeMapping, ...] = (
(
re.compile(r"^string", re.IGNORECASE),
types.String(),
@ -314,6 +314,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
GenericDataType.BOOLEAN,
),
)
# engine-specific type mappings to check prior to the defaults
column_type_mappings: Tuple[ColumnTypeMapping, ...] = ()
# Does database support join-free timeslot grouping
time_groupby_inline = False
@ -1389,24 +1391,25 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return label_mutated
@classmethod
def get_sqla_column_type(
def get_column_types(
cls,
column_type: Optional[str],
column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings,
) -> Optional[Tuple[TypeEngine, GenericDataType]]:
"""
Return a sqlalchemy native column type that corresponds to the column type
defined in the data source (return None to use default type inferred by
SQLAlchemy). Override `column_type_mappings` for specific needs
Return a sqlalchemy native column type and generic data type that corresponds
to the column type defined in the data source (return None to use default type
inferred by SQLAlchemy). Override `column_type_mappings` for specific needs
(see MSSQL for example of NCHAR/NVARCHAR handling).
:param column_type: Column type returned by inspector
:param column_type_mappings: Maps from string to SqlAlchemy TypeEngine
:return: SqlAlchemy column type
:return: SQLAlchemy and generic Superset column types
"""
if not column_type:
return None
for regex, sqla_type, generic_type in column_type_mappings:
for regex, sqla_type, generic_type in (
cls.column_type_mappings + cls._default_column_type_mappings
):
match = regex.match(column_type)
if not match:
continue
@ -1569,19 +1572,16 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
native_type: Optional[str],
db_extra: Optional[Dict[str, Any]] = None,
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings,
) -> Optional[ColumnSpec]:
"""
Converts native database type to sqlalchemy column type.
Get generic type related specs regarding a native column type.
:param native_type: Native database type
:param db_extra: The database extra object
:param source: Type coming from the database table or cursor description
:param column_type_mappings: Maps from string to SqlAlchemy TypeEngine
:return: ColumnSpec object
"""
col_types = cls.get_sqla_column_type(
native_type, column_type_mappings=column_type_mappings
)
col_types = cls.get_column_types(native_type)
if col_types:
column_type, generic_type = col_types
is_dttm = generic_type == GenericDataType.TEMPORAL
@ -1590,6 +1590,28 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
)
return None
@classmethod
def get_sqla_column_type(
cls,
native_type: Optional[str],
db_extra: Optional[Dict[str, Any]] = None,
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
) -> Optional[TypeEngine]:
"""
Converts native database type to sqlalchemy column type.
:param native_type: Native database type
:param db_extra: The database extra object
:param source: Type coming from the database table or cursor description
:return: ColumnSpec object
"""
column_spec = cls.get_column_spec(
native_type=native_type,
db_extra=db_extra,
source=source,
)
return column_spec.sqla_type if column_spec else None
# pylint: disable=unused-argument
@classmethod
def prepare_cancel_query(cls, query: Query, session: Session) -> None:

View File

@ -26,7 +26,7 @@ from apispec.ext.marshmallow import MarshmallowPlugin
from flask_babel import gettext as __
from marshmallow import fields, Schema
from marshmallow.exceptions import ValidationError
from sqlalchemy import column
from sqlalchemy import column, types
from sqlalchemy.engine.base import Engine
from sqlalchemy.sql import sqltypes
from typing_extensions import TypedDict
@ -201,15 +201,15 @@ class BigQueryEngineSpec(BaseEngineSpec):
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
return f"CAST('{dttm.date().isoformat()}' AS DATE)"
if tt == utils.TemporalType.DATETIME:
return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS DATETIME)"""
if tt == utils.TemporalType.TIME:
return f"""CAST('{dttm.strftime("%H:%M:%S.%f")}' AS TIME)"""
if tt == utils.TemporalType.TIMESTAMP:
if isinstance(sqla_type, types.TIMESTAMP):
return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS TIMESTAMP)"""
if isinstance(sqla_type, types.DateTime):
return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS DATETIME)"""
if isinstance(sqla_type, types.Time):
return f"""CAST('{dttm.strftime("%H:%M:%S.%f")}' AS TIME)"""
return None
@classmethod

View File

@ -18,12 +18,12 @@ import logging
from datetime import datetime
from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING
from sqlalchemy import types
from urllib3.exceptions import NewConnectionError
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.exceptions import SupersetDBAPIDatabaseError
from superset.extensions import cache_manager
from superset.utils import core as utils
if TYPE_CHECKING:
# prevent circular imports
@ -77,10 +77,11 @@ class ClickHouseEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
return f"toDate('{dttm.date().isoformat()}')"
if tt == utils.TemporalType.DATETIME:
if isinstance(sqla_type, types.DateTime):
return f"""toDateTime('{dttm.isoformat(sep=" ", timespec="seconds")}')"""
return None

View File

@ -14,11 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from datetime import datetime
from typing import Any, Dict, Optional, TYPE_CHECKING
from sqlalchemy import types
from superset.db_engine_specs.base import BaseEngineSpec
from superset.utils import core as utils
if TYPE_CHECKING:
from superset.connectors.sqla.models import TableColumn
@ -53,12 +56,13 @@ class CrateEngineSpec(BaseEngineSpec):
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.TIMESTAMP:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.TIMESTAMP):
return f"{dttm.timestamp() * 1000}"
return None
@classmethod
def alter_new_orm_column(cls, orm_col: "TableColumn") -> None:
def alter_new_orm_column(cls, orm_col: TableColumn) -> None:
if orm_col.type == "TIMESTAMP":
orm_col.python_date_format = "epoch_ms"

View File

@ -17,8 +17,9 @@
from datetime import datetime
from typing import Any, Dict, Optional
from sqlalchemy import types
from superset.db_engine_specs.base import BaseEngineSpec
from superset.utils import core as utils
class DremioEngineSpec(BaseEngineSpec):
@ -46,10 +47,11 @@ class DremioEngineSpec(BaseEngineSpec):
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
return f"TO_DATE('{dttm.date().isoformat()}', 'YYYY-MM-DD')"
if tt == utils.TemporalType.TIMESTAMP:
if isinstance(sqla_type, types.TIMESTAMP):
dttm_formatted = dttm.isoformat(sep=" ", timespec="milliseconds")
return f"""TO_TIMESTAMP('{dttm_formatted}', 'YYYY-MM-DD HH24:MI:SS.FFF')"""
return None

View File

@ -18,11 +18,11 @@ from datetime import datetime
from typing import Any, Dict, Optional
from urllib import parse
from sqlalchemy import types
from sqlalchemy.engine.url import URL
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.exceptions import SupersetDBAPIProgrammingError
from superset.utils import core as utils
class DrillEngineSpec(BaseEngineSpec):
@ -59,10 +59,11 @@ class DrillEngineSpec(BaseEngineSpec):
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
return f"TO_DATE('{dttm.date().isoformat()}', 'yyyy-MM-dd')"
if tt == utils.TemporalType.TIMESTAMP:
if isinstance(sqla_type, types.TIMESTAMP):
datetime_formatted = dttm.isoformat(sep=" ", timespec="seconds")
return f"""TO_TIMESTAMP('{datetime_formatted}', 'yyyy-MM-dd HH:mm:ss')"""
return None

View File

@ -14,11 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import json
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING
from sqlalchemy import types
from sqlalchemy.engine.reflection import Inspector
from superset import is_feature_enabled
@ -70,12 +74,12 @@ class DruidEngineSpec(BaseEngineSpec):
}
@classmethod
def alter_new_orm_column(cls, orm_col: "TableColumn") -> None:
def alter_new_orm_column(cls, orm_col: TableColumn) -> None:
if orm_col.column_name == "__time":
orm_col.is_dttm = True
@staticmethod
def get_extra_params(database: "Database") -> Dict[str, Any]:
def get_extra_params(database: Database) -> Dict[str, Any]:
"""
For Druid, the path to a SSL certificate is placed in `connect_args`.
@ -102,10 +106,11 @@ class DruidEngineSpec(BaseEngineSpec):
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
return f"CAST(TIME_PARSE('{dttm.date().isoformat()}') AS DATE)"
if tt in (utils.TemporalType.DATETIME, utils.TemporalType.TIMESTAMP):
if isinstance(sqla_type, (types.DateTime, types.TIMESTAMP)):
return f"""TIME_PARSE('{dttm.isoformat(timespec="seconds")}')"""
return None

View File

@ -21,11 +21,11 @@ from datetime import datetime
from typing import Any, Dict, Optional, Pattern, Set, Tuple, TYPE_CHECKING
from flask_babel import gettext as __
from sqlalchemy import types
from sqlalchemy.engine.reflection import Inspector
from superset.db_engine_specs.base import BaseEngineSpec
from superset.errors import SupersetErrorType
from superset.utils import core as utils
if TYPE_CHECKING:
# prevent circular imports
@ -67,8 +67,9 @@ class DuckDBEngineSpec(BaseEngineSpec):
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt in (utils.TemporalType.TEXT, utils.TemporalType.DATETIME):
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, (types.String, types.DateTime)):
return f"""'{dttm.isoformat(sep=" ", timespec="microseconds")}'"""
return None

View File

@ -17,8 +17,9 @@
from datetime import datetime
from typing import Any, Dict, Optional
from sqlalchemy import types
from superset.db_engine_specs.base import BaseEngineSpec
from superset.utils import core as utils
class DynamoDBEngineSpec(BaseEngineSpec):
@ -56,7 +57,9 @@ class DynamoDBEngineSpec(BaseEngineSpec):
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt in (utils.TemporalType.TEXT, utils.TemporalType.DATETIME):
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, (types.String, types.DateTime)):
return f"""'{dttm.isoformat(sep=" ", timespec="seconds")}'"""
return None

View File

@ -19,13 +19,14 @@ from datetime import datetime
from distutils.version import StrictVersion
from typing import Any, Dict, Optional, Type
from sqlalchemy import types
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.exceptions import (
SupersetDBAPIDatabaseError,
SupersetDBAPIOperationalError,
SupersetDBAPIProgrammingError,
)
from superset.utils import core as utils
logger = logging.getLogger()
@ -68,7 +69,10 @@ class ElasticSearchEngineSpec(BaseEngineSpec): # pylint: disable=abstract-metho
) -> Optional[str]:
db_extra = db_extra or {}
if target_type.upper() == utils.TemporalType.DATETIME:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.DateTime):
es_version = db_extra.get("version")
# The elasticsearch CAST function does not take effect for the time zone
# setting. In elasticsearch7.8 and above, we can use the DATETIME_PARSE
@ -119,7 +123,9 @@ class OpenDistroEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
if target_type.upper() == utils.TemporalType.DATETIME:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.DateTime):
return f"""'{dttm.isoformat(timespec="seconds")}'"""
return None

View File

@ -17,8 +17,9 @@
from datetime import datetime
from typing import Any, Dict, Optional
from sqlalchemy import types
from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
from superset.utils import core as utils
class FirebirdEngineSpec(BaseEngineSpec):
@ -73,13 +74,14 @@ class FirebirdEngineSpec(BaseEngineSpec):
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.TIMESTAMP:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
return f"CAST('{dttm.date().isoformat()}' AS DATE)"
if isinstance(sqla_type, types.DateTime):
dttm_formatted = dttm.isoformat(sep=" ")
dttm_valid_precision = dttm_formatted[: len("YYYY-MM-DD HH:MM:SS.MMMM")]
return f"CAST('{dttm_valid_precision}' AS TIMESTAMP)"
if tt == utils.TemporalType.DATE:
return f"CAST('{dttm.date().isoformat()}' AS DATE)"
if tt == utils.TemporalType.TIME:
if isinstance(sqla_type, types.Time):
return f"CAST('{dttm.time().isoformat()}' AS TIME)"
return None

View File

@ -17,8 +17,9 @@
from datetime import datetime
from typing import Any, Dict, Optional
from sqlalchemy import types
from superset.db_engine_specs.base import BaseEngineSpec
from superset.utils import core as utils
class FireboltEngineSpec(BaseEngineSpec):
@ -44,13 +45,14 @@ class FireboltEngineSpec(BaseEngineSpec):
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
return f"CAST('{dttm.date().isoformat()}' AS DATE)"
if tt == utils.TemporalType.DATETIME:
return f"""CAST('{dttm.isoformat(timespec="seconds")}' AS DATETIME)"""
if tt == utils.TemporalType.TIMESTAMP:
if isinstance(sqla_type, types.TIMESTAMP):
return f"""CAST('{dttm.isoformat(timespec="seconds")}' AS TIMESTAMP)"""
if isinstance(sqla_type, types.DateTime):
return f"""CAST('{dttm.isoformat(timespec="seconds")}' AS DATETIME)"""
return None
@classmethod

View File

@ -17,9 +17,10 @@
from datetime import datetime
from typing import Any, Dict, Optional
from sqlalchemy import types
from superset.db_engine_specs.base import LimitMethod
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
from superset.utils import core as utils
class HanaEngineSpec(PostgresBaseEngineSpec):
@ -46,10 +47,11 @@ class HanaEngineSpec(PostgresBaseEngineSpec):
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
return f"TO_DATE('{dttm.date().isoformat()}', 'YYYY-MM-DD')"
if tt == utils.TemporalType.TIMESTAMP:
if isinstance(sqla_type, types.TIMESTAMP):
return f"""TO_TIMESTAMP('{dttm
.isoformat(timespec="microseconds")}', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')"""
return None

View File

@ -30,7 +30,7 @@ import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from flask import current_app, g
from sqlalchemy import Column, text
from sqlalchemy import Column, text, types
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
@ -45,7 +45,6 @@ from superset.exceptions import SupersetException
from superset.extensions import cache_manager
from superset.models.sql_lab import Query
from superset.sql_parse import ParsedQuery, Table
from superset.utils import core as utils
if TYPE_CHECKING:
# prevent circular imports
@ -249,10 +248,11 @@ class HiveEngineSpec(PrestoEngineSpec):
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
return f"CAST('{dttm.date().isoformat()}' AS DATE)"
if tt == utils.TemporalType.TIMESTAMP:
if isinstance(sqla_type, types.TIMESTAMP):
return f"""CAST('{dttm
.isoformat(sep=" ", timespec="microseconds")}' AS TIMESTAMP)"""
return None

View File

@ -21,13 +21,13 @@ from datetime import datetime
from typing import Any, Dict, List, Optional
from flask import current_app
from sqlalchemy import types
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.orm import Session
from superset.constants import QUERY_EARLY_CANCEL_KEY
from superset.db_engine_specs.base import BaseEngineSpec
from superset.models.sql_lab import Query
from superset.utils import core as utils
logger = logging.getLogger(__name__)
# Query 5543ffdf692b7d02:f78a944000000000: 3% Complete (17 out of 547)
@ -59,10 +59,11 @@ class ImpalaEngineSpec(BaseEngineSpec):
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
return f"CAST('{dttm.date().isoformat()}' AS DATE)"
if tt == utils.TemporalType.TIMESTAMP:
if isinstance(sqla_type, types.TIMESTAMP):
return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS TIMESTAMP)"""
return None

View File

@ -14,9 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import re
from datetime import datetime
from typing import Any, Dict, List, Optional, Type
from sqlalchemy import types
from sqlalchemy.dialects.mssql.base import SMALLDATETIME
from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
from superset.db_engine_specs.exceptions import (
SupersetDBAPIDatabaseError,
@ -24,7 +28,7 @@ from superset.db_engine_specs.exceptions import (
SupersetDBAPIProgrammingError,
)
from superset.sql_parse import ParsedQuery
from superset.utils import core as utils
from superset.utils.core import GenericDataType
class KustoSqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
@ -59,6 +63,14 @@ class KustoSqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
type_code_map: Dict[int, str] = {} # loaded from get_datatype only if needed
column_type_mappings = (
(
re.compile(r"^smalldatetime.*", re.IGNORECASE),
SMALLDATETIME(),
GenericDataType.TEMPORAL,
),
)
@classmethod
def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
# pylint: disable=import-outside-toplevel,import-error
@ -74,18 +86,19 @@ class KustoSqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
return f"CONVERT(DATE, '{dttm.date().isoformat()}', 23)"
if tt == utils.TemporalType.DATETIME:
datetime_formatted = dttm.isoformat(timespec="milliseconds")
return f"""CONVERT(DATETIME, '{datetime_formatted}', 126)"""
if tt == utils.TemporalType.SMALLDATETIME:
datetime_formatted = dttm.isoformat(sep=" ", timespec="seconds")
return f"""CONVERT(SMALLDATETIME, '{datetime_formatted}', 20)"""
if tt == utils.TemporalType.TIMESTAMP:
if isinstance(sqla_type, types.TIMESTAMP):
datetime_formatted = dttm.isoformat(sep=" ", timespec="seconds")
return f"""CONVERT(TIMESTAMP, '{datetime_formatted}', 20)"""
if isinstance(sqla_type, SMALLDATETIME):
datetime_formatted = dttm.isoformat(sep=" ", timespec="seconds")
return f"""CONVERT(SMALLDATETIME, '{datetime_formatted}', 20)"""
if isinstance(sqla_type, types.DateTime):
datetime_formatted = dttm.isoformat(timespec="milliseconds")
return f"""CONVERT(DATETIME, '{datetime_formatted}', 126)"""
return None
@classmethod
@ -132,13 +145,12 @@ class KustoKqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
if target_type.upper() in [
utils.TemporalType.DATETIME,
utils.TemporalType.TIMESTAMP,
]:
return f"""datetime({dttm.isoformat(timespec="microseconds")})"""
if target_type.upper() == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
return f"""datetime({dttm.date().isoformat()})"""
if isinstance(sqla_type, types.DateTime):
return f"""datetime({dttm.isoformat(timespec="microseconds")})"""
return None

View File

@ -17,8 +17,9 @@
from datetime import datetime
from typing import Any, Dict, Optional
from sqlalchemy import types
from superset.db_engine_specs.base import BaseEngineSpec
from superset.utils import core as utils
class KylinEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
@ -43,10 +44,11 @@ class KylinEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
return f"CAST('{dttm.date().isoformat()}' AS DATE)"
if tt == utils.TemporalType.TIMESTAMP:
if isinstance(sqla_type, types.TIMESTAMP):
datetime_fomatted = dttm.isoformat(sep=" ", timespec="seconds")
return f"""CAST('{datetime_fomatted}' AS TIMESTAMP)"""
return None

View File

@ -20,10 +20,12 @@ from datetime import datetime
from typing import Any, Dict, List, Optional, Pattern, Tuple
from flask_babel import gettext as __
from sqlalchemy import types
from sqlalchemy.dialects.mssql.base import SMALLDATETIME
from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
from superset.errors import SupersetErrorType
from superset.utils import core as utils
from superset.utils.core import GenericDataType
logger = logging.getLogger(__name__)
@ -70,6 +72,13 @@ class MssqlEngineSpec(BaseEngineSpec):
"1969-12-29T00:00:00Z/P1W": "DATEADD(WEEK,"
" DATEDIFF(WEEK, 0, DATEADD(DAY, -1, {col})), 0)",
}
column_type_mappings = (
(
re.compile(r"^smalldatetime.*", re.IGNORECASE),
SMALLDATETIME(),
GenericDataType.TEMPORAL,
),
)
custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
CONNECTION_ACCESS_DENIED_REGEX: (
@ -108,15 +117,16 @@ class MssqlEngineSpec(BaseEngineSpec):
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
return f"CONVERT(DATE, '{dttm.date().isoformat()}', 23)"
if tt == utils.TemporalType.DATETIME:
datetime_formatted = dttm.isoformat(timespec="milliseconds")
return f"""CONVERT(DATETIME, '{datetime_formatted}', 126)"""
if tt == utils.TemporalType.SMALLDATETIME:
if isinstance(sqla_type, SMALLDATETIME):
datetime_formatted = dttm.isoformat(sep=" ", timespec="seconds")
return f"""CONVERT(SMALLDATETIME, '{datetime_formatted}', 20)"""
if isinstance(sqla_type, types.DateTime):
datetime_formatted = dttm.isoformat(timespec="milliseconds")
return f"""CONVERT(DATETIME, '{datetime_formatted}', 126)"""
return None
@classmethod

View File

@ -20,6 +20,7 @@ from typing import Any, Dict, Optional, Pattern, Tuple
from urllib import parse
from flask_babel import gettext as __
from sqlalchemy import types
from sqlalchemy.dialects.mysql import (
BIT,
DECIMAL,
@ -34,15 +35,10 @@ from sqlalchemy.dialects.mysql import (
)
from sqlalchemy.engine.url import URL
from superset.db_engine_specs.base import (
BaseEngineSpec,
BasicParametersMixin,
ColumnTypeMapping,
)
from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin
from superset.errors import SupersetErrorType
from superset.models.sql_lab import Query
from superset.utils import core as utils
from superset.utils.core import ColumnSpec, GenericDataType
from superset.utils.core import GenericDataType
# Regular expressions to catch custom errors
CONNECTION_ACCESS_DENIED_REGEX = re.compile(
@ -182,10 +178,11 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
return f"STR_TO_DATE('{dttm.date().isoformat()}', '%Y-%m-%d')"
if tt == utils.TemporalType.DATETIME:
if isinstance(sqla_type, types.DateTime):
datetime_formatted = dttm.isoformat(sep=" ", timespec="microseconds")
return f"""STR_TO_DATE('{datetime_formatted}', '%Y-%m-%d %H:%i:%s.%f')"""
return None
@ -232,23 +229,6 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
pass
return message
@classmethod
def get_column_spec(
cls,
native_type: Optional[str],
db_extra: Optional[Dict[str, Any]] = None,
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings,
) -> Optional[ColumnSpec]:
column_spec = super().get_column_spec(native_type)
if column_spec:
return column_spec
return super().get_column_spec(
native_type, column_type_mappings=column_type_mappings
)
@classmethod
def get_cancel_query_id(cls, cursor: Any, query: Query) -> Optional[str]:
"""

View File

@ -17,8 +17,9 @@
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from sqlalchemy import types
from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
from superset.utils import core as utils
class OracleEngineSpec(BaseEngineSpec):
@ -44,15 +45,16 @@ class OracleEngineSpec(BaseEngineSpec):
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
return f"TO_DATE('{dttm.date().isoformat()}', 'YYYY-MM-DD')"
if tt == utils.TemporalType.DATETIME:
datetime_formatted = dttm.isoformat(timespec="seconds")
return f"""TO_DATE('{datetime_formatted}', 'YYYY-MM-DD"T"HH24:MI:SS')"""
if tt == utils.TemporalType.TIMESTAMP:
if isinstance(sqla_type, types.TIMESTAMP):
return f"""TO_TIMESTAMP('{dttm
.isoformat(timespec="microseconds")}', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')"""
if isinstance(sqla_type, types.DateTime):
datetime_formatted = dttm.isoformat(timespec="seconds")
return f"""TO_DATE('{datetime_formatted}', 'YYYY-MM-DD"T"HH24:MI:SS')"""
return None
@classmethod

View File

@ -21,20 +21,16 @@ from datetime import datetime
from typing import Any, Dict, List, Optional, Pattern, Set, Tuple, TYPE_CHECKING
from flask_babel import gettext as __
from sqlalchemy.dialects.postgresql import ARRAY, DOUBLE_PRECISION, ENUM, JSON
from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON
from sqlalchemy.dialects.postgresql.base import PGInspector
from sqlalchemy.types import String
from sqlalchemy.types import Date, DateTime, String
from superset.db_engine_specs.base import (
BaseEngineSpec,
BasicParametersMixin,
ColumnTypeMapping,
)
from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin
from superset.errors import SupersetErrorType
from superset.exceptions import SupersetException
from superset.models.sql_lab import Query
from superset.utils import core as utils
from superset.utils.core import ColumnSpec, GenericDataType
from superset.utils.core import GenericDataType
if TYPE_CHECKING:
from superset.models.core import Database # pragma: no cover
@ -185,7 +181,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
),
(
re.compile(r"^array.*", re.IGNORECASE),
lambda match: ARRAY(int(match[2])) if match[2] else String(),
String(),
GenericDataType.STRING,
),
(
@ -238,10 +234,11 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, Date):
return f"TO_DATE('{dttm.date().isoformat()}', 'YYYY-MM-DD')"
if "TIMESTAMP" in tt or "DATETIME" in tt:
if isinstance(sqla_type, DateTime):
dttm_formatted = dttm.isoformat(sep=" ", timespec="microseconds")
return f"""TO_TIMESTAMP('{dttm_formatted}', 'YYYY-MM-DD HH24:MI:SS.US')"""
return None
@ -270,23 +267,6 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
extra["engine_params"] = engine_params
return extra
@classmethod
def get_column_spec(
cls,
native_type: Optional[str],
db_extra: Optional[Dict[str, Any]] = None,
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings,
) -> Optional[ColumnSpec]:
column_spec = super().get_column_spec(native_type)
if column_spec:
return column_spec
return super().get_column_spec(
native_type, column_type_mappings=column_type_mappings
)
@classmethod
def get_datatype(cls, type_code: Any) -> Optional[str]:
# pylint: disable=import-outside-toplevel

View File

@ -54,7 +54,7 @@ from sqlalchemy.sql.expression import ColumnClause, Select
from superset import cache_manager, is_feature_enabled
from superset.common.db_query_status import QueryStatus
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.base import BaseEngineSpec, ColumnTypeMapping
from superset.db_engine_specs.base import BaseEngineSpec
from superset.errors import SupersetErrorType
from superset.exceptions import SupersetTemplateException
from superset.models.sql_lab import Query
@ -70,7 +70,7 @@ from superset.models.sql_types.presto_sql_types import (
from superset.result_set import destringify
from superset.superset_typing import ResultSetColumnType
from superset.utils import core as utils
from superset.utils.core import ColumnSpec, GenericDataType
from superset.utils.core import GenericDataType
if TYPE_CHECKING:
# prevent circular imports
@ -165,6 +165,92 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
A base class that share common functions between Presto and Trino
"""
column_type_mappings = (
(
re.compile(r"^boolean.*", re.IGNORECASE),
types.BOOLEAN(),
GenericDataType.BOOLEAN,
),
(
re.compile(r"^tinyint.*", re.IGNORECASE),
TinyInteger(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^smallint.*", re.IGNORECASE),
types.SmallInteger(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^integer.*", re.IGNORECASE),
types.INTEGER(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^bigint.*", re.IGNORECASE),
types.BigInteger(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^real.*", re.IGNORECASE),
types.FLOAT(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^double.*", re.IGNORECASE),
types.FLOAT(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^decimal.*", re.IGNORECASE),
types.DECIMAL(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^varchar(\((\d+)\))*$", re.IGNORECASE),
lambda match: types.VARCHAR(int(match[2])) if match[2] else types.String(),
GenericDataType.STRING,
),
(
re.compile(r"^char(\((\d+)\))*$", re.IGNORECASE),
lambda match: types.CHAR(int(match[2])) if match[2] else types.String(),
GenericDataType.STRING,
),
(
re.compile(r"^varbinary.*", re.IGNORECASE),
types.VARBINARY(),
GenericDataType.STRING,
),
(
re.compile(r"^json.*", re.IGNORECASE),
types.JSON(),
GenericDataType.STRING,
),
(
re.compile(r"^date.*", re.IGNORECASE),
types.Date(),
GenericDataType.TEMPORAL,
),
(
re.compile(r"^timestamp.*", re.IGNORECASE),
types.TIMESTAMP(),
GenericDataType.TEMPORAL,
),
(
re.compile(r"^interval.*", re.IGNORECASE),
Interval(),
GenericDataType.TEMPORAL,
),
(
re.compile(r"^time.*", re.IGNORECASE),
types.Time(),
GenericDataType.TEMPORAL,
),
(re.compile(r"^array.*", re.IGNORECASE), Array(), GenericDataType.STRING),
(re.compile(r"^map.*", re.IGNORECASE), Map(), GenericDataType.STRING),
(re.compile(r"^row.*", re.IGNORECASE), Row(), GenericDataType.STRING),
)
# pylint: disable=line-too-long
_time_grain_expressions = {
None: "{col}",
@ -199,14 +285,13 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
Superset only defines time zone naive `datetime` objects, though this method
handles both time zone naive and aware conversions.
"""
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
return f"DATE '{dttm.date().isoformat()}'"
if tt in (
utils.TemporalType.TIMESTAMP,
utils.TemporalType.TIMESTAMP_WITH_TIME_ZONE,
):
if isinstance(sqla_type, types.TIMESTAMP):
return f"""TIMESTAMP '{dttm.isoformat(timespec="microseconds", sep=" ")}'"""
return None
@classmethod
@ -827,92 +912,6 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
full_table = "{}.{}".format(quote(schema), full_table)
return inspector.bind.execute(f"SHOW COLUMNS FROM {full_table}").fetchall()
column_type_mappings = (
(
re.compile(r"^boolean.*", re.IGNORECASE),
types.BOOLEAN,
GenericDataType.BOOLEAN,
),
(
re.compile(r"^tinyint.*", re.IGNORECASE),
TinyInteger(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^smallint.*", re.IGNORECASE),
types.SMALLINT(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^integer.*", re.IGNORECASE),
types.INTEGER(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^bigint.*", re.IGNORECASE),
types.BIGINT(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^real.*", re.IGNORECASE),
types.FLOAT(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^double.*", re.IGNORECASE),
types.FLOAT(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^decimal.*", re.IGNORECASE),
types.DECIMAL(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^varchar(\((\d+)\))*$", re.IGNORECASE),
lambda match: types.VARCHAR(int(match[2])) if match[2] else types.String(),
GenericDataType.STRING,
),
(
re.compile(r"^char(\((\d+)\))*$", re.IGNORECASE),
lambda match: types.CHAR(int(match[2])) if match[2] else types.CHAR(),
GenericDataType.STRING,
),
(
re.compile(r"^varbinary.*", re.IGNORECASE),
types.VARBINARY(),
GenericDataType.STRING,
),
(
re.compile(r"^json.*", re.IGNORECASE),
types.JSON(),
GenericDataType.STRING,
),
(
re.compile(r"^date.*", re.IGNORECASE),
types.DATE(),
GenericDataType.TEMPORAL,
),
(
re.compile(r"^timestamp.*", re.IGNORECASE),
types.TIMESTAMP(),
GenericDataType.TEMPORAL,
),
(
re.compile(r"^interval.*", re.IGNORECASE),
Interval(),
GenericDataType.TEMPORAL,
),
(
re.compile(r"^time.*", re.IGNORECASE),
types.Time(),
GenericDataType.TEMPORAL,
),
(re.compile(r"^array.*", re.IGNORECASE), Array(), GenericDataType.STRING),
(re.compile(r"^map.*", re.IGNORECASE), Map(), GenericDataType.STRING),
(re.compile(r"^row.*", re.IGNORECASE), Row(), GenericDataType.STRING),
)
@classmethod
def get_columns(
cls, inspector: Inspector, table_name: str, schema: Optional[str]
@ -1282,24 +1281,6 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
return error_dict.get("message", _("Unknown Presto Error"))
return utils.error_msg_from_exception(ex)
@classmethod
def get_column_spec(
cls,
native_type: Optional[str],
db_extra: Optional[Dict[str, Any]] = None,
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings,
) -> Optional[ColumnSpec]:
column_spec = super().get_column_spec(
native_type, column_type_mappings=column_type_mappings
)
if column_spec:
return column_spec
return super().get_column_spec(native_type)
@classmethod
def has_implicit_cancel(cls) -> bool:
"""

View File

@ -17,8 +17,9 @@
from datetime import datetime
from typing import Any, Dict, Optional, TYPE_CHECKING
from sqlalchemy import types
from superset.db_engine_specs.base import BaseEngineSpec
from superset.utils import core as utils
if TYPE_CHECKING:
from superset.connectors.sqla.models import TableColumn
@ -53,15 +54,16 @@ class RocksetEngineSpec(BaseEngineSpec):
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
return f"DATE '{dttm.date().isoformat()}'"
if tt == utils.TemporalType.DATETIME:
dttm_formatted = dttm.isoformat(sep=" ", timespec="microseconds")
return f"""DATETIME '{dttm_formatted}'"""
if tt == utils.TemporalType.TIMESTAMP:
if isinstance(sqla_type, types.TIMESTAMP):
dttm_formatted = dttm.isoformat(timespec="microseconds")
return f"""TIMESTAMP '{dttm_formatted}'"""
if isinstance(sqla_type, types.DateTime):
dttm_formatted = dttm.isoformat(sep=" ", timespec="microseconds")
return f"""DATETIME '{dttm_formatted}'"""
return None
@classmethod

View File

@ -28,6 +28,7 @@ from cryptography.hazmat.primitives import serialization
from flask import current_app
from flask_babel import gettext as __
from marshmallow import fields, Schema
from sqlalchemy import types
from sqlalchemy.engine.url import URL
from typing_extensions import TypedDict
@ -37,7 +38,6 @@ from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.sql_lab import Query
from superset.utils import core as utils
if TYPE_CHECKING:
from superset.models.core import Database
@ -157,13 +157,14 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
return f"TO_DATE('{dttm.date().isoformat()}')"
if tt == utils.TemporalType.DATETIME:
return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS DATETIME)"""
if utils.TemporalType.TIMESTAMP in tt:
if isinstance(sqla_type, types.TIMESTAMP):
return f"""TO_TIMESTAMP('{dttm.isoformat(timespec="microseconds")}')"""
if isinstance(sqla_type, types.DateTime):
return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS DATETIME)"""
return None
@staticmethod

View File

@ -19,11 +19,11 @@ from datetime import datetime
from typing import Any, Dict, Optional, Pattern, Set, Tuple, TYPE_CHECKING
from flask_babel import gettext as __
from sqlalchemy import types
from sqlalchemy.engine.reflection import Inspector
from superset.db_engine_specs.base import BaseEngineSpec
from superset.errors import SupersetErrorType
from superset.utils import core as utils
if TYPE_CHECKING:
# prevent circular imports
@ -76,12 +76,8 @@ class SqliteEngineSpec(BaseEngineSpec):
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt in (
utils.TemporalType.TEXT,
utils.TemporalType.DATETIME,
utils.TemporalType.TIMESTAMP,
):
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, (types.String, types.DateTime)):
return f"""'{dttm.isoformat(sep=" ", timespec="seconds")}'"""
return None

View File

@ -17,8 +17,6 @@
from __future__ import annotations
import logging
import re
from datetime import datetime
from typing import Any, Dict, Optional, Type, TYPE_CHECKING
import simplejson as json
@ -49,29 +47,6 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
engine = "trino"
engine_name = "Trino"
@classmethod
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
"""
Convert a Python `datetime` object to a SQL expression.
:param target_type: The target type of expression
:param dttm: The datetime object
:param db_extra: The database extra object
:return: The SQL expression
Superset only defines time zone naive `datetime` objects, though this method
handles both time zone naive and aware conversions.
"""
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
return f"DATE '{dttm.date().isoformat()}'"
if re.sub(r"\(\d\)", "", tt) in (
utils.TemporalType.TIMESTAMP,
utils.TemporalType.TIMESTAMP_WITH_TIME_ZONE,
):
return f"""TIMESTAMP '{dttm.isoformat(timespec="microseconds", sep=" ")}'"""
return None
@classmethod
def extra_table_metadata(
cls,

View File

@ -364,21 +364,6 @@ class RowLevelSecurityFilterType(str, Enum):
BASE = "Base"
class TemporalType(str, Enum):
"""
Supported temporal types
"""
DATE = "DATE"
DATETIME = "DATETIME"
SMALLDATETIME = "SMALLDATETIME"
TEXT = "TEXT"
TIME = "TIME"
TIME_WITH_TIME_ZONE = "TIME WITH TIME ZONE"
TIMESTAMP = "TIMESTAMP"
TIMESTAMP_WITH_TIME_ZONE = "TIMESTAMP WITH TIME ZONE"
class ColumnTypeSource(Enum):
GET_TABLE = 1
CURSOR_DESCRIPION = 2

View File

@ -22,7 +22,6 @@ from tests.integration_tests.test_app import app
from tests.integration_tests.base_tests import SupersetTestCase
from superset.db_engine_specs.base import BaseEngineSpec
from superset.models.core import Database
from superset.utils.core import GenericDataType
class TestDbEngineSpec(SupersetTestCase):
@ -37,16 +36,3 @@ class TestDbEngineSpec(SupersetTestCase):
main = Database(database_name="test_database", sqlalchemy_uri="sqlite://")
limited = engine_spec_class.apply_limit_to_sql(sql, limit, main, force)
self.assertEqual(expected_sql, limited)
def assert_generic_types(
spec: Type[BaseEngineSpec],
type_expectations: Tuple[Tuple[str, GenericDataType], ...],
) -> None:
for type_str, expected_type in type_expectations:
column_spec = spec.get_column_spec(type_str)
assert column_spec is not None
actual_type = column_spec.generic_type
assert (
actual_type == expected_type
), f"{type_str} should be {expected_type.name} but is {actual_type.name}"

View File

@ -48,23 +48,6 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
actual = BigQueryEngineSpec.make_label_compatible(column(original).name)
self.assertEqual(actual, expected)
def test_convert_dttm(self):
"""
DB Eng Specs (bigquery): Test conversion to date time
"""
dttm = self.get_dttm()
test_cases = {
"DATE": "CAST('2019-01-02' AS DATE)",
"DATETIME": "CAST('2019-01-02T03:04:05.678900' AS DATETIME)",
"TIMESTAMP": "CAST('2019-01-02T03:04:05.678900' AS TIMESTAMP)",
"TIME": "CAST('03:04:05.678900' AS TIME)",
"UNKNOWNTYPE": None,
}
for target_type, expected in test_cases.items():
actual = BigQueryEngineSpec.convert_dttm(target_type, dttm)
self.assertEqual(actual, expected)
def test_timegrain_expressions(self):
"""
DB Eng Specs (bigquery): Test time grain expressions

View File

@ -1,53 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.db_engine_specs.crate import CrateEngineSpec
from superset.models.core import Database
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestCrateDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
"""
DB Eng Specs (crate): Test conversion to date time
"""
dttm = self.get_dttm()
assert CrateEngineSpec.convert_dttm("TIMESTAMP", dttm) == str(
dttm.timestamp() * 1000
)
def test_epoch_to_dttm(self):
"""
DB Eng Specs (crate): Test epoch to dttm
"""
assert CrateEngineSpec.epoch_to_dttm() == "{col} * 1000"
def test_epoch_ms_to_dttm(self):
"""
DB Eng Specs (crate): Test epoch ms to dttm
"""
assert CrateEngineSpec.epoch_ms_to_dttm() == "{col}"
def test_alter_new_orm_column(self):
"""
DB Eng Specs (crate): Test alter orm column
"""
database = Database(database_name="crate", sqlalchemy_uri="crate://db")
tbl = SqlaTable(table_name="druid_tbl", database=database)
col = TableColumn(column_name="ts", type="TIMESTAMP", table=tbl)
CrateEngineSpec.alter_new_orm_column(col)
assert col.python_date_format == "epoch_ms"

View File

@ -14,18 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from textwrap import dedent
from unittest import mock
from sqlalchemy import column, literal_column
from superset.constants import USER_AGENT
from superset.db_engine_specs import get_engine_spec
from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import (
assert_generic_types,
TestDbEngineSpec,
)
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.fixtures.certificates import ssl_certificate
from tests.integration_tests.fixtures.database import default_db_extra

View File

@ -1,78 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest import mock
from sqlalchemy import column
from superset.db_engine_specs.druid import DruidEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.fixtures.certificates import ssl_certificate
from tests.integration_tests.fixtures.database import default_db_extra
class TestDruidDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
DruidEngineSpec.convert_dttm("DATETIME", dttm),
"TIME_PARSE('2019-01-02T03:04:05')",
)
self.assertEqual(
DruidEngineSpec.convert_dttm("TIMESTAMP", dttm),
"TIME_PARSE('2019-01-02T03:04:05')",
)
self.assertEqual(
DruidEngineSpec.convert_dttm("DATE", dttm),
"CAST(TIME_PARSE('2019-01-02') AS DATE)",
)
def test_timegrain_expressions(self):
"""
DB Eng Specs (druid): Test time grain expressions
"""
col = "__time"
sqla_col = column(col)
test_cases = {
"PT1S": f"TIME_FLOOR(CAST({col} AS TIMESTAMP), 'PT1S')",
"PT5M": f"TIME_FLOOR(CAST({col} AS TIMESTAMP), 'PT5M')",
"P1W/1970-01-03T00:00:00Z": f"TIME_SHIFT(TIME_FLOOR(TIME_SHIFT(CAST({col} AS TIMESTAMP), 'P1D', 1), 'P1W'), 'P1D', 5)",
"1969-12-28T00:00:00Z/P1W": f"TIME_SHIFT(TIME_FLOOR(TIME_SHIFT(CAST({col} AS TIMESTAMP), 'P1D', 1), 'P1W'), 'P1D', -1)",
}
for grain, expected in test_cases.items():
actual = DruidEngineSpec.get_timestamp_expr(
col=sqla_col, pdf=None, time_grain=grain
)
self.assertEqual(str(actual), expected)
def test_extras_without_ssl(self):
db = mock.Mock()
db.extra = default_db_extra
db.server_cert = None
extras = DruidEngineSpec.get_extra_params(db)
assert "connect_args" not in extras["engine_params"]
def test_extras_with_ssl(self):
db = mock.Mock()
db.extra = default_db_extra
db.server_cert = ssl_certificate
extras = DruidEngineSpec.get_extra_params(db)
connect_args = extras["engine_params"]["connect_args"]
assert connect_args["scheme"] == "https"
assert "ssl_verify_cert" in connect_args

View File

@ -1,104 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest.mock import MagicMock
import pytest
from sqlalchemy import column
from superset.db_engine_specs.elasticsearch import (
ElasticSearchEngineSpec,
OpenDistroEngineSpec,
)
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestElasticSearchDbEngineSpec(TestDbEngineSpec):
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
ElasticSearchEngineSpec.convert_dttm("DATETIME", dttm, db_extra=None),
"CAST('2019-01-02T03:04:05' AS DATETIME)",
)
def test_convert_dttm2(self):
"""
ES 7.8 and above versions need to use the DATETIME_PARSE function to
solve the time zone problem
"""
dttm = self.get_dttm()
db_extra = {"version": "7.8"}
self.assertEqual(
ElasticSearchEngineSpec.convert_dttm("DATETIME", dttm, db_extra=db_extra),
"DATETIME_PARSE('2019-01-02 03:04:05', 'yyyy-MM-dd HH:mm:ss')",
)
def test_convert_dttm3(self):
dttm = self.get_dttm()
db_extra = {"version": 7.8}
self.assertEqual(
ElasticSearchEngineSpec.convert_dttm("DATETIME", dttm, db_extra=db_extra),
"CAST('2019-01-02T03:04:05' AS DATETIME)",
)
self.assertNotEqual(
ElasticSearchEngineSpec.convert_dttm("DATETIME", dttm, db_extra=db_extra),
"DATETIME_PARSE('2019-01-02 03:04:05', 'yyyy-MM-dd HH:mm:ss')",
)
self.assertIn("Unexpected error while convert es_version", self._caplog.text)
def test_opendistro_convert_dttm(self):
"""
DB Eng Specs (opendistro): Test convert_dttm
"""
dttm = self.get_dttm()
self.assertEqual(
OpenDistroEngineSpec.convert_dttm("DATETIME", dttm, db_extra=None),
"'2019-01-02T03:04:05'",
)
def test_opendistro_sqla_column_label(self):
"""
DB Eng Specs (opendistro): Test column label
"""
test_cases = {
"Col": "Col",
"Col.keyword": "Col_keyword",
}
for original, expected in test_cases.items():
actual = OpenDistroEngineSpec.make_label_compatible(column(original).name)
self.assertEqual(actual, expected)
def test_opendistro_strip_comments(self):
"""
DB Eng Specs (opendistro): Test execute sql strip comments
"""
mock_cursor = MagicMock()
mock_cursor.execute.return_value = []
OpenDistroEngineSpec.execute(
mock_cursor, "-- some comment \nSELECT 1\n --other comment"
)
mock_cursor.execute.assert_called_once_with("SELECT 1\n")

View File

@ -1,81 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from unittest import mock
import pytest
from superset.db_engine_specs.firebird import FirebirdEngineSpec
grain_expressions = {
None: "timestamp_column",
"PT1S": (
"CAST(CAST(timestamp_column AS DATE) "
"|| ' ' "
"|| EXTRACT(HOUR FROM timestamp_column) "
"|| ':' "
"|| EXTRACT(MINUTE FROM timestamp_column) "
"|| ':' "
"|| FLOOR(EXTRACT(SECOND FROM timestamp_column)) AS TIMESTAMP)"
),
"PT1M": (
"CAST(CAST(timestamp_column AS DATE) "
"|| ' ' "
"|| EXTRACT(HOUR FROM timestamp_column) "
"|| ':' "
"|| EXTRACT(MINUTE FROM timestamp_column) "
"|| ':00' AS TIMESTAMP)"
),
"P1D": "CAST(timestamp_column AS DATE)",
"P1M": (
"CAST(EXTRACT(YEAR FROM timestamp_column) "
"|| '-' "
"|| EXTRACT(MONTH FROM timestamp_column) "
"|| '-01' AS DATE)"
),
"P1Y": "CAST(EXTRACT(YEAR FROM timestamp_column) || '-01-01' AS DATE)",
}
@pytest.mark.parametrize("grain,expected", grain_expressions.items())
def test_time_grain_expressions(grain, expected):
assert (
FirebirdEngineSpec._time_grain_expressions[grain].format(col="timestamp_column")
== expected
)
def test_epoch_to_dttm():
assert (
FirebirdEngineSpec.epoch_to_dttm().format(col="timestamp_column")
== "DATEADD(second, timestamp_column, CAST('00:00:00' AS TIMESTAMP))"
)
def test_convert_dttm():
dttm = datetime(2021, 1, 1)
assert (
FirebirdEngineSpec.convert_dttm("timestamp", dttm)
== "CAST('2021-01-01 00:00:00' AS TIMESTAMP)"
)
assert (
FirebirdEngineSpec.convert_dttm("TIMESTAMP", dttm)
== "CAST('2021-01-01 00:00:00' AS TIMESTAMP)"
)
assert FirebirdEngineSpec.convert_dttm("TIME", dttm) == "CAST('00:00:00' AS TIME)"
assert FirebirdEngineSpec.convert_dttm("DATE", dttm) == "CAST('2021-01-01' AS DATE)"
assert FirebirdEngineSpec.convert_dttm("STRING", dttm) is None

View File

@ -1,39 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.db_engine_specs.firebolt import FireboltEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestFireboltDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
dttm = self.get_dttm()
test_cases = {
"DATE": "CAST('2019-01-02' AS DATE)",
"DATETIME": "CAST('2019-01-02T03:04:05' AS DATETIME)",
"TIMESTAMP": "CAST('2019-01-02T03:04:05' AS TIMESTAMP)",
"UNKNOWNTYPE": None,
}
for target_type, expected in test_cases.items():
actual = FireboltEngineSpec.convert_dttm(target_type, dttm)
self.assertEqual(actual, expected)
def test_epoch_to_dttm(self):
assert (
FireboltEngineSpec.epoch_to_dttm().format(col="timestamp_column")
== "from_unixtime(timestamp_column)"
)

View File

@ -150,15 +150,6 @@ def test_hive_error_msg():
)
def test_convert_dttm():
dttm = datetime.strptime("2019-01-02 03:04:05.678900", "%Y-%m-%d %H:%M:%S.%f")
assert HiveEngineSpec.convert_dttm("DATE", dttm) == "CAST('2019-01-02' AS DATE)"
assert (
HiveEngineSpec.convert_dttm("TIMESTAMP", dttm)
== "CAST('2019-01-02 03:04:05.678900' AS TIMESTAMP)"
)
def test_df_to_csv() -> None:
with pytest.raises(SupersetException):
HiveEngineSpec.df_to_sql(

View File

@ -1,32 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.db_engine_specs.kylin import KylinEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestKylinDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
KylinEngineSpec.convert_dttm("DATE", dttm), "CAST('2019-01-02' AS DATE)"
)
self.assertEqual(
KylinEngineSpec.convert_dttm("TIMESTAMP", dttm),
"CAST('2019-01-02 03:04:05' AS TIMESTAMP)",
)

View File

@ -21,12 +21,7 @@ from sqlalchemy.dialects.mysql import DATE, NVARCHAR, TEXT, VARCHAR
from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.sql_lab import Query
from superset.utils.core import GenericDataType
from tests.integration_tests.db_engine_specs.base_tests import (
assert_generic_types,
TestDbEngineSpec,
)
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec):
@ -38,19 +33,6 @@ class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec):
self.assertEqual("TINY", MySQLEngineSpec.get_datatype(1))
self.assertEqual("VARCHAR", MySQLEngineSpec.get_datatype(15))
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
MySQLEngineSpec.convert_dttm("DATE", dttm),
"STR_TO_DATE('2019-01-02', '%Y-%m-%d')",
)
self.assertEqual(
MySQLEngineSpec.convert_dttm("DATETIME", dttm),
"STR_TO_DATE('2019-01-02 03:04:05.678900', '%Y-%m-%d %H:%i:%s.%f')",
)
def test_column_datatype_to_string(self):
test_cases = (
(DATE(), "DATE"),
@ -69,32 +51,6 @@ class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec):
)
self.assertEqual(actual, expected)
def test_generic_type(self):
type_expectations = (
# Numeric
("TINYINT", GenericDataType.NUMERIC),
("SMALLINT", GenericDataType.NUMERIC),
("MEDIUMINT", GenericDataType.NUMERIC),
("INT", GenericDataType.NUMERIC),
("BIGINT", GenericDataType.NUMERIC),
("DECIMAL", GenericDataType.NUMERIC),
("FLOAT", GenericDataType.NUMERIC),
("DOUBLE", GenericDataType.NUMERIC),
("BIT", GenericDataType.NUMERIC),
# String
("CHAR", GenericDataType.STRING),
("VARCHAR", GenericDataType.STRING),
("TINYTEXT", GenericDataType.STRING),
("MEDIUMTEXT", GenericDataType.STRING),
("LONGTEXT", GenericDataType.STRING),
# Temporal
("DATE", GenericDataType.TEMPORAL),
("DATETIME", GenericDataType.TEMPORAL),
("TIMESTAMP", GenericDataType.TEMPORAL),
("TIME", GenericDataType.TEMPORAL),
)
assert_generic_types(MySQLEngineSpec, type_expectations)
def test_extract_error_message(self):
from MySQLdb._exceptions import OperationalError
@ -239,22 +195,3 @@ class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec):
},
)
]
@unittest.mock.patch("sqlalchemy.engine.Engine.connect")
def test_get_cancel_query_id(self, engine_mock):
query = Query()
cursor_mock = engine_mock.return_value.__enter__.return_value
cursor_mock.fetchone.return_value = [123]
assert MySQLEngineSpec.get_cancel_query_id(cursor_mock, query) == 123
@unittest.mock.patch("sqlalchemy.engine.Engine.connect")
def test_cancel_query(self, engine_mock):
query = Query()
cursor_mock = engine_mock.return_value.__enter__.return_value
assert MySQLEngineSpec.cancel_query(cursor_mock, query, 123) is True
@unittest.mock.patch("sqlalchemy.engine.Engine.connect")
def test_cancel_query_failed(self, engine_mock):
query = Query()
cursor_mock = engine_mock.raiseError.side_effect = Exception()
assert MySQLEngineSpec.cancel_query(cursor_mock, query, 123) is False

View File

@ -1,87 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest import mock
import pytest
from sqlalchemy import column
from sqlalchemy.dialects import oracle
from sqlalchemy.dialects.oracle import DATE, NVARCHAR, VARCHAR
from superset.db_engine_specs.oracle import OracleEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestOracleDbEngineSpec(TestDbEngineSpec):
def test_oracle_sqla_column_name_length_exceeded(self):
col = column("This_Is_32_Character_Column_Name")
label = OracleEngineSpec.make_label_compatible(col.name)
self.assertEqual(label.quote, True)
label_expected = "3b26974078683be078219674eeb8f5"
self.assertEqual(label, label_expected)
def test_oracle_time_expression_reserved_keyword_1m_grain(self):
col = column("decimal")
expr = OracleEngineSpec.get_timestamp_expr(col, None, "P1M")
result = str(expr.compile(dialect=oracle.dialect()))
self.assertEqual(result, "TRUNC(CAST(\"decimal\" as DATE), 'MONTH')")
dttm = self.get_dttm()
def test_column_datatype_to_string(self):
test_cases = (
(DATE(), "DATE"),
(VARCHAR(length=255), "VARCHAR(255 CHAR)"),
(VARCHAR(length=255, collation="utf8"), "VARCHAR(255 CHAR)"),
(NVARCHAR(length=128), "NVARCHAR2(128)"),
)
for original, expected in test_cases:
actual = OracleEngineSpec.column_datatype_to_string(
original, oracle.dialect()
)
self.assertEqual(actual, expected)
def test_fetch_data_no_description(self):
cursor = mock.MagicMock()
cursor.description = []
assert OracleEngineSpec.fetch_data(cursor) == []
def test_fetch_data(self):
cursor = mock.MagicMock()
result = ["a", "b"]
cursor.fetchall.return_value = result
assert OracleEngineSpec.fetch_data(cursor) == result
@pytest.mark.parametrize(
"date_format,expected",
[
("DATE", "TO_DATE('2019-01-02', 'YYYY-MM-DD')"),
("DATETIME", """TO_DATE('2019-01-02T03:04:05', 'YYYY-MM-DD"T"HH24:MI:SS')"""),
(
"TIMESTAMP",
"""TO_TIMESTAMP('2019-01-02T03:04:05.678900', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')""",
),
(
"timestamp",
"""TO_TIMESTAMP('2019-01-02T03:04:05.678900', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')""",
),
("Other", None),
],
)
def test_convert_dttm(date_format, expected):
dttm = TestOracleDbEngineSpec.get_dttm()
assert OracleEngineSpec.convert_dttm(date_format, dttm) == expected

View File

@ -24,11 +24,7 @@ from superset.db_engine_specs import load_engine_specs
from superset.db_engine_specs.postgres import PostgresEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.sql_lab import Query
from superset.utils.core import GenericDataType
from tests.integration_tests.db_engine_specs.base_tests import (
assert_generic_types,
TestDbEngineSpec,
)
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.fixtures.certificates import ssl_certificate
from tests.integration_tests.fixtures.database import default_db_extra
@ -100,29 +96,6 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
result = str(expr.compile(None, dialect=postgresql.dialect()))
self.assertEqual(result, "DATE_TRUNC('year', \"MixedCase\")")
def test_convert_dttm(self):
"""
DB Eng Specs (postgres): Test conversion to date time
"""
dttm = self.get_dttm()
self.assertEqual(
PostgresEngineSpec.convert_dttm("DATE", dttm),
"TO_DATE('2019-01-02', 'YYYY-MM-DD')",
)
self.assertEqual(
PostgresEngineSpec.convert_dttm("TIMESTAMP", dttm),
"TO_TIMESTAMP('2019-01-02 03:04:05.678900', 'YYYY-MM-DD HH24:MI:SS.US')",
)
self.assertEqual(
PostgresEngineSpec.convert_dttm("DATETIME", dttm),
"TO_TIMESTAMP('2019-01-02 03:04:05.678900', 'YYYY-MM-DD HH24:MI:SS.US')",
)
self.assertEqual(PostgresEngineSpec.convert_dttm("TIME", dttm), None)
def test_empty_dbapi_cursor_description(self):
"""
DB Eng Specs (postgres): Test empty cursor description (no columns)
@ -541,28 +514,3 @@ def test_base_parameters_mixin():
},
"required": ["database", "host", "port", "username"],
}
def test_generic_type():
type_expectations = (
# Numeric
("SMALLINT", GenericDataType.NUMERIC),
("INTEGER", GenericDataType.NUMERIC),
("BIGINT", GenericDataType.NUMERIC),
("DECIMAL", GenericDataType.NUMERIC),
("NUMERIC", GenericDataType.NUMERIC),
("REAL", GenericDataType.NUMERIC),
("DOUBLE PRECISION", GenericDataType.NUMERIC),
("MONEY", GenericDataType.NUMERIC),
# String
("CHAR", GenericDataType.STRING),
("VARCHAR", GenericDataType.STRING),
("TEXT", GenericDataType.STRING),
# Temporal
("DATE", GenericDataType.TEMPORAL),
("TIMESTAMP", GenericDataType.TEMPORAL),
("TIME", GenericDataType.TEMPORAL),
# Boolean
("BOOLEAN", GenericDataType.BOOLEAN),
)
assert_generic_types(PostgresEngineSpec, type_expectations)

View File

@ -25,7 +25,6 @@ from sqlalchemy.sql import select
from superset.db_engine_specs.presto import PrestoEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql_parse import ParsedQuery
from superset.utils.core import DatasourceName, GenericDataType
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
@ -624,42 +623,6 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
self.assertEqual(actual_data, expected_data)
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
def test_get_sqla_column_type(self):
column_spec = PrestoEngineSpec.get_column_spec("varchar(255)")
assert isinstance(column_spec.sqla_type, types.VARCHAR)
assert column_spec.sqla_type.length == 255
self.assertEqual(column_spec.generic_type, GenericDataType.STRING)
column_spec = PrestoEngineSpec.get_column_spec("varchar")
assert isinstance(column_spec.sqla_type, types.String)
assert column_spec.sqla_type.length is None
self.assertEqual(column_spec.generic_type, GenericDataType.STRING)
column_spec = PrestoEngineSpec.get_column_spec("char(10)")
assert isinstance(column_spec.sqla_type, types.CHAR)
assert column_spec.sqla_type.length == 10
self.assertEqual(column_spec.generic_type, GenericDataType.STRING)
column_spec = PrestoEngineSpec.get_column_spec("char")
assert isinstance(column_spec.sqla_type, types.CHAR)
assert column_spec.sqla_type.length is None
self.assertEqual(column_spec.generic_type, GenericDataType.STRING)
column_spec = PrestoEngineSpec.get_column_spec("integer")
assert isinstance(column_spec.sqla_type, types.Integer)
self.assertEqual(column_spec.generic_type, GenericDataType.NUMERIC)
column_spec = PrestoEngineSpec.get_column_spec("time")
assert isinstance(column_spec.sqla_type, types.Time)
self.assertEqual(column_spec.generic_type, GenericDataType.TEMPORAL)
column_spec = PrestoEngineSpec.get_column_spec("timestamp")
assert isinstance(column_spec.sqla_type, types.TIMESTAMP)
self.assertEqual(column_spec.generic_type, GenericDataType.TEMPORAL)
sqla_type = PrestoEngineSpec.get_sqla_column_type(None)
assert sqla_type is None
@mock.patch("superset.db_engine_specs.base.BaseEngineSpec.get_table_names")
@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.get_view_names")
def test_get_table_names(

View File

@ -1,214 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
from typing import Any, Dict
from unittest import mock
from unittest.mock import Mock, patch
import pandas as pd
import pytest
from sqlalchemy import types
import superset.config
from superset.constants import USER_AGENT
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.utils.core import GenericDataType
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestTrinoDbEngineSpec(TestDbEngineSpec):
def test_get_extra_params(self):
database = Mock()
database.extra = json.dumps({})
database.server_cert = None
extra = TrinoEngineSpec.get_extra_params(database)
expected = {"engine_params": {"connect_args": {"source": USER_AGENT}}}
self.assertEqual(extra, expected)
expected = {
"first": 1,
"engine_params": {
"second": "two",
"connect_args": {"source": "foobar", "third": "three"},
},
}
database.extra = json.dumps(expected)
database.server_cert = None
extra = TrinoEngineSpec.get_extra_params(database)
self.assertEqual(extra, expected)
@patch("superset.utils.core.create_ssl_cert_file")
def test_get_extra_params_with_server_cert(self, create_ssl_cert_file_func: Mock):
database = Mock()
database.extra = json.dumps({})
database.server_cert = "TEST_CERT"
create_ssl_cert_file_func.return_value = "/path/to/tls.crt"
extra = TrinoEngineSpec.get_extra_params(database)
connect_args = extra.get("engine_params", {}).get("connect_args", {})
self.assertEqual(connect_args.get("http_scheme"), "https")
self.assertEqual(connect_args.get("verify"), "/path/to/tls.crt")
create_ssl_cert_file_func.assert_called_once_with(database.server_cert)
@patch("trino.auth.BasicAuthentication")
def test_auth_basic(self, auth: Mock):
database = Mock()
auth_params = {"username": "username", "password": "password"}
database.encrypted_extra = json.dumps(
{"auth_method": "basic", "auth_params": auth_params}
)
params: Dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
self.assertEqual(connect_args.get("http_scheme"), "https")
auth.assert_called_once_with(**auth_params)
@patch("trino.auth.KerberosAuthentication")
def test_auth_kerberos(self, auth: Mock):
database = Mock()
auth_params = {
"service_name": "superset",
"mutual_authentication": False,
"delegate": True,
}
database.encrypted_extra = json.dumps(
{"auth_method": "kerberos", "auth_params": auth_params}
)
params: Dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
self.assertEqual(connect_args.get("http_scheme"), "https")
auth.assert_called_once_with(**auth_params)
@patch("trino.auth.CertificateAuthentication")
def test_auth_certificate(self, auth: Mock):
database = Mock()
auth_params = {"cert": "/path/to/cert.pem", "key": "/path/to/key.pem"}
database.encrypted_extra = json.dumps(
{"auth_method": "certificate", "auth_params": auth_params}
)
params: Dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
self.assertEqual(connect_args.get("http_scheme"), "https")
auth.assert_called_once_with(**auth_params)
@patch("trino.auth.JWTAuthentication")
def test_auth_jwt(self, auth: Mock):
database = Mock()
auth_params = {"token": "jwt-token-string"}
database.encrypted_extra = json.dumps(
{"auth_method": "jwt", "auth_params": auth_params}
)
params: Dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
self.assertEqual(connect_args.get("http_scheme"), "https")
auth.assert_called_once_with(**auth_params)
def test_auth_custom_auth(self):
database = Mock()
auth_class = Mock()
auth_method = "custom_auth"
auth_params = {"params1": "params1", "params2": "params2"}
database.encrypted_extra = json.dumps(
{"auth_method": auth_method, "auth_params": auth_params}
)
with patch.dict(
"superset.config.ALLOWED_EXTRA_AUTHENTICATIONS",
{"trino": {"custom_auth": auth_class}},
clear=True,
):
params: Dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
self.assertEqual(connect_args.get("http_scheme"), "https")
auth_class.assert_called_once_with(**auth_params)
def test_auth_custom_auth_denied(self):
database = Mock()
auth_method = "my.module:TrinoAuthClass"
auth_params = {"params1": "params1", "params2": "params2"}
database.encrypted_extra = json.dumps(
{"auth_method": auth_method, "auth_params": auth_params}
)
superset.config.ALLOWED_EXTRA_AUTHENTICATIONS = {}
with pytest.raises(ValueError) as excinfo:
TrinoEngineSpec.update_params_from_encrypted_extra(database, {})
assert str(excinfo.value) == (
f"For security reason, custom authentication '{auth_method}' "
f"must be listed in 'ALLOWED_EXTRA_AUTHENTICATIONS' config"
)
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
TrinoEngineSpec.convert_dttm("TIMESTAMP", dttm),
"TIMESTAMP '2019-01-02 03:04:05.678900'",
)
self.assertEqual(
TrinoEngineSpec.convert_dttm("TIMESTAMP(3)", dttm),
"TIMESTAMP '2019-01-02 03:04:05.678900'",
)
self.assertEqual(
TrinoEngineSpec.convert_dttm("TIMESTAMP WITH TIME ZONE", dttm),
"TIMESTAMP '2019-01-02 03:04:05.678900'",
)
self.assertEqual(
TrinoEngineSpec.convert_dttm("TIMESTAMP(3) WITH TIME ZONE", dttm),
"TIMESTAMP '2019-01-02 03:04:05.678900'",
)
self.assertEqual(
TrinoEngineSpec.convert_dttm("DATE", dttm),
"DATE '2019-01-02'",
)
def test_extra_table_metadata(self):
db = mock.Mock()
db.get_indexes = mock.Mock(
return_value=[{"column_names": ["ds", "hour"], "name": "partition"}]
)
db.get_extra = mock.Mock(return_value={})
db.has_view_by_name = mock.Mock(return_value=None)
db.get_df = mock.Mock(
return_value=pd.DataFrame({"ds": ["01-01-19"], "hour": [1]})
)
result = TrinoEngineSpec.extra_table_metadata(db, "test_table", "test_schema")
assert result["partitions"]["cols"] == ["ds", "hour"]
assert result["partitions"]["latest"] == {"ds": "01-01-19", "hour": 1}

View File

@ -39,7 +39,6 @@ from superset.utils.core import (
AdhocMetricExpressionType,
FilterOperator,
GenericDataType,
TemporalType,
)
from superset.utils.database import get_example_database
from tests.integration_tests.fixtures.birth_names_dashboard import (
@ -805,7 +804,7 @@ def test__normalize_prequery_result_type(
def _convert_dttm(
target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
if target_type.upper() == TemporalType.TIMESTAMP:
if target_type.upper() == "TIMESTAMP":
return f"""TIME_PARSE('{dttm.isoformat(timespec="seconds")}')"""
return None

View File

@ -17,8 +17,12 @@
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
import re
from datetime import datetime
from typing import Optional
import pytest
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
SYNTAX_ERROR_REGEX = re.compile(
@ -26,19 +30,20 @@ SYNTAX_ERROR_REGEX = re.compile(
)
def test_convert_dttm(dttm: datetime) -> None:
"""
Test that date objects are converted correctly.
"""
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "DATE '2019-01-02'"),
("TimeStamp", "TIMESTAMP '2019-01-02 03:04:05.678'"),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.athena import AthenaEngineSpec as spec
from superset.db_engine_specs.athena import AthenaEngineSpec
assert AthenaEngineSpec.convert_dttm("DATE", dttm) == "DATE '2019-01-02'"
assert (
AthenaEngineSpec.convert_dttm("TIMESTAMP", dttm)
== "TIMESTAMP '2019-01-02 03:04:05.678'"
)
assert_convert_dttm(spec, target_type, expected_result, dttm)
def test_extract_errors() -> None:

View File

@ -17,9 +17,13 @@
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
from textwrap import dedent
from typing import Any, Dict, Optional, Type
import pytest
from sqlalchemy.types import TypeEngine
from sqlalchemy import types
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import assert_column_spec
def test_get_text_clause_with_colon() -> None:
@ -94,8 +98,43 @@ select 'USD' as cur
),
],
)
def test_cte_query_parsing(original: TypeEngine, expected: str) -> None:
def test_cte_query_parsing(original: types.TypeEngine, expected: str) -> None:
from superset.db_engine_specs.base import BaseEngineSpec
actual = BaseEngineSpec.get_cte_query(original)
assert actual == expected
@pytest.mark.parametrize(
"native_type,sqla_type,attrs,generic_type,is_dttm",
[
("SMALLINT", types.SmallInteger, None, GenericDataType.NUMERIC, False),
("INTEGER", types.Integer, None, GenericDataType.NUMERIC, False),
("BIGINT", types.BigInteger, None, GenericDataType.NUMERIC, False),
("DECIMAL", types.Numeric, None, GenericDataType.NUMERIC, False),
("NUMERIC", types.Numeric, None, GenericDataType.NUMERIC, False),
("REAL", types.REAL, None, GenericDataType.NUMERIC, False),
("DOUBLE PRECISION", types.Float, None, GenericDataType.NUMERIC, False),
("MONEY", types.Numeric, None, GenericDataType.NUMERIC, False),
# String
("CHAR", types.String, None, GenericDataType.STRING, False),
("VARCHAR", types.String, None, GenericDataType.STRING, False),
("TEXT", types.String, None, GenericDataType.STRING, False),
# Temporal
("DATE", types.Date, None, GenericDataType.TEMPORAL, True),
("TIMESTAMP", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True),
("TIME", types.Time, None, GenericDataType.TEMPORAL, True),
# Boolean
("BOOLEAN", types.Boolean, None, GenericDataType.BOOLEAN, False),
],
)
def test_get_column_spec(
native_type: str,
sqla_type: Type[types.TypeEngine],
attrs: Optional[Dict[str, Any]],
generic_type: GenericDataType,
is_dttm: bool,
) -> None:
from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec as spec
assert_column_spec(spec, native_type, sqla_type, attrs, generic_type, is_dttm)

View File

@ -18,12 +18,18 @@
# pylint: disable=line-too-long, import-outside-toplevel, protected-access, invalid-name
import json
from datetime import datetime
from typing import Optional
import pytest
from pytest_mock import MockFixture
from sqlalchemy import select
from sqlalchemy.sql import sqltypes
from sqlalchemy_bigquery import BigQueryDialect
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
def test_get_fields() -> None:
"""
@ -285,3 +291,24 @@ def test_parse_error_raises_exception() -> None:
== expected_result
)
assert str(BigQueryEngineSpec.parse_error_exception(Exception(message_2))) == "6"
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "CAST('2019-01-02' AS DATE)"),
("DateTime", "CAST('2019-01-02T03:04:05.678900' AS DATETIME)"),
("TimeStamp", "CAST('2019-01-02T03:04:05.678900' AS TIMESTAMP)"),
("Time", "CAST('03:04:05.678900' AS TIME)"),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
"""
DB Eng Specs (bigquery): Test conversion to date time
"""
from superset.db_engine_specs.bigquery import BigQueryEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)

View File

@ -14,22 +14,31 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from unittest import mock
from typing import Optional
from unittest.mock import Mock
import pytest
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
def test_convert_dttm(dttm: datetime) -> None:
from superset.db_engine_specs.clickhouse import ClickHouseEngineSpec
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "toDate('2019-01-02')"),
("DateTime", "toDateTime('2019-01-02 03:04:05')"),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.clickhouse import ClickHouseEngineSpec as spec
assert ClickHouseEngineSpec.convert_dttm("DATE", dttm) == "toDate('2019-01-02')"
assert (
ClickHouseEngineSpec.convert_dttm("DATETIME", dttm)
== "toDateTime('2019-01-02 03:04:05')"
)
assert_convert_dttm(spec, target_type, expected_result, dttm)
def test_execute_connection_error() -> None:
@ -38,7 +47,7 @@ def test_execute_connection_error() -> None:
from superset.db_engine_specs.clickhouse import ClickHouseEngineSpec
from superset.db_engine_specs.exceptions import SupersetDBAPIDatabaseError
cursor = mock.Mock()
cursor = Mock()
cursor.execute.side_effect = NewConnectionError(
"Dummypool", "Exception with sensitive data"
)

View File

@ -0,0 +1,71 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Optional
import pytest
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
def test_epoch_to_dttm() -> None:
"""
DB Eng Specs (crate): Test epoch to dttm
"""
from superset.db_engine_specs.crate import CrateEngineSpec
assert CrateEngineSpec.epoch_to_dttm() == "{col} * 1000"
def test_epoch_ms_to_dttm() -> None:
"""
DB Eng Specs (crate): Test epoch ms to dttm
"""
from superset.db_engine_specs.crate import CrateEngineSpec
assert CrateEngineSpec.epoch_ms_to_dttm() == "{col}"
def test_alter_new_orm_column() -> None:
"""
DB Eng Specs (crate): Test alter orm column
"""
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.db_engine_specs.crate import CrateEngineSpec
from superset.models.core import Database
database = Database(database_name="crate", sqlalchemy_uri="crate://db")
tbl = SqlaTable(table_name="tbl", database=database)
col = TableColumn(column_name="ts", type="TIMESTAMP", table=tbl)
CrateEngineSpec.alter_new_orm_column(col)
assert col.python_date_format == "epoch_ms"
@pytest.mark.parametrize(
"target_type,expected_result",
[
("TimeStamp", "1546398245678.9"),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.crate import CrateEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)

View File

@ -17,13 +17,16 @@
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
import json
from datetime import datetime
from typing import Optional
import pytest
from pytest_mock import MockerFixture
from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
def test_get_parameters_from_uri() -> None:
@ -109,37 +112,6 @@ def test_parameters_json_schema() -> None:
}
def test_generic_type() -> None:
"""
assert that generic types match
"""
from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import assert_generic_types
type_expectations = (
# Numeric
("SMALLINT", GenericDataType.NUMERIC),
("INTEGER", GenericDataType.NUMERIC),
("BIGINT", GenericDataType.NUMERIC),
("DECIMAL", GenericDataType.NUMERIC),
("NUMERIC", GenericDataType.NUMERIC),
("REAL", GenericDataType.NUMERIC),
("DOUBLE PRECISION", GenericDataType.NUMERIC),
("MONEY", GenericDataType.NUMERIC),
# String
("CHAR", GenericDataType.STRING),
("VARCHAR", GenericDataType.STRING),
("TEXT", GenericDataType.STRING),
# Temporal
("DATE", GenericDataType.TEMPORAL),
("TIMESTAMP", GenericDataType.TEMPORAL),
("TIME", GenericDataType.TEMPORAL),
# Boolean
("BOOLEAN", GenericDataType.BOOLEAN),
)
assert_generic_types(DatabricksNativeEngineSpec, type_expectations)
def test_get_extra_params(mocker: MockerFixture) -> None:
"""
Test the ``get_extra_params`` method.
@ -253,3 +225,22 @@ def test_extract_errors_with_context() -> None:
},
)
]
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "CAST('2019-01-02' AS DATE)"),
(
"TimeStamp",
"CAST('2019-01-02 03:04:05.678900' AS TIMESTAMP)",
),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)

View File

@ -14,20 +14,29 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.db_engine_specs.dremio import DremioEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from datetime import datetime
from typing import Optional
import pytest
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
class TestDremioDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
DremioEngineSpec.convert_dttm("DATE", dttm),
"TO_DATE('2019-01-02', 'YYYY-MM-DD')",
)
self.assertEqual(
DremioEngineSpec.convert_dttm("TIMESTAMP", dttm),
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "TO_DATE('2019-01-02', 'YYYY-MM-DD')"),
(
"TimeStamp",
"TO_TIMESTAMP('2019-01-02 03:04:05.678', 'YYYY-MM-DD HH24:MI:SS.FFF')",
)
),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.dremio import DremioEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)

View File

@ -16,7 +16,13 @@
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
from pytest import raises
from datetime import datetime
from typing import Optional
import pytest
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
def test_odbc_impersonation() -> None:
@ -82,5 +88,21 @@ def test_invalid_impersonation() -> None:
url = URL("drill+foobar")
username = "DoAsUser"
with raises(SupersetDBAPIProgrammingError):
with pytest.raises(SupersetDBAPIProgrammingError):
DrillEngineSpec.get_url_for_impersonation(url, True, username)
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "TO_DATE('2019-01-02', 'yyyy-MM-dd')"),
("TimeStamp", "TO_TIMESTAMP('2019-01-02 03:04:05', 'yyyy-MM-dd HH:mm:ss')"),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.drill import DrillEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)

View File

@ -0,0 +1,95 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Optional
from unittest import mock
import pytest
from sqlalchemy import column
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "CAST(TIME_PARSE('2019-01-02') AS DATE)"),
("DateTime", "TIME_PARSE('2019-01-02T03:04:05')"),
("TimeStamp", "TIME_PARSE('2019-01-02T03:04:05')"),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.druid import DruidEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)
@pytest.mark.parametrize(
"time_grain,expected_result",
[
("PT1S", "TIME_FLOOR(CAST(col AS TIMESTAMP), 'PT1S')"),
("PT5M", "TIME_FLOOR(CAST({col} AS TIMESTAMP), 'PT5M')"),
(
"P1W/1970-01-03T00:00:00Z",
"TIME_SHIFT(TIME_FLOOR(TIME_SHIFT(CAST(col AS TIMESTAMP), 'P1D', 1), 'P1W'), 'P1D', 5)",
),
(
"1969-12-28T00:00:00Z/P1W",
"TIME_SHIFT(TIME_FLOOR(TIME_SHIFT(CAST(col AS TIMESTAMP), 'P1D', 1), 'P1W'), 'P1D', -1)",
),
],
)
def test_timegrain_expressions(time_grain: str, expected_result: str) -> None:
"""
DB Eng Specs (druid): Test time grain expressions
"""
from superset.db_engine_specs.druid import DruidEngineSpec
assert str(
DruidEngineSpec.get_timestamp_expr(
col=column("col"), pdf=None, time_grain=time_grain
)
)
def test_extras_without_ssl() -> None:
from superset.db_engine_specs.druid import DruidEngineSpec
from tests.integration_tests.fixtures.database import default_db_extra
db = mock.Mock()
db.extra = default_db_extra
db.server_cert = None
extras = DruidEngineSpec.get_extra_params(db)
assert "connect_args" not in extras["engine_params"]
def test_extras_with_ssl() -> None:
from superset.db_engine_specs.druid import DruidEngineSpec
from tests.integration_tests.fixtures.certificates import ssl_certificate
from tests.integration_tests.fixtures.database import default_db_extra
db = mock.Mock()
db.extra = default_db_extra
db.server_cert = ssl_certificate
extras = DruidEngineSpec.get_extra_params(db)
connect_args = extras["engine_params"]["connect_args"]
assert connect_args["scheme"] == "https"
assert "ssl_verify_cert" in connect_args

View File

@ -14,20 +14,27 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.db_engine_specs.drill import DrillEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from datetime import datetime
from typing import Optional
import pytest
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
class TestDrillDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
dttm = self.get_dttm()
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Text", "'2019-01-02 03:04:05.678900'"),
("DateTime", "'2019-01-02 03:04:05.678900'"),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.duckdb import DuckDBEngineSpec as spec
self.assertEqual(
DrillEngineSpec.convert_dttm("DATE", dttm),
"TO_DATE('2019-01-02', 'yyyy-MM-dd')",
)
self.assertEqual(
DrillEngineSpec.convert_dttm("TIMESTAMP", dttm),
"TO_TIMESTAMP('2019-01-02 03:04:05', 'yyyy-MM-dd HH:mm:ss')",
)
assert_convert_dttm(spec, target_type, expected_result, dttm)

View File

@ -14,24 +14,27 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from datetime import datetime
from typing import Optional
import pytest
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
def test_convert_dttm(dttm: datetime) -> None:
from superset.db_engine_specs.dynamodb import DynamoDBEngineSpec
@pytest.mark.parametrize(
"target_type,expected_result",
[
("text", "'2019-01-02 03:04:05'"),
("dateTime", "'2019-01-02 03:04:05'"),
("unknowntype", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.dynamodb import DynamoDBEngineSpec as spec
assert DynamoDBEngineSpec.convert_dttm("TEXT", dttm) == "'2019-01-02 03:04:05'"
def test_convert_dttm_lower(dttm: datetime) -> None:
from superset.db_engine_specs.dynamodb import DynamoDBEngineSpec
assert DynamoDBEngineSpec.convert_dttm("text", dttm) == "'2019-01-02 03:04:05'"
def test_convert_dttm_invalid_type(dttm: datetime) -> None:
from superset.db_engine_specs.dynamodb import DynamoDBEngineSpec
assert DynamoDBEngineSpec.convert_dttm("other", dttm) is None
assert_convert_dttm(spec, target_type, expected_result, dttm)

View File

@ -0,0 +1,106 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Any, Dict, Optional
from unittest.mock import MagicMock
import pytest
from sqlalchemy import column
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
@pytest.mark.parametrize(
"target_type,db_extra,expected_result",
[
("DateTime", None, "CAST('2019-01-02T03:04:05' AS DATETIME)"),
(
"DateTime",
{"version": "7.7"},
"CAST('2019-01-02T03:04:05' AS DATETIME)",
),
(
"DateTime",
{"version": "7.8"},
"DATETIME_PARSE('2019-01-02 03:04:05', 'yyyy-MM-dd HH:mm:ss')",
),
(
"DateTime",
{"version": "unparseable semver version"},
"CAST('2019-01-02T03:04:05' AS DATETIME)",
),
("Unknown", None, None),
],
)
def test_elasticsearch_convert_dttm(
target_type: str,
db_extra: Optional[Dict[str, Any]],
expected_result: Optional[str],
dttm: datetime,
) -> None:
from superset.db_engine_specs.elasticsearch import ElasticSearchEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm, db_extra)
@pytest.mark.parametrize(
"target_type,expected_result",
[
("DateTime", "'2019-01-02T03:04:05'"),
("Unknown", None),
],
)
def test_opendistro_convert_dttm(
target_type: str,
expected_result: Optional[str],
dttm: datetime,
) -> None:
from superset.db_engine_specs.elasticsearch import OpenDistroEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)
@pytest.mark.parametrize(
"original,expected",
[
("Col", "Col"),
("Col.keyword", "Col_keyword"),
],
)
def test_opendistro_sqla_column_label(original: str, expected: str) -> None:
"""
DB Eng Specs (opendistro): Test column label
"""
from superset.db_engine_specs.elasticsearch import OpenDistroEngineSpec
assert OpenDistroEngineSpec.make_label_compatible(original) == expected
def test_opendistro_strip_comments() -> None:
"""
DB Eng Specs (opendistro): Test execute sql strip comments
"""
from superset.db_engine_specs.elasticsearch import OpenDistroEngineSpec
mock_cursor = MagicMock()
mock_cursor.execute.return_value = []
OpenDistroEngineSpec.execute(
mock_cursor, "-- some comment \nSELECT 1\n --other comment"
)
mock_cursor.execute.assert_called_once_with("SELECT 1\n")

View File

@ -0,0 +1,102 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Optional
import pytest
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
@pytest.mark.parametrize(
"time_grain,expected",
[
(None, "timestamp_column"),
(
"PT1S",
(
"CAST(CAST(timestamp_column AS DATE) "
"|| ' ' "
"|| EXTRACT(HOUR FROM timestamp_column) "
"|| ':' "
"|| EXTRACT(MINUTE FROM timestamp_column) "
"|| ':' "
"|| FLOOR(EXTRACT(SECOND FROM timestamp_column)) AS TIMESTAMP)"
),
),
(
"PT1M",
(
"CAST(CAST(timestamp_column AS DATE) "
"|| ' ' "
"|| EXTRACT(HOUR FROM timestamp_column) "
"|| ':' "
"|| EXTRACT(MINUTE FROM timestamp_column) "
"|| ':00' AS TIMESTAMP)"
),
),
("P1D", "CAST(timestamp_column AS DATE)"),
(
"P1M",
(
"CAST(EXTRACT(YEAR FROM timestamp_column) "
"|| '-' "
"|| EXTRACT(MONTH FROM timestamp_column) "
"|| '-01' AS DATE)"
),
),
("P1Y", "CAST(EXTRACT(YEAR FROM timestamp_column) || '-01-01' AS DATE)"),
],
)
def test_time_grain_expressions(time_grain: Optional[str], expected: str) -> None:
from superset.db_engine_specs.firebird import FirebirdEngineSpec
assert (
FirebirdEngineSpec._time_grain_expressions[time_grain].format(
col="timestamp_column",
)
== expected
)
def test_epoch_to_dttm() -> None:
from superset.db_engine_specs.firebird import FirebirdEngineSpec
assert (
FirebirdEngineSpec.epoch_to_dttm().format(col="timestamp_column")
== "DATEADD(second, timestamp_column, CAST('00:00:00' AS TIMESTAMP))"
)
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "CAST('2019-01-02' AS DATE)"),
("DateTime", "CAST('2019-01-02 03:04:05.6789' AS TIMESTAMP)"),
("TimeStamp", "CAST('2019-01-02 03:04:05.6789' AS TIMESTAMP)"),
("Time", "CAST('03:04:05.678900' AS TIME)"),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.firebird import FirebirdEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)

View File

@ -0,0 +1,57 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Optional
import pytest
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "CAST('2019-01-02' AS DATE)"),
(
"DateTime",
"CAST('2019-01-02T03:04:05' AS DATETIME)",
),
(
"TimeStamp",
"CAST('2019-01-02T03:04:05' AS TIMESTAMP)",
),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.firebolt import FireboltEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)
def test_epoch_to_dttm() -> None:
from superset.db_engine_specs.firebolt import FireboltEngineSpec
assert (
FireboltEngineSpec.epoch_to_dttm().format(col="timestamp_column")
== "from_unixtime(timestamp_column)"
)

View File

@ -14,20 +14,30 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.db_engine_specs.hana import HanaEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from datetime import datetime
from typing import Optional
import pytest
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
class TestHanaDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
dttm = self.get_dttm()
self.assertEqual(
HanaEngineSpec.convert_dttm("DATE", dttm),
"TO_DATE('2019-01-02', 'YYYY-MM-DD')",
)
self.assertEqual(
HanaEngineSpec.convert_dttm("TIMESTAMP", dttm),
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "TO_DATE('2019-01-02', 'YYYY-MM-DD')"),
(
"TimeStamp",
"TO_TIMESTAMP('2019-01-02T03:04:05.678900', 'YYYY-MM-DD\"T\"HH24:MI:SS.ff6')",
)
),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.hana import HanaEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)

View File

@ -0,0 +1,44 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Optional
import pytest
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "CAST('2019-01-02' AS DATE)"),
(
"TimeStamp",
"CAST('2019-01-02 03:04:05.678900' AS TIMESTAMP)",
),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.hive import HiveEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)

View File

@ -0,0 +1,40 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Optional
import pytest
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "CAST('2019-01-02' AS DATE)"),
("TimeStamp", "CAST('2019-01-02T03:04:05.678900' AS TIMESTAMP)"),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.impala import ImpalaEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)

View File

@ -16,9 +16,11 @@
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
from datetime import datetime
from typing import Optional
import pytest
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
@ -108,45 +110,35 @@ def test_kql_parse_sql() -> None:
@pytest.mark.parametrize(
"target_type,expected_dttm",
"target_type,expected_result",
[
("DATETIME", "datetime(2019-01-02T03:04:05.678900)"),
("TIMESTAMP", "datetime(2019-01-02T03:04:05.678900)"),
("DATE", "datetime(2019-01-02)"),
("DateTime", "datetime(2019-01-02T03:04:05.678900)"),
("TimeStamp", "datetime(2019-01-02T03:04:05.678900)"),
("Date", "datetime(2019-01-02)"),
("UnknownType", None),
],
)
def test_kql_convert_dttm(
target_type: str,
expected_dttm: str,
dttm: datetime,
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
"""
Test that date objects are converted correctly.
"""
from superset.db_engine_specs.kusto import KustoKqlEngineSpec as spec
from superset.db_engine_specs.kusto import KustoKqlEngineSpec
assert expected_dttm == KustoKqlEngineSpec.convert_dttm(target_type, dttm)
assert_convert_dttm(spec, target_type, expected_result, dttm)
@pytest.mark.parametrize(
"target_type,expected_dttm",
"target_type,expected_result",
[
("DATETIME", "CONVERT(DATETIME, '2019-01-02T03:04:05.678', 126)"),
("DATE", "CONVERT(DATE, '2019-01-02', 23)"),
("SMALLDATETIME", "CONVERT(SMALLDATETIME, '2019-01-02 03:04:05', 20)"),
("TIMESTAMP", "CONVERT(TIMESTAMP, '2019-01-02 03:04:05', 20)"),
("Date", "CONVERT(DATE, '2019-01-02', 23)"),
("DateTime", "CONVERT(DATETIME, '2019-01-02T03:04:05.678', 126)"),
("SmallDateTime", "CONVERT(SMALLDATETIME, '2019-01-02 03:04:05', 20)"),
("TimeStamp", "CONVERT(TIMESTAMP, '2019-01-02 03:04:05', 20)"),
("UnknownType", None),
],
)
def test_sql_convert_dttm(
target_type: str,
expected_dttm: str,
dttm: datetime,
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
"""
Test that date objects are converted correctly.
"""
from superset.db_engine_specs.kusto import KustoSqlEngineSpec as spec
from superset.db_engine_specs.kusto import KustoSqlEngineSpec
assert expected_dttm == KustoSqlEngineSpec.convert_dttm(target_type, dttm)
assert_convert_dttm(spec, target_type, expected_result, dttm)

View File

@ -14,19 +14,27 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.db_engine_specs.impala import ImpalaEngineSpec
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from datetime import datetime
from typing import Optional
import pytest
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
class TestImpalaDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
dttm = self.get_dttm()
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "CAST('2019-01-02' AS DATE)"),
("TimeStamp", "CAST('2019-01-02 03:04:05' AS TIMESTAMP)"),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.kylin import KylinEngineSpec as spec
self.assertEqual(
ImpalaEngineSpec.convert_dttm("DATE", dttm), "CAST('2019-01-02' AS DATE)"
)
self.assertEqual(
ImpalaEngineSpec.convert_dttm("TIMESTAMP", dttm),
"CAST('2019-01-02T03:04:05.678900' AS TIMESTAMP)",
)
assert_convert_dttm(spec, target_type, expected_result, dttm)

View File

@ -17,9 +17,10 @@
import unittest.mock as mock
from datetime import datetime
from textwrap import dedent
from typing import Any, Dict, Optional, Type
import pytest
from sqlalchemy import column, table
from sqlalchemy import column, table, types
from sqlalchemy.dialects import mssql
from sqlalchemy.dialects.mssql import DATE, NTEXT, NVARCHAR, TEXT, VARCHAR
from sqlalchemy.sql import select
@ -27,36 +28,36 @@ from sqlalchemy.types import String, TypeEngine, UnicodeText
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
assert_column_spec,
assert_convert_dttm,
)
from tests.unit_tests.fixtures.common import dttm
@pytest.mark.parametrize(
"type_string,type_expected,generic_type_expected",
"native_type,sqla_type,attrs,generic_type,is_dttm",
[
("STRING", String, GenericDataType.STRING),
("CHAR(10)", String, GenericDataType.STRING),
("VARCHAR(10)", String, GenericDataType.STRING),
("TEXT", String, GenericDataType.STRING),
("NCHAR(10)", UnicodeText, GenericDataType.STRING),
("NVARCHAR(10)", UnicodeText, GenericDataType.STRING),
("NTEXT", UnicodeText, GenericDataType.STRING),
("CHAR", String, None, GenericDataType.STRING, False),
("CHAR(10)", String, None, GenericDataType.STRING, False),
("VARCHAR", String, None, GenericDataType.STRING, False),
("VARCHAR(10)", String, None, GenericDataType.STRING, False),
("TEXT", String, None, GenericDataType.STRING, False),
("NCHAR(10)", UnicodeText, None, GenericDataType.STRING, False),
("NVARCHAR(10)", UnicodeText, None, GenericDataType.STRING, False),
("NTEXT", UnicodeText, None, GenericDataType.STRING, False),
],
)
def test_mssql_column_types(
type_string: str,
type_expected: TypeEngine,
generic_type_expected: GenericDataType,
def test_get_column_spec(
native_type: str,
sqla_type: Type[types.TypeEngine],
attrs: Optional[Dict[str, Any]],
generic_type: GenericDataType,
is_dttm: bool,
) -> None:
from superset.db_engine_specs.mssql import MssqlEngineSpec
from superset.db_engine_specs.mssql import MssqlEngineSpec as spec
if type_expected is None:
type_assigned = MssqlEngineSpec.get_sqla_column_type(type_string)
assert type_assigned is None
else:
column_spec = MssqlEngineSpec.get_column_spec(type_string)
if column_spec is not None:
assert isinstance(column_spec.sqla_type, type_expected)
assert column_spec.generic_type == generic_type_expected
assert_column_spec(spec, native_type, sqla_type, attrs, generic_type, is_dttm)
def test_where_clause_n_prefix() -> None:
@ -65,13 +66,13 @@ def test_where_clause_n_prefix() -> None:
dialect = mssql.dialect()
# non-unicode col
sqla_column_type = MssqlEngineSpec.get_sqla_column_type("VARCHAR(10)")
sqla_column_type = MssqlEngineSpec.get_column_types("VARCHAR(10)")
assert sqla_column_type is not None
type_, _ = sqla_column_type
str_col = column("col", type_=type_)
# unicode col
sqla_column_type = MssqlEngineSpec.get_sqla_column_type("NTEXT")
sqla_column_type = MssqlEngineSpec.get_column_types("NTEXT")
assert sqla_column_type is not None
type_, _ = sqla_column_type
unicode_col = column("unicode_col", type_=type_)
@ -103,30 +104,31 @@ def test_time_exp_mixd_case_col_1y() -> None:
@pytest.mark.parametrize(
"actual,expected",
"target_type,expected_result",
[
(
"DATE",
"date",
"CONVERT(DATE, '2019-01-02', 23)",
),
(
"DATETIME",
"datetime",
"CONVERT(DATETIME, '2019-01-02T03:04:05.678', 126)",
),
(
"SMALLDATETIME",
"smalldatetime",
"CONVERT(SMALLDATETIME, '2019-01-02 03:04:05', 20)",
),
("Other", None),
],
)
def test_convert_dttm(
actual: str,
expected: str,
target_type: str,
expected_result: Optional[str],
dttm: datetime,
) -> None:
from superset.db_engine_specs.mssql import MssqlEngineSpec
from superset.db_engine_specs.mssql import MssqlEngineSpec as spec
assert MssqlEngineSpec.convert_dttm(actual, dttm) == expected
assert_convert_dttm(spec, target_type, expected_result, dttm)
def test_extract_error_message() -> None:

View File

@ -0,0 +1,130 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Any, Dict, Optional, Type
from unittest.mock import Mock, patch
import pytest
from sqlalchemy import types
from sqlalchemy.dialects.mysql import (
BIT,
DECIMAL,
DOUBLE,
FLOAT,
INTEGER,
LONGTEXT,
MEDIUMINT,
MEDIUMTEXT,
TINYINT,
TINYTEXT,
)
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
assert_column_spec,
assert_convert_dttm,
)
from tests.unit_tests.fixtures.common import dttm
@pytest.mark.parametrize(
"native_type,sqla_type,attrs,generic_type,is_dttm",
[
# Numeric
("TINYINT", TINYINT, None, GenericDataType.NUMERIC, False),
("SMALLINT", types.SmallInteger, None, GenericDataType.NUMERIC, False),
("MEDIUMINT", MEDIUMINT, None, GenericDataType.NUMERIC, False),
("INT", INTEGER, None, GenericDataType.NUMERIC, False),
("BIGINT", types.BigInteger, None, GenericDataType.NUMERIC, False),
("DECIMAL", DECIMAL, None, GenericDataType.NUMERIC, False),
("FLOAT", FLOAT, None, GenericDataType.NUMERIC, False),
("DOUBLE", DOUBLE, None, GenericDataType.NUMERIC, False),
("BIT", BIT, None, GenericDataType.NUMERIC, False),
# String
("CHAR", types.String, None, GenericDataType.STRING, False),
("VARCHAR", types.String, None, GenericDataType.STRING, False),
("TINYTEXT", TINYTEXT, None, GenericDataType.STRING, False),
("MEDIUMTEXT", MEDIUMTEXT, None, GenericDataType.STRING, False),
("LONGTEXT", LONGTEXT, None, GenericDataType.STRING, False),
# Temporal
("DATE", types.Date, None, GenericDataType.TEMPORAL, True),
("DATETIME", types.DateTime, None, GenericDataType.TEMPORAL, True),
("TIMESTAMP", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True),
("TIME", types.Time, None, GenericDataType.TEMPORAL, True),
],
)
def test_get_column_spec(
native_type: str,
sqla_type: Type[types.TypeEngine],
attrs: Optional[Dict[str, Any]],
generic_type: GenericDataType,
is_dttm: bool,
) -> None:
from superset.db_engine_specs.mysql import MySQLEngineSpec as spec
assert_column_spec(spec, native_type, sqla_type, attrs, generic_type, is_dttm)
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "STR_TO_DATE('2019-01-02', '%Y-%m-%d')"),
(
"DateTime",
"STR_TO_DATE('2019-01-02 03:04:05.678900', '%Y-%m-%d %H:%i:%s.%f')",
),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.mysql import MySQLEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)
@patch("sqlalchemy.engine.Engine.connect")
def test_get_cancel_query_id(engine_mock: Mock) -> None:
from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.models.sql_lab import Query
query = Query()
cursor_mock = engine_mock.return_value.__enter__.return_value
cursor_mock.fetchone.return_value = ["123"]
assert MySQLEngineSpec.get_cancel_query_id(cursor_mock, query) == "123"
@patch("sqlalchemy.engine.Engine.connect")
def test_cancel_query(engine_mock: Mock) -> None:
from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.models.sql_lab import Query
query = Query()
cursor_mock = engine_mock.return_value.__enter__.return_value
assert MySQLEngineSpec.cancel_query(cursor_mock, query, "123") is True
@patch("sqlalchemy.engine.Engine.connect")
def test_cancel_query_failed(engine_mock: Mock) -> None:
from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.models.sql_lab import Query
query = Query()
cursor_mock = engine_mock.raiseError.side_effect = Exception()
assert MySQLEngineSpec.cancel_query(cursor_mock, query, "123") is False

View File

@ -0,0 +1,113 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Optional, Union
from unittest import mock
import pytest
from sqlalchemy import column, types
from sqlalchemy.dialects import oracle
from sqlalchemy.dialects.oracle import DATE, NVARCHAR, VARCHAR
from sqlalchemy.sql import quoted_name
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
@pytest.mark.parametrize(
"column_name,expected_result",
[
("This_Is_32_Character_Column_Name", "3b26974078683be078219674eeb8f5"),
("snake_label", "snake_label"),
("camelLabel", "camelLabel"),
],
)
def test_oracle_sqla_column_name_length_exceeded(
column_name: str, expected_result: Union[str, quoted_name]
) -> None:
from superset.db_engine_specs.oracle import OracleEngineSpec
label = OracleEngineSpec.make_label_compatible(column_name)
assert isinstance(label, quoted_name)
assert label.quote is True
assert label == expected_result
def test_oracle_time_expression_reserved_keyword_1m_grain() -> None:
from superset.db_engine_specs.oracle import OracleEngineSpec
col = column("decimal")
expr = OracleEngineSpec.get_timestamp_expr(col, None, "P1M")
result = str(expr.compile(dialect=oracle.dialect()))
assert result == "TRUNC(CAST(\"decimal\" as DATE), 'MONTH')"
@pytest.mark.parametrize(
"sqla_type,expected_result",
[
(DATE(), "DATE"),
(VARCHAR(length=255), "VARCHAR(255 CHAR)"),
(VARCHAR(length=255, collation="utf8"), "VARCHAR(255 CHAR)"),
(NVARCHAR(length=128), "NVARCHAR2(128)"),
],
)
def test_column_datatype_to_string(
sqla_type: types.TypeEngine, expected_result: str
) -> None:
from superset.db_engine_specs.oracle import OracleEngineSpec
assert (
OracleEngineSpec.column_datatype_to_string(sqla_type, oracle.dialect())
== expected_result
)
def test_fetch_data_no_description() -> None:
from superset.db_engine_specs.oracle import OracleEngineSpec
cursor = mock.MagicMock()
cursor.description = []
assert OracleEngineSpec.fetch_data(cursor) == []
def test_fetch_data() -> None:
from superset.db_engine_specs.oracle import OracleEngineSpec
cursor = mock.MagicMock()
result = ["a", "b"]
cursor.fetchall.return_value = result
assert OracleEngineSpec.fetch_data(cursor) == result
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "TO_DATE('2019-01-02', 'YYYY-MM-DD')"),
("DateTime", """TO_DATE('2019-01-02T03:04:05', 'YYYY-MM-DD"T"HH24:MI:SS')"""),
(
"TimeStamp",
"""TO_TIMESTAMP('2019-01-02T03:04:05.678900', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')""",
),
("Other", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.oracle import OracleEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)

View File

@ -0,0 +1,91 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Any, Dict, Optional, Type
import pytest
from sqlalchemy import types
from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
assert_column_spec,
assert_convert_dttm,
)
from tests.unit_tests.fixtures.common import dttm
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "TO_DATE('2019-01-02', 'YYYY-MM-DD')"),
(
"DateTime",
"TO_TIMESTAMP('2019-01-02 03:04:05.678900', 'YYYY-MM-DD HH24:MI:SS.US')",
),
(
"TimeStamp",
"TO_TIMESTAMP('2019-01-02 03:04:05.678900', 'YYYY-MM-DD HH24:MI:SS.US')",
),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.postgres import PostgresEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)
@pytest.mark.parametrize(
"native_type,sqla_type,attrs,generic_type,is_dttm",
[
("SMALLINT", types.SmallInteger, None, GenericDataType.NUMERIC, False),
("INTEGER", types.Integer, None, GenericDataType.NUMERIC, False),
("BIGINT", types.BigInteger, None, GenericDataType.NUMERIC, False),
("DECIMAL", types.Numeric, None, GenericDataType.NUMERIC, False),
("NUMERIC", types.Numeric, None, GenericDataType.NUMERIC, False),
("REAL", types.REAL, None, GenericDataType.NUMERIC, False),
("DOUBLE PRECISION", DOUBLE_PRECISION, None, GenericDataType.NUMERIC, False),
("MONEY", types.Numeric, None, GenericDataType.NUMERIC, False),
# String
("CHAR", types.String, None, GenericDataType.STRING, False),
("VARCHAR", types.String, None, GenericDataType.STRING, False),
("TEXT", types.String, None, GenericDataType.STRING, False),
("ARRAY", types.String, None, GenericDataType.STRING, False),
("ENUM", ENUM, None, GenericDataType.STRING, False),
("JSON", JSON, None, GenericDataType.STRING, False),
# Temporal
("DATE", types.Date, None, GenericDataType.TEMPORAL, True),
("TIMESTAMP", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True),
("TIME", types.Time, None, GenericDataType.TEMPORAL, True),
# Boolean
("BOOLEAN", types.Boolean, None, GenericDataType.BOOLEAN, False),
],
)
def test_get_column_spec(
native_type: str,
sqla_type: Type[types.TypeEngine],
attrs: Optional[Dict[str, Any]],
generic_type: GenericDataType,
is_dttm: bool,
) -> None:
from superset.db_engine_specs.postgres import PostgresEngineSpec as spec
assert_column_spec(spec, native_type, sqla_type, attrs, generic_type, is_dttm)

View File

@ -15,14 +15,21 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Optional
from typing import Any, Dict, Optional, Type
import pytest
import pytz
from sqlalchemy import types
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
assert_column_spec,
assert_convert_dttm,
)
@pytest.mark.parametrize(
"target_type,dttm,result",
"target_type,dttm,expected_result",
[
("VARCHAR", datetime(2022, 1, 1), None),
("DATE", datetime(2022, 1, 1), "DATE '2022-01-01'"),
@ -46,9 +53,32 @@ import pytz
def test_convert_dttm(
target_type: str,
dttm: datetime,
result: Optional[str],
expected_result: Optional[str],
) -> None:
from superset.db_engine_specs.presto import PrestoEngineSpec
from superset.db_engine_specs.presto import PrestoEngineSpec as spec
for case in (str.lower, str.upper):
assert PrestoEngineSpec.convert_dttm(case(target_type), dttm) == result
assert_convert_dttm(spec, target_type, expected_result, dttm)
@pytest.mark.parametrize(
"native_type,sqla_type,attrs,generic_type,is_dttm",
[
("varchar(255)", types.VARCHAR, {"length": 255}, GenericDataType.STRING, False),
("varchar", types.String, None, GenericDataType.STRING, False),
("char(255)", types.CHAR, {"length": 255}, GenericDataType.STRING, False),
("char", types.String, None, GenericDataType.STRING, False),
("integer", types.Integer, None, GenericDataType.NUMERIC, False),
("time", types.Time, None, GenericDataType.TEMPORAL, True),
("timestamp", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True),
],
)
def test_get_column_spec(
native_type: str,
sqla_type: Type[types.TypeEngine],
attrs: Optional[Dict[str, Any]],
generic_type: GenericDataType,
is_dttm: bool,
) -> None:
from superset.db_engine_specs.presto import PrestoEngineSpec as spec
assert_column_spec(spec, native_type, sqla_type, attrs, generic_type, is_dttm)

View File

@ -0,0 +1,41 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Optional
import pytest
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Date", "DATE '2019-01-02'"),
("DateTime", "DATETIME '2019-01-02 03:04:05.678900'"),
("Timestamp", "TIMESTAMP '2019-01-02T03:04:05.678900'"),
("UnknownType", None),
],
)
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.rockset import RocksetEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)

View File

@ -19,28 +19,43 @@
import json
from datetime import datetime
from typing import Optional
from unittest import mock
import pytest
from pytest_mock import MockerFixture
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
@pytest.mark.parametrize(
"actual,expected",
"target_type,expected_result",
[
("DATE", "TO_DATE('2019-01-02')"),
("DATETIME", "CAST('2019-01-02T03:04:05.678900' AS DATETIME)"),
("TIMESTAMP", "TO_TIMESTAMP('2019-01-02T03:04:05.678900')"),
("Date", "TO_DATE('2019-01-02')"),
("DateTime", "CAST('2019-01-02T03:04:05.678900' AS DATETIME)"),
("TimeStamp", "TO_TIMESTAMP('2019-01-02T03:04:05.678900')"),
("TIMESTAMP_NTZ", "TO_TIMESTAMP('2019-01-02T03:04:05.678900')"),
("TIMESTAMP_LTZ", "TO_TIMESTAMP('2019-01-02T03:04:05.678900')"),
("TIMESTAMP_TZ", "TO_TIMESTAMP('2019-01-02T03:04:05.678900')"),
("TIMESTAMPLTZ", "TO_TIMESTAMP('2019-01-02T03:04:05.678900')"),
("TIMESTAMPNTZ", "TO_TIMESTAMP('2019-01-02T03:04:05.678900')"),
("TIMESTAMPTZ", "TO_TIMESTAMP('2019-01-02T03:04:05.678900')"),
(
"TIMESTAMP WITH LOCAL TIME ZONE",
"TO_TIMESTAMP('2019-01-02T03:04:05.678900')",
),
("TIMESTAMP WITHOUT TIME ZONE", "TO_TIMESTAMP('2019-01-02T03:04:05.678900')"),
("UnknownType", None),
],
)
def test_convert_dttm(actual: str, expected: str, dttm: datetime) -> None:
from superset.db_engine_specs.snowflake import SnowflakeEngineSpec
def test_convert_dttm(
target_type: str, expected_result: Optional[str], dttm: datetime
) -> None:
from superset.db_engine_specs.snowflake import SnowflakeEngineSpec as spec
assert SnowflakeEngineSpec.convert_dttm(actual, dttm) == expected
assert_convert_dttm(spec, target_type, expected_result, dttm)
def test_database_connection_test_mutator() -> None:

View File

@ -16,30 +16,32 @@
# under the License.
# pylint: disable=invalid-name, unused-argument, import-outside-toplevel, redefined-outer-name
from datetime import datetime
from unittest import mock
from typing import Optional
import pytest
from sqlalchemy.engine import create_engine
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
def test_convert_dttm(dttm: datetime) -> None:
from superset.db_engine_specs.sqlite import SqliteEngineSpec
@pytest.mark.parametrize(
"target_type,expected_result",
[
("Text", "'2019-01-02 03:04:05'"),
("DateTime", "'2019-01-02 03:04:05'"),
("TimeStamp", "'2019-01-02 03:04:05'"),
("Other", None),
],
)
def test_convert_dttm(
target_type: str,
expected_result: Optional[str],
dttm: datetime,
) -> None:
from superset.db_engine_specs.sqlite import SqliteEngineSpec as spec
assert SqliteEngineSpec.convert_dttm("TEXT", dttm) == "'2019-01-02 03:04:05'"
def test_convert_dttm_lower(dttm: datetime) -> None:
from superset.db_engine_specs.sqlite import SqliteEngineSpec
assert SqliteEngineSpec.convert_dttm("text", dttm) == "'2019-01-02 03:04:05'"
def test_convert_dttm_invalid_type(dttm: datetime) -> None:
from superset.db_engine_specs.sqlite import SqliteEngineSpec
assert SqliteEngineSpec.convert_dttm("other", dttm) is None
assert_convert_dttm(spec, target_type, expected_result, dttm)
@pytest.mark.parametrize(

View File

@ -16,17 +16,288 @@
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
import json
from typing import Any, Dict
from unittest import mock
from datetime import datetime
from typing import Any, Dict, Optional, Type
from unittest.mock import Mock, patch
import pandas as pd
import pytest
from pytest_mock import MockerFixture
from sqlalchemy import types
from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY
import superset.config
from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
assert_column_spec,
assert_convert_dttm,
)
from tests.unit_tests.fixtures.common import dttm
@mock.patch("sqlalchemy.engine.Engine.connect")
def test_cancel_query_success(engine_mock: mock.Mock) -> None:
@pytest.mark.parametrize(
"extra,expected",
[
({}, {"engine_params": {"connect_args": {"source": USER_AGENT}}}),
(
{
"first": 1,
"engine_params": {
"second": "two",
"connect_args": {"source": "foobar", "third": "three"},
},
},
{
"first": 1,
"engine_params": {
"second": "two",
"connect_args": {"source": "foobar", "third": "three"},
},
},
),
],
)
def test_get_extra_params(extra: Dict[str, Any], expected: Dict[str, Any]) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
database.extra = json.dumps(extra)
database.server_cert = None
assert TrinoEngineSpec.get_extra_params(database) == expected
@patch("superset.utils.core.create_ssl_cert_file")
def test_get_extra_params_with_server_cert(mock_create_ssl_cert_file: Mock) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
database.extra = json.dumps({})
database.server_cert = "TEST_CERT"
mock_create_ssl_cert_file.return_value = "/path/to/tls.crt"
extra = TrinoEngineSpec.get_extra_params(database)
connect_args = extra.get("engine_params", {}).get("connect_args", {})
assert connect_args.get("http_scheme") == "https"
assert connect_args.get("verify") == "/path/to/tls.crt"
mock_create_ssl_cert_file.assert_called_once_with(database.server_cert)
@patch("trino.auth.BasicAuthentication")
def test_auth_basic(mock_auth: Mock) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
auth_params = {"username": "username", "password": "password"}
database.encrypted_extra = json.dumps(
{"auth_method": "basic", "auth_params": auth_params}
)
params: Dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
assert connect_args.get("http_scheme") == "https"
mock_auth.assert_called_once_with(**auth_params)
@patch("trino.auth.KerberosAuthentication")
def test_auth_kerberos(mock_auth: Mock) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
auth_params = {
"service_name": "superset",
"mutual_authentication": False,
"delegate": True,
}
database.encrypted_extra = json.dumps(
{"auth_method": "kerberos", "auth_params": auth_params}
)
params: Dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
assert connect_args.get("http_scheme") == "https"
mock_auth.assert_called_once_with(**auth_params)
@patch("trino.auth.CertificateAuthentication")
def test_auth_certificate(mock_auth: Mock) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
auth_params = {"cert": "/path/to/cert.pem", "key": "/path/to/key.pem"}
database.encrypted_extra = json.dumps(
{"auth_method": "certificate", "auth_params": auth_params}
)
params: Dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
assert connect_args.get("http_scheme") == "https"
mock_auth.assert_called_once_with(**auth_params)
@patch("trino.auth.JWTAuthentication")
def test_auth_jwt(mock_auth: Mock) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
auth_params = {"token": "jwt-token-string"}
database.encrypted_extra = json.dumps(
{"auth_method": "jwt", "auth_params": auth_params}
)
params: Dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
assert connect_args.get("http_scheme") == "https"
mock_auth.assert_called_once_with(**auth_params)
def test_auth_custom_auth() -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
auth_class = Mock()
auth_method = "custom_auth"
auth_params = {"params1": "params1", "params2": "params2"}
database.encrypted_extra = json.dumps(
{"auth_method": auth_method, "auth_params": auth_params}
)
with patch.dict(
"superset.config.ALLOWED_EXTRA_AUTHENTICATIONS",
{"trino": {"custom_auth": auth_class}},
clear=True,
):
params: Dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
assert connect_args.get("http_scheme") == "https"
auth_class.assert_called_once_with(**auth_params)
def test_auth_custom_auth_denied() -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
auth_method = "my.module:TrinoAuthClass"
auth_params = {"params1": "params1", "params2": "params2"}
database.encrypted_extra = json.dumps(
{"auth_method": auth_method, "auth_params": auth_params}
)
superset.config.ALLOWED_EXTRA_AUTHENTICATIONS = {}
with pytest.raises(ValueError) as excinfo:
TrinoEngineSpec.update_params_from_encrypted_extra(database, {})
assert str(excinfo.value) == (
f"For security reason, custom authentication '{auth_method}' "
f"must be listed in 'ALLOWED_EXTRA_AUTHENTICATIONS' config"
)
@pytest.mark.parametrize(
"native_type,sqla_type,attrs,generic_type,is_dttm",
[
("BOOLEAN", types.Boolean, None, GenericDataType.BOOLEAN, False),
("TINYINT", types.Integer, None, GenericDataType.NUMERIC, False),
("SMALLINT", types.SmallInteger, None, GenericDataType.NUMERIC, False),
("INTEGER", types.Integer, None, GenericDataType.NUMERIC, False),
("BIGINT", types.BigInteger, None, GenericDataType.NUMERIC, False),
("REAL", types.FLOAT, None, GenericDataType.NUMERIC, False),
("DOUBLE", types.FLOAT, None, GenericDataType.NUMERIC, False),
("DECIMAL", types.DECIMAL, None, GenericDataType.NUMERIC, False),
("VARCHAR", types.String, None, GenericDataType.STRING, False),
("VARCHAR(20)", types.VARCHAR, {"length": 20}, GenericDataType.STRING, False),
("CHAR", types.String, None, GenericDataType.STRING, False),
("CHAR(2)", types.CHAR, {"length": 2}, GenericDataType.STRING, False),
("JSON", types.JSON, None, GenericDataType.STRING, False),
("TIMESTAMP", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True),
("TIMESTAMP(3)", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True),
(
"TIMESTAMP WITH TIME ZONE",
types.TIMESTAMP,
None,
GenericDataType.TEMPORAL,
True,
),
(
"TIMESTAMP(3) WITH TIME ZONE",
types.TIMESTAMP,
None,
GenericDataType.TEMPORAL,
True,
),
("DATE", types.Date, None, GenericDataType.TEMPORAL, True),
],
)
def test_get_column_spec(
native_type: str,
sqla_type: Type[types.TypeEngine],
attrs: Optional[Dict[str, Any]],
generic_type: GenericDataType,
is_dttm: bool,
) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec as spec
assert_column_spec(
spec,
native_type,
sqla_type,
attrs,
generic_type,
is_dttm,
)
@pytest.mark.parametrize(
"target_type,expected_result",
[
("TimeStamp", "TIMESTAMP '2019-01-02 03:04:05.678900'"),
("TimeStamp(3)", "TIMESTAMP '2019-01-02 03:04:05.678900'"),
("TimeStamp With Time Zone", "TIMESTAMP '2019-01-02 03:04:05.678900'"),
("TimeStamp(3) With Time Zone", "TIMESTAMP '2019-01-02 03:04:05.678900'"),
("Date", "DATE '2019-01-02'"),
("Other", None),
],
)
def test_convert_dttm(
target_type: str,
expected_result: Optional[str],
dttm: datetime,
) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
assert_convert_dttm(TrinoEngineSpec, target_type, expected_result, dttm)
def test_extra_table_metadata() -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
db_mock = Mock()
db_mock.get_indexes = Mock(
return_value=[{"column_names": ["ds", "hour"], "name": "partition"}]
)
db_mock.get_extra = Mock(return_value={})
db_mock.has_view_by_name = Mock(return_value=None)
db_mock.get_df = Mock(return_value=pd.DataFrame({"ds": ["01-01-19"], "hour": [1]}))
result = TrinoEngineSpec.extra_table_metadata(db_mock, "test_table", "test_schema")
assert result["partitions"]["cols"] == ["ds", "hour"]
assert result["partitions"]["latest"] == {"ds": "01-01-19", "hour": 1}
@patch("sqlalchemy.engine.Engine.connect")
def test_cancel_query_success(engine_mock: Mock) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.models.sql_lab import Query
@ -35,8 +306,8 @@ def test_cancel_query_success(engine_mock: mock.Mock) -> None:
assert TrinoEngineSpec.cancel_query(cursor_mock, query, "123") is True
@mock.patch("sqlalchemy.engine.Engine.connect")
def test_cancel_query_failed(engine_mock: mock.Mock) -> None:
@patch("sqlalchemy.engine.Engine.connect")
def test_cancel_query_failed(engine_mock: Mock) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.models.sql_lab import Query
@ -67,11 +338,11 @@ def test_prepare_cancel_query(
@pytest.mark.parametrize("cancel_early", [True, False])
@mock.patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query")
@mock.patch("sqlalchemy.engine.Engine.connect")
@patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query")
@patch("sqlalchemy.engine.Engine.connect")
def test_handle_cursor_early_cancel(
engine_mock: mock.Mock,
cancel_query_mock: mock.Mock,
engine_mock: Mock,
cancel_query_mock: Mock,
cancel_early: bool,
mocker: MockerFixture,
) -> None:

View File

@ -0,0 +1,67 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from datetime import datetime
from typing import Any, Dict, Optional, Type, TYPE_CHECKING
from sqlalchemy import types
from superset.utils.core import GenericDataType
if TYPE_CHECKING:
from superset.db_engine_specs.base import BaseEngineSpec
def assert_convert_dttm(
db_engine_spec: Type[BaseEngineSpec],
target_type: str,
expected_result: Optional[str],
dttm: datetime,
db_extra: Optional[Dict[str, Any]] = None,
) -> None:
for target in (
target_type,
target_type.upper(),
target_type.lower(),
target_type.capitalize(),
):
assert (
result := db_engine_spec.convert_dttm(
target_type=target,
dttm=dttm,
db_extra=db_extra,
)
) == expected_result, result
def assert_column_spec(
db_engine_spec: Type[BaseEngineSpec],
native_type: str,
sqla_type: Type[types.TypeEngine],
attrs: Optional[Dict[str, Any]],
generic_type: GenericDataType,
is_dttm: bool,
) -> None:
assert (column_spec := db_engine_spec.get_column_spec(native_type)) is not None
assert isinstance(column_spec.sqla_type, sqla_type)
for key, value in (attrs or {}).items():
assert getattr(column_spec.sqla_type, key) == value
assert column_spec.generic_type == generic_type
assert column_spec.is_dttm == is_dttm