fix: improve get_db_engine_spec_for_backend (#21171)

* fix: improve get_db_engine_spec_for_backend

* Fix tests

* Fix docs

* fix lint

* fix fallback

* Fix engine validation

* Fix test
This commit is contained in:
Beto Dealmeida 2022-08-29 13:42:42 -05:00 committed by GitHub
parent 710a8ce5c0
commit 8772e2cdb3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 309 additions and 130 deletions

View File

@ -1083,8 +1083,8 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
"preferred": engine_spec.engine_name in preferred_databases,
}
if hasattr(engine_spec, "default_driver"):
payload["default_driver"] = engine_spec.default_driver # type: ignore
if engine_spec.default_driver:
payload["default_driver"] = engine_spec.default_driver
# show configuration parameters for DBs that support it
if (

View File

@ -29,8 +29,7 @@ from superset.databases.commands.exceptions import (
)
from superset.databases.dao import DatabaseDAO
from superset.databases.utils import make_url_safe
from superset.db_engine_specs import get_engine_specs
from superset.db_engine_specs.base import BasicParametersMixin
from superset.db_engine_specs import get_engine_spec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.extensions import event_logger
from superset.models.core import Database
@ -45,25 +44,13 @@ class ValidateDatabaseParametersCommand(BaseCommand):
def run(self) -> None:
engine = self._properties["engine"]
engine_specs = get_engine_specs()
driver = self._properties.get("driver")
if engine in BYPASS_VALIDATION_ENGINES:
# Skip engines that are only validated onCreate
return
if engine not in engine_specs:
raise InvalidEngineError(
SupersetError(
message=__(
'Engine "%(engine)s" is not a valid engine.',
engine=engine,
),
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
level=ErrorLevel.ERROR,
extra={"allowed": list(engine_specs), "provided": engine},
),
)
engine_spec = engine_specs[engine]
engine_spec = get_engine_spec(engine, driver)
if not hasattr(engine_spec, "parameters_schema"):
raise InvalidEngineError(
SupersetError(
@ -73,14 +60,6 @@ class ValidateDatabaseParametersCommand(BaseCommand):
),
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
level=ErrorLevel.ERROR,
extra={
"allowed": [
name
for name, engine_spec in engine_specs.items()
if issubclass(engine_spec, BasicParametersMixin)
],
"provided": engine,
},
),
)

View File

@ -16,7 +16,7 @@
# under the License.
import inspect
import json
from typing import Any, Dict, Optional, Type
from typing import Any, Dict
from flask import current_app
from flask_babel import lazy_gettext as _
@ -28,7 +28,7 @@ from sqlalchemy import MetaData
from superset import db
from superset.databases.commands.exceptions import DatabaseInvalidError
from superset.databases.utils import make_url_safe
from superset.db_engine_specs import BaseEngineSpec, get_engine_specs
from superset.db_engine_specs import get_engine_spec
from superset.exceptions import CertificateException, SupersetSecurityException
from superset.models.core import ConfigurationMethod, Database, PASSWORD_MASK
from superset.security.analytics_db_safety import check_sqlalchemy_uri
@ -150,7 +150,7 @@ def sqlalchemy_uri_validator(value: str) -> str:
[
_(
"Invalid connection string, a valid string usually follows: "
"driver://user:password@database-host/database-name"
"backend+driver://user:password@database-host/database-name"
)
]
) from ex
@ -231,6 +231,7 @@ class DatabaseParametersSchemaMixin: # pylint: disable=too-few-public-methods
"""
engine = fields.String(allow_none=True, description="SQLAlchemy engine to use")
driver = fields.String(allow_none=True, description="SQLAlchemy driver to use")
parameters = fields.Dict(
keys=fields.String(),
values=fields.Raw(),
@ -262,10 +263,20 @@ class DatabaseParametersSchemaMixin: # pylint: disable=too-few-public-methods
or parameters.pop("engine", None)
or data.pop("backend", None)
)
driver = data.pop("driver", None)
configuration_method = data.get("configuration_method")
if configuration_method == ConfigurationMethod.DYNAMIC_FORM:
engine_spec = get_engine_spec(engine)
if not engine:
raise ValidationError(
[
_(
"An engine must be specified when passing "
"individual parameters to a database."
)
]
)
engine_spec = get_engine_spec(engine, driver)
if not hasattr(engine_spec, "build_sqlalchemy_uri") or not hasattr(
engine_spec, "parameters_schema"
@ -295,34 +306,12 @@ class DatabaseParametersSchemaMixin: # pylint: disable=too-few-public-methods
return data
def get_engine_spec(engine: Optional[str]) -> Type[BaseEngineSpec]:
if not engine:
raise ValidationError(
[
_(
"An engine must be specified when passing "
"individual parameters to a database."
)
]
)
engine_specs = get_engine_specs()
if engine not in engine_specs:
raise ValidationError(
[
_(
'Engine "%(engine)s" is not a valid engine.',
engine=engine,
)
]
)
return engine_specs[engine]
class DatabaseValidateParametersSchema(Schema):
class Meta: # pylint: disable=too-few-public-methods
unknown = EXCLUDE
engine = fields.String(required=True, description="SQLAlchemy engine to use")
driver = fields.String(allow_none=True, description="SQLAlchemy driver to use")
parameters = fields.Dict(
keys=fields.String(),
values=fields.Raw(allow_none=True),

View File

@ -33,27 +33,34 @@ import pkgutil
from collections import defaultdict
from importlib import import_module
from pathlib import Path
from typing import Any, Dict, List, Set, Type
from typing import Any, Dict, List, Optional, Set, Type
import sqlalchemy.databases
import sqlalchemy.dialects
from pkg_resources import iter_entry_points
from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy.engine.url import URL
from superset.db_engine_specs.base import BaseEngineSpec
logger = logging.getLogger(__name__)
def is_engine_spec(attr: Any) -> bool:
def is_engine_spec(obj: Any) -> bool:
"""
Return true if a given object is a DB engine spec.
"""
return (
inspect.isclass(attr)
and issubclass(attr, BaseEngineSpec)
and attr != BaseEngineSpec
inspect.isclass(obj)
and issubclass(obj, BaseEngineSpec)
and obj != BaseEngineSpec
)
def load_engine_specs() -> List[Type[BaseEngineSpec]]:
"""
Load all engine specs, native and 3rd party.
"""
engine_specs: List[Type[BaseEngineSpec]] = []
# load standard engines
@ -78,20 +85,31 @@ def load_engine_specs() -> List[Type[BaseEngineSpec]]:
return engine_specs
def get_engine_specs() -> Dict[str, Type[BaseEngineSpec]]:
def get_engine_spec(backend: str, driver: Optional[str] = None) -> Type[BaseEngineSpec]:
"""
Return the DB engine spec associated with a given SQLAlchemy URL.
Note that if a driver is not specified the function returns the first DB engine spec
that supports the backend. Also, if a driver is specified but no DB engine explicitly
supporting that driver exists then a backend-only match is done, in order to allow new
drivers to work with Superset even if they are not listed in the DB engine spec
drivers.
"""
engine_specs = load_engine_specs()
# build map from name/alias -> spec
engine_specs_map: Dict[str, Type[BaseEngineSpec]] = {}
if driver is not None:
for engine_spec in engine_specs:
if engine_spec.supports_backend(backend, driver):
return engine_spec
# check ignoring the driver, in order to support new drivers; this will return a
# random DB engine spec that supports the engine
for engine_spec in engine_specs:
names = [engine_spec.engine]
if engine_spec.engine_aliases:
names.extend(engine_spec.engine_aliases)
if engine_spec.supports_backend(backend):
return engine_spec
for name in names:
engine_specs_map[name] = engine_spec
return engine_specs_map
# default to the generic DB engine spec
return BaseEngineSpec
# there's a mismatch between the dialect name reported by the driver in these

View File

@ -183,9 +183,15 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
having to add the same aggregation in SELECT.
"""
engine_name: Optional[str] = None # for user messages, overridden in child classes
# These attributes map the DB engine spec to one or more SQLAlchemy dialects/drivers;
# see the ``supports_url`` and ``supports_backend`` methods below.
engine = "base" # str as defined in sqlalchemy.engine.engine
engine_aliases: Set[str] = set()
engine_name: Optional[str] = None # for user messages, overridden in child classes
drivers: Dict[str, str] = {}
default_driver: Optional[str] = None
_date_trunc_functions: Dict[str, str] = {}
_time_grain_expressions: Dict[Optional[str], str] = {}
column_type_mappings: Tuple[ColumnTypeMapping, ...] = (
@ -355,6 +361,58 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]
] = {}
@classmethod
def supports_url(cls, url: URL) -> bool:
"""
Returns true if the DB engine spec supports a given SQLAlchemy URL.
As an example, if a given DB engine spec has:
class PostgresDBEngineSpec:
engine = "postgresql"
engine_aliases = "postgres"
drivers = {
"psycopg2": "The default Postgres driver",
"asyncpg": "An asynchronous Postgres driver",
}
It would be used for all the following SQLAlchemy URIs:
- postgres://user:password@host/db
- postgresql://user:password@host/db
- postgres+asyncpg://user:password@host/db
- postgres+psycopg2://user:password@host/db
- postgresql+asyncpg://user:password@host/db
- postgresql+psycopg2://user:password@host/db
Note that SQLAlchemy has a default driver even if one is not specified:
>>> from sqlalchemy.engine.url import make_url
>>> make_url('postgres://').get_driver_name()
'psycopg2'
"""
backend = url.get_backend_name()
driver = url.get_driver_name()
return cls.supports_backend(backend, driver)
@classmethod
def supports_backend(cls, backend: str, driver: Optional[str] = None) -> bool:
"""
Returns true if the DB engine spec supports a given SQLAlchemy backend/driver.
"""
# check the backend first
if backend != cls.engine and backend not in cls.engine_aliases:
return False
# originally DB engine specs didn't declare any drivers and the check was made
# only on the engine; if that's the case, ignore the driver for backwards
# compatibility
if not cls.drivers or driver is None:
return True
return driver in cls.drivers
@classmethod
def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
"""
@ -394,7 +452,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def get_text_clause(cls, clause: str) -> TextClause:
"""
SQLALchemy wrapper to ensure text clauses are escaped properly
SQLAlchemy wrapper to ensure text clauses are escaped properly
:param clause: string clause with potentially unescaped characters
:return: text clause with escaped characters

View File

@ -47,18 +47,23 @@ time_grain_expressions = {
class DatabricksHiveEngineSpec(HiveEngineSpec):
engine = "databricks"
engine_name = "Databricks Interactive Cluster"
driver = "pyhive"
engine = "databricks"
drivers = {"pyhive": "Hive driver for Interactive Cluster"}
default_driver = "pyhive"
_show_functions_column = "function"
_time_grain_expressions = time_grain_expressions
class DatabricksODBCEngineSpec(BaseEngineSpec):
engine = "databricks"
engine_name = "Databricks SQL Endpoint"
driver = "pyodbc"
engine = "databricks"
drivers = {"pyodbc": "ODBC driver for SQL endpoint"}
default_driver = "pyodbc"
_time_grain_expressions = time_grain_expressions
@ -74,9 +79,11 @@ class DatabricksODBCEngineSpec(BaseEngineSpec):
class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec):
engine = "databricks"
engine_name = "Databricks Native Connector"
driver = "connector"
engine = "databricks"
drivers = {"connector": "Native all-purpose driver"}
default_driver = "connector"
@staticmethod
def get_extra_params(database: "Database") -> Dict[str, Any]:

View File

@ -20,7 +20,11 @@ from superset.db_engine_specs.sqlite import SqliteEngineSpec
class ShillelaghEngineSpec(SqliteEngineSpec):
"""Engine for shillelagh"""
engine = "shillelagh"
engine_name = "Shillelagh"
engine = "shillelagh"
drivers = {"apsw": "SQLite driver"}
default_driver = "apsw"
sqlalchemy_uri_placeholder = "shillelagh://"
allows_joins = True
allows_subqueries = True

View File

@ -46,7 +46,7 @@ from sqlalchemy import (
from sqlalchemy.engine import Connection, Dialect, Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlalchemy.exc import ArgumentError
from sqlalchemy.exc import ArgumentError, NoSuchModuleError
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import relationship
from sqlalchemy.pool import NullPool
@ -635,15 +635,20 @@ class Database(
@property
def db_engine_spec(self) -> Type[db_engine_specs.BaseEngineSpec]:
return self.get_db_engine_spec_for_backend(self.backend)
url = make_url_safe(self.sqlalchemy_uri_decrypted)
return self.get_db_engine_spec(url)
@classmethod
@memoized
def get_db_engine_spec_for_backend(
cls, backend: str
) -> Type[db_engine_specs.BaseEngineSpec]:
engines = db_engine_specs.get_engine_specs()
return engines.get(backend, db_engine_specs.BaseEngineSpec)
def get_db_engine_spec(cls, url: URL) -> Type[db_engine_specs.BaseEngineSpec]:
backend = url.get_backend_name()
try:
driver = url.get_driver_name()
except NoSuchModuleError:
# can't load the driver, fallback for backwards compatibility
driver = None
return db_engine_specs.get_engine_spec(backend, driver)
def grains(self) -> Tuple[TimeGrain, ...]:
"""Defines time granularity database-specific expressions.

View File

@ -1425,7 +1425,7 @@ class TestDatabaseApi(SupersetTestCase):
expected_response = {
"errors": [
{
"message": "Could not load database driver: AzureSynapseSpec",
"message": "Could not load database driver: MssqlEngineSpec",
"error_type": "GENERIC_COMMAND_ERROR",
"level": "warning",
"extra": {

View File

@ -20,7 +20,7 @@ from unittest import mock
import pytest
from superset.connectors.sqla.models import TableColumn
from superset.db_engine_specs import get_engine_specs
from superset.db_engine_specs import load_engine_specs
from superset.db_engine_specs.base import (
BaseEngineSpec,
BasicParametersMixin,
@ -195,7 +195,7 @@ class TestDbEngineSpecs(TestDbEngineSpec):
def test_engine_time_grain_validity(self):
time_grains = set(builtin_time_grains.keys())
# loop over all subclasses of BaseEngineSpec
for engine in get_engine_specs().values():
for engine in load_engine_specs():
if engine is not BaseEngineSpec:
# make sure time grain functions have been defined
self.assertGreater(len(engine.get_time_grain_expressions()), 0)

View File

@ -20,7 +20,7 @@ from unittest import mock
from sqlalchemy import column, literal_column
from sqlalchemy.dialects import postgresql
from superset.db_engine_specs import get_engine_specs
from superset.db_engine_specs import load_engine_specs
from superset.db_engine_specs.postgres import PostgresEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.sql_lab import Query
@ -137,7 +137,11 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
"""
DB Eng Specs (postgres): Test "postgres" in engine spec
"""
self.assertIn("postgres", get_engine_specs())
backends = set()
for engine in load_engine_specs():
backends.add(engine.engine)
backends.update(engine.engine_aliases)
assert "postgres" in backends
def test_extras_without_ssl(self):
db = mock.Mock()

View File

@ -15,31 +15,59 @@
# specific language governing permissions and limitations
# under the License.
from unittest import mock
# pylint: disable=import-outside-toplevel, invalid-name, unused-argument, redefined-outer-name
from typing import TYPE_CHECKING
import pytest
from marshmallow import fields, Schema, ValidationError
from pytest_mock import MockFixture
from superset.databases.schemas import DatabaseParametersSchemaMixin
from superset.db_engine_specs.base import BasicParametersMixin
from superset.models.core import ConfigurationMethod
class DummySchema(Schema, DatabaseParametersSchemaMixin):
sqlalchemy_uri = fields.String()
class DummyEngine(BasicParametersMixin):
engine = "dummy"
default_driver = "dummy"
if TYPE_CHECKING:
from superset.databases.schemas import DatabaseParametersSchemaMixin
from superset.db_engine_specs.base import BasicParametersMixin
# pylint: disable=too-few-public-methods
class InvalidEngine:
pass
"""
An invalid DB engine spec.
"""
@mock.patch("superset.databases.schemas.get_engine_specs")
def test_database_parameters_schema_mixin(get_engine_specs):
get_engine_specs.return_value = {"dummy_engine": DummyEngine}
@pytest.fixture
def dummy_schema() -> "DatabaseParametersSchemaMixin":
"""
Fixture providing a dummy schema.
"""
from superset.databases.schemas import DatabaseParametersSchemaMixin
class DummySchema(Schema, DatabaseParametersSchemaMixin):
sqlalchemy_uri = fields.String()
return DummySchema()
@pytest.fixture
def dummy_engine(mocker: MockFixture) -> None:
"""
Fixture proving a dummy DB engine spec.
"""
from superset.db_engine_specs.base import BasicParametersMixin
class DummyEngine(BasicParametersMixin):
engine = "dummy"
default_driver = "dummy"
mocker.patch("superset.databases.schemas.get_engine_spec", return_value=DummyEngine)
def test_database_parameters_schema_mixin(
dummy_engine: None,
dummy_schema: "Schema",
) -> None:
from superset.models.core import ConfigurationMethod
payload = {
"engine": "dummy_engine",
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
@ -51,15 +79,18 @@ def test_database_parameters_schema_mixin(get_engine_specs):
"database": "dbname",
},
}
schema = DummySchema()
result = schema.load(payload)
result = dummy_schema.load(payload)
assert result == {
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
"sqlalchemy_uri": "dummy+dummy://username:password@localhost:12345/dbname",
}
def test_database_parameters_schema_mixin_no_engine():
def test_database_parameters_schema_mixin_no_engine(
dummy_schema: "Schema",
) -> None:
from superset.models.core import ConfigurationMethod
payload = {
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
"parameters": {
@ -67,23 +98,28 @@ def test_database_parameters_schema_mixin_no_engine():
"password": "password",
"host": "localhost",
"port": 12345,
"dbname": "dbname",
"database": "dbname",
},
}
schema = DummySchema()
try:
schema.load(payload)
dummy_schema.load(payload)
except ValidationError as err:
assert err.messages == {
"_schema": [
"An engine must be specified when passing individual parameters to a database."
(
"An engine must be specified when passing individual parameters to "
"a database."
),
]
}
@mock.patch("superset.databases.schemas.get_engine_specs")
def test_database_parameters_schema_mixin_invalid_engine(get_engine_specs):
get_engine_specs.return_value = {}
def test_database_parameters_schema_mixin_invalid_engine(
dummy_engine: None,
dummy_schema: "Schema",
) -> None:
from superset.models.core import ConfigurationMethod
payload = {
"engine": "dummy_engine",
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
@ -92,21 +128,24 @@ def test_database_parameters_schema_mixin_invalid_engine(get_engine_specs):
"password": "password",
"host": "localhost",
"port": 12345,
"dbname": "dbname",
"database": "dbname",
},
}
schema = DummySchema()
try:
schema.load(payload)
dummy_schema.load(payload)
except ValidationError as err:
print(err.messages)
assert err.messages == {
"_schema": ['Engine "dummy_engine" is not a valid engine.']
}
@mock.patch("superset.databases.schemas.get_engine_specs")
def test_database_parameters_schema_no_mixin(get_engine_specs):
get_engine_specs.return_value = {"invalid_engine": InvalidEngine}
def test_database_parameters_schema_no_mixin(
dummy_engine: None,
dummy_schema: "Schema",
) -> None:
from superset.models.core import ConfigurationMethod
payload = {
"engine": "invalid_engine",
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
@ -118,9 +157,8 @@ def test_database_parameters_schema_no_mixin(get_engine_specs):
"database": "dbname",
},
}
schema = DummySchema()
try:
schema.load(payload)
dummy_schema.load(payload)
except ValidationError as err:
assert err.messages == {
"_schema": [
@ -132,9 +170,12 @@ def test_database_parameters_schema_no_mixin(get_engine_specs):
}
@mock.patch("superset.databases.schemas.get_engine_specs")
def test_database_parameters_schema_mixin_invalid_type(get_engine_specs):
get_engine_specs.return_value = {"dummy_engine": DummyEngine}
def test_database_parameters_schema_mixin_invalid_type(
dummy_engine: None,
dummy_schema: "Schema",
) -> None:
from superset.models.core import ConfigurationMethod
payload = {
"engine": "dummy_engine",
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
@ -146,8 +187,7 @@ def test_database_parameters_schema_mixin_invalid_type(get_engine_specs):
"database": "dbname",
},
}
schema = DummySchema()
try:
schema.load(payload)
dummy_schema.load(payload)
except ValidationError as err:
assert err.messages == {"port": ["Not a valid integer."]}

View File

@ -59,7 +59,7 @@ def test_get_metrics(mocker: MockFixture) -> None:
},
]
database.get_db_engine_spec_for_backend = mocker.MagicMock( # type: ignore
database.get_db_engine_spec = mocker.MagicMock( # type: ignore
return_value=CustomSqliteEngineSpec
)
assert database.get_metrics("table") == [
@ -70,3 +70,78 @@ def test_get_metrics(mocker: MockFixture) -> None:
"verbose_name": "COUNT(DISTINCT user_id)",
},
]
def test_get_db_engine_spec(mocker: MockFixture) -> None:
"""
Tests for ``get_db_engine_spec``.
"""
from superset.db_engine_specs import BaseEngineSpec
from superset.models.core import Database
# pylint: disable=abstract-method
class PostgresDBEngineSpec(BaseEngineSpec):
"""
A DB engine spec with drivers and a default driver.
"""
engine = "postgresql"
engine_aliases = {"postgres"}
drivers = {
"psycopg2": "The default Postgres driver",
"asyncpg": "An async Postgres driver",
}
default_driver = "psycopg2"
# pylint: disable=abstract-method
class OldDBEngineSpec(BaseEngineSpec):
"""
And old DB engine spec without drivers nor a default driver.
"""
engine = "mysql"
load_engine_specs = mocker.patch("superset.db_engine_specs.load_engine_specs")
load_engine_specs.return_value = [
PostgresDBEngineSpec,
OldDBEngineSpec,
]
assert (
Database(database_name="db", sqlalchemy_uri="postgresql://").db_engine_spec
== PostgresDBEngineSpec
)
assert (
Database(
database_name="db", sqlalchemy_uri="postgresql+psycopg2://"
).db_engine_spec
== PostgresDBEngineSpec
)
assert (
Database(
database_name="db", sqlalchemy_uri="postgresql+asyncpg://"
).db_engine_spec
== PostgresDBEngineSpec
)
assert (
Database(
database_name="db", sqlalchemy_uri="postgresql+fancynewdriver://"
).db_engine_spec
== PostgresDBEngineSpec
)
assert (
Database(database_name="db", sqlalchemy_uri="mysql://").db_engine_spec
== OldDBEngineSpec
)
assert (
Database(
database_name="db", sqlalchemy_uri="mysql+mysqlconnector://"
).db_engine_spec
== OldDBEngineSpec
)
assert (
Database(
database_name="db", sqlalchemy_uri="mysql+fancynewdriver://"
).db_engine_spec
== OldDBEngineSpec
)