fix: improve get_db_engine_spec_for_backend (#21171)
* fix: improve get_db_engine_spec_for_backend * Fix tests * Fix docs * fix lint * fix fallback * Fix engine validation * Fix test
This commit is contained in:
parent
710a8ce5c0
commit
8772e2cdb3
|
|
@ -1083,8 +1083,8 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
"preferred": engine_spec.engine_name in preferred_databases,
|
||||
}
|
||||
|
||||
if hasattr(engine_spec, "default_driver"):
|
||||
payload["default_driver"] = engine_spec.default_driver # type: ignore
|
||||
if engine_spec.default_driver:
|
||||
payload["default_driver"] = engine_spec.default_driver
|
||||
|
||||
# show configuration parameters for DBs that support it
|
||||
if (
|
||||
|
|
|
|||
|
|
@ -29,8 +29,7 @@ from superset.databases.commands.exceptions import (
|
|||
)
|
||||
from superset.databases.dao import DatabaseDAO
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.db_engine_specs import get_engine_specs
|
||||
from superset.db_engine_specs.base import BasicParametersMixin
|
||||
from superset.db_engine_specs import get_engine_spec
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.extensions import event_logger
|
||||
from superset.models.core import Database
|
||||
|
|
@ -45,25 +44,13 @@ class ValidateDatabaseParametersCommand(BaseCommand):
|
|||
|
||||
def run(self) -> None:
|
||||
engine = self._properties["engine"]
|
||||
engine_specs = get_engine_specs()
|
||||
driver = self._properties.get("driver")
|
||||
|
||||
if engine in BYPASS_VALIDATION_ENGINES:
|
||||
# Skip engines that are only validated onCreate
|
||||
return
|
||||
|
||||
if engine not in engine_specs:
|
||||
raise InvalidEngineError(
|
||||
SupersetError(
|
||||
message=__(
|
||||
'Engine "%(engine)s" is not a valid engine.',
|
||||
engine=engine,
|
||||
),
|
||||
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
|
||||
level=ErrorLevel.ERROR,
|
||||
extra={"allowed": list(engine_specs), "provided": engine},
|
||||
),
|
||||
)
|
||||
engine_spec = engine_specs[engine]
|
||||
engine_spec = get_engine_spec(engine, driver)
|
||||
if not hasattr(engine_spec, "parameters_schema"):
|
||||
raise InvalidEngineError(
|
||||
SupersetError(
|
||||
|
|
@ -73,14 +60,6 @@ class ValidateDatabaseParametersCommand(BaseCommand):
|
|||
),
|
||||
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
|
||||
level=ErrorLevel.ERROR,
|
||||
extra={
|
||||
"allowed": [
|
||||
name
|
||||
for name, engine_spec in engine_specs.items()
|
||||
if issubclass(engine_spec, BasicParametersMixin)
|
||||
],
|
||||
"provided": engine,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
import inspect
|
||||
import json
|
||||
from typing import Any, Dict, Optional, Type
|
||||
from typing import Any, Dict
|
||||
|
||||
from flask import current_app
|
||||
from flask_babel import lazy_gettext as _
|
||||
|
|
@ -28,7 +28,7 @@ from sqlalchemy import MetaData
|
|||
from superset import db
|
||||
from superset.databases.commands.exceptions import DatabaseInvalidError
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.db_engine_specs import BaseEngineSpec, get_engine_specs
|
||||
from superset.db_engine_specs import get_engine_spec
|
||||
from superset.exceptions import CertificateException, SupersetSecurityException
|
||||
from superset.models.core import ConfigurationMethod, Database, PASSWORD_MASK
|
||||
from superset.security.analytics_db_safety import check_sqlalchemy_uri
|
||||
|
|
@ -150,7 +150,7 @@ def sqlalchemy_uri_validator(value: str) -> str:
|
|||
[
|
||||
_(
|
||||
"Invalid connection string, a valid string usually follows: "
|
||||
"driver://user:password@database-host/database-name"
|
||||
"backend+driver://user:password@database-host/database-name"
|
||||
)
|
||||
]
|
||||
) from ex
|
||||
|
|
@ -231,6 +231,7 @@ class DatabaseParametersSchemaMixin: # pylint: disable=too-few-public-methods
|
|||
"""
|
||||
|
||||
engine = fields.String(allow_none=True, description="SQLAlchemy engine to use")
|
||||
driver = fields.String(allow_none=True, description="SQLAlchemy driver to use")
|
||||
parameters = fields.Dict(
|
||||
keys=fields.String(),
|
||||
values=fields.Raw(),
|
||||
|
|
@ -262,10 +263,20 @@ class DatabaseParametersSchemaMixin: # pylint: disable=too-few-public-methods
|
|||
or parameters.pop("engine", None)
|
||||
or data.pop("backend", None)
|
||||
)
|
||||
driver = data.pop("driver", None)
|
||||
|
||||
configuration_method = data.get("configuration_method")
|
||||
if configuration_method == ConfigurationMethod.DYNAMIC_FORM:
|
||||
engine_spec = get_engine_spec(engine)
|
||||
if not engine:
|
||||
raise ValidationError(
|
||||
[
|
||||
_(
|
||||
"An engine must be specified when passing "
|
||||
"individual parameters to a database."
|
||||
)
|
||||
]
|
||||
)
|
||||
engine_spec = get_engine_spec(engine, driver)
|
||||
|
||||
if not hasattr(engine_spec, "build_sqlalchemy_uri") or not hasattr(
|
||||
engine_spec, "parameters_schema"
|
||||
|
|
@ -295,34 +306,12 @@ class DatabaseParametersSchemaMixin: # pylint: disable=too-few-public-methods
|
|||
return data
|
||||
|
||||
|
||||
def get_engine_spec(engine: Optional[str]) -> Type[BaseEngineSpec]:
|
||||
if not engine:
|
||||
raise ValidationError(
|
||||
[
|
||||
_(
|
||||
"An engine must be specified when passing "
|
||||
"individual parameters to a database."
|
||||
)
|
||||
]
|
||||
)
|
||||
engine_specs = get_engine_specs()
|
||||
if engine not in engine_specs:
|
||||
raise ValidationError(
|
||||
[
|
||||
_(
|
||||
'Engine "%(engine)s" is not a valid engine.',
|
||||
engine=engine,
|
||||
)
|
||||
]
|
||||
)
|
||||
return engine_specs[engine]
|
||||
|
||||
|
||||
class DatabaseValidateParametersSchema(Schema):
|
||||
class Meta: # pylint: disable=too-few-public-methods
|
||||
unknown = EXCLUDE
|
||||
|
||||
engine = fields.String(required=True, description="SQLAlchemy engine to use")
|
||||
driver = fields.String(allow_none=True, description="SQLAlchemy driver to use")
|
||||
parameters = fields.Dict(
|
||||
keys=fields.String(),
|
||||
values=fields.Raw(allow_none=True),
|
||||
|
|
|
|||
|
|
@ -33,27 +33,34 @@ import pkgutil
|
|||
from collections import defaultdict
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Set, Type
|
||||
from typing import Any, Dict, List, Optional, Set, Type
|
||||
|
||||
import sqlalchemy.databases
|
||||
import sqlalchemy.dialects
|
||||
from pkg_resources import iter_entry_points
|
||||
from sqlalchemy.engine.default import DefaultDialect
|
||||
from sqlalchemy.engine.url import URL
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_engine_spec(attr: Any) -> bool:
|
||||
def is_engine_spec(obj: Any) -> bool:
|
||||
"""
|
||||
Return true if a given object is a DB engine spec.
|
||||
"""
|
||||
return (
|
||||
inspect.isclass(attr)
|
||||
and issubclass(attr, BaseEngineSpec)
|
||||
and attr != BaseEngineSpec
|
||||
inspect.isclass(obj)
|
||||
and issubclass(obj, BaseEngineSpec)
|
||||
and obj != BaseEngineSpec
|
||||
)
|
||||
|
||||
|
||||
def load_engine_specs() -> List[Type[BaseEngineSpec]]:
|
||||
"""
|
||||
Load all engine specs, native and 3rd party.
|
||||
"""
|
||||
engine_specs: List[Type[BaseEngineSpec]] = []
|
||||
|
||||
# load standard engines
|
||||
|
|
@ -78,20 +85,31 @@ def load_engine_specs() -> List[Type[BaseEngineSpec]]:
|
|||
return engine_specs
|
||||
|
||||
|
||||
def get_engine_specs() -> Dict[str, Type[BaseEngineSpec]]:
|
||||
def get_engine_spec(backend: str, driver: Optional[str] = None) -> Type[BaseEngineSpec]:
|
||||
"""
|
||||
Return the DB engine spec associated with a given SQLAlchemy URL.
|
||||
|
||||
Note that if a driver is not specified the function returns the first DB engine spec
|
||||
that supports the backend. Also, if a driver is specified but no DB engine explicitly
|
||||
supporting that driver exists then a backend-only match is done, in order to allow new
|
||||
drivers to work with Superset even if they are not listed in the DB engine spec
|
||||
drivers.
|
||||
"""
|
||||
engine_specs = load_engine_specs()
|
||||
|
||||
# build map from name/alias -> spec
|
||||
engine_specs_map: Dict[str, Type[BaseEngineSpec]] = {}
|
||||
if driver is not None:
|
||||
for engine_spec in engine_specs:
|
||||
if engine_spec.supports_backend(backend, driver):
|
||||
return engine_spec
|
||||
|
||||
# check ignoring the driver, in order to support new drivers; this will return a
|
||||
# random DB engine spec that supports the engine
|
||||
for engine_spec in engine_specs:
|
||||
names = [engine_spec.engine]
|
||||
if engine_spec.engine_aliases:
|
||||
names.extend(engine_spec.engine_aliases)
|
||||
if engine_spec.supports_backend(backend):
|
||||
return engine_spec
|
||||
|
||||
for name in names:
|
||||
engine_specs_map[name] = engine_spec
|
||||
|
||||
return engine_specs_map
|
||||
# default to the generic DB engine spec
|
||||
return BaseEngineSpec
|
||||
|
||||
|
||||
# there's a mismatch between the dialect name reported by the driver in these
|
||||
|
|
|
|||
|
|
@ -183,9 +183,15 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
having to add the same aggregation in SELECT.
|
||||
"""
|
||||
|
||||
engine_name: Optional[str] = None # for user messages, overridden in child classes
|
||||
|
||||
# These attributes map the DB engine spec to one or more SQLAlchemy dialects/drivers;
|
||||
# see the ``supports_url`` and ``supports_backend`` methods below.
|
||||
engine = "base" # str as defined in sqlalchemy.engine.engine
|
||||
engine_aliases: Set[str] = set()
|
||||
engine_name: Optional[str] = None # for user messages, overridden in child classes
|
||||
drivers: Dict[str, str] = {}
|
||||
default_driver: Optional[str] = None
|
||||
|
||||
_date_trunc_functions: Dict[str, str] = {}
|
||||
_time_grain_expressions: Dict[Optional[str], str] = {}
|
||||
column_type_mappings: Tuple[ColumnTypeMapping, ...] = (
|
||||
|
|
@ -355,6 +361,58 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]
|
||||
] = {}
|
||||
|
||||
@classmethod
|
||||
def supports_url(cls, url: URL) -> bool:
|
||||
"""
|
||||
Returns true if the DB engine spec supports a given SQLAlchemy URL.
|
||||
|
||||
As an example, if a given DB engine spec has:
|
||||
|
||||
class PostgresDBEngineSpec:
|
||||
engine = "postgresql"
|
||||
engine_aliases = "postgres"
|
||||
drivers = {
|
||||
"psycopg2": "The default Postgres driver",
|
||||
"asyncpg": "An asynchronous Postgres driver",
|
||||
}
|
||||
|
||||
It would be used for all the following SQLAlchemy URIs:
|
||||
|
||||
- postgres://user:password@host/db
|
||||
- postgresql://user:password@host/db
|
||||
- postgres+asyncpg://user:password@host/db
|
||||
- postgres+psycopg2://user:password@host/db
|
||||
- postgresql+asyncpg://user:password@host/db
|
||||
- postgresql+psycopg2://user:password@host/db
|
||||
|
||||
Note that SQLAlchemy has a default driver even if one is not specified:
|
||||
|
||||
>>> from sqlalchemy.engine.url import make_url
|
||||
>>> make_url('postgres://').get_driver_name()
|
||||
'psycopg2'
|
||||
|
||||
"""
|
||||
backend = url.get_backend_name()
|
||||
driver = url.get_driver_name()
|
||||
return cls.supports_backend(backend, driver)
|
||||
|
||||
@classmethod
|
||||
def supports_backend(cls, backend: str, driver: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Returns true if the DB engine spec supports a given SQLAlchemy backend/driver.
|
||||
"""
|
||||
# check the backend first
|
||||
if backend != cls.engine and backend not in cls.engine_aliases:
|
||||
return False
|
||||
|
||||
# originally DB engine specs didn't declare any drivers and the check was made
|
||||
# only on the engine; if that's the case, ignore the driver for backwards
|
||||
# compatibility
|
||||
if not cls.drivers or driver is None:
|
||||
return True
|
||||
|
||||
return driver in cls.drivers
|
||||
|
||||
@classmethod
|
||||
def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
|
||||
"""
|
||||
|
|
@ -394,7 +452,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
@classmethod
|
||||
def get_text_clause(cls, clause: str) -> TextClause:
|
||||
"""
|
||||
SQLALchemy wrapper to ensure text clauses are escaped properly
|
||||
SQLAlchemy wrapper to ensure text clauses are escaped properly
|
||||
|
||||
:param clause: string clause with potentially unescaped characters
|
||||
:return: text clause with escaped characters
|
||||
|
|
|
|||
|
|
@ -47,18 +47,23 @@ time_grain_expressions = {
|
|||
|
||||
|
||||
class DatabricksHiveEngineSpec(HiveEngineSpec):
|
||||
engine = "databricks"
|
||||
engine_name = "Databricks Interactive Cluster"
|
||||
driver = "pyhive"
|
||||
|
||||
engine = "databricks"
|
||||
drivers = {"pyhive": "Hive driver for Interactive Cluster"}
|
||||
default_driver = "pyhive"
|
||||
|
||||
_show_functions_column = "function"
|
||||
|
||||
_time_grain_expressions = time_grain_expressions
|
||||
|
||||
|
||||
class DatabricksODBCEngineSpec(BaseEngineSpec):
|
||||
engine = "databricks"
|
||||
engine_name = "Databricks SQL Endpoint"
|
||||
driver = "pyodbc"
|
||||
|
||||
engine = "databricks"
|
||||
drivers = {"pyodbc": "ODBC driver for SQL endpoint"}
|
||||
default_driver = "pyodbc"
|
||||
|
||||
_time_grain_expressions = time_grain_expressions
|
||||
|
||||
|
|
@ -74,9 +79,11 @@ class DatabricksODBCEngineSpec(BaseEngineSpec):
|
|||
|
||||
|
||||
class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec):
|
||||
engine = "databricks"
|
||||
engine_name = "Databricks Native Connector"
|
||||
driver = "connector"
|
||||
|
||||
engine = "databricks"
|
||||
drivers = {"connector": "Native all-purpose driver"}
|
||||
default_driver = "connector"
|
||||
|
||||
@staticmethod
|
||||
def get_extra_params(database: "Database") -> Dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -20,7 +20,11 @@ from superset.db_engine_specs.sqlite import SqliteEngineSpec
|
|||
class ShillelaghEngineSpec(SqliteEngineSpec):
|
||||
"""Engine for shillelagh"""
|
||||
|
||||
engine = "shillelagh"
|
||||
engine_name = "Shillelagh"
|
||||
engine = "shillelagh"
|
||||
drivers = {"apsw": "SQLite driver"}
|
||||
default_driver = "apsw"
|
||||
sqlalchemy_uri_placeholder = "shillelagh://"
|
||||
|
||||
allows_joins = True
|
||||
allows_subqueries = True
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ from sqlalchemy import (
|
|||
from sqlalchemy.engine import Connection, Dialect, Engine
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.exc import ArgumentError
|
||||
from sqlalchemy.exc import ArgumentError, NoSuchModuleError
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
|
@ -635,15 +635,20 @@ class Database(
|
|||
|
||||
@property
|
||||
def db_engine_spec(self) -> Type[db_engine_specs.BaseEngineSpec]:
|
||||
return self.get_db_engine_spec_for_backend(self.backend)
|
||||
url = make_url_safe(self.sqlalchemy_uri_decrypted)
|
||||
return self.get_db_engine_spec(url)
|
||||
|
||||
@classmethod
|
||||
@memoized
|
||||
def get_db_engine_spec_for_backend(
|
||||
cls, backend: str
|
||||
) -> Type[db_engine_specs.BaseEngineSpec]:
|
||||
engines = db_engine_specs.get_engine_specs()
|
||||
return engines.get(backend, db_engine_specs.BaseEngineSpec)
|
||||
def get_db_engine_spec(cls, url: URL) -> Type[db_engine_specs.BaseEngineSpec]:
|
||||
backend = url.get_backend_name()
|
||||
try:
|
||||
driver = url.get_driver_name()
|
||||
except NoSuchModuleError:
|
||||
# can't load the driver, fallback for backwards compatibility
|
||||
driver = None
|
||||
|
||||
return db_engine_specs.get_engine_spec(backend, driver)
|
||||
|
||||
def grains(self) -> Tuple[TimeGrain, ...]:
|
||||
"""Defines time granularity database-specific expressions.
|
||||
|
|
|
|||
|
|
@ -1425,7 +1425,7 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
expected_response = {
|
||||
"errors": [
|
||||
{
|
||||
"message": "Could not load database driver: AzureSynapseSpec",
|
||||
"message": "Could not load database driver: MssqlEngineSpec",
|
||||
"error_type": "GENERIC_COMMAND_ERROR",
|
||||
"level": "warning",
|
||||
"extra": {
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from unittest import mock
|
|||
import pytest
|
||||
|
||||
from superset.connectors.sqla.models import TableColumn
|
||||
from superset.db_engine_specs import get_engine_specs
|
||||
from superset.db_engine_specs import load_engine_specs
|
||||
from superset.db_engine_specs.base import (
|
||||
BaseEngineSpec,
|
||||
BasicParametersMixin,
|
||||
|
|
@ -195,7 +195,7 @@ class TestDbEngineSpecs(TestDbEngineSpec):
|
|||
def test_engine_time_grain_validity(self):
|
||||
time_grains = set(builtin_time_grains.keys())
|
||||
# loop over all subclasses of BaseEngineSpec
|
||||
for engine in get_engine_specs().values():
|
||||
for engine in load_engine_specs():
|
||||
if engine is not BaseEngineSpec:
|
||||
# make sure time grain functions have been defined
|
||||
self.assertGreater(len(engine.get_time_grain_expressions()), 0)
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from unittest import mock
|
|||
from sqlalchemy import column, literal_column
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from superset.db_engine_specs import get_engine_specs
|
||||
from superset.db_engine_specs import load_engine_specs
|
||||
from superset.db_engine_specs.postgres import PostgresEngineSpec
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.models.sql_lab import Query
|
||||
|
|
@ -137,7 +137,11 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
|
|||
"""
|
||||
DB Eng Specs (postgres): Test "postgres" in engine spec
|
||||
"""
|
||||
self.assertIn("postgres", get_engine_specs())
|
||||
backends = set()
|
||||
for engine in load_engine_specs():
|
||||
backends.add(engine.engine)
|
||||
backends.update(engine.engine_aliases)
|
||||
assert "postgres" in backends
|
||||
|
||||
def test_extras_without_ssl(self):
|
||||
db = mock.Mock()
|
||||
|
|
|
|||
|
|
@ -15,31 +15,59 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from unittest import mock
|
||||
# pylint: disable=import-outside-toplevel, invalid-name, unused-argument, redefined-outer-name
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from marshmallow import fields, Schema, ValidationError
|
||||
from pytest_mock import MockFixture
|
||||
|
||||
from superset.databases.schemas import DatabaseParametersSchemaMixin
|
||||
from superset.db_engine_specs.base import BasicParametersMixin
|
||||
from superset.models.core import ConfigurationMethod
|
||||
|
||||
|
||||
class DummySchema(Schema, DatabaseParametersSchemaMixin):
|
||||
sqlalchemy_uri = fields.String()
|
||||
|
||||
|
||||
class DummyEngine(BasicParametersMixin):
|
||||
engine = "dummy"
|
||||
default_driver = "dummy"
|
||||
if TYPE_CHECKING:
|
||||
from superset.databases.schemas import DatabaseParametersSchemaMixin
|
||||
from superset.db_engine_specs.base import BasicParametersMixin
|
||||
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
class InvalidEngine:
|
||||
pass
|
||||
"""
|
||||
An invalid DB engine spec.
|
||||
"""
|
||||
|
||||
|
||||
@mock.patch("superset.databases.schemas.get_engine_specs")
|
||||
def test_database_parameters_schema_mixin(get_engine_specs):
|
||||
get_engine_specs.return_value = {"dummy_engine": DummyEngine}
|
||||
@pytest.fixture
|
||||
def dummy_schema() -> "DatabaseParametersSchemaMixin":
|
||||
"""
|
||||
Fixture providing a dummy schema.
|
||||
"""
|
||||
from superset.databases.schemas import DatabaseParametersSchemaMixin
|
||||
|
||||
class DummySchema(Schema, DatabaseParametersSchemaMixin):
|
||||
sqlalchemy_uri = fields.String()
|
||||
|
||||
return DummySchema()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_engine(mocker: MockFixture) -> None:
|
||||
"""
|
||||
Fixture proving a dummy DB engine spec.
|
||||
"""
|
||||
from superset.db_engine_specs.base import BasicParametersMixin
|
||||
|
||||
class DummyEngine(BasicParametersMixin):
|
||||
engine = "dummy"
|
||||
default_driver = "dummy"
|
||||
|
||||
mocker.patch("superset.databases.schemas.get_engine_spec", return_value=DummyEngine)
|
||||
|
||||
|
||||
def test_database_parameters_schema_mixin(
|
||||
dummy_engine: None,
|
||||
dummy_schema: "Schema",
|
||||
) -> None:
|
||||
from superset.models.core import ConfigurationMethod
|
||||
|
||||
payload = {
|
||||
"engine": "dummy_engine",
|
||||
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
|
||||
|
|
@ -51,15 +79,18 @@ def test_database_parameters_schema_mixin(get_engine_specs):
|
|||
"database": "dbname",
|
||||
},
|
||||
}
|
||||
schema = DummySchema()
|
||||
result = schema.load(payload)
|
||||
result = dummy_schema.load(payload)
|
||||
assert result == {
|
||||
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
|
||||
"sqlalchemy_uri": "dummy+dummy://username:password@localhost:12345/dbname",
|
||||
}
|
||||
|
||||
|
||||
def test_database_parameters_schema_mixin_no_engine():
|
||||
def test_database_parameters_schema_mixin_no_engine(
|
||||
dummy_schema: "Schema",
|
||||
) -> None:
|
||||
from superset.models.core import ConfigurationMethod
|
||||
|
||||
payload = {
|
||||
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
|
||||
"parameters": {
|
||||
|
|
@ -67,23 +98,28 @@ def test_database_parameters_schema_mixin_no_engine():
|
|||
"password": "password",
|
||||
"host": "localhost",
|
||||
"port": 12345,
|
||||
"dbname": "dbname",
|
||||
"database": "dbname",
|
||||
},
|
||||
}
|
||||
schema = DummySchema()
|
||||
try:
|
||||
schema.load(payload)
|
||||
dummy_schema.load(payload)
|
||||
except ValidationError as err:
|
||||
assert err.messages == {
|
||||
"_schema": [
|
||||
"An engine must be specified when passing individual parameters to a database."
|
||||
(
|
||||
"An engine must be specified when passing individual parameters to "
|
||||
"a database."
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@mock.patch("superset.databases.schemas.get_engine_specs")
|
||||
def test_database_parameters_schema_mixin_invalid_engine(get_engine_specs):
|
||||
get_engine_specs.return_value = {}
|
||||
def test_database_parameters_schema_mixin_invalid_engine(
|
||||
dummy_engine: None,
|
||||
dummy_schema: "Schema",
|
||||
) -> None:
|
||||
from superset.models.core import ConfigurationMethod
|
||||
|
||||
payload = {
|
||||
"engine": "dummy_engine",
|
||||
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
|
||||
|
|
@ -92,21 +128,24 @@ def test_database_parameters_schema_mixin_invalid_engine(get_engine_specs):
|
|||
"password": "password",
|
||||
"host": "localhost",
|
||||
"port": 12345,
|
||||
"dbname": "dbname",
|
||||
"database": "dbname",
|
||||
},
|
||||
}
|
||||
schema = DummySchema()
|
||||
try:
|
||||
schema.load(payload)
|
||||
dummy_schema.load(payload)
|
||||
except ValidationError as err:
|
||||
print(err.messages)
|
||||
assert err.messages == {
|
||||
"_schema": ['Engine "dummy_engine" is not a valid engine.']
|
||||
}
|
||||
|
||||
|
||||
@mock.patch("superset.databases.schemas.get_engine_specs")
|
||||
def test_database_parameters_schema_no_mixin(get_engine_specs):
|
||||
get_engine_specs.return_value = {"invalid_engine": InvalidEngine}
|
||||
def test_database_parameters_schema_no_mixin(
|
||||
dummy_engine: None,
|
||||
dummy_schema: "Schema",
|
||||
) -> None:
|
||||
from superset.models.core import ConfigurationMethod
|
||||
|
||||
payload = {
|
||||
"engine": "invalid_engine",
|
||||
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
|
||||
|
|
@ -118,9 +157,8 @@ def test_database_parameters_schema_no_mixin(get_engine_specs):
|
|||
"database": "dbname",
|
||||
},
|
||||
}
|
||||
schema = DummySchema()
|
||||
try:
|
||||
schema.load(payload)
|
||||
dummy_schema.load(payload)
|
||||
except ValidationError as err:
|
||||
assert err.messages == {
|
||||
"_schema": [
|
||||
|
|
@ -132,9 +170,12 @@ def test_database_parameters_schema_no_mixin(get_engine_specs):
|
|||
}
|
||||
|
||||
|
||||
@mock.patch("superset.databases.schemas.get_engine_specs")
|
||||
def test_database_parameters_schema_mixin_invalid_type(get_engine_specs):
|
||||
get_engine_specs.return_value = {"dummy_engine": DummyEngine}
|
||||
def test_database_parameters_schema_mixin_invalid_type(
|
||||
dummy_engine: None,
|
||||
dummy_schema: "Schema",
|
||||
) -> None:
|
||||
from superset.models.core import ConfigurationMethod
|
||||
|
||||
payload = {
|
||||
"engine": "dummy_engine",
|
||||
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
|
||||
|
|
@ -146,8 +187,7 @@ def test_database_parameters_schema_mixin_invalid_type(get_engine_specs):
|
|||
"database": "dbname",
|
||||
},
|
||||
}
|
||||
schema = DummySchema()
|
||||
try:
|
||||
schema.load(payload)
|
||||
dummy_schema.load(payload)
|
||||
except ValidationError as err:
|
||||
assert err.messages == {"port": ["Not a valid integer."]}
|
||||
|
|
@ -59,7 +59,7 @@ def test_get_metrics(mocker: MockFixture) -> None:
|
|||
},
|
||||
]
|
||||
|
||||
database.get_db_engine_spec_for_backend = mocker.MagicMock( # type: ignore
|
||||
database.get_db_engine_spec = mocker.MagicMock( # type: ignore
|
||||
return_value=CustomSqliteEngineSpec
|
||||
)
|
||||
assert database.get_metrics("table") == [
|
||||
|
|
@ -70,3 +70,78 @@ def test_get_metrics(mocker: MockFixture) -> None:
|
|||
"verbose_name": "COUNT(DISTINCT user_id)",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_get_db_engine_spec(mocker: MockFixture) -> None:
|
||||
"""
|
||||
Tests for ``get_db_engine_spec``.
|
||||
"""
|
||||
from superset.db_engine_specs import BaseEngineSpec
|
||||
from superset.models.core import Database
|
||||
|
||||
# pylint: disable=abstract-method
|
||||
class PostgresDBEngineSpec(BaseEngineSpec):
|
||||
"""
|
||||
A DB engine spec with drivers and a default driver.
|
||||
"""
|
||||
|
||||
engine = "postgresql"
|
||||
engine_aliases = {"postgres"}
|
||||
drivers = {
|
||||
"psycopg2": "The default Postgres driver",
|
||||
"asyncpg": "An async Postgres driver",
|
||||
}
|
||||
default_driver = "psycopg2"
|
||||
|
||||
# pylint: disable=abstract-method
|
||||
class OldDBEngineSpec(BaseEngineSpec):
|
||||
"""
|
||||
And old DB engine spec without drivers nor a default driver.
|
||||
"""
|
||||
|
||||
engine = "mysql"
|
||||
|
||||
load_engine_specs = mocker.patch("superset.db_engine_specs.load_engine_specs")
|
||||
load_engine_specs.return_value = [
|
||||
PostgresDBEngineSpec,
|
||||
OldDBEngineSpec,
|
||||
]
|
||||
|
||||
assert (
|
||||
Database(database_name="db", sqlalchemy_uri="postgresql://").db_engine_spec
|
||||
== PostgresDBEngineSpec
|
||||
)
|
||||
assert (
|
||||
Database(
|
||||
database_name="db", sqlalchemy_uri="postgresql+psycopg2://"
|
||||
).db_engine_spec
|
||||
== PostgresDBEngineSpec
|
||||
)
|
||||
assert (
|
||||
Database(
|
||||
database_name="db", sqlalchemy_uri="postgresql+asyncpg://"
|
||||
).db_engine_spec
|
||||
== PostgresDBEngineSpec
|
||||
)
|
||||
assert (
|
||||
Database(
|
||||
database_name="db", sqlalchemy_uri="postgresql+fancynewdriver://"
|
||||
).db_engine_spec
|
||||
== PostgresDBEngineSpec
|
||||
)
|
||||
assert (
|
||||
Database(database_name="db", sqlalchemy_uri="mysql://").db_engine_spec
|
||||
== OldDBEngineSpec
|
||||
)
|
||||
assert (
|
||||
Database(
|
||||
database_name="db", sqlalchemy_uri="mysql+mysqlconnector://"
|
||||
).db_engine_spec
|
||||
== OldDBEngineSpec
|
||||
)
|
||||
assert (
|
||||
Database(
|
||||
database_name="db", sqlalchemy_uri="mysql+fancynewdriver://"
|
||||
).db_engine_spec
|
||||
== OldDBEngineSpec
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue